diff --git a/db.go b/db.go index 160cd47..b1f89a3 100644 --- a/db.go +++ b/db.go @@ -17,6 +17,7 @@ import ( "sort" "strings" "sync" + "syscall" "time" "github.com/benbjohnson/litestream/internal" @@ -1747,23 +1748,34 @@ func (db *DB) stream(ctx context.Context) error { // streamSnapshot reads the snapshot into the WAL and applies it to the main database. func (db *DB) streamSnapshot(ctx context.Context, hdr *StreamRecordHeader, r io.Reader) error { - // Truncate WAL file. - if _, err := db.db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { - return fmt.Errorf("truncate: %w", err) - } - // Determine total page count. pageN := int(hdr.Size / int64(db.pageSize)) - ww := NewWALWriter(db.WALPath(), db.fileMode, db.pageSize) - if err := ww.Open(); err != nil { - return fmt.Errorf("open wal writer: %w", err) + // Open database file. + f, err := os.OpenFile(db.path, os.O_RDWR, 0666) + if err != nil { + return fmt.Errorf("open db file: %w", err) } - defer func() { _ = ww.Close() }() + defer f.Close() - if err := ww.WriteHeader(); err != nil { - return fmt.Errorf("write wal header: %w", err) + // Open shm file for locking. + shmFile, err := os.OpenFile(db.SHMPath(), os.O_RDWR, 0666) + if err != nil { + return fmt.Errorf("open shm file: %w", err) } + defer shmFile.Close() + + // Obtain WAL checkpoint lock. + if err := setLkw(shmFile, syscall.F_WRLCK, WAL_CKPT_LOCK_OFFSET, 1); err != nil { + return fmt.Errorf("cannot obtain wal checkpoint lock: %w", err) + } + defer func() { _ = setLkw(shmFile, syscall.F_UNLCK, WAL_CKPT_LOCK_OFFSET, 1) }() + + // Obtain WAL write lock. + if err := setLkw(shmFile, syscall.F_WRLCK, WAL_WRITE_LOCK_OFFSET, 1); err != nil { + return fmt.Errorf("cannot obtain wal write lock: %w", err) + } + defer func() { _ = setLkw(shmFile, syscall.F_UNLCK, WAL_WRITE_LOCK_OFFSET, 1) }() // Iterate over pages buf := make([]byte, db.pageSize) @@ -1775,26 +1787,18 @@ func (db *DB) streamSnapshot(ctx context.Context, hdr *StreamRecordHeader, r io. return fmt.Errorf("read snapshot page %d: %w", pgno, err) } - // Issue a commit flag when the last page is reached. - var commit uint32 + // Copy page to database file. + offset := int64(pgno-1) * int64(db.pageSize) + if _, err := f.WriteAt(buf, offset); err != nil { + return fmt.Errorf("copy to db: pgno=%d err=%w", pgno, err) + } + + // Truncate database to final size. if pgno == uint32(pageN) { - commit = uint32(pageN) + if err := f.Truncate(int64(pageN) * int64(db.pageSize)); err != nil { + return fmt.Errorf("truncate db: commit=%d err=%w", pageN, err) + } } - - // Write page into WAL frame. - if err := ww.WriteFrame(pgno, commit, buf); err != nil { - return fmt.Errorf("write wal frame: %w", err) - } - } - - // Close WAL file writer. - if err := ww.Close(); err != nil { - return fmt.Errorf("close wal writer: %w", err) - } - - // Invalidate WAL index. - if err := invalidateSHMFile(db.path); err != nil { - return fmt.Errorf("invalidate shm file: %w", err) } // Write position to file so other processes can read it. @@ -1819,44 +1823,63 @@ func (db *DB) streamWALSegment(ctx context.Context, hdr *StreamRecordHeader, r i } } - ww := NewWALWriter(db.WALPath(), db.fileMode, db.pageSize) - if err := ww.Open(); err != nil { - return fmt.Errorf("open wal writer: %w", err) + // Open database file. + f, err := os.OpenFile(db.path, os.O_RDWR, 0666) + if err != nil { + return fmt.Errorf("open db file: %w", err) } - defer func() { _ = ww.Close() }() + defer f.Close() - if err := ww.WriteHeader(); err != nil { - return fmt.Errorf("write wal header: %w", err) + // Open shm file for locking. + shmFile, err := os.OpenFile(db.SHMPath(), os.O_RDWR, 0666) + if err != nil { + return fmt.Errorf("open shm file: %w", err) } + defer shmFile.Close() + + // Obtain WAL checkpoint lock. + if err := setLkw(shmFile, syscall.F_WRLCK, WAL_CKPT_LOCK_OFFSET, 1); err != nil { + return fmt.Errorf("cannot obtain wal checkpoint lock: %w", err) + } + defer func() { _ = setLkw(shmFile, syscall.F_UNLCK, WAL_CKPT_LOCK_OFFSET, 1) }() + + // Obtain WAL write lock. + if err := setLkw(shmFile, syscall.F_WRLCK, WAL_WRITE_LOCK_OFFSET, 1); err != nil { + return fmt.Errorf("cannot obtain wal write lock: %w", err) + } + defer func() { _ = setLkw(shmFile, syscall.F_UNLCK, WAL_WRITE_LOCK_OFFSET, 1) }() // Iterate over incoming WAL pages. buf := make([]byte, WALFrameHeaderSize+db.pageSize) for i := 0; ; i++ { // Read snapshot page into a buffer. - if _, err := io.ReadFull(zr, buf); err == io.EOF { + if n, err := io.ReadFull(zr, buf); err == io.EOF { break } else if err != nil { - return fmt.Errorf("read wal frame %d: %w", i, err) + return fmt.Errorf("read wal frame: i=%d n=%d err=%w", i, n, err) } // Read page number & commit field. pgno := binary.BigEndian.Uint32(buf[0:]) commit := binary.BigEndian.Uint32(buf[4:]) - // Write page into WAL frame. - if err := ww.WriteFrame(pgno, commit, buf[WALFrameHeaderSize:]); err != nil { - return fmt.Errorf("write wal frame: %w", err) + // Copy page to database file. + offset := int64(pgno-1) * int64(db.pageSize) + if _, err := f.WriteAt(buf[WALFrameHeaderSize:], offset); err != nil { + return fmt.Errorf("copy to db: pgno=%d err=%w", pgno, err) + } + + // Truncate database, if commit specified. + if commit != 0 { + if err := f.Truncate(int64(commit) * int64(db.pageSize)); err != nil { + return fmt.Errorf("truncate db: commit=%d err=%w", commit, err) + } } } - // Close WAL file writer. - if err := ww.Close(); err != nil { - return fmt.Errorf("close wal writer: %w", err) - } - - // Invalidate WAL index. - if err := invalidateSHMFile(db.path); err != nil { - return fmt.Errorf("invalidate shm file: %w", err) + // Close database file writer. + if err := f.Close(); err != nil { + return fmt.Errorf("close db writer: %w", err) } // Write position to file so other processes can read it. @@ -2016,51 +2039,21 @@ func logPrefixPath(path string) string { return path } -// invalidateSHMFile clears the iVersion field of the -shm file in order that -// the next transaction will rebuild it. -func invalidateSHMFile(dbPath string) error { - db, err := sql.Open("sqlite3", dbPath) - if err != nil { - return fmt.Errorf("reopen db: %w", err) - } - defer func() { _ = db.Close() }() - - if _, err := db.Exec(`PRAGMA wal_checkpoint(PASSIVE)`); err != nil { - return fmt.Errorf("passive checkpoint: %w", err) - } - - f, err := os.OpenFile(dbPath+"-shm", os.O_RDWR, 0666) - if err != nil { - return fmt.Errorf("open shm index: %w", err) - } - defer f.Close() - - buf := make([]byte, WALIndexHeaderSize) - if _, err := io.ReadFull(f, buf); err != nil { - return fmt.Errorf("read shm index: %w", err) - } - - // Invalidate "isInit" fields. - buf[12], buf[60] = 0, 0 - - // Rewrite header. - if _, err := f.Seek(0, io.SeekStart); err != nil { - return fmt.Errorf("seek shm index: %w", err) - } else if _, err := f.Write(buf); err != nil { - return fmt.Errorf("overwrite shm index: %w", err) - } else if err := f.Close(); err != nil { - return fmt.Errorf("close shm index: %w", err) - } - - // Truncate WAL file again. - var row [3]int - if err := db.QueryRow(`PRAGMA wal_checkpoint(TRUNCATE)`).Scan(&row[0], &row[1], &row[2]); err != nil { - return fmt.Errorf("truncate: %w", err) - } - - return nil -} - // A marker error to indicate that a restart checkpoint could not verify // continuity between WAL indices and a new generation should be started. var errRestartGeneration = errors.New("restart generation") + +const ( + WAL_WRITE_LOCK_OFFSET = 120 + WAL_CKPT_LOCK_OFFSET = 121 +) + +// setLkw is a helper function for calling fcntl for file locking. +func setLkw(f *os.File, typ int16, start, len int64) error { + return syscall.FcntlFlock(f.Fd(), syscall.F_SETLKW, &syscall.Flock_t{ + Start: start, + Len: len, + Type: typ, + Whence: io.SeekStart, + }) +} diff --git a/http/client.go b/http/client.go index 2975ec6..bf3e842 100644 --- a/http/client.go +++ b/http/client.go @@ -48,6 +48,7 @@ func (c *Client) Stream(ctx context.Context, pos litestream.Pos) (litestream.Str if !pos.IsZero() { q.Set("generation", pos.Generation) q.Set("index", litestream.FormatIndex(pos.Index)) + q.Set("offset", litestream.FormatOffset(pos.Offset)) } // Strip off everything but the scheme & host. diff --git a/http/server.go b/http/server.go index 702b2d6..f8899a2 100644 --- a/http/server.go +++ b/http/server.go @@ -134,16 +134,31 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) { return } + generationStr := q.Get("generation") + indexStr := q.Get("index") + offsetStr := q.Get("offset") + // Parse current client position, if available. var pos litestream.Pos - if generation, index := q.Get("generation"), q.Get("index"); generation != "" && index != "" { - pos.Generation = generation + if generationStr != "" && indexStr != "" && offsetStr != "" { - var err error - if pos.Index, err = litestream.ParseIndex(index); err != nil { + index, err := litestream.ParseIndex(indexStr) + if err != nil { s.writeError(w, r, "Invalid index query parameter", http.StatusBadRequest) return } + + offset, err := litestream.ParseOffset(offsetStr) + if err != nil { + s.writeError(w, r, "Invalid offset query parameter", http.StatusBadRequest) + return + } + + pos = litestream.Pos{ + Generation: generationStr, + Index: index, + Offset: offset, + } } // Fetch database instance from the primary server. @@ -162,7 +177,6 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) { s.writeError(w, r, "No generation available", http.StatusServiceUnavailable) return } - dbPos.Offset = 0 // Use database position if generation has changed. var snapshotRequired bool @@ -170,6 +184,7 @@ func (s *Server) handleGetStream(w http.ResponseWriter, r *http.Request) { s.Logger.Printf("stream generation mismatch, using primary position: client.pos=%s", pos) pos, snapshotRequired = dbPos, true } + pos.Offset = 0 // Obtain iterator before snapshot so we don't miss any WAL segments. fitr, err := db.WALSegments(r.Context(), pos.Generation) diff --git a/integration/cmd_test.go b/integration/cmd_test.go index 62ee336..eb13f7e 100644 --- a/integration/cmd_test.go +++ b/integration/cmd_test.go @@ -481,7 +481,7 @@ func TestCmd_Replicate_HTTP_PartialRecovery(t *testing.T) { t.Fatal(err) } else if _, err := db0.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil { t.Fatal(err) - } else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { + } else if _, err := db0.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY, DATA TEXT)`); err != nil { t.Fatal(err) } defer db0.Close() @@ -489,8 +489,8 @@ func TestCmd_Replicate_HTTP_PartialRecovery(t *testing.T) { var index int insertAndWait := func() { index++ - t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", index) - if _, err := db0.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, index); err != nil { + t.Logf("[exec] INSERT INTO t (id, data) VALUES (%d, '...')", index) + if _, err := db0.ExecContext(ctx, `INSERT INTO t (id, data) VALUES (?, ?)`, index, strings.Repeat("x", 512)); err != nil { t.Fatal(err) } time.Sleep(100 * time.Millisecond)