diff --git a/db.go b/db.go index ebddae4..aa1cbe2 100644 --- a/db.go +++ b/db.go @@ -14,21 +14,11 @@ import ( "math/rand" "os" "path/filepath" - "strconv" "strings" "sync" "time" ) -const ( - MetaDirSuffix = "-litestream" - - WALDirName = "wal" - WALExt = ".wal" - - GenerationNameLen = 16 -) - // Default DB settings. const ( DefaultMonitorInterval = 1 * time.Second @@ -152,9 +142,16 @@ func (db *DB) Open() (err error) { return fmt.Errorf("enable wal: %w", err) } + // Create a table to force writes to the WAL when empty. + // There should only ever be one row with id=1. + if _, err := db.db.Exec(`CREATE TABLE IF NOT EXISTS _litestream_seq (id INTEGER PRIMARY KEY, seq INTEGER);`); err != nil { + return fmt.Errorf("create _litestream_seq table: %w", err) + } + // Create a lock table to force write locks during sync. + // The sync write transaction always rolls back so no data should be in this table. if _, err := db.db.Exec(`CREATE TABLE IF NOT EXISTS _litestream_lock (id INTEGER);`); err != nil { - return fmt.Errorf("enable wal: %w", err) + return fmt.Errorf("create _litestream_lock table: %w", err) } // Start a long-running read transaction to prevent other transactions @@ -175,12 +172,43 @@ func (db *DB) Open() (err error) { return err } + // Clean up previous generations. + if err := db.clean(); err != nil { + return fmt.Errorf("clean: %w", err) + } + db.wg.Add(1) go func() { defer db.wg.Done(); db.monitor() }() return nil } +// clean removes old generations. +func (db *DB) clean() error { + generation, err := db.CurrentGeneration() + if err != nil { + return err + } + + dir := filepath.Join(db.MetaPath(), "generations") + fis, err := ioutil.ReadDir(dir) + if err != nil { + return err + } + for _, fi := range fis { + // Skip the current generation. + if filepath.Base(fi.Name()) == generation { + continue + } + + // Delete all other generations. + if err := os.RemoveAll(filepath.Join(dir, fi.Name())); err != nil { + return err + } + } + return nil +} + // Close releases the read lock & closes the database. This method should only // be called by tests as it causes the underlying database to be checkpointed. func (db *DB) Close() (err error) { @@ -283,7 +311,7 @@ func (db *DB) createGeneration() (string, error) { // Initialize shadow WAL with copy of header. if err := db.initShadowWALFile(db.ShadowWALPath(generation, 0)); err != nil { - return "", fmt.Errorf("copy initial wal: %w", err) + return "", fmt.Errorf("initialize shadow wal: %w", err) } // Atomically write generation name as current generation. @@ -310,6 +338,11 @@ func (db *DB) Sync() (err error) { // TODO: Force "-wal" file if it doesn't exist. + // Ensure WAL has at least one frame in it. + if err := db.ensureWALExists(); err != nil { + return fmt.Errorf("ensure wal exists: %w", err) + } + // Start a transaction. This will be promoted immediately after. tx, err := db.db.Begin() if err != nil { @@ -387,6 +420,18 @@ func (db *DB) Sync() (err error) { return nil } +// ensureWALExists checks that the real WAL exists and has a header. +func (db *DB) ensureWALExists() (err error) { + // Exit early if WAL header exists. + if fi, err := os.Stat(db.WALPath()); err == nil && fi.Size() >= WALHeaderSize { + return nil + } + + // Otherwise create transaction that updates the internal litestream table. + _, err = db.db.Exec(`INSERT INTO _litestream_seq (id, seq) VALUES (1, 1) ON CONFLICT (id) DO UPDATE SET seq = seq + 1`) + return err +} + // verifyWAL ensures the current shadow WAL state matches where it left off from // the real WAL. Returns generation & WAL sync information. If info.reason is // not blank, verification failed and a new generation should be started. @@ -498,7 +543,7 @@ func (db *DB) syncWAL(info syncInfo) (newSize int64, err error) { } func (db *DB) initShadowWALFile(filename string) error { - hdr, err := readWALHeader(filename) + hdr, err := readWALHeader(db.WALPath()) if err != nil { return fmt.Errorf("read header: %w", err) } @@ -526,6 +571,9 @@ func (db *DB) initShadowWALFile(filename string) error { } // Write header to new WAL shadow file. + if err := os.MkdirAll(filepath.Dir(filename), 0700); err != nil { + return err + } return ioutil.WriteFile(filename, hdr, 0600) } @@ -680,67 +728,3 @@ func (db *DB) monitor() { } } } - -const ( - // WALHeaderSize is the size of the WAL header, in bytes. - WALHeaderSize = 32 - - // WALFrameHeaderSize is the size of the WAL frame header, in bytes. - WALFrameHeaderSize = 24 -) - -// calcWALSize returns the size of the WAL, in bytes, for a given number of pages. -func calcWALSize(pageSize int, n int) int64 { - return int64(WALHeaderSize + ((WALFrameHeaderSize + pageSize) * n)) -} - -// rollback rolls back tx. Ignores already-rolled-back errors. -func rollback(tx *sql.Tx) error { - if err := tx.Rollback(); err != nil && !strings.Contains(err.Error(), `transaction has already been committed or rolled back`) { - return err - } - return nil -} - -// readWALHeader returns the header read from a WAL file. -func readWALHeader(filename string) ([]byte, error) { - f, err := os.Open(filename) - if err != nil { - return nil, err - } - defer f.Close() - - buf := make([]byte, WALHeaderSize) - n, err := io.ReadFull(f, buf) - return buf[:n], err -} - -// readFileAt reads a slice from a file. -func readFileAt(filename string, offset, n int64) ([]byte, error) { - f, err := os.Open(filename) - if err != nil { - return nil, err - } - defer f.Close() - - buf := make([]byte, n) - if n, err := f.ReadAt(buf, offset); err != nil { - return buf[:n], err - } else if n < len(buf) { - return buf[:n], io.ErrUnexpectedEOF - } - return buf, nil -} - -func ParseWALFilename(name string) (index int, err error) { - v, err := strconv.ParseInt(strings.TrimSuffix(name, WALExt), 16, 64) - if err != nil { - return 0, fmt.Errorf("invalid wal filename: %q", name) - } - return int(v), nil -} - -func FormatWALFilename(index int) string { - assert(index >= 0, "wal index must be non-negative") - return fmt.Sprintf("%016d%s", index, WALExt) -} diff --git a/litestream.go b/litestream.go index d2c180e..06b541f 100644 --- a/litestream.go +++ b/litestream.go @@ -1,13 +1,27 @@ package litestream import ( + "database/sql" "encoding/binary" "encoding/hex" + "fmt" + "io" + "os" + "strconv" "strings" _ "github.com/mattn/go-sqlite3" ) +const ( + MetaDirSuffix = "-litestream" + + WALDirName = "wal" + WALExt = ".wal" + + GenerationNameLen = 16 +) + // Checksum computes a running SQLite checksum over a byte slice. func Checksum(bo binary.ByteOrder, s0, s1 uint32, b []byte) (uint32, uint32) { assert(len(b)%8 == 0, "misaligned checksum byte slice") @@ -20,6 +34,70 @@ func Checksum(bo binary.ByteOrder, s0, s1 uint32, b []byte) (uint32, uint32) { return s0, s1 } +const ( + // WALHeaderSize is the size of the WAL header, in bytes. + WALHeaderSize = 32 + + // WALFrameHeaderSize is the size of the WAL frame header, in bytes. + WALFrameHeaderSize = 24 +) + +// calcWALSize returns the size of the WAL, in bytes, for a given number of pages. +func calcWALSize(pageSize int, n int) int64 { + return int64(WALHeaderSize + ((WALFrameHeaderSize + pageSize) * n)) +} + +// rollback rolls back tx. Ignores already-rolled-back errors. +func rollback(tx *sql.Tx) error { + if err := tx.Rollback(); err != nil && !strings.Contains(err.Error(), `transaction has already been committed or rolled back`) { + return err + } + return nil +} + +// readWALHeader returns the header read from a WAL file. +func readWALHeader(filename string) ([]byte, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + buf := make([]byte, WALHeaderSize) + n, err := io.ReadFull(f, buf) + return buf[:n], err +} + +// readFileAt reads a slice from a file. +func readFileAt(filename string, offset, n int64) ([]byte, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + buf := make([]byte, n) + if n, err := f.ReadAt(buf, offset); err != nil { + return buf[:n], err + } else if n < len(buf) { + return buf[:n], io.ErrUnexpectedEOF + } + return buf, nil +} + +func ParseWALFilename(name string) (index int, err error) { + v, err := strconv.ParseInt(strings.TrimSuffix(name, WALExt), 16, 64) + if err != nil { + return 0, fmt.Errorf("invalid wal filename: %q", name) + } + return int(v), nil +} + +func FormatWALFilename(index int) string { + assert(index >= 0, "wal index must be non-negative") + return fmt.Sprintf("%016d%s", index, WALExt) +} + // HexDump returns hexdump output but with duplicate lines removed. func HexDump(b []byte) string { const prefixN = len("00000000")