From 6552a108880f139f260ff32491f8dfdfd585e91d Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 17 Nov 2022 11:56:25 -0500 Subject: [PATCH] Make session control logic reusable Session control logic existed within `HandleNewConn` of `srv/regular.Server`. This prevented any of it from being used by other components that also needed to enforce session control. All the logic from within `HandleNewConn` was refactored to a new `srv.SessionController` object which the `regular.Server` now uses to perform session control. There were a few additional changes needed to accomadate that session control now exists outside the server and to make tests easier to write. Namely, altering `srv.ComputLockTargets` to not take a `Server` as a parameter and leveraging a clock within `services.SemaphoreLock`. This is step 2 in addressing #15167. Before the web apiserver can leverage the newly introduced proxy.Router and bypass making ssh connections to the proxy ssh server it needs to be able to perfrom session control. --- integration/utmp_integration_test.go | 13 + lib/service/service.go | 52 +++- lib/services/semaphore.go | 20 +- lib/srv/ctx.go | 31 +- lib/srv/regular/sshserver.go | 136 ++------- lib/srv/regular/sshserver_test.go | 69 ++++- lib/srv/session_control.go | 268 ++++++++++++++++++ lib/srv/session_control_test.go | 408 +++++++++++++++++++++++++++ lib/web/apiserver_test.go | 48 ++++ 9 files changed, 902 insertions(+), 143 deletions(-) create mode 100644 lib/srv/session_control.go create mode 100644 lib/srv/session_control_test.go diff --git a/integration/utmp_integration_test.go b/integration/utmp_integration_test.go index 0162a12508dee..61cd3c1956532 100644 --- a/integration/utmp_integration_test.go +++ b/integration/utmp_integration_test.go @@ -40,6 +40,7 @@ import ( "github.com/gravitational/teleport/lib/pam" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/regular" "github.com/gravitational/teleport/lib/srv/uacc" "github.com/gravitational/teleport/lib/sshutils" @@ -257,8 +258,19 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx { require.NoError(t, err) t.Cleanup(lockWatcher.Close) + nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: s.nodeClient, + AccessPoint: s.nodeClient, + LockEnforcer: lockWatcher, + Emitter: s.nodeClient, + Component: teleport.ComponentNode, + ServerID: s.nodeID, + }) + require.NoError(t, err) + nodeDir := t.TempDir() srv, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{s.signer}, @@ -286,6 +298,7 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx { regular.SetClock(s.clock), regular.SetUtmpPath(utmpPath, utmpPath), regular.SetLockWatcher(lockWatcher), + regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) s.srv = srv diff --git a/lib/service/service.go b/lib/service/service.go index c3e06dbd3aa40..b79acd34d949a 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2288,7 +2288,29 @@ func (process *TeleportProcess) initSSH() error { storagePresence := local.NewPresenceService(process.storage.BackendStorage) - s, err := regular.New(cfg.SSH.Addr, + // read the host UUID: + serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir) + if err != nil { + return trace.Wrap(err) + } + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: authClient, + AccessPoint: authClient, + LockEnforcer: lockWatcher, + Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer}, + Component: teleport.ComponentNode, + Logger: process.log.WithField(trace.Component, "sessionctrl"), + TracerProvider: process.TracingProvider, + ServerID: serverID, + }) + if err != nil { + return trace.Wrap(err) + } + + s, err := regular.New( + process.ExitContext(), + cfg.SSH.Addr, cfg.Hostname, []ssh.Signer{conn.ServerIdentity.KeySigner}, authClient, @@ -2320,6 +2342,8 @@ func (process *TeleportProcess) initSSH() error { regular.SetCreateHostUser(!cfg.SSH.DisableCreateHostUser), regular.SetStoragePresenceService(storagePresence), regular.SetInventoryControlHandle(process.inventoryHandle), + regular.SetTracerProvider(process.TracingProvider), + regular.SetSessionController(sessionController), ) if err != nil { return trace.Wrap(err) @@ -3627,7 +3651,29 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { proxyRouter = router } - sshProxy, err := regular.New(cfg.Proxy.SSHAddr, + // read the host UUID: + serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir) + if err != nil { + return trace.Wrap(err) + } + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: accessPoint, + AccessPoint: accessPoint, + LockEnforcer: lockWatcher, + Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer}, + Component: teleport.ComponentProxy, + Logger: process.log.WithField(trace.Component, "sessionctrl"), + TracerProvider: process.TracingProvider, + ServerID: serverID, + }) + if err != nil { + return trace.Wrap(err) + } + + sshProxy, err := regular.New( + process.ExitContext(), + cfg.SSH.Addr, cfg.Hostname, []ssh.Signer{conn.ServerIdentity.KeySigner}, accessPoint, @@ -3651,6 +3697,8 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // accurately checked later when an SCP/SFTP request hits the // destination Node. regular.SetAllowFileCopying(true), + regular.SetTracerProvider(process.TracingProvider), + regular.SetSessionController(sessionController), ) if err != nil { return trace.Wrap(err) diff --git a/lib/services/semaphore.go b/lib/services/semaphore.go index cc72f67930e89..78df632e1888b 100644 --- a/lib/services/semaphore.go +++ b/lib/services/semaphore.go @@ -20,6 +20,7 @@ import ( "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" @@ -39,10 +40,16 @@ type SemaphoreLockConfig struct { TickRate time.Duration // Params holds the semaphore lease acquisition parameters. Params types.AcquireSemaphoreRequest + // Clock used to alter time in tests + Clock clockwork.Clock } // CheckAndSetDefaults checks and sets default parameters func (l *SemaphoreLockConfig) CheckAndSetDefaults() error { + if l.Clock == nil { + l.Clock = clockwork.NewRealClock() + } + if l.Service == nil { return trace.BadParameter("missing semaphore service") } @@ -59,7 +66,7 @@ func (l *SemaphoreLockConfig) CheckAndSetDefaults() error { return trace.BadParameter("tick-rate must be less than expiry") } if l.Params.Expires.IsZero() { - l.Params.Expires = time.Now().UTC().Add(l.Expiry) + l.Params.Expires = l.Clock.Now().UTC().Add(l.Expiry) } if err := l.Params.Check(); err != nil { return trace.Wrap(err) @@ -73,7 +80,7 @@ type SemaphoreLock struct { cfg SemaphoreLockConfig lease0 types.SemaphoreLease retry retryutils.Retry - ticker *time.Ticker + ticker clockwork.Ticker doneC chan struct{} closeOnce sync.Once renewalC chan struct{} @@ -140,7 +147,7 @@ func (l *SemaphoreLock) keepAlive(ctx context.Context) { // cancellation/expiry. return } - if lease.Expires.After(time.Now().UTC()) { + if lease.Expires.After(l.cfg.Clock.Now().UTC()) { // parent context is closed. create orphan context with generous // timeout for lease cancellation scope. this will not block any // caller that is not explicitly waiting on the final error value. @@ -157,7 +164,7 @@ func (l *SemaphoreLock) keepAlive(ctx context.Context) { Outer: for { select { - case tick := <-l.ticker.C: + case tick := <-l.ticker.Chan(): leaseContext, leaseCancel := context.WithDeadline(ctx, lease.Expires) nextLease := lease nextLease.Expires = tick.Add(l.cfg.Expiry) @@ -185,7 +192,7 @@ Outer: l.retry.Inc() select { case <-l.retry.After(): - case tick = <-l.ticker.C: + case tick = <-l.ticker.Chan(): // check to make sure that we still have some time on the lease. the default tick rate would have // us waking _as_ the lease expires here, but if we're working with a higher tick rate, its worth // retrying again. @@ -247,6 +254,7 @@ func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*Semaph Max: cfg.Expiry / 4, Step: cfg.Expiry / 16, Jitter: retryutils.NewJitter(), + Clock: cfg.Clock, }) if err != nil { return nil, trace.Wrap(err) @@ -259,7 +267,7 @@ func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*Semaph cfg: cfg, lease0: *lease, retry: retry, - ticker: time.NewTicker(cfg.TickRate), + ticker: cfg.Clock.NewTicker(cfg.TickRate), doneC: make(chan struct{}), renewalC: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 6a8c2d9cd948d..18e794204f5b9 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -435,13 +435,15 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s trace.ComponentFields: fields, }) - lockTargets, err := ComputeLockTargets(srv, identityContext) + clusterName, err := srv.GetAccessPoint().GetClusterName() if err != nil { - return nil, nil, trace.Wrap(err) + childErr := child.Close() + return nil, nil, trace.NewAggregate(err, childErr) } + monitorConfig := MonitorConfig{ LockWatcher: child.srv.GetLockWatcher(), - LockTargets: lockTargets, + LockTargets: ComputeLockTargets(clusterName.GetClusterName(), srv.HostUUID(), identityContext), LockingMode: identityContext.AccessChecker.LockingMode(authPref.GetLockingMode()), DisconnectExpiredCert: child.disconnectExpiredCert, ClientIdleTimeout: child.clientIdleTimeout, @@ -1147,28 +1149,19 @@ func newUaccMetadata(c *ServerContext) (*UaccMetadata, error) { }, nil } -// ComputeLockTargets computes lock targets inferred from a Server -// and an IdentityContext. -func ComputeLockTargets(s Server, id IdentityContext) ([]types.LockTarget, error) { - clusterName, err := s.GetAccessPoint().GetClusterName() - if err != nil { - return nil, trace.Wrap(err) - } +// ComputeLockTargets computes lock targets inferred from the clusterName, serverID and IdentityContext. +func ComputeLockTargets(clusterName, serverID string, id IdentityContext) []types.LockTarget { lockTargets := []types.LockTarget{ {User: id.TeleportUser}, {Login: id.Login}, - {Node: s.HostUUID()}, - {Node: auth.HostFQDN(s.HostUUID(), clusterName.GetClusterName())}, + {Node: serverID}, + {Node: auth.HostFQDN(serverID, clusterName)}, {MFADevice: id.Certificate.Extensions[teleport.CertExtensionMFAVerified]}, } roles := apiutils.Deduplicate(append(id.AccessChecker.RoleNames(), id.UnmappedRoles...)) - lockTargets = append(lockTargets, - services.RolesToLockTargets(roles)..., - ) - lockTargets = append(lockTargets, - services.AccessRequestsToLockTargets(id.ActiveRequests)..., - ) - return lockTargets, nil + lockTargets = append(lockTargets, services.RolesToLockTargets(roles)...) + lockTargets = append(lockTargets, services.AccessRequestsToLockTargets(id.ActiveRequests)...) + return lockTargets } // SetRequest sets the ssh request that was issued by the client. diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index b07c5c79003c6..01ed7820198f9 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" oteltrace "go.opentelemetry.io/otel/trace" @@ -45,7 +44,6 @@ import ( tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/defaults" @@ -53,7 +51,6 @@ import ( "github.com/gravitational/teleport/lib/inventory" "github.com/gravitational/teleport/lib/labels" "github.com/gravitational/teleport/lib/limiter" - "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/pam" "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" @@ -73,13 +70,6 @@ var ( log = logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentNode, }) - - userSessionLimitHitCount = prometheus.NewCounter( - prometheus.CounterOpts{ - Name: teleport.MetricUserMaxConcurrentSessionsHit, - Help: "Number of times a user exceeded their max concurrent ssh connections", - }, - ) ) // Server implements SSH server that uses configuration backend and @@ -228,6 +218,10 @@ type Server struct { // router used by subsystem requests to connect to nodes // and clusters router *proxy.Router + + // sessionController is used to restrict new sessions + // based on locks and cluster preferences + sessionController *srv.SessionController } // TargetMetadata returns metadata about the server. @@ -658,8 +652,18 @@ func SetTracerProvider(provider oteltrace.TracerProvider) ServerOption { } } +// SetSessionController sets the session controller. +func SetSessionController(controller *srv.SessionController) ServerOption { + return func(s *Server) error { + s.sessionController = controller + return nil + } +} + // New returns an unstarted server -func New(addr utils.NetAddr, +func New( + ctx context.Context, + addr utils.NetAddr, hostname string, signers []ssh.Signer, authService srv.AccessPoint, @@ -669,18 +673,13 @@ func New(addr utils.NetAddr, auth auth.ClientI, options ...ServerOption, ) (*Server, error) { - err := metrics.RegisterPrometheusCollectors(userSessionLimitHitCount) - if err != nil { - return nil, trace.Wrap(err) - } - // read the host UUID: uuid, err := utils.ReadOrMakeHostUUID(dataDir) if err != nil { return nil, trace.Wrap(err) } - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(ctx) s := &Server{ addr: addr, authService: authService, @@ -723,6 +722,10 @@ func New(addr utils.NetAddr, return nil, trace.BadParameter("setup valid LockWatcher parameter using SetLockWatcher") } + if s.sessionController == nil { + return nil, trace.BadParameter("setup valid SessionControl parameter using SetSessionControl") + } + if s.connectedProxyGetter == nil { s.connectedProxyGetter = reversetunnel.NewConnectedProxyGetter() } @@ -1070,104 +1073,9 @@ func (s *Server) HandleNewConn(ctx context.Context, ccx *sshutils.ConnectionCont if err != nil { return ctx, trace.Wrap(err) } - authPref, err := s.GetAccessPoint().GetAuthPreference(ctx) - if err != nil { - return ctx, trace.Wrap(err) - } - lockingMode := identityContext.AccessChecker.LockingMode(authPref.GetLockingMode()) - - event := &apievents.SessionReject{ - Metadata: apievents.Metadata{ - Type: events.SessionRejectedEvent, - Code: events.SessionRejectedCode, - }, - UserMetadata: identityContext.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - Protocol: events.EventProtocolSSH, - LocalAddr: ccx.ServerConn.LocalAddr().String(), - RemoteAddr: ccx.ServerConn.RemoteAddr().String(), - }, - ServerMetadata: apievents.ServerMetadata{ - ServerID: s.uuid, - ServerNamespace: s.GetNamespace(), - }, - } - lockTargets, err := srv.ComputeLockTargets(s, identityContext) - if err != nil { - return ctx, trace.Wrap(err) - } - if lockErr := s.lockWatcher.CheckLockInForce(lockingMode, lockTargets...); lockErr != nil { - event.Reason = lockErr.Error() - if err := s.EmitAuditEvent(s.ctx, event); err != nil { - s.Logger.WithError(err).Warn("Failed to emit session reject event.") - } - return ctx, trace.Wrap(lockErr) - } - - // Check that the required private key policy, defined by roles and auth pref, - // is met by this Identity's ssh certificate. - identityPolicy := identityContext.Certificate.Extensions[teleport.CertExtensionPrivateKeyPolicy] - requiredPolicy := identityContext.AccessChecker.PrivateKeyPolicy(authPref.GetPrivateKeyPolicy()) - if err := requiredPolicy.VerifyPolicy(keys.PrivateKeyPolicy(identityPolicy)); err != nil { - return nil, trace.Wrap(err) - } - - // Don't apply the following checks in non-node contexts. - if s.Component() != teleport.ComponentNode { - return ctx, nil - } - - maxConnections := identityContext.AccessChecker.MaxConnections() - if maxConnections == 0 { - // concurrent session control is not active, nothing - // else needs to be done here. - return ctx, nil - } - - netConfig, err := s.GetAccessPoint().GetClusterNetworkingConfig(ctx) - if err != nil { - return ctx, trace.Wrap(err) - } - - semLock, err := services.AcquireSemaphoreLock(ctx, services.SemaphoreLockConfig{ - Service: s.authService, - Expiry: netConfig.GetSessionControlTimeout(), - Params: types.AcquireSemaphoreRequest{ - SemaphoreKind: types.SemaphoreKindConnection, - SemaphoreName: identityContext.TeleportUser, - MaxLeases: maxConnections, - Holder: s.uuid, - }, - }) - if err != nil { - if strings.Contains(err.Error(), teleport.MaxLeases) { - // user has exceeded their max concurrent ssh connections. - userSessionLimitHitCount.Inc() - event.Reason = events.SessionRejectedEvent - event.Maximum = maxConnections - if err := s.EmitAuditEvent(s.ctx, event); err != nil { - s.Logger.WithError(err).Warn("Failed to emit session reject event.") - } - err = trace.AccessDenied("too many concurrent ssh connections for user %q (max=%d)", - identityContext.TeleportUser, - maxConnections, - ) - } - return ctx, trace.Wrap(err) - } - - // ensure that losing the lock closes the connection context. Under normal - // conditions, cancellation propagates from the connection context to the - // lock, but if we lose the lock due to some error (e.g. poor connectivity - // to auth server) then cancellation propagates in the other direction. - go func() { - // TODO(fspmarshall): If lock was lost due to error, find a way to propagate - // an error message to user. - <-semLock.Done() - ccx.Close() - }() - return ctx, nil + ctx, err = s.sessionController.AcquireSessionContext(ctx, identityContext, ccx.ServerConn.LocalAddr().String(), ccx.ServerConn.RemoteAddr().String(), ccx) + return ctx, trace.Wrap(err) } // HandleNewChan is called when new channel is opened diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index e4ac44763c481..cb61a0dc7feb3 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -213,6 +213,18 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO require.NoError(t, err) t.Cleanup(func() { require.NoError(t, nodeClient.Close()) }) + lockWatcher := newLockWatcher(ctx, t, nodeClient) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: lockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: nodeID, + }) + require.NoError(t, err) + nodeDir := t.TempDir() serverOptions := []ServerOption{ SetUUID(nodeID), @@ -232,13 +244,15 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO SetBPF(&bpf.NOP{}), SetRestrictedSessionManager(&restricted.NOP{}), SetClock(clock), - SetLockWatcher(newLockWatcher(ctx, t, nodeClient)), + SetLockWatcher(lockWatcher), SetX11ForwardingConfig(&x11.ServerConfig{}), + SetSessionController(sessionController), } serverOptions = append(serverOptions, sshOpts...) sshSrv, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, testServer.ClusterName(), []ssh.Signer{signer}, @@ -1394,7 +1408,18 @@ func TestProxyRoundRobin(t *testing.T) { }) require.NoError(t, err) + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: proxyClient, + AccessPoint: proxyClient, + LockEnforcer: lockWatcher, + Emitter: proxyClient, + Component: teleport.ComponentNode, + ServerID: hostID, + }) + require.NoError(t, err) + proxy, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -1412,6 +1437,7 @@ func TestProxyRoundRobin(t *testing.T) { SetClock(f.clock), SetLockWatcher(lockWatcher), SetNodeWatcher(nodeWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1523,7 +1549,18 @@ func TestProxyDirectAccess(t *testing.T) { }) require.NoError(t, err) + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: lockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: hostID, + }) + require.NoError(t, err) + proxy, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -1541,6 +1578,7 @@ func TestProxyDirectAccess(t *testing.T) { SetClock(f.clock), SetLockWatcher(lockWatcher), SetNodeWatcher(nodeWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1691,8 +1729,22 @@ func TestLimiter(t *testing.T) { require.NoError(t, err) nodeClient, _ := newNodeClient(t, f.testSrv) + + lockWatcher := newLockWatcher(ctx, t, nodeClient) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: lockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: hostID, + }) + require.NoError(t, err) + nodeStateDir := t.TempDir() srv, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -1709,7 +1761,8 @@ func TestLimiter(t *testing.T) { SetBPF(&bpf.NOP{}), SetRestrictedSessionManager(&restricted.NOP{}), SetClock(f.clock), - SetLockWatcher(newLockWatcher(ctx, t, nodeClient)), + SetLockWatcher(lockWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, srv.Start()) @@ -2262,7 +2315,18 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { }) require.NoError(t, err) + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: proxyClient, + AccessPoint: proxyClient, + LockEnforcer: lockWatcher, + Emitter: proxyClient, + Component: teleport.ComponentProxy, + ServerID: hostID, + }) + require.NoError(t, err) + proxy, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -2280,6 +2344,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { SetClock(f.clock), SetLockWatcher(lockWatcher), SetNodeWatcher(nodeWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, proxy.Start()) diff --git a/lib/srv/session_control.go b/lib/srv/session_control.go new file mode 100644 index 0000000000000..92920b51512fc --- /dev/null +++ b/lib/srv/session_control.go @@ -0,0 +1,268 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package srv + +import ( + "context" + "io" + "strings" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/observability/metrics" + "github.com/gravitational/teleport/lib/services" +) + +var ( + userSessionLimitHitCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: teleport.MetricUserMaxConcurrentSessionsHit, + Help: "Number of times a user exceeded their max concurrent ssh connections", + }, + ) +) + +func init() { + _ = metrics.RegisterPrometheusCollectors(userSessionLimitHitCount) +} + +// LockEnforcer determines whether a lock is being enforced on the provided targets +type LockEnforcer interface { + CheckLockInForce(mode constants.LockingMode, targets ...types.LockTarget) error +} + +// SessionControllerConfig contains dependencies needed to +// create a SessionController +type SessionControllerConfig struct { + // Semaphores is used to obtain a semaphore lock when max sessions are defined + Semaphores types.Semaphores + // AccessPoint is the cache used to get cluster information + AccessPoint AccessPoint + // LockEnforcer is used to determine if locks should prevent a session + LockEnforcer LockEnforcer + // Emitter is used to emit session rejection events + Emitter apievents.Emitter + // Component is the component running the session controller. Nodes and Proxies + // have different flows + Component string + // Logger is used to emit log entries + Logger *logrus.Entry + // TracerProvider creates a tracer so that spans may be emitted + TracerProvider oteltrace.TracerProvider + // ServerID is the UUID of the server + ServerID string + // Clock used in tests to change time + Clock clockwork.Clock + + tracer oteltrace.Tracer +} + +// CheckAndSetDefaults ensures all the required dependencies were +// provided and sets any optional values to their defaults +func (c *SessionControllerConfig) CheckAndSetDefaults() error { + if c.Semaphores == nil { + return trace.BadParameter("Semaphores must be provided") + } + + if c.AccessPoint == nil { + return trace.BadParameter("AccessPoint must be provided") + } + + if c.LockEnforcer == nil { + return trace.BadParameter("LockWatcher must be provided") + } + + if c.Emitter == nil { + return trace.BadParameter("Emitter must be provided") + } + + if c.Component == "" { + return trace.BadParameter("Component must be provided") + } + + if c.TracerProvider == nil { + c.TracerProvider = tracing.DefaultProvider() + } + + if c.Logger == nil { + c.Logger = logrus.WithField(trace.Component, "SessionCtrl") + } + + if c.Clock == nil { + c.Clock = clockwork.NewRealClock() + } + + c.tracer = c.TracerProvider.Tracer("SessionController") + + return nil +} + +// SessionController enforces session control restrictions required by +// locks, private key policy, and max connection limits +type SessionController struct { + cfg SessionControllerConfig +} + +// NewSessionController creates a SessionController from the provided config. If any +// of the required parameters in the SessionControllerConfig are not provided an +// error is returned. +func NewSessionController(cfg SessionControllerConfig) (*SessionController, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &SessionController{cfg: cfg}, nil +} + +// AcquireSessionContext attempts to create a context for the session. If the session is +// not allowed due to session control an error is returned. The returned +// context is scoped to the session and will be canceled in the event the semaphore lock +// is no longer held. The closers provided are immediately closed when the semaphore lock +// is released as well. +func (s *SessionController) AcquireSessionContext(ctx context.Context, identity IdentityContext, localAddr, remoteAddr string, closers ...io.Closer) (context.Context, error) { + // create a separate context for tracing the operations + // within that doesn't leak into the returned context + spanCtx, span := s.cfg.tracer.Start(ctx, "SessionController/AcquireSessionContext") + defer span.End() + + authPref, err := s.cfg.AccessPoint.GetAuthPreference(spanCtx) + if err != nil { + return ctx, trace.Wrap(err) + } + + clusterName, err := s.cfg.AccessPoint.GetClusterName() + if err != nil { + return ctx, trace.Wrap(err) + } + + lockingMode := identity.AccessChecker.LockingMode(authPref.GetLockingMode()) + lockTargets := ComputeLockTargets(clusterName.GetClusterName(), s.cfg.ServerID, identity) + + if lockErr := s.cfg.LockEnforcer.CheckLockInForce(lockingMode, lockTargets...); lockErr != nil { + s.emitRejection(spanCtx, identity.GetUserMetadata(), localAddr, remoteAddr, lockErr.Error(), 0) + return ctx, trace.Wrap(lockErr) + } + + // Check that the required private key policy, defined by roles and auth pref, + // is met by this Identity's ssh certificate. + identityPolicy := identity.Certificate.Extensions[teleport.CertExtensionPrivateKeyPolicy] + requiredPolicy := identity.AccessChecker.PrivateKeyPolicy(authPref.GetPrivateKeyPolicy()) + if err := requiredPolicy.VerifyPolicy(keys.PrivateKeyPolicy(identityPolicy)); err != nil { + return ctx, trace.Wrap(err) + } + + // Don't apply the following checks in non-node contexts. + if s.cfg.Component != teleport.ComponentNode { + return ctx, nil + } + + maxConnections := identity.AccessChecker.MaxConnections() + if maxConnections == 0 { + // concurrent session control is not active, nothing + // else needs to be done here. + return ctx, nil + } + + netConfig, err := s.cfg.AccessPoint.GetClusterNetworkingConfig(spanCtx) + if err != nil { + return ctx, trace.Wrap(err) + } + + semLock, err := services.AcquireSemaphoreLock(spanCtx, services.SemaphoreLockConfig{ + Service: s.cfg.Semaphores, + Clock: s.cfg.Clock, + Expiry: netConfig.GetSessionControlTimeout(), + Params: types.AcquireSemaphoreRequest{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: identity.TeleportUser, + MaxLeases: maxConnections, + Holder: s.cfg.ServerID, + }, + }) + if err != nil { + if strings.Contains(err.Error(), teleport.MaxLeases) { + // user has exceeded their max concurrent ssh connections. + userSessionLimitHitCount.Inc() + s.emitRejection(spanCtx, identity.GetUserMetadata(), localAddr, remoteAddr, events.SessionRejectedEvent, maxConnections) + + return ctx, trace.AccessDenied("too many concurrent ssh connections for user %q (max=%d)", identity.TeleportUser, maxConnections) + } + + return ctx, trace.Wrap(err) + } + + ctx, cancel := context.WithCancel(ctx) + // ensure that losing the lock closes the connection context. Under normal + // conditions, cancellation propagates from the connection context to the + // lock, but if we lose the lock due to some error (e.g. poor connectivity + // to auth server) then cancellation propagates in the other direction. + go func() { + // TODO(fspmarshall): If lock was lost due to error, find a way to propagate + // an error message to user. + <-semLock.Done() + cancel() + + // close any provided closers + for _, closer := range closers { + _ = closer.Close() + } + }() + + return ctx, nil +} + +// emitRejection emits a SessionRejectedEvent with the provided information +func (s *SessionController) emitRejection(ctx context.Context, userMetadata apievents.UserMetadata, localAddr, remoteAddr string, reason string, max int64) { + // link a background context to the current span so things + // are related but while still allowing the audit event to + // not be tied to the request scoped context + emitCtx := oteltrace.ContextWithSpanContext(context.Background(), oteltrace.SpanContextFromContext(ctx)) + + ctx, span := s.cfg.tracer.Start(emitCtx, "SessionController/emitRejection") + defer span.End() + + if err := s.cfg.Emitter.EmitAuditEvent(ctx, &apievents.SessionReject{ + Metadata: apievents.Metadata{ + Type: events.SessionRejectedEvent, + Code: events.SessionRejectedCode, + }, + UserMetadata: userMetadata, + ConnectionMetadata: apievents.ConnectionMetadata{ + Protocol: events.EventProtocolSSH, + LocalAddr: localAddr, + RemoteAddr: remoteAddr, + }, + ServerMetadata: apievents.ServerMetadata{ + ServerID: s.cfg.ServerID, + ServerNamespace: apidefaults.Namespace, + }, + Reason: reason, + Maximum: max, + }); err != nil { + s.cfg.Logger.WithError(err).Warn("Failed to emit session reject event.") + } +} diff --git a/lib/srv/session_control_test.go b/lib/srv/session_control_test.go new file mode 100644 index 0000000000000..9dfd6b144610d --- /dev/null +++ b/lib/srv/session_control_test.go @@ -0,0 +1,408 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package srv + +import ( + "context" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/services" +) + +type mockLockEnforcer struct { + lockInForceErr error +} + +func (m mockLockEnforcer) CheckLockInForce(constants.LockingMode, ...types.LockTarget) error { + return m.lockInForceErr +} + +type mockAccessPoint struct { + AccessPoint + + authPreference types.AuthPreference + clusterName types.ClusterName + netConfig types.ClusterNetworkingConfig +} + +func (m mockAccessPoint) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { + return m.authPreference, nil +} + +func (m mockAccessPoint) GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) { + return m.clusterName, nil +} + +func (m mockAccessPoint) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) { + return m.netConfig, nil +} + +type mockSemaphores struct { + types.Semaphores + + lease *types.SemaphoreLease + acquireErr error +} + +func (m mockSemaphores) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) { + return m.lease, m.acquireErr +} + +func (m mockSemaphores) CancelSemaphoreLease(ctx context.Context, lease types.SemaphoreLease) error { + return nil +} + +type mockAccessChecker struct { + services.AccessChecker + + lockMode constants.LockingMode + maxConnections int64 + keyPolicy keys.PrivateKeyPolicy + roleNames []string +} + +func (m mockAccessChecker) LockingMode(defaultMode constants.LockingMode) constants.LockingMode { + return m.lockMode +} + +func (m mockAccessChecker) MaxConnections() int64 { + return m.maxConnections +} + +func (m mockAccessChecker) PrivateKeyPolicy(defaultPolicy keys.PrivateKeyPolicy) keys.PrivateKeyPolicy { + return m.keyPolicy +} +func (m mockAccessChecker) RoleNames() []string { + return m.roleNames +} + +func TestSessionController_AcquireSessionContext(t *testing.T) { + t.Parallel() + + clock := clockwork.NewFakeClock() + emitter := &eventstest.MockEmitter{} + + cases := []struct { + name string + cfg SessionControllerConfig + identity IdentityContext + assertion func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) + }{ + { + name: "proxy: access allowed", + cfg: SessionControllerConfig{ + Semaphores: mockSemaphores{}, + AccessPoint: mockAccessPoint{ + netConfig: &types.ClusterNetworkingConfigV2{}, + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentProxy, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.NoError(t, err) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events()) + }, + }, + { + name: "node: access allowed", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{ + lease: &types.SemaphoreLease{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: "test", + LeaseID: "1", + Expires: clock.Now().Add(time.Minute), + }, + }, + AccessPoint: mockAccessPoint{ + netConfig: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + SessionControlTimeout: types.NewDuration(time.Minute), + }, + }, + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.NoError(t, err) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events()) + }, + }, + { + name: "session rejected due to lock", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{}, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{ + lockInForceErr: trace.AccessDenied("lock in force"), + }, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.ErrorIs(t, err, trace.AccessDenied("lock in force")) + require.NotNil(t, ctx) + require.Len(t, emitter.Events(), 1) + + evt, ok := emitter.Events()[0].(*apievents.SessionReject) + require.True(t, ok) + require.Equal(t, events.SessionRejectedEvent, evt.Metadata.Type) + require.Equal(t, events.SessionRejectedCode, evt.Metadata.Code) + require.Equal(t, events.EventProtocolSSH, evt.ConnectionMetadata.Protocol) + require.Equal(t, "lock in force", evt.Reason) + }, + }, + { + name: "session rejected due to private key policy", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{}, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyHardwareKey, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.Error(t, err) + require.True(t, trace.IsBadParameter(err)) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events()) + }, + }, + { + name: "session rejected due to connection limit", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{ + acquireErr: trace.LimitExceeded(teleport.MaxLeases), + }, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + netConfig: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + SessionControlTimeout: types.NewDuration(time.Minute), + }, + }, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + require.NotNil(t, ctx) + require.Len(t, emitter.Events(), 1) + + evt, ok := emitter.Events()[0].(*apievents.SessionReject) + require.True(t, ok) + require.Equal(t, events.SessionRejectedEvent, evt.Metadata.Type) + require.Equal(t, events.SessionRejectedCode, evt.Metadata.Code) + require.Equal(t, events.EventProtocolSSH, evt.ConnectionMetadata.Protocol) + require.Equal(t, events.SessionRejectedEvent, evt.Reason) + require.Equal(t, int64(1), evt.Maximum) + }, + }, + { + name: "no connection limits prevent acquiring semaphore lock", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{ + acquireErr: trace.LimitExceeded(teleport.MaxLeases), + }, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + netConfig: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + SessionControlTimeout: types.NewDuration(time.Minute), + }, + }, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 0, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.NoError(t, err) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events(), 0) + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + emitter.Reset() + ctrl, err := NewSessionController(tt.cfg) + require.NoError(t, err) + + ctx, err := ctrl.AcquireSessionContext(context.Background(), tt.identity, "127.0.0.1:1", "127.0.0.1:2") + tt.assertion(t, ctx, err, emitter) + + }) + } +} diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index f3911e5cf324f..6d3d15d6c63f9 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -252,9 +252,20 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { }) require.NoError(t, err) + nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: nodeLockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: nodeID, + }) + require.NoError(t, err) + // create SSH service: nodeDataDir := t.TempDir() node, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{signer}, @@ -272,6 +283,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(s.clock), regular.SetLockWatcher(nodeLockWatcher), + regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) s.node = node @@ -347,8 +359,19 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { }) require.NoError(t, err) + proxySessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: s.proxyClient, + AccessPoint: s.proxyClient, + LockEnforcer: proxyLockWatcher, + Emitter: s.proxyClient, + Component: teleport.ComponentProxy, + ServerID: proxyID, + }) + require.NoError(t, err) + // proxy server: s.proxy, err = regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{signer}, @@ -366,6 +389,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { regular.SetClock(s.clock), regular.SetLockWatcher(proxyLockWatcher), regular.SetNodeWatcher(proxyNodeWatcher), + regular.SetSessionController(proxySessionController), ) require.NoError(t, err) @@ -5907,9 +5931,20 @@ func newWebPack(t *testing.T, numProxies int) *webPack { require.NoError(t, err) t.Cleanup(nodeLockWatcher.Close) + nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: nodeLockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: nodeID, + }) + require.NoError(t, err) + // create SSH service: nodeDataDir := t.TempDir() node, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, server.TLS.ClusterName(), hostSigners, @@ -5927,6 +5962,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack { regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(clock), regular.SetLockWatcher(nodeLockWatcher), + regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) @@ -6034,7 +6070,18 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula }) require.NoError(t, err) + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: client, + AccessPoint: client, + LockEnforcer: proxyLockWatcher, + Emitter: client, + Component: teleport.ComponentProxy, + ServerID: proxyID, + }) + require.NoError(t, err) + proxyServer, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, authServer.ClusterName(), hostSigners, @@ -6052,6 +6099,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula regular.SetClock(clock), regular.SetLockWatcher(proxyLockWatcher), regular.SetNodeWatcher(proxyNodeWatcher), + regular.SetSessionController(sessionController), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, proxyServer.Close()) })