From 531e19ed6faea423162978e12afaaebbe0bb087e Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Sun, 12 Dec 2021 10:25:20 -0700 Subject: [PATCH] Refactor checksum calculation; improve test coverage --- db.go | 95 +++++++++++++++++++----------------- db_test.go | 63 ++++++++++++++++++++++++ testdata/read-wal-fields/ok | Bin 0 -> 12392 bytes 3 files changed, 114 insertions(+), 44 deletions(-) create mode 100644 testdata/read-wal-fields/ok diff --git a/db.go b/db.go index 2a421ad..8c224cc 100644 --- a/db.go +++ b/db.go @@ -300,58 +300,20 @@ func (db *DB) invalidateChecksum(ctx context.Context) error { assert(!db.pos.IsZero(), "position required to invalidate checksum") // Read entire WAL from combined segments. - walReader, err := db.WALReader(ctx, db.pos.Generation, db.pos.Index) + rc, err := db.WALReader(ctx, db.pos.Generation, db.pos.Index) if err != nil { return fmt.Errorf("cannot read last wal: %w", err) } - defer walReader.Close() + defer func() { _ = rc.Close() }() // Ensure we don't read past our position. - r := &io.LimitedReader{R: walReader, N: db.pos.Offset} + r := &io.LimitedReader{R: rc, N: db.pos.Offset} - // Read header. - hdr := make([]byte, WALHeaderSize) - if _, err := io.ReadFull(r, hdr); err != nil { - return fmt.Errorf("read shadow wal header: %w", err) - } - - // Read byte order. - byteOrder, err := headerByteOrder(hdr) + // Determine cache values from the current WAL file. + db.salt0, db.salt1, db.chksum0, db.chksum1, db.byteOrder, db.frame, err = ReadWALFields(r, db.pageSize) if err != nil { - return err + return fmt.Errorf("calc checksum: %w", err) } - - // Save salt & checksum to cache, although checksum may be overridden later. - db.salt0 = binary.BigEndian.Uint32(hdr[16:]) - db.salt1 = binary.BigEndian.Uint32(hdr[20:]) - db.chksum0 = binary.BigEndian.Uint32(hdr[24:]) - db.chksum1 = binary.BigEndian.Uint32(hdr[28:]) - db.byteOrder = byteOrder - - // Iterate over each page in the WAL and save the checksum. - frame := make([]byte, db.pageSize+WALFrameHeaderSize) - var hasFrame bool - for { - // Read next page from WAL file. - if _, err := io.ReadFull(r, frame); err == io.EOF { - break // end of WAL file - } else if err != nil { - return fmt.Errorf("read wal: %w", err) - } - - // Save frame checksum to cache. - hasFrame = true - db.chksum0 = binary.BigEndian.Uint32(frame[16:]) - db.chksum1 = binary.BigEndian.Uint32(frame[20:]) - } - - // Save last frame to cache. - if hasFrame { - db.frame = frame - } else { - db.frame = nil - } - return nil } @@ -1739,6 +1701,51 @@ func NewRestoreOptions() RestoreOptions { } } +// ReadWALFields iterates over the header & frames in the WAL data in r. +// Returns salt, checksum, byte order & the last frame. WAL data must start +// from the beginning of the WAL header and must end on either the WAL header +// or at the end of a WAL frame. +func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 uint32, byteOrder binary.ByteOrder, frame []byte, err error) { + // Read header. + hdr := make([]byte, WALHeaderSize) + if _, err := io.ReadFull(r, hdr); err != nil { + return 0, 0, 0, 0, nil, nil, fmt.Errorf("short wal header: %w", err) + } + + // Save salt, initial checksum, & byte order. + salt0 = binary.BigEndian.Uint32(hdr[16:]) + salt1 = binary.BigEndian.Uint32(hdr[20:]) + chksum0 = binary.BigEndian.Uint32(hdr[24:]) + chksum1 = binary.BigEndian.Uint32(hdr[28:]) + if byteOrder, err = headerByteOrder(hdr); err != nil { + return 0, 0, 0, 0, nil, nil, err + } + + // Iterate over each page in the WAL and save the checksum. + frame = make([]byte, pageSize+WALFrameHeaderSize) + var hasFrame bool + for { + // Read next page from WAL file. + if n, err := io.ReadFull(r, frame); err == io.EOF { + break // end of WAL file + } else if err != nil { + return 0, 0, 0, 0, nil, nil, fmt.Errorf("short wal frame (n=%d): %w", n, err) + } + + // Update checksum on each successful frame. + hasFrame = true + chksum0 = binary.BigEndian.Uint32(frame[16:]) + chksum1 = binary.BigEndian.Uint32(frame[20:]) + } + + // Clear frame if none were successfully read. + if !hasFrame { + frame = nil + } + + return salt0, salt1, chksum0, chksum1, byteOrder, frame, nil +} + // Database metrics. var ( dbSizeGaugeVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ diff --git a/db_test.go b/db_test.go index 220f7e6..5c3f51c 100644 --- a/db_test.go +++ b/db_test.go @@ -1,8 +1,10 @@ package litestream_test import ( + "bytes" "context" "database/sql" + "encoding/binary" "os" "path/filepath" "strings" @@ -473,6 +475,67 @@ func TestDB_Sync(t *testing.T) { }) } +func TestReadWALFields(t *testing.T) { + b, err := os.ReadFile("testdata/read-wal-fields/ok") + if err != nil { + t.Fatal(err) + } + + t.Run("OK", func(t *testing.T) { + if salt0, salt1, chksum0, chksum1, byteOrder, frame, err := litestream.ReadWALFields(bytes.NewReader(b), 4096); err != nil { + t.Fatal(err) + } else if got, want := salt0, uint32(0x4F7598FD); got != want { + t.Fatalf("salt0=%x, want %x", got, want) + } else if got, want := salt1, uint32(0x875FFD5B); got != want { + t.Fatalf("salt1=%x, want %x", got, want) + } else if got, want := chksum0, uint32(0x2081CAF7); got != want { + t.Fatalf("chksum0=%x, want %x", got, want) + } else if got, want := chksum1, uint32(0x31093CD3); got != want { + t.Fatalf("chksum1=%x, want %x", got, want) + } else if got, want := byteOrder, binary.LittleEndian; got != want { + t.Fatalf("chksum1=%x, want %x", got, want) + } else if !bytes.Equal(frame, b[8272:]) { + t.Fatal("last frame mismatch") + } + }) + + t.Run("HeaderOnly", func(t *testing.T) { + if salt0, salt1, chksum0, chksum1, byteOrder, frame, err := litestream.ReadWALFields(bytes.NewReader(b[:32]), 4096); err != nil { + t.Fatal(err) + } else if got, want := salt0, uint32(0x4F7598FD); got != want { + t.Fatalf("salt0=%x, want %x", got, want) + } else if got, want := salt1, uint32(0x875FFD5B); got != want { + t.Fatalf("salt1=%x, want %x", got, want) + } else if got, want := chksum0, uint32(0xD27F7862); got != want { + t.Fatalf("chksum0=%x, want %x", got, want) + } else if got, want := chksum1, uint32(0xE664AF8E); got != want { + t.Fatalf("chksum1=%x, want %x", got, want) + } else if got, want := byteOrder, binary.LittleEndian; got != want { + t.Fatalf("chksum1=%x, want %x", got, want) + } else if frame != nil { + t.Fatal("expected no frame") + } + }) + + t.Run("ErrShortHeader", func(t *testing.T) { + if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader([]byte{}), 4096); err == nil || err.Error() != `short wal header: EOF` { + t.Fatal(err) + } + }) + + t.Run("ErrBadMagic", func(t *testing.T) { + if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(make([]byte, 32)), 4096); err == nil || err.Error() != `invalid wal header magic: 0` { + t.Fatal(err) + } + }) + + t.Run("ErrShortFrame", func(t *testing.T) { + if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(b[:100]), 4096); err == nil || err.Error() != `short wal frame (n=68): unexpected EOF` { + t.Fatal(err) + } + }) +} + // MustOpenDBs returns a new instance of a DB & associated SQL DB. func MustOpenDBs(tb testing.TB) (*litestream.DB, *sql.DB) { tb.Helper() diff --git a/testdata/read-wal-fields/ok b/testdata/read-wal-fields/ok new file mode 100644 index 0000000000000000000000000000000000000000..e019bfe23076f72cea2fe23aa339a85beaf925b2 GIT binary patch literal 12392 zcmeI%F-yZh6u|Mjqf|PSxalIi(IPDr1s4aaWDp8s?I1Xm7ApvfASZMyb#Zfcau)48gcH8L(l^?AuwW@p^-#wdlOx zZZ*Cy-k;o*UYzfaSGGPogO}5j@|A7MxOdI$^l3hPekvYKFY$T{ z_x}xzD(pzsJ<-8Xxr0Ce0R#|0009ILKmY**5I_Kd#05mB*MdM7ZM|{bwALI}!csMq zYU!XHij~b)&V4f8=g<2EM;mXiXCH&e@7ZUqgv6_2K?D#$009ILKmY**5I_I{1Q1vn wfj@l#>GXbznZ2*+zx;^`CQDN%g&=?c0tg_000IagfB*srAb>!E0!Ew6FEVR6MgRZ+ literal 0 HcmV?d00001