diff --git a/api/utils/route.go b/api/utils/route.go index 7af4d96510171..03df193065978 100644 --- a/api/utils/route.go +++ b/api/utils/route.go @@ -15,11 +15,14 @@ package utils import ( + "context" + "errors" "net" "slices" "unicode/utf8" "github.com/google/uuid" + "github.com/gravitational/trace" "github.com/gravitational/teleport/api/utils/aws" ) @@ -29,26 +32,68 @@ import ( // to let other parts of teleport easily find matching servers when generating // error messages or building access requests. type SSHRouteMatcher struct { - targetHost string - targetPort string - caseInsensitive bool - ips []string - matchServerIDs bool + cfg SSHRouteMatcherConfig + ips []string + matchServerIDs bool +} + +// SSHRouteMatcherConfig configures an SSHRouteMatcher. +type SSHRouteMatcherConfig struct { + // Host is the target host that we want to route to. + Host string + // Port is an optional target port. If empty or zero + // it will match servers listening on any port. + Port string + // Resolver can be set to override default hostname lookup + // behavior (used in tests). + Resolver HostResolver + // CaseInsensitive enabled case insensitive routing when true. + CaseInsensitive bool +} + +// HostResolver provides an interface matching the net.Resolver.LookupHost method. Typically +// only used as a means of overriding dns resolution behavior in tests. +type HostResolver interface { + // LookupHost performs a hostname lookup. See net.Resolver.LookupHost for details. + LookupHost(ctx context.Context, host string) (addrs []string, err error) +} + +var errEmptyHost = errors.New("cannot route to empty target host") + +// NewSSHRouteMatcherFromConfig sets up an ssh route matcher from the supplied configuration. +func NewSSHRouteMatcherFromConfig(cfg SSHRouteMatcherConfig) (*SSHRouteMatcher, error) { + if cfg.Host == "" { + return nil, trace.Wrap(errEmptyHost) + } + + if cfg.Resolver == nil { + cfg.Resolver = net.DefaultResolver + } + + m := newSSHRouteMatcher(cfg) + return &m, nil } // NewSSHRouteMatcher builds a new matcher for ssh routing decisions. func NewSSHRouteMatcher(host, port string, caseInsensitive bool) SSHRouteMatcher { - _, err := uuid.Parse(host) - dialByID := err == nil || aws.IsEC2NodeID(host) + return newSSHRouteMatcher(SSHRouteMatcherConfig{ + Host: host, + Port: port, + CaseInsensitive: caseInsensitive, + Resolver: net.DefaultResolver, + }) +} + +func newSSHRouteMatcher(cfg SSHRouteMatcherConfig) SSHRouteMatcher { + _, err := uuid.Parse(cfg.Host) + dialByID := err == nil || aws.IsEC2NodeID(cfg.Host) - ips, _ := net.LookupHost(host) + ips, _ := cfg.Resolver.LookupHost(context.Background(), cfg.Host) return SSHRouteMatcher{ - targetHost: host, - targetPort: port, - caseInsensitive: caseInsensitive, - ips: ips, - matchServerIDs: dialByID, + cfg: cfg, + ips: ips, + matchServerIDs: dialByID, } } @@ -64,10 +109,23 @@ type RouteableServer interface { // RouteToServer checks if this route matcher wants to route to the supplied server. func (m *SSHRouteMatcher) RouteToServer(server RouteableServer) bool { + return m.RouteToServerScore(server) > 0 +} + +const ( + notMatch = 0 + indirectMatch = 1 + directMatch = 2 +) + +// RouteToServerScore checks wether this route matcher wants to route to the supplied server +// and represents the result of that check as an integer score indicating the strength of the +// match. Positive scores indicate a match, higher being stronger. +func (m *SSHRouteMatcher) RouteToServerScore(server RouteableServer) (score int) { // if host is a UUID or EC2 ID match only // by server name and treat matches as unambiguous - if m.matchServerIDs && server.GetName() == m.targetHost { - return true + if m.matchServerIDs && server.GetName() == m.cfg.Host { + return directMatch } hostnameMatch := m.routeToHostname(server.GetHostname()) @@ -75,34 +133,46 @@ func (m *SSHRouteMatcher) RouteToServer(server RouteableServer) bool { // if the server has connected over a reverse tunnel // then match only by hostname. if server.GetUseTunnel() { - return hostnameMatch + if hostnameMatch { + return directMatch + } + return notMatch } - matchAddr := func(addr string) bool { + matchAddr := func(addr string) int { ip, nodePort, err := net.SplitHostPort(addr) if err != nil { - return false + return notMatch } - if (m.targetHost == ip || hostnameMatch || slices.Contains(m.ips, ip)) && - (m.targetPort == "" || m.targetPort == "0" || m.targetPort == nodePort) { - return true + if m.cfg.Port != "" && m.cfg.Port != "0" && m.cfg.Port != nodePort { + // if port is well-specified and does not match, don't bother + // continuing the check. + return notMatch } - return false - } + if hostnameMatch || m.cfg.Host == ip { + // server presents a hostname or addr that exactly matches + // our target. + return directMatch + } + + if slices.Contains(m.ips, ip) { + // server presents an addr that indirectly matches our target + // due to dns resolution. + return indirectMatch + } - if matchAddr(server.GetAddr()) { - return true + return notMatch } + score = matchAddr(server.GetAddr()) + for _, addr := range server.GetPublicAddrs() { - if matchAddr(addr) { - return true - } + score = max(score, matchAddr(addr)) } - return false + return score } // routeToHostname helps us perform a special kind of case-insensitive comparison. SSH certs do not generally @@ -112,21 +182,21 @@ func (m *SSHRouteMatcher) RouteToServer(server RouteableServer) bool { // the literal hostname and a lowered version of the hostname, meaning that it is sane to route a request for host 'foo' to // host 'Foo', but it is not sane to route a request for host 'Bar' to host 'bar'. func (m *SSHRouteMatcher) routeToHostname(principal string) bool { - if !m.caseInsensitive { - return m.targetHost == principal + if !m.cfg.CaseInsensitive { + return m.cfg.Host == principal } - if len(m.targetHost) != len(principal) { + if len(m.cfg.Host) != len(principal) { return false } // the below is modeled off of the fast ASCII path of strings.EqualFold - for i := 0; i < len(principal) && i < len(m.targetHost); i++ { + for i := 0; i < len(principal) && i < len(m.cfg.Host); i++ { pr := principal[i] - hr := m.targetHost[i] + hr := m.cfg.Host[i] if pr|hr >= utf8.RuneSelf { // not pure-ascii, fallback to literal comparison - return m.targetHost == principal + return m.cfg.Host == principal } // Easy case. @@ -146,7 +216,7 @@ func (m *SSHRouteMatcher) routeToHostname(principal string) bool { // IsEmpty checks if this route matcher has had a hostname set. func (m *SSHRouteMatcher) IsEmpty() bool { - return m.targetHost == "" + return m.cfg.Host == "" } // MatchesServerIDs checks if this matcher wants to perform server ID matching. diff --git a/api/utils/route_test.go b/api/utils/route_test.go index eff3762f021e6..5de05efa673d8 100644 --- a/api/utils/route_test.go +++ b/api/utils/route_test.go @@ -15,6 +15,7 @@ package utils import ( + "context" "testing" "github.com/google/uuid" @@ -244,3 +245,81 @@ func TestRouteToServer(t *testing.T) { }) } } + +type mockHostResolver struct { + ips []string +} + +func (r mockHostResolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) { + return r.ips, nil +} + +// TestSSHRouteMatcherScoring verifies the expected scoring behavior of SSHRouteMatcher. +func TestSSHRouteMatcherScoring(t *testing.T) { + t.Parallel() + + // set up matcher with mock resolver in order to control ips + matcher, err := NewSSHRouteMatcherFromConfig(SSHRouteMatcherConfig{ + Host: "foo.example.com", + Resolver: mockHostResolver{ + ips: []string{ + "1.2.3.4", + "4.5.6.7", + }, + }, + }) + require.NoError(t, err) + + tts := []struct { + desc string + hostname string + addrs []string + score int + }{ + { + desc: "multi factor match", + hostname: "foo.example.com", + addrs: []string{ + "1.2.3.4:0", + }, + score: directMatch, + }, + { + desc: "ip match only", + hostname: "bar.example.com", + addrs: []string{ + "1.2.3.4:0", + }, + score: indirectMatch, + }, + { + desc: "hostname match only", + hostname: "foo.example.com", + addrs: []string{ + "7.7.7.7:0", + }, + score: directMatch, + }, + { + desc: "not match", + hostname: "bar.example.com", + addrs: []string{ + "0.0.0.0:0", + "1.1.1.1:0", + }, + score: notMatch, + }, + } + + for _, tt := range tts { + t.Run(tt.desc, func(t *testing.T) { + score := matcher.RouteToServerScore(mockRouteableServer{ + name: uuid.NewString(), + hostname: tt.hostname, + publicAddr: tt.addrs, + }) + + require.Equal(t, tt.score, score) + }) + } +} diff --git a/lib/proxy/router.go b/lib/proxy/router.go index bbd67acf6d351..3e29c4078b366 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -417,6 +417,13 @@ func (r remoteSite) GetClusterNetworkingConfig(ctx context.Context) (types.Clust // getServer attempts to locate a node matching the provided host and port in // the provided site. func getServer(ctx context.Context, host, port string, site site) (types.Server, error) { + return getServerWithResolver(ctx, host, port, site, nil /* use default resolver */) +} + +// getServerWithResolver attempts to locate a node matching the provided host and port in +// the provided site. The resolver argument is used in certain tests to mock DNS resolution +// and can generally be left nil. +func getServerWithResolver(ctx context.Context, host, port string, site site, resolver apiutils.HostResolver) (types.Server, error) { if site == nil { return nil, trace.BadParameter("invalid remote site provided") } @@ -428,10 +435,27 @@ func getServer(ctx context.Context, host, port string, site site) (types.Server, caseInsensitiveRouting = cfg.GetCaseInsensitiveRouting() } - routeMatcher := apiutils.NewSSHRouteMatcher(host, port, caseInsensitiveRouting) + routeMatcher, err := apiutils.NewSSHRouteMatcherFromConfig(apiutils.SSHRouteMatcherConfig{ + Host: host, + Port: port, + CaseInsensitive: caseInsensitiveRouting, + Resolver: resolver, + }) + if err != nil { + return nil, trace.Wrap(err) + } + var maxScore int + scores := make(map[string]int) matches, err := site.GetNodes(ctx, func(server services.Node) bool { - return routeMatcher.RouteToServer(server) + score := routeMatcher.RouteToServerScore(server) + if score < 1 { + return false + } + + scores[server.GetName()] = score + maxScore = max(maxScore, score) + return true }) if err != nil { return nil, trace.Wrap(err) @@ -449,6 +473,21 @@ func getServer(ctx context.Context, host, port string, site site) (types.Server, } } + if len(matches) > 1 { + // in the event of multiple matches, some matches may be of higher quality than others + // (e.g. matching an ip/hostname directly versus matching a resolved ip). if we have a + // mix of match qualities, filter out the lower quality matches to reduce ambiguity. + filtered := matches[:0] + for _, m := range matches { + if scores[m.GetName()] < maxScore { + continue + } + + filtered = append(filtered, m) + } + matches = filtered + } + var server types.Server switch { case strategy == types.RoutingStrategy_MOST_RECENT: diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go index b668df5cf72f3..c3d21f15c0fa3 100644 --- a/lib/proxy/router_test.go +++ b/lib/proxy/router_test.go @@ -90,6 +90,135 @@ func createServers(srvs []server) []types.Server { return out } +type mockHostResolver struct { + hosts map[string][]string +} + +func (r *mockHostResolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) { + return r.hosts[host], nil +} + +// TestRouteScoring verifies expected behavior in the specific cases where multiple matches +// of different quality are made. +func TestRouteScoring(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // set up various servers with overlapping IPs and hostnames + servers := createServers([]server{ + { + name: uuid.NewString(), + hostname: "one.example.com", + addr: "1.2.3.4:123", + }, + { + name: uuid.NewString(), + hostname: "two.example.com", + addr: "1.2.3.4:456", + }, + { + name: uuid.NewString(), + hostname: "dupe.example.com", + addr: "1.2.3.4:789", + }, + { + name: uuid.NewString(), + hostname: "dupe.example.com", + addr: "1.2.3.4:1011", + }, + { + name: uuid.NewString(), + hostname: "blue.example.com", + addr: "2.3.4.5:22", + }, + }) + + // scoring behavior is independent of routing strategy so we just + // use the most strict config for all cases. + site := &testSite{ + cfg: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + RoutingStrategy: types.RoutingStrategy_UNAMBIGUOUS_MATCH, + }, + }, + nodes: servers, + } + + // set up resolver + resolver := &mockHostResolver{ + hosts: map[string][]string{ + // register a hostname that only indirectly maps to a node + "red.example.com": []string{"2.3.4.5"}, + }, + } + + for _, s := range servers { + resolver.hosts[s.GetHostname()] = []string{"1.2.3.4"} + } + + tts := []struct { + desc string + host, port string + expect string + ambiguous bool + }{ + { + // this is the primary case that route scoring was implemented to solve. prior to scoring, + // dialing by a hostname that is itself unambiguous but resolves to an ip that + // *is* ambiguous would result in an unexpected ambiguous host error, despite the fact that + // what the user typed in was clearly unambiguous. + desc: "dial by hostname", + host: "one.example.com", + expect: "one.example.com", + }, + { + desc: "dial by ip only", + host: "2.3.4.5", + expect: "blue.example.com", + }, + { + desc: "dial by ip and port", + host: "1.2.3.4", + port: "456", + expect: "two.example.com", + }, + { + desc: "ambiguous hostname dial", + host: "dupe.example.com", + ambiguous: true, + }, + { + desc: "ambiguous ip dial", + host: "1.2.3.4", + ambiguous: true, + }, + { + desc: "disambiguate by port", + host: "dupe.example.com", + port: "789", + expect: "dupe.example.com", + }, + { + desc: "indirect ip resolve", + host: "red.example.com", + expect: "blue.example.com", + }, + } + + for _, tt := range tts { + t.Run(tt.desc, func(t *testing.T) { + srv, err := getServerWithResolver(ctx, tt.host, tt.port, site, resolver) + if tt.ambiguous { + require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous)) + return + } + require.Equal(t, tt.expect, srv.GetHostname()) + }) + } +} + func TestGetServers(t *testing.T) { t.Parallel()