Add salt & checksum checking
This commit is contained in:
425
db.go
425
db.go
@@ -1,8 +1,10 @@
|
||||
package litestream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -12,13 +14,12 @@ import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrNoGeneration = errors.New("litestream: no generation")
|
||||
|
||||
const (
|
||||
MetaDirSuffix = "-litestream"
|
||||
|
||||
@@ -42,6 +43,9 @@ type DB struct {
|
||||
rtx *sql.Tx // long running read transaction
|
||||
pageSize int // page size, in bytes
|
||||
|
||||
byteOrder binary.ByteOrder // determined by WAL header magic
|
||||
salt0, salt1 uint32 // read from WAL header
|
||||
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
wg sync.WaitGroup
|
||||
@@ -244,11 +248,11 @@ func (db *DB) releaseReadLock() error {
|
||||
}
|
||||
|
||||
// CurrentGeneration returns the name of the generation saved to the "generation"
|
||||
// file in the meta data directory. Returns ErrNoGeneration if none exists.
|
||||
// file in the meta data directory. Returns empty string if none exists.
|
||||
func (db *DB) CurrentGeneration() (string, error) {
|
||||
buf, err := ioutil.ReadFile(db.GenerationNamePath())
|
||||
if os.IsNotExist(err) {
|
||||
return "", ErrNoGeneration
|
||||
return "", nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -257,7 +261,7 @@ func (db *DB) CurrentGeneration() (string, error) {
|
||||
|
||||
generation := strings.TrimSpace(string(buf))
|
||||
if len(generation) != GenerationNameLen {
|
||||
return "", ErrNoGeneration
|
||||
return "", nil
|
||||
}
|
||||
return generation, nil
|
||||
}
|
||||
@@ -277,8 +281,8 @@ func (db *DB) createGeneration() (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Copy to shadow WAL.
|
||||
if err := db.copyInitialWAL(generation); err != nil {
|
||||
// Initialize shadow WAL with copy of header.
|
||||
if err := db.initShadowWALFile(db.ShadowWALPath(generation, 0)); err != nil {
|
||||
return "", fmt.Errorf("copy initial wal: %w", err)
|
||||
}
|
||||
|
||||
@@ -300,42 +304,12 @@ func (db *DB) createGeneration() (string, error) {
|
||||
return generation, nil
|
||||
}
|
||||
|
||||
// copyInitialWAL copies the full WAL file to the initial shadow WAL path.
|
||||
func (db *DB) copyInitialWAL(generation string) error {
|
||||
shadowWALPath := db.ShadowWALPath(generation, 0)
|
||||
if err := os.MkdirAll(filepath.Dir(shadowWALPath), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open the initial shadow WAL file for writing.
|
||||
w, err := os.OpenFile(shadowWALPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
// Open the database's WAL file for reading.
|
||||
r, err := os.Open(db.WALPath())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// Copy & sync.
|
||||
if _, err := io.Copy(w, r); err != nil {
|
||||
return err
|
||||
} else if err := w.Sync(); err != nil {
|
||||
return err
|
||||
} else if err := w.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync copies pending data from the WAL to the shadow WAL.
|
||||
func (db *DB) Sync() (err error) {
|
||||
// TODO: Lock DB while syncing?
|
||||
|
||||
// TODO: Force "-wal" file if it doesn't exist.
|
||||
|
||||
// Start a transaction. This will be promoted immediately after.
|
||||
tx, err := db.db.Begin()
|
||||
if err != nil {
|
||||
@@ -362,18 +336,29 @@ func (db *DB) Sync() (err error) {
|
||||
return fmt.Errorf("disable autocheckpoint: %w", err)
|
||||
}
|
||||
|
||||
// Look up existing generation or start a new one.
|
||||
generation, err := db.CurrentGeneration()
|
||||
if err == ErrNoGeneration {
|
||||
if generation, err = db.createGeneration(); err != nil {
|
||||
// Verify our last sync matches the current state of the WAL.
|
||||
// This ensures that we have an existing generation & that the last sync
|
||||
// position of the real WAL hasn't been overwritten by another process.
|
||||
//
|
||||
// If we are unable to verify the WAL state then we start a new generation.
|
||||
info, err := db.verifyWAL()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot verify wal state: %w", err)
|
||||
} else if info.reason != "" {
|
||||
// Start new generation & notify user via log message.
|
||||
if info.generation, err = db.createGeneration(); err != nil {
|
||||
return fmt.Errorf("create generation: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("cannot find current generation: %w", err)
|
||||
log.Printf("%s: new generation %q, %s", db.path, info.generation, info.reason)
|
||||
|
||||
// Clear shadow wal info.
|
||||
info.shadowWALPath, info.shadowWALSize = "", 0
|
||||
info.restart = false
|
||||
info.reason = ""
|
||||
}
|
||||
|
||||
// Synchronize real WAL with current shadow WAL.
|
||||
newWALSize, err := db.syncWAL(generation)
|
||||
newWALSize, err := db.syncWAL(info)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sync wal: %w", err)
|
||||
}
|
||||
@@ -402,6 +387,245 @@ func (db *DB) Sync() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyWAL ensures the current shadow WAL state matches where it left off from
|
||||
// the real WAL. Returns generation & WAL sync information. If info.reason is
|
||||
// not blank, verification failed and a new generation should be started.
|
||||
func (db *DB) verifyWAL() (info syncInfo, err error) {
|
||||
// Look up existing generation.
|
||||
generation, err := db.CurrentGeneration()
|
||||
if err != nil {
|
||||
return info, fmt.Errorf("cannot find current generation: %w", err)
|
||||
} else if generation == "" {
|
||||
info.reason = "no generation"
|
||||
return info, nil
|
||||
}
|
||||
info.generation = generation
|
||||
|
||||
// Determine total bytes of real WAL.
|
||||
fi, err := os.Stat(db.WALPath())
|
||||
if err != nil {
|
||||
return info, err
|
||||
}
|
||||
info.walSize = fi.Size()
|
||||
|
||||
// Open shadow WAL to copy append to.
|
||||
info.shadowWALPath, err = db.CurrentShadowWALPath(info.generation)
|
||||
if info.shadowWALPath == "" {
|
||||
info.reason = "no shadow wal"
|
||||
return info, nil
|
||||
} else if err != nil {
|
||||
return info, fmt.Errorf("cannot determine shadow WAL: %w", err)
|
||||
}
|
||||
|
||||
// Determine shadow WAL current size.
|
||||
fi, err = os.Stat(info.shadowWALPath)
|
||||
if err != nil {
|
||||
return info, err
|
||||
}
|
||||
info.shadowWALSize = fi.Size()
|
||||
|
||||
// TODO: Truncate shadow WAL if there is a partial page.
|
||||
|
||||
// If shadow WAL is larger than real WAL then the WAL has been truncated
|
||||
// so we cannot determine our last state.
|
||||
if info.shadowWALSize > info.walSize {
|
||||
info.reason = "wal truncated by another process"
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// Compare WAL headers. If mismatched then the real WAL has been restarted.
|
||||
// We'll need to start a new shadow WAL during the sync.
|
||||
if hdr0, err := readWALHeader(db.WALPath()); err != nil {
|
||||
return info, fmt.Errorf("cannot read wal header: %w", err)
|
||||
} else if hdr1, err := readWALHeader(info.shadowWALPath); err != nil {
|
||||
return info, fmt.Errorf("cannot read shadow wal header: %w", err)
|
||||
} else {
|
||||
info.restart = !bytes.Equal(hdr0, hdr1)
|
||||
}
|
||||
|
||||
// Verify last page synced still matches.
|
||||
offset := info.shadowWALSize - int64(db.pageSize+WALFrameHeaderSize)
|
||||
if buf0, err := readFileAt(db.WALPath(), offset, int64(db.pageSize+WALFrameHeaderSize)); err != nil {
|
||||
return info, fmt.Errorf("cannot read last synced wal page: %w", err)
|
||||
} else if buf1, err := readFileAt(info.shadowWALPath, offset, int64(db.pageSize+WALFrameHeaderSize)); err != nil {
|
||||
return info, fmt.Errorf("cannot read last synced shadow wal page: %w", err)
|
||||
} else if !bytes.Equal(buf0, buf1) {
|
||||
info.reason = "wal overwritten by another process"
|
||||
return info, nil
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
type syncInfo struct {
|
||||
generation string // generation name
|
||||
walSize int64 // size of real WAL file
|
||||
shadowWALPath string // name of last shadow WAL file
|
||||
shadowWALSize int64 // size of last shadow WAL file
|
||||
restart bool // if true, real WAL header does not match shadow WAL
|
||||
reason string // if non-blank, reason for sync failure
|
||||
}
|
||||
|
||||
// syncWAL copies pending bytes from the real WAL to the shadow WAL.
|
||||
func (db *DB) syncWAL(info syncInfo) (newSize int64, err error) {
|
||||
// Copy WAL starting from end of shadow WAL. Exit if no new shadow WAL needed.
|
||||
newSize, err = db.copyToShadowWAL(info.shadowWALPath)
|
||||
if err != nil {
|
||||
return newSize, fmt.Errorf("cannot copy to shadow wal: %w", err)
|
||||
} else if !info.restart {
|
||||
return newSize, nil // If no restart required, exit.
|
||||
}
|
||||
|
||||
// Parse index of current shadow WAL file.
|
||||
dir, base := filepath.Split(info.shadowWALPath)
|
||||
index, err := ParseWALFilename(base)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot parse shadow wal filename: %s", base)
|
||||
}
|
||||
|
||||
// Start a new shadow WAL file with next index.
|
||||
newShadowWALPath := filepath.Join(dir, FormatWALFilename(index+1))
|
||||
if err := db.initShadowWALFile(newShadowWALPath); err != nil {
|
||||
return 0, fmt.Errorf("cannot init shadow wal file: name=%s err=%w", newShadowWALPath, err)
|
||||
}
|
||||
|
||||
// Copy rest of valid WAL to new shadow WAL.
|
||||
newSize, err = db.copyToShadowWAL(newShadowWALPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot copy to new shadow wal: %w", err)
|
||||
}
|
||||
return newSize, nil
|
||||
}
|
||||
|
||||
func (db *DB) initShadowWALFile(filename string) error {
|
||||
hdr, err := readWALHeader(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read header: %w", err)
|
||||
}
|
||||
|
||||
// Determine byte order for checksumming from header magic.
|
||||
magic := binary.BigEndian.Uint32(hdr[0:])
|
||||
switch magic {
|
||||
case 0x377f0682:
|
||||
db.byteOrder = binary.LittleEndian
|
||||
case 0x377f0683:
|
||||
db.byteOrder = binary.BigEndian
|
||||
default:
|
||||
return fmt.Errorf("invalid wal header magic: %x", magic)
|
||||
}
|
||||
|
||||
// Read header salt.
|
||||
db.salt0 = binary.BigEndian.Uint32(hdr[16:])
|
||||
db.salt1 = binary.BigEndian.Uint32(hdr[20:])
|
||||
|
||||
// Verify checksum.
|
||||
s0 := binary.BigEndian.Uint32(hdr[24:])
|
||||
s1 := binary.BigEndian.Uint32(hdr[28:])
|
||||
if v0, v1 := Checksum(db.byteOrder, 0, 0, hdr[:24]); v0 != s0 || v1 != s1 {
|
||||
return fmt.Errorf("invalid header checksum: (%x,%x) != (%x,%x)", v0, v1, s0, s1)
|
||||
}
|
||||
|
||||
// Write header to new WAL shadow file.
|
||||
return ioutil.WriteFile(filename, hdr, 0600)
|
||||
}
|
||||
|
||||
func (db *DB) copyToShadowWAL(filename string) (newSize int64, err error) {
|
||||
r, err := os.Open(db.WALPath())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
w, err := os.OpenFile(filename, os.O_RDWR, 0600)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
fi, err := w.Stat()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Read previous checksum.
|
||||
chksum0, chksum1, err := readLastChecksumFrom(w, db.pageSize)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("last checksum: %w", err)
|
||||
}
|
||||
|
||||
// Seek to correct position on both files.
|
||||
if _, err := r.Seek(fi.Size(), io.SeekStart); err != nil {
|
||||
return 0, fmt.Errorf("wal seek: %w", err)
|
||||
} else if _, err := w.Seek(fi.Size(), io.SeekStart); err != nil {
|
||||
return 0, fmt.Errorf("shadow wal seek: %w", err)
|
||||
}
|
||||
|
||||
// TODO: Optimize to use bufio on reader & writer to minimize syscalls.
|
||||
|
||||
// Loop over each page, verify checksum, & copy to writer.
|
||||
newSize = fi.Size()
|
||||
buf := make([]byte, db.pageSize+WALFrameHeaderSize)
|
||||
for {
|
||||
// Read next page from WAL file.
|
||||
if _, err := io.ReadFull(r, buf); err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
break // end of file or partial page
|
||||
} else if err != nil {
|
||||
return newSize, fmt.Errorf("read wal: %w", err)
|
||||
}
|
||||
|
||||
// Read frame salt & compare to header salt. Stop reading on mismatch.
|
||||
salt0 := binary.BigEndian.Uint32(buf[8:])
|
||||
salt1 := binary.BigEndian.Uint32(buf[12:])
|
||||
if salt0 != db.salt0 || salt1 != db.salt1 {
|
||||
break
|
||||
}
|
||||
|
||||
// Verify checksum of page is valid.
|
||||
fchksum0 := binary.BigEndian.Uint32(buf[16:])
|
||||
fchksum1 := binary.BigEndian.Uint32(buf[20:])
|
||||
chksum0, chksum1 = Checksum(db.byteOrder, chksum0, chksum1, buf[:8]) // frame header
|
||||
chksum0, chksum1 = Checksum(db.byteOrder, chksum0, chksum1, buf[24:]) // frame data
|
||||
if chksum0 != fchksum0 || chksum1 != fchksum1 {
|
||||
return newSize, fmt.Errorf("checksum mismatch: offset=%d (%x,%x) != (%x,%x)", newSize, chksum0, chksum1, fchksum0, fchksum1)
|
||||
}
|
||||
|
||||
// Add page to the new size of the shadow WAL.
|
||||
newSize += int64(len(buf))
|
||||
}
|
||||
|
||||
// Sync & close writer.
|
||||
if err := w.Sync(); err != nil {
|
||||
return newSize, err
|
||||
} else if err := w.Close(); err != nil {
|
||||
return newSize, err
|
||||
}
|
||||
|
||||
return newSize, nil
|
||||
}
|
||||
|
||||
const WALHeaderChecksumOffset = 24
|
||||
const WALFrameHeaderChecksumOffset = 16
|
||||
|
||||
func readLastChecksumFrom(f *os.File, pageSize int) (uint32, uint32, error) {
|
||||
// Determine the byte offset of the checksum for the header (if no pages
|
||||
// exist) or for the last page (if at least one page exists).
|
||||
offset := int64(WALHeaderChecksumOffset)
|
||||
if fi, err := f.Stat(); err != nil {
|
||||
return 0, 0, err
|
||||
} else if fi.Size() > WALHeaderSize {
|
||||
offset = fi.Size() - int64(pageSize) - WALFrameHeaderSize + WALFrameHeaderChecksumOffset
|
||||
}
|
||||
|
||||
// Read big endian checksum.
|
||||
b := make([]byte, 8)
|
||||
if n, err := f.ReadAt(b, offset); err != nil {
|
||||
return 0, 0, err
|
||||
} else if n != len(b) {
|
||||
return 0, 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
return binary.BigEndian.Uint32(b[0:]), binary.BigEndian.Uint32(b[4:]), nil
|
||||
}
|
||||
|
||||
// checkpoint performs a checkpoint on the WAL file.
|
||||
func (db *DB) checkpoint(force bool) error {
|
||||
// Ensure the read lock has been removed before issuing a checkpoint.
|
||||
@@ -433,72 +657,6 @@ func (db *DB) checkpoint(force bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncWAL copies pending bytes from the real WAL to the shadow WAL.
|
||||
func (db *DB) syncWAL(generation string) (newSize int64, err error) {
|
||||
// Determine total bytes of real WAL.
|
||||
fi, err := os.Stat(db.WALPath())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
walSize := fi.Size()
|
||||
|
||||
// Open shadow WAL to copy append to.
|
||||
shadowWALPath, err := db.CurrentShadowWALPath(generation)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot determine shadow WAL: %w", err)
|
||||
}
|
||||
|
||||
// TODO: Compare WAL headers.
|
||||
|
||||
// Determine shadow WAL current size.
|
||||
fi, err = os.Stat(shadowWALPath)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
shadowWALSize := fi.Size()
|
||||
|
||||
// Ensure we have pending bytes to write.
|
||||
// TODO: Verify pending bytes is divisble by (pageSize+headerSize)?
|
||||
pendingN := walSize - shadowWALSize
|
||||
if pendingN < 0 {
|
||||
panic("shadow wal larger than real wal") // TODO: Handle gracefully
|
||||
} else if pendingN == 0 {
|
||||
return shadowWALSize, nil // wals match, exit
|
||||
}
|
||||
|
||||
// TODO: Verify last page copied matches.
|
||||
|
||||
// Open handles for the shadow WAL & real WAL.
|
||||
w, err := os.OpenFile(shadowWALPath, os.O_RDWR, 0600)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
r, err := os.Open(db.WALPath())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// Seek to the correct position for each file.
|
||||
if _, err := r.Seek(shadowWALSize, io.SeekStart); err != nil {
|
||||
return 0, fmt.Errorf("wal seek: %w", err)
|
||||
} else if _, err := w.Seek(shadowWALSize, io.SeekStart); err != nil {
|
||||
return 0, fmt.Errorf("shadow wal seek: %w", err)
|
||||
}
|
||||
|
||||
// Copy and sync.
|
||||
if _, err := io.CopyN(w, r, pendingN); err != nil {
|
||||
return 0, fmt.Errorf("copy shadow wal error: %w", err)
|
||||
} else if err := w.Sync(); err != nil {
|
||||
return 0, fmt.Errorf("shadow wal sync: %w", err)
|
||||
} else if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("shadow wal close: %w", err)
|
||||
}
|
||||
return walSize, nil
|
||||
}
|
||||
|
||||
// monitor runs in a separate goroutine and monitors the database & WAL.
|
||||
func (db *DB) monitor() {
|
||||
ticker := time.NewTicker(db.MonitorInterval)
|
||||
@@ -543,3 +701,46 @@ func rollback(tx *sql.Tx) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readWALHeader returns the header read from a WAL file.
|
||||
func readWALHeader(filename string) ([]byte, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
buf := make([]byte, WALHeaderSize)
|
||||
n, err := io.ReadFull(f, buf)
|
||||
return buf[:n], err
|
||||
}
|
||||
|
||||
// readFileAt reads a slice from a file.
|
||||
func readFileAt(filename string, offset, n int64) ([]byte, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
buf := make([]byte, n)
|
||||
if n, err := f.ReadAt(buf, offset); err != nil {
|
||||
return buf[:n], err
|
||||
} else if n < len(buf) {
|
||||
return buf[:n], io.ErrUnexpectedEOF
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func ParseWALFilename(name string) (index int, err error) {
|
||||
v, err := strconv.ParseInt(strings.TrimSuffix(name, WALExt), 16, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid wal filename: %q", name)
|
||||
}
|
||||
return int(v), nil
|
||||
}
|
||||
|
||||
func FormatWALFilename(index int) string {
|
||||
assert(index >= 0, "wal index must be non-negative")
|
||||
return fmt.Sprintf("%016d%s", index, WALExt)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,25 @@
|
||||
package litestream
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Checksum computes a running SQLite checksum over a byte slice.
|
||||
func Checksum(bo binary.ByteOrder, s0, s1 uint32, b []byte) (uint32, uint32) {
|
||||
assert(len(b)%8 == 0, "misaligned checksum byte slice")
|
||||
|
||||
// Iterate over 8-byte units and compute checksum.
|
||||
for i := 0; i < len(b); i += 8 {
|
||||
s0 += bo.Uint32(b[i:]) + s1
|
||||
s1 += bo.Uint32(b[i+4:]) + s0
|
||||
}
|
||||
return s0, s1
|
||||
}
|
||||
|
||||
// HexDump returns hexdump output but with duplicate lines removed.
|
||||
func HexDump(b []byte) string {
|
||||
const prefixN = len("00000000")
|
||||
|
||||
Reference in New Issue
Block a user