Refactor Restore()

This commit refactors out the complexity of downloading ordered WAL
files in parallel to a type called `WALDownloader`. This makes it
easier to test the restore separately from the download.
This commit is contained in:
Ben Johnson
2022-01-04 14:47:11 -07:00
parent 531e19ed6f
commit 3f0ec9fa9f
130 changed files with 2943 additions and 1254 deletions

View File

@@ -30,31 +30,31 @@ jobs:
LITESTREAM_SFTP_KEY: ${{secrets.LITESTREAM_SFTP_KEY}} LITESTREAM_SFTP_KEY: ${{secrets.LITESTREAM_SFTP_KEY}}
- name: Run unit tests - name: Run unit tests
run: go test -v ./... run: make testdata && go test -v ./...
- name: Run aws s3 tests - name: Run aws s3 tests
run: go test -v -run=TestReplicaClient . -integration s3 run: go test -v -run=TestReplicaClient ./integration -replica-type s3
env: env:
LITESTREAM_S3_ACCESS_KEY_ID: ${{ secrets.LITESTREAM_S3_ACCESS_KEY_ID }} LITESTREAM_S3_ACCESS_KEY_ID: ${{ secrets.LITESTREAM_S3_ACCESS_KEY_ID }}
LITESTREAM_S3_SECRET_ACCESS_KEY: ${{ secrets.LITESTREAM_S3_SECRET_ACCESS_KEY }} LITESTREAM_S3_SECRET_ACCESS_KEY: ${{ secrets.LITESTREAM_S3_SECRET_ACCESS_KEY }}
LITESTREAM_S3_REGION: ${{ secrets.LITESTREAM_S3_REGION }} LITESTREAM_S3_REGION: us-east-1
LITESTREAM_S3_BUCKET: ${{ secrets.LITESTREAM_S3_BUCKET }} LITESTREAM_S3_BUCKET: integration.litestream.io
- name: Run google cloud storage (gcs) tests - name: Run google cloud storage (gcs) tests
run: go test -v -run=TestReplicaClient . -integration gcs run: go test -v -run=TestReplicaClient ./integration -replica-type gcs
env: env:
GOOGLE_APPLICATION_CREDENTIALS: /opt/gcp.json GOOGLE_APPLICATION_CREDENTIALS: /opt/gcp.json
LITESTREAM_GCS_BUCKET: ${{ secrets.LITESTREAM_GCS_BUCKET }} LITESTREAM_GCS_BUCKET: integration.litestream.io
- name: Run azure blob storage (abs) tests - name: Run azure blob storage (abs) tests
run: go test -v -run=TestReplicaClient . -integration abs run: go test -v -run=TestReplicaClient ./integration -replica-type abs
env: env:
LITESTREAM_ABS_ACCOUNT_NAME: ${{ secrets.LITESTREAM_ABS_ACCOUNT_NAME }} LITESTREAM_ABS_ACCOUNT_NAME: ${{ secrets.LITESTREAM_ABS_ACCOUNT_NAME }}
LITESTREAM_ABS_ACCOUNT_KEY: ${{ secrets.LITESTREAM_ABS_ACCOUNT_KEY }} LITESTREAM_ABS_ACCOUNT_KEY: ${{ secrets.LITESTREAM_ABS_ACCOUNT_KEY }}
LITESTREAM_ABS_BUCKET: ${{ secrets.LITESTREAM_ABS_BUCKET }} LITESTREAM_ABS_BUCKET: integration
- name: Run sftp tests - name: Run sftp tests
run: go test -v -run=TestReplicaClient . -integration sftp run: go test -v -run=TestReplicaClient ./integration -replica-type sftp
env: env:
LITESTREAM_SFTP_HOST: ${{ secrets.LITESTREAM_SFTP_HOST }} LITESTREAM_SFTP_HOST: ${{ secrets.LITESTREAM_SFTP_HOST }}
LITESTREAM_SFTP_USER: ${{ secrets.LITESTREAM_SFTP_USER }} LITESTREAM_SFTP_USER: ${{ secrets.LITESTREAM_SFTP_USER }}

View File

@@ -1,4 +1,9 @@
default: .PHONY: default
default: testdata
.PHONY: testdata
testdata:
make -C testdata
docker: docker:
docker build -t litestream . docker build -t litestream .

View File

@@ -10,12 +10,15 @@ import (
) )
// DatabasesCommand is a command for listing managed databases. // DatabasesCommand is a command for listing managed databases.
type DatabasesCommand struct{} type DatabasesCommand struct {
configPath string
noExpandEnv bool
}
// Run executes the command. // Run executes the command.
func (c *DatabasesCommand) Run(ctx context.Context, args []string) (err error) { func (c *DatabasesCommand) Run(ctx context.Context, args []string) (err error) {
fs := flag.NewFlagSet("litestream-databases", flag.ContinueOnError) fs := flag.NewFlagSet("litestream-databases", flag.ContinueOnError)
configPath, noExpandEnv := registerConfigFlag(fs) registerConfigFlag(fs, &c.configPath, &c.noExpandEnv)
fs.Usage = c.Usage fs.Usage = c.Usage
if err := fs.Parse(args); err != nil { if err := fs.Parse(args); err != nil {
return err return err
@@ -24,10 +27,10 @@ func (c *DatabasesCommand) Run(ctx context.Context, args []string) (err error) {
} }
// Load configuration. // Load configuration.
if *configPath == "" { if c.configPath == "" {
*configPath = DefaultConfigPath() c.configPath = DefaultConfigPath()
} }
config, err := ReadConfigFile(*configPath, !*noExpandEnv) config, err := ReadConfigFile(c.configPath, !c.noExpandEnv)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -13,12 +13,15 @@ import (
) )
// GenerationsCommand represents a command to list all generations for a database. // GenerationsCommand represents a command to list all generations for a database.
type GenerationsCommand struct{} type GenerationsCommand struct {
configPath string
noExpandEnv bool
}
// Run executes the command. // Run executes the command.
func (c *GenerationsCommand) Run(ctx context.Context, args []string) (err error) { func (c *GenerationsCommand) Run(ctx context.Context, args []string) (err error) {
fs := flag.NewFlagSet("litestream-generations", flag.ContinueOnError) fs := flag.NewFlagSet("litestream-generations", flag.ContinueOnError)
configPath, noExpandEnv := registerConfigFlag(fs) registerConfigFlag(fs, &c.configPath, &c.noExpandEnv)
replicaName := fs.String("replica", "", "replica name") replicaName := fs.String("replica", "", "replica name")
fs.Usage = c.Usage fs.Usage = c.Usage
if err := fs.Parse(args); err != nil { if err := fs.Parse(args); err != nil {
@@ -33,19 +36,19 @@ func (c *GenerationsCommand) Run(ctx context.Context, args []string) (err error)
var r *litestream.Replica var r *litestream.Replica
dbUpdatedAt := time.Now() dbUpdatedAt := time.Now()
if isURL(fs.Arg(0)) { if isURL(fs.Arg(0)) {
if *configPath != "" { if c.configPath != "" {
return fmt.Errorf("cannot specify a replica URL and the -config flag") return fmt.Errorf("cannot specify a replica URL and the -config flag")
} }
if r, err = NewReplicaFromConfig(&ReplicaConfig{URL: fs.Arg(0)}, nil); err != nil { if r, err = NewReplicaFromConfig(&ReplicaConfig{URL: fs.Arg(0)}, nil); err != nil {
return err return err
} }
} else { } else {
if *configPath == "" { if c.configPath == "" {
*configPath = DefaultConfigPath() c.configPath = DefaultConfigPath()
} }
// Load configuration. // Load configuration.
config, err := ReadConfigFile(*configPath, !*noExpandEnv) config, err := ReadConfigFile(c.configPath, !c.noExpandEnv)
if err != nil { if err != nil {
return err return err
} }
@@ -93,7 +96,7 @@ func (c *GenerationsCommand) Run(ctx context.Context, args []string) (err error)
// Iterate over each generation for the replica. // Iterate over each generation for the replica.
for _, generation := range generations { for _, generation := range generations {
createdAt, updatedAt, err := r.GenerationTimeBounds(ctx, generation) createdAt, updatedAt, err := litestream.GenerationTimeBounds(ctx, r.Client, generation)
if err != nil { if err != nil {
log.Printf("%s: cannot determine generation time bounds: %s", r.Name(), err) log.Printf("%s: cannot determine generation time bounds: %s", r.Name(), err)
continue continue

View File

@@ -20,7 +20,6 @@ import (
"github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream"
"github.com/benbjohnson/litestream/abs" "github.com/benbjohnson/litestream/abs"
"github.com/benbjohnson/litestream/file"
"github.com/benbjohnson/litestream/gcs" "github.com/benbjohnson/litestream/gcs"
"github.com/benbjohnson/litestream/s3" "github.com/benbjohnson/litestream/s3"
"github.com/benbjohnson/litestream/sftp" "github.com/benbjohnson/litestream/sftp"
@@ -126,7 +125,7 @@ func (m *Main) Run(ctx context.Context, args []string) (err error) {
return err return err
case "restore": case "restore":
return (&RestoreCommand{}).Run(ctx, args) return NewRestoreCommand().Run(ctx, args)
case "snapshots": case "snapshots":
return (&SnapshotsCommand{}).Run(ctx, args) return (&SnapshotsCommand{}).Run(ctx, args)
case "version": case "version":
@@ -383,8 +382,8 @@ func NewReplicaFromConfig(c *ReplicaConfig, db *litestream.DB) (_ *litestream.Re
return r, nil return r, nil
} }
// newFileReplicaClientFromConfig returns a new instance of file.ReplicaClient built from config. // newFileReplicaClientFromConfig returns a new instance of FileReplicaClient built from config.
func newFileReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *file.ReplicaClient, err error) { func newFileReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_ *litestream.FileReplicaClient, err error) {
// Ensure URL & path are not both specified. // Ensure URL & path are not both specified.
if c.URL != "" && c.Path != "" { if c.URL != "" && c.Path != "" {
return nil, fmt.Errorf("cannot specify url & path for file replica") return nil, fmt.Errorf("cannot specify url & path for file replica")
@@ -409,9 +408,7 @@ func newFileReplicaClientFromConfig(c *ReplicaConfig, r *litestream.Replica) (_
} }
// Instantiate replica and apply time fields, if set. // Instantiate replica and apply time fields, if set.
client := file.NewReplicaClient(path) return litestream.NewFileReplicaClient(path), nil
client.Replica = r
return client, nil
} }
// newS3ReplicaClientFromConfig returns a new instance of s3.ReplicaClient built from config. // newS3ReplicaClientFromConfig returns a new instance of s3.ReplicaClient built from config.
@@ -669,9 +666,9 @@ func DefaultConfigPath() string {
return defaultConfigPath return defaultConfigPath
} }
func registerConfigFlag(fs *flag.FlagSet) (configPath *string, noExpandEnv *bool) { func registerConfigFlag(fs *flag.FlagSet, configPath *string, noExpandEnv *bool) {
return fs.String("config", "", "config path"), fs.StringVar(configPath, "config", "", "config path")
fs.Bool("no-expand-env", false, "do not expand env vars in config") fs.BoolVar(noExpandEnv, "no-expand-env", false, "do not expand env vars in config")
} }
// expand returns an absolute path for s. // expand returns an absolute path for s.

View File

@@ -9,7 +9,6 @@ import (
"github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream"
main "github.com/benbjohnson/litestream/cmd/litestream" main "github.com/benbjohnson/litestream/cmd/litestream"
"github.com/benbjohnson/litestream/file"
"github.com/benbjohnson/litestream/gcs" "github.com/benbjohnson/litestream/gcs"
"github.com/benbjohnson/litestream/s3" "github.com/benbjohnson/litestream/s3"
) )
@@ -103,7 +102,7 @@ func TestNewFileReplicaFromConfig(t *testing.T) {
r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{Path: "/foo"}, nil) r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{Path: "/foo"}, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} else if client, ok := r.Client.(*file.ReplicaClient); !ok { } else if client, ok := r.Client.(*litestream.FileReplicaClient); !ok {
t.Fatal("unexpected replica type") t.Fatal("unexpected replica type")
} else if got, want := client.Path(), "/foo"; got != want { } else if got, want := client.Path(), "/foo"; got != want {
t.Fatalf("Path=%s, want %s", got, want) t.Fatalf("Path=%s, want %s", got, want)

View File

@@ -13,7 +13,6 @@ import (
"github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream"
"github.com/benbjohnson/litestream/abs" "github.com/benbjohnson/litestream/abs"
"github.com/benbjohnson/litestream/file"
"github.com/benbjohnson/litestream/gcs" "github.com/benbjohnson/litestream/gcs"
"github.com/benbjohnson/litestream/s3" "github.com/benbjohnson/litestream/s3"
"github.com/benbjohnson/litestream/sftp" "github.com/benbjohnson/litestream/sftp"
@@ -23,6 +22,9 @@ import (
// ReplicateCommand represents a command that continuously replicates SQLite databases. // ReplicateCommand represents a command that continuously replicates SQLite databases.
type ReplicateCommand struct { type ReplicateCommand struct {
configPath string
noExpandEnv bool
cmd *exec.Cmd // subcommand cmd *exec.Cmd // subcommand
execCh chan error // subcommand error channel execCh chan error // subcommand error channel
@@ -42,7 +44,7 @@ func NewReplicateCommand() *ReplicateCommand {
func (c *ReplicateCommand) ParseFlags(ctx context.Context, args []string) (err error) { func (c *ReplicateCommand) ParseFlags(ctx context.Context, args []string) (err error) {
fs := flag.NewFlagSet("litestream-replicate", flag.ContinueOnError) fs := flag.NewFlagSet("litestream-replicate", flag.ContinueOnError)
execFlag := fs.String("exec", "", "execute subcommand") execFlag := fs.String("exec", "", "execute subcommand")
configPath, noExpandEnv := registerConfigFlag(fs) registerConfigFlag(fs, &c.configPath, &c.noExpandEnv)
fs.Usage = c.Usage fs.Usage = c.Usage
if err := fs.Parse(args); err != nil { if err := fs.Parse(args); err != nil {
return err return err
@@ -52,7 +54,7 @@ func (c *ReplicateCommand) ParseFlags(ctx context.Context, args []string) (err e
if fs.NArg() == 1 { if fs.NArg() == 1 {
return fmt.Errorf("must specify at least one replica URL for %s", fs.Arg(0)) return fmt.Errorf("must specify at least one replica URL for %s", fs.Arg(0))
} else if fs.NArg() > 1 { } else if fs.NArg() > 1 {
if *configPath != "" { if c.configPath != "" {
return fmt.Errorf("cannot specify a replica URL and the -config flag") return fmt.Errorf("cannot specify a replica URL and the -config flag")
} }
@@ -66,10 +68,10 @@ func (c *ReplicateCommand) ParseFlags(ctx context.Context, args []string) (err e
} }
c.Config.DBs = []*DBConfig{dbConfig} c.Config.DBs = []*DBConfig{dbConfig}
} else { } else {
if *configPath == "" { if c.configPath == "" {
*configPath = DefaultConfigPath() c.configPath = DefaultConfigPath()
} }
if c.Config, err = ReadConfigFile(*configPath, !*noExpandEnv); err != nil { if c.Config, err = ReadConfigFile(c.configPath, !c.noExpandEnv); err != nil {
return err return err
} }
} }
@@ -110,7 +112,7 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) {
log.Printf("initialized db: %s", db.Path()) log.Printf("initialized db: %s", db.Path())
for _, r := range db.Replicas { for _, r := range db.Replicas {
switch client := r.Client.(type) { switch client := r.Client.(type) {
case *file.ReplicaClient: case *litestream.FileReplicaClient:
log.Printf("replicating to: name=%q type=%q path=%q", r.Name(), client.Type(), client.Path()) log.Printf("replicating to: name=%q type=%q path=%q", r.Name(), client.Type(), client.Path())
case *s3.ReplicaClient: case *s3.ReplicaClient:
log.Printf("replicating to: name=%q type=%q bucket=%q path=%q region=%q endpoint=%q sync-interval=%s", r.Name(), client.Type(), client.Bucket, client.Path, client.Region, client.Endpoint, r.SyncInterval) log.Printf("replicating to: name=%q type=%q bucket=%q path=%q region=%q endpoint=%q sync-interval=%s", r.Name(), client.Type(), client.Bucket, client.Path, client.Region, client.Endpoint, r.SyncInterval)

View File

@@ -7,31 +7,46 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"path/filepath"
"strconv" "strconv"
"time"
"github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream"
) )
// RestoreCommand represents a command to restore a database from a backup. // RestoreCommand represents a command to restore a database from a backup.
type RestoreCommand struct{} type RestoreCommand struct {
snapshotIndex int // index of snapshot to start from
// CLI options
configPath string // path to config file
noExpandEnv bool // if true, do not expand env variables in config
outputPath string // path to restore database to
replicaName string // optional, name of replica to restore from
generation string // optional, generation to restore
targetIndex int // optional, last WAL index to replay
ifDBNotExists bool // if true, skips restore if output path already exists
ifReplicaExists bool // if true, skips if no backups exist
opt litestream.RestoreOptions
}
func NewRestoreCommand() *RestoreCommand {
return &RestoreCommand{
targetIndex: -1,
opt: litestream.NewRestoreOptions(),
}
}
// Run executes the command. // Run executes the command.
func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) { func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) {
opt := litestream.NewRestoreOptions()
opt.Verbose = true
fs := flag.NewFlagSet("litestream-restore", flag.ContinueOnError) fs := flag.NewFlagSet("litestream-restore", flag.ContinueOnError)
configPath, noExpandEnv := registerConfigFlag(fs) registerConfigFlag(fs, &c.configPath, &c.noExpandEnv)
fs.StringVar(&opt.OutputPath, "o", "", "output path") fs.StringVar(&c.outputPath, "o", "", "output path")
fs.StringVar(&opt.ReplicaName, "replica", "", "replica name") fs.StringVar(&c.replicaName, "replica", "", "replica name")
fs.StringVar(&opt.Generation, "generation", "", "generation name") fs.StringVar(&c.generation, "generation", "", "generation name")
fs.Var((*indexVar)(&opt.Index), "index", "wal index") fs.Var((*indexVar)(&c.targetIndex), "index", "wal index")
fs.IntVar(&opt.Parallelism, "parallelism", opt.Parallelism, "parallelism") fs.IntVar(&c.opt.Parallelism, "parallelism", c.opt.Parallelism, "parallelism")
ifDBNotExists := fs.Bool("if-db-not-exists", false, "") fs.BoolVar(&c.ifDBNotExists, "if-db-not-exists", false, "")
ifReplicaExists := fs.Bool("if-replica-exists", false, "") fs.BoolVar(&c.ifReplicaExists, "if-replica-exists", false, "")
timestampStr := fs.String("timestamp", "", "timestamp")
verbose := fs.Bool("v", false, "verbose output")
fs.Usage = c.Usage fs.Usage = c.Usage
if err := fs.Parse(args); err != nil { if err := fs.Parse(args); err != nil {
return err return err
@@ -40,83 +55,100 @@ func (c *RestoreCommand) Run(ctx context.Context, args []string) (err error) {
} else if fs.NArg() > 1 { } else if fs.NArg() > 1 {
return fmt.Errorf("too many arguments") return fmt.Errorf("too many arguments")
} }
arg := fs.Arg(0)
// Parse timestamp, if specified. // Ensure a generation is specified if target index is specified.
if *timestampStr != "" { if c.targetIndex != -1 && c.generation == "" {
if opt.Timestamp, err = time.Parse(time.RFC3339, *timestampStr); err != nil { return fmt.Errorf("must specify -generation when using -index flag")
return errors.New("invalid -timestamp, must specify in ISO 8601 format (e.g. 2000-01-01T00:00:00Z)")
}
} }
// Instantiate logger if verbose output is enabled. // Default to original database path if output path not specified.
if *verbose { if !isURL(arg) && c.outputPath == "" {
opt.Logger = log.New(os.Stderr, "", log.LstdFlags|log.Lmicroseconds) c.outputPath = arg
} }
// Determine replica & generation to restore from. // Exit successfully if the output file already exists and flag is set.
var r *litestream.Replica if _, err := os.Stat(c.outputPath); !os.IsNotExist(err) && c.ifDBNotExists {
if isURL(fs.Arg(0)) {
if *configPath != "" {
return fmt.Errorf("cannot specify a replica URL and the -config flag")
}
if r, err = c.loadFromURL(ctx, fs.Arg(0), *ifDBNotExists, &opt); err == errSkipDBExists {
fmt.Println("database already exists, skipping") fmt.Println("database already exists, skipping")
return nil return nil
} else if err != nil {
return err
}
} else {
if *configPath == "" {
*configPath = DefaultConfigPath()
}
if r, err = c.loadFromConfig(ctx, fs.Arg(0), *configPath, !*noExpandEnv, *ifDBNotExists, &opt); err == errSkipDBExists {
fmt.Println("database already exists, skipping")
return nil
} else if err != nil {
return err
}
} }
// Create parent directory if it doesn't already exist.
if err := os.MkdirAll(filepath.Dir(c.outputPath), 0700); err != nil {
return fmt.Errorf("cannot create parent directory: %w", err)
}
// Build replica from either a URL or config.
r, err := c.loadReplica(ctx, arg)
if err != nil {
return err
}
// Determine latest generation if one is not specified.
if c.generation == "" {
if c.generation, err = litestream.FindLatestGeneration(ctx, r.Client); err == litestream.ErrNoGeneration {
// Return an error if no matching targets found. // Return an error if no matching targets found.
// If optional flag set, return success. Useful for automated recovery. // If optional flag set, return success. Useful for automated recovery.
if opt.Generation == "" { if c.ifReplicaExists {
if *ifReplicaExists {
fmt.Println("no matching backups found") fmt.Println("no matching backups found")
return nil return nil
} }
return fmt.Errorf("no matching backups found") return fmt.Errorf("no matching backups found")
} else if err != nil {
return fmt.Errorf("cannot determine latest generation: %w", err)
}
} }
return r.Restore(ctx, opt) // Determine the maximum available index for the generation if one is not specified.
if c.targetIndex == -1 {
if c.targetIndex, err = litestream.FindMaxIndexByGeneration(ctx, r.Client, c.generation); err != nil {
return fmt.Errorf("cannot determine latest index in generation %q: %w", c.generation, err)
}
} }
// loadFromURL creates a replica & updates the restore options from a replica URL. // Find lastest snapshot that occurs before the index.
func (c *RestoreCommand) loadFromURL(ctx context.Context, replicaURL string, ifDBNotExists bool, opt *litestream.RestoreOptions) (*litestream.Replica, error) { // TODO: Optionally allow -snapshot-index
if opt.OutputPath == "" { if c.snapshotIndex, err = litestream.FindSnapshotForIndex(ctx, r.Client, c.generation, c.targetIndex); err != nil {
return fmt.Errorf("cannot find snapshot index: %w", err)
}
c.opt.Logger = log.New(os.Stderr, "", log.LstdFlags|log.Lmicroseconds)
return litestream.Restore(ctx, r.Client, c.outputPath, c.generation, c.snapshotIndex, c.targetIndex, c.opt)
}
func (c *RestoreCommand) loadReplica(ctx context.Context, arg string) (*litestream.Replica, error) {
if isURL(arg) {
return c.loadReplicaFromURL(ctx, arg)
}
return c.loadReplicaFromConfig(ctx, arg)
}
// loadReplicaFromURL creates a replica & updates the restore options from a replica URL.
func (c *RestoreCommand) loadReplicaFromURL(ctx context.Context, replicaURL string) (*litestream.Replica, error) {
if c.configPath != "" {
return nil, fmt.Errorf("cannot specify a replica URL and the -config flag")
} else if c.replicaName != "" {
return nil, fmt.Errorf("cannot specify a replica URL and the -replica flag")
} else if c.outputPath == "" {
return nil, fmt.Errorf("output path required") return nil, fmt.Errorf("output path required")
} }
// Exit successfully if the output file already exists.
if _, err := os.Stat(opt.OutputPath); !os.IsNotExist(err) && ifDBNotExists {
return nil, errSkipDBExists
}
syncInterval := litestream.DefaultSyncInterval syncInterval := litestream.DefaultSyncInterval
r, err := NewReplicaFromConfig(&ReplicaConfig{ return NewReplicaFromConfig(&ReplicaConfig{
URL: replicaURL, URL: replicaURL,
SyncInterval: &syncInterval, SyncInterval: &syncInterval,
}, nil) }, nil)
if err != nil {
return nil, err
}
opt.Generation, _, err = r.CalcRestoreTarget(ctx, *opt)
return r, err
} }
// loadFromConfig returns a replica & updates the restore options from a DB reference. // loadReplicaFromConfig returns replicas based on the specific config path.
func (c *RestoreCommand) loadFromConfig(ctx context.Context, dbPath, configPath string, expandEnv, ifDBNotExists bool, opt *litestream.RestoreOptions) (*litestream.Replica, error) { func (c *RestoreCommand) loadReplicaFromConfig(ctx context.Context, dbPath string) (*litestream.Replica, error) {
if c.configPath == "" {
c.configPath = DefaultConfigPath()
}
// Load configuration. // Load configuration.
config, err := ReadConfigFile(configPath, expandEnv) config, err := ReadConfigFile(c.configPath, !c.noExpandEnv)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -132,25 +164,34 @@ func (c *RestoreCommand) loadFromConfig(ctx context.Context, dbPath, configPath
db, err := NewDBFromConfig(dbConfig) db, err := NewDBFromConfig(dbConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} else if len(db.Replicas) == 0 {
return nil, fmt.Errorf("database has no replicas: %s", dbPath)
} }
// Restore into original database path if not specified. // Filter by replica name if specified.
if opt.OutputPath == "" { if c.replicaName != "" {
opt.OutputPath = dbPath r := db.Replica(c.replicaName)
if r == nil {
return nil, fmt.Errorf("replica %q not found", c.replicaName)
}
return r, nil
} }
// Exit successfully if the output file already exists. // Choose only replica if only one available and no name is specified.
if _, err := os.Stat(opt.OutputPath); !os.IsNotExist(err) && ifDBNotExists { if len(db.Replicas) == 1 {
return nil, errSkipDBExists return db.Replicas[0], nil
} }
// Determine the appropriate replica & generation to restore from, // A replica must be specified when restoring a specific generation with multiple replicas.
r, generation, err := db.CalcRestoreTarget(ctx, *opt) if c.generation != "" {
return nil, fmt.Errorf("must specify -replica when restoring from a specific generation")
}
// Determine latest replica to restore from.
r, err := litestream.LatestReplica(ctx, db.Replicas)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("cannot determine latest replica: %w", err)
} }
opt.Generation = generation
return r, nil return r, nil
} }
@@ -186,10 +227,6 @@ Arguments:
Restore up to a specific hex-encoded WAL index (inclusive). Restore up to a specific hex-encoded WAL index (inclusive).
Defaults to use the highest available index. Defaults to use the highest available index.
-timestamp TIMESTAMP
Restore to a specific point-in-time.
Defaults to use the latest available backup.
-o PATH -o PATH
Output path of the restored database. Output path of the restored database.
Defaults to original DB path. Defaults to original DB path.
@@ -213,9 +250,6 @@ Examples:
# Restore latest replica for database to original location. # Restore latest replica for database to original location.
$ litestream restore /path/to/db $ litestream restore /path/to/db
# Restore replica for database to a given point in time.
$ litestream restore -timestamp 2020-01-01T00:00:00Z /path/to/db
# Restore latest replica for database to new /tmp directory # Restore latest replica for database to new /tmp directory
$ litestream restore -o /tmp/db /path/to/db $ litestream restore -o /tmp/db /path/to/db

View File

@@ -14,12 +14,15 @@ import (
) )
// SnapshotsCommand represents a command to list snapshots for a command. // SnapshotsCommand represents a command to list snapshots for a command.
type SnapshotsCommand struct{} type SnapshotsCommand struct {
configPath string
noExpandEnv bool
}
// Run executes the command. // Run executes the command.
func (c *SnapshotsCommand) Run(ctx context.Context, args []string) (err error) { func (c *SnapshotsCommand) Run(ctx context.Context, args []string) (err error) {
fs := flag.NewFlagSet("litestream-snapshots", flag.ContinueOnError) fs := flag.NewFlagSet("litestream-snapshots", flag.ContinueOnError)
configPath, noExpandEnv := registerConfigFlag(fs) registerConfigFlag(fs, &c.configPath, &c.noExpandEnv)
replicaName := fs.String("replica", "", "replica name") replicaName := fs.String("replica", "", "replica name")
fs.Usage = c.Usage fs.Usage = c.Usage
if err := fs.Parse(args); err != nil { if err := fs.Parse(args); err != nil {
@@ -33,19 +36,19 @@ func (c *SnapshotsCommand) Run(ctx context.Context, args []string) (err error) {
var db *litestream.DB var db *litestream.DB
var r *litestream.Replica var r *litestream.Replica
if isURL(fs.Arg(0)) { if isURL(fs.Arg(0)) {
if *configPath != "" { if c.configPath != "" {
return fmt.Errorf("cannot specify a replica URL and the -config flag") return fmt.Errorf("cannot specify a replica URL and the -config flag")
} }
if r, err = NewReplicaFromConfig(&ReplicaConfig{URL: fs.Arg(0)}, nil); err != nil { if r, err = NewReplicaFromConfig(&ReplicaConfig{URL: fs.Arg(0)}, nil); err != nil {
return err return err
} }
} else { } else {
if *configPath == "" { if c.configPath == "" {
*configPath = DefaultConfigPath() c.configPath = DefaultConfigPath()
} }
// Load configuration. // Load configuration.
config, err := ReadConfigFile(*configPath, !*noExpandEnv) config, err := ReadConfigFile(c.configPath, !c.noExpandEnv)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -13,12 +13,15 @@ import (
) )
// WALCommand represents a command to list WAL files for a database. // WALCommand represents a command to list WAL files for a database.
type WALCommand struct{} type WALCommand struct {
configPath string
noExpandEnv bool
}
// Run executes the command. // Run executes the command.
func (c *WALCommand) Run(ctx context.Context, args []string) (err error) { func (c *WALCommand) Run(ctx context.Context, args []string) (err error) {
fs := flag.NewFlagSet("litestream-wal", flag.ContinueOnError) fs := flag.NewFlagSet("litestream-wal", flag.ContinueOnError)
configPath, noExpandEnv := registerConfigFlag(fs) registerConfigFlag(fs, &c.configPath, &c.noExpandEnv)
replicaName := fs.String("replica", "", "replica name") replicaName := fs.String("replica", "", "replica name")
generation := fs.String("generation", "", "generation name") generation := fs.String("generation", "", "generation name")
fs.Usage = c.Usage fs.Usage = c.Usage
@@ -33,19 +36,19 @@ func (c *WALCommand) Run(ctx context.Context, args []string) (err error) {
var db *litestream.DB var db *litestream.DB
var r *litestream.Replica var r *litestream.Replica
if isURL(fs.Arg(0)) { if isURL(fs.Arg(0)) {
if *configPath != "" { if c.configPath != "" {
return fmt.Errorf("cannot specify a replica URL and the -config flag") return fmt.Errorf("cannot specify a replica URL and the -config flag")
} }
if r, err = NewReplicaFromConfig(&ReplicaConfig{URL: fs.Arg(0)}, nil); err != nil { if r, err = NewReplicaFromConfig(&ReplicaConfig{URL: fs.Arg(0)}, nil); err != nil {
return err return err
} }
} else { } else {
if *configPath == "" { if c.configPath == "" {
*configPath = DefaultConfigPath() c.configPath = DefaultConfigPath()
} }
// Load configuration. // Load configuration.
config, err := ReadConfigFile(*configPath, !*noExpandEnv) config, err := ReadConfigFile(c.configPath, !c.noExpandEnv)
if err != nil { if err != nil {
return err return err
} }

118
db.go
View File

@@ -12,7 +12,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"math"
"math/rand" "math/rand"
"os" "os"
"path/filepath" "path/filepath"
@@ -62,8 +61,9 @@ type DB struct {
chksum0, chksum1 uint32 chksum0, chksum1 uint32
byteOrder binary.ByteOrder byteOrder binary.ByteOrder
fileInfo os.FileInfo // db info cached during init fileMode os.FileMode // db mode cached during init
dirInfo os.FileInfo // parent dir info cached during init dirMode os.FileMode // parent dir mode cached during init
uid, gid int // db user & group id cached during init
ctx context.Context ctx context.Context
cancel func() cancel func()
@@ -180,16 +180,6 @@ func (db *DB) ShadowWALDir(generation string) string {
return filepath.Join(db.GenerationPath(generation), "wal") return filepath.Join(db.GenerationPath(generation), "wal")
} }
// FileInfo returns the cached file stats for the database file when it was initialized.
func (db *DB) FileInfo() os.FileInfo {
return db.fileInfo
}
// DirInfo returns the cached file stats for the parent directory of the database file when it was initialized.
func (db *DB) DirInfo() os.FileInfo {
return db.dirInfo
}
// Replica returns a replica by name. // Replica returns a replica by name.
func (db *DB) Replica(name string) *Replica { func (db *DB) Replica(name string) *Replica {
for _, r := range db.Replicas { for _, r := range db.Replicas {
@@ -505,13 +495,14 @@ func (db *DB) init() (err error) {
} else if err != nil { } else if err != nil {
return err return err
} }
db.fileInfo = fi db.fileMode = fi.Mode()
db.uid, db.gid = internal.Fileinfo(fi)
// Obtain permissions for parent directory. // Obtain permissions for parent directory.
if fi, err = os.Stat(filepath.Dir(db.path)); err != nil { if fi, err = os.Stat(filepath.Dir(db.path)); err != nil {
return err return err
} }
db.dirInfo = fi db.dirMode = fi.Mode()
dsn := db.path dsn := db.path
dsn += fmt.Sprintf("?_busy_timeout=%d", BusyTimeout.Milliseconds()) dsn += fmt.Sprintf("?_busy_timeout=%d", BusyTimeout.Milliseconds())
@@ -577,7 +568,7 @@ func (db *DB) init() (err error) {
} }
// Ensure meta directory structure exists. // Ensure meta directory structure exists.
if err := internal.MkdirAll(db.MetaPath(), db.dirInfo); err != nil { if err := internal.MkdirAll(db.MetaPath(), db.dirMode, db.uid, db.gid); err != nil {
return err return err
} }
@@ -785,7 +776,7 @@ func (db *DB) createGeneration(ctx context.Context) (string, error) {
// Generate new directory. // Generate new directory.
dir := filepath.Join(db.MetaPath(), "generations", generation) dir := filepath.Join(db.MetaPath(), "generations", generation)
if err := internal.MkdirAll(dir, db.dirInfo); err != nil { if err := internal.MkdirAll(dir, db.dirMode, db.uid, db.gid); err != nil {
return "", err return "", err
} }
@@ -796,15 +787,10 @@ func (db *DB) createGeneration(ctx context.Context) (string, error) {
// Atomically write generation name as current generation. // Atomically write generation name as current generation.
generationNamePath := db.GenerationNamePath() generationNamePath := db.GenerationNamePath()
mode := os.FileMode(0600) if err := os.WriteFile(generationNamePath+".tmp", []byte(generation+"\n"), db.fileMode); err != nil {
if db.fileInfo != nil {
mode = db.fileInfo.Mode()
}
if err := os.WriteFile(generationNamePath+".tmp", []byte(generation+"\n"), mode); err != nil {
return "", fmt.Errorf("write generation temp file: %w", err) return "", fmt.Errorf("write generation temp file: %w", err)
} }
uid, gid := internal.Fileinfo(db.fileInfo) _ = os.Chown(generationNamePath+".tmp", db.uid, db.gid)
_ = os.Chown(generationNamePath+".tmp", uid, gid)
if err := os.Rename(generationNamePath+".tmp", generationNamePath); err != nil { if err := os.Rename(generationNamePath+".tmp", generationNamePath); err != nil {
return "", fmt.Errorf("rename generation file: %w", err) return "", fmt.Errorf("rename generation file: %w", err)
} }
@@ -1086,7 +1072,7 @@ func (db *DB) copyToShadowWAL(ctx context.Context) error {
tempFilename := filepath.Join(db.ShadowWALDir(pos.Generation), FormatIndex(pos.Index), FormatOffset(pos.Offset)+".wal.tmp") tempFilename := filepath.Join(db.ShadowWALDir(pos.Generation), FormatIndex(pos.Index), FormatOffset(pos.Offset)+".wal.tmp")
defer os.Remove(tempFilename) defer os.Remove(tempFilename)
f, err := internal.CreateFile(tempFilename, db.fileInfo) f, err := internal.CreateFile(tempFilename, db.fileMode, db.uid, db.gid)
if err != nil { if err != nil {
return err return err
} }
@@ -1214,12 +1200,12 @@ func (db *DB) writeWALSegment(ctx context.Context, pos Pos, rd io.Reader) error
filename := filepath.Join(db.ShadowWALDir(pos.Generation), FormatIndex(pos.Index), FormatOffset(pos.Offset)+".wal.lz4") filename := filepath.Join(db.ShadowWALDir(pos.Generation), FormatIndex(pos.Index), FormatOffset(pos.Offset)+".wal.lz4")
// Ensure parent directory exists. // Ensure parent directory exists.
if err := internal.MkdirAll(filepath.Dir(filename), db.dirInfo); err != nil { if err := internal.MkdirAll(filepath.Dir(filename), db.dirMode, db.uid, db.gid); err != nil {
return err return err
} }
// Write WAL segment to temporary file next to destination path. // Write WAL segment to temporary file next to destination path.
f, err := internal.CreateFile(filename+".tmp", db.fileInfo) f, err := internal.CreateFile(filename+".tmp", db.fileMode, db.uid, db.gid)
if err != nil { if err != nil {
return err return err
} }
@@ -1542,39 +1528,10 @@ func (db *DB) monitor() {
} }
} }
// CalcRestoreTarget returns a replica & generation to restore from based on opt criteria. // ApplyWAL performs a truncating checkpoint on the given database.
func (db *DB) CalcRestoreTarget(ctx context.Context, opt RestoreOptions) (*Replica, string, error) { func ApplyWAL(ctx context.Context, dbPath, walPath string) error {
var target struct {
replica *Replica
generation string
updatedAt time.Time
}
for _, r := range db.Replicas {
// Skip replica if it does not match filter.
if opt.ReplicaName != "" && r.Name() != opt.ReplicaName {
continue
}
generation, updatedAt, err := r.CalcRestoreTarget(ctx, opt)
if err != nil {
return nil, "", err
}
// Use the latest replica if we have multiple candidates.
if !updatedAt.After(target.updatedAt) {
continue
}
target.replica, target.generation, target.updatedAt = r, generation, updatedAt
}
return target.replica, target.generation, nil
}
// applyWAL performs a truncating checkpoint on the given database.
func applyWAL(ctx context.Context, index int, dbPath string) error {
// Copy WAL file from it's staging path to the correct "-wal" location. // Copy WAL file from it's staging path to the correct "-wal" location.
if err := os.Rename(fmt.Sprintf("%s-%08x-wal", dbPath, index), dbPath+"-wal"); err != nil { if err := os.Rename(walPath, dbPath+"-wal"); err != nil {
return err return err
} }
@@ -1583,7 +1540,7 @@ func applyWAL(ctx context.Context, index int, dbPath string) error {
if err != nil { if err != nil {
return err return err
} }
defer d.Close() defer func() { _ = d.Close() }()
var row [3]int var row [3]int
if err := d.QueryRow(`PRAGMA wal_checkpoint(TRUNCATE);`).Scan(&row[0], &row[1], &row[2]); err != nil { if err := d.QueryRow(`PRAGMA wal_checkpoint(TRUNCATE);`).Scan(&row[0], &row[1], &row[2]); err != nil {
@@ -1660,47 +1617,6 @@ func formatWALPath(index int) string {
var walPathRegex = regexp.MustCompile(`^([0-9a-f]{8})\.wal$`) var walPathRegex = regexp.MustCompile(`^([0-9a-f]{8})\.wal$`)
// DefaultRestoreParallelism is the default parallelism when downloading WAL files.
const DefaultRestoreParallelism = 8
// RestoreOptions represents options for DB.Restore().
type RestoreOptions struct {
// Target path to restore into.
// If blank, the original DB path is used.
OutputPath string
// Specific replica to restore from.
// If blank, all replicas are considered.
ReplicaName string
// Specific generation to restore from.
// If blank, all generations considered.
Generation string
// Specific index to restore from.
// Set to math.MaxInt32 to ignore index.
Index int
// Point-in-time to restore database.
// If zero, database restore to most recent state available.
Timestamp time.Time
// Specifies how many WAL files are downloaded in parallel during restore.
Parallelism int
// Logging settings.
Logger *log.Logger
Verbose bool
}
// NewRestoreOptions returns a new instance of RestoreOptions with defaults.
func NewRestoreOptions() RestoreOptions {
return RestoreOptions{
Index: math.MaxInt32,
Parallelism: DefaultRestoreParallelism,
}
}
// ReadWALFields iterates over the header & frames in the WAL data in r. // ReadWALFields iterates over the header & frames in the WAL data in r.
// Returns salt, checksum, byte order & the last frame. WAL data must start // Returns salt, checksum, byte order & the last frame. WAL data must start
// from the beginning of the WAL header and must end on either the WAL header // from the beginning of the WAL header and must end on either the WAL header

View File

@@ -1,4 +1,4 @@
package file package litestream
import ( import (
"context" "context"
@@ -10,49 +10,46 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/benbjohnson/litestream"
"github.com/benbjohnson/litestream/internal" "github.com/benbjohnson/litestream/internal"
) )
// ReplicaClientType is the client type for this package. // FileReplicaClientType is the client type for file replica clients.
const ReplicaClientType = "file" const FileReplicaClientType = "file"
var _ litestream.ReplicaClient = (*ReplicaClient)(nil) var _ ReplicaClient = (*FileReplicaClient)(nil)
// ReplicaClient is a client for writing snapshots & WAL segments to disk. // FileReplicaClient is a client for writing snapshots & WAL segments to disk.
type ReplicaClient struct { type FileReplicaClient struct {
path string // destination path path string // destination path
Replica *litestream.Replica // File info
FileMode os.FileMode
DirMode os.FileMode
Uid, Gid int
} }
// NewReplicaClient returns a new instance of ReplicaClient. // NewFileReplicaClient returns a new instance of FileReplicaClient.
func NewReplicaClient(path string) *ReplicaClient { func NewFileReplicaClient(path string) *FileReplicaClient {
return &ReplicaClient{ return &FileReplicaClient{
path: path, path: path,
}
}
// db returns the database, if available. FileMode: 0600,
func (c *ReplicaClient) db() *litestream.DB { DirMode: 0700,
if c.Replica == nil {
return nil
} }
return c.Replica.DB()
} }
// Type returns "file" as the client type. // Type returns "file" as the client type.
func (c *ReplicaClient) Type() string { func (c *FileReplicaClient) Type() string {
return ReplicaClientType return FileReplicaClientType
} }
// Path returns the destination path to replicate the database to. // Path returns the destination path to replicate the database to.
func (c *ReplicaClient) Path() string { func (c *FileReplicaClient) Path() string {
return c.path return c.path
} }
// GenerationsDir returns the path to a generation root directory. // GenerationsDir returns the path to a generation root directory.
func (c *ReplicaClient) GenerationsDir() (string, error) { func (c *FileReplicaClient) GenerationsDir() (string, error) {
if c.path == "" { if c.path == "" {
return "", fmt.Errorf("file replica path required") return "", fmt.Errorf("file replica path required")
} }
@@ -60,7 +57,7 @@ func (c *ReplicaClient) GenerationsDir() (string, error) {
} }
// GenerationDir returns the path to a generation's root directory. // GenerationDir returns the path to a generation's root directory.
func (c *ReplicaClient) GenerationDir(generation string) (string, error) { func (c *FileReplicaClient) GenerationDir(generation string) (string, error) {
dir, err := c.GenerationsDir() dir, err := c.GenerationsDir()
if err != nil { if err != nil {
return "", err return "", err
@@ -71,7 +68,7 @@ func (c *ReplicaClient) GenerationDir(generation string) (string, error) {
} }
// SnapshotsDir returns the path to a generation's snapshot directory. // SnapshotsDir returns the path to a generation's snapshot directory.
func (c *ReplicaClient) SnapshotsDir(generation string) (string, error) { func (c *FileReplicaClient) SnapshotsDir(generation string) (string, error) {
dir, err := c.GenerationDir(generation) dir, err := c.GenerationDir(generation)
if err != nil { if err != nil {
return "", err return "", err
@@ -80,16 +77,16 @@ func (c *ReplicaClient) SnapshotsDir(generation string) (string, error) {
} }
// SnapshotPath returns the path to an uncompressed snapshot file. // SnapshotPath returns the path to an uncompressed snapshot file.
func (c *ReplicaClient) SnapshotPath(generation string, index int) (string, error) { func (c *FileReplicaClient) SnapshotPath(generation string, index int) (string, error) {
dir, err := c.SnapshotsDir(generation) dir, err := c.SnapshotsDir(generation)
if err != nil { if err != nil {
return "", err return "", err
} }
return filepath.Join(dir, litestream.FormatIndex(index)+".snapshot.lz4"), nil return filepath.Join(dir, FormatIndex(index)+".snapshot.lz4"), nil
} }
// WALDir returns the path to a generation's WAL directory // WALDir returns the path to a generation's WAL directory
func (c *ReplicaClient) WALDir(generation string) (string, error) { func (c *FileReplicaClient) WALDir(generation string) (string, error) {
dir, err := c.GenerationDir(generation) dir, err := c.GenerationDir(generation)
if err != nil { if err != nil {
return "", err return "", err
@@ -98,16 +95,16 @@ func (c *ReplicaClient) WALDir(generation string) (string, error) {
} }
// WALSegmentPath returns the path to a WAL segment file. // WALSegmentPath returns the path to a WAL segment file.
func (c *ReplicaClient) WALSegmentPath(generation string, index int, offset int64) (string, error) { func (c *FileReplicaClient) WALSegmentPath(generation string, index int, offset int64) (string, error) {
dir, err := c.WALDir(generation) dir, err := c.WALDir(generation)
if err != nil { if err != nil {
return "", err return "", err
} }
return filepath.Join(dir, litestream.FormatIndex(index), fmt.Sprintf("%08x.wal.lz4", offset)), nil return filepath.Join(dir, FormatIndex(index), fmt.Sprintf("%08x.wal.lz4", offset)), nil
} }
// Generations returns a list of available generation names. // Generations returns a list of available generation names.
func (c *ReplicaClient) Generations(ctx context.Context) ([]string, error) { func (c *FileReplicaClient) Generations(ctx context.Context) ([]string, error) {
root, err := c.GenerationsDir() root, err := c.GenerationsDir()
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot determine generations path: %w", err) return nil, fmt.Errorf("cannot determine generations path: %w", err)
@@ -122,7 +119,7 @@ func (c *ReplicaClient) Generations(ctx context.Context) ([]string, error) {
var generations []string var generations []string
for _, fi := range fis { for _, fi := range fis {
if !litestream.IsGenerationName(fi.Name()) { if !IsGenerationName(fi.Name()) {
continue continue
} else if !fi.IsDir() { } else if !fi.IsDir() {
continue continue
@@ -133,7 +130,7 @@ func (c *ReplicaClient) Generations(ctx context.Context) ([]string, error) {
} }
// DeleteGeneration deletes all snapshots & WAL segments within a generation. // DeleteGeneration deletes all snapshots & WAL segments within a generation.
func (c *ReplicaClient) DeleteGeneration(ctx context.Context, generation string) error { func (c *FileReplicaClient) DeleteGeneration(ctx context.Context, generation string) error {
dir, err := c.GenerationDir(generation) dir, err := c.GenerationDir(generation)
if err != nil { if err != nil {
return fmt.Errorf("cannot determine generation path: %w", err) return fmt.Errorf("cannot determine generation path: %w", err)
@@ -146,7 +143,7 @@ func (c *ReplicaClient) DeleteGeneration(ctx context.Context, generation string)
} }
// Snapshots returns an iterator over all available snapshots for a generation. // Snapshots returns an iterator over all available snapshots for a generation.
func (c *ReplicaClient) Snapshots(ctx context.Context, generation string) (litestream.SnapshotIterator, error) { func (c *FileReplicaClient) Snapshots(ctx context.Context, generation string) (SnapshotIterator, error) {
dir, err := c.SnapshotsDir(generation) dir, err := c.SnapshotsDir(generation)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -154,7 +151,7 @@ func (c *ReplicaClient) Snapshots(ctx context.Context, generation string) (lites
f, err := os.Open(dir) f, err := os.Open(dir)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return litestream.NewSnapshotInfoSliceIterator(nil), nil return NewSnapshotInfoSliceIterator(nil), nil
} else if err != nil { } else if err != nil {
return nil, err return nil, err
} }
@@ -166,7 +163,7 @@ func (c *ReplicaClient) Snapshots(ctx context.Context, generation string) (lites
} }
// Iterate over every file and convert to metadata. // Iterate over every file and convert to metadata.
infos := make([]litestream.SnapshotInfo, 0, len(fis)) infos := make([]SnapshotInfo, 0, len(fis))
for _, fi := range fis { for _, fi := range fis {
// Parse index from filename. // Parse index from filename.
index, err := internal.ParseSnapshotPath(filepath.Base(fi.Name())) index, err := internal.ParseSnapshotPath(filepath.Base(fi.Name()))
@@ -174,7 +171,7 @@ func (c *ReplicaClient) Snapshots(ctx context.Context, generation string) (lites
continue continue
} }
infos = append(infos, litestream.SnapshotInfo{ infos = append(infos, SnapshotInfo{
Generation: generation, Generation: generation,
Index: index, Index: index,
Size: fi.Size(), Size: fi.Size(),
@@ -182,30 +179,25 @@ func (c *ReplicaClient) Snapshots(ctx context.Context, generation string) (lites
}) })
} }
sort.Sort(litestream.SnapshotInfoSlice(infos)) sort.Sort(SnapshotInfoSlice(infos))
return litestream.NewSnapshotInfoSliceIterator(infos), nil return NewSnapshotInfoSliceIterator(infos), nil
} }
// WriteSnapshot writes LZ4 compressed data from rd into a file on disk. // WriteSnapshot writes LZ4 compressed data from rd into a file on disk.
func (c *ReplicaClient) WriteSnapshot(ctx context.Context, generation string, index int, rd io.Reader) (info litestream.SnapshotInfo, err error) { func (c *FileReplicaClient) WriteSnapshot(ctx context.Context, generation string, index int, rd io.Reader) (info SnapshotInfo, err error) {
filename, err := c.SnapshotPath(generation, index) filename, err := c.SnapshotPath(generation, index)
if err != nil { if err != nil {
return info, err return info, err
} }
var fileInfo, dirInfo os.FileInfo
if db := c.db(); db != nil {
fileInfo, dirInfo = db.FileInfo(), db.DirInfo()
}
// Ensure parent directory exists. // Ensure parent directory exists.
if err := internal.MkdirAll(filepath.Dir(filename), dirInfo); err != nil { if err := internal.MkdirAll(filepath.Dir(filename), c.DirMode, c.Uid, c.Gid); err != nil {
return info, err return info, err
} }
// Write snapshot to temporary file next to destination path. // Write snapshot to temporary file next to destination path.
f, err := internal.CreateFile(filename+".tmp", fileInfo) f, err := internal.CreateFile(filename+".tmp", c.FileMode, c.Uid, c.Gid)
if err != nil { if err != nil {
return info, err return info, err
} }
@@ -224,7 +216,7 @@ func (c *ReplicaClient) WriteSnapshot(ctx context.Context, generation string, in
if err != nil { if err != nil {
return info, err return info, err
} }
info = litestream.SnapshotInfo{ info = SnapshotInfo{
Generation: generation, Generation: generation,
Index: index, Index: index,
Size: fi.Size(), Size: fi.Size(),
@@ -241,7 +233,7 @@ func (c *ReplicaClient) WriteSnapshot(ctx context.Context, generation string, in
// SnapshotReader returns a reader for snapshot data at the given generation/index. // SnapshotReader returns a reader for snapshot data at the given generation/index.
// Returns os.ErrNotExist if no matching index is found. // Returns os.ErrNotExist if no matching index is found.
func (c *ReplicaClient) SnapshotReader(ctx context.Context, generation string, index int) (io.ReadCloser, error) { func (c *FileReplicaClient) SnapshotReader(ctx context.Context, generation string, index int) (io.ReadCloser, error) {
filename, err := c.SnapshotPath(generation, index) filename, err := c.SnapshotPath(generation, index)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -250,7 +242,7 @@ func (c *ReplicaClient) SnapshotReader(ctx context.Context, generation string, i
} }
// DeleteSnapshot deletes a snapshot with the given generation & index. // DeleteSnapshot deletes a snapshot with the given generation & index.
func (c *ReplicaClient) DeleteSnapshot(ctx context.Context, generation string, index int) error { func (c *FileReplicaClient) DeleteSnapshot(ctx context.Context, generation string, index int) error {
filename, err := c.SnapshotPath(generation, index) filename, err := c.SnapshotPath(generation, index)
if err != nil { if err != nil {
return fmt.Errorf("cannot determine snapshot path: %w", err) return fmt.Errorf("cannot determine snapshot path: %w", err)
@@ -262,7 +254,7 @@ func (c *ReplicaClient) DeleteSnapshot(ctx context.Context, generation string, i
} }
// WALSegments returns an iterator over all available WAL files for a generation. // WALSegments returns an iterator over all available WAL files for a generation.
func (c *ReplicaClient) WALSegments(ctx context.Context, generation string) (litestream.WALSegmentIterator, error) { func (c *FileReplicaClient) WALSegments(ctx context.Context, generation string) (WALSegmentIterator, error) {
dir, err := c.WALDir(generation) dir, err := c.WALDir(generation)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -270,7 +262,7 @@ func (c *ReplicaClient) WALSegments(ctx context.Context, generation string) (lit
f, err := os.Open(dir) f, err := os.Open(dir)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return litestream.NewWALSegmentInfoSliceIterator(nil), nil return NewWALSegmentInfoSliceIterator(nil), nil
} else if err != nil { } else if err != nil {
return nil, err return nil, err
} }
@@ -284,7 +276,7 @@ func (c *ReplicaClient) WALSegments(ctx context.Context, generation string) (lit
// Iterate over every file and convert to metadata. // Iterate over every file and convert to metadata.
indexes := make([]int, 0, len(fis)) indexes := make([]int, 0, len(fis))
for _, fi := range fis { for _, fi := range fis {
index, err := litestream.ParseIndex(fi.Name()) index, err := ParseIndex(fi.Name())
if err != nil || !fi.IsDir() { if err != nil || !fi.IsDir() {
continue continue
} }
@@ -293,28 +285,23 @@ func (c *ReplicaClient) WALSegments(ctx context.Context, generation string) (lit
sort.Ints(indexes) sort.Ints(indexes)
return newWALSegmentIterator(dir, generation, indexes), nil return newFileWALSegmentIterator(dir, generation, indexes), nil
} }
// WriteWALSegment writes LZ4 compressed data from rd into a file on disk. // WriteWALSegment writes LZ4 compressed data from rd into a file on disk.
func (c *ReplicaClient) WriteWALSegment(ctx context.Context, pos litestream.Pos, rd io.Reader) (info litestream.WALSegmentInfo, err error) { func (c *FileReplicaClient) WriteWALSegment(ctx context.Context, pos Pos, rd io.Reader) (info WALSegmentInfo, err error) {
filename, err := c.WALSegmentPath(pos.Generation, pos.Index, pos.Offset) filename, err := c.WALSegmentPath(pos.Generation, pos.Index, pos.Offset)
if err != nil { if err != nil {
return info, err return info, err
} }
var fileInfo, dirInfo os.FileInfo
if db := c.db(); db != nil {
fileInfo, dirInfo = db.FileInfo(), db.DirInfo()
}
// Ensure parent directory exists. // Ensure parent directory exists.
if err := internal.MkdirAll(filepath.Dir(filename), dirInfo); err != nil { if err := internal.MkdirAll(filepath.Dir(filename), c.DirMode, c.Uid, c.Gid); err != nil {
return info, err return info, err
} }
// Write WAL segment to temporary file next to destination path. // Write WAL segment to temporary file next to destination path.
f, err := internal.CreateFile(filename+".tmp", fileInfo) f, err := internal.CreateFile(filename+".tmp", c.FileMode, c.Uid, c.Gid)
if err != nil { if err != nil {
return info, err return info, err
} }
@@ -333,7 +320,7 @@ func (c *ReplicaClient) WriteWALSegment(ctx context.Context, pos litestream.Pos,
if err != nil { if err != nil {
return info, err return info, err
} }
info = litestream.WALSegmentInfo{ info = WALSegmentInfo{
Generation: pos.Generation, Generation: pos.Generation,
Index: pos.Index, Index: pos.Index,
Offset: pos.Offset, Offset: pos.Offset,
@@ -351,7 +338,7 @@ func (c *ReplicaClient) WriteWALSegment(ctx context.Context, pos litestream.Pos,
// WALSegmentReader returns a reader for a section of WAL data at the given position. // WALSegmentReader returns a reader for a section of WAL data at the given position.
// Returns os.ErrNotExist if no matching index/offset is found. // Returns os.ErrNotExist if no matching index/offset is found.
func (c *ReplicaClient) WALSegmentReader(ctx context.Context, pos litestream.Pos) (io.ReadCloser, error) { func (c *FileReplicaClient) WALSegmentReader(ctx context.Context, pos Pos) (io.ReadCloser, error) {
filename, err := c.WALSegmentPath(pos.Generation, pos.Index, pos.Offset) filename, err := c.WALSegmentPath(pos.Generation, pos.Index, pos.Offset)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -360,7 +347,7 @@ func (c *ReplicaClient) WALSegmentReader(ctx context.Context, pos litestream.Pos
} }
// DeleteWALSegments deletes WAL segments at the given positions. // DeleteWALSegments deletes WAL segments at the given positions.
func (c *ReplicaClient) DeleteWALSegments(ctx context.Context, a []litestream.Pos) error { func (c *FileReplicaClient) DeleteWALSegments(ctx context.Context, a []Pos) error {
for _, pos := range a { for _, pos := range a {
filename, err := c.WALSegmentPath(pos.Generation, pos.Index, pos.Offset) filename, err := c.WALSegmentPath(pos.Generation, pos.Index, pos.Offset)
if err != nil { if err != nil {
@@ -373,28 +360,28 @@ func (c *ReplicaClient) DeleteWALSegments(ctx context.Context, a []litestream.Po
return nil return nil
} }
type walSegmentIterator struct { type fileWalSegmentIterator struct {
dir string dir string
generation string generation string
indexes []int indexes []int
infos []litestream.WALSegmentInfo infos []WALSegmentInfo
err error err error
} }
func newWALSegmentIterator(dir, generation string, indexes []int) *walSegmentIterator { func newFileWALSegmentIterator(dir, generation string, indexes []int) *fileWalSegmentIterator {
return &walSegmentIterator{ return &fileWalSegmentIterator{
dir: dir, dir: dir,
generation: generation, generation: generation,
indexes: indexes, indexes: indexes,
} }
} }
func (itr *walSegmentIterator) Close() (err error) { func (itr *fileWalSegmentIterator) Close() (err error) {
return itr.err return itr.err
} }
func (itr *walSegmentIterator) Next() bool { func (itr *fileWalSegmentIterator) Next() bool {
// Exit if an error has already occurred. // Exit if an error has already occurred.
if itr.err != nil { if itr.err != nil {
return false return false
@@ -416,7 +403,7 @@ func (itr *walSegmentIterator) Next() bool {
// Read segments into a cache for the current index. // Read segments into a cache for the current index.
index := itr.indexes[0] index := itr.indexes[0]
itr.indexes = itr.indexes[1:] itr.indexes = itr.indexes[1:]
f, err := os.Open(filepath.Join(itr.dir, litestream.FormatIndex(index))) f, err := os.Open(filepath.Join(itr.dir, FormatIndex(index)))
if err != nil { if err != nil {
itr.err = err itr.err = err
return false return false
@@ -438,12 +425,12 @@ func (itr *walSegmentIterator) Next() bool {
continue continue
} }
offset, err := litestream.ParseOffset(strings.TrimSuffix(filename, ".wal.lz4")) offset, err := ParseOffset(strings.TrimSuffix(filename, ".wal.lz4"))
if err != nil { if err != nil {
continue continue
} }
itr.infos = append(itr.infos, litestream.WALSegmentInfo{ itr.infos = append(itr.infos, WALSegmentInfo{
Generation: itr.generation, Generation: itr.generation,
Index: index, Index: index,
Offset: offset, Offset: offset,
@@ -453,7 +440,7 @@ func (itr *walSegmentIterator) Next() bool {
} }
// Ensure segments are sorted within index. // Ensure segments are sorted within index.
sort.Sort(litestream.WALSegmentInfoSlice(itr.infos)) sort.Sort(WALSegmentInfoSlice(itr.infos))
if len(itr.infos) > 0 { if len(itr.infos) > 0 {
return true return true
@@ -461,11 +448,11 @@ func (itr *walSegmentIterator) Next() bool {
} }
} }
func (itr *walSegmentIterator) Err() error { return itr.err } func (itr *fileWalSegmentIterator) Err() error { return itr.err }
func (itr *walSegmentIterator) WALSegment() litestream.WALSegmentInfo { func (itr *fileWalSegmentIterator) WALSegment() WALSegmentInfo {
if len(itr.infos) == 0 { if len(itr.infos) == 0 {
return litestream.WALSegmentInfo{} return WALSegmentInfo{}
} }
return itr.infos[0] return itr.infos[0]
} }

View File

@@ -1,34 +1,34 @@
package file_test package litestream_test
import ( import (
"testing" "testing"
"github.com/benbjohnson/litestream/file" "github.com/benbjohnson/litestream"
) )
func TestReplicaClient_Path(t *testing.T) { func TestReplicaClient_Path(t *testing.T) {
c := file.NewReplicaClient("/foo/bar") c := litestream.NewFileReplicaClient("/foo/bar")
if got, want := c.Path(), "/foo/bar"; got != want { if got, want := c.Path(), "/foo/bar"; got != want {
t.Fatalf("Path()=%v, want %v", got, want) t.Fatalf("Path()=%v, want %v", got, want)
} }
} }
func TestReplicaClient_Type(t *testing.T) { func TestReplicaClient_Type(t *testing.T) {
if got, want := file.NewReplicaClient("").Type(), "file"; got != want { if got, want := litestream.NewFileReplicaClient("").Type(), "file"; got != want {
t.Fatalf("Type()=%v, want %v", got, want) t.Fatalf("Type()=%v, want %v", got, want)
} }
} }
func TestReplicaClient_GenerationsDir(t *testing.T) { func TestReplicaClient_GenerationsDir(t *testing.T) {
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
if got, err := file.NewReplicaClient("/foo").GenerationsDir(); err != nil { if got, err := litestream.NewFileReplicaClient("/foo").GenerationsDir(); err != nil {
t.Fatal(err) t.Fatal(err)
} else if want := "/foo/generations"; got != want { } else if want := "/foo/generations"; got != want {
t.Fatalf("GenerationsDir()=%v, want %v", got, want) t.Fatalf("GenerationsDir()=%v, want %v", got, want)
} }
}) })
t.Run("ErrNoPath", func(t *testing.T) { t.Run("ErrNoPath", func(t *testing.T) {
if _, err := file.NewReplicaClient("").GenerationsDir(); err == nil || err.Error() != `file replica path required` { if _, err := litestream.NewFileReplicaClient("").GenerationsDir(); err == nil || err.Error() != `file replica path required` {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
}) })
@@ -36,19 +36,19 @@ func TestReplicaClient_GenerationsDir(t *testing.T) {
func TestReplicaClient_GenerationDir(t *testing.T) { func TestReplicaClient_GenerationDir(t *testing.T) {
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
if got, err := file.NewReplicaClient("/foo").GenerationDir("0123456701234567"); err != nil { if got, err := litestream.NewFileReplicaClient("/foo").GenerationDir("0123456701234567"); err != nil {
t.Fatal(err) t.Fatal(err)
} else if want := "/foo/generations/0123456701234567"; got != want { } else if want := "/foo/generations/0123456701234567"; got != want {
t.Fatalf("GenerationDir()=%v, want %v", got, want) t.Fatalf("GenerationDir()=%v, want %v", got, want)
} }
}) })
t.Run("ErrNoPath", func(t *testing.T) { t.Run("ErrNoPath", func(t *testing.T) {
if _, err := file.NewReplicaClient("").GenerationDir("0123456701234567"); err == nil || err.Error() != `file replica path required` { if _, err := litestream.NewFileReplicaClient("").GenerationDir("0123456701234567"); err == nil || err.Error() != `file replica path required` {
t.Fatalf("expected error: %v", err) t.Fatalf("expected error: %v", err)
} }
}) })
t.Run("ErrNoGeneration", func(t *testing.T) { t.Run("ErrNoGeneration", func(t *testing.T) {
if _, err := file.NewReplicaClient("/foo").GenerationDir(""); err == nil || err.Error() != `generation required` { if _, err := litestream.NewFileReplicaClient("/foo").GenerationDir(""); err == nil || err.Error() != `generation required` {
t.Fatalf("expected error: %v", err) t.Fatalf("expected error: %v", err)
} }
}) })
@@ -56,19 +56,19 @@ func TestReplicaClient_GenerationDir(t *testing.T) {
func TestReplicaClient_SnapshotsDir(t *testing.T) { func TestReplicaClient_SnapshotsDir(t *testing.T) {
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
if got, err := file.NewReplicaClient("/foo").SnapshotsDir("0123456701234567"); err != nil { if got, err := litestream.NewFileReplicaClient("/foo").SnapshotsDir("0123456701234567"); err != nil {
t.Fatal(err) t.Fatal(err)
} else if want := "/foo/generations/0123456701234567/snapshots"; got != want { } else if want := "/foo/generations/0123456701234567/snapshots"; got != want {
t.Fatalf("SnapshotsDir()=%v, want %v", got, want) t.Fatalf("SnapshotsDir()=%v, want %v", got, want)
} }
}) })
t.Run("ErrNoPath", func(t *testing.T) { t.Run("ErrNoPath", func(t *testing.T) {
if _, err := file.NewReplicaClient("").SnapshotsDir("0123456701234567"); err == nil || err.Error() != `file replica path required` { if _, err := litestream.NewFileReplicaClient("").SnapshotsDir("0123456701234567"); err == nil || err.Error() != `file replica path required` {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
}) })
t.Run("ErrNoGeneration", func(t *testing.T) { t.Run("ErrNoGeneration", func(t *testing.T) {
if _, err := file.NewReplicaClient("/foo").SnapshotsDir(""); err == nil || err.Error() != `generation required` { if _, err := litestream.NewFileReplicaClient("/foo").SnapshotsDir(""); err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
}) })
@@ -76,19 +76,19 @@ func TestReplicaClient_SnapshotsDir(t *testing.T) {
func TestReplicaClient_SnapshotPath(t *testing.T) { func TestReplicaClient_SnapshotPath(t *testing.T) {
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
if got, err := file.NewReplicaClient("/foo").SnapshotPath("0123456701234567", 1000); err != nil { if got, err := litestream.NewFileReplicaClient("/foo").SnapshotPath("0123456701234567", 1000); err != nil {
t.Fatal(err) t.Fatal(err)
} else if want := "/foo/generations/0123456701234567/snapshots/000003e8.snapshot.lz4"; got != want { } else if want := "/foo/generations/0123456701234567/snapshots/000003e8.snapshot.lz4"; got != want {
t.Fatalf("SnapshotPath()=%v, want %v", got, want) t.Fatalf("SnapshotPath()=%v, want %v", got, want)
} }
}) })
t.Run("ErrNoPath", func(t *testing.T) { t.Run("ErrNoPath", func(t *testing.T) {
if _, err := file.NewReplicaClient("").SnapshotPath("0123456701234567", 1000); err == nil || err.Error() != `file replica path required` { if _, err := litestream.NewFileReplicaClient("").SnapshotPath("0123456701234567", 1000); err == nil || err.Error() != `file replica path required` {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
}) })
t.Run("ErrNoGeneration", func(t *testing.T) { t.Run("ErrNoGeneration", func(t *testing.T) {
if _, err := file.NewReplicaClient("/foo").SnapshotPath("", 1000); err == nil || err.Error() != `generation required` { if _, err := litestream.NewFileReplicaClient("/foo").SnapshotPath("", 1000); err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
}) })
@@ -96,19 +96,19 @@ func TestReplicaClient_SnapshotPath(t *testing.T) {
func TestReplicaClient_WALDir(t *testing.T) { func TestReplicaClient_WALDir(t *testing.T) {
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
if got, err := file.NewReplicaClient("/foo").WALDir("0123456701234567"); err != nil { if got, err := litestream.NewFileReplicaClient("/foo").WALDir("0123456701234567"); err != nil {
t.Fatal(err) t.Fatal(err)
} else if want := "/foo/generations/0123456701234567/wal"; got != want { } else if want := "/foo/generations/0123456701234567/wal"; got != want {
t.Fatalf("WALDir()=%v, want %v", got, want) t.Fatalf("WALDir()=%v, want %v", got, want)
} }
}) })
t.Run("ErrNoPath", func(t *testing.T) { t.Run("ErrNoPath", func(t *testing.T) {
if _, err := file.NewReplicaClient("").WALDir("0123456701234567"); err == nil || err.Error() != `file replica path required` { if _, err := litestream.NewFileReplicaClient("").WALDir("0123456701234567"); err == nil || err.Error() != `file replica path required` {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
}) })
t.Run("ErrNoGeneration", func(t *testing.T) { t.Run("ErrNoGeneration", func(t *testing.T) {
if _, err := file.NewReplicaClient("/foo").WALDir(""); err == nil || err.Error() != `generation required` { if _, err := litestream.NewFileReplicaClient("/foo").WALDir(""); err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
}) })
@@ -116,19 +116,19 @@ func TestReplicaClient_WALDir(t *testing.T) {
func TestReplicaClient_WALSegmentPath(t *testing.T) { func TestReplicaClient_WALSegmentPath(t *testing.T) {
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
if got, err := file.NewReplicaClient("/foo").WALSegmentPath("0123456701234567", 1000, 1001); err != nil { if got, err := litestream.NewFileReplicaClient("/foo").WALSegmentPath("0123456701234567", 1000, 1001); err != nil {
t.Fatal(err) t.Fatal(err)
} else if want := "/foo/generations/0123456701234567/wal/000003e8/000003e9.wal.lz4"; got != want { } else if want := "/foo/generations/0123456701234567/wal/000003e8/000003e9.wal.lz4"; got != want {
t.Fatalf("WALPath()=%v, want %v", got, want) t.Fatalf("WALPath()=%v, want %v", got, want)
} }
}) })
t.Run("ErrNoPath", func(t *testing.T) { t.Run("ErrNoPath", func(t *testing.T) {
if _, err := file.NewReplicaClient("").WALSegmentPath("0123456701234567", 1000, 0); err == nil || err.Error() != `file replica path required` { if _, err := litestream.NewFileReplicaClient("").WALSegmentPath("0123456701234567", 1000, 0); err == nil || err.Error() != `file replica path required` {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
}) })
t.Run("ErrNoGeneration", func(t *testing.T) { t.Run("ErrNoGeneration", func(t *testing.T) {
if _, err := file.NewReplicaClient("/foo").WALSegmentPath("", 1000, 0); err == nil || err.Error() != `generation required` { if _, err := litestream.NewFileReplicaClient("/foo").WALSegmentPath("", 1000, 0); err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
}) })

View File

@@ -0,0 +1,566 @@
package integration_test
import (
"context"
"flag"
"fmt"
"io/ioutil"
"math/rand"
"os"
"path"
"reflect"
"sort"
"strings"
"testing"
"time"
"github.com/benbjohnson/litestream"
"github.com/benbjohnson/litestream/abs"
"github.com/benbjohnson/litestream/gcs"
"github.com/benbjohnson/litestream/s3"
"github.com/benbjohnson/litestream/sftp"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
var (
// Enables integration tests.
replicaType = flag.String("replica-type", "file", "")
)
// S3 settings
var (
// Replica client settings
s3AccessKeyID = flag.String("s3-access-key-id", os.Getenv("LITESTREAM_S3_ACCESS_KEY_ID"), "")
s3SecretAccessKey = flag.String("s3-secret-access-key", os.Getenv("LITESTREAM_S3_SECRET_ACCESS_KEY"), "")
s3Region = flag.String("s3-region", os.Getenv("LITESTREAM_S3_REGION"), "")
s3Bucket = flag.String("s3-bucket", os.Getenv("LITESTREAM_S3_BUCKET"), "")
s3Path = flag.String("s3-path", os.Getenv("LITESTREAM_S3_PATH"), "")
s3Endpoint = flag.String("s3-endpoint", os.Getenv("LITESTREAM_S3_ENDPOINT"), "")
s3ForcePathStyle = flag.Bool("s3-force-path-style", os.Getenv("LITESTREAM_S3_FORCE_PATH_STYLE") == "true", "")
s3SkipVerify = flag.Bool("s3-skip-verify", os.Getenv("LITESTREAM_S3_SKIP_VERIFY") == "true", "")
)
// Google cloud storage settings
var (
gcsBucket = flag.String("gcs-bucket", os.Getenv("LITESTREAM_GCS_BUCKET"), "")
gcsPath = flag.String("gcs-path", os.Getenv("LITESTREAM_GCS_PATH"), "")
)
// Azure blob storage settings
var (
absAccountName = flag.String("abs-account-name", os.Getenv("LITESTREAM_ABS_ACCOUNT_NAME"), "")
absAccountKey = flag.String("abs-account-key", os.Getenv("LITESTREAM_ABS_ACCOUNT_KEY"), "")
absBucket = flag.String("abs-bucket", os.Getenv("LITESTREAM_ABS_BUCKET"), "")
absPath = flag.String("abs-path", os.Getenv("LITESTREAM_ABS_PATH"), "")
)
// SFTP settings
var (
sftpHost = flag.String("sftp-host", os.Getenv("LITESTREAM_SFTP_HOST"), "")
sftpUser = flag.String("sftp-user", os.Getenv("LITESTREAM_SFTP_USER"), "")
sftpPassword = flag.String("sftp-password", os.Getenv("LITESTREAM_SFTP_PASSWORD"), "")
sftpKeyPath = flag.String("sftp-key-path", os.Getenv("LITESTREAM_SFTP_KEY_PATH"), "")
sftpPath = flag.String("sftp-path", os.Getenv("LITESTREAM_SFTP_PATH"), "")
)
func TestReplicaClient_Generations(t *testing.T) {
RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
// Write snapshots.
if _, err := c.WriteSnapshot(context.Background(), "5efbd8d042012dca", 0, strings.NewReader(`foo`)); err != nil {
t.Fatal(err)
} else if _, err := c.WriteSnapshot(context.Background(), "b16ddcf5c697540f", 0, strings.NewReader(`bar`)); err != nil {
t.Fatal(err)
} else if _, err := c.WriteSnapshot(context.Background(), "155fe292f8333c72", 0, strings.NewReader(`baz`)); err != nil {
t.Fatal(err)
}
// Verify returned generations.
if got, err := c.Generations(context.Background()); err != nil {
t.Fatal(err)
} else if want := []string{"155fe292f8333c72", "5efbd8d042012dca", "b16ddcf5c697540f"}; !reflect.DeepEqual(got, want) {
t.Fatalf("Generations()=%v, want %v", got, want)
}
})
RunWithReplicaClient(t, "NoGenerationsDir", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if generations, err := c.Generations(context.Background()); err != nil {
t.Fatal(err)
} else if got, want := len(generations), 0; got != want {
t.Fatalf("len(Generations())=%v, want %v", got, want)
}
})
}
func TestReplicaClient_Snapshots(t *testing.T) {
RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
// Write snapshots.
if _, err := c.WriteSnapshot(context.Background(), "5efbd8d042012dca", 1, strings.NewReader(``)); err != nil {
t.Fatal(err)
} else if _, err := c.WriteSnapshot(context.Background(), "b16ddcf5c697540f", 5, strings.NewReader(`x`)); err != nil {
t.Fatal(err)
} else if _, err := c.WriteSnapshot(context.Background(), "b16ddcf5c697540f", 10, strings.NewReader(`xyz`)); err != nil {
t.Fatal(err)
}
// Fetch all snapshots by generation.
itr, err := c.Snapshots(context.Background(), "b16ddcf5c697540f")
if err != nil {
t.Fatal(err)
}
defer itr.Close()
// Read all snapshots into a slice so they can be sorted.
a, err := litestream.SliceSnapshotIterator(itr)
if err != nil {
t.Fatal(err)
} else if got, want := len(a), 2; got != want {
t.Fatalf("len=%v, want %v", got, want)
}
sort.Sort(litestream.SnapshotInfoSlice(a))
// Verify first snapshot metadata.
if got, want := a[0].Generation, "b16ddcf5c697540f"; got != want {
t.Fatalf("Generation=%v, want %v", got, want)
} else if got, want := a[0].Index, 5; got != want {
t.Fatalf("Index=%v, want %v", got, want)
} else if got, want := a[0].Size, int64(1); got != want {
t.Fatalf("Size=%v, want %v", got, want)
} else if a[0].CreatedAt.IsZero() {
t.Fatalf("expected CreatedAt")
}
// Verify second snapshot metadata.
if got, want := a[1].Generation, "b16ddcf5c697540f"; got != want {
t.Fatalf("Generation=%v, want %v", got, want)
} else if got, want := a[1].Index, 0xA; got != want {
t.Fatalf("Index=%v, want %v", got, want)
} else if got, want := a[1].Size, int64(3); got != want {
t.Fatalf("Size=%v, want %v", got, want)
} else if a[1].CreatedAt.IsZero() {
t.Fatalf("expected CreatedAt")
}
// Ensure close is clean.
if err := itr.Close(); err != nil {
t.Fatal(err)
}
})
RunWithReplicaClient(t, "NoGenerationDir", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
itr, err := c.Snapshots(context.Background(), "5efbd8d042012dca")
if err != nil {
t.Fatal(err)
}
defer itr.Close()
if itr.Next() {
t.Fatal("expected no snapshots")
}
})
RunWithReplicaClient(t, "ErrNoGeneration", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
itr, err := c.Snapshots(context.Background(), "")
if err == nil {
err = itr.Close()
}
if err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err)
}
})
}
func TestReplicaClient_WriteSnapshot(t *testing.T) {
RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WriteSnapshot(context.Background(), "b16ddcf5c697540f", 1000, strings.NewReader(`foobar`)); err != nil {
t.Fatal(err)
}
if r, err := c.SnapshotReader(context.Background(), "b16ddcf5c697540f", 1000); err != nil {
t.Fatal(err)
} else if buf, err := ioutil.ReadAll(r); err != nil {
t.Fatal(err)
} else if err := r.Close(); err != nil {
t.Fatal(err)
} else if got, want := string(buf), `foobar`; got != want {
t.Fatalf("data=%q, want %q", got, want)
}
})
RunWithReplicaClient(t, "ErrNoGeneration", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WriteSnapshot(context.Background(), "", 0, nil); err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err)
}
})
}
func TestReplicaClient_SnapshotReader(t *testing.T) {
RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WriteSnapshot(context.Background(), "5efbd8d042012dca", 10, strings.NewReader(`foo`)); err != nil {
t.Fatal(err)
}
r, err := c.SnapshotReader(context.Background(), "5efbd8d042012dca", 10)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if buf, err := ioutil.ReadAll(r); err != nil {
t.Fatal(err)
} else if got, want := string(buf), "foo"; got != want {
t.Fatalf("ReadAll=%v, want %v", got, want)
}
})
RunWithReplicaClient(t, "ErrNotFound", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.SnapshotReader(context.Background(), "5efbd8d042012dca", 1); !os.IsNotExist(err) {
t.Fatalf("expected not exist, got %#v", err)
}
})
RunWithReplicaClient(t, "ErrNoGeneration", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.SnapshotReader(context.Background(), "", 1); err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err)
}
})
}
func TestReplicaClient_WALSegments(t *testing.T) {
RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "5efbd8d042012dca", Index: 1, Offset: 0}, strings.NewReader(``)); err != nil {
t.Fatal(err)
}
if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "b16ddcf5c697540f", Index: 2, Offset: 0}, strings.NewReader(`12345`)); err != nil {
t.Fatal(err)
} else if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "b16ddcf5c697540f", Index: 2, Offset: 5}, strings.NewReader(`67`)); err != nil {
t.Fatal(err)
} else if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "b16ddcf5c697540f", Index: 3, Offset: 0}, strings.NewReader(`xyz`)); err != nil {
t.Fatal(err)
}
itr, err := c.WALSegments(context.Background(), "b16ddcf5c697540f")
if err != nil {
t.Fatal(err)
}
defer itr.Close()
// Read all WAL segment files into a slice so they can be sorted.
a, err := litestream.SliceWALSegmentIterator(itr)
if err != nil {
t.Fatal(err)
} else if got, want := len(a), 3; got != want {
t.Fatalf("len=%v, want %v", got, want)
}
sort.Sort(litestream.WALSegmentInfoSlice(a))
// Verify first WAL segment metadata.
if got, want := a[0].Generation, "b16ddcf5c697540f"; got != want {
t.Fatalf("Generation=%v, want %v", got, want)
} else if got, want := a[0].Index, 2; got != want {
t.Fatalf("Index=%v, want %v", got, want)
} else if got, want := a[0].Offset, int64(0); got != want {
t.Fatalf("Offset=%v, want %v", got, want)
} else if got, want := a[0].Size, int64(5); got != want {
t.Fatalf("Size=%v, want %v", got, want)
} else if a[0].CreatedAt.IsZero() {
t.Fatalf("expected CreatedAt")
}
// Verify first WAL segment metadata.
if got, want := a[1].Generation, "b16ddcf5c697540f"; got != want {
t.Fatalf("Generation=%v, want %v", got, want)
} else if got, want := a[1].Index, 2; got != want {
t.Fatalf("Index=%v, want %v", got, want)
} else if got, want := a[1].Offset, int64(5); got != want {
t.Fatalf("Offset=%v, want %v", got, want)
} else if got, want := a[1].Size, int64(2); got != want {
t.Fatalf("Size=%v, want %v", got, want)
} else if a[1].CreatedAt.IsZero() {
t.Fatalf("expected CreatedAt")
}
// Verify third WAL segment metadata.
if got, want := a[2].Generation, "b16ddcf5c697540f"; got != want {
t.Fatalf("Generation=%v, want %v", got, want)
} else if got, want := a[2].Index, 3; got != want {
t.Fatalf("Index=%v, want %v", got, want)
} else if got, want := a[2].Offset, int64(0); got != want {
t.Fatalf("Offset=%v, want %v", got, want)
} else if got, want := a[2].Size, int64(3); got != want {
t.Fatalf("Size=%v, want %v", got, want)
} else if a[1].CreatedAt.IsZero() {
t.Fatalf("expected CreatedAt")
}
// Ensure close is clean.
if err := itr.Close(); err != nil {
t.Fatal(err)
}
})
RunWithReplicaClient(t, "NoGenerationDir", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
itr, err := c.WALSegments(context.Background(), "5efbd8d042012dca")
if err != nil {
t.Fatal(err)
}
defer itr.Close()
if itr.Next() {
t.Fatal("expected no wal files")
}
})
RunWithReplicaClient(t, "NoWALs", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WriteSnapshot(context.Background(), "5efbd8d042012dca", 0, strings.NewReader(`foo`)); err != nil {
t.Fatal(err)
}
itr, err := c.WALSegments(context.Background(), "5efbd8d042012dca")
if err != nil {
t.Fatal(err)
}
defer itr.Close()
if itr.Next() {
t.Fatal("expected no wal files")
}
})
RunWithReplicaClient(t, "ErrNoGeneration", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
itr, err := c.WALSegments(context.Background(), "")
if err == nil {
err = itr.Close()
}
if err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err)
}
})
}
func TestReplicaClient_WriteWALSegment(t *testing.T) {
RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "b16ddcf5c697540f", Index: 1000, Offset: 2000}, strings.NewReader(`foobar`)); err != nil {
t.Fatal(err)
}
if r, err := c.WALSegmentReader(context.Background(), litestream.Pos{Generation: "b16ddcf5c697540f", Index: 1000, Offset: 2000}); err != nil {
t.Fatal(err)
} else if buf, err := ioutil.ReadAll(r); err != nil {
t.Fatal(err)
} else if err := r.Close(); err != nil {
t.Fatal(err)
} else if got, want := string(buf), `foobar`; got != want {
t.Fatalf("data=%q, want %q", got, want)
}
})
RunWithReplicaClient(t, "ErrNoGeneration", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "", Index: 0, Offset: 0}, nil); err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err)
}
})
}
func TestReplicaClient_WALSegmentReader(t *testing.T) {
RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "5efbd8d042012dca", Index: 10, Offset: 5}, strings.NewReader(`foobar`)); err != nil {
t.Fatal(err)
}
r, err := c.WALSegmentReader(context.Background(), litestream.Pos{Generation: "5efbd8d042012dca", Index: 10, Offset: 5})
if err != nil {
t.Fatal(err)
}
defer r.Close()
if buf, err := ioutil.ReadAll(r); err != nil {
t.Fatal(err)
} else if got, want := string(buf), "foobar"; got != want {
t.Fatalf("ReadAll=%v, want %v", got, want)
}
})
RunWithReplicaClient(t, "ErrNotFound", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WALSegmentReader(context.Background(), litestream.Pos{Generation: "5efbd8d042012dca", Index: 1, Offset: 0}); !os.IsNotExist(err) {
t.Fatalf("expected not exist, got %#v", err)
}
})
}
func TestReplicaClient_DeleteWALSegments(t *testing.T) {
RunWithReplicaClient(t, "OK", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "b16ddcf5c697540f", Index: 1, Offset: 2}, strings.NewReader(`foo`)); err != nil {
t.Fatal(err)
} else if _, err := c.WriteWALSegment(context.Background(), litestream.Pos{Generation: "5efbd8d042012dca", Index: 3, Offset: 4}, strings.NewReader(`bar`)); err != nil {
t.Fatal(err)
}
if err := c.DeleteWALSegments(context.Background(), []litestream.Pos{
{Generation: "b16ddcf5c697540f", Index: 1, Offset: 2},
{Generation: "5efbd8d042012dca", Index: 3, Offset: 4},
}); err != nil {
t.Fatal(err)
}
if _, err := c.WALSegmentReader(context.Background(), litestream.Pos{Generation: "b16ddcf5c697540f", Index: 1, Offset: 2}); !os.IsNotExist(err) {
t.Fatalf("expected not exist, got %#v", err)
} else if _, err := c.WALSegmentReader(context.Background(), litestream.Pos{Generation: "5efbd8d042012dca", Index: 3, Offset: 4}); !os.IsNotExist(err) {
t.Fatalf("expected not exist, got %#v", err)
}
})
RunWithReplicaClient(t, "ErrNoGeneration", func(t *testing.T, c litestream.ReplicaClient) {
t.Parallel()
if err := c.DeleteWALSegments(context.Background(), []litestream.Pos{{}}); err == nil || err.Error() != `generation required` {
t.Fatalf("unexpected error: %v", err)
}
})
}
// RunWithReplicaClient executes fn with each replica specified by the -replica-type flag
func RunWithReplicaClient(t *testing.T, name string, fn func(*testing.T, litestream.ReplicaClient)) {
t.Run(name, func(t *testing.T) {
for _, typ := range strings.Split(*replicaType, ",") {
t.Run(typ, func(t *testing.T) {
c := NewReplicaClient(t, typ)
defer MustDeleteAll(t, c)
fn(t, c)
})
}
})
}
// NewReplicaClient returns a new client for integration testing by type name.
func NewReplicaClient(tb testing.TB, typ string) litestream.ReplicaClient {
tb.Helper()
switch typ {
case litestream.FileReplicaClientType:
return litestream.NewFileReplicaClient(tb.TempDir())
case s3.ReplicaClientType:
return NewS3ReplicaClient(tb)
case gcs.ReplicaClientType:
return NewGCSReplicaClient(tb)
case abs.ReplicaClientType:
return NewABSReplicaClient(tb)
case sftp.ReplicaClientType:
return NewSFTPReplicaClient(tb)
default:
tb.Fatalf("invalid replica client type: %q", typ)
return nil
}
}
// NewS3ReplicaClient returns a new client for integration testing.
func NewS3ReplicaClient(tb testing.TB) *s3.ReplicaClient {
tb.Helper()
c := s3.NewReplicaClient()
c.AccessKeyID = *s3AccessKeyID
c.SecretAccessKey = *s3SecretAccessKey
c.Region = *s3Region
c.Bucket = *s3Bucket
c.Path = path.Join(*s3Path, fmt.Sprintf("%016x", rand.Uint64()))
c.Endpoint = *s3Endpoint
c.ForcePathStyle = *s3ForcePathStyle
c.SkipVerify = *s3SkipVerify
return c
}
// NewGCSReplicaClient returns a new client for integration testing.
func NewGCSReplicaClient(tb testing.TB) *gcs.ReplicaClient {
tb.Helper()
c := gcs.NewReplicaClient()
c.Bucket = *gcsBucket
c.Path = path.Join(*gcsPath, fmt.Sprintf("%016x", rand.Uint64()))
return c
}
// NewABSReplicaClient returns a new client for integration testing.
func NewABSReplicaClient(tb testing.TB) *abs.ReplicaClient {
tb.Helper()
c := abs.NewReplicaClient()
c.AccountName = *absAccountName
c.AccountKey = *absAccountKey
c.Bucket = *absBucket
c.Path = path.Join(*absPath, fmt.Sprintf("%016x", rand.Uint64()))
return c
}
// NewSFTPReplicaClient returns a new client for integration testing.
func NewSFTPReplicaClient(tb testing.TB) *sftp.ReplicaClient {
tb.Helper()
c := sftp.NewReplicaClient()
c.Host = *sftpHost
c.User = *sftpUser
c.Password = *sftpPassword
c.KeyPath = *sftpKeyPath
c.Path = path.Join(*sftpPath, fmt.Sprintf("%016x", rand.Uint64()))
return c
}
// MustDeleteAll deletes all objects under the client's path.
func MustDeleteAll(tb testing.TB, c litestream.ReplicaClient) {
tb.Helper()
generations, err := c.Generations(context.Background())
if err != nil {
tb.Fatalf("cannot list generations for deletion: %s", err)
}
for _, generation := range generations {
if err := c.DeleteGeneration(context.Background(), generation); err != nil {
tb.Fatalf("cannot delete generation: %s", err)
}
}
switch c := c.(type) {
case *sftp.ReplicaClient:
if err := c.Cleanup(context.Background()); err != nil {
tb.Fatalf("cannot cleanup sftp: %s", err)
}
}
}

View File

@@ -94,27 +94,19 @@ func (r *ReadCounter) Read(p []byte) (int, error) {
func (r *ReadCounter) N() int64 { return r.n } func (r *ReadCounter) N() int64 { return r.n }
// CreateFile creates the file and matches the mode & uid/gid of fi. // CreateFile creates the file and matches the mode & uid/gid of fi.
func CreateFile(filename string, fi os.FileInfo) (*os.File, error) { func CreateFile(filename string, mode os.FileMode, uid, gid int) (*os.File, error) {
mode := os.FileMode(0600)
if fi != nil {
mode = fi.Mode()
}
f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
uid, gid := Fileinfo(fi)
_ = f.Chown(uid, gid) _ = f.Chown(uid, gid)
return f, nil return f, nil
} }
// MkdirAll is a copy of os.MkdirAll() except that it attempts to set the // MkdirAll is a copy of os.MkdirAll() except that it attempts to set the
// mode/uid/gid to match fi for each created directory. // mode/uid/gid to match fi for each created directory.
func MkdirAll(path string, fi os.FileInfo) error { func MkdirAll(path string, mode os.FileMode, uid, gid int) error {
uid, gid := Fileinfo(fi)
// Fast path: if we can tell whether path is a directory or file, stop with success or error. // Fast path: if we can tell whether path is a directory or file, stop with success or error.
dir, err := os.Stat(path) dir, err := os.Stat(path)
if err == nil { if err == nil {
@@ -137,17 +129,13 @@ func MkdirAll(path string, fi os.FileInfo) error {
if j > 1 { if j > 1 {
// Create parent. // Create parent.
err = MkdirAll(fixRootDirectory(path[:j-1]), fi) err = MkdirAll(fixRootDirectory(path[:j-1]), mode, uid, gid)
if err != nil { if err != nil {
return err return err
} }
} }
// Parent now exists; invoke Mkdir and use its result. // Parent now exists; invoke Mkdir and use its result.
mode := os.FileMode(0700)
if fi != nil {
mode = fi.Mode()
}
err = os.Mkdir(path, mode) err = os.Mkdir(path, mode)
if err != nil { if err != nil {
// Handle arguments like "foo/." by // Handle arguments like "foo/." by

View File

@@ -37,6 +37,7 @@ const (
var ( var (
ErrNoGeneration = errors.New("no generation available") ErrNoGeneration = errors.New("no generation available")
ErrNoSnapshots = errors.New("no snapshots available") ErrNoSnapshots = errors.New("no snapshots available")
ErrNoWALSegments = errors.New("no wal segments available")
ErrChecksumMismatch = errors.New("invalid replica, checksum mismatch") ErrChecksumMismatch = errors.New("invalid replica, checksum mismatch")
) )
@@ -440,6 +441,20 @@ func ParseOffset(s string) (int64, error) {
return v, nil return v, nil
} }
// removeDBFiles deletes the database and related files (journal, shm, wal).
func removeDBFiles(filename string) error {
if err := os.Remove(filename); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("cannot delete database %q: %w", filename, err)
} else if err := os.Remove(filename + "-journal"); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("cannot delete journal for %q: %w", filename, err)
} else if err := os.Remove(filename + "-shm"); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("cannot delete shared memory for %q: %w", filename, err)
} else if err := os.Remove(filename + "-wal"); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("cannot delete wal for %q: %w", filename, err)
}
return nil
}
// isHexChar returns true if ch is a lowercase hex character. // isHexChar returns true if ch is a lowercase hex character.
func isHexChar(ch rune) bool { func isHexChar(ch rune) bool {
return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f')

File diff suppressed because one or more lines are too long

14
mock/read_closer.go Normal file
View File

@@ -0,0 +1,14 @@
package mock
type ReadCloser struct {
CloseFunc func() error
ReadFunc func([]byte) (int, error)
}
func (r *ReadCloser) Close() error {
return r.CloseFunc()
}
func (r *ReadCloser) Read(b []byte) (int, error) {
return r.ReadFunc(b)
}

28
mock/snapshot_iterator.go Normal file
View File

@@ -0,0 +1,28 @@
package mock
import (
"github.com/benbjohnson/litestream"
)
type SnapshotIterator struct {
CloseFunc func() error
NextFunc func() bool
ErrFunc func() error
SnapshotFunc func() litestream.SnapshotInfo
}
func (itr *SnapshotIterator) Close() error {
return itr.CloseFunc()
}
func (itr *SnapshotIterator) Next() bool {
return itr.NextFunc()
}
func (itr *SnapshotIterator) Err() error {
return itr.ErrFunc()
}
func (itr *SnapshotIterator) Snapshot() litestream.SnapshotInfo {
return itr.SnapshotFunc()
}

View File

@@ -0,0 +1,28 @@
package mock
import (
"github.com/benbjohnson/litestream"
)
type WALSegmentIterator struct {
CloseFunc func() error
NextFunc func() bool
ErrFunc func() error
WALSegmentFunc func() litestream.WALSegmentInfo
}
func (itr *WALSegmentIterator) Close() error {
return itr.CloseFunc()
}
func (itr *WALSegmentIterator) Next() bool {
return itr.NextFunc()
}
func (itr *WALSegmentIterator) Err() error {
return itr.ErrFunc()
}
func (itr *WALSegmentIterator) WALSegment() litestream.WALSegmentInfo {
return itr.WALSegmentFunc()
}

View File

@@ -7,14 +7,12 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"math"
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
"sync" "sync"
"time" "time"
"github.com/benbjohnson/litestream/internal"
"github.com/pierrec/lz4/v4" "github.com/pierrec/lz4/v4"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
@@ -144,6 +142,15 @@ func (r *Replica) Stop(hard bool) (err error) {
return err return err
} }
// logPrefix returns the prefix used when logging from the replica.
// This includes the replica name as well as the database path, if available.
func (r *Replica) logPrefix() string {
if db := r.DB(); db != nil {
return fmt.Sprintf("%s(%s): ", db.Path(), r.Name())
}
return r.Name() + ": "
}
// Sync copies new WAL frames from the shadow WAL to the replica client. // Sync copies new WAL frames from the shadow WAL to the replica client.
func (r *Replica) Sync(ctx context.Context) (err error) { func (r *Replica) Sync(ctx context.Context) (err error) {
// Clear last position if if an error occurs during sync. // Clear last position if if an error occurs during sync.
@@ -766,14 +773,18 @@ func (r *Replica) Validate(ctx context.Context) error {
return fmt.Errorf("cannot wait for replica: %w", err) return fmt.Errorf("cannot wait for replica: %w", err)
} }
// Find lastest snapshot that occurs before the index.
snapshotIndex, err := FindSnapshotForIndex(ctx, r.Client, pos.Generation, pos.Index-1)
if err != nil {
return fmt.Errorf("cannot find snapshot index: %w", err)
}
restorePath := filepath.Join(tmpdir, "replica") restorePath := filepath.Join(tmpdir, "replica")
if err := r.Restore(ctx, RestoreOptions{ opt := RestoreOptions{
OutputPath: restorePath,
ReplicaName: r.Name(),
Generation: pos.Generation,
Index: pos.Index - 1,
Logger: log.New(os.Stderr, "", 0), Logger: log.New(os.Stderr, "", 0),
}); err != nil { LogPrefix: r.logPrefix(),
}
if err := Restore(ctx, r.Client, restorePath, pos.Generation, snapshotIndex, pos.Index-1, opt); err != nil {
return fmt.Errorf("cannot restore: %w", err) return fmt.Errorf("cannot restore: %w", err)
} }
@@ -883,295 +894,6 @@ func (r *Replica) GenerationCreatedAt(ctx context.Context, generation string) (t
return min, itr.Close() return min, itr.Close()
} }
// GenerationTimeBounds returns the creation time & last updated time of a generation.
// Returns zero time if no snapshots or WAL segments exist.
func (r *Replica) GenerationTimeBounds(ctx context.Context, generation string) (createdAt, updatedAt time.Time, err error) {
// Iterate over snapshots.
sitr, err := r.Client.Snapshots(ctx, generation)
if err != nil {
return createdAt, updatedAt, err
}
defer sitr.Close()
for sitr.Next() {
info := sitr.Snapshot()
if createdAt.IsZero() || info.CreatedAt.Before(createdAt) {
createdAt = info.CreatedAt
}
if updatedAt.IsZero() || info.CreatedAt.After(updatedAt) {
updatedAt = info.CreatedAt
}
}
if err := sitr.Close(); err != nil {
return createdAt, updatedAt, err
}
// Iterate over WAL segments.
witr, err := r.Client.WALSegments(ctx, generation)
if err != nil {
return createdAt, updatedAt, err
}
defer witr.Close()
for witr.Next() {
info := witr.WALSegment()
if createdAt.IsZero() || info.CreatedAt.Before(createdAt) {
createdAt = info.CreatedAt
}
if updatedAt.IsZero() || info.CreatedAt.After(updatedAt) {
updatedAt = info.CreatedAt
}
}
if err := witr.Close(); err != nil {
return createdAt, updatedAt, err
}
return createdAt, updatedAt, nil
}
// CalcRestoreTarget returns a generation to restore from.
func (r *Replica) CalcRestoreTarget(ctx context.Context, opt RestoreOptions) (generation string, updatedAt time.Time, err error) {
var target struct {
generation string
updatedAt time.Time
}
generations, err := r.Client.Generations(ctx)
if err != nil {
return "", time.Time{}, fmt.Errorf("cannot fetch generations: %w", err)
}
// Search generations for one that contains the requested timestamp.
for _, generation := range generations {
// Skip generation if it does not match filter.
if opt.Generation != "" && generation != opt.Generation {
continue
}
// Determine the time bounds for the generation.
createdAt, updatedAt, err := r.GenerationTimeBounds(ctx, generation)
if err != nil {
return "", time.Time{}, fmt.Errorf("generation created at: %w", err)
}
// Skip if it does not contain timestamp.
if !opt.Timestamp.IsZero() {
if opt.Timestamp.Before(createdAt) || opt.Timestamp.After(updatedAt) {
continue
}
}
// Use the latest replica if we have multiple candidates.
if !updatedAt.After(target.updatedAt) {
continue
}
target.generation = generation
target.updatedAt = updatedAt
}
return target.generation, target.updatedAt, nil
}
// Replica restores the database from a replica based on the options given.
// This method will restore into opt.OutputPath, if specified, or into the
// DB's original database path. It can optionally restore from a specific
// replica or generation or it will automatically choose the best one. Finally,
// a timestamp can be specified to restore the database to a specific
// point-in-time.
func (r *Replica) Restore(ctx context.Context, opt RestoreOptions) (err error) {
// Validate options.
if opt.OutputPath == "" {
if r.db.path == "" {
return fmt.Errorf("output path required")
}
opt.OutputPath = r.db.path
} else if opt.Generation == "" && opt.Index != math.MaxInt32 {
return fmt.Errorf("must specify generation when restoring to index")
} else if opt.Index != math.MaxInt32 && !opt.Timestamp.IsZero() {
return fmt.Errorf("cannot specify index & timestamp to restore")
}
// Ensure logger exists.
logger := opt.Logger
if logger == nil {
logger = log.New(ioutil.Discard, "", 0)
}
logPrefix := r.Name()
if db := r.DB(); db != nil {
logPrefix = fmt.Sprintf("%s(%s)", db.Path(), r.Name())
}
// Ensure output path does not already exist.
if _, err := os.Stat(opt.OutputPath); err == nil {
return fmt.Errorf("cannot restore, output path already exists: %s", opt.OutputPath)
} else if err != nil && !os.IsNotExist(err) {
return err
}
// Find lastest snapshot that occurs before timestamp or index.
var minWALIndex int
if opt.Index < math.MaxInt32 {
if minWALIndex, err = r.SnapshotIndexByIndex(ctx, opt.Generation, opt.Index); err != nil {
return fmt.Errorf("cannot find snapshot index: %w", err)
}
} else {
if minWALIndex, err = r.SnapshotIndexAt(ctx, opt.Generation, opt.Timestamp); err != nil {
return fmt.Errorf("cannot find snapshot index by timestamp: %w", err)
}
}
// Compute list of offsets for each WAL index.
walSegmentMap, err := r.walSegmentMap(ctx, opt.Generation, opt.Index, opt.Timestamp)
if err != nil {
return fmt.Errorf("cannot find max wal index for restore: %w", err)
}
// Find the maximum WAL index that occurs before timestamp.
maxWALIndex := -1
for index := range walSegmentMap {
if index > maxWALIndex {
maxWALIndex = index
}
}
// Ensure that we found the specific index, if one was specified.
if opt.Index != math.MaxInt32 && opt.Index != opt.Index {
return fmt.Errorf("unable to locate index %d in generation %q, highest index was %d", opt.Index, opt.Generation, maxWALIndex)
}
// If no WAL files were found, mark this as a snapshot-only restore.
snapshotOnly := maxWALIndex == -1
// Initialize starting position.
pos := Pos{Generation: opt.Generation, Index: minWALIndex}
tmpPath := opt.OutputPath + ".tmp"
// Copy snapshot to output path.
logger.Printf("%s: restoring snapshot %s/%08x to %s", logPrefix, opt.Generation, minWALIndex, tmpPath)
if err := r.restoreSnapshot(ctx, pos.Generation, pos.Index, tmpPath); err != nil {
return fmt.Errorf("cannot restore snapshot: %w", err)
}
// If no WAL files available, move snapshot to final path & exit early.
if snapshotOnly {
logger.Printf("%s: snapshot only, finalizing database", logPrefix)
return os.Rename(tmpPath, opt.OutputPath)
}
// Begin processing WAL files.
logger.Printf("%s: restoring wal files: generation=%s index=[%08x,%08x]", logPrefix, opt.Generation, minWALIndex, maxWALIndex)
// Fill input channel with all WAL indexes to be loaded in order.
// Verify every index has at least one offset.
ch := make(chan int, maxWALIndex-minWALIndex+1)
for index := minWALIndex; index <= maxWALIndex; index++ {
if len(walSegmentMap[index]) == 0 {
return fmt.Errorf("missing WAL index: %s/%08x", opt.Generation, index)
}
ch <- index
}
close(ch)
// Track load state for each WAL.
var mu sync.Mutex
cond := sync.NewCond(&mu)
walStates := make([]walRestoreState, maxWALIndex-minWALIndex+1)
parallelism := opt.Parallelism
if parallelism < 1 {
parallelism = 1
}
// Download WAL files to disk in parallel.
g, ctx := errgroup.WithContext(ctx)
for i := 0; i < parallelism; i++ {
g.Go(func() error {
for {
select {
case <-ctx.Done():
cond.Broadcast()
return err
case index, ok := <-ch:
if !ok {
cond.Broadcast()
return nil
}
startTime := time.Now()
err := r.downloadWAL(ctx, opt.Generation, index, walSegmentMap[index], tmpPath)
if err != nil {
err = fmt.Errorf("cannot download wal %s/%08x: %w", opt.Generation, index, err)
}
// Mark index as ready-to-apply and notify applying code.
mu.Lock()
walStates[index-minWALIndex] = walRestoreState{ready: true, err: err}
mu.Unlock()
cond.Broadcast()
// Returning the error here will cancel the other goroutines.
if err != nil {
return err
}
logger.Printf("%s: downloaded wal %s/%08x elapsed=%s",
logPrefix, opt.Generation, index,
time.Since(startTime).String(),
)
}
}
})
}
// Apply WAL files in order as they are ready.
for index := minWALIndex; index <= maxWALIndex; index++ {
// Wait until next WAL file is ready to apply.
mu.Lock()
for !walStates[index-minWALIndex].ready {
if err := ctx.Err(); err != nil {
return err
}
cond.Wait()
}
if err := walStates[index-minWALIndex].err; err != nil {
return err
}
mu.Unlock()
// Apply WAL to database file.
startTime := time.Now()
if err = applyWAL(ctx, index, tmpPath); err != nil {
return fmt.Errorf("cannot apply wal: %w", err)
}
logger.Printf("%s: applied wal %s/%08x elapsed=%s",
logPrefix, opt.Generation, index,
time.Since(startTime).String(),
)
}
// Ensure all goroutines finish. All errors should have been handled during
// the processing of WAL files but this ensures that all processing is done.
if err := g.Wait(); err != nil {
return err
}
// Copy file to final location.
logger.Printf("%s: renaming database from temporary location", logPrefix)
if err := os.Rename(tmpPath, opt.OutputPath); err != nil {
return err
}
return nil
}
type walRestoreState struct {
ready bool
err error
}
// SnapshotIndexAt returns the highest index for a snapshot within a generation // SnapshotIndexAt returns the highest index for a snapshot within a generation
// that occurs before timestamp. If timestamp is zero, returns the latest snapshot. // that occurs before timestamp. If timestamp is zero, returns the latest snapshot.
func (r *Replica) SnapshotIndexAt(ctx context.Context, generation string, timestamp time.Time) (int, error) { func (r *Replica) SnapshotIndexAt(ctx context.Context, generation string, timestamp time.Time) (int, error) {
@@ -1202,137 +924,19 @@ func (r *Replica) SnapshotIndexAt(ctx context.Context, generation string, timest
return snapshotIndex, nil return snapshotIndex, nil
} }
// SnapshotIndexbyIndex returns the highest index for a snapshot within a generation // LatestReplica returns the most recently updated replica.
// that occurs before a given index. If index is MaxInt32, returns the latest snapshot. func LatestReplica(ctx context.Context, replicas []*Replica) (*Replica, error) {
func (r *Replica) SnapshotIndexByIndex(ctx context.Context, generation string, index int) (int, error) { var t time.Time
itr, err := r.Client.Snapshots(ctx, generation) var r *Replica
if err != nil { for i := range replicas {
return 0, err _, max, err := ReplicaClientTimeBounds(ctx, replicas[i].Client)
}
defer itr.Close()
snapshotIndex := -1
for itr.Next() {
snapshot := itr.Snapshot()
if index < math.MaxInt32 && snapshot.Index > index {
continue // after index, skip
}
// Use snapshot if it newer.
if snapshotIndex == -1 || snapshotIndex >= snapshotIndex {
snapshotIndex = snapshot.Index
}
}
if err := itr.Close(); err != nil {
return 0, err
} else if snapshotIndex == -1 {
return 0, ErrNoSnapshots
}
return snapshotIndex, nil
}
// walSegmentMap returns a map of WAL indices to their segments.
// Filters by a max timestamp or a max index.
func (r *Replica) walSegmentMap(ctx context.Context, generation string, maxIndex int, maxTimestamp time.Time) (map[int][]int64, error) {
itr, err := r.Client.WALSegments(ctx, generation)
if err != nil { if err != nil {
return nil, err return nil, err
} else if r == nil || max.After(t) {
r, t = replicas[i], max
} }
defer itr.Close()
m := make(map[int][]int64)
for itr.Next() {
info := itr.WALSegment()
// Exit if we go past the max timestamp or index.
if !maxTimestamp.IsZero() && info.CreatedAt.After(maxTimestamp) {
break // after max timestamp, skip
} else if info.Index > maxIndex {
break // after max index, skip
} }
return r, nil
// Verify offsets are added in order.
offsets := m[info.Index]
if len(offsets) == 0 && info.Offset != 0 {
return nil, fmt.Errorf("missing initial wal segment: generation=%s index=%08x offset=%d", generation, info.Index, info.Offset)
} else if len(offsets) > 0 && offsets[len(offsets)-1] >= info.Offset {
return nil, fmt.Errorf("wal segments out of order: generation=%s index=%08x offsets=(%d,%d)", generation, info.Index, offsets[len(offsets)-1], info.Offset)
}
// Append to the end of the WAL file.
m[info.Index] = append(offsets, info.Offset)
}
return m, itr.Close()
}
// restoreSnapshot copies a snapshot from the replica to a file.
func (r *Replica) restoreSnapshot(ctx context.Context, generation string, index int, filename string) error {
// Determine the user/group & mode based on the DB, if available.
var fileInfo, dirInfo os.FileInfo
if db := r.DB(); db != nil {
fileInfo, dirInfo = db.fileInfo, db.dirInfo
}
if err := internal.MkdirAll(filepath.Dir(filename), dirInfo); err != nil {
return err
}
f, err := internal.CreateFile(filename, fileInfo)
if err != nil {
return err
}
defer f.Close()
rd, err := r.Client.SnapshotReader(ctx, generation, index)
if err != nil {
return err
}
defer rd.Close()
if _, err := io.Copy(f, lz4.NewReader(rd)); err != nil {
return err
} else if err := f.Sync(); err != nil {
return err
}
return f.Close()
}
// downloadWAL copies a WAL file from the replica to a local copy next to the DB.
// The WAL is later applied by applyWAL(). This function can be run in parallel
// to download multiple WAL files simultaneously.
func (r *Replica) downloadWAL(ctx context.Context, generation string, index int, offsets []int64, dbPath string) (err error) {
// Determine the user/group & mode based on the DB, if available.
var fileInfo os.FileInfo
if db := r.DB(); db != nil {
fileInfo = db.fileInfo
}
// Open readers for every segment in the WAL file, in order.
var readers []io.Reader
for _, offset := range offsets {
rd, err := r.Client.WALSegmentReader(ctx, Pos{Generation: generation, Index: index, Offset: offset})
if err != nil {
return err
}
defer rd.Close()
readers = append(readers, lz4.NewReader(rd))
}
// Open handle to destination WAL path.
f, err := internal.CreateFile(fmt.Sprintf("%s-%08x-wal", dbPath, index), fileInfo)
if err != nil {
return err
}
defer f.Close()
// Combine segments together and copy WAL to target path.
if _, err := io.Copy(f, io.MultiReader(readers...)); err != nil {
return err
} else if err := f.Close(); err != nil {
return err
}
return nil
} }
// Replica metrics. // Replica metrics.

View File

@@ -2,9 +2,19 @@ package litestream
import ( import (
"context" "context"
"fmt"
"io" "io"
"log"
"os"
"time"
"github.com/benbjohnson/litestream/internal"
"github.com/pierrec/lz4/v4"
) )
// DefaultRestoreParallelism is the default parallelism when downloading WAL files.
const DefaultRestoreParallelism = 8
// ReplicaClient represents client to connect to a Replica. // ReplicaClient represents client to connect to a Replica.
type ReplicaClient interface { type ReplicaClient interface {
// Returns the type of client. // Returns the type of client.
@@ -46,3 +56,382 @@ type ReplicaClient interface {
// WAL segment does not exist. // WAL segment does not exist.
WALSegmentReader(ctx context.Context, pos Pos) (io.ReadCloser, error) WALSegmentReader(ctx context.Context, pos Pos) (io.ReadCloser, error)
} }
// FindSnapshotForIndex returns the highest index for a snapshot within a
// generation that occurs before a given index.
func FindSnapshotForIndex(ctx context.Context, client ReplicaClient, generation string, index int) (int, error) {
itr, err := client.Snapshots(ctx, generation)
if err != nil {
return 0, fmt.Errorf("snapshots: %w", err)
}
defer itr.Close()
// Iterate over all snapshots to find the closest to our given index.
snapshotIndex := -1
var n int
for ; itr.Next(); n++ {
info := itr.Snapshot()
if info.Index > index {
continue // after given index, skip
}
// Use snapshot if it's more recent.
if info.Index >= snapshotIndex {
snapshotIndex = info.Index
}
}
if err := itr.Close(); err != nil {
return 0, fmt.Errorf("snapshot iteration: %w", err)
}
// Ensure we find at least one snapshot and that it's before the given index.
if n == 0 {
return 0, ErrNoSnapshots
} else if snapshotIndex == -1 {
return 0, fmt.Errorf("no snapshots available at or before index %08x", index)
}
return snapshotIndex, nil
}
// GenerationTimeBounds returns the creation time & last updated time of a generation.
// Returns ErrNoSnapshots if no data exists for the generation.
func GenerationTimeBounds(ctx context.Context, client ReplicaClient, generation string) (createdAt, updatedAt time.Time, err error) {
// Determine bounds for snapshots only first.
// This will return ErrNoSnapshots if no snapshots exist.
if createdAt, updatedAt, err = SnapshotTimeBounds(ctx, client, generation); err != nil {
return createdAt, updatedAt, err
}
// Update ending time bounds if WAL segments exist after the last snapshot.
_, max, err := WALTimeBounds(ctx, client, generation)
if err != nil && err != ErrNoWALSegments {
return createdAt, updatedAt, err
} else if max.After(updatedAt) {
updatedAt = max
}
return createdAt, updatedAt, nil
}
// SnapshotTimeBounds returns the minimum and maximum snapshot timestamps within a generation.
// Returns ErrNoSnapshots if no data exists for the generation.
func SnapshotTimeBounds(ctx context.Context, client ReplicaClient, generation string) (min, max time.Time, err error) {
itr, err := client.Snapshots(ctx, generation)
if err != nil {
return min, max, fmt.Errorf("snapshots: %w", err)
}
defer itr.Close()
// Iterate over all snapshots to find the oldest and newest.
var n int
for ; itr.Next(); n++ {
info := itr.Snapshot()
if min.IsZero() || info.CreatedAt.Before(min) {
min = info.CreatedAt
}
if max.IsZero() || info.CreatedAt.After(max) {
max = info.CreatedAt
}
}
if err := itr.Close(); err != nil {
return min, max, fmt.Errorf("snapshot iteration: %w", err)
}
// Return error if no snapshots exist.
if n == 0 {
return min, max, ErrNoSnapshots
}
return min, max, nil
}
// WALTimeBounds returns the minimum and maximum snapshot timestamps.
// Returns ErrNoWALSegments if no data exists for the generation.
func WALTimeBounds(ctx context.Context, client ReplicaClient, generation string) (min, max time.Time, err error) {
itr, err := client.WALSegments(ctx, generation)
if err != nil {
return min, max, fmt.Errorf("wal segments: %w", err)
}
defer itr.Close()
// Iterate over all WAL segments to find oldest and newest.
var n int
for ; itr.Next(); n++ {
info := itr.WALSegment()
if min.IsZero() || info.CreatedAt.Before(min) {
min = info.CreatedAt
}
if max.IsZero() || info.CreatedAt.After(max) {
max = info.CreatedAt
}
}
if err := itr.Close(); err != nil {
return min, max, fmt.Errorf("wal segment iteration: %w", err)
}
if n == 0 {
return min, max, ErrNoWALSegments
}
return min, max, nil
}
// FindLatestGeneration returns the most recent generation for a client.
func FindLatestGeneration(ctx context.Context, client ReplicaClient) (generation string, err error) {
generations, err := client.Generations(ctx)
if err != nil {
return "", fmt.Errorf("generations: %w", err)
}
// Search generations for one latest updated.
var maxTime time.Time
for i := range generations {
// Determine the latest update for the generation.
_, updatedAt, err := GenerationTimeBounds(ctx, client, generations[i])
if err != nil {
return "", fmt.Errorf("generation time bounds: %w", err)
}
// Use the latest replica if we have multiple candidates.
if updatedAt.After(maxTime) {
maxTime = updatedAt
generation = generations[i]
}
}
if generation == "" {
return "", ErrNoGeneration
}
return generation, nil
}
// ReplicaClientTimeBounds returns time range covered by a replica client
// across all generations. It scans the time range of all generations and
// computes the lower and upper bounds of them.
func ReplicaClientTimeBounds(ctx context.Context, client ReplicaClient) (min, max time.Time, err error) {
generations, err := client.Generations(ctx)
if err != nil {
return min, max, fmt.Errorf("generations: %w", err)
} else if len(generations) == 0 {
return min, max, ErrNoGeneration
}
// Iterate over generations to determine outer bounds.
for i := range generations {
// Determine the time range for the generation.
createdAt, updatedAt, err := GenerationTimeBounds(ctx, client, generations[i])
if err != nil {
return min, max, fmt.Errorf("generation time bounds: %w", err)
}
// Update time bounds.
if min.IsZero() || createdAt.Before(min) {
min = createdAt
}
if max.IsZero() || updatedAt.After(max) {
max = updatedAt
}
}
return min, max, nil
}
// FindMaxIndexByGeneration returns the last index within a generation.
// Returns ErrNoSnapshots if no index exists on the replica for the generation.
func FindMaxIndexByGeneration(ctx context.Context, client ReplicaClient, generation string) (index int, err error) {
// Determine the highest available snapshot index. Returns an error if no
// snapshot are available as WALs are not useful without snapshots.
snapshotIndex, err := FindMaxSnapshotIndexByGeneration(ctx, client, generation)
if err == ErrNoSnapshots {
return index, err
} else if err != nil {
return index, fmt.Errorf("max snapshot index: %w", err)
}
// Determine the highest available WAL index.
walIndex, err := FindMaxWALIndexByGeneration(ctx, client, generation)
if err != nil && err != ErrNoWALSegments {
return index, fmt.Errorf("max wal index: %w", err)
}
// Use snapshot index if it's after the last WAL index.
if snapshotIndex > walIndex {
return snapshotIndex, nil
}
return walIndex, nil
}
// FindMaxSnapshotIndexByGeneration returns the last snapshot index within a generation.
// Returns ErrNoSnapshots if no snapshots exist for the generation on the replica.
func FindMaxSnapshotIndexByGeneration(ctx context.Context, client ReplicaClient, generation string) (index int, err error) {
itr, err := client.Snapshots(ctx, generation)
if err != nil {
return 0, fmt.Errorf("snapshots: %w", err)
}
defer func() { _ = itr.Close() }()
// Iterate over snapshots to find the highest index.
var n int
for ; itr.Next(); n++ {
if info := itr.Snapshot(); info.Index > index {
index = info.Index
}
}
if err := itr.Close(); err != nil {
return 0, fmt.Errorf("snapshot iteration: %w", err)
}
// Return an error if no snapshots were found.
if n == 0 {
return 0, ErrNoSnapshots
}
return index, nil
}
// FindMaxWALIndexByGeneration returns the last WAL index within a generation.
// Returns ErrNoWALSegments if no segments exist for the generation on the replica.
func FindMaxWALIndexByGeneration(ctx context.Context, client ReplicaClient, generation string) (index int, err error) {
itr, err := client.WALSegments(ctx, generation)
if err != nil {
return 0, fmt.Errorf("wal segments: %w", err)
}
defer func() { _ = itr.Close() }()
// Iterate over WAL segments to find the highest index.
var n int
for ; itr.Next(); n++ {
if info := itr.WALSegment(); info.Index > index {
index = info.Index
}
}
if err := itr.Close(); err != nil {
return 0, fmt.Errorf("wal segment iteration: %w", err)
}
// Return an error if no WAL segments were found.
if n == 0 {
return 0, ErrNoWALSegments
}
return index, nil
}
// Restore restores the database to the given index on a generation.
func Restore(ctx context.Context, client ReplicaClient, filename, generation string, snapshotIndex, targetIndex int, opt RestoreOptions) (err error) {
// Validate options.
if filename == "" {
return fmt.Errorf("restore path required")
} else if generation == "" {
return fmt.Errorf("generation required")
} else if snapshotIndex < 0 {
return fmt.Errorf("snapshot index required")
} else if targetIndex < 0 {
return fmt.Errorf("target index required")
}
// Require a default level of parallelism.
if opt.Parallelism < 1 {
opt.Parallelism = DefaultRestoreParallelism
}
// Ensure logger exists.
logger := opt.Logger
if logger == nil {
logger = log.New(io.Discard, "", 0)
}
// Ensure output path does not already exist.
// If doesn't exist, also remove the journal, shm, & wal if left behind.
if _, err := os.Stat(filename); err == nil {
return fmt.Errorf("cannot restore, output path already exists: %s", filename)
} else if err != nil && !os.IsNotExist(err) {
return err
} else if err := removeDBFiles(filename); err != nil {
return err
}
// Copy snapshot to output path.
tmpPath := filename + ".tmp"
logger.Printf("%srestoring snapshot %s/%08x to %s", opt.LogPrefix, generation, snapshotIndex, tmpPath)
if err := RestoreSnapshot(ctx, client, tmpPath, generation, snapshotIndex, opt.Mode, opt.Uid, opt.Gid); err != nil {
return fmt.Errorf("cannot restore snapshot: %w", err)
}
// Download & apply all WAL files between the snapshot & the target index.
d := NewWALDownloader(client, tmpPath, generation, snapshotIndex, targetIndex)
d.Parallelism = opt.Parallelism
d.Mode = opt.Mode
d.Uid, d.Gid = opt.Uid, opt.Gid
for {
// Read next WAL file from downloader.
walIndex, walPath, err := d.Next(ctx)
if err == io.EOF {
break
}
// If we are only reading a single index, a WAL file may not be found.
if _, ok := err.(*WALNotFoundError); ok && snapshotIndex == targetIndex {
logger.Printf("%sno wal files found, snapshot only", opt.LogPrefix)
break
} else if err != nil {
return fmt.Errorf("cannot download WAL: %w", err)
}
// Apply WAL file.
startTime := time.Now()
if err = ApplyWAL(ctx, tmpPath, walPath); err != nil {
return fmt.Errorf("cannot apply wal: %w", err)
}
logger.Printf("%sapplied wal %s/%08x elapsed=%s", opt.LogPrefix, generation, walIndex, time.Since(startTime).String())
}
// Copy file to final location.
logger.Printf("%srenaming database from temporary location", opt.LogPrefix)
if err := os.Rename(tmpPath, filename); err != nil {
return err
}
return nil
}
// RestoreOptions represents options for DB.Restore().
type RestoreOptions struct {
// File info used for restored snapshot & WAL files.
Mode os.FileMode
Uid, Gid int
// Specifies how many WAL files are downloaded in parallel during restore.
Parallelism int
// Logging settings.
Logger *log.Logger
LogPrefix string
}
// NewRestoreOptions returns a new instance of RestoreOptions with defaults.
func NewRestoreOptions() RestoreOptions {
return RestoreOptions{
Mode: 0600,
Parallelism: DefaultRestoreParallelism,
}
}
// RestoreSnapshot copies a snapshot from the replica client to a file.
func RestoreSnapshot(ctx context.Context, client ReplicaClient, filename, generation string, index int, mode os.FileMode, uid, gid int) error {
f, err := internal.CreateFile(filename, mode, uid, gid)
if err != nil {
return err
}
defer f.Close()
rd, err := client.SnapshotReader(ctx, generation, index)
if err != nil {
return err
}
defer rd.Close()
if _, err := io.Copy(f, lz4.NewReader(rd)); err != nil {
return err
} else if err := f.Sync(); err != nil {
return err
}
return f.Close()
}

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,6 @@ import (
"testing" "testing"
"github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream"
"github.com/benbjohnson/litestream/file"
"github.com/benbjohnson/litestream/mock" "github.com/benbjohnson/litestream/mock"
"github.com/pierrec/lz4/v4" "github.com/pierrec/lz4/v4"
) )
@@ -45,9 +44,9 @@ func TestReplica_Sync(t *testing.T) {
// Fetch current database position. // Fetch current database position.
dpos := db.Pos() dpos := db.Pos()
c := file.NewReplicaClient(t.TempDir()) c := litestream.NewFileReplicaClient(t.TempDir())
r := litestream.NewReplica(db, "") r := litestream.NewReplica(db, "")
c.Replica, r.Client = r, c r.Client = c
if err := r.Sync(context.Background()); err != nil { if err := r.Sync(context.Background()); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -81,7 +80,7 @@ func TestReplica_Snapshot(t *testing.T) {
db, sqldb := MustOpenDBs(t) db, sqldb := MustOpenDBs(t)
defer MustCloseDBs(t, db, sqldb) defer MustCloseDBs(t, db, sqldb)
c := file.NewReplicaClient(t.TempDir()) c := litestream.NewFileReplicaClient(t.TempDir())
r := litestream.NewReplica(db, "") r := litestream.NewReplica(db, "")
r.Client = c r.Client = c

8
testdata/Makefile vendored Normal file
View File

@@ -0,0 +1,8 @@
.PHONY: default
default:
make -C find-latest-generation/ok
make -C generation-time-bounds/ok
make -C generation-time-bounds/snapshots-only
make -C replica-client-time-bounds/ok
make -C snapshot-time-bounds/ok
make -C wal-time-bounds/ok

View File

@@ -0,0 +1,7 @@
.PHONY: default
default:
TZ=UTC touch -t 200001010000 generations/0000000000000000/snapshots/00000000.snapshot.lz4
TZ=UTC touch -t 200001020000 generations/0000000000000000/snapshots/00000001.snapshot.lz4
TZ=UTC touch -t 200001010000 generations/0000000000000001/snapshots/00000000.snapshot.lz4
TZ=UTC touch -t 200001030000 generations/0000000000000001/snapshots/00000001.snapshot.lz4
TZ=UTC touch -t 200001010000 generations/0000000000000002/snapshots/00000000.snapshot.lz4

View File

@@ -0,0 +1,8 @@
.PHONY: default
default:
TZ=UTC touch -t 200001010000 generations/0000000000000000/snapshots/00000000.snapshot.lz4
TZ=UTC touch -t 200001020000 generations/0000000000000000/snapshots/00000001.snapshot.lz4
TZ=UTC touch -t 200001010000 generations/0000000000000000/wal/00000000/00000000.wal.lz4
TZ=UTC touch -t 200001020000 generations/0000000000000000/wal/00000000/00000001.wal.lz4
TZ=UTC touch -t 200001030000 generations/0000000000000000/wal/00000001/00000000.wal.lz4

View File

@@ -0,0 +1,5 @@
.PHONY: default
default:
TZ=UTC touch -t 200001010000 generations/0000000000000000/snapshots/00000000.snapshot.lz4
TZ=UTC touch -t 200001020000 generations/0000000000000000/snapshots/00000001.snapshot.lz4

View File

@@ -0,0 +1,6 @@
.PHONY: default
default:
TZ=UTC touch -t 200001020000 generations/0000000000000000/snapshots/00000000.snapshot.lz4
TZ=UTC touch -t 200001010000 generations/0000000000000001/snapshots/00000000.snapshot.lz4
TZ=UTC touch -t 200001030000 generations/0000000000000001/snapshots/00000001.snapshot.lz4
TZ=UTC touch -t 200001010000 generations/0000000000000002/snapshots/00000000.snapshot.lz4

View File

@@ -0,0 +1,6 @@
.PHONY: default
default:
TZ=UTC touch -t 200001020000 generations/0000000000000000/snapshots/00000000.snapshot.lz4
TZ=UTC touch -t 200001010000 generations/0000000000000001/snapshots/00000000.snapshot.lz4
TZ=UTC touch -t 200001030000 generations/0000000000000001/snapshots/00000001.snapshot.lz4
TZ=UTC touch -t 200001010000 generations/0000000000000002/snapshots/00000000.snapshot.lz4

Binary file not shown.

36
testdata/restore/bad-permissions/README vendored Normal file
View File

@@ -0,0 +1,36 @@
To reproduce this testdata, run sqlite3 and execute:
PRAGMA journal_mode = WAL;
CREATE TABLE t (x);
INSERT INTO t (x) VALUES (1);
INSERT INTO t (x) VALUES (2);
sl3 split -o generations/0000000000000000/wal/00000000 db-wal
cp db generations/0000000000000000/snapshots/00000000.snapshot
lz4 -c --rm generations/0000000000000000/snapshots/00000000.snapshot
Then execute:
PRAGMA wal_checkpoint(TRUNCATE);
INSERT INTO t (x) VALUES (3);
sl3 split -o generations/0000000000000000/wal/00000001 db-wal
Then execute:
PRAGMA wal_checkpoint(TRUNCATE);
INSERT INTO t (x) VALUES (4);
INSERT INTO t (x) VALUES (5);
sl3 split -o generations/0000000000000000/wal/00000002 db-wal
Finally, obtain the final snapshot:
PRAGMA wal_checkpoint(TRUNCATE);
cp db 00000002.db
rm db*

BIN
testdata/restore/ok/00000002.db vendored Normal file

Binary file not shown.

36
testdata/restore/ok/README vendored Normal file
View File

@@ -0,0 +1,36 @@
To reproduce this testdata, run sqlite3 and execute:
PRAGMA journal_mode = WAL;
CREATE TABLE t (x);
INSERT INTO t (x) VALUES (1);
INSERT INTO t (x) VALUES (2);
sl3 split -o generations/0000000000000000/wal/00000000 db-wal
cp db generations/0000000000000000/snapshots/00000000.snapshot
lz4 -c --rm generations/0000000000000000/snapshots/00000000.snapshot
Then execute:
PRAGMA wal_checkpoint(TRUNCATE);
INSERT INTO t (x) VALUES (3);
sl3 split -o generations/0000000000000000/wal/00000001 db-wal
Then execute:
PRAGMA wal_checkpoint(TRUNCATE);
INSERT INTO t (x) VALUES (4);
INSERT INTO t (x) VALUES (5);
sl3 split -o generations/0000000000000000/wal/00000002 db-wal
Finally, obtain the final snapshot:
PRAGMA wal_checkpoint(TRUNCATE);
cp db 00000002.db
rm db*

Binary file not shown.

View File

@@ -0,0 +1,6 @@
.PHONY: default
default:
TZ=UTC touch -t 200001010000 generations/0000000000000000/snapshots/00000000.snapshot.lz4
TZ=UTC touch -t 200001020000 generations/0000000000000000/snapshots/00000001.snapshot.lz4
TZ=UTC touch -t 200001030000 generations/0000000000000000/snapshots/00000002.snapshot.lz4

Some files were not shown because too many files have changed in this diff Show More