-
Notifications
You must be signed in to change notification settings - Fork 380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework client to prevent after-Close usage, and support perm at Open #574
Changes from 3 commits
d1903fb
f3501dc
e21cd94
3df3035
4cd7ff4
6c7c0da
e808920
ba3d6ab
72aa403
3ce4d4e
5d66cde
159d286
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -257,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie | |
// read/write at the same time. For those services you will need to use | ||
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`. | ||
func (c *Client) Create(path string) (*File, error) { | ||
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) | ||
return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) | ||
drakkan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt | ||
|
@@ -363,7 +363,10 @@ func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, e | |
filename, data = unmarshalString(data) | ||
_, data = unmarshalString(data) // discard longname | ||
var attr *FileStat | ||
attr, data = unmarshalAttrs(data) | ||
attr, data, err = unmarshalAttrs(data) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if filename == "." || filename == ".." { | ||
continue | ||
} | ||
|
@@ -434,8 +437,8 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { | |
if sid != id { | ||
return nil, &unexpectedIDErr{id, sid} | ||
} | ||
attr, _ := unmarshalAttrs(data) | ||
return fileInfoFromStat(attr, path.Base(p)), nil | ||
attr, _, err := unmarshalAttrs(data) | ||
return fileInfoFromStat(attr, path.Base(p)), err | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should prefer returning either a valid left or valid right, not both. |
||
case sshFxpStatus: | ||
return nil, normaliseError(unmarshalStatus(id, data)) | ||
default: | ||
|
@@ -510,7 +513,7 @@ func (c *Client) Symlink(oldname, newname string) error { | |
} | ||
} | ||
|
||
func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error { | ||
func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error { | ||
drakkan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
id := c.nextID() | ||
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{ | ||
ID: id, | ||
|
@@ -590,14 +593,14 @@ func (c *Client) Truncate(path string, size int64) error { | |
// returned file can be used for reading; the associated file descriptor | ||
// has mode O_RDONLY. | ||
func (c *Client) Open(path string) (*File, error) { | ||
return c.open(path, flags(os.O_RDONLY)) | ||
return c.open(path, toPflags(os.O_RDONLY)) | ||
} | ||
|
||
// OpenFile is the generalized open call; most users will use Open or | ||
// Create instead. It opens the named file with specified flag (O_RDONLY | ||
// etc.). If successful, methods on the returned File can be used for I/O. | ||
func (c *Client) OpenFile(path string, f int) (*File, error) { | ||
return c.open(path, flags(f)) | ||
return c.open(path, toPflags(f)) | ||
} | ||
|
||
func (c *Client) open(path string, pflags uint32) (*File, error) { | ||
|
@@ -660,8 +663,8 @@ func (c *Client) stat(path string) (*FileStat, error) { | |
if sid != id { | ||
return nil, &unexpectedIDErr{id, sid} | ||
} | ||
attr, _ := unmarshalAttrs(data) | ||
return attr, nil | ||
attr, _, err := unmarshalAttrs(data) | ||
return attr, err | ||
case sshFxpStatus: | ||
return nil, normaliseError(unmarshalStatus(id, data)) | ||
default: | ||
|
@@ -684,8 +687,8 @@ func (c *Client) fstat(handle string) (*FileStat, error) { | |
if sid != id { | ||
return nil, &unexpectedIDErr{id, sid} | ||
} | ||
attr, _ := unmarshalAttrs(data) | ||
return attr, nil | ||
attr, _, err := unmarshalAttrs(data) | ||
return attr, err | ||
case sshFxpStatus: | ||
return nil, normaliseError(unmarshalStatus(id, data)) | ||
default: | ||
|
@@ -974,18 +977,32 @@ func (c *Client) RemoveAll(path string) error { | |
|
||
// File represents a remote file. | ||
type File struct { | ||
c *Client | ||
path string | ||
handle string | ||
c *Client | ||
path string | ||
|
||
mu sync.Mutex | ||
mu sync.RWMutex | ||
handle string | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
offset int64 // current offset within remote file | ||
} | ||
|
||
// Close closes the File, rendering it unusable for I/O. It returns an | ||
// error, if any. | ||
func (f *File) Close() error { | ||
return f.c.close(f.handle) | ||
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
// When `openssh-portable/sftp-server.c` is doing `handle_close`, | ||
// it will unconditionally mark the handle as unused, | ||
// so we need to also unconditionally mark this handle as invalid. | ||
|
||
handle := f.handle | ||
f.handle = "" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the meat of the fix for double closes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The design principle here is that the handle on the server side may be reused after a This way, one cannot “accidentally” also use-after-close something like |
||
|
||
return f.c.close(handle) | ||
} | ||
|
||
// Name returns the name of the file as presented to Open or Create. | ||
|
@@ -1006,7 +1023,11 @@ func (f *File) Read(b []byte) (int, error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
n, err := f.ReadAt(b, f.offset) | ||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
n, err := f.readAt(b, f.offset) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function has to change names, so that So, |
||
f.offset += int64(n) | ||
return n, err | ||
} | ||
|
@@ -1071,6 +1092,17 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) { | |
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics, | ||
// so the file offset is not altered during the read. | ||
func (f *File) ReadAt(b []byte, off int64) (int, error) { | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
Comment on lines
+1097
to
+1098
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mutex is converted to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, so this is why we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is why. |
||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
return f.readAt(b, off) | ||
} | ||
|
||
func (f *File) readAt(b []byte, off int64) (int, error) { | ||
if len(b) <= f.c.maxPacket { | ||
// This should be able to be serviced with 1/2 requests. | ||
// So, just do it directly. | ||
|
@@ -1267,6 +1299,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
if f.c.disableConcurrentReads { | ||
return f.writeToSequential(w) | ||
} | ||
|
@@ -1459,6 +1495,17 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { | |
// Stat returns the FileInfo structure describing file. If there is an | ||
// error. | ||
func (f *File) Stat() (os.FileInfo, error) { | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return nil, os.ErrClosed | ||
} | ||
|
||
return f.stat() | ||
} | ||
|
||
func (f *File) stat() (os.FileInfo, error) { | ||
drakkan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fs, err := f.c.fstat(f.handle) | ||
if err != nil { | ||
return nil, err | ||
|
@@ -1478,7 +1525,11 @@ func (f *File) Write(b []byte) (int, error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
n, err := f.WriteAt(b, f.offset) | ||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
n, err := f.writeAt(b, f.offset) | ||
f.offset += int64(n) | ||
return n, err | ||
} | ||
|
@@ -1636,6 +1687,17 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { | |
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics, | ||
// so the file offset is not altered during the write. | ||
func (f *File) WriteAt(b []byte, off int64) (written int, err error) { | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
return f.writeAt(b, off) | ||
} | ||
|
||
func (f *File) writeAt(b []byte, off int64) (written int, err error) { | ||
if len(b) <= f.c.maxPacket { | ||
// We can do this in one write. | ||
return f.writeChunkAt(nil, b, off) | ||
|
@@ -1675,6 +1737,17 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) { | |
// | ||
// Otherwise, the given concurrency will be capped by the Client's max concurrency. | ||
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { | ||
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
return f.readFromWithConcurrency(r, concurrency) | ||
} | ||
|
||
func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { | ||
drakkan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Split the write into multiple maxPacket sized concurrent writes. | ||
// This allows writes with a suitably large reader | ||
// to transfer data at a much faster rate due to overlapping round trip times. | ||
|
@@ -1824,6 +1897,10 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
if f.c.useConcurrentWrites { | ||
var remain int64 | ||
switch r := r.(type) { | ||
|
@@ -1845,7 +1922,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { | |
|
||
if remain < 0 { | ||
// We can strongly assert that we want default max concurrency here. | ||
return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests) | ||
return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests) | ||
} | ||
|
||
if remain > int64(f.c.maxPacket) { | ||
|
@@ -1860,7 +1937,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { | |
concurrency64 = int64(f.c.maxConcurrentRequests) | ||
} | ||
|
||
return f.ReadFromWithConcurrency(r, int(concurrency64)) | ||
return f.readFromWithConcurrency(r, int(concurrency64)) | ||
} | ||
} | ||
|
||
|
@@ -1903,12 +1980,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { | |
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return 0, os.ErrClosed | ||
} | ||
|
||
switch whence { | ||
case io.SeekStart: | ||
case io.SeekCurrent: | ||
offset += f.offset | ||
case io.SeekEnd: | ||
fi, err := f.Stat() | ||
fi, err := f.stat() | ||
if err != nil { | ||
return f.offset, err | ||
} | ||
|
@@ -1927,20 +2008,60 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { | |
|
||
// Chown changes the uid/gid of the current file. | ||
func (f *File) Chown(uid, gid int) error { | ||
return f.c.Chown(f.path, uid, gid) | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{ | ||
UID: uint32(uid), | ||
GID: uint32(gid), | ||
}) | ||
drakkan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
// Chmod changes the permissions of the current file. | ||
// | ||
// See Client.Chmod for details. | ||
func (f *File) Chmod(mode os.FileMode) error { | ||
return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) | ||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) | ||
} | ||
|
||
// Truncate sets the size of the current file. Although it may be safely assumed | ||
// that if the size is less than its current size it will be truncated to fit, | ||
// the SFTP protocol does not specify what behavior the server should do when setting | ||
// size greater than the current size. | ||
// We send a SSH_FXP_FSETSTAT here since we have a file handle | ||
func (f *File) Truncate(size int64) error { | ||
drakkan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f.mu.RLock() | ||
defer f.mu.RUnlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size)) | ||
} | ||
|
||
// Sync requests a flush of the contents of a File to stable storage. | ||
// | ||
// Sync requires the server to support the [email protected] extension. | ||
func (f *File) Sync() error { | ||
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
if f.handle == "" { | ||
return os.ErrClosed | ||
} | ||
|
||
id := f.c.nextID() | ||
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{ | ||
ID: id, | ||
|
@@ -1957,15 +2078,6 @@ func (f *File) Sync() error { | |
} | ||
} | ||
|
||
// Truncate sets the size of the current file. Although it may be safely assumed | ||
// that if the size is less than its current size it will be truncated to fit, | ||
// the SFTP protocol does not specify what behavior the server should do when setting | ||
// size greater than the current size. | ||
// We send a SSH_FXP_FSETSTAT here since we have a file handle | ||
func (f *File) Truncate(size int64) error { | ||
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size)) | ||
} | ||
|
||
// normaliseError normalises an error into a more standard form that can be | ||
// checked against stdlib errors like io.EOF or os.ErrNotExist. | ||
func normaliseError(err error) error { | ||
|
@@ -1990,7 +2102,7 @@ func normaliseError(err error) error { | |
|
||
// flags converts the flags passed to OpenFile into ssh flags. | ||
// Unsupported flags are ignored. | ||
func flags(f int) uint32 { | ||
func toPflags(f int) uint32 { | ||
drakkan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
var out uint32 | ||
switch f & os.O_WRONLY { | ||
case os.O_WRONLY: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was located in a different file, which is a little weird, since it operates on
FileStat
.