diff --git a/api/utils/route.go b/api/utils/route.go index 3925519f0cec1..c45860c3d1b88 100644 --- a/api/utils/route.go +++ b/api/utils/route.go @@ -59,7 +59,7 @@ type RouteableServer interface { GetHostname() string GetAddr() string GetUseTunnel() bool - GetPublicAddr() string + GetPublicAddrs() []string } // RouteToServer checks if this route matcher wants to route to the supplied server. @@ -70,22 +70,36 @@ func (m *SSHRouteMatcher) RouteToServer(server RouteableServer) bool { return true } + hostnameMatch := m.routeToHostname(server.GetHostname()) + // if the server has connected over a reverse tunnel // then match only by hostname. if server.GetUseTunnel() { - return m.routeToHostname(server.GetHostname()) + return hostnameMatch } - for _, addr := range []string{server.GetAddr(), server.GetPublicAddr()} { + matchAddr := func(addr string) bool { ip, nodePort, err := net.SplitHostPort(addr) if err != nil { - continue + return false } - if (m.targetHost == ip || m.routeToHostname(server.GetHostname()) || slices.Contains(m.ips, ip)) && + if (m.targetHost == ip || hostnameMatch || slices.Contains(m.ips, ip)) && (m.targetPort == "" || m.targetPort == "0" || m.targetPort == nodePort) { return true } + + return false + } + + if matchAddr(server.GetAddr()) { + return true + } + + for _, addr := range server.GetPublicAddrs() { + if matchAddr(addr) { + return true + } } return false diff --git a/api/utils/route_test.go b/api/utils/route_test.go index f3add75ab52a5..eff3762f021e6 100644 --- a/api/utils/route_test.go +++ b/api/utils/route_test.go @@ -110,7 +110,7 @@ type mockRouteableServer struct { hostname string addr string useTunnel bool - publicAddr string + publicAddr []string } func (m mockRouteableServer) GetName() string { @@ -129,7 +129,7 @@ func (m mockRouteableServer) GetUseTunnel() bool { return m.useTunnel } -func (m mockRouteableServer) GetPublicAddr() string { +func (m mockRouteableServer) GetPublicAddrs() []string { return m.publicAddr } @@ -140,7 +140,7 @@ func TestRouteToServer(t *testing.T) { matchAddrServer := mockRouteableServer{ name: "test", addr: "example.com:1111", - publicAddr: "public.example.com:1111", + publicAddr: []string{"node:1234", "public.example.com:1111"}, } tests := []struct { @@ -156,7 +156,7 @@ func TestRouteToServer(t *testing.T) { name: "test", addr: "localhost", hostname: "example.com", - publicAddr: "example.com", + publicAddr: []string{"example.com"}, }, assert: require.False, }, @@ -167,7 +167,7 @@ func TestRouteToServer(t *testing.T) { name: testUUID, addr: "localhost", hostname: "example.com", - publicAddr: "example.com", + publicAddr: []string{"example.com"}, }, assert: require.True, }, @@ -178,7 +178,7 @@ func TestRouteToServer(t *testing.T) { name: testUUID, addr: "addr.example.com", hostname: "example.com", - publicAddr: "public.example.com", + publicAddr: []string{"public.example.com"}, useTunnel: true, }, assert: require.True, @@ -190,7 +190,7 @@ func TestRouteToServer(t *testing.T) { name: testUUID, addr: "example.com", hostname: "fake.example.com", - publicAddr: "example.com", + publicAddr: []string{"example.com"}, useTunnel: true, }, assert: require.False, @@ -214,7 +214,13 @@ func TestRouteToServer(t *testing.T) { assert: require.False, }, { - name: "match public addr", + name: "match first public addr", + matcher: NewSSHRouteMatcher("node", "1234", true), + server: matchAddrServer, + assert: require.True, + }, + { + name: "match second public addr", matcher: NewSSHRouteMatcher("public.example.com", "1111", true), server: matchAddrServer, assert: require.True, diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 15eaafeae9999..5ba8716328fb3 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -1720,8 +1720,8 @@ type Node interface { GetTeleportVersion() string // GetAddr return server address GetAddr() string - // GetPublicAddr returns a public address where this server can be reached. - GetPublicAddr() string + // GetPublicAddrs returns all public addresses where this server can be reached. + GetPublicAddrs() []string // GetHostname returns server hostname GetHostname() string // GetNamespace returns server namespace @@ -1732,7 +1732,7 @@ type Node interface { GetRotation() types.Rotation // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. GetUseTunnel() bool - // GetProxyID returns a list of proxy ids this server is connected to. + // GetProxyIDs returns a list of proxy ids this server is connected to. GetProxyIDs() []string }