diff --git a/client.go b/client.go index 8f6e0486..11bb0bc2 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "io" + "math" "os" "path" "sync" @@ -52,30 +53,6 @@ func MaxPacketChecked(size int) ClientOption { } } -// UseFstat sets whether to use Fstat or Stat when File.WriteTo is called -// (usually when copying files). -// Some servers limit the amount of open files and calling Stat after opening -// the file will throw an error From the server. Setting this flag will call -// Fstat instead of Stat which is suppose to be called on an open file handle. -// -// It has been found that that with IBM Sterling SFTP servers which have -// "extractability" level set to 1 which means only 1 file can be opened at -// any given time. -// -// If the server you are working with still has an issue with both Stat and -// Fstat calls you can always open a file and read it until the end. -// -// Another reason to read the file until its end and Fstat doesn't work is -// that in some servers, reading a full file will automatically delete the -// file as some of these mainframes map the file to a message in a queue. -// Once the file has been read it will get deleted. -func UseFstat(value bool) ClientOption { - return func(c *Client) error { - c.useFstat = value - return nil - } -} - // MaxPacketUnchecked sets the maximum size of the payload, measured in bytes. // It accepts sizes larger than the 32768 bytes all servers should support. // Only use a setting higher than 32768 if your application always connects to @@ -120,6 +97,65 @@ func MaxConcurrentRequestsPerFile(n int) ClientOption { } } +// UseConcurrentWrites allows the Client to perform concurrent Writes. +// +// Using concurrency while doing writes, requires special consideration. +// A write to a later offset in a file after an error, +// could end up with a file length longer than what was successfully written. +// +// When using this option, if you receive an error during `io.Copy` or `io.WriteTo`, +// you may need to `Truncate` the target Writer to avoid “holes” in the data written. +func UseConcurrentWrites(value bool) ClientOption { + return func(c *Client) error { + c.useConcurrentWrites = value + return nil + } +} + +// UseFstat sets whether to use Fstat or Stat when File.WriteTo is called +// (usually when copying files). +// Some servers limit the amount of open files and calling Stat after opening +// the file will throw an error From the server. Setting this flag will call +// Fstat instead of Stat which is suppose to be called on an open file handle. +// +// It has been found that that with IBM Sterling SFTP servers which have +// "extractability" level set to 1 which means only 1 file can be opened at +// any given time. +// +// If the server you are working with still has an issue with both Stat and +// Fstat calls you can always open a file and read it until the end. +// +// Another reason to read the file until its end and Fstat doesn't work is +// that in some servers, reading a full file will automatically delete the +// file as some of these mainframes map the file to a message in a queue. +// Once the file has been read it will get deleted. +func UseFstat(value bool) ClientOption { + return func(c *Client) error { + c.useFstat = value + return nil + } +} + +// Client represents an SFTP session on a *ssh.ClientConn SSH connection. +// Multiple Clients can be active on a single SSH connection, and a Client +// may be called concurrently from multiple Goroutines. +// +// Client implements the github.com/kr/fs.FileSystem interface. +type Client struct { + clientConn + + ext map[string]string // Extensions (name -> data). + + maxPacket int // max packet size read or written. + maxConcurrentRequests int + nextid uint32 + + // write concurrency is… error prone. + // Default behavior should be to not use it. + useConcurrentWrites bool + useFstat bool +} + // NewClient creates a new SFTP client on conn, using zero or more option // functions. func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { @@ -161,10 +197,14 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie maxPacket: 1 << 15, maxConcurrentRequests: 64, } - if err := sftp.applyOptions(opts...); err != nil { - wr.Close() - return nil, err + + for _, opt := range opts { + if err := opt(sftp); err != nil { + wr.Close() + return nil, err + } } + if err := sftp.sendInit(); err != nil { wr.Close() return nil, err @@ -173,25 +213,11 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie wr.Close() return nil, err } + sftp.clientConn.wg.Add(1) go sftp.loop() - return sftp, nil -} -// Client represents an SFTP session on a *ssh.ClientConn SSH connection. -// Multiple Clients can be active on a single SSH connection, and a Client -// may be called concurrently from multiple Goroutines. -// -// Client implements the github.com/kr/fs.FileSystem interface. -type Client struct { - clientConn - - ext map[string]string // Extensions (name -> data). - - maxPacket int // max packet size read or written. - maxConcurrentRequests int - nextid uint32 - useFstat bool + return sftp, nil } // Create creates the named file mode 0666 (before umask), truncating it if it @@ -209,7 +235,7 @@ func (c *Client) Create(path string) (*File, error) { const sftpProtocolVersion = 3 // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 func (c *Client) sendInit() error { - return c.clientConn.conn.sendPacket(sshFxInitPacket{ + return c.clientConn.conn.sendPacket(&sshFxInitPacket{ Version: sftpProtocolVersion, // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 }) } @@ -271,7 +297,7 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { var done = false for !done { id := c.nextID() - typ, data, err1 := c.sendPacket(sshFxpReaddirPacket{ + typ, data, err1 := c.sendPacket(nil, &sshFxpReaddirPacket{ ID: id, Handle: handle, }) @@ -314,7 +340,7 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { func (c *Client) opendir(path string) (string, error) { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpOpendirPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpOpendirPacket{ ID: id, Path: path, }) @@ -340,7 +366,7 @@ func (c *Client) opendir(path string) (string, error) { // If 'p' is a symbolic link, the returned FileInfo structure describes the referent file. func (c *Client) Stat(p string) (os.FileInfo, error) { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpStatPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{ ID: id, Path: p, }) @@ -366,7 +392,7 @@ func (c *Client) Stat(p string) (os.FileInfo, error) { // If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link. func (c *Client) Lstat(p string) (os.FileInfo, error) { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpLstatPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpLstatPacket{ ID: id, Path: p, }) @@ -391,7 +417,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { // ReadLink reads the target of a symbolic link. func (c *Client) ReadLink(p string) (string, error) { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpReadlinkPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpReadlinkPacket{ ID: id, Path: p, }) @@ -420,7 +446,7 @@ func (c *Client) ReadLink(p string) (string, error) { // Link creates a hard link at 'newname', pointing at the same inode as 'oldname' func (c *Client) Link(oldname, newname string) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpHardlinkPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpHardlinkPacket{ ID: id, Oldpath: oldname, Newpath: newname, @@ -439,7 +465,7 @@ func (c *Client) Link(oldname, newname string) error { // Symlink creates a symbolic link at 'newname', pointing at target 'oldname' func (c *Client) Symlink(oldname, newname string) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpSymlinkPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpSymlinkPacket{ ID: id, Linkpath: newname, Targetpath: oldname, @@ -457,7 +483,7 @@ func (c *Client) Symlink(oldname, newname string) error { func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpFsetstatPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpFsetstatPacket{ ID: id, Handle: handle, Flags: flags, @@ -477,7 +503,7 @@ func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error // setstat is a convience wrapper to allow for changing of various parts of the file descriptor. func (c *Client) setstat(path string, flags uint32, attrs interface{}) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpSetstatPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpSetstatPacket{ ID: id, Path: path, Flags: flags, @@ -543,7 +569,7 @@ func (c *Client) OpenFile(path string, f int) (*File, error) { func (c *Client) open(path string, pflags uint32) (*File, error) { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpOpenPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpOpenPacket{ ID: id, Path: path, Pflags: pflags, @@ -571,7 +597,7 @@ func (c *Client) open(path string, pflags uint32) (*File, error) { // immediately after this request has been sent. func (c *Client) close(handle string) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpClosePacket{ + typ, data, err := c.sendPacket(nil, &sshFxpClosePacket{ ID: id, Handle: handle, }) @@ -588,7 +614,7 @@ func (c *Client) close(handle string) error { func (c *Client) fstat(handle string) (*FileStat, error) { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpFstatPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{ ID: id, Handle: handle, }) @@ -617,7 +643,7 @@ func (c *Client) fstat(handle string) (*FileStat, error) { func (c *Client) StatVFS(path string) (*StatVFS, error) { // send the StatVFS packet to the server id := c.nextID() - typ, data, err := c.sendPacket(sshFxpStatvfsPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpStatvfsPacket{ ID: id, Path: path, }) @@ -672,7 +698,7 @@ func (c *Client) Remove(path string) error { func (c *Client) removeFile(path string) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpRemovePacket{ + typ, data, err := c.sendPacket(nil, &sshFxpRemovePacket{ ID: id, Filename: path, }) @@ -690,7 +716,7 @@ func (c *Client) removeFile(path string) error { // RemoveDirectory removes a directory path. func (c *Client) RemoveDirectory(path string) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpRmdirPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpRmdirPacket{ ID: id, Path: path, }) @@ -708,7 +734,7 @@ func (c *Client) RemoveDirectory(path string) error { // Rename renames a file. func (c *Client) Rename(oldname, newname string) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpRenamePacket{ + typ, data, err := c.sendPacket(nil, &sshFxpRenamePacket{ ID: id, Oldpath: oldname, Newpath: newname, @@ -728,7 +754,7 @@ func (c *Client) Rename(oldname, newname string) error { // which will replace newname if it already exists. func (c *Client) PosixRename(oldname, newname string) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpPosixRenamePacket{ + typ, data, err := c.sendPacket(nil, &sshFxpPosixRenamePacket{ ID: id, Oldpath: oldname, Newpath: newname, @@ -746,7 +772,7 @@ func (c *Client) PosixRename(oldname, newname string) error { func (c *Client) realpath(path string) (string, error) { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpRealpathPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpRealpathPacket{ ID: id, Path: path, }) @@ -783,7 +809,7 @@ func (c *Client) Getwd() (string, error) { // parent folder does not exist (the method cannot create complete paths). func (c *Client) Mkdir(path string) error { id := c.nextID() - typ, data, err := c.sendPacket(sshFxpMkdirPacket{ + typ, data, err := c.sendPacket(nil, &sshFxpMkdirPacket{ ID: id, Path: path, }) @@ -846,17 +872,6 @@ func (c *Client) MkdirAll(path string) error { return nil } -// applyOptions applies options functions to the Client. -// If an error is encountered, option processing ceases. -func (c *Client) applyOptions(opts ...ClientOption) error { - for _, f := range opts { - if err := f(c); err != nil { - return err - } - } - return nil -} - // File represents a remote file. type File struct { c *Client @@ -864,7 +879,7 @@ type File struct { handle string mu sync.Mutex - offset uint64 // current offset within remote file + offset int64 // current offset within remote file } // Close closes the File, rendering it unusable for I/O. It returns an @@ -891,113 +906,199 @@ func (f *File) Read(b []byte) (int, error) { f.mu.Lock() defer f.mu.Unlock() - r, err := f.ReadAt(b, int64(f.offset)) - f.offset += uint64(r) - return r, err + n, err := f.ReadAt(b, f.offset) + f.offset += int64(n) + return n, err } -// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns -// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics, -// so the file offset is not altered during the read. -func (f *File) ReadAt(b []byte, off int64) (n int, err error) { - // Split the read into multiple maxPacket sized concurrent reads - // bounded by maxConcurrentRequests. This allows reads with a suitably - // large buffer to transfer data at a much faster rate due to - // overlapping round trip times. - inFlight := 0 - desiredInFlight := 1 - offset := uint64(off) - // maxConcurrentRequests buffer to deal with broadcastErr() floods - // also must have a buffer of max value of (desiredInFlight - inFlight) - ch := make(chan result, f.c.maxConcurrentRequests+1) - type inflightRead struct { - b []byte - offset uint64 - } - reqs := map[uint32]inflightRead{} - type offsetErr struct { - offset uint64 - err error - } - var firstErr offsetErr - - sendReq := func(b []byte, offset uint64) { - reqID := f.c.nextID() - f.c.dispatchRequest(ch, sshFxpReadPacket{ - ID: reqID, +// readChunkAt attempts to read the whole entire length of the buffer from the file starting at the offset. +// It will continue progressively reading into the buffer until it fills the whole buffer, or an error occurs. +func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) { + for err == nil && n < len(b) { + id := f.c.nextID() + typ, data, err := f.c.sendPacket(ch, &sshFxpReadPacket{ + ID: id, Handle: f.handle, - Offset: offset, - Len: uint32(len(b)), + Offset: uint64(off) + uint64(n), + Len: uint32(len(b) - n), }) - inFlight++ - reqs[reqID] = inflightRead{b: b, offset: offset} - } - - var read int - for len(b) > 0 || inFlight > 0 { - for inFlight < desiredInFlight && len(b) > 0 && firstErr.err == nil { - l := min(len(b), f.c.maxPacket) - rb := b[:l] - sendReq(rb, offset) - offset += uint64(l) - b = b[l:] + if err != nil { + return n, err } - if inFlight == 0 { - break - } - res := <-ch - inFlight-- - if res.err != nil { - firstErr = offsetErr{offset: 0, err: res.err} - continue - } - reqID, data := unmarshalUint32(res.data) - req, ok := reqs[reqID] - if !ok { - firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)} - continue - } - delete(reqs, reqID) - switch res.typ { + switch typ { case sshFxpStatus: - if firstErr.err == nil || req.offset < firstErr.offset { - firstErr = offsetErr{ - offset: req.offset, - err: normaliseError(unmarshalStatus(reqID, res.data)), - } - } + return n, normaliseError(unmarshalStatus(id, data)) + case sshFxpData: + sid, data := unmarshalUint32(data) + if id != sid { + return n, &unexpectedIDErr{id, sid} + } + l, data := unmarshalUint32(data) - n := copy(req.b, data[:l]) - read += n - if n < len(req.b) { - sendReq(req.b[l:], req.offset+uint64(l)) + n += copy(b[n:], data[:l]) + + default: + return n, unimplementedPacketErr(typ) + } + } + + return +} + +// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns +// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics, +// so the file offset is not altered during the read. +func (f *File) ReadAt(b []byte, off int64) (int, error) { + if len(b) <= f.c.maxPacket { + // This should be able to be serviced with 1/2 requests. + // So, just do it directly. + return f.readChunkAt(nil, b, off) + } + + // Split the read into multiple maxPacket-sized concurrent reads bounded by maxConcurrentRequests. + // This allows writes with a suitably large buffer to transfer data at a much faster rate + // by overlapping round trip times. + + cancel := make(chan struct{}) + + type work struct { + b []byte + off int64 + } + workCh := make(chan work) + + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. + go func() { + defer close(workCh) + + b := b + offset := off + chunkSize := f.c.maxPacket + + for len(b) > 0 { + rb := b + if len(rb) > chunkSize { + rb = rb[:chunkSize] + } + + select { + case workCh <- work{rb, offset}: + case <-cancel: + return } - if desiredInFlight < f.c.maxConcurrentRequests { - desiredInFlight++ + + offset += int64(len(rb)) + b = b[len(rb):] + } + }() + + type rErr struct { + off int64 + err error + } + errCh := make(chan rErr) + + concurrency := len(b)/f.c.maxPacket + 1 + if concurrency > f.c.maxConcurrentRequests { + concurrency = f.c.maxConcurrentRequests + } + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets work, and then performs the Read into its buffer from its respective offset. + go func() { + defer wg.Done() + + ch := make(chan result, 1) // reusable channel per mapper. + + for packet := range workCh { + n, err := f.readChunkAt(ch, packet.b, packet.off) + if err != nil { + // return the offset as the start + how much we read before the error. + errCh <- rErr{packet.off + int64(n), err} + return + } } + }() + } + + // Wait for long tail, before closing results. + go func() { + wg.Wait() + close(errCh) + }() + + // Reduce: collect all the results into a relevant return: the earliest offset to return an error. + firstErr := rErr{math.MaxInt64, nil} + for rErr := range errCh { + if rErr.off <= firstErr.off { + firstErr = rErr + } + + select { + case <-cancel: default: - firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)} + // stop any more work from being distributed. (Just in case.) + close(cancel) } } - // If the error is anything other than EOF, then there - // may be gaps in the data copied to the buffer so it's - // best to return 0 so the caller can't make any - // incorrect assumptions about the state of the buffer. - if firstErr.err != nil && firstErr.err != io.EOF { - read = 0 + + if firstErr.err != nil { + // firstErr.err != nil if and only if firstErr.off > our starting offset. + return int(firstErr.off - off), firstErr.err + } + + // As per spec for io.ReaderAt, we return nil error if and only if we read everything. + return len(b), nil +} + +// writeToSequential implements WriteTo, but works sequentially with no parallelism. +func (f *File) writeToSequential(w io.Writer) (written int64, err error) { + b := make([]byte, f.c.maxPacket) + ch := make(chan result, 1) // reusable channel + + for { + n, err := f.readChunkAt(ch, b, f.offset) + if n < 0 { + panic("sftp.File: returned negative count from readChunkAt") + } + + if n > 0 { + f.offset += int64(n) + + m, err2 := w.Write(b[:n]) + written += int64(m) + + if err == nil { + err = err2 + } + } + + if err != nil { + if err == io.EOF { + return written, nil // return nil explicitly. + } + + return written, err + } } - return read, firstErr.err } -// WriteTo writes the file to w. The return value is the number of bytes -// written. Any error encountered during the write is also returned. +// WriteTo writes the file to the given Writer. +// The return value is the number of bytes written. +// Any error encountered during the write is also returned. // -// This method is preferred over calling Read multiple times to -// maximise throughput for transferring the entire file (especially -// over high latency links). -func (f *File) WriteTo(w io.Writer) (int64, error) { +// This method is preferred over calling Read multiple times +// to maximise throughput for transferring the entire file, +// especially over high latency links. +func (f *File) WriteTo(w io.Writer) (written int64, err error) { + f.mu.Lock() + defer f.mu.Unlock() + + // For concurrency, we want to guess how many concurrent workers we should use. var fileSize uint64 if f.c.useFstat { fileStat, err := f.c.fstat(f.handle) @@ -1013,130 +1114,144 @@ func (f *File) WriteTo(w io.Writer) (int64, error) { fileSize = uint64(fi.Size()) } - inFlight := 0 - desiredInFlight := 1 - offset := f.offset - writeOffset := offset - // see comment on same line in Read() above - ch := make(chan result, f.c.maxConcurrentRequests+1) - type inflightRead struct { - b []byte - offset uint64 - } - reqs := map[uint32]inflightRead{} - pendingWrites := map[uint64][]byte{} - type offsetErr struct { - offset uint64 - err error - } - var firstErr offsetErr - - sendReq := func(b []byte, offset uint64) { - reqID := f.c.nextID() - f.c.dispatchRequest(ch, sshFxpReadPacket{ - ID: reqID, - Handle: f.handle, - Offset: offset, - Len: uint32(len(b)), - }) - inFlight++ - reqs[reqID] = inflightRead{b: b, offset: offset} - } - - var copied int64 - for firstErr.err == nil || inFlight > 0 { - if firstErr.err == nil { - for inFlight+len(pendingWrites) < desiredInFlight { - b := make([]byte, f.c.maxPacket) - sendReq(b, offset) - offset += uint64(f.c.maxPacket) - if offset > fileSize { - desiredInFlight = 1 - } + if fileSize <= uint64(f.c.maxPacket) { + // We should be able to handle this in one Read. + return f.writeToSequential(w) + } + + concurrency := int(fileSize/uint64(f.c.maxPacket) + 1) // a bad guess, but better than no guess + if concurrency > f.c.maxConcurrentRequests { + concurrency = f.c.maxConcurrentRequests + } + + cancel := make(chan struct{}) + var wg sync.WaitGroup + defer func() { + // Once the writing Reduce phase has ended, all the feed work needs to unconditionally stop. + close(cancel) + + // We want to wait until all outstanding goroutines with an `f` or `f.c` reference have completed. + // Just to be sure we don’t orphan any goroutines any hanging references. + wg.Wait() + }() + + type writeWork struct { + b []byte + n int + off int64 + err error + + next chan writeWork + } + writeCh := make(chan writeWork) + + type readWork struct { + off int64 + cur, next chan writeWork + } + readCh := make(chan readWork) + + // Slice: hand out chunks of work on demand, with a `cur` and `next` channel built-in for sequencing. + go func() { + defer close(readCh) + + off := f.offset + chunkSize := int64(f.c.maxPacket) + + cur := writeCh + for { + next := make(chan writeWork) + readWork := readWork{ + off: off, + cur: cur, + next: next, } - } - if inFlight == 0 { - if firstErr.err == nil && len(pendingWrites) > 0 { - return copied, ErrInternalInconsistency + select { + case readCh <- readWork: + case <-cancel: + return } - break - } - res := <-ch - inFlight-- - if res.err != nil { - firstErr = offsetErr{offset: 0, err: res.err} - continue - } - reqID, data := unmarshalUint32(res.data) - req, ok := reqs[reqID] - if !ok { - firstErr = offsetErr{offset: 0, err: errors.Errorf("sid: %v not found", reqID)} - continue + + off += chunkSize + cur = next } - delete(reqs, reqID) - switch res.typ { - case sshFxpStatus: - if firstErr.err == nil || req.offset < firstErr.offset { - firstErr = offsetErr{offset: req.offset, err: normaliseError(unmarshalStatus(reqID, res.data))} - } - case sshFxpData: - l, data := unmarshalUint32(data) - if req.offset == writeOffset { - nbytes, err := w.Write(data) - copied += int64(nbytes) - if err != nil { - // We will never receive another DATA with offset==writeOffset, so - // the loop will drain inFlight and then exit. - firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: err} - break + }() + + pool := sync.Pool{ + New: func() interface{} { + return make([]byte, f.c.maxPacket) + }, + } + + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets readWork, and does the Read into a buffer at the given offset. + go func() { + defer wg.Done() + + ch := make(chan result, 1) // reusable channel + + for readWork := range readCh { + b := pool.Get().([]byte) + + n, err := f.readChunkAt(ch, b, readWork.off) + if n < 0 { + panic("sftp.File: returned negative count from readChunkAt") } - if nbytes < int(l) { - firstErr = offsetErr{offset: req.offset + uint64(nbytes), err: io.ErrShortWrite} - break + + writeWork := writeWork{ + b: b, + n: n, + off: readWork.off, + err: err, + + next: readWork.next, } - switch { - case offset > fileSize: - desiredInFlight = 1 - case desiredInFlight < f.c.maxConcurrentRequests: - desiredInFlight++ + + select { + case readWork.cur <- writeWork: + case <-cancel: + return } - writeOffset += uint64(nbytes) - for { - pendingData, ok := pendingWrites[writeOffset] - if !ok { - break - } - // Give go a chance to free the memory. - delete(pendingWrites, writeOffset) - nbytes, err := w.Write(pendingData) - // Do not move writeOffset on error so subsequent iterations won't trigger - // any writes. - if err != nil { - firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: err} - break - } - if nbytes < len(pendingData) { - firstErr = offsetErr{offset: writeOffset + uint64(nbytes), err: io.ErrShortWrite} - break - } - writeOffset += uint64(nbytes) + + if err != nil { + return } - } else { - // Don't write the data yet because - // this response came in out of order - // and we need to wait for responses - // for earlier segments of the file. - pendingWrites[req.offset] = data } - default: - firstErr = offsetErr{offset: 0, err: unimplementedPacketErr(res.typ)} - } + }() } - if firstErr.err != io.EOF { - return copied, firstErr.err + + // Reduce: serialize the results from the reads into sequential writes. + cur := writeCh + for { + packet, ok := <-cur + if !ok { + return written, nil + } + + // Because writes are serialized, this will always be the last successfully read byte. + f.offset = packet.off + int64(packet.n) + + if packet.n > 0 { + n, err := w.Write(packet.b[:packet.n]) + written += int64(n) + if err != nil { + return written, err + } + } + + if packet.err != nil { + if packet.err == io.EOF { + return written, nil + } + + return written, packet.err + } + + pool.Put(packet.b) + cur = packet.next } - return copied, nil } // Stat returns the FileInfo structure describing file. If there is an @@ -1158,157 +1273,390 @@ func (f *File) Stat() (os.FileInfo, error) { // than calling Write multiple times. io.Copy will do this // automatically. func (f *File) Write(b []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + + n, err := f.WriteAt(b, f.offset) + f.offset += int64(n) + return n, err +} + +func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) { + typ, data, err := f.c.sendPacket(ch, &sshFxpWritePacket{ + ID: f.c.nextID(), + Handle: f.handle, + Offset: uint64(off), + Length: uint32(len(b)), + Data: b, + }) + if err != nil { + return 0, err + } + + switch typ { + case sshFxpStatus: + id, _ := unmarshalUint32(data) + err := normaliseError(unmarshalStatus(id, data)) + if err != nil { + return 0, err + } + + default: + return 0, unimplementedPacketErr(typ) + } + + return len(b), nil +} + +// writeAtConcurrent implements WriterAt, but works concurrently rather than sequentially. +func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { // Split the write into multiple maxPacket sized concurrent writes // bounded by maxConcurrentRequests. This allows writes with a suitably // large buffer to transfer data at a much faster rate due to // overlapping round trip times. - inFlight := 0 - desiredInFlight := 1 - offset := f.offset - // see comment on same line in Read() above - ch := make(chan result, f.c.maxConcurrentRequests+1) - var firstErr error - written := len(b) - for len(b) > 0 || inFlight > 0 { - for inFlight < desiredInFlight && len(b) > 0 && firstErr == nil { - l := min(len(b), f.c.maxPacket) - rb := b[:l] - f.c.dispatchRequest(ch, sshFxpWritePacket{ - ID: f.c.nextID(), - Handle: f.handle, - Offset: offset, - Length: uint32(len(rb)), - Data: rb, - }) - inFlight++ - offset += uint64(l) - b = b[l:] + + cancel := make(chan struct{}) + + type work struct { + b []byte + off int64 + } + workCh := make(chan work) + + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. + go func() { + defer close(workCh) + + var read int + chunkSize := f.c.maxPacket + + for read < len(b) { + wb := b[read:] + if len(wb) > chunkSize { + wb = wb[:chunkSize] + } + + select { + case workCh <- work{wb, off + int64(read)}: + case <-cancel: + return + } + + read += len(wb) } + }() - if inFlight == 0 { - break + type wErr struct { + off int64 + err error + } + errCh := make(chan wErr) + + concurrency := len(b)/f.c.maxPacket + 1 + if concurrency > f.c.maxConcurrentRequests { + concurrency = f.c.maxConcurrentRequests + } + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets work, and does the Write from each buffer to its respective offset. + go func() { + defer wg.Done() + + ch := make(chan result, 1) // reusable channel per mapper. + + for packet := range workCh { + n, err := f.writeChunkAt(ch, packet.b, packet.off) + if err != nil { + // return the offset as the start + how much we wrote before the error. + errCh <- wErr{packet.off + int64(n), err} + } + } + }() + } + + // Wait for long tail, before closing results. + go func() { + wg.Wait() + close(errCh) + }() + + // Reduce: collect all the results into a relevant return: the earliest offset to return an error. + firstErr := wErr{math.MaxInt64, nil} + for wErr := range errCh { + if wErr.off <= firstErr.off { + firstErr = wErr } - res := <-ch - inFlight-- - if res.err != nil { - firstErr = res.err - continue + + select { + case <-cancel: + default: + // stop any more work from being distributed. (Just in case.) + close(cancel) } - switch res.typ { - case sshFxpStatus: - id, _ := unmarshalUint32(res.data) - err := normaliseError(unmarshalStatus(id, res.data)) - if err != nil && firstErr == nil { - firstErr = err - break + } + + if firstErr.err != nil { + // firstErr.err != nil if and only if firstErr.off >= our starting offset. + return int(firstErr.off - off), firstErr.err + } + + return len(b), nil +} + +// WriteAt writess up to len(b) byte to the File at a given offset `off`. It returns +// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics, +// so the file offset is not altered during the write. +func (f *File) WriteAt(b []byte, off int64) (written int, err error) { + if len(b) <= f.c.maxPacket { + // We can do this in one write. + return f.writeChunkAt(nil, b, off) + } + + if f.c.useConcurrentWrites { + return f.writeAtConcurrent(b, off) + } + + ch := make(chan result, 1) // reusable channel + + chunkSize := f.c.maxPacket + + for written < len(b) { + wb := b[written:] + if len(wb) > chunkSize { + wb = wb[:chunkSize] + } + + n, err := f.writeChunkAt(ch, wb, off+int64(written)) + if n > 0 { + written += n + } + + if err != nil { + return written, err + } + } + + return len(b), nil +} + +// readFromConcurrent implements ReaderFrom, but works concurrently rather than sequentially. +func (f *File) readFromConcurrent(r io.Reader, remain int64) (read int64, err error) { + // Split the write into multiple maxPacket sized concurrent writes. + // This allows writes with a suitably large reader + // to transfer data at a much faster rate due to overlapping round trip times. + + cancel := make(chan struct{}) + + type work struct { + b []byte + n int + off int64 + } + workCh := make(chan work) + + type rwErr struct { + off int64 + err error + } + errCh := make(chan rwErr) + + pool := sync.Pool{ + New: func() interface{} { + return make([]byte, f.c.maxPacket) + }, + } + + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. + go func() { + defer close(workCh) + + off := f.offset + + for { + b := pool.Get().([]byte) + + n, err := r.Read(b) + if n > 0 { + read += int64(n) + + select { + case workCh <- work{b, n, off}: + // We need the pool.Put(b) to put the whole slice, not just trunced. + case <-cancel: + return + } + + off += int64(n) } - if desiredInFlight < f.c.maxConcurrentRequests { - desiredInFlight++ + + if err != nil { + if err != io.EOF { + errCh <- rwErr{off, err} + } + return } + } + }() + + concurrency := int(remain/int64(f.c.maxPacket) + 1) // a bad guess, but better than no guess + if concurrency > f.c.maxConcurrentRequests { + concurrency = f.c.maxConcurrentRequests + } + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets work, and does the Write from each buffer to its respective offset. + go func() { + defer wg.Done() + + ch := make(chan result, 1) // reusable channel per mapper. + + for packet := range workCh { + n, err := f.writeChunkAt(ch, packet.b[:packet.n], packet.off) + if err != nil { + // return the offset as the start + how much we wrote before the error. + errCh <- rwErr{packet.off + int64(n), err} + } + pool.Put(packet.b) + } + }() + } + + // Wait for long tail, before closing results. + go func() { + wg.Wait() + close(errCh) + }() + + // Reduce: Collect all the results into a relevant return: the earliest offset to return an error. + firstErr := rwErr{math.MaxInt64, nil} + for rwErr := range errCh { + if rwErr.off <= firstErr.off { + firstErr = rwErr + } + + select { + case <-cancel: default: - firstErr = unimplementedPacketErr(res.typ) + // stop any more work from being distributed. + close(cancel) } } - // If error is non-nil, then there may be gaps in the data written to - // the file so it's best to return 0 so the caller can't make any - // incorrect assumptions about the state of the file. - if firstErr != nil { - written = 0 + + if firstErr.err != nil { + // firstErr.err != nil if and only if firstErr.off is a valid offset. + // + // firstErr.off will then be the lesser of: + // * the offset of the first error from writing, + // * the last successfully read offset. + // + // This could be less than the last succesfully written offset, + // which is the whole reason for the UseConcurrentWrites() ClientOption. + // + // Callers are responsible for truncating any SFTP files to a safe length. + f.offset = firstErr.off + + // ReadFrom is defined to return the read bytes, regardless of any writer errors. + return read, firstErr.err } - f.offset += uint64(written) - return written, firstErr + + f.offset += read + return read, nil } // ReadFrom reads data from r until EOF and writes it to the file. The return // value is the number of bytes read. Any error except io.EOF encountered // during the read is also returned. // -// This method is preferred over calling Write multiple times to -// maximise throughput for transferring the entire file (especially -// over high latency links). +// This method is preferred over calling Write multiple times +// to maximise throughput for transferring the entire file, +// especially over high-latency links. func (f *File) ReadFrom(r io.Reader) (int64, error) { - inFlight := 0 - desiredInFlight := 1 - offset := f.offset - // see comment on same line in Read() above - ch := make(chan result, f.c.maxConcurrentRequests+1) - var firstErr error - read := int64(0) - b := make([]byte, f.c.maxPacket) - for inFlight > 0 || firstErr == nil { - for inFlight < desiredInFlight && firstErr == nil { - n, err := r.Read(b) - if err != nil { - firstErr = err - } - f.c.dispatchRequest(ch, sshFxpWritePacket{ - ID: f.c.nextID(), - Handle: f.handle, - Offset: offset, - Length: uint32(n), - Data: b[:n], - }) - inFlight++ - offset += uint64(n) - read += int64(n) + f.mu.Lock() + defer f.mu.Unlock() + + if f.c.useConcurrentWrites { + var remain int64 + switch r := r.(type) { + case interface{ Len() int }: + remain = int64(r.Len()) + + case *io.LimitedReader: + remain = r.N + + case *os.File: + // For files, always presume max concurrency. + remain = math.MaxInt64 } - if inFlight == 0 { - break + if remain > int64(f.c.maxPacket) { + // Only use concurrency, if it would be at least two read/writes. + return f.readFromConcurrent(r, remain) } - res := <-ch - inFlight-- - if res.err != nil { - firstErr = res.err - continue + } + + ch := make(chan result, 1) // reusable channel + + b := make([]byte, f.c.maxPacket) + + var read int64 + for { + n, err := r.Read(b) + if n < 0 { + panic("sftp.File: reader returned negative count from Read") } - switch res.typ { - case sshFxpStatus: - id, _ := unmarshalUint32(res.data) - err := normaliseError(unmarshalStatus(id, res.data)) - if err != nil && firstErr == nil { - firstErr = err - break + + if n > 0 { + read += int64(n) + + m, err2 := f.writeChunkAt(ch, b[:n], f.offset) + f.offset += int64(m) + + if err == nil { + err = err2 } - if desiredInFlight < f.c.maxConcurrentRequests { - desiredInFlight++ + } + + if err != nil { + if err == io.EOF { + return read, nil // return nil explicitly. } - default: - firstErr = unimplementedPacketErr(res.typ) + + return read, err } } - if firstErr == io.EOF { - firstErr = nil - } - // If error is non-nil, then there may be gaps in the data written to - // the file so it's best to return 0 so the caller can't make any - // incorrect assumptions about the state of the file. - if firstErr != nil { - read = 0 - } - f.offset += uint64(read) - return read, firstErr } // Seek implements io.Seeker by setting the client offset for the next Read or // Write. It returns the next offset read. Seeking before or after the end of // the file is undefined. Seeking relative to the end calls Stat. func (f *File) Seek(offset int64, whence int) (int64, error) { + f.mu.Lock() + defer f.mu.Unlock() + switch whence { case io.SeekStart: - f.offset = uint64(offset) case io.SeekCurrent: - f.offset = uint64(int64(f.offset) + offset) + offset += f.offset case io.SeekEnd: fi, err := f.Stat() if err != nil { - return int64(f.offset), err + return f.offset, err } - f.offset = uint64(fi.Size() + offset) + offset += fi.Size() default: - return int64(f.offset), unimplementedSeekWhence(whence) + return f.offset, unimplementedSeekWhence(whence) } - return int64(f.offset), nil + + if offset < 0 { + return f.offset, os.ErrInvalid + } + + f.offset = offset + return f.offset, nil } // Chown changes the uid/gid of the current file. @@ -1326,7 +1674,7 @@ func (f *File) Chmod(mode os.FileMode) error { // Sync requires the server to support the fsync@openssh.com extension. func (f *File) Sync() error { id := f.c.nextID() - typ, data, err := f.c.sendPacket(sshFxpFsyncPacket{ + typ, data, err := f.c.sendPacket(nil, &sshFxpFsyncPacket{ ID: id, Handle: f.handle, }) diff --git a/client_integration_test.go b/client_integration_test.go index caf15cbe..fcb41c3b 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -21,6 +21,7 @@ import ( "regexp" "sort" "strconv" + "sync" "testing" "testing/quick" "time" @@ -48,41 +49,71 @@ type delayedWrite struct { // underlying writer will panic so this should only be used over reliable // connections. type delayedWriter struct { - w io.WriteCloser - ch chan delayedWrite closed chan struct{} + + mu sync.Mutex + ch chan delayedWrite + closing chan struct{} } func newDelayedWriter(w io.WriteCloser, delay time.Duration) io.WriteCloser { - ch := make(chan delayedWrite, 128) - closed := make(chan struct{}) + dw := &delayedWriter{ + ch: make(chan delayedWrite, 128), + closed: make(chan struct{}), + closing: make(chan struct{}), + } + go func() { - for writeMsg := range ch { + defer close(dw.closed) + defer w.Close() + + for writeMsg := range dw.ch { time.Sleep(time.Until(writeMsg.t.Add(delay))) + n, err := w.Write(writeMsg.b) if err != nil { panic("write error") } + if n < len(writeMsg.b) { panic("showrt write") } } - w.Close() - close(closed) }() - return delayedWriter{w: w, ch: ch, closed: closed} + + return dw } -func (w delayedWriter) Write(b []byte) (int, error) { - bcopy := make([]byte, len(b)) - copy(bcopy, b) - w.ch <- delayedWrite{t: time.Now(), b: bcopy} +func (dw *delayedWriter) Write(b []byte) (int, error) { + dw.mu.Lock() + defer dw.mu.Unlock() + + write := delayedWrite{ + t: time.Now(), + b: append([]byte(nil), b...), + } + + select { + case <-dw.closing: + return 0, errors.New("delayedWriter is closing") + case dw.ch <- write: + } + return len(b), nil } -func (w delayedWriter) Close() error { - close(w.ch) - <-w.closed +func (dw *delayedWriter) Close() error { + dw.mu.Lock() + defer dw.mu.Unlock() + + select { + case <-dw.closing: + default: + close(dw.ch) + close(dw.closing) + } + + <-dw.closed return nil } @@ -100,24 +131,37 @@ func netPipe(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) { t.Fatal(err) } + closeListener := make(chan struct{}, 1) + closeListener <- struct{}{} + ch := make(chan result, 1) go func() { conn, err := l.Accept() ch <- result{conn, err} - err = l.Close() - if err != nil { - t.Error(err) + + if _, ok := <-closeListener; ok { + err = l.Close() + if err != nil { + t.Error(err) + } + close(closeListener) } }() + c1, err := net.Dial("tcp", l.Addr().String()) if err != nil { - l.Close() // might cause another in the listening goroutine, but too bad + if _, ok := <-closeListener; ok { + l.Close() + close(closeListener) + } t.Fatal(err) } + r := <-ch if r.error != nil { t.Fatal(err) } + return c1, r.Conn } @@ -135,12 +179,12 @@ func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration) (*Client, } go server.Serve() - var ctx io.WriteCloser = c2 + var wr io.WriteCloser = c2 if delay > NODELAY { - ctx = newDelayedWriter(ctx, delay) + wr = newDelayedWriter(wr, delay) } - client, err := NewClientPipe(c2, ctx) + client, err := NewClientPipe(c2, wr) if err != nil { t.Fatal(err) } @@ -164,18 +208,23 @@ func testClient(t testing.TB, readonly bool, delay time.Duration) (*Client, *exe if !readonly { cmd = exec.Command(*testSftp, "-e", "-l", debuglevel) // log to stderr } + cmd.Stderr = os.Stdout + pw, err := cmd.StdinPipe() if err != nil { t.Fatal(err) } + if delay > NODELAY { pw = newDelayedWriter(pw, delay) } + pr, err := cmd.StdoutPipe() if err != nil { t.Fatal(err) } + if err := cmd.Start(); err != nil { t.Skipf("could not start sftp-server process: %v", err) } @@ -480,6 +529,7 @@ func TestClientFileName(t *testing.T) { if err != nil { t.Fatal(err) } + defer f2.Close() if got, want := f2.Name(), f.Name(); got != want { t.Fatalf("Name: got %q want %q", want, got) @@ -506,6 +556,7 @@ func TestClientFileStat(t *testing.T) { if err != nil { t.Fatal(err) } + defer f2.Close() got, err := f2.Stat() if err != nil { @@ -1325,12 +1376,9 @@ func TestClientReadFromDeadlock(t *testing.T) { clientWriteDeadlock(t, 1, func(f *File) { b := make([]byte, 32768*4) content := bytes.NewReader(b) - n, err := f.ReadFrom(content) - if n != 0 { - t.Fatal("Write should return 0", n) - } + _, err := f.ReadFrom(content) if err != errFakeNet { - t.Fatal("Didn't recieve correct error", err) + t.Fatal("Didn't recieve correct error:", err) } }) } @@ -1339,12 +1387,10 @@ func TestClientReadFromDeadlock(t *testing.T) { func TestClientWriteDeadlock(t *testing.T) { clientWriteDeadlock(t, 1, func(f *File) { b := make([]byte, 32768*4) - n, err := f.Write(b) - if n != 0 { - t.Fatal("Write should return 0", n) - } + + _, err := f.Write(b) if err != errFakeNet { - t.Fatal("Didn't recieve correct error", err) + t.Fatal("Didn't recieve correct error:", err) } }) } @@ -1394,12 +1440,10 @@ func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) { func TestClientReadDeadlock(t *testing.T) { clientReadDeadlock(t, 1, func(f *File) { b := make([]byte, 32768*4) - n, err := f.Read(b) - if n != 0 { - t.Fatal("Write should return 0", n) - } + + _, err := f.Read(b) if err != errFakeNet { - t.Fatal("Didn't recieve correct error", err) + t.Fatal("Didn't recieve correct error:", err) } }) } @@ -1407,13 +1451,12 @@ func TestClientReadDeadlock(t *testing.T) { func TestClientWriteToDeadlock(t *testing.T) { clientReadDeadlock(t, 2, func(f *File) { b := make([]byte, 32768*4) + buf := bytes.NewBuffer(b) - n, err := f.WriteTo(buf) - if n != 32768 { - t.Fatal("Write should return 0", n) - } + + _, err := f.WriteTo(buf) if err != errFakeNet { - t.Fatal("Didn't recieve correct error", err) + t.Fatal("Didn't recieve correct error:", err) } }) } @@ -1433,14 +1476,16 @@ func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) { defer os.RemoveAll(d) f := path.Join(d, "writeTest") + w, err := sftp.Create(f) if err != nil { t.Fatal(err) } + defer w.Close() + // write the data for the read tests b := make([]byte, 32768*4) w.Write(b) - defer w.Close() // open new copy of file for read tests r, err := sftp.Open(f) @@ -1459,6 +1504,7 @@ func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) { } return sendPacket(w, m) } + sftp.clientConn.conn.sendPacketTest = sendPacketTest defer func() { sftp.clientConn.conn.sendPacketTest = nil @@ -1909,26 +1955,29 @@ func TestServerRoughDisconnect3(t *testing.T) { if *testServerImpl { t.Skipf("skipping with -testserver") } + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - rf, err := sftp.OpenFile("/dev/null", os.O_RDWR) + dest, err := sftp.OpenFile("/dev/null", os.O_RDWR) if err != nil { t.Fatal(err) } - defer rf.Close() - lf, err := os.Open("/dev/zero") + defer dest.Close() + + src, err := os.Open("/dev/zero") if err != nil { t.Fatal(err) } - defer lf.Close() + defer src.Close() + go func() { time.Sleep(10 * time.Millisecond) cmd.Process.Kill() }() - _, err = io.Copy(rf, lf) + _, err = io.Copy(dest, src) assert.Error(t, err) } @@ -1942,31 +1991,34 @@ func TestServerRoughDisconnect4(t *testing.T) { defer cmd.Wait() defer sftp.Close() - rf, err := sftp.OpenFile("/dev/null", os.O_RDWR) + dest, err := sftp.OpenFile("/dev/null", os.O_RDWR) if err != nil { t.Fatal(err) } - defer rf.Close() - lf, err := os.Open("/dev/zero") + defer dest.Close() + + src, err := os.Open("/dev/zero") if err != nil { t.Fatal(err) } - defer lf.Close() + defer src.Close() + go func() { time.Sleep(10 * time.Millisecond) cmd.Process.Kill() }() + b := make([]byte, 32768*200) - lf.Read(b) + src.Read(b) for { - _, err = rf.Write(b) + _, err = dest.Write(b) if err != nil { assert.NotEqual(t, io.EOF, err) break } } - _, err = io.Copy(rf, lf) + _, err = io.Copy(dest, src) assert.Error(t, err) } @@ -2024,7 +2076,7 @@ func benchmarkRead(b *testing.B, bufsize int, delay time.Duration) { // open sftp client sftp, cmd := testClient(b, READONLY, delay) defer cmd.Wait() - // defer sftp.Close() + defer sftp.Close() buf := make([]byte, bufsize) @@ -2038,7 +2090,6 @@ func benchmarkRead(b *testing.B, bufsize int, delay time.Duration) { if err != nil { b.Fatal(err) } - defer f2.Close() for offset < size { n, err := io.ReadFull(f2, buf) @@ -2053,6 +2104,8 @@ func benchmarkRead(b *testing.B, bufsize int, delay time.Duration) { offset += n } + + f2.Close() } } @@ -2102,7 +2155,7 @@ func benchmarkWrite(b *testing.B, bufsize int, delay time.Duration) { // open sftp client sftp, cmd := testClient(b, false, delay) defer cmd.Wait() - // defer sftp.Close() + defer sftp.Close() data := make([]byte, size) @@ -2116,13 +2169,12 @@ func benchmarkWrite(b *testing.B, bufsize int, delay time.Duration) { if err != nil { b.Fatal(err) } - defer os.Remove(f.Name()) + defer os.Remove(f.Name()) // actually queue up a series of removes for these files f2, err := sftp.Create(f.Name()) if err != nil { b.Fatal(err) } - defer f2.Close() for offset < size { n, err := f2.Write(data[offset:min(len(data), offset+bufsize)]) @@ -2198,7 +2250,7 @@ func benchmarkReadFrom(b *testing.B, bufsize int, delay time.Duration) { // open sftp client sftp, cmd := testClient(b, false, delay) defer cmd.Wait() - // defer sftp.Close() + defer sftp.Close() data := make([]byte, size) @@ -2274,6 +2326,87 @@ func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) { benchmarkReadFrom(b, 4*1024*1024, 150*time.Millisecond) } +func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { + size := 10*1024*1024 + 123 // ~10MiB + + // open sftp client + sftp, cmd := testClient(b, false, delay) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-benchwriteto") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + + data := make([]byte, size) + + f.Write(data) + f.Close() + + b.ResetTimer() + b.SetBytes(int64(size)) + + buf := new(bytes.Buffer) + + for i := 0; i < b.N; i++ { + buf.Reset() + + f2, err := sftp.Open(f.Name()) + if err != nil { + b.Fatal(err) + } + + f2.WriteTo(buf) + f2.Close() + + if buf.Len() != size { + b.Fatalf("wrote buffer size: want %d, got %d", size, buf.Len()) + } + } +} + +func BenchmarkWriteTo1k(b *testing.B) { + benchmarkWriteTo(b, 1*1024, NODELAY) +} + +func BenchmarkWriteTo16k(b *testing.B) { + benchmarkWriteTo(b, 16*1024, NODELAY) +} + +func BenchmarkWriteTo32k(b *testing.B) { + benchmarkWriteTo(b, 32*1024, NODELAY) +} + +func BenchmarkWriteTo128k(b *testing.B) { + benchmarkWriteTo(b, 128*1024, NODELAY) +} + +func BenchmarkWriteTo512k(b *testing.B) { + benchmarkWriteTo(b, 512*1024, NODELAY) +} + +func BenchmarkWriteTo1MiB(b *testing.B) { + benchmarkWriteTo(b, 1024*1024, NODELAY) +} + +func BenchmarkWriteTo4MiB(b *testing.B) { + benchmarkWriteTo(b, 4*1024*1024, NODELAY) +} + +func BenchmarkWriteTo4MiBDelay10Msec(b *testing.B) { + benchmarkWriteTo(b, 4*1024*1024, 10*time.Millisecond) +} + +func BenchmarkWriteTo4MiBDelay50Msec(b *testing.B) { + benchmarkWriteTo(b, 4*1024*1024, 50*time.Millisecond) +} + +func BenchmarkWriteTo4MiBDelay150Msec(b *testing.B) { + benchmarkWriteTo(b, 4*1024*1024, 150*time.Millisecond) +} + func benchmarkCopyDown(b *testing.B, fileSize int64, delay time.Duration) { skipIfWindows(b) // Create a temp file and fill it with zero's. @@ -2300,7 +2433,7 @@ func benchmarkCopyDown(b *testing.B, fileSize int64, delay time.Duration) { sftp, cmd := testClient(b, READONLY, delay) defer cmd.Wait() - // defer sftp.Close() + defer sftp.Close() b.ResetTimer() b.SetBytes(fileSize) @@ -2374,7 +2507,7 @@ func benchmarkCopyUp(b *testing.B, fileSize int64, delay time.Duration) { sftp, cmd := testClient(b, false, delay) defer cmd.Wait() - // defer sftp.Close() + defer sftp.Close() b.ResetTimer() b.SetBytes(fileSize) diff --git a/conn.go b/conn.go index fe07f433..952a2be4 100644 --- a/conn.go +++ b/conn.go @@ -43,7 +43,8 @@ func (c *conn) Close() error { type clientConn struct { conn - wg sync.WaitGroup + wg sync.WaitGroup + sync.Mutex // protects inflight inflight map[uint32]chan<- result // outstanding requests @@ -76,29 +77,53 @@ func (c *clientConn) loop() { // recv continuously reads from the server and forwards responses to the // appropriate channel. func (c *clientConn) recv() error { - defer func() { - c.conn.Close() - }() + defer c.conn.Close() + for { typ, data, err := c.recvPacket(0) if err != nil { return err } sid, _ := unmarshalUint32(data) - c.Lock() - ch, ok := c.inflight[sid] - delete(c.inflight, sid) - c.Unlock() + + ch, ok := c.getChannel(sid) if !ok { // This is an unexpected occurrence. Send the error // back to all listeners so that they terminate // gracefully. - return errors.Errorf("sid: %v not fond", sid) + return errors.Errorf("sid not found: %v", sid) } + ch <- result{typ: typ, data: data} } } +func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool { + c.Lock() + defer c.Unlock() + + select { + case <-c.closed: + // already closed with broadcastErr, return error on chan. + ch <- result{err: ErrSSHFxConnectionLost} + return false + default: + } + + c.inflight[sid] = ch + return true +} + +func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) { + c.Lock() + defer c.Unlock() + + ch, ok := c.inflight[sid] + delete(c.inflight, sid) + + return ch, ok +} + // result captures the result of receiving the a packet from the server type result struct { typ byte @@ -111,37 +136,48 @@ type idmarshaler interface { encoding.BinaryMarshaler } -func (c *clientConn) sendPacket(p idmarshaler) (byte, []byte, error) { - ch := make(chan result, 2) +func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) { + if cap(ch) < 1 { + ch = make(chan result, 1) + } + c.dispatchRequest(ch, p) s := <-ch return s.typ, s.data, s.err } +// dispatchRequest should ideally only be called by race-detection tests outside of this file, +// where you have to ensure two packets are in flight sequentially after each other. func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) { - c.Lock() - c.inflight[p.id()] = ch - c.Unlock() + sid := p.id() + + if !c.putChannel(ch, sid) { + // already closed. + return + } + if err := c.conn.sendPacket(p); err != nil { - c.Lock() - delete(c.inflight, p.id()) - c.Unlock() - ch <- result{err: err} + if ch, ok := c.getChannel(sid); ok { + ch <- result{err: err} + } } } // broadcastErr sends an error to all goroutines waiting for a response. func (c *clientConn) broadcastErr(err error) { c.Lock() - listeners := make([]chan<- result, 0, len(c.inflight)) - for _, ch := range c.inflight { - listeners = append(listeners, ch) - } - c.Unlock() - bcastRes := result{err: errors.New("unexpected server disconnect")} - for _, ch := range listeners { + defer c.Unlock() + + bcastRes := result{err: ErrSSHFxConnectionLost} + for sid, ch := range c.inflight { ch <- bcastRes + + // Replace the chan in inflight, + // we have hijacked this chan, + // and this guarantees always-only-once sending. + c.inflight[sid] = make(chan<- result, 1) } + c.err = err close(c.closed) } @@ -150,6 +186,6 @@ type serverConn struct { conn } -func (s *serverConn) sendError(p ider, err error) error { - return s.sendPacket(statusFromError(p, err)) +func (s *serverConn) sendError(id uint32, err error) error { + return s.sendPacket(statusFromError(id, err)) } diff --git a/packet-manager_test.go b/packet-manager_test.go index 652910b6..e43a4cfc 100644 --- a/packet-manager_test.go +++ b/packet-manager_test.go @@ -31,7 +31,7 @@ func fake(rid, order uint32) fakepacket { } func (fakepacket) MarshalBinary() ([]byte, error) { - return []byte{}, nil + return make([]byte, 4), nil } func (fakepacket) UnmarshalBinary([]byte) error { diff --git a/packet-typing.go b/packet-typing.go index addd0e32..da5c2bc6 100644 --- a/packet-typing.go +++ b/packet-typing.go @@ -34,51 +34,51 @@ type notReadOnly interface { //// define types by adding methods // hasPath -func (p sshFxpLstatPacket) getPath() string { return p.Path } -func (p sshFxpStatPacket) getPath() string { return p.Path } -func (p sshFxpRmdirPacket) getPath() string { return p.Path } -func (p sshFxpReadlinkPacket) getPath() string { return p.Path } -func (p sshFxpRealpathPacket) getPath() string { return p.Path } -func (p sshFxpMkdirPacket) getPath() string { return p.Path } -func (p sshFxpSetstatPacket) getPath() string { return p.Path } -func (p sshFxpStatvfsPacket) getPath() string { return p.Path } -func (p sshFxpRemovePacket) getPath() string { return p.Filename } -func (p sshFxpRenamePacket) getPath() string { return p.Oldpath } -func (p sshFxpSymlinkPacket) getPath() string { return p.Targetpath } -func (p sshFxpOpendirPacket) getPath() string { return p.Path } -func (p sshFxpOpenPacket) getPath() string { return p.Path } +func (p *sshFxpLstatPacket) getPath() string { return p.Path } +func (p *sshFxpStatPacket) getPath() string { return p.Path } +func (p *sshFxpRmdirPacket) getPath() string { return p.Path } +func (p *sshFxpReadlinkPacket) getPath() string { return p.Path } +func (p *sshFxpRealpathPacket) getPath() string { return p.Path } +func (p *sshFxpMkdirPacket) getPath() string { return p.Path } +func (p *sshFxpSetstatPacket) getPath() string { return p.Path } +func (p *sshFxpStatvfsPacket) getPath() string { return p.Path } +func (p *sshFxpRemovePacket) getPath() string { return p.Filename } +func (p *sshFxpRenamePacket) getPath() string { return p.Oldpath } +func (p *sshFxpSymlinkPacket) getPath() string { return p.Targetpath } +func (p *sshFxpOpendirPacket) getPath() string { return p.Path } +func (p *sshFxpOpenPacket) getPath() string { return p.Path } -func (p sshFxpExtendedPacketPosixRename) getPath() string { return p.Oldpath } -func (p sshFxpExtendedPacketHardlink) getPath() string { return p.Oldpath } +func (p *sshFxpExtendedPacketPosixRename) getPath() string { return p.Oldpath } +func (p *sshFxpExtendedPacketHardlink) getPath() string { return p.Oldpath } // getHandle -func (p sshFxpFstatPacket) getHandle() string { return p.Handle } -func (p sshFxpFsetstatPacket) getHandle() string { return p.Handle } -func (p sshFxpReadPacket) getHandle() string { return p.Handle } -func (p sshFxpWritePacket) getHandle() string { return p.Handle } -func (p sshFxpReaddirPacket) getHandle() string { return p.Handle } -func (p sshFxpClosePacket) getHandle() string { return p.Handle } +func (p *sshFxpFstatPacket) getHandle() string { return p.Handle } +func (p *sshFxpFsetstatPacket) getHandle() string { return p.Handle } +func (p *sshFxpReadPacket) getHandle() string { return p.Handle } +func (p *sshFxpWritePacket) getHandle() string { return p.Handle } +func (p *sshFxpReaddirPacket) getHandle() string { return p.Handle } +func (p *sshFxpClosePacket) getHandle() string { return p.Handle } // notReadOnly -func (p sshFxpWritePacket) notReadOnly() {} -func (p sshFxpSetstatPacket) notReadOnly() {} -func (p sshFxpFsetstatPacket) notReadOnly() {} -func (p sshFxpRemovePacket) notReadOnly() {} -func (p sshFxpMkdirPacket) notReadOnly() {} -func (p sshFxpRmdirPacket) notReadOnly() {} -func (p sshFxpRenamePacket) notReadOnly() {} -func (p sshFxpSymlinkPacket) notReadOnly() {} -func (p sshFxpExtendedPacketPosixRename) notReadOnly() {} -func (p sshFxpExtendedPacketHardlink) notReadOnly() {} +func (p *sshFxpWritePacket) notReadOnly() {} +func (p *sshFxpSetstatPacket) notReadOnly() {} +func (p *sshFxpFsetstatPacket) notReadOnly() {} +func (p *sshFxpRemovePacket) notReadOnly() {} +func (p *sshFxpMkdirPacket) notReadOnly() {} +func (p *sshFxpRmdirPacket) notReadOnly() {} +func (p *sshFxpRenamePacket) notReadOnly() {} +func (p *sshFxpSymlinkPacket) notReadOnly() {} +func (p *sshFxpExtendedPacketPosixRename) notReadOnly() {} +func (p *sshFxpExtendedPacketHardlink) notReadOnly() {} // some packets with ID are missing id() -func (p sshFxpDataPacket) id() uint32 { return p.ID } -func (p sshFxpStatusPacket) id() uint32 { return p.ID } -func (p sshFxpStatResponse) id() uint32 { return p.ID } -func (p sshFxpNamePacket) id() uint32 { return p.ID } -func (p sshFxpHandlePacket) id() uint32 { return p.ID } -func (p StatVFS) id() uint32 { return p.ID } -func (p sshFxVersionPacket) id() uint32 { return 0 } +func (p *sshFxpDataPacket) id() uint32 { return p.ID } +func (p *sshFxpStatusPacket) id() uint32 { return p.ID } +func (p *sshFxpStatResponse) id() uint32 { return p.ID } +func (p *sshFxpNamePacket) id() uint32 { return p.ID } +func (p *sshFxpHandlePacket) id() uint32 { return p.ID } +func (p *StatVFS) id() uint32 { return p.ID } +func (p *sshFxVersionPacket) id() uint32 { return 0 } // take raw incoming packet data and build packet objects func makePacket(p rxPacket) (requestPacket, error) { diff --git a/packet.go b/packet.go index 71f0675f..4a686355 100644 --- a/packet.go +++ b/packet.go @@ -57,12 +57,12 @@ func marshal(b []byte, v interface{}) []byte { switch d := reflect.ValueOf(v); d.Kind() { case reflect.Struct: for i, n := 0, d.NumField(); i < n; i++ { - b = append(marshal(b, d.Field(i).Interface())) + b = marshal(b, d.Field(i).Interface()) } return b case reflect.Slice: for i, n := 0, d.Len(); i < n; i++ { - b = append(marshal(b, d.Index(i).Interface())) + b = marshal(b, d.Index(i).Interface()) } return b default: @@ -116,26 +116,45 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) { return string(b[:n]), b[n:], nil } +type packetMarshaler interface { + marshalPacket() (header, payload []byte, err error) +} + +func marshalPacket(m encoding.BinaryMarshaler) (header, payload []byte, err error) { + if m, ok := m.(packetMarshaler); ok { + return m.marshalPacket() + } + + header, err = m.MarshalBinary() + return +} + // sendPacket marshals p according to RFC 4234. func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { - bb, err := m.MarshalBinary() + header, payload, err := marshalPacket(m) if err != nil { return errors.Errorf("binary marshaller failed: %v", err) } + + length := len(header) + len(payload) - 4 // subtract the uint32(length) from the start if debugDumpTxPacketBytes { - debug("send packet: %s %d bytes %x", fxp(bb[0]), len(bb), bb[1:]) + debug("send packet: %s %d bytes %x%x", fxp(header[4]), length, header[5:], payload) } else if debugDumpTxPacket { - debug("send packet: %s %d bytes", fxp(bb[0]), len(bb)) + debug("send packet: %s %d bytes", fxp(header[4]), length) } - // Slide packet down 4 bytes to make room for length header. - packet := append(bb, make([]byte, 4)...) // optimistically assume bb has capacity - copy(packet[4:], bb) - binary.BigEndian.PutUint32(packet[:4], uint32(len(bb))) - _, err = w.Write(packet) - if err != nil { + binary.BigEndian.PutUint32(header[:4], uint32(length)) + + if _, err := w.Write(header); err != nil { return errors.Errorf("failed to send packet: %v", err) } + + if len(payload) > 0 { + if _, err := w.Write(payload); err != nil { + return errors.Errorf("failed to send packet payload: %v", err) + } + } + return nil } @@ -199,19 +218,21 @@ type sshFxInitPacket struct { Extensions []extensionPair } -func (p sshFxInitPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 // byte + uint32 +func (p *sshFxInitPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version) for _, e := range p.Extensions { l += 4 + len(e.Name) + 4 + len(e.Data) } - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpInit) b = marshalUint32(b, p.Version) + for _, e := range p.Extensions { b = marshalString(b, e.Name) b = marshalString(b, e.Data) } + return b, nil } @@ -240,30 +261,33 @@ type sshExtensionPair struct { Name, Data string } -func (p sshFxVersionPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 // byte + uint32 +func (p *sshFxVersionPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version) for _, e := range p.Extensions { l += 4 + len(e.Name) + 4 + len(e.Data) } - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpVersion) b = marshalUint32(b, p.Version) + for _, e := range p.Extensions { b = marshalString(b, e.Name) b = marshalString(b, e.Data) } + return b, nil } -func marshalIDString(packetType byte, id uint32, str string) ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func marshalIDStringPacket(packetType byte, id uint32, str string) ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(str) - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, packetType) b = marshalUint32(b, id) b = marshalString(b, str) + return b, nil } @@ -282,10 +306,10 @@ type sshFxpReaddirPacket struct { Handle string } -func (p sshFxpReaddirPacket) id() uint32 { return p.ID } +func (p *sshFxpReaddirPacket) id() uint32 { return p.ID } -func (p sshFxpReaddirPacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpReaddir, p.ID, p.Handle) +func (p *sshFxpReaddirPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpReaddir, p.ID, p.Handle) } func (p *sshFxpReaddirPacket) UnmarshalBinary(b []byte) error { @@ -297,10 +321,10 @@ type sshFxpOpendirPacket struct { Path string } -func (p sshFxpOpendirPacket) id() uint32 { return p.ID } +func (p *sshFxpOpendirPacket) id() uint32 { return p.ID } -func (p sshFxpOpendirPacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpOpendir, p.ID, p.Path) +func (p *sshFxpOpendirPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpOpendir, p.ID, p.Path) } func (p *sshFxpOpendirPacket) UnmarshalBinary(b []byte) error { @@ -312,10 +336,10 @@ type sshFxpLstatPacket struct { Path string } -func (p sshFxpLstatPacket) id() uint32 { return p.ID } +func (p *sshFxpLstatPacket) id() uint32 { return p.ID } -func (p sshFxpLstatPacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpLstat, p.ID, p.Path) +func (p *sshFxpLstatPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpLstat, p.ID, p.Path) } func (p *sshFxpLstatPacket) UnmarshalBinary(b []byte) error { @@ -327,10 +351,10 @@ type sshFxpStatPacket struct { Path string } -func (p sshFxpStatPacket) id() uint32 { return p.ID } +func (p *sshFxpStatPacket) id() uint32 { return p.ID } -func (p sshFxpStatPacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpStat, p.ID, p.Path) +func (p *sshFxpStatPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpStat, p.ID, p.Path) } func (p *sshFxpStatPacket) UnmarshalBinary(b []byte) error { @@ -342,10 +366,10 @@ type sshFxpFstatPacket struct { Handle string } -func (p sshFxpFstatPacket) id() uint32 { return p.ID } +func (p *sshFxpFstatPacket) id() uint32 { return p.ID } -func (p sshFxpFstatPacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpFstat, p.ID, p.Handle) +func (p *sshFxpFstatPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpFstat, p.ID, p.Handle) } func (p *sshFxpFstatPacket) UnmarshalBinary(b []byte) error { @@ -357,10 +381,10 @@ type sshFxpClosePacket struct { Handle string } -func (p sshFxpClosePacket) id() uint32 { return p.ID } +func (p *sshFxpClosePacket) id() uint32 { return p.ID } -func (p sshFxpClosePacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpClose, p.ID, p.Handle) +func (p *sshFxpClosePacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpClose, p.ID, p.Handle) } func (p *sshFxpClosePacket) UnmarshalBinary(b []byte) error { @@ -372,10 +396,10 @@ type sshFxpRemovePacket struct { Filename string } -func (p sshFxpRemovePacket) id() uint32 { return p.ID } +func (p *sshFxpRemovePacket) id() uint32 { return p.ID } -func (p sshFxpRemovePacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpRemove, p.ID, p.Filename) +func (p *sshFxpRemovePacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpRemove, p.ID, p.Filename) } func (p *sshFxpRemovePacket) UnmarshalBinary(b []byte) error { @@ -387,10 +411,10 @@ type sshFxpRmdirPacket struct { Path string } -func (p sshFxpRmdirPacket) id() uint32 { return p.ID } +func (p *sshFxpRmdirPacket) id() uint32 { return p.ID } -func (p sshFxpRmdirPacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpRmdir, p.ID, p.Path) +func (p *sshFxpRmdirPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpRmdir, p.ID, p.Path) } func (p *sshFxpRmdirPacket) UnmarshalBinary(b []byte) error { @@ -403,18 +427,19 @@ type sshFxpSymlinkPacket struct { Linkpath string } -func (p sshFxpSymlinkPacket) id() uint32 { return p.ID } +func (p *sshFxpSymlinkPacket) id() uint32 { return p.ID } -func (p sshFxpSymlinkPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpSymlinkPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Targetpath) + 4 + len(p.Linkpath) - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpSymlink) b = marshalUint32(b, p.ID) b = marshalString(b, p.Targetpath) b = marshalString(b, p.Linkpath) + return b, nil } @@ -436,21 +461,22 @@ type sshFxpHardlinkPacket struct { Newpath string } -func (p sshFxpHardlinkPacket) id() uint32 { return p.ID } +func (p *sshFxpHardlinkPacket) id() uint32 { return p.ID } -func (p sshFxpHardlinkPacket) MarshalBinary() ([]byte, error) { +func (p *sshFxpHardlinkPacket) MarshalBinary() ([]byte, error) { const ext = "hardlink@openssh.com" - l := 1 + 4 + // type(byte) + uint32 + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(ext) + 4 + len(p.Oldpath) + 4 + len(p.Newpath) - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpExtended) b = marshalUint32(b, p.ID) b = marshalString(b, ext) b = marshalString(b, p.Oldpath) b = marshalString(b, p.Newpath) + return b, nil } @@ -459,10 +485,10 @@ type sshFxpReadlinkPacket struct { Path string } -func (p sshFxpReadlinkPacket) id() uint32 { return p.ID } +func (p *sshFxpReadlinkPacket) id() uint32 { return p.ID } -func (p sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpReadlink, p.ID, p.Path) +func (p *sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpReadlink, p.ID, p.Path) } func (p *sshFxpReadlinkPacket) UnmarshalBinary(b []byte) error { @@ -474,10 +500,10 @@ type sshFxpRealpathPacket struct { Path string } -func (p sshFxpRealpathPacket) id() uint32 { return p.ID } +func (p *sshFxpRealpathPacket) id() uint32 { return p.ID } -func (p sshFxpRealpathPacket) MarshalBinary() ([]byte, error) { - return marshalIDString(sshFxpRealpath, p.ID, p.Path) +func (p *sshFxpRealpathPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpRealpath, p.ID, p.Path) } func (p *sshFxpRealpathPacket) UnmarshalBinary(b []byte) error { @@ -490,8 +516,8 @@ type sshFxpNameAttr struct { Attrs []interface{} } -func (p sshFxpNameAttr) MarshalBinary() ([]byte, error) { - b := []byte{} +func (p *sshFxpNameAttr) MarshalBinary() ([]byte, error) { + var b []byte b = marshalString(b, p.Name) b = marshalString(b, p.LongName) for _, attr := range p.Attrs { @@ -502,23 +528,34 @@ func (p sshFxpNameAttr) MarshalBinary() ([]byte, error) { type sshFxpNamePacket struct { ID uint32 - NameAttrs []sshFxpNameAttr + NameAttrs []*sshFxpNameAttr } -func (p sshFxpNamePacket) MarshalBinary() ([]byte, error) { - b := []byte{} +func (p *sshFxpNamePacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + + b := make([]byte, 4, l) b = append(b, sshFxpName) b = marshalUint32(b, p.ID) b = marshalUint32(b, uint32(len(p.NameAttrs))) + + var payload []byte for _, na := range p.NameAttrs { ab, err := na.MarshalBinary() if err != nil { - return nil, err + return nil, nil, err } - b = append(b, ab...) + payload = append(payload, ab...) } - return b, nil + + return b, payload, nil +} + +func (p *sshFxpNamePacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } type sshFxpOpenPacket struct { @@ -528,19 +565,20 @@ type sshFxpOpenPacket struct { Flags uint32 // ignored } -func (p sshFxpOpenPacket) id() uint32 { return p.ID } +func (p *sshFxpOpenPacket) id() uint32 { return p.ID } -func (p sshFxpOpenPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + +func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + 4 + 4 - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpOpen) b = marshalUint32(b, p.ID) b = marshalString(b, p.Path) b = marshalUint32(b, p.Pflags) b = marshalUint32(b, p.Flags) + return b, nil } @@ -565,19 +603,20 @@ type sshFxpReadPacket struct { Handle string } -func (p sshFxpReadPacket) id() uint32 { return p.ID } +func (p *sshFxpReadPacket) id() uint32 { return p.ID } -func (p sshFxpReadPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpReadPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Handle) + 8 + 4 // uint64 + uint32 - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpRead) b = marshalUint32(b, p.ID) b = marshalString(b, p.Handle) b = marshalUint64(b, p.Offset) b = marshalUint32(b, p.Len) + return b, nil } @@ -595,16 +634,19 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error { return nil } +// We need allocate bigger slices with extra capacity to avoid a re-allocation in sshFxpDataPacket.MarshalBinary +// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length) +const dataHeaderLen = 4 + 1 + 4 + 4 + func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte { dataLen := clamp(p.Len, maxTxPacket) if alloc != nil { // GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in - // sshFxpDataPacket.MarshalBinary and sendPacket + // sshFxpDataPacket.MarshalBinary return alloc.GetPage(orderID)[:dataLen] } - // we allocate a slice with a bigger capacity so we avoid a new allocation in sshFxpDataPacket.MarshalBinary - // and in sendPacket, we need 9 bytes in MarshalBinary and 4 bytes in sendPacket. - return make([]byte, dataLen, dataLen+9+4) + // allocate with extra space for the header + return make([]byte, dataLen, dataLen+dataHeaderLen) } type sshFxpRenamePacket struct { @@ -613,18 +655,19 @@ type sshFxpRenamePacket struct { Newpath string } -func (p sshFxpRenamePacket) id() uint32 { return p.ID } +func (p *sshFxpRenamePacket) id() uint32 { return p.ID } -func (p sshFxpRenamePacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpRenamePacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Oldpath) + 4 + len(p.Newpath) - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpRename) b = marshalUint32(b, p.ID) b = marshalString(b, p.Oldpath) b = marshalString(b, p.Newpath) + return b, nil } @@ -646,21 +689,22 @@ type sshFxpPosixRenamePacket struct { Newpath string } -func (p sshFxpPosixRenamePacket) id() uint32 { return p.ID } +func (p *sshFxpPosixRenamePacket) id() uint32 { return p.ID } -func (p sshFxpPosixRenamePacket) MarshalBinary() ([]byte, error) { +func (p *sshFxpPosixRenamePacket) MarshalBinary() ([]byte, error) { const ext = "posix-rename@openssh.com" - l := 1 + 4 + // type(byte) + uint32 + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(ext) + 4 + len(p.Oldpath) + 4 + len(p.Newpath) - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpExtended) b = marshalUint32(b, p.ID) b = marshalString(b, ext) b = marshalString(b, p.Oldpath) b = marshalString(b, p.Newpath) + return b, nil } @@ -672,22 +716,27 @@ type sshFxpWritePacket struct { Data []byte } -func (p sshFxpWritePacket) id() uint32 { return p.ID } +func (p *sshFxpWritePacket) id() uint32 { return p.ID } -func (p sshFxpWritePacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpWritePacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Handle) + - 8 + 4 + // uint64 + uint32 - len(p.Data) + 8 + // uint64 + 4 - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpWrite) b = marshalUint32(b, p.ID) b = marshalString(b, p.Handle) b = marshalUint64(b, p.Offset) b = marshalUint32(b, p.Length) - b = append(b, p.Data...) - return b, nil + + return b, p.Data, nil +} + +func (p *sshFxpWritePacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } func (p *sshFxpWritePacket) UnmarshalBinary(b []byte) error { @@ -714,18 +763,19 @@ type sshFxpMkdirPacket struct { Path string } -func (p sshFxpMkdirPacket) id() uint32 { return p.ID } +func (p *sshFxpMkdirPacket) id() uint32 { return p.ID } -func (p sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + 4 // uint32 - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpMkdir) b = marshalUint32(b, p.ID) b = marshalString(b, p.Path) b = marshalUint32(b, p.Flags) + return b, nil } @@ -755,35 +805,49 @@ type sshFxpFsetstatPacket struct { Attrs interface{} } -func (p sshFxpSetstatPacket) id() uint32 { return p.ID } -func (p sshFxpFsetstatPacket) id() uint32 { return p.ID } +func (p *sshFxpSetstatPacket) id() uint32 { return p.ID } +func (p *sshFxpFsetstatPacket) id() uint32 { return p.ID } -func (p sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + - 4 // uint32 + uint64 + 4 // uint32 - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpSetstat) b = marshalUint32(b, p.ID) b = marshalString(b, p.Path) b = marshalUint32(b, p.Flags) - b = marshal(b, p.Attrs) - return b, nil + + payload := marshal(nil, p.Attrs) + + return b, payload, nil } -func (p sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Handle) + - 4 // uint32 + uint64 + 4 // uint32 - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpFsetstat) b = marshalUint32(b, p.ID) b = marshalString(b, p.Handle) b = marshalUint32(b, p.Flags) - b = marshal(b, p.Attrs) - return b, nil + + payload := marshal(nil, p.Attrs) + + return b, payload, nil +} + +func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { @@ -817,10 +881,15 @@ type sshFxpHandlePacket struct { Handle string } -func (p sshFxpHandlePacket) MarshalBinary() ([]byte, error) { - b := []byte{sshFxpHandle} +func (p *sshFxpHandlePacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Handle) + + b := make([]byte, 4, l) + b = append(b, sshFxpHandle) b = marshalUint32(b, p.ID) b = marshalString(b, p.Handle) + return b, nil } @@ -829,10 +898,17 @@ type sshFxpStatusPacket struct { StatusError } -func (p sshFxpStatusPacket) MarshalBinary() ([]byte, error) { - b := []byte{sshFxpStatus} +func (p *sshFxpStatusPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + + 4 + len(p.StatusError.msg) + + 4 + len(p.StatusError.lang) + + b := make([]byte, 4, l) + b = append(b, sshFxpStatus) b = marshalUint32(b, p.ID) b = marshalStatus(b, p.StatusError) + return b, nil } @@ -842,14 +918,30 @@ type sshFxpDataPacket struct { Data []byte } +func (p *sshFxpDataPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + + b := make([]byte, 4, l) + b = append(b, sshFxpData) + b = marshalUint32(b, p.ID) + b = marshalUint32(b, p.Length) + + return b, p.Data, nil +} + // MarshalBinary encodes the receiver into a binary form and returns the result. // To avoid a new allocation the Data slice must have a capacity >= Length + 9 -func (p sshFxpDataPacket) MarshalBinary() ([]byte, error) { - b := append(p.Data, make([]byte, 9)...) - copy(b[9:], p.Data[:p.Length]) - b[0] = sshFxpData - binary.BigEndian.PutUint32(b[1:5], p.ID) - binary.BigEndian.PutUint32(b[5:9], p.Length) +// +// This is hand-coded rather than just append(header, payload...), +// in order to try and reuse the r.Data backing store in the packet. +func (p *sshFxpDataPacket) MarshalBinary() ([]byte, error) { + b := append(p.Data, make([]byte, dataHeaderLen)...) + copy(b[dataHeaderLen:], p.Data[:p.Length]) + // b[0:4] will be overwritten with the length in sendPacket + b[4] = sshFxpData + binary.BigEndian.PutUint32(b[5:9], p.ID) + binary.BigEndian.PutUint32(b[9:13], p.Length) return b, nil } @@ -872,18 +964,20 @@ type sshFxpStatvfsPacket struct { Path string } -func (p sshFxpStatvfsPacket) id() uint32 { return p.ID } +func (p *sshFxpStatvfsPacket) id() uint32 { return p.ID } -func (p sshFxpStatvfsPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 - len(p.Path) + - len("statvfs@openssh.com") +func (p *sshFxpStatvfsPacket) MarshalBinary() ([]byte, error) { + const ext = "statvfs@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + + 4 + len(p.Path) - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpExtended) b = marshalUint32(b, p.ID) - b = marshalString(b, "statvfs@openssh.com") + b = marshalString(b, ext) b = marshalString(b, p.Path) + return b, nil } @@ -913,12 +1007,19 @@ func (p *StatVFS) FreeSpace() uint64 { return p.Frsize * p.Bfree } -// MarshalBinary converts to ssh_FXP_EXTENDED_REPLY packet binary format -func (p *StatVFS) MarshalBinary() ([]byte, error) { +// marshalPacket converts to ssh_FXP_EXTENDED_REPLY packet binary format +func (p *StatVFS) marshalPacket() ([]byte, []byte, error) { + header := []byte{0, 0, 0, 0, sshFxpExtendedReply} + var buf bytes.Buffer - buf.Write([]byte{sshFxpExtendedReply}) err := binary.Write(&buf, binary.BigEndian, p) - return buf.Bytes(), err + + return header, buf.Bytes(), err +} + +func (p *StatVFS) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } type sshFxpFsyncPacket struct { @@ -926,18 +1027,20 @@ type sshFxpFsyncPacket struct { Handle string } -func (p sshFxpFsyncPacket) id() uint32 { return p.ID } +func (p *sshFxpFsyncPacket) id() uint32 { return p.ID } -func (p sshFxpFsyncPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type (byte) + ID (uint32) - 4 + len("fsync@openssh.com") + +func (p *sshFxpFsyncPacket) MarshalBinary() ([]byte, error) { + const ext = "fsync@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + 4 + len(p.Handle) - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpExtended) b = marshalUint32(b, p.ID) - b = marshalString(b, "fsync@openssh.com") + b = marshalString(b, ext) b = marshalString(b, p.Handle) + return b, nil } @@ -950,17 +1053,17 @@ type sshFxpExtendedPacket struct { } } -func (p sshFxpExtendedPacket) id() uint32 { return p.ID } -func (p sshFxpExtendedPacket) readonly() bool { +func (p *sshFxpExtendedPacket) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacket) readonly() bool { if p.SpecificPacket == nil { return true } return p.SpecificPacket.readonly() } -func (p sshFxpExtendedPacket) respond(svr *Server) responsePacket { +func (p *sshFxpExtendedPacket) respond(svr *Server) responsePacket { if p.SpecificPacket == nil { - return statusFromError(p, nil) + return statusFromError(p.ID, nil) } return p.SpecificPacket.respond(svr) } @@ -995,8 +1098,8 @@ type sshFxpExtendedPacketStatVFS struct { Path string } -func (p sshFxpExtendedPacketStatVFS) id() uint32 { return p.ID } -func (p sshFxpExtendedPacketStatVFS) readonly() bool { return true } +func (p *sshFxpExtendedPacketStatVFS) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacketStatVFS) readonly() bool { return true } func (p *sshFxpExtendedPacketStatVFS) UnmarshalBinary(b []byte) error { var err error if p.ID, b, err = unmarshalUint32Safe(b); err != nil { @@ -1016,8 +1119,8 @@ type sshFxpExtendedPacketPosixRename struct { Newpath string } -func (p sshFxpExtendedPacketPosixRename) id() uint32 { return p.ID } -func (p sshFxpExtendedPacketPosixRename) readonly() bool { return false } +func (p *sshFxpExtendedPacketPosixRename) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacketPosixRename) readonly() bool { return false } func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error { var err error if p.ID, b, err = unmarshalUint32Safe(b); err != nil { @@ -1032,9 +1135,9 @@ func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error { return nil } -func (p sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket { +func (p *sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket { err := os.Rename(p.Oldpath, p.Newpath) - return statusFromError(p, err) + return statusFromError(p.ID, err) } type sshFxpExtendedPacketHardlink struct { @@ -1045,8 +1148,8 @@ type sshFxpExtendedPacketHardlink struct { } // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL -func (p sshFxpExtendedPacketHardlink) id() uint32 { return p.ID } -func (p sshFxpExtendedPacketHardlink) readonly() bool { return true } +func (p *sshFxpExtendedPacketHardlink) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacketHardlink) readonly() bool { return true } func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error { var err error if p.ID, b, err = unmarshalUint32Safe(b); err != nil { @@ -1061,7 +1164,7 @@ func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error { return nil } -func (p sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket { +func (p *sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket { err := os.Link(p.Oldpath, p.Newpath) - return statusFromError(p, err) + return statusFromError(p.ID, err) } diff --git a/packet_test.go b/packet_test.go index 8b16be6e..976f66fc 100644 --- a/packet_test.go +++ b/packet_test.go @@ -142,20 +142,20 @@ var sendPacketTests = []struct { p encoding.BinaryMarshaler want []byte }{ - {sshFxInitPacket{ + {&sshFxInitPacket{ Version: 3, Extensions: []extensionPair{ {"posix-rename@openssh.com", "1"}, }, }, []byte{0x0, 0x0, 0x0, 0x26, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, - {sshFxpOpenPacket{ + {&sshFxpOpenPacket{ ID: 1, Path: "/foo", Pflags: flags(os.O_RDONLY), }, []byte{0x0, 0x0, 0x0, 0x15, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}}, - {sshFxpWritePacket{ + {&sshFxpWritePacket{ ID: 124, Handle: "foo", Offset: 13, @@ -163,7 +163,7 @@ var sendPacketTests = []struct { Data: []byte("bar"), }, []byte{0x0, 0x0, 0x0, 0x1b, 0x6, 0x0, 0x0, 0x0, 0x7c, 0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, 0x0, 0x0, 0x0, 0x3, 0x62, 0x61, 0x72}}, - {sshFxpSetstatPacket{ + {&sshFxpSetstatPacket{ ID: 31, Path: "/bar", Flags: flags(os.O_WRONLY), @@ -195,7 +195,7 @@ var recvPacketTests = []struct { want uint8 rest []byte }{ - {sp(sshFxInitPacket{ + {sp(&sshFxInitPacket{ Version: 3, Extensions: []extensionPair{ {"posix-rename@openssh.com", "1"}, @@ -299,7 +299,7 @@ func TestSSHFxpOpenPackethasPflags(t *testing.T) { func BenchmarkMarshalInit(b *testing.B) { for i := 0; i < b.N; i++ { - sp(sshFxInitPacket{ + sp(&sshFxInitPacket{ Version: 3, Extensions: []extensionPair{ {"posix-rename@openssh.com", "1"}, @@ -310,7 +310,7 @@ func BenchmarkMarshalInit(b *testing.B) { func BenchmarkMarshalOpen(b *testing.B) { for i := 0; i < b.N; i++ { - sp(sshFxpOpenPacket{ + sp(&sshFxpOpenPacket{ ID: 1, Path: "/home/test/some/random/path", Pflags: flags(os.O_RDONLY), @@ -321,7 +321,7 @@ func BenchmarkMarshalOpen(b *testing.B) { func BenchmarkMarshalWriteWorstCase(b *testing.B) { data := make([]byte, 32*1024) for i := 0; i < b.N; i++ { - sp(sshFxpWritePacket{ + sp(&sshFxpWritePacket{ ID: 1, Handle: "someopaquehandle", Offset: 0, @@ -334,7 +334,7 @@ func BenchmarkMarshalWriteWorstCase(b *testing.B) { func BenchmarkMarshalWrite1k(b *testing.B) { data := make([]byte, 1024) for i := 0; i < b.N; i++ { - sp(sshFxpWritePacket{ + sp(&sshFxpWritePacket{ ID: 1, Handle: "someopaquehandle", Offset: 0, diff --git a/request-server.go b/request-server.go index 72ee3b86..fe94ae25 100644 --- a/request-server.go +++ b/request-server.go @@ -193,10 +193,10 @@ func (rs *RequestServer) packetWorker( var rpkt responsePacket switch pkt := pkt.requestPacket.(type) { case *sshFxInitPacket: - rpkt = sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions} + rpkt = &sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions} case *sshFxpClosePacket: handle := pkt.getHandle() - rpkt = statusFromError(pkt, rs.closeRequest(handle)) + rpkt = statusFromError(pkt.ID, rs.closeRequest(handle)) case *sshFxpRealpathPacket: rpkt = cleanPacketPath(pkt) case *sshFxpOpendirPacket: @@ -219,7 +219,7 @@ func (rs *RequestServer) packetWorker( handle := pkt.getHandle() request, ok := rs.getRequest(handle) if !ok { - rpkt = statusFromError(pkt, EBADF) + rpkt = statusFromError(pkt.ID, EBADF) } else { request = NewRequest("Stat", request.Filepath) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) @@ -228,7 +228,7 @@ func (rs *RequestServer) packetWorker( handle := pkt.getHandle() request, ok := rs.getRequest(handle) if !ok { - rpkt = statusFromError(pkt, EBADF) + rpkt = statusFromError(pkt.ID, EBADF) } else { request = NewRequest("Setstat", request.Filepath) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) @@ -244,7 +244,7 @@ func (rs *RequestServer) packetWorker( handle := pkt.getHandle() request, ok := rs.getRequest(handle) if !ok { - rpkt = statusFromError(pkt, EBADF) + rpkt = statusFromError(pkt.id(), EBADF) } else { rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) } @@ -253,7 +253,7 @@ func (rs *RequestServer) packetWorker( rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) request.close() default: - rpkt = statusFromError(pkt, ErrSSHFxOpUnsupported) + rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported) } rs.pktMgr.readyPacket( @@ -267,11 +267,13 @@ func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket { path := cleanPath(pkt.getPath()) return &sshFxpNamePacket{ ID: pkt.id(), - NameAttrs: []sshFxpNameAttr{{ - Name: path, - LongName: path, - Attrs: emptyFileStat, - }}, + NameAttrs: []*sshFxpNameAttr{ + &sshFxpNameAttr{ + Name: path, + LongName: path, + Attrs: emptyFileStat, + }, + }, } } diff --git a/request-server_test.go b/request-server_test.go index af606783..3d2502a6 100644 --- a/request-server_test.go +++ b/request-server_test.go @@ -39,11 +39,14 @@ const sock = "/tmp/rstest.sock" func clientRequestServerPair(t *testing.T) *csPair { skipIfWindows(t) skipIfPlan9(t) - ready := make(chan bool) + + ready := make(chan struct{}) + canReturn := make(chan struct{}) os.Remove(sock) // either this or signal handling pair := &csPair{ svrResult: make(chan error, 1), } + var server *RequestServer go func() { l, err := net.Listen("unix", sock) @@ -51,26 +54,37 @@ func clientRequestServerPair(t *testing.T) *csPair { // neither assert nor t.Fatal reliably exit before Accept errors panic(err) } - ready <- true + + close(ready) + fd, err := l.Accept() require.NoError(t, err) + handlers := InMemHandler() var options []RequestServerOption if *testAllocator { options = append(options, WithRSAllocator()) } + server = NewRequestServer(fd, handlers, options...) + close(canReturn) + err = server.Serve() pair.svrResult <- err }() + <-ready defer os.Remove(sock) + c, err := net.Dial("unix", sock) require.NoError(t, err) + client, err := NewClientPipe(c, c) if err != nil { - t.Fatalf("%+v\n", err) + t.Fatalf("unexpected error: %+v", err) } + + <-canReturn pair.svr = server pair.cli = client return pair @@ -147,11 +161,12 @@ func TestRequestCacheState(t *testing.T) { func putTestFile(cli *Client, path, content string) (int, error) { w, err := cli.Create(path) - if err == nil { - defer w.Close() - return w.Write([]byte(content)) + if err != nil { + return 0, err } - return 0, err + defer w.Close() + + return w.Write([]byte(content)) } func getTestFile(cli *Client, path string) ([]byte, error) { diff --git a/request.go b/request.go index 086d18d1..44651027 100644 --- a/request.go +++ b/request.go @@ -219,7 +219,7 @@ func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, o case "Stat", "Lstat", "Readlink": return filestat(handlers.FileList, r, pkt) default: - return statusFromError(pkt, + return statusFromError(pkt.id(), errors.Errorf("unexpected method: %s", r.Method)) } } @@ -228,6 +228,8 @@ func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, o func (r *Request) open(h Handlers, pkt requestPacket) responsePacket { flags := r.Pflags() + id := pkt.id() + switch { case flags.Write, flags.Append, flags.Creat, flags.Trunc: if flags.Read { @@ -235,36 +237,37 @@ func (r *Request) open(h Handlers, pkt requestPacket) responsePacket { r.Method = "Open" rw, err := openFileWriter.OpenFile(r) if err != nil { - return statusFromError(pkt, err) + return statusFromError(id, err) } r.state.writerReaderAt = rw - return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle} + return &sshFxpHandlePacket{ID: id, Handle: r.handle} } } r.Method = "Put" wr, err := h.FilePut.Filewrite(r) if err != nil { - return statusFromError(pkt, err) + return statusFromError(id, err) } r.state.writerAt = wr case flags.Read: r.Method = "Get" rd, err := h.FileGet.Fileread(r) if err != nil { - return statusFromError(pkt, err) + return statusFromError(id, err) } r.state.readerAt = rd default: - return statusFromError(pkt, errors.New("bad file flags")) + return statusFromError(id, errors.New("bad file flags")) } - return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle} + return &sshFxpHandlePacket{ID: id, Handle: r.handle} } + func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { r.Method = "List" la, err := h.FileList.Filelist(r) if err != nil { - return statusFromError(pkt, wrapPathError(r.Filepath, err)) + return statusFromError(pkt.id(), wrapPathError(r.Filepath, err)) } r.state.listerAt = la return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle} @@ -276,14 +279,14 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde reader := r.state.readerAt r.state.RUnlock() if reader == nil { - return statusFromError(pkt, errors.New("unexpected read packet")) + return statusFromError(pkt.id(), errors.New("unexpected read packet")) } data, offset, _ := packetData(pkt, alloc, orderID) n, err := reader.ReadAt(data, offset) // only return EOF error if no data left to read if err != nil && (err != io.EOF || n == 0) { - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } return &sshFxpDataPacket{ ID: pkt.id(), @@ -298,12 +301,12 @@ func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orde writer := r.state.writerAt r.state.RUnlock() if writer == nil { - return statusFromError(pkt, errors.New("unexpected write packet")) + return statusFromError(pkt.id(), errors.New("unexpected write packet")) } data, offset, _ := packetData(pkt, alloc, orderID) _, err := writer.WriteAt(data, offset) - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } // wrap OpenFileWriter handler @@ -312,15 +315,15 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o writerReader := r.state.writerReaderAt r.state.RUnlock() if writerReader == nil { - return statusFromError(pkt, errors.New("unexpected write and read packet")) + return statusFromError(pkt.id(), errors.New("unexpected write and read packet")) } - switch pkt.(type) { + switch p := pkt.(type) { case *sshFxpReadPacket: - data, offset, _ := packetData(pkt, alloc, orderID) + data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset) n, err := writerReader.ReadAt(data, offset) // only return EOF error if no data left to read if err != nil && (err != io.EOF || n == 0) { - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } return &sshFxpDataPacket{ ID: pkt.id(), @@ -328,11 +331,11 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o Data: data[:n], } case *sshFxpWritePacket: - data, offset, _ := packetData(pkt, alloc, orderID) + data, offset := p.Data, int64(p.Offset) _, err := writerReader.WriteAt(data, offset) - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) default: - return statusFromError(pkt, errors.New("unexpected packet type for read or write")) + return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write")) } } @@ -340,13 +343,9 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) { switch p := p.(type) { case *sshFxpReadPacket: - length = p.Len - offset = int64(p.Offset) - data = p.getDataSlice(alloc, orderID) + return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len case *sshFxpWritePacket: - data = p.Data - length = p.Length - offset = int64(p.Offset) + return p.Data, int64(p.Offset), p.Length } return } @@ -362,30 +361,30 @@ func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket { if r.Method == "PosixRename" { if posixRenamer, ok := h.(PosixRenameFileCmder); ok { err := posixRenamer.PosixRename(r) - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } // PosixRenameFileCmder not implemented handle this request as a Rename r.Method = "Rename" err := h.Filecmd(r) - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } if r.Method == "StatVFS" { if statVFSCmdr, ok := h.(StatVFSFileCmder); ok { stat, err := statVFSCmdr.StatVFS(r) if err != nil { - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } stat.ID = pkt.id() return stat } - return statusFromError(pkt, ErrSSHFxOpUnsupported) + return statusFromError(pkt.id(), ErrSSHFxOpUnsupported) } err := h.Filecmd(r) - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } // wrap FileLister handler @@ -393,7 +392,7 @@ func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket { var err error lister := r.getLister() if lister == nil { - return statusFromError(pkt, errors.New("unexpected dir packet")) + return statusFromError(pkt.id(), errors.New("unexpected dir packet")) } offset := r.lsNext() @@ -406,16 +405,16 @@ func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket { switch r.Method { case "List": if err != nil && err != io.EOF { - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } if err == io.EOF && n == 0 { - return statusFromError(pkt, io.EOF) + return statusFromError(pkt.id(), io.EOF) } dirname := filepath.ToSlash(path.Base(r.Filepath)) ret := &sshFxpNamePacket{ID: pkt.id()} for _, fi := range finfo { - ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{ + ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{ Name: fi.Name(), LongName: runLs(dirname, fi), Attrs: []interface{}{fi}, @@ -424,7 +423,7 @@ func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket { return ret default: err = errors.Errorf("unexpected method: %s", r.Method) - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } } @@ -444,7 +443,7 @@ func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket { lister, err = h.Filelist(r) } if err != nil { - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } finfo := make([]os.FileInfo, 1) n, err := lister.ListAt(finfo, 0) @@ -453,12 +452,12 @@ func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket { switch r.Method { case "Stat", "Lstat": if err != nil && err != io.EOF { - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } if n == 0 { err = &os.PathError{Op: strings.ToLower(r.Method), Path: r.Filepath, Err: syscall.ENOENT} - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } return &sshFxpStatResponse{ ID: pkt.id(), @@ -466,25 +465,27 @@ func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket { } case "Readlink": if err != nil && err != io.EOF { - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } if n == 0 { err = &os.PathError{Op: "readlink", Path: r.Filepath, Err: syscall.ENOENT} - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } filename := finfo[0].Name() return &sshFxpNamePacket{ ID: pkt.id(), - NameAttrs: []sshFxpNameAttr{{ - Name: filename, - LongName: filename, - Attrs: emptyFileStat, - }}, + NameAttrs: []*sshFxpNameAttr{ + &sshFxpNameAttr{ + Name: filename, + LongName: filename, + Attrs: emptyFileStat, + }, + }, } default: err = errors.Errorf("unexpected method: %s", r.Method) - return statusFromError(pkt, err) + return statusFromError(pkt.id(), err) } } diff --git a/request_test.go b/request_test.go index 9f1ed661..d3f7db06 100644 --- a/request_test.go +++ b/request_test.go @@ -118,11 +118,11 @@ func (h *Handlers) returnError(err error) { } func getStatusMsg(p interface{}) string { - pkt := p.(sshFxpStatusPacket) + pkt := p.(*sshFxpStatusPacket) return pkt.StatusError.msg } func checkOkStatus(t *testing.T, p interface{}) { - pkt := p.(sshFxpStatusPacket) + pkt := p.(*sshFxpStatusPacket) assert.Equal(t, pkt.StatusError.Code, uint32(sshFxOk), "sshFxpStatusPacket not OK\n", pkt.StatusError.msg) } @@ -166,7 +166,7 @@ func TestRequestCustomError(t *testing.T) { cmdErr := errors.New("stat not supported") handlers.returnError(cmdErr) rpkt := request.call(handlers, pkt, nil, 0) - assert.Equal(t, rpkt, statusFromError(rpkt, cmdErr)) + assert.Equal(t, rpkt, statusFromError(pkt.myid, cmdErr)) } // XXX can't just set method to Get, need to use Open to setup Get/Put @@ -194,7 +194,7 @@ func TestRequestCmdr(t *testing.T) { handlers.returnError(errTest) rpkt = request.call(handlers, pkt, nil, 0) - assert.Equal(t, rpkt, statusFromError(rpkt, errTest)) + assert.Equal(t, rpkt, statusFromError(pkt.myid, errTest)) } func TestRequestInfoStat(t *testing.T) { @@ -227,7 +227,7 @@ func TestRequestInfoReadlink(t *testing.T) { rpkt := request.call(handlers, pkt, nil, 0) npkt, ok := rpkt.(*sshFxpNamePacket) if assert.True(t, ok) { - assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0]) + assert.IsType(t, &sshFxpNameAttr{}, npkt.NameAttrs[0]) assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go") } } diff --git a/server.go b/server.go index f595518a..cfebed73 100644 --- a/server.go +++ b/server.go @@ -152,7 +152,7 @@ func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error { // return permission denied if !readonly && svr.readOnly { svr.pktMgr.readyPacket( - svr.pktMgr.newOrderedResponse(statusFromError(pkt, syscall.EPERM), pkt.orderID()), + svr.pktMgr.newOrderedResponse(statusFromError(pkt.id(), syscall.EPERM), pkt.orderID()), ) continue } @@ -169,29 +169,29 @@ func handlePacket(s *Server, p orderedRequest) error { orderID := p.orderID() switch p := p.requestPacket.(type) { case *sshFxInitPacket: - rpkt = sshFxVersionPacket{ + rpkt = &sshFxVersionPacket{ Version: sftpProtocolVersion, Extensions: sftpExtensions, } case *sshFxpStatPacket: // stat the requested file info, err := os.Stat(p.Path) - rpkt = sshFxpStatResponse{ + rpkt = &sshFxpStatResponse{ ID: p.ID, info: info, } if err != nil { - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) } case *sshFxpLstatPacket: // stat the requested file info, err := os.Lstat(p.Path) - rpkt = sshFxpStatResponse{ + rpkt = &sshFxpStatResponse{ ID: p.ID, info: info, } if err != nil { - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) } case *sshFxpFstatPacket: f, ok := s.getHandle(p.Handle) @@ -199,71 +199,75 @@ func handlePacket(s *Server, p orderedRequest) error { var info os.FileInfo if ok { info, err = f.Stat() - rpkt = sshFxpStatResponse{ + rpkt = &sshFxpStatResponse{ ID: p.ID, info: info, } } if err != nil { - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) } case *sshFxpMkdirPacket: // TODO FIXME: ignore flags field err := os.Mkdir(p.Path, 0755) - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) case *sshFxpRmdirPacket: err := os.Remove(p.Path) - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) case *sshFxpRemovePacket: err := os.Remove(p.Filename) - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) case *sshFxpRenamePacket: err := os.Rename(p.Oldpath, p.Newpath) - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) case *sshFxpSymlinkPacket: err := os.Symlink(p.Targetpath, p.Linkpath) - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) case *sshFxpClosePacket: - rpkt = statusFromError(p, s.closeHandle(p.Handle)) + rpkt = statusFromError(p.ID, s.closeHandle(p.Handle)) case *sshFxpReadlinkPacket: f, err := os.Readlink(p.Path) - rpkt = sshFxpNamePacket{ + rpkt = &sshFxpNamePacket{ ID: p.ID, - NameAttrs: []sshFxpNameAttr{{ - Name: f, - LongName: f, - Attrs: emptyFileStat, - }}, + NameAttrs: []*sshFxpNameAttr{ + &sshFxpNameAttr{ + Name: f, + LongName: f, + Attrs: emptyFileStat, + }, + }, } if err != nil { - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) } case *sshFxpRealpathPacket: f, err := filepath.Abs(p.Path) f = cleanPath(f) - rpkt = sshFxpNamePacket{ + rpkt = &sshFxpNamePacket{ ID: p.ID, - NameAttrs: []sshFxpNameAttr{{ - Name: f, - LongName: f, - Attrs: emptyFileStat, - }}, + NameAttrs: []*sshFxpNameAttr{ + &sshFxpNameAttr{ + Name: f, + LongName: f, + Attrs: emptyFileStat, + }, + }, } if err != nil { - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) } case *sshFxpOpendirPacket: if stat, err := os.Stat(p.Path); err != nil { - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) } else if !stat.IsDir() { - rpkt = statusFromError(p, &os.PathError{ + rpkt = statusFromError(p.ID, &os.PathError{ Path: p.Path, Err: syscall.ENOTDIR}) } else { - rpkt = sshFxpOpenPacket{ + rpkt = (&sshFxpOpenPacket{ ID: p.ID, Path: p.Path, Pflags: sshFxfRead, - }.respond(s) + }).respond(s) } case *sshFxpReadPacket: var err error = EBADF @@ -275,7 +279,7 @@ func handlePacket(s *Server, p orderedRequest) error { if _err != nil && (_err != io.EOF || n == 0) { err = _err } - rpkt = sshFxpDataPacket{ + rpkt = &sshFxpDataPacket{ ID: p.ID, Length: uint32(n), Data: data[:n], @@ -283,7 +287,7 @@ func handlePacket(s *Server, p orderedRequest) error { } } if err != nil { - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) } case *sshFxpWritePacket: @@ -292,10 +296,10 @@ func handlePacket(s *Server, p orderedRequest) error { if ok { _, err = f.WriteAt(p.Data, int64(p.Offset)) } - rpkt = statusFromError(p, err) + rpkt = statusFromError(p.ID, err) case *sshFxpExtendedPacket: if p.SpecificPacket == nil { - rpkt = statusFromError(p, ErrSSHFxOpUnsupported) + rpkt = statusFromError(p.ID, ErrSSHFxOpUnsupported) } else { rpkt = p.respond(s) } @@ -375,27 +379,38 @@ type ider interface { } // The init packet has no ID, so we just return a zero-value ID -func (p sshFxInitPacket) id() uint32 { return 0 } +func (p *sshFxInitPacket) id() uint32 { return 0 } type sshFxpStatResponse struct { ID uint32 info os.FileInfo } -func (p sshFxpStatResponse) MarshalBinary() ([]byte, error) { - b := []byte{sshFxpAttrs} +func (p *sshFxpStatResponse) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(id) + + b := make([]byte, 4, l) + b = append(b, sshFxpAttrs) b = marshalUint32(b, p.ID) - b = marshalFileInfo(b, p.info) - return b, nil + + var payload []byte + payload = marshalFileInfo(payload, p.info) + + return b, payload, nil +} + +func (p *sshFxpStatResponse) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } var emptyFileStat = []interface{}{uint32(0)} -func (p sshFxpOpenPacket) readonly() bool { +func (p *sshFxpOpenPacket) readonly() bool { return !p.hasPflags(sshFxfWrite) } -func (p sshFxpOpenPacket) hasPflags(flags ...uint32) bool { +func (p *sshFxpOpenPacket) hasPflags(flags ...uint32) bool { for _, f := range flags { if p.Pflags&f == 0 { return false @@ -404,7 +419,7 @@ func (p sshFxpOpenPacket) hasPflags(flags ...uint32) bool { return true } -func (p sshFxpOpenPacket) respond(svr *Server) responsePacket { +func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { var osFlags int if p.hasPflags(sshFxfRead, sshFxfWrite) { osFlags |= os.O_RDWR @@ -414,7 +429,7 @@ func (p sshFxpOpenPacket) respond(svr *Server) responsePacket { osFlags |= os.O_RDONLY } else { // how are they opening? - return statusFromError(p, syscall.EINVAL) + return statusFromError(p.ID, syscall.EINVAL) } // Don't use O_APPEND flag as it conflicts with WriteAt. @@ -432,28 +447,28 @@ func (p sshFxpOpenPacket) respond(svr *Server) responsePacket { f, err := os.OpenFile(p.Path, osFlags, 0644) if err != nil { - return statusFromError(p, err) + return statusFromError(p.ID, err) } handle := svr.nextHandle(f) - return sshFxpHandlePacket{ID: p.id(), Handle: handle} + return &sshFxpHandlePacket{ID: p.ID, Handle: handle} } -func (p sshFxpReaddirPacket) respond(svr *Server) responsePacket { +func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket { f, ok := svr.getHandle(p.Handle) if !ok { - return statusFromError(p, EBADF) + return statusFromError(p.ID, EBADF) } dirname := f.Name() dirents, err := f.Readdir(128) if err != nil { - return statusFromError(p, err) + return statusFromError(p.ID, err) } - ret := sshFxpNamePacket{ID: p.ID} + ret := &sshFxpNamePacket{ID: p.ID} for _, dirent := range dirents { - ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{ + ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{ Name: dirent.Name(), LongName: runLs(dirname, dirent), Attrs: []interface{}{dirent}, @@ -462,7 +477,7 @@ func (p sshFxpReaddirPacket) respond(svr *Server) responsePacket { return ret } -func (p sshFxpSetstatPacket) respond(svr *Server) responsePacket { +func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { // additional unmarshalling is required for each possibility here b := p.Attrs.([]byte) var err error @@ -501,13 +516,13 @@ func (p sshFxpSetstatPacket) respond(svr *Server) responsePacket { } } - return statusFromError(p, err) + return statusFromError(p.ID, err) } -func (p sshFxpFsetstatPacket) respond(svr *Server) responsePacket { +func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { f, ok := svr.getHandle(p.Handle) if !ok { - return statusFromError(p, EBADF) + return statusFromError(p.ID, EBADF) } // additional unmarshalling is required for each possibility here @@ -548,12 +563,12 @@ func (p sshFxpFsetstatPacket) respond(svr *Server) responsePacket { } } - return statusFromError(p, err) + return statusFromError(p.ID, err) } -func statusFromError(p ider, err error) sshFxpStatusPacket { - ret := sshFxpStatusPacket{ - ID: p.id(), +func statusFromError(id uint32, err error) *sshFxpStatusPacket { + ret := &sshFxpStatusPacket{ + ID: id, StatusError: StatusError{ // sshFXOk = 0 // sshFXEOF = 1 diff --git a/server_statvfs_impl.go b/server_statvfs_impl.go index 9dc793c8..2d467d1e 100644 --- a/server_statvfs_impl.go +++ b/server_statvfs_impl.go @@ -9,21 +9,20 @@ import ( "syscall" ) -func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { +func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { retPkt, err := getStatVFSForPath(p.Path) if err != nil { - return statusFromError(p, err) + return statusFromError(p.ID, err) } - retPkt.ID = p.ID return retPkt } func getStatVFSForPath(name string) (*StatVFS, error) { - stat := &syscall.Statfs_t{} - if err := syscall.Statfs(name, stat); err != nil { + var stat syscall.Statfs_t + if err := syscall.Statfs(name, &stat); err != nil { return nil, err } - return statvfsFromStatfst(stat) + return statvfsFromStatfst(&stat) } diff --git a/server_statvfs_plan9.go b/server_statvfs_plan9.go index 5a293237..e71a27d3 100644 --- a/server_statvfs_plan9.go +++ b/server_statvfs_plan9.go @@ -4,8 +4,8 @@ import ( "syscall" ) -func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { - return statusFromError(p, syscall.EPLAN9) +func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { + return statusFromError(p.ID, syscall.EPLAN9) } func getStatVFSForPath(name string) (*StatVFS, error) { diff --git a/server_statvfs_stubs.go b/server_statvfs_stubs.go index c1bb104a..fbf49068 100644 --- a/server_statvfs_stubs.go +++ b/server_statvfs_stubs.go @@ -6,8 +6,8 @@ import ( "syscall" ) -func (p sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { - return statusFromError(p, syscall.ENOTSUP) +func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { + return statusFromError(p.ID, syscall.ENOTSUP) } func getStatVFSForPath(name string) (*StatVFS, error) { diff --git a/server_test.go b/server_test.go index f93ba3f7..ddfdd225 100644 --- a/server_test.go +++ b/server_test.go @@ -193,15 +193,16 @@ type sshFxpTestBadExtendedPacket struct { func (p sshFxpTestBadExtendedPacket) id() uint32 { return p.ID } func (p sshFxpTestBadExtendedPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + 4 + // type(byte) + uint32 + uint32 - len(p.Extension) + - len(p.Data) + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Extension) + + 4 + len(p.Data) - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, sshFxpExtended) b = marshalUint32(b, p.ID) b = marshalString(b, p.Extension) b = marshalString(b, p.Data) + return b, nil } @@ -222,7 +223,7 @@ func TestInvalidExtendedPacket(t *testing.T) { defer server.Close() badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"} - typ, data, err := client.clientConn.sendPacket(badPacket) + typ, data, err := client.clientConn.sendPacket(nil, badPacket) if err != nil { t.Fatalf("unexpected error from sendPacket: %s", err) } @@ -280,10 +281,10 @@ func TestConcurrentRequests(t *testing.T) { func TestStatusFromError(t *testing.T) { type test struct { err error - pkt sshFxpStatusPacket + pkt *sshFxpStatusPacket } - tpkt := func(id, code uint32) sshFxpStatusPacket { - return sshFxpStatusPacket{ + tpkt := func(id, code uint32) *sshFxpStatusPacket { + return &sshFxpStatusPacket{ ID: id, StatusError: StatusError{Code: code}, } @@ -300,7 +301,7 @@ func TestStatusFromError(t *testing.T) { } for _, tc := range testCases { tc.pkt.StatusError.msg = tc.err.Error() - assert.Equal(t, tc.pkt, statusFromError(tc.pkt, tc.err)) + assert.Equal(t, tc.pkt, statusFromError(tc.pkt.ID, tc.err)) } } @@ -326,17 +327,17 @@ func TestOpenStatRace(t *testing.T) { pflags := flags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) ch := make(chan result, 3) id1 := client.nextID() - client.dispatchRequest(ch, sshFxpOpenPacket{ + client.dispatchRequest(ch, &sshFxpOpenPacket{ ID: id1, Path: tmppath, Pflags: pflags, }) id2 := client.nextID() - client.dispatchRequest(ch, sshFxpLstatPacket{ + client.dispatchRequest(ch, &sshFxpLstatPacket{ ID: id2, Path: tmppath, }) - testreply := func(id uint32, ch chan result) { + testreply := func(id uint32) { r := <-ch switch r.typ { case sshFxpAttrs, sshFxpHandle: // ignore @@ -344,11 +345,11 @@ func TestOpenStatRace(t *testing.T) { err := normaliseError(unmarshalStatus(id, r.data)) assert.NoError(t, err, "race hit, stat before open") default: - assert.Fail(t, "Unexpected type") + t.Fatal("unexpected type:", r.typ) } } - testreply(id1, ch) - testreply(id2, ch) + testreply(id1) + testreply(id2) os.Remove(tmppath) checkServerAllocator(t, server) } @@ -369,8 +370,8 @@ func TestStatNonExistent(t *testing.T) { } func TestServerWithBrokenClient(t *testing.T) { - validInit := sp(sshFxInitPacket{Version: 3}) - brokenOpen := sp(sshFxpOpenPacket{Path: "foo"}) + validInit := sp(&sshFxInitPacket{Version: 3}) + brokenOpen := sp(&sshFxpOpenPacket{Path: "foo"}) brokenOpen = brokenOpen[:len(brokenOpen)-2] for _, clientInput := range [][]byte{ diff --git a/sftp_test.go b/sftp_test.go index 174a457c..4c1b080c 100644 --- a/sftp_test.go +++ b/sftp_test.go @@ -10,7 +10,6 @@ import ( ) func TestErrFxCode(t *testing.T) { - ider := sshFxpStatusPacket{ID: 1} table := []struct { err error fx fxerr @@ -22,7 +21,7 @@ func TestErrFxCode(t *testing.T) { {err: io.EOF, fx: ErrSSHFxEOF}, } for _, tt := range table { - statusErr := statusFromError(ider, tt.err).StatusError + statusErr := statusFromError(1, tt.err).StatusError assert.Equal(t, statusErr.FxCode(), tt.fx) } }