diff --git a/integration/port_forwarding_test.go b/integration/port_forwarding_test.go index 31bf61a0b510d..57445e7339732 100644 --- a/integration/port_forwarding_test.go +++ b/integration/port_forwarding_test.go @@ -21,6 +21,7 @@ package integration import ( "context" "fmt" + "net" "net/http" "net/http/httptest" "net/url" @@ -77,6 +78,24 @@ func waitForSessionToBeEstablished(ctx context.Context, namespace string, site a } } +// testPingLocalServer checks whether or not an HTTP server is serving on +// localhost at the given port. +func testPingLocalServer(t *testing.T, port int, expectSuccess bool) { + addr := fmt.Sprintf("http://%s:%d/", "localhost", port) + r, err := http.Get(addr) + + if r != nil { + r.Body.Close() + } + + if expectSuccess { + require.NoError(t, err) + require.NotNil(t, r) + } else { + require.Error(t, err) + } +} + func testPortForwarding(t *testing.T, suite *integrationTestSuite) { invalidOSLogin := uuid.NewString()[:12] notFound := false @@ -214,18 +233,33 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) { site := instance.GetSiteAPI(helpers.Site) - // ...and a running dummy server - remoteSvr := httptest.NewServer(http.HandlerFunc( + // ...and a pair of running dummy servers + handler := http.HandlerFunc( func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("Hello, World")) - })) + }) + remoteListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + remoteSvr := httptest.NewUnstartedServer(handler) + remoteSvr.Listener = remoteListener + remoteSvr.Start() defer remoteSvr.Close() + localListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + localSvr := httptest.NewUnstartedServer(handler) + localSvr.Listener = localListener + localSvr.Start() + defer localSvr.Close() + // ... and a client connection that was launched with port - // forwarding enabled to that dummy server - localPort := newPortValue() - remotePort, err := extractPort(remoteSvr) + // forwarding enabled to the dummy servers + localClientPort := newPortValue() + remoteServerPort, err := extractPort(remoteSvr) + require.NoError(t, err) + remoteClientPort := newPortValue() + localServerPort, err := extractPort(localSvr) require.NoError(t, err) nodeSSHPort := helpers.Port(t, instance.SSH) @@ -239,9 +273,17 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) { cl.Config.LocalForwardPorts = []client.ForwardedPort{ { SrcIP: "127.0.0.1", - SrcPort: localPort, + SrcPort: localClientPort, DestHost: "localhost", - DestPort: remotePort, + DestPort: remoteServerPort, + }, + } + cl.Config.RemoteForwardPorts = []client.ForwardedPort{ + { + SrcIP: "localhost", + SrcPort: remoteClientPort, + DestHost: "127.0.0.1", + DestPort: localServerPort, }, } term := NewTerminal(250) @@ -259,20 +301,13 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) // When everything is *finally* set up, and I attempt to use the - // forwarded connection - localURL := fmt.Sprintf("http://%s:%d/", "localhost", localPort) - r, err := http.Get(localURL) - - if r != nil { - r.Body.Close() - } - - if tt.expectSuccess { - require.NoError(t, err) - require.NotNil(t, r) - } else { - require.Error(t, err) - } + // forwarded connections + t.Run("local forwarding", func(t *testing.T) { + testPingLocalServer(t, localClientPort, tt.expectSuccess) + }) + t.Run("remote forwarding", func(t *testing.T) { + testPingLocalServer(t, remoteClientPort, tt.expectSuccess) + }) }) } } diff --git a/lib/client/api.go b/lib/client/api.go index e421e8d049885..f76f52f22f067 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -128,6 +128,8 @@ const ( ForwardAgentLocal ) +const remoteForwardUnsupportedMessage = "ssh: tcpip-forward request denied by peer" + var log = logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentClient, }) @@ -312,6 +314,10 @@ type Config struct { // port forwarding (parameters to -D ssh flag). DynamicForwardedPorts DynamicForwardedPorts + // RemoteForwardPorts are the list of ports the remote connection listens on + // for remote port forwarding (parameters to -R ssh flag). + RemoteForwardPorts ForwardedPorts + // HostKeyCallback will be called to check host keys of the remote // node, if not specified will be using CheckHostSignature function // that uses local cache to validate hosts @@ -1958,6 +1964,22 @@ func (tc *TeleportClient) startPortForwarding(ctx context.Context, nodeClient *N } go nodeClient.dynamicListenAndForward(ctx, socket, addr) } + for _, fp := range tc.Config.RemoteForwardPorts { + addr := net.JoinHostPort(fp.SrcIP, strconv.Itoa(fp.SrcPort)) + socket, err := nodeClient.Client.Listen("tcp", addr) + if err != nil { + // We log the error here instead of returning it to be consistent with + // the other port forwarding methods, which don't stop the session + // if forwarding fails. + message := fmt.Sprintf("Failed to bind on remote host to %v: %v.", addr, err) + if strings.Contains(err.Error(), remoteForwardUnsupportedMessage) { + message = "Node does not support remote port forwarding (-R)." + } + log.Error(message) + } else { + go nodeClient.remoteListenAndForward(ctx, socket, net.JoinHostPort(fp.DestHost, strconv.Itoa(fp.DestPort)), addr) + } + } return nil } diff --git a/lib/client/client.go b/lib/client/client.go index c4aa5547f3edb..e7b066b4763ce 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -2020,6 +2020,31 @@ func (c *NodeClient) dynamicListenAndForward(ctx context.Context, ln net.Listene log.WithError(ctx.Err()).Infof("Shutting down dynamic port forwarding.") } +// remoteListenAndForward requests a listening socket and forwards all incoming +// commands to the local address through the SSH tunnel. +func (c *NodeClient) remoteListenAndForward(ctx context.Context, ln net.Listener, localAddr, remoteAddr string) { + defer ln.Close() + log := log.WithField("localAddr", localAddr).WithField("remoteAddr", remoteAddr) + log.Infof("Starting remote port forwarding") + + for ctx.Err() == nil { + conn, err := acceptWithContext(ctx, ln) + if err != nil { + if ctx.Err() == nil { + log.WithError(err).Errorf("Remote port forwarding failed.") + } + continue + } + + go func() { + if err := proxyConnection(ctx, conn, localAddr, &net.Dialer{}); err != nil { + log.WithError(err).Warnf("Failed to proxy connection") + } + }() + } + log.WithError(ctx.Err()).Infof("Shutting down remote port forwarding.") +} + // GetRemoteTerminalSize fetches the terminal size of a given SSH session. func (c *NodeClient) GetRemoteTerminalSize(ctx context.Context, sessionID string) (*term.Winsize, error) { ctx, span := c.Tracer.Start( diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index ebfaaac9ee8d7..1d0030471db8e 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -177,6 +177,8 @@ type CLIConf struct { // DynamicForwardedPorts is port forwarding using SOCKS5. It is similar to // "ssh -D 8080 example.com". DynamicForwardedPorts []string + // -R flag for ssh. Remote port forwarding like 'ssh -R 80:localhost:80 -R 443:localhost:443' + RemoteForwardPorts []string // ForwardAgent agent to target node. Equivalent of -A for OpenSSH. ForwardAgent bool // ProxyJump is an optional -J flag pointing to the list of jumphosts, @@ -740,6 +742,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { ssh.Flag("forward-agent", "Forward agent to target node").Short('A').BoolVar(&cf.ForwardAgent) ssh.Flag("forward", "Forward localhost connections to remote server").Short('L').StringsVar(&cf.LocalForwardPorts) ssh.Flag("dynamic-forward", "Forward localhost connections to remote server using SOCKS5").Short('D').StringsVar(&cf.DynamicForwardedPorts) + ssh.Flag("remote-forward", "Forward remote connections to localhost").Short('R').StringsVar(&cf.RemoteForwardPorts) ssh.Flag("local", "Execute command on localhost after connecting to SSH node").Default("false").BoolVar(&cf.LocalExec) ssh.Flag("tty", "Allocate TTY").Short('t').BoolVar(&cf.Interactive) ssh.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) @@ -3641,6 +3644,11 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err return nil, trace.Wrap(err) } + rPorts, err := client.ParsePortForwardSpec(cf.RemoteForwardPorts) + if err != nil { + return nil, trace.Wrap(err) + } + // 1: start with the defaults c := client.MakeDefaultConfig() @@ -3786,6 +3794,9 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err if len(dPorts) > 0 { c.DynamicForwardedPorts = dPorts } + if len(rPorts) > 0 { + c.RemoteForwardPorts = rPorts + } if cf.SiteName != "" { c.SiteName = cf.SiteName }