diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index 3cf5a0e..73b9fd8 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -23,7 +23,6 @@ import ( "github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream/abs" "github.com/benbjohnson/litestream/gs" - "github.com/benbjohnson/litestream/http" "github.com/benbjohnson/litestream/s3" "github.com/benbjohnson/litestream/sftp" _ "github.com/mattn/go-sqlite3" @@ -284,7 +283,6 @@ func readConfigFile(filename string, expandEnv bool) (_ Config, err error) { // DBConfig represents the configuration for a single database. type DBConfig struct { Path string `yaml:"path"` - Upstream UpstreamConfig `yaml:"upstream"` MonitorDelayInterval *time.Duration `yaml:"monitor-delay-interval"` CheckpointInterval *time.Duration `yaml:"checkpoint-interval"` MinCheckpointPageN *int `yaml:"min-checkpoint-page-count"` @@ -308,16 +306,6 @@ func NewDBFromConfigWithPath(dbc *DBConfig, path string) (*litestream.DB, error) // Initialize database with given path. db := litestream.NewDB(path) - // Attach upstream HTTP client if specified. - if upstreamURL := dbc.Upstream.URL; upstreamURL != "" { - // Use local database path if upstream path is not specified. - upstreamPath := dbc.Upstream.Path - if upstreamPath == "" { - upstreamPath = db.Path() - } - db.StreamClient = http.NewClient(upstreamURL, upstreamPath) - } - // Override default database settings if specified in configuration. if dbc.MonitorDelayInterval != nil { db.MonitorDelayInterval = *dbc.MonitorDelayInterval @@ -347,11 +335,6 @@ func NewDBFromConfigWithPath(dbc *DBConfig, path string) (*litestream.DB, error) return db, nil } -type UpstreamConfig struct { - URL string `yaml:"url"` - Path string `yaml:"path"` -} - // ReplicaConfig represents the configuration for a single replica in a database. type ReplicaConfig struct { Type string `yaml:"type"` // "file", "s3" diff --git a/db.go b/db.go index 160cd47..baeeada 100644 --- a/db.go +++ b/db.go @@ -54,9 +54,6 @@ type DB struct { pageSize int // page size, in bytes notifyCh chan struct{} // notifies DB of changes - // Iterators used to stream new WAL changes to replicas - itrs map[*FileWALSegmentIterator]struct{} - // Cached salt & checksum from current shadow header. hdr []byte frame []byte @@ -85,11 +82,6 @@ type DB struct { checkpointErrorNCounterVec *prometheus.CounterVec checkpointSecondsCounterVec *prometheus.CounterVec - // Client used to receive live, upstream changes. If specified, then - // DB should be used as read-only as local changes will conflict with - // upstream changes. - StreamClient StreamClient - // Minimum threshold of WAL size, in pages, before a passive checkpoint. // A passive checkpoint will attempt a checkpoint but fail if there are // active transactions occurring at the same time. @@ -130,8 +122,6 @@ func NewDB(path string) *DB { path: path, notifyCh: make(chan struct{}, 1), - itrs: make(map[*FileWALSegmentIterator]struct{}), - MinCheckpointPageN: DefaultMinCheckpointPageN, MaxCheckpointPageN: DefaultMaxCheckpointPageN, ShadowRetentionN: DefaultShadowRetentionN, @@ -275,7 +265,7 @@ func (db *DB) invalidatePos(ctx context.Context) error { } // Iterate over all segments to find the last one. - itr, err := db.walSegments(context.Background(), generation, false) + itr, err := db.walSegments(context.Background(), generation) if err != nil { return err } @@ -422,9 +412,7 @@ func (db *DB) Open() (err error) { return fmt.Errorf("cannot remove tmp files: %w", err) } - // If an upstream client is specified, then we should simply stream changes - // into the database. If it is not specified, then we should monitor the - // database for local changes and replicate them out. + // Continually monitor local changes in a separate goroutine. db.g.Go(func() error { return db.monitor(db.ctx) }) return nil @@ -466,14 +454,6 @@ func (db *DB) Close() (err error) { } } - // Remove all iterators. - db.mu.Lock() - for itr := range db.itrs { - itr.SetErr(ErrDBClosed) - delete(db.itrs, itr) - } - db.mu.Unlock() - // Release the read lock to allow other applications to handle checkpointing. if db.rtx != nil { if e := db.releaseReadLock(); e != nil && err == nil { @@ -645,74 +625,6 @@ func (db *DB) init() (err error) { return nil } -// initReplica initializes a new database file as a replica of an upstream database. -func (db *DB) initReplica(pageSize int) (err error) { - // Exit if already initialized. - if db.db != nil { - return nil - } - - // Obtain permissions for parent directory. - fi, err := os.Stat(filepath.Dir(db.path)) - if err != nil { - return err - } - db.dirMode = fi.Mode() - - dsn := db.path - dsn += fmt.Sprintf("?_busy_timeout=%d", BusyTimeout.Milliseconds()) - - // Connect to SQLite database. Use the driver registered with a hook to - // prevent WAL files from being removed. - if db.db, err = sql.Open("litestream-sqlite3", dsn); err != nil { - return err - } - - // Initialize database file if it doesn't exist. It doesn't matter what we - // store in it as it will be erased by the replication. We just need to - // ensure a WAL file is created and there is at least a page in the database. - if _, err := os.Stat(db.path); os.IsNotExist(err) { - if _, err := db.db.ExecContext(db.ctx, fmt.Sprintf(`PRAGMA page_size = %d`, pageSize)); err != nil { - return fmt.Errorf("set page size: %w", err) - } - - var mode string - if err := db.db.QueryRow(`PRAGMA journal_mode = wal`).Scan(&mode); err != nil { - return err - } else if mode != "wal" { - return fmt.Errorf("enable wal failed, mode=%q", mode) - } - - if _, err := db.db.ExecContext(db.ctx, `CREATE TABLE IF NOT EXISTS _litestream (id INTEGER)`); err != nil { - return fmt.Errorf("create _litestream table: %w", err) - } else if _, err := db.db.ExecContext(db.ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { - return fmt.Errorf("create _litestream table: %w", err) - } - } - - // Obtain file info once we know the database exists. - fi, err = os.Stat(db.path) - if err != nil { - return fmt.Errorf("init file stat: %w", err) - } - db.fileMode = fi.Mode() - db.uid, db.gid = internal.Fileinfo(fi) - - // Verify page size matches. - if err := db.db.QueryRowContext(db.ctx, `PRAGMA page_size;`).Scan(&db.pageSize); err != nil { - return fmt.Errorf("read page size: %w", err) - } else if db.pageSize != pageSize { - return fmt.Errorf("page size mismatch: %d <> %d", db.pageSize, pageSize) - } - - // Ensure meta directory structure exists. - if err := internal.MkdirAll(db.MetaPath(), db.dirMode, db.uid, db.gid); err != nil { - return err - } - - return nil -} - func (db *DB) clearGeneration(ctx context.Context) error { if err := os.Remove(db.GenerationNamePath()); err != nil && !os.IsNotExist(err) { return err @@ -927,11 +839,6 @@ func (db *DB) createGeneration(ctx context.Context) (string, error) { // Sync copies pending data from the WAL to the shadow WAL. func (db *DB) Sync(ctx context.Context) error { - if db.StreamClient != nil { - db.Logger.Printf("using upstream client, skipping sync") - return nil - } - const retryN = 5 for i := 0; i < retryN; i++ { @@ -1386,8 +1293,7 @@ func (db *DB) writeWALSegment(ctx context.Context, pos Pos, rd io.Reader) error } defer f.Close() - n, err := io.Copy(f, rd) - if err != nil { + if _, err := io.Copy(f, rd); err != nil { return err } else if err := f.Sync(); err != nil { return err @@ -1405,50 +1311,9 @@ func (db *DB) writeWALSegment(ctx context.Context, pos Pos, rd io.Reader) error return fmt.Errorf("write position file: %w", err) } - // Generate - info := WALSegmentInfo{ - Generation: pos.Generation, - Index: pos.Index, - Offset: pos.Offset, - Size: n, - CreatedAt: time.Now(), - } - - // Notify all managed segment iterators. - for itr := range db.itrs { - // Notify iterators of generation change. - if itr.Generation() != pos.Generation { - itr.SetErr(ErrGenerationChanged) - delete(db.itrs, itr) - continue - } - - // Attempt to append segment to end of iterator. - // On error, mark it on the iterator and remove from future notifications. - if err := itr.Append(info); err != nil { - itr.SetErr(fmt.Errorf("cannot append wal segment: %w", err)) - delete(db.itrs, itr) - continue - } - } - return nil } -// readPositionFile reads the position from the position file. -func (db *DB) readPositionFile() (Pos, error) { - buf, err := os.ReadFile(db.PositionPath()) - if os.IsNotExist(err) { - return Pos{}, nil - } else if err != nil { - return Pos{}, err - } - - // Treat invalid format as a non-existent file so we return an empty position. - pos, _ := ParsePos(strings.TrimSpace(string(buf))) - return pos, nil -} - // writePositionFile writes pos as the current position. func (db *DB) writePositionFile(pos Pos) error { return internal.WriteFile(db.PositionPath(), []byte(pos.String()+"\n"), db.fileMode, db.uid, db.gid) @@ -1458,10 +1323,10 @@ func (db *DB) writePositionFile(pos Pos) error { func (db *DB) WALSegments(ctx context.Context, generation string) (*FileWALSegmentIterator, error) { db.mu.Lock() defer db.mu.Unlock() - return db.walSegments(ctx, generation, true) + return db.walSegments(ctx, generation) } -func (db *DB) walSegments(ctx context.Context, generation string, managed bool) (*FileWALSegmentIterator, error) { +func (db *DB) walSegments(ctx context.Context, generation string) (*FileWALSegmentIterator, error) { ents, err := os.ReadDir(db.ShadowWALDir(generation)) if os.IsNotExist(err) { return NewFileWALSegmentIterator(db.ShadowWALDir(generation), generation, nil), nil @@ -1481,27 +1346,7 @@ func (db *DB) walSegments(ctx context.Context, generation string, managed bool) sort.Ints(indexes) - itr := NewFileWALSegmentIterator(db.ShadowWALDir(generation), generation, indexes) - - // Managed iterators will have new segments pushed to them. - if managed { - itr.closeFunc = func() error { - return db.CloseWALSegmentIterator(itr) - } - - db.itrs[itr] = struct{}{} - } - - return itr, nil -} - -// CloseWALSegmentIterator removes itr from the list of managed iterators. -func (db *DB) CloseWALSegmentIterator(itr *FileWALSegmentIterator) error { - db.mu.Lock() - defer db.mu.Unlock() - - delete(db.itrs, itr) - return nil + return NewFileWALSegmentIterator(db.ShadowWALDir(generation), generation, indexes), nil } // SQLite WAL constants @@ -1645,15 +1490,8 @@ func (db *DB) execCheckpoint(mode string) (err error) { return nil } -func (db *DB) monitor(ctx context.Context) error { - if db.StreamClient != nil { - return db.monitorUpstream(ctx) - } - return db.monitorLocal(ctx) -} - // monitor runs in a separate goroutine and monitors the local database & WAL. -func (db *DB) monitorLocal(ctx context.Context) error { +func (db *DB) monitor(ctx context.Context) error { var timer *time.Timer if db.MonitorDelayInterval > 0 { timer = time.NewTimer(db.MonitorDelayInterval) @@ -1686,189 +1524,6 @@ func (db *DB) monitorLocal(ctx context.Context) error { } } -// monitorUpstream runs in a separate goroutine and streams data into the local DB. -func (db *DB) monitorUpstream(ctx context.Context) error { - for { - if err := db.stream(ctx); err != nil { - if ctx.Err() != nil { - return nil - } - db.Logger.Printf("stream error, retrying: %s", err) - } - - // Delay before retrying stream. - select { - case <-ctx.Done(): - return nil - case <-time.After(1 * time.Second): - } - } -} - -// stream initializes the local database and continuously streams new upstream data. -func (db *DB) stream(ctx context.Context) error { - pos, err := db.readPositionFile() - if err != nil { - return fmt.Errorf("read position file: %w", err) - } - - // Continuously stream and apply records from client. - sr, err := db.StreamClient.Stream(ctx, pos) - if err != nil { - return fmt.Errorf("stream connect: %w", err) - } - defer sr.Close() - - // Initialize the database and create it if it doesn't exist. - if err := db.initReplica(sr.PageSize()); err != nil { - return fmt.Errorf("init replica: %w", err) - } - - for { - hdr, err := sr.Next() - if err != nil { - return err - } - - switch hdr.Type { - case StreamRecordTypeSnapshot: - if err := db.streamSnapshot(ctx, hdr, sr); err != nil { - return fmt.Errorf("snapshot: %w", err) - } - case StreamRecordTypeWALSegment: - if err := db.streamWALSegment(ctx, hdr, sr); err != nil { - return fmt.Errorf("wal segment: %w", err) - } - default: - return fmt.Errorf("invalid stream record type: 0x%02x", hdr.Type) - } - } -} - -// streamSnapshot reads the snapshot into the WAL and applies it to the main database. -func (db *DB) streamSnapshot(ctx context.Context, hdr *StreamRecordHeader, r io.Reader) error { - // Truncate WAL file. - if _, err := db.db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { - return fmt.Errorf("truncate: %w", err) - } - - // Determine total page count. - pageN := int(hdr.Size / int64(db.pageSize)) - - ww := NewWALWriter(db.WALPath(), db.fileMode, db.pageSize) - if err := ww.Open(); err != nil { - return fmt.Errorf("open wal writer: %w", err) - } - defer func() { _ = ww.Close() }() - - if err := ww.WriteHeader(); err != nil { - return fmt.Errorf("write wal header: %w", err) - } - - // Iterate over pages - buf := make([]byte, db.pageSize) - for pgno := uint32(1); ; pgno++ { - // Read snapshot page into a buffer. - if _, err := io.ReadFull(r, buf); err == io.EOF { - break - } else if err != nil { - return fmt.Errorf("read snapshot page %d: %w", pgno, err) - } - - // Issue a commit flag when the last page is reached. - var commit uint32 - if pgno == uint32(pageN) { - commit = uint32(pageN) - } - - // Write page into WAL frame. - if err := ww.WriteFrame(pgno, commit, buf); err != nil { - return fmt.Errorf("write wal frame: %w", err) - } - } - - // Close WAL file writer. - if err := ww.Close(); err != nil { - return fmt.Errorf("close wal writer: %w", err) - } - - // Invalidate WAL index. - if err := invalidateSHMFile(db.path); err != nil { - return fmt.Errorf("invalidate shm file: %w", err) - } - - // Write position to file so other processes can read it. - if err := db.writePositionFile(hdr.Pos()); err != nil { - return fmt.Errorf("write position file: %w", err) - } - - db.Logger.Printf("snapshot applied") - - return nil -} - -// streamWALSegment rewrites a WAL segment into the local WAL and applies it to the main database. -func (db *DB) streamWALSegment(ctx context.Context, hdr *StreamRecordHeader, r io.Reader) error { - // Decompress incoming segment - zr := lz4.NewReader(r) - - // Drop WAL header if starting from offset zero. - if hdr.Offset == 0 { - if _, err := io.CopyN(io.Discard, zr, WALHeaderSize); err != nil { - return fmt.Errorf("read wal header: %w", err) - } - } - - ww := NewWALWriter(db.WALPath(), db.fileMode, db.pageSize) - if err := ww.Open(); err != nil { - return fmt.Errorf("open wal writer: %w", err) - } - defer func() { _ = ww.Close() }() - - if err := ww.WriteHeader(); err != nil { - return fmt.Errorf("write wal header: %w", err) - } - - // Iterate over incoming WAL pages. - buf := make([]byte, WALFrameHeaderSize+db.pageSize) - for i := 0; ; i++ { - // Read snapshot page into a buffer. - if _, err := io.ReadFull(zr, buf); err == io.EOF { - break - } else if err != nil { - return fmt.Errorf("read wal frame %d: %w", i, err) - } - - // Read page number & commit field. - pgno := binary.BigEndian.Uint32(buf[0:]) - commit := binary.BigEndian.Uint32(buf[4:]) - - // Write page into WAL frame. - if err := ww.WriteFrame(pgno, commit, buf[WALFrameHeaderSize:]); err != nil { - return fmt.Errorf("write wal frame: %w", err) - } - } - - // Close WAL file writer. - if err := ww.Close(); err != nil { - return fmt.Errorf("close wal writer: %w", err) - } - - // Invalidate WAL index. - if err := invalidateSHMFile(db.path); err != nil { - return fmt.Errorf("invalidate shm file: %w", err) - } - - // Write position to file so other processes can read it. - if err := db.writePositionFile(hdr.Pos()); err != nil { - return fmt.Errorf("write position file: %w", err) - } - - db.Logger.Printf("wal segment applied: %s", hdr.Pos().String()) - - return nil -} - // ApplyWAL performs a truncating checkpoint on the given database. func ApplyWAL(ctx context.Context, dbPath, walPath string) error { // Copy WAL file from it's staging path to the correct "-wal" location. @@ -2016,51 +1671,6 @@ func logPrefixPath(path string) string { return path } -// invalidateSHMFile clears the iVersion field of the -shm file in order that -// the next transaction will rebuild it. -func invalidateSHMFile(dbPath string) error { - db, err := sql.Open("sqlite3", dbPath) - if err != nil { - return fmt.Errorf("reopen db: %w", err) - } - defer func() { _ = db.Close() }() - - if _, err := db.Exec(`PRAGMA wal_checkpoint(PASSIVE)`); err != nil { - return fmt.Errorf("passive checkpoint: %w", err) - } - - f, err := os.OpenFile(dbPath+"-shm", os.O_RDWR, 0666) - if err != nil { - return fmt.Errorf("open shm index: %w", err) - } - defer f.Close() - - buf := make([]byte, WALIndexHeaderSize) - if _, err := io.ReadFull(f, buf); err != nil { - return fmt.Errorf("read shm index: %w", err) - } - - // Invalidate "isInit" fields. - buf[12], buf[60] = 0, 0 - - // Rewrite header. - if _, err := f.Seek(0, io.SeekStart); err != nil { - return fmt.Errorf("seek shm index: %w", err) - } else if _, err := f.Write(buf); err != nil { - return fmt.Errorf("overwrite shm index: %w", err) - } else if err := f.Close(); err != nil { - return fmt.Errorf("close shm index: %w", err) - } - - // Truncate WAL file again. - var row [3]int - if err := db.QueryRow(`PRAGMA wal_checkpoint(TRUNCATE)`).Scan(&row[0], &row[1], &row[2]); err != nil { - return fmt.Errorf("truncate: %w", err) - } - - return nil -} - // A marker error to indicate that a restart checkpoint could not verify // continuity between WAL indices and a new generation should be started. var errRestartGeneration = errors.New("restart generation") diff --git a/file_replica_client.go b/file_replica_client.go index dc323a3..fd97995 100644 --- a/file_replica_client.go +++ b/file_replica_client.go @@ -362,9 +362,8 @@ func (c *FileReplicaClient) DeleteWALSegments(ctx context.Context, a []Pos) erro } type FileWALSegmentIterator struct { - mu sync.Mutex - notifyCh chan struct{} - closeFunc func() error + mu sync.Mutex + notifyCh chan struct{} dir string generation string @@ -386,12 +385,6 @@ func NewFileWALSegmentIterator(dir, generation string, indexes []int) *FileWALSe } func (itr *FileWALSegmentIterator) Close() (err error) { - if itr.closeFunc != nil { - if e := itr.closeFunc(); e != nil && err == nil { - err = e - } - } - if e := itr.Err(); e != nil && err == nil { err = e } diff --git a/http/client.go b/http/client.go deleted file mode 100644 index 2975ec6..0000000 --- a/http/client.go +++ /dev/null @@ -1,140 +0,0 @@ -package http - -import ( - "context" - "fmt" - "io" - "net/http" - "net/url" - "strconv" - - "github.com/benbjohnson/litestream" -) - -// Client represents an client for a streaming Litestream HTTP server. -type Client struct { - // Upstream endpoint - URL string - - // Path of database on upstream server. - Path string - - // Underlying HTTP client - HTTPClient *http.Client -} - -// NewClient returns an instance of Client. -func NewClient(rawurl, path string) *Client { - return &Client{ - URL: rawurl, - Path: path, - HTTPClient: http.DefaultClient, - } -} - -// Stream returns a snapshot and continuous stream of WAL updates. -func (c *Client) Stream(ctx context.Context, pos litestream.Pos) (litestream.StreamReader, error) { - u, err := url.Parse(c.URL) - if err != nil { - return nil, fmt.Errorf("invalid client URL: %w", err) - } else if u.Scheme != "http" && u.Scheme != "https" { - return nil, fmt.Errorf("invalid URL scheme") - } else if u.Host == "" { - return nil, fmt.Errorf("URL host required") - } - - // Add path & position to query path. - q := url.Values{"path": []string{c.Path}} - if !pos.IsZero() { - q.Set("generation", pos.Generation) - q.Set("index", litestream.FormatIndex(pos.Index)) - } - - // Strip off everything but the scheme & host. - *u = url.URL{ - Scheme: u.Scheme, - Host: u.Host, - Path: "/stream", - RawQuery: q.Encode(), - } - - req, err := http.NewRequest("GET", u.String(), nil) - if err != nil { - return nil, err - } - req = req.WithContext(ctx) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } else if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return nil, fmt.Errorf("invalid response: code=%d", resp.StatusCode) - } - - pageSize, _ := strconv.Atoi(resp.Header.Get("Litestream-page-size")) - if pageSize <= 0 { - resp.Body.Close() - return nil, fmt.Errorf("stream page size unavailable") - } - - return &StreamReader{ - pageSize: pageSize, - rc: resp.Body, - lr: io.LimitedReader{R: resp.Body}, - }, nil -} - -// StreamReader represents an optional snapshot followed by a continuous stream -// of WAL updates. It is used to implement live read replication from a single -// primary Litestream server to one or more remote Litestream replicas. -type StreamReader struct { - pageSize int - rc io.ReadCloser - lr io.LimitedReader -} - -// Close closes the underlying reader. -func (r *StreamReader) Close() (err error) { - if e := r.rc.Close(); err == nil { - err = e - } - return err -} - -// PageSize returns the page size on the remote database. -func (r *StreamReader) PageSize() int { return r.pageSize } - -// Read reads bytes of the current payload into p. Only valid after a successful -// call to Next(). On io.EOF, call Next() again to begin reading next record. -func (r *StreamReader) Read(p []byte) (n int, err error) { - return r.lr.Read(p) -} - -// Next returns the next available record. This call will block until a record -// is available. After calling Next(), read the payload from the reader using -// Read() until io.EOF is reached. -func (r *StreamReader) Next() (*litestream.StreamRecordHeader, error) { - // If bytes remain on the current file, discard. - if r.lr.N > 0 { - if _, err := io.Copy(io.Discard, &r.lr); err != nil { - return nil, err - } - } - - // Read record header. - buf := make([]byte, litestream.StreamRecordHeaderSize) - if _, err := io.ReadFull(r.rc, buf); err != nil { - return nil, fmt.Errorf("http.StreamReader.Next(): %w", err) - } - - var hdr litestream.StreamRecordHeader - if err := hdr.UnmarshalBinary(buf); err != nil { - return nil, err - } - - // Update remaining bytes on file reader. - r.lr.N = hdr.Size - - return &hdr, nil -} diff --git a/http/server.go b/http/server.go index 702b2d6..910158a 100644 --- a/http/server.go +++ b/http/server.go @@ -2,13 +2,11 @@ package http import ( "fmt" - "io" "log" "net" "net/http" httppprof "net/http/pprof" "os" - "strconv" "strings" "github.com/benbjohnson/litestream" @@ -113,190 +111,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/metrics": s.promHandler.ServeHTTP(w, r) - - case "/stream": - switch r.Method { - case http.MethodGet: - s.handleGetStream(w, r) - default: - s.writeError(w, r, "Method not allowed", http.StatusMethodNotAllowed) - } default: http.NotFound(w, r) } } - -func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) { - q := r.URL.Query() - path := q.Get("path") - if path == "" { - s.writeError(w, r, "Database name required", http.StatusBadRequest) - return - } - - // Parse current client position, if available. - var pos litestream.Pos - if generation, index := q.Get("generation"), q.Get("index"); generation != "" && index != "" { - pos.Generation = generation - - var err error - if pos.Index, err = litestream.ParseIndex(index); err != nil { - s.writeError(w, r, "Invalid index query parameter", http.StatusBadRequest) - return - } - } - - // Fetch database instance from the primary server. - db := s.server.DB(path) - if db == nil { - s.writeError(w, r, "Database not found", http.StatusNotFound) - return - } - - // Set the page size in the header. - w.Header().Set("Litestream-page-size", strconv.Itoa(db.PageSize())) - - // Determine starting position. - dbPos := db.Pos() - if dbPos.Generation == "" { - s.writeError(w, r, "No generation available", http.StatusServiceUnavailable) - return - } - dbPos.Offset = 0 - - // Use database position if generation has changed. - var snapshotRequired bool - if pos.Generation != dbPos.Generation { - s.Logger.Printf("stream generation mismatch, using primary position: client.pos=%s", pos) - pos, snapshotRequired = dbPos, true - } - - // Obtain iterator before snapshot so we don't miss any WAL segments. - fitr, err := db.WALSegments(r.Context(), pos.Generation) - if err != nil { - s.writeError(w, r, fmt.Sprintf("Cannot obtain WAL iterator: %s", err), http.StatusInternalServerError) - return - } - defer fitr.Close() - - bitr := litestream.NewBufferedWALSegmentIterator(fitr) - - // Peek at first position to see if client is too old. - if info, ok := bitr.Peek(); !ok { - s.writeError(w, r, "cannot peek WAL iterator, no segments available", http.StatusInternalServerError) - return - } else if cmp, err := litestream.ComparePos(pos, info.Pos()); err != nil { - s.writeError(w, r, fmt.Sprintf("cannot compare pos: %s", err), http.StatusInternalServerError) - return - } else if cmp == -1 { - s.Logger.Printf("stream position no longer available, using using primary position: client.pos=%s", pos) - pos, snapshotRequired = dbPos, true - } - - s.Logger.Printf("stream connected: pos=%s snapshot=%v", pos, snapshotRequired) - defer s.Logger.Printf("stream disconnected") - - // Write snapshot to response body. - if snapshotRequired { - if err := db.WithFile(func(f *os.File) error { - fi, err := f.Stat() - if err != nil { - return err - } - - // Write snapshot header with current position & size. - hdr := litestream.StreamRecordHeader{ - Type: litestream.StreamRecordTypeSnapshot, - Generation: pos.Generation, - Index: pos.Index, - Size: fi.Size(), - } - if buf, err := hdr.MarshalBinary(); err != nil { - return fmt.Errorf("marshal snapshot stream record header: %w", err) - } else if _, err := w.Write(buf); err != nil { - return fmt.Errorf("write snapshot stream record header: %w", err) - } - - if _, err := io.CopyN(w, f, fi.Size()); err != nil { - return fmt.Errorf("copy snapshot: %w", err) - } - - return nil - }); err != nil { - s.writeError(w, r, err.Error(), http.StatusInternalServerError) - return - } - - // Flush after snapshot has been written. - w.(http.Flusher).Flush() - } - - for { - // Wait for notification of new entries. - select { - case <-r.Context().Done(): - return - case <-fitr.NotifyCh(): - } - - for bitr.Next() { - info := bitr.WALSegment() - - // Skip any segments before our initial position. - if cmp, err := litestream.ComparePos(info.Pos(), pos); err != nil { - s.Logger.Printf("pos compare: %s", err) - return - } else if cmp == -1 { - continue - } - - hdr := litestream.StreamRecordHeader{ - Type: litestream.StreamRecordTypeWALSegment, - Flags: 0, - Generation: info.Generation, - Index: info.Index, - Offset: info.Offset, - Size: info.Size, - } - - // Write record header. - data, err := hdr.MarshalBinary() - if err != nil { - s.Logger.Printf("marshal WAL segment stream record header: %s", err) - return - } else if _, err := w.Write(data); err != nil { - s.Logger.Printf("write WAL segment stream record header: %s", err) - return - } - - // Copy WAL segment data to writer. - if err := func() error { - rd, err := db.WALSegmentReader(r.Context(), info.Pos()) - if err != nil { - return fmt.Errorf("cannot fetch wal segment reader: %w", err) - } - defer rd.Close() - - if _, err := io.CopyN(w, rd, hdr.Size); err != nil { - return fmt.Errorf("cannot copy wal segment: %w", err) - } - return nil - }(); err != nil { - log.Print(err) - return - } - - // Flush after WAL segment has been written. - w.(http.Flusher).Flush() - } - if err := bitr.Err(); err != nil { - s.Logger.Printf("wal iterator error: %s", err) - return - } - } -} - -func (s *Server) writeError(w http.ResponseWriter, r *http.Request, err string, code int) { - s.Logger.Printf("error: %s", err) - http.Error(w, err, code) -} diff --git a/integration/cmd_test.go b/integration/cmd_test.go index 62ee336..3abc628 100644 --- a/integration/cmd_test.go +++ b/integration/cmd_test.go @@ -391,254 +391,6 @@ LOOP: restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db")) } -// Ensure a database can be replicated over HTTP. -func TestCmd_Replicate_HTTP(t *testing.T) { - ctx := context.Background() - testDir, tempDir := filepath.Join("testdata", "replicate", "http"), t.TempDir() - if err := os.Mkdir(filepath.Join(tempDir, "0"), 0777); err != nil { - t.Fatal(err) - } else if err := os.Mkdir(filepath.Join(tempDir, "1"), 0777); err != nil { - t.Fatal(err) - } - - env0 := []string{"LITESTREAM_TEMPDIR=" + tempDir} - env1 := []string{"LITESTREAM_TEMPDIR=" + tempDir, "LITESTREAM_UPSTREAM_URL=http://localhost:10001"} - - cmd0, stdout0, _ := commandContext(ctx, env0, "replicate", "-config", filepath.Join(testDir, "litestream.0.yml")) - if err := cmd0.Start(); err != nil { - t.Fatal(err) - } - cmd1, stdout1, _ := commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml")) - if err := cmd1.Start(); err != nil { - t.Fatal(err) - } - - db0, err := sql.Open("sqlite3", filepath.Join(tempDir, "0", "db")) - if err != nil { - t.Fatal(err) - } else if _, err := db0.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil { - t.Fatal(err) - } else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { - t.Fatal(err) - } - defer db0.Close() - - // Execute writes periodically. - for i := 0; i < 100; i++ { - t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", i) - if _, err := db0.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, i); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) - } - - // Wait for replica to catch up. - time.Sleep(1 * time.Second) - - // Verify count in replica table. - db1, err := sql.Open("sqlite3", filepath.Join(tempDir, "1", "db")) - if err != nil { - t.Fatal(err) - } - defer db1.Close() - - var n int - if err := db1.QueryRowContext(ctx, `SELECT COUNT(*) FROM t`).Scan(&n); err != nil { - t.Fatal(err) - } else if got, want := n, 100; got != want { - t.Fatalf("replica count=%d, want %d", got, want) - } - - // Stop & wait for Litestream command. - killLitestreamCmd(t, cmd1, stdout1) // kill - killLitestreamCmd(t, cmd0, stdout0) -} - -// Ensure a database can recover when disconnected from HTTP. -func TestCmd_Replicate_HTTP_PartialRecovery(t *testing.T) { - ctx := context.Background() - testDir, tempDir := filepath.Join("testdata", "replicate", "http-partial-recovery"), t.TempDir() - if err := os.Mkdir(filepath.Join(tempDir, "0"), 0777); err != nil { - t.Fatal(err) - } else if err := os.Mkdir(filepath.Join(tempDir, "1"), 0777); err != nil { - t.Fatal(err) - } - - env0 := []string{"LITESTREAM_TEMPDIR=" + tempDir} - env1 := []string{"LITESTREAM_TEMPDIR=" + tempDir, "LITESTREAM_UPSTREAM_URL=http://localhost:10002"} - - cmd0, stdout0, _ := commandContext(ctx, env0, "replicate", "-config", filepath.Join(testDir, "litestream.0.yml")) - if err := cmd0.Start(); err != nil { - t.Fatal(err) - } - cmd1, stdout1, _ := commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml")) - if err := cmd1.Start(); err != nil { - t.Fatal(err) - } - - db0, err := sql.Open("sqlite3", filepath.Join(tempDir, "0", "db")) - if err != nil { - t.Fatal(err) - } else if _, err := db0.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil { - t.Fatal(err) - } else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { - t.Fatal(err) - } - defer db0.Close() - - var index int - insertAndWait := func() { - index++ - t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", index) - if _, err := db0.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, index); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) - } - - // Execute writes periodically. - for i := 0; i < 50; i++ { - insertAndWait() - } - - // Kill the replica. - t.Logf("Killing replica...") - killLitestreamCmd(t, cmd1, stdout1) - t.Logf("Replica killed") - - // Keep writing. - for i := 0; i < 25; i++ { - insertAndWait() - } - - // Restart replica. - t.Logf("Restarting replica...") - cmd1, stdout1, _ = commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml")) - if err := cmd1.Start(); err != nil { - t.Fatal(err) - } - t.Logf("Replica restarted") - - // Continue writing... - for i := 0; i < 25; i++ { - insertAndWait() - } - - // Wait for replica to catch up. - time.Sleep(1 * time.Second) - - // Verify count in replica table. - db1, err := sql.Open("sqlite3", filepath.Join(tempDir, "1", "db")) - if err != nil { - t.Fatal(err) - } - defer db1.Close() - - var n int - if err := db1.QueryRowContext(ctx, `SELECT COUNT(*) FROM t`).Scan(&n); err != nil { - t.Fatal(err) - } else if got, want := n, 100; got != want { - t.Fatalf("replica count=%d, want %d", got, want) - } - - // Stop & wait for Litestream command. - killLitestreamCmd(t, cmd1, stdout1) // kill - killLitestreamCmd(t, cmd0, stdout0) -} - -// Ensure a database can recover when disconnected from HTTP but when last index -// is no longer available. -func TestCmd_Replicate_HTTP_FullRecovery(t *testing.T) { - ctx := context.Background() - testDir, tempDir := filepath.Join("testdata", "replicate", "http-full-recovery"), t.TempDir() - if err := os.Mkdir(filepath.Join(tempDir, "0"), 0777); err != nil { - t.Fatal(err) - } else if err := os.Mkdir(filepath.Join(tempDir, "1"), 0777); err != nil { - t.Fatal(err) - } - - env0 := []string{"LITESTREAM_TEMPDIR=" + tempDir} - env1 := []string{"LITESTREAM_TEMPDIR=" + tempDir, "LITESTREAM_UPSTREAM_URL=http://localhost:10002"} - - cmd0, stdout0, _ := commandContext(ctx, env0, "replicate", "-config", filepath.Join(testDir, "litestream.0.yml")) - if err := cmd0.Start(); err != nil { - t.Fatal(err) - } - cmd1, stdout1, _ := commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml")) - if err := cmd1.Start(); err != nil { - t.Fatal(err) - } - - db0, err := sql.Open("sqlite3", filepath.Join(tempDir, "0", "db")) - if err != nil { - t.Fatal(err) - } else if _, err := db0.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil { - t.Fatal(err) - } else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { - t.Fatal(err) - } - defer db0.Close() - - var index int - insertAndWait := func() { - index++ - t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", index) - if _, err := db0.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, index); err != nil { - t.Fatal(err) - } - time.Sleep(100 * time.Millisecond) - } - - // Execute writes periodically. - for i := 0; i < 50; i++ { - insertAndWait() - } - - // Kill the replica. - t.Logf("Killing replica...") - killLitestreamCmd(t, cmd1, stdout1) - t.Logf("Replica killed") - - // Keep writing. - for i := 0; i < 25; i++ { - insertAndWait() - } - - // Restart replica. - t.Logf("Restarting replica...") - cmd1, stdout1, _ = commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml")) - if err := cmd1.Start(); err != nil { - t.Fatal(err) - } - t.Logf("Replica restarted") - - // Continue writing... - for i := 0; i < 25; i++ { - insertAndWait() - } - - // Wait for replica to catch up. - time.Sleep(1 * time.Second) - - // Verify count in replica table. - db1, err := sql.Open("sqlite3", filepath.Join(tempDir, "1", "db")) - if err != nil { - t.Fatal(err) - } - defer db1.Close() - - var n int - if err := db1.QueryRowContext(ctx, `SELECT COUNT(*) FROM t`).Scan(&n); err != nil { - t.Fatal(err) - } else if got, want := n, 100; got != want { - t.Fatalf("replica count=%d, want %d", got, want) - } - - // Stop & wait for Litestream command. - killLitestreamCmd(t, cmd1, stdout1) // kill - killLitestreamCmd(t, cmd0, stdout0) -} - // commandContext returns a "litestream" command with stdout/stderr buffers. func commandContext(ctx context.Context, env []string, arg ...string) (cmd *exec.Cmd, stdout, stderr *internal.LockingBuffer) { cmd = exec.CommandContext(ctx, "litestream", arg...) diff --git a/litestream.go b/litestream.go index b6ca7f0..6cf0b9b 100644 --- a/litestream.go +++ b/litestream.go @@ -1,7 +1,6 @@ package litestream import ( - "context" "database/sql" "encoding/binary" "errors" @@ -536,76 +535,6 @@ func ParseOffset(s string) (int64, error) { return int64(v), nil } -const ( - StreamRecordTypeSnapshot = 1 - StreamRecordTypeWALSegment = 2 -) - -const StreamRecordHeaderSize = 0 + - 4 + 4 + // type, flags - 8 + 8 + 8 + 8 // generation, index, offset, size - -type StreamRecordHeader struct { - Type int - Flags int - Generation string - Index int - Offset int64 - Size int64 -} - -func (hdr *StreamRecordHeader) Pos() Pos { - return Pos{ - Generation: hdr.Generation, - Index: hdr.Index, - Offset: hdr.Offset, - } -} - -func (hdr *StreamRecordHeader) MarshalBinary() ([]byte, error) { - generation, err := strconv.ParseUint(hdr.Generation, 16, 64) - if err != nil { - return nil, fmt.Errorf("invalid generation: %q", generation) - } - - data := make([]byte, StreamRecordHeaderSize) - binary.BigEndian.PutUint32(data[0:4], uint32(hdr.Type)) - binary.BigEndian.PutUint32(data[4:8], uint32(hdr.Flags)) - binary.BigEndian.PutUint64(data[8:16], generation) - binary.BigEndian.PutUint64(data[16:24], uint64(hdr.Index)) - binary.BigEndian.PutUint64(data[24:32], uint64(hdr.Offset)) - binary.BigEndian.PutUint64(data[32:40], uint64(hdr.Size)) - return data, nil -} - -// UnmarshalBinary from data into hdr. -func (hdr *StreamRecordHeader) UnmarshalBinary(data []byte) error { - if len(data) < StreamRecordHeaderSize { - return io.ErrUnexpectedEOF - } - hdr.Type = int(binary.BigEndian.Uint32(data[0:4])) - hdr.Flags = int(binary.BigEndian.Uint32(data[4:8])) - hdr.Generation = fmt.Sprintf("%016x", binary.BigEndian.Uint64(data[8:16])) - hdr.Index = int(binary.BigEndian.Uint64(data[16:24])) - hdr.Offset = int64(binary.BigEndian.Uint64(data[24:32])) - hdr.Size = int64(binary.BigEndian.Uint64(data[32:40])) - return nil -} - -// StreamClient represents a client for streaming changes to a replica DB. -type StreamClient interface { - // Stream returns a reader which contains and optional snapshot followed - // by a series of WAL segments. This stream begins from the given position. - Stream(ctx context.Context, pos Pos) (StreamReader, error) -} - -// StreamReader represents a reader that streams snapshot and WAL records. -type StreamReader interface { - io.ReadCloser - PageSize() int - Next() (*StreamRecordHeader, error) -} - // removeDBFiles deletes the database and related files (journal, shm, wal). func removeDBFiles(filename string) error { if err := os.Remove(filename); err != nil && !os.IsNotExist(err) {