diff --git a/s3/s3.go b/s3/s3.go index 90754a2..9366af4 100644 --- a/s3/s3.go +++ b/s3/s3.go @@ -20,13 +20,15 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream/internal" - "github.com/davecgh/go-spew/spew" ) const ( DefaultRetentionInterval = 1 * time.Hour ) +// MaxKeys is the number of keys S3 can operate on per batch. +const MaxKeys = 1000 + var _ litestream.Replica = (*Replica)(nil) // Replica is a replica that replicates a DB to an S3 bucket. @@ -417,8 +419,6 @@ func (r *Replica) retainer(ctx context.Context) { // CalcPos returns the position for the replica for the current generation. // Returns a zero value if there is no active generation. func (r *Replica) CalcPos(generation string) (pos litestream.Pos, err error) { - println("dbg/calcpos", generation) - if err := r.Init(context.Background()); err != nil { return pos, err } @@ -429,7 +429,6 @@ func (r *Replica) CalcPos(generation string) (pos litestream.Pos, err error) { if pos.Index, err = r.MaxSnapshotIndex(generation); err != nil { return litestream.Pos{}, err } - println("dbg/calcpos.snapshotindex", pos.Index) index := -1 var offset int64 @@ -438,9 +437,6 @@ func (r *Replica) CalcPos(generation string) (pos litestream.Pos, err error) { Prefix: aws.String(r.WALDir(generation) + "/"), Delimiter: aws.String("/"), }, func(page *s3.ListObjectsOutput, lastPage bool) bool { - println("dbg/calcpos.page") - spew.Dump(page) - for _, obj := range page.Contents { key := path.Base(*obj.Key) @@ -548,7 +544,6 @@ func (r *Replica) Init(ctx context.Context) (err error) { } func (r *Replica) Sync(ctx context.Context) (err error) { - println("dbg/s3.sync") if err := r.Init(ctx); err != nil { return err } @@ -622,7 +617,6 @@ func (r *Replica) syncWAL(ctx context.Context) (err error) { r.WALDir(rd.Pos().Generation), litestream.FormatWALPathWithOffsetSize(rd.Pos().Index, rd.Pos().Offset, int64(len(b)))+".gz", ) - println("dbg/syncwal", walPath) if _, err := r.uploader.Upload(&s3manager.UploadInput{ Bucket: aws.String(r.Bucket), @@ -641,72 +635,90 @@ func (r *Replica) syncWAL(ctx context.Context) (err error) { } // SnapshotReader returns a reader for snapshot data at the given generation/index. -// Returns os.ErrNotExist if no matching index is found. func (r *Replica) SnapshotReader(ctx context.Context, generation string, index int) (io.ReadCloser, error) { - dir := r.SnapshotDir(generation) - fis, err := ioutil.ReadDir(dir) + if err := r.Init(ctx); err != nil { + return nil, err + } + + // Pipe download to return an io.Reader. + out, err := r.s3.GetObjectWithContext(ctx, &s3.GetObjectInput{ + Bucket: aws.String(r.Bucket), + Key: aws.String(r.SnapshotPath(generation, index)), + }) if err != nil { return nil, err } - for _, fi := range fis { - // Parse index from snapshot filename. Skip if no match. - idx, ext, err := litestream.ParseSnapshotPath(fi.Name()) - if err != nil || index != idx { - continue - } - - // Open & return the file handle if uncompressed. - f, err := os.Open(path.Join(dir, fi.Name())) - if err != nil { - return nil, err - } else if ext == ".snapshot" { - return f, nil // not compressed, return as-is. - } - // assert(ext == ".snapshot.gz", "invalid snapshot extension") - - // If compressed, wrap in a gzip reader and return with wrapper to - // ensure that the underlying file is closed. - r, err := gzip.NewReader(f) - if err != nil { - f.Close() - return nil, err - } - return internal.NewReadCloser(r, f), nil + // Decompress the snapshot file. + gr, err := gzip.NewReader(out.Body) + if err != nil { + out.Body.Close() + return nil, err } - return nil, os.ErrNotExist + return internal.NewReadCloser(gr, out.Body), nil } // WALReader returns a reader for WAL data at the given index. // Returns os.ErrNotExist if no matching index is found. func (r *Replica) WALReader(ctx context.Context, generation string, index int) (io.ReadCloser, error) { - panic("TODO") - /* - filename := r.WALPath(generation, index) + if err := r.Init(ctx); err != nil { + return nil, err + } - // Attempt to read uncompressed file first. - f, err := os.Open(filename) - if err == nil { - return f, nil // file exist, return - } else if err != nil && !os.IsNotExist(err) { - return nil, err + // Collect all files for the index. + var keys []string + var offset int64 + var innerErr error + if err := r.s3.ListObjectsPages(&s3.ListObjectsInput{ + Bucket: aws.String(r.Bucket), + Prefix: aws.String(path.Join(r.WALDir(generation), fmt.Sprintf("%016x_", index))), + }, func(page *s3.ListObjectsOutput, lastPage bool) bool { + for _, obj := range page.Contents { + // Read the offset & size from the filename. We need to check this + // against a running offset to ensure there are no gaps. + _, off, sz, _, err := litestream.ParseWALPath(path.Base(*obj.Key)) + if err != nil { + continue + } else if off != offset { + innerErr = fmt.Errorf("out of sequence wal segments: %s/%016x", generation, index) + return false + } + + keys = append(keys, *obj.Key) + offset += sz } + return true + }); err != nil { + return nil, err + } else if innerErr != nil { + return nil, innerErr + } - // Otherwise read the compressed file. Return error if file doesn't exist. - f, err = os.Open(filename + ".gz") + // Open each file and concatenate into a multi-reader. + var mrc multiReadCloser + for _, key := range keys { + // Pipe download to return an io.Reader. + out, err := r.s3.GetObjectWithContext(ctx, &s3.GetObjectInput{ + Bucket: aws.String(r.Bucket), + Key: aws.String(key), + }) if err != nil { + mrc.Close() return nil, err } - // If compressed, wrap in a gzip reader and return with wrapper to - // ensure that the underlying file is closed. - rd, err := gzip.NewReader(f) + // Decompress the snapshot file. + gr, err := gzip.NewReader(out.Body) if err != nil { - f.Close() + out.Body.Close() + mrc.Close() return nil, err } - return internal.NewReadCloser(rd, f), nil - */ + + mrc.readers = append(mrc.readers, internal.NewReadCloser(gr, out.Body)) + } + + return &mrc, nil } // EnforceRetention forces a new snapshot once the retention interval has passed. @@ -752,44 +764,63 @@ func (r *Replica) EnforceRetention(ctx context.Context) (err error) { // Delete generations if it has no snapshots being retained. if snapshot == nil { log.Printf("%s(%s): generation %q has no retained snapshots, deleting", r.db.Path(), r.Name(), generation) - if err := os.RemoveAll(r.GenerationDir(generation)); err != nil { + if err := r.deleteGenerationBefore(ctx, generation, -1); err != nil { return fmt.Errorf("cannot delete generation %q dir: %w", generation, err) } continue } // Otherwise delete all snapshots & WAL files before a lowest retained index. - if err := r.deleteGenerationSnapshotsBefore(ctx, generation, snapshot.Index); err != nil { - return fmt.Errorf("cannot delete generation %q snapshots before index %d: %w", generation, snapshot.Index, err) - } else if err := r.deleteGenerationWALBefore(ctx, generation, snapshot.Index); err != nil { - return fmt.Errorf("cannot delete generation %q wal before index %d: %w", generation, snapshot.Index, err) + if err := r.deleteGenerationBefore(ctx, generation, snapshot.Index); err != nil { + return fmt.Errorf("cannot delete generation %q files before index %d: %w", generation, snapshot.Index, err) } } return nil } -// deleteGenerationSnapshotsBefore deletes snapshot before a given index. -func (r *Replica) deleteGenerationSnapshotsBefore(ctx context.Context, generation string, index int) (err error) { - dir := r.SnapshotDir(generation) +func (r *Replica) deleteGenerationBefore(ctx context.Context, generation string, index int) (err error) { + // Collect all files for the generation. + var objIDs []*s3.ObjectIdentifier + if err := r.s3.ListObjectsPages(&s3.ListObjectsInput{ + Bucket: aws.String(r.Bucket), + Prefix: aws.String(r.GenerationDir(generation)), + }, func(page *s3.ListObjectsOutput, lastPage bool) bool { + for _, obj := range page.Contents { + // Skip snapshots or WALs that are after the search index unless -1. + if index != -1 { + if idx, _, err := litestream.ParseSnapshotPath(path.Base(*obj.Key)); err == nil && idx >= index { + continue + } else if idx, _, _, _, err := litestream.ParseWALPath(path.Base(*obj.Key)); err == nil && idx >= index { + continue + } + } - fis, err := ioutil.ReadDir(dir) - if os.IsNotExist(err) { - return nil - } else if err != nil { + objIDs = append(objIDs, &s3.ObjectIdentifier{Key: obj.Key}) + } + return true + }); err != nil { return err } - for _, fi := range fis { - idx, _, err := litestream.ParseSnapshotPath(fi.Name()) - if err != nil { - continue - } else if idx >= index { - continue + // Delete all files in batches. + for i := 0; i < len(objIDs); i += MaxKeys { + j := i + MaxKeys + if j > len(objIDs) { + j = len(objIDs) } - log.Printf("%s(%s): generation %q snapshot no longer retained, deleting %s", r.db.Path(), r.Name(), generation, fi.Name()) - if err := os.Remove(path.Join(dir, fi.Name())); err != nil { + for _, objID := range objIDs[i:j] { + log.Printf("%s(%s): generation %q file no longer retained, deleting %s", r.db.Path(), r.Name(), generation, path.Base(*objID.Key)) + } + + if _, err := r.s3.DeleteObjectsWithContext(ctx, &s3.DeleteObjectsInput{ + Bucket: aws.String(r.Bucket), + Delete: &s3.Delete{ + Objects: objIDs[i:j], + Quiet: aws.Bool(true), + }, + }); err != nil { return err } } @@ -797,30 +828,36 @@ func (r *Replica) deleteGenerationSnapshotsBefore(ctx context.Context, generatio return nil } -// deleteGenerationWALBefore deletes WAL files before a given index. -func (r *Replica) deleteGenerationWALBefore(ctx context.Context, generation string, index int) (err error) { - dir := r.WALDir(generation) - - fis, err := ioutil.ReadDir(dir) - if os.IsNotExist(err) { - return nil - } else if err != nil { - return err - } - - for _, fi := range fis { - idx, _, _, _, err := litestream.ParseWALPath(fi.Name()) - if err != nil { - continue - } else if idx >= index { - continue - } - - log.Printf("%s(%s): generation %q wal no longer retained, deleting %s", r.db.Path(), r.Name(), generation, fi.Name()) - if err := os.Remove(path.Join(dir, fi.Name())); err != nil { - return err - } - } - - return nil +type multiReadCloser struct { + readers []io.ReadCloser +} + +func (mr *multiReadCloser) Read(p []byte) (n int, err error) { + for len(mr.readers) > 0 { + n, err = mr.readers[0].Read(p) + if err == io.EOF { + if e := mr.readers[0].Close(); e != nil { + return n, e + } + mr.readers[0] = nil + mr.readers = mr.readers[1:] + } + + if n > 0 || err != io.EOF { + if err == io.EOF && len(mr.readers) > 0 { + err = nil + } + return + } + } + return 0, io.EOF +} + +func (mr *multiReadCloser) Close() (err error) { + for _, r := range mr.readers { + if e := r.Close(); e != nil && err == nil { + err = e + } + } + return err }