diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go index dd0a20c1d3217..068999656d659 100644 --- a/api/observability/tracing/ssh/client.go +++ b/api/observability/tracing/ssh/client.go @@ -20,7 +20,6 @@ import ( "fmt" "net" "sync" - "sync/atomic" "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" @@ -29,6 +28,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/utils/sshutils" ) // Client is a wrapper around ssh.Client that adds tracing support. @@ -172,11 +172,11 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) { // tracing context so that spans may be correlated properly over the ssh // connection. The handling of channel requests from the underlying SSH // session can be controlled with chanReqCallback. -func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) { +func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback sshutils.ChannelRequestCallback) (*Session, error) { return c.newSession(ctx, chanReqCallback) } -func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) { +func (c *Client) newSession(ctx context.Context, chanReqCallback sshutils.ChannelRequestCallback) (*Session, error) { tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) ctx, span := tracer.Start( @@ -229,13 +229,8 @@ type clientWrapper struct { contexts map[string][]context.Context } -// ChannelRequestCallback allows the handling of channel requests -// to be customized. nil can be returned if you don't want -// golang/x/crypto/ssh to handle the request. -type ChannelRequestCallback func(req *ssh.Request) *ssh.Request - // NewSession opens a new Session for this client. -func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, error) { +func (c *clientWrapper) NewSession(callback sshutils.ChannelRequestCallback) (*Session, error) { // create a client that will defer to us when // opening the "session" channel so that we // can add an Envelope to the request @@ -243,40 +238,9 @@ func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, e Conn: c, } - var session *ssh.Session - var err error - if callback != nil { - // open a session manually so we can take ownership of the - // requests chan - ch, originalReqs, openChannelErr := client.OpenChannel("session", nil) - if openChannelErr != nil { - return nil, trace.Wrap(openChannelErr) - } - - // pass the channel requests to the provided callback and - // forward them to another chan so golang.org/x/crypto/ssh - // can handle Session exiting correctly - reqs := make(chan *ssh.Request, cap(originalReqs)) - go func() { - defer close(reqs) - - for req := range originalReqs { - if req := callback(req); req != nil { - reqs <- req - } - } - }() - - session, err = newCryptoSSHSession(ch, reqs) - if err != nil { - _ = ch.Close() - return nil, trace.Wrap(err) - } - } else { - session, err = client.NewSession() - if err != nil { - return nil, trace.Wrap(err) - } + session, err := sshutils.NewSession(client, callback) + if err != nil { + return nil, trace.Wrap(err) } // wrap the session so all session requests on the channel @@ -287,39 +251,6 @@ func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, e }, nil } -// wrappedSSHConn allows an SSH session to be created while also allowing -// callers to take ownership of the SSH channel requests chan. -type wrappedSSHConn struct { - ssh.Conn - - channelOpened atomic.Bool - - ch ssh.Channel - reqs <-chan *ssh.Request -} - -func (f *wrappedSSHConn) OpenChannel(_ string, _ []byte) (ssh.Channel, <-chan *ssh.Request, error) { - if !f.channelOpened.CompareAndSwap(false, true) { - panic("wrappedSSHConn OpenChannel called more than once") - } - - return f.ch, f.reqs, nil -} - -// newCryptoSSHSession allows callers to take ownership of the SSH -// channel requests chan and allow callers to handle SSH channel requests. -// golang.org/x/crypto/ssh.(Client).NewSession takes ownership of all -// SSH channel requests and doesn't allow the caller to view or reply -// to them, so this workaround is needed. -func newCryptoSSHSession(ch ssh.Channel, reqs <-chan *ssh.Request) (*ssh.Session, error) { - return (&ssh.Client{ - Conn: &wrappedSSHConn{ - ch: ch, - reqs: reqs, - }, - }).NewSession() -} - // Dial initiates a connection to the addr from the remote host. func (c *clientWrapper) Dial(n, addr string) (net.Conn, error) { // create a client that will defer to us when diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index b59549f2181bc..3fb7a297369ac 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -242,29 +242,3 @@ func TestSetEnvs(t *testing.T) { default: } } - -type mockSSHChannel struct { - ssh.Channel -} - -func TestWrappedSSHConn(t *testing.T) { - sshCh := new(mockSSHChannel) - reqs := make(<-chan *ssh.Request) - - // ensure that OpenChannel returns the same SSH channel and requests - // chan that wrappedSSHConn was given - wrappedConn := &wrappedSSHConn{ - ch: sshCh, - reqs: reqs, - } - retCh, retReqs, err := wrappedConn.OpenChannel("", nil) - require.NoError(t, err) - require.Equal(t, sshCh, retCh) - require.Equal(t, reqs, retReqs) - - // ensure the wrapped SSH conn will panic if OpenChannel is called - // twice - require.Panics(t, func() { - wrappedConn.OpenChannel("", nil) - }) -} diff --git a/api/utils/sshutils/session.go b/api/utils/sshutils/session.go new file mode 100644 index 0000000000000..65332a535a784 --- /dev/null +++ b/api/utils/sshutils/session.go @@ -0,0 +1,106 @@ +// Copyright 2025 Gravitational, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sshutils + +import ( + "sync/atomic" + + "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" +) + +// ChannelRequestCallback allows the handling of channel requests +// to be customized. nil can be returned if you don't want +// golang/x/crypto/ssh to handle the request. +type ChannelRequestCallback func(req *ssh.Request) *ssh.Request + +// NewSession opens a new Session for this client. +func NewSession(client *ssh.Client, callback ChannelRequestCallback) (*ssh.Session, error) { + // No custom request handling needed. We can use the basic golang/x/crypto/ssh implementation. + if callback == nil { + session, err := client.NewSession() + if err != nil { + return nil, trace.Wrap(err) + } + return session, nil + } + + // open a session manually so we can take ownership of the + // requests chan + ch, originalReqs, openChannelErr := client.OpenChannel("session", nil) + if openChannelErr != nil { + return nil, trace.Wrap(openChannelErr) + } + + handleReqs := originalReqs + if callback != nil { + reqs := make(chan *ssh.Request, cap(originalReqs)) + handleReqs = reqs + + // pass the channel requests to the provided callback and + // forward them to another chan so golang.org/x/crypto/ssh + // can handle Session exiting correctly + go func() { + defer close(reqs) + + for req := range originalReqs { + if req := callback(req); req != nil { + reqs <- req + } + } + }() + } + + session, err := newCryptoSSHSession(ch, handleReqs) + if err != nil { + _ = ch.Close() + return nil, trace.Wrap(err) + } + + return session, nil +} + +// wrappedSSHConn allows an SSH session to be created while also allowing +// callers to take ownership of the SSH channel requests chan. +type wrappedSSHConn struct { + ssh.Conn + + channelOpened atomic.Bool + + ch ssh.Channel + reqs <-chan *ssh.Request +} + +func (f *wrappedSSHConn) OpenChannel(_ string, _ []byte) (ssh.Channel, <-chan *ssh.Request, error) { + if !f.channelOpened.CompareAndSwap(false, true) { + panic("WrappedSSHConn.OpenChannel called more than once") + } + + return f.ch, f.reqs, nil +} + +// newCryptoSSHSession allows callers to take ownership of the SSH +// channel requests chan and allow callers to handle SSH channel requests. +// golang.org/x/crypto/ssh.(Client).NewSession takes ownership of all +// SSH channel requests and doesn't allow the caller to view or reply +// to them, so this workaround is needed. +func newCryptoSSHSession(ch ssh.Channel, reqs <-chan *ssh.Request) (*ssh.Session, error) { + return (&ssh.Client{ + Conn: &wrappedSSHConn{ + ch: ch, + reqs: reqs, + }, + }).NewSession() +} diff --git a/api/utils/sshutils/session_test.go b/api/utils/sshutils/session_test.go new file mode 100644 index 0000000000000..87e104a52641c --- /dev/null +++ b/api/utils/sshutils/session_test.go @@ -0,0 +1,48 @@ +// Copyright 2025 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sshutils + +import ( + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +type mockSSHChannel struct { + ssh.Channel +} + +func TestWrappedSSHConn(t *testing.T) { + sshCh := new(mockSSHChannel) + reqs := make(<-chan *ssh.Request) + + // ensure that OpenChannel returns the same SSH channel and requests + // chan that wrappedSSHConn was given + wrappedConn := &wrappedSSHConn{ + ch: sshCh, + reqs: reqs, + } + retCh, retReqs, err := wrappedConn.OpenChannel("", nil) + require.NoError(t, err) + require.Equal(t, sshCh, retCh) + require.Equal(t, reqs, retReqs) + + // ensure the wrapped SSH conn will panic if OpenChannel is called + // twice + require.Panics(t, func() { + wrappedConn.OpenChannel("", nil) + }) +} diff --git a/lib/client/api.go b/lib/client/api.go index d26bc9d0d9d15..6d02e31c75385 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -75,6 +75,7 @@ import ( "github.com/gravitational/teleport/api/utils/keys/hardwarekey" "github.com/gravitational/teleport/api/utils/pingconn" "github.com/gravitational/teleport/api/utils/prompt" + apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/touchid" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" @@ -1262,7 +1263,7 @@ type TeleportClient struct { // OnChannelRequest gets called when SSH channel requests are // received. It's safe to keep it nil. - OnChannelRequest tracessh.ChannelRequestCallback + OnChannelRequest apisshutils.ChannelRequestCallback // OnShellCreated gets called when the shell is created. It's // safe to keep it nil. diff --git a/lib/client/client.go b/lib/client/client.go index 21dffe27825b1..fef31ca8d4363 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -49,6 +49,7 @@ import ( apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -391,7 +392,7 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co // RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr // to and from the node and local shell. This will block until the interactive shell on the node // is terminated. -func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer)) error { +func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, chanReqCallback sshutils.ChannelRequestCallback, beforeStart func(io.Writer)) error { ctx, span := c.Tracer.Start( ctx, "nodeClient/RunInteractiveShell", diff --git a/lib/client/session.go b/lib/client/session.go index 9edc9c5b7ab2b..edbe21fe9e1b7 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -41,6 +41,7 @@ import ( "github.com/gravitational/teleport" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" + apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/client/escape" "github.com/gravitational/teleport/lib/client/terminal" "github.com/gravitational/teleport/lib/defaults" @@ -178,7 +179,7 @@ func (ns *NodeSession) NodeClient() *NodeClient { return ns.nodeClient } -func (ns *NodeSession) regularSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback func(s *tracessh.Session) error) error { +func (ns *NodeSession) regularSession(ctx context.Context, chanReqCallback apisshutils.ChannelRequestCallback, sessionCallback func(s *tracessh.Session) error) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/regularSession", @@ -198,7 +199,7 @@ func (ns *NodeSession) regularSession(ctx context.Context, chanReqCallback trace type interactiveCallback func(serverSession *tracessh.Session, shell io.ReadWriteCloser) error -func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback) (*tracessh.Session, error) { +func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback apisshutils.ChannelRequestCallback) (*tracessh.Session, error) { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/createServerSession", @@ -272,7 +273,7 @@ func selectKeyAgent(ctx context.Context, tc *TeleportClient) sshagent.ClientGett // interactiveSession creates an interactive session on the remote node, executes // the given callback on it, and waits for the session to end -func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback interactiveCallback) error { +func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback apisshutils.ChannelRequestCallback, sessionCallback interactiveCallback) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/interactiveSession", @@ -516,7 +517,7 @@ func (s *sessionWriter) Write(p []byte) (int, error) { } // runShell executes user's shell on the remote node under an interactive session -func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { +func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback apisshutils.ChannelRequestCallback, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { return ns.interactiveSession(ctx, mode, chanReqCallback, func(s *tracessh.Session, shell io.ReadWriteCloser) error { w := &sessionWriter{ tshOut: ns.nodeClient.TC.Stdout, @@ -545,7 +546,7 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa // runCommand executes a "exec" request either in interactive mode (with a // TTY attached) or non-intractive mode (no TTY). -func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, chanReqCallback tracessh.ChannelRequestCallback, shellCallback ShellCreatedCallback, interactive bool) error { +func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, chanReqCallback apisshutils.ChannelRequestCallback, shellCallback ShellCreatedCallback, interactive bool) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/runCommand",