Prevent double-close for SFTP client
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -264,3 +265,18 @@ func TruncateDuration(d time.Duration) time.Duration {
|
|||||||
func MD5Hash(b []byte) string {
|
func MD5Hash(b []byte) string {
|
||||||
return fmt.Sprintf("%x", md5.Sum(b))
|
return fmt.Sprintf("%x", md5.Sum(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnceCloser returns a closer that will only ignore duplicate closes.
|
||||||
|
func OnceCloser(c io.Closer) io.Closer {
|
||||||
|
return &onceCloser{Closer: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
type onceCloser struct {
|
||||||
|
sync.Once
|
||||||
|
io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *onceCloser) Close() (err error) {
|
||||||
|
c.Once.Do(func() { err = c.Closer.Close() })
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/benbjohnson/litestream/internal"
|
"github.com/benbjohnson/litestream/internal"
|
||||||
|
"github.com/benbjohnson/litestream/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseSnapshotPath(t *testing.T) {
|
func TestParseSnapshotPath(t *testing.T) {
|
||||||
@@ -100,3 +101,27 @@ func TestTruncateDuration(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOnceCloser(t *testing.T) {
|
||||||
|
var closed bool
|
||||||
|
var rc = &mock.ReadCloser{
|
||||||
|
CloseFunc: func() error {
|
||||||
|
if closed {
|
||||||
|
t.Fatal("already closed")
|
||||||
|
}
|
||||||
|
closed = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
oc := internal.OnceCloser(rc)
|
||||||
|
if err := oc.Close(); err != nil {
|
||||||
|
t.Fatalf("first close: %s", err)
|
||||||
|
} else if err := oc.Close(); err != nil {
|
||||||
|
t.Fatalf("second close: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !closed {
|
||||||
|
t.Fatal("expected close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -270,12 +270,13 @@ func (c *ReplicaClient) WriteSnapshot(ctx context.Context, generation string, in
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return info, fmt.Errorf("cannot open snapshot file for writing: %w", err)
|
return info, fmt.Errorf("cannot open snapshot file for writing: %w", err)
|
||||||
}
|
}
|
||||||
defer f.Close()
|
closer := internal.OnceCloser(f)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
n, err := io.Copy(f, rd)
|
n, err := io.Copy(f, rd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return info, err
|
return info, err
|
||||||
} else if err := f.Close(); err != nil {
|
} else if err := closer.Close(); err != nil {
|
||||||
return info, err
|
return info, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -391,12 +392,13 @@ func (c *ReplicaClient) WriteWALSegment(ctx context.Context, pos litestream.Pos,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return info, fmt.Errorf("cannot open snapshot file for writing: %w", err)
|
return info, fmt.Errorf("cannot open snapshot file for writing: %w", err)
|
||||||
}
|
}
|
||||||
defer f.Close()
|
closer := internal.OnceCloser(f)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
n, err := io.Copy(f, rd)
|
n, err := io.Copy(f, rd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return info, err
|
return info, err
|
||||||
} else if err := f.Close(); err != nil {
|
} else if err := closer.Close(); err != nil {
|
||||||
return info, err
|
return info, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user