From bd22444bc8c3d2bb4ee31b22d93c31f4500c3826 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Wed, 9 Nov 2022 17:35:20 -0500 Subject: [PATCH 1/3] create package to contain proxy peering code --- lib/proxy/clusterdial/dial.go | 12 ++++++------ lib/proxy/{ => peer}/auth.go | 2 +- lib/proxy/{ => peer}/client.go | 2 +- lib/proxy/{ => peer}/client_test.go | 2 +- lib/proxy/{ => peer}/clientmetrics.go | 2 +- lib/proxy/{ => peer}/conn.go | 2 +- lib/proxy/{ => peer}/conn_test.go | 2 +- lib/proxy/{ => peer}/helpers_test.go | 2 +- lib/proxy/{ => peer}/interceptor.go | 2 +- lib/proxy/{ => peer}/reporter.go | 2 +- lib/proxy/{ => peer}/server.go | 2 +- lib/proxy/{ => peer}/server_test.go | 2 +- lib/proxy/{ => peer}/servermetrics.go | 2 +- lib/proxy/{ => peer}/service.go | 2 +- lib/proxy/{ => peer}/service_test.go | 2 +- lib/proxy/{ => peer}/stats.go | 4 ++-- lib/reversetunnel/api.go | 4 ++-- lib/reversetunnel/conn_metric.go | 17 +++++++++-------- lib/reversetunnel/localsite.go | 12 ++++++------ lib/reversetunnel/srv.go | 6 +++--- lib/service/service.go | 10 +++++----- 21 files changed, 47 insertions(+), 46 deletions(-) rename lib/proxy/{ => peer}/auth.go (99%) rename lib/proxy/{ => peer}/client.go (99%) rename lib/proxy/{ => peer}/client_test.go (99%) rename lib/proxy/{ => peer}/clientmetrics.go (99%) rename lib/proxy/{ => peer}/conn.go (99%) rename lib/proxy/{ => peer}/conn_test.go (99%) rename lib/proxy/{ => peer}/helpers_test.go (99%) rename lib/proxy/{ => peer}/interceptor.go (99%) rename lib/proxy/{ => peer}/reporter.go (99%) rename lib/proxy/{ => peer}/server.go (99%) rename lib/proxy/{ => peer}/server_test.go (99%) rename lib/proxy/{ => peer}/servermetrics.go (99%) rename lib/proxy/{ => peer}/service.go (99%) rename lib/proxy/{ => peer}/service_test.go (99%) rename lib/proxy/{ => peer}/stats.go (98%) diff --git a/lib/proxy/clusterdial/dial.go b/lib/proxy/clusterdial/dial.go index 13822db268f72..59a01e9a97dd8 100644 --- a/lib/proxy/clusterdial/dial.go +++ b/lib/proxy/clusterdial/dial.go @@ -19,21 +19,21 @@ import ( "github.com/gravitational/trace" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/reversetunnel" ) -// ClusterDialerFunc is a function that implements a proxy.ClusterDialer. -type ClusterDialerFunc func(clusterName string, request proxy.DialParams) (net.Conn, error) +// ClusterDialerFunc is a function that implements a peer.ClusterDialer. +type ClusterDialerFunc func(clusterName string, request peer.DialParams) (net.Conn, error) // Dial dials makes a dial request to the given cluster. -func (f ClusterDialerFunc) Dial(clusterName string, request proxy.DialParams) (net.Conn, error) { +func (f ClusterDialerFunc) Dial(clusterName string, request peer.DialParams) (net.Conn, error) { return f(clusterName, request) } // NewClusterDialer implements proxy.ClusterDialer for a reverse tunnel server. func NewClusterDialer(server reversetunnel.Server) ClusterDialerFunc { - return ClusterDialerFunc(func(clusterName string, request proxy.DialParams) (net.Conn, error) { + return func(clusterName string, request peer.DialParams) (net.Conn, error) { site, err := server.GetSite(clusterName) if err != nil { return nil, trace.Wrap(err) @@ -52,5 +52,5 @@ func NewClusterDialer(server reversetunnel.Server) ClusterDialerFunc { return nil, trace.Wrap(err) } return conn, nil - }) + } } diff --git a/lib/proxy/auth.go b/lib/proxy/peer/auth.go similarity index 99% rename from lib/proxy/auth.go rename to lib/proxy/peer/auth.go index 622258aa2f22a..537f714dd1ef3 100644 --- a/lib/proxy/auth.go +++ b/lib/proxy/peer/auth.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/client.go b/lib/proxy/peer/client.go similarity index 99% rename from lib/proxy/client.go rename to lib/proxy/peer/client.go index cd83b05a6caf4..b80dc83a77e4e 100644 --- a/lib/proxy/client.go +++ b/lib/proxy/peer/client.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/client_test.go b/lib/proxy/peer/client_test.go similarity index 99% rename from lib/proxy/client_test.go rename to lib/proxy/peer/client_test.go index 6253a2c3c8764..fa942c372ba3e 100644 --- a/lib/proxy/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "crypto/tls" diff --git a/lib/proxy/clientmetrics.go b/lib/proxy/peer/clientmetrics.go similarity index 99% rename from lib/proxy/clientmetrics.go rename to lib/proxy/peer/clientmetrics.go index da6c0c0676143..e67d571d47e7f 100644 --- a/lib/proxy/clientmetrics.go +++ b/lib/proxy/peer/clientmetrics.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "github.com/gravitational/trace" diff --git a/lib/proxy/conn.go b/lib/proxy/peer/conn.go similarity index 99% rename from lib/proxy/conn.go rename to lib/proxy/peer/conn.go index 1fc1f5d5bc16e..0e7bf1bfde5a4 100644 --- a/lib/proxy/conn.go +++ b/lib/proxy/peer/conn.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/conn_test.go b/lib/proxy/peer/conn_test.go similarity index 99% rename from lib/proxy/conn_test.go rename to lib/proxy/peer/conn_test.go index 751e61d1fc968..fbe0bfbbba334 100644 --- a/lib/proxy/conn_test.go +++ b/lib/proxy/peer/conn_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/helpers_test.go b/lib/proxy/peer/helpers_test.go similarity index 99% rename from lib/proxy/helpers_test.go rename to lib/proxy/peer/helpers_test.go index afb2d10b43ef8..60036def48a47 100644 --- a/lib/proxy/helpers_test.go +++ b/lib/proxy/peer/helpers_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "crypto/rand" diff --git a/lib/proxy/interceptor.go b/lib/proxy/peer/interceptor.go similarity index 99% rename from lib/proxy/interceptor.go rename to lib/proxy/peer/interceptor.go index de838322c48ff..19f103b9b8aac 100644 --- a/lib/proxy/interceptor.go +++ b/lib/proxy/peer/interceptor.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/reporter.go b/lib/proxy/peer/reporter.go similarity index 99% rename from lib/proxy/reporter.go rename to lib/proxy/peer/reporter.go index 068d2b09b77d2..3656fc1f99899 100644 --- a/lib/proxy/reporter.go +++ b/lib/proxy/peer/reporter.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "time" diff --git a/lib/proxy/server.go b/lib/proxy/peer/server.go similarity index 99% rename from lib/proxy/server.go rename to lib/proxy/peer/server.go index 0314e2bd57db2..82a4671a3e629 100644 --- a/lib/proxy/server.go +++ b/lib/proxy/peer/server.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "crypto/tls" diff --git a/lib/proxy/server_test.go b/lib/proxy/peer/server_test.go similarity index 99% rename from lib/proxy/server_test.go rename to lib/proxy/peer/server_test.go index 546f99de288c9..66a4845ec0bf2 100644 --- a/lib/proxy/server_test.go +++ b/lib/proxy/peer/server_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "testing" diff --git a/lib/proxy/servermetrics.go b/lib/proxy/peer/servermetrics.go similarity index 99% rename from lib/proxy/servermetrics.go rename to lib/proxy/peer/servermetrics.go index 5adc1c4f4bc90..bd7bd85170229 100644 --- a/lib/proxy/servermetrics.go +++ b/lib/proxy/peer/servermetrics.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "github.com/gravitational/trace" diff --git a/lib/proxy/service.go b/lib/proxy/peer/service.go similarity index 99% rename from lib/proxy/service.go rename to lib/proxy/peer/service.go index 053c098fea3de..600ef470cb922 100644 --- a/lib/proxy/service.go +++ b/lib/proxy/peer/service.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "net" diff --git a/lib/proxy/service_test.go b/lib/proxy/peer/service_test.go similarity index 99% rename from lib/proxy/service_test.go rename to lib/proxy/peer/service_test.go index e6f9b576548c3..0cba1b7e36f3b 100644 --- a/lib/proxy/service_test.go +++ b/lib/proxy/peer/service_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/stats.go b/lib/proxy/peer/stats.go similarity index 98% rename from lib/proxy/stats.go rename to lib/proxy/peer/stats.go index 0d37b741718e9..dc0cd89501308 100644 --- a/lib/proxy/stats.go +++ b/lib/proxy/peer/stats.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" @@ -51,7 +51,7 @@ func (s *statsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) con return ctx } -// HandleRPC implements per-Connection stats reporting. +// HandleConn implements per-Connection stats reporting. func (s *statsHandler) HandleConn(ctx context.Context, connStats stats.ConnStats) { // client connection stats are monitored by the monitor() function in client.go if connStats.IsClient() { diff --git a/lib/reversetunnel/api.go b/lib/reversetunnel/api.go index 2866533ac40f5..d7886fe7ee002 100644 --- a/lib/reversetunnel/api.go +++ b/lib/reversetunnel/api.go @@ -25,7 +25,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" ) @@ -140,7 +140,7 @@ type Server interface { // Wait waits for server to close all outstanding operations Wait() // GetProxyPeerClient returns the proxy peer client - GetProxyPeerClient() *proxy.Client + GetProxyPeerClient() *peer.Client } const ( diff --git a/lib/reversetunnel/conn_metric.go b/lib/reversetunnel/conn_metric.go index b8c39ed8c0b0e..d2913fded9873 100644 --- a/lib/reversetunnel/conn_metric.go +++ b/lib/reversetunnel/conn_metric.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package reversetunnel import ( @@ -27,16 +28,16 @@ import ( type dialType string const ( - // direct is a direct dialed connection. - direct dialType = "direct" - // peer is a connection established through a peer proxy. - peer dialType = "peer" - // tunnel is a connection established over a local reverse tunnel initiated + // dialType_direct is a direct dialed connection. + dialType_direct dialType = "direct" + // dialType_peer is a connection established through a peer proxy. + dialType_peer dialType = "peer" + // dialType_tunnel is a connection established over a local reverse tunnel initiated // by a client. - tunnel dialType = "tunnel" - // peerTunnel is a connection established over a local reverse tunnel + dialType_tunnel dialType = "tunnel" + // dialType_peerTunnel is a connection established over a local reverse tunnel // initiated by a peer proxy. - peerTunnel dialType = "peer-tunnel" + dialType_peerTunnel dialType = "peer-tunnel" ) // metricConn reports metrics for reversetunnel connections. diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 9e6bdf2b08235..24ae4642ae400 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -36,7 +36,7 @@ import ( "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/observability/metrics" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/reversetunnel/track" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/forward" @@ -150,7 +150,7 @@ type localSite struct { offlineThreshold time.Duration // peerClient is the proxy peering client - peerClient *proxy.Client + peerClient *peer.Client // periodicFunctionInterval defines the interval period functions run at periodicFunctionInterval time.Duration @@ -447,9 +447,9 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e // return a tunnel connection to that node. Otherwise net.Dial to the target host. conn, tunnelErr = s.dialTunnel(dreq) if tunnelErr == nil { - dt := tunnel + dt := dialType_tunnel if params.FromPeerProxy { - dt = peerTunnel + dt = dialType_peerTunnel } return newMetricConn(conn, dt, dialStart, s.srv.Clock), true, nil @@ -462,7 +462,7 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e params.ProxyIDs, params.ServerID, params.From, params.To, params.ConnType, ) if peerErr == nil { - return newMetricConn(conn, peer, dialStart, s.srv.Clock), true, nil + return newMetricConn(conn, dialType_peer, dialStart, s.srv.Clock), true, nil } s.log.WithError(peerErr).WithField("address", dreq.Address).Debug("Error occurred while dialing over peer proxy.") } @@ -494,7 +494,7 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e } // Return a direct dialed connection. - return newMetricConn(conn, direct, dialStart, s.srv.Clock), false, nil + return newMetricConn(conn, dialType_direct, dialStart, s.srv.Clock), false, nil } func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.Conn, sconn ssh.Conn) (*remoteConn, error) { diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 63701b819a018..df7d5dcda0e89 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -43,7 +43,7 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/observability/metrics" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" @@ -191,7 +191,7 @@ type Config struct { NewCachingAccessPointOldProxy auth.NewRemoteProxyCachingAccessPoint // PeerClient is a client to peer proxy servers. - PeerClient *proxy.Client + PeerClient *peer.Client // LockWatcher is a lock watcher. LockWatcher *services.LockWatcher @@ -985,7 +985,7 @@ func (s *server) GetSite(name string) (RemoteSite, error) { } // GetProxyPeerClient returns the proxy peer client -func (s *server) GetProxyPeerClient() *proxy.Client { +func (s *server) GetProxyPeerClient() *peer.Client { return s.PeerClient } diff --git a/lib/service/service.go b/lib/service/service.go index 57e7c149fa5ac..fcfa24107c8bf 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -93,8 +93,8 @@ import ( "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/plugin" - "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/proxy/clusterdial" + "github.com/gravitational/teleport/lib/proxy/peer" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -3376,11 +3376,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // register SSH reverse tunnel server that accepts connections // from remote teleport nodes var tsrv reversetunnel.Server - var peerClient *proxy.Client + var peerClient *peer.Client if !process.Config.Proxy.DisableReverseTunnel { if listeners.proxy != nil { - peerClient, err = proxy.NewClient(proxy.ClientConfig{ + peerClient, err = peer.NewClient(peer.ClientConfig{ Context: process.ExitContext(), ID: process.Config.HostUUID, AuthClient: conn.Client, @@ -3559,14 +3559,14 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } var peerAddrString string - var proxyServer *proxy.Server + var proxyServer *peer.Server if !process.Config.Proxy.DisableReverseTunnel && listeners.proxy != nil { peerAddr, err := process.Config.Proxy.publicPeerAddr() if err != nil { return trace.Wrap(err) } peerAddrString = peerAddr.String() - proxyServer, err = proxy.NewServer(proxy.ServerConfig{ + proxyServer, err = peer.NewServer(peer.ServerConfig{ AccessCache: accessPoint, Listener: listeners.proxy, TLSConfig: serverTLSConfig, From fd4132beda727bf681fa4fe9aeaa122d02a19749 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 10 Nov 2022 16:41:22 -0500 Subject: [PATCH 2/3] Refactor proxy routing logic into a reusable object Routing logic existed within an unexported handler of ssh subsystem requests, which prevented it from being reused by other components within the proxy, like the webapi server. This causes significant latency issues for web sessions because the web apiserver is required to dial the proxy ssh server to determine how to route to the host. Since the web apiserver and the proxy ssh server exist in the same process this is an entirely unnecesarry step that could be avoided if the routing and ability to established connections were shared throughout the proxy. A new `proxy.Router` object is introduced which contains all the logic that used to exist in `regular.proxySubsys` for determining how to connect to servers and clusters. All routing within the `regular.proxySubsys` now leverages the `proxy.Router` to dial the target. This is step 1 in addressing #15167. Now that the `proxy.Router` exists `web.APIServer` will be able to make use of it to avoid dialing the same process to establish connections. --- lib/proxy/router.go | 408 ++++++++++++++++++++++++ lib/proxy/router_test.go | 509 ++++++++++++++++++++++++++++++ lib/reversetunnel/conn_metric.go | 16 +- lib/reversetunnel/localsite.go | 8 +- lib/service/service.go | 20 +- lib/srv/regular/proxy.go | 339 +++----------------- lib/srv/regular/proxy_test.go | 211 ------------- lib/srv/regular/sshserver.go | 8 +- lib/srv/regular/sshserver_test.go | 35 +- lib/web/apiserver_test.go | 24 +- 10 files changed, 1046 insertions(+), 532 deletions(-) create mode 100644 lib/proxy/router.go create mode 100644 lib/proxy/router_test.go diff --git a/lib/proxy/router.go b/lib/proxy/router.go new file mode 100644 index 0000000000000..b527608b16a26 --- /dev/null +++ b/lib/proxy/router.go @@ -0,0 +1,408 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "fmt" + "net" + "strconv" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/exp/slices" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/teleagent" + "github.com/gravitational/teleport/lib/utils" +) + +type serverResolverFn = func(ctx context.Context, host, port string, site site) (types.Server, error) + +// SiteGetter provides access to connected local or remote sites +type SiteGetter interface { + // GetSite returns the site matching the provided clusterName + GetSite(clusterName string) (reversetunnel.RemoteSite, error) +} + +// NopSiteGetter is an implementation of SiteGetter that always +// returns NopSite. Only used when the reverse tunnel server is disabled +// in tests. +type NopSiteGetter struct{} + +func (n NopSiteGetter) GetSite(string) (reversetunnel.RemoteSite, error) { + return NopSite{}, nil +} + +// NopSite is a reversetunnel.RemoteSite that is +// unimplemented and returned by NopSiteGetter.GetSite +type NopSite struct { + reversetunnel.RemoteSite +} + +// RemoteClusterGetter provides access to remote cluster resources +type RemoteClusterGetter interface { + // GetRemoteCluster returns a remote cluster by name + GetRemoteCluster(clusterName string) (types.RemoteCluster, error) +} + +// RouterConfig contains all the dependencies required +// by the Router +type RouterConfig struct { + // ClusterName indicates which cluster the router is for + ClusterName string + // Log is the logger to use + Log *logrus.Entry + // AccessPoint is the proxy cache + RemoteClusterGetter RemoteClusterGetter + // SiteGetter allows looking up sites + SiteGetter SiteGetter + // TracerProvider allows tracers to be created + TracerProvider oteltrace.TracerProvider + + // serverResolver is used to resolve hosts, used by tests + serverResolver serverResolverFn +} + +// CheckAndSetDefaults ensures the required items were populated +func (c *RouterConfig) CheckAndSetDefaults() error { + if c.Log == nil { + c.Log = logrus.WithField(trace.Component, "Router") + } + + if c.ClusterName == "" { + return trace.BadParameter("ClusterName must be provided") + } + + if c.RemoteClusterGetter == nil { + return trace.BadParameter("RemoteClusterGetter must be provided") + } + + if c.SiteGetter == nil { + return trace.BadParameter("SiteGetter must be provided") + } + + if c.TracerProvider == nil { + c.TracerProvider = tracing.DefaultProvider() + } + + if c.serverResolver == nil { + c.serverResolver = getServer + } + + return nil +} + +// Router is used by the proxy to establish connections to both +// nodes and other clusters. +type Router struct { + clusterName string + log *logrus.Entry + clusterGetter RemoteClusterGetter + localSite reversetunnel.RemoteSite + siteGetter SiteGetter + tracer oteltrace.Tracer + serverResolver serverResolverFn +} + +// NewRouter creates and returns a Router that is populated +// from the provided RouterConfig. +func NewRouter(cfg RouterConfig) (*Router, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + localSite, err := cfg.SiteGetter.GetSite(cfg.ClusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + return &Router{ + clusterName: cfg.ClusterName, + log: cfg.Log, + clusterGetter: cfg.RemoteClusterGetter, + localSite: localSite, + siteGetter: cfg.SiteGetter, + tracer: cfg.TracerProvider.Tracer("Router"), + serverResolver: cfg.serverResolver, + }, nil + +} + +// DialHost dials the node that matches the provided host, port and cluster. If no matching node +// is found an error is returned. If more than one matching node is found and the cluster networking +// configuration is not set to route to the most recent an error is returned. +func (r *Router) DialHost(ctx context.Context, from net.Addr, host, port, clusterName string, accessChecker services.AccessChecker, agentGetter teleagent.Getter) (net.Conn, error) { + ctx, span := r.tracer.Start( + ctx, + "router/DialHost", + oteltrace.WithAttributes( + attribute.String("host", host), + attribute.String("port", port), + attribute.String("site", clusterName), + ), + ) + defer span.End() + + site := r.localSite + if clusterName != r.clusterName { + remoteSite, err := r.getRemoteCluster(ctx, clusterName, accessChecker) + if err != nil { + return nil, trace.Wrap(err) + } + site = remoteSite + } + + span.AddEvent("looking up server") + target, err := r.serverResolver(ctx, host, port, remoteSite{site}) + if err != nil { + return nil, trace.Wrap(err) + } + span.AddEvent("retrieved target server") + + principals := []string{host} + + var ( + serverID string + serverAddr string + proxyIDs []string + ) + if target != nil { + proxyIDs = target.GetProxyIDs() + serverID = fmt.Sprintf("%v.%v", target.GetName(), clusterName) + + // add hostUUID.cluster to the principals + principals = append(principals, serverID) + + // add ip if it exists to the principals + serverAddr = target.GetAddr() + + switch { + case serverAddr != "": + h, _, err := net.SplitHostPort(serverAddr) + if err != nil { + return nil, trace.Wrap(err) + } + + principals = append(principals, h) + case serverAddr == "" && target.GetUseTunnel(): + serverAddr = reversetunnel.LocalNode + } + } else { + if port == "" || port == "0" { + port = strconv.Itoa(defaults.SSHServerListenPort) + } + + serverAddr = net.JoinHostPort(host, port) + r.log.Warnf("server lookup failed: using default=%v", serverAddr) + } + + conn, err := site.Dial(reversetunnel.DialParams{ + From: from, + To: &utils.NetAddr{AddrNetwork: "tcp", Addr: serverAddr}, + GetUserAgent: agentGetter, + Address: host, + ServerID: serverID, + ProxyIDs: proxyIDs, + Principals: principals, + ConnType: types.NodeTunnel, + }) + + return conn, trace.Wrap(err) +} + +// getRemoteCluster looks up the provided clusterName to determine if a remote site exists with +// that name and determines if the user has access to it. +func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, checker services.AccessChecker) (reversetunnel.RemoteSite, error) { + _, span := r.tracer.Start( + ctx, + "router/getRemoteCluster", + oteltrace.WithAttributes( + attribute.String("site", clusterName), + ), + ) + defer span.End() + + site, err := r.siteGetter.GetSite(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + rc, err := r.clusterGetter.GetRemoteCluster(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := checker.CheckAccessToRemoteCluster(rc); err != nil { + return nil, utils.OpaqueAccessDenied(err) + } + + return site, nil +} + +// site is the minimum interface needed to match servers +// for a reversetunnel.RemoteSite. It makes testing easier. +type site interface { + GetNodes(fn func(n services.Node) bool) ([]types.Server, error) + GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) +} + +// remoteSite is a site implementation that wraps +// a reversetunnel.RemoteSite +type remoteSite struct { + site reversetunnel.RemoteSite +} + +// GetNodes uses the wrapped sites NodeWatcher to filter nodes +func (r remoteSite) GetNodes(fn func(n services.Node) bool) ([]types.Server, error) { + watcher, err := r.site.NodeWatcher() + if err != nil { + return nil, trace.Wrap(err) + } + + return watcher.GetNodes(fn), nil +} + +// GetClusterNetworkingConfig uses the wrapped sites cache to retrieve the ClusterNetworkingConfig +func (r remoteSite) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) { + ap, err := r.site.CachingAccessPoint() + if err != nil { + return nil, trace.Wrap(err) + } + + cfg, err := ap.GetClusterNetworkingConfig(ctx, opts...) + return cfg, trace.Wrap(err) +} + +// getServer attempts to locate a node matching the provided host and port in +// the provided site. +func getServer(ctx context.Context, host, port string, site site) (types.Server, error) { + if site == nil { + return nil, trace.BadParameter("invalid remote site provided") + } + + strategy := types.RoutingStrategy_UNAMBIGUOUS_MATCH + if cfg, err := site.GetClusterNetworkingConfig(ctx); err == nil { + strategy = cfg.GetRoutingStrategy() + } + + _, err := uuid.Parse(host) + dialByID := err == nil || utils.IsEC2NodeID(host) + + ips, _ := net.LookupHost(host) + + var unambiguousIDMatch bool + matches, err := site.GetNodes(func(server services.Node) bool { + if unambiguousIDMatch { + return false + } + + // if host is a UUID or EC2 ID match only + // by server name and treat matches as unambiguous + if dialByID && server.GetName() == host { + unambiguousIDMatch = true + return true + } + + // if the server has connected over a reverse tunnel + // then match only by hostname + if server.GetUseTunnel() { + return host == server.GetHostname() + } + + ip, nodePort, err := net.SplitHostPort(server.GetAddr()) + if err != nil { + return false + } + + if (host == ip || host == server.GetHostname() || slices.Contains(ips, ip)) && + (port == "" || port == "0" || port == nodePort) { + return true + } + + return false + }) + if err != nil { + return nil, trace.Wrap(err) + } + + var server types.Server + switch { + case strategy == types.RoutingStrategy_MOST_RECENT: + for _, m := range matches { + if server == nil || m.Expiry().After(server.Expiry()) { + server = m + } + } + case len(matches) > 1: + return nil, trace.NotFound(teleport.NodeIsAmbiguous) + case len(matches) == 1: + server = matches[0] + } + + if dialByID && server == nil { + idType := "UUID" + if utils.IsEC2NodeID(host) { + idType = "EC2" + } + + return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, host) + } + + return server, nil + +} + +// DialSite establishes a connection to the auth server in the provided +// cluster. If the clusterName is an empty string then a connection to +// the local auth server will be established. +func (r *Router) DialSite(ctx context.Context, clusterName string) (net.Conn, error) { + _, span := r.tracer.Start( + ctx, + "router/DialSite", + oteltrace.WithAttributes( + attribute.String("site", clusterName), + ), + ) + defer span.End() + + // default to local cluster if one wasn't provided + if clusterName == "" { + clusterName = r.clusterName + } + + // dial the local auth server + if clusterName == r.clusterName { + conn, err := r.localSite.DialAuthServer() + return conn, trace.Wrap(err) + } + + // lookup the site and dial its auth server + site, err := r.siteGetter.GetSite(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + conn, err := site.DialAuthServer() + return conn, trace.Wrap(err) +} diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go new file mode 100644 index 0000000000000..019e665517b8d --- /dev/null +++ b/lib/proxy/router_test.go @@ -0,0 +1,509 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "net" + "testing" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/observability/tracing" + "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" +) + +type testSite struct { + cfg types.ClusterNetworkingConfig + nodes []types.Server +} + +func (t testSite) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) { + return t.cfg, nil +} + +func (t testSite) GetNodes(fn func(n services.Node) bool) ([]types.Server, error) { + var out []types.Server + for _, s := range t.nodes { + if fn(s) { + out = append(out, s) + } + } + + return out, nil +} + +type server struct { + name string + hostname string + addr string + tunnel bool +} + +func createServers(srvs []server) []types.Server { + out := make([]types.Server, 0, len(srvs)) + for _, s := range srvs { + srv := &types.ServerV2{ + Kind: types.KindNode, + Version: types.V2, + Metadata: types.Metadata{ + Name: s.name, + }, + Spec: types.ServerSpecV2{ + Addr: s.addr, + Hostname: s.hostname, + UseTunnel: s.tunnel, + }, + } + out = append(out, srv) + } + + return out +} + +func TestGetServers(t *testing.T) { + t.Parallel() + + mostRecentCfg := types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + RoutingStrategy: types.RoutingStrategy_MOST_RECENT, + }, + } + + unambiguousCfg := types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + RoutingStrategy: types.RoutingStrategy_UNAMBIGUOUS_MATCH, + }, + } + + hostID := uuid.NewString() + const ec2ID = "012345678901-i-01234567890abcdef" + + servers := createServers([]server{ + { + name: hostID, + hostname: "llama", + addr: "llama:123", + }, + { + name: "llama", + hostname: "llama", + addr: "llama:123", + tunnel: true, + }, + { + name: "llama", + hostname: hostID, + addr: "llama:123", + }, + { + name: ec2ID, + hostname: "node.aws", + addr: "node.aws:123", + }, + { + name: "node.aws", + hostname: "node.aws", + addr: "node.aws:123", + tunnel: true, + }, + { + name: "node.aws", + hostname: ec2ID, + addr: "node.aws:123", + }, + { + name: "alpaca", + hostname: "alpaca", + addr: "alpaca:123", + tunnel: true, + }, + { + name: "alpaca", + hostname: "localhost", + addr: "alpaca:987", + tunnel: true, + }, + { + name: "goat", + hostname: "goat", + addr: "1.2.3.4:123", + }, + { + name: "sheep", + hostname: "sheep", + addr: "sheep.bah:0", + }, + { + name: "sheep2", + hostname: "sheep", + addr: "sheep.bah:0", + }, + { + name: "lion", + hostname: "lion", + addr: "lion.roar", + }, + }) + + cases := []struct { + name string + host string + port string + site testSite + errAssertion require.ErrorAssertionFunc + serverAssertion func(t *testing.T, srv types.Server) + }{ + { + name: "no matches for hostname", + site: testSite{cfg: &unambiguousCfg}, + host: "test", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "no matches for uuid", + site: testSite{cfg: &mostRecentCfg}, + host: uuid.NewString(), + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsNotFound(err), i...) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "no matches for ec2 id", + site: testSite{cfg: &unambiguousCfg}, + host: "123456789012-i-1234567890abcdef0", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsNotFound(err), i...) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "ambiguous match fails", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "sheep", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous)) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "ambiguous match returns most recent", + site: testSite{cfg: &mostRecentCfg, nodes: servers}, + host: "sheep", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "sheep", srv.GetHostname()) + }, + }, + { + name: "match by uuid", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: hostID, + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "llama", srv.GetHostname()) + }, + }, + { + name: "match by ec2 id", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: ec2ID, + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "node.aws", srv.GetHostname()) + }, + }, + { + name: "match by ip", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "1.2.3.4", + port: "123", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "goat", srv.GetHostname()) + }, + }, + { + name: "match by host only for tunnels", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "alpaca", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "alpaca", srv.GetHostname()) + }, + }, + { + name: "failure on invalid addresses", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "lion", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + } + + ctx := context.Background() + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + srv, err := getServer(ctx, tt.host, tt.port, tt.site) + tt.errAssertion(t, err) + tt.serverAssertion(t, srv) + }) + } +} + +func serverResolver(srv types.Server, err error) serverResolverFn { + return func(ctx context.Context, host, port string, site site) (types.Server, error) { + return srv, err + } +} + +type tunnel struct { + reversetunnel.Tunnel + + site reversetunnel.RemoteSite + err error +} + +func (t tunnel) GetSite(cluster string) (reversetunnel.RemoteSite, error) { + return t.site, t.err +} + +type testRemoteSite struct { + reversetunnel.RemoteSite + conn net.Conn + err error +} + +func (r testRemoteSite) Dial(reversetunnel.DialParams) (net.Conn, error) { + return r.conn, r.err +} + +func (r testRemoteSite) DialAuthServer() (net.Conn, error) { + return r.conn, r.err +} + +type fakeConn struct { + net.Conn +} + +func TestRouter_DialHost(t *testing.T) { + t.Parallel() + + logger := utils.NewLoggerForTests().WithField(trace.Component, "test") + + srv := &types.ServerV2{ + Kind: types.KindNode, + Version: types.V2, + Metadata: types.Metadata{ + Name: uuid.NewString(), + }, + Spec: types.ServerSpecV2{ + Addr: "127.0.0.1:8889", + Hostname: "test", + }, + } + + cases := []struct { + name string + router Router + assertion func(t *testing.T, conn net.Conn, err error) + }{ + { + name: "failure resolving node", + router: Router{ + clusterName: "test", + log: logger, + tracer: tracing.NoopTracer("test"), + serverResolver: serverResolver(nil, trace.NotFound(teleport.NodeIsAmbiguous)), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.Nil(t, conn) + }, + }, + { + name: "failure looking up cluster", + router: Router{ + clusterName: "leaf", + siteGetter: tunnel{err: trace.NotFound("unknown cluster")}, + log: logger, + tracer: tracing.NoopTracer("test"), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + require.Nil(t, conn) + }, + }, + { + name: "dial failure", + router: Router{ + clusterName: "test", + log: logger, + localSite: &testRemoteSite{err: trace.ConnectionProblem(context.DeadlineExceeded, "connection refused")}, + tracer: tracing.NoopTracer("test"), + serverResolver: serverResolver(srv, nil), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err)) + require.Nil(t, conn) + }, + }, + { + name: "dial success", + router: Router{ + clusterName: "test", + log: logger, + localSite: &testRemoteSite{conn: fakeConn{}}, + tracer: tracing.NoopTracer("test"), + serverResolver: serverResolver(srv, nil), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + } + + ctx := context.Background() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + conn, err := tt.router.DialHost(ctx, &utils.NetAddr{}, "host", "0", "test", nil, nil) + tt.assertion(t, conn, err) + }) + } + +} + +func TestRouter_DialSite(t *testing.T) { + t.Parallel() + + const cluster = "test" + logger := utils.NewLoggerForTests().WithField(trace.Component, cluster) + + cases := []struct { + name string + cluster string + localSite testRemoteSite + tunnel tunnel + assertion func(t *testing.T, conn net.Conn, err error) + }{ + { + name: "failure to dial local site", + cluster: cluster, + localSite: testRemoteSite{err: trace.ConnectionProblem(context.DeadlineExceeded, "connection refused")}, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err)) + require.Nil(t, conn) + }, + }, + { + name: "successfully dial local site", + cluster: cluster, + localSite: testRemoteSite{conn: fakeConn{}}, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + + { + name: "default to dialing local site", + localSite: testRemoteSite{conn: fakeConn{}}, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + { + name: "failure to dial remote site", + cluster: "leaf", + tunnel: tunnel{ + site: testRemoteSite{err: trace.ConnectionProblem(context.DeadlineExceeded, "connection refused")}, + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err)) + require.Nil(t, conn) + }, + }, + { + name: "unknown cluster", + cluster: "fake", + tunnel: tunnel{ + err: trace.NotFound("unknown cluster"), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + require.Nil(t, conn) + }, + }, + { + name: "successfully dial remote site", + cluster: "leaf", + tunnel: tunnel{ + site: testRemoteSite{conn: fakeConn{}}, + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + } + + ctx := context.Background() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + router := Router{ + clusterName: cluster, + log: logger, + localSite: &tt.localSite, + siteGetter: tt.tunnel, + tracer: tracing.NoopTracer(cluster), + } + + conn, err := router.DialSite(ctx, tt.cluster) + tt.assertion(t, conn, err) + }) + } +} diff --git a/lib/reversetunnel/conn_metric.go b/lib/reversetunnel/conn_metric.go index d2913fded9873..ab6300884c635 100644 --- a/lib/reversetunnel/conn_metric.go +++ b/lib/reversetunnel/conn_metric.go @@ -28,16 +28,16 @@ import ( type dialType string const ( - // dialType_direct is a direct dialed connection. - dialType_direct dialType = "direct" - // dialType_peer is a connection established through a peer proxy. - dialType_peer dialType = "peer" - // dialType_tunnel is a connection established over a local reverse tunnel initiated + // dialTypeDirect is a direct dialed connection. + dialTypeDirect dialType = "direct" + // dialTypePeer is a connection established through a peer proxy. + dialTypePeer dialType = "peer" + // dialTypeTunnel is a connection established over a local reverse tunnel initiated // by a client. - dialType_tunnel dialType = "tunnel" - // dialType_peerTunnel is a connection established over a local reverse tunnel + dialTypeTunnel dialType = "tunnel" + // dialTypePeerTunnel is a connection established over a local reverse tunnel // initiated by a peer proxy. - dialType_peerTunnel dialType = "peer-tunnel" + dialTypePeerTunnel dialType = "peer-tunnel" ) // metricConn reports metrics for reversetunnel connections. diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 24ae4642ae400..558f9762543f8 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -447,9 +447,9 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e // return a tunnel connection to that node. Otherwise net.Dial to the target host. conn, tunnelErr = s.dialTunnel(dreq) if tunnelErr == nil { - dt := dialType_tunnel + dt := dialTypeTunnel if params.FromPeerProxy { - dt = dialType_peerTunnel + dt = dialTypePeerTunnel } return newMetricConn(conn, dt, dialStart, s.srv.Clock), true, nil @@ -462,7 +462,7 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e params.ProxyIDs, params.ServerID, params.From, params.To, params.ConnType, ) if peerErr == nil { - return newMetricConn(conn, dialType_peer, dialStart, s.srv.Clock), true, nil + return newMetricConn(conn, dialTypePeer, dialStart, s.srv.Clock), true, nil } s.log.WithError(peerErr).WithField("address", dreq.Address).Debug("Error occurred while dialing over peer proxy.") } @@ -494,7 +494,7 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e } // Return a direct dialed connection. - return newMetricConn(conn, dialType_direct, dialStart, s.srv.Clock), false, nil + return newMetricConn(conn, dialTypeDirect, dialStart, s.srv.Clock), false, nil } func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.Conn, sconn ssh.Conn) (*remoteConn, error) { diff --git a/lib/service/service.go b/lib/service/service.go index fcfa24107c8bf..64bcab56a26b7 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -93,6 +93,7 @@ import ( "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/plugin" + "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/proxy/clusterdial" "github.com/gravitational/teleport/lib/proxy/peer" restricted "github.com/gravitational/teleport/lib/restrictedsession" @@ -3441,6 +3442,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return nil }) } + if !process.Config.Proxy.DisableTLS { tlsConfigWeb, err = process.setupProxyTLSConfig(conn, tsrv, accessPoint, clusterName) if err != nil { @@ -3593,6 +3595,22 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { }) } + var siteGetter proxy.SiteGetter = tsrv + if process.Config.Proxy.DisableReverseTunnel { + siteGetter = proxy.NopSiteGetter{} + } + + proxyRouter, err := proxy.NewRouter(proxy.RouterConfig{ + ClusterName: clusterName, + Log: process.log.WithField(trace.Component, "router"), + RemoteClusterGetter: accessPoint, + SiteGetter: siteGetter, + TracerProvider: process.TracingProvider, + }) + if err != nil { + return trace.Wrap(err) + } + sshProxy, err := regular.New(cfg.Proxy.SSHAddr, cfg.Hostname, []ssh.Signer{conn.ServerIdentity.KeySigner}, @@ -3602,7 +3620,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { process.proxyPublicAddr(), conn.Client, regular.SetLimiter(proxyLimiter), - regular.SetProxyMode(peerAddrString, tsrv, accessPoint), + regular.SetProxyMode(peerAddrString, tsrv, accessPoint, proxyRouter), regular.SetCiphers(cfg.Ciphers), regular.SetKEXAlgorithms(cfg.KEXAlgorithms), regular.SetMACAlgorithms(cfg.MACAlgorithms), diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index 9baae13730055..76283f86f528a 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -23,26 +23,19 @@ import ( "fmt" "io" "net" - "strconv" "strings" - "sync" - "github.com/google/uuid" "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" - "golang.org/x/exp/slices" "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/observability/tracing" - "github.com/gravitational/teleport/api/types" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/observability/metrics" - "github.com/gravitational/teleport/lib/reversetunnel" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -70,20 +63,20 @@ var ( // failedConnectingToNode counts failed attempts to connect to a node Help: "Number of SSH connection attempts to a node. Use with `failed_connect_to_node_attempts_total` to get the failure rate.", }, ) - - prometheusCollectors = []prometheus.Collector{proxiedSessions, failedConnectingToNode, connectingToNode} ) +func init() { + metrics.RegisterPrometheusCollectors(proxiedSessions, failedConnectingToNode, connectingToNode) +} + // proxySubsys implements an SSH subsystem for proxying listening sockets from // remote hosts to a proxy client (AKA port mapping) type proxySubsys struct { proxySubsysRequest - srv *Server - ctx *srv.ServerContext - log *logrus.Entry - closeC chan struct{} - error error - closeOnce sync.Once + router *proxy.Router + ctx *srv.ServerContext + log *logrus.Entry + closeC chan error } // parseProxySubsys looks at the requested subsystem name and returns a fully configured @@ -110,8 +103,9 @@ func parseProxySubsysRequest(request string) (proxySubsysRequest, error) { } requestBody := strings.TrimPrefix(request, prefix) namespace := apidefaults.Namespace - var err error parts := strings.Split(requestBody, "@") + + var err error switch { case len(parts) == 0: // "proxy:" return proxySubsysRequest{}, trace.BadParameter(paramMessage) @@ -190,11 +184,6 @@ func (p *proxySubsysRequest) SetDefaults() { // a port forwarding request, used to implement ProxyJump feature in proxy // and reuse the code func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) (*proxySubsys, error) { - err := metrics.RegisterPrometheusCollectors(prometheusCollectors...) - if err != nil { - return nil, trace.Wrap(err) - } - req.SetDefaults() if req.clusterName == "" && ctx.Identity.RouteToCluster != "" { log.Debugf("Proxy subsystem: routing user %q to cluster %q based on the route to cluster extension.", @@ -212,12 +201,12 @@ func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) return &proxySubsys{ proxySubsysRequest: req, ctx: ctx, - srv: srv, log: logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentSubsystemProxy, trace.ComponentFields: map[string]string{}, }), - closeC: make(chan struct{}), + closeC: make(chan error), + router: srv.router, }, nil } @@ -239,12 +228,8 @@ func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.C }) t.log.Debugf("Starting subsystem") - var ( - site reversetunnel.RemoteSite - err error - tunnel = t.srv.tunnelWithAccessChecker(serverContext) - clientAddr = sconn.RemoteAddr() - ) + clientAddr := sconn.RemoteAddr() + // did the client pass us a true client IP ahead of time via an environment variable? // (usually the web client would do that) trueClientIP, ok := serverContext.GetEnv(sshutils.TrueClientAddrVar) @@ -254,185 +239,41 @@ func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.C clientAddr = a } } - // get the cluster by name: - if t.clusterName != "" { - site, err = tunnel.GetSite(t.clusterName) - if err != nil { - t.log.Warn(err) - return trace.Wrap(err) - } - } - // connecting to a specific host: - if t.host != "" { - // no site given? use the 1st one: - if site == nil { - sites, err := tunnel.GetSites() - if err != nil { - return trace.Wrap(err) - } - if len(sites) == 0 { - t.log.Error("Not connected to any remote clusters") - return trace.NotFound("no connected sites") - } - site = sites[0] - t.clusterName = site.GetName() - t.log.Debugf("Cluster not specified. connecting to default='%s'", site.GetName()) - } - return t.proxyToHost(ctx, site, clientAddr, ch) + + // connect to a site's auth server + if t.host == "" { + return t.proxyToSite(ctx, ch, t.clusterName) } - // connect to a site's auth server: - return t.proxyToSite(serverContext, site, clientAddr, ch) + + // connect to a server + return t.proxyToHost(ctx, ch, clientAddr) } // proxyToSite establishes a proxy connection from the connected SSH client to the // auth server of the requested remote site -func (t *proxySubsys) proxyToSite( - ctx *srv.ServerContext, site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { - conn, err := site.DialAuthServer() +func (t *proxySubsys) proxyToSite(ctx context.Context, ch ssh.Channel, clusterName string) error { + t.log.Debugf("Connecting to site: %v", clusterName) + + conn, err := t.router.DialSite(ctx, clusterName) if err != nil { return trace.Wrap(err) } - t.log.Infof("Connected to auth server: %v", conn.RemoteAddr()) + t.log.Infof("Connected to cluster %v at %v", clusterName, conn.RemoteAddr()) proxiedSessions.Inc() + go func() { - var err error - defer func() { - t.close(err) - }() - defer ch.Close() - _, err = io.Copy(ch, conn) - }() - go func() { - var err error - defer func() { - t.close(err) - }() - defer conn.Close() - _, err = io.Copy(conn, ch) + t.close(utils.ProxyConn(ctx, ch, conn)) }() - return nil } // proxyToHost establishes a proxy connection from the connected SSH client to the // requested remote node (t.host:t.port) via the given site -func (t *proxySubsys) proxyToHost( - ctx context.Context, site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { - // - // first, lets fetch a list of servers at the given site. this allows us to - // match the given "host name" against node configuration (their 'nodename' setting) - // - // but failing to fetch the list of servers is also OK, we'll use standard - // network resolution (by IP or DNS) - // - var ( - strategy types.RoutingStrategy - nodeWatcher NodesGetter - err error - ) - localCluster, _ := t.srv.proxyAccessPoint.GetClusterName() - // going to "local" CA? lets use the caching 'auth service' directly and avoid - // hitting the reverse tunnel link (it can be offline if the CA is down) - if site.GetName() == localCluster.GetName() { - nodeWatcher = t.srv.nodeWatcher +func (t *proxySubsys) proxyToHost(ctx context.Context, ch ssh.Channel, remoteAddr net.Addr) error { + t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v", t.host, t.port, t.SpecifiedPort()) - cfg, err := t.srv.authService.GetClusterNetworkingConfig(ctx) - if err != nil { - t.log.Warn(err) - } else { - strategy = cfg.GetRoutingStrategy() - } - } else { - // "remote" CA? use a reverse tunnel to talk to it: - siteClient, err := site.CachingAccessPoint() - if err != nil { - t.log.Warn(err) - } else { - watcher, err := site.NodeWatcher() - if err != nil { - t.log.Warn(err) - } else { - nodeWatcher = watcher - } - - cfg, err := siteClient.GetClusterNetworkingConfig(ctx) - if err != nil { - t.log.Warn(err) - } else { - strategy = cfg.GetRoutingStrategy() - } - } - } - - // if port is 0, it means the client wants us to figure out - // which port to use - t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v, strategy=%s", t.host, t.port, t.SpecifiedPort(), strategy) - - // determine which server to connect to - server, err := t.getMatchingServer(nodeWatcher, strategy) - if err != nil { - return trace.Wrap(err) - } - - // Create a slice of principals that will be added into the host certificate. - // Here t.host is either an IP address or a DNS name as the user requested. - principals := []string{t.host} - - // Used to store the server ID (hostUUID.clusterName) of a Teleport node. - var serverID string - - // Resolve the IP address to dial to because the hostname may not be - // DNS resolvable. - var ( - serverAddr string - proxyIDs []string - ) - - if server != nil { - // Add hostUUID.clusterName to list of principals. - serverID = fmt.Sprintf("%v.%v", server.GetName(), t.clusterName) - principals = append(principals, serverID) - proxyIDs = server.GetProxyIDs() - - // Add IP address (if it exists) of the node to list of principals. - serverAddr = server.GetAddr() - if serverAddr != "" { - host, _, err := net.SplitHostPort(serverAddr) - if err != nil { - return trace.Wrap(err) - } - principals = append(principals, host) - } else if server.GetUseTunnel() { - serverAddr = reversetunnel.LocalNode - } - } else { - if !t.SpecifiedPort() { - t.port = strconv.Itoa(defaults.SSHServerListenPort) - } - serverAddr = net.JoinHostPort(t.host, t.port) - t.log.Warnf("server lookup failed: using default=%v", serverAddr) - } - - // Pass the agent along to the site. If the proxy is in recording mode, this - // agent is used to perform user authentication. Pass the DNS name to the - // dialer as well so the forwarding proxy can generate a host certificate - // with the correct hostname). - toAddr := &utils.NetAddr{ - AddrNetwork: "tcp", - Addr: serverAddr, - } - connectingToNode.Inc() - conn, err := site.Dial(reversetunnel.DialParams{ - From: remoteAddr, - To: toAddr, - GetUserAgent: t.ctx.StartAgentChannel, - Address: t.host, - ServerID: serverID, - ProxyIDs: proxyIDs, - Principals: principals, - ConnType: types.NodeTunnel, - }) + conn, err := t.router.DialHost(ctx, remoteAddr, t.host, t.port, t.clusterName, t.ctx.Identity.AccessChecker, t.ctx.StartAgentChannel) if err != nil { failedConnectingToNode.Inc() return trace.Wrap(err) @@ -443,126 +284,20 @@ func (t *proxySubsys) proxyToHost( t.doHandshake(ctx, remoteAddr, ch, conn) proxiedSessions.Inc() + go func() { - var err error - defer func() { - t.close(err) - }() - defer ch.Close() - _, err = io.Copy(ch, conn) - }() - go func() { - var err error - defer func() { - t.close(err) - }() - defer conn.Close() - _, err = io.Copy(conn, ch) + t.close(utils.ProxyConn(ctx, ch, conn)) }() - return nil } -// NodesGetter is a function that retrieves a subset of nodes matching -// the filter criteria. -type NodesGetter interface { - GetNodes(fn func(n services.Node) bool) []types.Server -} - -// getMatchingServer determines the server to connect to from the provided servers. Duplicate entries are treated -// differently based on strategy. Legacy behavior of returning an ambiguous error occurs if the strategy -// is types.RoutingStrategy_UNAMBIGUOUS_MATCH. When the strategy is types.RoutingStrategy_MOST_RECENT then -// the server that has heartbeated most recently will be returned instead of an error. If no matches are found then -// both the types.Server and error returned will be nil. -func (t *proxySubsys) getMatchingServer(watcher NodesGetter, strategy types.RoutingStrategy) (types.Server, error) { - if watcher == nil { - return nil, trace.NotFound("unable to retrieve nodes matching host %s", t.host) - } - - // check if hostname is a valid uuid or EC2 node ID. If it is, we will - // preferentially match by node ID over node hostname. - _, err := uuid.Parse(t.host) - hostIsUniqueID := err == nil || utils.IsEC2NodeID(t.host) - - ips, _ := net.LookupHost(t.host) - - var unambiguousIDMatch bool - // enumerate and try to find a server with self-registered with a matching name/IP: - matches := watcher.GetNodes(func(server services.Node) bool { - if unambiguousIDMatch { - return false - } - - // If the host parameter is a UUID or EC2 node ID, and it matches the - // Node ID, treat this as an unambiguous match. - if hostIsUniqueID && server.GetName() == t.host { - unambiguousIDMatch = true - return true - } - // If the server has connected over a reverse tunnel, match only on hostname. - if server.GetUseTunnel() { - return t.host == server.GetHostname() - } - - ip, port, err := net.SplitHostPort(server.GetAddr()) - if err != nil { - t.log.Errorf("Failed to parse address %q: %v.", server.GetAddr(), err) - return false - } - if t.host == ip || t.host == server.GetHostname() || slices.Contains(ips, ip) { - if !t.SpecifiedPort() || t.port == port { - return true - } - } - return false - }) - - var server types.Server - switch { - case strategy == types.RoutingStrategy_MOST_RECENT: - // find the most recent of all the matches - for _, m := range matches { - if server == nil || m.Expiry().After(server.Expiry()) { - server = m - } - } - case len(matches) > 1: - // if we matched more than one server, then the target was ambiguous. - return nil, trace.NotFound(teleport.NodeIsAmbiguous) - case len(matches) == 1: - server = matches[0] - } - - // If we matched zero nodes but hostname is a UUID (or EC2 node ID) then it - // isn't sane to fallback to dns based resolution. This has the unfortunate - // consequence of preventing users from calling OpenSSH nodes which happen - // to use hostnames which are also valid UUIDs. This restriction is - // necessary in order to protect users attempting to connect to a node by - // UUID from being re-routed to an unintended target if the node is offline. - // This restriction can be lifted if we decide to move to explicit UUID - // based resolution in the future. - if hostIsUniqueID && server == nil { - idType := "UUID" - if utils.IsEC2NodeID(t.host) { - idType = "EC2" - } - return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, t.host) - } - - return server, nil -} - func (t *proxySubsys) close(err error) { - t.closeOnce.Do(func() { - proxiedSessions.Dec() - t.error = err - close(t.closeC) - }) + proxiedSessions.Dec() + t.closeC <- err } func (t *proxySubsys) Wait() error { - <-t.closeC - return t.error + return <-t.closeC } // doHandshake allows a proxy server to send additional information (client IP) diff --git a/lib/srv/regular/proxy_test.go b/lib/srv/regular/proxy_test.go index d6259bbb209cf..8fef4a19d7186 100644 --- a/lib/srv/regular/proxy_test.go +++ b/lib/srv/regular/proxy_test.go @@ -18,14 +18,10 @@ package regular import ( "testing" - "time" - "github.com/google/uuid" "github.com/stretchr/testify/require" apidefaults "github.com/gravitational/teleport/api/defaults" - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" ) @@ -130,210 +126,3 @@ func TestParseBadRequests(t *testing.T) { }) } } - -type nodeGetter struct { - servers []types.Server -} - -func (n nodeGetter) GetNodes(fn func(n services.Node) bool) []types.Server { - var servers []types.Server - for _, s := range n.servers { - if fn(s) { - servers = append(servers, s) - } - } - - return servers -} - -func TestProxySubsys_getMatchingServer(t *testing.T) { - t.Parallel() - - serverUUID := uuid.NewString() - - ec2NodeID := "123456789012-i-abcdef12345678901" - - setExpiry := func(time time.Time) func(server types.Server) { - return func(server types.Server) { - server.SetExpiry(time) - } - } - - createServer := func(name string, spec types.ServerSpecV2, opts ...func(server types.Server)) types.Server { - t.Helper() - - server, err := types.NewServer(name, types.KindNode, spec) - require.NoError(t, err) - - for _, opt := range opts { - opt(server) - } - - return server - } - - servers := []types.Server{ - createServer(serverUUID, types.ServerSpecV2{ - Hostname: "127.0.0.1", - Addr: "127.0.0.1:80", - }, setExpiry(time.Now().Add(-time.Hour))), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - }, setExpiry(time.Now().Add(time.Hour*24))), - createServer("server3", types.ServerSpecV2{ - Hostname: serverUUID, - Addr: "127.0.0.1:", - }), - createServer(ec2NodeID, types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:", - }), - } - - cases := []struct { - desc string - req proxySubsysRequest - strategy types.RoutingStrategy - servers []types.Server - expectError require.ErrorAssertionFunc - expectServer func(servers []types.Server) types.Server - }{ - { - desc: "No matches found", - expectError: require.NoError, - }, - { - desc: "No matches found for UUID host", - expectError: require.Error, - servers: []types.Server{createServer(uuid.NewString(), types.ServerSpecV2{ - Addr: "127.0.0.1:0", - })}, - req: proxySubsysRequest{ - host: uuid.NewString(), - }, - }, - { - desc: "Match by UUID", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[0] - }, - servers: servers, - req: proxySubsysRequest{ - host: serverUUID, - }, - }, - { - desc: "Match by EC2 ID", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[3] - }, - servers: servers, - req: proxySubsysRequest{ - host: ec2NodeID, - }, - }, - { - desc: "Match Tunnel By Host Only", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[0] - }, - servers: []types.Server{ - createServer("server1", types.ServerSpecV2{ - Addr: "127.0.0.1", - Hostname: "127.0.0.1", - UseTunnel: true, - }), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - UseTunnel: true, - }), - }, - req: proxySubsysRequest{ - host: "127.0.0.1", - port: "80", - }, - }, - { - desc: "Match by IP", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[1] - }, - servers: []types.Server{ - createServer("server1", types.ServerSpecV2{ - Addr: "127.0.0.1:0", - Hostname: "127.0.0.1", - }), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - }), - }, - req: proxySubsysRequest{ - host: "127.0.0.1", - port: "80", - }, - }, - { - desc: "Match by hostname", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[1] - }, - servers: []types.Server{ - createServer("server1", types.ServerSpecV2{ - Addr: "127.0.0.1:0", - Hostname: "localhost", - }), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - }), - }, - req: proxySubsysRequest{ - host: "localhost", - port: "80", - }, - }, - { - desc: "Ambiguous match", - expectError: require.Error, - servers: servers, - req: proxySubsysRequest{ - host: "localhost", - }, - }, - { - desc: "Most Recent match", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[1] - }, - servers: servers, - strategy: types.RoutingStrategy_MOST_RECENT, - req: proxySubsysRequest{ - host: "localhost", - }, - }, - } - - for _, tt := range cases { - t.Run(tt.desc, func(t *testing.T) { - subsystem := proxySubsys{ - proxySubsysRequest: tt.req, - srv: &Server{}, - } - - server, err := subsystem.getMatchingServer(nodeGetter{tt.servers}, tt.strategy) - tt.expectError(t, err) - if tt.expectServer != nil { - require.Equal(t, tt.expectServer(tt.servers), server) - } - }) - } -} diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 65b587e867140..b07c5c79003c6 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -55,6 +55,7 @@ import ( "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/pam" + "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -223,6 +224,10 @@ type Server struct { // tracerProvider is used to create tracers capable // of starting spans. tracerProvider oteltrace.TracerProvider + + // router used by subsystem requests to connect to nodes + // and clusters + router *proxy.Router } // TargetMetadata returns metadata about the server. @@ -422,7 +427,7 @@ func SetShell(shell string) ServerOption { } // SetProxyMode starts this server in SSH proxying mode -func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyAccessPoint) ServerOption { +func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyAccessPoint, router *proxy.Router) ServerOption { return func(s *Server) error { // always set proxy mode to true, // because in some tests reverse tunnel is disabled, @@ -431,6 +436,7 @@ func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyA s.proxyTun = tsrv s.proxyAccessPoint = ap s.peerAddr = peerAddr + s.router = router return nil } } diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index d8d86c91e0e2e..e4ac44763c481 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -59,7 +59,9 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/limiter" + "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/pam" + libproxy "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -1383,6 +1385,15 @@ func TestProxyRoundRobin(t *testing.T) { require.NoError(t, reverseTunnelServer.Start()) defer reverseTunnelServer.Close() + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: f.testSrv.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: proxyClient, + SiteGetter: reverseTunnelServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + proxy, err := New( utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), @@ -1392,7 +1403,7 @@ func TestProxyRoundRobin(t *testing.T) { "", utils.NetAddr{}, proxyClient, - SetProxyMode("", reverseTunnelServer, proxyClient), + SetProxyMode("", reverseTunnelServer, proxyClient, router), SetEmitter(nodeClient), SetNamespace(apidefaults.Namespace), SetPAMConfig(&pam.Config{Enabled: false}), @@ -1503,6 +1514,15 @@ func TestProxyDirectAccess(t *testing.T) { nodeClient, _ := newNodeClient(t, f.testSrv) + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: f.testSrv.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: proxyClient, + SiteGetter: reverseTunnelServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + proxy, err := New( utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), @@ -1512,7 +1532,7 @@ func TestProxyDirectAccess(t *testing.T) { "", utils.NetAddr{}, proxyClient, - SetProxyMode("", reverseTunnelServer, proxyClient), + SetProxyMode("", reverseTunnelServer, proxyClient, router), SetEmitter(nodeClient), SetNamespace(apidefaults.Namespace), SetPAMConfig(&pam.Config{Enabled: false}), @@ -2233,6 +2253,15 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { nodeClient, _ := newNodeClient(t, f.testSrv) + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: f.testSrv.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: proxyClient, + SiteGetter: reverseTunnelServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + proxy, err := New( utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), @@ -2242,7 +2271,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { "", utils.NetAddr{}, proxyClient, - SetProxyMode("", reverseTunnelServer, proxyClient), + SetProxyMode("", reverseTunnelServer, proxyClient, router), SetEmitter(nodeClient), SetNamespace(apidefaults.Namespace), SetPAMConfig(&pam.Config{Enabled: false}), diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 5bf9d6eec8ac1..645968c846383 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -95,7 +95,9 @@ import ( kubeproxy "github.com/gravitational/teleport/lib/kube/proxy" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/pam" + libproxy "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/secret" @@ -329,6 +331,15 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { require.NoError(t, err) s.proxyTunnel = revTunServer + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: s.server.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: s.proxyClient, + SiteGetter: revTunServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + // proxy server: s.proxy, err = regular.New( utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, @@ -340,7 +351,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { utils.NetAddr{}, s.proxyClient, regular.SetUUID(proxyID), - regular.SetProxyMode("", revTunServer, s.proxyClient), + regular.SetProxyMode("", revTunServer, s.proxyClient, router), regular.SetEmitter(s.proxyClient), regular.SetNamespace(apidefaults.Namespace), regular.SetBPF(&bpf.NOP{}), @@ -5847,6 +5858,15 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) }) + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: authServer.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: client, + SiteGetter: revTunServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + proxyServer, err := regular.New( utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, authServer.ClusterName(), @@ -5857,7 +5877,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula utils.NetAddr{AddrNetwork: "tcp", Addr: "proxy-1.example.com:443"}, client, regular.SetUUID(proxyID), - regular.SetProxyMode("", revTunServer, client), + regular.SetProxyMode("", revTunServer, client, router), regular.SetEmitter(client), regular.SetNamespace(apidefaults.Namespace), regular.SetBPF(&bpf.NOP{}), From a278ef2483099227f5de616b1eb4a4d96007e215 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 17 Nov 2022 08:06:03 -0500 Subject: [PATCH 3/3] don't create a router when reverse tunnels are disabled --- lib/proxy/router.go | 15 --------------- lib/service/service.go | 27 +++++++++++++-------------- 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/lib/proxy/router.go b/lib/proxy/router.go index b527608b16a26..6925147d345fe 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -45,21 +45,6 @@ type SiteGetter interface { GetSite(clusterName string) (reversetunnel.RemoteSite, error) } -// NopSiteGetter is an implementation of SiteGetter that always -// returns NopSite. Only used when the reverse tunnel server is disabled -// in tests. -type NopSiteGetter struct{} - -func (n NopSiteGetter) GetSite(string) (reversetunnel.RemoteSite, error) { - return NopSite{}, nil -} - -// NopSite is a reversetunnel.RemoteSite that is -// unimplemented and returned by NopSiteGetter.GetSite -type NopSite struct { - reversetunnel.RemoteSite -} - // RemoteClusterGetter provides access to remote cluster resources type RemoteClusterGetter interface { // GetRemoteCluster returns a remote cluster by name diff --git a/lib/service/service.go b/lib/service/service.go index 64bcab56a26b7..01b3735e7293e 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3595,20 +3595,19 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { }) } - var siteGetter proxy.SiteGetter = tsrv - if process.Config.Proxy.DisableReverseTunnel { - siteGetter = proxy.NopSiteGetter{} - } - - proxyRouter, err := proxy.NewRouter(proxy.RouterConfig{ - ClusterName: clusterName, - Log: process.log.WithField(trace.Component, "router"), - RemoteClusterGetter: accessPoint, - SiteGetter: siteGetter, - TracerProvider: process.TracingProvider, - }) - if err != nil { - return trace.Wrap(err) + var proxyRouter *proxy.Router + if !process.Config.Proxy.DisableReverseTunnel { + router, err := proxy.NewRouter(proxy.RouterConfig{ + ClusterName: clusterName, + Log: process.log.WithField(trace.Component, "router"), + RemoteClusterGetter: accessPoint, + SiteGetter: tsrv, + TracerProvider: process.TracingProvider, + }) + if err != nil { + return trace.Wrap(err) + } + proxyRouter = router } sshProxy, err := regular.New(cfg.Proxy.SSHAddr,