Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 95 additions & 78 deletions api/observability/tracing/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ type Client struct {
*ssh.Client
opts []tracing.Option
capability tracingCapability

requestHandlersMu sync.Mutex
requestHandlers map[string]RequestHandlerFn
}

type tracingCapability int
Expand All @@ -56,9 +59,10 @@ const (
// of whether they should provide tracing context.
func NewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, opts ...tracing.Option) *Client {
clt := &Client{
Client: ssh.NewClient(c, chans, reqs),
opts: opts,
capability: tracingUnsupported,
Client: ssh.NewClient(c, chans, reqs),
opts: opts,
capability: tracingUnsupported,
requestHandlers: map[string]RequestHandlerFn{},
}

if bytes.HasPrefix(clt.ServerVersion(), []byte("SSH-2.0-Teleport")) {
Expand Down Expand Up @@ -89,7 +93,7 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err
)
defer span.End()

// create the wrapper while the lock is held
// create a new wrapper to propagate tracing span context.
wrapper := &clientWrapper{
capability: c.capability,
Conn: c.Client.Conn,
Expand Down Expand Up @@ -165,18 +169,6 @@ func (c *Client) OpenChannel(
// NewSession creates a new SSH session that is passed tracing context
// so that spans may be correlated properly over the ssh connection.
func (c *Client) NewSession(ctx context.Context) (*Session, error) {
return c.newSession(ctx, nil)
}

// NewSessionWithRequestCallback creates a new SSH session that is passed
// tracing context so that spans may be correlated properly over the ssh
// connection. The handling of channel requests from the underlying SSH
// session can be controlled with chanReqCallback.
func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
return c.newSession(ctx, chanReqCallback)
}

func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

ctx, span := tracer.Start(
Expand All @@ -194,7 +186,7 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC
)
defer span.End()

// create the wrapper while the lock is still held
// create a new wrapper to propagate tracing span context.
wrapper := &clientWrapper{
capability: c.capability,
Conn: c.Client.Conn,
Expand All @@ -203,9 +195,92 @@ func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestC
contexts: make(map[string][]context.Context),
}

// get a session from the wrapper
session, err := wrapper.NewSession(chanReqCallback)
return session, trace.Wrap(err)
// open a session manually so we can take ownership of the
// requests chan
ch, reqs, err := wrapper.OpenChannel("session", nil)
if err != nil {
return nil, trace.Wrap(err)
}

unhandledReqs := c.serveSessionRequests(ctx, reqs)
session, err := newCryptoSSHSession(ch, unhandledReqs)
if err != nil {
_ = ch.Close()
return nil, trace.Wrap(err)
}

// wrap the session so all session requests on the channel
// can be traced
return &Session{
Session: session,
wrapper: wrapper,
}, nil
}

// RequestHandlerFn is an ssh request handler function.
type RequestHandlerFn func(ctx context.Context, ch *ssh.Request)

// HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the
// provided type within a session. If the type is already being handled, an error is returned.
// All registered handlers are consumed by the next call to [Client.NewSession].
func (c *Client) HandleSessionRequest(ctx context.Context, requestType string, handlerFn RequestHandlerFn) error {
c.requestHandlersMu.Lock()
defer c.requestHandlersMu.Unlock()

if _, ok := c.requestHandlers[requestType]; ok {
return trace.AlreadyExists("ssh request type %q is already being handled for this session", requestType)
}

c.requestHandlers[requestType] = handlerFn
return nil
}

// serveSessionRequests from the remote side with registered handlers.
//
// This method consumes all registered handlers so that the next call to
// [Client.NewSession] will not reuse the same handlers.
func (c *Client) serveSessionRequests(ctx context.Context, in <-chan *ssh.Request) <-chan *ssh.Request {
c.requestHandlersMu.Lock()
requestHandlers := c.requestHandlers
c.requestHandlers = make(map[string]RequestHandlerFn)
c.requestHandlersMu.Unlock()

// Capture requests not handled by registered request handlers and
// pass them to the crypto [ssh.Session].
unhandledReqs := make(chan *ssh.Request, cap(in))

tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)
go func() {
defer close(unhandledReqs)
for req := range in {
ctx, span := tracer.Start(
ctx,
fmt.Sprintf("ssh.HandleRequests/%s", req.Type),
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
append(
peerAttr(c.Conn.RemoteAddr()),
semconv.RPCServiceKey.String("ssh.Client"),
semconv.RPCMethodKey.String("HandleRequests"),
semconv.RPCSystemKey.String("ssh"),
)...,
),
)

handler, ok := requestHandlers[req.Type]
if ok {
handler(ctx, req)
} else {
// Pass on requests without a registered handler. These will be
// handled by the default x/crypto/ssh request handler.
unhandledReqs <- req
}

span.End()
}
}()

return unhandledReqs
}

// clientWrapper wraps the ssh.Conn for individual ssh.Client
Expand All @@ -229,64 +304,6 @@ type clientWrapper struct {
contexts map[string][]context.Context
}

// ChannelRequestCallback allows the handling of channel requests
// to be customized. nil can be returned if you don't want
// golang/x/crypto/ssh to handle the request.
type ChannelRequestCallback func(req *ssh.Request) *ssh.Request

// NewSession opens a new Session for this client.
func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, error) {
// create a client that will defer to us when
// opening the "session" channel so that we
// can add an Envelope to the request
client := &ssh.Client{
Conn: c,
}

var session *ssh.Session
var err error
if callback != nil {
// open a session manually so we can take ownership of the
// requests chan
ch, originalReqs, openChannelErr := client.OpenChannel("session", nil)
if openChannelErr != nil {
return nil, trace.Wrap(openChannelErr)
}

// pass the channel requests to the provided callback and
// forward them to another chan so golang.org/x/crypto/ssh
// can handle Session exiting correctly
reqs := make(chan *ssh.Request, cap(originalReqs))
go func() {
defer close(reqs)

for req := range originalReqs {
if req := callback(req); req != nil {
reqs <- req
}
}
}()

session, err = newCryptoSSHSession(ch, reqs)
if err != nil {
_ = ch.Close()
return nil, trace.Wrap(err)
}
} else {
session, err = client.NewSession()
if err != nil {
return nil, trace.Wrap(err)
}
}

// wrap the session so all session requests on the channel
// can be traced
return &Session{
Session: session,
wrapper: c,
}, nil
}

// wrappedSSHConn allows an SSH session to be created while also allowing
// callers to take ownership of the SSH channel requests chan.
type wrappedSSHConn struct {
Expand Down
Loading
Loading