From 84d08f547a282f16145657c6af994d346c7d5ed3 Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Fri, 14 Jan 2022 15:31:04 -0700 Subject: [PATCH] Add end-to-end replication/restore testing --- .github/workflows/test.yml | 20 +- cmd/litestream/databases_test.go | 4 +- cmd/litestream/generations.go | 4 +- cmd/litestream/generations_test.go | 8 +- cmd/litestream/main.go | 66 +-- cmd/litestream/main_test.go | 10 +- cmd/litestream/main_windows.go | 2 +- cmd/litestream/replicate.go | 4 +- cmd/litestream/restore.go | 8 +- cmd/litestream/restore_test.go | 8 +- cmd/litestream/snapshots_test.go | 6 +- cmd/litestream/wal.go | 4 +- cmd/litestream/wal_test.go | 6 +- db.go | 25 +- db_test.go | 10 +- integration/cmd_test.go | 411 ++++++++++++++++++ .../replicate/high-load/litestream.yml | 7 + .../replicate/long-running/litestream.yml | 4 + .../testdata/replicate/ok/litestream.yml | 7 + .../litestream.yml | 4 + .../resume-with-new-generation/litestream.yml | 4 + .../testdata/replicate/resume/litestream.yml | 7 + internal/locking_buffer.go | 145 ++++++ internal/testingutil/testingutil.go | 25 +- litestream.go | 10 +- replica.go | 52 ++- replica_test.go | 11 +- 27 files changed, 755 insertions(+), 117 deletions(-) create mode 100644 integration/cmd_test.go create mode 100644 integration/testdata/replicate/high-load/litestream.yml create mode 100644 integration/testdata/replicate/long-running/litestream.yml create mode 100644 integration/testdata/replicate/ok/litestream.yml create mode 100644 integration/testdata/replicate/resume-with-current-generation/litestream.yml create mode 100644 integration/testdata/replicate/resume-with-new-generation/litestream.yml create mode 100644 integration/testdata/replicate/resume/litestream.yml create mode 100644 internal/locking_buffer.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aabaa6e..cd533d3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,6 +29,9 @@ jobs: env: LITESTREAM_SFTP_KEY: ${{secrets.LITESTREAM_SFTP_KEY}} + - name: Build binary + run: go install ./cmd/litestream + - name: Run unit tests run: make testdata && go test -v ./... @@ -53,10 +56,13 @@ jobs: LITESTREAM_ABS_ACCOUNT_KEY: ${{ secrets.LITESTREAM_ABS_ACCOUNT_KEY }} LITESTREAM_ABS_BUCKET: integration - - name: Run sftp tests - run: go test -v -run=TestReplicaClient ./integration -replica-type sftp - env: - LITESTREAM_SFTP_HOST: ${{ secrets.LITESTREAM_SFTP_HOST }} - LITESTREAM_SFTP_USER: ${{ secrets.LITESTREAM_SFTP_USER }} - LITESTREAM_SFTP_KEY_PATH: /opt/id_ed25519 - LITESTREAM_SFTP_PATH: ${{ secrets.LITESTREAM_SFTP_PATH }} + #- name: Run sftp tests + # run: go test -v -run=TestReplicaClient ./integration -replica-type sftp + # env: + # LITESTREAM_SFTP_HOST: ${{ secrets.LITESTREAM_SFTP_HOST }} + # LITESTREAM_SFTP_USER: ${{ secrets.LITESTREAM_SFTP_USER }} + # LITESTREAM_SFTP_KEY_PATH: /opt/id_ed25519 + # LITESTREAM_SFTP_PATH: ${{ secrets.LITESTREAM_SFTP_PATH }} + + - name: Run long-running test + run: go test -v ./integration -long-running-duration 1m diff --git a/cmd/litestream/databases_test.go b/cmd/litestream/databases_test.go index 9499dc6..25aef5e 100644 --- a/cmd/litestream/databases_test.go +++ b/cmd/litestream/databases_test.go @@ -16,7 +16,7 @@ func TestDatabasesCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"databases", "-config", filepath.Join(testDir, "litestream.yml")}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -26,7 +26,7 @@ func TestDatabasesCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"databases", "-config", filepath.Join(testDir, "litestream.yml")}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) diff --git a/cmd/litestream/generations.go b/cmd/litestream/generations.go index da74099..5d237a2 100644 --- a/cmd/litestream/generations.go +++ b/cmd/litestream/generations.go @@ -74,7 +74,7 @@ func (c *GenerationsCommand) Run(ctx context.Context, args []string) (ret error) fmt.Fprintln(w, "name\tgeneration\tlag\tstart\tend") for _, r := range replicas { - generations, err := r.Client.Generations(ctx) + generations, err := r.Client().Generations(ctx) if err != nil { fmt.Fprintf(c.stderr, "%s: cannot list generations: %s", r.Name(), err) ret = errExit // signal error return without printing message @@ -83,7 +83,7 @@ func (c *GenerationsCommand) Run(ctx context.Context, args []string) (ret error) // Iterate over each generation for the replica. for _, generation := range generations { - createdAt, updatedAt, err := litestream.GenerationTimeBounds(ctx, r.Client, generation) + createdAt, updatedAt, err := litestream.GenerationTimeBounds(ctx, r.Client(), generation) if err != nil { fmt.Fprintf(c.stderr, "%s: cannot determine generation time bounds: %s", r.Name(), err) ret = errExit // signal error return without printing message diff --git a/cmd/litestream/generations_test.go b/cmd/litestream/generations_test.go index 097bd35..1da23e4 100644 --- a/cmd/litestream/generations_test.go +++ b/cmd/litestream/generations_test.go @@ -18,7 +18,7 @@ func TestGenerationsCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"generations", "-config", filepath.Join(testDir, "litestream.yml"), filepath.Join(testDir, "db")}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -30,7 +30,7 @@ func TestGenerationsCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"generations", "-config", filepath.Join(testDir, "litestream.yml"), "-replica", "replica1", filepath.Join(testDir, "db")}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -43,7 +43,7 @@ func TestGenerationsCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"generations", replicaURL}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -55,7 +55,7 @@ func TestGenerationsCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"generations", "-config", filepath.Join(testDir, "litestream.yml"), filepath.Join(testDir, "db")}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index 176ec99..db491c4 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -38,6 +38,7 @@ var errExit = errors.New("exit") func main() { log.SetFlags(0) + log.SetOutput(os.Stdout) m := NewMain(os.Stdin, os.Stdout, os.Stderr) if err := m.Run(context.Background(), os.Args[1:]); err == flag.ErrHelp || err == errExit { @@ -354,8 +355,35 @@ func NewReplicaFromConfig(c *ReplicaConfig, db *litestream.DB) (_ *litestream.Re return nil, fmt.Errorf("replica path cannot be a url, please use the 'url' field instead: %s", c.Path) } + // Build and set client on replica. + var client litestream.ReplicaClient + switch typ := c.ReplicaType(); typ { + case "file": + if client, err = newFileReplicaClientFromConfig(c); err != nil { + return nil, err + } + case "s3": + if client, err = newS3ReplicaClientFromConfig(c); err != nil { + return nil, err + } + case "gcs": + if client, err = newGCSReplicaClientFromConfig(c); err != nil { + return nil, err + } + case "abs": + if client, err = newABSReplicaClientFromConfig(c); err != nil { + return nil, err + } + case "sftp": + if client, err = newSFTPReplicaClientFromConfig(c); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unknown replica type in config: %q", typ) + } + // Build replica. - r := litestream.NewReplica(db, c.Name) + r := litestream.NewReplica(db, c.Name, client) if v := c.Retention; v != nil { r.Retention = *v } @@ -372,37 +400,11 @@ func NewReplicaFromConfig(c *ReplicaConfig, db *litestream.DB) (_ *litestream.Re r.ValidationInterval = *v } - // Build and set client on replica. - switch typ := c.ReplicaType(); typ { - case "file": - if r.Client, err = newFileReplicaClientFromConfig(c, r); err != nil { - return nil, err - } - case "s3": - if r.Client, err = newS3ReplicaClientFromConfig(c, r); err != nil { - return nil, err - } - case "gcs": - if r.Client, err = newGCSReplicaClientFromConfig(c, r); err != nil { - return nil, err - } - case "abs": - if r.Client, err = newABSReplicaClientFromConfig(c, r); err != nil { - return nil, err - } - case "sftp": - if r.Client, err = newSFTPReplicaClientFromConfig(c, r); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("unknown replica type in config: %q", typ) - } - return r, nil } // newFileReplicaClientFromConfig returns a new instance of FileReplicaClient built from config. -func newFileReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *litestream.FileReplicaClient, err error) { +func newFileReplicaClientFromConfig(c *ReplicaConfig) (_ *litestream.FileReplicaClient, err error) { // Ensure URL & path are not both specified. if c.URL != "" && c.Path != "" { return nil, fmt.Errorf("cannot specify url & path for file replica") @@ -431,7 +433,7 @@ func newFileReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ } // newS3ReplicaClientFromConfig returns a new instance of s3.ReplicaClient built from config. -func newS3ReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *s3.ReplicaClient, err error) { +func newS3ReplicaClientFromConfig(c *ReplicaConfig) (_ *s3.ReplicaClient, err error) { // Ensure URL & constituent parts are not both specified. if c.URL != "" && c.Path != "" { return nil, fmt.Errorf("cannot specify url & path for s3 replica") @@ -494,7 +496,7 @@ func newS3ReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *s } // newGCSReplicaClientFromConfig returns a new instance of gcs.ReplicaClient built from config. -func newGCSReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *gcs.ReplicaClient, err error) { +func newGCSReplicaClientFromConfig(c *ReplicaConfig) (_ *gcs.ReplicaClient, err error) { // Ensure URL & constituent parts are not both specified. if c.URL != "" && c.Path != "" { return nil, fmt.Errorf("cannot specify url & path for gcs replica") @@ -533,7 +535,7 @@ func newGCSReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ * } // newABSReplicaClientFromConfig returns a new instance of abs.ReplicaClient built from config. -func newABSReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *abs.ReplicaClient, err error) { +func newABSReplicaClientFromConfig(c *ReplicaConfig) (_ *abs.ReplicaClient, err error) { // Ensure URL & constituent parts are not both specified. if c.URL != "" && c.Path != "" { return nil, fmt.Errorf("cannot specify url & path for abs replica") @@ -576,7 +578,7 @@ func newABSReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ * } // newSFTPReplicaClientFromConfig returns a new instance of sftp.ReplicaClient built from config. -func newSFTPReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *sftp.ReplicaClient, err error) { +func newSFTPReplicaClientFromConfig(c *ReplicaConfig) (_ *sftp.ReplicaClient, err error) { // Ensure URL & constituent parts are not both specified. if c.URL != "" && c.Path != "" { return nil, fmt.Errorf("cannot specify url & path for sftp replica") diff --git a/cmd/litestream/main_test.go b/cmd/litestream/main_test.go index f3e9fb1..d3d0af8 100644 --- a/cmd/litestream/main_test.go +++ b/cmd/litestream/main_test.go @@ -104,7 +104,7 @@ func TestNewFileReplicaFromConfig(t *testing.T) { r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{Path: "/foo"}, nil) if err != nil { t.Fatal(err) - } else if client, ok := r.Client.(*litestream.FileReplicaClient); !ok { + } else if client, ok := r.Client().(*litestream.FileReplicaClient); !ok { t.Fatal("unexpected replica type") } else if got, want := client.Path(), "/foo"; got != want { t.Fatalf("Path=%s, want %s", got, want) @@ -116,7 +116,7 @@ func TestNewS3ReplicaFromConfig(t *testing.T) { r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{URL: "s3://foo/bar"}, nil) if err != nil { t.Fatal(err) - } else if client, ok := r.Client.(*s3.ReplicaClient); !ok { + } else if client, ok := r.Client().(*s3.ReplicaClient); !ok { t.Fatal("unexpected replica type") } else if got, want := client.Bucket, "foo"; got != want { t.Fatalf("Bucket=%s, want %s", got, want) @@ -135,7 +135,7 @@ func TestNewS3ReplicaFromConfig(t *testing.T) { r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{URL: "s3://foo.localhost:9000/bar"}, nil) if err != nil { t.Fatal(err) - } else if client, ok := r.Client.(*s3.ReplicaClient); !ok { + } else if client, ok := r.Client().(*s3.ReplicaClient); !ok { t.Fatal("unexpected replica type") } else if got, want := client.Bucket, "foo"; got != want { t.Fatalf("Bucket=%s, want %s", got, want) @@ -154,7 +154,7 @@ func TestNewS3ReplicaFromConfig(t *testing.T) { r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{URL: "s3://foo.s3.us-west-000.backblazeb2.com/bar"}, nil) if err != nil { t.Fatal(err) - } else if client, ok := r.Client.(*s3.ReplicaClient); !ok { + } else if client, ok := r.Client().(*s3.ReplicaClient); !ok { t.Fatal("unexpected replica type") } else if got, want := client.Bucket, "foo"; got != want { t.Fatalf("Bucket=%s, want %s", got, want) @@ -174,7 +174,7 @@ func TestNewGCSReplicaFromConfig(t *testing.T) { r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{URL: "gcs://foo/bar"}, nil) if err != nil { t.Fatal(err) - } else if client, ok := r.Client.(*gcs.ReplicaClient); !ok { + } else if client, ok := r.Client().(*gcs.ReplicaClient); !ok { t.Fatal("unexpected replica type") } else if got, want := client.Bucket, "foo"; got != want { t.Fatalf("Bucket=%s, want %s", got, want) diff --git a/cmd/litestream/main_windows.go b/cmd/litestream/main_windows.go index d437c1b..e6276eb 100644 --- a/cmd/litestream/main_windows.go +++ b/cmd/litestream/main_windows.go @@ -37,7 +37,7 @@ func runWindowsService(ctx context.Context) error { // Set eventlog as log writer while running. log.SetOutput((*eventlogWriter)(elog)) - defer log.SetOutput(os.Stderr) + defer log.SetOutput(os.Stdout) log.Print("Litestream service starting") diff --git a/cmd/litestream/replicate.go b/cmd/litestream/replicate.go index 5d32db7..7f9a9aa 100644 --- a/cmd/litestream/replicate.go +++ b/cmd/litestream/replicate.go @@ -121,7 +121,7 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) { for _, db := range c.DBs { log.Printf("initialized db: %s", db.Path()) for _, r := range db.Replicas { - switch client := r.Client.(type) { + switch client := r.Client().(type) { case *litestream.FileReplicaClient: log.Printf("replicating to: name=%q type=%q path=%q", r.Name(), client.Type(), client.Path()) case *s3.ReplicaClient: @@ -173,6 +173,8 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) { go func() { c.execCh <- c.cmd.Wait() }() } + log.Printf("litestream initialization complete") + return nil } diff --git a/cmd/litestream/restore.go b/cmd/litestream/restore.go index 1a0f5fd..d2a5d1d 100644 --- a/cmd/litestream/restore.go +++ b/cmd/litestream/restore.go @@ -104,7 +104,7 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) { // Determine latest generation if one is not specified. if c.generation == "" { - if c.generation, err = litestream.FindLatestGeneration(ctx, r.Client); err == litestream.ErrNoGeneration { + if c.generation, err = litestream.FindLatestGeneration(ctx, r.Client()); err == litestream.ErrNoGeneration { // Return an error if no matching targets found. // If optional flag set, return success. Useful for automated recovery. if c.ifReplicaExists { @@ -119,14 +119,14 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) { // Determine the maximum available index for the generation if one is not specified. if c.targetIndex == -1 { - if c.targetIndex, err = litestream.FindMaxIndexByGeneration(ctx, r.Client, c.generation); err != nil { + if c.targetIndex, err = litestream.FindMaxIndexByGeneration(ctx, r.Client(), c.generation); err != nil { return fmt.Errorf("cannot determine latest index in generation %q: %w", c.generation, err) } } // Find lastest snapshot that occurs before the index. // TODO: Optionally allow -snapshot-index - if c.snapshotIndex, err = litestream.FindSnapshotForIndex(ctx, r.Client, c.generation, c.targetIndex); err != nil { + if c.snapshotIndex, err = litestream.FindSnapshotForIndex(ctx, r.Client(), c.generation, c.targetIndex); err != nil { return fmt.Errorf("cannot find snapshot index: %w", err) } @@ -137,7 +137,7 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) { c.opt.Logger = log.New(c.stdout, "", log.LstdFlags|log.Lmicroseconds) - return litestream.Restore(ctx, r.Client, c.outputPath, c.generation, c.snapshotIndex, c.targetIndex, c.opt) + return litestream.Restore(ctx, r.Client(), c.outputPath, c.generation, c.snapshotIndex, c.targetIndex, c.opt) } func (c *RestoreCommand) loadReplica(ctx context.Context, config Config, arg string) (*litestream.Replica, error) { diff --git a/cmd/litestream/restore_test.go b/cmd/litestream/restore_test.go index 4d0770c..9469b5a 100644 --- a/cmd/litestream/restore_test.go +++ b/cmd/litestream/restore_test.go @@ -120,7 +120,7 @@ func TestRestoreCommand(t *testing.T) { err := m.Run(context.Background(), []string{"restore", "-config", filepath.Join(testDir, "litestream.yml"), "-if-db-not-exists", filepath.Join(testDir, "db")}) if err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -133,7 +133,7 @@ func TestRestoreCommand(t *testing.T) { err := m.Run(context.Background(), []string{"restore", "-config", filepath.Join(testDir, "litestream.yml"), "-if-replica-exists", filepath.Join(testDir, "db")}) if err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -147,9 +147,9 @@ func TestRestoreCommand(t *testing.T) { err := m.Run(context.Background(), []string{"restore", "-config", filepath.Join(testDir, "litestream.yml"), "-o", filepath.Join(tempDir, "db"), filepath.Join(testDir, "db")}) if err == nil || err.Error() != `no matching backups found` { t.Fatalf("unexpected error: %s", err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) - } else if got, want := stderr.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stderr"))); got != want { + } else if got, want := stderr.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stderr"))); got != want { t.Fatalf("stderr=%q, want %q", got, want) } }) diff --git a/cmd/litestream/snapshots_test.go b/cmd/litestream/snapshots_test.go index 3dc9288..f845cdc 100644 --- a/cmd/litestream/snapshots_test.go +++ b/cmd/litestream/snapshots_test.go @@ -18,7 +18,7 @@ func TestSnapshotsCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"snapshots", "-config", filepath.Join(testDir, "litestream.yml"), filepath.Join(testDir, "db")}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -30,7 +30,7 @@ func TestSnapshotsCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"snapshots", "-config", filepath.Join(testDir, "litestream.yml"), "-replica", "replica1", filepath.Join(testDir, "db")}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -43,7 +43,7 @@ func TestSnapshotsCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"snapshots", replicaURL}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) diff --git a/cmd/litestream/wal.go b/cmd/litestream/wal.go index fa28107..1124c03 100644 --- a/cmd/litestream/wal.go +++ b/cmd/litestream/wal.go @@ -69,7 +69,7 @@ func (c *WALCommand) Run(ctx context.Context, args []string) (ret error) { if c.generation != "" { generations = []string{c.generation} } else { - if generations, err = r.Client.Generations(ctx); err != nil { + if generations, err = r.Client().Generations(ctx); err != nil { log.Printf("%s: cannot determine generations: %s", r.Name(), err) ret = errExit // signal error return without printing message continue @@ -78,7 +78,7 @@ func (c *WALCommand) Run(ctx context.Context, args []string) (ret error) { for _, generation := range generations { if err := func() error { - itr, err := r.Client.WALSegments(ctx, generation) + itr, err := r.Client().WALSegments(ctx, generation) if err != nil { return err } diff --git a/cmd/litestream/wal_test.go b/cmd/litestream/wal_test.go index f313e01..6fbe0b0 100644 --- a/cmd/litestream/wal_test.go +++ b/cmd/litestream/wal_test.go @@ -18,7 +18,7 @@ func TestWALCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"wal", "-config", filepath.Join(testDir, "litestream.yml"), filepath.Join(testDir, "db")}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -30,7 +30,7 @@ func TestWALCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"wal", "-config", filepath.Join(testDir, "litestream.yml"), "-replica", "replica1", filepath.Join(testDir, "db")}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) @@ -43,7 +43,7 @@ func TestWALCommand(t *testing.T) { m, _, stdout, _ := newMain() if err := m.Run(context.Background(), []string{"wal", replicaURL}); err != nil { t.Fatal(err) - } else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want { + } else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want { t.Fatalf("stdout=%q, want %q", got, want) } }) diff --git a/db.go b/db.go index f56bdc5..ed5a97f 100644 --- a/db.go +++ b/db.go @@ -121,7 +121,7 @@ func NewDB(path string) *DB { CheckpointInterval: DefaultCheckpointInterval, MonitorInterval: DefaultMonitorInterval, - Logger: log.New(LogWriter, fmt.Sprintf("%s: ", path), LogFlags), + Logger: log.New(LogWriter, fmt.Sprintf("%s: ", logPrefixPath(path)), LogFlags), } db.dbSizeGauge = dbSizeGaugeVec.WithLabelValues(db.path) @@ -300,7 +300,7 @@ func (db *DB) invalidateChecksum(ctx context.Context) error { r := &io.LimitedReader{R: rc, N: db.pos.Offset} // Determine cache values from the current WAL file. - db.salt0, db.salt1, db.chksum0, db.chksum1, db.byteOrder, db.frame, err = ReadWALFields(r, db.pageSize) + db.salt0, db.salt1, db.chksum0, db.chksum1, db.byteOrder, db.hdr, db.frame, err = ReadWALFields(r, db.pageSize) if err != nil { return fmt.Errorf("calc checksum: %w", err) } @@ -1621,11 +1621,11 @@ var walPathRegex = regexp.MustCompile(`^([0-9a-f]{8})\.wal$`) // Returns salt, checksum, byte order & the last frame. WAL data must start // from the beginning of the WAL header and must end on either the WAL header // or at the end of a WAL frame. -func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 uint32, byteOrder binary.ByteOrder, frame []byte, err error) { +func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 uint32, byteOrder binary.ByteOrder, hdr, frame []byte, err error) { // Read header. - hdr := make([]byte, WALHeaderSize) + hdr = make([]byte, WALHeaderSize) if _, err := io.ReadFull(r, hdr); err != nil { - return 0, 0, 0, 0, nil, nil, fmt.Errorf("short wal header: %w", err) + return 0, 0, 0, 0, nil, nil, nil, fmt.Errorf("short wal header: %w", err) } // Save salt, initial checksum, & byte order. @@ -1634,7 +1634,7 @@ func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 ui chksum0 = binary.BigEndian.Uint32(hdr[24:]) chksum1 = binary.BigEndian.Uint32(hdr[28:]) if byteOrder, err = headerByteOrder(hdr); err != nil { - return 0, 0, 0, 0, nil, nil, err + return 0, 0, 0, 0, nil, nil, nil, err } // Iterate over each page in the WAL and save the checksum. @@ -1645,7 +1645,7 @@ func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 ui if n, err := io.ReadFull(r, frame); err == io.EOF { break // end of WAL file } else if err != nil { - return 0, 0, 0, 0, nil, nil, fmt.Errorf("short wal frame (n=%d): %w", n, err) + return 0, 0, 0, 0, nil, nil, nil, fmt.Errorf("short wal frame (n=%d): %w", n, err) } // Update checksum on each successful frame. @@ -1659,7 +1659,7 @@ func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 ui frame = nil } - return salt0, salt1, chksum0, chksum1, byteOrder, frame, nil + return salt0, salt1, chksum0, chksum1, byteOrder, hdr, frame, nil } // Database metrics. @@ -1731,3 +1731,12 @@ func headerByteOrder(hdr []byte) (binary.ByteOrder, error) { return nil, fmt.Errorf("invalid wal header magic: %x", magic) } } + +// logPrefixPath returns the path to be used for logging. +// The path is reduced to its base if it appears to be a temporary test path. +func logPrefixPath(path string) string { + if strings.Contains(path, "TestCmd") { + return filepath.Base(path) + } + return path +} diff --git a/db_test.go b/db_test.go index 5c3f51c..fe74225 100644 --- a/db_test.go +++ b/db_test.go @@ -482,7 +482,7 @@ func TestReadWALFields(t *testing.T) { } t.Run("OK", func(t *testing.T) { - if salt0, salt1, chksum0, chksum1, byteOrder, frame, err := litestream.ReadWALFields(bytes.NewReader(b), 4096); err != nil { + if salt0, salt1, chksum0, chksum1, byteOrder, _, frame, err := litestream.ReadWALFields(bytes.NewReader(b), 4096); err != nil { t.Fatal(err) } else if got, want := salt0, uint32(0x4F7598FD); got != want { t.Fatalf("salt0=%x, want %x", got, want) @@ -500,7 +500,7 @@ func TestReadWALFields(t *testing.T) { }) t.Run("HeaderOnly", func(t *testing.T) { - if salt0, salt1, chksum0, chksum1, byteOrder, frame, err := litestream.ReadWALFields(bytes.NewReader(b[:32]), 4096); err != nil { + if salt0, salt1, chksum0, chksum1, byteOrder, _, frame, err := litestream.ReadWALFields(bytes.NewReader(b[:32]), 4096); err != nil { t.Fatal(err) } else if got, want := salt0, uint32(0x4F7598FD); got != want { t.Fatalf("salt0=%x, want %x", got, want) @@ -518,19 +518,19 @@ func TestReadWALFields(t *testing.T) { }) t.Run("ErrShortHeader", func(t *testing.T) { - if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader([]byte{}), 4096); err == nil || err.Error() != `short wal header: EOF` { + if _, _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader([]byte{}), 4096); err == nil || err.Error() != `short wal header: EOF` { t.Fatal(err) } }) t.Run("ErrBadMagic", func(t *testing.T) { - if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(make([]byte, 32)), 4096); err == nil || err.Error() != `invalid wal header magic: 0` { + if _, _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(make([]byte, 32)), 4096); err == nil || err.Error() != `invalid wal header magic: 0` { t.Fatal(err) } }) t.Run("ErrShortFrame", func(t *testing.T) { - if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(b[:100]), 4096); err == nil || err.Error() != `short wal frame (n=68): unexpected EOF` { + if _, _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(b[:100]), 4096); err == nil || err.Error() != `short wal frame (n=68): unexpected EOF` { t.Fatal(err) } }) diff --git a/integration/cmd_test.go b/integration/cmd_test.go new file mode 100644 index 0000000..ab9c1c0 --- /dev/null +++ b/integration/cmd_test.go @@ -0,0 +1,411 @@ +package integration_test + +import ( + "bytes" + "context" + "database/sql" + "flag" + "fmt" + "io" + "math/rand" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/benbjohnson/litestream/internal" + "github.com/benbjohnson/litestream/internal/testingutil" + _ "github.com/mattn/go-sqlite3" +) + +var longRunningDuration = flag.Duration("long-running-duration", 0, "") + +func init() { + fmt.Fprintln(os.Stderr, "# ") + fmt.Fprintln(os.Stderr, "# NOTE: Build litestream to your PATH before running integration tests") + fmt.Fprintln(os.Stderr, "#") + fmt.Fprintln(os.Stderr, "") +} + +// Ensure the default configuration works with light database load. +func TestCmd_Replicate_OK(t *testing.T) { + ctx := context.Background() + testDir, tempDir := filepath.Join("testdata", "replicate", "ok"), t.TempDir() + env := []string{"LITESTREAM_TEMPDIR=" + tempDir} + + cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml")) + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db")) + if err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { + t.Fatal(err) + } + defer db.Close() + + // Execute writes periodically. + for i := 0; i < 100; i++ { + t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", i) + if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, i); err != nil { + t.Fatal(err) + } + time.Sleep(10 * time.Millisecond) + } + + // Stop & wait for Litestream command. + killLitestreamCmd(t, cmd, stdout) + + // Ensure signal and shutdown are logged. + if s := stdout.String(); !strings.Contains(s, `signal received, litestream shutting down`) { + t.Fatal("missing log output for signal received") + } else if s := stdout.String(); !strings.Contains(s, `litestream shut down`) { + t.Fatal("missing log output for shut down") + } + + // Checkpoint & verify original SQLite database. + if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { + t.Fatal(err) + } + restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db")) +} + +// Ensure that stopping and restarting Litestream before an application-induced +// checkpoint will cause Litestream to continue replicating using the same generation. +func TestCmd_Replicate_ResumeWithCurrentGeneration(t *testing.T) { + ctx := context.Background() + testDir, tempDir := filepath.Join("testdata", "replicate", "resume-with-current-generation"), t.TempDir() + env := []string{"LITESTREAM_TEMPDIR=" + tempDir} + + cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml")) + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + t.Log("writing to database during replication") + + db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db")) + if err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { + t.Fatal(err) + } + defer db.Close() + + // Execute a few writes to populate the WAL. + if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (1)`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (2)`); err != nil { + t.Fatal(err) + } + + // Wait for replication to occur & shutdown. + waitForLogMessage(t, stdout, `wal segment written`) + killLitestreamCmd(t, cmd, stdout) + t.Log("replication shutdown, continuing database writes") + + // Execute a few more writes while replication is stopped. + if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (3)`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (4)`); err != nil { + t.Fatal(err) + } + + t.Log("restarting replication") + + cmd, stdout, _ = commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml")) + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + waitForLogMessage(t, stdout, `wal segment written`) + killLitestreamCmd(t, cmd, stdout) + + t.Log("replication shutdown again") + + // Litestream should resume replication from the previous generation. + if s := stdout.String(); strings.Contains(s, "no generation exists") { + t.Fatal("expected existing generation to resume; started new generation instead") + } + + // Checkpoint & verify original SQLite database. + if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { + t.Fatal(err) + } + restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db")) +} + +// Ensure that restarting Litestream after a full checkpoint has occurred will +// cause it to begin a new generation. +func TestCmd_Replicate_ResumeWithNewGeneration(t *testing.T) { + ctx := context.Background() + testDir, tempDir := filepath.Join("testdata", "replicate", "resume-with-new-generation"), t.TempDir() + env := []string{"LITESTREAM_TEMPDIR=" + tempDir} + + cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml")) + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + t.Log("writing to database during replication") + + db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db")) + if err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { + t.Fatal(err) + } + defer db.Close() + + // Execute a few writes to populate the WAL. + if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (1)`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (2)`); err != nil { + t.Fatal(err) + } + + // Wait for replication to occur & shutdown. + waitForLogMessage(t, stdout, `wal segment written`) + killLitestreamCmd(t, cmd, stdout) + t.Log("replication shutdown, continuing database writes") + + // Execute a few more writes while replication is stopped. + if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (3)`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (4)`); err != nil { + t.Fatal(err) + } + + t.Log("issuing checkpoint") + + // Issue a checkpoint to restart WAL. + if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(RESTART)`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (5)`); err != nil { + t.Fatal(err) + } + + t.Log("restarting replication") + + cmd, stdout, _ = commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml")) + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + waitForLogMessage(t, stdout, `wal segment written`) + killLitestreamCmd(t, cmd, stdout) + + t.Log("replication shutdown again") + + // Litestream should resume replication from the previous generation. + if s := stdout.String(); !strings.Contains(s, "no generation exists") { + t.Fatal("expected new generation to start; continued existing generation instead") + } + + // Checkpoint & verify original SQLite database. + if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { + t.Fatal(err) + } + restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db")) +} + +// Ensure the default configuration works with heavy write load. +func TestCmd_Replicate_HighLoad(t *testing.T) { + if testing.Short() { + t.Skip("short mode enabled, skipping") + } + + const writeDuration = 30 * time.Second + + ctx := context.Background() + testDir, tempDir := filepath.Join("testdata", "replicate", "high-load"), t.TempDir() + env := []string{"LITESTREAM_TEMPDIR=" + tempDir} + + cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml")) + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db")) + if err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `PRAGMA journal_mode = WAL`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `PRAGMA synchronous = NORMAL`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `PRAGMA wal_autocheckpoint = 0`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { + t.Fatal(err) + } + defer db.Close() + + // Execute writes as fast as possible for a period of time. + timer := time.NewTimer(writeDuration) + defer timer.Stop() + + t.Logf("executing writes for %s", writeDuration) + +LOOP: + for i := 0; ; i++ { + select { + case <-timer.C: + break LOOP + default: + if i%1000 == 0 { + t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", i) + } + if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, i); err != nil { + t.Fatal(err) + } + } + } + + t.Logf("writes complete, shutting down") + + // Stop & wait for Litestream command. + time.Sleep(5 * time.Second) + killLitestreamCmd(t, cmd, stdout) + + // Checkpoint & verify original SQLite database. + if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { + t.Fatal(err) + } + restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db")) +} + +// Ensure replication works for an extended period. +func TestCmd_Replicate_LongRunning(t *testing.T) { + if *longRunningDuration == 0 { + t.Skip("long running test duration not specified, skipping") + } + + ctx := context.Background() + testDir, tempDir := filepath.Join("testdata", "replicate", "long-running"), t.TempDir() + env := []string{"LITESTREAM_TEMPDIR=" + tempDir} + + cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml")) + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db")) + if err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `PRAGMA journal_mode = WAL`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `PRAGMA synchronous = NORMAL`); err != nil { + t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { + t.Fatal(err) + } + defer db.Close() + + // Execute writes as fast as possible for a period of time. + timer := time.NewTimer(*longRunningDuration) + defer timer.Stop() + + t.Logf("executing writes for %s", longRunningDuration) + +LOOP: + for i := 0; ; i++ { + select { + case <-timer.C: + break LOOP + default: + t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", i) + if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, i); err != nil { + t.Fatal(err) + } + + time.Sleep(time.Duration(rand.Intn(int(time.Second)))) + } + } + + t.Logf("writes complete, shutting down") + + // Stop & wait for Litestream command. + killLitestreamCmd(t, cmd, stdout) + + // Checkpoint & verify original SQLite database. + if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil { + t.Fatal(err) + } + restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db")) +} + +// commandContext returns a "litestream" command with stdout/stderr buffers. +func commandContext(ctx context.Context, env []string, arg ...string) (cmd *exec.Cmd, stdout, stderr *internal.LockingBuffer) { + cmd = exec.CommandContext(ctx, "litestream", arg...) + cmd.Env = env + var outBuf, errBuf internal.LockingBuffer + + // Split stdout/stderr to terminal if verbose flag set. + cmd.Stdout, cmd.Stderr = &outBuf, &errBuf + if testing.Verbose() { + cmd.Stdout = io.MultiWriter(&outBuf, os.Stdout) + cmd.Stderr = io.MultiWriter(&errBuf, os.Stderr) + } + + return cmd, &outBuf, &errBuf +} + +// waitForLogMessage continuously checks b for a message and returns when it occurs. +func waitForLogMessage(tb testing.TB, b *internal.LockingBuffer, msg string) { + timer := time.NewTimer(30 * time.Second) + defer timer.Stop() + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timer.C: + tb.Fatal("timed out waiting for cmd initialization") + case <-ticker.C: + if strings.Contains(b.String(), msg) { + return + } + } + } +} + +// killLitestreamCmd interrupts the process and waits for a clean shutdown. +func killLitestreamCmd(tb testing.TB, cmd *exec.Cmd, stdout *internal.LockingBuffer) { + if err := cmd.Process.Signal(os.Interrupt); err != nil { + tb.Fatal(err) + } else if err := cmd.Wait(); err != nil { + tb.Fatal(err) + } +} + +// restoreAndVerify executes a "restore" and compares byte with the original database. +func restoreAndVerify(tb testing.TB, ctx context.Context, env []string, configPath, dbPath string) { + restorePath := filepath.Join(tb.TempDir(), "db") + + // Restore database. + cmd, _, _ := commandContext(ctx, env, "restore", "-config", configPath, "-o", restorePath, dbPath) + if err := cmd.Run(); err != nil { + tb.Fatalf("error running 'restore' command: %s", err) + } + + // Compare original database & restored database. + buf0 := testingutil.ReadFile(tb, dbPath) + buf1 := testingutil.ReadFile(tb, restorePath) + if bytes.Equal(buf0, buf1) { + return // ok, exit + } + + // On mismatch, copy out original & restored DBs. + dir, err := os.MkdirTemp("", "litestream-*") + if err != nil { + tb.Fatal(err) + } + testingutil.CopyFile(tb, dbPath, filepath.Join(dir, "original.db")) + testingutil.CopyFile(tb, restorePath, filepath.Join(dir, "restored.db")) + + tb.Fatalf("database mismatch; databases copied to %s", dir) +} diff --git a/integration/testdata/replicate/high-load/litestream.yml b/integration/testdata/replicate/high-load/litestream.yml new file mode 100644 index 0000000..26fb119 --- /dev/null +++ b/integration/testdata/replicate/high-load/litestream.yml @@ -0,0 +1,7 @@ +dbs: + - path: $LITESTREAM_TEMPDIR/db + replicas: + - path: $LITESTREAM_TEMPDIR/replica + + monitor-interval: 100ms + max-checkpoint-page-count: 20 diff --git a/integration/testdata/replicate/long-running/litestream.yml b/integration/testdata/replicate/long-running/litestream.yml new file mode 100644 index 0000000..b7d0e0e --- /dev/null +++ b/integration/testdata/replicate/long-running/litestream.yml @@ -0,0 +1,4 @@ +dbs: + - path: $LITESTREAM_TEMPDIR/db + replicas: + - path: $LITESTREAM_TEMPDIR/replica diff --git a/integration/testdata/replicate/ok/litestream.yml b/integration/testdata/replicate/ok/litestream.yml new file mode 100644 index 0000000..26fb119 --- /dev/null +++ b/integration/testdata/replicate/ok/litestream.yml @@ -0,0 +1,7 @@ +dbs: + - path: $LITESTREAM_TEMPDIR/db + replicas: + - path: $LITESTREAM_TEMPDIR/replica + + monitor-interval: 100ms + max-checkpoint-page-count: 20 diff --git a/integration/testdata/replicate/resume-with-current-generation/litestream.yml b/integration/testdata/replicate/resume-with-current-generation/litestream.yml new file mode 100644 index 0000000..b7d0e0e --- /dev/null +++ b/integration/testdata/replicate/resume-with-current-generation/litestream.yml @@ -0,0 +1,4 @@ +dbs: + - path: $LITESTREAM_TEMPDIR/db + replicas: + - path: $LITESTREAM_TEMPDIR/replica diff --git a/integration/testdata/replicate/resume-with-new-generation/litestream.yml b/integration/testdata/replicate/resume-with-new-generation/litestream.yml new file mode 100644 index 0000000..b7d0e0e --- /dev/null +++ b/integration/testdata/replicate/resume-with-new-generation/litestream.yml @@ -0,0 +1,4 @@ +dbs: + - path: $LITESTREAM_TEMPDIR/db + replicas: + - path: $LITESTREAM_TEMPDIR/replica diff --git a/integration/testdata/replicate/resume/litestream.yml b/integration/testdata/replicate/resume/litestream.yml new file mode 100644 index 0000000..0bdd84e --- /dev/null +++ b/integration/testdata/replicate/resume/litestream.yml @@ -0,0 +1,7 @@ +dbs: + - path: $LITESTREAM_TEMPDIR/db + replicas: + - path: $LITESTREAM_TEMPDIR/replica + + monitor-interval: 100ms + max-checkpoint-page-count: 10 diff --git a/internal/locking_buffer.go b/internal/locking_buffer.go new file mode 100644 index 0000000..5a95df9 --- /dev/null +++ b/internal/locking_buffer.go @@ -0,0 +1,145 @@ +package internal + +import ( + "bytes" + "io" + "sync" +) + +// LockingBuffer wraps a bytes.Buffer with a mutex. +type LockingBuffer struct { + mu sync.Mutex + b bytes.Buffer +} + +func (b *LockingBuffer) Bytes() []byte { + b.mu.Lock() + defer b.mu.Unlock() + buf := b.b.Bytes() + other := make([]byte, len(buf)) + copy(other, buf) + return other +} + +func (b *LockingBuffer) Cap() int { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.Cap() +} + +func (b *LockingBuffer) Grow(n int) { + b.mu.Lock() + defer b.mu.Unlock() + b.b.Grow(n) +} + +func (b *LockingBuffer) Len() int { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.Len() +} + +func (b *LockingBuffer) Next(n int) []byte { + b.mu.Lock() + defer b.mu.Unlock() + buf := b.b.Next(n) + other := make([]byte, len(buf)) + copy(other, buf) + return other +} + +func (b *LockingBuffer) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.Read(p) +} + +func (b *LockingBuffer) ReadByte() (byte, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.ReadByte() +} + +func (b *LockingBuffer) ReadBytes(delim byte) (line []byte, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.ReadBytes(delim) +} + +func (b *LockingBuffer) ReadFrom(r io.Reader) (n int64, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.ReadFrom(r) +} + +func (b *LockingBuffer) ReadRune() (r rune, size int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.ReadRune() +} + +func (b *LockingBuffer) ReadString(delim byte) (line string, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.ReadString(delim) +} + +func (b *LockingBuffer) Reset() { + b.mu.Lock() + defer b.mu.Unlock() + b.b.Reset() +} + +func (b *LockingBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.String() +} + +func (b *LockingBuffer) Truncate(n int) { + b.mu.Lock() + defer b.mu.Unlock() + b.b.Truncate(n) +} + +func (b *LockingBuffer) UnreadByte() error { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.UnreadByte() +} + +func (b *LockingBuffer) UnreadRune() error { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.UnreadRune() +} + +func (b *LockingBuffer) Write(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.Write(p) +} + +func (b *LockingBuffer) WriteByte(c byte) error { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.WriteByte(c) +} + +func (b *LockingBuffer) WriteRune(r rune) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.WriteRune(r) +} + +func (b *LockingBuffer) WriteString(s string) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.WriteString(s) +} + +func (b *LockingBuffer) WriteTo(w io.Writer) (n int64, err error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.b.WriteTo(w) +} diff --git a/internal/testingutil/testingutil.go b/internal/testingutil/testingutil.go index 99bc47f..22636f2 100644 --- a/internal/testingutil/testingutil.go +++ b/internal/testingutil/testingutil.go @@ -1,12 +1,13 @@ package testingutil import ( + "io" "os" "testing" ) -// MustReadFile reads all data from filename. Fail on error. -func MustReadFile(tb testing.TB, filename string) []byte { +// ReadFile reads all data from filename. Fail on error. +func ReadFile(tb testing.TB, filename string) []byte { tb.Helper() b, err := os.ReadFile(filename) if err != nil { @@ -15,6 +16,26 @@ func MustReadFile(tb testing.TB, filename string) []byte { return b } +// CopyFile copies all data from src to dst. Fail on error. +func CopyFile(tb testing.TB, src, dst string) { + tb.Helper() + r, err := os.Open(src) + if err != nil { + tb.Fatal(err) + } + defer r.Close() + + w, err := os.Create(dst) + if err != nil { + tb.Fatal(err) + } + defer w.Close() + + if _, err := io.Copy(w, r); err != nil { + tb.Fatal(err) + } +} + // Getpwd returns the working directory. Fail on error. func Getwd(tb testing.TB) string { tb.Helper() diff --git a/litestream.go b/litestream.go index e962f14..98f94a8 100644 --- a/litestream.go +++ b/litestream.go @@ -1,8 +1,10 @@ package litestream import ( + "crypto/md5" "database/sql" "encoding/binary" + "encoding/hex" "errors" "fmt" "io" @@ -43,7 +45,7 @@ var ( var ( // LogWriter is the destination writer for all logging. - LogWriter = os.Stderr + LogWriter = os.Stdout // LogFlags are the flags passed to log.New(). LogFlags = 0 @@ -460,6 +462,12 @@ func isHexChar(ch rune) bool { return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') } +// md5hash returns a hex-encoded MD5 hash of b. +func md5hash(b []byte) string { + sum := md5.Sum(b) + return hex.EncodeToString(sum[:]) +} + // Tracef is used for low-level tracing. var Tracef = func(format string, a ...interface{}) {} diff --git a/replica.go b/replica.go index 67e9d14..8547e61 100644 --- a/replica.go +++ b/replica.go @@ -43,7 +43,7 @@ type Replica struct { cancel func() // Client used to connect to the remote replica. - Client ReplicaClient + client ReplicaClient // Time between syncs with the shadow WAL. SyncInterval time.Duration @@ -68,10 +68,11 @@ type Replica struct { Logger *log.Logger } -func NewReplica(db *DB, name string) *Replica { +func NewReplica(db *DB, name string, client ReplicaClient) *Replica { r := &Replica{ db: db, name: name, + client: client, cancel: func() {}, SyncInterval: DefaultSyncInterval, @@ -82,7 +83,7 @@ func NewReplica(db *DB, name string) *Replica { prefix := fmt.Sprintf("%s: ", r.Name()) if db != nil { - prefix = fmt.Sprintf("%s(%s): ", db.Path(), r.Name()) + prefix = fmt.Sprintf("%s(%s): ", logPrefixPath(db.Path()), r.Name()) } r.Logger = log.New(LogWriter, prefix, LogFlags) @@ -91,8 +92,8 @@ func NewReplica(db *DB, name string) *Replica { // Name returns the name of the replica. func (r *Replica) Name() string { - if r.name == "" && r.Client != nil { - return r.Client.Type() + if r.name == "" && r.client != nil { + return r.client.Type() } return r.name } @@ -100,6 +101,9 @@ func (r *Replica) Name() string { // DB returns a reference to the database the replica is attached to, if any. func (r *Replica) DB() *DB { return r.db } +// Client returns the client the replica was initialized with. +func (r *Replica) Client() ReplicaClient { return r.client } + // Starts replicating in a background goroutine. func (r *Replica) Start(ctx context.Context) error { // Ignore if replica is being used sychronously. @@ -265,7 +269,7 @@ func (r *Replica) writeIndexSegments(ctx context.Context, segments []WALSegmentI // Copy through pipe into client from the starting position. var g errgroup.Group g.Go(func() error { - _, err := r.Client.WriteWALSegment(ctx, initialPos, pr) + _, err := r.client.WriteWALSegment(ctx, initialPos, pr) return err }) @@ -332,7 +336,7 @@ func (r *Replica) writeIndexSegments(ctx context.Context, segments []WALSegmentI // snapshotN returns the number of snapshots for a generation. func (r *Replica) snapshotN(generation string) (int, error) { - itr, err := r.Client.Snapshots(context.Background(), generation) + itr, err := r.client.Snapshots(context.Background(), generation) if err != nil { return 0, err } @@ -364,7 +368,7 @@ func (r *Replica) calcPos(ctx context.Context, generation string) (pos Pos, err } // Read segment to determine size to add to offset. - rd, err := r.Client.WALSegmentReader(ctx, segment.Pos()) + rd, err := r.client.WALSegmentReader(ctx, segment.Pos()) if err != nil { return pos, fmt.Errorf("wal segment reader: %w", err) } @@ -385,7 +389,7 @@ func (r *Replica) calcPos(ctx context.Context, generation string) (pos Pos, err // maxSnapshot returns the last snapshot in a generation. func (r *Replica) maxSnapshot(ctx context.Context, generation string) (*SnapshotInfo, error) { - itr, err := r.Client.Snapshots(ctx, generation) + itr, err := r.client.Snapshots(ctx, generation) if err != nil { return nil, err } @@ -402,7 +406,7 @@ func (r *Replica) maxSnapshot(ctx context.Context, generation string) (*Snapshot // maxWALSegment returns the highest WAL segment in a generation. func (r *Replica) maxWALSegment(ctx context.Context, generation string) (*WALSegmentInfo, error) { - itr, err := r.Client.WALSegments(ctx, generation) + itr, err := r.client.WALSegments(ctx, generation) if err != nil { return nil, err } @@ -427,7 +431,7 @@ func (r *Replica) Pos() Pos { // Snapshots returns a list of all snapshots across all generations. func (r *Replica) Snapshots(ctx context.Context) ([]SnapshotInfo, error) { - generations, err := r.Client.Generations(ctx) + generations, err := r.client.Generations(ctx) if err != nil { return nil, fmt.Errorf("cannot fetch generations: %w", err) } @@ -435,7 +439,7 @@ func (r *Replica) Snapshots(ctx context.Context) ([]SnapshotInfo, error) { var a []SnapshotInfo for _, generation := range generations { if err := func() error { - itr, err := r.Client.Snapshots(ctx, generation) + itr, err := r.client.Snapshots(ctx, generation) if err != nil { return err } @@ -518,7 +522,7 @@ func (r *Replica) Snapshot(ctx context.Context) (info SnapshotInfo, err error) { }) // Delegate write to client & wait for writer goroutine to finish. - if info, err = r.Client.WriteSnapshot(ctx, pos.Generation, pos.Index, pr); err != nil { + if info, err = r.client.WriteSnapshot(ctx, pos.Generation, pos.Index, pr); err != nil { return info, err } else if err := g.Wait(); err != nil { return info, err @@ -549,7 +553,7 @@ func (r *Replica) EnforceRetention(ctx context.Context) (err error) { } // Loop over generations and delete unretained snapshots & WAL files. - generations, err := r.Client.Generations(ctx) + generations, err := r.client.Generations(ctx) if err != nil { return fmt.Errorf("generations: %w", err) } @@ -559,7 +563,7 @@ func (r *Replica) EnforceRetention(ctx context.Context) (err error) { // Delete entire generation if no snapshots are being retained. if snapshot == nil { - if err := r.Client.DeleteGeneration(ctx, generation); err != nil { + if err := r.client.DeleteGeneration(ctx, generation); err != nil { return fmt.Errorf("delete generation: %w", err) } continue @@ -577,7 +581,7 @@ func (r *Replica) EnforceRetention(ctx context.Context) (err error) { } func (r *Replica) deleteSnapshotsBeforeIndex(ctx context.Context, generation string, index int) error { - itr, err := r.Client.Snapshots(ctx, generation) + itr, err := r.client.Snapshots(ctx, generation) if err != nil { return fmt.Errorf("fetch snapshots: %w", err) } @@ -589,7 +593,7 @@ func (r *Replica) deleteSnapshotsBeforeIndex(ctx context.Context, generation str continue } - if err := r.Client.DeleteSnapshot(ctx, info.Generation, info.Index); err != nil { + if err := r.client.DeleteSnapshot(ctx, info.Generation, info.Index); err != nil { return fmt.Errorf("delete snapshot %s/%08x: %w", info.Generation, info.Index, err) } r.Logger.Printf("snapshot deleted %s/%08x", generation, index) @@ -599,7 +603,7 @@ func (r *Replica) deleteSnapshotsBeforeIndex(ctx context.Context, generation str } func (r *Replica) deleteWALSegmentsBeforeIndex(ctx context.Context, generation string, index int) error { - itr, err := r.Client.WALSegments(ctx, generation) + itr, err := r.client.WALSegments(ctx, generation) if err != nil { return fmt.Errorf("fetch wal segments: %w", err) } @@ -621,7 +625,7 @@ func (r *Replica) deleteWALSegmentsBeforeIndex(ctx context.Context, generation s return nil } - if err := r.Client.DeleteWALSegments(ctx, a); err != nil { + if err := r.client.DeleteWALSegments(ctx, a); err != nil { return fmt.Errorf("delete wal segments: %w", err) } @@ -774,7 +778,7 @@ func (r *Replica) Validate(ctx context.Context) error { } // Find lastest snapshot that occurs before the index. - snapshotIndex, err := FindSnapshotForIndex(ctx, r.Client, pos.Generation, pos.Index-1) + snapshotIndex, err := FindSnapshotForIndex(ctx, r.client, pos.Generation, pos.Index-1) if err != nil { return fmt.Errorf("cannot find snapshot index: %w", err) } @@ -784,7 +788,7 @@ func (r *Replica) Validate(ctx context.Context) error { Logger: log.New(os.Stderr, "", 0), LogPrefix: r.logPrefix(), } - if err := Restore(ctx, r.Client, restorePath, pos.Generation, snapshotIndex, pos.Index-1, opt); err != nil { + if err := Restore(ctx, r.client, restorePath, pos.Generation, snapshotIndex, pos.Index-1, opt); err != nil { return fmt.Errorf("cannot restore: %w", err) } @@ -880,7 +884,7 @@ func (r *Replica) waitForReplica(ctx context.Context, pos Pos) error { func (r *Replica) GenerationCreatedAt(ctx context.Context, generation string) (time.Time, error) { var min time.Time - itr, err := r.Client.Snapshots(ctx, generation) + itr, err := r.client.Snapshots(ctx, generation) if err != nil { return min, err } @@ -897,7 +901,7 @@ func (r *Replica) GenerationCreatedAt(ctx context.Context, generation string) (t // SnapshotIndexAt returns the highest index for a snapshot within a generation // that occurs before timestamp. If timestamp is zero, returns the latest snapshot. func (r *Replica) SnapshotIndexAt(ctx context.Context, generation string, timestamp time.Time) (int, error) { - itr, err := r.Client.Snapshots(ctx, generation) + itr, err := r.client.Snapshots(ctx, generation) if err != nil { return 0, err } @@ -929,7 +933,7 @@ func LatestReplica(ctx context.Context, replicas []*Replica) (*Replica, error) { var t time.Time var r *Replica for i := range replicas { - _, max, err := ReplicaClientTimeBounds(ctx, replicas[i].Client) + _, max, err := ReplicaClientTimeBounds(ctx, replicas[i].client) if err != nil { return nil, err } else if r == nil || max.After(t) { diff --git a/replica_test.go b/replica_test.go index a0220bb..97455ac 100644 --- a/replica_test.go +++ b/replica_test.go @@ -14,13 +14,12 @@ import ( func TestReplica_Name(t *testing.T) { t.Run("WithName", func(t *testing.T) { - if got, want := litestream.NewReplica(nil, "NAME").Name(), "NAME"; got != want { + if got, want := litestream.NewReplica(nil, "NAME", nil).Name(), "NAME"; got != want { t.Fatalf("Name()=%v, want %v", got, want) } }) t.Run("WithoutName", func(t *testing.T) { - r := litestream.NewReplica(nil, "") - r.Client = &mock.ReplicaClient{} + r := litestream.NewReplica(nil, "", &mock.ReplicaClient{}) if got, want := r.Name(), "mock"; got != want { t.Fatalf("Name()=%v, want %v", got, want) } @@ -45,8 +44,7 @@ func TestReplica_Sync(t *testing.T) { dpos := db.Pos() c := litestream.NewFileReplicaClient(t.TempDir()) - r := litestream.NewReplica(db, "") - r.Client = c + r := litestream.NewReplica(db, "", c) if err := r.Sync(context.Background()); err != nil { t.Fatal(err) @@ -81,8 +79,7 @@ func TestReplica_Snapshot(t *testing.T) { defer MustCloseDBs(t, db, sqldb) c := litestream.NewFileReplicaClient(t.TempDir()) - r := litestream.NewReplica(db, "") - r.Client = c + r := litestream.NewReplica(db, "", c) // Execute a query to force a write to the WAL. if _, err := sqldb.Exec(`CREATE TABLE foo (bar TEXT);`); err != nil {