package litestream import ( "context" "fmt" "hash/crc64" "io" "io/ioutil" "log" "math" "os" "path/filepath" "sort" "sync" "time" "github.com/benbjohnson/litestream/internal" "github.com/pierrec/lz4/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "golang.org/x/sync/errgroup" ) // Default replica settings. const ( DefaultSyncInterval = 1 * time.Second DefaultRetention = 24 * time.Hour DefaultRetentionCheckInterval = 1 * time.Hour ) // Replica connects a database to a replication destination via a ReplicaClient. // The replica manages periodic synchronization and maintaining the current // replica position. type Replica struct { db *DB name string mu sync.RWMutex pos Pos // current replicated position muf sync.Mutex f *os.File // long-running file descriptor to avoid non-OFD lock issues wg sync.WaitGroup cancel func() // Client used to connect to the remote replica. Client ReplicaClient // Time between syncs with the shadow WAL. SyncInterval time.Duration // Frequency to create new snapshots. SnapshotInterval time.Duration // Time to keep snapshots and related WAL files. // Database is snapshotted after interval, if needed, and older WAL files are discarded. Retention time.Duration // Time between checks for retention. RetentionCheckInterval time.Duration // Time between validation checks. ValidationInterval time.Duration // If true, replica monitors database for changes automatically. // Set to false if replica is being used synchronously (such as in tests). MonitorEnabled bool Logger *log.Logger } func NewReplica(db *DB, name string) *Replica { r := &Replica{ db: db, name: name, cancel: func() {}, SyncInterval: DefaultSyncInterval, Retention: DefaultRetention, RetentionCheckInterval: DefaultRetentionCheckInterval, MonitorEnabled: true, } prefix := fmt.Sprintf("%s: ", r.Name()) if db != nil { prefix = fmt.Sprintf("%s(%s): ", db.Path(), r.Name()) } r.Logger = log.New(LogWriter, prefix, LogFlags) return r } // Name returns the name of the replica. func (r *Replica) Name() string { if r.name == "" && r.Client != nil { return r.Client.Type() } return r.name } // DB returns a reference to the database the replica is attached to, if any. func (r *Replica) DB() *DB { return r.db } // Starts replicating in a background goroutine. func (r *Replica) Start(ctx context.Context) error { // Ignore if replica is being used sychronously. if !r.MonitorEnabled { return nil } // Stop previous replication. r.Stop(false) // Wrap context with cancelation. ctx, r.cancel = context.WithCancel(ctx) // Start goroutine to replicate data. r.wg.Add(4) go func() { defer r.wg.Done(); r.monitor(ctx) }() go func() { defer r.wg.Done(); r.retainer(ctx) }() go func() { defer r.wg.Done(); r.snapshotter(ctx) }() go func() { defer r.wg.Done(); r.validator(ctx) }() return nil } // Stop cancels any outstanding replication and blocks until finished. // // Performing a hard stop will close the DB file descriptor which could release // locks on per-process locks. Hard stops should only be performed when // stopping the entire process. func (r *Replica) Stop(hard bool) (err error) { r.cancel() r.wg.Wait() r.muf.Lock() defer r.muf.Unlock() if hard && r.f != nil { if e := r.f.Close(); e != nil && err == nil { err = e } } return err } // Sync copies new WAL frames from the shadow WAL to the replica client. func (r *Replica) Sync(ctx context.Context) (err error) { // Clear last position if if an error occurs during sync. defer func() { if err != nil { r.mu.Lock() r.pos = Pos{} r.mu.Unlock() } }() // Find current position of database. dpos := r.db.Pos() if dpos.IsZero() { return fmt.Errorf("no generation, waiting for data") } generation := dpos.Generation // Create snapshot if no snapshots exist for generation. snapshotN, err := r.snapshotN(generation) if err != nil { return err } else if snapshotN == 0 { if info, err := r.Snapshot(ctx); err != nil { return err } else if info.Generation != generation { return fmt.Errorf("generation changed during snapshot, exiting sync") } snapshotN = 1 } replicaSnapshotTotalGaugeVec.WithLabelValues(r.db.Path(), r.Name()).Set(float64(snapshotN)) // Determine position, if necessary. if r.Pos().Generation != generation { pos, err := r.calcPos(ctx, generation) if err != nil { return fmt.Errorf("cannot determine replica position: %s", err) } r.mu.Lock() r.pos = pos r.mu.Unlock() } // Read all WAL files since the last position. if err = r.syncWAL(ctx); err != nil { return err } return nil } func (r *Replica) syncWAL(ctx context.Context) (err error) { pos := r.Pos() itr, err := r.db.WALSegments(ctx, pos.Generation) if err != nil { return err } defer itr.Close() // Group segments by index. var segments [][]WALSegmentInfo for itr.Next() { info := itr.WALSegment() if cmp, err := ComparePos(pos, info.Pos()); err != nil { return fmt.Errorf("compare pos: %w", err) } else if cmp == 1 { continue // already processed, skip } // Start a new chunk if index has changed. if len(segments) == 0 || segments[len(segments)-1][0].Index != info.Index { segments = append(segments, []WALSegmentInfo{info}) continue } // Add segment to the end of the current index, if matching. segments[len(segments)-1] = append(segments[len(segments)-1], info) } // Write out segments to replica by index so they can be combined. for i := range segments { if err := r.writeIndexSegments(ctx, segments[i]); err != nil { return fmt.Errorf("write index segments: index=%d err=%w", segments[i][0].Index, err) } } return nil } func (r *Replica) writeIndexSegments(ctx context.Context, segments []WALSegmentInfo) (err error) { assert(len(segments) > 0, "segments required for replication") // First segment position must be equal to last replica position or // the start of the next index. if pos := r.Pos(); pos != segments[0].Pos() { nextIndexPos := pos.Truncate() nextIndexPos.Index++ if nextIndexPos != segments[0].Pos() { return fmt.Errorf("replica skipped position: replica=%s initial=%s", pos, segments[0].Pos()) } } pos := segments[0].Pos() initialPos := pos // Copy shadow WAL to client write via io.Pipe(). pr, pw := io.Pipe() defer func() { _ = pw.CloseWithError(err) }() // Copy through pipe into client from the starting position. var g errgroup.Group g.Go(func() error { _, err := r.Client.WriteWALSegment(ctx, initialPos, pr) return err }) // Wrap writer to LZ4 compress. zw := lz4.NewWriter(pw) // Write each segment out to the replica. for _, info := range segments { if err := func() error { // Ensure segments are in order and no bytes are skipped. if pos != info.Pos() { return fmt.Errorf("non-contiguous segment: expected=%s current=%s", pos, info.Pos()) } rc, err := r.db.WALSegmentReader(ctx, info.Pos()) if err != nil { return err } defer rc.Close() n, err := io.Copy(zw, lz4.NewReader(rc)) if err != nil { return err } else if err := rc.Close(); err != nil { return err } // Track last position written. pos = info.Pos() pos.Offset += n return nil }(); err != nil { return fmt.Errorf("wal segment: pos=%s err=%w", info.Pos(), err) } } // Flush LZ4 writer, close pipe, and wait for write to finish. if err := zw.Close(); err != nil { return err } else if err := pw.Close(); err != nil { return err } else if err := g.Wait(); err != nil { return err } // Save last replicated position. r.mu.Lock() r.pos = pos r.mu.Unlock() replicaWALBytesCounterVec.WithLabelValues(r.db.Path(), r.Name()).Add(float64(pos.Offset - initialPos.Offset)) // Track total WAL bytes written to replica client. replicaWALIndexGaugeVec.WithLabelValues(r.db.Path(), r.Name()).Set(float64(pos.Index)) replicaWALOffsetGaugeVec.WithLabelValues(r.db.Path(), r.Name()).Set(float64(pos.Offset)) r.Logger.Printf("wal segment written: %s sz=%d", initialPos, pos.Offset-initialPos.Offset) return nil } // snapshotN returns the number of snapshots for a generation. func (r *Replica) snapshotN(generation string) (int, error) { itr, err := r.Client.Snapshots(context.Background(), generation) if err != nil { return 0, err } defer itr.Close() var n int for itr.Next() { n++ } return n, itr.Close() } // calcPos returns the last position for the given generation. func (r *Replica) calcPos(ctx context.Context, generation string) (pos Pos, err error) { // Fetch last snapshot. Return error if no snapshots exist. snapshot, err := r.maxSnapshot(ctx, generation) if err != nil { return pos, fmt.Errorf("max snapshot: %w", err) } else if snapshot == nil { return pos, fmt.Errorf("no snapshot available: generation=%s", generation) } // Determine last WAL segment available. Use snapshot if none exist. segment, err := r.maxWALSegment(ctx, generation) if err != nil { return pos, fmt.Errorf("max wal segment: %w", err) } else if segment == nil { return Pos{Generation: snapshot.Generation, Index: snapshot.Index}, nil } // Read segment to determine size to add to offset. rd, err := r.Client.WALSegmentReader(ctx, segment.Pos()) if err != nil { return pos, fmt.Errorf("wal segment reader: %w", err) } defer rd.Close() n, err := io.Copy(ioutil.Discard, lz4.NewReader(rd)) if err != nil { return pos, err } // Return the position at the end of the last WAL segment. return Pos{ Generation: segment.Generation, Index: segment.Index, Offset: segment.Offset + n, }, nil } // maxSnapshot returns the last snapshot in a generation. func (r *Replica) maxSnapshot(ctx context.Context, generation string) (*SnapshotInfo, error) { itr, err := r.Client.Snapshots(ctx, generation) if err != nil { return nil, err } defer itr.Close() var max *SnapshotInfo for itr.Next() { if info := itr.Snapshot(); max == nil || info.Index > max.Index { max = &info } } return max, itr.Close() } // maxWALSegment returns the highest WAL segment in a generation. func (r *Replica) maxWALSegment(ctx context.Context, generation string) (*WALSegmentInfo, error) { itr, err := r.Client.WALSegments(ctx, generation) if err != nil { return nil, err } defer itr.Close() var max *WALSegmentInfo for itr.Next() { if info := itr.WALSegment(); max == nil || info.Index > max.Index || (info.Index == max.Index && info.Offset > max.Offset) { max = &info } } return max, itr.Close() } // Pos returns the current replicated position. // Returns a zero value if the current position cannot be determined. func (r *Replica) Pos() Pos { r.mu.RLock() defer r.mu.RUnlock() return r.pos } // Snapshots returns a list of all snapshots across all generations. func (r *Replica) Snapshots(ctx context.Context) ([]SnapshotInfo, error) { generations, err := r.Client.Generations(ctx) if err != nil { return nil, fmt.Errorf("cannot fetch generations: %w", err) } var a []SnapshotInfo for _, generation := range generations { if err := func() error { itr, err := r.Client.Snapshots(ctx, generation) if err != nil { return err } defer itr.Close() other, err := SliceSnapshotIterator(itr) if err != nil { return err } a = append(a, other...) return itr.Close() }(); err != nil { return a, err } } sort.Sort(SnapshotInfoSlice(a)) return a, nil } // Snapshot copies the entire database to the replica path. func (r *Replica) Snapshot(ctx context.Context) (info SnapshotInfo, err error) { if r.db == nil || r.db.db == nil { return info, fmt.Errorf("no database available") } r.muf.Lock() defer r.muf.Unlock() // Issue a passive checkpoint to flush any pages to disk before snapshotting. if _, err := r.db.db.ExecContext(ctx, `PRAGMA wal_checkpoint(PASSIVE);`); err != nil { return info, fmt.Errorf("pre-snapshot checkpoint: %w", err) } // Acquire a read lock on the database during snapshot to prevent checkpoints. tx, err := r.db.db.Begin() if err != nil { return info, err } else if _, err := tx.ExecContext(ctx, `SELECT COUNT(1) FROM _litestream_seq;`); err != nil { _ = tx.Rollback() return info, err } defer func() { _ = tx.Rollback() }() // Obtain current position. pos := r.db.Pos() if pos.IsZero() { return info, ErrNoGeneration } // Open db file descriptor, if not already open, & position at beginning. if r.f == nil { if r.f, err = os.Open(r.db.Path()); err != nil { return info, err } } if _, err := r.f.Seek(0, io.SeekStart); err != nil { return info, err } // Use a pipe to convert the LZ4 writer to a reader. pr, pw := io.Pipe() // Copy the database file to the LZ4 writer in a separate goroutine. var g errgroup.Group g.Go(func() error { zr := lz4.NewWriter(pw) defer zr.Close() if _, err := io.Copy(zr, r.f); err != nil { pw.CloseWithError(err) return err } else if err := zr.Close(); err != nil { pw.CloseWithError(err) return err } return pw.Close() }) // Delegate write to client & wait for writer goroutine to finish. if info, err = r.Client.WriteSnapshot(ctx, pos.Generation, pos.Index, pr); err != nil { return info, err } else if err := g.Wait(); err != nil { return info, err } r.Logger.Printf("snapshot written %s/%08x", pos.Generation, pos.Index) return info, nil } // EnforceRetention forces a new snapshot once the retention interval has passed. // Older snapshots and WAL files are then removed. func (r *Replica) EnforceRetention(ctx context.Context) (err error) { // Obtain list of snapshots that are within the retention period. snapshots, err := r.Snapshots(ctx) if err != nil { return fmt.Errorf("snapshots: %w", err) } retained := FilterSnapshotsAfter(snapshots, time.Now().Add(-r.Retention)) // If no retained snapshots exist, create a new snapshot. if len(retained) == 0 { snapshot, err := r.Snapshot(ctx) if err != nil { return fmt.Errorf("snapshot: %w", err) } retained = append(retained, snapshot) } // Loop over generations and delete unretained snapshots & WAL files. generations, err := r.Client.Generations(ctx) if err != nil { return fmt.Errorf("generations: %w", err) } for _, generation := range generations { // Find earliest retained snapshot for this generation. snapshot := FindMinSnapshotByGeneration(retained, generation) // Delete entire generation if no snapshots are being retained. if snapshot == nil { if err := r.Client.DeleteGeneration(ctx, generation); err != nil { return fmt.Errorf("delete generation: %w", err) } continue } // Otherwise remove all earlier snapshots & WAL segments. if err := r.deleteSnapshotsBeforeIndex(ctx, generation, snapshot.Index); err != nil { return fmt.Errorf("delete snapshots before index: %w", err) } else if err := r.deleteWALSegmentsBeforeIndex(ctx, generation, snapshot.Index); err != nil { return fmt.Errorf("delete wal segments before index: %w", err) } } return nil } func (r *Replica) deleteSnapshotsBeforeIndex(ctx context.Context, generation string, index int) error { itr, err := r.Client.Snapshots(ctx, generation) if err != nil { return fmt.Errorf("fetch snapshots: %w", err) } defer itr.Close() for itr.Next() { info := itr.Snapshot() if info.Index >= index { continue } if err := r.Client.DeleteSnapshot(ctx, info.Generation, info.Index); err != nil { return fmt.Errorf("delete snapshot %s/%08x: %w", info.Generation, info.Index, err) } r.Logger.Printf("snapshot deleted %s/%08x", generation, index) } return itr.Close() } func (r *Replica) deleteWALSegmentsBeforeIndex(ctx context.Context, generation string, index int) error { itr, err := r.Client.WALSegments(ctx, generation) if err != nil { return fmt.Errorf("fetch wal segments: %w", err) } defer itr.Close() var a []Pos for itr.Next() { info := itr.WALSegment() if info.Index >= index { continue } a = append(a, info.Pos()) } if err := itr.Close(); err != nil { return err } if len(a) == 0 { return nil } if err := r.Client.DeleteWALSegments(ctx, a); err != nil { return fmt.Errorf("delete wal segments: %w", err) } for _, pos := range a { r.Logger.Printf("wal segmented deleted: %s", pos) } return nil } // monitor runs in a separate goroutine and continuously replicates the DB. func (r *Replica) monitor(ctx context.Context) { ticker := time.NewTicker(r.SyncInterval) defer ticker.Stop() // Continuously check for new data to replicate. ch := make(chan struct{}) close(ch) var notify <-chan struct{} = ch for initial := true; ; initial = false { // Enforce a minimum time between synchronization. if !initial { select { case <-ctx.Done(): return case <-ticker.C: } } // Wait for changes to the database. select { case <-ctx.Done(): return case <-notify: } // Fetch new notify channel before replicating data. notify = r.db.Notify() // Synchronize the shadow wal into the replication directory. if err := r.Sync(ctx); err != nil { r.Logger.Printf("monitor error: %s", err) continue } } } // retainer runs in a separate goroutine and handles retention. func (r *Replica) retainer(ctx context.Context) { // Disable retention enforcement if retention period is non-positive. if r.Retention <= 0 { return } // Ensure check interval is not longer than retention period. checkInterval := r.RetentionCheckInterval if checkInterval > r.Retention { checkInterval = r.Retention } ticker := time.NewTicker(checkInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if err := r.EnforceRetention(ctx); err != nil { r.Logger.Printf("retainer error: %s", err) continue } } } } // snapshotter runs in a separate goroutine and handles snapshotting. func (r *Replica) snapshotter(ctx context.Context) { if r.SnapshotInterval <= 0 { return } ticker := time.NewTicker(r.SnapshotInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if _, err := r.Snapshot(ctx); err != nil && err != ErrNoGeneration { r.Logger.Printf("snapshotter error: %s", err) continue } } } } // validator runs in a separate goroutine and handles periodic validation. func (r *Replica) validator(ctx context.Context) { // Initialize counters since validation occurs infrequently. for _, status := range []string{"ok", "error"} { replicaValidationTotalCounterVec.WithLabelValues(r.db.Path(), r.Name(), status).Add(0) } // Exit validation if interval is not set. if r.ValidationInterval <= 0 { return } ticker := time.NewTicker(r.ValidationInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if err := r.Validate(ctx); err != nil { r.Logger.Printf("validation error: %s", err) continue } } } } // Validate restores the most recent data from a replica and validates // that the resulting database matches the current database. func (r *Replica) Validate(ctx context.Context) error { db := r.DB() // Restore replica to a temporary directory. tmpdir, err := ioutil.TempDir("", "*-litestream") if err != nil { return err } defer os.RemoveAll(tmpdir) // Compute checksum of primary database under lock. This prevents a // sync from occurring and the database will not be written. chksum0, pos, err := db.CRC64(ctx) if err != nil { return fmt.Errorf("cannot compute checksum: %w", err) } // Wait until replica catches up to position. if err := r.waitForReplica(ctx, pos); err != nil { return fmt.Errorf("cannot wait for replica: %w", err) } restorePath := filepath.Join(tmpdir, "replica") if err := r.Restore(ctx, RestoreOptions{ OutputPath: restorePath, ReplicaName: r.Name(), Generation: pos.Generation, Index: pos.Index - 1, Logger: log.New(os.Stderr, "", 0), }); err != nil { return fmt.Errorf("cannot restore: %w", err) } // Open file handle for restored database. // NOTE: This open is ok as the restored database is not managed by litestream. f, err := os.Open(restorePath) if err != nil { return err } defer f.Close() // Read entire file into checksum. h := crc64.New(crc64.MakeTable(crc64.ISO)) if _, err := io.Copy(h, f); err != nil { return err } chksum1 := h.Sum64() status := "ok" mismatch := chksum0 != chksum1 if mismatch { status = "mismatch" } r.Logger.Printf("validator: status=%s db=%016x replica=%016x pos=%s", status, chksum0, chksum1, pos) // Validate checksums match. if mismatch { replicaValidationTotalCounterVec.WithLabelValues(r.db.Path(), r.Name(), "error").Inc() return ErrChecksumMismatch } replicaValidationTotalCounterVec.WithLabelValues(r.db.Path(), r.Name(), "ok").Inc() if err := os.RemoveAll(tmpdir); err != nil { return fmt.Errorf("cannot remove temporary validation directory: %w", err) } return nil } // waitForReplica blocks until replica reaches at least the given position. func (r *Replica) waitForReplica(ctx context.Context, pos Pos) error { ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() timer := time.NewTicker(10 * time.Second) defer ticker.Stop() once := make(chan struct{}, 1) once <- struct{}{} for { select { case <-ctx.Done(): return ctx.Err() case <-timer.C: return fmt.Errorf("replica wait exceeded timeout") case <-ticker.C: case <-once: // immediate on first check } // Obtain current position of replica, check if past target position. curr := r.Pos() if curr.IsZero() { r.Logger.Printf("validator: no replica position available") continue } // Exit if the generation has changed while waiting as there will be // no further progress on the old generation. if curr.Generation != pos.Generation { return fmt.Errorf("generation changed") } ready := true if curr.Index < pos.Index { ready = false } else if curr.Index == pos.Index && curr.Offset < pos.Offset { ready = false } // If not ready, restart loop. if !ready { continue } // Current position at or after target position. return nil } } // GenerationCreatedAt returns the earliest creation time of any snapshot. // Returns zero time if no snapshots exist. func (r *Replica) GenerationCreatedAt(ctx context.Context, generation string) (time.Time, error) { var min time.Time itr, err := r.Client.Snapshots(ctx, generation) if err != nil { return min, err } defer itr.Close() for itr.Next() { if info := itr.Snapshot(); min.IsZero() || info.CreatedAt.Before(min) { min = info.CreatedAt } } return min, itr.Close() } // GenerationTimeBounds returns the creation time & last updated time of a generation. // Returns zero time if no snapshots or WAL segments exist. func (r *Replica) GenerationTimeBounds(ctx context.Context, generation string) (createdAt, updatedAt time.Time, err error) { // Iterate over snapshots. sitr, err := r.Client.Snapshots(ctx, generation) if err != nil { return createdAt, updatedAt, err } defer sitr.Close() for sitr.Next() { info := sitr.Snapshot() if createdAt.IsZero() || info.CreatedAt.Before(createdAt) { createdAt = info.CreatedAt } if updatedAt.IsZero() || info.CreatedAt.After(updatedAt) { updatedAt = info.CreatedAt } } if err := sitr.Close(); err != nil { return createdAt, updatedAt, err } // Iterate over WAL segments. witr, err := r.Client.WALSegments(ctx, generation) if err != nil { return createdAt, updatedAt, err } defer witr.Close() for witr.Next() { info := witr.WALSegment() if createdAt.IsZero() || info.CreatedAt.Before(createdAt) { createdAt = info.CreatedAt } if updatedAt.IsZero() || info.CreatedAt.After(updatedAt) { updatedAt = info.CreatedAt } } if err := witr.Close(); err != nil { return createdAt, updatedAt, err } return createdAt, updatedAt, nil } // CalcRestoreTarget returns a generation to restore from. func (r *Replica) CalcRestoreTarget(ctx context.Context, opt RestoreOptions) (generation string, updatedAt time.Time, err error) { var target struct { generation string updatedAt time.Time } generations, err := r.Client.Generations(ctx) if err != nil { return "", time.Time{}, fmt.Errorf("cannot fetch generations: %w", err) } // Search generations for one that contains the requested timestamp. for _, generation := range generations { // Skip generation if it does not match filter. if opt.Generation != "" && generation != opt.Generation { continue } // Determine the time bounds for the generation. createdAt, updatedAt, err := r.GenerationTimeBounds(ctx, generation) if err != nil { return "", time.Time{}, fmt.Errorf("generation created at: %w", err) } // Skip if it does not contain timestamp. if !opt.Timestamp.IsZero() { if opt.Timestamp.Before(createdAt) || opt.Timestamp.After(updatedAt) { continue } } // Use the latest replica if we have multiple candidates. if !updatedAt.After(target.updatedAt) { continue } target.generation = generation target.updatedAt = updatedAt } return target.generation, target.updatedAt, nil } // Replica restores the database from a replica based on the options given. // This method will restore into opt.OutputPath, if specified, or into the // DB's original database path. It can optionally restore from a specific // replica or generation or it will automatically choose the best one. Finally, // a timestamp can be specified to restore the database to a specific // point-in-time. func (r *Replica) Restore(ctx context.Context, opt RestoreOptions) (err error) { // Validate options. if opt.OutputPath == "" { if r.db.path == "" { return fmt.Errorf("output path required") } opt.OutputPath = r.db.path } else if opt.Generation == "" && opt.Index != math.MaxInt32 { return fmt.Errorf("must specify generation when restoring to index") } else if opt.Index != math.MaxInt32 && !opt.Timestamp.IsZero() { return fmt.Errorf("cannot specify index & timestamp to restore") } // Ensure logger exists. logger := opt.Logger if logger == nil { logger = log.New(ioutil.Discard, "", 0) } logPrefix := r.Name() if db := r.DB(); db != nil { logPrefix = fmt.Sprintf("%s(%s)", db.Path(), r.Name()) } // Ensure output path does not already exist. if _, err := os.Stat(opt.OutputPath); err == nil { return fmt.Errorf("cannot restore, output path already exists: %s", opt.OutputPath) } else if err != nil && !os.IsNotExist(err) { return err } // Find lastest snapshot that occurs before timestamp or index. var minWALIndex int if opt.Index < math.MaxInt32 { if minWALIndex, err = r.SnapshotIndexByIndex(ctx, opt.Generation, opt.Index); err != nil { return fmt.Errorf("cannot find snapshot index: %w", err) } } else { if minWALIndex, err = r.SnapshotIndexAt(ctx, opt.Generation, opt.Timestamp); err != nil { return fmt.Errorf("cannot find snapshot index by timestamp: %w", err) } } // Compute list of offsets for each WAL index. walSegmentMap, err := r.walSegmentMap(ctx, opt.Generation, opt.Index, opt.Timestamp) if err != nil { return fmt.Errorf("cannot find max wal index for restore: %w", err) } // Find the maximum WAL index that occurs before timestamp. maxWALIndex := -1 for index := range walSegmentMap { if index > maxWALIndex { maxWALIndex = index } } // Ensure that we found the specific index, if one was specified. if opt.Index != math.MaxInt32 && opt.Index != opt.Index { return fmt.Errorf("unable to locate index %d in generation %q, highest index was %d", opt.Index, opt.Generation, maxWALIndex) } // If no WAL files were found, mark this as a snapshot-only restore. snapshotOnly := maxWALIndex == -1 // Initialize starting position. pos := Pos{Generation: opt.Generation, Index: minWALIndex} tmpPath := opt.OutputPath + ".tmp" // Copy snapshot to output path. logger.Printf("%s: restoring snapshot %s/%08x to %s", logPrefix, opt.Generation, minWALIndex, tmpPath) if err := r.restoreSnapshot(ctx, pos.Generation, pos.Index, tmpPath); err != nil { return fmt.Errorf("cannot restore snapshot: %w", err) } // If no WAL files available, move snapshot to final path & exit early. if snapshotOnly { logger.Printf("%s: snapshot only, finalizing database", logPrefix) return os.Rename(tmpPath, opt.OutputPath) } // Begin processing WAL files. logger.Printf("%s: restoring wal files: generation=%s index=[%08x,%08x]", logPrefix, opt.Generation, minWALIndex, maxWALIndex) // Fill input channel with all WAL indexes to be loaded in order. // Verify every index has at least one offset. ch := make(chan int, maxWALIndex-minWALIndex+1) for index := minWALIndex; index <= maxWALIndex; index++ { if len(walSegmentMap[index]) == 0 { return fmt.Errorf("missing WAL index: %s/%08x", opt.Generation, index) } ch <- index } close(ch) // Track load state for each WAL. var mu sync.Mutex cond := sync.NewCond(&mu) walStates := make([]walRestoreState, maxWALIndex-minWALIndex+1) parallelism := opt.Parallelism if parallelism < 1 { parallelism = 1 } // Download WAL files to disk in parallel. g, ctx := errgroup.WithContext(ctx) for i := 0; i < parallelism; i++ { g.Go(func() error { for { select { case <-ctx.Done(): cond.Broadcast() return err case index, ok := <-ch: if !ok { cond.Broadcast() return nil } startTime := time.Now() err := r.downloadWAL(ctx, opt.Generation, index, walSegmentMap[index], tmpPath) if err != nil { err = fmt.Errorf("cannot download wal %s/%08x: %w", opt.Generation, index, err) } // Mark index as ready-to-apply and notify applying code. mu.Lock() walStates[index-minWALIndex] = walRestoreState{ready: true, err: err} mu.Unlock() cond.Broadcast() // Returning the error here will cancel the other goroutines. if err != nil { return err } logger.Printf("%s: downloaded wal %s/%08x elapsed=%s", logPrefix, opt.Generation, index, time.Since(startTime).String(), ) } } }) } // Apply WAL files in order as they are ready. for index := minWALIndex; index <= maxWALIndex; index++ { // Wait until next WAL file is ready to apply. mu.Lock() for !walStates[index-minWALIndex].ready { if err := ctx.Err(); err != nil { return err } cond.Wait() } if err := walStates[index-minWALIndex].err; err != nil { return err } mu.Unlock() // Apply WAL to database file. startTime := time.Now() if err = applyWAL(ctx, index, tmpPath); err != nil { return fmt.Errorf("cannot apply wal: %w", err) } logger.Printf("%s: applied wal %s/%08x elapsed=%s", logPrefix, opt.Generation, index, time.Since(startTime).String(), ) } // Ensure all goroutines finish. All errors should have been handled during // the processing of WAL files but this ensures that all processing is done. if err := g.Wait(); err != nil { return err } // Copy file to final location. logger.Printf("%s: renaming database from temporary location", logPrefix) if err := os.Rename(tmpPath, opt.OutputPath); err != nil { return err } return nil } type walRestoreState struct { ready bool err error } // SnapshotIndexAt returns the highest index for a snapshot within a generation // that occurs before timestamp. If timestamp is zero, returns the latest snapshot. func (r *Replica) SnapshotIndexAt(ctx context.Context, generation string, timestamp time.Time) (int, error) { itr, err := r.Client.Snapshots(ctx, generation) if err != nil { return 0, err } defer itr.Close() snapshotIndex := -1 var max time.Time for itr.Next() { snapshot := itr.Snapshot() if !timestamp.IsZero() && snapshot.CreatedAt.After(timestamp) { continue // after timestamp, skip } // Use snapshot if it newer. if max.IsZero() || snapshot.CreatedAt.After(max) { snapshotIndex, max = snapshot.Index, snapshot.CreatedAt } } if err := itr.Close(); err != nil { return 0, err } else if snapshotIndex == -1 { return 0, ErrNoSnapshots } return snapshotIndex, nil } // SnapshotIndexbyIndex returns the highest index for a snapshot within a generation // that occurs before a given index. If index is MaxInt32, returns the latest snapshot. func (r *Replica) SnapshotIndexByIndex(ctx context.Context, generation string, index int) (int, error) { itr, err := r.Client.Snapshots(ctx, generation) if err != nil { return 0, err } defer itr.Close() snapshotIndex := -1 for itr.Next() { snapshot := itr.Snapshot() if index < math.MaxInt32 && snapshot.Index > index { continue // after index, skip } // Use snapshot if it newer. if snapshotIndex == -1 || snapshotIndex >= snapshotIndex { snapshotIndex = snapshot.Index } } if err := itr.Close(); err != nil { return 0, err } else if snapshotIndex == -1 { return 0, ErrNoSnapshots } return snapshotIndex, nil } // walSegmentMap returns a map of WAL indices to their segments. // Filters by a max timestamp or a max index. func (r *Replica) walSegmentMap(ctx context.Context, generation string, maxIndex int, maxTimestamp time.Time) (map[int][]int64, error) { itr, err := r.Client.WALSegments(ctx, generation) if err != nil { return nil, err } defer itr.Close() m := make(map[int][]int64) for itr.Next() { info := itr.WALSegment() // Exit if we go past the max timestamp or index. if !maxTimestamp.IsZero() && info.CreatedAt.After(maxTimestamp) { break // after max timestamp, skip } else if info.Index > maxIndex { break // after max index, skip } // Verify offsets are added in order. offsets := m[info.Index] if len(offsets) == 0 && info.Offset != 0 { return nil, fmt.Errorf("missing initial wal segment: generation=%s index=%08x offset=%d", generation, info.Index, info.Offset) } else if len(offsets) > 0 && offsets[len(offsets)-1] >= info.Offset { return nil, fmt.Errorf("wal segments out of order: generation=%s index=%08x offsets=(%d,%d)", generation, info.Index, offsets[len(offsets)-1], info.Offset) } // Append to the end of the WAL file. m[info.Index] = append(offsets, info.Offset) } return m, itr.Close() } // restoreSnapshot copies a snapshot from the replica to a file. func (r *Replica) restoreSnapshot(ctx context.Context, generation string, index int, filename string) error { // Determine the user/group & mode based on the DB, if available. var fileInfo, dirInfo os.FileInfo if db := r.DB(); db != nil { fileInfo, dirInfo = db.fileInfo, db.dirInfo } if err := internal.MkdirAll(filepath.Dir(filename), dirInfo); err != nil { return err } f, err := internal.CreateFile(filename, fileInfo) if err != nil { return err } defer f.Close() rd, err := r.Client.SnapshotReader(ctx, generation, index) if err != nil { return err } defer rd.Close() if _, err := io.Copy(f, lz4.NewReader(rd)); err != nil { return err } else if err := f.Sync(); err != nil { return err } return f.Close() } // downloadWAL copies a WAL file from the replica to a local copy next to the DB. // The WAL is later applied by applyWAL(). This function can be run in parallel // to download multiple WAL files simultaneously. func (r *Replica) downloadWAL(ctx context.Context, generation string, index int, offsets []int64, dbPath string) (err error) { // Determine the user/group & mode based on the DB, if available. var fileInfo os.FileInfo if db := r.DB(); db != nil { fileInfo = db.fileInfo } // Open readers for every segment in the WAL file, in order. var readers []io.Reader for _, offset := range offsets { rd, err := r.Client.WALSegmentReader(ctx, Pos{Generation: generation, Index: index, Offset: offset}) if err != nil { return err } defer rd.Close() readers = append(readers, lz4.NewReader(rd)) } // Open handle to destination WAL path. f, err := internal.CreateFile(fmt.Sprintf("%s-%08x-wal", dbPath, index), fileInfo) if err != nil { return err } defer f.Close() // Combine segments together and copy WAL to target path. if _, err := io.Copy(f, io.MultiReader(readers...)); err != nil { return err } else if err := f.Close(); err != nil { return err } return nil } // Replica metrics. var ( replicaSnapshotTotalGaugeVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "litestream", Subsystem: "replica", Name: "snapshot_total", Help: "The current number of snapshots", }, []string{"db", "name"}) replicaWALBytesCounterVec = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "litestream", Subsystem: "replica", Name: "wal_bytes", Help: "The number wal bytes written", }, []string{"db", "name"}) replicaWALIndexGaugeVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "litestream", Subsystem: "replica", Name: "wal_index", Help: "The current WAL index", }, []string{"db", "name"}) replicaWALOffsetGaugeVec = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "litestream", Subsystem: "replica", Name: "wal_offset", Help: "The current WAL offset", }, []string{"db", "name"}) replicaValidationTotalCounterVec = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "litestream", Subsystem: "replica", Name: "validation_total", Help: "The number of validations performed", }, []string{"db", "name", "status"}) )