Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 107 additions & 35 deletions api/utils/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
package utils

import (
"context"
"errors"
"net"
"unicode/utf8"

"github.com/google/uuid"
"github.com/gravitational/trace"
"golang.org/x/exp/slices"

"github.com/gravitational/teleport/api/utils/aws"
Expand All @@ -29,26 +32,68 @@ import (
// to let other parts of teleport easily find matching servers when generating
// error messages or building access requests.
type SSHRouteMatcher struct {
targetHost string
targetPort string
caseInsensitive bool
ips []string
matchServerIDs bool
cfg SSHRouteMatcherConfig
ips []string
matchServerIDs bool
}

// SSHRouteMatcherConfig configures an SSHRouteMatcher.
type SSHRouteMatcherConfig struct {
// Host is the target host that we want to route to.
Host string
// Port is an optional target port. If empty or zero
// it will match servers listening on any port.
Port string
// Resolver can be set to override default hostname lookup
// behavior (used in tests).
Resolver HostResolver
// CaseInsensitive enabled case insensitive routing when true.
CaseInsensitive bool
}

// HostResolver provides an interface matching the net.Resolver.LookupHost method. Typically
// only used as a means of overriding dns resolution behavior in tests.
type HostResolver interface {
// LookupHost performs a hostname lookup. See net.Resolver.LookupHost for details.
LookupHost(ctx context.Context, host string) (addrs []string, err error)
}

var errEmptyHost = errors.New("cannot route to empty target host")

// NewSSHRouteMatcherFromConfig sets up an ssh route matcher from the supplied configuration.
func NewSSHRouteMatcherFromConfig(cfg SSHRouteMatcherConfig) (*SSHRouteMatcher, error) {
if cfg.Host == "" {
return nil, trace.Wrap(errEmptyHost)
}

if cfg.Resolver == nil {
cfg.Resolver = net.DefaultResolver
}

m := newSSHRouteMatcher(cfg)
return &m, nil
}

// NewSSHRouteMatcher builds a new matcher for ssh routing decisions.
func NewSSHRouteMatcher(host, port string, caseInsensitive bool) SSHRouteMatcher {
_, err := uuid.Parse(host)
dialByID := err == nil || aws.IsEC2NodeID(host)
return newSSHRouteMatcher(SSHRouteMatcherConfig{
Host: host,
Port: port,
CaseInsensitive: caseInsensitive,
Resolver: net.DefaultResolver,
})
}

ips, _ := net.LookupHost(host)
func newSSHRouteMatcher(cfg SSHRouteMatcherConfig) SSHRouteMatcher {
_, err := uuid.Parse(cfg.Host)
dialByID := err == nil || aws.IsEC2NodeID(cfg.Host)

ips, _ := cfg.Resolver.LookupHost(context.Background(), cfg.Host)

return SSHRouteMatcher{
targetHost: host,
targetPort: port,
caseInsensitive: caseInsensitive,
ips: ips,
matchServerIDs: dialByID,
cfg: cfg,
ips: ips,
matchServerIDs: dialByID,
}
}

Expand All @@ -64,45 +109,72 @@ type RouteableServer interface {

// RouteToServer checks if this route matcher wants to route to the supplied server.
func (m *SSHRouteMatcher) RouteToServer(server RouteableServer) bool {
return m.RouteToServerScore(server) > 0
}

const (
notMatch = 0
indirectMatch = 1
directMatch = 2
)

// RouteToServerScore checks wether this route matcher wants to route to the supplied server
// and represents the result of that check as an integer score indicating the strength of the
// match. Positive scores indicate a match, higher being stronger.
func (m *SSHRouteMatcher) RouteToServerScore(server RouteableServer) (score int) {
// if host is a UUID or EC2 ID match only
// by server name and treat matches as unambiguous
if m.matchServerIDs && server.GetName() == m.targetHost {
return true
if m.matchServerIDs && server.GetName() == m.cfg.Host {
return directMatch
}

hostnameMatch := m.routeToHostname(server.GetHostname())

// if the server has connected over a reverse tunnel
// then match only by hostname.
if server.GetUseTunnel() {
return hostnameMatch
if hostnameMatch {
return directMatch
}
return notMatch
}

matchAddr := func(addr string) bool {
matchAddr := func(addr string) int {
ip, nodePort, err := net.SplitHostPort(addr)
if err != nil {
return false
return notMatch
}

if (m.targetHost == ip || hostnameMatch || slices.Contains(m.ips, ip)) &&
(m.targetPort == "" || m.targetPort == "0" || m.targetPort == nodePort) {
return true
if m.cfg.Port != "" && m.cfg.Port != "0" && m.cfg.Port != nodePort {
// if port is well-specified and does not match, don't bother
// continuing the check.
return notMatch
}

return false
}
if hostnameMatch || m.cfg.Host == ip {
// server presents a hostname or addr that exactly matches
// our target.
return directMatch
}

if matchAddr(server.GetAddr()) {
return true
if slices.Contains(m.ips, ip) {
// server presents an addr that indirectly matches our target
// due to dns resolution.
return indirectMatch
}

return notMatch
}

score = matchAddr(server.GetAddr())

for _, addr := range server.GetPublicAddrs() {
if matchAddr(addr) {
return true
if s := matchAddr(addr); s > score {
score = s
}
}

return false
return score
}

// routeToHostname helps us perform a special kind of case-insensitive comparison. SSH certs do not generally
Expand All @@ -112,21 +184,21 @@ func (m *SSHRouteMatcher) RouteToServer(server RouteableServer) bool {
// the literal hostname and a lowered version of the hostname, meaning that it is sane to route a request for host 'foo' to
// host 'Foo', but it is not sane to route a request for host 'Bar' to host 'bar'.
func (m *SSHRouteMatcher) routeToHostname(principal string) bool {
if !m.caseInsensitive {
return m.targetHost == principal
if !m.cfg.CaseInsensitive {
return m.cfg.Host == principal
}

if len(m.targetHost) != len(principal) {
if len(m.cfg.Host) != len(principal) {
return false
}

// the below is modeled off of the fast ASCII path of strings.EqualFold
for i := 0; i < len(principal) && i < len(m.targetHost); i++ {
for i := 0; i < len(principal) && i < len(m.cfg.Host); i++ {
pr := principal[i]
hr := m.targetHost[i]
hr := m.cfg.Host[i]
if pr|hr >= utf8.RuneSelf {
// not pure-ascii, fallback to literal comparison
return m.targetHost == principal
return m.cfg.Host == principal
}

// Easy case.
Expand All @@ -146,7 +218,7 @@ func (m *SSHRouteMatcher) routeToHostname(principal string) bool {

// IsEmpty checks if this route matcher has had a hostname set.
func (m *SSHRouteMatcher) IsEmpty() bool {
return m.targetHost == ""
return m.cfg.Host == ""
}

// MatchesServerIDs checks if this matcher wants to perform server ID matching.
Expand Down
79 changes: 79 additions & 0 deletions api/utils/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package utils

import (
"context"
"testing"

"github.com/google/uuid"
Expand Down Expand Up @@ -244,3 +245,81 @@ func TestRouteToServer(t *testing.T) {
})
}
}

type mockHostResolver struct {
ips []string
}

func (r mockHostResolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
return r.ips, nil
}

// TestSSHRouteMatcherScoring verifies the expected scoring behavior of SSHRouteMatcher.
func TestSSHRouteMatcherScoring(t *testing.T) {
t.Parallel()

// set up matcher with mock resolver in order to control ips
matcher, err := NewSSHRouteMatcherFromConfig(SSHRouteMatcherConfig{
Host: "foo.example.com",
Resolver: mockHostResolver{
ips: []string{
"1.2.3.4",
"4.5.6.7",
},
},
})
require.NoError(t, err)

tts := []struct {
desc string
hostname string
addrs []string
score int
}{
{
desc: "multi factor match",
hostname: "foo.example.com",
addrs: []string{
"1.2.3.4:0",
},
score: directMatch,
},
{
desc: "ip match only",
hostname: "bar.example.com",
addrs: []string{
"1.2.3.4:0",
},
score: indirectMatch,
},
{
desc: "hostname match only",
hostname: "foo.example.com",
addrs: []string{
"7.7.7.7:0",
},
score: directMatch,
},
{
desc: "not match",
hostname: "bar.example.com",
addrs: []string{
"0.0.0.0:0",
"1.1.1.1:0",
},
score: notMatch,
},
}

for _, tt := range tts {
t.Run(tt.desc, func(t *testing.T) {
score := matcher.RouteToServerScore(mockRouteableServer{
name: uuid.NewString(),
hostname: tt.hostname,
publicAddr: tt.addrs,
})

require.Equal(t, tt.score, score)
})
}
}
43 changes: 41 additions & 2 deletions lib/proxy/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ func (r remoteSite) GetClusterNetworkingConfig(ctx context.Context, opts ...serv
// getServer attempts to locate a node matching the provided host and port in
// the provided site.
func getServer(ctx context.Context, host, port string, site site) (types.Server, error) {
return getServerWithResolver(ctx, host, port, site, nil /* use default resolver */)
}

// getServerWithResolver attempts to locate a node matching the provided host and port in
// the provided site. The resolver argument is used in certain tests to mock DNS resolution
// and can generally be left nil.
func getServerWithResolver(ctx context.Context, host, port string, site site, resolver apiutils.HostResolver) (types.Server, error) {
if site == nil {
return nil, trace.BadParameter("invalid remote site provided")
}
Expand All @@ -459,10 +466,27 @@ func getServer(ctx context.Context, host, port string, site site) (types.Server,
caseInsensitiveRouting = cfg.GetCaseInsensitiveRouting()
}

routeMatcher := apiutils.NewSSHRouteMatcher(host, port, caseInsensitiveRouting)
routeMatcher, err := apiutils.NewSSHRouteMatcherFromConfig(apiutils.SSHRouteMatcherConfig{
Host: host,
Port: port,
CaseInsensitive: caseInsensitiveRouting,
Resolver: resolver,
})
if err != nil {
return nil, trace.Wrap(err)
}

var maxScore int
scores := make(map[string]int)
matches, err := site.GetNodes(ctx, func(server services.Node) bool {
return routeMatcher.RouteToServer(server)
score := routeMatcher.RouteToServerScore(server)
if score < 1 {
return false
}

scores[server.GetName()] = score
maxScore = max(maxScore, score)
return true
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -480,6 +504,21 @@ func getServer(ctx context.Context, host, port string, site site) (types.Server,
}
}

if len(matches) > 1 {
// in the event of multiple matches, some matches may be of higher quality than others
// (e.g. matching an ip/hostname directly versus matching a resolved ip). if we have a
// mix of match qualities, filter out the lower quality matches to reduce ambiguity.
filtered := matches[:0]
for _, m := range matches {
if scores[m.GetName()] < maxScore {
continue
}

filtered = append(filtered, m)
}
matches = filtered
}

var server types.Server
switch {
case strategy == types.RoutingStrategy_MOST_RECENT:
Expand Down
Loading