From 57a02a8628f333e74d5670eec718ac89f0ea222c Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Wed, 13 Jan 2021 10:14:54 -0700 Subject: [PATCH] S3 replica --- cmd/litestream/main.go | 36 +- db.go | 13 +- go.mod | 1 + go.sum | 2 + litestream.go | 24 +- replica.go | 18 +- s3/s3.go | 874 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 940 insertions(+), 28 deletions(-) create mode 100644 s3/s3.go diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index 520e778..6a6f6a8 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -13,6 +13,7 @@ import ( "time" "github.com/benbjohnson/litestream" + "github.com/benbjohnson/litestream/s3" "gopkg.in/yaml.v2" ) @@ -160,6 +161,11 @@ type ReplicaConfig struct { Name string `yaml:"name"` // name of replica, optional. Path string `yaml:"path"` Retention time.Duration `yaml:"retention"` + + // S3 settings + AccessKeyID string `yaml:"access-key-id"` + SecretAccessKey string `yaml:"secret-access-key"` + Bucket string `yaml:"bucket"` } func registerConfigFlag(fs *flag.FlagSet, p *string) { @@ -188,6 +194,8 @@ func newReplicaFromConfig(db *litestream.DB, config *ReplicaConfig) (litestream. switch config.Type { case "", "file": return newFileReplicaFromConfig(db, config) + case "s3": + return newS3ReplicaFromConfig(db, config) default: return nil, fmt.Errorf("unknown replica type in config: %q", config.Type) } @@ -196,11 +204,37 @@ func newReplicaFromConfig(db *litestream.DB, config *ReplicaConfig) (litestream. // newFileReplicaFromConfig returns a new instance of FileReplica build from config. func newFileReplicaFromConfig(db *litestream.DB, config *ReplicaConfig) (*litestream.FileReplica, error) { if config.Path == "" { - return nil, fmt.Errorf("file replica path require for db %q", db.Path()) + return nil, fmt.Errorf("%s: file replica path required", db.Path()) } + r := litestream.NewFileReplica(db, config.Name, config.Path) if v := config.Retention; v > 0 { r.RetentionInterval = v } return r, nil } + +// newS3ReplicaFromConfig returns a new instance of S3Replica build from config. +func newS3ReplicaFromConfig(db *litestream.DB, config *ReplicaConfig) (*s3.Replica, error) { + if config.AccessKeyID == "" { + return nil, fmt.Errorf("%s: s3 access key id required", db.Path()) + } else if config.SecretAccessKey == "" { + return nil, fmt.Errorf("%s: s3 secret access key required", db.Path()) + } else if config.Region == "" { + return nil, fmt.Errorf("%s: s3 region required", db.Path()) + } else if config.Bucket == "" { + return nil, fmt.Errorf("%s: s3 bucket required", db.Path()) + } + + r := aws.NewS3Replica(db, config.Name) + r.AccessKeyID = config.AccessKeyID + r.SecretAccessKey = config.SecretAccessKey + r.Region = config.Region + r.Bucket = config.Bucket + r.Path = config.Path + + if v := config.Retention; v > 0 { + r.RetentionInterval = v + } + return r, nil +} diff --git a/db.go b/db.go index 7d982bb..20ab26b 100644 --- a/db.go +++ b/db.go @@ -117,6 +117,11 @@ func NewDB(path string) *DB { return db } +// SQLDB returns a reference to the underlying sql.DB connection. +func (db *DB) SQLDB() *sql.DB { + return db.db +} + // Path returns the path to the database. func (db *DB) Path() string { return db.path @@ -181,7 +186,7 @@ func (db *DB) CurrentShadowWALIndex(generation string) (index int, size int64, e if !strings.HasSuffix(fi.Name(), WALExt) { continue } - if v, _, _, err := ParseWALPath(fi.Name()); err != nil { + if v, _, _, _, err := ParseWALPath(fi.Name()); err != nil { continue // invalid wal filename } else if v > index { index = v @@ -478,7 +483,7 @@ func (db *DB) cleanWAL() error { return err } for _, fi := range fis { - if idx, _, _, err := ParseWALPath(fi.Name()); err != nil || idx >= min { + if idx, _, _, _, err := ParseWALPath(fi.Name()); err != nil || idx >= min { continue } if err := os.Remove(filepath.Join(dir, fi.Name())); err != nil { @@ -858,7 +863,7 @@ func (db *DB) syncWAL(info syncInfo) (newSize int64, err error) { // Parse index of current shadow WAL file. dir, base := filepath.Split(info.shadowWALPath) - index, _, _, err := ParseWALPath(base) + index, _, _, _, err := ParseWALPath(base) if err != nil { return 0, fmt.Errorf("cannot parse shadow wal filename: %s", base) } @@ -1217,7 +1222,7 @@ func (db *DB) checkpointAndInit(info syncInfo, mode string) error { } // Parse index of current shadow WAL file. - index, _, _, err := ParseWALPath(info.shadowWALPath) + index, _, _, _, err := ParseWALPath(info.shadowWALPath) if err != nil { return fmt.Errorf("cannot parse shadow wal filename: %s", info.shadowWALPath) } diff --git a/go.mod b/go.mod index 1862f3e..9afca13 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/benbjohnson/litestream go 1.15 require ( + github.com/aws/aws-sdk-go v1.27.0 github.com/mattn/go-sqlite3 v1.14.5 github.com/prometheus/client_golang v1.9.0 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 808e811..b4cfcf4 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,7 @@ github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmV github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= +github.com/aws/aws-sdk-go v1.27.0 h1:0xphMHGMLBrPMfxR2AmVjZKcMEESEgWF8Kru94BNByk= github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= @@ -124,6 +125,7 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= diff --git a/litestream.go b/litestream.go index 1e0537f..7e4ef25 100644 --- a/litestream.go +++ b/litestream.go @@ -52,8 +52,8 @@ type SnapshotInfo struct { CreatedAt time.Time } -// filterSnapshotsAfter returns all snapshots that were created on or after t. -func filterSnapshotsAfter(a []*SnapshotInfo, t time.Time) []*SnapshotInfo { +// FilterSnapshotsAfter returns all snapshots that were created on or after t. +func FilterSnapshotsAfter(a []*SnapshotInfo, t time.Time) []*SnapshotInfo { other := make([]*SnapshotInfo, 0, len(a)) for _, snapshot := range a { if !snapshot.CreatedAt.Before(t) { @@ -63,8 +63,8 @@ func filterSnapshotsAfter(a []*SnapshotInfo, t time.Time) []*SnapshotInfo { return other } -// findMinSnapshotByGeneration finds the snapshot with the lowest index in a generation. -func findMinSnapshotByGeneration(a []*SnapshotInfo, generation string) *SnapshotInfo { +// FindMinSnapshotByGeneration finds the snapshot with the lowest index in a generation. +func FindMinSnapshotByGeneration(a []*SnapshotInfo, generation string) *SnapshotInfo { var min *SnapshotInfo for _, snapshot := range a { if snapshot.Generation != generation { @@ -229,17 +229,18 @@ func IsWALPath(s string) bool { // ParseWALPath returns the index & offset for the WAL file. // Returns an error if the path is not a valid snapshot path. -func ParseWALPath(s string) (index int, offset int64, ext string, err error) { +func ParseWALPath(s string) (index int, offset, sz int64, ext string, err error) { s = filepath.Base(s) a := walPathRegex.FindStringSubmatch(s) if a == nil { - return 0, 0, "", fmt.Errorf("invalid wal path: %s", s) + return 0, 0, 0, "", fmt.Errorf("invalid wal path: %s", s) } i64, _ := strconv.ParseUint(a[1], 16, 64) off64, _ := strconv.ParseUint(a[2], 16, 64) - return int(i64), int64(off64), a[3], nil + sz64, _ := strconv.ParseUint(a[3], 16, 64) + return int(i64), int64(off64), int64(sz64), a[4], nil } // FormatWALPath formats a WAL filename with a given index. @@ -248,14 +249,15 @@ func FormatWALPath(index int) string { return fmt.Sprintf("%016x%s", index, WALExt) } -// FormatWALPathWithOffset formats a WAL filename with a given index & offset. -func FormatWALPathWithOffset(index int, offset int64) string { +// FormatWALPathWithOffsetSize formats a WAL filename with a given index, offset & size. +func FormatWALPathWithOffsetSize(index int, offset, sz int64) string { assert(index >= 0, "wal index must be non-negative") assert(offset >= 0, "wal offset must be non-negative") - return fmt.Sprintf("%016x_%016x%s", index, offset, WALExt) + assert(sz >= 0, "wal size must be non-negative") + return fmt.Sprintf("%016x_%016x_%016x%s", index, offset, sz, WALExt) } -var walPathRegex = regexp.MustCompile(`^([0-9a-f]{16})(?:_([0-9a-f]{16}))?(.wal(?:.gz)?)$`) +var walPathRegex = regexp.MustCompile(`^([0-9a-f]{16})(?:_([0-9a-f]{16})_([0-9a-f]{16}))?(.wal(?:.gz)?)$`) // isHexChar returns true if ch is a lowercase hex character. func isHexChar(ch rune) bool { diff --git a/replica.go b/replica.go index df3ca99..b85f1b5 100644 --- a/replica.go +++ b/replica.go @@ -333,7 +333,7 @@ func (r *FileReplica) WALs(ctx context.Context) ([]*WALInfo, error) { // Iterate over each WAL file. for _, fi := range fis { - index, offset, _, err := ParseWALPath(fi.Name()) + index, offset, _, _, err := ParseWALPath(fi.Name()) if err != nil { continue } @@ -446,7 +446,7 @@ func (r *FileReplica) CalcPos(generation string) (pos Pos, err error) { index := -1 for _, fi := range fis { - if idx, _, _, err := ParseWALPath(fi.Name()); err != nil { + if idx, _, _, _, err := ParseWALPath(fi.Name()); err != nil { continue // invalid wal filename } else if index == -1 || idx > index { index = idx @@ -679,7 +679,7 @@ func (r *FileReplica) WALIndexAt(ctx context.Context, generation string, maxInde for _, fi := range fis { // Read index from snapshot filename. - idx, _, _, err := ParseWALPath(fi.Name()) + idx, _, _, _, err := ParseWALPath(fi.Name()) if err != nil { continue // not a snapshot, skip } else if !timestamp.IsZero() && fi.ModTime().After(timestamp) { @@ -783,7 +783,7 @@ func (r *FileReplica) EnforceRetention(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("cannot obtain snapshot list: %w", err) } - snapshots = filterSnapshotsAfter(snapshots, time.Now().Add(-r.RetentionInterval)) + snapshots = FilterSnapshotsAfter(snapshots, time.Now().Add(-r.RetentionInterval)) // If no retained snapshots exist, create a new snapshot. if len(snapshots) == 0 { @@ -801,7 +801,7 @@ func (r *FileReplica) EnforceRetention(ctx context.Context) (err error) { } for _, generation := range generations { // Find earliest retained snapshot for this generation. - snapshot := findMinSnapshotByGeneration(snapshots, generation) + snapshot := FindMinSnapshotByGeneration(snapshots, generation) // Delete generations if it has no snapshots being retained. if snapshot == nil { @@ -863,7 +863,7 @@ func (r *FileReplica) deleteGenerationWALBefore(ctx context.Context, generation } for _, fi := range fis { - idx, _, _, err := ParseWALPath(fi.Name()) + idx, _, _, _, err := ParseWALPath(fi.Name()) if err != nil { continue } else if idx >= index { @@ -910,9 +910,3 @@ func compressFile(src, dst string, uid, gid int) error { // Move compressed file to final location. return os.Rename(dst+".tmp", dst) } - -// walDirMask is a mask used to group 64K wal files into a directory. -const ( - walDirFileN = 0x10000 - walDirMask = uint64(0xFFFFFFFFFFFFFFFF ^ (walDirFileN - 1)) -) diff --git a/s3/s3.go b/s3/s3.go new file mode 100644 index 0000000..d9e8abb --- /dev/null +++ b/s3/s3.go @@ -0,0 +1,874 @@ +package aws + +import ( + "compress/gzip" + "context" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "os" + "path" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/benbjohnson/litestream" +) + +const ( + DefaultRetentionInterval = 1 * time.Hour +) + +var _ litestream.Replica = (*Replica)(nil) + +// Replica is a replica that replicates a DB to an S3 bucket. +type Replica struct { + db *litestream.DB // source database + name string // replica name, optional + s3 *s3.S3 // s3 service + uploader *s3manager.Uploader + + mu sync.RWMutex + pos litestream.Pos // last position + + wg sync.WaitGroup + ctx context.Context + cancel func() + + // AWS authentication keys. + AccessKeyID string + SecretAccessKey string + + // S3 bucket information + Region string + Bucket string + Path string + + // Time to keep snapshots and related WAL files. + // Database is snapshotted after interval and older WAL files are discarded. + RetentionInterval time.Duration + + // If true, replica monitors database for changes automatically. + // Set to false if replica is being used synchronously (such as in tests). + MonitorEnabled bool +} + +// NewReplica returns a new instance of Replica. +func NewReplica(db *litestream.DB, name string) *Replica { + return &Replica{ + db: db, + name: name, + cancel: func() {}, + + RetentionInterval: DefaultRetentionInterval, + MonitorEnabled: true, + } +} + +// Name returns the name of the replica. Returns the type if no name set. +func (r *Replica) Name() string { + if r.name != "" { + return r.name + } + return r.Type() +} + +// Type returns the type of replica. +func (r *Replica) Type() string { + return "file" +} + +// LastPos returns the last successfully replicated position. +func (r *Replica) LastPos() litestream.Pos { + r.mu.RLock() + defer r.mu.RUnlock() + return r.pos +} + +// GenerationDir returns the path to a generation's root directory. +func (r *Replica) GenerationDir(generation string) string { + return path.Join("/", r.Path, "generations", generation) +} + +// SnapshotDir returns the path to a generation's snapshot directory. +func (r *Replica) SnapshotDir(generation string) string { + return path.Join("/", r.Path, r.GenerationDir(generation), "snapshots") +} + +// SnapshotPath returns the path to a snapshot file. +func (r *Replica) SnapshotPath(generation string, index int) string { + return path.Join(r.SnapshotDir(generation), fmt.Sprintf("%016x.snapshot.gz", index)) +} + +// MaxSnapshotIndex returns the highest index for the snapshots. +func (r *Replica) MaxSnapshotIndex(generation string) (int, error) { + snapshots, err := r.Snapshots(context.Background()) + if err != nil { + return 0, err + } + + index := -1 + for _, snapshot := range snapshots { + if snapshot.Generation != generation { + continue + } else if index == -1 || snapshot.Index > index { + index = snapshot.Index + } + } + if index == -1 { + return 0, fmt.Errorf("no snapshots found") + } + return index, nil +} + +// WALDir returns the path to a generation's WAL directory +func (r *Replica) WALDir(generation string) string { + return path.Join(r.GenerationDir(generation), "wal") +} + +// WALPath returns the path to a WAL file. +func (r *Replica) WALPath(generation string, index int) string { + return path.Join(r.WALDir(generation), fmt.Sprintf("%016x.wal", index)) +} + +// Generations returns a list of available generation names. +func (r *Replica) Generations(ctx context.Context) ([]string, error) { + if err := r.Init(ctx); err != nil { + return nil, err + } + + var generations []string + if err := r.s3.ListObjectsPages(&s3.ListObjectsInput{ + Bucket: aws.String(r.Bucket), + Prefix: aws.String(path.Join("/", r.Path, "generations")), + Delimiter: aws.String("/"), + }, func(page *s3.ListObjectsOutput, lastPage bool) bool { + for _, obj := range page.Contents { + key := path.Base(*obj.Key) + if !litestream.IsGenerationName(key) { + continue + } + generations = append(generations, key) + } + return true + }); err != nil { + return nil, err + } + + return generations, nil +} + +// GenerationStats returns stats for a generation. +func (r *Replica) GenerationStats(ctx context.Context, generation string) (stats litestream.GenerationStats, err error) { + if err := r.Init(ctx); err != nil { + return stats, err + } + + // Determine stats for all snapshots. + n, min, max, err := r.snapshotStats(generation) + if err != nil { + return stats, err + } + stats.SnapshotN = n + stats.CreatedAt, stats.UpdatedAt = min, max + + // Update stats if we have WAL files. + n, min, max, err = r.walStats(generation) + if err != nil { + return stats, err + } else if n == 0 { + return stats, nil + } + + stats.WALN = n + if stats.CreatedAt.IsZero() || min.Before(stats.CreatedAt) { + stats.CreatedAt = min + } + if stats.UpdatedAt.IsZero() || max.After(stats.UpdatedAt) { + stats.UpdatedAt = max + } + return stats, nil +} + +func (r *Replica) snapshotStats(generation string) (n int, min, max time.Time, err error) { + var generations []string + if err := r.s3.ListObjectsPages(&s3.ListObjectsInput{ + Bucket: aws.String(r.Bucket), + Prefix: aws.String(r.SnapshotDir(generation)), + Delimiter: aws.String("/"), + }, func(page *s3.ListObjectsOutput, lastPage bool) bool { + for _, obj := range page.Contents { + if !litestream.IsSnapshotPath(*obj.Key) { + continue + } + modTime := obj.LastModified.UTC() + + n++ + if min.IsZero() || modTime.Before(min) { + min = modTime + } + if max.IsZero() || modTime.After(max) { + max = modTime + } + } + return true + }); err != nil { + return n, min, max, err + } + return n, min, max, nil +} + +func (r *Replica) walStats(generation string) (n int, min, max time.Time, err error) { + var generations []string + if err := r.s3.ListObjectsPages(&s3.ListObjectsInput{ + Bucket: aws.String(r.Bucket), + Prefix: aws.String(r.WALDir(generation)), + Delimiter: aws.String("/"), + }, func(page *s3.ListObjectsOutput, lastPage bool) bool { + for _, obj := range page.Contents { + if !litestream.IsWALPath(*obj.Key) { + continue + } + modTime := obj.LastModified.UTC() + + n++ + if min.IsZero() || modTime.Before(min) { + min = modTime + } + if max.IsZero() || modTime.After(max) { + max = modTime + } + } + return true + }); err != nil { + return n, min, max, err + } + return n, min, max, nil +} + +// Snapshots returns a list of available snapshots in the replica. +func (r *Replica) Snapshots(ctx context.Context) ([]*litestream.SnapshotInfo, error) { + if err := r.Init(ctx); err != nil { + return nil, err + } + + generations, err := r.Generations(ctx) + if err != nil { + return nil, err + } + + var infos []*litestream.SnapshotInfo + for _, generation := range generations { + if err := r.s3.ListObjectsPages(&s3.ListObjectsInput{ + Bucket: aws.String(r.Bucket), + Prefix: aws.String(r.SnapshotDir(generation)), + Delimiter: aws.String("/"), + }, func(page *s3.ListObjectsOutput, lastPage bool) bool { + for _, obj := range page.Contents { + key := path.Base(*obj.Key) + index, _, err := litestream.ParseSnapshotPath(key) + if err != nil { + continue + } + + infos = append(infos, &litestream.SnapshotInfo{ + Name: key, + Replica: r.Name(), + Generation: generation, + Index: index, + Size: *obj.Size, + CreatedAt: obj.LastModified.UTC(), + }) + } + return true + }); err != nil { + return nil, err + } + } + + return infos, nil +} + +// WALs returns a list of available WAL files in the replica. +func (r *Replica) WALs(ctx context.Context) ([]*litestream.WALInfo, error) { + if err := r.Init(ctx); err != nil { + return nil, err + } + + generations, err := r.Generations(ctx) + if err != nil { + return nil, err + } + + var infos []*litestream.WALInfo + for _, generation := range generations { + var prev *litestream.WALInfo + if err := r.s3.ListObjectsPages(&s3.ListObjectsInput{ + Bucket: aws.String(r.Bucket), + Prefix: aws.String(r.WALDir(generation)), + Delimiter: aws.String("/"), + }, func(page *s3.ListObjectsOutput, lastPage bool) bool { + for _, obj := range page.Contents { + key := path.Base(*obj.Key) + + index, offset, _, _, err := litestream.ParseWALPath(key) + if err != nil { + continue + } + + // Update previous record if generation & index match. + if prev != nil && prev.Index == index { + prev.Size += *obj.Size + prev.CreatedAt = obj.LastModified.UTC() + continue + } + + // Append new WAL record and keep reference to append additional + // size for segmented WAL files. + prev = &litestream.WALInfo{ + Name: key, + Replica: r.Name(), + Generation: generation, + Index: index, + Offset: offset, + Size: *obj.Size, + CreatedAt: obj.LastModified.UTC(), + } + infos = append(infos, prev) + } + return true + }); err != nil { + return nil, err + } + } + + return infos, nil +} + +// Start starts replication for a given generation. +func (r *Replica) Start(ctx context.Context) { + // Ignore if replica is being used sychronously. + if !r.MonitorEnabled { + return + } + + // Stop previous replication. + r.Stop() + + // Wrap context with cancelation. + ctx, r.cancel = context.WithCancel(ctx) + + // Start goroutines to manage replica data. + r.wg.Add(2) + go func() { defer r.wg.Done(); r.monitor(ctx) }() + go func() { defer r.wg.Done(); r.retainer(ctx) }() +} + +// Stop cancels any outstanding replication and blocks until finished. +func (r *Replica) Stop() { + r.cancel() + r.wg.Wait() +} + +// monitor runs in a separate goroutine and continuously replicates the DB. +func (r *Replica) monitor(ctx context.Context) { + // Continuously check for new data to replicate. + ch := make(chan struct{}) + close(ch) + var notify <-chan struct{} = ch + + for { + select { + case <-ctx.Done(): + return + case <-notify: + } + + // Fetch new notify channel before replicating data. + notify = r.db.Notify() + + // Synchronize the shadow wal into the replication directory. + if err := r.Sync(ctx); err != nil { + log.Printf("%s(%s): sync error: %s", r.db.Path(), r.Name(), err) + continue + } + } +} + +// retainer runs in a separate goroutine and handles retention. +func (r *Replica) retainer(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := r.EnforceRetention(ctx); err != nil { + log.Printf("%s(%s): retain error: %s", r.db.Path(), r.Name(), err) + continue + } + } + } +} + +// CalcPos returns the position for the replica for the current generation. +// Returns a zero value if there is no active generation. +func (r *Replica) CalcPos(generation string) (pos litestream.Pos, err error) { + if err := r.Init(context.Background()); err != nil { + return pos, err + } + + pos.Generation = generation + + // Find maximum snapshot index. + if pos.Index, err = r.MaxSnapshotIndex(generation); err != nil { + return litestream.Pos{}, err + } + + index := -1 + var offset int64 + if err := r.s3.ListObjectsPages(&s3.ListObjectsInput{ + Bucket: aws.String(r.Bucket), + Prefix: aws.String(r.WALDir(generation)), + Delimiter: aws.String("/"), + }, func(page *s3.ListObjectsOutput, lastPage bool) bool { + for _, obj := range page.Contents { + key := path.Base(*obj.Key) + + idx, offset, sz, _, err := litestream.ParseWALPath(key) + if err != nil { + continue // invalid wal filename + } + + if index == -1 || idx > index { + index, offset = idx, 0 // start tracking new wal + } else if idx == index { + offset += sz // append additional size to wal + } + } + return true + }); err != nil { + return litestream.Pos{}, err + } + if index == -1 { + return pos, nil // no wal files + } + pos.Index = index + pos.Offset = offset + + return pos, nil +} + +// snapshot copies the entire database to the replica path. +func (r *Replica) snapshot(ctx context.Context, generation string, index int) error { + // Acquire a read lock on the database during snapshot to prevent checkpoints. + tx, err := r.db.SQLDB().Begin() + if err != nil { + return err + } else if _, err := tx.ExecContext(ctx, `SELECT COUNT(1) FROM _litestream_seq;`); err != nil { + tx.Rollback() + return err + } + defer tx.Rollback() + + // Open database file handle. + f, err := os.Open(r.db.Path()) + if err != nil { + return err + } + defer f.Close() + + pr, pw := io.Pipe() + gw := gzip.NewWriter(pw) + go func() { io.Copy(gw, f) }() + + snapshotPath := r.SnapshotPath(generation, index) + + if _, err := r.uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String(r.Bucket), + Key: aws.String(snapshotPath), + Body: pr, + }); err != nil { + return err + } + return nil +} + +// snapshotN returns the number of snapshots for a generation. +func (r *Replica) snapshotN(generation string) (int, error) { + snapshots, err := r.Snapshots(context.Background()) + if err != nil { + return 0, err + } + + var n int + for _, snapshot := range snapshots { + if snapshot.Generation == generation { + n++ + } + } + return n, nil +} + +// Init initializes the connection to S3. No-op if already initialized. +func (r *Replica) Init(ctx context.Context) (err error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.s3 != nil { + return nil + } + + sess, err := session.NewSession(&aws.Config{Region: aws.String(r.Region)}) + if err != nil { + return fmt.Errorf("cannot create aws session: %w", err) + } + r.s3 = s3.New(sess) + r.uploader = s3manager.NewUploader(sess) + return nil +} + +func (r *Replica) Sync(ctx context.Context) (err error) { + if err := r.Init(ctx); err != nil { + return err + } + + // Find current position of database. + dpos, err := r.db.Pos() + if err != nil { + return fmt.Errorf("cannot determine current generation: %w", err) + } else if dpos.IsZero() { + return fmt.Errorf("no generation, waiting for data") + } + generation := dpos.Generation + + // Create snapshot if no snapshots exist for generation. + if n, err := r.snapshotN(generation); err != nil { + return err + } else if n == 0 { + if err := r.snapshot(ctx, generation, dpos.Index); err != nil { + return err + } + } + + // Determine position, if necessary. + if r.LastPos().IsZero() { + pos, err := r.CalcPos(generation) + if err != nil { + return fmt.Errorf("cannot determine replica position: %s", err) + } + + r.mu.Lock() + r.pos = pos + r.mu.Unlock() + } + + // Read all WAL files since the last position. + for { + if err = r.syncWAL(ctx); err == io.EOF { + break + } else if err != nil { + return err + } + } + + return nil +} + +func (r *Replica) syncWAL(ctx context.Context) (err error) { + rd, err := r.db.ShadowWALReader(r.LastPos()) + if err == io.EOF { + return err + } else if err != nil { + return fmt.Errorf("wal reader: %w", err) + } + defer rd.Close() + + // Ensure parent directory exists for WAL file. + filename := r.WALPath(rd.Pos().Generation, rd.Pos().Index) + if err := os.MkdirAll(path.Dir(filename), 0700); err != nil { + return err + } + + // Create a temporary file to write into so we don't have partial writes. + w, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return err + } + defer w.Close() + + // Seek, copy & sync WAL contents. + if _, err := w.Seek(rd.Pos().Offset, io.SeekStart); err != nil { + return err + } else if _, err := io.Copy(w, rd); err != nil { + return err + } else if err := w.Sync(); err != nil { + return err + } else if err := w.Close(); err != nil { + return err + } + + // Save last replicated position. + r.mu.Lock() + r.pos = rd.Pos() + r.mu.Unlock() + + return nil +} + +// SnapsotIndexAt returns the highest index for a snapshot within a generation +// that occurs before timestamp. If timestamp is zero, returns the latest snapshot. +func (r *Replica) SnapshotIndexAt(ctx context.Context, generation string, timestamp time.Time) (int, error) { + fis, err := ioutil.ReadDir(r.SnapshotDir(generation)) + if os.IsNotExist(err) { + return 0, litestream.ErrNoSnapshots + } else if err != nil { + return 0, err + } + + index := -1 + var max time.Time + for _, fi := range fis { + // Read index from snapshot filename. + idx, _, err := litestream.ParseSnapshotPath(fi.Name()) + if err != nil { + continue // not a snapshot, skip + } else if !timestamp.IsZero() && fi.ModTime().After(timestamp) { + continue // after timestamp, skip + } + + // Use snapshot if it newer. + if max.IsZero() || fi.ModTime().After(max) { + index, max = idx, fi.ModTime() + } + } + + if index == -1 { + return 0, litestream.ErrNoSnapshots + } + return index, nil +} + +// Returns the highest index for a WAL file that occurs before maxIndex & timestamp. +// If timestamp is zero, returns the highest WAL index. +func (r *Replica) WALIndexAt(ctx context.Context, generation string, maxIndex int, timestamp time.Time) (int, error) { + var index int + fis, err := ioutil.ReadDir(r.WALDir(generation)) + if os.IsNotExist(err) { + return 0, nil + } else if err != nil { + return 0, err + } + + for _, fi := range fis { + // Read index from snapshot filename. + idx, _, _, _, err := litestream.ParseWALPath(fi.Name()) + if err != nil { + continue // not a snapshot, skip + } else if !timestamp.IsZero() && fi.ModTime().After(timestamp) { + continue // after timestamp, skip + } else if idx > maxIndex { + continue // after timestamp, skip + } else if idx < index { + continue // earlier index, skip + } + + index = idx + } + + // If max index is specified but not found, return an error. + if maxIndex != math.MaxInt64 && index != maxIndex { + return index, fmt.Errorf("unable to locate index %d in generation %q, highest index was %d", maxIndex, generation, index) + } + + return index, nil +} + +// SnapshotReader returns a reader for snapshot data at the given generation/index. +// Returns os.ErrNotExist if no matching index is found. +func (r *Replica) SnapshotReader(ctx context.Context, generation string, index int) (io.ReadCloser, error) { + dir := r.SnapshotDir(generation) + fis, err := ioutil.ReadDir(dir) + if err != nil { + return nil, err + } + + for _, fi := range fis { + // Parse index from snapshot filename. Skip if no match. + idx, ext, err := litestream.ParseSnapshotPath(fi.Name()) + if err != nil || index != idx { + continue + } + + // Open & return the file handle if uncompressed. + f, err := os.Open(path.Join(dir, fi.Name())) + if err != nil { + return nil, err + } else if ext == ".snapshot" { + return f, nil // not compressed, return as-is. + } + // assert(ext == ".snapshot.gz", "invalid snapshot extension") + + // If compressed, wrap in a gzip reader and return with wrapper to + // ensure that the underlying file is closed. + r, err := gzip.NewReader(f) + if err != nil { + f.Close() + return nil, err + } + return &gzipReadCloser{r: r, closer: f}, nil + } + return nil, os.ErrNotExist +} + +// WALReader returns a reader for WAL data at the given index. +// Returns os.ErrNotExist if no matching index is found. +func (r *Replica) WALReader(ctx context.Context, generation string, index int) (io.ReadCloser, error) { + filename := r.WALPath(generation, index) + + // Attempt to read uncompressed file first. + f, err := os.Open(filename) + if err == nil { + return f, nil // file exist, return + } else if err != nil && !os.IsNotExist(err) { + return nil, err + } + + // Otherwise read the compressed file. Return error if file doesn't exist. + f, err = os.Open(filename + ".gz") + if err != nil { + return nil, err + } + + // If compressed, wrap in a gzip reader and return with wrapper to + // ensure that the underlying file is closed. + rd, err := gzip.NewReader(f) + if err != nil { + f.Close() + return nil, err + } + return &gzipReadCloser{r: rd, closer: f}, nil +} + +// EnforceRetention forces a new snapshot once the retention interval has passed. +// Older snapshots and WAL files are then removed. +func (r *Replica) EnforceRetention(ctx context.Context) (err error) { + if err := r.Init(ctx); err != nil { + return err + } + + // Find current position of database. + pos, err := r.db.Pos() + if err != nil { + return fmt.Errorf("cannot determine current generation: %w", err) + } else if pos.IsZero() { + return fmt.Errorf("no generation, waiting for data") + } + + // Obtain list of snapshots that are within the retention period. + snapshots, err := r.Snapshots(ctx) + if err != nil { + return fmt.Errorf("cannot obtain snapshot list: %w", err) + } + snapshots = litestream.FilterSnapshotsAfter(snapshots, time.Now().Add(-r.RetentionInterval)) + + // If no retained snapshots exist, create a new snapshot. + if len(snapshots) == 0 { + log.Printf("%s(%s): snapshots exceeds retention, creating new snapshot", r.db.Path(), r.Name()) + if err := r.snapshot(ctx, pos.Generation, pos.Index); err != nil { + return fmt.Errorf("cannot snapshot: %w", err) + } + snapshots = append(snapshots, &litestream.SnapshotInfo{Generation: pos.Generation, Index: pos.Index}) + } + + // Loop over generations and delete unretained snapshots & WAL files. + generations, err := r.Generations(ctx) + if err != nil { + return fmt.Errorf("cannot obtain generations: %w", err) + } + for _, generation := range generations { + // Find earliest retained snapshot for this generation. + snapshot := litestream.FindMinSnapshotByGeneration(snapshots, generation) + + // Delete generations if it has no snapshots being retained. + if snapshot == nil { + log.Printf("%s(%s): generation %q has no retained snapshots, deleting", r.db.Path(), r.Name(), generation) + if err := os.RemoveAll(r.GenerationDir(generation)); err != nil { + return fmt.Errorf("cannot delete generation %q dir: %w", generation, err) + } + continue + } + + // Otherwise delete all snapshots & WAL files before a lowest retained index. + if err := r.deleteGenerationSnapshotsBefore(ctx, generation, snapshot.Index); err != nil { + return fmt.Errorf("cannot delete generation %q snapshots before index %d: %w", generation, snapshot.Index, err) + } else if err := r.deleteGenerationWALBefore(ctx, generation, snapshot.Index); err != nil { + return fmt.Errorf("cannot delete generation %q wal before index %d: %w", generation, snapshot.Index, err) + } + } + + return nil +} + +// deleteGenerationSnapshotsBefore deletes snapshot before a given index. +func (r *Replica) deleteGenerationSnapshotsBefore(ctx context.Context, generation string, index int) (err error) { + dir := r.SnapshotDir(generation) + + fis, err := ioutil.ReadDir(dir) + if os.IsNotExist(err) { + return nil + } else if err != nil { + return err + } + + for _, fi := range fis { + idx, _, err := litestream.ParseSnapshotPath(fi.Name()) + if err != nil { + continue + } else if idx >= index { + continue + } + + log.Printf("%s(%s): generation %q snapshot no longer retained, deleting %s", r.db.Path(), r.Name(), generation, fi.Name()) + if err := os.Remove(path.Join(dir, fi.Name())); err != nil { + return err + } + } + + return nil +} + +// deleteGenerationWALBefore deletes WAL files before a given index. +func (r *Replica) deleteGenerationWALBefore(ctx context.Context, generation string, index int) (err error) { + dir := r.WALDir(generation) + + fis, err := ioutil.ReadDir(dir) + if os.IsNotExist(err) { + return nil + } else if err != nil { + return err + } + + for _, fi := range fis { + idx, _, _, _, err := litestream.ParseWALPath(fi.Name()) + if err != nil { + continue + } else if idx >= index { + continue + } + + log.Printf("%s(%s): generation %q wal no longer retained, deleting %s", r.db.Path(), r.Name(), generation, fi.Name()) + if err := os.Remove(path.Join(dir, fi.Name())); err != nil { + return err + } + } + + return nil +}