Refactor Restore()
This commit refactors out the complexity of downloading ordered WAL files in parallel to a type called `WALDownloader`. This makes it easier to test the restore separately from the download.
This commit is contained in:
335
wal_downloader.go
Normal file
335
wal_downloader.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package litestream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
"github.com/pierrec/lz4/v4"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// WALDownloader represents a parallel downloader of WAL files from a replica client.
|
||||
//
|
||||
// It works on a per-index level so WAL files are always downloaded in their
|
||||
// entirety and are not segmented. WAL files are downloaded from minIndex to
|
||||
// maxIndex, inclusively, and are written to a path prefix. WAL files are named
|
||||
// with the prefix and suffixed with the WAL index. It is the responsibility of
|
||||
// the caller to clean up these WAL files.
|
||||
//
|
||||
// The purpose of the parallization is that RTT & WAL apply time can consume
|
||||
// much of the restore time so it's useful to download multiple WAL files in
|
||||
// the background to minimize the latency. While some WAL indexes may be
|
||||
// downloaded out of order, the WALDownloader ensures that Next() always
|
||||
// returns the WAL files sequentially.
|
||||
type WALDownloader struct {
|
||||
ctx context.Context // context used for early close/cancellation
|
||||
cancel func()
|
||||
|
||||
client ReplicaClient // client to read WAL segments with
|
||||
generation string // generation to download WAL files from
|
||||
minIndex int // starting WAL index (inclusive)
|
||||
maxIndex int // ending WAL index (inclusive)
|
||||
prefix string // output file prefix
|
||||
|
||||
err error // error occuring during init, propagated to Next()
|
||||
n int // number of WAL files returned by Next()
|
||||
|
||||
// Concurrency coordination
|
||||
mu sync.Mutex // used to serialize sending of next WAL index
|
||||
cond *sync.Cond // used with mu above
|
||||
g *errgroup.Group // manages worker goroutines for downloading
|
||||
input chan walDownloadInput // holds ordered WAL indices w/ offsets
|
||||
output chan walDownloadOutput // always sends next sequential WAL; used by Next()
|
||||
nextIndex int // tracks next WAL index to send to output channel
|
||||
|
||||
// File info used for downloaded WAL files.
|
||||
Mode os.FileMode
|
||||
Uid, Gid int
|
||||
|
||||
// Number of downloads occurring in parallel.
|
||||
Parallelism int
|
||||
}
|
||||
|
||||
// NewWALDownloader returns a new instance of WALDownloader.
|
||||
func NewWALDownloader(client ReplicaClient, prefix string, generation string, minIndex, maxIndex int) *WALDownloader {
|
||||
d := &WALDownloader{
|
||||
client: client,
|
||||
prefix: prefix,
|
||||
generation: generation,
|
||||
minIndex: minIndex,
|
||||
maxIndex: maxIndex,
|
||||
|
||||
Mode: 0600,
|
||||
Parallelism: 1,
|
||||
}
|
||||
|
||||
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||
d.cond = sync.NewCond(&d.mu)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// Close cancels all downloads and returns any error that has occurred.
|
||||
func (d *WALDownloader) Close() (err error) {
|
||||
if d.err != nil {
|
||||
err = d.err
|
||||
}
|
||||
|
||||
d.cancel()
|
||||
|
||||
if d.g != nil {
|
||||
if e := d.g.Wait(); err != nil && e != context.Canceled {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// init initializes the downloader on the first invocation only. It generates
|
||||
// the input channel with all WAL indices & offsets needed, it initializes
|
||||
// the output channel that Next() waits on, and starts the worker goroutines
|
||||
// that begin downloading WAL files in the background.
|
||||
func (d *WALDownloader) init(ctx context.Context) error {
|
||||
if d.input != nil {
|
||||
return nil // already initialized
|
||||
} else if d.minIndex < 0 {
|
||||
return fmt.Errorf("minimum index required")
|
||||
} else if d.maxIndex < 0 {
|
||||
return fmt.Errorf("maximum index required")
|
||||
} else if d.maxIndex < d.minIndex {
|
||||
return fmt.Errorf("minimum index cannot be larger than maximum index")
|
||||
} else if d.Parallelism < 1 {
|
||||
return fmt.Errorf("parallelism must be at least one")
|
||||
}
|
||||
|
||||
// Populate input channel with indices & offsets.
|
||||
if err := d.initInputCh(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
d.nextIndex = d.minIndex
|
||||
|
||||
// Generate output channel that Next() pulls from.
|
||||
d.output = make(chan walDownloadOutput)
|
||||
|
||||
// Spawn worker goroutines to download WALs.
|
||||
d.g, d.ctx = errgroup.WithContext(d.ctx)
|
||||
for i := 0; i < d.Parallelism; i++ {
|
||||
d.g.Go(func() error { return d.downloader(d.ctx) })
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initInputCh populates the input channel with each WAL index between minIndex
|
||||
// and maxIndex. It also includes all offsets needed with the index.
|
||||
func (d *WALDownloader) initInputCh(ctx context.Context) error {
|
||||
itr, err := d.client.WALSegments(ctx, d.generation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wal segments: %w", err)
|
||||
}
|
||||
defer func() { _ = itr.Close() }()
|
||||
|
||||
d.input = make(chan walDownloadInput, d.maxIndex-d.minIndex+1)
|
||||
defer close(d.input)
|
||||
|
||||
index := d.minIndex - 1
|
||||
var offsets []int64
|
||||
for itr.Next() {
|
||||
info := itr.WALSegment()
|
||||
|
||||
// Restrict segments to within our index range.
|
||||
if info.Index < d.minIndex {
|
||||
continue // haven't reached minimum index, skip
|
||||
} else if info.Index > d.maxIndex {
|
||||
break // after max index, stop
|
||||
}
|
||||
|
||||
// Flush index & offsets when index changes.
|
||||
if info.Index != index {
|
||||
if info.Index != index+1 { // must be sequential
|
||||
return &WALNotFoundError{Generation: d.generation, Index: index + 1}
|
||||
}
|
||||
|
||||
if len(offsets) > 0 {
|
||||
d.input <- walDownloadInput{index: index, offsets: offsets}
|
||||
offsets = make([]int64, 0)
|
||||
}
|
||||
|
||||
index = info.Index
|
||||
}
|
||||
|
||||
// Append to the end of the WAL file.
|
||||
offsets = append(offsets, info.Offset)
|
||||
}
|
||||
|
||||
// Ensure we read to the last index.
|
||||
if index != d.maxIndex {
|
||||
return &WALNotFoundError{Generation: d.generation, Index: index + 1}
|
||||
}
|
||||
|
||||
// Flush if we have remaining offsets.
|
||||
if len(offsets) > 0 {
|
||||
d.input <- walDownloadInput{index: index, offsets: offsets}
|
||||
}
|
||||
|
||||
return itr.Close()
|
||||
}
|
||||
|
||||
// N returns the number of WAL files returned by Next().
|
||||
func (d *WALDownloader) N() int { return d.n }
|
||||
|
||||
// Next returns the index & local file path of the next downloaded WAL file.
|
||||
func (d *WALDownloader) Next(ctx context.Context) (int, string, error) {
|
||||
if d.err != nil {
|
||||
return 0, "", d.err
|
||||
} else if d.err = d.init(ctx); d.err != nil {
|
||||
return 0, "", d.err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, "", ctx.Err()
|
||||
case <-d.ctx.Done():
|
||||
return 0, "", d.ctx.Err()
|
||||
case v, ok := <-d.output:
|
||||
if !ok {
|
||||
return 0, "", io.EOF
|
||||
}
|
||||
|
||||
d.n++
|
||||
return v.index, v.path, v.err
|
||||
}
|
||||
}
|
||||
|
||||
// downloader runs in a separate goroutine and downloads the next input index.
|
||||
func (d *WALDownloader) downloader(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
d.cond.Broadcast()
|
||||
return ctx.Err()
|
||||
|
||||
case input, ok := <-d.input:
|
||||
if !ok {
|
||||
return nil // no more input
|
||||
}
|
||||
|
||||
// Wait until next index equals input index and then send file to
|
||||
// output to ensure sorted order.
|
||||
if err := func() error {
|
||||
walPath, err := d.downloadWAL(ctx, input.index, input.offsets)
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// Notify other downloader goroutines when we escape this
|
||||
// anonymous function.
|
||||
defer d.cond.Broadcast()
|
||||
|
||||
// Keep looping until our index matches the next index to send.
|
||||
for d.nextIndex != input.index {
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return ctxErr
|
||||
}
|
||||
d.cond.Wait()
|
||||
}
|
||||
|
||||
// Still under lock, wait until Next() requests next index.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
||||
case d.output <- walDownloadOutput{
|
||||
index: input.index,
|
||||
path: walPath,
|
||||
err: err,
|
||||
}:
|
||||
// At the last index, close out output channel to notify
|
||||
// the Next() method to return io.EOF.
|
||||
if d.nextIndex == d.maxIndex {
|
||||
close(d.output)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update next expected index now that our send is successful.
|
||||
d.nextIndex++
|
||||
}
|
||||
|
||||
return err
|
||||
}(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// downloadWAL sequentially downloads all the segments for WAL index from the
|
||||
// replica client and appends them to a single on-disk file. Returns the name
|
||||
// of the on-disk file on success.
|
||||
func (d *WALDownloader) downloadWAL(ctx context.Context, index int, offsets []int64) (string, error) {
|
||||
// Open handle to destination WAL path.
|
||||
walPath := fmt.Sprintf("%s-%08x-wal", d.prefix, index)
|
||||
f, err := internal.CreateFile(walPath, d.Mode, d.Uid, d.Gid)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Open readers for every segment in the WAL file, in order.
|
||||
var written int64
|
||||
for _, offset := range offsets {
|
||||
if err := func() error {
|
||||
// Ensure next offset is our current position in the file.
|
||||
if written != offset {
|
||||
return fmt.Errorf("missing WAL offset: generation=%s index=%08x offset=%08x", d.generation, index, written)
|
||||
}
|
||||
|
||||
rd, err := d.client.WALSegmentReader(ctx, Pos{Generation: d.generation, Index: index, Offset: offset})
|
||||
if err != nil {
|
||||
return fmt.Errorf("read WAL segment: %w", err)
|
||||
}
|
||||
defer rd.Close()
|
||||
|
||||
n, err := io.Copy(f, lz4.NewReader(rd))
|
||||
if err != nil {
|
||||
return fmt.Errorf("copy WAL segment: %w", err)
|
||||
}
|
||||
written += n
|
||||
|
||||
return nil
|
||||
}(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return walPath, nil
|
||||
}
|
||||
|
||||
type walDownloadInput struct {
|
||||
index int
|
||||
offsets []int64
|
||||
}
|
||||
|
||||
type walDownloadOutput struct {
|
||||
path string
|
||||
index int
|
||||
err error
|
||||
}
|
||||
|
||||
// WALNotFoundError is returned by WALDownloader if an WAL index is not found.
|
||||
type WALNotFoundError struct {
|
||||
Generation string
|
||||
Index int
|
||||
}
|
||||
|
||||
// Error returns the error string.
|
||||
func (e *WALNotFoundError) Error() string {
|
||||
return fmt.Sprintf("wal not found: generation=%s index=%08x", e.Generation, e.Index)
|
||||
}
|
||||
Reference in New Issue
Block a user