Allow read replication recovery from last position

This commit is contained in:
Ben Johnson
2022-04-03 09:18:54 -06:00
parent 2c3e28c786
commit 44662022fa
8 changed files with 361 additions and 56 deletions

31
db.go
View File

@@ -914,6 +914,11 @@ func (db *DB) createGeneration(ctx context.Context) (string, error) {
// Sync copies pending data from the WAL to the shadow WAL. // Sync copies pending data from the WAL to the shadow WAL.
func (db *DB) Sync(ctx context.Context) error { func (db *DB) Sync(ctx context.Context) error {
if db.StreamClient != nil {
db.Logger.Printf("using upstream client, skipping sync")
return nil
}
const retryN = 5 const retryN = 5
for i := 0; i < retryN; i++ { for i := 0; i < retryN; i++ {
@@ -1417,6 +1422,20 @@ func (db *DB) writeWALSegment(ctx context.Context, pos Pos, rd io.Reader) error
return nil return nil
} }
// readPositionFile reads the position from the position file.
func (db *DB) readPositionFile() (Pos, error) {
buf, err := os.ReadFile(db.PositionPath())
if os.IsNotExist(err) {
return Pos{}, nil
} else if err != nil {
return Pos{}, err
}
// Treat invalid format as a non-existent file so we return an empty position.
pos, _ := ParsePos(strings.TrimSpace(string(buf)))
return pos, nil
}
// writePositionFile writes pos as the current position. // writePositionFile writes pos as the current position.
func (db *DB) writePositionFile(pos Pos) error { func (db *DB) writePositionFile(pos Pos) error {
return internal.WriteFile(db.PositionPath(), []byte(pos.String()+"\n"), db.fileMode, db.uid, db.gid) return internal.WriteFile(db.PositionPath(), []byte(pos.String()+"\n"), db.fileMode, db.uid, db.gid)
@@ -1675,18 +1694,20 @@ func (db *DB) monitorUpstream(ctx context.Context) error {
// stream initializes the local database and continuously streams new upstream data. // stream initializes the local database and continuously streams new upstream data.
func (db *DB) stream(ctx context.Context) error { func (db *DB) stream(ctx context.Context) error {
pos, err := db.readPositionFile()
if err != nil {
return fmt.Errorf("read position file: %w", err)
}
// Continuously stream and apply records from client. // Continuously stream and apply records from client.
sr, err := db.StreamClient.Stream(ctx) sr, err := db.StreamClient.Stream(ctx, pos)
if err != nil { if err != nil {
return fmt.Errorf("stream connect: %w", err) return fmt.Errorf("stream connect: %w", err)
} }
defer sr.Close() defer sr.Close()
// TODO: Determine page size of upstream database before creating local.
const pageSize = 4096
// Initialize the database and create it if it doesn't exist. // Initialize the database and create it if it doesn't exist.
if err := db.initReplica(pageSize); err != nil { if err := db.initReplica(sr.PageSize()); err != nil {
return fmt.Errorf("init replica: %w", err) return fmt.Errorf("init replica: %w", err)
} }

View File

@@ -33,7 +33,7 @@ func NewClient(rawurl, path string) *Client {
} }
// Stream returns a snapshot and continuous stream of WAL updates. // Stream returns a snapshot and continuous stream of WAL updates.
func (c *Client) Stream(ctx context.Context) (litestream.StreamReader, error) { func (c *Client) Stream(ctx context.Context, pos litestream.Pos) (litestream.StreamReader, error) {
u, err := url.Parse(c.URL) u, err := url.Parse(c.URL)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid client URL: %w", err) return nil, fmt.Errorf("invalid client URL: %w", err)
@@ -43,14 +43,19 @@ func (c *Client) Stream(ctx context.Context) (litestream.StreamReader, error) {
return nil, fmt.Errorf("URL host required") return nil, fmt.Errorf("URL host required")
} }
// Add path & position to query path.
q := url.Values{"path": []string{c.Path}}
if !pos.IsZero() {
q.Set("generation", pos.Generation)
q.Set("index", litestream.FormatIndex(pos.Index))
}
// Strip off everything but the scheme & host. // Strip off everything but the scheme & host.
*u = url.URL{ *u = url.URL{
Scheme: u.Scheme, Scheme: u.Scheme,
Host: u.Host, Host: u.Host,
Path: "/stream", Path: "/stream",
RawQuery: (url.Values{ RawQuery: q.Encode(),
"path": []string{c.Path},
}).Encode(),
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), nil)

View File

@@ -128,13 +128,25 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) { func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query() q := r.URL.Query()
// TODO: Listen for all databases matching query criteria.
path := q.Get("path") path := q.Get("path")
if path == "" { if path == "" {
s.writeError(w, r, "Database name required", http.StatusBadRequest) s.writeError(w, r, "Database name required", http.StatusBadRequest)
return return
} }
// Parse current client position, if available.
var pos litestream.Pos
if generation, index := q.Get("generation"), q.Get("index"); generation != "" && index != "" {
pos.Generation = generation
var err error
if pos.Index, err = litestream.ParseIndex(index); err != nil {
s.writeError(w, r, "Invalid index query parameter", http.StatusBadRequest)
return
}
}
// Fetch database instance from the primary server.
db := s.server.DB(path) db := s.server.DB(path)
if db == nil { if db == nil {
s.writeError(w, r, "Database not found", http.StatusNotFound) s.writeError(w, r, "Database not found", http.StatusNotFound)
@@ -144,70 +156,91 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) {
// Set the page size in the header. // Set the page size in the header.
w.Header().Set("Litestream-page-size", strconv.Itoa(db.PageSize())) w.Header().Set("Litestream-page-size", strconv.Itoa(db.PageSize()))
// TODO: Restart stream from a previous position, if specified.
// Determine starting position. // Determine starting position.
pos := db.Pos() dbPos := db.Pos()
if pos.Generation == "" { if dbPos.Generation == "" {
s.writeError(w, r, "No generation available", http.StatusServiceUnavailable) s.writeError(w, r, "No generation available", http.StatusServiceUnavailable)
return return
} }
pos.Offset = 0 dbPos.Offset = 0
s.Logger.Printf("stream connected @ %s", pos) // Use database position if generation has changed.
defer s.Logger.Printf("stream disconnected") var snapshotRequired bool
if pos.Generation != dbPos.Generation {
s.Logger.Printf("stream generation mismatch, using primary position: client.pos=%s", pos)
pos, snapshotRequired = dbPos, true
}
// Obtain iterator before snapshot so we don't miss any WAL segments. // Obtain iterator before snapshot so we don't miss any WAL segments.
itr, err := db.WALSegments(r.Context(), pos.Generation) fitr, err := db.WALSegments(r.Context(), pos.Generation)
if err != nil { if err != nil {
s.writeError(w, r, fmt.Sprintf("Cannot obtain WAL iterator: %s", err), http.StatusInternalServerError) s.writeError(w, r, fmt.Sprintf("Cannot obtain WAL iterator: %s", err), http.StatusInternalServerError)
return return
} }
defer itr.Close() defer fitr.Close()
// Write snapshot to response body. bitr := litestream.NewBufferedWALSegmentIterator(fitr)
if err := db.WithFile(func(f *os.File) error {
fi, err := f.Stat()
if err != nil {
return err
}
// Write snapshot header with current position & size. // Peek at first position to see if client is too old.
hdr := litestream.StreamRecordHeader{ if info, ok := bitr.Peek(); !ok {
Type: litestream.StreamRecordTypeSnapshot, s.writeError(w, r, "cannot peek WAL iterator, no segments available", http.StatusInternalServerError)
Generation: pos.Generation,
Index: pos.Index,
Size: fi.Size(),
}
if buf, err := hdr.MarshalBinary(); err != nil {
return fmt.Errorf("marshal snapshot stream record header: %w", err)
} else if _, err := w.Write(buf); err != nil {
return fmt.Errorf("write snapshot stream record header: %w", err)
}
if _, err := io.CopyN(w, f, fi.Size()); err != nil {
return fmt.Errorf("copy snapshot: %w", err)
}
return nil
}); err != nil {
s.writeError(w, r, err.Error(), http.StatusInternalServerError)
return return
} else if cmp, err := litestream.ComparePos(pos, info.Pos()); err != nil {
s.writeError(w, r, fmt.Sprintf("cannot compare pos: %s", err), http.StatusInternalServerError)
return
} else if cmp == -1 {
s.Logger.Printf("stream position no longer available, using using primary position: client.pos=%s", pos)
pos, snapshotRequired = dbPos, true
} }
// Flush after snapshot has been written. s.Logger.Printf("stream connected: pos=%s snapshot=%v", pos, snapshotRequired)
w.(http.Flusher).Flush() defer s.Logger.Printf("stream disconnected")
// Write snapshot to response body.
if snapshotRequired {
if err := db.WithFile(func(f *os.File) error {
fi, err := f.Stat()
if err != nil {
return err
}
// Write snapshot header with current position & size.
hdr := litestream.StreamRecordHeader{
Type: litestream.StreamRecordTypeSnapshot,
Generation: pos.Generation,
Index: pos.Index,
Size: fi.Size(),
}
if buf, err := hdr.MarshalBinary(); err != nil {
return fmt.Errorf("marshal snapshot stream record header: %w", err)
} else if _, err := w.Write(buf); err != nil {
return fmt.Errorf("write snapshot stream record header: %w", err)
}
if _, err := io.CopyN(w, f, fi.Size()); err != nil {
return fmt.Errorf("copy snapshot: %w", err)
}
return nil
}); err != nil {
s.writeError(w, r, err.Error(), http.StatusInternalServerError)
return
}
// Flush after snapshot has been written.
w.(http.Flusher).Flush()
}
for { for {
// Wait for notification of new entries. // Wait for notification of new entries.
select { select {
case <-r.Context().Done(): case <-r.Context().Done():
return return
case <-itr.NotifyCh(): case <-fitr.NotifyCh():
} }
for itr.Next() { for bitr.Next() {
info := itr.WALSegment() info := bitr.WALSegment()
// Skip any segments before our initial position. // Skip any segments before our initial position.
if cmp, err := litestream.ComparePos(info.Pos(), pos); err != nil { if cmp, err := litestream.ComparePos(info.Pos(), pos); err != nil {
@@ -256,7 +289,7 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) {
// Flush after WAL segment has been written. // Flush after WAL segment has been written.
w.(http.Flusher).Flush() w.(http.Flusher).Flush()
} }
if itr.Err() != nil { if bitr.Err() != nil {
s.Logger.Printf("wal iterator error: %s", err) s.Logger.Printf("wal iterator error: %s", err)
return return
} }

View File

@@ -454,6 +454,98 @@ func TestCmd_Replicate_HTTP(t *testing.T) {
killLitestreamCmd(t, cmd0, stdout0) killLitestreamCmd(t, cmd0, stdout0)
} }
// Ensure a database can recover when disconnected from HTTP.
func TestCmd_Replicate_HTTP_Recovery(t *testing.T) {
ctx := context.Background()
testDir, tempDir := filepath.Join("testdata", "replicate", "http-recovery"), t.TempDir()
if err := os.Mkdir(filepath.Join(tempDir, "0"), 0777); err != nil {
t.Fatal(err)
} else if err := os.Mkdir(filepath.Join(tempDir, "1"), 0777); err != nil {
t.Fatal(err)
}
env0 := []string{"LITESTREAM_TEMPDIR=" + tempDir}
env1 := []string{"LITESTREAM_TEMPDIR=" + tempDir, "LITESTREAM_UPSTREAM_URL=http://localhost:10002"}
cmd0, stdout0, _ := commandContext(ctx, env0, "replicate", "-config", filepath.Join(testDir, "litestream.0.yml"))
if err := cmd0.Start(); err != nil {
t.Fatal(err)
}
cmd1, stdout1, _ := commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml"))
if err := cmd1.Start(); err != nil {
t.Fatal(err)
}
db0, err := sql.Open("sqlite3", filepath.Join(tempDir, "0", "db"))
if err != nil {
t.Fatal(err)
} else if _, err := db0.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil {
t.Fatal(err)
} else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil {
t.Fatal(err)
}
defer db0.Close()
var index int
insertAndWait := func() {
index++
t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", index)
if _, err := db0.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, index); err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)
}
// Execute writes periodically.
for i := 0; i < 50; i++ {
insertAndWait()
}
// Kill the replica.
t.Logf("Killing replica...")
killLitestreamCmd(t, cmd1, stdout1)
t.Logf("Replica killed")
// Keep writing.
for i := 0; i < 25; i++ {
insertAndWait()
}
// Restart replica.
t.Logf("Restarting replica...")
cmd1, stdout1, _ = commandContext(ctx, env1, "replicate", "-config", filepath.Join(testDir, "litestream.1.yml"))
if err := cmd1.Start(); err != nil {
t.Fatal(err)
}
t.Logf("Replica restarted")
// Continue writing...
for i := 0; i < 25; i++ {
insertAndWait()
}
// Wait for replica to catch up.
time.Sleep(1 * time.Second)
// Verify count in replica table.
db1, err := sql.Open("sqlite3", filepath.Join(tempDir, "1", "db"))
if err != nil {
t.Fatal(err)
}
defer db1.Close()
var n int
if err := db1.QueryRowContext(ctx, `SELECT COUNT(*) FROM t`).Scan(&n); err != nil {
t.Fatal(err)
} else if got, want := n, 100; got != want {
t.Fatalf("replica count=%d, want %d", got, want)
}
// Stop & wait for Litestream command.
killLitestreamCmd(t, cmd1, stdout1) // kill
killLitestreamCmd(t, cmd0, stdout0)
}
// commandContext returns a "litestream" command with stdout/stderr buffers. // commandContext returns a "litestream" command with stdout/stderr buffers.
func commandContext(ctx context.Context, env []string, arg ...string) (cmd *exec.Cmd, stdout, stderr *internal.LockingBuffer) { func commandContext(ctx context.Context, env []string, arg ...string) (cmd *exec.Cmd, stdout, stderr *internal.LockingBuffer) {
cmd = exec.CommandContext(ctx, "litestream", arg...) cmd = exec.CommandContext(ctx, "litestream", arg...)

View File

@@ -0,0 +1,5 @@
addr: :10002
dbs:
- path: $LITESTREAM_TEMPDIR/0/db
max-checkpoint-page-count: 10

View File

@@ -0,0 +1,5 @@
dbs:
- path: $LITESTREAM_TEMPDIR/1/db
upstream:
url: "$LITESTREAM_UPSTREAM_URL"
path: "$LITESTREAM_TEMPDIR/0/db"

View File

@@ -10,6 +10,7 @@ import (
"math" "math"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -191,6 +192,49 @@ func (itr *WALSegmentInfoSliceIterator) WALSegment() WALSegmentInfo {
return itr.a[0] return itr.a[0]
} }
type BufferedWALSegmentIterator struct {
itr WALSegmentIterator
buffered bool
}
// NewBufferedWALSegmentIterator returns a new instance of BufferedWALSegmentIterator.
func NewBufferedWALSegmentIterator(itr WALSegmentIterator) *BufferedWALSegmentIterator {
return &BufferedWALSegmentIterator{itr: itr}
}
// Close closes the underlying iterator.
func (itr *BufferedWALSegmentIterator) Close() error {
return itr.itr.Close()
}
// Peek returns the next segment without moving the iterator forward.
func (itr *BufferedWALSegmentIterator) Peek() (info WALSegmentInfo, ok bool) {
if !itr.Next() {
return WALSegmentInfo{}, false
}
itr.buffered = true
return itr.itr.WALSegment(), true
}
// Next returns the next segment. If buffer is full, this call is a no-op.
func (itr *BufferedWALSegmentIterator) Next() bool {
if itr.buffered {
itr.buffered = false
return true
}
return itr.itr.Next()
}
// Returns an error that occurred during iteration.
func (itr *BufferedWALSegmentIterator) Err() error {
return itr.itr.Err()
}
// Returns metadata for the currently positioned WAL segment file.
func (itr *BufferedWALSegmentIterator) WALSegment() WALSegmentInfo {
return itr.itr.WALSegment()
}
// SnapshotInfo represents file information about a snapshot. // SnapshotInfo represents file information about a snapshot.
type SnapshotInfo struct { type SnapshotInfo struct {
Generation string Generation string
@@ -302,6 +346,32 @@ type Pos struct {
Offset int64 // offset within wal file Offset int64 // offset within wal file
} }
// ParsePos parses a position generated by Pos.String().
func ParsePos(s string) (Pos, error) {
a := posRegex.FindStringSubmatch(s)
if a == nil {
return Pos{}, fmt.Errorf("invalid pos: %q", s)
}
index, err := ParseIndex(a[2])
if err != nil {
return Pos{}, err
}
offset, err := ParseOffset(a[3])
if err != nil {
return Pos{}, err
}
return Pos{
Generation: a[1],
Index: index,
Offset: offset,
}, nil
}
var posRegex = regexp.MustCompile(`^(\w+)/(\w+):(\w+)$`)
// String returns a string representation. // String returns a string representation.
func (p Pos) String() string { func (p Pos) String() string {
if p.IsZero() { if p.IsZero() {
@@ -524,7 +594,9 @@ func (hdr *StreamRecordHeader) UnmarshalBinary(data []byte) error {
// StreamClient represents a client for streaming changes to a replica DB. // StreamClient represents a client for streaming changes to a replica DB.
type StreamClient interface { type StreamClient interface {
Stream(ctx context.Context) (StreamReader, error) // Stream returns a reader which contains and optional snapshot followed
// by a series of WAL segments. This stream begins from the given position.
Stream(ctx context.Context, pos Pos) (StreamReader, error)
} }
// StreamReader represents a reader that streams snapshot and WAL records. // StreamReader represents a reader that streams snapshot and WAL records.

View File

@@ -52,6 +52,78 @@ func TestFindMinSnapshotByGeneration(t *testing.T) {
} }
} }
func TestBufferedWALSegmentIterator(t *testing.T) {
t.Run("OK", func(t *testing.T) {
a := []litestream.WALSegmentInfo{{Index: 1}, {Index: 2}}
itr := litestream.NewBufferedWALSegmentIterator(litestream.NewWALSegmentInfoSliceIterator(a))
if info, ok := itr.Peek(); !ok {
t.Fatal("expected info")
} else if got, want := info.Index, 1; got != want {
t.Fatalf("index=%d, want %d", got, want)
}
if !itr.Next() {
t.Fatal("expected next")
} else if got, want := itr.WALSegment().Index, 1; got != want {
t.Fatalf("index=%d, want %d", got, want)
}
if !itr.Next() {
t.Fatal("expected next")
} else if got, want := itr.WALSegment().Index, 2; got != want {
t.Fatalf("index=%d, want %d", got, want)
}
if itr.Next() {
t.Fatal("expected eof")
}
})
t.Run("Empty", func(t *testing.T) {
itr := litestream.NewBufferedWALSegmentIterator(litestream.NewWALSegmentInfoSliceIterator(nil))
if info, ok := itr.Peek(); ok {
t.Fatal("expected eof")
} else if got, want := info.Index, 0; got != want {
t.Fatalf("index=%d, want %d", got, want)
}
})
}
func TestParsePos(t *testing.T) {
t.Run("OK", func(t *testing.T) {
if pos, err := litestream.ParsePos("29cf4bced74e92ab/00000000000003e8:00000000000007d0"); err != nil {
t.Fatal(err)
} else if got, want := pos.Generation, "29cf4bced74e92ab"; got != want {
t.Fatalf("generation=%s, want %s", got, want)
} else if got, want := pos.Index, 1000; got != want {
t.Fatalf("index=%v, want %v", got, want)
} else if got, want := pos.Offset, 2000; got != int64(want) {
t.Fatalf("offset=%v, want %v", got, want)
}
})
t.Run("ErrMismatch", func(t *testing.T) {
_, err := litestream.ParsePos("29cf4bced74e92ab-00000000000003e8-00000000000007d0")
if err == nil || err.Error() != `invalid pos: "29cf4bced74e92ab-00000000000003e8-00000000000007d0"` {
t.Fatal(err)
}
})
t.Run("ErrInvalidIndex", func(t *testing.T) {
_, err := litestream.ParsePos("29cf4bced74e92ab/0000000000000xxx:00000000000007d0")
if err == nil || err.Error() != `cannot parse index: "0000000000000xxx"` {
t.Fatal(err)
}
})
t.Run("ErrInvalidIndex", func(t *testing.T) {
_, err := litestream.ParsePos("29cf4bced74e92ab/00000000000003e8:0000000000000xxx")
if err == nil || err.Error() != `cannot parse offset: "0000000000000xxx"` {
t.Fatal(err)
}
})
}
func decodeHexString(tb testing.TB, s string) []byte { func decodeHexString(tb testing.TB, s string) []byte {
tb.Helper() tb.Helper()