diff --git a/lib/proxy/clusterdial/dial.go b/lib/proxy/clusterdial/dial.go index 13822db268f72..59a01e9a97dd8 100644 --- a/lib/proxy/clusterdial/dial.go +++ b/lib/proxy/clusterdial/dial.go @@ -19,21 +19,21 @@ import ( "github.com/gravitational/trace" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/reversetunnel" ) -// ClusterDialerFunc is a function that implements a proxy.ClusterDialer. -type ClusterDialerFunc func(clusterName string, request proxy.DialParams) (net.Conn, error) +// ClusterDialerFunc is a function that implements a peer.ClusterDialer. +type ClusterDialerFunc func(clusterName string, request peer.DialParams) (net.Conn, error) // Dial dials makes a dial request to the given cluster. -func (f ClusterDialerFunc) Dial(clusterName string, request proxy.DialParams) (net.Conn, error) { +func (f ClusterDialerFunc) Dial(clusterName string, request peer.DialParams) (net.Conn, error) { return f(clusterName, request) } // NewClusterDialer implements proxy.ClusterDialer for a reverse tunnel server. func NewClusterDialer(server reversetunnel.Server) ClusterDialerFunc { - return ClusterDialerFunc(func(clusterName string, request proxy.DialParams) (net.Conn, error) { + return func(clusterName string, request peer.DialParams) (net.Conn, error) { site, err := server.GetSite(clusterName) if err != nil { return nil, trace.Wrap(err) @@ -52,5 +52,5 @@ func NewClusterDialer(server reversetunnel.Server) ClusterDialerFunc { return nil, trace.Wrap(err) } return conn, nil - }) + } } diff --git a/lib/proxy/auth.go b/lib/proxy/peer/auth.go similarity index 99% rename from lib/proxy/auth.go rename to lib/proxy/peer/auth.go index 622258aa2f22a..537f714dd1ef3 100644 --- a/lib/proxy/auth.go +++ b/lib/proxy/peer/auth.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/client.go b/lib/proxy/peer/client.go similarity index 99% rename from lib/proxy/client.go rename to lib/proxy/peer/client.go index cd83b05a6caf4..b80dc83a77e4e 100644 --- a/lib/proxy/client.go +++ b/lib/proxy/peer/client.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/client_test.go b/lib/proxy/peer/client_test.go similarity index 99% rename from lib/proxy/client_test.go rename to lib/proxy/peer/client_test.go index 6253a2c3c8764..fa942c372ba3e 100644 --- a/lib/proxy/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "crypto/tls" diff --git a/lib/proxy/clientmetrics.go b/lib/proxy/peer/clientmetrics.go similarity index 99% rename from lib/proxy/clientmetrics.go rename to lib/proxy/peer/clientmetrics.go index da6c0c0676143..e67d571d47e7f 100644 --- a/lib/proxy/clientmetrics.go +++ b/lib/proxy/peer/clientmetrics.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "github.com/gravitational/trace" diff --git a/lib/proxy/conn.go b/lib/proxy/peer/conn.go similarity index 99% rename from lib/proxy/conn.go rename to lib/proxy/peer/conn.go index 1fc1f5d5bc16e..0e7bf1bfde5a4 100644 --- a/lib/proxy/conn.go +++ b/lib/proxy/peer/conn.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/conn_test.go b/lib/proxy/peer/conn_test.go similarity index 99% rename from lib/proxy/conn_test.go rename to lib/proxy/peer/conn_test.go index 751e61d1fc968..fbe0bfbbba334 100644 --- a/lib/proxy/conn_test.go +++ b/lib/proxy/peer/conn_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/helpers_test.go b/lib/proxy/peer/helpers_test.go similarity index 99% rename from lib/proxy/helpers_test.go rename to lib/proxy/peer/helpers_test.go index afb2d10b43ef8..60036def48a47 100644 --- a/lib/proxy/helpers_test.go +++ b/lib/proxy/peer/helpers_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "crypto/rand" diff --git a/lib/proxy/interceptor.go b/lib/proxy/peer/interceptor.go similarity index 99% rename from lib/proxy/interceptor.go rename to lib/proxy/peer/interceptor.go index de838322c48ff..19f103b9b8aac 100644 --- a/lib/proxy/interceptor.go +++ b/lib/proxy/peer/interceptor.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/reporter.go b/lib/proxy/peer/reporter.go similarity index 99% rename from lib/proxy/reporter.go rename to lib/proxy/peer/reporter.go index 068d2b09b77d2..3656fc1f99899 100644 --- a/lib/proxy/reporter.go +++ b/lib/proxy/peer/reporter.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "time" diff --git a/lib/proxy/server.go b/lib/proxy/peer/server.go similarity index 99% rename from lib/proxy/server.go rename to lib/proxy/peer/server.go index 0314e2bd57db2..82a4671a3e629 100644 --- a/lib/proxy/server.go +++ b/lib/proxy/peer/server.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "crypto/tls" diff --git a/lib/proxy/server_test.go b/lib/proxy/peer/server_test.go similarity index 99% rename from lib/proxy/server_test.go rename to lib/proxy/peer/server_test.go index 546f99de288c9..66a4845ec0bf2 100644 --- a/lib/proxy/server_test.go +++ b/lib/proxy/peer/server_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "testing" diff --git a/lib/proxy/servermetrics.go b/lib/proxy/peer/servermetrics.go similarity index 99% rename from lib/proxy/servermetrics.go rename to lib/proxy/peer/servermetrics.go index 5adc1c4f4bc90..bd7bd85170229 100644 --- a/lib/proxy/servermetrics.go +++ b/lib/proxy/peer/servermetrics.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "github.com/gravitational/trace" diff --git a/lib/proxy/service.go b/lib/proxy/peer/service.go similarity index 99% rename from lib/proxy/service.go rename to lib/proxy/peer/service.go index 053c098fea3de..600ef470cb922 100644 --- a/lib/proxy/service.go +++ b/lib/proxy/peer/service.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "net" diff --git a/lib/proxy/service_test.go b/lib/proxy/peer/service_test.go similarity index 99% rename from lib/proxy/service_test.go rename to lib/proxy/peer/service_test.go index e6f9b576548c3..0cba1b7e36f3b 100644 --- a/lib/proxy/service_test.go +++ b/lib/proxy/peer/service_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/stats.go b/lib/proxy/peer/stats.go similarity index 98% rename from lib/proxy/stats.go rename to lib/proxy/peer/stats.go index 0d37b741718e9..dc0cd89501308 100644 --- a/lib/proxy/stats.go +++ b/lib/proxy/peer/stats.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" @@ -51,7 +51,7 @@ func (s *statsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) con return ctx } -// HandleRPC implements per-Connection stats reporting. +// HandleConn implements per-Connection stats reporting. func (s *statsHandler) HandleConn(ctx context.Context, connStats stats.ConnStats) { // client connection stats are monitored by the monitor() function in client.go if connStats.IsClient() { diff --git a/lib/proxy/router.go b/lib/proxy/router.go new file mode 100644 index 0000000000000..6925147d345fe --- /dev/null +++ b/lib/proxy/router.go @@ -0,0 +1,393 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "fmt" + "net" + "strconv" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/exp/slices" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/teleagent" + "github.com/gravitational/teleport/lib/utils" +) + +type serverResolverFn = func(ctx context.Context, host, port string, site site) (types.Server, error) + +// SiteGetter provides access to connected local or remote sites +type SiteGetter interface { + // GetSite returns the site matching the provided clusterName + GetSite(clusterName string) (reversetunnel.RemoteSite, error) +} + +// RemoteClusterGetter provides access to remote cluster resources +type RemoteClusterGetter interface { + // GetRemoteCluster returns a remote cluster by name + GetRemoteCluster(clusterName string) (types.RemoteCluster, error) +} + +// RouterConfig contains all the dependencies required +// by the Router +type RouterConfig struct { + // ClusterName indicates which cluster the router is for + ClusterName string + // Log is the logger to use + Log *logrus.Entry + // AccessPoint is the proxy cache + RemoteClusterGetter RemoteClusterGetter + // SiteGetter allows looking up sites + SiteGetter SiteGetter + // TracerProvider allows tracers to be created + TracerProvider oteltrace.TracerProvider + + // serverResolver is used to resolve hosts, used by tests + serverResolver serverResolverFn +} + +// CheckAndSetDefaults ensures the required items were populated +func (c *RouterConfig) CheckAndSetDefaults() error { + if c.Log == nil { + c.Log = logrus.WithField(trace.Component, "Router") + } + + if c.ClusterName == "" { + return trace.BadParameter("ClusterName must be provided") + } + + if c.RemoteClusterGetter == nil { + return trace.BadParameter("RemoteClusterGetter must be provided") + } + + if c.SiteGetter == nil { + return trace.BadParameter("SiteGetter must be provided") + } + + if c.TracerProvider == nil { + c.TracerProvider = tracing.DefaultProvider() + } + + if c.serverResolver == nil { + c.serverResolver = getServer + } + + return nil +} + +// Router is used by the proxy to establish connections to both +// nodes and other clusters. +type Router struct { + clusterName string + log *logrus.Entry + clusterGetter RemoteClusterGetter + localSite reversetunnel.RemoteSite + siteGetter SiteGetter + tracer oteltrace.Tracer + serverResolver serverResolverFn +} + +// NewRouter creates and returns a Router that is populated +// from the provided RouterConfig. +func NewRouter(cfg RouterConfig) (*Router, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + localSite, err := cfg.SiteGetter.GetSite(cfg.ClusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + return &Router{ + clusterName: cfg.ClusterName, + log: cfg.Log, + clusterGetter: cfg.RemoteClusterGetter, + localSite: localSite, + siteGetter: cfg.SiteGetter, + tracer: cfg.TracerProvider.Tracer("Router"), + serverResolver: cfg.serverResolver, + }, nil + +} + +// DialHost dials the node that matches the provided host, port and cluster. If no matching node +// is found an error is returned. If more than one matching node is found and the cluster networking +// configuration is not set to route to the most recent an error is returned. +func (r *Router) DialHost(ctx context.Context, from net.Addr, host, port, clusterName string, accessChecker services.AccessChecker, agentGetter teleagent.Getter) (net.Conn, error) { + ctx, span := r.tracer.Start( + ctx, + "router/DialHost", + oteltrace.WithAttributes( + attribute.String("host", host), + attribute.String("port", port), + attribute.String("site", clusterName), + ), + ) + defer span.End() + + site := r.localSite + if clusterName != r.clusterName { + remoteSite, err := r.getRemoteCluster(ctx, clusterName, accessChecker) + if err != nil { + return nil, trace.Wrap(err) + } + site = remoteSite + } + + span.AddEvent("looking up server") + target, err := r.serverResolver(ctx, host, port, remoteSite{site}) + if err != nil { + return nil, trace.Wrap(err) + } + span.AddEvent("retrieved target server") + + principals := []string{host} + + var ( + serverID string + serverAddr string + proxyIDs []string + ) + if target != nil { + proxyIDs = target.GetProxyIDs() + serverID = fmt.Sprintf("%v.%v", target.GetName(), clusterName) + + // add hostUUID.cluster to the principals + principals = append(principals, serverID) + + // add ip if it exists to the principals + serverAddr = target.GetAddr() + + switch { + case serverAddr != "": + h, _, err := net.SplitHostPort(serverAddr) + if err != nil { + return nil, trace.Wrap(err) + } + + principals = append(principals, h) + case serverAddr == "" && target.GetUseTunnel(): + serverAddr = reversetunnel.LocalNode + } + } else { + if port == "" || port == "0" { + port = strconv.Itoa(defaults.SSHServerListenPort) + } + + serverAddr = net.JoinHostPort(host, port) + r.log.Warnf("server lookup failed: using default=%v", serverAddr) + } + + conn, err := site.Dial(reversetunnel.DialParams{ + From: from, + To: &utils.NetAddr{AddrNetwork: "tcp", Addr: serverAddr}, + GetUserAgent: agentGetter, + Address: host, + ServerID: serverID, + ProxyIDs: proxyIDs, + Principals: principals, + ConnType: types.NodeTunnel, + }) + + return conn, trace.Wrap(err) +} + +// getRemoteCluster looks up the provided clusterName to determine if a remote site exists with +// that name and determines if the user has access to it. +func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, checker services.AccessChecker) (reversetunnel.RemoteSite, error) { + _, span := r.tracer.Start( + ctx, + "router/getRemoteCluster", + oteltrace.WithAttributes( + attribute.String("site", clusterName), + ), + ) + defer span.End() + + site, err := r.siteGetter.GetSite(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + rc, err := r.clusterGetter.GetRemoteCluster(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := checker.CheckAccessToRemoteCluster(rc); err != nil { + return nil, utils.OpaqueAccessDenied(err) + } + + return site, nil +} + +// site is the minimum interface needed to match servers +// for a reversetunnel.RemoteSite. It makes testing easier. +type site interface { + GetNodes(fn func(n services.Node) bool) ([]types.Server, error) + GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) +} + +// remoteSite is a site implementation that wraps +// a reversetunnel.RemoteSite +type remoteSite struct { + site reversetunnel.RemoteSite +} + +// GetNodes uses the wrapped sites NodeWatcher to filter nodes +func (r remoteSite) GetNodes(fn func(n services.Node) bool) ([]types.Server, error) { + watcher, err := r.site.NodeWatcher() + if err != nil { + return nil, trace.Wrap(err) + } + + return watcher.GetNodes(fn), nil +} + +// GetClusterNetworkingConfig uses the wrapped sites cache to retrieve the ClusterNetworkingConfig +func (r remoteSite) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) { + ap, err := r.site.CachingAccessPoint() + if err != nil { + return nil, trace.Wrap(err) + } + + cfg, err := ap.GetClusterNetworkingConfig(ctx, opts...) + return cfg, trace.Wrap(err) +} + +// getServer attempts to locate a node matching the provided host and port in +// the provided site. +func getServer(ctx context.Context, host, port string, site site) (types.Server, error) { + if site == nil { + return nil, trace.BadParameter("invalid remote site provided") + } + + strategy := types.RoutingStrategy_UNAMBIGUOUS_MATCH + if cfg, err := site.GetClusterNetworkingConfig(ctx); err == nil { + strategy = cfg.GetRoutingStrategy() + } + + _, err := uuid.Parse(host) + dialByID := err == nil || utils.IsEC2NodeID(host) + + ips, _ := net.LookupHost(host) + + var unambiguousIDMatch bool + matches, err := site.GetNodes(func(server services.Node) bool { + if unambiguousIDMatch { + return false + } + + // if host is a UUID or EC2 ID match only + // by server name and treat matches as unambiguous + if dialByID && server.GetName() == host { + unambiguousIDMatch = true + return true + } + + // if the server has connected over a reverse tunnel + // then match only by hostname + if server.GetUseTunnel() { + return host == server.GetHostname() + } + + ip, nodePort, err := net.SplitHostPort(server.GetAddr()) + if err != nil { + return false + } + + if (host == ip || host == server.GetHostname() || slices.Contains(ips, ip)) && + (port == "" || port == "0" || port == nodePort) { + return true + } + + return false + }) + if err != nil { + return nil, trace.Wrap(err) + } + + var server types.Server + switch { + case strategy == types.RoutingStrategy_MOST_RECENT: + for _, m := range matches { + if server == nil || m.Expiry().After(server.Expiry()) { + server = m + } + } + case len(matches) > 1: + return nil, trace.NotFound(teleport.NodeIsAmbiguous) + case len(matches) == 1: + server = matches[0] + } + + if dialByID && server == nil { + idType := "UUID" + if utils.IsEC2NodeID(host) { + idType = "EC2" + } + + return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, host) + } + + return server, nil + +} + +// DialSite establishes a connection to the auth server in the provided +// cluster. If the clusterName is an empty string then a connection to +// the local auth server will be established. +func (r *Router) DialSite(ctx context.Context, clusterName string) (net.Conn, error) { + _, span := r.tracer.Start( + ctx, + "router/DialSite", + oteltrace.WithAttributes( + attribute.String("site", clusterName), + ), + ) + defer span.End() + + // default to local cluster if one wasn't provided + if clusterName == "" { + clusterName = r.clusterName + } + + // dial the local auth server + if clusterName == r.clusterName { + conn, err := r.localSite.DialAuthServer() + return conn, trace.Wrap(err) + } + + // lookup the site and dial its auth server + site, err := r.siteGetter.GetSite(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + conn, err := site.DialAuthServer() + return conn, trace.Wrap(err) +} diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go new file mode 100644 index 0000000000000..019e665517b8d --- /dev/null +++ b/lib/proxy/router_test.go @@ -0,0 +1,509 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "net" + "testing" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/observability/tracing" + "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" +) + +type testSite struct { + cfg types.ClusterNetworkingConfig + nodes []types.Server +} + +func (t testSite) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) { + return t.cfg, nil +} + +func (t testSite) GetNodes(fn func(n services.Node) bool) ([]types.Server, error) { + var out []types.Server + for _, s := range t.nodes { + if fn(s) { + out = append(out, s) + } + } + + return out, nil +} + +type server struct { + name string + hostname string + addr string + tunnel bool +} + +func createServers(srvs []server) []types.Server { + out := make([]types.Server, 0, len(srvs)) + for _, s := range srvs { + srv := &types.ServerV2{ + Kind: types.KindNode, + Version: types.V2, + Metadata: types.Metadata{ + Name: s.name, + }, + Spec: types.ServerSpecV2{ + Addr: s.addr, + Hostname: s.hostname, + UseTunnel: s.tunnel, + }, + } + out = append(out, srv) + } + + return out +} + +func TestGetServers(t *testing.T) { + t.Parallel() + + mostRecentCfg := types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + RoutingStrategy: types.RoutingStrategy_MOST_RECENT, + }, + } + + unambiguousCfg := types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + RoutingStrategy: types.RoutingStrategy_UNAMBIGUOUS_MATCH, + }, + } + + hostID := uuid.NewString() + const ec2ID = "012345678901-i-01234567890abcdef" + + servers := createServers([]server{ + { + name: hostID, + hostname: "llama", + addr: "llama:123", + }, + { + name: "llama", + hostname: "llama", + addr: "llama:123", + tunnel: true, + }, + { + name: "llama", + hostname: hostID, + addr: "llama:123", + }, + { + name: ec2ID, + hostname: "node.aws", + addr: "node.aws:123", + }, + { + name: "node.aws", + hostname: "node.aws", + addr: "node.aws:123", + tunnel: true, + }, + { + name: "node.aws", + hostname: ec2ID, + addr: "node.aws:123", + }, + { + name: "alpaca", + hostname: "alpaca", + addr: "alpaca:123", + tunnel: true, + }, + { + name: "alpaca", + hostname: "localhost", + addr: "alpaca:987", + tunnel: true, + }, + { + name: "goat", + hostname: "goat", + addr: "1.2.3.4:123", + }, + { + name: "sheep", + hostname: "sheep", + addr: "sheep.bah:0", + }, + { + name: "sheep2", + hostname: "sheep", + addr: "sheep.bah:0", + }, + { + name: "lion", + hostname: "lion", + addr: "lion.roar", + }, + }) + + cases := []struct { + name string + host string + port string + site testSite + errAssertion require.ErrorAssertionFunc + serverAssertion func(t *testing.T, srv types.Server) + }{ + { + name: "no matches for hostname", + site: testSite{cfg: &unambiguousCfg}, + host: "test", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "no matches for uuid", + site: testSite{cfg: &mostRecentCfg}, + host: uuid.NewString(), + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsNotFound(err), i...) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "no matches for ec2 id", + site: testSite{cfg: &unambiguousCfg}, + host: "123456789012-i-1234567890abcdef0", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsNotFound(err), i...) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "ambiguous match fails", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "sheep", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous)) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "ambiguous match returns most recent", + site: testSite{cfg: &mostRecentCfg, nodes: servers}, + host: "sheep", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "sheep", srv.GetHostname()) + }, + }, + { + name: "match by uuid", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: hostID, + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "llama", srv.GetHostname()) + }, + }, + { + name: "match by ec2 id", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: ec2ID, + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "node.aws", srv.GetHostname()) + }, + }, + { + name: "match by ip", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "1.2.3.4", + port: "123", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "goat", srv.GetHostname()) + }, + }, + { + name: "match by host only for tunnels", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "alpaca", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "alpaca", srv.GetHostname()) + }, + }, + { + name: "failure on invalid addresses", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "lion", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + } + + ctx := context.Background() + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + srv, err := getServer(ctx, tt.host, tt.port, tt.site) + tt.errAssertion(t, err) + tt.serverAssertion(t, srv) + }) + } +} + +func serverResolver(srv types.Server, err error) serverResolverFn { + return func(ctx context.Context, host, port string, site site) (types.Server, error) { + return srv, err + } +} + +type tunnel struct { + reversetunnel.Tunnel + + site reversetunnel.RemoteSite + err error +} + +func (t tunnel) GetSite(cluster string) (reversetunnel.RemoteSite, error) { + return t.site, t.err +} + +type testRemoteSite struct { + reversetunnel.RemoteSite + conn net.Conn + err error +} + +func (r testRemoteSite) Dial(reversetunnel.DialParams) (net.Conn, error) { + return r.conn, r.err +} + +func (r testRemoteSite) DialAuthServer() (net.Conn, error) { + return r.conn, r.err +} + +type fakeConn struct { + net.Conn +} + +func TestRouter_DialHost(t *testing.T) { + t.Parallel() + + logger := utils.NewLoggerForTests().WithField(trace.Component, "test") + + srv := &types.ServerV2{ + Kind: types.KindNode, + Version: types.V2, + Metadata: types.Metadata{ + Name: uuid.NewString(), + }, + Spec: types.ServerSpecV2{ + Addr: "127.0.0.1:8889", + Hostname: "test", + }, + } + + cases := []struct { + name string + router Router + assertion func(t *testing.T, conn net.Conn, err error) + }{ + { + name: "failure resolving node", + router: Router{ + clusterName: "test", + log: logger, + tracer: tracing.NoopTracer("test"), + serverResolver: serverResolver(nil, trace.NotFound(teleport.NodeIsAmbiguous)), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.Nil(t, conn) + }, + }, + { + name: "failure looking up cluster", + router: Router{ + clusterName: "leaf", + siteGetter: tunnel{err: trace.NotFound("unknown cluster")}, + log: logger, + tracer: tracing.NoopTracer("test"), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + require.Nil(t, conn) + }, + }, + { + name: "dial failure", + router: Router{ + clusterName: "test", + log: logger, + localSite: &testRemoteSite{err: trace.ConnectionProblem(context.DeadlineExceeded, "connection refused")}, + tracer: tracing.NoopTracer("test"), + serverResolver: serverResolver(srv, nil), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err)) + require.Nil(t, conn) + }, + }, + { + name: "dial success", + router: Router{ + clusterName: "test", + log: logger, + localSite: &testRemoteSite{conn: fakeConn{}}, + tracer: tracing.NoopTracer("test"), + serverResolver: serverResolver(srv, nil), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + } + + ctx := context.Background() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + conn, err := tt.router.DialHost(ctx, &utils.NetAddr{}, "host", "0", "test", nil, nil) + tt.assertion(t, conn, err) + }) + } + +} + +func TestRouter_DialSite(t *testing.T) { + t.Parallel() + + const cluster = "test" + logger := utils.NewLoggerForTests().WithField(trace.Component, cluster) + + cases := []struct { + name string + cluster string + localSite testRemoteSite + tunnel tunnel + assertion func(t *testing.T, conn net.Conn, err error) + }{ + { + name: "failure to dial local site", + cluster: cluster, + localSite: testRemoteSite{err: trace.ConnectionProblem(context.DeadlineExceeded, "connection refused")}, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err)) + require.Nil(t, conn) + }, + }, + { + name: "successfully dial local site", + cluster: cluster, + localSite: testRemoteSite{conn: fakeConn{}}, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + + { + name: "default to dialing local site", + localSite: testRemoteSite{conn: fakeConn{}}, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + { + name: "failure to dial remote site", + cluster: "leaf", + tunnel: tunnel{ + site: testRemoteSite{err: trace.ConnectionProblem(context.DeadlineExceeded, "connection refused")}, + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err)) + require.Nil(t, conn) + }, + }, + { + name: "unknown cluster", + cluster: "fake", + tunnel: tunnel{ + err: trace.NotFound("unknown cluster"), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + require.Nil(t, conn) + }, + }, + { + name: "successfully dial remote site", + cluster: "leaf", + tunnel: tunnel{ + site: testRemoteSite{conn: fakeConn{}}, + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + } + + ctx := context.Background() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + router := Router{ + clusterName: cluster, + log: logger, + localSite: &tt.localSite, + siteGetter: tt.tunnel, + tracer: tracing.NoopTracer(cluster), + } + + conn, err := router.DialSite(ctx, tt.cluster) + tt.assertion(t, conn, err) + }) + } +} diff --git a/lib/reversetunnel/api.go b/lib/reversetunnel/api.go index 2866533ac40f5..d7886fe7ee002 100644 --- a/lib/reversetunnel/api.go +++ b/lib/reversetunnel/api.go @@ -25,7 +25,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" ) @@ -140,7 +140,7 @@ type Server interface { // Wait waits for server to close all outstanding operations Wait() // GetProxyPeerClient returns the proxy peer client - GetProxyPeerClient() *proxy.Client + GetProxyPeerClient() *peer.Client } const ( diff --git a/lib/reversetunnel/conn_metric.go b/lib/reversetunnel/conn_metric.go index b8c39ed8c0b0e..ab6300884c635 100644 --- a/lib/reversetunnel/conn_metric.go +++ b/lib/reversetunnel/conn_metric.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package reversetunnel import ( @@ -27,16 +28,16 @@ import ( type dialType string const ( - // direct is a direct dialed connection. - direct dialType = "direct" - // peer is a connection established through a peer proxy. - peer dialType = "peer" - // tunnel is a connection established over a local reverse tunnel initiated + // dialTypeDirect is a direct dialed connection. + dialTypeDirect dialType = "direct" + // dialTypePeer is a connection established through a peer proxy. + dialTypePeer dialType = "peer" + // dialTypeTunnel is a connection established over a local reverse tunnel initiated // by a client. - tunnel dialType = "tunnel" - // peerTunnel is a connection established over a local reverse tunnel + dialTypeTunnel dialType = "tunnel" + // dialTypePeerTunnel is a connection established over a local reverse tunnel // initiated by a peer proxy. - peerTunnel dialType = "peer-tunnel" + dialTypePeerTunnel dialType = "peer-tunnel" ) // metricConn reports metrics for reversetunnel connections. diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 9e6bdf2b08235..558f9762543f8 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -36,7 +36,7 @@ import ( "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/observability/metrics" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/reversetunnel/track" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/forward" @@ -150,7 +150,7 @@ type localSite struct { offlineThreshold time.Duration // peerClient is the proxy peering client - peerClient *proxy.Client + peerClient *peer.Client // periodicFunctionInterval defines the interval period functions run at periodicFunctionInterval time.Duration @@ -447,9 +447,9 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e // return a tunnel connection to that node. Otherwise net.Dial to the target host. conn, tunnelErr = s.dialTunnel(dreq) if tunnelErr == nil { - dt := tunnel + dt := dialTypeTunnel if params.FromPeerProxy { - dt = peerTunnel + dt = dialTypePeerTunnel } return newMetricConn(conn, dt, dialStart, s.srv.Clock), true, nil @@ -462,7 +462,7 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e params.ProxyIDs, params.ServerID, params.From, params.To, params.ConnType, ) if peerErr == nil { - return newMetricConn(conn, peer, dialStart, s.srv.Clock), true, nil + return newMetricConn(conn, dialTypePeer, dialStart, s.srv.Clock), true, nil } s.log.WithError(peerErr).WithField("address", dreq.Address).Debug("Error occurred while dialing over peer proxy.") } @@ -494,7 +494,7 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e } // Return a direct dialed connection. - return newMetricConn(conn, direct, dialStart, s.srv.Clock), false, nil + return newMetricConn(conn, dialTypeDirect, dialStart, s.srv.Clock), false, nil } func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.Conn, sconn ssh.Conn) (*remoteConn, error) { diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 63701b819a018..df7d5dcda0e89 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -43,7 +43,7 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/observability/metrics" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" @@ -191,7 +191,7 @@ type Config struct { NewCachingAccessPointOldProxy auth.NewRemoteProxyCachingAccessPoint // PeerClient is a client to peer proxy servers. - PeerClient *proxy.Client + PeerClient *peer.Client // LockWatcher is a lock watcher. LockWatcher *services.LockWatcher @@ -985,7 +985,7 @@ func (s *server) GetSite(name string) (RemoteSite, error) { } // GetProxyPeerClient returns the proxy peer client -func (s *server) GetProxyPeerClient() *proxy.Client { +func (s *server) GetProxyPeerClient() *peer.Client { return s.PeerClient } diff --git a/lib/service/service.go b/lib/service/service.go index 57e7c149fa5ac..01b3735e7293e 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -95,6 +95,7 @@ import ( "github.com/gravitational/teleport/lib/plugin" "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/proxy/clusterdial" + "github.com/gravitational/teleport/lib/proxy/peer" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -3376,11 +3377,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // register SSH reverse tunnel server that accepts connections // from remote teleport nodes var tsrv reversetunnel.Server - var peerClient *proxy.Client + var peerClient *peer.Client if !process.Config.Proxy.DisableReverseTunnel { if listeners.proxy != nil { - peerClient, err = proxy.NewClient(proxy.ClientConfig{ + peerClient, err = peer.NewClient(peer.ClientConfig{ Context: process.ExitContext(), ID: process.Config.HostUUID, AuthClient: conn.Client, @@ -3441,6 +3442,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return nil }) } + if !process.Config.Proxy.DisableTLS { tlsConfigWeb, err = process.setupProxyTLSConfig(conn, tsrv, accessPoint, clusterName) if err != nil { @@ -3559,14 +3561,14 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } var peerAddrString string - var proxyServer *proxy.Server + var proxyServer *peer.Server if !process.Config.Proxy.DisableReverseTunnel && listeners.proxy != nil { peerAddr, err := process.Config.Proxy.publicPeerAddr() if err != nil { return trace.Wrap(err) } peerAddrString = peerAddr.String() - proxyServer, err = proxy.NewServer(proxy.ServerConfig{ + proxyServer, err = peer.NewServer(peer.ServerConfig{ AccessCache: accessPoint, Listener: listeners.proxy, TLSConfig: serverTLSConfig, @@ -3593,6 +3595,21 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { }) } + var proxyRouter *proxy.Router + if !process.Config.Proxy.DisableReverseTunnel { + router, err := proxy.NewRouter(proxy.RouterConfig{ + ClusterName: clusterName, + Log: process.log.WithField(trace.Component, "router"), + RemoteClusterGetter: accessPoint, + SiteGetter: tsrv, + TracerProvider: process.TracingProvider, + }) + if err != nil { + return trace.Wrap(err) + } + proxyRouter = router + } + sshProxy, err := regular.New(cfg.Proxy.SSHAddr, cfg.Hostname, []ssh.Signer{conn.ServerIdentity.KeySigner}, @@ -3602,7 +3619,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { process.proxyPublicAddr(), conn.Client, regular.SetLimiter(proxyLimiter), - regular.SetProxyMode(peerAddrString, tsrv, accessPoint), + regular.SetProxyMode(peerAddrString, tsrv, accessPoint, proxyRouter), regular.SetCiphers(cfg.Ciphers), regular.SetKEXAlgorithms(cfg.KEXAlgorithms), regular.SetMACAlgorithms(cfg.MACAlgorithms), diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index 9baae13730055..76283f86f528a 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -23,26 +23,19 @@ import ( "fmt" "io" "net" - "strconv" "strings" - "sync" - "github.com/google/uuid" "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" - "golang.org/x/exp/slices" "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/observability/tracing" - "github.com/gravitational/teleport/api/types" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/observability/metrics" - "github.com/gravitational/teleport/lib/reversetunnel" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -70,20 +63,20 @@ var ( // failedConnectingToNode counts failed attempts to connect to a node Help: "Number of SSH connection attempts to a node. Use with `failed_connect_to_node_attempts_total` to get the failure rate.", }, ) - - prometheusCollectors = []prometheus.Collector{proxiedSessions, failedConnectingToNode, connectingToNode} ) +func init() { + metrics.RegisterPrometheusCollectors(proxiedSessions, failedConnectingToNode, connectingToNode) +} + // proxySubsys implements an SSH subsystem for proxying listening sockets from // remote hosts to a proxy client (AKA port mapping) type proxySubsys struct { proxySubsysRequest - srv *Server - ctx *srv.ServerContext - log *logrus.Entry - closeC chan struct{} - error error - closeOnce sync.Once + router *proxy.Router + ctx *srv.ServerContext + log *logrus.Entry + closeC chan error } // parseProxySubsys looks at the requested subsystem name and returns a fully configured @@ -110,8 +103,9 @@ func parseProxySubsysRequest(request string) (proxySubsysRequest, error) { } requestBody := strings.TrimPrefix(request, prefix) namespace := apidefaults.Namespace - var err error parts := strings.Split(requestBody, "@") + + var err error switch { case len(parts) == 0: // "proxy:" return proxySubsysRequest{}, trace.BadParameter(paramMessage) @@ -190,11 +184,6 @@ func (p *proxySubsysRequest) SetDefaults() { // a port forwarding request, used to implement ProxyJump feature in proxy // and reuse the code func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) (*proxySubsys, error) { - err := metrics.RegisterPrometheusCollectors(prometheusCollectors...) - if err != nil { - return nil, trace.Wrap(err) - } - req.SetDefaults() if req.clusterName == "" && ctx.Identity.RouteToCluster != "" { log.Debugf("Proxy subsystem: routing user %q to cluster %q based on the route to cluster extension.", @@ -212,12 +201,12 @@ func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) return &proxySubsys{ proxySubsysRequest: req, ctx: ctx, - srv: srv, log: logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentSubsystemProxy, trace.ComponentFields: map[string]string{}, }), - closeC: make(chan struct{}), + closeC: make(chan error), + router: srv.router, }, nil } @@ -239,12 +228,8 @@ func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.C }) t.log.Debugf("Starting subsystem") - var ( - site reversetunnel.RemoteSite - err error - tunnel = t.srv.tunnelWithAccessChecker(serverContext) - clientAddr = sconn.RemoteAddr() - ) + clientAddr := sconn.RemoteAddr() + // did the client pass us a true client IP ahead of time via an environment variable? // (usually the web client would do that) trueClientIP, ok := serverContext.GetEnv(sshutils.TrueClientAddrVar) @@ -254,185 +239,41 @@ func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.C clientAddr = a } } - // get the cluster by name: - if t.clusterName != "" { - site, err = tunnel.GetSite(t.clusterName) - if err != nil { - t.log.Warn(err) - return trace.Wrap(err) - } - } - // connecting to a specific host: - if t.host != "" { - // no site given? use the 1st one: - if site == nil { - sites, err := tunnel.GetSites() - if err != nil { - return trace.Wrap(err) - } - if len(sites) == 0 { - t.log.Error("Not connected to any remote clusters") - return trace.NotFound("no connected sites") - } - site = sites[0] - t.clusterName = site.GetName() - t.log.Debugf("Cluster not specified. connecting to default='%s'", site.GetName()) - } - return t.proxyToHost(ctx, site, clientAddr, ch) + + // connect to a site's auth server + if t.host == "" { + return t.proxyToSite(ctx, ch, t.clusterName) } - // connect to a site's auth server: - return t.proxyToSite(serverContext, site, clientAddr, ch) + + // connect to a server + return t.proxyToHost(ctx, ch, clientAddr) } // proxyToSite establishes a proxy connection from the connected SSH client to the // auth server of the requested remote site -func (t *proxySubsys) proxyToSite( - ctx *srv.ServerContext, site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { - conn, err := site.DialAuthServer() +func (t *proxySubsys) proxyToSite(ctx context.Context, ch ssh.Channel, clusterName string) error { + t.log.Debugf("Connecting to site: %v", clusterName) + + conn, err := t.router.DialSite(ctx, clusterName) if err != nil { return trace.Wrap(err) } - t.log.Infof("Connected to auth server: %v", conn.RemoteAddr()) + t.log.Infof("Connected to cluster %v at %v", clusterName, conn.RemoteAddr()) proxiedSessions.Inc() + go func() { - var err error - defer func() { - t.close(err) - }() - defer ch.Close() - _, err = io.Copy(ch, conn) - }() - go func() { - var err error - defer func() { - t.close(err) - }() - defer conn.Close() - _, err = io.Copy(conn, ch) + t.close(utils.ProxyConn(ctx, ch, conn)) }() - return nil } // proxyToHost establishes a proxy connection from the connected SSH client to the // requested remote node (t.host:t.port) via the given site -func (t *proxySubsys) proxyToHost( - ctx context.Context, site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { - // - // first, lets fetch a list of servers at the given site. this allows us to - // match the given "host name" against node configuration (their 'nodename' setting) - // - // but failing to fetch the list of servers is also OK, we'll use standard - // network resolution (by IP or DNS) - // - var ( - strategy types.RoutingStrategy - nodeWatcher NodesGetter - err error - ) - localCluster, _ := t.srv.proxyAccessPoint.GetClusterName() - // going to "local" CA? lets use the caching 'auth service' directly and avoid - // hitting the reverse tunnel link (it can be offline if the CA is down) - if site.GetName() == localCluster.GetName() { - nodeWatcher = t.srv.nodeWatcher +func (t *proxySubsys) proxyToHost(ctx context.Context, ch ssh.Channel, remoteAddr net.Addr) error { + t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v", t.host, t.port, t.SpecifiedPort()) - cfg, err := t.srv.authService.GetClusterNetworkingConfig(ctx) - if err != nil { - t.log.Warn(err) - } else { - strategy = cfg.GetRoutingStrategy() - } - } else { - // "remote" CA? use a reverse tunnel to talk to it: - siteClient, err := site.CachingAccessPoint() - if err != nil { - t.log.Warn(err) - } else { - watcher, err := site.NodeWatcher() - if err != nil { - t.log.Warn(err) - } else { - nodeWatcher = watcher - } - - cfg, err := siteClient.GetClusterNetworkingConfig(ctx) - if err != nil { - t.log.Warn(err) - } else { - strategy = cfg.GetRoutingStrategy() - } - } - } - - // if port is 0, it means the client wants us to figure out - // which port to use - t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v, strategy=%s", t.host, t.port, t.SpecifiedPort(), strategy) - - // determine which server to connect to - server, err := t.getMatchingServer(nodeWatcher, strategy) - if err != nil { - return trace.Wrap(err) - } - - // Create a slice of principals that will be added into the host certificate. - // Here t.host is either an IP address or a DNS name as the user requested. - principals := []string{t.host} - - // Used to store the server ID (hostUUID.clusterName) of a Teleport node. - var serverID string - - // Resolve the IP address to dial to because the hostname may not be - // DNS resolvable. - var ( - serverAddr string - proxyIDs []string - ) - - if server != nil { - // Add hostUUID.clusterName to list of principals. - serverID = fmt.Sprintf("%v.%v", server.GetName(), t.clusterName) - principals = append(principals, serverID) - proxyIDs = server.GetProxyIDs() - - // Add IP address (if it exists) of the node to list of principals. - serverAddr = server.GetAddr() - if serverAddr != "" { - host, _, err := net.SplitHostPort(serverAddr) - if err != nil { - return trace.Wrap(err) - } - principals = append(principals, host) - } else if server.GetUseTunnel() { - serverAddr = reversetunnel.LocalNode - } - } else { - if !t.SpecifiedPort() { - t.port = strconv.Itoa(defaults.SSHServerListenPort) - } - serverAddr = net.JoinHostPort(t.host, t.port) - t.log.Warnf("server lookup failed: using default=%v", serverAddr) - } - - // Pass the agent along to the site. If the proxy is in recording mode, this - // agent is used to perform user authentication. Pass the DNS name to the - // dialer as well so the forwarding proxy can generate a host certificate - // with the correct hostname). - toAddr := &utils.NetAddr{ - AddrNetwork: "tcp", - Addr: serverAddr, - } - connectingToNode.Inc() - conn, err := site.Dial(reversetunnel.DialParams{ - From: remoteAddr, - To: toAddr, - GetUserAgent: t.ctx.StartAgentChannel, - Address: t.host, - ServerID: serverID, - ProxyIDs: proxyIDs, - Principals: principals, - ConnType: types.NodeTunnel, - }) + conn, err := t.router.DialHost(ctx, remoteAddr, t.host, t.port, t.clusterName, t.ctx.Identity.AccessChecker, t.ctx.StartAgentChannel) if err != nil { failedConnectingToNode.Inc() return trace.Wrap(err) @@ -443,126 +284,20 @@ func (t *proxySubsys) proxyToHost( t.doHandshake(ctx, remoteAddr, ch, conn) proxiedSessions.Inc() + go func() { - var err error - defer func() { - t.close(err) - }() - defer ch.Close() - _, err = io.Copy(ch, conn) - }() - go func() { - var err error - defer func() { - t.close(err) - }() - defer conn.Close() - _, err = io.Copy(conn, ch) + t.close(utils.ProxyConn(ctx, ch, conn)) }() - return nil } -// NodesGetter is a function that retrieves a subset of nodes matching -// the filter criteria. -type NodesGetter interface { - GetNodes(fn func(n services.Node) bool) []types.Server -} - -// getMatchingServer determines the server to connect to from the provided servers. Duplicate entries are treated -// differently based on strategy. Legacy behavior of returning an ambiguous error occurs if the strategy -// is types.RoutingStrategy_UNAMBIGUOUS_MATCH. When the strategy is types.RoutingStrategy_MOST_RECENT then -// the server that has heartbeated most recently will be returned instead of an error. If no matches are found then -// both the types.Server and error returned will be nil. -func (t *proxySubsys) getMatchingServer(watcher NodesGetter, strategy types.RoutingStrategy) (types.Server, error) { - if watcher == nil { - return nil, trace.NotFound("unable to retrieve nodes matching host %s", t.host) - } - - // check if hostname is a valid uuid or EC2 node ID. If it is, we will - // preferentially match by node ID over node hostname. - _, err := uuid.Parse(t.host) - hostIsUniqueID := err == nil || utils.IsEC2NodeID(t.host) - - ips, _ := net.LookupHost(t.host) - - var unambiguousIDMatch bool - // enumerate and try to find a server with self-registered with a matching name/IP: - matches := watcher.GetNodes(func(server services.Node) bool { - if unambiguousIDMatch { - return false - } - - // If the host parameter is a UUID or EC2 node ID, and it matches the - // Node ID, treat this as an unambiguous match. - if hostIsUniqueID && server.GetName() == t.host { - unambiguousIDMatch = true - return true - } - // If the server has connected over a reverse tunnel, match only on hostname. - if server.GetUseTunnel() { - return t.host == server.GetHostname() - } - - ip, port, err := net.SplitHostPort(server.GetAddr()) - if err != nil { - t.log.Errorf("Failed to parse address %q: %v.", server.GetAddr(), err) - return false - } - if t.host == ip || t.host == server.GetHostname() || slices.Contains(ips, ip) { - if !t.SpecifiedPort() || t.port == port { - return true - } - } - return false - }) - - var server types.Server - switch { - case strategy == types.RoutingStrategy_MOST_RECENT: - // find the most recent of all the matches - for _, m := range matches { - if server == nil || m.Expiry().After(server.Expiry()) { - server = m - } - } - case len(matches) > 1: - // if we matched more than one server, then the target was ambiguous. - return nil, trace.NotFound(teleport.NodeIsAmbiguous) - case len(matches) == 1: - server = matches[0] - } - - // If we matched zero nodes but hostname is a UUID (or EC2 node ID) then it - // isn't sane to fallback to dns based resolution. This has the unfortunate - // consequence of preventing users from calling OpenSSH nodes which happen - // to use hostnames which are also valid UUIDs. This restriction is - // necessary in order to protect users attempting to connect to a node by - // UUID from being re-routed to an unintended target if the node is offline. - // This restriction can be lifted if we decide to move to explicit UUID - // based resolution in the future. - if hostIsUniqueID && server == nil { - idType := "UUID" - if utils.IsEC2NodeID(t.host) { - idType = "EC2" - } - return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, t.host) - } - - return server, nil -} - func (t *proxySubsys) close(err error) { - t.closeOnce.Do(func() { - proxiedSessions.Dec() - t.error = err - close(t.closeC) - }) + proxiedSessions.Dec() + t.closeC <- err } func (t *proxySubsys) Wait() error { - <-t.closeC - return t.error + return <-t.closeC } // doHandshake allows a proxy server to send additional information (client IP) diff --git a/lib/srv/regular/proxy_test.go b/lib/srv/regular/proxy_test.go index d6259bbb209cf..8fef4a19d7186 100644 --- a/lib/srv/regular/proxy_test.go +++ b/lib/srv/regular/proxy_test.go @@ -18,14 +18,10 @@ package regular import ( "testing" - "time" - "github.com/google/uuid" "github.com/stretchr/testify/require" apidefaults "github.com/gravitational/teleport/api/defaults" - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" ) @@ -130,210 +126,3 @@ func TestParseBadRequests(t *testing.T) { }) } } - -type nodeGetter struct { - servers []types.Server -} - -func (n nodeGetter) GetNodes(fn func(n services.Node) bool) []types.Server { - var servers []types.Server - for _, s := range n.servers { - if fn(s) { - servers = append(servers, s) - } - } - - return servers -} - -func TestProxySubsys_getMatchingServer(t *testing.T) { - t.Parallel() - - serverUUID := uuid.NewString() - - ec2NodeID := "123456789012-i-abcdef12345678901" - - setExpiry := func(time time.Time) func(server types.Server) { - return func(server types.Server) { - server.SetExpiry(time) - } - } - - createServer := func(name string, spec types.ServerSpecV2, opts ...func(server types.Server)) types.Server { - t.Helper() - - server, err := types.NewServer(name, types.KindNode, spec) - require.NoError(t, err) - - for _, opt := range opts { - opt(server) - } - - return server - } - - servers := []types.Server{ - createServer(serverUUID, types.ServerSpecV2{ - Hostname: "127.0.0.1", - Addr: "127.0.0.1:80", - }, setExpiry(time.Now().Add(-time.Hour))), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - }, setExpiry(time.Now().Add(time.Hour*24))), - createServer("server3", types.ServerSpecV2{ - Hostname: serverUUID, - Addr: "127.0.0.1:", - }), - createServer(ec2NodeID, types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:", - }), - } - - cases := []struct { - desc string - req proxySubsysRequest - strategy types.RoutingStrategy - servers []types.Server - expectError require.ErrorAssertionFunc - expectServer func(servers []types.Server) types.Server - }{ - { - desc: "No matches found", - expectError: require.NoError, - }, - { - desc: "No matches found for UUID host", - expectError: require.Error, - servers: []types.Server{createServer(uuid.NewString(), types.ServerSpecV2{ - Addr: "127.0.0.1:0", - })}, - req: proxySubsysRequest{ - host: uuid.NewString(), - }, - }, - { - desc: "Match by UUID", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[0] - }, - servers: servers, - req: proxySubsysRequest{ - host: serverUUID, - }, - }, - { - desc: "Match by EC2 ID", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[3] - }, - servers: servers, - req: proxySubsysRequest{ - host: ec2NodeID, - }, - }, - { - desc: "Match Tunnel By Host Only", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[0] - }, - servers: []types.Server{ - createServer("server1", types.ServerSpecV2{ - Addr: "127.0.0.1", - Hostname: "127.0.0.1", - UseTunnel: true, - }), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - UseTunnel: true, - }), - }, - req: proxySubsysRequest{ - host: "127.0.0.1", - port: "80", - }, - }, - { - desc: "Match by IP", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[1] - }, - servers: []types.Server{ - createServer("server1", types.ServerSpecV2{ - Addr: "127.0.0.1:0", - Hostname: "127.0.0.1", - }), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - }), - }, - req: proxySubsysRequest{ - host: "127.0.0.1", - port: "80", - }, - }, - { - desc: "Match by hostname", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[1] - }, - servers: []types.Server{ - createServer("server1", types.ServerSpecV2{ - Addr: "127.0.0.1:0", - Hostname: "localhost", - }), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - }), - }, - req: proxySubsysRequest{ - host: "localhost", - port: "80", - }, - }, - { - desc: "Ambiguous match", - expectError: require.Error, - servers: servers, - req: proxySubsysRequest{ - host: "localhost", - }, - }, - { - desc: "Most Recent match", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[1] - }, - servers: servers, - strategy: types.RoutingStrategy_MOST_RECENT, - req: proxySubsysRequest{ - host: "localhost", - }, - }, - } - - for _, tt := range cases { - t.Run(tt.desc, func(t *testing.T) { - subsystem := proxySubsys{ - proxySubsysRequest: tt.req, - srv: &Server{}, - } - - server, err := subsystem.getMatchingServer(nodeGetter{tt.servers}, tt.strategy) - tt.expectError(t, err) - if tt.expectServer != nil { - require.Equal(t, tt.expectServer(tt.servers), server) - } - }) - } -} diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 65b587e867140..b07c5c79003c6 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -55,6 +55,7 @@ import ( "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/pam" + "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -223,6 +224,10 @@ type Server struct { // tracerProvider is used to create tracers capable // of starting spans. tracerProvider oteltrace.TracerProvider + + // router used by subsystem requests to connect to nodes + // and clusters + router *proxy.Router } // TargetMetadata returns metadata about the server. @@ -422,7 +427,7 @@ func SetShell(shell string) ServerOption { } // SetProxyMode starts this server in SSH proxying mode -func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyAccessPoint) ServerOption { +func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyAccessPoint, router *proxy.Router) ServerOption { return func(s *Server) error { // always set proxy mode to true, // because in some tests reverse tunnel is disabled, @@ -431,6 +436,7 @@ func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyA s.proxyTun = tsrv s.proxyAccessPoint = ap s.peerAddr = peerAddr + s.router = router return nil } } diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index d8d86c91e0e2e..e4ac44763c481 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -59,7 +59,9 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/limiter" + "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/pam" + libproxy "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -1383,6 +1385,15 @@ func TestProxyRoundRobin(t *testing.T) { require.NoError(t, reverseTunnelServer.Start()) defer reverseTunnelServer.Close() + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: f.testSrv.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: proxyClient, + SiteGetter: reverseTunnelServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + proxy, err := New( utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), @@ -1392,7 +1403,7 @@ func TestProxyRoundRobin(t *testing.T) { "", utils.NetAddr{}, proxyClient, - SetProxyMode("", reverseTunnelServer, proxyClient), + SetProxyMode("", reverseTunnelServer, proxyClient, router), SetEmitter(nodeClient), SetNamespace(apidefaults.Namespace), SetPAMConfig(&pam.Config{Enabled: false}), @@ -1503,6 +1514,15 @@ func TestProxyDirectAccess(t *testing.T) { nodeClient, _ := newNodeClient(t, f.testSrv) + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: f.testSrv.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: proxyClient, + SiteGetter: reverseTunnelServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + proxy, err := New( utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), @@ -1512,7 +1532,7 @@ func TestProxyDirectAccess(t *testing.T) { "", utils.NetAddr{}, proxyClient, - SetProxyMode("", reverseTunnelServer, proxyClient), + SetProxyMode("", reverseTunnelServer, proxyClient, router), SetEmitter(nodeClient), SetNamespace(apidefaults.Namespace), SetPAMConfig(&pam.Config{Enabled: false}), @@ -2233,6 +2253,15 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { nodeClient, _ := newNodeClient(t, f.testSrv) + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: f.testSrv.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: proxyClient, + SiteGetter: reverseTunnelServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + proxy, err := New( utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), @@ -2242,7 +2271,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { "", utils.NetAddr{}, proxyClient, - SetProxyMode("", reverseTunnelServer, proxyClient), + SetProxyMode("", reverseTunnelServer, proxyClient, router), SetEmitter(nodeClient), SetNamespace(apidefaults.Namespace), SetPAMConfig(&pam.Config{Enabled: false}), diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 5bf9d6eec8ac1..645968c846383 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -95,7 +95,9 @@ import ( kubeproxy "github.com/gravitational/teleport/lib/kube/proxy" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/pam" + libproxy "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/secret" @@ -329,6 +331,15 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { require.NoError(t, err) s.proxyTunnel = revTunServer + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: s.server.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: s.proxyClient, + SiteGetter: revTunServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + // proxy server: s.proxy, err = regular.New( utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, @@ -340,7 +351,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { utils.NetAddr{}, s.proxyClient, regular.SetUUID(proxyID), - regular.SetProxyMode("", revTunServer, s.proxyClient), + regular.SetProxyMode("", revTunServer, s.proxyClient, router), regular.SetEmitter(s.proxyClient), regular.SetNamespace(apidefaults.Namespace), regular.SetBPF(&bpf.NOP{}), @@ -5847,6 +5858,15 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) }) + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: authServer.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: client, + SiteGetter: revTunServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + proxyServer, err := regular.New( utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, authServer.ClusterName(), @@ -5857,7 +5877,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula utils.NetAddr{AddrNetwork: "tcp", Addr: "proxy-1.example.com:443"}, client, regular.SetUUID(proxyID), - regular.SetProxyMode("", revTunServer, client), + regular.SetProxyMode("", revTunServer, client, router), regular.SetEmitter(client), regular.SetNamespace(apidefaults.Namespace), regular.SetBPF(&bpf.NOP{}),