Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework client to prevent after-Close usage, and support perm at Open #574

Merged
merged 12 commits into from
Feb 12, 2024
19 changes: 17 additions & 2 deletions attrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func (fi *fileInfo) Name() string { return fi.name }
func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) }

// Mode returns file mode bits.
func (fi *fileInfo) Mode() os.FileMode { return toFileMode(fi.stat.Mode) }
func (fi *fileInfo) Mode() os.FileMode { return fi.stat.FileMode() }

// ModTime returns the last modification time of the file.
func (fi *fileInfo) ModTime() time.Time { return time.Unix(int64(fi.stat.Mtime), 0) }
func (fi *fileInfo) ModTime() time.Time { return fi.stat.ModTime() }

// IsDir returns true if the file is a directory.
func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() }
Expand All @@ -56,6 +56,21 @@ type FileStat struct {
Extended []StatExtended
}

// ModTime returns the Mtime SFTP file attribute converted to a time.Time
func (fs *FileStat) ModTime() time.Time {
return time.Unix(int64(fs.Mtime), 0)
}

// AccessTime returns the Atime SFTP file attribute converted to a time.Time
func (fs *FileStat) AccessTime() time.Time {
return time.Unix(int64(fs.Atime), 0)
}

// FileMode returns the Mode SFTP file attribute converted to an os.FileMode
func (fs *FileStat) FileMode() os.FileMode {
return toFileMode(fs.Mode)
}
Comment on lines +69 to +72
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was located in a different file, which is a little weird, since it operates on FileStat.


// StatExtended contains additional, extended information for a FileStat.
type StatExtended struct {
ExtType string
Expand Down
178 changes: 145 additions & 33 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
// read/write at the same time. For those services you will need to use
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`.
func (c *Client) Create(path string) (*File, error) {
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
drakkan marked this conversation as resolved.
Show resolved Hide resolved
}

const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
Expand Down Expand Up @@ -363,7 +363,10 @@ func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, e
filename, data = unmarshalString(data)
_, data = unmarshalString(data) // discard longname
var attr *FileStat
attr, data = unmarshalAttrs(data)
attr, data, err = unmarshalAttrs(data)
if err != nil {
return nil, err
}
if filename == "." || filename == ".." {
continue
}
Expand Down Expand Up @@ -434,8 +437,8 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return fileInfoFromStat(attr, path.Base(p)), nil
attr, _, err := unmarshalAttrs(data)
return fileInfoFromStat(attr, path.Base(p)), err
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should prefer returning either a valid left or valid right, not both.

case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
Expand Down Expand Up @@ -510,7 +513,7 @@ func (c *Client) Symlink(oldname, newname string) error {
}
}

func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error {
drakkan marked this conversation as resolved.
Show resolved Hide resolved
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
ID: id,
Expand Down Expand Up @@ -590,14 +593,14 @@ func (c *Client) Truncate(path string, size int64) error {
// returned file can be used for reading; the associated file descriptor
// has mode O_RDONLY.
func (c *Client) Open(path string) (*File, error) {
return c.open(path, flags(os.O_RDONLY))
return c.open(path, toPflags(os.O_RDONLY))
}

// OpenFile is the generalized open call; most users will use Open or
// Create instead. It opens the named file with specified flag (O_RDONLY
// etc.). If successful, methods on the returned File can be used for I/O.
func (c *Client) OpenFile(path string, f int) (*File, error) {
return c.open(path, flags(f))
return c.open(path, toPflags(f))
}

func (c *Client) open(path string, pflags uint32) (*File, error) {
Expand Down Expand Up @@ -660,8 +663,8 @@ func (c *Client) stat(path string) (*FileStat, error) {
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return attr, nil
attr, _, err := unmarshalAttrs(data)
return attr, err
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
Expand All @@ -684,8 +687,8 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return attr, nil
attr, _, err := unmarshalAttrs(data)
return attr, err
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
Expand Down Expand Up @@ -974,18 +977,32 @@ func (c *Client) RemoveAll(path string) error {

// File represents a remote file.
type File struct {
c *Client
path string
handle string
c *Client
path string

mu sync.Mutex
mu sync.RWMutex
handle string
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle is now protected by mutex.

offset int64 // current offset within remote file
}

// Close closes the File, rendering it unusable for I/O. It returns an
// error, if any.
func (f *File) Close() error {
return f.c.close(f.handle)
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return os.ErrClosed
}

// When `openssh-portable/sftp-server.c` is doing `handle_close`,
// it will unconditionally mark the handle as unused,
// so we need to also unconditionally mark this handle as invalid.

handle := f.handle
f.handle = ""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the meat of the fix for double closes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design principle here is that the handle on the server side may be reused after a Close, and thus we should synchronize around invalidating the handle in our struct in order to indicate this closed status.

This way, one cannot “accidentally” also use-after-close something like Read or Write as well.


return f.c.close(handle)
}

// Name returns the name of the file as presented to Open or Create.
Expand All @@ -1006,7 +1023,11 @@ func (f *File) Read(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

n, err := f.ReadAt(b, f.offset)
if f.handle == "" {
return 0, os.ErrClosed
}

n, err := f.readAt(b, f.offset)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function has to change names, so that ReadAt can read-lock, while this function needs to hold the write-lock.

So, readat needs to be called with a lock, but it doesn’t matter which lock.

f.offset += int64(n)
return n, err
}
Expand Down Expand Up @@ -1071,6 +1092,17 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) {
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// so the file offset is not altered during the read.
func (f *File) ReadAt(b []byte, off int64) (int, error) {
f.mu.RLock()
defer f.mu.RUnlock()
Comment on lines +1097 to +1098
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutex is converted to a RWMutex so that we can distinguish between “critical only one request is ever pending at a time” from “must be allowed to run with essentially arbitrary parallelism”

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, so this is why we use f.mu.Lock() in Read/Write and f.mu.RLock() in ReadAt/WriteAt even though they both call f.readAt/f.writeAt

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is why.


if f.handle == "" {
return 0, os.ErrClosed
}

return f.readAt(b, off)
}

func (f *File) readAt(b []byte, off int64) (int, error) {
if len(b) <= f.c.maxPacket {
// This should be able to be serviced with 1/2 requests.
// So, just do it directly.
Expand Down Expand Up @@ -1267,6 +1299,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

if f.c.disableConcurrentReads {
return f.writeToSequential(w)
}
Expand Down Expand Up @@ -1459,6 +1495,17 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
// Stat returns the FileInfo structure describing file. If there is an
// error.
func (f *File) Stat() (os.FileInfo, error) {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return nil, os.ErrClosed
}

return f.stat()
}

func (f *File) stat() (os.FileInfo, error) {
drakkan marked this conversation as resolved.
Show resolved Hide resolved
fs, err := f.c.fstat(f.handle)
if err != nil {
return nil, err
Expand All @@ -1478,7 +1525,11 @@ func (f *File) Write(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

n, err := f.WriteAt(b, f.offset)
if f.handle == "" {
return 0, os.ErrClosed
}

n, err := f.writeAt(b, f.offset)
f.offset += int64(n)
return n, err
}
Expand Down Expand Up @@ -1636,6 +1687,17 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) {
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics,
// so the file offset is not altered during the write.
func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return 0, os.ErrClosed
}

return f.writeAt(b, off)
}

func (f *File) writeAt(b []byte, off int64) (written int, err error) {
if len(b) <= f.c.maxPacket {
// We can do this in one write.
return f.writeChunkAt(nil, b, off)
Expand Down Expand Up @@ -1675,6 +1737,17 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
//
// Otherwise, the given concurrency will be capped by the Client's max concurrency.
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

return f.readFromWithConcurrency(r, concurrency)
}

func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
drakkan marked this conversation as resolved.
Show resolved Hide resolved
// Split the write into multiple maxPacket sized concurrent writes.
// This allows writes with a suitably large reader
// to transfer data at a much faster rate due to overlapping round trip times.
Expand Down Expand Up @@ -1824,6 +1897,10 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

if f.c.useConcurrentWrites {
var remain int64
switch r := r.(type) {
Expand All @@ -1845,7 +1922,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {

if remain < 0 {
// We can strongly assert that we want default max concurrency here.
return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests)
return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests)
}

if remain > int64(f.c.maxPacket) {
Expand All @@ -1860,7 +1937,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
concurrency64 = int64(f.c.maxConcurrentRequests)
}

return f.ReadFromWithConcurrency(r, int(concurrency64))
return f.readFromWithConcurrency(r, int(concurrency64))
}
}

Expand Down Expand Up @@ -1903,12 +1980,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

switch whence {
case io.SeekStart:
case io.SeekCurrent:
offset += f.offset
case io.SeekEnd:
fi, err := f.Stat()
fi, err := f.stat()
if err != nil {
return f.offset, err
}
Expand All @@ -1927,20 +2008,60 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {

// Chown changes the uid/gid of the current file.
func (f *File) Chown(uid, gid int) error {
return f.c.Chown(f.path, uid, gid)
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return os.ErrClosed
}

return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{
UID: uint32(uid),
GID: uint32(gid),
})
drakkan marked this conversation as resolved.
Show resolved Hide resolved
}

// Chmod changes the permissions of the current file.
//
// See Client.Chmod for details.
func (f *File) Chmod(mode os.FileMode) error {
return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return os.ErrClosed
}

return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
}

// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
drakkan marked this conversation as resolved.
Show resolved Hide resolved
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return os.ErrClosed
}

return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size))
}

// Sync requests a flush of the contents of a File to stable storage.
//
// Sync requires the server to support the [email protected] extension.
func (f *File) Sync() error {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return os.ErrClosed
}

id := f.c.nextID()
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id,
Expand All @@ -1957,15 +2078,6 @@ func (f *File) Sync() error {
}
}

// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
}

// normaliseError normalises an error into a more standard form that can be
// checked against stdlib errors like io.EOF or os.ErrNotExist.
func normaliseError(err error) error {
Expand All @@ -1990,7 +2102,7 @@ func normaliseError(err error) error {

// flags converts the flags passed to OpenFile into ssh flags.
// Unsupported flags are ignored.
func flags(f int) uint32 {
func toPflags(f int) uint32 {
drakkan marked this conversation as resolved.
Show resolved Hide resolved
var out uint32
switch f & os.O_WRONLY {
case os.O_WRONLY:
Expand Down
Loading
Loading