Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions api/observability/tracing/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ import (
"sync"
"sync/atomic"

"github.com/google/uuid"
"github.com/gravitational/trace"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/types"
)

// Client is a wrapper around ssh.Client that adds tracing support.
Expand Down Expand Up @@ -166,9 +168,65 @@ func (c *Client) OpenChannel(
}, reqs, err
}

// NewSession creates a new SSH session that is passed tracing context
// so that spans may be correlated properly over the ssh connection.
// SessionParams are session parameters supported by Teleport to provide additional
// session context or parameters to the server.
type SessionParams struct {
// WebProxyAddr is the address of the proxy forwarding the SSH connection to the target server.
WebProxyAddr string
// Reason is a reason attached to started sessions meant to describe their intent.
Reason string
// Invited is a list of people invited to a session.
Invited []string
// DisplayParticipantRequirements is set if debug information about participants requirements
// should be printed in moderated sessions.
DisplayParticipantRequirements bool
// JoinSessionID is the ID of a session to join.
JoinSessionID string
// JoinMode is the participant mode to join the session with.
// Required if JoinSessionID is set.
JoinMode types.SessionParticipantMode
// ModeratedSessionID is an optional parameter sent during SCP requests to specify which moderated session
// to check for valid FileTransferRequests.
ModeratedSessionID string
}

// ParseSessionParams unmarshals session parameters which have been [ssh.Marshal]ed by the client
// and provided as extra data in the session channel request. If the provided data is empty, nil params
// will be returned with a nil error.
func ParseSessionParams(data []byte) (*SessionParams, error) {
if len(data) == 0 {
return nil, nil
}

var params SessionParams
if err := ssh.Unmarshal(data, &params); err != nil {
return nil, trace.Wrap(err)
}

if params.JoinSessionID != "" {
if _, err := uuid.Parse(params.JoinSessionID); err != nil {
return nil, trace.Wrap(err, "failed to parse join session ID: %v", params.JoinSessionID)
}

switch params.JoinMode {
case types.SessionModeratorMode, types.SessionObserverMode, types.SessionPeerMode:
default:
return nil, trace.BadParameter("Unrecognized session participant mode: %q", params.JoinMode)
}
}

return &params, nil
}

// NewSession creates a new SSH session. This session is passed a tracing context so that
// spans may be correlated properly over the ssh connection.
func (c *Client) NewSession(ctx context.Context) (*Session, error) {
return c.NewSessionWithParams(ctx, nil)
}

// NewSessionWithParams creates a new SSH session with the given (optional) params. This session is
// passed a tracing context so that spans may be correlated properly over the ssh connection.
func (c *Client) NewSessionWithParams(ctx context.Context, sessionParams *SessionParams) (*Session, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

ctx, span := tracer.Start(
Expand All @@ -195,9 +253,16 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) {
contexts: make(map[string][]context.Context),
}

// If we are connected to a Teleport server, send session params in the session request.
Comment thread
rosstimothy marked this conversation as resolved.
// If the server does not support session parameters in the extra data, it will be ignored.
var sessionData []byte
if sessionParams != nil && c.capability == tracingSupported {
sessionData = ssh.Marshal(sessionParams)
}

// open a session manually so we can take ownership of the
// requests chan
ch, reqs, err := wrapper.OpenChannel("session", nil)
ch, reqs, err := wrapper.OpenChannel("session", sessionData)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -218,7 +283,7 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) {
}

// RequestHandlerFn is an ssh request handler function.
type RequestHandlerFn func(ctx context.Context, ch *ssh.Request)
type RequestHandlerFn func(ctx context.Context, req *ssh.Request)

// HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the
// provided type within a session. If the type is already being handled, an error is returned.
Expand Down
6 changes: 3 additions & 3 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) {
)
assert.NoError(t, err)

errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil)
errCh <- nodeClient.RunInteractiveShell(ctx, "", "", nil)
assert.NoError(t, nodeClient.Close())
}()

Expand Down Expand Up @@ -7978,7 +7978,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) {
isNilOrEOFErr(t, transferSess.Close())
})

err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID())
err = transferSess.Setenv(ctx, telesftp.EnvModeratedSessionID, sessTracker.GetSessionID())
require.NoError(t, err)

err = transferSess.RequestSubsystem(ctx, teleport.SFTPSubsystem)
Expand Down Expand Up @@ -8040,7 +8040,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, transferSess.Close())
})

err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID())
err = transferSess.Setenv(ctx, telesftp.EnvModeratedSessionID, sessTracker.GetSessionID())
require.NoError(t, err)

// Test that only operations needed to complete the download
Expand Down
20 changes: 3 additions & 17 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"errors"
"fmt"
"io"
"maps"
"net"
"net/url"
"os"
Expand Down Expand Up @@ -94,7 +93,6 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -2285,7 +2283,7 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
// Reuse the existing nodeClient we connected above.
return nodeClient.RunCommand(ctx, command)
}
return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil))
return trace.Wrap(nodeClient.RunInteractiveShell(ctx, "", "", nil))
}

func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error {
Expand Down Expand Up @@ -2439,7 +2437,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan
}

// running shell with a given session means "join" it:
err = nc.RunInteractiveShell(ctx, mode, session, beforeStart)
err = nc.RunInteractiveShell(ctx, sessionID.String(), mode, beforeStart)
return trace.Wrap(err)
}

Expand Down Expand Up @@ -2685,7 +2683,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, clt *ClusterClient,
return trace.Wrap(err)
}

return trace.Wrap(nodeClient.TransferFiles(ctx, cfg))
return trace.Wrap(nodeClient.TransferFiles(ctx, cfg, "" /*moderatedSessionID*/))
}

// ListNodesWithFilters returns all nodes that match the filters in the current cluster
Expand Down Expand Up @@ -3113,18 +3111,6 @@ func (tc *TeleportClient) writeCommandResults(nodes []execResult) error {
return nil
}

func (tc *TeleportClient) newSessionEnv() map[string]string {
env := map[string]string{
teleport.SSHSessionWebProxyAddr: tc.WebProxyAddr,
}
if tc.SessionID != "" {
env[sshutils.SessionEnvVar] = tc.SessionID
}

maps.Copy(env, tc.ExtraEnvs)
return env
}

// getProxyLogin determines which SSH principal to use when connecting to proxy.
func (tc *TeleportClient) getProxySSHPrincipal() string {
if tc.ProxySSHPrincipal != "" {
Expand Down
56 changes: 32 additions & 24 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"net"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -388,39 +387,32 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co
return nc, nil
}

// RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr
// RunInteractiveShell creates or joins an interactive shell on the node and copies stdin/stdout/stderr
// to and from the node and local shell. This will block until the interactive shell on the node
// is terminated.
func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, beforeStart func(io.Writer)) error {
func (c *NodeClient) RunInteractiveShell(ctx context.Context, joinSessionID string, joinMode types.SessionParticipantMode, beforeStart func(io.Writer)) error {
ctx, span := c.Tracer.Start(
ctx,
"nodeClient/RunInteractiveShell",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

env := c.TC.newSessionEnv()
env[teleport.EnvSSHJoinMode] = string(mode)
env[teleport.EnvSSHSessionReason] = c.TC.Config.Reason
env[teleport.EnvSSHSessionDisplayParticipantRequirements] = strconv.FormatBool(c.TC.Config.DisplayParticipantRequirements)
encoded, err := json.Marshal(&c.TC.Config.Invited)
if err != nil {
return trace.Wrap(err)
sessionParams := &tracessh.SessionParams{
WebProxyAddr: c.WebProxyAddr(),
Reason: c.TC.Config.Reason,
Invited: c.TC.Config.Invited,
DisplayParticipantRequirements: c.TC.Config.DisplayParticipantRequirements,
JoinSessionID: joinSessionID,
JoinMode: joinMode,
}
env[teleport.EnvSSHSessionInvited] = string(encoded)

// Overwrite "SSH_SESSION_WEBPROXY_ADDR" with the public addr reported by the proxy. Otherwise,
// this would be set to the localhost addr (tc.WebProxyAddr) used for Web UI client connections.
if c.ProxyPublicAddr != "" && c.TC.WebProxyAddr != c.ProxyPublicAddr {
env[teleport.SSHSessionWebProxyAddr] = c.ProxyPublicAddr
}

nodeSession, err := newSession(ctx, c, sessToJoin, env, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, !c.TC.DisableEscapeSequences)
nodeSession, err := newSession(ctx, c, sessionParams, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, !c.TC.DisableEscapeSequences)
if err != nil {
return trace.Wrap(err)
}

if err = nodeSession.runShell(ctx, mode, beforeStart, c.TC.OnShellCreated); err != nil {
if err = nodeSession.runShell(ctx, sessionParams, beforeStart, c.TC.OnShellCreated); err != nil {
var exitErr *ssh.ExitError
var exitMissingErr *ssh.ExitMissingError
switch err := trace.Unwrap(err); {
Expand Down Expand Up @@ -616,13 +608,19 @@ func (c *NodeClient) RunCommand(ctx context.Context, command []string, opts ...R
}
}

nodeSession, err := newSession(ctx, c, nil, c.TC.newSessionEnv(), c.TC.Stdin, stdout, stderr, !c.TC.DisableEscapeSequences)
sessionParams := &tracessh.SessionParams{
WebProxyAddr: c.WebProxyAddr(),
Reason: c.TC.Config.Reason,
Invited: c.TC.Config.Invited,
DisplayParticipantRequirements: c.TC.Config.DisplayParticipantRequirements,
}

nodeSession, err := newSession(ctx, c, sessionParams, c.TC.Stdin, stdout, stderr, !c.TC.DisableEscapeSequences)
if err != nil {
return trace.Wrap(err)
}
defer nodeSession.Close()

err = nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand)
err = nodeSession.runCommand(ctx, sessionParams, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand)
if err != nil {
c.TC.SetExitStatus(getExitStatus(err))
}
Expand Down Expand Up @@ -745,15 +743,15 @@ func newClientConn(
}

// TransferFiles transfers files over SFTP.
func (c *NodeClient) TransferFiles(ctx context.Context, cfg *sftp.Config) error {
func (c *NodeClient) TransferFiles(ctx context.Context, cfg *sftp.Config, moderatedSessionID string) error {
ctx, span := c.Tracer.Start(
ctx,
"nodeClient/TransferFiles",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

if err := cfg.TransferFiles(ctx, c.Client.Client); err != nil {
if err := cfg.TransferFiles(ctx, c.Client, moderatedSessionID); err != nil {
// TODO(tross): DELETE IN 19.0.0 - Older versions of Teleport would return
// a trace.BadParameter error when ~user path expansion was rejected, and
// reauthentication logic is attempted on BadParameter errors.
Expand Down Expand Up @@ -1029,3 +1027,13 @@ func GetPaginatedSessions(ctx context.Context, fromUTC, toUTC time.Time, pageSiz
}
return sessions, nil
}

// WebProxyAddr is the address of the proxy forwarding the SSH connection to the target server.
func (c *NodeClient) WebProxyAddr() string {
// Prioritize the public addr reported by the proxy. Otherwise, this would
// return the localhost addr used for Web UI client connections.
if c.ProxyPublicAddr != "" {
return c.ProxyPublicAddr
}
return c.TC.WebProxyAddr
}
12 changes: 2 additions & 10 deletions lib/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
"github.com/gravitational/teleport/api/client/proto"
tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh"
"github.com/gravitational/teleport/lib/observability/tracing"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tlsca"
)

Expand All @@ -48,27 +47,20 @@ func TestHelperFunctions(t *testing.T) {

func TestNewSession(t *testing.T) {
nc := &NodeClient{
TC: &TeleportClient{},
Tracer: tracing.NoopProvider().Tracer("test"),
}

ctx := context.Background()
// defaults:
ses, err := newSession(ctx, nc, nil, nil, nil, nil, nil, true)
ses, err := newSession(ctx, nc, nil, nil, nil, nil, true)
require.NoError(t, err)
require.NotNil(t, ses)
require.Equal(t, nc, ses.NodeClient())
require.NotNil(t, ses.env)
require.Equal(t, os.Stderr, ses.terminal.Stderr())
require.Equal(t, os.Stdout, ses.terminal.Stdout())
require.Equal(t, os.Stdin, ses.terminal.Stdin())

// pass environ map
env := map[string]string{
sshutils.SessionEnvVar: "session-id",
}
ses, err = newSession(ctx, nc, nil, env, nil, nil, nil, true)
require.NoError(t, err)
require.NotNil(t, ses)
}

// TestProxyConnection verifies that client or server-side disconnect
Expand Down
Loading
Loading