Hold read lock while snapshotting.

This commit is contained in:
Ben Johnson
2020-12-24 16:32:48 -07:00
parent 9fa526f2c3
commit e8e4bffdb2

View File

@@ -1,6 +1,7 @@
package litestream package litestream
import ( import (
"compress/gzip"
"context" "context"
"fmt" "fmt"
"io" "io"
@@ -8,6 +9,7 @@ import (
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"sort"
"strings" "strings"
"sync" "sync"
) )
@@ -17,7 +19,6 @@ import (
type Replicator interface { type Replicator interface {
Name() string Name() string
Type() string Type() string
Snapshotting() bool
Start(ctx context.Context) Start(ctx context.Context)
Stop() Stop()
} }
@@ -30,9 +31,8 @@ type FileReplicator struct {
name string // replicator name, optional name string // replicator name, optional
dst string // destination path dst string // destination path
mu sync.RWMutex // mu sync.RWMutex
wg sync.WaitGroup wg sync.WaitGroup
snapshotting bool // if true, currently copying database
ctx context.Context ctx context.Context
cancel func() cancel func()
@@ -63,7 +63,7 @@ func (r *FileReplicator) Type() string {
// SnapshotPath returns the path to a snapshot file. // SnapshotPath returns the path to a snapshot file.
func (r *FileReplicator) SnapshotPath(generation string, index int) string { func (r *FileReplicator) SnapshotPath(generation string, index int) string {
return filepath.Join(r.dst, "generations", generation, "snapshots", fmt.Sprintf("%016x.snapshot", index)) return filepath.Join(r.dst, "generations", generation, "snapshots", fmt.Sprintf("%016x.snapshot.gz", index))
} }
// WALPath returns the path to a WAL file. // WALPath returns the path to a WAL file.
@@ -71,24 +71,11 @@ func (r *FileReplicator) WALPath(generation string, index int) string {
return filepath.Join(r.dst, "generations", generation, "wal", fmt.Sprintf("%016x.wal", index)) return filepath.Join(r.dst, "generations", generation, "wal", fmt.Sprintf("%016x.wal", index))
} }
// Snapshotting returns true if replicator is current snapshotting.
func (r *FileReplicator) Snapshotting() bool {
r.mu.RLock()
defer r.mu.RLock()
return r.snapshotting
}
// Start starts replication for a given generation. // Start starts replication for a given generation.
func (r *FileReplicator) Start(ctx context.Context) { func (r *FileReplicator) Start(ctx context.Context) {
// Stop previous replication. // Stop previous replication.
r.Stop() r.Stop()
r.mu.Lock()
defer r.mu.Unlock()
// Set snapshotting state.
r.snapshotting = true
// Wrap context with cancelation. // Wrap context with cancelation.
ctx, r.cancel = context.WithCancel(ctx) ctx, r.cancel = context.WithCancel(ctx)
@@ -146,6 +133,14 @@ func (r *FileReplicator) monitor(ctx context.Context) {
log.Printf("%s(%s): sync error: %s", r.db.Path(), r.Name(), err) log.Printf("%s(%s): sync error: %s", r.db.Path(), r.Name(), err)
continue continue
} }
// Gzip any old WAL files.
if pos.Generation != "" {
if err := r.compress(ctx, pos.Generation); err != nil {
log.Printf("%s(%s): compress error: %s", r.db.Path(), r.Name(), err)
continue
}
}
} }
} }
@@ -172,11 +167,14 @@ func (r *FileReplicator) pos() (pos Pos, err error) {
index := -1 index := -1
for _, fi := range fis { for _, fi := range fis {
if !strings.HasSuffix(fi.Name(), WALExt) { name := fi.Name()
name = strings.TrimSuffix(name, ".gz")
if !strings.HasSuffix(name, WALExt) {
continue continue
} }
if v, err := ParseWALFilename(filepath.Base(fi.Name())); err != nil { if v, err := ParseWALFilename(filepath.Base(name)); err != nil {
continue // invalid wal filename continue // invalid wal filename
} else if index == -1 || v > index { } else if index == -1 || v > index {
index = v index = v
@@ -199,17 +197,15 @@ func (r *FileReplicator) pos() (pos Pos, err error) {
// snapshot copies the entire database to the replica path. // snapshot copies the entire database to the replica path.
func (r *FileReplicator) snapshot(ctx context.Context, generation string, index int) error { func (r *FileReplicator) snapshot(ctx context.Context, generation string, index int) error {
// Mark replicator as snapshotting to prevent checkpoints by the DB. // Acquire a read lock on the database during snapshot to prevent checkpoints.
r.mu.Lock() tx, err := r.db.db.Begin()
r.snapshotting = true if err != nil {
r.mu.Unlock() return err
} else if _, err := tx.ExecContext(ctx, `SELECT COUNT(1) FROM _litestream_seq;`); err != nil {
// Ensure we release the snapshot flag when we leave the function. tx.Rollback()
defer func() { return err
r.mu.Lock() }
r.snapshotting = false defer tx.Rollback()
r.mu.Unlock()
}()
// Ignore if we already have a snapshot for the given WAL index. // Ignore if we already have a snapshot for the given WAL index.
snapshotPath := r.SnapshotPath(generation, index) snapshotPath := r.SnapshotPath(generation, index)
@@ -217,33 +213,11 @@ func (r *FileReplicator) snapshot(ctx context.Context, generation string, index
return nil return nil
} }
rd, err := os.Open(r.db.Path())
if err != nil {
return err
}
defer rd.Close()
if err := os.MkdirAll(filepath.Dir(snapshotPath), 0700); err != nil { if err := os.MkdirAll(filepath.Dir(snapshotPath), 0700); err != nil {
return err return err
} }
w, err := os.Create(snapshotPath + ".tmp") return compressFile(r.db.Path(), snapshotPath)
if err != nil {
return err
}
defer w.Close()
if _, err := io.Copy(w, rd); err != nil {
return err
} else if err := w.Sync(); err != nil {
return err
} else if err := w.Close(); err != nil {
return err
} else if err := os.Rename(snapshotPath+".tmp", snapshotPath); err != nil {
return err
}
return nil
} }
// snapshotN returns the number of snapshots for a generation. // snapshotN returns the number of snapshots for a generation.
@@ -257,7 +231,10 @@ func (r *FileReplicator) snapshotN(generation string) (int, error) {
var n int var n int
for _, fi := range fis { for _, fi := range fis {
if strings.HasSuffix(fi.Name(), SnapshotExt) { name := fi.Name()
name = strings.TrimSuffix(name, ".gz")
if strings.HasSuffix(name, SnapshotExt) {
n++ n++
} }
} }
@@ -320,3 +297,68 @@ func (r *FileReplicator) syncNext(ctx context.Context, pos Pos) (_ Pos, err erro
// Return ending position of the reader. // Return ending position of the reader.
return rd.Pos(), nil return rd.Pos(), nil
} }
// compress gzips all WAL files before the current one.
func (r *FileReplicator) compress(ctx context.Context, generation string) error {
dir := filepath.Join(r.dst, "generations", generation, "wal")
filenames, err := filepath.Glob(filepath.Join(dir, "*.wal"))
if err != nil {
return err
} else if len(filenames) <= 1 {
return nil // no uncompressed wal files or only one active file
}
// Ensure filenames are sorted & remove the last (active) WAL.
sort.Strings(filenames)
filenames = filenames[:len(filenames)-1]
// Compress each file from oldest to newest.
for _, filename := range filenames {
select {
case <-ctx.Done():
return err
default:
}
dst := filename + ".gz"
if err := compressFile(filename, dst); err != nil {
return err
} else if err := os.Remove(filename); err != nil {
return err
}
}
return nil
}
// compressFile compresses a file and replaces it with a new file with a .gz extension.
func compressFile(src, dst string) error {
r, err := os.Open(src)
if err != nil {
return err
}
defer r.Close()
w, err := os.Create(dst + ".tmp")
if err != nil {
return err
}
defer w.Close()
gz := gzip.NewWriter(w)
defer gz.Close()
// Copy & compress file contents to temporary file.
if _, err := io.Copy(gz, r); err != nil {
return err
} else if err := gz.Close(); err != nil {
return err
} else if err := w.Sync(); err != nil {
return err
} else if err := w.Close(); err != nil {
return err
}
// Move compressed file to final location.
return os.Rename(dst+".tmp", dst)
}