Skip to content
Merged
173 changes: 95 additions & 78 deletions api/observability/tracing/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ type Client struct {
*ssh.Client
opts []tracing.Option
capability tracingCapability

requestHandlersMu sync.Mutex
requestHandlers map[string]RequestHandlerFn
}

type tracingCapability int
Expand All @@ -56,9 +59,10 @@ const (
// of whether they should provide tracing context.
func NewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, opts ...tracing.Option) *Client {
clt := &Client{
Client: ssh.NewClient(c, chans, reqs),
opts: opts,
capability: tracingUnsupported,
Client: ssh.NewClient(c, chans, reqs),
opts: opts,
capability: tracingUnsupported,
requestHandlers: map[string]RequestHandlerFn{},
}

if bytes.HasPrefix(clt.ServerVersion(), []byte("SSH-2.0-Teleport")) {
Expand Down Expand Up @@ -89,7 +93,7 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err
)
defer span.End()

// create the wrapper while the lock is held
// create a new wrapper to propagate tracing span context.
wrapper := &clientWrapper{
capability: c.capability,
Conn: c.Client.Conn,
Expand Down Expand Up @@ -165,18 +169,6 @@ func (c *Client) OpenChannel(
// NewSession creates a new SSH session that is passed tracing context
// so that spans may be correlated properly over the ssh connection.
func (c *Client) NewSession(ctx context.Context) (*Session, error) {
return c.newSession(ctx, nil)
}

// NewSessionWithRequestCallback creates a new SSH session that is passed
// tracing context so that spans may be correlated properly over the ssh
// connection. The handling of channel requests from the underlying SSH
// session can be controlled with chanReqCallback.
func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
return c.newSession(ctx, chanReqCallback)
}

func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

ctx, span := tracer.Start(
Expand All @@ -194,7 +186,7 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC
)
defer span.End()

// create the wrapper while the lock is still held
// create a new wrapper to propagate tracing span context.
wrapper := &clientWrapper{
capability: c.capability,
Conn: c.Client.Conn,
Expand All @@ -203,9 +195,92 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC
contexts: make(map[string][]context.Context),
}

// get a session from the wrapper
session, err := wrapper.NewSession(chanReqCallback)
return session, trace.Wrap(err)
// open a session manually so we can take ownership of the
// requests chan
ch, reqs, err := wrapper.OpenChannel("session", nil)
if err != nil {
return nil, trace.Wrap(err)
}

unhandledReqs := c.serveSessionRequests(ctx, reqs)
session, err := newCryptoSSHSession(ch, unhandledReqs)
if err != nil {
_ = ch.Close()
return nil, trace.Wrap(err)
}

// wrap the session so all session requests on the channel
// can be traced
return &Session{
Session: session,
wrapper: wrapper,
}, nil
}

// RequestHandlerFn is an ssh request handler function.
type RequestHandlerFn func(ctx context.Context, ch *ssh.Request)

// HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the
// provided type within a session. If the type is already being handled, an error is returned.
// All registered handlers are consumed by the next call to [Client.NewSession].
func (c *Client) HandleSessionRequest(ctx context.Context, requestType string, handlerFn RequestHandlerFn) error {
c.requestHandlersMu.Lock()
defer c.requestHandlersMu.Unlock()

if _, ok := c.requestHandlers[requestType]; ok {
return trace.AlreadyExists("ssh request type %q is already being handled for this session", requestType)
}

c.requestHandlers[requestType] = handlerFn
return nil
}

// serveSessionRequests from the remote side with registered handlers.
//
// This method consumes all registered handlers so that the next call to
// [Client.NewSession] will not reuse the same handlers.
func (c *Client) serveSessionRequests(ctx context.Context, in <-chan *ssh.Request) <-chan *ssh.Request {
c.requestHandlersMu.Lock()
requestHandlers := c.requestHandlers
c.requestHandlers = make(map[string]RequestHandlerFn)
c.requestHandlersMu.Unlock()

// Capture requests not handled by registered request handlers and
// pass them to the crypto [ssh.Session].
unhandledReqs := make(chan *ssh.Request, cap(in))

tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)
go func() {
defer close(unhandledReqs)
for req := range in {
ctx, span := tracer.Start(
ctx,
fmt.Sprintf("ssh.HandleRequests/%s", req.Type),
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
append(
peerAttr(c.Conn.RemoteAddr()),
semconv.RPCServiceKey.String("ssh.Client"),
semconv.RPCMethodKey.String("HandleRequests"),
semconv.RPCSystemKey.String("ssh"),
)...,
),
)

handler, ok := requestHandlers[req.Type]
if ok {
handler(ctx, req)
} else {
// Pass on requests without a registered handler. These will be
// handled by the default x/crypto/ssh request handler.
unhandledReqs <- req
}

span.End()
}
}()

return unhandledReqs
}

// clientWrapper wraps the ssh.Conn for individual ssh.Client
Expand All @@ -229,64 +304,6 @@ type clientWrapper struct {
contexts map[string][]context.Context
}

// ChannelRequestCallback allows the handling of channel requests
// to be customized. nil can be returned if you don't want
// golang/x/crypto/ssh to handle the request.
type ChannelRequestCallback func(req *ssh.Request) *ssh.Request

// NewSession opens a new Session for this client.
func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, error) {
// create a client that will defer to us when
// opening the "session" channel so that we
// can add an Envelope to the request
client := &ssh.Client{
Conn: c,
}

var session *ssh.Session
var err error
if callback != nil {
// open a session manually so we can take ownership of the
// requests chan
ch, originalReqs, openChannelErr := client.OpenChannel("session", nil)
if openChannelErr != nil {
return nil, trace.Wrap(openChannelErr)
}

// pass the channel requests to the provided callback and
// forward them to another chan so golang.org/x/crypto/ssh
// can handle Session exiting correctly
reqs := make(chan *ssh.Request, cap(originalReqs))
go func() {
defer close(reqs)

for req := range originalReqs {
if req := callback(req); req != nil {
reqs <- req
}
}
}()

session, err = newCryptoSSHSession(ch, reqs)
if err != nil {
_ = ch.Close()
return nil, trace.Wrap(err)
}
} else {
session, err = client.NewSession()
if err != nil {
return nil, trace.Wrap(err)
}
}

// wrap the session so all session requests on the channel
// can be traced
return &Session{
Session: session,
wrapper: c,
}, nil
}

// wrappedSSHConn allows an SSH session to be created while also allowing
// callers to take ownership of the SSH channel requests chan.
type wrappedSSHConn struct {
Expand Down
Loading
Loading