Add salt & checksum checking

This commit is contained in:
Ben Johnson
2020-12-21 16:59:15 -07:00
parent f4819efbeb
commit 2bbe5d91bf
2 changed files with 326 additions and 112 deletions

425
db.go
View File

@@ -1,8 +1,10 @@
package litestream
import (
"bytes"
"context"
"database/sql"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -12,13 +14,12 @@ import (
"math/rand"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
var ErrNoGeneration = errors.New("litestream: no generation")
const (
MetaDirSuffix = "-litestream"
@@ -42,6 +43,9 @@ type DB struct {
rtx *sql.Tx // long running read transaction
pageSize int // page size, in bytes
byteOrder binary.ByteOrder // determined by WAL header magic
salt0, salt1 uint32 // read from WAL header
ctx context.Context
cancel func()
wg sync.WaitGroup
@@ -244,11 +248,11 @@ func (db *DB) releaseReadLock() error {
}
// CurrentGeneration returns the name of the generation saved to the "generation"
// file in the meta data directory. Returns ErrNoGeneration if none exists.
// file in the meta data directory. Returns empty string if none exists.
func (db *DB) CurrentGeneration() (string, error) {
buf, err := ioutil.ReadFile(db.GenerationNamePath())
if os.IsNotExist(err) {
return "", ErrNoGeneration
return "", nil
} else if err != nil {
return "", err
}
@@ -257,7 +261,7 @@ func (db *DB) CurrentGeneration() (string, error) {
generation := strings.TrimSpace(string(buf))
if len(generation) != GenerationNameLen {
return "", ErrNoGeneration
return "", nil
}
return generation, nil
}
@@ -277,8 +281,8 @@ func (db *DB) createGeneration() (string, error) {
return "", err
}
// Copy to shadow WAL.
if err := db.copyInitialWAL(generation); err != nil {
// Initialize shadow WAL with copy of header.
if err := db.initShadowWALFile(db.ShadowWALPath(generation, 0)); err != nil {
return "", fmt.Errorf("copy initial wal: %w", err)
}
@@ -300,42 +304,12 @@ func (db *DB) createGeneration() (string, error) {
return generation, nil
}
// copyInitialWAL copies the full WAL file to the initial shadow WAL path.
func (db *DB) copyInitialWAL(generation string) error {
shadowWALPath := db.ShadowWALPath(generation, 0)
if err := os.MkdirAll(filepath.Dir(shadowWALPath), 0700); err != nil {
return err
}
// Open the initial shadow WAL file for writing.
w, err := os.OpenFile(shadowWALPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer w.Close()
// Open the database's WAL file for reading.
r, err := os.Open(db.WALPath())
if err != nil {
return err
}
defer r.Close()
// Copy & sync.
if _, err := io.Copy(w, r); err != nil {
return err
} else if err := w.Sync(); err != nil {
return err
} else if err := w.Close(); err != nil {
return err
}
return nil
}
// Sync copies pending data from the WAL to the shadow WAL.
func (db *DB) Sync() (err error) {
// TODO: Lock DB while syncing?
// TODO: Force "-wal" file if it doesn't exist.
// Start a transaction. This will be promoted immediately after.
tx, err := db.db.Begin()
if err != nil {
@@ -362,18 +336,29 @@ func (db *DB) Sync() (err error) {
return fmt.Errorf("disable autocheckpoint: %w", err)
}
// Look up existing generation or start a new one.
generation, err := db.CurrentGeneration()
if err == ErrNoGeneration {
if generation, err = db.createGeneration(); err != nil {
// Verify our last sync matches the current state of the WAL.
// This ensures that we have an existing generation & that the last sync
// position of the real WAL hasn't been overwritten by another process.
//
// If we are unable to verify the WAL state then we start a new generation.
info, err := db.verifyWAL()
if err != nil {
return fmt.Errorf("cannot verify wal state: %w", err)
} else if info.reason != "" {
// Start new generation & notify user via log message.
if info.generation, err = db.createGeneration(); err != nil {
return fmt.Errorf("create generation: %w", err)
}
} else if err != nil {
return fmt.Errorf("cannot find current generation: %w", err)
log.Printf("%s: new generation %q, %s", db.path, info.generation, info.reason)
// Clear shadow wal info.
info.shadowWALPath, info.shadowWALSize = "", 0
info.restart = false
info.reason = ""
}
// Synchronize real WAL with current shadow WAL.
newWALSize, err := db.syncWAL(generation)
newWALSize, err := db.syncWAL(info)
if err != nil {
return fmt.Errorf("sync wal: %w", err)
}
@@ -402,6 +387,245 @@ func (db *DB) Sync() (err error) {
return nil
}
// verifyWAL ensures the current shadow WAL state matches where it left off from
// the real WAL. Returns generation & WAL sync information. If info.reason is
// not blank, verification failed and a new generation should be started.
func (db *DB) verifyWAL() (info syncInfo, err error) {
// Look up existing generation.
generation, err := db.CurrentGeneration()
if err != nil {
return info, fmt.Errorf("cannot find current generation: %w", err)
} else if generation == "" {
info.reason = "no generation"
return info, nil
}
info.generation = generation
// Determine total bytes of real WAL.
fi, err := os.Stat(db.WALPath())
if err != nil {
return info, err
}
info.walSize = fi.Size()
// Open shadow WAL to copy append to.
info.shadowWALPath, err = db.CurrentShadowWALPath(info.generation)
if info.shadowWALPath == "" {
info.reason = "no shadow wal"
return info, nil
} else if err != nil {
return info, fmt.Errorf("cannot determine shadow WAL: %w", err)
}
// Determine shadow WAL current size.
fi, err = os.Stat(info.shadowWALPath)
if err != nil {
return info, err
}
info.shadowWALSize = fi.Size()
// TODO: Truncate shadow WAL if there is a partial page.
// If shadow WAL is larger than real WAL then the WAL has been truncated
// so we cannot determine our last state.
if info.shadowWALSize > info.walSize {
info.reason = "wal truncated by another process"
return info, nil
}
// Compare WAL headers. If mismatched then the real WAL has been restarted.
// We'll need to start a new shadow WAL during the sync.
if hdr0, err := readWALHeader(db.WALPath()); err != nil {
return info, fmt.Errorf("cannot read wal header: %w", err)
} else if hdr1, err := readWALHeader(info.shadowWALPath); err != nil {
return info, fmt.Errorf("cannot read shadow wal header: %w", err)
} else {
info.restart = !bytes.Equal(hdr0, hdr1)
}
// Verify last page synced still matches.
offset := info.shadowWALSize - int64(db.pageSize+WALFrameHeaderSize)
if buf0, err := readFileAt(db.WALPath(), offset, int64(db.pageSize+WALFrameHeaderSize)); err != nil {
return info, fmt.Errorf("cannot read last synced wal page: %w", err)
} else if buf1, err := readFileAt(info.shadowWALPath, offset, int64(db.pageSize+WALFrameHeaderSize)); err != nil {
return info, fmt.Errorf("cannot read last synced shadow wal page: %w", err)
} else if !bytes.Equal(buf0, buf1) {
info.reason = "wal overwritten by another process"
return info, nil
}
return info, nil
}
type syncInfo struct {
generation string // generation name
walSize int64 // size of real WAL file
shadowWALPath string // name of last shadow WAL file
shadowWALSize int64 // size of last shadow WAL file
restart bool // if true, real WAL header does not match shadow WAL
reason string // if non-blank, reason for sync failure
}
// syncWAL copies pending bytes from the real WAL to the shadow WAL.
func (db *DB) syncWAL(info syncInfo) (newSize int64, err error) {
// Copy WAL starting from end of shadow WAL. Exit if no new shadow WAL needed.
newSize, err = db.copyToShadowWAL(info.shadowWALPath)
if err != nil {
return newSize, fmt.Errorf("cannot copy to shadow wal: %w", err)
} else if !info.restart {
return newSize, nil // If no restart required, exit.
}
// Parse index of current shadow WAL file.
dir, base := filepath.Split(info.shadowWALPath)
index, err := ParseWALFilename(base)
if err != nil {
return 0, fmt.Errorf("cannot parse shadow wal filename: %s", base)
}
// Start a new shadow WAL file with next index.
newShadowWALPath := filepath.Join(dir, FormatWALFilename(index+1))
if err := db.initShadowWALFile(newShadowWALPath); err != nil {
return 0, fmt.Errorf("cannot init shadow wal file: name=%s err=%w", newShadowWALPath, err)
}
// Copy rest of valid WAL to new shadow WAL.
newSize, err = db.copyToShadowWAL(newShadowWALPath)
if err != nil {
return 0, fmt.Errorf("cannot copy to new shadow wal: %w", err)
}
return newSize, nil
}
func (db *DB) initShadowWALFile(filename string) error {
hdr, err := readWALHeader(filename)
if err != nil {
return fmt.Errorf("read header: %w", err)
}
// Determine byte order for checksumming from header magic.
magic := binary.BigEndian.Uint32(hdr[0:])
switch magic {
case 0x377f0682:
db.byteOrder = binary.LittleEndian
case 0x377f0683:
db.byteOrder = binary.BigEndian
default:
return fmt.Errorf("invalid wal header magic: %x", magic)
}
// Read header salt.
db.salt0 = binary.BigEndian.Uint32(hdr[16:])
db.salt1 = binary.BigEndian.Uint32(hdr[20:])
// Verify checksum.
s0 := binary.BigEndian.Uint32(hdr[24:])
s1 := binary.BigEndian.Uint32(hdr[28:])
if v0, v1 := Checksum(db.byteOrder, 0, 0, hdr[:24]); v0 != s0 || v1 != s1 {
return fmt.Errorf("invalid header checksum: (%x,%x) != (%x,%x)", v0, v1, s0, s1)
}
// Write header to new WAL shadow file.
return ioutil.WriteFile(filename, hdr, 0600)
}
func (db *DB) copyToShadowWAL(filename string) (newSize int64, err error) {
r, err := os.Open(db.WALPath())
if err != nil {
return 0, err
}
defer r.Close()
w, err := os.OpenFile(filename, os.O_RDWR, 0600)
if err != nil {
return 0, err
}
defer w.Close()
fi, err := w.Stat()
if err != nil {
return 0, err
}
// Read previous checksum.
chksum0, chksum1, err := readLastChecksumFrom(w, db.pageSize)
if err != nil {
return 0, fmt.Errorf("last checksum: %w", err)
}
// Seek to correct position on both files.
if _, err := r.Seek(fi.Size(), io.SeekStart); err != nil {
return 0, fmt.Errorf("wal seek: %w", err)
} else if _, err := w.Seek(fi.Size(), io.SeekStart); err != nil {
return 0, fmt.Errorf("shadow wal seek: %w", err)
}
// TODO: Optimize to use bufio on reader & writer to minimize syscalls.
// Loop over each page, verify checksum, & copy to writer.
newSize = fi.Size()
buf := make([]byte, db.pageSize+WALFrameHeaderSize)
for {
// Read next page from WAL file.
if _, err := io.ReadFull(r, buf); err == io.EOF || err == io.ErrUnexpectedEOF {
break // end of file or partial page
} else if err != nil {
return newSize, fmt.Errorf("read wal: %w", err)
}
// Read frame salt & compare to header salt. Stop reading on mismatch.
salt0 := binary.BigEndian.Uint32(buf[8:])
salt1 := binary.BigEndian.Uint32(buf[12:])
if salt0 != db.salt0 || salt1 != db.salt1 {
break
}
// Verify checksum of page is valid.
fchksum0 := binary.BigEndian.Uint32(buf[16:])
fchksum1 := binary.BigEndian.Uint32(buf[20:])
chksum0, chksum1 = Checksum(db.byteOrder, chksum0, chksum1, buf[:8]) // frame header
chksum0, chksum1 = Checksum(db.byteOrder, chksum0, chksum1, buf[24:]) // frame data
if chksum0 != fchksum0 || chksum1 != fchksum1 {
return newSize, fmt.Errorf("checksum mismatch: offset=%d (%x,%x) != (%x,%x)", newSize, chksum0, chksum1, fchksum0, fchksum1)
}
// Add page to the new size of the shadow WAL.
newSize += int64(len(buf))
}
// Sync & close writer.
if err := w.Sync(); err != nil {
return newSize, err
} else if err := w.Close(); err != nil {
return newSize, err
}
return newSize, nil
}
const WALHeaderChecksumOffset = 24
const WALFrameHeaderChecksumOffset = 16
func readLastChecksumFrom(f *os.File, pageSize int) (uint32, uint32, error) {
// Determine the byte offset of the checksum for the header (if no pages
// exist) or for the last page (if at least one page exists).
offset := int64(WALHeaderChecksumOffset)
if fi, err := f.Stat(); err != nil {
return 0, 0, err
} else if fi.Size() > WALHeaderSize {
offset = fi.Size() - int64(pageSize) - WALFrameHeaderSize + WALFrameHeaderChecksumOffset
}
// Read big endian checksum.
b := make([]byte, 8)
if n, err := f.ReadAt(b, offset); err != nil {
return 0, 0, err
} else if n != len(b) {
return 0, 0, io.ErrUnexpectedEOF
}
return binary.BigEndian.Uint32(b[0:]), binary.BigEndian.Uint32(b[4:]), nil
}
// checkpoint performs a checkpoint on the WAL file.
func (db *DB) checkpoint(force bool) error {
// Ensure the read lock has been removed before issuing a checkpoint.
@@ -433,72 +657,6 @@ func (db *DB) checkpoint(force bool) error {
return nil
}
// syncWAL copies pending bytes from the real WAL to the shadow WAL.
func (db *DB) syncWAL(generation string) (newSize int64, err error) {
// Determine total bytes of real WAL.
fi, err := os.Stat(db.WALPath())
if err != nil {
return 0, err
}
walSize := fi.Size()
// Open shadow WAL to copy append to.
shadowWALPath, err := db.CurrentShadowWALPath(generation)
if err != nil {
return 0, fmt.Errorf("cannot determine shadow WAL: %w", err)
}
// TODO: Compare WAL headers.
// Determine shadow WAL current size.
fi, err = os.Stat(shadowWALPath)
if err != nil {
return 0, err
}
shadowWALSize := fi.Size()
// Ensure we have pending bytes to write.
// TODO: Verify pending bytes is divisble by (pageSize+headerSize)?
pendingN := walSize - shadowWALSize
if pendingN < 0 {
panic("shadow wal larger than real wal") // TODO: Handle gracefully
} else if pendingN == 0 {
return shadowWALSize, nil // wals match, exit
}
// TODO: Verify last page copied matches.
// Open handles for the shadow WAL & real WAL.
w, err := os.OpenFile(shadowWALPath, os.O_RDWR, 0600)
if err != nil {
return 0, err
}
defer w.Close()
r, err := os.Open(db.WALPath())
if err != nil {
return 0, err
}
defer r.Close()
// Seek to the correct position for each file.
if _, err := r.Seek(shadowWALSize, io.SeekStart); err != nil {
return 0, fmt.Errorf("wal seek: %w", err)
} else if _, err := w.Seek(shadowWALSize, io.SeekStart); err != nil {
return 0, fmt.Errorf("shadow wal seek: %w", err)
}
// Copy and sync.
if _, err := io.CopyN(w, r, pendingN); err != nil {
return 0, fmt.Errorf("copy shadow wal error: %w", err)
} else if err := w.Sync(); err != nil {
return 0, fmt.Errorf("shadow wal sync: %w", err)
} else if err := w.Close(); err != nil {
return 0, fmt.Errorf("shadow wal close: %w", err)
}
return walSize, nil
}
// monitor runs in a separate goroutine and monitors the database & WAL.
func (db *DB) monitor() {
ticker := time.NewTicker(db.MonitorInterval)
@@ -543,3 +701,46 @@ func rollback(tx *sql.Tx) error {
}
return nil
}
// readWALHeader returns the header read from a WAL file.
func readWALHeader(filename string) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
buf := make([]byte, WALHeaderSize)
n, err := io.ReadFull(f, buf)
return buf[:n], err
}
// readFileAt reads a slice from a file.
func readFileAt(filename string, offset, n int64) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
buf := make([]byte, n)
if n, err := f.ReadAt(buf, offset); err != nil {
return buf[:n], err
} else if n < len(buf) {
return buf[:n], io.ErrUnexpectedEOF
}
return buf, nil
}
func ParseWALFilename(name string) (index int, err error) {
v, err := strconv.ParseInt(strings.TrimSuffix(name, WALExt), 16, 64)
if err != nil {
return 0, fmt.Errorf("invalid wal filename: %q", name)
}
return int(v), nil
}
func FormatWALFilename(index int) string {
assert(index >= 0, "wal index must be non-negative")
return fmt.Sprintf("%016d%s", index, WALExt)
}