Refactor checksum calculation; improve test coverage

This commit is contained in:
Ben Johnson
2021-12-12 10:25:20 -07:00
parent ba6e13b5d0
commit 531e19ed6f
3 changed files with 114 additions and 44 deletions

95
db.go
View File

@@ -300,58 +300,20 @@ func (db *DB) invalidateChecksum(ctx context.Context) error {
assert(!db.pos.IsZero(), "position required to invalidate checksum")
// Read entire WAL from combined segments.
walReader, err := db.WALReader(ctx, db.pos.Generation, db.pos.Index)
rc, err := db.WALReader(ctx, db.pos.Generation, db.pos.Index)
if err != nil {
return fmt.Errorf("cannot read last wal: %w", err)
}
defer walReader.Close()
defer func() { _ = rc.Close() }()
// Ensure we don't read past our position.
r := &io.LimitedReader{R: walReader, N: db.pos.Offset}
r := &io.LimitedReader{R: rc, N: db.pos.Offset}
// Read header.
hdr := make([]byte, WALHeaderSize)
if _, err := io.ReadFull(r, hdr); err != nil {
return fmt.Errorf("read shadow wal header: %w", err)
}
// Read byte order.
byteOrder, err := headerByteOrder(hdr)
// Determine cache values from the current WAL file.
db.salt0, db.salt1, db.chksum0, db.chksum1, db.byteOrder, db.frame, err = ReadWALFields(r, db.pageSize)
if err != nil {
return err
return fmt.Errorf("calc checksum: %w", err)
}
// Save salt & checksum to cache, although checksum may be overridden later.
db.salt0 = binary.BigEndian.Uint32(hdr[16:])
db.salt1 = binary.BigEndian.Uint32(hdr[20:])
db.chksum0 = binary.BigEndian.Uint32(hdr[24:])
db.chksum1 = binary.BigEndian.Uint32(hdr[28:])
db.byteOrder = byteOrder
// Iterate over each page in the WAL and save the checksum.
frame := make([]byte, db.pageSize+WALFrameHeaderSize)
var hasFrame bool
for {
// Read next page from WAL file.
if _, err := io.ReadFull(r, frame); err == io.EOF {
break // end of WAL file
} else if err != nil {
return fmt.Errorf("read wal: %w", err)
}
// Save frame checksum to cache.
hasFrame = true
db.chksum0 = binary.BigEndian.Uint32(frame[16:])
db.chksum1 = binary.BigEndian.Uint32(frame[20:])
}
// Save last frame to cache.
if hasFrame {
db.frame = frame
} else {
db.frame = nil
}
return nil
}
@@ -1739,6 +1701,51 @@ func NewRestoreOptions() RestoreOptions {
}
}
// ReadWALFields iterates over the header & frames in the WAL data in r.
// Returns salt, checksum, byte order & the last frame. WAL data must start
// from the beginning of the WAL header and must end on either the WAL header
// or at the end of a WAL frame.
func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 uint32, byteOrder binary.ByteOrder, frame []byte, err error) {
// Read header.
hdr := make([]byte, WALHeaderSize)
if _, err := io.ReadFull(r, hdr); err != nil {
return 0, 0, 0, 0, nil, nil, fmt.Errorf("short wal header: %w", err)
}
// Save salt, initial checksum, & byte order.
salt0 = binary.BigEndian.Uint32(hdr[16:])
salt1 = binary.BigEndian.Uint32(hdr[20:])
chksum0 = binary.BigEndian.Uint32(hdr[24:])
chksum1 = binary.BigEndian.Uint32(hdr[28:])
if byteOrder, err = headerByteOrder(hdr); err != nil {
return 0, 0, 0, 0, nil, nil, err
}
// Iterate over each page in the WAL and save the checksum.
frame = make([]byte, pageSize+WALFrameHeaderSize)
var hasFrame bool
for {
// Read next page from WAL file.
if n, err := io.ReadFull(r, frame); err == io.EOF {
break // end of WAL file
} else if err != nil {
return 0, 0, 0, 0, nil, nil, fmt.Errorf("short wal frame (n=%d): %w", n, err)
}
// Update checksum on each successful frame.
hasFrame = true
chksum0 = binary.BigEndian.Uint32(frame[16:])
chksum1 = binary.BigEndian.Uint32(frame[20:])
}
// Clear frame if none were successfully read.
if !hasFrame {
frame = nil
}
return salt0, salt1, chksum0, chksum1, byteOrder, frame, nil
}
// Database metrics.
var (
dbSizeGaugeVec = promauto.NewGaugeVec(prometheus.GaugeOpts{

View File

@@ -1,8 +1,10 @@
package litestream_test
import (
"bytes"
"context"
"database/sql"
"encoding/binary"
"os"
"path/filepath"
"strings"
@@ -473,6 +475,67 @@ func TestDB_Sync(t *testing.T) {
})
}
func TestReadWALFields(t *testing.T) {
b, err := os.ReadFile("testdata/read-wal-fields/ok")
if err != nil {
t.Fatal(err)
}
t.Run("OK", func(t *testing.T) {
if salt0, salt1, chksum0, chksum1, byteOrder, frame, err := litestream.ReadWALFields(bytes.NewReader(b), 4096); err != nil {
t.Fatal(err)
} else if got, want := salt0, uint32(0x4F7598FD); got != want {
t.Fatalf("salt0=%x, want %x", got, want)
} else if got, want := salt1, uint32(0x875FFD5B); got != want {
t.Fatalf("salt1=%x, want %x", got, want)
} else if got, want := chksum0, uint32(0x2081CAF7); got != want {
t.Fatalf("chksum0=%x, want %x", got, want)
} else if got, want := chksum1, uint32(0x31093CD3); got != want {
t.Fatalf("chksum1=%x, want %x", got, want)
} else if got, want := byteOrder, binary.LittleEndian; got != want {
t.Fatalf("chksum1=%x, want %x", got, want)
} else if !bytes.Equal(frame, b[8272:]) {
t.Fatal("last frame mismatch")
}
})
t.Run("HeaderOnly", func(t *testing.T) {
if salt0, salt1, chksum0, chksum1, byteOrder, frame, err := litestream.ReadWALFields(bytes.NewReader(b[:32]), 4096); err != nil {
t.Fatal(err)
} else if got, want := salt0, uint32(0x4F7598FD); got != want {
t.Fatalf("salt0=%x, want %x", got, want)
} else if got, want := salt1, uint32(0x875FFD5B); got != want {
t.Fatalf("salt1=%x, want %x", got, want)
} else if got, want := chksum0, uint32(0xD27F7862); got != want {
t.Fatalf("chksum0=%x, want %x", got, want)
} else if got, want := chksum1, uint32(0xE664AF8E); got != want {
t.Fatalf("chksum1=%x, want %x", got, want)
} else if got, want := byteOrder, binary.LittleEndian; got != want {
t.Fatalf("chksum1=%x, want %x", got, want)
} else if frame != nil {
t.Fatal("expected no frame")
}
})
t.Run("ErrShortHeader", func(t *testing.T) {
if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader([]byte{}), 4096); err == nil || err.Error() != `short wal header: EOF` {
t.Fatal(err)
}
})
t.Run("ErrBadMagic", func(t *testing.T) {
if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(make([]byte, 32)), 4096); err == nil || err.Error() != `invalid wal header magic: 0` {
t.Fatal(err)
}
})
t.Run("ErrShortFrame", func(t *testing.T) {
if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(b[:100]), 4096); err == nil || err.Error() != `short wal frame (n=68): unexpected EOF` {
t.Fatal(err)
}
})
}
// MustOpenDBs returns a new instance of a DB & associated SQL DB.
func MustOpenDBs(tb testing.TB) (*litestream.DB, *sql.DB) {
tb.Helper()

BIN
testdata/read-wal-fields/ok vendored Normal file

Binary file not shown.