Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions integration/port_forwarding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package integration
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -77,6 +78,24 @@ func waitForSessionToBeEstablished(ctx context.Context, namespace string, site a
}
}

// testPingLocalServer checks whether or not an HTTP server is serving on
// localhost at the given port.
func testPingLocalServer(t *testing.T, port int, expectSuccess bool) {
addr := fmt.Sprintf("http://%s:%d/", "localhost", port)
r, err := http.Get(addr)

if r != nil {
r.Body.Close()
}

if expectSuccess {
require.NoError(t, err)
require.NotNil(t, r)
} else {
require.Error(t, err)
}
}

func testPortForwarding(t *testing.T, suite *integrationTestSuite) {
invalidOSLogin := uuid.NewString()[:12]
notFound := false
Expand Down Expand Up @@ -214,18 +233,33 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) {

site := instance.GetSiteAPI(helpers.Site)

// ...and a running dummy server
remoteSvr := httptest.NewServer(http.HandlerFunc(
// ...and a pair of running dummy servers
handler := http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, World"))
}))
})
remoteListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
remoteSvr := httptest.NewUnstartedServer(handler)
remoteSvr.Listener = remoteListener
remoteSvr.Start()
defer remoteSvr.Close()

localListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
localSvr := httptest.NewUnstartedServer(handler)
localSvr.Listener = localListener
localSvr.Start()
defer localSvr.Close()

// ... and a client connection that was launched with port
// forwarding enabled to that dummy server
localPort := newPortValue()
remotePort, err := extractPort(remoteSvr)
// forwarding enabled to the dummy servers
localClientPort := newPortValue()
remoteServerPort, err := extractPort(remoteSvr)
require.NoError(t, err)
remoteClientPort := newPortValue()
localServerPort, err := extractPort(localSvr)
require.NoError(t, err)

nodeSSHPort := helpers.Port(t, instance.SSH)
Expand All @@ -239,9 +273,17 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) {
cl.Config.LocalForwardPorts = []client.ForwardedPort{
{
SrcIP: "127.0.0.1",
SrcPort: localPort,
SrcPort: localClientPort,
DestHost: "localhost",
DestPort: remotePort,
DestPort: remoteServerPort,
},
}
cl.Config.RemoteForwardPorts = []client.ForwardedPort{
{
SrcIP: "localhost",
SrcPort: remoteClientPort,
DestHost: "127.0.0.1",
DestPort: localServerPort,
},
}
term := NewTerminal(250)
Expand All @@ -259,20 +301,13 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// When everything is *finally* set up, and I attempt to use the
// forwarded connection
localURL := fmt.Sprintf("http://%s:%d/", "localhost", localPort)
r, err := http.Get(localURL)

if r != nil {
r.Body.Close()
}

if tt.expectSuccess {
require.NoError(t, err)
require.NotNil(t, r)
} else {
require.Error(t, err)
}
// forwarded connections
t.Run("local forwarding", func(t *testing.T) {
testPingLocalServer(t, localClientPort, tt.expectSuccess)
})
t.Run("remote forwarding", func(t *testing.T) {
testPingLocalServer(t, remoteClientPort, tt.expectSuccess)
})
})
}
}
22 changes: 22 additions & 0 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ const (
ForwardAgentLocal
)

const remoteForwardUnsupportedMessage = "ssh: tcpip-forward request denied by peer"

var log = logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentClient,
})
Expand Down Expand Up @@ -312,6 +314,10 @@ type Config struct {
// port forwarding (parameters to -D ssh flag).
DynamicForwardedPorts DynamicForwardedPorts

// RemoteForwardPorts are the list of ports the remote connection listens on
// for remote port forwarding (parameters to -R ssh flag).
RemoteForwardPorts ForwardedPorts

// HostKeyCallback will be called to check host keys of the remote
// node, if not specified will be using CheckHostSignature function
// that uses local cache to validate hosts
Expand Down Expand Up @@ -1958,6 +1964,22 @@ func (tc *TeleportClient) startPortForwarding(ctx context.Context, nodeClient *N
}
go nodeClient.dynamicListenAndForward(ctx, socket, addr)
}
for _, fp := range tc.Config.RemoteForwardPorts {
addr := net.JoinHostPort(fp.SrcIP, strconv.Itoa(fp.SrcPort))
socket, err := nodeClient.Client.Listen("tcp", addr)
if err != nil {
// We log the error here instead of returning it to be consistent with
// the other port forwarding methods, which don't stop the session
// if forwarding fails.
message := fmt.Sprintf("Failed to bind on remote host to %v: %v.", addr, err)
if strings.Contains(err.Error(), remoteForwardUnsupportedMessage) {
message = "Node does not support remote port forwarding (-R)."
}
log.Error(message)
} else {
go nodeClient.remoteListenAndForward(ctx, socket, net.JoinHostPort(fp.DestHost, strconv.Itoa(fp.DestPort)), addr)
}
}
return nil
}

Expand Down
25 changes: 25 additions & 0 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,31 @@ func (c *NodeClient) dynamicListenAndForward(ctx context.Context, ln net.Listene
log.WithError(ctx.Err()).Infof("Shutting down dynamic port forwarding.")
}

// remoteListenAndForward requests a listening socket and forwards all incoming
// commands to the local address through the SSH tunnel.
func (c *NodeClient) remoteListenAndForward(ctx context.Context, ln net.Listener, localAddr, remoteAddr string) {
defer ln.Close()
log := log.WithField("localAddr", localAddr).WithField("remoteAddr", remoteAddr)
log.Infof("Starting remote port forwarding")

for ctx.Err() == nil {
conn, err := acceptWithContext(ctx, ln)
if err != nil {
if ctx.Err() == nil {
log.WithError(err).Errorf("Remote port forwarding failed.")
}
continue
}

go func() {
if err := proxyConnection(ctx, conn, localAddr, &net.Dialer{}); err != nil {
log.WithError(err).Warnf("Failed to proxy connection")
}
}()
}
log.WithError(ctx.Err()).Infof("Shutting down remote port forwarding.")
}

// GetRemoteTerminalSize fetches the terminal size of a given SSH session.
func (c *NodeClient) GetRemoteTerminalSize(ctx context.Context, sessionID string) (*term.Winsize, error) {
ctx, span := c.Tracer.Start(
Expand Down
11 changes: 11 additions & 0 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ type CLIConf struct {
// DynamicForwardedPorts is port forwarding using SOCKS5. It is similar to
// "ssh -D 8080 example.com".
DynamicForwardedPorts []string
// -R flag for ssh. Remote port forwarding like 'ssh -R 80:localhost:80 -R 443:localhost:443'
RemoteForwardPorts []string
// ForwardAgent agent to target node. Equivalent of -A for OpenSSH.
ForwardAgent bool
// ProxyJump is an optional -J flag pointing to the list of jumphosts,
Expand Down Expand Up @@ -740,6 +742,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
ssh.Flag("forward-agent", "Forward agent to target node").Short('A').BoolVar(&cf.ForwardAgent)
ssh.Flag("forward", "Forward localhost connections to remote server").Short('L').StringsVar(&cf.LocalForwardPorts)
ssh.Flag("dynamic-forward", "Forward localhost connections to remote server using SOCKS5").Short('D').StringsVar(&cf.DynamicForwardedPorts)
ssh.Flag("remote-forward", "Forward remote connections to localhost").Short('R').StringsVar(&cf.RemoteForwardPorts)
ssh.Flag("local", "Execute command on localhost after connecting to SSH node").Default("false").BoolVar(&cf.LocalExec)
ssh.Flag("tty", "Allocate TTY").Short('t').BoolVar(&cf.Interactive)
ssh.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName)
Expand Down Expand Up @@ -3641,6 +3644,11 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err
return nil, trace.Wrap(err)
}

rPorts, err := client.ParsePortForwardSpec(cf.RemoteForwardPorts)
if err != nil {
return nil, trace.Wrap(err)
}

// 1: start with the defaults
c := client.MakeDefaultConfig()

Expand Down Expand Up @@ -3786,6 +3794,9 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err
if len(dPorts) > 0 {
c.DynamicForwardedPorts = dPorts
}
if len(rPorts) > 0 {
c.RemoteForwardPorts = rPorts
}
if cf.SiteName != "" {
c.SiteName = cf.SiteName
}
Expand Down