diff --git a/lib/vnet/ssh_proxy.go b/lib/vnet/ssh_proxy.go new file mode 100644 index 0000000000000..35c1f44e17b13 --- /dev/null +++ b/lib/vnet/ssh_proxy.go @@ -0,0 +1,251 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vnet + +import ( + "context" + "errors" + "log/slog" + "sync" + + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/lib/utils" +) + +// sshConn represents an established SSH client or server connection. +type sshConn struct { + conn ssh.Conn + chans <-chan ssh.NewChannel + reqs <-chan *ssh.Request +} + +// proxySSHConnection transparently proxies SSH channels and requests +// between 2 established SSH connections. serverConn represents an incoming SSH +// connection where this proxy acts as a server, client represents an outgoing +// SSH connection where this proxy acts as a client. +func proxySSHConnection( + ctx context.Context, + serverConn sshConn, + clientConn sshConn, +) { + closeConnections := sync.OnceFunc(func() { + clientConn.conn.Close() + serverConn.conn.Close() + }) + // Close both connections if the context is canceled. + stop := context.AfterFunc(ctx, closeConnections) + defer stop() + + // Avoid leaking goroutines by tracking them with a waitgroup. + // If any task exits make sure to close both connections so that all other + // tasks can terminate. + var wg sync.WaitGroup + runTask := func(task func()) { + wg.Add(1) + go func() { + task() + closeConnections() + wg.Done() + }() + } + + // Proxy channels initiated by either connection. + runTask(func() { + proxyChannels(ctx, serverConn.conn, clientConn.chans, closeConnections) + }) + runTask(func() { + proxyChannels(ctx, clientConn.conn, serverConn.chans, closeConnections) + }) + + // Proxy global requests in both directions. + runTask(func() { + proxyGlobalRequests(ctx, serverConn.conn, clientConn.reqs, closeConnections) + }) + runTask(func() { + proxyGlobalRequests(ctx, clientConn.conn, serverConn.reqs, closeConnections) + }) + + wg.Wait() +} + +func proxyChannels( + ctx context.Context, + targetConn ssh.Conn, + chans <-chan ssh.NewChannel, + closeConnections func(), +) { + // Proxy each SSH channel in its own goroutine, make sure they don't leak by + // tracking with a WaitGroup. + var wg sync.WaitGroup + for newChan := range chans { + wg.Add(1) + go func() { + defer wg.Done() + proxyChannel(ctx, targetConn, newChan, closeConnections) + }() + } + wg.Wait() +} + +func proxyChannel( + ctx context.Context, + targetConn ssh.Conn, + newChan ssh.NewChannel, + closeConnections func(), +) { + log := log.With("channel_type", newChan.ChannelType()) + log.DebugContext(ctx, "Proxying new SSH channel") + + // Try to open a corresponding channel on the target. + targetChan, targetChanRequests, err := targetConn.OpenChannel( + newChan.ChannelType(), newChan.ExtraData()) + if err != nil { + // Failed to open the channel on the target, newChan must be rejected. + var ( + rejectionReason ssh.RejectionReason + rejectionMessage string + openChannelErr *ssh.OpenChannelError + ) + if errors.As(err, &openChannelErr) { + // The target rejected the channel, this is totally expected. + rejectionReason = openChannelErr.Reason + rejectionMessage = openChannelErr.Message + } else { + // We got an unexpected error type trying to open the channel on the + // target, this is fatal, log and kill the connection. + log.DebugContext(ctx, "Unexpected error opening SSH channel on target", + "error", err) + closeConnections() + // newChan still has to be rejected below to satisfy the crypto/ssh + // API, but the underlying network connection is already closed so + // we just leave the reason and message empty. + } + if err := newChan.Reject(rejectionReason, rejectionMessage); err != nil { + // Failed to reject the incoming channel, this is fatal, log and + // kill the connection. + log.DebugContext(ctx, "Failed to reject SSH channel request", + "error", err) + closeConnections() + } + return + } + + // Now that the target accepted the channel, accept the incoming channel + // request. + incomingChan, incomingChanRequests, err := newChan.Accept() + if err != nil { + // Failing to accept an incoming channel request that the target already + // accepted is fatal. Kill the connection, close the channel we + // just opened on the target and drain the request channel. + log.DebugContext(ctx, "Failed to accept SSH channel request already accepted by the target, killing the connection", + "error", err) + closeConnections() + go ssh.DiscardRequests(targetChanRequests) + _ = targetChan.Close() + return + } + + // Copy channel requests in both directions concurrently. If either fails or + // exits it will cancel the context so that utils.ProxyConn below will close + // both channels so the other goroutine can also exit. + var wg sync.WaitGroup + wg.Add(2) + ctx, cancel := context.WithCancel(ctx) + go func() { + proxyChannelRequests(ctx, log, targetChan, incomingChanRequests, cancel) + cancel() + wg.Done() + }() + go func() { + proxyChannelRequests(ctx, log, incomingChan, targetChanRequests, cancel) + cancel() + wg.Done() + }() + + // ProxyConn copies channel data bidirectionally. If the context is + // canceled it will terminate, it always closes both channels before + // returning. + if err := utils.ProxyConn(ctx, incomingChan, targetChan); err != nil && + !utils.IsOKNetworkError(err) && !errors.Is(err, context.Canceled) { + log.DebugContext(ctx, "Unexpected error proxying channel data", "error", err) + } + + // Wait for all goroutines to terminate. + wg.Wait() +} + +func proxyChannelRequests( + ctx context.Context, + log *slog.Logger, + targetChan ssh.Channel, + reqs <-chan *ssh.Request, + closeChannels func(), +) { + log = log.With("request_layer", "channel") + sendRequest := func(name string, wantReply bool, payload []byte) (bool, []byte, error) { + ok, err := targetChan.SendRequest(name, wantReply, payload) + // Replies to channel requests never have a payload. + return ok, nil, err + } + proxyRequests(ctx, log, sendRequest, reqs, closeChannels) +} + +func proxyGlobalRequests( + ctx context.Context, + targetConn ssh.Conn, + reqs <-chan *ssh.Request, + closeConnections func(), +) { + log := log.With("request_layer", "global") + sendRequest := targetConn.SendRequest + proxyRequests(ctx, log, sendRequest, reqs, closeConnections) +} + +func proxyRequests( + ctx context.Context, + log *slog.Logger, + sendRequest func(name string, wantReply bool, payload []byte) (bool, []byte, error), + reqs <-chan *ssh.Request, + closeRequestSources func(), +) { + for req := range reqs { + log := log.With("request_type", req.Type) + log.DebugContext(ctx, "Proxying SSH request") + ok, reply, err := sendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + // We failed to send the request, the target must be dead. + log.DebugContext(ctx, "Failed to forward SSH request", "request_type", req.Type, "error", err) + // Close both connections or channels to clean up but we must + // continue handling requests on the chan until it is closed by + // crypto/ssh. + closeRequestSources() + _ = req.Reply(false, nil) + continue + } + if err := req.Reply(ok, reply); err != nil { + // A reply was expected and returned by the target but we failed to + // forward it back, the connection that initiated the request must + // be dead. + log.DebugContext(ctx, "Failed to reply to SSH request", "request_type", req.Type, "error", err) + // Close both connections or channels to clean up but we must + // continue handling requests on the chan until it is closed by + // crypto/ssh. + closeRequestSources() + } + } +} diff --git a/lib/vnet/ssh_proxy_test.go b/lib/vnet/ssh_proxy_test.go new file mode 100644 index 0000000000000..e9b5d4f4b6f87 --- /dev/null +++ b/lib/vnet/ssh_proxy_test.go @@ -0,0 +1,345 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vnet + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "fmt" + "io" + "net" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "google.golang.org/grpc/test/bufconn" + + "github.com/gravitational/teleport/lib/utils" +) + +// TestProxySSHConnection exercises [proxySSHConnection] to test that it +// transparently proxies SSH channels and requests. +// +// The test starts a target SSH server implemented in this file that handles +// channels and requests of type "echo", each handler echos input back to +// output. +// +// The test also starts a proxy server that proxies incoming SSH connections to +// the target server using [proxySSHConnection]. +// +// The test asserts that connecting directly to the target server or to the +// proxy appears identical to the client. +func TestProxySSHConnection(t *testing.T) { + ctx := context.Background() + + proxyListener := bufconn.Listen(100) + serverListener := bufconn.Listen(100) + + targetServerConfig := sshServerConfig(t) + proxyServerConfig := sshServerConfig(t) + + proxyClientConfig := sshClientConfig(t) + + utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{ + Name: "target server", + Task: func(ctx context.Context) error { + return runTestSSHServer(serverListener, targetServerConfig) + }, + Terminate: func() error { + return trace.Wrap(serverListener.Close()) + }, + }) + utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{ + Name: "proxy server", + Task: func(ctx context.Context) error { + return runTestSSHProxy(ctx, + proxyListener, + proxyServerConfig, + serverListener, + proxyClientConfig, + ) + }, + Terminate: func() error { + return trace.Wrap(proxyListener.Close()) + }, + }) + + // Test with a direct connection to the test server and a proxied connection + // to make sure the behavior is indistinguishable. + t.Run("direct", func(t *testing.T) { + testSSHConnection(t, serverListener) + }) + for i := range 4 { + // Run the proxied test multiple times to make sure proxySSHConnection + // actually returns when the connection ends. + t.Run(fmt.Sprintf("proxied_%d", i), func(t *testing.T) { + testSSHConnection(t, proxyListener) + }) + } +} + +func testSSHConnection(t *testing.T, dial dialer) { + tcpConn, err := dial.Dial() + require.NoError(t, err) + defer tcpConn.Close() + + clientConfig := sshClientConfig(t) + sshConn, chans, reqs, err := ssh.NewClientConn(tcpConn, "localhost", clientConfig) + require.NoError(t, err) + defer sshConn.Close() + go ssh.DiscardRequests(reqs) + go func() { + for newChan := range chans { + newChan.Reject(ssh.Prohibited, "test") + } + }() + + // Try sending some global requests. + t.Run("global requests", func(t *testing.T) { + testGlobalRequests(t, sshConn) + }) + + // Try opening a channel that the target server will reject. + t.Run("unexpected channel", func(t *testing.T) { + _, _, err := sshConn.OpenChannel("unexpected", nil) + require.Error(t, err) + require.ErrorAs(t, err, new(*ssh.OpenChannelError)) + }) + + // Try opening a channel that echoes input data back to output, run + // it twice to make sure multiple channels can be opened. + // testEchoChannel will also send channel requests. + t.Run("echo channel 1", func(t *testing.T) { + testEchoChannel(t, sshConn) + }) + t.Run("echo channel 2", func(t *testing.T) { + testEchoChannel(t, sshConn) + }) +} + +func testGlobalRequests(t *testing.T, conn ssh.Conn) { + // Send an echo request. + msg := []byte("hello") + reply, replyPayload, err := conn.SendRequest("echo", true, msg) + assert.NoError(t, err) + assert.True(t, reply) + assert.Equal(t, msg, replyPayload) + + // Send an unexepected request type. + reply, replyPayload, err = conn.SendRequest("unexpected", true, msg) + assert.NoError(t, err) + assert.False(t, reply) + assert.Empty(t, replyPayload) +} + +func testEchoChannel(t *testing.T, conn ssh.Conn) { + ch, reqs, err := conn.OpenChannel("echo", nil) + require.NoError(t, err) + go ssh.DiscardRequests(reqs) + defer ch.Close() + + // Try sending a message over the SSH channel and asserting that it is + // echoed back. + msg := []byte("hello") + _, err = ch.Write(msg) + require.NoError(t, err) + var buf [16]byte + n, err := ch.Read(buf[:]) + require.NoError(t, err) + require.Equal(t, len(msg), n) + require.Equal(t, msg, buf[:n]) + + // Try sending a channel request that expects a reply. + reply, err := ch.SendRequest("echo", true, nil) + require.NoError(t, err) + require.True(t, reply) + + // The test server replies false to channel requests with type other than + // "echo". + reply, err = ch.SendRequest("unknown", true, nil) + require.NoError(t, err) + require.False(t, reply) +} + +type dialer interface { + Dial() (net.Conn, error) +} + +// runTestSSHProxy runs an SSH proxy server. The function under test +// [proxySSHConnection] requires an established client and server SSH connection +// and only handles proxying SSH requests and channels between them, this server +// is the glue that handles accepting connections from a listener, dialing the +// target server, completing SSH handshakes with each of them, and then finally +// calling [proxySSHConnection]. +func runTestSSHProxy( + ctx context.Context, + lis net.Listener, + serverCfg *ssh.ServerConfig, + serverDialer dialer, + clientCfg *ssh.ClientConfig, +) error { + for { + incomingConn, err := lis.Accept() + if err != nil { + if err.Error() == "closed" { + return nil + } + return trace.Wrap(err) + } + outgoingConn, err := serverDialer.Dial() + if err != nil { + incomingConn.Close() + if err.Error() == "closed" { + return nil + } + return trace.Wrap(err) + } + if err := runTestSSHProxyInstance( + ctx, + incomingConn, + serverCfg, + outgoingConn, + clientCfg, + ); err != nil { + return trace.Wrap(err) + } + } +} + +func runTestSSHProxyInstance( + ctx context.Context, + incomingConn net.Conn, + serverCfg *ssh.ServerConfig, + outgoingConn net.Conn, + clientCfg *ssh.ClientConfig, +) error { + defer incomingConn.Close() + defer outgoingConn.Close() + incomingSSHConn, incomingChans, incomingReqs, err := ssh.NewServerConn(incomingConn, serverCfg) + if err != nil { + return trace.Wrap(err) + } + defer incomingSSHConn.Close() + outgoingSSHConn, outgoingChans, outgoingReqs, err := ssh.NewClientConn(outgoingConn, "localhost", clientCfg) + if err != nil { + return trace.Wrap(err, "proxying SSH conn in test") + } + defer outgoingSSHConn.Close() + proxySSHConnection(ctx, sshConn{ + conn: incomingSSHConn, + chans: incomingChans, + reqs: incomingReqs, + }, sshConn{ + conn: outgoingSSHConn, + chans: outgoingChans, + reqs: outgoingReqs, + }) + return trace.Wrap(err) +} + +// runTestSSHServer runs a test SSH server that responds to new channel +// requests, global requests, and channel requests of type "echo". It handles +// each by replying with an "echo" of the input. +func runTestSSHServer(lis net.Listener, cfg *ssh.ServerConfig) error { + for { + tcpConn, err := lis.Accept() + if err != nil { + if err.Error() == "closed" { + return nil + } + return trace.Wrap(err) + } + if err := runTestSSHServerInstance(tcpConn, cfg); err != nil { + return trace.Wrap(err) + } + } +} + +func runTestSSHServerInstance(tcpConn net.Conn, cfg *ssh.ServerConfig) error { + sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, cfg) + if err != nil { + return trace.Wrap(err) + } + go func() { + handleEchoRequests(reqs) + sshConn.Close() + }() + handleEchoChannels(chans) + sshConn.Close() + return nil +} + +func handleEchoRequests(reqs <-chan *ssh.Request) { + for req := range reqs { + switch req.Type { + case "echo": + req.Reply(true, req.Payload) + default: + req.Reply(false, nil) + } + } +} + +func handleEchoChannels(chans <-chan ssh.NewChannel) { + for newChan := range chans { + switch newChan.ChannelType() { + case "echo": + go handleEchoChannel(newChan) + default: + newChan.Reject(ssh.UnknownChannelType, "unknown channel type") + } + } +} + +func handleEchoChannel(newChan ssh.NewChannel) { + ch, reqs, err := newChan.Accept() + if err != nil { + return + } + go handleEchoRequests(reqs) + io.Copy(ch, ch) +} + +func sshServerConfig(t *testing.T) *ssh.ServerConfig { + _, serverKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + hostSigner, err := ssh.NewSignerFromSigner(serverKey) + require.NoError(t, err) + serverConfig := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + // We're not testing SSH authentication here, just accept any user key. + return nil, nil + }, + } + serverConfig.AddHostKey(hostSigner) + return serverConfig +} + +func sshClientConfig(t *testing.T) *ssh.ClientConfig { + _, clientKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + clientSigner, err := ssh.NewSignerFromSigner(clientKey) + require.NoError(t, err) + return &ssh.ClientConfig{ + Auth: []ssh.AuthMethod{ssh.PublicKeys(clientSigner)}, + // We're not testing SSH authentication here, just accept any host key. + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } +}