diff --git a/db.go b/db.go index 29d7114..0bc516c 100644 --- a/db.go +++ b/db.go @@ -291,18 +291,50 @@ func (db *DB) Open() (err error) { // Close releases the read lock & closes the database. This method should only // be called by tests as it causes the underlying database to be checkpointed. func (db *DB) Close() (err error) { - // Ensure replicas all stop replicating. - for _, r := range db.Replicas { - r.Stop(true) + return db.close(false) +} + +// SoftClose closes everything but the underlying db connection. This method +// is available because the binary needs to avoid closing the database on exit +// to prevent autocheckpointing. +func (db *DB) SoftClose() (err error) { + return db.close(true) +} + +func (db *DB) close(soft bool) (err error) { + db.cancel() + db.wg.Wait() + + // Start a new context for shutdown since we canceled the DB context. + ctx := context.Background() + + // Perform a final db sync, if initialized. + if db.db != nil { + if e := db.Sync(ctx); e != nil && err == nil { + err = e + } } + // Ensure replicas perform a final sync and stop replicating. + for _, r := range db.Replicas { + if db.db != nil { + if e := r.Sync(ctx); e != nil && err == nil { + err = e + } + } + r.Stop(!soft) + } + + // Release the read lock to allow other applications to handle checkpointing. if db.rtx != nil { if e := db.releaseReadLock(); e != nil && err == nil { err = e } } - if db.db != nil { + // Only perform full close if this is not a soft close. + // This closes the underlying database connection which can clean up the WAL. + if !soft && db.db != nil { if e := db.db.Close(); e != nil && err == nil { err = e } @@ -597,26 +629,6 @@ func (db *DB) cleanWAL() error { return nil } -// SoftClose closes everything but the underlying db connection. This method -// is available because the binary needs to avoid closing the database on exit -// to prevent autocheckpointing. -func (db *DB) SoftClose() (err error) { - db.cancel() - db.wg.Wait() - - // Ensure replicas all stop replicating. - for _, r := range db.Replicas { - r.Stop(false) - } - - if db.rtx != nil { - if e := db.releaseReadLock(); e != nil && err == nil { - err = e - } - } - return err -} - // acquireReadLock begins a read transaction on the database to prevent checkpointing. func (db *DB) acquireReadLock() error { if db.rtx != nil { @@ -711,7 +723,7 @@ func (db *DB) createGeneration() (string, error) { } // Sync copies pending data from the WAL to the shadow WAL. -func (db *DB) Sync() (err error) { +func (db *DB) Sync(ctx context.Context) (err error) { db.mu.Lock() defer db.mu.Unlock() @@ -755,7 +767,7 @@ func (db *DB) Sync() (err error) { // insert will never actually occur because our tx will be rolled back, // however, it will ensure our tx grabs the write lock. Unfortunately, // we can't call "BEGIN IMMEDIATE" as we are already in a transaction. - if _, err := tx.ExecContext(db.ctx, `INSERT INTO _litestream_lock (id) VALUES (1);`); err != nil { + if _, err := tx.ExecContext(ctx, `INSERT INTO _litestream_lock (id) VALUES (1);`); err != nil { return fmt.Errorf("_litestream_lock: %w", err) } @@ -814,7 +826,7 @@ func (db *DB) Sync() (err error) { if checkpoint { changed = true - if err := db.checkpointAndInit(info.generation, checkpointMode); err != nil { + if err := db.checkpointAndInit(ctx, info.generation, checkpointMode); err != nil { return fmt.Errorf("checkpoint: mode=%v err=%w", checkpointMode, err) } } @@ -1325,7 +1337,7 @@ func (db *DB) checkpoint(mode string) (err error) { // checkpointAndInit performs a checkpoint on the WAL file and initializes a // new shadow WAL file. -func (db *DB) checkpointAndInit(generation, mode string) error { +func (db *DB) checkpointAndInit(ctx context.Context, generation, mode string) error { shadowWALPath, err := db.CurrentShadowWALPath(generation) if err != nil { return err @@ -1368,7 +1380,7 @@ func (db *DB) checkpointAndInit(generation, mode string) error { // insert will never actually occur because our tx will be rolled back, // however, it will ensure our tx grabs the write lock. Unfortunately, // we can't call "BEGIN IMMEDIATE" as we are already in a transaction. - if _, err := tx.ExecContext(db.ctx, `INSERT INTO _litestream_lock (id) VALUES (1);`); err != nil { + if _, err := tx.ExecContext(ctx, `INSERT INTO _litestream_lock (id) VALUES (1);`); err != nil { return fmt.Errorf("_litestream_lock: %w", err) } @@ -1410,7 +1422,7 @@ func (db *DB) monitor() { } // Sync the database to the shadow WAL. - if err := db.Sync(); err != nil && !errors.Is(err, context.Canceled) { + if err := db.Sync(db.ctx); err != nil && !errors.Is(err, context.Canceled) { log.Printf("%s: sync error: %s", db.path, err) } } @@ -1666,7 +1678,7 @@ func restoreWAL(ctx context.Context, r Replica, generation string, index int, db // unable to checkpoint during this time. // // If dst is set, the database file is copied to that location before checksum. -func (db *DB) CRC64() (uint64, Pos, error) { +func (db *DB) CRC64(ctx context.Context) (uint64, Pos, error) { db.mu.Lock() defer db.mu.Unlock() @@ -1684,7 +1696,7 @@ func (db *DB) CRC64() (uint64, Pos, error) { } // Force a RESTART checkpoint to ensure the database is at the start of the WAL. - if err := db.checkpointAndInit(generation, CheckpointModeRestart); err != nil { + if err := db.checkpointAndInit(ctx, generation, CheckpointModeRestart); err != nil { return 0, Pos{}, err } diff --git a/db_test.go b/db_test.go index 7227cfb..a302067 100644 --- a/db_test.go +++ b/db_test.go @@ -1,10 +1,12 @@ package litestream_test import ( + "context" "database/sql" "io/ioutil" "os" "path/filepath" + "strings" "testing" "time" @@ -118,7 +120,7 @@ func TestDB_CRC64(t *testing.T) { t.Run("ErrNotExist", func(t *testing.T) { db := MustOpenDB(t) defer MustCloseDB(t, db) - if _, _, err := db.CRC64(); !os.IsNotExist(err) { + if _, _, err := db.CRC64(context.Background()); !os.IsNotExist(err) { t.Fatalf("unexpected error: %#v", err) } }) @@ -127,11 +129,11 @@ func TestDB_CRC64(t *testing.T) { db, sqldb := MustOpenDBs(t) defer MustCloseDBs(t, db, sqldb) - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } - chksum0, _, err := db.CRC64() + chksum0, _, err := db.CRC64(context.Background()) if err != nil { t.Fatal(err) } @@ -139,7 +141,7 @@ func TestDB_CRC64(t *testing.T) { // Issue change that is applied to the WAL. Checksum should not change. if _, err := sqldb.Exec(`CREATE TABLE t (id INT);`); err != nil { t.Fatal(err) - } else if chksum1, _, err := db.CRC64(); err != nil { + } else if chksum1, _, err := db.CRC64(context.Background()); err != nil { t.Fatal(err) } else if chksum0 == chksum1 { t.Fatal("expected different checksum event after WAL change") @@ -150,7 +152,7 @@ func TestDB_CRC64(t *testing.T) { t.Fatal(err) } - if chksum2, _, err := db.CRC64(); err != nil { + if chksum2, _, err := db.CRC64(context.Background()); err != nil { t.Fatal(err) } else if chksum0 == chksum2 { t.Fatal("expected different checksums after checkpoint") @@ -164,7 +166,7 @@ func TestDB_Sync(t *testing.T) { t.Run("NoDB", func(t *testing.T) { db := MustOpenDB(t) defer MustCloseDB(t, db) - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } }) @@ -174,7 +176,7 @@ func TestDB_Sync(t *testing.T) { db, sqldb := MustOpenDBs(t) defer MustCloseDBs(t, db, sqldb) - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -212,7 +214,7 @@ func TestDB_Sync(t *testing.T) { } // Perform initial sync & grab initial position. - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -227,7 +229,7 @@ func TestDB_Sync(t *testing.T) { } // Sync to ensure position moves forward one page. - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } else if pos1, err := db.Pos(); err != nil { t.Fatal(err) @@ -246,7 +248,7 @@ func TestDB_Sync(t *testing.T) { defer MustCloseDBs(t, db, sqldb) // Issue initial sync and truncate WAL. - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -275,7 +277,7 @@ func TestDB_Sync(t *testing.T) { defer MustCloseDB(t, db) // Re-sync and ensure new generation has been created. - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -298,7 +300,7 @@ func TestDB_Sync(t *testing.T) { } // Issue initial sync and truncate WAL. - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -334,7 +336,7 @@ func TestDB_Sync(t *testing.T) { defer MustCloseDB(t, db) // Re-sync and ensure new generation has been created. - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -354,7 +356,7 @@ func TestDB_Sync(t *testing.T) { // Execute a query to force a write to the WAL and then sync. if _, err := sqldb.Exec(`CREATE TABLE foo (bar TEXT);`); err != nil { t.Fatal(err) - } else if err := db.Sync(); err != nil { + } else if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -378,7 +380,7 @@ func TestDB_Sync(t *testing.T) { // Reopen managed database & ensure sync will still work. db = MustOpenDBAt(t, db.Path()) defer MustCloseDB(t, db) - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -398,7 +400,7 @@ func TestDB_Sync(t *testing.T) { // Execute a query to force a write to the WAL and then sync. if _, err := sqldb.Exec(`CREATE TABLE foo (bar TEXT);`); err != nil { t.Fatal(err) - } else if err := db.Sync(); err != nil { + } else if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -417,7 +419,7 @@ func TestDB_Sync(t *testing.T) { // Reopen managed database & ensure sync will still work. db = MustOpenDBAt(t, db.Path()) defer MustCloseDB(t, db) - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -437,7 +439,7 @@ func TestDB_Sync(t *testing.T) { // Execute a query to force a write to the WAL and then sync. if _, err := sqldb.Exec(`CREATE TABLE foo (bar TEXT);`); err != nil { t.Fatal(err) - } else if err := db.Sync(); err != nil { + } else if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -462,7 +464,7 @@ func TestDB_Sync(t *testing.T) { // Reopen managed database & ensure sync will still work. db = MustOpenDBAt(t, db.Path()) defer MustCloseDB(t, db) - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -489,7 +491,7 @@ func TestDB_Sync(t *testing.T) { // Execute a query to force a write to the WAL and then sync. if _, err := sqldb.Exec(`CREATE TABLE foo (bar TEXT);`); err != nil { t.Fatal(err) - } else if err := db.Sync(); err != nil { + } else if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -508,7 +510,7 @@ func TestDB_Sync(t *testing.T) { // Reopen managed database & ensure sync will still work. db = MustOpenDBAt(t, db.Path()) defer MustCloseDB(t, db) - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -532,7 +534,7 @@ func TestDB_Sync(t *testing.T) { // Execute a query to force a write to the WAL and then sync. if _, err := sqldb.Exec(`CREATE TABLE foo (bar TEXT);`); err != nil { t.Fatal(err) - } else if err := db.Sync(); err != nil { + } else if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -544,7 +546,7 @@ func TestDB_Sync(t *testing.T) { } // Sync to shadow WAL. - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -564,7 +566,7 @@ func TestDB_Sync(t *testing.T) { // Execute a query to force a write to the WAL and then sync. if _, err := sqldb.Exec(`CREATE TABLE foo (bar TEXT);`); err != nil { t.Fatal(err) - } else if err := db.Sync(); err != nil { + } else if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -574,7 +576,7 @@ func TestDB_Sync(t *testing.T) { // Write to WAL & sync. if _, err := sqldb.Exec(`INSERT INTO foo (bar) VALUES ('baz');`); err != nil { t.Fatal(err) - } else if err := db.Sync(); err != nil { + } else if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } @@ -589,12 +591,14 @@ func TestDB_Sync(t *testing.T) { // MustOpenDBs returns a new instance of a DB & associated SQL DB. func MustOpenDBs(tb testing.TB) (*litestream.DB, *sql.DB) { + tb.Helper() db := MustOpenDB(tb) return db, MustOpenSQLDB(tb, db.Path()) } // MustCloseDBs closes db & sqldb and removes the parent directory. func MustCloseDBs(tb testing.TB, db *litestream.DB, sqldb *sql.DB) { + tb.Helper() MustCloseDB(tb, db) MustCloseSQLDB(tb, sqldb) } @@ -619,7 +623,7 @@ func MustOpenDBAt(tb testing.TB, path string) *litestream.DB { // MustCloseDB closes db and removes its parent directory. func MustCloseDB(tb testing.TB, db *litestream.DB) { tb.Helper() - if err := db.Close(); err != nil { + if err := db.Close(); err != nil && !strings.Contains(err.Error(), `database is closed`) { tb.Fatal(err) } else if err := os.RemoveAll(filepath.Dir(db.Path())); err != nil { tb.Fatal(err) diff --git a/replica.go b/replica.go index d7369c5..bc24b5a 100644 --- a/replica.go +++ b/replica.go @@ -37,6 +37,9 @@ type Replica interface { // Stops all replication processing. Blocks until processing stopped. Stop(hard bool) error + // Performs a backup of outstanding WAL frames to the replica. + Sync(ctx context.Context) error + // Returns the last replication position. LastPos() Pos @@ -1164,7 +1167,7 @@ func ValidateReplica(ctx context.Context, r Replica) error { // Compute checksum of primary database under lock. This prevents a // sync from occurring and the database will not be written. - chksum0, pos, err := db.CRC64() + chksum0, pos, err := db.CRC64(ctx) if err != nil { return fmt.Errorf("cannot compute checksum: %w", err) } diff --git a/replica_test.go b/replica_test.go index 101b4cf..a70459f 100644 --- a/replica_test.go +++ b/replica_test.go @@ -15,7 +15,7 @@ func TestFileReplica_Sync(t *testing.T) { r := NewTestFileReplica(t, db) // Sync database & then sync replica. - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } else if err := r.Sync(context.Background()); err != nil { t.Fatal(err) @@ -47,7 +47,7 @@ func TestFileReplica_Sync(t *testing.T) { // Sync periodically. if i%100 == 0 || i == n-1 { - if err := db.Sync(); err != nil { + if err := db.Sync(context.Background()); err != nil { t.Fatal(err) } else if err := r.Sync(context.Background()); err != nil { t.Fatal(err)