diff --git a/api/types/session_tracker.go b/api/types/session_tracker.go index 91ed49376a263..37ad0984b2722 100644 --- a/api/types/session_tracker.go +++ b/api/types/session_tracker.go @@ -78,8 +78,8 @@ type SessionTracker interface { // GetAddress returns the address of the session target. GetAddress() string - // GetClustername returns the name of the cluster. - GetClustername() string + // GetClusterName returns the name of the cluster. + GetClusterName() string // GetLogin returns the target machine username used for this session. GetLogin() string @@ -272,7 +272,7 @@ func (s *SessionTrackerV1) GetAddress() string { } // GetClustername returns the name of the cluster the session is running in. -func (s *SessionTrackerV1) GetClustername() string { +func (s *SessionTrackerV1) GetClusterName() string { return s.Spec.ClusterName } diff --git a/constants.go b/constants.go index 59b4e1ef2cda7..f6e001e0616a2 100644 --- a/constants.go +++ b/constants.go @@ -639,6 +639,9 @@ const ( // ForceTerminateRequest is an SSH request to forcefully terminate a session. ForceTerminateRequest = "x-teleport-force-terminate" + // TerminalSizeRequest is a request for the terminal size of the session. + TerminalSizeRequest = "x-teleport-terminal-size" + // MFAPresenceRequest is an SSH request to notify clients that MFA presence is required for a session. MFAPresenceRequest = "x-teleport-mfa-presence" @@ -654,6 +657,10 @@ const ( // EnvSSHSessionDisplayParticipantRequirements is set to true or false to indicate if participant // requirement information should be printed. EnvSSHSessionDisplayParticipantRequirements = "TELEPORT_SESSION_PARTICIPANT_REQUIREMENTS" + + // SSHSessionJoinPrincipal is the SSH principal used when joining sessions. + // This starts with a hyphen so it isn't a valid unix login. + SSHSessionJoinPrincipal = "-teleport-internal-join" ) const ( diff --git a/integration/integration_test.go b/integration/integration_test.go index fabe6300142ee..323ea51464113 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -810,8 +810,8 @@ func testSSHTracker(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) cl.Stdout = personA cl.Stdin = personA - personA.Type("\aecho hi\n\r") go cl.SSH(ctx, []string{}, false) + personA.Type("\aecho hi\n\r") condition := func() bool { // verify that the tracker was created @@ -821,7 +821,7 @@ func testSSHTracker(t *testing.T, suite *integrationTestSuite) { } // wait for the tracker to be created - require.Eventually(t, condition, time.Minute, time.Millisecond*100) + require.Eventually(t, condition, time.Minute*5, time.Millisecond*100) } // testInteractive covers SSH into shell and joining the same session from another client diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index 5e0797d955a0f..a7c515fe808ba 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -1526,8 +1526,19 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) { }, }) require.NoError(t, err) + joinRole, err := types.NewRole("participant", types.RoleSpecV5{ + Allow: types.RoleConditions{ + JoinSessions: []*types.SessionJoinPolicy{{ + Name: "foo", + Roles: []string{"kubemaster"}, + Kinds: []string{string(types.KubernetesSessionKind)}, + Modes: []string{string(types.SessionPeerMode)}, + }}, + }, + }) + require.NoError(t, err) teleport.AddUserWithRole(hostUsername, role) - teleport.AddUserWithRole(participantUsername, role) + teleport.AddUserWithRole(participantUsername, joinRole) err = teleport.CreateEx(t, nil, tconf) require.NoError(t, err) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index c259a30a0f525..46c1c3e95d1d0 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -1045,6 +1045,10 @@ func (a *Server) generateUserCert(req certRequest) (*proto.Certs, error) { return nil, trace.Wrap(err) } + // Add the special join-only principal used for joining sessions. + // All users have access to this and join RBAC rules are checked after the connection is established. + allowedLogins = append(allowedLogins, "-teleport-internal-join") + params := services.UserCertParams{ CASigner: caSigner, CASigningAlg: sshutils.GetSigningAlgName(userCA), diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 21f243e4099a7..ed6081371d8d1 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -282,7 +282,7 @@ func TestAuthenticateSSHUser(t *testing.T) { gotSSHCert, err := sshutils.ParseCertificate(resp.Cert) require.NoError(t, err) require.Equal(t, gotSSHCert.Key, inSSHPub) - require.Equal(t, gotSSHCert.ValidPrincipals, []string{user}) + require.Equal(t, gotSSHCert.ValidPrincipals, []string{user, teleport.SSHSessionJoinPrincipal}) // Verify the public key and Subject in TLS cert. inCryptoPub := inSSHPub.(ssh.CryptoPublicKey).CryptoPublicKey() gotTLSCert, err := tlsca.ParseCertificatePEM(resp.TLSCert) @@ -291,7 +291,7 @@ func TestAuthenticateSSHUser(t *testing.T) { wantID := tlsca.Identity{ Username: user, Groups: []string{role.GetName()}, - Principals: []string{user}, + Principals: []string{user, teleport.SSHSessionJoinPrincipal}, KubernetesUsers: []string{user}, KubernetesGroups: []string{"system:masters"}, Expires: gotTLSCert.NotAfter, @@ -320,7 +320,7 @@ func TestAuthenticateSSHUser(t *testing.T) { wantID = tlsca.Identity{ Username: user, Groups: []string{role.GetName()}, - Principals: []string{user}, + Principals: []string{user, teleport.SSHSessionJoinPrincipal}, KubernetesUsers: []string{user}, KubernetesGroups: []string{"system:masters"}, // It's OK to use a non-existent kube cluster for leaf teleport @@ -364,7 +364,7 @@ func TestAuthenticateSSHUser(t *testing.T) { wantID = tlsca.Identity{ Username: user, Groups: []string{role.GetName()}, - Principals: []string{user}, + Principals: []string{user, teleport.SSHSessionJoinPrincipal}, KubernetesUsers: []string{user}, KubernetesGroups: []string{"system:masters"}, KubernetesCluster: "root-kube-cluster", @@ -397,7 +397,7 @@ func TestAuthenticateSSHUser(t *testing.T) { wantID = tlsca.Identity{ Username: user, Groups: []string{role.GetName()}, - Principals: []string{user}, + Principals: []string{user, teleport.SSHSessionJoinPrincipal}, KubernetesUsers: []string{user}, KubernetesGroups: []string{"system:masters"}, KubernetesCluster: "root-kube-cluster", @@ -439,7 +439,7 @@ func TestAuthenticateSSHUser(t *testing.T) { wantID = tlsca.Identity{ Username: user, Groups: []string{role.GetName()}, - Principals: []string{user}, + Principals: []string{user, teleport.SSHSessionJoinPrincipal}, KubernetesUsers: []string{user}, KubernetesGroups: []string{"system:masters"}, KubernetesCluster: "root-kube-cluster", @@ -472,7 +472,7 @@ func TestAuthenticateSSHUser(t *testing.T) { wantID = tlsca.Identity{ Username: user, Groups: []string{role.GetName()}, - Principals: []string{user}, + Principals: []string{user, teleport.SSHSessionJoinPrincipal}, KubernetesUsers: []string{user}, KubernetesGroups: []string{"system:masters"}, KubernetesCluster: "root-kube-cluster", diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index b03324fd8e043..8bd128518d9b9 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -291,8 +291,8 @@ func (a *ServerWithRoles) GetActiveSessionTrackers(ctx context.Context) ([]types var filteredSessions []types.SessionTracker - for _, session := range sessions { - evaluator := NewSessionAccessEvaluator(session.GetHostPolicySets(), session.GetSessionKind()) + for _, sess := range sessions { + evaluator := NewSessionAccessEvaluator(sess.GetHostPolicySets(), sess.GetSessionKind()) joinerRoles, err := a.authServer.GetRoles(ctx) if err != nil { return nil, trace.Wrap(err) @@ -300,10 +300,42 @@ func (a *ServerWithRoles) GetActiveSessionTrackers(ctx context.Context) ([]types modes, err := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles}) if err == nil || len(modes) > 0 { - filteredSessions = append(filteredSessions, session) + // Apply RFD 45 RBAC rules to the session if it's SSH. + // This is a bit of a hack. It converts to the old legacy format + // which we don't have all data for, luckily the fields we don't have aren't made available + // to the RBAC filter anyway. + if sess.GetKind() == types.KindSSHSession { + ruleCtx := &services.Context{User: a.context.User} + ruleCtx.SSHSession = &session.Session{ + ID: session.ID(sess.GetSessionID()), + Namespace: apidefaults.Namespace, + Login: sess.GetLogin(), + Created: sess.GetCreated(), + LastActive: a.authServer.GetClock().Now(), + ServerID: sess.GetAddress(), + ServerAddr: sess.GetAddress(), + ServerHostname: sess.GetHostname(), + ClusterName: sess.GetClusterName(), + } + + for _, participant := range sess.GetParticipants() { + // We only need to fill in User here since other fields get discarded anyway. + ruleCtx.SSHSession.Parties = append(ruleCtx.SSHSession.Parties, session.Party{ + User: participant.User, + }) + } + + // Skip past it if there's a deny rule in place blocking access. + if err := a.context.Checker.CheckAccessToRule(ruleCtx, apidefaults.Namespace, types.KindSSHSession, types.VerbList, true /* silent */); err != nil { + continue + } + } + + filteredSessions = append(filteredSessions, sess) + } else { + log.Warnf("Session %v is not allowed to join: %v", sess.GetSessionID(), err) } } - return filteredSessions, nil } diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 94a423834ed82..aad7d67f64040 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -461,7 +461,8 @@ func TestGenerateUserCertsWithRoleRequest(t *testing.T) { require.NoError(t, err) if len(tt.expectPrincipals) > 0 { - require.ElementsMatch(t, tt.expectPrincipals, userCert.ValidPrincipals, "principals must match") + expectPrincipals := append(tt.expectPrincipals, teleport.SSHSessionJoinPrincipal) + require.ElementsMatch(t, expectPrincipals, userCert.ValidPrincipals, "principals must match") } if tt.expectRoles != nil { diff --git a/lib/auth/session_access.go b/lib/auth/session_access.go index a271b7c9cb1b3..f1af91b4d6963 100644 --- a/lib/auth/session_access.go +++ b/lib/auth/session_access.go @@ -171,16 +171,21 @@ func (e *SessionAccessEvaluator) matchesKind(allow []string) bool { return false } +func HasV5Role(roles []types.Role) bool { + for _, role := range roles { + if role.GetVersion() == types.V5 { + return true + } + } + + return false +} + // CanJoin returns the modes a user has access to join a session with. // If the list is empty, the user doesn't have access to join the session at all. func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) ([]types.SessionParticipantMode, error) { - supported, err := e.supportsSessionAccessControls() - if err != nil { - return nil, trace.Wrap(err) - } - // If we don't support session access controls, return the default mode set that was supported prior to Moderated Sessions. - if !supported { + if !HasV5Role(user.Roles) { return preAccessControlsModes(e.kind), nil } @@ -217,16 +222,6 @@ type PolicyOptions struct { TerminateOnLeave bool } -func (e *SessionAccessEvaluator) hasPolicies() bool { - for _, policySet := range e.policySets { - if len(policySet.RequireSessionJoin) > 0 { - return true - } - } - - return false -} - // Generate a pretty-printed string of precise requirements for session start suitable for user display. func (e *SessionAccessEvaluator) PrettyRequirementsList() string { s := new(strings.Builder) @@ -262,16 +257,6 @@ func (e *SessionAccessEvaluator) extractApplicablePolicies(set *types.SessionTra // FulfilledFor checks if a given session may run with a list of participants. func (e *SessionAccessEvaluator) FulfilledFor(participants []SessionAccessContext) (bool, PolicyOptions, error) { - supported, err := e.supportsSessionAccessControls() - if err != nil { - return false, PolicyOptions{}, trace.Wrap(err) - } - - // If advanced access controls are supported or no require policies are defined, we allow by default. - if !e.hasPolicies() || !supported { - return true, PolicyOptions{TerminateOnLeave: true}, nil - } - options := PolicyOptions{TerminateOnLeave: true} // Check every policy set to check if it's fulfilled. @@ -341,28 +326,6 @@ policySetLoop: return true, options, nil } -// supportsSessionAccessControls checks if moderated sessions-style access controls can be applied to the session. -// If a set only has v4 or earlier roles, we don't want to apply the access checks to SSH sessions. -// -// This only applies to SSH sessions since they previously had no access control for joining sessions. -// We don't need this fallback behaviour for multiparty kubernetes since it's a new feature. -func (e *SessionAccessEvaluator) supportsSessionAccessControls() (bool, error) { - if e.kind == types.SSHSessionKind { - for _, policySet := range e.policySets { - switch policySet.Version { - case types.V1, types.V2, types.V3, types.V4: - return false, nil - case types.V5: - return true, nil - default: - return false, trace.BadParameter("unsupported role version: %v", policySet.Version) - } - } - } - - return true, nil -} - func preAccessControlsModes(kind types.SessionKind) []types.SessionParticipantMode { switch kind { case types.SSHSessionKind: diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 6f1fb00746a4d..30d2c32427747 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -1755,10 +1755,10 @@ func (s *TLSSuite) TestAccessRequest(c *check.C) { c.Assert(certRequests(userCerts.TLS), check.HasLen, 0) // verify that cert for user with no static logins is generated with - // exactly one login and that it is an invalid unix login (indicated + // exactly two logins and that the one that isn't a join principal is an invalid unix login (indicated // by preceding dash (-). logins := certLogins(userCerts.SSH) - c.Assert(len(logins), check.Equals, 1) + c.Assert(len(logins), check.Equals, 2) c.Assert(rune(logins[0][0]), check.Equals, '-') // attempt to apply request in PENDING state (should fail) @@ -1785,7 +1785,7 @@ func (s *TLSSuite) TestAccessRequest(c *check.C) { // verify that dynamically applied role granted a login, // which is is valid and has replaced the dummy login. logins = certLogins(userCerts.SSH) - c.Assert(len(logins), check.Equals, 1) + c.Assert(len(logins), check.Equals, 2) c.Assert(rune(logins[0][0]), check.Not(check.Equals), '-') elevatedCert, err := tls.X509KeyPair(userCerts.TLS, priv) diff --git a/lib/client/api.go b/lib/client/api.go index 2a681e76287c8..1feffa4aa1ba9 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -46,6 +46,7 @@ import ( "golang.org/x/crypto/ssh/agent" "golang.org/x/term" + "github.com/coreos/go-semver/semver" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" @@ -1535,6 +1536,27 @@ func (tc *TeleportClient) startPortForwarding(ctx context.Context, nodeClient *N } } +func getLegacySession(site auth.ClientI, namespace string, sessionID string, errMsg string) (*session.Session, error) { + sessions, err := site.GetSessions(namespace) + if err != nil { + return nil, trace.Wrap(err) + } + + var sess *session.Session + for _, s := range sessions { + if s.ID == session.ID(sessionID) { + sess = &s + break + } + } + + if sess == nil { + return nil, trace.NotFound(errMsg) + } + + return sess, nil +} + // Join connects to the existing/active SSH session func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipantMode, namespace string, sessionID session.ID, input io.Reader) (err error) { if namespace == "" { @@ -1560,54 +1582,47 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan return trace.Wrap(err) } - // find the session ID on the site: - sessions, err := site.GetSessions(namespace) + var sess *session.Session + ping, err := site.Ping(ctx) if err != nil { return trace.Wrap(err) } - var session *session.Session - for _, s := range sessions { - if s.ID == sessionID { - session = &s - break + // TODO(xacrimon): DELETE IN 10.0 + nodeLogin := teleport.SSHSessionJoinPrincipal + if semver.New(ping.ServerVersion).LessThan(*auth.MinSupportedModeratedSessionsVersion) { + nodeLogin = tc.HostLogin + sess, err = getLegacySession(site, namespace, string(sessionID), notFoundErrorMessage) + if err != nil { + return trace.Wrap(err) + } + } else { + sessions, err := site.GetActiveSessionTrackers(ctx) + if err != nil { + return trace.Wrap(err) } - } - if session == nil { - return trace.NotFound(notFoundErrorMessage) - } - // pick the 1st party of the session and use his server ID to connect to - if len(session.Parties) == 0 { - return trace.NotFound(notFoundErrorMessage) - } - serverID := session.Parties[0].ServerID + for _, sessionIter := range sessions { + if sessionIter.GetSessionID() == string(sessionID) { + sess = &session.Session{ + ID: session.ID(sessionIter.GetSessionID()), + Namespace: apidefaults.Namespace, + ServerID: sessionIter.GetAddress() + ":0", + } + break + } + } - // find a server address by its ID - nodes, err := site.GetNodes(ctx, namespace) - if err != nil { - return trace.Wrap(err) - } - var node types.Server - for _, n := range nodes { - if n.GetName() == serverID { - node = n - break + if sess == nil { + return trace.NotFound(notFoundErrorMessage) } } - if node == nil { - return trace.NotFound(notFoundErrorMessage) - } - target := node.GetAddr() - if target == "" { - // address is empty, try dialing by UUID instead - target = fmt.Sprintf("%s:0", serverID) - } + // connect to server: nc, err := proxyClient.ConnectToNode(ctx, NodeAddr{ - Addr: target, + Addr: sess.ServerID, Namespace: tc.Namespace, Cluster: tc.SiteName, - }, tc.Config.HostLogin, false) + }, nodeLogin, false) if err != nil { return trace.Wrap(err) } @@ -1623,13 +1638,13 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan if mode == types.SessionModeratorMode { beforeStart = func(out io.Writer) { nc.OnMFA = func() { - runPresenceTask(presenceCtx, out, site, tc, string(session.ID)) + runPresenceTask(presenceCtx, out, site, tc, string(sess.ID)) } } } // running shell with a given session means "join" it: - err = tc.runShell(ctx, nc, mode, session, beforeStart) + err = tc.runShell(ctx, nc, mode, sess, beforeStart) return trace.Wrap(err) } diff --git a/lib/client/client.go b/lib/client/client.go index 4d0c8298f3a88..66cb4c2cf156c 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -46,6 +46,7 @@ import ( "github.com/gravitational/teleport/lib/sshutils/scp" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/socks" + "github.com/moby/term" "github.com/gravitational/trace" ) @@ -1611,6 +1612,24 @@ func (c *NodeClient) dynamicListenAndForward(ctx context.Context, ln net.Listene } } +// GetRemoteTerminalSize fetches the terminal size of a given SSH session. +func (c *NodeClient) GetRemoteTerminalSize(sessionID string) (*term.Winsize, error) { + ok, payload, err := c.Client.SendRequest(teleport.TerminalSizeRequest, true, []byte(sessionID)) + if err != nil { + return nil, trace.Wrap(err) + } else if !ok { + return nil, trace.BadParameter("failed to get terminal size") + } + + ws := new(term.Winsize) + err = json.Unmarshal(payload, ws) + if err != nil { + return nil, trace.Wrap(err) + } + + return ws, nil +} + // Close closes client and it's operations func (c *NodeClient) Close() error { return c.Client.Close() diff --git a/lib/client/kubesession.go b/lib/client/kubesession.go index 1b07c0a41d29d..81b99245f87a5 100644 --- a/lib/client/kubesession.go +++ b/lib/client/kubesession.go @@ -166,7 +166,7 @@ func (s *KubeSession) handleMFA(ctx context.Context, tc *TeleportClient, mode ty return trace.Wrap(err) } - auth, err := proxy.ConnectToCluster(ctx, s.meta.GetClustername(), false) + auth, err := proxy.ConnectToCluster(ctx, s.meta.GetClusterName(), false) if err != nil { return trace.Wrap(err) } diff --git a/lib/client/session.go b/lib/client/session.go index 000140077b030..78f3a2862e242 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -129,16 +129,26 @@ func newSession(client *NodeClient, // if we're joining an existing session, we need to assume that session's // existing/current terminal size: if joinSession != nil { - ns.id = joinSession.ID - ns.namespace = joinSession.Namespace - tsize := joinSession.TerminalParams.Winsize() + id := string(joinSession.ID) + terminalSize := joinSession.TerminalParams + + // if these are zero, we're receiving data from a session tracker + if terminalSize.H == 0 && terminalSize.W == 0 { + tsize, err := client.GetRemoteTerminalSize(id) + if err != nil { + return nil, trace.Wrap(err) + } + terminalSize.W, terminalSize.H = int(tsize.Width), int(tsize.Height) + } + + ns.id = session.ID(id) + ns.namespace = joinSession.Namespace if ns.terminal.IsAttached() { - err = ns.terminal.Resize(int16(tsize.Width), int16(tsize.Height)) + err = ns.terminal.Resize(int16(terminalSize.W), int16(terminalSize.H)) if err != nil { log.Error(err) } - } // new session! } else { diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 9af5d05bc2da9..1984b7d99b3ea 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -256,7 +256,7 @@ func NewForwarder(cfg ForwarderConfig) (*Forwarder, error) { fwd.router.POST("/api/:ver/namespaces/:podNamespace/pods/:podName/portforward", fwd.withAuth(fwd.portForward)) fwd.router.GET("/api/:ver/namespaces/:podNamespace/pods/:podName/portforward", fwd.withAuth(fwd.portForward)) - fwd.router.GET("/api/:ver/teleport/join/:session", fwd.withAuth(fwd.join)) + fwd.router.GET("/api/:ver/teleport/join/:session", fwd.withAuthPassthrough(fwd.join)) fwd.router.NotFound = fwd.withAuthStd(fwd.catchAll) @@ -443,6 +443,21 @@ func (f *Forwarder) withAuthStd(handler handlerWithAuthFuncStd) http.HandlerFunc }, f.formatResponseError) } +// acquireConnectionLockWithIdentity acquires a connection lock under a given identity. +func (f *Forwarder) acquireConnectionLockWithIdentity(ctx context.Context, identity *authContext) error { + user := identity.Identity.GetIdentity().Username + roles, err := getRolesByName(f, identity.Identity.GetIdentity().Groups) + if err != nil { + return trace.Wrap(err) + } + + if err := f.acquireConnectionLock(ctx, user, roles); err != nil { + return trace.Wrap(err) + } + + return nil +} + func (f *Forwarder) withAuth(handler handlerWithAuthFunc) httprouter.Handle { return httplib.MakeHandlerWithErrorWriter(func(w http.ResponseWriter, req *http.Request, p httprouter.Params) (interface{}, error) { authContext, err := f.authenticate(req) @@ -452,14 +467,26 @@ func (f *Forwarder) withAuth(handler handlerWithAuthFunc) httprouter.Handle { if err := f.authorize(req.Context(), authContext); err != nil { return nil, trace.Wrap(err) } - - user := authContext.Identity.GetIdentity().Username - roles, err := getRolesByName(f, authContext.Identity.GetIdentity().Groups) + err = f.acquireConnectionLockWithIdentity(req.Context(), authContext) if err != nil { return nil, trace.Wrap(err) } + return handler(authContext, w, req, p) + }, f.formatResponseError) +} - if err := f.AcquireConnectionLock(req.Context(), user, roles); err != nil { +// withAuthPassthrough authenticates the request and fetches information but doesn't deny if the user +// doesn't have RBAC access to the Kubernetes cluster. +func (f *Forwarder) withAuthPassthrough(handler handlerWithAuthFunc) httprouter.Handle { + return httplib.MakeHandlerWithErrorWriter(func(w http.ResponseWriter, req *http.Request, p httprouter.Params) (interface{}, error) { + authContext, err := f.authenticate(req) + if err != nil { + if !trace.IsAccessDenied(err) && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + } + err = f.acquireConnectionLockWithIdentity(req.Context(), authContext) + if err != nil { return nil, trace.Wrap(err) } return handler(authContext, w, req, p) @@ -899,10 +926,10 @@ func wsProxy(wsSource *websocket.Conn, wsTarget *websocket.Conn) error { return trace.Wrap(err) } -// AcquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent. +// acquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent. // The semaphore is releasted when the request is returned/connection is closed. // Returns an error if a semaphore could not be acquired. -func (f *Forwarder) AcquireConnectionLock(ctx context.Context, user string, roles services.RoleSet) error { +func (f *Forwarder) acquireConnectionLock(ctx context.Context, user string, roles services.RoleSet) error { maxConnections := roles.MaxKubernetesConnections() if maxConnections == 0 { return nil diff --git a/lib/kube/proxy/forwarder_test.go b/lib/kube/proxy/forwarder_test.go index fba87ec4c4695..5386d81ea90e5 100644 --- a/lib/kube/proxy/forwarder_test.go +++ b/lib/kube/proxy/forwarder_test.go @@ -1067,7 +1067,8 @@ func newTestForwarder(ctx context.Context, cfg ForwarderConfig) *Forwarder { type mockSemaphoreClient struct { auth.ClientI - sem types.Semaphores + sem types.Semaphores + roles map[string]types.Role } func (m *mockSemaphoreClient) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) { @@ -1078,6 +1079,15 @@ func (m *mockSemaphoreClient) CancelSemaphoreLease(ctx context.Context, lease ty return m.sem.CancelSemaphoreLease(ctx, lease) } +func (m *mockSemaphoreClient) GetRole(ctx context.Context, name string) (types.Role, error) { + role, ok := m.roles[name] + if !ok { + return nil, trace.NotFound("role %q not found", name) + } + + return role, nil +} + func TestKubernetesConnectionLimit(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -1130,13 +1140,28 @@ func TestKubernetesConnectionLimit(t *testing.T) { require.NoError(t, err) sem := local.NewPresenceService(backend) - client := &mockSemaphoreClient{sem: sem} + client := &mockSemaphoreClient{ + sem: sem, + roles: map[string]types.Role{testCase.role.GetName(): testCase.role}, + } + forwarder := newTestForwarder(ctx, ForwarderConfig{ - AuthClient: client, + AuthClient: client, + CachingAuthClient: client, }) + identity := &authContext{ + Context: auth.Context{ + User: user, + Identity: auth.WrapIdentity(tlsca.Identity{ + Username: user.GetName(), + Groups: []string{testCase.role.GetName()}, + }), + }, + } + for i := 0; i < testCase.connections; i++ { - err = forwarder.AcquireConnectionLock(ctx, user.GetName(), services.NewRoleSet(testCase.role)) + err = forwarder.acquireConnectionLockWithIdentity(ctx, identity) if i == testCase.connections-1 { testCase.assert(t, err) } diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index 0615da51ca3f8..9a872377c5e4e 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshutils" @@ -433,6 +434,11 @@ func (h *AuthHandlers) canLoginWithRBAC(cert *ssh.Certificate, clusterName strin return trace.Wrap(err) } + // we don't need to check the RBAC for the node if they are only allowed to join sessions + if osUser == teleport.SSHSessionJoinPrincipal && auth.HasV5Role(roles) { + return nil + } + ap, err := h.c.AccessPoint.GetAuthPreference(ctx) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 4b968a05ee36c..91afab1fe2899 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -336,6 +336,9 @@ type ServerContext struct { // x11Config holds the xauth and XServer listener config for this session. x11Config *X11Config + + // JoinOnly is set if the connection was created using a join-only principal and may only be used to join other sessions. + JoinOnly bool } // NewServerContext creates a new *ServerContext which is used to pass and diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 98df7bf560a71..fa7bdb06770cc 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -896,6 +896,28 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, scx *srv.ServerContext) error { scx.Debugf("Handling request %v, want reply %v.", req.Type, req.WantReply) + // Certs with a join-only principal can only use a + // subset of all the possible request types. + if scx.JoinOnly { + switch req.Type { + case sshutils.PTYRequest: + return s.termHandlers.HandlePTYReq(ch, req, scx) + case sshutils.ShellRequest: + return s.termHandlers.HandleShell(ch, req, scx) + case sshutils.WindowChangeRequest: + return s.termHandlers.HandleWinChange(ch, req, scx) + case teleport.ForceTerminateRequest: + return s.termHandlers.HandleForceTerminate(ch, req, scx) + case sshutils.EnvRequest: + // We ignore all SSH setenv requests for join-only principals. + // SSH will send them anyway but it seems fine to silently drop them. + case sshutils.SubsystemRequest: + return s.handleSubsystem(ch, req, scx) + default: + return trace.AccessDenied("attempted %v request in join-only mode", req.Type) + } + } + switch req.Type { case sshutils.ExecRequest: return s.termHandlers.HandleExec(ch, req, scx) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 6b3a5683d5e9d..231452bc7e59f 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -925,6 +925,8 @@ func (s *Server) HandleRequest(r *ssh.Request) { s.handleRecordingProxy(r) case teleport.VersionRequest: s.handleVersionRequest(r) + case teleport.TerminalSizeRequest: + s.termHandlers.HandleTerminalSize(r) default: if r.WantReply { if err := r.Reply(false, nil); err != nil { @@ -1441,6 +1443,28 @@ func (s *Server) dispatch(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerConte } } + // Certs with a join-only principal can only use a + // subset of all the possible request types. + if ctx.JoinOnly { + switch req.Type { + case sshutils.PTYRequest: + return s.termHandlers.HandlePTYReq(ch, req, ctx) + case sshutils.ShellRequest: + return s.termHandlers.HandleShell(ch, req, ctx) + case sshutils.WindowChangeRequest: + return s.termHandlers.HandleWinChange(ch, req, ctx) + case teleport.ForceTerminateRequest: + return s.termHandlers.HandleForceTerminate(ch, req, ctx) + case sshutils.EnvRequest: + // We ignore all SSH setenv requests for join-only principals. + // SSH will send them anyway but it seems fine to silently drop them. + case sshutils.SubsystemRequest: + return s.handleSubsystem(ch, req, ctx) + default: + return trace.AccessDenied("attempted %v request in join-only mode", req.Type) + } + } + switch req.Type { case sshutils.ExecRequest: return s.termHandlers.HandleExec(ch, req, ctx) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 92447b728d97c..a811090a6612c 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -37,6 +37,7 @@ import ( rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" + "github.com/moby/term" "github.com/google/uuid" "github.com/gravitational/trace" @@ -183,6 +184,11 @@ func (s *SessionRegistry) OpenSession(ch ssh.Channel, ctx *ServerContext) error return nil } + + if ctx.JoinOnly { + return trace.AccessDenied("join-only mode was used to create this connection but attempted to create a new session.") + } + // session not found? need to create one. start by getting/generating an ID for it sid, found := ctx.GetEnv(sshutils.SessionEnvVar) if !found { @@ -259,6 +265,19 @@ func (s *SessionRegistry) ForceTerminate(ctx *ServerContext) error { return nil } +// GetTerminalSize fetches the terminal size of an active SSH session. +func (s *SessionRegistry) GetTerminalSize(sessionID string) (*term.Winsize, error) { + s.sessionsMux.Lock() + defer s.sessionsMux.Unlock() + + sess := s.sessions[rsession.ID(sessionID)] + if sess == nil { + return nil, trace.NotFound("No session found in context.") + } + + return sess.term.GetWinSize() +} + // NotifyWinChange is called to notify all members in the party that the PTY // size has changed. The notification is sent as a global SSH request and it // is the responsibility of the client to update it's window size upon receipt. @@ -1390,12 +1409,6 @@ func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) { // addParty is called when a new party joins the session. func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { - if s.login != p.login { - return trace.AccessDenied( - "can't switch users from %v to %v for session %v", - s.login, p.login, s.id) - } - s.mu.Lock() defer s.mu.Unlock() diff --git a/lib/srv/termhandlers.go b/lib/srv/termhandlers.go index 6633d46918420..ddce856b2208a 100644 --- a/lib/srv/termhandlers.go +++ b/lib/srv/termhandlers.go @@ -17,6 +17,8 @@ limitations under the License. package srv import ( + "encoding/json" + "golang.org/x/crypto/ssh" rsession "github.com/gravitational/teleport/lib/session" @@ -141,6 +143,22 @@ func (t *TermHandlers) HandleForceTerminate(ch ssh.Channel, req *ssh.Request, ct return trace.Wrap(err) } +func (t *TermHandlers) HandleTerminalSize(req *ssh.Request) error { + sessionID := string(req.Payload) + size, err := t.SessionRegistry.GetTerminalSize(sessionID) + if err != nil { + return trace.Wrap(err) + } + + payload, err := json.Marshal(size) + if err != nil { + return trace.Wrap(err) + } + + req.Reply(true, payload) + return nil +} + func parseExecRequest(req *ssh.Request, ctx *ServerContext) (Exec, error) { var err error diff --git a/tool/tsh/kube.go b/tool/tsh/kube.go index 33d397592c237..701ee71b50d06 100644 --- a/tool/tsh/kube.go +++ b/tool/tsh/kube.go @@ -135,7 +135,7 @@ func (c *kubeJoinCommand) run(cf *CLIConf) error { return trace.Wrap(err) } - cluster := meta.GetClustername() + cluster := meta.GetClusterName() kubeCluster := meta.GetKubeCluster() var k *client.Key