Prevent double-close for SFTP client
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -264,3 +265,18 @@ func TruncateDuration(d time.Duration) time.Duration {
|
||||
func MD5Hash(b []byte) string {
|
||||
return fmt.Sprintf("%x", md5.Sum(b))
|
||||
}
|
||||
|
||||
// OnceCloser returns a closer that will only ignore duplicate closes.
|
||||
func OnceCloser(c io.Closer) io.Closer {
|
||||
return &onceCloser{Closer: c}
|
||||
}
|
||||
|
||||
type onceCloser struct {
|
||||
sync.Once
|
||||
io.Closer
|
||||
}
|
||||
|
||||
func (c *onceCloser) Close() (err error) {
|
||||
c.Once.Do(func() { err = c.Closer.Close() })
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/litestream/internal"
|
||||
"github.com/benbjohnson/litestream/mock"
|
||||
)
|
||||
|
||||
func TestParseSnapshotPath(t *testing.T) {
|
||||
@@ -100,3 +101,27 @@ func TestTruncateDuration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnceCloser(t *testing.T) {
|
||||
var closed bool
|
||||
var rc = &mock.ReadCloser{
|
||||
CloseFunc: func() error {
|
||||
if closed {
|
||||
t.Fatal("already closed")
|
||||
}
|
||||
closed = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
oc := internal.OnceCloser(rc)
|
||||
if err := oc.Close(); err != nil {
|
||||
t.Fatalf("first close: %s", err)
|
||||
} else if err := oc.Close(); err != nil {
|
||||
t.Fatalf("second close: %s", err)
|
||||
}
|
||||
|
||||
if !closed {
|
||||
t.Fatal("expected close")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,12 +270,13 @@ func (c *ReplicaClient) WriteSnapshot(ctx context.Context, generation string, in
|
||||
if err != nil {
|
||||
return info, fmt.Errorf("cannot open snapshot file for writing: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
closer := internal.OnceCloser(f)
|
||||
defer closer.Close()
|
||||
|
||||
n, err := io.Copy(f, rd)
|
||||
if err != nil {
|
||||
return info, err
|
||||
} else if err := f.Close(); err != nil {
|
||||
} else if err := closer.Close(); err != nil {
|
||||
return info, err
|
||||
}
|
||||
|
||||
@@ -391,12 +392,13 @@ func (c *ReplicaClient) WriteWALSegment(ctx context.Context, pos litestream.Pos,
|
||||
if err != nil {
|
||||
return info, fmt.Errorf("cannot open snapshot file for writing: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
closer := internal.OnceCloser(f)
|
||||
defer closer.Close()
|
||||
|
||||
n, err := io.Copy(f, rd)
|
||||
if err != nil {
|
||||
return info, err
|
||||
} else if err := f.Close(); err != nil {
|
||||
} else if err := closer.Close(); err != nil {
|
||||
return info, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user