diff --git a/integration/utmp_integration_test.go b/integration/utmp_integration_test.go index ee7e98f2a72e3..6f119f43567bd 100644 --- a/integration/utmp_integration_test.go +++ b/integration/utmp_integration_test.go @@ -40,6 +40,7 @@ import ( "github.com/gravitational/teleport/lib/pam" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/regular" "github.com/gravitational/teleport/lib/srv/uacc" "github.com/gravitational/teleport/lib/sshutils" @@ -254,8 +255,19 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx { require.NoError(t, err) t.Cleanup(lockWatcher.Close) + nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: s.nodeClient, + AccessPoint: s.nodeClient, + LockEnforcer: lockWatcher, + Emitter: s.nodeClient, + Component: teleport.ComponentNode, + ServerID: s.nodeID, + }) + require.NoError(t, err) + nodeDir := t.TempDir() srv, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{s.signer}, @@ -284,6 +296,7 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx { regular.SetClock(s.clock), regular.SetUtmpPath(utmpPath, utmpPath), regular.SetLockWatcher(lockWatcher), + regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) s.srv = srv diff --git a/lib/proxy/clusterdial/dial.go b/lib/proxy/clusterdial/dial.go index 65b9c4cee9267..59a01e9a97dd8 100644 --- a/lib/proxy/clusterdial/dial.go +++ b/lib/proxy/clusterdial/dial.go @@ -17,22 +17,23 @@ package clusterdial import ( "net" - "github.com/gravitational/teleport/lib/proxy" - "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/proxy/peer" + "github.com/gravitational/teleport/lib/reversetunnel" ) -// ClusterDialerFunc is a function that implements a proxy.ClusterDialer. -type ClusterDialerFunc func(clusterName string, request proxy.DialParams) (net.Conn, error) +// ClusterDialerFunc is a function that implements a peer.ClusterDialer. +type ClusterDialerFunc func(clusterName string, request peer.DialParams) (net.Conn, error) // Dial dials makes a dial request to the given cluster. -func (f ClusterDialerFunc) Dial(clusterName string, request proxy.DialParams) (net.Conn, error) { +func (f ClusterDialerFunc) Dial(clusterName string, request peer.DialParams) (net.Conn, error) { return f(clusterName, request) } // NewClusterDialer implements proxy.ClusterDialer for a reverse tunnel server. func NewClusterDialer(server reversetunnel.Server) ClusterDialerFunc { - return ClusterDialerFunc(func(clusterName string, request proxy.DialParams) (net.Conn, error) { + return func(clusterName string, request peer.DialParams) (net.Conn, error) { site, err := server.GetSite(clusterName) if err != nil { return nil, trace.Wrap(err) @@ -51,5 +52,5 @@ func NewClusterDialer(server reversetunnel.Server) ClusterDialerFunc { return nil, trace.Wrap(err) } return conn, nil - }) + } } diff --git a/lib/proxy/auth.go b/lib/proxy/peer/auth.go similarity index 99% rename from lib/proxy/auth.go rename to lib/proxy/peer/auth.go index 6e2888fd0f5e0..6af3781850fdf 100644 --- a/lib/proxy/auth.go +++ b/lib/proxy/peer/auth.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/client.go b/lib/proxy/peer/client.go similarity index 99% rename from lib/proxy/client.go rename to lib/proxy/peer/client.go index b3c4f1f270539..c17ed1efdc85a 100644 --- a/lib/proxy/client.go +++ b/lib/proxy/peer/client.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/client_test.go b/lib/proxy/peer/client_test.go similarity index 99% rename from lib/proxy/client_test.go rename to lib/proxy/peer/client_test.go index bfd558cb4a8f7..3ca05bbc80a6a 100644 --- a/lib/proxy/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "crypto/tls" diff --git a/lib/proxy/clientmetrics.go b/lib/proxy/peer/clientmetrics.go similarity index 99% rename from lib/proxy/clientmetrics.go rename to lib/proxy/peer/clientmetrics.go index 5c7265b9ae2b4..85207e43153f3 100644 --- a/lib/proxy/clientmetrics.go +++ b/lib/proxy/peer/clientmetrics.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "github.com/gravitational/teleport/lib/utils" diff --git a/lib/proxy/conn.go b/lib/proxy/peer/conn.go similarity index 99% rename from lib/proxy/conn.go rename to lib/proxy/peer/conn.go index f2e7a87f7419d..faa8a7c0f7f31 100644 --- a/lib/proxy/conn.go +++ b/lib/proxy/peer/conn.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/conn_test.go b/lib/proxy/peer/conn_test.go similarity index 99% rename from lib/proxy/conn_test.go rename to lib/proxy/peer/conn_test.go index 0ffec8a853bb4..5349ed143b74e 100644 --- a/lib/proxy/conn_test.go +++ b/lib/proxy/peer/conn_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/helpers_test.go b/lib/proxy/peer/helpers_test.go similarity index 99% rename from lib/proxy/helpers_test.go rename to lib/proxy/peer/helpers_test.go index 9ff3edc3d1173..8ef4b4f450983 100644 --- a/lib/proxy/helpers_test.go +++ b/lib/proxy/peer/helpers_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "crypto/rand" diff --git a/lib/proxy/interceptor.go b/lib/proxy/peer/interceptor.go similarity index 99% rename from lib/proxy/interceptor.go rename to lib/proxy/peer/interceptor.go index de838322c48ff..19f103b9b8aac 100644 --- a/lib/proxy/interceptor.go +++ b/lib/proxy/peer/interceptor.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/reporter.go b/lib/proxy/peer/reporter.go similarity index 99% rename from lib/proxy/reporter.go rename to lib/proxy/peer/reporter.go index a6401f88ae11d..400e5881ff5d2 100644 --- a/lib/proxy/reporter.go +++ b/lib/proxy/peer/reporter.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "time" diff --git a/lib/proxy/server.go b/lib/proxy/peer/server.go similarity index 99% rename from lib/proxy/server.go rename to lib/proxy/peer/server.go index 7a4ad87b9279c..21741247d2fa8 100644 --- a/lib/proxy/server.go +++ b/lib/proxy/peer/server.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "crypto/tls" diff --git a/lib/proxy/server_test.go b/lib/proxy/peer/server_test.go similarity index 99% rename from lib/proxy/server_test.go rename to lib/proxy/peer/server_test.go index 3ae4d7f877fc0..9463bd60ff26e 100644 --- a/lib/proxy/server_test.go +++ b/lib/proxy/peer/server_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "testing" diff --git a/lib/proxy/servermetrics.go b/lib/proxy/peer/servermetrics.go similarity index 99% rename from lib/proxy/servermetrics.go rename to lib/proxy/peer/servermetrics.go index f21acf9cd9210..ac0eea8b5882f 100644 --- a/lib/proxy/servermetrics.go +++ b/lib/proxy/peer/servermetrics.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "github.com/gravitational/teleport/lib/utils" diff --git a/lib/proxy/service.go b/lib/proxy/peer/service.go similarity index 99% rename from lib/proxy/service.go rename to lib/proxy/peer/service.go index 21f6db72bd989..a1f1a63521c18 100644 --- a/lib/proxy/service.go +++ b/lib/proxy/peer/service.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "net" diff --git a/lib/proxy/service_test.go b/lib/proxy/peer/service_test.go similarity index 99% rename from lib/proxy/service_test.go rename to lib/proxy/peer/service_test.go index 04f6aaf189dba..69d7de8378c37 100644 --- a/lib/proxy/service_test.go +++ b/lib/proxy/peer/service_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" diff --git a/lib/proxy/stats.go b/lib/proxy/peer/stats.go similarity index 98% rename from lib/proxy/stats.go rename to lib/proxy/peer/stats.go index 0d37b741718e9..dc0cd89501308 100644 --- a/lib/proxy/stats.go +++ b/lib/proxy/peer/stats.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxy +package peer import ( "context" @@ -51,7 +51,7 @@ func (s *statsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) con return ctx } -// HandleRPC implements per-Connection stats reporting. +// HandleConn implements per-Connection stats reporting. func (s *statsHandler) HandleConn(ctx context.Context, connStats stats.ConnStats) { // client connection stats are monitored by the monitor() function in client.go if connStats.IsClient() { diff --git a/lib/proxy/router.go b/lib/proxy/router.go new file mode 100644 index 0000000000000..6925147d345fe --- /dev/null +++ b/lib/proxy/router.go @@ -0,0 +1,393 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "fmt" + "net" + "strconv" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/exp/slices" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/teleagent" + "github.com/gravitational/teleport/lib/utils" +) + +type serverResolverFn = func(ctx context.Context, host, port string, site site) (types.Server, error) + +// SiteGetter provides access to connected local or remote sites +type SiteGetter interface { + // GetSite returns the site matching the provided clusterName + GetSite(clusterName string) (reversetunnel.RemoteSite, error) +} + +// RemoteClusterGetter provides access to remote cluster resources +type RemoteClusterGetter interface { + // GetRemoteCluster returns a remote cluster by name + GetRemoteCluster(clusterName string) (types.RemoteCluster, error) +} + +// RouterConfig contains all the dependencies required +// by the Router +type RouterConfig struct { + // ClusterName indicates which cluster the router is for + ClusterName string + // Log is the logger to use + Log *logrus.Entry + // AccessPoint is the proxy cache + RemoteClusterGetter RemoteClusterGetter + // SiteGetter allows looking up sites + SiteGetter SiteGetter + // TracerProvider allows tracers to be created + TracerProvider oteltrace.TracerProvider + + // serverResolver is used to resolve hosts, used by tests + serverResolver serverResolverFn +} + +// CheckAndSetDefaults ensures the required items were populated +func (c *RouterConfig) CheckAndSetDefaults() error { + if c.Log == nil { + c.Log = logrus.WithField(trace.Component, "Router") + } + + if c.ClusterName == "" { + return trace.BadParameter("ClusterName must be provided") + } + + if c.RemoteClusterGetter == nil { + return trace.BadParameter("RemoteClusterGetter must be provided") + } + + if c.SiteGetter == nil { + return trace.BadParameter("SiteGetter must be provided") + } + + if c.TracerProvider == nil { + c.TracerProvider = tracing.DefaultProvider() + } + + if c.serverResolver == nil { + c.serverResolver = getServer + } + + return nil +} + +// Router is used by the proxy to establish connections to both +// nodes and other clusters. +type Router struct { + clusterName string + log *logrus.Entry + clusterGetter RemoteClusterGetter + localSite reversetunnel.RemoteSite + siteGetter SiteGetter + tracer oteltrace.Tracer + serverResolver serverResolverFn +} + +// NewRouter creates and returns a Router that is populated +// from the provided RouterConfig. +func NewRouter(cfg RouterConfig) (*Router, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + localSite, err := cfg.SiteGetter.GetSite(cfg.ClusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + return &Router{ + clusterName: cfg.ClusterName, + log: cfg.Log, + clusterGetter: cfg.RemoteClusterGetter, + localSite: localSite, + siteGetter: cfg.SiteGetter, + tracer: cfg.TracerProvider.Tracer("Router"), + serverResolver: cfg.serverResolver, + }, nil + +} + +// DialHost dials the node that matches the provided host, port and cluster. If no matching node +// is found an error is returned. If more than one matching node is found and the cluster networking +// configuration is not set to route to the most recent an error is returned. +func (r *Router) DialHost(ctx context.Context, from net.Addr, host, port, clusterName string, accessChecker services.AccessChecker, agentGetter teleagent.Getter) (net.Conn, error) { + ctx, span := r.tracer.Start( + ctx, + "router/DialHost", + oteltrace.WithAttributes( + attribute.String("host", host), + attribute.String("port", port), + attribute.String("site", clusterName), + ), + ) + defer span.End() + + site := r.localSite + if clusterName != r.clusterName { + remoteSite, err := r.getRemoteCluster(ctx, clusterName, accessChecker) + if err != nil { + return nil, trace.Wrap(err) + } + site = remoteSite + } + + span.AddEvent("looking up server") + target, err := r.serverResolver(ctx, host, port, remoteSite{site}) + if err != nil { + return nil, trace.Wrap(err) + } + span.AddEvent("retrieved target server") + + principals := []string{host} + + var ( + serverID string + serverAddr string + proxyIDs []string + ) + if target != nil { + proxyIDs = target.GetProxyIDs() + serverID = fmt.Sprintf("%v.%v", target.GetName(), clusterName) + + // add hostUUID.cluster to the principals + principals = append(principals, serverID) + + // add ip if it exists to the principals + serverAddr = target.GetAddr() + + switch { + case serverAddr != "": + h, _, err := net.SplitHostPort(serverAddr) + if err != nil { + return nil, trace.Wrap(err) + } + + principals = append(principals, h) + case serverAddr == "" && target.GetUseTunnel(): + serverAddr = reversetunnel.LocalNode + } + } else { + if port == "" || port == "0" { + port = strconv.Itoa(defaults.SSHServerListenPort) + } + + serverAddr = net.JoinHostPort(host, port) + r.log.Warnf("server lookup failed: using default=%v", serverAddr) + } + + conn, err := site.Dial(reversetunnel.DialParams{ + From: from, + To: &utils.NetAddr{AddrNetwork: "tcp", Addr: serverAddr}, + GetUserAgent: agentGetter, + Address: host, + ServerID: serverID, + ProxyIDs: proxyIDs, + Principals: principals, + ConnType: types.NodeTunnel, + }) + + return conn, trace.Wrap(err) +} + +// getRemoteCluster looks up the provided clusterName to determine if a remote site exists with +// that name and determines if the user has access to it. +func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, checker services.AccessChecker) (reversetunnel.RemoteSite, error) { + _, span := r.tracer.Start( + ctx, + "router/getRemoteCluster", + oteltrace.WithAttributes( + attribute.String("site", clusterName), + ), + ) + defer span.End() + + site, err := r.siteGetter.GetSite(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + rc, err := r.clusterGetter.GetRemoteCluster(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := checker.CheckAccessToRemoteCluster(rc); err != nil { + return nil, utils.OpaqueAccessDenied(err) + } + + return site, nil +} + +// site is the minimum interface needed to match servers +// for a reversetunnel.RemoteSite. It makes testing easier. +type site interface { + GetNodes(fn func(n services.Node) bool) ([]types.Server, error) + GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) +} + +// remoteSite is a site implementation that wraps +// a reversetunnel.RemoteSite +type remoteSite struct { + site reversetunnel.RemoteSite +} + +// GetNodes uses the wrapped sites NodeWatcher to filter nodes +func (r remoteSite) GetNodes(fn func(n services.Node) bool) ([]types.Server, error) { + watcher, err := r.site.NodeWatcher() + if err != nil { + return nil, trace.Wrap(err) + } + + return watcher.GetNodes(fn), nil +} + +// GetClusterNetworkingConfig uses the wrapped sites cache to retrieve the ClusterNetworkingConfig +func (r remoteSite) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) { + ap, err := r.site.CachingAccessPoint() + if err != nil { + return nil, trace.Wrap(err) + } + + cfg, err := ap.GetClusterNetworkingConfig(ctx, opts...) + return cfg, trace.Wrap(err) +} + +// getServer attempts to locate a node matching the provided host and port in +// the provided site. +func getServer(ctx context.Context, host, port string, site site) (types.Server, error) { + if site == nil { + return nil, trace.BadParameter("invalid remote site provided") + } + + strategy := types.RoutingStrategy_UNAMBIGUOUS_MATCH + if cfg, err := site.GetClusterNetworkingConfig(ctx); err == nil { + strategy = cfg.GetRoutingStrategy() + } + + _, err := uuid.Parse(host) + dialByID := err == nil || utils.IsEC2NodeID(host) + + ips, _ := net.LookupHost(host) + + var unambiguousIDMatch bool + matches, err := site.GetNodes(func(server services.Node) bool { + if unambiguousIDMatch { + return false + } + + // if host is a UUID or EC2 ID match only + // by server name and treat matches as unambiguous + if dialByID && server.GetName() == host { + unambiguousIDMatch = true + return true + } + + // if the server has connected over a reverse tunnel + // then match only by hostname + if server.GetUseTunnel() { + return host == server.GetHostname() + } + + ip, nodePort, err := net.SplitHostPort(server.GetAddr()) + if err != nil { + return false + } + + if (host == ip || host == server.GetHostname() || slices.Contains(ips, ip)) && + (port == "" || port == "0" || port == nodePort) { + return true + } + + return false + }) + if err != nil { + return nil, trace.Wrap(err) + } + + var server types.Server + switch { + case strategy == types.RoutingStrategy_MOST_RECENT: + for _, m := range matches { + if server == nil || m.Expiry().After(server.Expiry()) { + server = m + } + } + case len(matches) > 1: + return nil, trace.NotFound(teleport.NodeIsAmbiguous) + case len(matches) == 1: + server = matches[0] + } + + if dialByID && server == nil { + idType := "UUID" + if utils.IsEC2NodeID(host) { + idType = "EC2" + } + + return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, host) + } + + return server, nil + +} + +// DialSite establishes a connection to the auth server in the provided +// cluster. If the clusterName is an empty string then a connection to +// the local auth server will be established. +func (r *Router) DialSite(ctx context.Context, clusterName string) (net.Conn, error) { + _, span := r.tracer.Start( + ctx, + "router/DialSite", + oteltrace.WithAttributes( + attribute.String("site", clusterName), + ), + ) + defer span.End() + + // default to local cluster if one wasn't provided + if clusterName == "" { + clusterName = r.clusterName + } + + // dial the local auth server + if clusterName == r.clusterName { + conn, err := r.localSite.DialAuthServer() + return conn, trace.Wrap(err) + } + + // lookup the site and dial its auth server + site, err := r.siteGetter.GetSite(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + conn, err := site.DialAuthServer() + return conn, trace.Wrap(err) +} diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go new file mode 100644 index 0000000000000..019e665517b8d --- /dev/null +++ b/lib/proxy/router_test.go @@ -0,0 +1,509 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "net" + "testing" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/observability/tracing" + "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" +) + +type testSite struct { + cfg types.ClusterNetworkingConfig + nodes []types.Server +} + +func (t testSite) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) { + return t.cfg, nil +} + +func (t testSite) GetNodes(fn func(n services.Node) bool) ([]types.Server, error) { + var out []types.Server + for _, s := range t.nodes { + if fn(s) { + out = append(out, s) + } + } + + return out, nil +} + +type server struct { + name string + hostname string + addr string + tunnel bool +} + +func createServers(srvs []server) []types.Server { + out := make([]types.Server, 0, len(srvs)) + for _, s := range srvs { + srv := &types.ServerV2{ + Kind: types.KindNode, + Version: types.V2, + Metadata: types.Metadata{ + Name: s.name, + }, + Spec: types.ServerSpecV2{ + Addr: s.addr, + Hostname: s.hostname, + UseTunnel: s.tunnel, + }, + } + out = append(out, srv) + } + + return out +} + +func TestGetServers(t *testing.T) { + t.Parallel() + + mostRecentCfg := types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + RoutingStrategy: types.RoutingStrategy_MOST_RECENT, + }, + } + + unambiguousCfg := types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + RoutingStrategy: types.RoutingStrategy_UNAMBIGUOUS_MATCH, + }, + } + + hostID := uuid.NewString() + const ec2ID = "012345678901-i-01234567890abcdef" + + servers := createServers([]server{ + { + name: hostID, + hostname: "llama", + addr: "llama:123", + }, + { + name: "llama", + hostname: "llama", + addr: "llama:123", + tunnel: true, + }, + { + name: "llama", + hostname: hostID, + addr: "llama:123", + }, + { + name: ec2ID, + hostname: "node.aws", + addr: "node.aws:123", + }, + { + name: "node.aws", + hostname: "node.aws", + addr: "node.aws:123", + tunnel: true, + }, + { + name: "node.aws", + hostname: ec2ID, + addr: "node.aws:123", + }, + { + name: "alpaca", + hostname: "alpaca", + addr: "alpaca:123", + tunnel: true, + }, + { + name: "alpaca", + hostname: "localhost", + addr: "alpaca:987", + tunnel: true, + }, + { + name: "goat", + hostname: "goat", + addr: "1.2.3.4:123", + }, + { + name: "sheep", + hostname: "sheep", + addr: "sheep.bah:0", + }, + { + name: "sheep2", + hostname: "sheep", + addr: "sheep.bah:0", + }, + { + name: "lion", + hostname: "lion", + addr: "lion.roar", + }, + }) + + cases := []struct { + name string + host string + port string + site testSite + errAssertion require.ErrorAssertionFunc + serverAssertion func(t *testing.T, srv types.Server) + }{ + { + name: "no matches for hostname", + site: testSite{cfg: &unambiguousCfg}, + host: "test", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "no matches for uuid", + site: testSite{cfg: &mostRecentCfg}, + host: uuid.NewString(), + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsNotFound(err), i...) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "no matches for ec2 id", + site: testSite{cfg: &unambiguousCfg}, + host: "123456789012-i-1234567890abcdef0", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsNotFound(err), i...) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "ambiguous match fails", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "sheep", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous)) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + { + name: "ambiguous match returns most recent", + site: testSite{cfg: &mostRecentCfg, nodes: servers}, + host: "sheep", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "sheep", srv.GetHostname()) + }, + }, + { + name: "match by uuid", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: hostID, + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "llama", srv.GetHostname()) + }, + }, + { + name: "match by ec2 id", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: ec2ID, + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "node.aws", srv.GetHostname()) + }, + }, + { + name: "match by ip", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "1.2.3.4", + port: "123", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "goat", srv.GetHostname()) + }, + }, + { + name: "match by host only for tunnels", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "alpaca", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.Equal(t, "alpaca", srv.GetHostname()) + }, + }, + { + name: "failure on invalid addresses", + site: testSite{cfg: &unambiguousCfg, nodes: servers}, + host: "lion", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Empty(t, srv) + }, + }, + } + + ctx := context.Background() + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + srv, err := getServer(ctx, tt.host, tt.port, tt.site) + tt.errAssertion(t, err) + tt.serverAssertion(t, srv) + }) + } +} + +func serverResolver(srv types.Server, err error) serverResolverFn { + return func(ctx context.Context, host, port string, site site) (types.Server, error) { + return srv, err + } +} + +type tunnel struct { + reversetunnel.Tunnel + + site reversetunnel.RemoteSite + err error +} + +func (t tunnel) GetSite(cluster string) (reversetunnel.RemoteSite, error) { + return t.site, t.err +} + +type testRemoteSite struct { + reversetunnel.RemoteSite + conn net.Conn + err error +} + +func (r testRemoteSite) Dial(reversetunnel.DialParams) (net.Conn, error) { + return r.conn, r.err +} + +func (r testRemoteSite) DialAuthServer() (net.Conn, error) { + return r.conn, r.err +} + +type fakeConn struct { + net.Conn +} + +func TestRouter_DialHost(t *testing.T) { + t.Parallel() + + logger := utils.NewLoggerForTests().WithField(trace.Component, "test") + + srv := &types.ServerV2{ + Kind: types.KindNode, + Version: types.V2, + Metadata: types.Metadata{ + Name: uuid.NewString(), + }, + Spec: types.ServerSpecV2{ + Addr: "127.0.0.1:8889", + Hostname: "test", + }, + } + + cases := []struct { + name string + router Router + assertion func(t *testing.T, conn net.Conn, err error) + }{ + { + name: "failure resolving node", + router: Router{ + clusterName: "test", + log: logger, + tracer: tracing.NoopTracer("test"), + serverResolver: serverResolver(nil, trace.NotFound(teleport.NodeIsAmbiguous)), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.Nil(t, conn) + }, + }, + { + name: "failure looking up cluster", + router: Router{ + clusterName: "leaf", + siteGetter: tunnel{err: trace.NotFound("unknown cluster")}, + log: logger, + tracer: tracing.NoopTracer("test"), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + require.Nil(t, conn) + }, + }, + { + name: "dial failure", + router: Router{ + clusterName: "test", + log: logger, + localSite: &testRemoteSite{err: trace.ConnectionProblem(context.DeadlineExceeded, "connection refused")}, + tracer: tracing.NoopTracer("test"), + serverResolver: serverResolver(srv, nil), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err)) + require.Nil(t, conn) + }, + }, + { + name: "dial success", + router: Router{ + clusterName: "test", + log: logger, + localSite: &testRemoteSite{conn: fakeConn{}}, + tracer: tracing.NoopTracer("test"), + serverResolver: serverResolver(srv, nil), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + } + + ctx := context.Background() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + conn, err := tt.router.DialHost(ctx, &utils.NetAddr{}, "host", "0", "test", nil, nil) + tt.assertion(t, conn, err) + }) + } + +} + +func TestRouter_DialSite(t *testing.T) { + t.Parallel() + + const cluster = "test" + logger := utils.NewLoggerForTests().WithField(trace.Component, cluster) + + cases := []struct { + name string + cluster string + localSite testRemoteSite + tunnel tunnel + assertion func(t *testing.T, conn net.Conn, err error) + }{ + { + name: "failure to dial local site", + cluster: cluster, + localSite: testRemoteSite{err: trace.ConnectionProblem(context.DeadlineExceeded, "connection refused")}, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err)) + require.Nil(t, conn) + }, + }, + { + name: "successfully dial local site", + cluster: cluster, + localSite: testRemoteSite{conn: fakeConn{}}, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + + { + name: "default to dialing local site", + localSite: testRemoteSite{conn: fakeConn{}}, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + { + name: "failure to dial remote site", + cluster: "leaf", + tunnel: tunnel{ + site: testRemoteSite{err: trace.ConnectionProblem(context.DeadlineExceeded, "connection refused")}, + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err)) + require.Nil(t, conn) + }, + }, + { + name: "unknown cluster", + cluster: "fake", + tunnel: tunnel{ + err: trace.NotFound("unknown cluster"), + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + require.Nil(t, conn) + }, + }, + { + name: "successfully dial remote site", + cluster: "leaf", + tunnel: tunnel{ + site: testRemoteSite{conn: fakeConn{}}, + }, + assertion: func(t *testing.T, conn net.Conn, err error) { + require.NoError(t, err) + require.NotNil(t, conn) + }, + }, + } + + ctx := context.Background() + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + router := Router{ + clusterName: cluster, + log: logger, + localSite: &tt.localSite, + siteGetter: tt.tunnel, + tracer: tracing.NoopTracer(cluster), + } + + conn, err := router.DialSite(ctx, tt.cluster) + tt.assertion(t, conn, err) + }) + } +} diff --git a/lib/reversetunnel/api.go b/lib/reversetunnel/api.go index 2866533ac40f5..d7886fe7ee002 100644 --- a/lib/reversetunnel/api.go +++ b/lib/reversetunnel/api.go @@ -25,7 +25,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" ) @@ -140,7 +140,7 @@ type Server interface { // Wait waits for server to close all outstanding operations Wait() // GetProxyPeerClient returns the proxy peer client - GetProxyPeerClient() *proxy.Client + GetProxyPeerClient() *peer.Client } const ( diff --git a/lib/reversetunnel/conn_metric.go b/lib/reversetunnel/conn_metric.go index 79600bcd3fffe..10a49fa697b64 100644 --- a/lib/reversetunnel/conn_metric.go +++ b/lib/reversetunnel/conn_metric.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package reversetunnel import ( @@ -27,16 +28,16 @@ import ( type dialType string const ( - // direct is a direct dialed connection. - direct dialType = "direct" - // peer is a connection established through a peer proxy. - peer dialType = "peer" - // tunnel is a connection established over a local reverse tunnel initiated + // dialTypeDirect is a direct dialed connection. + dialTypeDirect dialType = "direct" + // dialTypePeer is a connection established through a peer proxy. + dialTypePeer dialType = "peer" + // dialTypeTunnel is a connection established over a local reverse tunnel initiated // by a client. - tunnel dialType = "tunnel" - // peerTunnel is a connection established over a local reverse tunnel + dialTypeTunnel dialType = "tunnel" + // dialTypePeerTunnel is a connection established over a local reverse tunnel // initiated by a peer proxy. - peerTunnel dialType = "peer-tunnel" + dialTypePeerTunnel dialType = "peer-tunnel" ) // metricConn reports metrics for reversetunnel connections. diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index c8414d6cb53ae..40ee4d97fc3f6 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -23,24 +23,24 @@ import ( "sync" "time" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/reversetunnel/track" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/utils" proxyutils "github.com/gravitational/teleport/lib/utils/proxy" - - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" - "golang.org/x/crypto/ssh" ) const ( @@ -149,7 +149,7 @@ type localSite struct { offlineThreshold time.Duration // peerClient is the proxy peering client - peerClient *proxy.Client + peerClient *peer.Client // periodicFunctionInterval defines the interval period functions run at periodicFunctionInterval time.Duration @@ -446,9 +446,9 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e // return a tunnel connection to that node. Otherwise net.Dial to the target host. conn, tunnelErr = s.dialTunnel(dreq) if tunnelErr == nil { - dt := tunnel + dt := dialTypeTunnel if params.FromPeerProxy { - dt = peerTunnel + dt = dialTypePeerTunnel } return newMetricConn(conn, dt, dialStart, s.srv.Clock), true, nil @@ -461,7 +461,7 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e params.ProxyIDs, params.ServerID, params.From, params.To, params.ConnType, ) if peerErr == nil { - return newMetricConn(conn, peer, dialStart, s.srv.Clock), true, nil + return newMetricConn(conn, dialTypePeer, dialStart, s.srv.Clock), true, nil } s.log.WithError(peerErr).WithField("address", dreq.Address).Debug("Error occurred while dialing over peer proxy.") } @@ -493,7 +493,7 @@ func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, e } // Return a direct dialed connection. - return newMetricConn(conn, direct, dialStart, s.srv.Clock), false, nil + return newMetricConn(conn, dialTypeDirect, dialStart, s.srv.Clock), false, nil } func (s *localSite) addConn(nodeID string, connType types.TunnelType, conn net.Conn, sconn ssh.Conn) (*remoteConn, error) { diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 08c2f3e57f197..8100b9ab9f66e 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -26,6 +26,12 @@ import ( "sync" "time" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/constants" @@ -35,17 +41,11 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/limiter" - "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" - "golang.org/x/crypto/ssh" ) var ( @@ -189,7 +189,7 @@ type Config struct { NewCachingAccessPointOldProxy auth.NewRemoteProxyCachingAccessPoint // PeerClient is a client to peer proxy servers. - PeerClient *proxy.Client + PeerClient *peer.Client // LockWatcher is a lock watcher. LockWatcher *services.LockWatcher @@ -1002,7 +1002,7 @@ func (s *server) GetSite(name string) (RemoteSite, error) { } // GetProxyPeerClient returns the proxy peer client -func (s *server) GetProxyPeerClient() *proxy.Client { +func (s *server) GetProxyPeerClient() *peer.Client { return s.PeerClient } diff --git a/lib/service/service.go b/lib/service/service.go index 5ea4a42b717da..39725ce8b92b5 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -96,6 +96,7 @@ import ( "github.com/gravitational/teleport/lib/plugin" "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/proxy/clusterdial" + "github.com/gravitational/teleport/lib/proxy/peer" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -2304,7 +2305,29 @@ func (process *TeleportProcess) initSSH() error { storagePresence := local.NewPresenceService(process.storage) - s, err := regular.New(cfg.SSH.Addr, + // read the host UUID: + serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir) + if err != nil { + return trace.Wrap(err) + } + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: authClient, + AccessPoint: authClient, + LockEnforcer: lockWatcher, + Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer}, + Component: teleport.ComponentNode, + Logger: process.log.WithField(trace.Component, "sessionctrl"), + TracerProvider: process.TracingProvider, + ServerID: serverID, + }) + if err != nil { + return trace.Wrap(err) + } + + s, err := regular.New( + process.ExitContext(), + cfg.SSH.Addr, cfg.Hostname, []ssh.Signer{conn.ServerIdentity.KeySigner}, authClient, @@ -2337,6 +2360,8 @@ func (process *TeleportProcess) initSSH() error { regular.SetCreateHostUser(!cfg.SSH.DisableCreateHostUser), regular.SetStoragePresenceService(storagePresence), regular.SetInventoryControlHandle(process.inventoryHandle), + regular.SetTracerProvider(process.TracingProvider), + regular.SetSessionController(sessionController), ) if err != nil { return trace.Wrap(err) @@ -3372,11 +3397,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // register SSH reverse tunnel server that accepts connections // from remote teleport nodes var tsrv reversetunnel.Server - var peerClient *proxy.Client + var peerClient *peer.Client if !process.Config.Proxy.DisableReverseTunnel { if listeners.proxy != nil { - peerClient, err = proxy.NewClient(proxy.ClientConfig{ + peerClient, err = peer.NewClient(peer.ClientConfig{ Context: process.ExitContext(), ID: process.Config.HostUUID, AuthClient: conn.Client, @@ -3437,6 +3462,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return nil }) } + if !process.Config.Proxy.DisableTLS { tlsConfigWeb, err = process.setupProxyTLSConfig(conn, tsrv, accessPoint, clusterName) if err != nil { @@ -3556,10 +3582,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } var peerAddr string - var proxyServer *proxy.Server + var proxyServer *peer.Server if !process.Config.Proxy.DisableReverseTunnel && listeners.proxy != nil { peerAddr = listeners.proxy.Addr().String() - proxyServer, err = proxy.NewServer(proxy.ServerConfig{ + proxyServer, err = peer.NewServer(peer.ServerConfig{ AccessCache: accessPoint, Listener: listeners.proxy, TLSConfig: serverTLSConfig, @@ -3586,7 +3612,44 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { }) } - sshProxy, err := regular.New(cfg.Proxy.SSHAddr, + var proxyRouter *proxy.Router + if !process.Config.Proxy.DisableReverseTunnel { + router, err := proxy.NewRouter(proxy.RouterConfig{ + ClusterName: clusterName, + Log: process.log.WithField(trace.Component, "router"), + RemoteClusterGetter: accessPoint, + SiteGetter: tsrv, + TracerProvider: process.TracingProvider, + }) + if err != nil { + return trace.Wrap(err) + } + proxyRouter = router + } + + // read the host UUID: + serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir) + if err != nil { + return trace.Wrap(err) + } + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: accessPoint, + AccessPoint: accessPoint, + LockEnforcer: lockWatcher, + Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer}, + Component: teleport.ComponentProxy, + Logger: process.log.WithField(trace.Component, "sessionctrl"), + TracerProvider: process.TracingProvider, + ServerID: serverID, + }) + if err != nil { + return trace.Wrap(err) + } + + sshProxy, err := regular.New( + process.ExitContext(), + cfg.SSH.Addr, cfg.Hostname, []ssh.Signer{conn.ServerIdentity.KeySigner}, accessPoint, @@ -3595,7 +3658,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { process.proxyPublicAddr(), conn.Client, regular.SetLimiter(proxyLimiter), - regular.SetProxyMode(peerAddr, tsrv, accessPoint), + regular.SetProxyMode(peerAddr, tsrv, accessPoint, proxyRouter), regular.SetSessionServer(conn.Client), regular.SetCiphers(cfg.Ciphers), regular.SetKEXAlgorithms(cfg.KEXAlgorithms), @@ -3611,6 +3674,8 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // accurately checked later when an SCP/SFTP request hits the // destination Node. regular.SetAllowFileCopying(true), + regular.SetTracerProvider(process.TracingProvider), + regular.SetSessionController(sessionController), ) if err != nil { return trace.Wrap(err) diff --git a/lib/services/semaphore.go b/lib/services/semaphore.go index 54108ed41a012..211e5ca3167b2 100644 --- a/lib/services/semaphore.go +++ b/lib/services/semaphore.go @@ -19,12 +19,13 @@ import ( "sync" "time" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + log "github.com/sirupsen/logrus" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" - - "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" ) type SemaphoreLockConfig struct { @@ -38,10 +39,16 @@ type SemaphoreLockConfig struct { TickRate time.Duration // Params holds the semaphore lease acquisition parameters. Params types.AcquireSemaphoreRequest + // Clock used to alter time in tests + Clock clockwork.Clock } // CheckAndSetDefaults checks and sets default parameters func (l *SemaphoreLockConfig) CheckAndSetDefaults() error { + if l.Clock == nil { + l.Clock = clockwork.NewRealClock() + } + if l.Service == nil { return trace.BadParameter("missing semaphore service") } @@ -58,7 +65,7 @@ func (l *SemaphoreLockConfig) CheckAndSetDefaults() error { return trace.BadParameter("tick-rate must be less than expiry") } if l.Params.Expires.IsZero() { - l.Params.Expires = time.Now().UTC().Add(l.Expiry) + l.Params.Expires = l.Clock.Now().UTC().Add(l.Expiry) } if err := l.Params.Check(); err != nil { return trace.Wrap(err) @@ -72,7 +79,7 @@ type SemaphoreLock struct { cfg SemaphoreLockConfig lease0 types.SemaphoreLease retry utils.Retry - ticker *time.Ticker + ticker clockwork.Ticker doneC chan struct{} closeOnce sync.Once renewalC chan struct{} @@ -139,7 +146,7 @@ func (l *SemaphoreLock) keepAlive(ctx context.Context) { // cancellation/expiry. return } - if lease.Expires.After(time.Now().UTC()) { + if lease.Expires.After(l.cfg.Clock.Now().UTC()) { // parent context is closed. create orphan context with generous // timeout for lease cancellation scope. this will not block any // caller that is not explicitly waiting on the final error value. @@ -156,7 +163,7 @@ func (l *SemaphoreLock) keepAlive(ctx context.Context) { Outer: for { select { - case tick := <-l.ticker.C: + case tick := <-l.ticker.Chan(): leaseContext, leaseCancel := context.WithDeadline(ctx, lease.Expires) nextLease := lease nextLease.Expires = tick.Add(l.cfg.Expiry) @@ -184,7 +191,7 @@ Outer: l.retry.Inc() select { case <-l.retry.After(): - case tick = <-l.ticker.C: + case tick = <-l.ticker.Chan(): // check to make sure that we still have some time on the lease. the default tick rate would have // us waking _as_ the lease expires here, but if we're working with a higher tick rate, its worth // retrying again. @@ -235,7 +242,7 @@ func AcquireSemaphoreWithRetry(ctx context.Context, req AcquireSemaphoreWithRetr } // AcquireSemaphoreLock attempts to acquire and hold a semaphore lease. If successfully acquired, -// background keepalive processes are started and an associated lock handle is returned. Cancelling +// background keepalive processes are started and an associated lock handle is returned. Canceling // the supplied context releases the semaphore. func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*SemaphoreLock, error) { if err := cfg.CheckAndSetDefaults(); err != nil { @@ -246,6 +253,7 @@ func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*Semaph Max: cfg.Expiry / 4, Step: cfg.Expiry / 16, Jitter: utils.NewJitter(), + Clock: cfg.Clock, }) if err != nil { return nil, trace.Wrap(err) @@ -258,7 +266,7 @@ func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*Semaph cfg: cfg, lease0: *lease, retry: retry, - ticker: time.NewTicker(cfg.TickRate), + ticker: cfg.Clock.NewTicker(cfg.TickRate), doneC: make(chan struct{}), renewalC: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 8af6e97be61c0..cb51826264b79 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -439,13 +439,15 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s trace.ComponentFields: fields, }) - lockTargets, err := ComputeLockTargets(srv, identityContext) + clusterName, err := srv.GetAccessPoint().GetClusterName() if err != nil { - return nil, nil, trace.Wrap(err) + childErr := child.Close() + return nil, nil, trace.NewAggregate(err, childErr) } + monitorConfig := MonitorConfig{ LockWatcher: child.srv.GetLockWatcher(), - LockTargets: lockTargets, + LockTargets: ComputeLockTargets(clusterName.GetClusterName(), srv.HostUUID(), identityContext), LockingMode: identityContext.AccessChecker.LockingMode(authPref.GetLockingMode()), DisconnectExpiredCert: child.disconnectExpiredCert, ClientIdleTimeout: child.clientIdleTimeout, @@ -1151,28 +1153,19 @@ func newUaccMetadata(c *ServerContext) (*UaccMetadata, error) { }, nil } -// ComputeLockTargets computes lock targets inferred from a Server -// and an IdentityContext. -func ComputeLockTargets(s Server, id IdentityContext) ([]types.LockTarget, error) { - clusterName, err := s.GetAccessPoint().GetClusterName() - if err != nil { - return nil, trace.Wrap(err) - } +// ComputeLockTargets computes lock targets inferred from the clusterName, serverID and IdentityContext. +func ComputeLockTargets(clusterName, serverID string, id IdentityContext) []types.LockTarget { lockTargets := []types.LockTarget{ {User: id.TeleportUser}, {Login: id.Login}, - {Node: s.HostUUID()}, - {Node: auth.HostFQDN(s.HostUUID(), clusterName.GetClusterName())}, + {Node: serverID}, + {Node: auth.HostFQDN(serverID, clusterName)}, {MFADevice: id.Certificate.Extensions[teleport.CertExtensionMFAVerified]}, } roles := apiutils.Deduplicate(append(id.AccessChecker.RoleNames(), id.UnmappedRoles...)) - lockTargets = append(lockTargets, - services.RolesToLockTargets(roles)..., - ) - lockTargets = append(lockTargets, - services.AccessRequestsToLockTargets(id.ActiveRequests)..., - ) - return lockTargets, nil + lockTargets = append(lockTargets, services.RolesToLockTargets(roles)...) + lockTargets = append(lockTargets, services.AccessRequestsToLockTargets(id.ActiveRequests)...) + return lockTargets } // SetRequest sets the ssh request that was issued by the client. diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index 0d7f073bded01..a13bf5113c60a 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -23,29 +23,21 @@ import ( "fmt" "io" "net" - "strconv" "strings" - "sync" + "github.com/gravitational/trace" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/observability/tracing" - "github.com/gravitational/teleport/api/types" - apiutils "github.com/gravitational/teleport/api/utils" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" - "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/reversetunnel" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - - "github.com/google/uuid" - "github.com/gravitational/trace" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" ) var ( // failedConnectingToNode counts failed attempts to connect to a node @@ -70,30 +62,31 @@ var ( // failedConnectingToNode counts failed attempts to connect to a node Help: "Number of SSH connection attempts to a node. Use with `failed_connect_to_node_attempts_total` to get the failure rate.", }, ) - - prometheusCollectors = []prometheus.Collector{proxiedSessions, failedConnectingToNode, connectingToNode} ) +func init() { + utils.RegisterPrometheusCollectors(proxiedSessions, failedConnectingToNode, connectingToNode) +} + // proxySubsys implements an SSH subsystem for proxying listening sockets from // remote hosts to a proxy client (AKA port mapping) type proxySubsys struct { proxySubsysRequest - srv *Server - ctx *srv.ServerContext - log *logrus.Entry - closeC chan struct{} - error error - closeOnce sync.Once + router *proxy.Router + ctx *srv.ServerContext + log *logrus.Entry + closeC chan error } // parseProxySubsys looks at the requested subsystem name and returns a fully configured // proxy subsystem // // proxy subsystem name can take the following forms: -// "proxy:host:22" - standard SSH request to connect to host:22 on the 1st cluster -// "proxy:@clustername" - Teleport request to connect to an auth server for cluster with name 'clustername' -// "proxy:host:22@clustername" - Teleport request to connect to host:22 on cluster 'clustername' -// "proxy:host:22@namespace@clustername" +// +// "proxy:host:22" - standard SSH request to connect to host:22 on the 1st cluster +// "proxy:@clustername" - Teleport request to connect to an auth server for cluster with name 'clustername' +// "proxy:host:22@clustername" - Teleport request to connect to host:22 on cluster 'clustername' +// "proxy:host:22@namespace@clustername" func parseProxySubsysRequest(request string) (proxySubsysRequest, error) { log.Debugf("parse_proxy_subsys(%q)", request) var ( @@ -109,8 +102,9 @@ func parseProxySubsysRequest(request string) (proxySubsysRequest, error) { } requestBody := strings.TrimPrefix(request, prefix) namespace := apidefaults.Namespace - var err error parts := strings.Split(requestBody, "@") + + var err error switch { case len(parts) == 0: // "proxy:" return proxySubsysRequest{}, trace.BadParameter(paramMessage) @@ -189,11 +183,6 @@ func (p *proxySubsysRequest) SetDefaults() { // a port forwarding request, used to implement ProxyJump feature in proxy // and reuse the code func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) (*proxySubsys, error) { - err := utils.RegisterPrometheusCollectors(prometheusCollectors...) - if err != nil { - return nil, trace.Wrap(err) - } - req.SetDefaults() if req.clusterName == "" && ctx.Identity.RouteToCluster != "" { log.Debugf("Proxy subsystem: routing user %q to cluster %q based on the route to cluster extension.", @@ -211,12 +200,12 @@ func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) return &proxySubsys{ proxySubsysRequest: req, ctx: ctx, - srv: srv, log: logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentSubsystemProxy, trace.ComponentFields: map[string]string{}, }), - closeC: make(chan struct{}), + closeC: make(chan error), + router: srv.router, }, nil } @@ -238,12 +227,8 @@ func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.C }) t.log.Debugf("Starting subsystem") - var ( - site reversetunnel.RemoteSite - err error - tunnel = t.srv.tunnelWithAccessChecker(serverContext) - clientAddr = sconn.RemoteAddr() - ) + clientAddr := sconn.RemoteAddr() + // did the client pass us a true client IP ahead of time via an environment variable? // (usually the web client would do that) trueClientIP, ok := serverContext.GetEnv(sshutils.TrueClientAddrVar) @@ -253,185 +238,41 @@ func (t *proxySubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.C clientAddr = a } } - // get the cluster by name: - if t.clusterName != "" { - site, err = tunnel.GetSite(t.clusterName) - if err != nil { - t.log.Warn(err) - return trace.Wrap(err) - } - } - // connecting to a specific host: - if t.host != "" { - // no site given? use the 1st one: - if site == nil { - sites, err := tunnel.GetSites() - if err != nil { - return trace.Wrap(err) - } - if len(sites) == 0 { - t.log.Error("Not connected to any remote clusters") - return trace.NotFound("no connected sites") - } - site = sites[0] - t.clusterName = site.GetName() - t.log.Debugf("Cluster not specified. connecting to default='%s'", site.GetName()) - } - return t.proxyToHost(ctx, site, clientAddr, ch) + + // connect to a site's auth server + if t.host == "" { + return t.proxyToSite(ctx, ch, t.clusterName) } - // connect to a site's auth server: - return t.proxyToSite(serverContext, site, clientAddr, ch) + + // connect to a server + return t.proxyToHost(ctx, ch, clientAddr) } // proxyToSite establishes a proxy connection from the connected SSH client to the // auth server of the requested remote site -func (t *proxySubsys) proxyToSite( - ctx *srv.ServerContext, site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { - conn, err := site.DialAuthServer() +func (t *proxySubsys) proxyToSite(ctx context.Context, ch ssh.Channel, clusterName string) error { + t.log.Debugf("Connecting to site: %v", clusterName) + + conn, err := t.router.DialSite(ctx, clusterName) if err != nil { return trace.Wrap(err) } - t.log.Infof("Connected to auth server: %v", conn.RemoteAddr()) + t.log.Infof("Connected to cluster %v at %v", clusterName, conn.RemoteAddr()) proxiedSessions.Inc() + go func() { - var err error - defer func() { - t.close(err) - }() - defer ch.Close() - _, err = io.Copy(ch, conn) - }() - go func() { - var err error - defer func() { - t.close(err) - }() - defer conn.Close() - _, err = io.Copy(conn, ch) + t.close(utils.ProxyConn(ctx, ch, conn)) }() - return nil } // proxyToHost establishes a proxy connection from the connected SSH client to the // requested remote node (t.host:t.port) via the given site -func (t *proxySubsys) proxyToHost( - ctx context.Context, site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { - // - // first, lets fetch a list of servers at the given site. this allows us to - // match the given "host name" against node configuration (their 'nodename' setting) - // - // but failing to fetch the list of servers is also OK, we'll use standard - // network resolution (by IP or DNS) - // - var ( - strategy types.RoutingStrategy - nodeWatcher NodesGetter - err error - ) - localCluster, _ := t.srv.proxyAccessPoint.GetClusterName() - // going to "local" CA? lets use the caching 'auth service' directly and avoid - // hitting the reverse tunnel link (it can be offline if the CA is down) - if site.GetName() == localCluster.GetName() { - nodeWatcher = t.srv.nodeWatcher - - cfg, err := t.srv.authService.GetClusterNetworkingConfig(ctx) - if err != nil { - t.log.Warn(err) - } else { - strategy = cfg.GetRoutingStrategy() - } - } else { - // "remote" CA? use a reverse tunnel to talk to it: - siteClient, err := site.CachingAccessPoint() - if err != nil { - t.log.Warn(err) - } else { - watcher, err := site.NodeWatcher() - if err != nil { - t.log.Warn(err) - } else { - nodeWatcher = watcher - } - - cfg, err := siteClient.GetClusterNetworkingConfig(ctx) - if err != nil { - t.log.Warn(err) - } else { - strategy = cfg.GetRoutingStrategy() - } - } - } - - // if port is 0, it means the client wants us to figure out - // which port to use - t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v, strategy=%s", t.host, t.port, t.SpecifiedPort(), strategy) +func (t *proxySubsys) proxyToHost(ctx context.Context, ch ssh.Channel, remoteAddr net.Addr) error { + t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v", t.host, t.port, t.SpecifiedPort()) - // determine which server to connect to - server, err := t.getMatchingServer(nodeWatcher, strategy) - if err != nil { - return trace.Wrap(err) - } - - // Create a slice of principals that will be added into the host certificate. - // Here t.host is either an IP address or a DNS name as the user requested. - principals := []string{t.host} - - // Used to store the server ID (hostUUID.clusterName) of a Teleport node. - var serverID string - - // Resolve the IP address to dial to because the hostname may not be - // DNS resolvable. - var ( - serverAddr string - proxyIDs []string - ) - - if server != nil { - // Add hostUUID.clusterName to list of principals. - serverID = fmt.Sprintf("%v.%v", server.GetName(), t.clusterName) - principals = append(principals, serverID) - proxyIDs = server.GetProxyIDs() - - // Add IP address (if it exists) of the node to list of principals. - serverAddr = server.GetAddr() - if serverAddr != "" { - host, _, err := net.SplitHostPort(serverAddr) - if err != nil { - return trace.Wrap(err) - } - principals = append(principals, host) - } else if server.GetUseTunnel() { - serverAddr = reversetunnel.LocalNode - } - } else { - if !t.SpecifiedPort() { - t.port = strconv.Itoa(defaults.SSHServerListenPort) - } - serverAddr = net.JoinHostPort(t.host, t.port) - t.log.Warnf("server lookup failed: using default=%v", serverAddr) - } - - // Pass the agent along to the site. If the proxy is in recording mode, this - // agent is used to perform user authentication. Pass the DNS name to the - // dialer as well so the forwarding proxy can generate a host certificate - // with the correct hostname). - toAddr := &utils.NetAddr{ - AddrNetwork: "tcp", - Addr: serverAddr, - } - connectingToNode.Inc() - conn, err := site.Dial(reversetunnel.DialParams{ - From: remoteAddr, - To: toAddr, - GetUserAgent: t.ctx.StartAgentChannel, - Address: t.host, - ServerID: serverID, - ProxyIDs: proxyIDs, - Principals: principals, - ConnType: types.NodeTunnel, - }) + conn, err := t.router.DialHost(ctx, remoteAddr, t.host, t.port, t.clusterName, t.ctx.Identity.AccessChecker, t.ctx.StartAgentChannel) if err != nil { failedConnectingToNode.Inc() return trace.Wrap(err) @@ -442,126 +283,20 @@ func (t *proxySubsys) proxyToHost( t.doHandshake(ctx, remoteAddr, ch, conn) proxiedSessions.Inc() + go func() { - var err error - defer func() { - t.close(err) - }() - defer ch.Close() - _, err = io.Copy(ch, conn) - }() - go func() { - var err error - defer func() { - t.close(err) - }() - defer conn.Close() - _, err = io.Copy(conn, ch) + t.close(utils.ProxyConn(ctx, ch, conn)) }() - return nil } -// NodesGetter is a function that retrieves a subset of nodes matching -// the filter criteria. -type NodesGetter interface { - GetNodes(fn func(n services.Node) bool) []types.Server -} - -// getMatchingServer determines the server to connect to from the provided servers. Duplicate entries are treated -// differently based on strategy. Legacy behavior of returning an ambiguous error occurs if the strategy -// is types.RoutingStrategy_UNAMBIGUOUS_MATCH. When the strategy is types.RoutingStrategy_MOST_RECENT then -// the server that has heartbeated most recently will be returned instead of an error. If no matches are found then -// both the types.Server and error returned will be nil. -func (t *proxySubsys) getMatchingServer(watcher NodesGetter, strategy types.RoutingStrategy) (types.Server, error) { - if watcher == nil { - return nil, trace.NotFound("unable to retrieve nodes matching host %s", t.host) - } - - // check if hostname is a valid uuid or EC2 node ID. If it is, we will - // preferentially match by node ID over node hostname. - _, err := uuid.Parse(t.host) - hostIsUniqueID := err == nil || utils.IsEC2NodeID(t.host) - - ips, _ := net.LookupHost(t.host) - - var unambiguousIDMatch bool - // enumerate and try to find a server with self-registered with a matching name/IP: - matches := watcher.GetNodes(func(server services.Node) bool { - if unambiguousIDMatch { - return false - } - - // If the host parameter is a UUID or EC2 node ID, and it matches the - // Node ID, treat this as an unambiguous match. - if hostIsUniqueID && server.GetName() == t.host { - unambiguousIDMatch = true - return true - } - // If the server has connected over a reverse tunnel, match only on hostname. - if server.GetUseTunnel() { - return t.host == server.GetHostname() - } - - ip, port, err := net.SplitHostPort(server.GetAddr()) - if err != nil { - t.log.Errorf("Failed to parse address %q: %v.", server.GetAddr(), err) - return false - } - if t.host == ip || t.host == server.GetHostname() || apiutils.SliceContainsStr(ips, ip) { - if !t.SpecifiedPort() || t.port == port { - return true - } - } - return false - }) - - var server types.Server - switch { - case strategy == types.RoutingStrategy_MOST_RECENT: - // find the most recent of all the matches - for _, m := range matches { - if server == nil || m.Expiry().After(server.Expiry()) { - server = m - } - } - case len(matches) > 1: - // if we matched more than one server, then the target was ambiguous. - return nil, trace.NotFound(teleport.NodeIsAmbiguous) - case len(matches) == 1: - server = matches[0] - } - - // If we matched zero nodes but hostname is a UUID (or EC2 node ID) then it - // isn't sane to fallback to dns based resolution. This has the unfortunate - // consequence of preventing users from calling OpenSSH nodes which happen - // to use hostnames which are also valid UUIDs. This restriction is - // necessary in order to protect users attempting to connect to a node by - // UUID from being re-routed to an unintended target if the node is offline. - // This restriction can be lifted if we decide to move to explicit UUID - // based resolution in the future. - if hostIsUniqueID && server == nil { - idType := "UUID" - if utils.IsEC2NodeID(t.host) { - idType = "EC2" - } - return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, t.host) - } - - return server, nil -} - func (t *proxySubsys) close(err error) { - t.closeOnce.Do(func() { - proxiedSessions.Dec() - t.error = err - close(t.closeC) - }) + proxiedSessions.Dec() + t.closeC <- err } func (t *proxySubsys) Wait() error { - <-t.closeC - return t.error + return <-t.closeC } // doHandshake allows a proxy server to send additional information (client IP) diff --git a/lib/srv/regular/proxy_test.go b/lib/srv/regular/proxy_test.go index d6259bbb209cf..8fef4a19d7186 100644 --- a/lib/srv/regular/proxy_test.go +++ b/lib/srv/regular/proxy_test.go @@ -18,14 +18,10 @@ package regular import ( "testing" - "time" - "github.com/google/uuid" "github.com/stretchr/testify/require" apidefaults "github.com/gravitational/teleport/api/defaults" - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" ) @@ -130,210 +126,3 @@ func TestParseBadRequests(t *testing.T) { }) } } - -type nodeGetter struct { - servers []types.Server -} - -func (n nodeGetter) GetNodes(fn func(n services.Node) bool) []types.Server { - var servers []types.Server - for _, s := range n.servers { - if fn(s) { - servers = append(servers, s) - } - } - - return servers -} - -func TestProxySubsys_getMatchingServer(t *testing.T) { - t.Parallel() - - serverUUID := uuid.NewString() - - ec2NodeID := "123456789012-i-abcdef12345678901" - - setExpiry := func(time time.Time) func(server types.Server) { - return func(server types.Server) { - server.SetExpiry(time) - } - } - - createServer := func(name string, spec types.ServerSpecV2, opts ...func(server types.Server)) types.Server { - t.Helper() - - server, err := types.NewServer(name, types.KindNode, spec) - require.NoError(t, err) - - for _, opt := range opts { - opt(server) - } - - return server - } - - servers := []types.Server{ - createServer(serverUUID, types.ServerSpecV2{ - Hostname: "127.0.0.1", - Addr: "127.0.0.1:80", - }, setExpiry(time.Now().Add(-time.Hour))), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - }, setExpiry(time.Now().Add(time.Hour*24))), - createServer("server3", types.ServerSpecV2{ - Hostname: serverUUID, - Addr: "127.0.0.1:", - }), - createServer(ec2NodeID, types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:", - }), - } - - cases := []struct { - desc string - req proxySubsysRequest - strategy types.RoutingStrategy - servers []types.Server - expectError require.ErrorAssertionFunc - expectServer func(servers []types.Server) types.Server - }{ - { - desc: "No matches found", - expectError: require.NoError, - }, - { - desc: "No matches found for UUID host", - expectError: require.Error, - servers: []types.Server{createServer(uuid.NewString(), types.ServerSpecV2{ - Addr: "127.0.0.1:0", - })}, - req: proxySubsysRequest{ - host: uuid.NewString(), - }, - }, - { - desc: "Match by UUID", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[0] - }, - servers: servers, - req: proxySubsysRequest{ - host: serverUUID, - }, - }, - { - desc: "Match by EC2 ID", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[3] - }, - servers: servers, - req: proxySubsysRequest{ - host: ec2NodeID, - }, - }, - { - desc: "Match Tunnel By Host Only", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[0] - }, - servers: []types.Server{ - createServer("server1", types.ServerSpecV2{ - Addr: "127.0.0.1", - Hostname: "127.0.0.1", - UseTunnel: true, - }), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - UseTunnel: true, - }), - }, - req: proxySubsysRequest{ - host: "127.0.0.1", - port: "80", - }, - }, - { - desc: "Match by IP", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[1] - }, - servers: []types.Server{ - createServer("server1", types.ServerSpecV2{ - Addr: "127.0.0.1:0", - Hostname: "127.0.0.1", - }), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - }), - }, - req: proxySubsysRequest{ - host: "127.0.0.1", - port: "80", - }, - }, - { - desc: "Match by hostname", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[1] - }, - servers: []types.Server{ - createServer("server1", types.ServerSpecV2{ - Addr: "127.0.0.1:0", - Hostname: "localhost", - }), - createServer("server2", types.ServerSpecV2{ - Hostname: "localhost", - Addr: "127.0.0.1:80", - }), - }, - req: proxySubsysRequest{ - host: "localhost", - port: "80", - }, - }, - { - desc: "Ambiguous match", - expectError: require.Error, - servers: servers, - req: proxySubsysRequest{ - host: "localhost", - }, - }, - { - desc: "Most Recent match", - expectError: require.NoError, - expectServer: func(servers []types.Server) types.Server { - return servers[1] - }, - servers: servers, - strategy: types.RoutingStrategy_MOST_RECENT, - req: proxySubsysRequest{ - host: "localhost", - }, - }, - } - - for _, tt := range cases { - t.Run(tt.desc, func(t *testing.T) { - subsystem := proxySubsys{ - proxySubsysRequest: tt.req, - srv: &Server{}, - } - - server, err := subsystem.getMatchingServer(nodeGetter{tt.servers}, tt.strategy) - tt.expectError(t, err) - if tt.expectServer != nil { - require.Equal(t, tt.expectServer(tt.servers), server) - } - }) - } -} diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index cd399d7e1fd3e..f187cdf840dc6 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -30,6 +30,13 @@ import ( "strings" "sync" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/crypto/ssh" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" @@ -45,6 +52,7 @@ import ( "github.com/gravitational/teleport/lib/labels" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/pam" + "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -55,14 +63,6 @@ import ( "github.com/gravitational/teleport/lib/sshutils/x11" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" - - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - semconv "go.opentelemetry.io/otel/semconv/v1.10.0" - oteltrace "go.opentelemetry.io/otel/trace" - "golang.org/x/crypto/ssh" ) const sftpSubsystem = "sftp" @@ -71,13 +71,6 @@ var ( log = logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentNode, }) - - userSessionLimitHitCount = prometheus.NewCounter( - prometheus.CounterOpts{ - Name: teleport.MetricUserMaxConcurrentSessionsHit, - Help: "Number of times a user exceeded their max concurrent ssh connections", - }, - ) ) // Server implements SSH server that uses configuration backend and @@ -223,6 +216,14 @@ type Server struct { // tracerProvider is used to create tracers capable // of starting spans. tracerProvider oteltrace.TracerProvider + + // router used by subsystem requests to connect to nodes + // and clusters + router *proxy.Router + + // sessionController is used to restrict new sessions + // based on locks and cluster preferences + sessionController *srv.SessionController } // TargetMetadata returns metadata about the server. @@ -458,7 +459,7 @@ func SetSessionServer(sessionServer rsession.Service) ServerOption { } // SetProxyMode starts this server in SSH proxying mode -func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyAccessPoint) ServerOption { +func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyAccessPoint, router *proxy.Router) ServerOption { return func(s *Server) error { // always set proxy mode to true, // because in some tests reverse tunnel is disabled, @@ -467,6 +468,7 @@ func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyA s.proxyTun = tsrv s.proxyAccessPoint = ap s.peerAddr = peerAddr + s.router = router return nil } } @@ -688,8 +690,18 @@ func SetTracerProvider(provider oteltrace.TracerProvider) ServerOption { } } +// SetSessionController sets the session controller. +func SetSessionController(controller *srv.SessionController) ServerOption { + return func(s *Server) error { + s.sessionController = controller + return nil + } +} + // New returns an unstarted server -func New(addr utils.NetAddr, +func New( + ctx context.Context, + addr utils.NetAddr, hostname string, signers []ssh.Signer, authService srv.AccessPoint, @@ -699,18 +711,13 @@ func New(addr utils.NetAddr, auth auth.ClientI, options ...ServerOption, ) (*Server, error) { - err := utils.RegisterPrometheusCollectors(userSessionLimitHitCount) - if err != nil { - return nil, trace.Wrap(err) - } - // read the host UUID: uuid, err := utils.ReadOrMakeHostUUID(dataDir) if err != nil { return nil, trace.Wrap(err) } - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(ctx) s := &Server{ addr: addr, authService: authService, @@ -753,6 +760,10 @@ func New(addr utils.NetAddr, return nil, trace.BadParameter("setup valid LockWatcher parameter using SetLockWatcher") } + if s.sessionController == nil { + return nil, trace.BadParameter("setup valid SessionControl parameter using SetSessionControl") + } + if s.connectedProxyGetter == nil { s.connectedProxyGetter = reversetunnel.NewConnectedProxyGetter() } @@ -1100,96 +1111,9 @@ func (s *Server) HandleNewConn(ctx context.Context, ccx *sshutils.ConnectionCont if err != nil { return ctx, trace.Wrap(err) } - authPref, err := s.GetAccessPoint().GetAuthPreference(ctx) - if err != nil { - return ctx, trace.Wrap(err) - } - lockingMode := identityContext.AccessChecker.LockingMode(authPref.GetLockingMode()) - event := &apievents.SessionReject{ - Metadata: apievents.Metadata{ - Type: events.SessionRejectedEvent, - Code: events.SessionRejectedCode, - }, - UserMetadata: identityContext.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - Protocol: events.EventProtocolSSH, - LocalAddr: ccx.ServerConn.LocalAddr().String(), - RemoteAddr: ccx.ServerConn.RemoteAddr().String(), - }, - ServerMetadata: apievents.ServerMetadata{ - ServerID: s.uuid, - ServerNamespace: s.GetNamespace(), - }, - } - - lockTargets, err := srv.ComputeLockTargets(s, identityContext) - if err != nil { - return ctx, trace.Wrap(err) - } - if lockErr := s.lockWatcher.CheckLockInForce(lockingMode, lockTargets...); lockErr != nil { - event.Reason = lockErr.Error() - if err := s.EmitAuditEvent(s.ctx, event); err != nil { - s.Logger.WithError(err).Warn("Failed to emit session reject event.") - } - return ctx, trace.Wrap(lockErr) - } - - // Don't apply the following checks in non-node contexts. - if s.Component() != teleport.ComponentNode { - return ctx, nil - } - - maxConnections := identityContext.AccessChecker.MaxConnections() - if maxConnections == 0 { - // concurrent session control is not active, nothing - // else needs to be done here. - return ctx, nil - } - - netConfig, err := s.GetAccessPoint().GetClusterNetworkingConfig(ctx) - if err != nil { - return ctx, trace.Wrap(err) - } - - semLock, err := services.AcquireSemaphoreLock(ctx, services.SemaphoreLockConfig{ - Service: s.authService, - Expiry: netConfig.GetSessionControlTimeout(), - Params: types.AcquireSemaphoreRequest{ - SemaphoreKind: types.SemaphoreKindConnection, - SemaphoreName: identityContext.TeleportUser, - MaxLeases: maxConnections, - Holder: s.uuid, - }, - }) - if err != nil { - if strings.Contains(err.Error(), teleport.MaxLeases) { - // user has exceeded their max concurrent ssh connections. - userSessionLimitHitCount.Inc() - event.Reason = events.SessionRejectedEvent - event.Maximum = maxConnections - if err := s.EmitAuditEvent(s.ctx, event); err != nil { - s.Logger.WithError(err).Warn("Failed to emit session reject event.") - } - err = trace.AccessDenied("too many concurrent ssh connections for user %q (max=%d)", - identityContext.TeleportUser, - maxConnections, - ) - } - return ctx, trace.Wrap(err) - } - - // ensure that losing the lock closes the connection context. Under normal - // conditions, cancellation propagates from the connection context to the - // lock, but if we lose the lock due to some error (e.g. poor connectivity - // to auth server) then cancellation propagates in the other direction. - go func() { - // TODO(fspmarshall): If lock was lost due to error, find a way to propagate - // an error message to user. - <-semLock.Done() - ccx.Close() - }() - return ctx, nil + ctx, err = s.sessionController.AcquireSessionContext(ctx, identityContext, ccx.ServerConn.LocalAddr().String(), ccx.ServerConn.RemoteAddr().String(), ccx) + return ctx, trace.Wrap(err) } // HandleNewChan is called when new channel is opened diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 702c9db064983..90a54c81e412f 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -59,7 +59,9 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/limiter" + "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/pam" + libproxy "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -211,6 +213,18 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO require.NoError(t, err) t.Cleanup(func() { require.NoError(t, nodeClient.Close()) }) + lockWatcher := newLockWatcher(ctx, t, nodeClient) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: lockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: nodeID, + }) + require.NoError(t, err) + nodeDir := t.TempDir() serverOptions := []ServerOption{ SetUUID(nodeID), @@ -231,13 +245,15 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO SetBPF(&bpf.NOP{}), SetRestrictedSessionManager(&restricted.NOP{}), SetClock(clock), - SetLockWatcher(newLockWatcher(ctx, t, nodeClient)), + SetLockWatcher(lockWatcher), SetX11ForwardingConfig(&x11.ServerConfig{}), + SetSessionController(sessionController), } serverOptions = append(serverOptions, sshOpts...) sshSrv, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, testServer.ClusterName(), []ssh.Signer{signer}, @@ -1398,7 +1414,27 @@ func TestProxyRoundRobin(t *testing.T) { require.NoError(t, reverseTunnelServer.Start()) defer reverseTunnelServer.Close() + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: f.testSrv.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: proxyClient, + SiteGetter: reverseTunnelServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: proxyClient, + AccessPoint: proxyClient, + LockEnforcer: lockWatcher, + Emitter: proxyClient, + Component: teleport.ComponentNode, + ServerID: hostID, + }) + require.NoError(t, err) + proxy, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -1407,7 +1443,7 @@ func TestProxyRoundRobin(t *testing.T) { "", utils.NetAddr{}, proxyClient, - SetProxyMode("", reverseTunnelServer, proxyClient), + SetProxyMode("", reverseTunnelServer, proxyClient, router), SetSessionServer(proxyClient), SetEmitter(nodeClient), SetNamespace(apidefaults.Namespace), @@ -1417,6 +1453,7 @@ func TestProxyRoundRobin(t *testing.T) { SetClock(f.clock), SetLockWatcher(lockWatcher), SetNodeWatcher(nodeWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1519,7 +1556,27 @@ func TestProxyDirectAccess(t *testing.T) { nodeClient, _ := newNodeClient(t, f.testSrv) + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: f.testSrv.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: proxyClient, + SiteGetter: reverseTunnelServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: lockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: hostID, + }) + require.NoError(t, err) + proxy, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -1528,7 +1585,7 @@ func TestProxyDirectAccess(t *testing.T) { "", utils.NetAddr{}, proxyClient, - SetProxyMode("", reverseTunnelServer, proxyClient), + SetProxyMode("", reverseTunnelServer, proxyClient, router), SetSessionServer(proxyClient), SetEmitter(nodeClient), SetNamespace(apidefaults.Namespace), @@ -1538,6 +1595,7 @@ func TestProxyDirectAccess(t *testing.T) { SetClock(f.clock), SetLockWatcher(lockWatcher), SetNodeWatcher(nodeWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1688,8 +1746,22 @@ func TestLimiter(t *testing.T) { require.NoError(t, err) nodeClient, _ := newNodeClient(t, f.testSrv) + + lockWatcher := newLockWatcher(ctx, t, nodeClient) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: lockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: hostID, + }) + require.NoError(t, err) + nodeStateDir := t.TempDir() srv, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -1707,7 +1779,8 @@ func TestLimiter(t *testing.T) { SetBPF(&bpf.NOP{}), SetRestrictedSessionManager(&restricted.NOP{}), SetClock(f.clock), - SetLockWatcher(newLockWatcher(ctx, t, nodeClient)), + SetLockWatcher(lockWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, srv.Start()) @@ -2251,7 +2324,27 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { nodeClient, _ := newNodeClient(t, f.testSrv) + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: f.testSrv.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: proxyClient, + SiteGetter: reverseTunnelServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: proxyClient, + AccessPoint: proxyClient, + LockEnforcer: lockWatcher, + Emitter: proxyClient, + Component: teleport.ComponentProxy, + ServerID: hostID, + }) + require.NoError(t, err) + proxy, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -2260,7 +2353,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { "", utils.NetAddr{}, proxyClient, - SetProxyMode("", reverseTunnelServer, proxyClient), + SetProxyMode("", reverseTunnelServer, proxyClient, router), SetSessionServer(proxyClient), SetEmitter(nodeClient), SetNamespace(apidefaults.Namespace), @@ -2270,6 +2363,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { SetClock(f.clock), SetLockWatcher(lockWatcher), SetNodeWatcher(nodeWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, proxy.Start()) diff --git a/lib/srv/session_control.go b/lib/srv/session_control.go new file mode 100644 index 0000000000000..0033aa074a7f3 --- /dev/null +++ b/lib/srv/session_control.go @@ -0,0 +1,259 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package srv + +import ( + "context" + "io" + "strings" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" +) + +var ( + userSessionLimitHitCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: teleport.MetricUserMaxConcurrentSessionsHit, + Help: "Number of times a user exceeded their max concurrent ssh connections", + }, + ) +) + +func init() { + _ = utils.RegisterPrometheusCollectors(userSessionLimitHitCount) +} + +// LockEnforcer determines whether a lock is being enforced on the provided targets +type LockEnforcer interface { + CheckLockInForce(mode constants.LockingMode, targets ...types.LockTarget) error +} + +// SessionControllerConfig contains dependencies needed to +// create a SessionController +type SessionControllerConfig struct { + // Semaphores is used to obtain a semaphore lock when max sessions are defined + Semaphores types.Semaphores + // AccessPoint is the cache used to get cluster information + AccessPoint AccessPoint + // LockEnforcer is used to determine if locks should prevent a session + LockEnforcer LockEnforcer + // Emitter is used to emit session rejection events + Emitter apievents.Emitter + // Component is the component running the session controller. Nodes and Proxies + // have different flows + Component string + // Logger is used to emit log entries + Logger *logrus.Entry + // TracerProvider creates a tracer so that spans may be emitted + TracerProvider oteltrace.TracerProvider + // ServerID is the UUID of the server + ServerID string + // Clock used in tests to change time + Clock clockwork.Clock + + tracer oteltrace.Tracer +} + +// CheckAndSetDefaults ensures all the required dependencies were +// provided and sets any optional values to their defaults +func (c *SessionControllerConfig) CheckAndSetDefaults() error { + if c.Semaphores == nil { + return trace.BadParameter("Semaphores must be provided") + } + + if c.AccessPoint == nil { + return trace.BadParameter("AccessPoint must be provided") + } + + if c.LockEnforcer == nil { + return trace.BadParameter("LockWatcher must be provided") + } + + if c.Emitter == nil { + return trace.BadParameter("Emitter must be provided") + } + + if c.Component == "" { + return trace.BadParameter("Component must be provided") + } + + if c.TracerProvider == nil { + c.TracerProvider = tracing.DefaultProvider() + } + + if c.Logger == nil { + c.Logger = logrus.WithField(trace.Component, "SessionCtrl") + } + + if c.Clock == nil { + c.Clock = clockwork.NewRealClock() + } + + c.tracer = c.TracerProvider.Tracer("SessionController") + + return nil +} + +// SessionController enforces session control restrictions required by +// locks, private key policy, and max connection limits +type SessionController struct { + cfg SessionControllerConfig +} + +// NewSessionController creates a SessionController from the provided config. If any +// of the required parameters in the SessionControllerConfig are not provided an +// error is returned. +func NewSessionController(cfg SessionControllerConfig) (*SessionController, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &SessionController{cfg: cfg}, nil +} + +// AcquireSessionContext attempts to create a context for the session. If the session is +// not allowed due to session control an error is returned. The returned +// context is scoped to the session and will be canceled in the event the semaphore lock +// is no longer held. The closers provided are immediately closed when the semaphore lock +// is released as well. +func (s *SessionController) AcquireSessionContext(ctx context.Context, identity IdentityContext, localAddr, remoteAddr string, closers ...io.Closer) (context.Context, error) { + // create a separate context for tracing the operations + // within that doesn't leak into the returned context + spanCtx, span := s.cfg.tracer.Start(ctx, "SessionController/AcquireSessionContext") + defer span.End() + + authPref, err := s.cfg.AccessPoint.GetAuthPreference(spanCtx) + if err != nil { + return ctx, trace.Wrap(err) + } + + clusterName, err := s.cfg.AccessPoint.GetClusterName() + if err != nil { + return ctx, trace.Wrap(err) + } + + lockingMode := identity.AccessChecker.LockingMode(authPref.GetLockingMode()) + lockTargets := ComputeLockTargets(clusterName.GetClusterName(), s.cfg.ServerID, identity) + + if lockErr := s.cfg.LockEnforcer.CheckLockInForce(lockingMode, lockTargets...); lockErr != nil { + s.emitRejection(spanCtx, identity.GetUserMetadata(), localAddr, remoteAddr, lockErr.Error(), 0) + return ctx, trace.Wrap(lockErr) + } + + // Don't apply the following checks in non-node contexts. + if s.cfg.Component != teleport.ComponentNode { + return ctx, nil + } + + maxConnections := identity.AccessChecker.MaxConnections() + if maxConnections == 0 { + // concurrent session control is not active, nothing + // else needs to be done here. + return ctx, nil + } + + netConfig, err := s.cfg.AccessPoint.GetClusterNetworkingConfig(spanCtx) + if err != nil { + return ctx, trace.Wrap(err) + } + + semLock, err := services.AcquireSemaphoreLock(spanCtx, services.SemaphoreLockConfig{ + Service: s.cfg.Semaphores, + Clock: s.cfg.Clock, + Expiry: netConfig.GetSessionControlTimeout(), + Params: types.AcquireSemaphoreRequest{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: identity.TeleportUser, + MaxLeases: maxConnections, + Holder: s.cfg.ServerID, + }, + }) + if err != nil { + if strings.Contains(err.Error(), teleport.MaxLeases) { + // user has exceeded their max concurrent ssh connections. + userSessionLimitHitCount.Inc() + s.emitRejection(spanCtx, identity.GetUserMetadata(), localAddr, remoteAddr, events.SessionRejectedEvent, maxConnections) + + return ctx, trace.AccessDenied("too many concurrent ssh connections for user %q (max=%d)", identity.TeleportUser, maxConnections) + } + + return ctx, trace.Wrap(err) + } + + ctx, cancel := context.WithCancel(ctx) + // ensure that losing the lock closes the connection context. Under normal + // conditions, cancellation propagates from the connection context to the + // lock, but if we lose the lock due to some error (e.g. poor connectivity + // to auth server) then cancellation propagates in the other direction. + go func() { + // TODO(fspmarshall): If lock was lost due to error, find a way to propagate + // an error message to user. + <-semLock.Done() + cancel() + + // close any provided closers + for _, closer := range closers { + _ = closer.Close() + } + }() + + return ctx, nil +} + +// emitRejection emits a SessionRejectedEvent with the provided information +func (s *SessionController) emitRejection(ctx context.Context, userMetadata apievents.UserMetadata, localAddr, remoteAddr string, reason string, max int64) { + // link a background context to the current span so things + // are related but while still allowing the audit event to + // not be tied to the request scoped context + emitCtx := oteltrace.ContextWithSpanContext(context.Background(), oteltrace.SpanContextFromContext(ctx)) + + ctx, span := s.cfg.tracer.Start(emitCtx, "SessionController/emitRejection") + defer span.End() + + if err := s.cfg.Emitter.EmitAuditEvent(ctx, &apievents.SessionReject{ + Metadata: apievents.Metadata{ + Type: events.SessionRejectedEvent, + Code: events.SessionRejectedCode, + }, + UserMetadata: userMetadata, + ConnectionMetadata: apievents.ConnectionMetadata{ + Protocol: events.EventProtocolSSH, + LocalAddr: localAddr, + RemoteAddr: remoteAddr, + }, + ServerMetadata: apievents.ServerMetadata{ + ServerID: s.cfg.ServerID, + ServerNamespace: apidefaults.Namespace, + }, + Reason: reason, + Maximum: max, + }); err != nil { + s.cfg.Logger.WithError(err).Warn("Failed to emit session reject event.") + } +} diff --git a/lib/srv/session_control_test.go b/lib/srv/session_control_test.go new file mode 100644 index 0000000000000..657bdafe3ce9c --- /dev/null +++ b/lib/srv/session_control_test.go @@ -0,0 +1,348 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package srv + +import ( + "context" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/services" +) + +type mockLockEnforcer struct { + lockInForceErr error +} + +func (m mockLockEnforcer) CheckLockInForce(constants.LockingMode, ...types.LockTarget) error { + return m.lockInForceErr +} + +type mockAccessPoint struct { + AccessPoint + + authPreference types.AuthPreference + clusterName types.ClusterName + netConfig types.ClusterNetworkingConfig +} + +func (m mockAccessPoint) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { + return m.authPreference, nil +} + +func (m mockAccessPoint) GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) { + return m.clusterName, nil +} + +func (m mockAccessPoint) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) { + return m.netConfig, nil +} + +type mockSemaphores struct { + types.Semaphores + + lease *types.SemaphoreLease + acquireErr error +} + +func (m mockSemaphores) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) { + return m.lease, m.acquireErr +} + +func (m mockSemaphores) CancelSemaphoreLease(ctx context.Context, lease types.SemaphoreLease) error { + return nil +} + +type mockAccessChecker struct { + services.AccessChecker + + lockMode constants.LockingMode + maxConnections int64 + roleNames []string +} + +func (m mockAccessChecker) LockingMode(defaultMode constants.LockingMode) constants.LockingMode { + return m.lockMode +} + +func (m mockAccessChecker) MaxConnections() int64 { + return m.maxConnections +} + +func (m mockAccessChecker) RoleNames() []string { + return m.roleNames +} + +func TestSessionController_AcquireSessionContext(t *testing.T) { + t.Parallel() + + clock := clockwork.NewFakeClock() + emitter := &eventstest.MockEmitter{} + + cases := []struct { + name string + cfg SessionControllerConfig + identity IdentityContext + assertion func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) + }{ + { + name: "proxy: access allowed", + cfg: SessionControllerConfig{ + Semaphores: mockSemaphores{}, + AccessPoint: mockAccessPoint{ + netConfig: &types.ClusterNetworkingConfigV2{}, + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentProxy, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{}, + }, + }, + AccessChecker: mockAccessChecker{ + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.NoError(t, err) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events()) + }, + }, + { + name: "node: access allowed", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{ + lease: &types.SemaphoreLease{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: "test", + LeaseID: "1", + Expires: clock.Now().Add(time.Minute), + }, + }, + AccessPoint: mockAccessPoint{ + netConfig: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + SessionControlTimeout: types.NewDuration(time.Minute), + }, + }, + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{}, + }, + }, + AccessChecker: mockAccessChecker{ + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.NoError(t, err) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events()) + }, + }, + { + name: "session rejected due to lock", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{}, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{ + lockInForceErr: trace.AccessDenied("lock in force"), + }, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{}, + }, + }, + AccessChecker: mockAccessChecker{ + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.ErrorIs(t, err, trace.AccessDenied("lock in force")) + require.NotNil(t, ctx) + require.Len(t, emitter.Events(), 1) + + evt, ok := emitter.Events()[0].(*apievents.SessionReject) + require.True(t, ok) + require.Equal(t, events.SessionRejectedEvent, evt.Metadata.Type) + require.Equal(t, events.SessionRejectedCode, evt.Metadata.Code) + require.Equal(t, events.EventProtocolSSH, evt.ConnectionMetadata.Protocol) + require.Equal(t, "lock in force", evt.Reason) + }, + }, + { + name: "session rejected due to connection limit", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{ + acquireErr: trace.LimitExceeded(teleport.MaxLeases), + }, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + netConfig: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + SessionControlTimeout: types.NewDuration(time.Minute), + }, + }, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{}, + }, + }, + AccessChecker: mockAccessChecker{ + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + require.NotNil(t, ctx) + require.Len(t, emitter.Events(), 1) + + evt, ok := emitter.Events()[0].(*apievents.SessionReject) + require.True(t, ok) + require.Equal(t, events.SessionRejectedEvent, evt.Metadata.Type) + require.Equal(t, events.SessionRejectedCode, evt.Metadata.Code) + require.Equal(t, events.EventProtocolSSH, evt.ConnectionMetadata.Protocol) + require.Equal(t, events.SessionRejectedEvent, evt.Reason) + require.Equal(t, int64(1), evt.Maximum) + }, + }, + { + name: "no connection limits prevent acquiring semaphore lock", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{ + acquireErr: trace.LimitExceeded(teleport.MaxLeases), + }, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + netConfig: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + SessionControlTimeout: types.NewDuration(time.Minute), + }, + }, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{}, + }, + }, + AccessChecker: mockAccessChecker{ + maxConnections: 0, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.NoError(t, err) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events(), 0) + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + emitter.Reset() + ctrl, err := NewSessionController(tt.cfg) + require.NoError(t, err) + + ctx, err := ctrl.AcquireSessionContext(context.Background(), tt.identity, "127.0.0.1:1", "127.0.0.1:2") + tt.assertion(t, ctx, err, emitter) + + }) + } +} diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 64a0cb79eb3e3..8f7123f8b7bd3 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -52,9 +52,25 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/gorilla/websocket" - "github.com/gravitational/roundtrip" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/julienschmidt/httprouter" + lemma_secret "github.com/mailgun/lemma/secret" + "github.com/pquerna/otp/totp" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "golang.org/x/exp/slices" + "golang.org/x/text/encoding/unicode" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/serializer" + kyaml "k8s.io/apimachinery/pkg/util/yaml" + authztypes "k8s.io/client-go/kubernetes/typed/authorization/v1" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/breaker" @@ -82,7 +98,9 @@ import ( kubeproxy "github.com/gravitational/teleport/lib/kube/proxy" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/pam" + libproxy "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/secret" @@ -96,24 +114,6 @@ import ( "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/ui" - - "github.com/jonboulle/clockwork" - "github.com/julienschmidt/httprouter" - lemma_secret "github.com/mailgun/lemma/secret" - "github.com/pquerna/otp/totp" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" - "golang.org/x/exp/slices" - "golang.org/x/text/encoding/unicode" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/serializer" - kyaml "k8s.io/apimachinery/pkg/util/yaml" - authztypes "k8s.io/client-go/kubernetes/typed/authorization/v1" - "k8s.io/client-go/tools/clientcmd" - clientcmdapi "k8s.io/client-go/tools/clientcmd/api" ) const hostID = "00000000-0000-0000-0000-000000000000" @@ -238,9 +238,20 @@ func newWebSuite(t *testing.T) *WebSuite { }) require.NoError(t, err) + nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: nodeLockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: nodeID, + }) + require.NoError(t, err) + // create SSH service: nodeDataDir := t.TempDir() node, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{signer}, @@ -259,6 +270,7 @@ func newWebSuite(t *testing.T) *WebSuite { regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(s.clock), regular.SetLockWatcher(nodeLockWatcher), + regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) s.node = node @@ -325,8 +337,28 @@ func newWebSuite(t *testing.T) *WebSuite { require.NoError(t, err) s.proxyTunnel = revTunServer + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: s.server.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: s.proxyClient, + SiteGetter: revTunServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + + proxySessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: s.proxyClient, + AccessPoint: s.proxyClient, + LockEnforcer: proxyLockWatcher, + Emitter: s.proxyClient, + Component: teleport.ComponentProxy, + ServerID: proxyID, + }) + require.NoError(t, err) + // proxy server: s.proxy, err = regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{signer}, @@ -336,7 +368,7 @@ func newWebSuite(t *testing.T) *WebSuite { utils.NetAddr{}, s.proxyClient, regular.SetUUID(proxyID), - regular.SetProxyMode("", revTunServer, s.proxyClient), + regular.SetProxyMode("", revTunServer, s.proxyClient, router), regular.SetSessionServer(s.proxyClient), regular.SetEmitter(s.proxyClient), regular.SetNamespace(apidefaults.Namespace), @@ -345,6 +377,7 @@ func newWebSuite(t *testing.T) *WebSuite { regular.SetClock(s.clock), regular.SetLockWatcher(proxyLockWatcher), regular.SetNodeWatcher(proxyNodeWatcher), + regular.SetSessionController(proxySessionController), ) require.NoError(t, err) @@ -4984,9 +5017,20 @@ func newWebPack(t *testing.T, numProxies int) *webPack { require.NoError(t, err) t.Cleanup(nodeLockWatcher.Close) + nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: nodeLockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: nodeID, + }) + require.NoError(t, err) + // create SSH service: nodeDataDir := t.TempDir() node, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, server.TLS.ClusterName(), hostSigners, @@ -5005,6 +5049,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack { regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(clock), regular.SetLockWatcher(nodeLockWatcher), + regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) @@ -5103,7 +5148,27 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) }) + router, err := libproxy.NewRouter(libproxy.RouterConfig{ + ClusterName: authServer.ClusterName(), + Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), + RemoteClusterGetter: client, + SiteGetter: revTunServer, + TracerProvider: tracing.NoopProvider(), + }) + require.NoError(t, err) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: client, + AccessPoint: client, + LockEnforcer: proxyLockWatcher, + Emitter: client, + Component: teleport.ComponentProxy, + ServerID: proxyID, + }) + require.NoError(t, err) + proxyServer, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, authServer.ClusterName(), hostSigners, @@ -5113,7 +5178,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula utils.NetAddr{}, client, regular.SetUUID(proxyID), - regular.SetProxyMode("", revTunServer, client), + regular.SetProxyMode("", revTunServer, client, router), regular.SetSessionServer(client), regular.SetEmitter(client), regular.SetNamespace(apidefaults.Namespace), @@ -5122,6 +5187,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula regular.SetClock(clock), regular.SetLockWatcher(proxyLockWatcher), regular.SetNodeWatcher(proxyNodeWatcher), + regular.SetSessionController(sessionController), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, proxyServer.Close()) })