diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go index dd0a20c1d3217..9b1a2f1149be5 100644 --- a/api/observability/tracing/ssh/client.go +++ b/api/observability/tracing/ssh/client.go @@ -36,6 +36,9 @@ type Client struct { *ssh.Client opts []tracing.Option capability tracingCapability + + requestHandlersMu sync.Mutex + requestHandlers map[string]RequestHandlerFn } type tracingCapability int @@ -56,9 +59,10 @@ const ( // of whether they should provide tracing context. func NewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, opts ...tracing.Option) *Client { clt := &Client{ - Client: ssh.NewClient(c, chans, reqs), - opts: opts, - capability: tracingUnsupported, + Client: ssh.NewClient(c, chans, reqs), + opts: opts, + capability: tracingUnsupported, + requestHandlers: map[string]RequestHandlerFn{}, } if bytes.HasPrefix(clt.ServerVersion(), []byte("SSH-2.0-Teleport")) { @@ -89,7 +93,7 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err ) defer span.End() - // create the wrapper while the lock is held + // create a new wrapper to propagate tracing span context. wrapper := &clientWrapper{ capability: c.capability, Conn: c.Client.Conn, @@ -165,18 +169,6 @@ func (c *Client) OpenChannel( // NewSession creates a new SSH session that is passed tracing context // so that spans may be correlated properly over the ssh connection. func (c *Client) NewSession(ctx context.Context) (*Session, error) { - return c.newSession(ctx, nil) -} - -// NewSessionWithRequestCallback creates a new SSH session that is passed -// tracing context so that spans may be correlated properly over the ssh -// connection. The handling of channel requests from the underlying SSH -// session can be controlled with chanReqCallback. -func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) { - return c.newSession(ctx, chanReqCallback) -} - -func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) { tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) ctx, span := tracer.Start( @@ -194,7 +186,7 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC ) defer span.End() - // create the wrapper while the lock is still held + // create a new wrapper to propagate tracing span context. wrapper := &clientWrapper{ capability: c.capability, Conn: c.Client.Conn, @@ -203,9 +195,92 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC contexts: make(map[string][]context.Context), } - // get a session from the wrapper - session, err := wrapper.NewSession(chanReqCallback) - return session, trace.Wrap(err) + // open a session manually so we can take ownership of the + // requests chan + ch, reqs, err := wrapper.OpenChannel("session", nil) + if err != nil { + return nil, trace.Wrap(err) + } + + unhandledReqs := c.serveSessionRequests(ctx, reqs) + session, err := newCryptoSSHSession(ch, unhandledReqs) + if err != nil { + _ = ch.Close() + return nil, trace.Wrap(err) + } + + // wrap the session so all session requests on the channel + // can be traced + return &Session{ + Session: session, + wrapper: wrapper, + }, nil +} + +// RequestHandlerFn is an ssh request handler function. +type RequestHandlerFn func(ctx context.Context, ch *ssh.Request) + +// HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the +// provided type within a session. If the type is already being handled, an error is returned. +// All registered handlers are consumed by the next call to [Client.NewSession]. +func (c *Client) HandleSessionRequest(ctx context.Context, requestType string, handlerFn RequestHandlerFn) error { + c.requestHandlersMu.Lock() + defer c.requestHandlersMu.Unlock() + + if _, ok := c.requestHandlers[requestType]; ok { + return trace.AlreadyExists("ssh request type %q is already being handled for this session", requestType) + } + + c.requestHandlers[requestType] = handlerFn + return nil +} + +// serveSessionRequests from the remote side with registered handlers. +// +// This method consumes all registered handlers so that the next call to +// [Client.NewSession] will not reuse the same handlers. +func (c *Client) serveSessionRequests(ctx context.Context, in <-chan *ssh.Request) <-chan *ssh.Request { + c.requestHandlersMu.Lock() + requestHandlers := c.requestHandlers + c.requestHandlers = make(map[string]RequestHandlerFn) + c.requestHandlersMu.Unlock() + + // Capture requests not handled by registered request handlers and + // pass them to the crypto [ssh.Session]. + unhandledReqs := make(chan *ssh.Request, cap(in)) + + tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) + go func() { + defer close(unhandledReqs) + for req := range in { + ctx, span := tracer.Start( + ctx, + fmt.Sprintf("ssh.HandleRequests/%s", req.Type), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + append( + peerAttr(c.Conn.RemoteAddr()), + semconv.RPCServiceKey.String("ssh.Client"), + semconv.RPCMethodKey.String("HandleRequests"), + semconv.RPCSystemKey.String("ssh"), + )..., + ), + ) + + handler, ok := requestHandlers[req.Type] + if ok { + handler(ctx, req) + } else { + // Pass on requests without a registered handler. These will be + // handled by the default x/crypto/ssh request handler. + unhandledReqs <- req + } + + span.End() + } + }() + + return unhandledReqs } // clientWrapper wraps the ssh.Conn for individual ssh.Client @@ -229,64 +304,6 @@ type clientWrapper struct { contexts map[string][]context.Context } -// ChannelRequestCallback allows the handling of channel requests -// to be customized. nil can be returned if you don't want -// golang/x/crypto/ssh to handle the request. -type ChannelRequestCallback func(req *ssh.Request) *ssh.Request - -// NewSession opens a new Session for this client. -func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, error) { - // create a client that will defer to us when - // opening the "session" channel so that we - // can add an Envelope to the request - client := &ssh.Client{ - Conn: c, - } - - var session *ssh.Session - var err error - if callback != nil { - // open a session manually so we can take ownership of the - // requests chan - ch, originalReqs, openChannelErr := client.OpenChannel("session", nil) - if openChannelErr != nil { - return nil, trace.Wrap(openChannelErr) - } - - // pass the channel requests to the provided callback and - // forward them to another chan so golang.org/x/crypto/ssh - // can handle Session exiting correctly - reqs := make(chan *ssh.Request, cap(originalReqs)) - go func() { - defer close(reqs) - - for req := range originalReqs { - if req := callback(req); req != nil { - reqs <- req - } - } - }() - - session, err = newCryptoSSHSession(ch, reqs) - if err != nil { - _ = ch.Close() - return nil, trace.Wrap(err) - } - } else { - session, err = client.NewSession() - if err != nil { - return nil, trace.Wrap(err) - } - } - - // wrap the session so all session requests on the channel - // can be traced - return &Session{ - Session: session, - wrapper: c, - }, nil -} - // wrappedSSHConn allows an SSH session to be created while also allowing // callers to take ownership of the SSH channel requests chan. type wrappedSSHConn struct { diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index b59549f2181bc..c7e8b5d7b9c38 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -21,7 +21,7 @@ import ( "testing" "time" - "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" ) @@ -48,9 +48,8 @@ func TestIsTracingSupported(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - errChan := make(chan error, 5) - srv := newServer(t, tt.expectedCapability, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + srv := newServer(t, tt.srvVersion, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { go ssh.DiscardRequests(requests) for { @@ -64,29 +63,17 @@ func TestIsTracingSupported(t *testing.T) { } if err := ch.Reject(ssh.Prohibited, "no channels allowed"); err != nil { - errChan <- trace.Wrap(err, "rejecting channel") + assert.NoError(t, err, "rejecting channel") return } } } }) - if tt.srvVersion != "" { - srv.config.ServerVersion = tt.srvVersion - } - - go srv.Run(errChan) - conn, chans, reqs := srv.GetClient(t) client := NewClient(conn, chans, reqs) require.Equal(t, tt.expectedCapability, client.capability) - - select { - case err := <-errChan: - require.NoError(t, err) - default: - } }) } } @@ -104,14 +91,13 @@ func TestSetEnvs(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - errChan := make(chan error, 5) expected := map[string]string{"a": "1", "b": "2", "c": "3"} // used to collect individual envs requests envReqC := make(chan envReqParams, 3) - srv := newServer(t, tracingSupported, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + srv := newServer(t, tracingSupportedVersion, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { for { select { case <-ctx.Done(): @@ -123,7 +109,7 @@ func TestSetEnvs(t *testing.T) { case ch.ChannelType() == "session": ch, reqs, err := ch.Accept() if err != nil { - errChan <- trace.Wrap(err, "failed to accept session channel") + assert.NoError(t, err, "failed to accept session channel") return } @@ -178,7 +164,7 @@ func TestSetEnvs(t *testing.T) { _ = req.Reply(true, nil) default: // out of order or unexpected message _ = req.Reply(false, []byte(fmt.Sprintf("unexpected ssh request %s on iteration %d", req.Type, i))) - errChan <- err + assert.NoError(t, err) return } } @@ -186,7 +172,7 @@ func TestSetEnvs(t *testing.T) { }() default: if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %s", ch.ChannelType())); err != nil { - errChan <- err + assert.NoError(t, err) return } } @@ -194,8 +180,6 @@ func TestSetEnvs(t *testing.T) { } }) - go srv.Run(errChan) - // create a client and open a session conn, chans, reqs := srv.GetClient(t) client := NewClient(conn, chans, reqs) @@ -235,12 +219,6 @@ func TestSetEnvs(t *testing.T) { require.Equal(t, v, actual, "expected value %s for env %s, got %s", v, k, actual) } }) - - select { - case err := <-errChan: - require.NoError(t, err) - default: - } } type mockSSHChannel struct { @@ -268,3 +246,136 @@ func TestWrappedSSHConn(t *testing.T) { wrappedConn.OpenChannel("", nil) }) } + +// TestGlobalAndSessionRequests tests that the tracing client correctly handles global and session requests. +func TestGlobalAndSessionRequests(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // pingRequest is an example request type. Whether sent by the server or client in + // a global or session context, the receiver should give an ok as the reply. + pingRequest := "ping@goteleport.com" + + clientGlobalReply := make(chan bool, 1) + clientSessionReply := make(chan bool, 1) + + srv := newServer(t, tracingSupportedVersion, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + // Send a ping request when the client connection is established. + ok, _, err := conn.SendRequest(pingRequest, true, nil) + assert.NoError(t, err, "server failed to send global ping request") + clientGlobalReply <- ok + + for { + select { + case <-ctx.Done(): + return + case req := <-requests: + switch req.Type { + case pingRequest: + err := req.Reply(true, nil) + assert.NoError(t, err, "server failed to reply to global ping request") + default: + err := req.Reply(false, nil) + assert.NoError(t, err, "server failed to reply to global %q request", req.Type) + } + case ch := <-channels: + switch { + case ch == nil: + return + case ch.ChannelType() == "session": + ch, reqs, err := ch.Accept() + if err != nil { + assert.NoError(t, err, "failed to accept session channel") + return + } + + go func() { + defer ch.Close() + for { + select { + case <-ctx.Done(): + return + case req := <-reqs: + switch req.Type { + case pingRequest: + err := req.Reply(true, nil) + assert.NoError(t, err, "server failed to reply to session ping request") + } + continue + } + } + }() + + // Send a ping request when the session is established. + ok, err := ch.SendRequest(pingRequest, true, nil) + assert.NoError(t, err, "server failed to send ping request") + clientSessionReply <- ok + default: + err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %s", ch.ChannelType())) + assert.NoError(t, err) + } + } + } + }) + + conn, chans, reqs := srv.GetClient(t) + client := NewClient(conn, chans, reqs) + + // The client should reply false to any global request from the server, as we + // don't currently support a mechanism for the client to register global handlers. + select { + case reply := <-clientGlobalReply: + require.False(t, reply, "Expected the client to reply false to global ping request") + case <-time.After(10 * time.Second): + t.Fatalf("Failed to receive client global reply to ping request") + } + + // The server should reply true to a global ping request. + ok, _, err := client.SendRequest(ctx, pingRequest, true, nil) + require.True(t, ok, "Expected the server to reply true to global ping request") + require.NoError(t, err) + + // If the client isn't setup to handle session requests, it should reply false to them. + // The client should reply true to a session ping request. + _, err = client.NewSession(ctx) + require.NoError(t, err) + + select { + case reply := <-clientSessionReply: + require.False(t, reply, "Expected the client to reply false to session ping request") + case <-time.After(10 * time.Second): + t.Fatalf("Failed to receive client session reply to ping request") + } + + // The client should reply true to a session ping request. + err = client.HandleSessionRequest(ctx, pingRequest, func(ctx context.Context, req *ssh.Request) { + err := req.Reply(true, nil) + assert.NoError(t, err) + }) + require.NoError(t, err) + _, err = client.NewSession(ctx) + require.NoError(t, err) + + select { + case reply := <-clientSessionReply: + require.True(t, reply, "Expected the client to reply true to session ping request") + case <-time.After(10 * time.Second): + t.Fatalf("Failed to receive client session reply to ping request") + } + + // New Sessions do not reuse previously registered handlers. + session, err := client.NewSession(ctx) + require.NoError(t, err) + + select { + case reply := <-clientSessionReply: + require.False(t, reply, "Expected the client to reply false to session ping request") + case <-time.After(10 * time.Second): + t.Fatalf("Failed to receive client session reply to ping request") + } + + // The server should reply true to a session ping request. + ok, err = session.SendRequest(ctx, pingRequest, true, nil) + require.NoError(t, err) + require.True(t, ok, "Expected the server to reply true to session ping request") +} diff --git a/api/observability/tracing/ssh/ssh_test.go b/api/observability/tracing/ssh/ssh_test.go index e7618a3be15e3..7d43d7c8fe244 100644 --- a/api/observability/tracing/ssh/ssh_test.go +++ b/api/observability/tracing/ssh/ssh_test.go @@ -45,31 +45,6 @@ type server struct { hSigner ssh.Signer } -func (s *server) Run(errC chan error) { - for { - conn, err := s.listener.Accept() - if err != nil { - if !errors.Is(err, net.ErrClosed) { - errC <- err - } - return - } - - go func() { - sconn, chans, reqs, err := ssh.NewServerConn(conn, s.config) - if err != nil { - errC <- err - return - } - s.handler(sconn, chans, reqs) - }() - } -} - -func (s *server) Stop() error { - return s.listener.Close() -} - func generateSigner(t *testing.T) ssh.Signer { _, private, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) @@ -91,18 +66,18 @@ func (s *server) GetClient(t *testing.T) (ssh.Conn, <-chan ssh.NewChannel, <-cha return sconn, nc, r } -func newServer(t *testing.T, tracingCap tracingCapability, handler func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)) *server { +const ( + tracingSupportedVersion = "SSH-2.0-Teleport" + tracingUnsupportedVersion = "SSH-2.0" +) + +func newServer(t *testing.T, version string, handler func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)) *server { listener, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) cSigner := generateSigner(t) hSigner := generateSigner(t) - version := "SSH-2.0-Teleport" - if tracingCap != tracingSupported { - version = "SSH-2.0" - } - config := &ssh.ServerConfig{ NoClientAuth: true, ServerVersion: version, @@ -117,7 +92,33 @@ func newServer(t *testing.T, tracingCap tracingCapability, handler func(*ssh.Ser hSigner: hSigner, } - t.Cleanup(func() { require.NoError(t, srv.Stop()) }) + errC := make(chan error, 1) + go func() { + defer close(errC) + for { + conn, err := srv.listener.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + errC <- err + } + return + } + + go func() { + sconn, chans, reqs, err := ssh.NewServerConn(conn, srv.config) + if err != nil { + errC <- err + return + } + srv.handler(sconn, chans, reqs) + }() + } + }() + + t.Cleanup(func() { + require.NoError(t, srv.listener.Close()) + require.NoError(t, <-errC) + }) return srv } @@ -300,8 +301,12 @@ func TestClient(t *testing.T) { ctx: ctx, } - srv := newServer(t, tt.tracingSupported, handler.handle) - go srv.Run(errChan) + version := tracingSupportedVersion + if tt.tracingSupported != tracingSupported { + version = tracingUnsupportedVersion + } + + srv := newServer(t, version, handler.handle) tp := sdktrace.NewTracerProvider() conn, chans, reqs := srv.GetClient(t) diff --git a/integration/integration_test.go b/integration/integration_test.go index 334e6c0569e59..91a12518eab9f 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1131,7 +1131,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { ) assert.NoError(t, err) - errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil, nil) + errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil) assert.NoError(t, nodeClient.Close()) }() diff --git a/lib/client/api.go b/lib/client/api.go index ebcd78022b449..f9f5bc8b1ca3f 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1260,10 +1260,6 @@ type TeleportClient struct { localAgent *LocalKeyAgent - // OnChannelRequest gets called when SSH channel requests are - // received. It's safe to keep it nil. - OnChannelRequest tracessh.ChannelRequestCallback - // OnShellCreated gets called when the shell is created. It's // safe to keep it nil. OnShellCreated ShellCreatedCallback @@ -2259,7 +2255,7 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt // Reuse the existing nodeClient we connected above. return nodeClient.RunCommand(ctx, command) } - return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, tc.OnChannelRequest, nil)) + return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil)) } func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error { @@ -2413,7 +2409,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan } // running shell with a given session means "join" it: - err = nc.RunInteractiveShell(ctx, mode, session, tc.OnChannelRequest, beforeStart) + err = nc.RunInteractiveShell(ctx, mode, session, beforeStart) return trace.Wrap(err) } diff --git a/lib/client/client.go b/lib/client/client.go index 21dffe27825b1..c2ff704c1dac2 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -391,7 +391,7 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co // RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr // to and from the node and local shell. This will block until the interactive shell on the node // is terminated. -func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer)) error { +func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, beforeStart func(io.Writer)) error { ctx, span := c.Tracer.Start( ctx, "nodeClient/RunInteractiveShell", @@ -420,7 +420,7 @@ func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.Session return trace.Wrap(err) } - if err = nodeSession.runShell(ctx, mode, c.TC.OnChannelRequest, beforeStart, c.TC.OnShellCreated); err != nil { + if err = nodeSession.runShell(ctx, mode, beforeStart, c.TC.OnShellCreated); err != nil { var exitErr *ssh.ExitError var exitMissingErr *ssh.ExitMissingError switch err := trace.Unwrap(err); { @@ -621,7 +621,8 @@ func (c *NodeClient) RunCommand(ctx context.Context, command []string, opts ...R return trace.Wrap(err) } defer nodeSession.Close() - err = nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnChannelRequest, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand) + + err = nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand) if err != nil { c.TC.SetExitStatus(getExitStatus(err)) } diff --git a/lib/client/session.go b/lib/client/session.go index 53ad5eecc7c99..69cdbddb0c6ec 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -171,7 +171,7 @@ func (ns *NodeSession) NodeClient() *NodeClient { return ns.nodeClient } -func (ns *NodeSession) regularSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback func(s *tracessh.Session) error) error { +func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func(s *tracessh.Session) error) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/regularSession", @@ -179,7 +179,7 @@ func (ns *NodeSession) regularSession(ctx context.Context, chanReqCallback trace ) defer span.End() - session, err := ns.createServerSession(ctx, chanReqCallback) + session, err := ns.createServerSession(ctx) if err != nil { return trace.Wrap(err) } @@ -191,7 +191,7 @@ func (ns *NodeSession) regularSession(ctx context.Context, chanReqCallback trace type interactiveCallback func(serverSession *tracessh.Session, shell io.ReadWriteCloser) error -func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback) (*tracessh.Session, error) { +func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Session, error) { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/createServerSession", @@ -199,7 +199,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback ) defer span.End() - sess, err := ns.nodeClient.Client.NewSessionWithRequestCallback(ctx, chanReqCallback) + sess, err := ns.nodeClient.Client.NewSession(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -267,7 +267,7 @@ func selectKeyAgent(ctx context.Context, tc *TeleportClient) sshagent.ClientGett // interactiveSession creates an interactive session on the remote node, executes // the given callback on it, and waits for the session to end -func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback interactiveCallback) error { +func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, sessionCallback interactiveCallback) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/interactiveSession", @@ -281,7 +281,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio termType = teleport.SafeTerminalType } // create the server-side session: - sess, err := ns.createServerSession(ctx, chanReqCallback) + sess, err := ns.createServerSession(ctx) if err != nil { return trace.Wrap(err) } @@ -511,8 +511,8 @@ func (s *sessionWriter) Write(p []byte) (int, error) { } // runShell executes user's shell on the remote node under an interactive session -func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { - return ns.interactiveSession(ctx, mode, chanReqCallback, func(s *tracessh.Session, shell io.ReadWriteCloser) error { +func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { + return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, shell io.ReadWriteCloser) error { w := &sessionWriter{ tshOut: ns.nodeClient.TC.Stdout, session: s, @@ -540,7 +540,7 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa // runCommand executes a "exec" request either in interactive mode (with a // TTY attached) or non-intractive mode (no TTY). -func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, chanReqCallback tracessh.ChannelRequestCallback, shellCallback ShellCreatedCallback, interactive bool) error { +func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, shellCallback ShellCreatedCallback, interactive bool) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/runCommand", @@ -554,7 +554,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // keyboard based signals will be propogated to the TTY on the server which is // where all signal handling will occur. if interactive { - return ns.interactiveSession(ctx, mode, chanReqCallback, func(s *tracessh.Session, term io.ReadWriteCloser) error { + return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, term io.ReadWriteCloser) error { err := s.Start(ctx, strings.Join(cmd, " ")) if err != nil { return trace.Wrap(err) @@ -581,7 +581,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // Unfortunately at the moment the Go SSH library Teleport uses does not // support sending SSH_MSG_DISCONNECT. Instead we close the SSH channel and // SSH client, and try and exit as gracefully as possible. - return ns.regularSession(ctx, chanReqCallback, func(s *tracessh.Session) error { + return ns.regularSession(ctx, func(s *tracessh.Session) error { errCh := make(chan error, 1) go func() { errCh <- s.Run(ctx, strings.Join(cmd, " ")) diff --git a/lib/web/sessions.go b/lib/web/sessions.go index b5a88b27c3761..a43bed02c868b 100644 --- a/lib/web/sessions.go +++ b/lib/web/sessions.go @@ -1350,25 +1350,23 @@ func prepareToReceiveSessionID(ctx context.Context, log *slog.Logger, nc *client // send the session ID received from the server var gotSessionID atomic.Bool sessionIDFromServer := make(chan session.ID, 1) - nc.TC.OnChannelRequest = func(req *ssh.Request) *ssh.Request { - // ignore unrelated requests and handle only the first session - // ID request - if req.Type != teleport.CurrentSessionIDRequest || gotSessionID.Load() { - return req + + nc.Client.HandleSessionRequest(ctx, teleport.CurrentSessionIDRequest, func(ctx context.Context, req *ssh.Request) { + // only handle the first session ID request + if gotSessionID.Load() { + return } sid, err := session.ParseID(string(req.Payload)) if err != nil { log.WarnContext(ctx, "Unable to parse session ID", "error", err) - return nil + return } if gotSessionID.CompareAndSwap(false, true) { sessionIDFromServer <- *sid } - - return nil - } + }) // If the session is about to close and we haven't received a session // ID yet, ask if the server even supports sending one. Send the diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 3319061ac5e9b..d8b3427b0e5fc 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -867,7 +867,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor // Establish SSH connection to the server. This function will block until // either an error occurs or it completes successfully. - if err = nc.RunInteractiveShell(ctx, t.participantMode, t.tracker, nil, beforeStart); err != nil { + if err = nc.RunInteractiveShell(ctx, t.participantMode, t.tracker, beforeStart); err != nil { if !t.closedByClient.Load() { t.stream.WriteError(ctx, err.Error()) }