Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 9 additions & 53 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ import (
"github.com/gravitational/teleport/lib/shell"
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/scp"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -2003,53 +2002,6 @@ func PlayFile(ctx context.Context, tarFile io.Reader, sid string) error {
return playSession(sessionEvents, stream)
}

// ExecuteSCP executes SCP command. It executes scp.Command using
// lower-level API integrations that mimic SCP CLI command behavior
func (tc *TeleportClient) ExecuteSCP(ctx context.Context, serverAddr string, cmd scp.Command) error {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/ExecuteSCP",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

// connect to proxy first:
if !tc.Config.ProxySpecified() {
return trace.BadParameter("proxy server is not specified")
}

clt, err := tc.ConnectToCluster(ctx)
if err != nil {
return trace.Wrap(err)
}
defer clt.Close()

nodeClient, err := tc.ConnectToNode(
ctx,
clt,
// We append the ":0" to tell the server to figure out the port for us.
NodeDetails{Addr: serverAddr + ":0", Namespace: tc.Namespace, Cluster: clt.ClusterName()},
tc.Config.HostLogin,
)
if err != nil {
tc.ExitStatus = 1
return trace.Wrap(err)
}

err = nodeClient.ExecuteSCP(ctx, cmd)
if err != nil {
// converts SSH error code to tc.ExitStatus
exitError, _ := trace.Unwrap(err).(*ssh.ExitError)
if exitError != nil {
tc.ExitStatus = exitError.ExitStatus()
}
return err

}

return nil
}

// SFTP securely copies files between Nodes or SSH servers using SFTP
func (tc *TeleportClient) SFTP(ctx context.Context, args []string, port int, opts sftp.Options, quiet bool) (err error) {
ctx, span := tc.Tracer.Start(
Expand Down Expand Up @@ -2107,7 +2059,7 @@ func (tc *TeleportClient) uploadConfig(args []string, port int, opts sftp.Option
// copy everything except the last arg (the destination)
dstPath := args[len(args)-1]

dst, addr, err := getSCPDestination(dstPath, port)
dst, addr, err := getSFTPDestination(dstPath, port)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -2129,7 +2081,7 @@ func (tc *TeleportClient) downloadConfig(args []string, port int, opts sftp.Opti
}

// args are guaranteed to have len(args) > 1
src, addr, err := getSCPDestination(args[0], port)
src, addr, err := getSFTPDestination(args[0], port)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -2145,8 +2097,8 @@ func (tc *TeleportClient) downloadConfig(args []string, port int, opts sftp.Opti
}, nil
}

func getSCPDestination(target string, port int) (dest *scp.Destination, addr string, err error) {
dest, err = scp.ParseSCPDestination(target)
func getSFTPDestination(target string, port int) (dest *sftp.Destination, addr string, err error) {
dest, err = sftp.ParseDestination(target)
if err != nil {
return nil, "", trace.Wrap(err)
}
Expand Down Expand Up @@ -2183,7 +2135,11 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr
client, err := tc.ConnectToNode(
ctx,
clt,
NodeDetails{Addr: nodeAddr, Namespace: tc.Namespace, Cluster: clt.ClusterName()},
NodeDetails{
Addr: nodeAddr,
Namespace: tc.Namespace,
Cluster: clt.ClusterName(),
},
hostLogin,
)
if err != nil {
Expand Down
103 changes: 0 additions & 103 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -54,7 +52,6 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/scp"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/socks"
Expand Down Expand Up @@ -1737,106 +1734,6 @@ func (proxy *ProxyClient) Close() error {
return trace.NewAggregate(proxy.Client.Close(), proxy.currentCluster.Close())
}

// ExecuteSCP runs remote scp command(shellCmd) on the remote server and
// runs local scp handler using SCP Command
func (c *NodeClient) ExecuteSCP(ctx context.Context, cmd scp.Command) error {
ctx, span := c.Tracer.Start(
ctx,
"nodeClient/ExecuteSCP",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

shellCmd, err := cmd.GetRemoteShellCmd()
if err != nil {
return trace.Wrap(err)
}

s, err := c.Client.NewSession(ctx)
if err != nil {
return trace.Wrap(err)
}
defer s.Close()

// File transfers in a moderated session require these two variablesto check for
// approval on the ssh server. If they exist in the context, set them in our env vars
if moderatedSessionID, ok := ctx.Value(scp.ModeratedSessionID).(string); ok {
s.Setenv(ctx, string(scp.ModeratedSessionID), moderatedSessionID)
}
if fileTransferRequestID, ok := ctx.Value(scp.FileTransferRequestID).(string); ok {
s.Setenv(ctx, string(scp.FileTransferRequestID), fileTransferRequestID)
}

stdin, err := s.StdinPipe()
if err != nil {
return trace.Wrap(err)
}

stdout, err := s.StdoutPipe()
if err != nil {
return trace.Wrap(err)
}

// Stream scp's stderr so tsh gets the verbose remote error
// if the command fails
stderr, err := s.StderrPipe()
if err != nil {
return trace.Wrap(err)
}
go io.Copy(os.Stderr, stderr)

ch := utils.NewPipeNetConn(
stdout,
stdin,
utils.MultiCloser(),
&net.IPAddr{},
&net.IPAddr{},
)

execC := make(chan error, 1)
go func() {
err := cmd.Execute(ch)
if err != nil && !trace.IsEOF(err) {
log.WithError(err).Warn("Failed to execute SCP command.")
}
stdin.Close()
execC <- err
}()

runC := make(chan error, 1)
go func() {
err := s.Run(ctx, shellCmd)
if err != nil && errors.Is(err, &ssh.ExitMissingError{}) {
// TODO(dmitri): currently, if the session is aborted with (*session).Close,
// the remote side cannot send exit-status and this error results.
// To abort the session properly, Teleport needs to support `signal` request
err = nil
}
runC <- err
}()

var runErr error
select {
case <-ctx.Done():
if err := s.Close(); err != nil {
log.WithError(err).Debug("Failed to close the SSH session.")
}
err, runErr = <-execC, <-runC
case err = <-execC:
runErr = <-runC
case runErr = <-runC:
err = <-execC
}

if runErr != nil && (err == nil || trace.IsEOF(err)) {
err = runErr
}
if trace.IsEOF(err) {
err = nil
}
return trace.Wrap(err)
}

// TransferFiles transfers files over SFTP.
func (c *NodeClient) TransferFiles(ctx context.Context, cfg *sftp.Config) error {
ctx, span := c.Tracer.Start(
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/authhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (h *AuthHandlers) CheckX11Forward(ctx *ServerContext) error {

func (h *AuthHandlers) CheckFileCopying(ctx *ServerContext) error {
if !ctx.Identity.AccessChecker.CanCopyFiles() {
return errRoleFileCopyingNotPermitted
return trace.Wrap(errRoleFileCopyingNotPermitted)
}
return nil
}
Expand Down
24 changes: 18 additions & 6 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ func writeChildError(w io.Writer, err error) {
Code: trace.ErrorToCode(err),
RawError: data,
})

}

// DecodeChildError consumes the output from a child
Expand Down Expand Up @@ -726,11 +725,11 @@ func (c *ServerContext) SetAllowFileCopying(allow bool) {
func (c *ServerContext) CheckFileCopyingAllowed() error {
// Check if remote file operations are disabled for this node.
if !c.AllowFileCopying {
return ErrNodeFileCopyingNotPermitted
return trace.Wrap(ErrNodeFileCopyingNotPermitted)
}
// Check if the user's RBAC role allows remote file operations.
if !c.Identity.AccessChecker.CanCopyFiles() {
return errRoleFileCopyingNotPermitted
return trace.Wrap(errRoleFileCopyingNotPermitted)
}

return nil
Expand All @@ -739,7 +738,7 @@ func (c *ServerContext) CheckFileCopyingAllowed() error {
// CheckSFTPAllowed returns an error if remote file operations via SCP
// or SFTP are not allowed by the user's role or the node's config, or
// if the user is not allowed to start unattended sessions.
func (c *ServerContext) CheckSFTPAllowed() error {
func (c *ServerContext) CheckSFTPAllowed(registry *SessionRegistry) error {
if err := c.CheckFileCopyingAllowed(); err != nil {
return trace.Wrap(err)
}
Expand All @@ -751,8 +750,21 @@ func (c *ServerContext) CheckSFTPAllowed() error {
if err != nil {
return trace.Wrap(err)
}
if !canStart {
return errCannotStartUnattendedSession
// canStart will be true for non-moderated sessions. If canStart is false, check to
// see if the request has been approved through a moderated session next.
if canStart {
return nil
}
if registry == nil {
return trace.Wrap(errCannotStartUnattendedSession)
}

approved, err := registry.isApprovedFileTransfer(c)
if err != nil {
return trace.Wrap(err)
}
if !approved {
return trace.Wrap(errCannotStartUnattendedSession)
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func TestCheckSFTPAllowed(t *testing.T) {
roles,
)

err := ctx.CheckSFTPAllowed()
err := ctx.CheckSFTPAllowed(nil)
if tt.expectedErr == nil {
require.NoError(t, err)
} else {
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2049,7 +2049,7 @@ func (s *Server) parseSubsystemRequest(req *ssh.Request, ch ssh.Channel, ctx *sr
case r.Name == teleport.GetHomeDirSubsystem:
return newHomeDirSubsys(), nil
case r.Name == sftpSubsystem:
if err := ctx.CheckSFTPAllowed(); err != nil {
if err := ctx.CheckSFTPAllowed(s.reg); err != nil {
s.replyError(ch, req, err)
return nil, trace.Wrap(err)
}
Expand Down
11 changes: 7 additions & 4 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import (
"github.com/gravitational/teleport/lib/services"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/scp"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -353,7 +353,7 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro
// if a sessID and requestID environment variables were not set, return not approved and no error.
// This means the file transfer came from a non-moderated session. sessionID will be passed after a
// moderated session approval process has completed.
sessID, _ := scx.GetEnv(string(scp.ModeratedSessionID))
sessID, _ := scx.GetEnv(string(sftp.ModeratedSessionID))
if sessID == "" {
return false, nil
}
Expand All @@ -364,7 +364,7 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro
return false, trace.NotFound("Session not found")
}

requestID, _ := scx.GetEnv(string(scp.FileTransferRequestID))
requestID, _ := scx.GetEnv(string(sftp.FileTransferRequestID))
if requestID == "" {
return false, nil
}
Expand All @@ -379,7 +379,10 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro
return false, trace.AccessDenied("Teleport user does not match original requester")
}

incomingShellCmd := string(scx.sshRequest.Payload)
var incomingShellCmd string
if scx.sshRequest != nil {
incomingShellCmd = string(scx.sshRequest.Payload)
}
if incomingShellCmd != req.shellCmd {
return false, trace.AccessDenied("Incoming request does not match the approved request")
}
Expand Down
6 changes: 3 additions & 3 deletions lib/srv/sess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import (
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshutils/scp"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -208,8 +208,8 @@ func TestIsApprovedFileTransfer(t *testing.T) {
Payload: []byte("/usr/bin/scp -f ~/logs.txt"),
}

scx.SetEnv(string(scp.ModeratedSessionID), sess.ID())
scx.SetEnv(string(scp.FileTransferRequestID), tt.reqID)
scx.SetEnv(string(sftp.ModeratedSessionID), sess.ID())
scx.SetEnv(string(sftp.FileTransferRequestID), tt.reqID)
result, err := reg.isApprovedFileTransfer(scx)
if err != nil {
require.Equal(t, tt.expectedError, err.Error())
Expand Down
Loading