diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index a02d862..65eb6af 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -23,6 +23,7 @@ import ( "github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream/abs" "github.com/benbjohnson/litestream/gcs" + "github.com/benbjohnson/litestream/http" "github.com/benbjohnson/litestream/s3" "github.com/benbjohnson/litestream/sftp" _ "github.com/mattn/go-sqlite3" @@ -267,6 +268,7 @@ func ReadConfigFile(filename string, expandEnv bool) (_ Config, err error) { // DBConfig represents the configuration for a single database. type DBConfig struct { Path string `yaml:"path"` + Upstream UpstreamConfig `yaml:"upstream"` MonitorDelayInterval *time.Duration `yaml:"monitor-delay-interval"` CheckpointInterval *time.Duration `yaml:"checkpoint-interval"` MinCheckpointPageN *int `yaml:"min-checkpoint-page-count"` @@ -289,6 +291,14 @@ func NewDBFromConfigWithPath(dbc *DBConfig, path string) (*litestream.DB, error) // Initialize database with given path. db := litestream.NewDB(path) + // Attach upstream HTTP client if specified. + if upstreamURL := dbc.Upstream.URL; upstreamURL != "" { + if dbc.Upstream.Path == "" { + return nil, fmt.Errorf("upstream path required") + } + db.StreamClient = http.NewClient(upstreamURL, dbc.Upstream.Path) + } + // Override default database settings if specified in configuration. if dbc.MonitorDelayInterval != nil { db.MonitorDelayInterval = *dbc.MonitorDelayInterval @@ -315,6 +325,11 @@ func NewDBFromConfigWithPath(dbc *DBConfig, path string) (*litestream.DB, error) return db, nil } +type UpstreamConfig struct { + URL string `yaml:"url"` + Path string `yaml:"path"` +} + // ReplicaConfig represents the configuration for a single replica in a database. type ReplicaConfig struct { Type string `yaml:"type"` // "file", "s3" diff --git a/cmd/litestream/replicate.go b/cmd/litestream/replicate.go index e0ae7bd..284fda7 100644 --- a/cmd/litestream/replicate.go +++ b/cmd/litestream/replicate.go @@ -6,19 +6,16 @@ import ( "fmt" "io" "log" - "net" - "net/http" - _ "net/http/pprof" "os" "os/exec" "github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream/abs" "github.com/benbjohnson/litestream/gcs" + "github.com/benbjohnson/litestream/http" "github.com/benbjohnson/litestream/s3" "github.com/benbjohnson/litestream/sftp" "github.com/mattn/go-shellwords" - "github.com/prometheus/client_golang/prometheus/promhttp" ) // ReplicateCommand represents a command that continuously replicates SQLite databases. @@ -35,7 +32,8 @@ type ReplicateCommand struct { Config Config - server *litestream.Server + server *litestream.Server + httpServer *http.Server } // NewReplicateCommand returns a new instance of ReplicateCommand. @@ -143,22 +141,12 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) { } } - // Serve metrics over HTTP if enabled. + // Serve HTTP if enabled. if c.Config.Addr != "" { - hostport := c.Config.Addr - if host, port, _ := net.SplitHostPort(c.Config.Addr); port == "" { - return fmt.Errorf("must specify port for bind address: %q", c.Config.Addr) - } else if host == "" { - hostport = net.JoinHostPort("localhost", port) + c.httpServer = http.NewServer(c.server, c.Config.Addr) + if err := c.httpServer.Open(); err != nil { + return fmt.Errorf("cannot start http server: %w", err) } - - log.Printf("serving metrics on http://%s/metrics", hostport) - go func() { - http.Handle("/metrics", promhttp.Handler()) - if err := http.ListenAndServe(c.Config.Addr, nil); err != nil { - log.Printf("cannot start metrics server: %s", err) - } - }() } // Parse exec commands args & start subprocess. @@ -183,10 +171,17 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) { return nil } -// Close closes all open databases. +// Close closes the HTTP server & all open databases. func (c *ReplicateCommand) Close() (err error) { - if e := c.server.Close(); e != nil && err == nil { - err = e + if c.httpServer != nil { + if e := c.httpServer.Close(); e != nil && err == nil { + err = e + } + } + if c.server != nil { + if e := c.server.Close(); e != nil && err == nil { + err = e + } } return err } diff --git a/db.go b/db.go index ff81c4e..06957e3 100644 --- a/db.go +++ b/db.go @@ -23,6 +23,7 @@ import ( "github.com/pierrec/lz4/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "golang.org/x/sync/errgroup" ) // Default DB settings. @@ -68,7 +69,7 @@ type DB struct { ctx context.Context cancel func() - wg sync.WaitGroup + g errgroup.Group // Metrics dbSizeGauge prometheus.Gauge @@ -83,6 +84,11 @@ type DB struct { checkpointErrorNCounterVec *prometheus.CounterVec checkpointSecondsCounterVec *prometheus.CounterVec + // Client used to receive live, upstream changes. If specified, then + // DB should be used as read-only as local changes will conflict with + // upstream changes. + StreamClient StreamClient + // Minimum threshold of WAL size, in pages, before a passive checkpoint. // A passive checkpoint will attempt a checkpoint but fail if there are // active transactions occurring at the same time. @@ -161,6 +167,11 @@ func (db *DB) WALPath() string { return db.path + "-wal" } +// SHMPath returns the path to the database's shared memory file. +func (db *DB) SHMPath() string { + return db.path + "-shm" +} + // MetaPath returns the path to the database metadata. func (db *DB) MetaPath() string { dir, file := filepath.Split(db.path) @@ -179,6 +190,12 @@ func (db *DB) GenerationPath(generation string) string { return filepath.Join(db.MetaPath(), "generations", generation) } +// PositionPath returns the path of the file that stores the current position. +// This file is only used to communicate state to external processes. +func (db *DB) PositionPath() string { + return filepath.Join(db.MetaPath(), "position") +} + // ShadowWALDir returns the path of the shadow wal directory. // Panics if generation is blank. func (db *DB) ShadowWALDir(generation string) string { @@ -399,9 +416,10 @@ func (db *DB) Open() (err error) { return fmt.Errorf("cannot remove tmp files: %w", err) } - // Start monitoring SQLite database in a separate goroutine. - db.wg.Add(1) - go func() { defer db.wg.Done(); db.monitor() }() + // If an upstream client is specified, then we should simply stream changes + // into the database. If it is not specified, then we should monitor the + // database for local changes and replicate them out. + db.g.Go(func() error { return db.monitor(db.ctx) }) return nil } @@ -410,7 +428,9 @@ func (db *DB) Open() (err error) { // and closes the database. func (db *DB) Close() (err error) { db.cancel() - db.wg.Wait() + if e := db.g.Wait(); e != nil && err == nil { + err = e + } // Start a new context for shutdown since we canceled the DB context. ctx := context.Background() @@ -484,8 +504,8 @@ func (db *DB) UpdatedAt() (time.Time, error) { return t, nil } -// init initializes the connection to the database. -// Skipped if already initialized or if the database file does not exist. +// init initializes the connection to the database. Skipped if already +// initialized or if the database file does not exist. func (db *DB) init() (err error) { // Exit if already initialized. if db.db != nil { @@ -493,17 +513,15 @@ func (db *DB) init() (err error) { } // Exit if no database file exists. - fi, err := os.Stat(db.path) - if os.IsNotExist(err) { + if _, err := os.Stat(db.path); os.IsNotExist(err) { return nil } else if err != nil { return err } - db.fileMode = fi.Mode() - db.uid, db.gid = internal.Fileinfo(fi) // Obtain permissions for parent directory. - if fi, err = os.Stat(filepath.Dir(db.path)); err != nil { + fi, err := os.Stat(filepath.Dir(db.path)) + if err != nil { return err } db.dirMode = fi.Mode() @@ -517,22 +535,6 @@ func (db *DB) init() (err error) { return err } - // Open long-running database file descriptor. Required for non-OFD locks. - if db.f, err = os.Open(db.path); err != nil { - return fmt.Errorf("open db file descriptor: %w", err) - } - - // Ensure database is closed if init fails. - // Initialization can retry on next sync. - defer func() { - if err != nil { - _ = db.releaseReadLock() - db.db.Close() - db.f.Close() - db.db, db.f = nil, nil - } - }() - // Enable WAL and ensure it is set. New mode should be returned on success: // https://www.sqlite.org/pragma.html#pragma_journal_mode var mode string @@ -559,6 +561,30 @@ func (db *DB) init() (err error) { return fmt.Errorf("create _litestream_lock table: %w", err) } + // Open long-running database file descriptor. Required for non-OFD locks. + if db.f, err = os.Open(db.path); err != nil { + return fmt.Errorf("open db file descriptor: %w", err) + } + + // Ensure database is closed if init fails. + // Initialization can retry on next sync. + defer func() { + if err != nil { + _ = db.releaseReadLock() + db.db.Close() + db.f.Close() + db.db, db.f = nil, nil + } + }() + + // Obtain file info once we know the database exists. + fi, err = os.Stat(db.path) + if err != nil { + return fmt.Errorf("init file stat: %w", err) + } + db.fileMode = fi.Mode() + db.uid, db.gid = internal.Fileinfo(fi) + // Start a long-running read transaction to prevent other transactions // from checkpointing. if err := db.acquireReadLock(); err != nil { @@ -603,6 +629,76 @@ func (db *DB) init() (err error) { return nil } +// initReplica initializes a new database file as a replica of an upstream database. +func (db *DB) initReplica(pageSize int) (err error) { + // Exit if already initialized. + if db.db != nil { + return nil + } + + // Obtain permissions for parent directory. + fi, err := os.Stat(filepath.Dir(db.path)) + if err != nil { + return err + } + db.dirMode = fi.Mode() + + dsn := db.path + dsn += fmt.Sprintf("?_busy_timeout=%d", BusyTimeout.Milliseconds()) + + // Connect to SQLite database. Use the driver registered with a hook to + // prevent WAL files from being removed. + if db.db, err = sql.Open("litestream-sqlite3", dsn); err != nil { + return err + } + + // Initialize database file if it doesn't exist. It doesn't matter what we + // store in it as it will be erased by the replication. We just need to + // ensure a WAL file is created and there is at least a page in the database. + if _, err := os.Stat(db.path); os.IsNotExist(err) { + if _, err := db.db.ExecContext(db.ctx, fmt.Sprintf(`PRAGMA page_size = %d`, pageSize)); err != nil { + return fmt.Errorf("set page size: %w", err) + } + + var mode string + if err := db.db.QueryRow(`PRAGMA journal_mode = wal`).Scan(&mode); err != nil { + return err + } else if mode != "wal" { + return fmt.Errorf("enable wal failed, mode=%q", mode) + } + + // TODO: Set page size. + + if _, err := db.db.ExecContext(db.ctx, `CREATE TABLE IF NOT EXISTS _litestream (id INTEGER)`); err != nil { + return fmt.Errorf("create _litestream table: %w", err) + } else if _, err := db.db.ExecContext(db.ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { + return fmt.Errorf("create _litestream table: %w", err) + } + } + + // Obtain file info once we know the database exists. + fi, err = os.Stat(db.path) + if err != nil { + return fmt.Errorf("init file stat: %w", err) + } + db.fileMode = fi.Mode() + db.uid, db.gid = internal.Fileinfo(fi) + + // Verify page size matches. + if err := db.db.QueryRowContext(db.ctx, `PRAGMA page_size;`).Scan(&db.pageSize); err != nil { + return fmt.Errorf("read page size: %w", err) + } else if db.pageSize != pageSize { + return fmt.Errorf("page size mismatch: %d <> %d", db.pageSize, pageSize) + } + + // Ensure meta directory structure exists. + if err := internal.MkdirAll(db.MetaPath(), db.dirMode, db.uid, db.gid); err != nil { + return err + } + + return nil +} + func (db *DB) clearGeneration(ctx context.Context) error { if err := os.Remove(db.GenerationNamePath()); err != nil && !os.IsNotExist(err) { return err @@ -1278,6 +1374,11 @@ func (db *DB) writeWALSegment(ctx context.Context, pos Pos, rd io.Reader) error return err } + // Write position to file so other processes can read it. + if err := db.writePositionFile(pos); err != nil { + return fmt.Errorf("write position file: %w", err) + } + // Generate info := WALSegmentInfo{ Generation: pos.Generation, @@ -1308,6 +1409,11 @@ func (db *DB) writeWALSegment(ctx context.Context, pos Pos, rd io.Reader) error return nil } +// writePositionFile writes pos as the current position. +func (db *DB) writePositionFile(pos Pos) error { + return internal.WriteFile(db.PositionPath(), []byte(pos.String()+"\n"), db.fileMode, db.uid, db.gid) +} + // WALSegments returns an iterator over all available WAL files for a generation. func (db *DB) WALSegments(ctx context.Context, generation string) (*FileWALSegmentIterator, error) { db.mu.Lock() @@ -1499,20 +1605,26 @@ func (db *DB) execCheckpoint(mode string) (err error) { return nil } -// monitor runs in a separate goroutine and monitors the database & WAL. -func (db *DB) monitor() { - var timer *time.Timer +func (db *DB) monitor(ctx context.Context) error { + if db.StreamClient != nil { + return db.monitorUpstream(ctx) + } + return db.monitorLocal(ctx) +} +// monitor runs in a separate goroutine and monitors the local database & WAL. +func (db *DB) monitorLocal(ctx context.Context) error { + var timer *time.Timer if db.MonitorDelayInterval > 0 { - timer := time.NewTimer(db.MonitorDelayInterval) + timer = time.NewTimer(db.MonitorDelayInterval) defer timer.Stop() } for { // Wait for a file change notification from the file system. select { - case <-db.ctx.Done(): - return + case <-ctx.Done(): + return nil case <-db.notifyCh: } @@ -1528,12 +1640,193 @@ func (db *DB) monitor() { default: } - if err := db.Sync(db.ctx); err != nil && !errors.Is(err, context.Canceled) { + if err := db.Sync(ctx); err != nil && !errors.Is(err, context.Canceled) { db.Logger.Printf("sync error: %s", err) } } } +// monitorUpstream runs in a separate goroutine and streams data into the local DB. +func (db *DB) monitorUpstream(ctx context.Context) error { + for { + if err := db.stream(ctx); err != nil { + if ctx.Err() != nil { + return nil + } + db.Logger.Printf("stream error, retrying: %s", err) + } + + // Delay before retrying stream. + select { + case <-ctx.Done(): + return nil + case <-time.After(1 * time.Second): + } + } +} + +// stream initializes the local database and continuously streams new upstream data. +func (db *DB) stream(ctx context.Context) error { + // Continuously stream and apply records from client. + sr, err := db.StreamClient.Stream(ctx) + if err != nil { + return fmt.Errorf("stream connect: %w", err) + } + defer sr.Close() + + // TODO: Determine page size of upstream database before creating local. + const pageSize = 4096 + + // Initialize the database and create it if it doesn't exist. + if err := db.initReplica(pageSize); err != nil { + return fmt.Errorf("init replica: %w", err) + } + + for { + hdr, err := sr.Next() + if err != nil { + return err + } + + switch hdr.Type { + case StreamRecordTypeSnapshot: + if err := db.streamSnapshot(ctx, hdr, sr); err != nil { + return fmt.Errorf("snapshot: %w", err) + } + case StreamRecordTypeWALSegment: + if err := db.streamWALSegment(ctx, hdr, sr); err != nil { + return fmt.Errorf("wal segment: %w", err) + } + default: + return fmt.Errorf("invalid stream record type: 0x%02x", hdr.Type) + } + } +} + +// streamSnapshot reads the snapshot into the WAL and applies it to the main database. +func (db *DB) streamSnapshot(ctx context.Context, hdr *StreamRecordHeader, r io.Reader) error { + // Truncate WAL file. + if _, err := db.db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { + return fmt.Errorf("truncate: %w", err) + } + + // Determine total page count. + pageN := int(hdr.Size / int64(db.pageSize)) + + ww := NewWALWriter(db.WALPath(), db.fileMode, db.pageSize) + if err := ww.Open(); err != nil { + return fmt.Errorf("open wal writer: %w", err) + } + defer func() { _ = ww.Close() }() + + if err := ww.WriteHeader(); err != nil { + return fmt.Errorf("write wal header: %w", err) + } + + // Iterate over pages + buf := make([]byte, db.pageSize) + for pgno := uint32(1); ; pgno++ { + // Read snapshot page into a buffer. + if _, err := io.ReadFull(r, buf); err == io.EOF { + break + } else if err != nil { + return fmt.Errorf("read snapshot page %d: %w", pgno, err) + } + + // Issue a commit flag when the last page is reached. + var commit uint32 + if pgno == uint32(pageN) { + commit = uint32(pageN) + } + + // Write page into WAL frame. + if err := ww.WriteFrame(pgno, commit, buf); err != nil { + return fmt.Errorf("write wal frame: %w", err) + } + } + + // Close WAL file writer. + if err := ww.Close(); err != nil { + return fmt.Errorf("close wal writer: %w", err) + } + + // Invalidate WAL index. + if err := invalidateSHMFile(db.path); err != nil { + return fmt.Errorf("invalidate shm file: %w", err) + } + + // Write position to file so other processes can read it. + if err := db.writePositionFile(hdr.Pos()); err != nil { + return fmt.Errorf("write position file: %w", err) + } + + db.Logger.Printf("snapshot applied") + + return nil +} + +// streamWALSegment rewrites a WAL segment into the local WAL and applies it to the main database. +func (db *DB) streamWALSegment(ctx context.Context, hdr *StreamRecordHeader, r io.Reader) error { + // Decompress incoming segment + zr := lz4.NewReader(r) + + // Drop WAL header if starting from offset zero. + if hdr.Offset == 0 { + if _, err := io.CopyN(io.Discard, zr, WALHeaderSize); err != nil { + return fmt.Errorf("read wal header: %w", err) + } + } + + ww := NewWALWriter(db.WALPath(), db.fileMode, db.pageSize) + if err := ww.Open(); err != nil { + return fmt.Errorf("open wal writer: %w", err) + } + defer func() { _ = ww.Close() }() + + if err := ww.WriteHeader(); err != nil { + return fmt.Errorf("write wal header: %w", err) + } + + // Iterate over incoming WAL pages. + buf := make([]byte, WALFrameHeaderSize+db.pageSize) + for i := 0; ; i++ { + // Read snapshot page into a buffer. + if _, err := io.ReadFull(zr, buf); err == io.EOF { + break + } else if err != nil { + return fmt.Errorf("read wal frame %d: %w", i, err) + } + + // Read page number & commit field. + pgno := binary.BigEndian.Uint32(buf[0:]) + commit := binary.BigEndian.Uint32(buf[4:]) + + // Write page into WAL frame. + if err := ww.WriteFrame(pgno, commit, buf[WALFrameHeaderSize:]); err != nil { + return fmt.Errorf("write wal frame: %w", err) + } + } + + // Close WAL file writer. + if err := ww.Close(); err != nil { + return fmt.Errorf("close wal writer: %w", err) + } + + // Invalidate WAL index. + if err := invalidateSHMFile(db.path); err != nil { + return fmt.Errorf("invalidate shm file: %w", err) + } + + // Write position to file so other processes can read it. + if err := db.writePositionFile(hdr.Pos()); err != nil { + return fmt.Errorf("write position file: %w", err) + } + + db.Logger.Printf("wal segment applied: %s", hdr.Pos().String()) + + return nil +} + // ApplyWAL performs a truncating checkpoint on the given database. func ApplyWAL(ctx context.Context, dbPath, walPath string) error { // Copy WAL file from it's staging path to the correct "-wal" location. @@ -1681,6 +1974,51 @@ func logPrefixPath(path string) string { return path } +// invalidateSHMFile clears the iVersion field of the -shm file in order that +// the next transaction will rebuild it. +func invalidateSHMFile(dbPath string) error { + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + return fmt.Errorf("reopen db: %w", err) + } + defer func() { _ = db.Close() }() + + if _, err := db.Exec(`PRAGMA wal_checkpoint(PASSIVE)`); err != nil { + return fmt.Errorf("passive checkpoint: %w", err) + } + + f, err := os.OpenFile(dbPath+"-shm", os.O_RDWR, 0666) + if err != nil { + return fmt.Errorf("open shm index: %w", err) + } + defer f.Close() + + buf := make([]byte, WALIndexHeaderSize) + if _, err := io.ReadFull(f, buf); err != nil { + return fmt.Errorf("read shm index: %w", err) + } + + // Invalidate "isInit" fields. + buf[12], buf[60] = 0, 0 + + // Rewrite header. + if _, err := f.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("seek shm index: %w", err) + } else if _, err := f.Write(buf); err != nil { + return fmt.Errorf("overwrite shm index: %w", err) + } else if err := f.Close(); err != nil { + return fmt.Errorf("close shm index: %w", err) + } + + // Truncate WAL file again. + var row [3]int + if err := db.QueryRow(`PRAGMA wal_checkpoint(TRUNCATE)`).Scan(&row[0], &row[1], &row[2]); err != nil { + return fmt.Errorf("truncate: %w", err) + } + + return nil +} + // A marker error to indicate that a restart checkpoint could not verify // continuity between WAL indices and a new generation should be started. var errRestartGeneration = errors.New("restart generation") diff --git a/db_bsd.go b/db_bsd.go new file mode 100644 index 0000000..9c0c6bb --- /dev/null +++ b/db_bsd.go @@ -0,0 +1,21 @@ +//go:build !linux + +package litestream + +import ( + "io" + "os" +) + +// WithFile executes fn with a file handle for the main database file. +// On Linux, this is a unique file handle for each call. On non-Linux +// systems, the file handle is shared because of lock semantics. +func (db *DB) WithFile(fn func(f *os.File) error) error { + db.mu.Lock() + defer db.mu.Unlock() + + if _, err := db.f.Seek(0, io.SeekStart); err != nil { + return err + } + return fn(db.f) +} diff --git a/db_linux.go b/db_linux.go new file mode 100644 index 0000000..b669109 --- /dev/null +++ b/db_linux.go @@ -0,0 +1,18 @@ +//go:build linux + +package litestream + +import "os" + +// WithFile executes fn with a file handle for the main database file. +// On Linux, this is a unique file handle for each call. On non-Linux +// systems, the file handle is shared because of lock semantics. +func (db *DB) WithFile(fn func(f *os.File) error) error { + f, err := os.Open(db.path) + if err != nil { + return err + } + defer f.Close() + + return fn(f) +} diff --git a/http/http.go b/http/http.go new file mode 100644 index 0000000..fc35f0f --- /dev/null +++ b/http/http.go @@ -0,0 +1,373 @@ +package http + +import ( + "context" + "fmt" + "io" + "log" + "net" + "net/http" + httppprof "net/http/pprof" + "net/url" + "os" + "strings" + + "github.com/benbjohnson/litestream" + "github.com/prometheus/client_golang/prometheus/promhttp" + "golang.org/x/sync/errgroup" +) + +// Server represents an HTTP API server for Litestream. +type Server struct { + ln net.Listener + closed bool + + httpServer *http.Server + promHandler http.Handler + + addr string + server *litestream.Server + + g errgroup.Group + + Logger *log.Logger +} + +func NewServer(server *litestream.Server, addr string) *Server { + s := &Server{ + addr: addr, + server: server, + Logger: log.New(os.Stderr, "http: ", litestream.LogFlags), + } + + s.promHandler = promhttp.Handler() + s.httpServer = &http.Server{ + Handler: http.HandlerFunc(s.serveHTTP), + } + return s +} + +func (s *Server) Open() (err error) { + if s.ln, err = net.Listen("tcp", s.addr); err != nil { + return err + } + + s.g.Go(func() error { + if err := s.httpServer.Serve(s.ln); err != nil && !s.closed { + return err + } + return nil + }) + + return nil +} + +func (s *Server) Close() (err error) { + s.closed = true + + if s.ln != nil { + if e := s.ln.Close(); e != nil && err == nil { + err = e + } + } + + if e := s.g.Wait(); e != nil && err == nil { + err = e + } + return err +} + +// Port returns the port the listener is running on. +func (s *Server) Port() int { + if s.ln == nil { + return 0 + } + return s.ln.Addr().(*net.TCPAddr).Port +} + +// URL returns the full base URL for the running server. +func (s *Server) URL() string { + host, _, _ := net.SplitHostPort(s.addr) + if host == "" { + host = "localhost" + } + return fmt.Sprintf("http://%s", net.JoinHostPort(host, fmt.Sprint(s.Port()))) +} + +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/debug/pprof") { + switch r.URL.Path { + case "/debug/pprof/cmdline": + httppprof.Cmdline(w, r) + case "/debug/pprof/profile": + httppprof.Profile(w, r) + case "/debug/pprof/symbol": + httppprof.Symbol(w, r) + case "/debug/pprof/trace": + httppprof.Trace(w, r) + default: + httppprof.Index(w, r) + } + return + } + + switch r.URL.Path { + case "/metrics": + s.promHandler.ServeHTTP(w, r) + + case "/stream": + switch r.Method { + case http.MethodGet: + s.handleGetStream(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + default: + http.NotFound(w, r) + } +} + +func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + // TODO: Listen for all databases matching query criteria. + path := q.Get("path") + if path == "" { + http.Error(w, "Database name required", http.StatusBadRequest) + return + } + db := s.server.DB(path) + if db == nil { + http.Error(w, "Database not found", http.StatusNotFound) + return + } + + // TODO: Restart stream from a previous position, if specified. + + // Determine starting position. + pos := db.Pos() + if pos.Generation == "" { + http.Error(w, "No generation available", http.StatusServiceUnavailable) + return + } + pos.Offset = 0 + + s.Logger.Printf("stream connected @ %s", pos) + defer s.Logger.Printf("stream disconnected") + + // Obtain iterator before snapshot so we don't miss any WAL segments. + itr, err := db.WALSegments(r.Context(), pos.Generation) + if err != nil { + http.Error(w, fmt.Sprintf("Cannot obtain WAL iterator: %s", err), http.StatusInternalServerError) + return + } + defer itr.Close() + + // Write snapshot to response body. + if err := db.WithFile(func(f *os.File) error { + fi, err := f.Stat() + if err != nil { + return err + } + + // Write snapshot header with current position & size. + hdr := litestream.StreamRecordHeader{ + Type: litestream.StreamRecordTypeSnapshot, + Generation: pos.Generation, + Index: pos.Index, + Size: fi.Size(), + } + if buf, err := hdr.MarshalBinary(); err != nil { + return fmt.Errorf("marshal snapshot stream record header: %w", err) + } else if _, err := w.Write(buf); err != nil { + return fmt.Errorf("write snapshot stream record header: %w", err) + } + + if _, err := io.CopyN(w, f, fi.Size()); err != nil { + return fmt.Errorf("copy snapshot: %w", err) + } + + return nil + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Flush after snapshot has been written. + w.(http.Flusher).Flush() + + for { + // Wait for notification of new entries. + select { + case <-r.Context().Done(): + return + case <-itr.NotifyCh(): + } + + for itr.Next() { + info := itr.WALSegment() + + // Skip any segments before our initial position. + if cmp, err := litestream.ComparePos(info.Pos(), pos); err != nil { + s.Logger.Printf("pos compare: %s", err) + return + } else if cmp == -1 { + continue + } + + hdr := litestream.StreamRecordHeader{ + Type: litestream.StreamRecordTypeWALSegment, + Flags: 0, + Generation: info.Generation, + Index: info.Index, + Offset: info.Offset, + Size: info.Size, + } + + // Write record header. + data, err := hdr.MarshalBinary() + if err != nil { + s.Logger.Printf("marshal WAL segment stream record header: %s", err) + return + } else if _, err := w.Write(data); err != nil { + s.Logger.Printf("write WAL segment stream record header: %s", err) + return + } + + // Copy WAL segment data to writer. + if err := func() error { + rd, err := db.WALSegmentReader(r.Context(), info.Pos()) + if err != nil { + return fmt.Errorf("cannot fetch wal segment reader: %w", err) + } + defer rd.Close() + + if _, err := io.CopyN(w, rd, hdr.Size); err != nil { + return fmt.Errorf("cannot copy wal segment: %w", err) + } + return nil + }(); err != nil { + log.Print(err) + return + } + + // Flush after WAL segment has been written. + w.(http.Flusher).Flush() + } + if itr.Err() != nil { + s.Logger.Printf("wal iterator error: %s", err) + return + } + } +} + +type Client struct { + // Upstream endpoint + URL string + + // Path of database on upstream server. + Path string + + // Underlying HTTP client + HTTPClient *http.Client +} + +func NewClient(rawurl, path string) *Client { + return &Client{ + URL: rawurl, + Path: path, + HTTPClient: http.DefaultClient, + } +} + +func (c *Client) Stream(ctx context.Context) (litestream.StreamReader, error) { + u, err := url.Parse(c.URL) + if err != nil { + return nil, fmt.Errorf("invalid client URL: %w", err) + } else if u.Scheme != "http" && u.Scheme != "https" { + return nil, fmt.Errorf("invalid URL scheme") + } else if u.Host == "" { + return nil, fmt.Errorf("URL host required") + } + + // Strip off everything but the scheme & host. + *u = url.URL{ + Scheme: u.Scheme, + Host: u.Host, + Path: "/stream", + RawQuery: (url.Values{ + "path": []string{c.Path}, + }).Encode(), + } + + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, err + } else if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("invalid response: code=%d", resp.StatusCode) + } + + return &StreamReader{ + body: resp.Body, + file: io.LimitedReader{R: resp.Body}, + }, nil +} + +type StreamReader struct { + body io.ReadCloser + file io.LimitedReader + err error +} + +func (r *StreamReader) Close() error { + if e := r.body.Close(); e != nil && r.err == nil { + r.err = e + } + return r.err +} + +func (r *StreamReader) Read(p []byte) (int, error) { + if r.err != nil { + return 0, r.err + } else if r.file.R == nil { + return 0, io.EOF + } + return r.file.Read(p) +} + +func (r *StreamReader) Next() (*litestream.StreamRecordHeader, error) { + if r.err != nil { + return nil, r.err + } + + // If bytes remain on the current file, discard. + if r.file.N > 0 { + if _, r.err = io.Copy(io.Discard, &r.file); r.err != nil { + return nil, r.err + } + } + + // Read record header. + buf := make([]byte, litestream.StreamRecordHeaderSize) + if _, err := io.ReadFull(r.body, buf); err != nil { + r.err = fmt.Errorf("http.StreamReader.Next(): %w", err) + return nil, r.err + } + + var hdr litestream.StreamRecordHeader + if r.err = hdr.UnmarshalBinary(buf); r.err != nil { + return nil, r.err + } + + // Update remaining bytes on file reader. + r.file.N = hdr.Size + + return &hdr, nil +} diff --git a/integration/cmd_test.go b/integration/cmd_test.go index a70ad53..3cab314 100644 --- a/integration/cmd_test.go +++ b/integration/cmd_test.go @@ -391,6 +391,69 @@ LOOP: restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db")) } +// Ensure a database can be replicated over HTTP. +func TestCmd_Replicate_HTTP(t *testing.T) { + ctx := context.Background() + testDir, tempDir := filepath.Join("testdata", "replicate", "http"), t.TempDir() + if err := os.Mkdir(filepath.Join(tempDir, "0"), 0777); err != nil { + t.Fatal(err) + } else if err := os.Mkdir(filepath.Join(tempDir, "1"), 0777); err != nil { + t.Fatal(err) + } + + env0 := []string{"LITESTREAM_TEMPDIR=" + tempDir} + env1 := []string{"LITESTREAM_TEMPDIR=" + tempDir, "LITESTREAM_UPSTREAM_URL=http://localhost:10001"} + + cmd0, stdout0, _ := commandContext(ctx, env0, "replicate", "-config", filepath.Join(testDir, "litestream.0.yml")) + if err := cmd0.Start(); err != nil { + t.Fatal(err) + } + cmd1, stdout1, _ := commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml")) + if err := cmd1.Start(); err != nil { + t.Fatal(err) + } + + db0, err := sql.Open("sqlite3", filepath.Join(tempDir, "0", "db")) + if err != nil { + t.Fatal(err) + } else if _, err := db0.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil { + t.Fatal(err) + } else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { + t.Fatal(err) + } + defer db0.Close() + + // Execute writes periodically. + for i := 0; i < 100; i++ { + t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", i) + if _, err := db0.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, i); err != nil { + t.Fatal(err) + } + time.Sleep(100 * time.Millisecond) + } + + // Wait for replica to catch up. + time.Sleep(1 * time.Second) + + // Verify count in replica table. + db1, err := sql.Open("sqlite3", filepath.Join(tempDir, "1", "db")) + if err != nil { + t.Fatal(err) + } + defer db1.Close() + + var n int + if err := db1.QueryRowContext(ctx, `SELECT COUNT(*) FROM t`).Scan(&n); err != nil { + t.Fatal(err) + } else if got, want := n, 100; got != want { + t.Fatalf("replica count=%d, want %d", got, want) + } + + // Stop & wait for Litestream command. + killLitestreamCmd(t, cmd1, stdout1) // kill + killLitestreamCmd(t, cmd0, stdout0) +} + // commandContext returns a "litestream" command with stdout/stderr buffers. func commandContext(ctx context.Context, env []string, arg ...string) (cmd *exec.Cmd, stdout, stderr *internal.LockingBuffer) { cmd = exec.CommandContext(ctx, "litestream", arg...) @@ -428,6 +491,7 @@ func waitForLogMessage(tb testing.TB, b *internal.LockingBuffer, msg string) { // killLitestreamCmd interrupts the process and waits for a clean shutdown. func killLitestreamCmd(tb testing.TB, cmd *exec.Cmd, stdout *internal.LockingBuffer) { + tb.Helper() if err := cmd.Process.Signal(os.Interrupt); err != nil { tb.Fatal("kill litestream: signal:", err) } else if err := cmd.Wait(); err != nil { diff --git a/integration/testdata/replicate/http/litestream.0.yml b/integration/testdata/replicate/http/litestream.0.yml new file mode 100644 index 0000000..e30e651 --- /dev/null +++ b/integration/testdata/replicate/http/litestream.0.yml @@ -0,0 +1,5 @@ +addr: :10001 + +dbs: + - path: $LITESTREAM_TEMPDIR/0/db + max-checkpoint-page-count: 20 diff --git a/integration/testdata/replicate/http/litestream.1.yml b/integration/testdata/replicate/http/litestream.1.yml new file mode 100644 index 0000000..e973570 --- /dev/null +++ b/integration/testdata/replicate/http/litestream.1.yml @@ -0,0 +1,5 @@ +dbs: + - path: $LITESTREAM_TEMPDIR/1/db + upstream: + url: "$LITESTREAM_UPSTREAM_URL" + path: "$LITESTREAM_TEMPDIR/0/db" diff --git a/internal/internal.go b/internal/internal.go index 0c70d4d..c671379 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -113,6 +113,21 @@ func CreateFile(filename string, mode os.FileMode, uid, gid int) (*os.File, erro return f, nil } +// WriteFile writes data to a named file and sets the mode & uid/gid. +func WriteFile(name string, data []byte, perm os.FileMode, uid, gid int) error { + f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) + if err != nil { + return err + } + _ = f.Chown(uid, gid) + + _, err = f.Write(data) + if err1 := f.Close(); err1 != nil && err == nil { + err = err1 + } + return err +} + // MkdirAll is a copy of os.MkdirAll() except that it attempts to set the // mode/uid/gid to match fi for each created directory. func MkdirAll(path string, mode os.FileMode, uid, gid int) error { diff --git a/litestream.go b/litestream.go index d8cb9e8..829e0ff 100644 --- a/litestream.go +++ b/litestream.go @@ -1,6 +1,7 @@ package litestream import ( + "context" "database/sql" "encoding/binary" "errors" @@ -357,6 +358,9 @@ const ( // WALFrameHeaderSize is the size of the WAL frame header, in bytes. WALFrameHeaderSize = 24 + + // WALIndexHeaderSize is the size of the SHM index header, in bytes. + WALIndexHeaderSize = 136 ) // calcWALSize returns the size of the WAL, in bytes, for a given number of pages. @@ -462,6 +466,73 @@ func ParseOffset(s string) (int64, error) { return int64(v), nil } +const ( + StreamRecordTypeSnapshot = 1 + StreamRecordTypeWALSegment = 2 +) + +const StreamRecordHeaderSize = 0 + + 4 + 4 + // type, flags + 8 + 8 + 8 + 8 // generation, index, offset, size + +type StreamRecordHeader struct { + Type int + Flags int + Generation string + Index int + Offset int64 + Size int64 +} + +func (hdr *StreamRecordHeader) Pos() Pos { + return Pos{ + Generation: hdr.Generation, + Index: hdr.Index, + Offset: hdr.Offset, + } +} + +func (hdr *StreamRecordHeader) MarshalBinary() ([]byte, error) { + generation, err := strconv.ParseUint(hdr.Generation, 16, 64) + if err != nil { + return nil, fmt.Errorf("invalid generation: %q", generation) + } + + data := make([]byte, StreamRecordHeaderSize) + binary.BigEndian.PutUint32(data[0:4], uint32(hdr.Type)) + binary.BigEndian.PutUint32(data[4:8], uint32(hdr.Flags)) + binary.BigEndian.PutUint64(data[8:16], generation) + binary.BigEndian.PutUint64(data[16:24], uint64(hdr.Index)) + binary.BigEndian.PutUint64(data[24:32], uint64(hdr.Offset)) + binary.BigEndian.PutUint64(data[32:40], uint64(hdr.Size)) + return data, nil +} + +// UnmarshalBinary from data into hdr. +func (hdr *StreamRecordHeader) UnmarshalBinary(data []byte) error { + if len(data) < StreamRecordHeaderSize { + return io.ErrUnexpectedEOF + } + hdr.Type = int(binary.BigEndian.Uint32(data[0:4])) + hdr.Flags = int(binary.BigEndian.Uint32(data[4:8])) + hdr.Generation = fmt.Sprintf("%16x", binary.BigEndian.Uint64(data[8:16])) + hdr.Index = int(binary.BigEndian.Uint64(data[16:24])) + hdr.Offset = int64(binary.BigEndian.Uint64(data[24:32])) + hdr.Size = int64(binary.BigEndian.Uint64(data[32:40])) + return nil +} + +// StreamClient represents a client for streaming changes to a replica DB. +type StreamClient interface { + Stream(ctx context.Context) (StreamReader, error) +} + +// StreamReader represents a reader that streams snapshot and WAL records. +type StreamReader interface { + io.ReadCloser + Next() (*StreamRecordHeader, error) +} + // removeDBFiles deletes the database and related files (journal, shm, wal). func removeDBFiles(filename string) error { if err := os.Remove(filename); err != nil && !os.IsNotExist(err) { diff --git a/testdata/wal-writer/live/README.md b/testdata/wal-writer/live/README.md new file mode 100644 index 0000000..2387431 --- /dev/null +++ b/testdata/wal-writer/live/README.md @@ -0,0 +1,19 @@ +WAL Writer Live +================= + +This test is to ensure we can copy a WAL file into place with a live DB and +trigger a checkpoint into the main DB file. + +To reproduce the data files: + +```sh +$ sqlite3 db + +sqlite> PRAGMA journal_mode = 'wal'; +sqlite> CREATE TABLE t (x); +sqlite> PRAGMA wal_checkpoint(TRUNCATE); +sqlite> INSERT INTO t (x) VALUES (1); + +sqlite> CTRL-\ +``` + diff --git a/testdata/wal-writer/live/db b/testdata/wal-writer/live/db new file mode 100644 index 0000000..6a63447 Binary files /dev/null and b/testdata/wal-writer/live/db differ diff --git a/testdata/wal-writer/live/db-shm b/testdata/wal-writer/live/db-shm new file mode 100644 index 0000000..1d5fdd8 Binary files /dev/null and b/testdata/wal-writer/live/db-shm differ diff --git a/testdata/wal-writer/live/db-wal b/testdata/wal-writer/live/db-wal new file mode 100644 index 0000000..43300fc Binary files /dev/null and b/testdata/wal-writer/live/db-wal differ diff --git a/testdata/wal-writer/static/README.md b/testdata/wal-writer/static/README.md new file mode 100644 index 0000000..99ffadb --- /dev/null +++ b/testdata/wal-writer/static/README.md @@ -0,0 +1,26 @@ +WAL Writer Static +================= + +This test is to ensure that WALWriter will generate the same WAL file as +the `sqlite3` command line. + +To reproduce the data file: + +```sh +$ sqlite3 db + +sqlite> PRAGMA journal_mode = 'wal'; + +sqlite> CREATE TABLE t (x); + +sqlite> INSERT INTO t (x) VALUES (1); + +sqlite> CTRL-\ +``` + +then remove the db & shm files: + +```sh +$ rm db db-shm +``` + diff --git a/testdata/wal-writer/static/db-wal b/testdata/wal-writer/static/db-wal new file mode 100644 index 0000000..5cac19e Binary files /dev/null and b/testdata/wal-writer/static/db-wal differ diff --git a/wal_writer.go b/wal_writer.go new file mode 100644 index 0000000..0cd90c5 --- /dev/null +++ b/wal_writer.go @@ -0,0 +1,103 @@ +package litestream + +import ( + "encoding/binary" + "fmt" + "os" +) + +// WALWriter represents a writer to a SQLite WAL file. +type WALWriter struct { + path string + mode os.FileMode + pageSize int + + f *os.File // WAL file handle + buf []byte // frame buffer + + chksum0, chksum1 uint32 // ongoing checksum + + Salt0, Salt1 uint32 +} + +// NewWALWriter returns a new instance of WALWriter. +func NewWALWriter(path string, mode os.FileMode, pageSize int) *WALWriter { + return &WALWriter{ + path: path, + mode: mode, + pageSize: pageSize, + + buf: make([]byte, WALFrameHeaderSize+pageSize), + } +} + +// Open opens the file handle to the WAL file. +func (w *WALWriter) Open() (err error) { + w.f, err = os.OpenFile(w.path, os.O_WRONLY|os.O_TRUNC, w.mode) + return err +} + +// Close closes the file handle to the WAL file. +func (w *WALWriter) Close() error { + if w.f == nil { + return nil + } + return w.f.Close() +} + +// WriteHeader writes the WAL header to the beginning of the file. +func (w *WALWriter) WriteHeader() error { + // Build WAL header byte slice. Page size and checksum set afterward. + hdr := []byte{ + 0x37, 0x7f, 0x06, 0x82, // magic (little-endian) + 0x00, 0x2d, 0xe2, 0x18, // file format version (3007000) + 0x00, 0x00, 0x00, 0x00, // page size + 0x00, 0x00, 0x00, 0x00, // checkpoint sequence number + 0x00, 0x00, 0x00, 0x00, // salt + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, // checksum + 0x00, 0x00, 0x00, 0x00, + } + + // Set page size on header + binary.BigEndian.PutUint32(hdr[8:], uint32(w.pageSize)) + + // Set salt + binary.BigEndian.PutUint32(hdr[16:], w.Salt0) + binary.BigEndian.PutUint32(hdr[20:], w.Salt1) + + // Compute header checksum. + w.chksum0, w.chksum1 = Checksum(binary.LittleEndian, w.chksum0, w.chksum1, hdr[:24]) + binary.BigEndian.PutUint32(hdr[24:], w.chksum0) + binary.BigEndian.PutUint32(hdr[28:], w.chksum1) + + // Write header to WAL. + _, err := w.f.Write(hdr) + return err +} + +func (w *WALWriter) WriteFrame(pgno, commit uint32, data []byte) error { + // Ensure data matches page size. + if len(data) != w.pageSize { + return fmt.Errorf("data size %d must match page size %d", len(data), w.pageSize) + } + + // Write frame header. + binary.BigEndian.PutUint32(w.buf[0:], pgno) // page number + binary.BigEndian.PutUint32(w.buf[4:], commit) // commit record (page count) + binary.BigEndian.PutUint32(w.buf[8:], w.Salt0) // salt + binary.BigEndian.PutUint32(w.buf[12:], w.Salt1) + + // Copy data to frame. + copy(w.buf[WALFrameHeaderSize:], data) + + // Compute checksum for frame. + w.chksum0, w.chksum1 = Checksum(binary.LittleEndian, w.chksum0, w.chksum1, w.buf[:8]) + w.chksum0, w.chksum1 = Checksum(binary.LittleEndian, w.chksum0, w.chksum1, w.buf[24:]) + binary.BigEndian.PutUint32(w.buf[16:], w.chksum0) + binary.BigEndian.PutUint32(w.buf[20:], w.chksum1) + + // Write to local WAL + _, err := w.f.Write(w.buf) + return err +} diff --git a/wal_writer_test.go b/wal_writer_test.go new file mode 100644 index 0000000..8a38085 --- /dev/null +++ b/wal_writer_test.go @@ -0,0 +1,116 @@ +package litestream_test + +import ( + "bytes" + "database/sql" + "encoding/binary" + "io" + "os" + "path/filepath" + "testing" + + "github.com/benbjohnson/litestream" + "github.com/benbjohnson/litestream/internal/testingutil" + _ "github.com/mattn/go-sqlite3" +) + +func TestWALWriter_Static(t *testing.T) { + testDir := filepath.Join("testdata", "wal-writer", "static") + tempDir := t.TempDir() + + // Read in WAL file generated by sqlite3 + buf, err := os.ReadFile(filepath.Join(testDir, "db-wal")) + if err != nil { + t.Fatal(err) + } + + // Create new WAL file. + if err := os.WriteFile(filepath.Join(tempDir, "db-wal"), nil, 0666); err != nil { + t.Fatal(err) + } + + w := litestream.NewWALWriter(filepath.Join(tempDir, "db-wal"), 0666, 4096) + w.Salt0 = binary.BigEndian.Uint32(buf[16:]) + w.Salt1 = binary.BigEndian.Uint32(buf[20:]) + + if err := w.Open(); err != nil { + t.Fatal(err) + } else if err := w.WriteHeader(); err != nil { + t.Fatal(err) + } + + for b := buf[litestream.WALHeaderSize:]; len(b) > 0; b = b[litestream.WALFrameHeaderSize+4096:] { + pgno := binary.BigEndian.Uint32(b[0:]) + commit := binary.BigEndian.Uint32(b[4:]) + if err := w.WriteFrame(pgno, commit, b[litestream.WALFrameHeaderSize:][:4096]); err != nil { + t.Fatal(err) + } + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + // Read generated WAL and compare with original. + if buf2, err := os.ReadFile(filepath.Join(tempDir, "db-wal")); err != nil { + t.Fatal(err) + } else if !bytes.Equal(buf, buf2) { + t.Fatal("wal file mismatch") + } +} + +func TestWALWriter_Live(t *testing.T) { + testDir := filepath.Join("testdata", "wal-writer", "live") + tempDir := t.TempDir() + + // Copy DB file into temporary dir. + testingutil.CopyFile(t, filepath.Join(testDir, "db"), filepath.Join(tempDir, "db")) + + // Open database. + db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // Verify that table is empty. + var n int + if err := db.QueryRow(`SELECT COUNT(*) FROM t`).Scan(&n); err != nil { + t.Fatal(err) + } else if got, want := n, 0; got != want { + t.Fatalf("init: n=%d, want %d", got, want) + } + + // Copy WAL file into place. + testingutil.CopyFile(t, filepath.Join(testDir, "db-wal"), filepath.Join(tempDir, "db-wal")) + + // Invalidate both copies of the WAL index headers. + f, err := os.OpenFile(filepath.Join(tempDir, "db-shm"), os.O_RDWR, 0666) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + // Read index header. + idx := make([]byte, 136) + if _, err := io.ReadFull(f, idx); err != nil { + t.Fatal(err) + } + + // Invalidate "isInit" flags + idx[12], idx[48+12] = 0, 0 + + // Write header back into index. + if _, err := f.Seek(0, io.SeekStart); err != nil { + t.Fatal(err) + } else if _, err := f.Write(idx); err != nil { + t.Fatal(err) + } + + // Verify that table now has one row. + if err := db.QueryRow(`SELECT COUNT(*) FROM t`).Scan(&n); err != nil { + t.Fatal(err) + } else if got, want := n, 1; got != want { + t.Fatalf("post-wal: n=%d, want %d", got, want) + } +}