diff --git a/db.go b/db.go index f62ee0d..d7a7d6e 100644 --- a/db.go +++ b/db.go @@ -1454,7 +1454,7 @@ func RestoreReplica(ctx context.Context, r Replica, opt RestoreOptions) (err err if err != nil { return fmt.Errorf("cannot find max wal index for restore: %w", err) } - logger.Printf("%s: starting restore: generation %s, index %08x-%08x", logPrefix, opt.Generation, minWALIndex, maxWALIndex) + snapshotOnly := maxWALIndex == -1 // Initialize starting position. pos := Pos{Generation: opt.Generation, Index: minWALIndex} @@ -1466,6 +1466,15 @@ func RestoreReplica(ctx context.Context, r Replica, opt RestoreOptions) (err err return fmt.Errorf("cannot restore snapshot: %w", err) } + // If no WAL files available, move snapshot to final path & exit early. + if snapshotOnly { + logger.Printf("%s: snapshot only, finalizing database", logPrefix) + return os.Rename(tmpPath, opt.OutputPath) + } + + // Begin processing WAL files. + logger.Printf("%s: restoring wal files: generation=%s index=[%08x,%08x]", logPrefix, opt.Generation, minWALIndex, maxWALIndex) + // Fill input channel with all WAL indexes to be loaded in order. ch := make(chan int, maxWALIndex-minWALIndex+1) for index := minWALIndex; index <= maxWALIndex; index++ { @@ -1476,7 +1485,7 @@ func RestoreReplica(ctx context.Context, r Replica, opt RestoreOptions) (err err // Track load state for each WAL. var mu sync.Mutex cond := sync.NewCond(&mu) - ready := make([]bool, maxWALIndex-minWALIndex+1) + walStates := make([]walRestoreState, maxWALIndex-minWALIndex+1) parallelism := opt.Parallelism if parallelism < 1 { @@ -1499,20 +1508,23 @@ func RestoreReplica(ctx context.Context, r Replica, opt RestoreOptions) (err err } startTime := time.Now() - if err = downloadWAL(ctx, r, opt.Generation, index, tmpPath); os.IsNotExist(err) && index == minWALIndex && index == maxWALIndex { - logger.Printf("%s: no wal available, snapshot only", logPrefix) - continue // snapshot file only, ignore error - } else if err != nil { - cond.Broadcast() - return fmt.Errorf("cannot download wal %s/%08x: %w", opt.Generation, index, err) + + err := downloadWAL(ctx, r, opt.Generation, index, tmpPath) + if err != nil { + err = fmt.Errorf("cannot download wal %s/%08x: %w", opt.Generation, index, err) } // Mark index as ready-to-apply and notify applying code. mu.Lock() - ready[index-minWALIndex] = true + walStates[index-minWALIndex] = walRestoreState{ready: true, err: err} mu.Unlock() cond.Broadcast() + // Returning the error here will cancel the other goroutines. + if err != nil { + return err + } + logger.Printf("%s: downloaded wal %s/%08x elapsed=%s", logPrefix, opt.Generation, index, time.Since(startTime).String(), @@ -1526,12 +1538,15 @@ func RestoreReplica(ctx context.Context, r Replica, opt RestoreOptions) (err err for index := minWALIndex; index <= maxWALIndex; index++ { // Wait until next WAL file is ready to apply. mu.Lock() - for !ready[index-minWALIndex] { + for !walStates[index-minWALIndex].ready { if err := ctx.Err(); err != nil { return err } cond.Wait() } + if err := walStates[index-minWALIndex].err; err != nil { + return err + } mu.Unlock() // Apply WAL to database file. @@ -1545,6 +1560,12 @@ func RestoreReplica(ctx context.Context, r Replica, opt RestoreOptions) (err err ) } + // Ensure all goroutines finish. All errors should have been handled during + // the processing of WAL files but this ensures that all processing is done. + if err := g.Wait(); err != nil { + return err + } + // Copy file to final location. logger.Printf("%s: renaming database from temporary location", logPrefix) if err := os.Rename(tmpPath, opt.OutputPath); err != nil { @@ -1554,6 +1575,11 @@ func RestoreReplica(ctx context.Context, r Replica, opt RestoreOptions) (err err return nil } +type walRestoreState struct { + ready bool + err error +} + // CalcRestoreTarget returns a replica & generation to restore from based on opt criteria. func (db *DB) CalcRestoreTarget(ctx context.Context, opt RestoreOptions) (Replica, string, error) { var target struct { diff --git a/replica.go b/replica.go index 0435315..74015e8 100644 --- a/replica.go +++ b/replica.go @@ -1072,7 +1072,9 @@ func SnapshotIndexAt(ctx context.Context, r Replica, generation string, timestam snapshotIndex := -1 var max time.Time for _, snapshot := range snapshots { - if !timestamp.IsZero() && snapshot.CreatedAt.After(timestamp) { + if snapshot.Generation != generation { + continue // generation mismatch, skip + } else if !timestamp.IsZero() && snapshot.CreatedAt.After(timestamp) { continue // after timestamp, skip } @@ -1116,15 +1118,16 @@ func SnapshotIndexByIndex(ctx context.Context, r Replica, generation string, ind return snapshotIndex, nil } -// WALIndexAt returns the highest index for a WAL file that occurs before maxIndex & timestamp. -// If timestamp is zero, returns the highest WAL index. +// WALIndexAt returns the highest index for a WAL file that occurs before +// maxIndex & timestamp. If timestamp is zero, returns the highest WAL index. +// Returns -1 if no WAL found and MaxInt32 specified. func WALIndexAt(ctx context.Context, r Replica, generation string, maxIndex int, timestamp time.Time) (int, error) { wals, err := r.WALs(ctx) if err != nil { return 0, err } - var index int + index := -1 for _, wal := range wals { if wal.Generation != generation { continue