Skip to content

Commit

Permalink
New method Client.Extensions to list server extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
greatroar committed Oct 22, 2020
1 parent fcaa492 commit 6269895
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
25 changes: 24 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
inflight: make(map[uint32]chan<- result),
closed: make(chan struct{}),
},

ext: make(map[string]string),

maxPacket: 1 << 15,
maxConcurrentRequests: 64,
}
Expand Down Expand Up @@ -183,6 +186,8 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
type Client struct {
clientConn

ext map[string]string // Extensions (name -> data).

maxPacket int // max packet size read or written.
maxConcurrentRequests int
nextid uint32
Expand Down Expand Up @@ -223,14 +228,32 @@ func (c *Client) recvVersion() error {
return &unexpectedPacketErr{sshFxpVersion, typ}
}

version, _ := unmarshalUint32(data)
version, data := unmarshalUint32(data)
if version != sftpProtocolVersion {
return &unexpectedVersionErr{sftpProtocolVersion, version}
}

for len(data) > 0 {
var ext extensionPair
ext, data, err = unmarshalExtensionPair(data)
if err != nil {
return err
}
c.ext[ext.Name] = ext.Data
}

return nil
}

// HasExtension checks whether the server supports a named extension.
//
// The second return value is the extension data reported by the server
// (typically a version number).
func (c *Client) HasExtension(name string) (ok bool, data string) {
data, ok = c.ext[name]
return
}

// Walk returns a new Walker rooted at root.
func (c *Client) Walk(root string) *fs.Walker {
return fs.WalkFS(root, c)
Expand Down
5 changes: 5 additions & 0 deletions client_integration_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sftp
import (
"syscall"
"testing"

"github.com/stretchr/testify/require"
)

func TestClientStatVFS(t *testing.T) {
Expand All @@ -13,6 +15,9 @@ func TestClientStatVFS(t *testing.T) {
defer cmd.Wait()
defer sftp.Close()

ok, _ := sftp.HasExtension("[email protected]")
require.True(t, ok, "server doesn't list statvfs extension")

vfs, err := sftp.StatVFS("/")
if err != nil {
t.Fatal(err)
Expand Down

0 comments on commit 6269895

Please sign in to comment.