diff --git a/integration/utmp_integration_test.go b/integration/utmp_integration_test.go index 0162a12508dee..61cd3c1956532 100644 --- a/integration/utmp_integration_test.go +++ b/integration/utmp_integration_test.go @@ -40,6 +40,7 @@ import ( "github.com/gravitational/teleport/lib/pam" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/regular" "github.com/gravitational/teleport/lib/srv/uacc" "github.com/gravitational/teleport/lib/sshutils" @@ -257,8 +258,19 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx { require.NoError(t, err) t.Cleanup(lockWatcher.Close) + nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: s.nodeClient, + AccessPoint: s.nodeClient, + LockEnforcer: lockWatcher, + Emitter: s.nodeClient, + Component: teleport.ComponentNode, + ServerID: s.nodeID, + }) + require.NoError(t, err) + nodeDir := t.TempDir() srv, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{s.signer}, @@ -286,6 +298,7 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx { regular.SetClock(s.clock), regular.SetUtmpPath(utmpPath, utmpPath), regular.SetLockWatcher(lockWatcher), + regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) s.srv = srv diff --git a/lib/client/client.go b/lib/client/client.go index f899e580a81e5..9c23925691b40 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -23,6 +23,7 @@ import ( "crypto/x509" "encoding/json" "errors" + "fmt" "io" "net" "os" @@ -84,7 +85,6 @@ type NodeClient struct { Namespace string Tracer oteltrace.Tracer Client *tracessh.Client - Proxy *ProxyClient TC *TeleportClient OnMFA func() FIPSEnabled bool @@ -1595,6 +1595,7 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeDet return nil, trace.ConnectionProblem(err, "failed connecting to node %v. %s", nodeName(nodeAddress.Addr), serverErrorMsg) } + pipeNetConn := utils.NewPipeNetConn( proxyReader, proxyWriter, @@ -1608,35 +1609,9 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeDet Auth: authMethods, HostKeyCallback: proxy.hostKeyCallback, } - conn, chans, reqs, err := newClientConn(ctx, pipeNetConn, nodeAddress.ProxyFormat(), sshConfig) - if err != nil { - if utils.IsHandshakeFailedError(err) { - proxySession.Close() - return nil, trace.AccessDenied(`access denied to %v connecting to %v`, user, nodeAddress) - } - return nil, trace.Wrap(err) - } - - // We pass an empty channel which we close right away to ssh.NewClient - // because the client need to handle requests itself. - emptyCh := make(chan *ssh.Request) - close(emptyCh) - - nc := &NodeClient{ - Client: tracessh.NewClient(conn, chans, emptyCh), - Proxy: proxy, - Namespace: apidefaults.Namespace, - TC: proxy.teleportClient, - Tracer: proxy.Tracer, - FIPSEnabled: details.FIPSEnabled, - } - // Start a goroutine that will run for the duration of the client to process - // global requests from the client. Teleport clients will use this to update - // terminal sizes when the remote PTY size has changed. - go nc.handleGlobalRequests(ctx, reqs) - - return nc, nil + nc, err := NewNodeClient(ctx, sshConfig, pipeNetConn, nodeAddress.ProxyFormat(), proxy.teleportClient, details.FIPSEnabled) + return nc, trace.Wrap(err) } // PortForwardToNode connects to the ssh server via Proxy @@ -1680,11 +1655,28 @@ func (proxy *ProxyClient) PortForwardToNode(ctx context.Context, nodeAddress Nod Auth: authMethods, HostKeyCallback: proxy.hostKeyCallback, } - conn, chans, reqs, err := newClientConn(ctx, proxyConn, nodeAddress.Addr, sshConfig) + + nc, err := NewNodeClient(ctx, sshConfig, proxyConn, nodeAddress.Addr, proxy.teleportClient, details.FIPSEnabled) + return nc, trace.Wrap(err) +} + +// NewNodeClient constructs a NodeClient that is connected to the node at nodeAddress +func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Conn, nodeAddress string, tc *TeleportClient, fipsEnabled bool) (*NodeClient, error) { + ctx, span := tc.Tracer.Start( + ctx, + "NewNodeClient", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.String("node", nodeAddress), + ), + ) + defer span.End() + + sshconn, chans, reqs, err := newClientConn(ctx, conn, nodeAddress, sshConfig) if err != nil { if utils.IsHandshakeFailedError(err) { - proxyConn.Close() - return nil, trace.AccessDenied(`access denied to %v connecting to %v`, user, nodeAddress) + conn.Close() + return nil, trace.AccessDenied(`access denied to %v connecting to %v`, sshConfig.User, nodeAddress) } return nil, trace.Wrap(err) } @@ -1695,11 +1687,11 @@ func (proxy *ProxyClient) PortForwardToNode(ctx context.Context, nodeAddress Nod close(emptyCh) nc := &NodeClient{ - Client: tracessh.NewClient(conn, chans, emptyCh), - Proxy: proxy, - Namespace: apidefaults.Namespace, - TC: proxy.teleportClient, - Tracer: proxy.Tracer, + Client: tracessh.NewClient(sshconn, chans, emptyCh), + Namespace: apidefaults.Namespace, + TC: tc, + Tracer: tc.Tracer, + FIPSEnabled: fipsEnabled, } // Start a goroutine that will run for the duration of the client to process @@ -1710,6 +1702,56 @@ func (proxy *ProxyClient) PortForwardToNode(ctx context.Context, nodeAddress Nod return nc, nil } +// RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr +// to and from the node and local shell. This will block until the interactive shell on the node +// is terminated. +func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker) error { + ctx, span := c.Tracer.Start( + ctx, + "nodeClient/RunInteractiveShell", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + + env := make(map[string]string) + env[teleport.EnvSSHJoinMode] = string(mode) + env[teleport.EnvSSHSessionReason] = c.TC.Config.Reason + env[teleport.EnvSSHSessionDisplayParticipantRequirements] = strconv.FormatBool(c.TC.Config.DisplayParticipantRequirements) + encoded, err := json.Marshal(&c.TC.Config.Invited) + if err != nil { + return trace.Wrap(err) + } + + env[teleport.EnvSSHSessionInvited] = string(encoded) + for key, value := range c.TC.Env { + env[key] = value + } + + nodeSession, err := newSession(ctx, c, sessToJoin, env, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, c.TC.EnableEscapeSequences) + if err != nil { + return trace.Wrap(err) + } + + if err = nodeSession.runShell(ctx, mode, nil, c.TC.OnShellCreated); err != nil { + switch e := trace.Unwrap(err).(type) { + case *ssh.ExitError: + c.TC.ExitStatus = e.ExitStatus() + case *ssh.ExitMissingError: + c.TC.ExitStatus = 1 + } + + return trace.Wrap(err) + } + + if nodeSession.ExitMsg == "" { + fmt.Fprintln(c.TC.Stderr, "the connection was closed on the remote side at ", time.Now().Format(time.RFC822)) + } else { + fmt.Fprintln(c.TC.Stderr, nodeSession.ExitMsg) + } + + return nil +} + func (c *NodeClient) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh.Request) { for { select { diff --git a/lib/client/session.go b/lib/client/session.go index 2e39bc33520da..120f92b840609 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -246,7 +246,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi // if agent forwarding was requested (and we have a agent to forward), // forward the agent to endpoint. - tc := ns.nodeClient.Proxy.teleportClient + tc := ns.nodeClient.TC targetAgent := selectKeyAgent(tc) if targetAgent != nil { diff --git a/lib/service/service.go b/lib/service/service.go index 53c1f2d49b0f5..10850cc8aec67 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2285,7 +2285,29 @@ func (process *TeleportProcess) initSSH() error { storagePresence := local.NewPresenceService(process.storage.BackendStorage) - s, err := regular.New(cfg.SSH.Addr, + // read the host UUID: + serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir) + if err != nil { + return trace.Wrap(err) + } + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: authClient, + AccessPoint: authClient, + LockEnforcer: lockWatcher, + Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer}, + Component: teleport.ComponentNode, + Logger: process.log.WithField(trace.Component, "sessionctrl"), + TracerProvider: process.TracingProvider, + ServerID: serverID, + }) + if err != nil { + return trace.Wrap(err) + } + + s, err := regular.New( + process.ExitContext(), + cfg.SSH.Addr, cfg.Hostname, []ssh.Signer{conn.ServerIdentity.KeySigner}, authClient, @@ -2317,6 +2339,8 @@ func (process *TeleportProcess) initSSH() error { regular.SetCreateHostUser(!cfg.SSH.DisableCreateHostUser), regular.SetStoragePresenceService(storagePresence), regular.SetInventoryControlHandle(process.inventoryHandle), + regular.SetTracerProvider(process.TracingProvider), + regular.SetSessionController(sessionController), ) if err != nil { return trace.Wrap(err) @@ -3447,6 +3471,42 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } } + var proxyRouter *proxy.Router + if !process.Config.Proxy.DisableReverseTunnel { + router, err := proxy.NewRouter(proxy.RouterConfig{ + ClusterName: clusterName, + Log: process.log.WithField(trace.Component, "router"), + RemoteClusterGetter: accessPoint, + SiteGetter: tsrv, + TracerProvider: process.TracingProvider, + }) + if err != nil { + return trace.Wrap(err) + } + + proxyRouter = router + } + + // read the host UUID: + serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir) + if err != nil { + return trace.Wrap(err) + } + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: accessPoint, + AccessPoint: accessPoint, + LockEnforcer: lockWatcher, + Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer}, + Component: teleport.ComponentProxy, + Logger: process.log.WithField(trace.Component, "sessionctrl"), + TracerProvider: process.TracingProvider, + ServerID: serverID, + }) + if err != nil { + return trace.Wrap(err) + } + // Register web proxy server alpnHandlerForWeb := &alpnproxy.ConnectionHandlerWrapper{} var webServer *http.Server @@ -3512,6 +3572,8 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { ALPNHandler: alpnHandlerForWeb.HandleConnection, ProxyKubeAddr: proxyKubeAddr, TraceClient: traceClt, + Router: proxyRouter, + SessionControl: sessionController, } webHandler, err = web.NewHandler(webConfig) if err != nil { @@ -3609,22 +3671,9 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { }) } - var proxyRouter *proxy.Router - if !process.Config.Proxy.DisableReverseTunnel { - router, err := proxy.NewRouter(proxy.RouterConfig{ - ClusterName: clusterName, - Log: process.log.WithField(trace.Component, "router"), - RemoteClusterGetter: accessPoint, - SiteGetter: tsrv, - TracerProvider: process.TracingProvider, - }) - if err != nil { - return trace.Wrap(err) - } - proxyRouter = router - } - - sshProxy, err := regular.New(cfg.Proxy.SSHAddr, + sshProxy, err := regular.New( + process.ExitContext(), + cfg.SSH.Addr, cfg.Hostname, []ssh.Signer{conn.ServerIdentity.KeySigner}, accessPoint, @@ -3648,6 +3697,8 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // accurately checked later when an SCP/SFTP request hits the // destination Node. regular.SetAllowFileCopying(true), + regular.SetTracerProvider(process.TracingProvider), + regular.SetSessionController(sessionController), ) if err != nil { return trace.Wrap(err) diff --git a/lib/services/semaphore.go b/lib/services/semaphore.go index cc72f67930e89..78df632e1888b 100644 --- a/lib/services/semaphore.go +++ b/lib/services/semaphore.go @@ -20,6 +20,7 @@ import ( "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" @@ -39,10 +40,16 @@ type SemaphoreLockConfig struct { TickRate time.Duration // Params holds the semaphore lease acquisition parameters. Params types.AcquireSemaphoreRequest + // Clock used to alter time in tests + Clock clockwork.Clock } // CheckAndSetDefaults checks and sets default parameters func (l *SemaphoreLockConfig) CheckAndSetDefaults() error { + if l.Clock == nil { + l.Clock = clockwork.NewRealClock() + } + if l.Service == nil { return trace.BadParameter("missing semaphore service") } @@ -59,7 +66,7 @@ func (l *SemaphoreLockConfig) CheckAndSetDefaults() error { return trace.BadParameter("tick-rate must be less than expiry") } if l.Params.Expires.IsZero() { - l.Params.Expires = time.Now().UTC().Add(l.Expiry) + l.Params.Expires = l.Clock.Now().UTC().Add(l.Expiry) } if err := l.Params.Check(); err != nil { return trace.Wrap(err) @@ -73,7 +80,7 @@ type SemaphoreLock struct { cfg SemaphoreLockConfig lease0 types.SemaphoreLease retry retryutils.Retry - ticker *time.Ticker + ticker clockwork.Ticker doneC chan struct{} closeOnce sync.Once renewalC chan struct{} @@ -140,7 +147,7 @@ func (l *SemaphoreLock) keepAlive(ctx context.Context) { // cancellation/expiry. return } - if lease.Expires.After(time.Now().UTC()) { + if lease.Expires.After(l.cfg.Clock.Now().UTC()) { // parent context is closed. create orphan context with generous // timeout for lease cancellation scope. this will not block any // caller that is not explicitly waiting on the final error value. @@ -157,7 +164,7 @@ func (l *SemaphoreLock) keepAlive(ctx context.Context) { Outer: for { select { - case tick := <-l.ticker.C: + case tick := <-l.ticker.Chan(): leaseContext, leaseCancel := context.WithDeadline(ctx, lease.Expires) nextLease := lease nextLease.Expires = tick.Add(l.cfg.Expiry) @@ -185,7 +192,7 @@ Outer: l.retry.Inc() select { case <-l.retry.After(): - case tick = <-l.ticker.C: + case tick = <-l.ticker.Chan(): // check to make sure that we still have some time on the lease. the default tick rate would have // us waking _as_ the lease expires here, but if we're working with a higher tick rate, its worth // retrying again. @@ -247,6 +254,7 @@ func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*Semaph Max: cfg.Expiry / 4, Step: cfg.Expiry / 16, Jitter: retryutils.NewJitter(), + Clock: cfg.Clock, }) if err != nil { return nil, trace.Wrap(err) @@ -259,7 +267,7 @@ func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*Semaph cfg: cfg, lease0: *lease, retry: retry, - ticker: time.NewTicker(cfg.TickRate), + ticker: cfg.Clock.NewTicker(cfg.TickRate), doneC: make(chan struct{}), renewalC: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index 936325244fccd..086ee2847ffd5 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -154,7 +154,7 @@ func (h *AuthHandlers) CreateIdentityContext(sconn *ssh.ServerConn) (IdentityCon } identity.Impersonator = certificate.Extensions[teleport.CertExtensionImpersonator] - accessRequestIDs, err := parseAccessRequestIDs(certificate.Extensions[teleport.CertExtensionTeleportActiveRequests]) + accessRequestIDs, err := ParseAccessRequestIDs(certificate.Extensions[teleport.CertExtensionTeleportActiveRequests]) if err != nil { return IdentityContext{}, trace.Wrap(err) } @@ -626,7 +626,7 @@ type AccessRequests struct { IDs []string `json:"access_requests"` } -func parseAccessRequestIDs(str string) ([]string, error) { +func ParseAccessRequestIDs(str string) ([]string, error) { var accessRequestIDs []string var ar AccessRequests diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 6a8c2d9cd948d..18e794204f5b9 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -435,13 +435,15 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s trace.ComponentFields: fields, }) - lockTargets, err := ComputeLockTargets(srv, identityContext) + clusterName, err := srv.GetAccessPoint().GetClusterName() if err != nil { - return nil, nil, trace.Wrap(err) + childErr := child.Close() + return nil, nil, trace.NewAggregate(err, childErr) } + monitorConfig := MonitorConfig{ LockWatcher: child.srv.GetLockWatcher(), - LockTargets: lockTargets, + LockTargets: ComputeLockTargets(clusterName.GetClusterName(), srv.HostUUID(), identityContext), LockingMode: identityContext.AccessChecker.LockingMode(authPref.GetLockingMode()), DisconnectExpiredCert: child.disconnectExpiredCert, ClientIdleTimeout: child.clientIdleTimeout, @@ -1147,28 +1149,19 @@ func newUaccMetadata(c *ServerContext) (*UaccMetadata, error) { }, nil } -// ComputeLockTargets computes lock targets inferred from a Server -// and an IdentityContext. -func ComputeLockTargets(s Server, id IdentityContext) ([]types.LockTarget, error) { - clusterName, err := s.GetAccessPoint().GetClusterName() - if err != nil { - return nil, trace.Wrap(err) - } +// ComputeLockTargets computes lock targets inferred from the clusterName, serverID and IdentityContext. +func ComputeLockTargets(clusterName, serverID string, id IdentityContext) []types.LockTarget { lockTargets := []types.LockTarget{ {User: id.TeleportUser}, {Login: id.Login}, - {Node: s.HostUUID()}, - {Node: auth.HostFQDN(s.HostUUID(), clusterName.GetClusterName())}, + {Node: serverID}, + {Node: auth.HostFQDN(serverID, clusterName)}, {MFADevice: id.Certificate.Extensions[teleport.CertExtensionMFAVerified]}, } roles := apiutils.Deduplicate(append(id.AccessChecker.RoleNames(), id.UnmappedRoles...)) - lockTargets = append(lockTargets, - services.RolesToLockTargets(roles)..., - ) - lockTargets = append(lockTargets, - services.AccessRequestsToLockTargets(id.ActiveRequests)..., - ) - return lockTargets, nil + lockTargets = append(lockTargets, services.RolesToLockTargets(roles)...) + lockTargets = append(lockTargets, services.AccessRequestsToLockTargets(id.ActiveRequests)...) + return lockTargets } // SetRequest sets the ssh request that was issued by the client. diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index b07c5c79003c6..01ed7820198f9 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" oteltrace "go.opentelemetry.io/otel/trace" @@ -45,7 +44,6 @@ import ( tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/defaults" @@ -53,7 +51,6 @@ import ( "github.com/gravitational/teleport/lib/inventory" "github.com/gravitational/teleport/lib/labels" "github.com/gravitational/teleport/lib/limiter" - "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/pam" "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" @@ -73,13 +70,6 @@ var ( log = logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentNode, }) - - userSessionLimitHitCount = prometheus.NewCounter( - prometheus.CounterOpts{ - Name: teleport.MetricUserMaxConcurrentSessionsHit, - Help: "Number of times a user exceeded their max concurrent ssh connections", - }, - ) ) // Server implements SSH server that uses configuration backend and @@ -228,6 +218,10 @@ type Server struct { // router used by subsystem requests to connect to nodes // and clusters router *proxy.Router + + // sessionController is used to restrict new sessions + // based on locks and cluster preferences + sessionController *srv.SessionController } // TargetMetadata returns metadata about the server. @@ -658,8 +652,18 @@ func SetTracerProvider(provider oteltrace.TracerProvider) ServerOption { } } +// SetSessionController sets the session controller. +func SetSessionController(controller *srv.SessionController) ServerOption { + return func(s *Server) error { + s.sessionController = controller + return nil + } +} + // New returns an unstarted server -func New(addr utils.NetAddr, +func New( + ctx context.Context, + addr utils.NetAddr, hostname string, signers []ssh.Signer, authService srv.AccessPoint, @@ -669,18 +673,13 @@ func New(addr utils.NetAddr, auth auth.ClientI, options ...ServerOption, ) (*Server, error) { - err := metrics.RegisterPrometheusCollectors(userSessionLimitHitCount) - if err != nil { - return nil, trace.Wrap(err) - } - // read the host UUID: uuid, err := utils.ReadOrMakeHostUUID(dataDir) if err != nil { return nil, trace.Wrap(err) } - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(ctx) s := &Server{ addr: addr, authService: authService, @@ -723,6 +722,10 @@ func New(addr utils.NetAddr, return nil, trace.BadParameter("setup valid LockWatcher parameter using SetLockWatcher") } + if s.sessionController == nil { + return nil, trace.BadParameter("setup valid SessionControl parameter using SetSessionControl") + } + if s.connectedProxyGetter == nil { s.connectedProxyGetter = reversetunnel.NewConnectedProxyGetter() } @@ -1070,104 +1073,9 @@ func (s *Server) HandleNewConn(ctx context.Context, ccx *sshutils.ConnectionCont if err != nil { return ctx, trace.Wrap(err) } - authPref, err := s.GetAccessPoint().GetAuthPreference(ctx) - if err != nil { - return ctx, trace.Wrap(err) - } - lockingMode := identityContext.AccessChecker.LockingMode(authPref.GetLockingMode()) - - event := &apievents.SessionReject{ - Metadata: apievents.Metadata{ - Type: events.SessionRejectedEvent, - Code: events.SessionRejectedCode, - }, - UserMetadata: identityContext.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - Protocol: events.EventProtocolSSH, - LocalAddr: ccx.ServerConn.LocalAddr().String(), - RemoteAddr: ccx.ServerConn.RemoteAddr().String(), - }, - ServerMetadata: apievents.ServerMetadata{ - ServerID: s.uuid, - ServerNamespace: s.GetNamespace(), - }, - } - lockTargets, err := srv.ComputeLockTargets(s, identityContext) - if err != nil { - return ctx, trace.Wrap(err) - } - if lockErr := s.lockWatcher.CheckLockInForce(lockingMode, lockTargets...); lockErr != nil { - event.Reason = lockErr.Error() - if err := s.EmitAuditEvent(s.ctx, event); err != nil { - s.Logger.WithError(err).Warn("Failed to emit session reject event.") - } - return ctx, trace.Wrap(lockErr) - } - - // Check that the required private key policy, defined by roles and auth pref, - // is met by this Identity's ssh certificate. - identityPolicy := identityContext.Certificate.Extensions[teleport.CertExtensionPrivateKeyPolicy] - requiredPolicy := identityContext.AccessChecker.PrivateKeyPolicy(authPref.GetPrivateKeyPolicy()) - if err := requiredPolicy.VerifyPolicy(keys.PrivateKeyPolicy(identityPolicy)); err != nil { - return nil, trace.Wrap(err) - } - - // Don't apply the following checks in non-node contexts. - if s.Component() != teleport.ComponentNode { - return ctx, nil - } - - maxConnections := identityContext.AccessChecker.MaxConnections() - if maxConnections == 0 { - // concurrent session control is not active, nothing - // else needs to be done here. - return ctx, nil - } - - netConfig, err := s.GetAccessPoint().GetClusterNetworkingConfig(ctx) - if err != nil { - return ctx, trace.Wrap(err) - } - - semLock, err := services.AcquireSemaphoreLock(ctx, services.SemaphoreLockConfig{ - Service: s.authService, - Expiry: netConfig.GetSessionControlTimeout(), - Params: types.AcquireSemaphoreRequest{ - SemaphoreKind: types.SemaphoreKindConnection, - SemaphoreName: identityContext.TeleportUser, - MaxLeases: maxConnections, - Holder: s.uuid, - }, - }) - if err != nil { - if strings.Contains(err.Error(), teleport.MaxLeases) { - // user has exceeded their max concurrent ssh connections. - userSessionLimitHitCount.Inc() - event.Reason = events.SessionRejectedEvent - event.Maximum = maxConnections - if err := s.EmitAuditEvent(s.ctx, event); err != nil { - s.Logger.WithError(err).Warn("Failed to emit session reject event.") - } - err = trace.AccessDenied("too many concurrent ssh connections for user %q (max=%d)", - identityContext.TeleportUser, - maxConnections, - ) - } - return ctx, trace.Wrap(err) - } - - // ensure that losing the lock closes the connection context. Under normal - // conditions, cancellation propagates from the connection context to the - // lock, but if we lose the lock due to some error (e.g. poor connectivity - // to auth server) then cancellation propagates in the other direction. - go func() { - // TODO(fspmarshall): If lock was lost due to error, find a way to propagate - // an error message to user. - <-semLock.Done() - ccx.Close() - }() - return ctx, nil + ctx, err = s.sessionController.AcquireSessionContext(ctx, identityContext, ccx.ServerConn.LocalAddr().String(), ccx.ServerConn.RemoteAddr().String(), ccx) + return ctx, trace.Wrap(err) } // HandleNewChan is called when new channel is opened diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index a67d1e7ee8c08..e15ff014b8c2d 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -213,6 +213,18 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO require.NoError(t, err) t.Cleanup(func() { require.NoError(t, nodeClient.Close()) }) + lockWatcher := newLockWatcher(ctx, t, nodeClient) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: lockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: nodeID, + }) + require.NoError(t, err) + nodeDir := t.TempDir() serverOptions := []ServerOption{ SetUUID(nodeID), @@ -232,13 +244,15 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO SetBPF(&bpf.NOP{}), SetRestrictedSessionManager(&restricted.NOP{}), SetClock(clock), - SetLockWatcher(newLockWatcher(ctx, t, nodeClient)), + SetLockWatcher(lockWatcher), SetX11ForwardingConfig(&x11.ServerConfig{}), + SetSessionController(sessionController), } serverOptions = append(serverOptions, sshOpts...) sshSrv, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, testServer.ClusterName(), []ssh.Signer{signer}, @@ -1408,7 +1422,18 @@ func TestProxyRoundRobin(t *testing.T) { }) require.NoError(t, err) + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: proxyClient, + AccessPoint: proxyClient, + LockEnforcer: lockWatcher, + Emitter: proxyClient, + Component: teleport.ComponentNode, + ServerID: hostID, + }) + require.NoError(t, err) + proxy, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -1426,6 +1451,7 @@ func TestProxyRoundRobin(t *testing.T) { SetClock(f.clock), SetLockWatcher(lockWatcher), SetNodeWatcher(nodeWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1537,7 +1563,18 @@ func TestProxyDirectAccess(t *testing.T) { }) require.NoError(t, err) + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: lockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: hostID, + }) + require.NoError(t, err) + proxy, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -1555,6 +1592,7 @@ func TestProxyDirectAccess(t *testing.T) { SetClock(f.clock), SetLockWatcher(lockWatcher), SetNodeWatcher(nodeWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1705,8 +1743,22 @@ func TestLimiter(t *testing.T) { require.NoError(t, err) nodeClient, _ := newNodeClient(t, f.testSrv) + + lockWatcher := newLockWatcher(ctx, t, nodeClient) + + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: lockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: hostID, + }) + require.NoError(t, err) + nodeStateDir := t.TempDir() srv, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -1723,7 +1775,8 @@ func TestLimiter(t *testing.T) { SetBPF(&bpf.NOP{}), SetRestrictedSessionManager(&restricted.NOP{}), SetClock(f.clock), - SetLockWatcher(newLockWatcher(ctx, t, nodeClient)), + SetLockWatcher(lockWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, srv.Start()) @@ -2276,7 +2329,18 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { }) require.NoError(t, err) + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: proxyClient, + AccessPoint: proxyClient, + LockEnforcer: lockWatcher, + Emitter: proxyClient, + Component: teleport.ComponentProxy, + ServerID: hostID, + }) + require.NoError(t, err) + proxy, err := New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, f.testSrv.ClusterName(), []ssh.Signer{f.signer}, @@ -2294,6 +2358,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { SetClock(f.clock), SetLockWatcher(lockWatcher), SetNodeWatcher(nodeWatcher), + SetSessionController(sessionController), ) require.NoError(t, err) require.NoError(t, proxy.Start()) diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index a0528b1b1cee8..b2e93a09b6434 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -80,7 +80,7 @@ func TestParseAccessRequestIDs(t *testing.T) { } for _, tt := range testCases { t.Run(tt.comment, func(t *testing.T) { - out, err := parseAccessRequestIDs(tt.input) + out, err := ParseAccessRequestIDs(tt.input) tt.assertErr(t, err) require.Equal(t, out, tt.result) }) diff --git a/lib/srv/session_control.go b/lib/srv/session_control.go new file mode 100644 index 0000000000000..92920b51512fc --- /dev/null +++ b/lib/srv/session_control.go @@ -0,0 +1,268 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package srv + +import ( + "context" + "io" + "strings" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/observability/metrics" + "github.com/gravitational/teleport/lib/services" +) + +var ( + userSessionLimitHitCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: teleport.MetricUserMaxConcurrentSessionsHit, + Help: "Number of times a user exceeded their max concurrent ssh connections", + }, + ) +) + +func init() { + _ = metrics.RegisterPrometheusCollectors(userSessionLimitHitCount) +} + +// LockEnforcer determines whether a lock is being enforced on the provided targets +type LockEnforcer interface { + CheckLockInForce(mode constants.LockingMode, targets ...types.LockTarget) error +} + +// SessionControllerConfig contains dependencies needed to +// create a SessionController +type SessionControllerConfig struct { + // Semaphores is used to obtain a semaphore lock when max sessions are defined + Semaphores types.Semaphores + // AccessPoint is the cache used to get cluster information + AccessPoint AccessPoint + // LockEnforcer is used to determine if locks should prevent a session + LockEnforcer LockEnforcer + // Emitter is used to emit session rejection events + Emitter apievents.Emitter + // Component is the component running the session controller. Nodes and Proxies + // have different flows + Component string + // Logger is used to emit log entries + Logger *logrus.Entry + // TracerProvider creates a tracer so that spans may be emitted + TracerProvider oteltrace.TracerProvider + // ServerID is the UUID of the server + ServerID string + // Clock used in tests to change time + Clock clockwork.Clock + + tracer oteltrace.Tracer +} + +// CheckAndSetDefaults ensures all the required dependencies were +// provided and sets any optional values to their defaults +func (c *SessionControllerConfig) CheckAndSetDefaults() error { + if c.Semaphores == nil { + return trace.BadParameter("Semaphores must be provided") + } + + if c.AccessPoint == nil { + return trace.BadParameter("AccessPoint must be provided") + } + + if c.LockEnforcer == nil { + return trace.BadParameter("LockWatcher must be provided") + } + + if c.Emitter == nil { + return trace.BadParameter("Emitter must be provided") + } + + if c.Component == "" { + return trace.BadParameter("Component must be provided") + } + + if c.TracerProvider == nil { + c.TracerProvider = tracing.DefaultProvider() + } + + if c.Logger == nil { + c.Logger = logrus.WithField(trace.Component, "SessionCtrl") + } + + if c.Clock == nil { + c.Clock = clockwork.NewRealClock() + } + + c.tracer = c.TracerProvider.Tracer("SessionController") + + return nil +} + +// SessionController enforces session control restrictions required by +// locks, private key policy, and max connection limits +type SessionController struct { + cfg SessionControllerConfig +} + +// NewSessionController creates a SessionController from the provided config. If any +// of the required parameters in the SessionControllerConfig are not provided an +// error is returned. +func NewSessionController(cfg SessionControllerConfig) (*SessionController, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &SessionController{cfg: cfg}, nil +} + +// AcquireSessionContext attempts to create a context for the session. If the session is +// not allowed due to session control an error is returned. The returned +// context is scoped to the session and will be canceled in the event the semaphore lock +// is no longer held. The closers provided are immediately closed when the semaphore lock +// is released as well. +func (s *SessionController) AcquireSessionContext(ctx context.Context, identity IdentityContext, localAddr, remoteAddr string, closers ...io.Closer) (context.Context, error) { + // create a separate context for tracing the operations + // within that doesn't leak into the returned context + spanCtx, span := s.cfg.tracer.Start(ctx, "SessionController/AcquireSessionContext") + defer span.End() + + authPref, err := s.cfg.AccessPoint.GetAuthPreference(spanCtx) + if err != nil { + return ctx, trace.Wrap(err) + } + + clusterName, err := s.cfg.AccessPoint.GetClusterName() + if err != nil { + return ctx, trace.Wrap(err) + } + + lockingMode := identity.AccessChecker.LockingMode(authPref.GetLockingMode()) + lockTargets := ComputeLockTargets(clusterName.GetClusterName(), s.cfg.ServerID, identity) + + if lockErr := s.cfg.LockEnforcer.CheckLockInForce(lockingMode, lockTargets...); lockErr != nil { + s.emitRejection(spanCtx, identity.GetUserMetadata(), localAddr, remoteAddr, lockErr.Error(), 0) + return ctx, trace.Wrap(lockErr) + } + + // Check that the required private key policy, defined by roles and auth pref, + // is met by this Identity's ssh certificate. + identityPolicy := identity.Certificate.Extensions[teleport.CertExtensionPrivateKeyPolicy] + requiredPolicy := identity.AccessChecker.PrivateKeyPolicy(authPref.GetPrivateKeyPolicy()) + if err := requiredPolicy.VerifyPolicy(keys.PrivateKeyPolicy(identityPolicy)); err != nil { + return ctx, trace.Wrap(err) + } + + // Don't apply the following checks in non-node contexts. + if s.cfg.Component != teleport.ComponentNode { + return ctx, nil + } + + maxConnections := identity.AccessChecker.MaxConnections() + if maxConnections == 0 { + // concurrent session control is not active, nothing + // else needs to be done here. + return ctx, nil + } + + netConfig, err := s.cfg.AccessPoint.GetClusterNetworkingConfig(spanCtx) + if err != nil { + return ctx, trace.Wrap(err) + } + + semLock, err := services.AcquireSemaphoreLock(spanCtx, services.SemaphoreLockConfig{ + Service: s.cfg.Semaphores, + Clock: s.cfg.Clock, + Expiry: netConfig.GetSessionControlTimeout(), + Params: types.AcquireSemaphoreRequest{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: identity.TeleportUser, + MaxLeases: maxConnections, + Holder: s.cfg.ServerID, + }, + }) + if err != nil { + if strings.Contains(err.Error(), teleport.MaxLeases) { + // user has exceeded their max concurrent ssh connections. + userSessionLimitHitCount.Inc() + s.emitRejection(spanCtx, identity.GetUserMetadata(), localAddr, remoteAddr, events.SessionRejectedEvent, maxConnections) + + return ctx, trace.AccessDenied("too many concurrent ssh connections for user %q (max=%d)", identity.TeleportUser, maxConnections) + } + + return ctx, trace.Wrap(err) + } + + ctx, cancel := context.WithCancel(ctx) + // ensure that losing the lock closes the connection context. Under normal + // conditions, cancellation propagates from the connection context to the + // lock, but if we lose the lock due to some error (e.g. poor connectivity + // to auth server) then cancellation propagates in the other direction. + go func() { + // TODO(fspmarshall): If lock was lost due to error, find a way to propagate + // an error message to user. + <-semLock.Done() + cancel() + + // close any provided closers + for _, closer := range closers { + _ = closer.Close() + } + }() + + return ctx, nil +} + +// emitRejection emits a SessionRejectedEvent with the provided information +func (s *SessionController) emitRejection(ctx context.Context, userMetadata apievents.UserMetadata, localAddr, remoteAddr string, reason string, max int64) { + // link a background context to the current span so things + // are related but while still allowing the audit event to + // not be tied to the request scoped context + emitCtx := oteltrace.ContextWithSpanContext(context.Background(), oteltrace.SpanContextFromContext(ctx)) + + ctx, span := s.cfg.tracer.Start(emitCtx, "SessionController/emitRejection") + defer span.End() + + if err := s.cfg.Emitter.EmitAuditEvent(ctx, &apievents.SessionReject{ + Metadata: apievents.Metadata{ + Type: events.SessionRejectedEvent, + Code: events.SessionRejectedCode, + }, + UserMetadata: userMetadata, + ConnectionMetadata: apievents.ConnectionMetadata{ + Protocol: events.EventProtocolSSH, + LocalAddr: localAddr, + RemoteAddr: remoteAddr, + }, + ServerMetadata: apievents.ServerMetadata{ + ServerID: s.cfg.ServerID, + ServerNamespace: apidefaults.Namespace, + }, + Reason: reason, + Maximum: max, + }); err != nil { + s.cfg.Logger.WithError(err).Warn("Failed to emit session reject event.") + } +} diff --git a/lib/srv/session_control_test.go b/lib/srv/session_control_test.go new file mode 100644 index 0000000000000..9dfd6b144610d --- /dev/null +++ b/lib/srv/session_control_test.go @@ -0,0 +1,408 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package srv + +import ( + "context" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/services" +) + +type mockLockEnforcer struct { + lockInForceErr error +} + +func (m mockLockEnforcer) CheckLockInForce(constants.LockingMode, ...types.LockTarget) error { + return m.lockInForceErr +} + +type mockAccessPoint struct { + AccessPoint + + authPreference types.AuthPreference + clusterName types.ClusterName + netConfig types.ClusterNetworkingConfig +} + +func (m mockAccessPoint) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { + return m.authPreference, nil +} + +func (m mockAccessPoint) GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) { + return m.clusterName, nil +} + +func (m mockAccessPoint) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) { + return m.netConfig, nil +} + +type mockSemaphores struct { + types.Semaphores + + lease *types.SemaphoreLease + acquireErr error +} + +func (m mockSemaphores) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) { + return m.lease, m.acquireErr +} + +func (m mockSemaphores) CancelSemaphoreLease(ctx context.Context, lease types.SemaphoreLease) error { + return nil +} + +type mockAccessChecker struct { + services.AccessChecker + + lockMode constants.LockingMode + maxConnections int64 + keyPolicy keys.PrivateKeyPolicy + roleNames []string +} + +func (m mockAccessChecker) LockingMode(defaultMode constants.LockingMode) constants.LockingMode { + return m.lockMode +} + +func (m mockAccessChecker) MaxConnections() int64 { + return m.maxConnections +} + +func (m mockAccessChecker) PrivateKeyPolicy(defaultPolicy keys.PrivateKeyPolicy) keys.PrivateKeyPolicy { + return m.keyPolicy +} +func (m mockAccessChecker) RoleNames() []string { + return m.roleNames +} + +func TestSessionController_AcquireSessionContext(t *testing.T) { + t.Parallel() + + clock := clockwork.NewFakeClock() + emitter := &eventstest.MockEmitter{} + + cases := []struct { + name string + cfg SessionControllerConfig + identity IdentityContext + assertion func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) + }{ + { + name: "proxy: access allowed", + cfg: SessionControllerConfig{ + Semaphores: mockSemaphores{}, + AccessPoint: mockAccessPoint{ + netConfig: &types.ClusterNetworkingConfigV2{}, + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentProxy, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.NoError(t, err) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events()) + }, + }, + { + name: "node: access allowed", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{ + lease: &types.SemaphoreLease{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: "test", + LeaseID: "1", + Expires: clock.Now().Add(time.Minute), + }, + }, + AccessPoint: mockAccessPoint{ + netConfig: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + SessionControlTimeout: types.NewDuration(time.Minute), + }, + }, + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.NoError(t, err) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events()) + }, + }, + { + name: "session rejected due to lock", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{}, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{ + lockInForceErr: trace.AccessDenied("lock in force"), + }, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.ErrorIs(t, err, trace.AccessDenied("lock in force")) + require.NotNil(t, ctx) + require.Len(t, emitter.Events(), 1) + + evt, ok := emitter.Events()[0].(*apievents.SessionReject) + require.True(t, ok) + require.Equal(t, events.SessionRejectedEvent, evt.Metadata.Type) + require.Equal(t, events.SessionRejectedCode, evt.Metadata.Code) + require.Equal(t, events.EventProtocolSSH, evt.ConnectionMetadata.Protocol) + require.Equal(t, "lock in force", evt.Reason) + }, + }, + { + name: "session rejected due to private key policy", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{}, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyHardwareKey, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.Error(t, err) + require.True(t, trace.IsBadParameter(err)) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events()) + }, + }, + { + name: "session rejected due to connection limit", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{ + acquireErr: trace.LimitExceeded(teleport.MaxLeases), + }, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + netConfig: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + SessionControlTimeout: types.NewDuration(time.Minute), + }, + }, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 1, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + require.NotNil(t, ctx) + require.Len(t, emitter.Events(), 1) + + evt, ok := emitter.Events()[0].(*apievents.SessionReject) + require.True(t, ok) + require.Equal(t, events.SessionRejectedEvent, evt.Metadata.Type) + require.Equal(t, events.SessionRejectedCode, evt.Metadata.Code) + require.Equal(t, events.EventProtocolSSH, evt.ConnectionMetadata.Protocol) + require.Equal(t, events.SessionRejectedEvent, evt.Reason) + require.Equal(t, int64(1), evt.Maximum) + }, + }, + { + name: "no connection limits prevent acquiring semaphore lock", + cfg: SessionControllerConfig{ + Clock: clock, + Semaphores: mockSemaphores{ + acquireErr: trace.LimitExceeded(teleport.MaxLeases), + }, + AccessPoint: mockAccessPoint{ + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + LockingMode: constants.LockingModeStrict, + }, + }, + clusterName: &types.ClusterNameV2{Spec: types.ClusterNameSpecV2{ClusterName: "llama"}}, + netConfig: &types.ClusterNetworkingConfigV2{ + Spec: types.ClusterNetworkingConfigSpecV2{ + SessionControlTimeout: types.NewDuration(time.Minute), + }, + }, + }, + LockEnforcer: mockLockEnforcer{}, + Emitter: emitter, + Component: teleport.ComponentNode, + ServerID: "1234", + }, + identity: IdentityContext{ + TeleportUser: "alpaca", + Login: "alpaca", + Certificate: &ssh.Certificate{ + Permissions: ssh.Permissions{ + Extensions: map[string]string{ + teleport.CertExtensionPrivateKeyPolicy: string(keys.PrivateKeyPolicyNone), + }, + }, + }, + AccessChecker: mockAccessChecker{ + keyPolicy: keys.PrivateKeyPolicyNone, + maxConnections: 0, + }, + }, + assertion: func(t *testing.T, ctx context.Context, err error, emitter *eventstest.MockEmitter) { + require.NoError(t, err) + require.NotNil(t, ctx) + require.Empty(t, emitter.Events(), 0) + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + emitter.Reset() + ctrl, err := NewSessionController(tt.cfg) + require.NoError(t, err) + + ctx, err := ctrl.AcquireSessionContext(context.Background(), tt.identity, "127.0.0.1:1", "127.0.0.1:2") + tt.assertion(t, ctx, err, emitter) + + }) + } +} diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index eaee353674ec2..c0ef8493be5b7 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -71,10 +71,12 @@ import ( "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/plugin" + "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/secret" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/app" "github.com/gravitational/teleport/lib/web/ui" @@ -213,6 +215,16 @@ type Config struct { // TraceClient is used to forward spans to the upstream collector for the UI TraceClient otlptrace.Client + + // Router is used to route ssh sessions to hosts + Router *proxy.Router + + // SessionControl is used to determine if users are + // allowed to spawn new sessions + SessionControl *srv.SessionController + + // TracerProvider generates tracers to create spans with + TracerProvider oteltrace.TracerProvider } type APIHandler struct { @@ -2325,6 +2337,40 @@ func (h *Handler) clusterLoginAlertsGet(w http.ResponseWriter, r *http.Request, }, nil } +// createIdentityContext creates a srv.IdentityContext from the ssh cert of the user +// stored within the SessionContext. +func createIdentityContext(login string, sessionCtx *SessionContext) (srv.IdentityContext, error) { + accessChecker, err := sessionCtx.GetUserAccessChecker() + if err != nil { + return srv.IdentityContext{}, trace.Wrap(err) + } + + sshCert, err := sessionCtx.GetSSHCertificate() + if err != nil { + return srv.IdentityContext{}, trace.Wrap(err) + } + + unmappedRoles, err := services.ExtractRolesFromCert(sshCert) + if err != nil { + return srv.IdentityContext{}, trace.Wrap(err) + } + + accessRequestIDs, err := srv.ParseAccessRequestIDs(sshCert.Extensions[teleport.CertExtensionTeleportActiveRequests]) + if err != nil { + return srv.IdentityContext{}, trace.Wrap(err) + } + + return srv.IdentityContext{ + AccessChecker: accessChecker, + TeleportUser: sessionCtx.user, + Login: login, + Certificate: sshCert, + UnmappedRoles: unmappedRoles, + ActiveRequests: accessRequestIDs, + Impersonator: sshCert.Extensions[teleport.CertExtensionImpersonator], + }, nil +} + // siteNodeConnect connect to the site node // // GET /v1/webapi/sites/:site/namespaces/:namespace/connect?access_token=bearer_token¶ms= @@ -2339,7 +2385,7 @@ func (h *Handler) siteNodeConnect( w http.ResponseWriter, r *http.Request, p httprouter.Params, - ctx *SessionContext, + sessionCtx *SessionContext, site reversetunnel.RemoteSite, ) (interface{}, error) { q := r.URL.Query() @@ -2353,7 +2399,17 @@ func (h *Handler) siteNodeConnect( } h.log.Debugf("New terminal request for ns=%s, server=%s, login=%s, sid=%s, websid=%s.", - req.Namespace, req.Server, req.Login, req.SessionID, ctx.GetSessionID()) + req.Namespace, req.Server, req.Login, req.SessionID, sessionCtx.GetSessionID()) + + identity, err := createIdentityContext(req.Login, sessionCtx) + if err != nil { + return nil, trace.Wrap(err) + } + + ctx, err := h.cfg.SessionControl.AcquireSessionContext(r.Context(), identity, h.cfg.ProxyWebAddr.Addr, r.RemoteAddr) + if err != nil { + return nil, trace.Wrap(err) + } authAccessPoint, err := site.CachingAccessPoint() if err != nil { @@ -2361,7 +2417,7 @@ func (h *Handler) siteNodeConnect( return nil, trace.Wrap(err) } - netConfig, err := authAccessPoint.GetClusterNetworkingConfig(r.Context()) + netConfig, err := authAccessPoint.GetClusterNetworkingConfig(ctx) if err != nil { h.log.WithError(err).Debug("Unable to fetch cluster networking config.") return nil, trace.Wrap(err) @@ -2372,12 +2428,18 @@ func (h *Handler) siteNodeConnect( req.ProxyHostPort = h.ProxyHostPort() req.Cluster = site.GetName() - clt, err := ctx.GetUserClient(site) + clt, err := sessionCtx.GetUserClient(site) if err != nil { return nil, trace.Wrap(err) } - term, err := NewTerminal(r.Context(), *req, clt, ctx) + term, err := NewTerminal(ctx, TerminalHandlerConfig{ + Req: *req, + AuthProvider: clt, + SessionCtx: sessionCtx, + Router: h.cfg.Router, + TracerProvider: h.cfg.TracerProvider, + }) if err != nil { h.log.WithError(err).Error("Unable to create terminal.") return nil, trace.Wrap(err) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index a81500f7af015..20064fd16d949 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -108,7 +108,7 @@ import ( "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/pam" - libproxy "github.com/gravitational/teleport/lib/proxy" + "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/secret" @@ -135,8 +135,9 @@ type WebSuite struct { proxyTunnel reversetunnel.Server srvID string - user string - webServer *httptest.Server + user string + webServer *httptest.Server + webHandler *APIHandler mockU2F *mocku2f.Key server *auth.TestServer @@ -256,9 +257,20 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { }) require.NoError(t, err) + nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: nodeLockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: nodeID, + }) + require.NoError(t, err) + // create SSH service: nodeDataDir := t.TempDir() node, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{signer}, @@ -276,6 +288,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(s.clock), regular.SetLockWatcher(nodeLockWatcher), + regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) s.node = node @@ -342,7 +355,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { require.NoError(t, err) s.proxyTunnel = revTunServer - router, err := libproxy.NewRouter(libproxy.RouterConfig{ + router, err := proxy.NewRouter(proxy.RouterConfig{ ClusterName: s.server.ClusterName(), Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), RemoteClusterGetter: s.proxyClient, @@ -351,8 +364,19 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { }) require.NoError(t, err) + proxySessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: s.proxyClient, + AccessPoint: s.proxyClient, + LockEnforcer: proxyLockWatcher, + Emitter: s.proxyClient, + Component: teleport.ComponentProxy, + ServerID: proxyID, + }) + require.NoError(t, err) + // proxy server: s.proxy, err = regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{signer}, @@ -370,6 +394,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { regular.SetClock(s.clock), regular.SetLockWatcher(proxyLockWatcher), regular.SetNodeWatcher(proxyNodeWatcher), + regular.SetSessionController(proxySessionController), ) require.NoError(t, err) @@ -392,10 +417,13 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { StaticFS: fs, cachedSessionLingeringThreshold: &sessionLingeringThreshold, ProxySettings: &mockProxySettings{}, + SessionControl: proxySessionController, + Router: router, }, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(s.clock)) require.NoError(t, err) s.webServer = httptest.NewUnstartedServer(handler) + s.webHandler = handler s.webServer.StartTLS() err = s.proxy.Start() require.NoError(t, err) @@ -1012,7 +1040,7 @@ func TestClusterAlertsGet(t *testing.T) { func TestSiteNodeConnectInvalidSessionID(t *testing.T) { t.Parallel() s := newWebSuite(t) - _, err := s.makeTerminal(t, s.authPack(t, "foo"), withSessionID(session.ID("/../../../foo"))) + _, err := s.makeTerminal(t, s.authPack(t, "foo"), withSessionID("/../../../foo")) require.Error(t, err) } @@ -1155,7 +1183,7 @@ func TestNewTerminalHandler(t *testing.T) { }, }, { - expectedErr: "bad term dimensions", + expectedErr: "invalid dimensions", authProvider: makeProvider(validNode), req: TerminalRequest{ SessionID: validSID, @@ -1181,7 +1209,12 @@ func TestNewTerminalHandler(t *testing.T) { ctx := context.Background() for _, testCase := range validCases { - term, err := NewTerminal(ctx, testCase.req, testCase.authProvider, nil) + term, err := NewTerminal(ctx, TerminalHandlerConfig{ + Req: testCase.req, + AuthProvider: testCase.authProvider, + SessionCtx: &SessionContext{}, + Router: &proxy.Router{}, + }) require.NoError(t, err) require.Empty(t, cmp.Diff(testCase.req, term.params)) require.Equal(t, testCase.expectedHost, testCase.expectedHost) @@ -1189,7 +1222,12 @@ func TestNewTerminalHandler(t *testing.T) { } for _, testCase := range invalidCases { - _, err := NewTerminal(ctx, testCase.req, testCase.authProvider, nil) + _, err := NewTerminal(ctx, TerminalHandlerConfig{ + Req: testCase.req, + AuthProvider: testCase.authProvider, + SessionCtx: &SessionContext{}, + Router: &proxy.Router{}, + }) require.Regexp(t, ".*"+testCase.expectedErr+".*", err.Error()) } } @@ -5697,6 +5735,14 @@ func (mock authProviderMock) GetSessionTracker(ctx context.Context, sessionID st return nil, trace.NotFound("foo") } +func (mock authProviderMock) IsMFARequired(ctx context.Context, req *authproto.IsMFARequiredRequest) (*authproto.IsMFARequiredResponse, error) { + return nil, nil +} + +func (mock authProviderMock) GenerateUserSingleUseCerts(ctx context.Context) (authproto.AuthService_GenerateUserSingleUseCertsClient, error) { + return nil, nil +} + type terminalOpt func(t *TerminalRequest) func withSessionID(sid session.ID) terminalOpt { @@ -5935,9 +5981,20 @@ func newWebPack(t *testing.T, numProxies int) *webPack { require.NoError(t, err) t.Cleanup(nodeLockWatcher.Close) + nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: nodeClient, + AccessPoint: nodeClient, + LockEnforcer: nodeLockWatcher, + Emitter: nodeClient, + Component: teleport.ComponentNode, + ServerID: nodeID, + }) + require.NoError(t, err) + // create SSH service: nodeDataDir := t.TempDir() node, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, server.TLS.ClusterName(), hostSigners, @@ -5955,6 +6012,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack { regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(clock), regular.SetLockWatcher(nodeLockWatcher), + regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) @@ -5962,7 +6020,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack { t.Cleanup(func() { require.NoError(t, node.Close()) }) require.NoError(t, auth.CreateUploaderDir(nodeDataDir)) - var proxies []*proxy + var proxies []*testProxy for p := 0; p < numProxies; p++ { proxyID := fmt.Sprintf("proxy%v", p) proxies = append(proxies, createProxy(ctx, t, proxyID, node, server.TLS, hostSigners, clock)) @@ -5990,7 +6048,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack { func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regular.Server, authServer *auth.TestTLSServer, hostSigners []ssh.Signer, clock clockwork.FakeClock, -) *proxy { +) *testProxy { // create reverse tunnel service: client, err := authServer.NewClient(auth.TestIdentity{ I: auth.BuiltinRole{ @@ -6053,7 +6111,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) }) - router, err := libproxy.NewRouter(libproxy.RouterConfig{ + router, err := proxy.NewRouter(proxy.RouterConfig{ ClusterName: authServer.ClusterName(), Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), RemoteClusterGetter: client, @@ -6062,7 +6120,18 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula }) require.NoError(t, err) + sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ + Semaphores: client, + AccessPoint: client, + LockEnforcer: proxyLockWatcher, + Emitter: client, + Component: teleport.ComponentProxy, + ServerID: proxyID, + }) + require.NoError(t, err) + proxyServer, err := regular.New( + ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, authServer.ClusterName(), hostSigners, @@ -6080,6 +6149,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula regular.SetClock(clock), regular.SetLockWatcher(proxyLockWatcher), regular.SetNodeWatcher(proxyNodeWatcher), + regular.SetSessionController(sessionController), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, proxyServer.Close()) }) @@ -6099,6 +6169,8 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula Emitter: client, StaticFS: fs, ProxySettings: &mockProxySettings{}, + SessionControl: sessionController, + Router: router, }, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(clock)) require.NoError(t, err) @@ -6128,7 +6200,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula url, err := url.Parse("https://" + webServer.Listener.Addr().String()) require.NoError(t, err) - return &proxy{ + return &testProxy{ clock: clock, auth: authServer, client: client, @@ -6146,13 +6218,13 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula // transition the test suite to use the testing package // directly. type webPack struct { - proxies []*proxy + proxies []*testProxy server *auth.TestServer node *regular.Server clock clockwork.FakeClock } -type proxy struct { +type testProxy struct { clock clockwork.FakeClock client *auth.Client auth *auth.TestTLSServer @@ -6166,7 +6238,7 @@ type proxy struct { // authPack returns new authenticated package consisting of created valid // user, otp token, created web session and authenticated client. -func (r *proxy) authPack(t *testing.T, teleportUser string, roles []types.Role) *authPack { +func (r *testProxy) authPack(t *testing.T, teleportUser string, roles []types.Role) *authPack { ctx := context.Background() const ( pass = "abc123" @@ -6227,7 +6299,7 @@ func (r *proxy) authPack(t *testing.T, teleportUser string, roles []types.Role) } } -func (r *proxy) authPackFromPack(t *testing.T, pack *authPack) *authPack { +func (r *testProxy) authPackFromPack(t *testing.T, pack *authPack) *authPack { jar, err := cookiejar.New(nil) require.NoError(t, err) @@ -6239,7 +6311,7 @@ func (r *proxy) authPackFromPack(t *testing.T, pack *authPack) *authPack { return &result } -func (r *proxy) authPackFromResponse(t *testing.T, httpResp *roundtrip.Response) *authPack { +func (r *testProxy) authPackFromResponse(t *testing.T, httpResp *roundtrip.Response) *authPack { var resp *CreateSessionResponse require.NoError(t, json.Unmarshal(httpResp.Bytes(), &resp)) @@ -6271,7 +6343,7 @@ func defaultRoleForNewUser(teleUser types.User, login string) types.Role { return role } -func (r *proxy) createUser(ctx context.Context, t *testing.T, user, login, pass, otpSecret string, roles []types.Role) { +func (r *testProxy) createUser(ctx context.Context, t *testing.T, user, login, pass, otpSecret string, roles []types.Role) { teleUser, err := types.NewUser(user) require.NoError(t, err) @@ -6304,14 +6376,14 @@ func (r *proxy) createUser(ctx context.Context, t *testing.T, user, login, pass, } } -func (r *proxy) newClient(t *testing.T, opts ...roundtrip.ClientParam) *client.WebClient { +func (r *testProxy) newClient(t *testing.T, opts ...roundtrip.ClientParam) *client.WebClient { opts = append(opts, roundtrip.HTTPClient(client.NewInsecureWebClient())) clt, err := client.NewWebClient(r.webURL.String(), opts...) require.NoError(t, err) return clt } -func (r *proxy) makeTerminal(t *testing.T, pack *authPack, sessionID session.ID) *websocket.Conn { +func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session.ID) *websocket.Conn { u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, @@ -6353,7 +6425,7 @@ func (r *proxy) makeTerminal(t *testing.T, pack *authPack, sessionID session.ID) return ws } -func (r *proxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID session.ID, addr net.Addr) *websocket.Conn { +func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID session.ID, addr net.Addr) *websocket.Conn { u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, diff --git a/lib/web/terminal.go b/lib/web/terminal.go index d1e1835058e53..6e9655da9bcd3 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "io" + "net" "net/http" "strconv" "strings" @@ -37,6 +38,7 @@ import ( "github.com/gravitational/teleport" authproto "github.com/gravitational/teleport/api/client/proto" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" @@ -45,6 +47,8 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -88,28 +92,77 @@ type AuthProvider interface { GetNodes(ctx context.Context, namespace string) ([]types.Server, error) GetSessionEvents(namespace string, sid session.ID, after int, includePrintEvents bool) ([]events.EventFields, error) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) + IsMFARequired(ctx context.Context, req *authproto.IsMFARequiredRequest) (*authproto.IsMFARequiredResponse, error) + GenerateUserSingleUseCerts(ctx context.Context) (authproto.AuthService_GenerateUserSingleUseCertsClient, error) } -// NewTerminal creates a web-based terminal based on WebSockets and returns a -// new TerminalHandler. -func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProvider, sessCtx *SessionContext) (*TerminalHandler, error) { - ctx, span := tracing.DefaultProvider().Tracer("terminal").Start(ctx, "terminal/NewTerminal") - defer span.End() +// TerminalHandlerConfig contains the configuration options necessary to +// correctly setup the TerminalHandler +type TerminalHandlerConfig struct { + // Req are the terminal parameters from the UI + Req TerminalRequest + // AuthProvider is used to communicate with the auth server + AuthProvider AuthProvider + // SessionCtx is the user specific session context + SessionCtx *SessionContext + // Router determines how connections to nodes are created + Router *proxy.Router + // TracerProvider is used to create the tracer + TracerProvider oteltrace.TracerProvider + // tracer is used to create spans + tracer oteltrace.Tracer +} + +// CheckAndSetDefaults validates the provided dependencies +// are valid and sets defaults for any optional items. +func (c *TerminalHandlerConfig) CheckAndSetDefaults() error { + if c.AuthProvider == nil { + return trace.BadParameter("AuthProvider must be provided") + } + + if c.SessionCtx == nil { + return trace.BadParameter("SessionCtx must be provided") + } + + if c.Router == nil { + return trace.BadParameter("Router must be provided") + } // Make sure whatever session is requested is a valid session. - _, err := session.ParseID(string(req.SessionID)) + _, err := session.ParseID(string(c.Req.SessionID)) if err != nil { - return nil, trace.BadParameter("sid: invalid session id") + return trace.BadParameter("invalid session id provided") + } + + if c.Req.Login == "" { + return trace.BadParameter("invalid login provided") + } + + if c.Req.Term.W <= 0 || c.Req.Term.H <= 0 || + c.Req.Term.W >= 4096 || c.Req.Term.H >= 4096 { + return trace.BadParameter("invalid dimensions(%dx%d)", c.Req.Term.W, c.Req.Term.H) } - if req.Login == "" { - return nil, trace.BadParameter("login: missing login") + if c.TracerProvider == nil { + c.TracerProvider = tracing.DefaultProvider() } - if req.Term.W <= 0 || req.Term.H <= 0 { - return nil, trace.BadParameter("term: bad term dimensions") + + c.tracer = c.TracerProvider.Tracer("webterminal") + + return nil +} + +// NewTerminal creates a web-based terminal based on WebSockets and returns a +// new TerminalHandler. +func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandler, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) } - servers, err := authProvider.GetNodes(ctx, req.Namespace) + ctx, span := cfg.tracer.Start(ctx, "NewTerminal") + defer span.End() + + servers, err := cfg.AuthProvider.GetNodes(ctx, apidefaults.Namespace) if err != nil { return nil, trace.Wrap(err) } @@ -118,13 +171,13 @@ func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProv // // All proxies will support lookup by uuid, so host/port lookup // and fallback can be dropped entirely. - hostName, hostPort, err := resolveServerHostPort(req.Server, servers) + hostName, hostPort, err := resolveServerHostPort(cfg.Req.Server, servers) if err != nil { - return nil, trace.BadParameter("invalid server name %q: %v", req.Server, err) + return nil, trace.BadParameter("invalid server name %q: %v", cfg.Req.Server, err) } var join bool - _, err = authProvider.GetSessionTracker(ctx, string(req.SessionID)) + _, err = cfg.AuthProvider.GetSessionTracker(ctx, string(cfg.Req.SessionID)) switch { case trace.IsNotFound(err): join = false @@ -137,17 +190,20 @@ func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProv return &TerminalHandler{ log: logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentWebsocket, + "session_id": cfg.Req.SessionID.String(), }), - params: req, - ctx: sessCtx, + params: cfg.Req, + ctx: cfg.SessionCtx, hostName: hostName, hostPort: hostPort, - hostUUID: req.Server, - authProvider: authProvider, + hostUUID: cfg.Req.Server, + authProvider: cfg.AuthProvider, encoder: unicode.UTF8.NewEncoder(), decoder: unicode.UTF8.NewDecoder(), wsLock: &sync.Mutex{}, join: join, + router: cfg.Router, + tracer: cfg.tracer, }, nil } @@ -200,6 +256,12 @@ type TerminalHandler struct { // join is set if we're joining an existing session join bool + + // router is used to dial the host + router *proxy.Router + + // tracer creates spans + tracer oteltrace.Tracer } // ServeHTTP builds a connection to the remote node and then pumps back two types of @@ -220,7 +282,7 @@ func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ws, err := upgrader.Upgrade(w, r, nil) if err != nil { errMsg := "Error upgrading to websocket" - t.log.Errorf("%v: %v", errMsg, err) + t.log.WithError(err).Error(errMsg) http.Error(w, errMsg, http.StatusInternalServerError) return } @@ -260,7 +322,7 @@ func (t *TerminalHandler) startPingLoop(ws *websocket.Conn) { // If this is just a temporary issue, we will retry shortly anyway. deadline := time.Now().Add(time.Second) if err := ws.WriteControl(websocket.PingMessage, nil, deadline); err != nil { - t.log.Errorf("Unable to send ping frame to web client: %v.", err) + t.log.WithError(err).Error("Unable to send ping frame to web client") t.Close() return } @@ -286,15 +348,12 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { // the terminal. tc, err := t.makeClient(ws, r) if err != nil { - t.log.WithError(err).Infof("Failed creating a client for session %v.", t.params.SessionID) - writeErr := t.writeError(err, ws) - if writeErr != nil { - t.log.WithError(writeErr).Warnf("Unable to send error to terminal.") - } + t.log.WithError(err).Info("Failed creating a client for session") + t.writeError(err, ws) return } - t.log.Debugf("Creating websocket stream for %v.", t.params.SessionID) + t.log.Debug("Creating websocket stream") // Update the read deadline upon receiving a pong message. ws.SetPongHandler(func(_ string) error { @@ -311,7 +370,7 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { // Block until the terminal session is complete. <-t.terminalContext.Done() - t.log.Debugf("Closing websocket stream for %v.", t.params.SessionID) + t.log.Debug("Closing websocket stream") } // makeClient builds a *client.TeleportClient for the connection. @@ -347,7 +406,7 @@ func (t *TerminalHandler) makeClient(ws *websocket.Conn, r *http.Request) (*clie clientConfig.HostPort = t.hostPort clientConfig.Env = map[string]string{sshutils.SessionEnvVar: string(t.params.SessionID)} clientConfig.ClientAddr = r.RemoteAddr - clientConfig.Tracer = tracing.DefaultProvider().Tracer("TerminalHandler") + clientConfig.Tracer = t.tracer if len(t.params.InteractiveCommand) > 0 { clientConfig.Interactive = true @@ -368,46 +427,108 @@ func (t *TerminalHandler) makeClient(ws *websocket.Conn, r *http.Request) (*clie return false, nil } - if err := t.issueSessionMFACerts(ctx, tc, ws); err != nil { - return nil, trace.Wrap(err) - } - return tc, nil } +// issueSessionMFACerts performs the mfa ceremony to retrieve new certs that can be +// used to access nodes which require per-session mfa. The ceremony is performed directly +// to make use of the authProvider already established for the session instead of leveraging +// the TeleportClient which would require dialing the auth server a second time. func (t *TerminalHandler) issueSessionMFACerts(ctx context.Context, tc *client.TeleportClient, ws *websocket.Conn) error { - ctx, span := tracing.DefaultProvider().Tracer("terminal").Start(ctx, "terminal/issueSessionMFACerts") + ctx, span := t.tracer.Start(ctx, "terminal/issueSessionMFACerts") defer span.End() - pc, err := tc.ConnectToProxy(ctx) + log.Debug("Attempting to issue a single-use user certificate with an MFA check.") + stream, err := t.authProvider.GenerateUserSingleUseCerts(ctx) if err != nil { return trace.Wrap(err) } - defer pc.Close() + defer func() { + stream.CloseSend() + stream.Recv() + }() pk, err := keys.ParsePrivateKey(t.ctx.session.GetPriv()) if err != nil { return trace.Wrap(err) } - key, err := pc.IssueUserCertsWithMFA(ctx, client.ReissueParams{ - RouteToCluster: t.params.Cluster, - NodeName: t.params.Server, - ExistingCreds: &client.Key{ - PrivateKey: pk, - Cert: t.ctx.session.GetPub(), - TLSCert: t.ctx.session.GetTLSCert(), - }, - }, promptMFAChallenge(ws, t.wsLock, protobufMFACodec{})) + key := &client.Key{ + PrivateKey: pk, + Cert: t.ctx.session.GetPub(), + TLSCert: t.ctx.session.GetTLSCert(), + } + + tlsCert, err := key.TeleportTLSCertificate() + if err != nil { + return trace.Wrap(err) + } + + if err := stream.Send( + &authproto.UserSingleUseCertsRequest{ + Request: &authproto.UserSingleUseCertsRequest_Init{ + Init: &authproto.UserCertsRequest{ + PublicKey: key.MarshalSSHPublicKey(), + Username: tlsCert.Subject.CommonName, + Expires: tlsCert.NotAfter, + RouteToCluster: t.params.Cluster, + NodeName: t.params.Server, + Usage: authproto.UserCertsRequest_SSH, + Format: tc.CertificateFormat, + }, + }, + }); err != nil { + return trace.Wrap(err) + } + + resp, err := stream.Recv() if err != nil { return trace.Wrap(err) } + challenge := resp.GetMFAChallenge() + if challenge == nil { + return trace.BadParameter("server sent a %T on GenerateUserSingleUseCerts, expected MFAChallenge", resp.Response) + } + + span.AddEvent("prompting user with mfa challenge") + assertion, err := promptMFAChallenge(ws, t.wsLock, protobufMFACodec{})(ctx, tc.WebProxyAddr, challenge) + if err != nil { + return trace.Wrap(err) + } + span.AddEvent("user completed mfa challenge") + + err = stream.Send(&authproto.UserSingleUseCertsRequest{Request: &authproto.UserSingleUseCertsRequest_MFAResponse{MFAResponse: assertion}}) + if err != nil { + return trace.Wrap(err) + } + + resp, err = stream.Recv() + if err != nil { + return trace.Wrap(err) + } + + certResp := resp.GetCert() + if certResp == nil { + return trace.BadParameter("server sent a %T on GenerateUserSingleUseCerts, expected SingleUseUserCert", resp.Response) + } + + switch crt := certResp.Cert.(type) { + case *authproto.SingleUseUserCert_SSH: + key.Cert = crt.SSH + default: + return trace.BadParameter("server sent a %T SingleUseUserCert in response", certResp.Cert) + } + + key.ClusterName = t.params.Cluster + am, err := key.AsAuthMethod() if err != nil { return trace.Wrap(err) } + tc.AuthMethods = []ssh.AuthMethod{am} + return nil } @@ -461,31 +582,101 @@ func promptMFAChallenge( // streamTerminal opens a SSH connection to the remote host and streams // events back to the web client. func (t *TerminalHandler) streamTerminal(ws *websocket.Conn, tc *client.TeleportClient) { - defer t.terminalCancel() + ctx, span := t.tracer.Start(t.terminalContext, "terminal/streamTerminal") + defer span.End() - // Establish SSH connection to the server. This function will block until - // either an error occurs or it completes successfully. - err := tc.SSH(t.terminalContext, t.params.InteractiveCommand, false) + defer t.terminalCancel() - // TODO IN: 5.0 - // - // Make connecting by UUID the default instead of the fallback. - // - if err != nil && strings.Contains(err.Error(), teleport.NodeIsAmbiguous) { - t.log.Debugf("Ambiguous hostname %q, attempting to connect by UUID (%q).", t.hostName, t.hostUUID) - tc.Host = t.hostUUID - // We don't technically need to zero the HostPort, but future version won't look up - // HostPort when connecting by UUID, so its best to keep behavior consistent. - tc.HostPort = 0 - err = tc.SSH(t.terminalContext, t.params.InteractiveCommand, false) + accessChecker, err := t.ctx.GetUserAccessChecker() + if err != nil { + t.log.WithError(err).Warn("Unable to stream terminal - failed to get access checker") + t.writeError(err, ws) + return } + conn, err := t.router.DialHost(ctx, ws.RemoteAddr(), t.hostName, strconv.Itoa(t.hostPort), tc.SiteName, accessChecker, nil) if err != nil { - t.log.Warnf("Unable to stream terminal: %v.", err) - er := t.writeError(err, ws) - if er != nil { - t.log.Warnf("Unable to send error to terminal: %v: %v.", err, er) + t.log.WithError(err).Warn("Unable to stream terminal - failed to dial host.") + t.writeError(err, ws) + return + } + + defer func() { + if conn == nil { + return + } + + if err := conn.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) { + t.log.WithError(err).Warn("Failed to close connection to host") } + }() + + sshConfig := &ssh.ClientConfig{ + User: tc.HostLogin, + Auth: tc.AuthMethods, + HostKeyCallback: tc.HostKeyCallback, + } + + nc, connectErr := client.NewNodeClient(ctx, sshConfig, conn, net.JoinHostPort(t.hostName, strconv.Itoa(t.hostPort)), tc, modules.GetModules().IsBoringBinary()) + switch { + case connectErr != nil && !trace.IsAccessDenied(connectErr): // catastrophic error, return it + t.log.WithError(connectErr).Warn("Unable to stream terminal - failed to create node client") + t.writeError(connectErr, ws) + return + case connectErr != nil && trace.IsAccessDenied(connectErr): // see if per session mfa would allow access + mfaRequiredResp, err := t.authProvider.IsMFARequired(ctx, &authproto.IsMFARequiredRequest{ + Target: &authproto.IsMFARequiredRequest_Node{ + Node: &authproto.NodeLogin{ + Node: t.params.Server, + Login: tc.HostLogin, + }, + }, + }) + if err != nil { + t.log.WithError(err).Warn("Unable to stream terminal - failed to determine if per session mfa is required") + // write the original connect error + t.writeError(connectErr, ws) + return + } + + if !mfaRequiredResp.Required { + t.log.WithError(connectErr).Warn("Unable to stream terminal - user does not have access to host") + // write the original connect error + t.writeError(connectErr, ws) + return + } + + // perform mfa ceremony and retrieve new certs + if err := t.issueSessionMFACerts(ctx, tc, ws); err != nil { + t.log.WithError(err).Warn("Unable to stream terminal - failed to perform mfa ceremony") + t.writeError(err, ws) + return + } + + // update auth methods + sshConfig.Auth = tc.AuthMethods + + // connect to the node again with the new certs + conn, err = t.router.DialHost(ctx, ws.RemoteAddr(), t.hostName, strconv.Itoa(t.hostPort), tc.SiteName, accessChecker, nil) + if err != nil { + t.log.WithError(err).Warn("Unable to stream terminal - failed to dial host") + t.writeError(err, ws) + return + } + + nc, err = client.NewNodeClient(ctx, sshConfig, conn, net.JoinHostPort(t.hostName, strconv.Itoa(t.hostPort)), tc, modules.GetModules().IsBoringBinary()) + if err != nil { + t.log.WithError(err).Warn("Unable to stream terminal - failed to create node client") + t.writeError(err, ws) + return + } + } + + // Establish SSH connection to the server. This function will block until + // either an error occurs or it completes successfully. + if err = nc.RunInteractiveShell(ctx, types.SessionPeerMode, nil); err != nil { + t.log.WithError(err).Warn("Unable to stream terminal - failure running interactive shell") + t.writeError(err, ws) return } @@ -497,17 +688,19 @@ func (t *TerminalHandler) streamTerminal(ws *websocket.Conn, tc *client.Teleport } envelopeBytes, err := proto.Marshal(envelope) if err != nil { - t.log.Errorf("Unable to marshal close event for web client.") + t.log.WithError(err).Error("Unable to marshal close event for web client.") return } + t.wsLock.Lock() err = ws.WriteMessage(websocket.BinaryMessage, envelopeBytes) t.wsLock.Unlock() if err != nil { - t.log.Errorf("Unable to send close event to web client.") + t.log.WithError(err).Error("Unable to send close event to web client.") return } - t.log.Debugf("Sent close event to web client.") + + t.log.Debug("Sent close event to web client.") } // streamEvents receives events over the SSH connection and forwards them to @@ -520,16 +713,16 @@ func (t *TerminalHandler) streamEvents(ws *websocket.Conn, tc *client.TeleportCl data, err := json.Marshal(event) logger := t.log.WithField("event", event.GetType()) if err != nil { - logger.Errorf("Unable to marshal audit event: %v.", err) + logger.WithError(err).Errorf("Unable to marshal audit event") continue } - t.log.Debugf("Sending audit event %v to web client.", event.GetType()) + logger.Debug("Sending audit event to web client.") // UTF-8 encode the error message and then wrap it in a raw envelope. encodedPayload, err := t.encoder.String(string(data)) if err != nil { - logger.Debugf("Unable to send audit event to web client: %v.", err) + logger.WithError(err).Debug("Unable to send audit event to web client") continue } envelope := &Envelope{ @@ -539,7 +732,7 @@ func (t *TerminalHandler) streamEvents(ws *websocket.Conn, tc *client.TeleportCl } envelopeBytes, err := proto.Marshal(envelope) if err != nil { - logger.Debugf("Unable to send audit event to web client: %v.", err) + logger.WithError(err).Debug("Unable to send audit event to web client") continue } @@ -548,7 +741,7 @@ func (t *TerminalHandler) streamEvents(ws *websocket.Conn, tc *client.TeleportCl err = ws.WriteMessage(websocket.BinaryMessage, envelopeBytes) t.wsLock.Unlock() if err != nil { - logger.Errorf("Unable to send audit event to web client: %v.", err) + logger.WithError(err).Error("Unable to send audit event to web client") continue } // Once the terminal stream is over (and the close envelope has been sent), @@ -572,16 +765,14 @@ func (t *TerminalHandler) windowChange(ctx context.Context, params *session.Term } // writeError displays an error in the terminal window. -func (t *TerminalHandler) writeError(err error, ws *websocket.Conn) error { +func (t *TerminalHandler) writeError(err error, ws *websocket.Conn) { // Replace \n with \r\n so the message correctly aligned. r := strings.NewReplacer("\r\n", "\r\n", "\n", "\r\n") errMessage := r.Replace(err.Error()) - _, err = t.write([]byte(errMessage), ws) - if err != nil { - return trace.Wrap(err) - } - return nil + if _, writeErr := t.write([]byte(errMessage), ws); writeErr != nil { + t.log.WithError(writeErr).Warnf("Unable to send error to terminal: %v", err) + } } // resolveServerHostPort parses server name and attempts to resolve hostname