Fix read replication stream restart position

This commit is contained in:
Ben Johnson
2022-06-15 14:04:37 -06:00
parent 46597ab22f
commit c53a09c124
4 changed files with 110 additions and 101 deletions

177
db.go
View File

@@ -17,6 +17,7 @@ import (
"sort" "sort"
"strings" "strings"
"sync" "sync"
"syscall"
"time" "time"
"github.com/benbjohnson/litestream/internal" "github.com/benbjohnson/litestream/internal"
@@ -1747,23 +1748,34 @@ func (db *DB) stream(ctx context.Context) error {
// streamSnapshot reads the snapshot into the WAL and applies it to the main database. // streamSnapshot reads the snapshot into the WAL and applies it to the main database.
func (db *DB) streamSnapshot(ctx context.Context, hdr *StreamRecordHeader, r io.Reader) error { func (db *DB) streamSnapshot(ctx context.Context, hdr *StreamRecordHeader, r io.Reader) error {
// Truncate WAL file.
if _, err := db.db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil {
return fmt.Errorf("truncate: %w", err)
}
// Determine total page count. // Determine total page count.
pageN := int(hdr.Size / int64(db.pageSize)) pageN := int(hdr.Size / int64(db.pageSize))
ww := NewWALWriter(db.WALPath(), db.fileMode, db.pageSize) // Open database file.
if err := ww.Open(); err != nil { f, err := os.OpenFile(db.path, os.O_RDWR, 0666)
return fmt.Errorf("open wal writer: %w", err) if err != nil {
return fmt.Errorf("open db file: %w", err)
} }
defer func() { _ = ww.Close() }() defer f.Close()
if err := ww.WriteHeader(); err != nil { // Open shm file for locking.
return fmt.Errorf("write wal header: %w", err) shmFile, err := os.OpenFile(db.SHMPath(), os.O_RDWR, 0666)
if err != nil {
return fmt.Errorf("open shm file: %w", err)
} }
defer shmFile.Close()
// Obtain WAL checkpoint lock.
if err := setLkw(shmFile, syscall.F_WRLCK, WAL_CKPT_LOCK_OFFSET, 1); err != nil {
return fmt.Errorf("cannot obtain wal checkpoint lock: %w", err)
}
defer func() { _ = setLkw(shmFile, syscall.F_UNLCK, WAL_CKPT_LOCK_OFFSET, 1) }()
// Obtain WAL write lock.
if err := setLkw(shmFile, syscall.F_WRLCK, WAL_WRITE_LOCK_OFFSET, 1); err != nil {
return fmt.Errorf("cannot obtain wal write lock: %w", err)
}
defer func() { _ = setLkw(shmFile, syscall.F_UNLCK, WAL_WRITE_LOCK_OFFSET, 1) }()
// Iterate over pages // Iterate over pages
buf := make([]byte, db.pageSize) buf := make([]byte, db.pageSize)
@@ -1775,26 +1787,18 @@ func (db *DB) streamSnapshot(ctx context.Context, hdr *StreamRecordHeader, r io.
return fmt.Errorf("read snapshot page %d: %w", pgno, err) return fmt.Errorf("read snapshot page %d: %w", pgno, err)
} }
// Issue a commit flag when the last page is reached. // Copy page to database file.
var commit uint32 offset := int64(pgno-1) * int64(db.pageSize)
if _, err := f.WriteAt(buf, offset); err != nil {
return fmt.Errorf("copy to db: pgno=%d err=%w", pgno, err)
}
// Truncate database to final size.
if pgno == uint32(pageN) { if pgno == uint32(pageN) {
commit = uint32(pageN) if err := f.Truncate(int64(pageN) * int64(db.pageSize)); err != nil {
} return fmt.Errorf("truncate db: commit=%d err=%w", pageN, err)
// Write page into WAL frame.
if err := ww.WriteFrame(pgno, commit, buf); err != nil {
return fmt.Errorf("write wal frame: %w", err)
} }
} }
// Close WAL file writer.
if err := ww.Close(); err != nil {
return fmt.Errorf("close wal writer: %w", err)
}
// Invalidate WAL index.
if err := invalidateSHMFile(db.path); err != nil {
return fmt.Errorf("invalidate shm file: %w", err)
} }
// Write position to file so other processes can read it. // Write position to file so other processes can read it.
@@ -1819,44 +1823,63 @@ func (db *DB) streamWALSegment(ctx context.Context, hdr *StreamRecordHeader, r i
} }
} }
ww := NewWALWriter(db.WALPath(), db.fileMode, db.pageSize) // Open database file.
if err := ww.Open(); err != nil { f, err := os.OpenFile(db.path, os.O_RDWR, 0666)
return fmt.Errorf("open wal writer: %w", err) if err != nil {
return fmt.Errorf("open db file: %w", err)
} }
defer func() { _ = ww.Close() }() defer f.Close()
if err := ww.WriteHeader(); err != nil { // Open shm file for locking.
return fmt.Errorf("write wal header: %w", err) shmFile, err := os.OpenFile(db.SHMPath(), os.O_RDWR, 0666)
if err != nil {
return fmt.Errorf("open shm file: %w", err)
} }
defer shmFile.Close()
// Obtain WAL checkpoint lock.
if err := setLkw(shmFile, syscall.F_WRLCK, WAL_CKPT_LOCK_OFFSET, 1); err != nil {
return fmt.Errorf("cannot obtain wal checkpoint lock: %w", err)
}
defer func() { _ = setLkw(shmFile, syscall.F_UNLCK, WAL_CKPT_LOCK_OFFSET, 1) }()
// Obtain WAL write lock.
if err := setLkw(shmFile, syscall.F_WRLCK, WAL_WRITE_LOCK_OFFSET, 1); err != nil {
return fmt.Errorf("cannot obtain wal write lock: %w", err)
}
defer func() { _ = setLkw(shmFile, syscall.F_UNLCK, WAL_WRITE_LOCK_OFFSET, 1) }()
// Iterate over incoming WAL pages. // Iterate over incoming WAL pages.
buf := make([]byte, WALFrameHeaderSize+db.pageSize) buf := make([]byte, WALFrameHeaderSize+db.pageSize)
for i := 0; ; i++ { for i := 0; ; i++ {
// Read snapshot page into a buffer. // Read snapshot page into a buffer.
if _, err := io.ReadFull(zr, buf); err == io.EOF { if n, err := io.ReadFull(zr, buf); err == io.EOF {
break break
} else if err != nil { } else if err != nil {
return fmt.Errorf("read wal frame %d: %w", i, err) return fmt.Errorf("read wal frame: i=%d n=%d err=%w", i, n, err)
} }
// Read page number & commit field. // Read page number & commit field.
pgno := binary.BigEndian.Uint32(buf[0:]) pgno := binary.BigEndian.Uint32(buf[0:])
commit := binary.BigEndian.Uint32(buf[4:]) commit := binary.BigEndian.Uint32(buf[4:])
// Write page into WAL frame. // Copy page to database file.
if err := ww.WriteFrame(pgno, commit, buf[WALFrameHeaderSize:]); err != nil { offset := int64(pgno-1) * int64(db.pageSize)
return fmt.Errorf("write wal frame: %w", err) if _, err := f.WriteAt(buf[WALFrameHeaderSize:], offset); err != nil {
return fmt.Errorf("copy to db: pgno=%d err=%w", pgno, err)
}
// Truncate database, if commit specified.
if commit != 0 {
if err := f.Truncate(int64(commit) * int64(db.pageSize)); err != nil {
return fmt.Errorf("truncate db: commit=%d err=%w", commit, err)
}
} }
} }
// Close WAL file writer. // Close database file writer.
if err := ww.Close(); err != nil { if err := f.Close(); err != nil {
return fmt.Errorf("close wal writer: %w", err) return fmt.Errorf("close db writer: %w", err)
}
// Invalidate WAL index.
if err := invalidateSHMFile(db.path); err != nil {
return fmt.Errorf("invalidate shm file: %w", err)
} }
// Write position to file so other processes can read it. // Write position to file so other processes can read it.
@@ -2016,51 +2039,21 @@ func logPrefixPath(path string) string {
return path return path
} }
// invalidateSHMFile clears the iVersion field of the -shm file in order that
// the next transaction will rebuild it.
func invalidateSHMFile(dbPath string) error {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return fmt.Errorf("reopen db: %w", err)
}
defer func() { _ = db.Close() }()
if _, err := db.Exec(`PRAGMA wal_checkpoint(PASSIVE)`); err != nil {
return fmt.Errorf("passive checkpoint: %w", err)
}
f, err := os.OpenFile(dbPath+"-shm", os.O_RDWR, 0666)
if err != nil {
return fmt.Errorf("open shm index: %w", err)
}
defer f.Close()
buf := make([]byte, WALIndexHeaderSize)
if _, err := io.ReadFull(f, buf); err != nil {
return fmt.Errorf("read shm index: %w", err)
}
// Invalidate "isInit" fields.
buf[12], buf[60] = 0, 0
// Rewrite header.
if _, err := f.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("seek shm index: %w", err)
} else if _, err := f.Write(buf); err != nil {
return fmt.Errorf("overwrite shm index: %w", err)
} else if err := f.Close(); err != nil {
return fmt.Errorf("close shm index: %w", err)
}
// Truncate WAL file again.
var row [3]int
if err := db.QueryRow(`PRAGMA wal_checkpoint(TRUNCATE)`).Scan(&row[0], &row[1], &row[2]); err != nil {
return fmt.Errorf("truncate: %w", err)
}
return nil
}
// A marker error to indicate that a restart checkpoint could not verify // A marker error to indicate that a restart checkpoint could not verify
// continuity between WAL indices and a new generation should be started. // continuity between WAL indices and a new generation should be started.
var errRestartGeneration = errors.New("restart generation") var errRestartGeneration = errors.New("restart generation")
const (
WAL_WRITE_LOCK_OFFSET = 120
WAL_CKPT_LOCK_OFFSET = 121
)
// setLkw is a helper function for calling fcntl for file locking.
func setLkw(f *os.File, typ int16, start, len int64) error {
return syscall.FcntlFlock(f.Fd(), syscall.F_SETLKW, &syscall.Flock_t{
Start: start,
Len: len,
Type: typ,
Whence: io.SeekStart,
})
}

View File

@@ -48,6 +48,7 @@ func (c *Client) Stream(ctx context.Context, pos litestream.Pos) (litestream.Str
if !pos.IsZero() { if !pos.IsZero() {
q.Set("generation", pos.Generation) q.Set("generation", pos.Generation)
q.Set("index", litestream.FormatIndex(pos.Index)) q.Set("index", litestream.FormatIndex(pos.Index))
q.Set("offset", litestream.FormatOffset(pos.Offset))
} }
// Strip off everything but the scheme & host. // Strip off everything but the scheme & host.

View File

@@ -134,16 +134,31 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) {
return return
} }
generationStr := q.Get("generation")
indexStr := q.Get("index")
offsetStr := q.Get("offset")
// Parse current client position, if available. // Parse current client position, if available.
var pos litestream.Pos var pos litestream.Pos
if generation, index := q.Get("generation"), q.Get("index"); generation != "" && index != "" { if generationStr != "" && indexStr != "" && offsetStr != "" {
pos.Generation = generation
var err error index, err := litestream.ParseIndex(indexStr)
if pos.Index, err = litestream.ParseIndex(index); err != nil { if err != nil {
s.writeError(w, r, "Invalid index query parameter", http.StatusBadRequest) s.writeError(w, r, "Invalid index query parameter", http.StatusBadRequest)
return return
} }
offset, err := litestream.ParseOffset(offsetStr)
if err != nil {
s.writeError(w, r, "Invalid offset query parameter", http.StatusBadRequest)
return
}
pos = litestream.Pos{
Generation: generationStr,
Index: index,
Offset: offset,
}
} }
// Fetch database instance from the primary server. // Fetch database instance from the primary server.
@@ -162,7 +177,6 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) {
s.writeError(w, r, "No generation available", http.StatusServiceUnavailable) s.writeError(w, r, "No generation available", http.StatusServiceUnavailable)
return return
} }
dbPos.Offset = 0
// Use database position if generation has changed. // Use database position if generation has changed.
var snapshotRequired bool var snapshotRequired bool
@@ -170,6 +184,7 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) {
s.Logger.Printf("stream generation mismatch, using primary position: client.pos=%s", pos) s.Logger.Printf("stream generation mismatch, using primary position: client.pos=%s", pos)
pos, snapshotRequired = dbPos, true pos, snapshotRequired = dbPos, true
} }
pos.Offset = 0
// Obtain iterator before snapshot so we don't miss any WAL segments. // Obtain iterator before snapshot so we don't miss any WAL segments.
fitr, err := db.WALSegments(r.Context(), pos.Generation) fitr, err := db.WALSegments(r.Context(), pos.Generation)

View File

@@ -481,7 +481,7 @@ func TestCmd_Replicate_HTTP_PartialRecovery(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} else if _, err := db0.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil { } else if _, err := db0.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil {
t.Fatal(err) t.Fatal(err)
} else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { } else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY, DATA TEXT)`); err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer db0.Close() defer db0.Close()
@@ -489,8 +489,8 @@ func TestCmd_Replicate_HTTP_PartialRecovery(t *testing.T) {
var index int var index int
insertAndWait := func() { insertAndWait := func() {
index++ index++
t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", index) t.Logf("[exec] INSERT INTO t (id, data) VALUES (%d, '...')", index)
if _, err := db0.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, index); err != nil { if _, err := db0.ExecContext(ctx, `INSERT INTO t (id, data) VALUES (?, ?)`, index, strings.Repeat("x", 512)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)