Refactor Restore()

This commit refactors out the complexity of downloading ordered WAL
files in parallel to a type called `WALDownloader`. This makes it
easier to test the restore separately from the download.
This commit is contained in:
Ben Johnson
2022-01-04 14:47:11 -07:00
parent 531e19ed6f
commit 3f0ec9fa9f
130 changed files with 2943 additions and 1254 deletions

View File

@@ -7,14 +7,12 @@ import (
"io"
"io/ioutil"
"log"
"math"
"os"
"path/filepath"
"sort"
"sync"
"time"
"github.com/benbjohnson/litestream/internal"
"github.com/pierrec/lz4/v4"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
@@ -144,6 +142,15 @@ func (r *Replica) Stop(hard bool) (err error) {
return err
}
// logPrefix returns the prefix used when logging from the replica.
// This includes the replica name as well as the database path, if available.
func (r *Replica) logPrefix() string {
if db := r.DB(); db != nil {
return fmt.Sprintf("%s(%s): ", db.Path(), r.Name())
}
return r.Name() + ": "
}
// Sync copies new WAL frames from the shadow WAL to the replica client.
func (r *Replica) Sync(ctx context.Context) (err error) {
// Clear last position if if an error occurs during sync.
@@ -766,14 +773,18 @@ func (r *Replica) Validate(ctx context.Context) error {
return fmt.Errorf("cannot wait for replica: %w", err)
}
// Find lastest snapshot that occurs before the index.
snapshotIndex, err := FindSnapshotForIndex(ctx, r.Client, pos.Generation, pos.Index-1)
if err != nil {
return fmt.Errorf("cannot find snapshot index: %w", err)
}
restorePath := filepath.Join(tmpdir, "replica")
if err := r.Restore(ctx, RestoreOptions{
OutputPath: restorePath,
ReplicaName: r.Name(),
Generation: pos.Generation,
Index: pos.Index - 1,
Logger: log.New(os.Stderr, "", 0),
}); err != nil {
opt := RestoreOptions{
Logger: log.New(os.Stderr, "", 0),
LogPrefix: r.logPrefix(),
}
if err := Restore(ctx, r.Client, restorePath, pos.Generation, snapshotIndex, pos.Index-1, opt); err != nil {
return fmt.Errorf("cannot restore: %w", err)
}
@@ -883,295 +894,6 @@ func (r *Replica) GenerationCreatedAt(ctx context.Context, generation string) (t
return min, itr.Close()
}
// GenerationTimeBounds returns the creation time & last updated time of a generation.
// Returns zero time if no snapshots or WAL segments exist.
func (r *Replica) GenerationTimeBounds(ctx context.Context, generation string) (createdAt, updatedAt time.Time, err error) {
// Iterate over snapshots.
sitr, err := r.Client.Snapshots(ctx, generation)
if err != nil {
return createdAt, updatedAt, err
}
defer sitr.Close()
for sitr.Next() {
info := sitr.Snapshot()
if createdAt.IsZero() || info.CreatedAt.Before(createdAt) {
createdAt = info.CreatedAt
}
if updatedAt.IsZero() || info.CreatedAt.After(updatedAt) {
updatedAt = info.CreatedAt
}
}
if err := sitr.Close(); err != nil {
return createdAt, updatedAt, err
}
// Iterate over WAL segments.
witr, err := r.Client.WALSegments(ctx, generation)
if err != nil {
return createdAt, updatedAt, err
}
defer witr.Close()
for witr.Next() {
info := witr.WALSegment()
if createdAt.IsZero() || info.CreatedAt.Before(createdAt) {
createdAt = info.CreatedAt
}
if updatedAt.IsZero() || info.CreatedAt.After(updatedAt) {
updatedAt = info.CreatedAt
}
}
if err := witr.Close(); err != nil {
return createdAt, updatedAt, err
}
return createdAt, updatedAt, nil
}
// CalcRestoreTarget returns a generation to restore from.
func (r *Replica) CalcRestoreTarget(ctx context.Context, opt RestoreOptions) (generation string, updatedAt time.Time, err error) {
var target struct {
generation string
updatedAt time.Time
}
generations, err := r.Client.Generations(ctx)
if err != nil {
return "", time.Time{}, fmt.Errorf("cannot fetch generations: %w", err)
}
// Search generations for one that contains the requested timestamp.
for _, generation := range generations {
// Skip generation if it does not match filter.
if opt.Generation != "" && generation != opt.Generation {
continue
}
// Determine the time bounds for the generation.
createdAt, updatedAt, err := r.GenerationTimeBounds(ctx, generation)
if err != nil {
return "", time.Time{}, fmt.Errorf("generation created at: %w", err)
}
// Skip if it does not contain timestamp.
if !opt.Timestamp.IsZero() {
if opt.Timestamp.Before(createdAt) || opt.Timestamp.After(updatedAt) {
continue
}
}
// Use the latest replica if we have multiple candidates.
if !updatedAt.After(target.updatedAt) {
continue
}
target.generation = generation
target.updatedAt = updatedAt
}
return target.generation, target.updatedAt, nil
}
// Replica restores the database from a replica based on the options given.
// This method will restore into opt.OutputPath, if specified, or into the
// DB's original database path. It can optionally restore from a specific
// replica or generation or it will automatically choose the best one. Finally,
// a timestamp can be specified to restore the database to a specific
// point-in-time.
func (r *Replica) Restore(ctx context.Context, opt RestoreOptions) (err error) {
// Validate options.
if opt.OutputPath == "" {
if r.db.path == "" {
return fmt.Errorf("output path required")
}
opt.OutputPath = r.db.path
} else if opt.Generation == "" && opt.Index != math.MaxInt32 {
return fmt.Errorf("must specify generation when restoring to index")
} else if opt.Index != math.MaxInt32 && !opt.Timestamp.IsZero() {
return fmt.Errorf("cannot specify index & timestamp to restore")
}
// Ensure logger exists.
logger := opt.Logger
if logger == nil {
logger = log.New(ioutil.Discard, "", 0)
}
logPrefix := r.Name()
if db := r.DB(); db != nil {
logPrefix = fmt.Sprintf("%s(%s)", db.Path(), r.Name())
}
// Ensure output path does not already exist.
if _, err := os.Stat(opt.OutputPath); err == nil {
return fmt.Errorf("cannot restore, output path already exists: %s", opt.OutputPath)
} else if err != nil && !os.IsNotExist(err) {
return err
}
// Find lastest snapshot that occurs before timestamp or index.
var minWALIndex int
if opt.Index < math.MaxInt32 {
if minWALIndex, err = r.SnapshotIndexByIndex(ctx, opt.Generation, opt.Index); err != nil {
return fmt.Errorf("cannot find snapshot index: %w", err)
}
} else {
if minWALIndex, err = r.SnapshotIndexAt(ctx, opt.Generation, opt.Timestamp); err != nil {
return fmt.Errorf("cannot find snapshot index by timestamp: %w", err)
}
}
// Compute list of offsets for each WAL index.
walSegmentMap, err := r.walSegmentMap(ctx, opt.Generation, opt.Index, opt.Timestamp)
if err != nil {
return fmt.Errorf("cannot find max wal index for restore: %w", err)
}
// Find the maximum WAL index that occurs before timestamp.
maxWALIndex := -1
for index := range walSegmentMap {
if index > maxWALIndex {
maxWALIndex = index
}
}
// Ensure that we found the specific index, if one was specified.
if opt.Index != math.MaxInt32 && opt.Index != opt.Index {
return fmt.Errorf("unable to locate index %d in generation %q, highest index was %d", opt.Index, opt.Generation, maxWALIndex)
}
// If no WAL files were found, mark this as a snapshot-only restore.
snapshotOnly := maxWALIndex == -1
// Initialize starting position.
pos := Pos{Generation: opt.Generation, Index: minWALIndex}
tmpPath := opt.OutputPath + ".tmp"
// Copy snapshot to output path.
logger.Printf("%s: restoring snapshot %s/%08x to %s", logPrefix, opt.Generation, minWALIndex, tmpPath)
if err := r.restoreSnapshot(ctx, pos.Generation, pos.Index, tmpPath); err != nil {
return fmt.Errorf("cannot restore snapshot: %w", err)
}
// If no WAL files available, move snapshot to final path & exit early.
if snapshotOnly {
logger.Printf("%s: snapshot only, finalizing database", logPrefix)
return os.Rename(tmpPath, opt.OutputPath)
}
// Begin processing WAL files.
logger.Printf("%s: restoring wal files: generation=%s index=[%08x,%08x]", logPrefix, opt.Generation, minWALIndex, maxWALIndex)
// Fill input channel with all WAL indexes to be loaded in order.
// Verify every index has at least one offset.
ch := make(chan int, maxWALIndex-minWALIndex+1)
for index := minWALIndex; index <= maxWALIndex; index++ {
if len(walSegmentMap[index]) == 0 {
return fmt.Errorf("missing WAL index: %s/%08x", opt.Generation, index)
}
ch <- index
}
close(ch)
// Track load state for each WAL.
var mu sync.Mutex
cond := sync.NewCond(&mu)
walStates := make([]walRestoreState, maxWALIndex-minWALIndex+1)
parallelism := opt.Parallelism
if parallelism < 1 {
parallelism = 1
}
// Download WAL files to disk in parallel.
g, ctx := errgroup.WithContext(ctx)
for i := 0; i < parallelism; i++ {
g.Go(func() error {
for {
select {
case <-ctx.Done():
cond.Broadcast()
return err
case index, ok := <-ch:
if !ok {
cond.Broadcast()
return nil
}
startTime := time.Now()
err := r.downloadWAL(ctx, opt.Generation, index, walSegmentMap[index], tmpPath)
if err != nil {
err = fmt.Errorf("cannot download wal %s/%08x: %w", opt.Generation, index, err)
}
// Mark index as ready-to-apply and notify applying code.
mu.Lock()
walStates[index-minWALIndex] = walRestoreState{ready: true, err: err}
mu.Unlock()
cond.Broadcast()
// Returning the error here will cancel the other goroutines.
if err != nil {
return err
}
logger.Printf("%s: downloaded wal %s/%08x elapsed=%s",
logPrefix, opt.Generation, index,
time.Since(startTime).String(),
)
}
}
})
}
// Apply WAL files in order as they are ready.
for index := minWALIndex; index <= maxWALIndex; index++ {
// Wait until next WAL file is ready to apply.
mu.Lock()
for !walStates[index-minWALIndex].ready {
if err := ctx.Err(); err != nil {
return err
}
cond.Wait()
}
if err := walStates[index-minWALIndex].err; err != nil {
return err
}
mu.Unlock()
// Apply WAL to database file.
startTime := time.Now()
if err = applyWAL(ctx, index, tmpPath); err != nil {
return fmt.Errorf("cannot apply wal: %w", err)
}
logger.Printf("%s: applied wal %s/%08x elapsed=%s",
logPrefix, opt.Generation, index,
time.Since(startTime).String(),
)
}
// Ensure all goroutines finish. All errors should have been handled during
// the processing of WAL files but this ensures that all processing is done.
if err := g.Wait(); err != nil {
return err
}
// Copy file to final location.
logger.Printf("%s: renaming database from temporary location", logPrefix)
if err := os.Rename(tmpPath, opt.OutputPath); err != nil {
return err
}
return nil
}
type walRestoreState struct {
ready bool
err error
}
// SnapshotIndexAt returns the highest index for a snapshot within a generation
// that occurs before timestamp. If timestamp is zero, returns the latest snapshot.
func (r *Replica) SnapshotIndexAt(ctx context.Context, generation string, timestamp time.Time) (int, error) {
@@ -1202,137 +924,19 @@ func (r *Replica) SnapshotIndexAt(ctx context.Context, generation string, timest
return snapshotIndex, nil
}
// SnapshotIndexbyIndex returns the highest index for a snapshot within a generation
// that occurs before a given index. If index is MaxInt32, returns the latest snapshot.
func (r *Replica) SnapshotIndexByIndex(ctx context.Context, generation string, index int) (int, error) {
itr, err := r.Client.Snapshots(ctx, generation)
if err != nil {
return 0, err
}
defer itr.Close()
snapshotIndex := -1
for itr.Next() {
snapshot := itr.Snapshot()
if index < math.MaxInt32 && snapshot.Index > index {
continue // after index, skip
}
// Use snapshot if it newer.
if snapshotIndex == -1 || snapshotIndex >= snapshotIndex {
snapshotIndex = snapshot.Index
}
}
if err := itr.Close(); err != nil {
return 0, err
} else if snapshotIndex == -1 {
return 0, ErrNoSnapshots
}
return snapshotIndex, nil
}
// walSegmentMap returns a map of WAL indices to their segments.
// Filters by a max timestamp or a max index.
func (r *Replica) walSegmentMap(ctx context.Context, generation string, maxIndex int, maxTimestamp time.Time) (map[int][]int64, error) {
itr, err := r.Client.WALSegments(ctx, generation)
if err != nil {
return nil, err
}
defer itr.Close()
m := make(map[int][]int64)
for itr.Next() {
info := itr.WALSegment()
// Exit if we go past the max timestamp or index.
if !maxTimestamp.IsZero() && info.CreatedAt.After(maxTimestamp) {
break // after max timestamp, skip
} else if info.Index > maxIndex {
break // after max index, skip
}
// Verify offsets are added in order.
offsets := m[info.Index]
if len(offsets) == 0 && info.Offset != 0 {
return nil, fmt.Errorf("missing initial wal segment: generation=%s index=%08x offset=%d", generation, info.Index, info.Offset)
} else if len(offsets) > 0 && offsets[len(offsets)-1] >= info.Offset {
return nil, fmt.Errorf("wal segments out of order: generation=%s index=%08x offsets=(%d,%d)", generation, info.Index, offsets[len(offsets)-1], info.Offset)
}
// Append to the end of the WAL file.
m[info.Index] = append(offsets, info.Offset)
}
return m, itr.Close()
}
// restoreSnapshot copies a snapshot from the replica to a file.
func (r *Replica) restoreSnapshot(ctx context.Context, generation string, index int, filename string) error {
// Determine the user/group & mode based on the DB, if available.
var fileInfo, dirInfo os.FileInfo
if db := r.DB(); db != nil {
fileInfo, dirInfo = db.fileInfo, db.dirInfo
}
if err := internal.MkdirAll(filepath.Dir(filename), dirInfo); err != nil {
return err
}
f, err := internal.CreateFile(filename, fileInfo)
if err != nil {
return err
}
defer f.Close()
rd, err := r.Client.SnapshotReader(ctx, generation, index)
if err != nil {
return err
}
defer rd.Close()
if _, err := io.Copy(f, lz4.NewReader(rd)); err != nil {
return err
} else if err := f.Sync(); err != nil {
return err
}
return f.Close()
}
// downloadWAL copies a WAL file from the replica to a local copy next to the DB.
// The WAL is later applied by applyWAL(). This function can be run in parallel
// to download multiple WAL files simultaneously.
func (r *Replica) downloadWAL(ctx context.Context, generation string, index int, offsets []int64, dbPath string) (err error) {
// Determine the user/group & mode based on the DB, if available.
var fileInfo os.FileInfo
if db := r.DB(); db != nil {
fileInfo = db.fileInfo
}
// Open readers for every segment in the WAL file, in order.
var readers []io.Reader
for _, offset := range offsets {
rd, err := r.Client.WALSegmentReader(ctx, Pos{Generation: generation, Index: index, Offset: offset})
// LatestReplica returns the most recently updated replica.
func LatestReplica(ctx context.Context, replicas []*Replica) (*Replica, error) {
var t time.Time
var r *Replica
for i := range replicas {
_, max, err := ReplicaClientTimeBounds(ctx, replicas[i].Client)
if err != nil {
return err
return nil, err
} else if r == nil || max.After(t) {
r, t = replicas[i], max
}
defer rd.Close()
readers = append(readers, lz4.NewReader(rd))
}
// Open handle to destination WAL path.
f, err := internal.CreateFile(fmt.Sprintf("%s-%08x-wal", dbPath, index), fileInfo)
if err != nil {
return err
}
defer f.Close()
// Combine segments together and copy WAL to target path.
if _, err := io.Copy(f, io.MultiReader(readers...)); err != nil {
return err
} else if err := f.Close(); err != nil {
return err
}
return nil
return r, nil
}
// Replica metrics.