Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions api/observability/tracing/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ import (
"sync"
"sync/atomic"

"github.com/google/uuid"
"github.com/gravitational/trace"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/types"
)

// Client is a wrapper around ssh.Client that adds tracing support.
Expand Down Expand Up @@ -184,9 +186,65 @@ func (c *Client) OpenChannel(
}, reqs, err
}

// NewSession creates a new SSH session that is passed tracing context
// so that spans may be correlated properly over the ssh connection.
// SessionParams are session parameters supported by Teleport to provide additional
// session context or parameters to the server.
type SessionParams struct {
// WebProxyAddr is the address of the proxy forwarding the SSH connection to the target server.
WebProxyAddr string
// Reason is a reason attached to started sessions meant to describe their intent.
Reason string
// Invited is a list of people invited to a session.
Invited []string
// DisplayParticipantRequirements is set if debug information about participants requirements
// should be printed in moderated sessions.
DisplayParticipantRequirements bool
// JoinSessionID is the ID of a session to join.
JoinSessionID string
// JoinMode is the participant mode to join the session with.
// Required if JoinSessionID is set.
JoinMode types.SessionParticipantMode
// ModeratedSessionID is an optional parameter sent during SCP requests to specify which moderated session
// to check for valid FileTransferRequests.
ModeratedSessionID string
}

// ParseSessionParams unmarshals session parameters which have been [ssh.Marshal]ed by the client
// and provided as extra data in the session channel request. If the provided data is empty, nil params
// will be returned with a nil error.
func ParseSessionParams(data []byte) (*SessionParams, error) {
if len(data) == 0 {
return nil, nil
}

var params SessionParams
if err := ssh.Unmarshal(data, &params); err != nil {
return nil, trace.Wrap(err)
}

if params.JoinSessionID != "" {
if _, err := uuid.Parse(params.JoinSessionID); err != nil {
return nil, trace.Wrap(err, "failed to parse join session ID: %v", params.JoinSessionID)
}

switch params.JoinMode {
case types.SessionModeratorMode, types.SessionObserverMode, types.SessionPeerMode:
default:
return nil, trace.BadParameter("Unrecognized session participant mode: %q", params.JoinMode)
}
}

return &params, nil
}

// NewSession creates a new SSH session. This session is passed a tracing context so that
// spans may be correlated properly over the ssh connection.
func (c *Client) NewSession(ctx context.Context) (*Session, error) {
return c.NewSessionWithParams(ctx, nil)
}

// NewSessionWithParams creates a new SSH session with the given (optional) params. This session is
// passed a tracing context so that spans may be correlated properly over the ssh connection.
func (c *Client) NewSessionWithParams(ctx context.Context, sessionParams *SessionParams) (*Session, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

ctx, span := tracer.Start(
Expand All @@ -213,9 +271,16 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) {
contexts: make(map[string][]context.Context),
}

// If we are connected to a Teleport server, send session params in the session request.
// If the server does not support session parameters in the extra data, it will be ignored.
var sessionData []byte
if sessionParams != nil && c.capability == tracingSupported {
sessionData = ssh.Marshal(sessionParams)
}

// open a session manually so we can take ownership of the
// requests chan
ch, reqs, err := wrapper.OpenChannel("session", nil)
ch, reqs, err := wrapper.OpenChannel("session", sessionData)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -236,7 +301,7 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) {
}

// RequestHandlerFn is an ssh request handler function.
type RequestHandlerFn func(ctx context.Context, ch *ssh.Request)
type RequestHandlerFn func(ctx context.Context, req *ssh.Request)

// HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the
// provided type within a session. If the type is already being handled, an error is returned.
Expand Down
15 changes: 14 additions & 1 deletion constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -831,9 +831,22 @@ const (
CurrentSessionIDRequest = "current-session-id@goteleport.com"

// SessionIDQueryRequest is sent by clients to ask servers if they
// will generate their own session ID when a new session is created.
// will generate and share their own session ID when a new session
// is started (session and exec/shell channels accepted).
//
// TODO(Joerger): DELETE IN v20.0.0
// All v17+ servers set the session ID. v19+ clients stop checking.
SessionIDQueryRequest = "session-id-query@goteleport.com"

// SessionIDQueryRequestV2 is sent by clients to ask servers if they
// will generate and share their own session ID when a new session
// channel is accepted, rather than when the shell/exec channel is.
//
// TODO(Joerger): DELETE IN v21.0.0
// all v19+ servers set the session ID directly after accepting the session channel.
// clients should stop checking in v21, and servers should stop responding to the query in v22.
SessionIDQueryRequestV2 = "session-id-query-v2@goteleport.com"

// ForceTerminateRequest is an SSH request to forcefully terminate a session.
ForceTerminateRequest = "x-teleport-force-terminate"

Expand Down
52 changes: 31 additions & 21 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ func TestIntegrations(t *testing.T) {
t.Run("BPFSessionDifferentiation", suite.bind(testBPFSessionDifferentiation))
t.Run("ClientIdleConnection", suite.bind(testClientIdleConnection))
t.Run("CmdLabels", suite.bind(testCmdLabels))
t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters))
t.Run("ControlMaster", suite.bind(testControlMaster))
t.Run("X11Forwarding", suite.bind(testX11Forwarding))
t.Run("CustomReverseTunnel", suite.bind(testCustomReverseTunnel))
t.Run("DataTransfer", suite.bind(testDataTransfer))
t.Run("DifferentPinnedIP", suite.bind(testDifferentPinnedIP))
Expand Down Expand Up @@ -200,12 +200,12 @@ func TestIntegrations(t *testing.T) {
t.Run("TrustedClustersRoleMapChanges", suite.bind(testTrustedClustersRoleMapChanges))
t.Run("TrustedClustersWithLabels", suite.bind(testTrustedClustersWithLabels))
t.Run("TrustedClustersSkipNameValidation", suite.bind(testTrustedClustersSkipNameValidation))
t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters))
t.Run("TrustedTunnelNode", suite.bind(testTrustedTunnelNode))
t.Run("TwoClustersProxy", suite.bind(testTwoClustersProxy))
t.Run("TwoClustersTunnel", suite.bind(testTwoClustersTunnel))
t.Run("UUIDBasedProxy", suite.bind(testUUIDBasedProxy))
t.Run("WindowChange", suite.bind(testWindowChange))
t.Run("X11Forwarding", suite.bind(testX11Forwarding))
}

// testDifferentPinnedIP tests connection is rejected when source IP doesn't match the pinned one
Expand Down Expand Up @@ -522,6 +522,7 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) {
}
}

// Test streaming events and recording.
capturedStream, sessionEvents := streamSession(ctx, t, site, sessionID)

findByType := func(et string) apievents.AuditEvent {
Expand All @@ -532,39 +533,23 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) {
}
return nil
}
// helper that asserts that a session event is also included in the
// general audit log.
requireInAuditLog := func(t *testing.T, sessionEvent apievents.AuditEvent) {
t.Helper()
auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{
To: time.Now(),
EventTypes: []string{sessionEvent.GetType()},
})
require.NoError(t, err)
require.True(t, slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool {
return ae.GetID() == sessionEvent.GetID()
}))
}

// there should always be 'session.start' event (and it must be first)
first := sessionEvents[0].(*apievents.SessionStart)
start := findByType(events.SessionStartEvent).(*apievents.SessionStart)
require.Equal(t, first, start)
require.Equal(t, sessionID, start.SessionID)
require.NotEmpty(t, start.TerminalSize)
requireInAuditLog(t, start)

// there should always be 'session.end' event
end := findByType(events.SessionEndEvent).(*apievents.SessionEnd)
require.NotNil(t, end)
require.Equal(t, sessionID, end.SessionID)
requireInAuditLog(t, end)

// there should always be 'session.leave' event
leave := findByType(events.SessionLeaveEvent).(*apievents.SessionLeave)
require.NotNil(t, leave)
require.Equal(t, sessionID, leave.SessionID)
requireInAuditLog(t, leave)

// all of them should have a proper time
for _, e := range sessionEvents {
Expand All @@ -575,6 +560,31 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) {
recorded := replaceNewlines(capturedStream)
require.Regexp(t, ".*exit.*", recorded)
require.Regexp(t, ".*echo hi.*", recorded)

sessionEvents, _, err = site.SearchEvents(ctx, events.SearchEventsRequest{
From: time.Time{},
To: time.Now(),
EventTypes: []string{
events.SessionStartEvent,
events.SessionLeaveEvent,
events.SessionEndEvent,
},
})
require.NoError(t, err)

// Check that the events found above in the session stream show up in the backend.
require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool {
return ae.GetID() == start.GetID()
}), "expected session events to contain session.start event")
require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool {
return ae.GetID() == end.GetID()
}), "expected session events to contain session.end event")
require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool {
return ae.GetID() == leave.GetID()
}), "expected session events to contain session.leave event")

// Ensure there are no duplicate events, e.g. from proxy recording mode.
require.Len(t, sessionEvents, 3, "%d unexpected duplicate events", len(sessionEvents)-4)
})
}
}
Expand Down Expand Up @@ -1120,7 +1130,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) {
)
assert.NoError(t, err)

errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil)
errCh <- nodeClient.RunInteractiveShell(ctx, "", "", nil)
assert.NoError(t, nodeClient.Close())
}()

Expand Down Expand Up @@ -7996,7 +8006,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) {
isNilOrEOFErr(t, transferSess.Close())
})

err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID())
err = transferSess.Setenv(ctx, string(telesftp.EnvModeratedSessionID), sessTracker.GetSessionID())
require.NoError(t, err)

err = transferSess.RequestSubsystem(ctx, teleport.SFTPSubsystem)
Expand Down Expand Up @@ -8058,7 +8068,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, transferSess.Close())
})

err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID())
err = transferSess.Setenv(ctx, string(telesftp.EnvModeratedSessionID), sessTracker.GetSessionID())
require.NoError(t, err)

// Test that only operations needed to complete the download
Expand Down
19 changes: 2 additions & 17 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -2303,7 +2302,7 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
// Reuse the existing nodeClient we connected above.
return nodeClient.RunCommand(ctx, command)
}
return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil))
return trace.Wrap(nodeClient.RunInteractiveShell(ctx, "", "", nil))
}

func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error {
Expand Down Expand Up @@ -2457,7 +2456,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan
}

// running shell with a given session means "join" it:
err = nc.RunInteractiveShell(ctx, mode, session, beforeStart)
err = nc.RunInteractiveShell(ctx, sessionID.String(), mode, beforeStart)
return trace.Wrap(err)
}

Expand Down Expand Up @@ -3155,20 +3154,6 @@ func (tc *TeleportClient) writeCommandResults(nodes []execResult) error {
return nil
}

func (tc *TeleportClient) newSessionEnv() map[string]string {
env := map[string]string{
teleport.SSHSessionWebProxyAddr: tc.WebProxyAddr,
}
if tc.SessionID != "" {
env[sshutils.SessionEnvVar] = tc.SessionID
}

for key, val := range tc.ExtraEnvs {
env[key] = val
}
return env
}

// getProxyLogin determines which SSH principal to use when connecting to proxy.
func (tc *TeleportClient) getProxySSHPrincipal() string {
if tc.ProxySSHPrincipal != "" {
Expand Down
Loading
Loading