From bcdb553267ebfe74a54f2db075178fb34235ae33 Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Mon, 11 Jan 2021 09:39:08 -0700 Subject: [PATCH] Use database owner/group --- db.go | 17 +++++++------ litestream.go | 58 +++++++++++++++++++++++++++++++++++++++++++ litestream_unix.go | 18 ++++++++++++++ litestream_windows.go | 23 +++++++++++++++++ replica.go | 12 ++++----- 5 files changed, 115 insertions(+), 13 deletions(-) create mode 100644 litestream_unix.go create mode 100644 litestream_windows.go diff --git a/db.go b/db.go index b903201..09ed776 100644 --- a/db.go +++ b/db.go @@ -41,6 +41,7 @@ type DB struct { rtx *sql.Tx // long running read transaction pageSize int // page size, in bytes notify chan struct{} // closes on WAL change + uid, gid int // db user/group obtained on init ctx context.Context cancel func() @@ -342,11 +343,13 @@ func (db *DB) init() (err error) { } // Exit if no database file exists. - if _, err := os.Stat(db.path); os.IsNotExist(err) { + fi, err := os.Stat(db.path) + if os.IsNotExist(err) { return nil } else if err != nil { return err } + db.uid, db.gid = fileinfo(fi) // Connect to SQLite database & enable WAL. if db.db, err = sql.Open("sqlite3", db.path); err != nil { @@ -386,7 +389,7 @@ func (db *DB) init() (err error) { } // Ensure meta directory structure exists. - if err := os.MkdirAll(db.MetaPath(), 0700); err != nil { + if err := mkdirAll(db.MetaPath(), 0700, db.uid, db.gid); err != nil { return err } @@ -569,7 +572,7 @@ func (db *DB) createGeneration() (string, error) { // Generate new directory. dir := filepath.Join(db.MetaPath(), "generations", generation) - if err := os.MkdirAll(dir, 0700); err != nil { + if err := mkdirAll(dir, 0700, db.uid, db.gid); err != nil { return "", err } @@ -888,7 +891,7 @@ func (db *DB) initShadowWALFile(filename string) error { } // Write header to new WAL shadow file. - if err := os.MkdirAll(filepath.Dir(filename), 0700); err != nil { + if err := mkdirAll(filepath.Dir(filename), 0700, db.uid, db.gid); err != nil { return err } return ioutil.WriteFile(filename, hdr, 0600) @@ -1383,11 +1386,11 @@ func (db *DB) restoreTarget(ctx context.Context, opt RestoreOptions, logger *log // restoreSnapshot copies a snapshot from the replica to a file. func (db *DB) restoreSnapshot(ctx context.Context, r Replica, generation string, index int, filename string) error { - if err := os.MkdirAll(filepath.Dir(filename), 0700); err != nil { + if err := mkdirAll(filepath.Dir(filename), 0700, db.uid, db.gid); err != nil { return err } - f, err := os.Create(filename) + f, err := createFile(filename, db.uid, db.gid) if err != nil { return err } @@ -1419,7 +1422,7 @@ func (db *DB) restoreWAL(ctx context.Context, r Replica, generation string, inde defer rd.Close() // Open handle to destination WAL path. - f, err := os.Create(dbPath + "-wal") + f, err := createFile(dbPath+"-wal", db.uid, db.gid) if err != nil { return err } diff --git a/litestream.go b/litestream.go index 828b94f..270e04b 100644 --- a/litestream.go +++ b/litestream.go @@ -12,6 +12,7 @@ import ( "regexp" "strconv" "strings" + "syscall" "time" _ "github.com/mattn/go-sqlite3" @@ -255,6 +256,63 @@ func (r *gzipReadCloser) Close() error { return r.closer.Close() } +// createFile creates the file and attempts to set the UID/GID. +func createFile(filename string, uid, gid int) (*os.File, error) { + f, err := os.Create(filename) + if err != nil { + return nil, err + } + _ = f.Chown(uid, gid) + return f, nil +} + +// mkdirAll is a copy of os.MkdirAll() except that it attempts to set the +// uid/gid for each created directory. +func mkdirAll(path string, perm os.FileMode, uid, gid int) error { + // Fast path: if we can tell whether path is a directory or file, stop with success or error. + dir, err := os.Stat(path) + if err == nil { + if dir.IsDir() { + return nil + } + return &os.PathError{Op: "mkdir", Path: path, Err: syscall.ENOTDIR} + } + + // Slow path: make sure parent exists and then call Mkdir for path. + i := len(path) + for i > 0 && os.IsPathSeparator(path[i-1]) { // Skip trailing path separator. + i-- + } + + j := i + for j > 0 && !os.IsPathSeparator(path[j-1]) { // Scan backward over element. + j-- + } + + if j > 1 { + // Create parent. + err = mkdirAll(fixRootDirectory(path[:j-1]), perm, uid, gid) + if err != nil { + return err + } + } + + // Parent now exists; invoke Mkdir and use its result. + err = os.Mkdir(path, perm) + if err != nil { + // Handle arguments like "foo/." by + // double-checking that directory doesn't exist. + dir, err1 := os.Lstat(path) + if err1 == nil && dir.IsDir() { + _ = os.Chown(path, uid, gid) + return nil + } + return err + } + _ = os.Chown(path, uid, gid) + return nil +} + func assert(condition bool, message string) { if !condition { panic("assertion failed: " + message) diff --git a/litestream_unix.go b/litestream_unix.go new file mode 100644 index 0000000..7ec7618 --- /dev/null +++ b/litestream_unix.go @@ -0,0 +1,18 @@ +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris + +package litestream + +import ( + "os" + "syscall" +) + +// fileinfo returns syscall fields from a FileInfo object. +func fileinfo(fi os.FileInfo) (uid, gid int) { + stat := fi.Sys().(*syscall.Stat_t) + return int(stat.Uid), int(stat.Gid) +} + +func fixRootDirectory(p string) string { + return p +} diff --git a/litestream_windows.go b/litestream_windows.go new file mode 100644 index 0000000..ba745b5 --- /dev/null +++ b/litestream_windows.go @@ -0,0 +1,23 @@ +// +build windows + +package litestream + +import ( + "os" + "syscall" +) + +// fileinfo returns syscall fields from a FileInfo object. +func fileinfo(fi os.FileInfo) (uid, gid int) { + return -1, -1 +} + +// fixRootDirectory is copied from the standard library for use with mkdirAll() +func fixRootDirectory(p string) string { + if len(p) == len(`\\?\c:`) { + if IsPathSeparator(p[0]) && IsPathSeparator(p[1]) && p[2] == '?' && IsPathSeparator(p[3]) && p[5] == ':' { + return p + `\` + } + } + return p +} diff --git a/replica.go b/replica.go index a09bb11..3d3e4a9 100644 --- a/replica.go +++ b/replica.go @@ -531,11 +531,11 @@ func (r *FileReplica) snapshot(ctx context.Context, generation string, index int return nil } - if err := os.MkdirAll(filepath.Dir(snapshotPath), 0700); err != nil { + if err := mkdirAll(filepath.Dir(snapshotPath), 0700, r.db.uid, r.db.gid); err != nil { return err } - return compressFile(r.db.Path(), snapshotPath) + return compressFile(r.db.Path(), snapshotPath, r.db.uid, r.db.gid) } // snapshotN returns the number of snapshots for a generation. @@ -617,7 +617,7 @@ func (r *FileReplica) syncWAL(ctx context.Context) (err error) { // Ensure parent directory exists for WAL file. filename := r.WALPath(rd.Pos().Generation, rd.Pos().Index) - if err := os.MkdirAll(filepath.Dir(filename), 0700); err != nil { + if err := mkdirAll(filepath.Dir(filename), 0700, r.db.uid, r.db.gid); err != nil { return err } @@ -669,7 +669,7 @@ func (r *FileReplica) compress(ctx context.Context, generation string) error { } dst := filename + ".gz" - if err := compressFile(filename, dst); err != nil { + if err := compressFile(filename, dst, r.db.uid, r.db.gid); err != nil { return err } else if err := os.Remove(filename); err != nil { return err @@ -824,14 +824,14 @@ func (r *FileReplica) WALReader(ctx context.Context, generation string, index in } // compressFile compresses a file and replaces it with a new file with a .gz extension. -func compressFile(src, dst string) error { +func compressFile(src, dst string, uid, gid int) error { r, err := os.Open(src) if err != nil { return err } defer r.Close() - w, err := os.Create(dst + ".tmp") + w, err := createFile(dst+".tmp", uid, gid) if err != nil { return err }