diff --git a/db.go b/db.go index ad9b3eb..ebddae4 100644 --- a/db.go +++ b/db.go @@ -1,8 +1,10 @@ package litestream import ( + "bytes" "context" "database/sql" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -12,13 +14,12 @@ import ( "math/rand" "os" "path/filepath" + "strconv" "strings" "sync" "time" ) -var ErrNoGeneration = errors.New("litestream: no generation") - const ( MetaDirSuffix = "-litestream" @@ -42,6 +43,9 @@ type DB struct { rtx *sql.Tx // long running read transaction pageSize int // page size, in bytes + byteOrder binary.ByteOrder // determined by WAL header magic + salt0, salt1 uint32 // read from WAL header + ctx context.Context cancel func() wg sync.WaitGroup @@ -244,11 +248,11 @@ func (db *DB) releaseReadLock() error { } // CurrentGeneration returns the name of the generation saved to the "generation" -// file in the meta data directory. Returns ErrNoGeneration if none exists. +// file in the meta data directory. Returns empty string if none exists. func (db *DB) CurrentGeneration() (string, error) { buf, err := ioutil.ReadFile(db.GenerationNamePath()) if os.IsNotExist(err) { - return "", ErrNoGeneration + return "", nil } else if err != nil { return "", err } @@ -257,7 +261,7 @@ func (db *DB) CurrentGeneration() (string, error) { generation := strings.TrimSpace(string(buf)) if len(generation) != GenerationNameLen { - return "", ErrNoGeneration + return "", nil } return generation, nil } @@ -277,8 +281,8 @@ func (db *DB) createGeneration() (string, error) { return "", err } - // Copy to shadow WAL. - if err := db.copyInitialWAL(generation); err != nil { + // Initialize shadow WAL with copy of header. + if err := db.initShadowWALFile(db.ShadowWALPath(generation, 0)); err != nil { return "", fmt.Errorf("copy initial wal: %w", err) } @@ -300,42 +304,12 @@ func (db *DB) createGeneration() (string, error) { return generation, nil } -// copyInitialWAL copies the full WAL file to the initial shadow WAL path. -func (db *DB) copyInitialWAL(generation string) error { - shadowWALPath := db.ShadowWALPath(generation, 0) - if err := os.MkdirAll(filepath.Dir(shadowWALPath), 0700); err != nil { - return err - } - - // Open the initial shadow WAL file for writing. - w, err := os.OpenFile(shadowWALPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - return err - } - defer w.Close() - - // Open the database's WAL file for reading. - r, err := os.Open(db.WALPath()) - if err != nil { - return err - } - defer r.Close() - - // Copy & sync. - if _, err := io.Copy(w, r); err != nil { - return err - } else if err := w.Sync(); err != nil { - return err - } else if err := w.Close(); err != nil { - return err - } - return nil -} - // Sync copies pending data from the WAL to the shadow WAL. func (db *DB) Sync() (err error) { // TODO: Lock DB while syncing? + // TODO: Force "-wal" file if it doesn't exist. + // Start a transaction. This will be promoted immediately after. tx, err := db.db.Begin() if err != nil { @@ -362,18 +336,29 @@ func (db *DB) Sync() (err error) { return fmt.Errorf("disable autocheckpoint: %w", err) } - // Look up existing generation or start a new one. - generation, err := db.CurrentGeneration() - if err == ErrNoGeneration { - if generation, err = db.createGeneration(); err != nil { + // Verify our last sync matches the current state of the WAL. + // This ensures that we have an existing generation & that the last sync + // position of the real WAL hasn't been overwritten by another process. + // + // If we are unable to verify the WAL state then we start a new generation. + info, err := db.verifyWAL() + if err != nil { + return fmt.Errorf("cannot verify wal state: %w", err) + } else if info.reason != "" { + // Start new generation & notify user via log message. + if info.generation, err = db.createGeneration(); err != nil { return fmt.Errorf("create generation: %w", err) } - } else if err != nil { - return fmt.Errorf("cannot find current generation: %w", err) + log.Printf("%s: new generation %q, %s", db.path, info.generation, info.reason) + + // Clear shadow wal info. + info.shadowWALPath, info.shadowWALSize = "", 0 + info.restart = false + info.reason = "" } // Synchronize real WAL with current shadow WAL. - newWALSize, err := db.syncWAL(generation) + newWALSize, err := db.syncWAL(info) if err != nil { return fmt.Errorf("sync wal: %w", err) } @@ -402,6 +387,245 @@ func (db *DB) Sync() (err error) { return nil } +// verifyWAL ensures the current shadow WAL state matches where it left off from +// the real WAL. Returns generation & WAL sync information. If info.reason is +// not blank, verification failed and a new generation should be started. +func (db *DB) verifyWAL() (info syncInfo, err error) { + // Look up existing generation. + generation, err := db.CurrentGeneration() + if err != nil { + return info, fmt.Errorf("cannot find current generation: %w", err) + } else if generation == "" { + info.reason = "no generation" + return info, nil + } + info.generation = generation + + // Determine total bytes of real WAL. + fi, err := os.Stat(db.WALPath()) + if err != nil { + return info, err + } + info.walSize = fi.Size() + + // Open shadow WAL to copy append to. + info.shadowWALPath, err = db.CurrentShadowWALPath(info.generation) + if info.shadowWALPath == "" { + info.reason = "no shadow wal" + return info, nil + } else if err != nil { + return info, fmt.Errorf("cannot determine shadow WAL: %w", err) + } + + // Determine shadow WAL current size. + fi, err = os.Stat(info.shadowWALPath) + if err != nil { + return info, err + } + info.shadowWALSize = fi.Size() + + // TODO: Truncate shadow WAL if there is a partial page. + + // If shadow WAL is larger than real WAL then the WAL has been truncated + // so we cannot determine our last state. + if info.shadowWALSize > info.walSize { + info.reason = "wal truncated by another process" + return info, nil + } + + // Compare WAL headers. If mismatched then the real WAL has been restarted. + // We'll need to start a new shadow WAL during the sync. + if hdr0, err := readWALHeader(db.WALPath()); err != nil { + return info, fmt.Errorf("cannot read wal header: %w", err) + } else if hdr1, err := readWALHeader(info.shadowWALPath); err != nil { + return info, fmt.Errorf("cannot read shadow wal header: %w", err) + } else { + info.restart = !bytes.Equal(hdr0, hdr1) + } + + // Verify last page synced still matches. + offset := info.shadowWALSize - int64(db.pageSize+WALFrameHeaderSize) + if buf0, err := readFileAt(db.WALPath(), offset, int64(db.pageSize+WALFrameHeaderSize)); err != nil { + return info, fmt.Errorf("cannot read last synced wal page: %w", err) + } else if buf1, err := readFileAt(info.shadowWALPath, offset, int64(db.pageSize+WALFrameHeaderSize)); err != nil { + return info, fmt.Errorf("cannot read last synced shadow wal page: %w", err) + } else if !bytes.Equal(buf0, buf1) { + info.reason = "wal overwritten by another process" + return info, nil + } + + return info, nil +} + +type syncInfo struct { + generation string // generation name + walSize int64 // size of real WAL file + shadowWALPath string // name of last shadow WAL file + shadowWALSize int64 // size of last shadow WAL file + restart bool // if true, real WAL header does not match shadow WAL + reason string // if non-blank, reason for sync failure +} + +// syncWAL copies pending bytes from the real WAL to the shadow WAL. +func (db *DB) syncWAL(info syncInfo) (newSize int64, err error) { + // Copy WAL starting from end of shadow WAL. Exit if no new shadow WAL needed. + newSize, err = db.copyToShadowWAL(info.shadowWALPath) + if err != nil { + return newSize, fmt.Errorf("cannot copy to shadow wal: %w", err) + } else if !info.restart { + return newSize, nil // If no restart required, exit. + } + + // Parse index of current shadow WAL file. + dir, base := filepath.Split(info.shadowWALPath) + index, err := ParseWALFilename(base) + if err != nil { + return 0, fmt.Errorf("cannot parse shadow wal filename: %s", base) + } + + // Start a new shadow WAL file with next index. + newShadowWALPath := filepath.Join(dir, FormatWALFilename(index+1)) + if err := db.initShadowWALFile(newShadowWALPath); err != nil { + return 0, fmt.Errorf("cannot init shadow wal file: name=%s err=%w", newShadowWALPath, err) + } + + // Copy rest of valid WAL to new shadow WAL. + newSize, err = db.copyToShadowWAL(newShadowWALPath) + if err != nil { + return 0, fmt.Errorf("cannot copy to new shadow wal: %w", err) + } + return newSize, nil +} + +func (db *DB) initShadowWALFile(filename string) error { + hdr, err := readWALHeader(filename) + if err != nil { + return fmt.Errorf("read header: %w", err) + } + + // Determine byte order for checksumming from header magic. + magic := binary.BigEndian.Uint32(hdr[0:]) + switch magic { + case 0x377f0682: + db.byteOrder = binary.LittleEndian + case 0x377f0683: + db.byteOrder = binary.BigEndian + default: + return fmt.Errorf("invalid wal header magic: %x", magic) + } + + // Read header salt. + db.salt0 = binary.BigEndian.Uint32(hdr[16:]) + db.salt1 = binary.BigEndian.Uint32(hdr[20:]) + + // Verify checksum. + s0 := binary.BigEndian.Uint32(hdr[24:]) + s1 := binary.BigEndian.Uint32(hdr[28:]) + if v0, v1 := Checksum(db.byteOrder, 0, 0, hdr[:24]); v0 != s0 || v1 != s1 { + return fmt.Errorf("invalid header checksum: (%x,%x) != (%x,%x)", v0, v1, s0, s1) + } + + // Write header to new WAL shadow file. + return ioutil.WriteFile(filename, hdr, 0600) +} + +func (db *DB) copyToShadowWAL(filename string) (newSize int64, err error) { + r, err := os.Open(db.WALPath()) + if err != nil { + return 0, err + } + defer r.Close() + + w, err := os.OpenFile(filename, os.O_RDWR, 0600) + if err != nil { + return 0, err + } + defer w.Close() + + fi, err := w.Stat() + if err != nil { + return 0, err + } + + // Read previous checksum. + chksum0, chksum1, err := readLastChecksumFrom(w, db.pageSize) + if err != nil { + return 0, fmt.Errorf("last checksum: %w", err) + } + + // Seek to correct position on both files. + if _, err := r.Seek(fi.Size(), io.SeekStart); err != nil { + return 0, fmt.Errorf("wal seek: %w", err) + } else if _, err := w.Seek(fi.Size(), io.SeekStart); err != nil { + return 0, fmt.Errorf("shadow wal seek: %w", err) + } + + // TODO: Optimize to use bufio on reader & writer to minimize syscalls. + + // Loop over each page, verify checksum, & copy to writer. + newSize = fi.Size() + buf := make([]byte, db.pageSize+WALFrameHeaderSize) + for { + // Read next page from WAL file. + if _, err := io.ReadFull(r, buf); err == io.EOF || err == io.ErrUnexpectedEOF { + break // end of file or partial page + } else if err != nil { + return newSize, fmt.Errorf("read wal: %w", err) + } + + // Read frame salt & compare to header salt. Stop reading on mismatch. + salt0 := binary.BigEndian.Uint32(buf[8:]) + salt1 := binary.BigEndian.Uint32(buf[12:]) + if salt0 != db.salt0 || salt1 != db.salt1 { + break + } + + // Verify checksum of page is valid. + fchksum0 := binary.BigEndian.Uint32(buf[16:]) + fchksum1 := binary.BigEndian.Uint32(buf[20:]) + chksum0, chksum1 = Checksum(db.byteOrder, chksum0, chksum1, buf[:8]) // frame header + chksum0, chksum1 = Checksum(db.byteOrder, chksum0, chksum1, buf[24:]) // frame data + if chksum0 != fchksum0 || chksum1 != fchksum1 { + return newSize, fmt.Errorf("checksum mismatch: offset=%d (%x,%x) != (%x,%x)", newSize, chksum0, chksum1, fchksum0, fchksum1) + } + + // Add page to the new size of the shadow WAL. + newSize += int64(len(buf)) + } + + // Sync & close writer. + if err := w.Sync(); err != nil { + return newSize, err + } else if err := w.Close(); err != nil { + return newSize, err + } + + return newSize, nil +} + +const WALHeaderChecksumOffset = 24 +const WALFrameHeaderChecksumOffset = 16 + +func readLastChecksumFrom(f *os.File, pageSize int) (uint32, uint32, error) { + // Determine the byte offset of the checksum for the header (if no pages + // exist) or for the last page (if at least one page exists). + offset := int64(WALHeaderChecksumOffset) + if fi, err := f.Stat(); err != nil { + return 0, 0, err + } else if fi.Size() > WALHeaderSize { + offset = fi.Size() - int64(pageSize) - WALFrameHeaderSize + WALFrameHeaderChecksumOffset + } + + // Read big endian checksum. + b := make([]byte, 8) + if n, err := f.ReadAt(b, offset); err != nil { + return 0, 0, err + } else if n != len(b) { + return 0, 0, io.ErrUnexpectedEOF + } + return binary.BigEndian.Uint32(b[0:]), binary.BigEndian.Uint32(b[4:]), nil +} + // checkpoint performs a checkpoint on the WAL file. func (db *DB) checkpoint(force bool) error { // Ensure the read lock has been removed before issuing a checkpoint. @@ -433,72 +657,6 @@ func (db *DB) checkpoint(force bool) error { return nil } -// syncWAL copies pending bytes from the real WAL to the shadow WAL. -func (db *DB) syncWAL(generation string) (newSize int64, err error) { - // Determine total bytes of real WAL. - fi, err := os.Stat(db.WALPath()) - if err != nil { - return 0, err - } - walSize := fi.Size() - - // Open shadow WAL to copy append to. - shadowWALPath, err := db.CurrentShadowWALPath(generation) - if err != nil { - return 0, fmt.Errorf("cannot determine shadow WAL: %w", err) - } - - // TODO: Compare WAL headers. - - // Determine shadow WAL current size. - fi, err = os.Stat(shadowWALPath) - if err != nil { - return 0, err - } - shadowWALSize := fi.Size() - - // Ensure we have pending bytes to write. - // TODO: Verify pending bytes is divisble by (pageSize+headerSize)? - pendingN := walSize - shadowWALSize - if pendingN < 0 { - panic("shadow wal larger than real wal") // TODO: Handle gracefully - } else if pendingN == 0 { - return shadowWALSize, nil // wals match, exit - } - - // TODO: Verify last page copied matches. - - // Open handles for the shadow WAL & real WAL. - w, err := os.OpenFile(shadowWALPath, os.O_RDWR, 0600) - if err != nil { - return 0, err - } - defer w.Close() - - r, err := os.Open(db.WALPath()) - if err != nil { - return 0, err - } - defer r.Close() - - // Seek to the correct position for each file. - if _, err := r.Seek(shadowWALSize, io.SeekStart); err != nil { - return 0, fmt.Errorf("wal seek: %w", err) - } else if _, err := w.Seek(shadowWALSize, io.SeekStart); err != nil { - return 0, fmt.Errorf("shadow wal seek: %w", err) - } - - // Copy and sync. - if _, err := io.CopyN(w, r, pendingN); err != nil { - return 0, fmt.Errorf("copy shadow wal error: %w", err) - } else if err := w.Sync(); err != nil { - return 0, fmt.Errorf("shadow wal sync: %w", err) - } else if err := w.Close(); err != nil { - return 0, fmt.Errorf("shadow wal close: %w", err) - } - return walSize, nil -} - // monitor runs in a separate goroutine and monitors the database & WAL. func (db *DB) monitor() { ticker := time.NewTicker(db.MonitorInterval) @@ -543,3 +701,46 @@ func rollback(tx *sql.Tx) error { } return nil } + +// readWALHeader returns the header read from a WAL file. +func readWALHeader(filename string) ([]byte, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + buf := make([]byte, WALHeaderSize) + n, err := io.ReadFull(f, buf) + return buf[:n], err +} + +// readFileAt reads a slice from a file. +func readFileAt(filename string, offset, n int64) ([]byte, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + buf := make([]byte, n) + if n, err := f.ReadAt(buf, offset); err != nil { + return buf[:n], err + } else if n < len(buf) { + return buf[:n], io.ErrUnexpectedEOF + } + return buf, nil +} + +func ParseWALFilename(name string) (index int, err error) { + v, err := strconv.ParseInt(strings.TrimSuffix(name, WALExt), 16, 64) + if err != nil { + return 0, fmt.Errorf("invalid wal filename: %q", name) + } + return int(v), nil +} + +func FormatWALFilename(index int) string { + assert(index >= 0, "wal index must be non-negative") + return fmt.Sprintf("%016d%s", index, WALExt) +} diff --git a/litestream.go b/litestream.go index 7a9bda4..d2c180e 100644 --- a/litestream.go +++ b/litestream.go @@ -1,12 +1,25 @@ package litestream import ( + "encoding/binary" "encoding/hex" "strings" _ "github.com/mattn/go-sqlite3" ) +// Checksum computes a running SQLite checksum over a byte slice. +func Checksum(bo binary.ByteOrder, s0, s1 uint32, b []byte) (uint32, uint32) { + assert(len(b)%8 == 0, "misaligned checksum byte slice") + + // Iterate over 8-byte units and compute checksum. + for i := 0; i < len(b); i += 8 { + s0 += bo.Uint32(b[i:]) + s1 + s1 += bo.Uint32(b[i+4:]) + s0 + } + return s0, s1 +} + // HexDump returns hexdump output but with duplicate lines removed. func HexDump(b []byte) string { const prefixN = len("00000000")