Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for working directory in Server #528

Merged
merged 12 commits into from
Oct 18, 2022
4 changes: 2 additions & 2 deletions packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error {
}

func (p *sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket {
err := os.Rename(toLocalPath(p.Oldpath), toLocalPath(p.Newpath))
err := os.Rename(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath))
return statusFromError(p.ID, err)
}

Expand Down Expand Up @@ -1276,6 +1276,6 @@ func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error {
}

func (p *sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket {
err := os.Link(toLocalPath(p.Oldpath), toLocalPath(p.Newpath))
err := os.Link(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath))
return statusFromError(p.ID, err)
}
19 changes: 0 additions & 19 deletions request-plan9.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
package sftp

import (
"path"
"path/filepath"
"syscall"
)

Expand All @@ -15,20 +13,3 @@ func fakeFileInfoSys() interface{} {
func testOsSys(sys interface{}) error {
return nil
}

func toLocalPath(p string) string {
lp := filepath.FromSlash(p)

if path.IsAbs(p) {
tmp := lp[1:]

if filepath.IsAbs(tmp) {
// If the FromSlash without any starting slashes is absolute,
// then we have a filepath encoded with a prefix '/'.
// e.g. "/#s/boot" to "#s/boot"
return tmp
}
}

return lp
}
4 changes: 0 additions & 4 deletions request-unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,3 @@ func testOsSys(sys interface{}) error {
}
return nil
}

func toLocalPath(p string) string {
return p
}
31 changes: 0 additions & 31 deletions request_windows.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package sftp

import (
"path"
"path/filepath"
"syscall"
)

Expand All @@ -13,32 +11,3 @@ func fakeFileInfoSys() interface{} {
func testOsSys(sys interface{}) error {
return nil
}

func toLocalPath(p string) string {
lp := filepath.FromSlash(p)

if path.IsAbs(p) {
tmp := lp
for len(tmp) > 0 && tmp[0] == '\\' {
tmp = tmp[1:]
}

if filepath.IsAbs(tmp) {
// If the FromSlash without any starting slashes is absolute,
// then we have a filepath encoded with a prefix '/'.
// e.g. "/C:/Windows" to "C:\\Windows"
return tmp
}

tmp += "\\"

if filepath.IsAbs(tmp) {
// If the FromSlash without any starting slashes but with extra end slash is absolute,
// then we have a filepath encoded with a prefix '/' and a dropped '/' at the end.
// e.g. "/C:" to "C:\\"
return tmp
}
}

return lp
}
40 changes: 26 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type Server struct {
openFiles map[string]*os.File
openFilesLock sync.RWMutex
handleCount int
workDir string
}

func (svr *Server) nextHandle(f *os.File) string {
Expand Down Expand Up @@ -128,6 +129,16 @@ func WithAllocator() ServerOption {
}
}

// WithServerWorkingDirectory sets a working directory to use as base
// for relative paths.
// If unset the default is current working directory (os.Getwd).
func WithServerWorkingDirectory(workDir string) ServerOption {
return func(s *Server) error {
s.workDir = cleanPath(workDir)
return nil
}
}

type rxPacket struct {
pktType fxp
pktBytes []byte
Expand Down Expand Up @@ -174,7 +185,7 @@ func handlePacket(s *Server, p orderedRequest) error {
}
case *sshFxpStatPacket:
// stat the requested file
info, err := os.Stat(toLocalPath(p.Path))
info, err := os.Stat(s.toLocalPath(p.Path))
rpkt = &sshFxpStatResponse{
ID: p.ID,
info: info,
Expand All @@ -184,7 +195,7 @@ func handlePacket(s *Server, p orderedRequest) error {
}
case *sshFxpLstatPacket:
// stat the requested file
info, err := os.Lstat(toLocalPath(p.Path))
info, err := os.Lstat(s.toLocalPath(p.Path))
rpkt = &sshFxpStatResponse{
ID: p.ID,
info: info,
Expand All @@ -208,24 +219,24 @@ func handlePacket(s *Server, p orderedRequest) error {
}
case *sshFxpMkdirPacket:
// TODO FIXME: ignore flags field
err := os.Mkdir(toLocalPath(p.Path), 0755)
err := os.Mkdir(s.toLocalPath(p.Path), 0o755)
rpkt = statusFromError(p.ID, err)
case *sshFxpRmdirPacket:
err := os.Remove(toLocalPath(p.Path))
err := os.Remove(s.toLocalPath(p.Path))
rpkt = statusFromError(p.ID, err)
case *sshFxpRemovePacket:
err := os.Remove(toLocalPath(p.Filename))
err := os.Remove(s.toLocalPath(p.Filename))
rpkt = statusFromError(p.ID, err)
case *sshFxpRenamePacket:
err := os.Rename(toLocalPath(p.Oldpath), toLocalPath(p.Newpath))
err := os.Rename(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath))
rpkt = statusFromError(p.ID, err)
case *sshFxpSymlinkPacket:
err := os.Symlink(toLocalPath(p.Targetpath), toLocalPath(p.Linkpath))
err := os.Symlink(s.toLocalPath(p.Targetpath), s.toLocalPath(p.Linkpath))
rpkt = statusFromError(p.ID, err)
case *sshFxpClosePacket:
rpkt = statusFromError(p.ID, s.closeHandle(p.Handle))
case *sshFxpReadlinkPacket:
f, err := os.Readlink(toLocalPath(p.Path))
f, err := os.Readlink(s.toLocalPath(p.Path))
rpkt = &sshFxpNamePacket{
ID: p.ID,
NameAttrs: []*sshFxpNameAttr{
Expand All @@ -240,7 +251,7 @@ func handlePacket(s *Server, p orderedRequest) error {
rpkt = statusFromError(p.ID, err)
}
case *sshFxpRealpathPacket:
f, err := filepath.Abs(toLocalPath(p.Path))
f, err := filepath.Abs(s.toLocalPath(p.Path))
f = cleanPath(f)
rpkt = &sshFxpNamePacket{
ID: p.ID,
Expand All @@ -256,13 +267,14 @@ func handlePacket(s *Server, p orderedRequest) error {
rpkt = statusFromError(p.ID, err)
}
case *sshFxpOpendirPacket:
p.Path = toLocalPath(p.Path)
lp := s.toLocalPath(p.Path)

if stat, err := os.Stat(p.Path); err != nil {
if stat, err := os.Stat(lp); err != nil {
rpkt = statusFromError(p.ID, err)
} else if !stat.IsDir() {
rpkt = statusFromError(p.ID, &os.PathError{
Path: p.Path, Err: syscall.ENOTDIR})
Path: lp, Err: syscall.ENOTDIR,
})
} else {
rpkt = (&sshFxpOpenPacket{
ID: p.ID,
Expand Down Expand Up @@ -446,7 +458,7 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
osFlags |= os.O_EXCL
}

f, err := os.OpenFile(toLocalPath(p.Path), osFlags, 0644)
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644)
if err != nil {
return statusFromError(p.ID, err)
}
Expand Down Expand Up @@ -484,7 +496,7 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {
b := p.Attrs.([]byte)
var err error

p.Path = toLocalPath(p.Path)
p.Path = svr.toLocalPath(p.Path)

debug("setstat name \"%s\"", p.Path)
if (p.Flags & sshFileXferAttrSize) != 0 {
Expand Down
30 changes: 30 additions & 0 deletions server_plan9.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//go:build plan9
// +build plan9
mafredri marked this conversation as resolved.
Show resolved Hide resolved

package sftp

import (
"path"
"path/filepath"
)

func (s *Server) toLocalPath(p string) string {
if s.workDir != "" && !path.IsAbs(p) {
p = path.Join(s.workDir, p)
}

lp := filepath.FromSlash(p)

if path.IsAbs(p) {
tmp := lp[1:]

if filepath.IsAbs(tmp) {
// If the FromSlash without any starting slashes is absolute,
// then we have a filepath encoded with a prefix '/'.
// e.g. "/#s/boot" to "#s/boot"
return tmp
}
}

return lp
}
16 changes: 16 additions & 0 deletions server_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//go:build !windows && !plan9
// +build !windows,!plan9

package sftp

import (
"path"
)

func (s *Server) toLocalPath(p string) string {
if s.workDir != "" && !path.IsAbs(p) {
p = path.Join(s.workDir, p)
}

return p
}
87 changes: 87 additions & 0 deletions server_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//go:build !windows && !plan9
// +build !windows,!plan9
mafredri marked this conversation as resolved.
Show resolved Hide resolved

package sftp

import (
"testing"
)

func TestServer_toLocalPath(t *testing.T) {
tests := []struct {
name string
withWorkDir string
p string
want string
}{
{
name: "empty path with no workdir",
p: "",
want: "",
},
{
name: "relative path with no workdir",
p: "file",
want: "file",
},
{
name: "absolute path with no workdir",
p: "/file",
want: "/file",
},
{
name: "workdir and empty path",
withWorkDir: "/home/user",
p: "",
want: "/home/user",
},
{
name: "workdir and relative path",
withWorkDir: "/home/user",
p: "file",
want: "/home/user/file",
},
{
name: "workdir and relative path with .",
withWorkDir: "/home/user",
p: ".",
want: "/home/user",
},
{
name: "workdir and relative path with . and file",
withWorkDir: "/home/user",
p: "./file",
want: "/home/user/file",
},
{
name: "workdir and absolute path",
withWorkDir: "/home/user",
p: "/file",
want: "/file",
},
{
name: "workdir and non-unixy path prefixes workdir",
withWorkDir: "/home/user",
p: "C:\\file",
// This may look like a bug but it is the result of passing
// invalid input (a non-unixy path) to the server.
want: "/home/user/C:\\file",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// We don't need to initialize the Server further to test
// toLocalPath behavior.
s := &Server{}
if tt.withWorkDir != "" {
if err := WithServerWorkingDirectory(tt.withWorkDir)(s); err != nil {
t.Fatal(err)
}
}

if got := s.toLocalPath(tt.p); got != tt.want {
t.Errorf("Server.toLocalPath() = %v, want %v", got, tt.want)
mafredri marked this conversation as resolved.
Show resolved Hide resolved
}
})
}
}
39 changes: 39 additions & 0 deletions server_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package sftp

import (
"path"
"path/filepath"
)

func (s *Server) toLocalPath(p string) string {
if s.workDir != "" && !path.IsAbs(p) {
p = path.Join(s.workDir, p)
}

lp := filepath.FromSlash(p)

if path.IsAbs(p) {
tmp := lp
for len(tmp) > 0 && tmp[0] == '\\' {
tmp = tmp[1:]
}

if filepath.IsAbs(tmp) {
// If the FromSlash without any starting slashes is absolute,
// then we have a filepath encoded with a prefix '/'.
// e.g. "/C:/Windows" to "C:\\Windows"
return tmp
}

tmp += "\\"

if filepath.IsAbs(tmp) {
// If the FromSlash without any starting slashes but with extra end slash is absolute,
// then we have a filepath encoded with a prefix '/' and a dropped '/' at the end.
// e.g. "/C:" to "C:\\"
return tmp
}
}

return lp
}
Loading