diff --git a/internal/internal.go b/internal/internal.go index c671379..95d0f78 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -8,6 +8,7 @@ import ( "os" "regexp" "strconv" + "sync" "syscall" "time" @@ -264,3 +265,18 @@ func TruncateDuration(d time.Duration) time.Duration { func MD5Hash(b []byte) string { return fmt.Sprintf("%x", md5.Sum(b)) } + +// OnceCloser returns a closer that will only ignore duplicate closes. +func OnceCloser(c io.Closer) io.Closer { + return &onceCloser{Closer: c} +} + +type onceCloser struct { + sync.Once + io.Closer +} + +func (c *onceCloser) Close() (err error) { + c.Once.Do(func() { err = c.Closer.Close() }) + return err +} diff --git a/internal/internal_test.go b/internal/internal_test.go index 9d2c49b..11d0f6d 100644 --- a/internal/internal_test.go +++ b/internal/internal_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/benbjohnson/litestream/internal" + "github.com/benbjohnson/litestream/mock" ) func TestParseSnapshotPath(t *testing.T) { @@ -100,3 +101,27 @@ func TestTruncateDuration(t *testing.T) { }) } } + +func TestOnceCloser(t *testing.T) { + var closed bool + var rc = &mock.ReadCloser{ + CloseFunc: func() error { + if closed { + t.Fatal("already closed") + } + closed = true + return nil + }, + } + + oc := internal.OnceCloser(rc) + if err := oc.Close(); err != nil { + t.Fatalf("first close: %s", err) + } else if err := oc.Close(); err != nil { + t.Fatalf("second close: %s", err) + } + + if !closed { + t.Fatal("expected close") + } +} diff --git a/sftp/replica_client.go b/sftp/replica_client.go index 8566ab0..269e8ca 100644 --- a/sftp/replica_client.go +++ b/sftp/replica_client.go @@ -270,12 +270,13 @@ func (c *ReplicaClient) WriteSnapshot(ctx context.Context, generation string, in if err != nil { return info, fmt.Errorf("cannot open snapshot file for writing: %w", err) } - defer f.Close() + closer := internal.OnceCloser(f) + defer closer.Close() n, err := io.Copy(f, rd) if err != nil { return info, err - } else if err := f.Close(); err != nil { + } else if err := closer.Close(); err != nil { return info, err } @@ -391,12 +392,13 @@ func (c *ReplicaClient) WriteWALSegment(ctx context.Context, pos litestream.Pos, if err != nil { return info, fmt.Errorf("cannot open snapshot file for writing: %w", err) } - defer f.Close() + closer := internal.OnceCloser(f) + defer closer.Close() n, err := io.Copy(f, rd) if err != nil { return info, err - } else if err := f.Close(); err != nil { + } else if err := closer.Close(); err != nil { return info, err }