From 08d03fde4306cd421f8f9db46db885933bccf60c Mon Sep 17 00:00:00 2001 From: Andrew Burke Date: Thu, 11 Jan 2024 11:03:37 -0800 Subject: [PATCH] Route to server by public addr This change fixes a bug where tsh ssh could not dial a node with its public address. --- api/utils/route.go | 17 ++--- api/utils/route_test.go | 135 ++++++++++++++++++++++++++++++++++++++++ lib/services/watcher.go | 2 + 3 files changed, 147 insertions(+), 7 deletions(-) diff --git a/api/utils/route.go b/api/utils/route.go index 67100bdcbb01f..3925519f0cec1 100644 --- a/api/utils/route.go +++ b/api/utils/route.go @@ -59,6 +59,7 @@ type RouteableServer interface { GetHostname() string GetAddr() string GetUseTunnel() bool + GetPublicAddr() string } // RouteToServer checks if this route matcher wants to route to the supplied server. @@ -75,14 +76,16 @@ func (m *SSHRouteMatcher) RouteToServer(server RouteableServer) bool { return m.routeToHostname(server.GetHostname()) } - ip, nodePort, err := net.SplitHostPort(server.GetAddr()) - if err != nil { - return false - } + for _, addr := range []string{server.GetAddr(), server.GetPublicAddr()} { + ip, nodePort, err := net.SplitHostPort(addr) + if err != nil { + continue + } - if (m.targetHost == ip || m.routeToHostname(server.GetHostname()) || slices.Contains(m.ips, ip)) && - (m.targetPort == "" || m.targetPort == "0" || m.targetPort == nodePort) { - return true + if (m.targetHost == ip || m.routeToHostname(server.GetHostname()) || slices.Contains(m.ips, ip)) && + (m.targetPort == "" || m.targetPort == "0" || m.targetPort == nodePort) { + return true + } } return false diff --git a/api/utils/route_test.go b/api/utils/route_test.go index ef3118679ba2f..f3add75ab52a5 100644 --- a/api/utils/route_test.go +++ b/api/utils/route_test.go @@ -17,6 +17,7 @@ package utils import ( "testing" + "github.com/google/uuid" "github.com/stretchr/testify/require" ) @@ -103,3 +104,137 @@ func TestSSHRouteMatcherHostnameMatching(t *testing.T) { require.Equal(t, tt.match, matcher.routeToHostname(tt.principal), "desc=%q", tt.desc) } } + +type mockRouteableServer struct { + name string + hostname string + addr string + useTunnel bool + publicAddr string +} + +func (m mockRouteableServer) GetName() string { + return m.name +} + +func (m mockRouteableServer) GetHostname() string { + return m.hostname +} + +func (m mockRouteableServer) GetAddr() string { + return m.addr +} + +func (m mockRouteableServer) GetUseTunnel() bool { + return m.useTunnel +} + +func (m mockRouteableServer) GetPublicAddr() string { + return m.publicAddr +} + +func TestRouteToServer(t *testing.T) { + t.Parallel() + testUUID := uuid.NewString() + + matchAddrServer := mockRouteableServer{ + name: "test", + addr: "example.com:1111", + publicAddr: "public.example.com:1111", + } + + tests := []struct { + name string + matcher SSHRouteMatcher + server RouteableServer + assert require.BoolAssertionFunc + }{ + { + name: "no match", + matcher: NewSSHRouteMatcher(testUUID, "", true), + server: mockRouteableServer{ + name: "test", + addr: "localhost", + hostname: "example.com", + publicAddr: "example.com", + }, + assert: require.False, + }, + { + name: "match by server name", + matcher: NewSSHRouteMatcher(testUUID, "", true), + server: mockRouteableServer{ + name: testUUID, + addr: "localhost", + hostname: "example.com", + publicAddr: "example.com", + }, + assert: require.True, + }, + { + name: "match by hostname over tunnel", + matcher: NewSSHRouteMatcher("example.com", "", true), + server: mockRouteableServer{ + name: testUUID, + addr: "addr.example.com", + hostname: "example.com", + publicAddr: "public.example.com", + useTunnel: true, + }, + assert: require.True, + }, + { + name: "mismatch hostname over tunnel", + matcher: NewSSHRouteMatcher("example.com", "", true), + server: mockRouteableServer{ + name: testUUID, + addr: "example.com", + hostname: "fake.example.com", + publicAddr: "example.com", + useTunnel: true, + }, + assert: require.False, + }, + { + name: "match addr", + matcher: NewSSHRouteMatcher("example.com", "1111", true), + server: matchAddrServer, + assert: require.True, + }, + { + name: "match addr with empty port", + matcher: NewSSHRouteMatcher("example.com", "", true), + server: matchAddrServer, + assert: require.True, + }, + { + name: "mismatch addr with wrong port", + matcher: NewSSHRouteMatcher("example.com", "2222", true), + server: matchAddrServer, + assert: require.False, + }, + { + name: "match public addr", + matcher: NewSSHRouteMatcher("public.example.com", "1111", true), + server: matchAddrServer, + assert: require.True, + }, + { + name: "match public addr with empty port", + matcher: NewSSHRouteMatcher("public.example.com", "", true), + server: matchAddrServer, + assert: require.True, + }, + { + name: "mismatch public addr with wrong port", + matcher: NewSSHRouteMatcher("public.example.com", "2222", true), + server: matchAddrServer, + assert: require.False, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.assert(t, tc.matcher.RouteToServer(tc.server)) + }) + } +} diff --git a/lib/services/watcher.go b/lib/services/watcher.go index d9850b8f115ce..15eaafeae9999 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -1720,6 +1720,8 @@ type Node interface { GetTeleportVersion() string // GetAddr return server address GetAddr() string + // GetPublicAddr returns a public address where this server can be reached. + GetPublicAddr() string // GetHostname returns server hostname GetHostname() string // GetNamespace returns server namespace