From 44662022fa3a02541fd1589d3440676741333cfb Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Sun, 3 Apr 2022 09:18:54 -0600 Subject: [PATCH] Allow read replication recovery from last position --- db.go | 31 ++++- http/client.go | 19 +-- http/server.go | 119 +++++++++++------- integration/cmd_test.go | 92 ++++++++++++++ .../replicate/http-recovery/litestream.0.yml | 5 + .../replicate/http-recovery/litestream.1.yml | 5 + litestream.go | 74 ++++++++++- litestream_test.go | 72 +++++++++++ 8 files changed, 361 insertions(+), 56 deletions(-) create mode 100644 integration/testdata/replicate/http-recovery/litestream.0.yml create mode 100644 integration/testdata/replicate/http-recovery/litestream.1.yml diff --git a/db.go b/db.go index c1ac7da..b244698 100644 --- a/db.go +++ b/db.go @@ -914,6 +914,11 @@ func (db *DB) createGeneration(ctx context.Context) (string, error) { // Sync copies pending data from the WAL to the shadow WAL. func (db *DB) Sync(ctx context.Context) error { + if db.StreamClient != nil { + db.Logger.Printf("using upstream client, skipping sync") + return nil + } + const retryN = 5 for i := 0; i < retryN; i++ { @@ -1417,6 +1422,20 @@ func (db *DB) writeWALSegment(ctx context.Context, pos Pos, rd io.Reader) error return nil } +// readPositionFile reads the position from the position file. +func (db *DB) readPositionFile() (Pos, error) { + buf, err := os.ReadFile(db.PositionPath()) + if os.IsNotExist(err) { + return Pos{}, nil + } else if err != nil { + return Pos{}, err + } + + // Treat invalid format as a non-existent file so we return an empty position. + pos, _ := ParsePos(strings.TrimSpace(string(buf))) + return pos, nil +} + // writePositionFile writes pos as the current position. func (db *DB) writePositionFile(pos Pos) error { return internal.WriteFile(db.PositionPath(), []byte(pos.String()+"\n"), db.fileMode, db.uid, db.gid) @@ -1675,18 +1694,20 @@ func (db *DB) monitorUpstream(ctx context.Context) error { // stream initializes the local database and continuously streams new upstream data. func (db *DB) stream(ctx context.Context) error { + pos, err := db.readPositionFile() + if err != nil { + return fmt.Errorf("read position file: %w", err) + } + // Continuously stream and apply records from client. - sr, err := db.StreamClient.Stream(ctx) + sr, err := db.StreamClient.Stream(ctx, pos) if err != nil { return fmt.Errorf("stream connect: %w", err) } defer sr.Close() - // TODO: Determine page size of upstream database before creating local. - const pageSize = 4096 - // Initialize the database and create it if it doesn't exist. - if err := db.initReplica(pageSize); err != nil { + if err := db.initReplica(sr.PageSize()); err != nil { return fmt.Errorf("init replica: %w", err) } diff --git a/http/client.go b/http/client.go index 5c5ae75..2975ec6 100644 --- a/http/client.go +++ b/http/client.go @@ -33,7 +33,7 @@ func NewClient(rawurl, path string) *Client { } // Stream returns a snapshot and continuous stream of WAL updates. -func (c *Client) Stream(ctx context.Context) (litestream.StreamReader, error) { +func (c *Client) Stream(ctx context.Context, pos litestream.Pos) (litestream.StreamReader, error) { u, err := url.Parse(c.URL) if err != nil { return nil, fmt.Errorf("invalid client URL: %w", err) @@ -43,14 +43,19 @@ func (c *Client) Stream(ctx context.Context) (litestream.StreamReader, error) { return nil, fmt.Errorf("URL host required") } + // Add path & position to query path. + q := url.Values{"path": []string{c.Path}} + if !pos.IsZero() { + q.Set("generation", pos.Generation) + q.Set("index", litestream.FormatIndex(pos.Index)) + } + // Strip off everything but the scheme & host. *u = url.URL{ - Scheme: u.Scheme, - Host: u.Host, - Path: "/stream", - RawQuery: (url.Values{ - "path": []string{c.Path}, - }).Encode(), + Scheme: u.Scheme, + Host: u.Host, + Path: "/stream", + RawQuery: q.Encode(), } req, err := http.NewRequest("GET", u.String(), nil) diff --git a/http/server.go b/http/server.go index dc22656..dbc0a61 100644 --- a/http/server.go +++ b/http/server.go @@ -128,13 +128,25 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() - - // TODO: Listen for all databases matching query criteria. path := q.Get("path") if path == "" { s.writeError(w, r, "Database name required", http.StatusBadRequest) return } + + // Parse current client position, if available. + var pos litestream.Pos + if generation, index := q.Get("generation"), q.Get("index"); generation != "" && index != "" { + pos.Generation = generation + + var err error + if pos.Index, err = litestream.ParseIndex(index); err != nil { + s.writeError(w, r, "Invalid index query parameter", http.StatusBadRequest) + return + } + } + + // Fetch database instance from the primary server. db := s.server.DB(path) if db == nil { s.writeError(w, r, "Database not found", http.StatusNotFound) @@ -144,70 +156,91 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) { // Set the page size in the header. w.Header().Set("Litestream-page-size", strconv.Itoa(db.PageSize())) - // TODO: Restart stream from a previous position, if specified. - // Determine starting position. - pos := db.Pos() - if pos.Generation == "" { + dbPos := db.Pos() + if dbPos.Generation == "" { s.writeError(w, r, "No generation available", http.StatusServiceUnavailable) return } - pos.Offset = 0 + dbPos.Offset = 0 - s.Logger.Printf("stream connected @ %s", pos) - defer s.Logger.Printf("stream disconnected") + // Use database position if generation has changed. + var snapshotRequired bool + if pos.Generation != dbPos.Generation { + s.Logger.Printf("stream generation mismatch, using primary position: client.pos=%s", pos) + pos, snapshotRequired = dbPos, true + } // Obtain iterator before snapshot so we don't miss any WAL segments. - itr, err := db.WALSegments(r.Context(), pos.Generation) + fitr, err := db.WALSegments(r.Context(), pos.Generation) if err != nil { s.writeError(w, r, fmt.Sprintf("Cannot obtain WAL iterator: %s", err), http.StatusInternalServerError) return } - defer itr.Close() + defer fitr.Close() - // Write snapshot to response body. - if err := db.WithFile(func(f *os.File) error { - fi, err := f.Stat() - if err != nil { - return err - } + bitr := litestream.NewBufferedWALSegmentIterator(fitr) - // Write snapshot header with current position & size. - hdr := litestream.StreamRecordHeader{ - Type: litestream.StreamRecordTypeSnapshot, - Generation: pos.Generation, - Index: pos.Index, - Size: fi.Size(), - } - if buf, err := hdr.MarshalBinary(); err != nil { - return fmt.Errorf("marshal snapshot stream record header: %w", err) - } else if _, err := w.Write(buf); err != nil { - return fmt.Errorf("write snapshot stream record header: %w", err) - } - - if _, err := io.CopyN(w, f, fi.Size()); err != nil { - return fmt.Errorf("copy snapshot: %w", err) - } - - return nil - }); err != nil { - s.writeError(w, r, err.Error(), http.StatusInternalServerError) + // Peek at first position to see if client is too old. + if info, ok := bitr.Peek(); !ok { + s.writeError(w, r, "cannot peek WAL iterator, no segments available", http.StatusInternalServerError) return + } else if cmp, err := litestream.ComparePos(pos, info.Pos()); err != nil { + s.writeError(w, r, fmt.Sprintf("cannot compare pos: %s", err), http.StatusInternalServerError) + return + } else if cmp == -1 { + s.Logger.Printf("stream position no longer available, using using primary position: client.pos=%s", pos) + pos, snapshotRequired = dbPos, true } - // Flush after snapshot has been written. - w.(http.Flusher).Flush() + s.Logger.Printf("stream connected: pos=%s snapshot=%v", pos, snapshotRequired) + defer s.Logger.Printf("stream disconnected") + + // Write snapshot to response body. + if snapshotRequired { + if err := db.WithFile(func(f *os.File) error { + fi, err := f.Stat() + if err != nil { + return err + } + + // Write snapshot header with current position & size. + hdr := litestream.StreamRecordHeader{ + Type: litestream.StreamRecordTypeSnapshot, + Generation: pos.Generation, + Index: pos.Index, + Size: fi.Size(), + } + if buf, err := hdr.MarshalBinary(); err != nil { + return fmt.Errorf("marshal snapshot stream record header: %w", err) + } else if _, err := w.Write(buf); err != nil { + return fmt.Errorf("write snapshot stream record header: %w", err) + } + + if _, err := io.CopyN(w, f, fi.Size()); err != nil { + return fmt.Errorf("copy snapshot: %w", err) + } + + return nil + }); err != nil { + s.writeError(w, r, err.Error(), http.StatusInternalServerError) + return + } + + // Flush after snapshot has been written. + w.(http.Flusher).Flush() + } for { // Wait for notification of new entries. select { case <-r.Context().Done(): return - case <-itr.NotifyCh(): + case <-fitr.NotifyCh(): } - for itr.Next() { - info := itr.WALSegment() + for bitr.Next() { + info := bitr.WALSegment() // Skip any segments before our initial position. if cmp, err := litestream.ComparePos(info.Pos(), pos); err != nil { @@ -256,7 +289,7 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) { // Flush after WAL segment has been written. w.(http.Flusher).Flush() } - if itr.Err() != nil { + if bitr.Err() != nil { s.Logger.Printf("wal iterator error: %s", err) return } diff --git a/integration/cmd_test.go b/integration/cmd_test.go index 3cab314..96b8558 100644 --- a/integration/cmd_test.go +++ b/integration/cmd_test.go @@ -454,6 +454,98 @@ func TestCmd_Replicate_HTTP(t *testing.T) { killLitestreamCmd(t, cmd0, stdout0) } +// Ensure a database can recover when disconnected from HTTP. +func TestCmd_Replicate_HTTP_Recovery(t *testing.T) { + ctx := context.Background() + testDir, tempDir := filepath.Join("testdata", "replicate", "http-recovery"), t.TempDir() + if err := os.Mkdir(filepath.Join(tempDir, "0"), 0777); err != nil { + t.Fatal(err) + } else if err := os.Mkdir(filepath.Join(tempDir, "1"), 0777); err != nil { + t.Fatal(err) + } + + env0 := []string{"LITESTREAM_TEMPDIR=" + tempDir} + env1 := []string{"LITESTREAM_TEMPDIR=" + tempDir, "LITESTREAM_UPSTREAM_URL=http://localhost:10002"} + + cmd0, stdout0, _ := commandContext(ctx, env0, "replicate", "-config", filepath.Join(testDir, "litestream.0.yml")) + if err := cmd0.Start(); err != nil { + t.Fatal(err) + } + cmd1, stdout1, _ := commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml")) + if err := cmd1.Start(); err != nil { + t.Fatal(err) + } + + db0, err := sql.Open("sqlite3", filepath.Join(tempDir, "0", "db")) + if err != nil { + t.Fatal(err) + } else if _, err := db0.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil { + t.Fatal(err) + } else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { + t.Fatal(err) + } + defer db0.Close() + + var index int + insertAndWait := func() { + index++ + t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", index) + if _, err := db0.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, index); err != nil { + t.Fatal(err) + } + time.Sleep(100 * time.Millisecond) + } + + // Execute writes periodically. + for i := 0; i < 50; i++ { + insertAndWait() + } + + // Kill the replica. + t.Logf("Killing replica...") + killLitestreamCmd(t, cmd1, stdout1) + t.Logf("Replica killed") + + // Keep writing. + for i := 0; i < 25; i++ { + insertAndWait() + } + + // Restart replica. + t.Logf("Restarting replica...") + cmd1, stdout1, _ = commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml")) + if err := cmd1.Start(); err != nil { + t.Fatal(err) + } + t.Logf("Replica restarted") + + // Continue writing... + for i := 0; i < 25; i++ { + insertAndWait() + } + + // Wait for replica to catch up. + time.Sleep(1 * time.Second) + + // Verify count in replica table. + db1, err := sql.Open("sqlite3", filepath.Join(tempDir, "1", "db")) + if err != nil { + t.Fatal(err) + } + defer db1.Close() + + var n int + if err := db1.QueryRowContext(ctx, `SELECT COUNT(*) FROM t`).Scan(&n); err != nil { + t.Fatal(err) + } else if got, want := n, 100; got != want { + t.Fatalf("replica count=%d, want %d", got, want) + } + + // Stop & wait for Litestream command. + killLitestreamCmd(t, cmd1, stdout1) // kill + killLitestreamCmd(t, cmd0, stdout0) +} + // commandContext returns a "litestream" command with stdout/stderr buffers. func commandContext(ctx context.Context, env []string, arg ...string) (cmd *exec.Cmd, stdout, stderr *internal.LockingBuffer) { cmd = exec.CommandContext(ctx, "litestream", arg...) diff --git a/integration/testdata/replicate/http-recovery/litestream.0.yml b/integration/testdata/replicate/http-recovery/litestream.0.yml new file mode 100644 index 0000000..41c7b1b --- /dev/null +++ b/integration/testdata/replicate/http-recovery/litestream.0.yml @@ -0,0 +1,5 @@ +addr: :10002 + +dbs: + - path: $LITESTREAM_TEMPDIR/0/db + max-checkpoint-page-count: 10 diff --git a/integration/testdata/replicate/http-recovery/litestream.1.yml b/integration/testdata/replicate/http-recovery/litestream.1.yml new file mode 100644 index 0000000..e973570 --- /dev/null +++ b/integration/testdata/replicate/http-recovery/litestream.1.yml @@ -0,0 +1,5 @@ +dbs: + - path: $LITESTREAM_TEMPDIR/1/db + upstream: + url: "$LITESTREAM_UPSTREAM_URL" + path: "$LITESTREAM_TEMPDIR/0/db" diff --git a/litestream.go b/litestream.go index a367b33..0e9509c 100644 --- a/litestream.go +++ b/litestream.go @@ -10,6 +10,7 @@ import ( "math" "os" "path/filepath" + "regexp" "strconv" "strings" "time" @@ -191,6 +192,49 @@ func (itr *WALSegmentInfoSliceIterator) WALSegment() WALSegmentInfo { return itr.a[0] } +type BufferedWALSegmentIterator struct { + itr WALSegmentIterator + buffered bool +} + +// NewBufferedWALSegmentIterator returns a new instance of BufferedWALSegmentIterator. +func NewBufferedWALSegmentIterator(itr WALSegmentIterator) *BufferedWALSegmentIterator { + return &BufferedWALSegmentIterator{itr: itr} +} + +// Close closes the underlying iterator. +func (itr *BufferedWALSegmentIterator) Close() error { + return itr.itr.Close() +} + +// Peek returns the next segment without moving the iterator forward. +func (itr *BufferedWALSegmentIterator) Peek() (info WALSegmentInfo, ok bool) { + if !itr.Next() { + return WALSegmentInfo{}, false + } + itr.buffered = true + return itr.itr.WALSegment(), true +} + +// Next returns the next segment. If buffer is full, this call is a no-op. +func (itr *BufferedWALSegmentIterator) Next() bool { + if itr.buffered { + itr.buffered = false + return true + } + return itr.itr.Next() +} + +// Returns an error that occurred during iteration. +func (itr *BufferedWALSegmentIterator) Err() error { + return itr.itr.Err() +} + +// Returns metadata for the currently positioned WAL segment file. +func (itr *BufferedWALSegmentIterator) WALSegment() WALSegmentInfo { + return itr.itr.WALSegment() +} + // SnapshotInfo represents file information about a snapshot. type SnapshotInfo struct { Generation string @@ -302,6 +346,32 @@ type Pos struct { Offset int64 // offset within wal file } +// ParsePos parses a position generated by Pos.String(). +func ParsePos(s string) (Pos, error) { + a := posRegex.FindStringSubmatch(s) + if a == nil { + return Pos{}, fmt.Errorf("invalid pos: %q", s) + } + + index, err := ParseIndex(a[2]) + if err != nil { + return Pos{}, err + } + + offset, err := ParseOffset(a[3]) + if err != nil { + return Pos{}, err + } + + return Pos{ + Generation: a[1], + Index: index, + Offset: offset, + }, nil +} + +var posRegex = regexp.MustCompile(`^(\w+)/(\w+):(\w+)$`) + // String returns a string representation. func (p Pos) String() string { if p.IsZero() { @@ -524,7 +594,9 @@ func (hdr *StreamRecordHeader) UnmarshalBinary(data []byte) error { // StreamClient represents a client for streaming changes to a replica DB. type StreamClient interface { - Stream(ctx context.Context) (StreamReader, error) + // Stream returns a reader which contains and optional snapshot followed + // by a series of WAL segments. This stream begins from the given position. + Stream(ctx context.Context, pos Pos) (StreamReader, error) } // StreamReader represents a reader that streams snapshot and WAL records. diff --git a/litestream_test.go b/litestream_test.go index 2be6ba2..860b658 100644 --- a/litestream_test.go +++ b/litestream_test.go @@ -52,6 +52,78 @@ func TestFindMinSnapshotByGeneration(t *testing.T) { } } +func TestBufferedWALSegmentIterator(t *testing.T) { + t.Run("OK", func(t *testing.T) { + a := []litestream.WALSegmentInfo{{Index: 1}, {Index: 2}} + itr := litestream.NewBufferedWALSegmentIterator(litestream.NewWALSegmentInfoSliceIterator(a)) + + if info, ok := itr.Peek(); !ok { + t.Fatal("expected info") + } else if got, want := info.Index, 1; got != want { + t.Fatalf("index=%d, want %d", got, want) + } + + if !itr.Next() { + t.Fatal("expected next") + } else if got, want := itr.WALSegment().Index, 1; got != want { + t.Fatalf("index=%d, want %d", got, want) + } + + if !itr.Next() { + t.Fatal("expected next") + } else if got, want := itr.WALSegment().Index, 2; got != want { + t.Fatalf("index=%d, want %d", got, want) + } + + if itr.Next() { + t.Fatal("expected eof") + } + }) + + t.Run("Empty", func(t *testing.T) { + itr := litestream.NewBufferedWALSegmentIterator(litestream.NewWALSegmentInfoSliceIterator(nil)) + + if info, ok := itr.Peek(); ok { + t.Fatal("expected eof") + } else if got, want := info.Index, 0; got != want { + t.Fatalf("index=%d, want %d", got, want) + } + }) +} + +func TestParsePos(t *testing.T) { + t.Run("OK", func(t *testing.T) { + if pos, err := litestream.ParsePos("29cf4bced74e92ab/00000000000003e8:00000000000007d0"); err != nil { + t.Fatal(err) + } else if got, want := pos.Generation, "29cf4bced74e92ab"; got != want { + t.Fatalf("generation=%s, want %s", got, want) + } else if got, want := pos.Index, 1000; got != want { + t.Fatalf("index=%v, want %v", got, want) + } else if got, want := pos.Offset, 2000; got != int64(want) { + t.Fatalf("offset=%v, want %v", got, want) + } + }) + + t.Run("ErrMismatch", func(t *testing.T) { + _, err := litestream.ParsePos("29cf4bced74e92ab-00000000000003e8-00000000000007d0") + if err == nil || err.Error() != `invalid pos: "29cf4bced74e92ab-00000000000003e8-00000000000007d0"` { + t.Fatal(err) + } + }) + t.Run("ErrInvalidIndex", func(t *testing.T) { + _, err := litestream.ParsePos("29cf4bced74e92ab/0000000000000xxx:00000000000007d0") + if err == nil || err.Error() != `cannot parse index: "0000000000000xxx"` { + t.Fatal(err) + } + }) + t.Run("ErrInvalidIndex", func(t *testing.T) { + _, err := litestream.ParsePos("29cf4bced74e92ab/00000000000003e8:0000000000000xxx") + if err == nil || err.Error() != `cannot parse offset: "0000000000000xxx"` { + t.Fatal(err) + } + }) +} + func decodeHexString(tb testing.TB, s string) []byte { tb.Helper()