Refactor replica system

This commit is contained in:
Ben Johnson
2021-05-20 18:37:54 -06:00
parent 8685e9f2d1
commit fb80bc10ae
24 changed files with 4338 additions and 2978 deletions

746
s3/replica_client.go Normal file
View File

@@ -0,0 +1,746 @@
package s3
import (
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
"os"
"path"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/benbjohnson/litestream"
"github.com/benbjohnson/litestream/internal"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"golang.org/x/sync/errgroup"
)
// MaxKeys is the number of keys S3 can operate on per batch.
const MaxKeys = 1000
// DefaultRegion is the region used if one is not specified.
const DefaultRegion = "us-east-1"
var _ litestream.ReplicaClient = (*ReplicaClient)(nil)
// ReplicaClient is a client for writing snapshots & WAL segments to disk.
type ReplicaClient struct {
mu sync.Mutex
s3 *s3.S3 // s3 service
uploader *s3manager.Uploader
// AWS authentication keys.
AccessKeyID string
SecretAccessKey string
// S3 bucket information
Region string
Bucket string
Path string
Endpoint string
ForcePathStyle bool
SkipVerify bool
}
// NewReplicaClient returns a new instance of ReplicaClient.
func NewReplicaClient() *ReplicaClient {
return &ReplicaClient{}
}
// Type returns "s3" as the client type.
func (c *ReplicaClient) Type() string {
return "s3"
}
// GenerationsDir returns the path to a generation root directory.
func (c *ReplicaClient) GenerationsDir() string {
return path.Join(c.Path, "generations")
}
// GenerationDir returns the path to a generation's root directory.
func (c *ReplicaClient) GenerationDir(generation string) (string, error) {
dir := c.GenerationsDir()
if generation == "" {
return "", fmt.Errorf("generation required")
}
return path.Join(dir, generation), nil
}
// SnapshotsDir returns the path to a generation's snapshot directory.
func (c *ReplicaClient) SnapshotsDir(generation string) (string, error) {
dir, err := c.GenerationDir(generation)
if err != nil {
return "", err
}
return path.Join(dir, "snapshots"), nil
}
// SnapshotPath returns the path to an uncompressed snapshot file.
func (c *ReplicaClient) SnapshotPath(generation string, index int) (string, error) {
dir, err := c.SnapshotsDir(generation)
if err != nil {
return "", err
}
return path.Join(dir, litestream.FormatSnapshotPath(index)), nil
}
// WALDir returns the path to a generation's WAL directory
func (c *ReplicaClient) WALDir(generation string) (string, error) {
dir, err := c.GenerationDir(generation)
if err != nil {
return "", err
}
return path.Join(dir, "wal"), nil
}
// WALSegmentPath returns the path to a WAL segment file.
func (c *ReplicaClient) WALSegmentPath(generation string, index int, offset int64) (string, error) {
dir, err := c.WALDir(generation)
if err != nil {
return "", err
}
return path.Join(dir, litestream.FormatWALSegmentPath(index, offset)), nil
}
// Init initializes the connection to S3. No-op if already initialized.
func (c *ReplicaClient) Init(ctx context.Context) (err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.s3 != nil {
return nil
}
// Look up region if not specified and no endpoint is used.
// Endpoints are typically used for non-S3 object stores and do not
// necessarily require a region.
region := c.Region
if region == "" {
if c.Endpoint == "" {
if region, err = c.findBucketRegion(ctx, c.Bucket); err != nil {
return fmt.Errorf("cannot lookup bucket region: %w", err)
}
} else {
region = DefaultRegion // default for non-S3 object stores
}
}
// Create new AWS session.
config := c.config()
if region != "" {
config.Region = aws.String(region)
}
sess, err := session.NewSession(config)
if err != nil {
return fmt.Errorf("cannot create aws session: %w", err)
}
c.s3 = s3.New(sess)
c.uploader = s3manager.NewUploader(sess)
return nil
}
// config returns the AWS configuration. Uses the default credential chain
// unless a key/secret are explicitly set.
func (c *ReplicaClient) config() *aws.Config {
config := defaults.Get().Config
if c.AccessKeyID != "" || c.SecretAccessKey != "" {
config.Credentials = credentials.NewStaticCredentials(c.AccessKeyID, c.SecretAccessKey, "")
}
if c.Endpoint != "" {
config.Endpoint = aws.String(c.Endpoint)
}
if c.ForcePathStyle {
config.S3ForcePathStyle = aws.Bool(c.ForcePathStyle)
}
if c.SkipVerify {
config.HTTPClient = &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
}
return config
}
func (c *ReplicaClient) findBucketRegion(ctx context.Context, bucket string) (string, error) {
// Connect to US standard region to fetch info.
config := c.config()
config.Region = aws.String(DefaultRegion)
sess, err := session.NewSession(config)
if err != nil {
return "", err
}
// Fetch bucket location, if possible. Must be bucket owner.
// This call can return a nil location which means it's in us-east-1.
if out, err := s3.New(sess).GetBucketLocation(&s3.GetBucketLocationInput{
Bucket: aws.String(bucket),
}); err != nil {
return "", err
} else if out.LocationConstraint != nil {
return *out.LocationConstraint, nil
}
return DefaultRegion, nil
}
// Generations returns a list of available generation names.
func (c *ReplicaClient) Generations(ctx context.Context) ([]string, error) {
if err := c.Init(ctx); err != nil {
return nil, err
}
var generations []string
if err := c.s3.ListObjectsPagesWithContext(ctx, &s3.ListObjectsInput{
Bucket: aws.String(c.Bucket),
Prefix: aws.String(c.GenerationsDir() + "/"),
Delimiter: aws.String("/"),
}, func(page *s3.ListObjectsOutput, lastPage bool) bool {
operationTotalCounterVec.WithLabelValues("LIST").Inc()
for _, prefix := range page.CommonPrefixes {
name := path.Base(*prefix.Prefix)
if !litestream.IsGenerationName(name) {
continue
}
generations = append(generations, name)
}
return true
}); err != nil {
return nil, err
}
return generations, nil
}
// DeleteGeneration deletes all snapshots & WAL segments within a generation.
func (c *ReplicaClient) DeleteGeneration(ctx context.Context, generation string) error {
if err := c.Init(ctx); err != nil {
return err
}
dir, err := c.GenerationDir(generation)
if err != nil {
return fmt.Errorf("cannot determine generation directory path: %w", err)
}
// Collect all files for the generation.
var objIDs []*s3.ObjectIdentifier
if err := c.s3.ListObjectsPagesWithContext(ctx, &s3.ListObjectsInput{
Bucket: aws.String(c.Bucket),
Prefix: aws.String(dir),
}, func(page *s3.ListObjectsOutput, lastPage bool) bool {
operationTotalCounterVec.WithLabelValues("LIST").Inc()
for _, obj := range page.Contents {
objIDs = append(objIDs, &s3.ObjectIdentifier{Key: obj.Key})
}
return true
}); err != nil {
return err
}
// Delete all files in batches.
for len(objIDs) > 0 {
n := MaxKeys
if len(objIDs) < n {
n = len(objIDs)
}
if _, err := c.s3.DeleteObjectsWithContext(ctx, &s3.DeleteObjectsInput{
Bucket: aws.String(c.Bucket),
Delete: &s3.Delete{Objects: objIDs[:n], Quiet: aws.Bool(true)},
}); err != nil {
return err
}
operationTotalCounterVec.WithLabelValues("DELETE").Inc()
objIDs = objIDs[n:]
}
// log.Printf("%s(%s): retainer: deleting generation: %s", r.db.Path(), r.Name(), generation)
return nil
}
// Snapshots returns an iterator over all available snapshots for a generation.
func (c *ReplicaClient) Snapshots(ctx context.Context, generation string) (litestream.SnapshotIterator, error) {
if err := c.Init(ctx); err != nil {
return nil, err
}
return newSnapshotIterator(ctx, c, generation), nil
}
// WriteSnapshot writes LZ4 compressed data from rd into a file on disk.
func (c *ReplicaClient) WriteSnapshot(ctx context.Context, generation string, index int, rd io.Reader) (info litestream.SnapshotInfo, err error) {
if err := c.Init(ctx); err != nil {
return info, err
}
key, err := c.SnapshotPath(generation, index)
if err != nil {
return info, fmt.Errorf("cannot determine snapshot path: %w", err)
}
startTime := time.Now()
rc := internal.NewReadCounter(rd)
if _, err := c.uploader.UploadWithContext(ctx, &s3manager.UploadInput{
Bucket: aws.String(c.Bucket),
Key: aws.String(key),
Body: rc,
}); err != nil {
return info, err
}
operationTotalCounterVec.WithLabelValues("PUT").Inc()
operationBytesCounterVec.WithLabelValues("PUT").Add(float64(rc.N()))
// log.Printf("%s(%s): snapshot: creating %s/%08x t=%s", r.db.Path(), r.Name(), generation, index, time.Since(startTime).Truncate(time.Millisecond))
return litestream.SnapshotInfo{
Generation: generation,
Index: index,
Size: rc.N(),
CreatedAt: startTime.UTC(),
}, nil
}
// SnapshotReader returns a reader for snapshot data at the given generation/index.
func (c *ReplicaClient) SnapshotReader(ctx context.Context, generation string, index int) (io.ReadCloser, error) {
if err := c.Init(ctx); err != nil {
return nil, err
}
key, err := c.SnapshotPath(generation, index)
if err != nil {
return nil, fmt.Errorf("cannot determine snapshot path: %w", err)
}
out, err := c.s3.GetObjectWithContext(ctx, &s3.GetObjectInput{
Bucket: aws.String(c.Bucket),
Key: aws.String(key),
})
if isNotExists(err) {
return nil, os.ErrNotExist
} else if err != nil {
return nil, err
}
operationTotalCounterVec.WithLabelValues("GET").Inc()
operationBytesCounterVec.WithLabelValues("GET").Add(float64(*out.ContentLength))
return out.Body, nil
}
// DeleteSnapshot deletes a snapshot with the given generation & index.
func (c *ReplicaClient) DeleteSnapshot(ctx context.Context, generation string, index int) error {
if err := c.Init(ctx); err != nil {
return err
}
key, err := c.SnapshotPath(generation, index)
if err != nil {
return fmt.Errorf("cannot determine snapshot path: %w", err)
}
if _, err := c.s3.DeleteObjectsWithContext(ctx, &s3.DeleteObjectsInput{
Bucket: aws.String(c.Bucket),
Delete: &s3.Delete{Objects: []*s3.ObjectIdentifier{{Key: &key}}, Quiet: aws.Bool(true)},
}); err != nil {
return err
}
operationTotalCounterVec.WithLabelValues("DELETE").Inc()
return nil
}
// WALSegments returns an iterator over all available WAL files for a generation.
func (c *ReplicaClient) WALSegments(ctx context.Context, generation string) (litestream.WALSegmentIterator, error) {
if err := c.Init(ctx); err != nil {
return nil, err
}
return newWALSegmentIterator(ctx, c, generation), nil
}
// WriteWALSegment writes LZ4 compressed data from rd into a file on disk.
func (c *ReplicaClient) WriteWALSegment(ctx context.Context, pos litestream.Pos, rd io.Reader) (info litestream.WALSegmentInfo, err error) {
if err := c.Init(ctx); err != nil {
return info, err
}
key, err := c.WALSegmentPath(pos.Generation, pos.Index, pos.Offset)
if err != nil {
return info, fmt.Errorf("cannot determine wal segment path: %w", err)
}
startTime := time.Now()
rc := internal.NewReadCounter(rd)
if _, err := c.uploader.UploadWithContext(ctx, &s3manager.UploadInput{
Bucket: aws.String(c.Bucket),
Key: aws.String(key),
Body: rc,
}); err != nil {
return info, err
}
operationTotalCounterVec.WithLabelValues("PUT").Inc()
operationBytesCounterVec.WithLabelValues("PUT").Add(float64(rc.N()))
return litestream.WALSegmentInfo{
Generation: pos.Generation,
Index: pos.Index,
Offset: pos.Offset,
Size: rc.N(),
CreatedAt: startTime.UTC(),
}, nil
}
// WALSegmentReader returns a reader for a section of WAL data at the given index.
// Returns os.ErrNotExist if no matching index/offset is found.
func (c *ReplicaClient) WALSegmentReader(ctx context.Context, pos litestream.Pos) (io.ReadCloser, error) {
if err := c.Init(ctx); err != nil {
return nil, err
}
key, err := c.WALSegmentPath(pos.Generation, pos.Index, pos.Offset)
if err != nil {
return nil, fmt.Errorf("cannot determine wal segment path: %w", err)
}
out, err := c.s3.GetObjectWithContext(ctx, &s3.GetObjectInput{
Bucket: aws.String(c.Bucket),
Key: aws.String(key),
})
if isNotExists(err) {
return nil, os.ErrNotExist
} else if err != nil {
return nil, err
}
operationTotalCounterVec.WithLabelValues("GET").Inc()
operationBytesCounterVec.WithLabelValues("GET").Add(float64(*out.ContentLength))
return out.Body, nil
}
// DeleteWALSegments deletes WAL segments with at the given positions.
func (c *ReplicaClient) DeleteWALSegments(ctx context.Context, a []litestream.Pos) error {
if err := c.Init(ctx); err != nil {
return err
}
objIDs := make([]*s3.ObjectIdentifier, MaxKeys)
for len(a) > 0 {
n := MaxKeys
if len(a) < n {
n = len(a)
}
// Generate a batch of object IDs for deleting the WAL segments.
for i, pos := range a[:n] {
key, err := c.WALSegmentPath(pos.Generation, pos.Index, pos.Offset)
if err != nil {
return fmt.Errorf("cannot determine wal segment path: %w", err)
}
objIDs[i] = &s3.ObjectIdentifier{Key: &key}
}
// Delete S3 objects in bulk.
if _, err := c.s3.DeleteObjectsWithContext(ctx, &s3.DeleteObjectsInput{
Bucket: aws.String(c.Bucket),
Delete: &s3.Delete{Objects: objIDs[:n], Quiet: aws.Bool(true)},
}); err != nil {
return err
}
operationTotalCounterVec.WithLabelValues("DELETE").Inc()
a = a[n:]
}
return nil
}
// DeleteAll deletes everything on the remote path. Mainly used for testing.
func (c *ReplicaClient) DeleteAll(ctx context.Context) error {
if err := c.Init(ctx); err != nil {
return err
}
prefix := c.Path
if prefix != "" {
prefix += "/"
}
// Collect all files for the generation.
var objIDs []*s3.ObjectIdentifier
if err := c.s3.ListObjectsPagesWithContext(ctx, &s3.ListObjectsInput{
Bucket: aws.String(c.Bucket),
Prefix: aws.String(prefix),
}, func(page *s3.ListObjectsOutput, lastPage bool) bool {
operationTotalCounterVec.WithLabelValues("LIST").Inc()
for _, obj := range page.Contents {
objIDs = append(objIDs, &s3.ObjectIdentifier{Key: obj.Key})
}
return true
}); err != nil {
return err
}
// Delete all files in batches.
for len(objIDs) > 0 {
n := MaxKeys
if len(objIDs) < n {
n = len(objIDs)
}
if _, err := c.s3.DeleteObjectsWithContext(ctx, &s3.DeleteObjectsInput{
Bucket: aws.String(c.Bucket),
Delete: &s3.Delete{Objects: objIDs[:n], Quiet: aws.Bool(true)},
}); err != nil {
return err
}
operationTotalCounterVec.WithLabelValues("DELETE").Inc()
objIDs = objIDs[n:]
}
return nil
}
type snapshotIterator struct {
client *ReplicaClient
generation string
ch chan litestream.SnapshotInfo
g errgroup.Group
ctx context.Context
cancel func()
info litestream.SnapshotInfo
err error
}
func newSnapshotIterator(ctx context.Context, client *ReplicaClient, generation string) *snapshotIterator {
itr := &snapshotIterator{
client: client,
generation: generation,
ch: make(chan litestream.SnapshotInfo),
}
itr.ctx, itr.cancel = context.WithCancel(ctx)
itr.g.Go(itr.fetch)
return itr
}
// fetch runs in a separate goroutine to fetch pages of objects and stream them to a channel.
func (itr *snapshotIterator) fetch() error {
defer close(itr.ch)
dir, err := itr.client.SnapshotsDir(itr.generation)
if err != nil {
return fmt.Errorf("cannot determine snapshot directory path: %w", err)
}
return itr.client.s3.ListObjectsPagesWithContext(itr.ctx, &s3.ListObjectsInput{
Bucket: aws.String(itr.client.Bucket),
Prefix: aws.String(dir + "/"),
Delimiter: aws.String("/"),
}, func(page *s3.ListObjectsOutput, lastPage bool) bool {
operationTotalCounterVec.WithLabelValues("LIST").Inc()
for _, obj := range page.Contents {
key := path.Base(*obj.Key)
index, err := litestream.ParseSnapshotPath(key)
if err != nil {
continue
}
info := litestream.SnapshotInfo{
Generation: itr.generation,
Index: index,
Size: *obj.Size,
CreatedAt: obj.LastModified.UTC(),
}
select {
case <-itr.ctx.Done():
case itr.ch <- info:
}
}
return true
})
}
func (itr *snapshotIterator) Close() (err error) {
err = itr.err
// Cancel context and wait for error group to finish.
itr.cancel()
if e := itr.g.Wait(); e != nil && err == nil {
err = e
}
return err
}
func (itr *snapshotIterator) Next() bool {
// Exit if an error has already occurred.
if itr.err != nil {
return false
}
// Return false if context was canceled or if there are no more snapshots.
// Otherwise fetch the next snapshot and store it on the iterator.
select {
case <-itr.ctx.Done():
return false
case info, ok := <-itr.ch:
if !ok {
return false
}
itr.info = info
return true
}
}
func (itr *snapshotIterator) Err() error { return itr.err }
func (itr *snapshotIterator) Snapshot() litestream.SnapshotInfo {
return itr.info
}
type walSegmentIterator struct {
client *ReplicaClient
generation string
ch chan litestream.WALSegmentInfo
g errgroup.Group
ctx context.Context
cancel func()
info litestream.WALSegmentInfo
err error
}
func newWALSegmentIterator(ctx context.Context, client *ReplicaClient, generation string) *walSegmentIterator {
itr := &walSegmentIterator{
client: client,
generation: generation,
ch: make(chan litestream.WALSegmentInfo),
}
itr.ctx, itr.cancel = context.WithCancel(ctx)
itr.g.Go(itr.fetch)
return itr
}
// fetch runs in a separate goroutine to fetch pages of objects and stream them to a channel.
func (itr *walSegmentIterator) fetch() error {
defer close(itr.ch)
dir, err := itr.client.WALDir(itr.generation)
if err != nil {
return fmt.Errorf("cannot determine wal directory path: %w", err)
}
return itr.client.s3.ListObjectsPagesWithContext(itr.ctx, &s3.ListObjectsInput{
Bucket: aws.String(itr.client.Bucket),
Prefix: aws.String(dir + "/"),
Delimiter: aws.String("/"),
}, func(page *s3.ListObjectsOutput, lastPage bool) bool {
operationTotalCounterVec.WithLabelValues("LIST").Inc()
for _, obj := range page.Contents {
key := path.Base(*obj.Key)
index, offset, err := litestream.ParseWALSegmentPath(key)
if err != nil {
continue
}
info := litestream.WALSegmentInfo{
Generation: itr.generation,
Index: index,
Offset: offset,
Size: *obj.Size,
CreatedAt: obj.LastModified.UTC(),
}
select {
case <-itr.ctx.Done():
return false
case itr.ch <- info:
}
}
return true
})
}
func (itr *walSegmentIterator) Close() (err error) {
err = itr.err
// Cancel context and wait for error group to finish.
itr.cancel()
if e := itr.g.Wait(); e != nil && err == nil {
err = e
}
return err
}
func (itr *walSegmentIterator) Next() bool {
// Exit if an error has already occurred.
if itr.err != nil {
return false
}
// Return false if context was canceled or if there are no more segments.
// Otherwise fetch the next segment and store it on the iterator.
select {
case <-itr.ctx.Done():
return false
case info, ok := <-itr.ch:
if !ok {
return false
}
itr.info = info
return true
}
}
func (itr *walSegmentIterator) Err() error { return itr.err }
func (itr *walSegmentIterator) WALSegment() litestream.WALSegmentInfo {
return itr.info
}
func isNotExists(err error) bool {
switch err := err.(type) {
case awserr.Error:
return err.Code() == `NoSuchKey`
default:
return false
}
}
// S3 metrics.
var (
operationTotalCounterVec = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "litestream_s3_operation_total",
Help: "The number of S3 operations performed",
}, []string{"type"})
operationBytesCounterVec = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "litestream_s3_operation_bytes",
Help: "The number of bytes used by S3 operations",
}, []string{"type"})
)