Add end-to-end replication/restore testing

This commit is contained in:
Ben Johnson
2022-01-14 15:31:04 -07:00
parent f308e0b154
commit 84d08f547a
27 changed files with 755 additions and 117 deletions

View File

@@ -29,6 +29,9 @@ jobs:
env:
LITESTREAM_SFTP_KEY: ${{secrets.LITESTREAM_SFTP_KEY}}
- name: Build binary
run: go install ./cmd/litestream
- name: Run unit tests
run: make testdata && go test -v ./...
@@ -53,10 +56,13 @@ jobs:
LITESTREAM_ABS_ACCOUNT_KEY: ${{ secrets.LITESTREAM_ABS_ACCOUNT_KEY }}
LITESTREAM_ABS_BUCKET: integration
- name: Run sftp tests
run: go test -v -run=TestReplicaClient ./integration -replica-type sftp
env:
LITESTREAM_SFTP_HOST: ${{ secrets.LITESTREAM_SFTP_HOST }}
LITESTREAM_SFTP_USER: ${{ secrets.LITESTREAM_SFTP_USER }}
LITESTREAM_SFTP_KEY_PATH: /opt/id_ed25519
LITESTREAM_SFTP_PATH: ${{ secrets.LITESTREAM_SFTP_PATH }}
#- name: Run sftp tests
# run: go test -v -run=TestReplicaClient ./integration -replica-type sftp
# env:
# LITESTREAM_SFTP_HOST: ${{ secrets.LITESTREAM_SFTP_HOST }}
# LITESTREAM_SFTP_USER: ${{ secrets.LITESTREAM_SFTP_USER }}
# LITESTREAM_SFTP_KEY_PATH: /opt/id_ed25519
# LITESTREAM_SFTP_PATH: ${{ secrets.LITESTREAM_SFTP_PATH }}
- name: Run long-running test
run: go test -v ./integration -long-running-duration 1m

View File

@@ -16,7 +16,7 @@ func TestDatabasesCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"databases", "-config", filepath.Join(testDir, "litestream.yml")}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -26,7 +26,7 @@ func TestDatabasesCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"databases", "-config", filepath.Join(testDir, "litestream.yml")}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})

View File

@@ -74,7 +74,7 @@ func (c *GenerationsCommand) Run(ctx context.Context, args []string) (ret error)
fmt.Fprintln(w, "name\tgeneration\tlag\tstart\tend")
for _, r := range replicas {
generations, err := r.Client.Generations(ctx)
generations, err := r.Client().Generations(ctx)
if err != nil {
fmt.Fprintf(c.stderr, "%s: cannot list generations: %s", r.Name(), err)
ret = errExit // signal error return without printing message
@@ -83,7 +83,7 @@ func (c *GenerationsCommand) Run(ctx context.Context, args []string) (ret error)
// Iterate over each generation for the replica.
for _, generation := range generations {
createdAt, updatedAt, err := litestream.GenerationTimeBounds(ctx, r.Client, generation)
createdAt, updatedAt, err := litestream.GenerationTimeBounds(ctx, r.Client(), generation)
if err != nil {
fmt.Fprintf(c.stderr, "%s: cannot determine generation time bounds: %s", r.Name(), err)
ret = errExit // signal error return without printing message

View File

@@ -18,7 +18,7 @@ func TestGenerationsCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"generations", "-config", filepath.Join(testDir, "litestream.yml"), filepath.Join(testDir, "db")}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -30,7 +30,7 @@ func TestGenerationsCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"generations", "-config", filepath.Join(testDir, "litestream.yml"), "-replica", "replica1", filepath.Join(testDir, "db")}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -43,7 +43,7 @@ func TestGenerationsCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"generations", replicaURL}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -55,7 +55,7 @@ func TestGenerationsCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"generations", "-config", filepath.Join(testDir, "litestream.yml"), filepath.Join(testDir, "db")}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})

View File

@@ -38,6 +38,7 @@ var errExit = errors.New("exit")
func main() {
log.SetFlags(0)
log.SetOutput(os.Stdout)
m := NewMain(os.Stdin, os.Stdout, os.Stderr)
if err := m.Run(context.Background(), os.Args[1:]); err == flag.ErrHelp || err == errExit {
@@ -354,8 +355,35 @@ func NewReplicaFromConfig(c *ReplicaConfig, db *litestream.DB) (_ *litestream.Re
return nil, fmt.Errorf("replica path cannot be a url, please use the 'url' field instead: %s", c.Path)
}
// Build and set client on replica.
var client litestream.ReplicaClient
switch typ := c.ReplicaType(); typ {
case "file":
if client, err = newFileReplicaClientFromConfig(c); err != nil {
return nil, err
}
case "s3":
if client, err = newS3ReplicaClientFromConfig(c); err != nil {
return nil, err
}
case "gcs":
if client, err = newGCSReplicaClientFromConfig(c); err != nil {
return nil, err
}
case "abs":
if client, err = newABSReplicaClientFromConfig(c); err != nil {
return nil, err
}
case "sftp":
if client, err = newSFTPReplicaClientFromConfig(c); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unknown replica type in config: %q", typ)
}
// Build replica.
r := litestream.NewReplica(db, c.Name)
r := litestream.NewReplica(db, c.Name, client)
if v := c.Retention; v != nil {
r.Retention = *v
}
@@ -372,37 +400,11 @@ func NewReplicaFromConfig(c *ReplicaConfig, db *litestream.DB) (_ *litestream.Re
r.ValidationInterval = *v
}
// Build and set client on replica.
switch typ := c.ReplicaType(); typ {
case "file":
if r.Client, err = newFileReplicaClientFromConfig(c, r); err != nil {
return nil, err
}
case "s3":
if r.Client, err = newS3ReplicaClientFromConfig(c, r); err != nil {
return nil, err
}
case "gcs":
if r.Client, err = newGCSReplicaClientFromConfig(c, r); err != nil {
return nil, err
}
case "abs":
if r.Client, err = newABSReplicaClientFromConfig(c, r); err != nil {
return nil, err
}
case "sftp":
if r.Client, err = newSFTPReplicaClientFromConfig(c, r); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unknown replica type in config: %q", typ)
}
return r, nil
}
// newFileReplicaClientFromConfig returns a new instance of FileReplicaClient built from config.
func newFileReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *litestream.FileReplicaClient, err error) {
func newFileReplicaClientFromConfig(c *ReplicaConfig) (_ *litestream.FileReplicaClient, err error) {
// Ensure URL & path are not both specified.
if c.URL != "" && c.Path != "" {
return nil, fmt.Errorf("cannot specify url & path for file replica")
@@ -431,7 +433,7 @@ func newFileReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_
}
// newS3ReplicaClientFromConfig returns a new instance of s3.ReplicaClient built from config.
func newS3ReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *s3.ReplicaClient, err error) {
func newS3ReplicaClientFromConfig(c *ReplicaConfig) (_ *s3.ReplicaClient, err error) {
// Ensure URL & constituent parts are not both specified.
if c.URL != "" && c.Path != "" {
return nil, fmt.Errorf("cannot specify url & path for s3 replica")
@@ -494,7 +496,7 @@ func newS3ReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *s
}
// newGCSReplicaClientFromConfig returns a new instance of gcs.ReplicaClient built from config.
func newGCSReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *gcs.ReplicaClient, err error) {
func newGCSReplicaClientFromConfig(c *ReplicaConfig) (_ *gcs.ReplicaClient, err error) {
// Ensure URL & constituent parts are not both specified.
if c.URL != "" && c.Path != "" {
return nil, fmt.Errorf("cannot specify url & path for gcs replica")
@@ -533,7 +535,7 @@ func newGCSReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *
}
// newABSReplicaClientFromConfig returns a new instance of abs.ReplicaClient built from config.
func newABSReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *abs.ReplicaClient, err error) {
func newABSReplicaClientFromConfig(c *ReplicaConfig) (_ *abs.ReplicaClient, err error) {
// Ensure URL & constituent parts are not both specified.
if c.URL != "" && c.Path != "" {
return nil, fmt.Errorf("cannot specify url & path for abs replica")
@@ -576,7 +578,7 @@ func newABSReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *
}
// newSFTPReplicaClientFromConfig returns a new instance of sftp.ReplicaClient built from config.
func newSFTPReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *sftp.ReplicaClient, err error) {
func newSFTPReplicaClientFromConfig(c *ReplicaConfig) (_ *sftp.ReplicaClient, err error) {
// Ensure URL & constituent parts are not both specified.
if c.URL != "" && c.Path != "" {
return nil, fmt.Errorf("cannot specify url & path for sftp replica")

View File

@@ -104,7 +104,7 @@ func TestNewFileReplicaFromConfig(t *testing.T) {
r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{Path: "/foo"}, nil)
if err != nil {
t.Fatal(err)
} else if client, ok := r.Client.(*litestream.FileReplicaClient); !ok {
} else if client, ok := r.Client().(*litestream.FileReplicaClient); !ok {
t.Fatal("unexpected replica type")
} else if got, want := client.Path(), "/foo"; got != want {
t.Fatalf("Path=%s, want %s", got, want)
@@ -116,7 +116,7 @@ func TestNewS3ReplicaFromConfig(t *testing.T) {
r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{URL: "s3://foo/bar"}, nil)
if err != nil {
t.Fatal(err)
} else if client, ok := r.Client.(*s3.ReplicaClient); !ok {
} else if client, ok := r.Client().(*s3.ReplicaClient); !ok {
t.Fatal("unexpected replica type")
} else if got, want := client.Bucket, "foo"; got != want {
t.Fatalf("Bucket=%s, want %s", got, want)
@@ -135,7 +135,7 @@ func TestNewS3ReplicaFromConfig(t *testing.T) {
r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{URL: "s3://foo.localhost:9000/bar"}, nil)
if err != nil {
t.Fatal(err)
} else if client, ok := r.Client.(*s3.ReplicaClient); !ok {
} else if client, ok := r.Client().(*s3.ReplicaClient); !ok {
t.Fatal("unexpected replica type")
} else if got, want := client.Bucket, "foo"; got != want {
t.Fatalf("Bucket=%s, want %s", got, want)
@@ -154,7 +154,7 @@ func TestNewS3ReplicaFromConfig(t *testing.T) {
r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{URL: "s3://foo.s3.us-west-000.backblazeb2.com/bar"}, nil)
if err != nil {
t.Fatal(err)
} else if client, ok := r.Client.(*s3.ReplicaClient); !ok {
} else if client, ok := r.Client().(*s3.ReplicaClient); !ok {
t.Fatal("unexpected replica type")
} else if got, want := client.Bucket, "foo"; got != want {
t.Fatalf("Bucket=%s, want %s", got, want)
@@ -174,7 +174,7 @@ func TestNewGCSReplicaFromConfig(t *testing.T) {
r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{URL: "gcs://foo/bar"}, nil)
if err != nil {
t.Fatal(err)
} else if client, ok := r.Client.(*gcs.ReplicaClient); !ok {
} else if client, ok := r.Client().(*gcs.ReplicaClient); !ok {
t.Fatal("unexpected replica type")
} else if got, want := client.Bucket, "foo"; got != want {
t.Fatalf("Bucket=%s, want %s", got, want)

View File

@@ -37,7 +37,7 @@ func runWindowsService(ctx context.Context) error {
// Set eventlog as log writer while running.
log.SetOutput((*eventlogWriter)(elog))
defer log.SetOutput(os.Stderr)
defer log.SetOutput(os.Stdout)
log.Print("Litestream service starting")

View File

@@ -121,7 +121,7 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) {
for _, db := range c.DBs {
log.Printf("initialized db: %s", db.Path())
for _, r := range db.Replicas {
switch client := r.Client.(type) {
switch client := r.Client().(type) {
case *litestream.FileReplicaClient:
log.Printf("replicating to: name=%q type=%q path=%q", r.Name(), client.Type(), client.Path())
case *s3.ReplicaClient:
@@ -173,6 +173,8 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) {
go func() { c.execCh <- c.cmd.Wait() }()
}
log.Printf("litestream initialization complete")
return nil
}

View File

@@ -104,7 +104,7 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) {
// Determine latest generation if one is not specified.
if c.generation == "" {
if c.generation, err = litestream.FindLatestGeneration(ctx, r.Client); err == litestream.ErrNoGeneration {
if c.generation, err = litestream.FindLatestGeneration(ctx, r.Client()); err == litestream.ErrNoGeneration {
// Return an error if no matching targets found.
// If optional flag set, return success. Useful for automated recovery.
if c.ifReplicaExists {
@@ -119,14 +119,14 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) {
// Determine the maximum available index for the generation if one is not specified.
if c.targetIndex == -1 {
if c.targetIndex, err = litestream.FindMaxIndexByGeneration(ctx, r.Client, c.generation); err != nil {
if c.targetIndex, err = litestream.FindMaxIndexByGeneration(ctx, r.Client(), c.generation); err != nil {
return fmt.Errorf("cannot determine latest index in generation %q: %w", c.generation, err)
}
}
// Find lastest snapshot that occurs before the index.
// TODO: Optionally allow -snapshot-index
if c.snapshotIndex, err = litestream.FindSnapshotForIndex(ctx, r.Client, c.generation, c.targetIndex); err != nil {
if c.snapshotIndex, err = litestream.FindSnapshotForIndex(ctx, r.Client(), c.generation, c.targetIndex); err != nil {
return fmt.Errorf("cannot find snapshot index: %w", err)
}
@@ -137,7 +137,7 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) {
c.opt.Logger = log.New(c.stdout, "", log.LstdFlags|log.Lmicroseconds)
return litestream.Restore(ctx, r.Client, c.outputPath, c.generation, c.snapshotIndex, c.targetIndex, c.opt)
return litestream.Restore(ctx, r.Client(), c.outputPath, c.generation, c.snapshotIndex, c.targetIndex, c.opt)
}
func (c *RestoreCommand) loadReplica(ctx context.Context, config Config, arg string) (*litestream.Replica, error) {

View File

@@ -120,7 +120,7 @@ func TestRestoreCommand(t *testing.T) {
err := m.Run(context.Background(), []string{"restore", "-config", filepath.Join(testDir, "litestream.yml"), "-if-db-not-exists", filepath.Join(testDir, "db")})
if err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -133,7 +133,7 @@ func TestRestoreCommand(t *testing.T) {
err := m.Run(context.Background(), []string{"restore", "-config", filepath.Join(testDir, "litestream.yml"), "-if-replica-exists", filepath.Join(testDir, "db")})
if err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -147,9 +147,9 @@ func TestRestoreCommand(t *testing.T) {
err := m.Run(context.Background(), []string{"restore", "-config", filepath.Join(testDir, "litestream.yml"), "-o", filepath.Join(tempDir, "db"), filepath.Join(testDir, "db")})
if err == nil || err.Error() != `no matching backups found` {
t.Fatalf("unexpected error: %s", err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
} else if got, want := stderr.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stderr"))); got != want {
} else if got, want := stderr.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stderr"))); got != want {
t.Fatalf("stderr=%q, want %q", got, want)
}
})

View File

@@ -18,7 +18,7 @@ func TestSnapshotsCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"snapshots", "-config", filepath.Join(testDir, "litestream.yml"), filepath.Join(testDir, "db")}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -30,7 +30,7 @@ func TestSnapshotsCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"snapshots", "-config", filepath.Join(testDir, "litestream.yml"), "-replica", "replica1", filepath.Join(testDir, "db")}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -43,7 +43,7 @@ func TestSnapshotsCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"snapshots", replicaURL}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})

View File

@@ -69,7 +69,7 @@ func (c *WALCommand) Run(ctx context.Context, args []string) (ret error) {
if c.generation != "" {
generations = []string{c.generation}
} else {
if generations, err = r.Client.Generations(ctx); err != nil {
if generations, err = r.Client().Generations(ctx); err != nil {
log.Printf("%s: cannot determine generations: %s", r.Name(), err)
ret = errExit // signal error return without printing message
continue
@@ -78,7 +78,7 @@ func (c *WALCommand) Run(ctx context.Context, args []string) (ret error) {
for _, generation := range generations {
if err := func() error {
itr, err := r.Client.WALSegments(ctx, generation)
itr, err := r.Client().WALSegments(ctx, generation)
if err != nil {
return err
}

View File

@@ -18,7 +18,7 @@ func TestWALCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"wal", "-config", filepath.Join(testDir, "litestream.yml"), filepath.Join(testDir, "db")}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -30,7 +30,7 @@ func TestWALCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"wal", "-config", filepath.Join(testDir, "litestream.yml"), "-replica", "replica1", filepath.Join(testDir, "db")}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})
@@ -43,7 +43,7 @@ func TestWALCommand(t *testing.T) {
m, _, stdout, _ := newMain()
if err := m.Run(context.Background(), []string{"wal", replicaURL}); err != nil {
t.Fatal(err)
} else if got, want := stdout.String(), string(testingutil.MustReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
} else if got, want := stdout.String(), string(testingutil.ReadFile(t, filepath.Join(testDir, "stdout"))); got != want {
t.Fatalf("stdout=%q, want %q", got, want)
}
})

25
db.go
View File

@@ -121,7 +121,7 @@ func NewDB(path string) *DB {
CheckpointInterval: DefaultCheckpointInterval,
MonitorInterval: DefaultMonitorInterval,
Logger: log.New(LogWriter, fmt.Sprintf("%s: ", path), LogFlags),
Logger: log.New(LogWriter, fmt.Sprintf("%s: ", logPrefixPath(path)), LogFlags),
}
db.dbSizeGauge = dbSizeGaugeVec.WithLabelValues(db.path)
@@ -300,7 +300,7 @@ func (db *DB) invalidateChecksum(ctx context.Context) error {
r := &io.LimitedReader{R: rc, N: db.pos.Offset}
// Determine cache values from the current WAL file.
db.salt0, db.salt1, db.chksum0, db.chksum1, db.byteOrder, db.frame, err = ReadWALFields(r, db.pageSize)
db.salt0, db.salt1, db.chksum0, db.chksum1, db.byteOrder, db.hdr, db.frame, err = ReadWALFields(r, db.pageSize)
if err != nil {
return fmt.Errorf("calc checksum: %w", err)
}
@@ -1621,11 +1621,11 @@ var walPathRegex = regexp.MustCompile(`^([0-9a-f]{8})\.wal$`)
// Returns salt, checksum, byte order & the last frame. WAL data must start
// from the beginning of the WAL header and must end on either the WAL header
// or at the end of a WAL frame.
func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 uint32, byteOrder binary.ByteOrder, frame []byte, err error) {
func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 uint32, byteOrder binary.ByteOrder, hdr, frame []byte, err error) {
// Read header.
hdr := make([]byte, WALHeaderSize)
hdr = make([]byte, WALHeaderSize)
if _, err := io.ReadFull(r, hdr); err != nil {
return 0, 0, 0, 0, nil, nil, fmt.Errorf("short wal header: %w", err)
return 0, 0, 0, 0, nil, nil, nil, fmt.Errorf("short wal header: %w", err)
}
// Save salt, initial checksum, & byte order.
@@ -1634,7 +1634,7 @@ func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 ui
chksum0 = binary.BigEndian.Uint32(hdr[24:])
chksum1 = binary.BigEndian.Uint32(hdr[28:])
if byteOrder, err = headerByteOrder(hdr); err != nil {
return 0, 0, 0, 0, nil, nil, err
return 0, 0, 0, 0, nil, nil, nil, err
}
// Iterate over each page in the WAL and save the checksum.
@@ -1645,7 +1645,7 @@ func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 ui
if n, err := io.ReadFull(r, frame); err == io.EOF {
break // end of WAL file
} else if err != nil {
return 0, 0, 0, 0, nil, nil, fmt.Errorf("short wal frame (n=%d): %w", n, err)
return 0, 0, 0, 0, nil, nil, nil, fmt.Errorf("short wal frame (n=%d): %w", n, err)
}
// Update checksum on each successful frame.
@@ -1659,7 +1659,7 @@ func ReadWALFields(r io.Reader, pageSize int) (salt0, salt1, chksum0, chksum1 ui
frame = nil
}
return salt0, salt1, chksum0, chksum1, byteOrder, frame, nil
return salt0, salt1, chksum0, chksum1, byteOrder, hdr, frame, nil
}
// Database metrics.
@@ -1731,3 +1731,12 @@ func headerByteOrder(hdr []byte) (binary.ByteOrder, error) {
return nil, fmt.Errorf("invalid wal header magic: %x", magic)
}
}
// logPrefixPath returns the path to be used for logging.
// The path is reduced to its base if it appears to be a temporary test path.
func logPrefixPath(path string) string {
if strings.Contains(path, "TestCmd") {
return filepath.Base(path)
}
return path
}

View File

@@ -482,7 +482,7 @@ func TestReadWALFields(t *testing.T) {
}
t.Run("OK", func(t *testing.T) {
if salt0, salt1, chksum0, chksum1, byteOrder, frame, err := litestream.ReadWALFields(bytes.NewReader(b), 4096); err != nil {
if salt0, salt1, chksum0, chksum1, byteOrder, _, frame, err := litestream.ReadWALFields(bytes.NewReader(b), 4096); err != nil {
t.Fatal(err)
} else if got, want := salt0, uint32(0x4F7598FD); got != want {
t.Fatalf("salt0=%x, want %x", got, want)
@@ -500,7 +500,7 @@ func TestReadWALFields(t *testing.T) {
})
t.Run("HeaderOnly", func(t *testing.T) {
if salt0, salt1, chksum0, chksum1, byteOrder, frame, err := litestream.ReadWALFields(bytes.NewReader(b[:32]), 4096); err != nil {
if salt0, salt1, chksum0, chksum1, byteOrder, _, frame, err := litestream.ReadWALFields(bytes.NewReader(b[:32]), 4096); err != nil {
t.Fatal(err)
} else if got, want := salt0, uint32(0x4F7598FD); got != want {
t.Fatalf("salt0=%x, want %x", got, want)
@@ -518,19 +518,19 @@ func TestReadWALFields(t *testing.T) {
})
t.Run("ErrShortHeader", func(t *testing.T) {
if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader([]byte{}), 4096); err == nil || err.Error() != `short wal header: EOF` {
if _, _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader([]byte{}), 4096); err == nil || err.Error() != `short wal header: EOF` {
t.Fatal(err)
}
})
t.Run("ErrBadMagic", func(t *testing.T) {
if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(make([]byte, 32)), 4096); err == nil || err.Error() != `invalid wal header magic: 0` {
if _, _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(make([]byte, 32)), 4096); err == nil || err.Error() != `invalid wal header magic: 0` {
t.Fatal(err)
}
})
t.Run("ErrShortFrame", func(t *testing.T) {
if _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(b[:100]), 4096); err == nil || err.Error() != `short wal frame (n=68): unexpected EOF` {
if _, _, _, _, _, _, _, err := litestream.ReadWALFields(bytes.NewReader(b[:100]), 4096); err == nil || err.Error() != `short wal frame (n=68): unexpected EOF` {
t.Fatal(err)
}
})

411
integration/cmd_test.go Normal file
View File

@@ -0,0 +1,411 @@
package integration_test
import (
"bytes"
"context"
"database/sql"
"flag"
"fmt"
"io"
"math/rand"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"github.com/benbjohnson/litestream/internal"
"github.com/benbjohnson/litestream/internal/testingutil"
_ "github.com/mattn/go-sqlite3"
)
var longRunningDuration = flag.Duration("long-running-duration", 0, "")
func init() {
fmt.Fprintln(os.Stderr, "# ")
fmt.Fprintln(os.Stderr, "# NOTE: Build litestream to your PATH before running integration tests")
fmt.Fprintln(os.Stderr, "#")
fmt.Fprintln(os.Stderr, "")
}
// Ensure the default configuration works with light database load.
func TestCmd_Replicate_OK(t *testing.T) {
ctx := context.Background()
testDir, tempDir := filepath.Join("testdata", "replicate", "ok"), t.TempDir()
env := []string{"LITESTREAM_TEMPDIR=" + tempDir}
cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml"))
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db"))
if err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil {
t.Fatal(err)
}
defer db.Close()
// Execute writes periodically.
for i := 0; i < 100; i++ {
t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", i)
if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, i); err != nil {
t.Fatal(err)
}
time.Sleep(10 * time.Millisecond)
}
// Stop & wait for Litestream command.
killLitestreamCmd(t, cmd, stdout)
// Ensure signal and shutdown are logged.
if s := stdout.String(); !strings.Contains(s, `signal received, litestream shutting down`) {
t.Fatal("missing log output for signal received")
} else if s := stdout.String(); !strings.Contains(s, `litestream shut down`) {
t.Fatal("missing log output for shut down")
}
// Checkpoint & verify original SQLite database.
if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil {
t.Fatal(err)
}
restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db"))
}
// Ensure that stopping and restarting Litestream before an application-induced
// checkpoint will cause Litestream to continue replicating using the same generation.
func TestCmd_Replicate_ResumeWithCurrentGeneration(t *testing.T) {
ctx := context.Background()
testDir, tempDir := filepath.Join("testdata", "replicate", "resume-with-current-generation"), t.TempDir()
env := []string{"LITESTREAM_TEMPDIR=" + tempDir}
cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml"))
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
t.Log("writing to database during replication")
db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db"))
if err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil {
t.Fatal(err)
}
defer db.Close()
// Execute a few writes to populate the WAL.
if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (1)`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (2)`); err != nil {
t.Fatal(err)
}
// Wait for replication to occur & shutdown.
waitForLogMessage(t, stdout, `wal segment written`)
killLitestreamCmd(t, cmd, stdout)
t.Log("replication shutdown, continuing database writes")
// Execute a few more writes while replication is stopped.
if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (3)`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (4)`); err != nil {
t.Fatal(err)
}
t.Log("restarting replication")
cmd, stdout, _ = commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml"))
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
waitForLogMessage(t, stdout, `wal segment written`)
killLitestreamCmd(t, cmd, stdout)
t.Log("replication shutdown again")
// Litestream should resume replication from the previous generation.
if s := stdout.String(); strings.Contains(s, "no generation exists") {
t.Fatal("expected existing generation to resume; started new generation instead")
}
// Checkpoint & verify original SQLite database.
if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil {
t.Fatal(err)
}
restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db"))
}
// Ensure that restarting Litestream after a full checkpoint has occurred will
// cause it to begin a new generation.
func TestCmd_Replicate_ResumeWithNewGeneration(t *testing.T) {
ctx := context.Background()
testDir, tempDir := filepath.Join("testdata", "replicate", "resume-with-new-generation"), t.TempDir()
env := []string{"LITESTREAM_TEMPDIR=" + tempDir}
cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml"))
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
t.Log("writing to database during replication")
db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db"))
if err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil {
t.Fatal(err)
}
defer db.Close()
// Execute a few writes to populate the WAL.
if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (1)`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (2)`); err != nil {
t.Fatal(err)
}
// Wait for replication to occur & shutdown.
waitForLogMessage(t, stdout, `wal segment written`)
killLitestreamCmd(t, cmd, stdout)
t.Log("replication shutdown, continuing database writes")
// Execute a few more writes while replication is stopped.
if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (3)`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (4)`); err != nil {
t.Fatal(err)
}
t.Log("issuing checkpoint")
// Issue a checkpoint to restart WAL.
if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(RESTART)`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (5)`); err != nil {
t.Fatal(err)
}
t.Log("restarting replication")
cmd, stdout, _ = commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml"))
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
waitForLogMessage(t, stdout, `wal segment written`)
killLitestreamCmd(t, cmd, stdout)
t.Log("replication shutdown again")
// Litestream should resume replication from the previous generation.
if s := stdout.String(); !strings.Contains(s, "no generation exists") {
t.Fatal("expected new generation to start; continued existing generation instead")
}
// Checkpoint & verify original SQLite database.
if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil {
t.Fatal(err)
}
restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db"))
}
// Ensure the default configuration works with heavy write load.
func TestCmd_Replicate_HighLoad(t *testing.T) {
if testing.Short() {
t.Skip("short mode enabled, skipping")
}
const writeDuration = 30 * time.Second
ctx := context.Background()
testDir, tempDir := filepath.Join("testdata", "replicate", "high-load"), t.TempDir()
env := []string{"LITESTREAM_TEMPDIR=" + tempDir}
cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml"))
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db"))
if err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `PRAGMA journal_mode = WAL`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `PRAGMA synchronous = NORMAL`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `PRAGMA wal_autocheckpoint = 0`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil {
t.Fatal(err)
}
defer db.Close()
// Execute writes as fast as possible for a period of time.
timer := time.NewTimer(writeDuration)
defer timer.Stop()
t.Logf("executing writes for %s", writeDuration)
LOOP:
for i := 0; ; i++ {
select {
case <-timer.C:
break LOOP
default:
if i%1000 == 0 {
t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", i)
}
if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, i); err != nil {
t.Fatal(err)
}
}
}
t.Logf("writes complete, shutting down")
// Stop & wait for Litestream command.
time.Sleep(5 * time.Second)
killLitestreamCmd(t, cmd, stdout)
// Checkpoint & verify original SQLite database.
if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil {
t.Fatal(err)
}
restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db"))
}
// Ensure replication works for an extended period.
func TestCmd_Replicate_LongRunning(t *testing.T) {
if *longRunningDuration == 0 {
t.Skip("long running test duration not specified, skipping")
}
ctx := context.Background()
testDir, tempDir := filepath.Join("testdata", "replicate", "long-running"), t.TempDir()
env := []string{"LITESTREAM_TEMPDIR=" + tempDir}
cmd, stdout, _ := commandContext(ctx, env, "replicate", "-config", filepath.Join(testDir, "litestream.yml"))
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db"))
if err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `PRAGMA journal_mode = WAL`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `PRAGMA synchronous = NORMAL`); err != nil {
t.Fatal(err)
} else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil {
t.Fatal(err)
}
defer db.Close()
// Execute writes as fast as possible for a period of time.
timer := time.NewTimer(*longRunningDuration)
defer timer.Stop()
t.Logf("executing writes for %s", longRunningDuration)
LOOP:
for i := 0; ; i++ {
select {
case <-timer.C:
break LOOP
default:
t.Logf("[exec] INSERT INTO t (id) VALUES (%d)", i)
if _, err := db.ExecContext(ctx, `INSERT INTO t (id) VALUES (?)`, i); err != nil {
t.Fatal(err)
}
time.Sleep(time.Duration(rand.Intn(int(time.Second))))
}
}
t.Logf("writes complete, shutting down")
// Stop & wait for Litestream command.
killLitestreamCmd(t, cmd, stdout)
// Checkpoint & verify original SQLite database.
if _, err := db.ExecContext(ctx, `PRAGMA wal_checkpoint(TRUNCATE)`); err != nil {
t.Fatal(err)
}
restoreAndVerify(t, ctx, env, filepath.Join(testDir, "litestream.yml"), filepath.Join(tempDir, "db"))
}
// commandContext returns a "litestream" command with stdout/stderr buffers.
func commandContext(ctx context.Context, env []string, arg ...string) (cmd *exec.Cmd, stdout, stderr *internal.LockingBuffer) {
cmd = exec.CommandContext(ctx, "litestream", arg...)
cmd.Env = env
var outBuf, errBuf internal.LockingBuffer
// Split stdout/stderr to terminal if verbose flag set.
cmd.Stdout, cmd.Stderr = &outBuf, &errBuf
if testing.Verbose() {
cmd.Stdout = io.MultiWriter(&outBuf, os.Stdout)
cmd.Stderr = io.MultiWriter(&errBuf, os.Stderr)
}
return cmd, &outBuf, &errBuf
}
// waitForLogMessage continuously checks b for a message and returns when it occurs.
func waitForLogMessage(tb testing.TB, b *internal.LockingBuffer, msg string) {
timer := time.NewTimer(30 * time.Second)
defer timer.Stop()
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-timer.C:
tb.Fatal("timed out waiting for cmd initialization")
case <-ticker.C:
if strings.Contains(b.String(), msg) {
return
}
}
}
}
// killLitestreamCmd interrupts the process and waits for a clean shutdown.
func killLitestreamCmd(tb testing.TB, cmd *exec.Cmd, stdout *internal.LockingBuffer) {
if err := cmd.Process.Signal(os.Interrupt); err != nil {
tb.Fatal(err)
} else if err := cmd.Wait(); err != nil {
tb.Fatal(err)
}
}
// restoreAndVerify executes a "restore" and compares byte with the original database.
func restoreAndVerify(tb testing.TB, ctx context.Context, env []string, configPath, dbPath string) {
restorePath := filepath.Join(tb.TempDir(), "db")
// Restore database.
cmd, _, _ := commandContext(ctx, env, "restore", "-config", configPath, "-o", restorePath, dbPath)
if err := cmd.Run(); err != nil {
tb.Fatalf("error running 'restore' command: %s", err)
}
// Compare original database & restored database.
buf0 := testingutil.ReadFile(tb, dbPath)
buf1 := testingutil.ReadFile(tb, restorePath)
if bytes.Equal(buf0, buf1) {
return // ok, exit
}
// On mismatch, copy out original & restored DBs.
dir, err := os.MkdirTemp("", "litestream-*")
if err != nil {
tb.Fatal(err)
}
testingutil.CopyFile(tb, dbPath, filepath.Join(dir, "original.db"))
testingutil.CopyFile(tb, restorePath, filepath.Join(dir, "restored.db"))
tb.Fatalf("database mismatch; databases copied to %s", dir)
}

View File

@@ -0,0 +1,7 @@
dbs:
- path: $LITESTREAM_TEMPDIR/db
replicas:
- path: $LITESTREAM_TEMPDIR/replica
monitor-interval: 100ms
max-checkpoint-page-count: 20

View File

@@ -0,0 +1,4 @@
dbs:
- path: $LITESTREAM_TEMPDIR/db
replicas:
- path: $LITESTREAM_TEMPDIR/replica

View File

@@ -0,0 +1,7 @@
dbs:
- path: $LITESTREAM_TEMPDIR/db
replicas:
- path: $LITESTREAM_TEMPDIR/replica
monitor-interval: 100ms
max-checkpoint-page-count: 20

View File

@@ -0,0 +1,4 @@
dbs:
- path: $LITESTREAM_TEMPDIR/db
replicas:
- path: $LITESTREAM_TEMPDIR/replica

View File

@@ -0,0 +1,4 @@
dbs:
- path: $LITESTREAM_TEMPDIR/db
replicas:
- path: $LITESTREAM_TEMPDIR/replica

View File

@@ -0,0 +1,7 @@
dbs:
- path: $LITESTREAM_TEMPDIR/db
replicas:
- path: $LITESTREAM_TEMPDIR/replica
monitor-interval: 100ms
max-checkpoint-page-count: 10

145
internal/locking_buffer.go Normal file
View File

@@ -0,0 +1,145 @@
package internal
import (
"bytes"
"io"
"sync"
)
// LockingBuffer wraps a bytes.Buffer with a mutex.
type LockingBuffer struct {
mu sync.Mutex
b bytes.Buffer
}
func (b *LockingBuffer) Bytes() []byte {
b.mu.Lock()
defer b.mu.Unlock()
buf := b.b.Bytes()
other := make([]byte, len(buf))
copy(other, buf)
return other
}
func (b *LockingBuffer) Cap() int {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.Cap()
}
func (b *LockingBuffer) Grow(n int) {
b.mu.Lock()
defer b.mu.Unlock()
b.b.Grow(n)
}
func (b *LockingBuffer) Len() int {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.Len()
}
func (b *LockingBuffer) Next(n int) []byte {
b.mu.Lock()
defer b.mu.Unlock()
buf := b.b.Next(n)
other := make([]byte, len(buf))
copy(other, buf)
return other
}
func (b *LockingBuffer) Read(p []byte) (n int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.Read(p)
}
func (b *LockingBuffer) ReadByte() (byte, error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.ReadByte()
}
func (b *LockingBuffer) ReadBytes(delim byte) (line []byte, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.ReadBytes(delim)
}
func (b *LockingBuffer) ReadFrom(r io.Reader) (n int64, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.ReadFrom(r)
}
func (b *LockingBuffer) ReadRune() (r rune, size int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.ReadRune()
}
func (b *LockingBuffer) ReadString(delim byte) (line string, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.ReadString(delim)
}
func (b *LockingBuffer) Reset() {
b.mu.Lock()
defer b.mu.Unlock()
b.b.Reset()
}
func (b *LockingBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.String()
}
func (b *LockingBuffer) Truncate(n int) {
b.mu.Lock()
defer b.mu.Unlock()
b.b.Truncate(n)
}
func (b *LockingBuffer) UnreadByte() error {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.UnreadByte()
}
func (b *LockingBuffer) UnreadRune() error {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.UnreadRune()
}
func (b *LockingBuffer) Write(p []byte) (n int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.Write(p)
}
func (b *LockingBuffer) WriteByte(c byte) error {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.WriteByte(c)
}
func (b *LockingBuffer) WriteRune(r rune) (n int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.WriteRune(r)
}
func (b *LockingBuffer) WriteString(s string) (n int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.WriteString(s)
}
func (b *LockingBuffer) WriteTo(w io.Writer) (n int64, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.b.WriteTo(w)
}

View File

@@ -1,12 +1,13 @@
package testingutil
import (
"io"
"os"
"testing"
)
// MustReadFile reads all data from filename. Fail on error.
func MustReadFile(tb testing.TB, filename string) []byte {
// ReadFile reads all data from filename. Fail on error.
func ReadFile(tb testing.TB, filename string) []byte {
tb.Helper()
b, err := os.ReadFile(filename)
if err != nil {
@@ -15,6 +16,26 @@ func MustReadFile(tb testing.TB, filename string) []byte {
return b
}
// CopyFile copies all data from src to dst. Fail on error.
func CopyFile(tb testing.TB, src, dst string) {
tb.Helper()
r, err := os.Open(src)
if err != nil {
tb.Fatal(err)
}
defer r.Close()
w, err := os.Create(dst)
if err != nil {
tb.Fatal(err)
}
defer w.Close()
if _, err := io.Copy(w, r); err != nil {
tb.Fatal(err)
}
}
// Getpwd returns the working directory. Fail on error.
func Getwd(tb testing.TB) string {
tb.Helper()

View File

@@ -1,8 +1,10 @@
package litestream
import (
"crypto/md5"
"database/sql"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
@@ -43,7 +45,7 @@ var (
var (
// LogWriter is the destination writer for all logging.
LogWriter = os.Stderr
LogWriter = os.Stdout
// LogFlags are the flags passed to log.New().
LogFlags = 0
@@ -460,6 +462,12 @@ func isHexChar(ch rune) bool {
return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f')
}
// md5hash returns a hex-encoded MD5 hash of b.
func md5hash(b []byte) string {
sum := md5.Sum(b)
return hex.EncodeToString(sum[:])
}
// Tracef is used for low-level tracing.
var Tracef = func(format string, a ...interface{}) {}

View File

@@ -43,7 +43,7 @@ type Replica struct {
cancel func()
// Client used to connect to the remote replica.
Client ReplicaClient
client ReplicaClient
// Time between syncs with the shadow WAL.
SyncInterval time.Duration
@@ -68,10 +68,11 @@ type Replica struct {
Logger *log.Logger
}
func NewReplica(db *DB, name string) *Replica {
func NewReplica(db *DB, name string, client ReplicaClient) *Replica {
r := &Replica{
db: db,
name: name,
client: client,
cancel: func() {},
SyncInterval: DefaultSyncInterval,
@@ -82,7 +83,7 @@ func NewReplica(db *DB, name string) *Replica {
prefix := fmt.Sprintf("%s: ", r.Name())
if db != nil {
prefix = fmt.Sprintf("%s(%s): ", db.Path(), r.Name())
prefix = fmt.Sprintf("%s(%s): ", logPrefixPath(db.Path()), r.Name())
}
r.Logger = log.New(LogWriter, prefix, LogFlags)
@@ -91,8 +92,8 @@ func NewReplica(db *DB, name string) *Replica {
// Name returns the name of the replica.
func (r *Replica) Name() string {
if r.name == "" && r.Client != nil {
return r.Client.Type()
if r.name == "" && r.client != nil {
return r.client.Type()
}
return r.name
}
@@ -100,6 +101,9 @@ func (r *Replica) Name() string {
// DB returns a reference to the database the replica is attached to, if any.
func (r *Replica) DB() *DB { return r.db }
// Client returns the client the replica was initialized with.
func (r *Replica) Client() ReplicaClient { return r.client }
// Starts replicating in a background goroutine.
func (r *Replica) Start(ctx context.Context) error {
// Ignore if replica is being used sychronously.
@@ -265,7 +269,7 @@ func (r *Replica) writeIndexSegments(ctx context.Context, segments []WALSegmentI
// Copy through pipe into client from the starting position.
var g errgroup.Group
g.Go(func() error {
_, err := r.Client.WriteWALSegment(ctx, initialPos, pr)
_, err := r.client.WriteWALSegment(ctx, initialPos, pr)
return err
})
@@ -332,7 +336,7 @@ func (r *Replica) writeIndexSegments(ctx context.Context, segments []WALSegmentI
// snapshotN returns the number of snapshots for a generation.
func (r *Replica) snapshotN(generation string) (int, error) {
itr, err := r.Client.Snapshots(context.Background(), generation)
itr, err := r.client.Snapshots(context.Background(), generation)
if err != nil {
return 0, err
}
@@ -364,7 +368,7 @@ func (r *Replica) calcPos(ctx context.Context, generation string) (pos Pos, err
}
// Read segment to determine size to add to offset.
rd, err := r.Client.WALSegmentReader(ctx, segment.Pos())
rd, err := r.client.WALSegmentReader(ctx, segment.Pos())
if err != nil {
return pos, fmt.Errorf("wal segment reader: %w", err)
}
@@ -385,7 +389,7 @@ func (r *Replica) calcPos(ctx context.Context, generation string) (pos Pos, err
// maxSnapshot returns the last snapshot in a generation.
func (r *Replica) maxSnapshot(ctx context.Context, generation string) (*SnapshotInfo, error) {
itr, err := r.Client.Snapshots(ctx, generation)
itr, err := r.client.Snapshots(ctx, generation)
if err != nil {
return nil, err
}
@@ -402,7 +406,7 @@ func (r *Replica) maxSnapshot(ctx context.Context, generation string) (*Snapshot
// maxWALSegment returns the highest WAL segment in a generation.
func (r *Replica) maxWALSegment(ctx context.Context, generation string) (*WALSegmentInfo, error) {
itr, err := r.Client.WALSegments(ctx, generation)
itr, err := r.client.WALSegments(ctx, generation)
if err != nil {
return nil, err
}
@@ -427,7 +431,7 @@ func (r *Replica) Pos() Pos {
// Snapshots returns a list of all snapshots across all generations.
func (r *Replica) Snapshots(ctx context.Context) ([]SnapshotInfo, error) {
generations, err := r.Client.Generations(ctx)
generations, err := r.client.Generations(ctx)
if err != nil {
return nil, fmt.Errorf("cannot fetch generations: %w", err)
}
@@ -435,7 +439,7 @@ func (r *Replica) Snapshots(ctx context.Context) ([]SnapshotInfo, error) {
var a []SnapshotInfo
for _, generation := range generations {
if err := func() error {
itr, err := r.Client.Snapshots(ctx, generation)
itr, err := r.client.Snapshots(ctx, generation)
if err != nil {
return err
}
@@ -518,7 +522,7 @@ func (r *Replica) Snapshot(ctx context.Context) (info SnapshotInfo, err error) {
})
// Delegate write to client & wait for writer goroutine to finish.
if info, err = r.Client.WriteSnapshot(ctx, pos.Generation, pos.Index, pr); err != nil {
if info, err = r.client.WriteSnapshot(ctx, pos.Generation, pos.Index, pr); err != nil {
return info, err
} else if err := g.Wait(); err != nil {
return info, err
@@ -549,7 +553,7 @@ func (r *Replica) EnforceRetention(ctx context.Context) (err error) {
}
// Loop over generations and delete unretained snapshots & WAL files.
generations, err := r.Client.Generations(ctx)
generations, err := r.client.Generations(ctx)
if err != nil {
return fmt.Errorf("generations: %w", err)
}
@@ -559,7 +563,7 @@ func (r *Replica) EnforceRetention(ctx context.Context) (err error) {
// Delete entire generation if no snapshots are being retained.
if snapshot == nil {
if err := r.Client.DeleteGeneration(ctx, generation); err != nil {
if err := r.client.DeleteGeneration(ctx, generation); err != nil {
return fmt.Errorf("delete generation: %w", err)
}
continue
@@ -577,7 +581,7 @@ func (r *Replica) EnforceRetention(ctx context.Context) (err error) {
}
func (r *Replica) deleteSnapshotsBeforeIndex(ctx context.Context, generation string, index int) error {
itr, err := r.Client.Snapshots(ctx, generation)
itr, err := r.client.Snapshots(ctx, generation)
if err != nil {
return fmt.Errorf("fetch snapshots: %w", err)
}
@@ -589,7 +593,7 @@ func (r *Replica) deleteSnapshotsBeforeIndex(ctx context.Context, generation str
continue
}
if err := r.Client.DeleteSnapshot(ctx, info.Generation, info.Index); err != nil {
if err := r.client.DeleteSnapshot(ctx, info.Generation, info.Index); err != nil {
return fmt.Errorf("delete snapshot %s/%08x: %w", info.Generation, info.Index, err)
}
r.Logger.Printf("snapshot deleted %s/%08x", generation, index)
@@ -599,7 +603,7 @@ func (r *Replica) deleteSnapshotsBeforeIndex(ctx context.Context, generation str
}
func (r *Replica) deleteWALSegmentsBeforeIndex(ctx context.Context, generation string, index int) error {
itr, err := r.Client.WALSegments(ctx, generation)
itr, err := r.client.WALSegments(ctx, generation)
if err != nil {
return fmt.Errorf("fetch wal segments: %w", err)
}
@@ -621,7 +625,7 @@ func (r *Replica) deleteWALSegmentsBeforeIndex(ctx context.Context, generation s
return nil
}
if err := r.Client.DeleteWALSegments(ctx, a); err != nil {
if err := r.client.DeleteWALSegments(ctx, a); err != nil {
return fmt.Errorf("delete wal segments: %w", err)
}
@@ -774,7 +778,7 @@ func (r *Replica) Validate(ctx context.Context) error {
}
// Find lastest snapshot that occurs before the index.
snapshotIndex, err := FindSnapshotForIndex(ctx, r.Client, pos.Generation, pos.Index-1)
snapshotIndex, err := FindSnapshotForIndex(ctx, r.client, pos.Generation, pos.Index-1)
if err != nil {
return fmt.Errorf("cannot find snapshot index: %w", err)
}
@@ -784,7 +788,7 @@ func (r *Replica) Validate(ctx context.Context) error {
Logger: log.New(os.Stderr, "", 0),
LogPrefix: r.logPrefix(),
}
if err := Restore(ctx, r.Client, restorePath, pos.Generation, snapshotIndex, pos.Index-1, opt); err != nil {
if err := Restore(ctx, r.client, restorePath, pos.Generation, snapshotIndex, pos.Index-1, opt); err != nil {
return fmt.Errorf("cannot restore: %w", err)
}
@@ -880,7 +884,7 @@ func (r *Replica) waitForReplica(ctx context.Context, pos Pos) error {
func (r *Replica) GenerationCreatedAt(ctx context.Context, generation string) (time.Time, error) {
var min time.Time
itr, err := r.Client.Snapshots(ctx, generation)
itr, err := r.client.Snapshots(ctx, generation)
if err != nil {
return min, err
}
@@ -897,7 +901,7 @@ func (r *Replica) GenerationCreatedAt(ctx context.Context, generation string) (t
// SnapshotIndexAt returns the highest index for a snapshot within a generation
// that occurs before timestamp. If timestamp is zero, returns the latest snapshot.
func (r *Replica) SnapshotIndexAt(ctx context.Context, generation string, timestamp time.Time) (int, error) {
itr, err := r.Client.Snapshots(ctx, generation)
itr, err := r.client.Snapshots(ctx, generation)
if err != nil {
return 0, err
}
@@ -929,7 +933,7 @@ func LatestReplica(ctx context.Context, replicas []*Replica) (*Replica, error) {
var t time.Time
var r *Replica
for i := range replicas {
_, max, err := ReplicaClientTimeBounds(ctx, replicas[i].Client)
_, max, err := ReplicaClientTimeBounds(ctx, replicas[i].client)
if err != nil {
return nil, err
} else if r == nil || max.After(t) {

View File

@@ -14,13 +14,12 @@ import (
func TestReplica_Name(t *testing.T) {
t.Run("WithName", func(t *testing.T) {
if got, want := litestream.NewReplica(nil, "NAME").Name(), "NAME"; got != want {
if got, want := litestream.NewReplica(nil, "NAME", nil).Name(), "NAME"; got != want {
t.Fatalf("Name()=%v, want %v", got, want)
}
})
t.Run("WithoutName", func(t *testing.T) {
r := litestream.NewReplica(nil, "")
r.Client = &mock.ReplicaClient{}
r := litestream.NewReplica(nil, "", &mock.ReplicaClient{})
if got, want := r.Name(), "mock"; got != want {
t.Fatalf("Name()=%v, want %v", got, want)
}
@@ -45,8 +44,7 @@ func TestReplica_Sync(t *testing.T) {
dpos := db.Pos()
c := litestream.NewFileReplicaClient(t.TempDir())
r := litestream.NewReplica(db, "")
r.Client = c
r := litestream.NewReplica(db, "", c)
if err := r.Sync(context.Background()); err != nil {
t.Fatal(err)
@@ -81,8 +79,7 @@ func TestReplica_Snapshot(t *testing.T) {
defer MustCloseDBs(t, db, sqldb)
c := litestream.NewFileReplicaClient(t.TempDir())
r := litestream.NewReplica(db, "")
r.Client = c
r := litestream.NewReplica(db, "", c)
// Execute a query to force a write to the WAL.
if _, err := sqldb.Exec(`CREATE TABLE foo (bar TEXT);`); err != nil {