Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 7 additions & 76 deletions api/observability/tracing/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"

"github.com/gravitational/trace"
"go.opentelemetry.io/otel/attribute"
Expand All @@ -29,6 +28,7 @@ import (
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/utils/sshutils"
)

// Client is a wrapper around ssh.Client that adds tracing support.
Expand Down Expand Up @@ -172,11 +172,11 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) {
// tracing context so that spans may be correlated properly over the ssh
// connection. The handling of channel requests from the underlying SSH
// session can be controlled with chanReqCallback.
func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback sshutils.ChannelRequestCallback) (*Session, error) {
return c.newSession(ctx, chanReqCallback)
}

func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
func (c *Client) newSession(ctx context.Context, chanReqCallback sshutils.ChannelRequestCallback) (*Session, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

ctx, span := tracer.Start(
Expand Down Expand Up @@ -229,54 +229,18 @@ type clientWrapper struct {
contexts map[string][]context.Context
}

// ChannelRequestCallback allows the handling of channel requests
// to be customized. nil can be returned if you don't want
// golang/x/crypto/ssh to handle the request.
type ChannelRequestCallback func(req *ssh.Request) *ssh.Request

// NewSession opens a new Session for this client.
func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, error) {
func (c *clientWrapper) NewSession(callback sshutils.ChannelRequestCallback) (*Session, error) {
// create a client that will defer to us when
// opening the "session" channel so that we
// can add an Envelope to the request
client := &ssh.Client{
Conn: c,
}

var session *ssh.Session
var err error
if callback != nil {
// open a session manually so we can take ownership of the
// requests chan
ch, originalReqs, openChannelErr := client.OpenChannel("session", nil)
if openChannelErr != nil {
return nil, trace.Wrap(openChannelErr)
}

// pass the channel requests to the provided callback and
// forward them to another chan so golang.org/x/crypto/ssh
// can handle Session exiting correctly
reqs := make(chan *ssh.Request, cap(originalReqs))
go func() {
defer close(reqs)

for req := range originalReqs {
if req := callback(req); req != nil {
reqs <- req
}
}
}()

session, err = newCryptoSSHSession(ch, reqs)
if err != nil {
_ = ch.Close()
return nil, trace.Wrap(err)
}
} else {
session, err = client.NewSession()
if err != nil {
return nil, trace.Wrap(err)
}
session, err := sshutils.NewSession(client, callback)
if err != nil {
return nil, trace.Wrap(err)
}

// wrap the session so all session requests on the channel
Expand All @@ -287,39 +251,6 @@ func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, e
}, nil
}

// wrappedSSHConn allows an SSH session to be created while also allowing
// callers to take ownership of the SSH channel requests chan.
type wrappedSSHConn struct {
ssh.Conn

channelOpened atomic.Bool

ch ssh.Channel
reqs <-chan *ssh.Request
}

func (f *wrappedSSHConn) OpenChannel(_ string, _ []byte) (ssh.Channel, <-chan *ssh.Request, error) {
if !f.channelOpened.CompareAndSwap(false, true) {
panic("wrappedSSHConn OpenChannel called more than once")
}

return f.ch, f.reqs, nil
}

// newCryptoSSHSession allows callers to take ownership of the SSH
// channel requests chan and allow callers to handle SSH channel requests.
// golang.org/x/crypto/ssh.(Client).NewSession takes ownership of all
// SSH channel requests and doesn't allow the caller to view or reply
// to them, so this workaround is needed.
func newCryptoSSHSession(ch ssh.Channel, reqs <-chan *ssh.Request) (*ssh.Session, error) {
return (&ssh.Client{
Conn: &wrappedSSHConn{
ch: ch,
reqs: reqs,
},
}).NewSession()
}

// Dial initiates a connection to the addr from the remote host.
func (c *clientWrapper) Dial(n, addr string) (net.Conn, error) {
// create a client that will defer to us when
Expand Down
26 changes: 0 additions & 26 deletions api/observability/tracing/ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,29 +242,3 @@ func TestSetEnvs(t *testing.T) {
default:
}
}

type mockSSHChannel struct {
ssh.Channel
}

func TestWrappedSSHConn(t *testing.T) {
sshCh := new(mockSSHChannel)
reqs := make(<-chan *ssh.Request)

// ensure that OpenChannel returns the same SSH channel and requests
// chan that wrappedSSHConn was given
wrappedConn := &wrappedSSHConn{
ch: sshCh,
reqs: reqs,
}
retCh, retReqs, err := wrappedConn.OpenChannel("", nil)
require.NoError(t, err)
require.Equal(t, sshCh, retCh)
require.Equal(t, reqs, retReqs)

// ensure the wrapped SSH conn will panic if OpenChannel is called
// twice
require.Panics(t, func() {
wrappedConn.OpenChannel("", nil)
})
}
106 changes: 106 additions & 0 deletions api/utils/sshutils/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2025 Gravitational, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sshutils

import (
"sync/atomic"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
)

// ChannelRequestCallback allows the handling of channel requests
// to be customized. nil can be returned if you don't want
// golang/x/crypto/ssh to handle the request.
type ChannelRequestCallback func(req *ssh.Request) *ssh.Request

// NewSession opens a new Session for this client.
func NewSession(client *ssh.Client, callback ChannelRequestCallback) (*ssh.Session, error) {
// No custom request handling needed. We can use the basic golang/x/crypto/ssh implementation.
if callback == nil {
session, err := client.NewSession()
if err != nil {
return nil, trace.Wrap(err)
}
return session, nil
}

// open a session manually so we can take ownership of the
// requests chan
ch, originalReqs, openChannelErr := client.OpenChannel("session", nil)
if openChannelErr != nil {
return nil, trace.Wrap(openChannelErr)
}

handleReqs := originalReqs
if callback != nil {
reqs := make(chan *ssh.Request, cap(originalReqs))
handleReqs = reqs

// pass the channel requests to the provided callback and
// forward them to another chan so golang.org/x/crypto/ssh
// can handle Session exiting correctly
go func() {
defer close(reqs)

for req := range originalReqs {
if req := callback(req); req != nil {
reqs <- req
}
}
}()
}

session, err := newCryptoSSHSession(ch, handleReqs)
if err != nil {
_ = ch.Close()
return nil, trace.Wrap(err)
}

return session, nil
}

// wrappedSSHConn allows an SSH session to be created while also allowing
// callers to take ownership of the SSH channel requests chan.
type wrappedSSHConn struct {
ssh.Conn

channelOpened atomic.Bool

ch ssh.Channel
reqs <-chan *ssh.Request
}

func (f *wrappedSSHConn) OpenChannel(_ string, _ []byte) (ssh.Channel, <-chan *ssh.Request, error) {
if !f.channelOpened.CompareAndSwap(false, true) {
panic("WrappedSSHConn.OpenChannel called more than once")
}

return f.ch, f.reqs, nil
}

// newCryptoSSHSession allows callers to take ownership of the SSH
// channel requests chan and allow callers to handle SSH channel requests.
// golang.org/x/crypto/ssh.(Client).NewSession takes ownership of all
// SSH channel requests and doesn't allow the caller to view or reply
// to them, so this workaround is needed.
func newCryptoSSHSession(ch ssh.Channel, reqs <-chan *ssh.Request) (*ssh.Session, error) {
return (&ssh.Client{
Conn: &wrappedSSHConn{
ch: ch,
reqs: reqs,
},
}).NewSession()
}
48 changes: 48 additions & 0 deletions api/utils/sshutils/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2025 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sshutils

import (
"testing"

"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)

type mockSSHChannel struct {
ssh.Channel
}

func TestWrappedSSHConn(t *testing.T) {
sshCh := new(mockSSHChannel)
reqs := make(<-chan *ssh.Request)

// ensure that OpenChannel returns the same SSH channel and requests
// chan that wrappedSSHConn was given
wrappedConn := &wrappedSSHConn{
ch: sshCh,
reqs: reqs,
}
retCh, retReqs, err := wrappedConn.OpenChannel("", nil)
require.NoError(t, err)
require.Equal(t, sshCh, retCh)
require.Equal(t, reqs, retReqs)

// ensure the wrapped SSH conn will panic if OpenChannel is called
// twice
require.Panics(t, func() {
wrappedConn.OpenChannel("", nil)
})
}
3 changes: 2 additions & 1 deletion lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ import (
"github.com/gravitational/teleport/api/utils/keys/hardwarekey"
"github.com/gravitational/teleport/api/utils/pingconn"
"github.com/gravitational/teleport/api/utils/prompt"
apisshutils "github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/touchid"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
Expand Down Expand Up @@ -1262,7 +1263,7 @@ type TeleportClient struct {

// OnChannelRequest gets called when SSH channel requests are
// received. It's safe to keep it nil.
OnChannelRequest tracessh.ChannelRequestCallback
OnChannelRequest apisshutils.ChannelRequestCallback

// OnShellCreated gets called when the shell is created. It's
// safe to keep it nil.
Expand Down
3 changes: 2 additions & 1 deletion lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
Expand Down Expand Up @@ -391,7 +392,7 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co
// RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr
// to and from the node and local shell. This will block until the interactive shell on the node
// is terminated.
func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer)) error {
func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, chanReqCallback sshutils.ChannelRequestCallback, beforeStart func(io.Writer)) error {
ctx, span := c.Tracer.Start(
ctx,
"nodeClient/RunInteractiveShell",
Expand Down
Loading
Loading