diff --git a/api/gen/proto/go/teleport/decision/v1alpha1/ssh_identity.pb.go b/api/gen/proto/go/teleport/decision/v1alpha1/ssh_identity.pb.go index 8045b6575cdf9..5ebc9995e65c6 100644 --- a/api/gen/proto/go/teleport/decision/v1alpha1/ssh_identity.pb.go +++ b/api/gen/proto/go/teleport/decision/v1alpha1/ssh_identity.pb.go @@ -289,7 +289,9 @@ type SSHIdentity struct { JoinToken string `protobuf:"bytes,34,opt,name=join_token,json=joinToken,proto3" json:"join_token,omitempty"` // ScopePin is an optional pin that ties the certificate to a specific scope and set of scoped roles. When // set, the Roles field must not be set. - ScopePin *v11.Pin `protobuf:"bytes,35,opt,name=scope_pin,json=scopePin,proto3" json:"scope_pin,omitempty"` + ScopePin *v11.Pin `protobuf:"bytes,35,opt,name=scope_pin,json=scopePin,proto3" json:"scope_pin,omitempty"` + // The scope associated with a host identity. + AgentScope string `protobuf:"bytes,36,opt,name=agent_scope,json=agentScope,proto3" json:"agent_scope,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -569,6 +571,13 @@ func (x *SSHIdentity) GetScopePin() *v11.Pin { return nil } +func (x *SSHIdentity) GetAgentScope() string { + if x != nil { + return x.AgentScope + } + return "" +} + // CertExtension represents a key/value for a certificate extension. This type must // be kept up to date with types.CertExtension. type CertExtension struct { @@ -654,7 +663,7 @@ const file_teleport_decision_v1alpha1_ssh_identity_proto_rawDesc = "" + "-teleport/decision/v1alpha1/ssh_identity.proto\x12\x1ateleport.decision.v1alpha1\x1a\x1fgoogle/protobuf/timestamp.proto\x1a-teleport/decision/v1alpha1/tls_identity.proto\x1a\x1fteleport/scopes/v1/scopes.proto\x1a\x1dteleport/trait/v1/trait.proto\"X\n" + "\fSSHAuthority\x12!\n" + "\fcluster_name\x18\x01 \x01(\tR\vclusterName\x12%\n" + - "\x0eauthority_type\x18\x02 \x01(\tR\rauthorityType\"\xef\v\n" + + "\x0eauthority_type\x18\x02 \x01(\tR\rauthorityType\"\x90\f\n" + "\vSSHIdentity\x12\x1f\n" + "\vvalid_after\x18\x01 \x01(\x04R\n" + "validAfter\x12!\n" + @@ -698,7 +707,9 @@ const file_teleport_decision_v1alpha1_ssh_identity_proto_rawDesc = "" + "\x0fgithub_username\x18! \x01(\tR\x0egithubUsername\x12\x1d\n" + "\n" + "join_token\x18\" \x01(\tR\tjoinToken\x124\n" + - "\tscope_pin\x18# \x01(\v2\x17.teleport.scopes.v1.PinR\bscopePin\"\xbf\x01\n" + + "\tscope_pin\x18# \x01(\v2\x17.teleport.scopes.v1.PinR\bscopePin\x12\x1f\n" + + "\vagent_scope\x18$ \x01(\tR\n" + + "agentScope\"\xbf\x01\n" + "\rCertExtension\x12A\n" + "\x04type\x18\x01 \x01(\x0e2-.teleport.decision.v1alpha1.CertExtensionTypeR\x04type\x12A\n" + "\x04mode\x18\x02 \x01(\x0e2-.teleport.decision.v1alpha1.CertExtensionModeR\x04mode\x12\x12\n" + diff --git a/api/proto/teleport/decision/v1alpha1/ssh_identity.proto b/api/proto/teleport/decision/v1alpha1/ssh_identity.proto index dba0207bb8290..add9a21342d18 100644 --- a/api/proto/teleport/decision/v1alpha1/ssh_identity.proto +++ b/api/proto/teleport/decision/v1alpha1/ssh_identity.proto @@ -164,6 +164,9 @@ message SSHIdentity { // ScopePin is an optional pin that ties the certificate to a specific scope and set of scoped roles. When // set, the Roles field must not be set. teleport.scopes.v1.Pin scope_pin = 35; + + // The scope associated with a host identity. + string agent_scope = 36; } // CertExtensionMode specifies the type of extension to use in the cert. This type diff --git a/api/types/provisioning.go b/api/types/provisioning.go index 485c0131b2308..bd63f564ffb97 100644 --- a/api/types/provisioning.go +++ b/api/types/provisioning.go @@ -178,6 +178,11 @@ type ProvisionToken interface { // join methods where the name is secret. This should be used when logging // the token name. GetSafeName() string + + // GetAssignedScope always returns an empty string because a [ProvisionToken] is always + // unscoped + GetAssignedScope() string + // Clone creates a copy of the token. Clone() ProvisionToken } @@ -652,6 +657,12 @@ func (p *ProvisionTokenV2) GetSafeName() string { return name } +// GetAssignedScope always returns an empty string because a [ProvisionTokenV2] is always +// unscoped +func (p *ProvisionTokenV2) GetAssignedScope() string { + return "" +} + // String returns the human readable representation of a provisioning token. func (p ProvisionTokenV2) String() string { expires := "never" diff --git a/constants.go b/constants.go index 73365df4a381b..47b4defc1846f 100644 --- a/constants.go +++ b/constants.go @@ -464,9 +464,14 @@ const ( ) const ( - // CertExtensionScopePin is used to pin a certificate to a specific scope and - // set of scoped roles. + // CertExtensionScopePin is used to pin a user certificate to a specific scope and + // set of scoped roles. This constrains a user's access to resources based on both + // the scoping rules and scoped roles defined. CertExtensionScopePin = "scope-pin@goteleport.com" + // CertExtensionAgentScope is used to pin an agent/host certificate to a specific scope. + // This constrains other identities' access to the agent itself as well as the agent's + // access to other resources based on scoping rules. + CertExtensionAgentScope = "agent-scope@goteleport.com" // CertExtensionPermitX11Forwarding allows X11 forwarding for certificate CertExtensionPermitX11Forwarding = "permit-X11-forwarding" // CertExtensionPermitAgentForwarding allows agent forwarding for certificate diff --git a/integration/utmp_integration_test.go b/integration/utmp_integration_test.go index 46125dd2094df..53897c4b00ebb 100644 --- a/integration/utmp_integration_test.go +++ b/integration/utmp_integration_test.go @@ -264,7 +264,7 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx { Role: types.RoleNode, PublicSSHKey: sshPublicKey, PublicTLSKey: tlsPublicKey, - }) + }, "") require.NoError(t, err) // set up user CA and set up a user that has access to the server diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 62556ddfa227b..f2a5e020588f5 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -5128,7 +5128,7 @@ func ExtractHostID(hostName string, clusterName string) (string, error) { // GenerateHostCerts generates new host certificates (signed // by the host certificate authority) for a node. -func (a *Server) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequest) (*proto.Certs, error) { +func (a *Server) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequest, scope string) (*proto.Certs, error) { if err := req.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } @@ -5262,6 +5262,7 @@ func (a *Server) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequ ClusterName: clusterName.GetClusterName(), SystemRole: req.Role, Principals: req.AdditionalPrincipals, + AgentScope: scope, }, }) if err != nil { @@ -5283,6 +5284,7 @@ func (a *Server) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequ Groups: []string{req.Role.String()}, TeleportCluster: clusterName.GetClusterName(), SystemRoles: systemRoles, + AgentScope: scope, } subject, err := identity.Subject() if err != nil { @@ -5315,6 +5317,7 @@ func (a *Server) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequ if err != nil { return nil, trace.Wrap(err) } + return &proto.Certs{ SSH: hostSSHCert, TLS: hostTLSCert, diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index bea1429393fc3..ee9d782f1ed46 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -660,7 +660,8 @@ func (a *ServerWithRoles) GenerateHostCerts(ctx context.Context, req *proto.Host if !a.hasBuiltinRole(req.Role) && req.Role != types.RoleInstance { return nil, trace.AccessDenied("roles do not match: %v and %v", existingRoles, req.Role) } - return a.authServer.GenerateHostCerts(ctx, req) + identity := a.context.Identity.GetIdentity() + return a.authServer.GenerateHostCerts(ctx, req, identity.AgentScope) } // checkAdditionalSystemRoles verifies additional system roles in host cert request. diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index ff826aa20a23c..2499dc50bbe05 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -7246,6 +7246,66 @@ func TestGenerateHostCert(t *testing.T) { } } +// newScopedTestServerForHost creates a self-cleaning `ServerWithRoles`, configured +// for a given host +func newScopedTestServerForHost(t *testing.T, srv *authtest.AuthServer, hostID, scope string, role types.SystemRole) *auth.ServerWithRoles { + authzContext := authz.ContextWithUser(t.Context(), authtest.TestScopedHost(srv.ClusterName, hostID, scope, role).I) + ctxIdentity, err := srv.Authorizer.Authorize(authzContext) + require.NoError(t, err) + + authWithRole := auth.NewServerWithRoles( + srv.AuthServer, + srv.AuditLog, + *ctxIdentity, + ) + + t.Cleanup(func() { authWithRole.Close() }) + + return authWithRole +} + +func TestGenerateHostCertsScoped(t *testing.T) { + ctx := context.Background() + srv := newTestTLSServer(t) + + scope := "/aa/bb" + hostID := "testhost" + roles := types.SystemRoles{types.RoleNode} + + s := newScopedTestServerForHost(t, srv.AuthServer, hostID, scope, types.RoleNode) + + _, sshPub, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + tlsKey, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.ECDSAP256) + require.NoError(t, err) + tlsPubPEM, err := keys.MarshalPublicKey(tlsKey.Public()) + require.NoError(t, err) + + certs, err := s.GenerateHostCerts(ctx, &proto.HostCertsRequest{ + PublicTLSKey: tlsPubPEM, + PublicSSHKey: sshPub, + HostID: hostID, + Role: types.RoleInstance, + SystemRoles: roles, + }) + require.NoError(t, err) + + // ensure scope encoded in TLS cert matches the auth identity + tlsCert, err := tlsca.ParseCertificatePEM(certs.TLS) + require.NoError(t, err) + + tlsIdent, err := tlsca.FromSubject(tlsCert.Subject, tlsCert.NotAfter) + require.NoError(t, err) + require.Equal(t, scope, tlsIdent.AgentScope) + + // ensure scope encoded in SSH cert matches the auth identity + sshCert, err := sshutils.ParseCertificate(certs.SSH) + require.NoError(t, err) + sshIdent, err := sshca.DecodeIdentity(sshCert) + require.NoError(t, err) + require.Equal(t, scope, sshIdent.AgentScope) +} + // TestLocalServiceRolesHavePermissionsForUploaderService verifies that all of Teleport's // builtin roles have permissions to execute the calls required by the uploader service. // This is because only one uploader service runs per Teleport process, and it will use diff --git a/lib/auth/authtest/authtest.go b/lib/auth/authtest/authtest.go index 0e5cfa829b36a..59b005350290f 100644 --- a/lib/auth/authtest/authtest.go +++ b/lib/auth/authtest/authtest.go @@ -696,7 +696,7 @@ func generateCertificate(authServer *auth.Server, identity TestIdentity) ([]byte PublicTLSKey: tlsPublicKeyPEM, PublicSSHKey: sshPublicKeyPEM, SystemRoles: id.AdditionalSystemRoles, - }) + }, "") if err != nil { return nil, nil, trace.Wrap(err) } @@ -709,7 +709,7 @@ func generateCertificate(authServer *auth.Server, identity TestIdentity) ([]byte Role: id.Role, PublicTLSKey: tlsPublicKeyPEM, PublicSSHKey: sshPublicKeyPEM, - }) + }, "") if err != nil { return nil, nil, trace.Wrap(err) } @@ -1044,6 +1044,24 @@ func TestBuiltin(role types.SystemRole) TestIdentity { } } +// TestScopedHost returns TestIdentity for a scoped host +func TestScopedHost(clusterName, hostID, scope string, role types.SystemRole) TestIdentity { + username := hostID + if clusterName != "" { + username = utils.HostFQDN(hostID, clusterName) + } + return TestIdentity{ + I: authz.BuiltinRole{ + Role: types.RoleInstance, + Username: username, + AdditionalSystemRoles: types.SystemRoles{role}, + Identity: tlsca.Identity{ + AgentScope: scope, + }, + }, + } +} + // TestServerID returns a TestIdentity for a node with the passed in serverID. func TestServerID(role types.SystemRole, serverID string) TestIdentity { return TestIdentity{ @@ -1321,7 +1339,7 @@ func NewServerIdentity(clt *auth.Server, hostID string, role types.SystemRole) ( Role: role, PublicSSHKey: ssh.MarshalAuthorizedKey(sshPubKey), PublicTLSKey: tlsPubKey, - }) + }, "") if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index ec84f0f2f6843..aaab17cb350c7 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -6070,9 +6070,10 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { authpb.RegisterJoinServiceServer(server, legacyJoinServiceServer) joinv1.RegisterJoinServiceServer(server, join.NewServer(&join.ServerConfig{ - Authorizer: cfg.Authorizer, - AuthService: cfg.AuthServer, - FIPS: cfg.AuthServer.fips, + Authorizer: cfg.Authorizer, + AuthService: cfg.AuthServer, + FIPS: cfg.AuthServer.fips, + ScopedTokenService: cfg.AuthServer.Services, })) integrationServiceServer, err := integrationv1.NewService(&integrationv1.ServiceConfig{ diff --git a/lib/auth/init.go b/lib/auth/init.go index bd58f1c055dfd..ec7ab01290ad9 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -1585,7 +1585,7 @@ func GenerateIdentity(a *Server, id state.IdentityID, additionalPrincipals, dnsN DNSNames: dnsNames, PublicSSHKey: ssh.MarshalAuthorizedKey(sshPub), PublicTLSKey: tlsPub, - }) + }, "") if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/join.go b/lib/auth/join.go index e630d0bcfa176..9ec86812d0dcf 100644 --- a/lib/auth/join.go +++ b/lib/auth/join.go @@ -44,6 +44,7 @@ import ( "github.com/gravitational/teleport/lib/join" "github.com/gravitational/teleport/lib/join/joinutils" "github.com/gravitational/teleport/lib/join/legacyjoin" + "github.com/gravitational/teleport/lib/join/provision" ) // checkTokenJoinRequestCommon checks all token join rules that are common to @@ -66,7 +67,7 @@ func (a *Server) checkTokenJoinRequestCommon(ctx context.Context, req *types.Reg return nil, trace.AccessDenied("%q can not join the cluster with role %q, %s", req.NodeName, req.Role, msg) } - if err := join.ProvisionTokenAllowsRole(provisionToken, req.Role); err != nil { + if err := join.TokenAllowsRole(provisionToken, req.Role); err != nil { return nil, trace.Wrap(err) } return provisionToken, nil @@ -348,13 +349,13 @@ func makeBotCertsParams(req *types.RegisterUsingTokenRequest, rawClaims any, att // of a cluster join attempt. func (a *Server) GenerateBotCertsForJoin( ctx context.Context, - provisionToken types.ProvisionToken, + token provision.Token, params *join.BotCertsParams, ) (*proto.Certs, string, error) { // bots use this endpoint but get a user cert // botResourceName must be set, enforced in CheckAndSetDefaults - botName := provisionToken.GetBotName() - joinMethod := provisionToken.GetJoinMethod() + botName := token.GetBotName() + joinMethod := token.GetJoinMethod() // Check this is a join method for bots we support. if !slices.Contains(machineidv1.SupportedJoinMethods, joinMethod) { @@ -391,7 +392,7 @@ func (a *Server) GenerateBotCertsForJoin( }, BotName: botName, Method: string(joinMethod), - TokenName: provisionToken.GetSafeName(), + TokenName: token.GetSafeName(), UserName: machineidv1.BotResourceName(botName), ConnectionMetadata: apievents.ConnectionMetadata{ RemoteAddr: params.RemoteAddr, @@ -416,7 +417,7 @@ func (a *Server) GenerateBotCertsForJoin( JoinMethod: string(joinMethod), } if joinMethod != types.JoinMethodToken { - params.Attrs.Meta.JoinTokenName = provisionToken.GetName() + params.Attrs.Meta.JoinTokenName = token.GetName() } auth := &machineidv1pb.BotInstanceStatusAuthentication{ @@ -424,8 +425,8 @@ func (a *Server) GenerateBotCertsForJoin( // TODO: GetSafeName may not return an appropriate value for later // comparison / locking purposes, and this also shouldn't contain // secrets. Should we hash it? - JoinToken: provisionToken.GetSafeName(), - JoinMethod: string(provisionToken.GetJoinMethod()), + JoinToken: token.GetSafeName(), + JoinMethod: string(token.GetJoinMethod()), PublicKey: params.PublicTLSKey, JoinAttrs: params.Attrs, } @@ -458,9 +459,9 @@ func (a *Server) GenerateBotCertsForJoin( if shouldDeleteToken { // delete ephemeral bot join tokens so they can't be re-used - if err := a.DeleteToken(ctx, provisionToken.GetName()); err != nil { + if err := a.DeleteToken(ctx, token.GetName()); err != nil { a.logger.WarnContext(ctx, "Could not delete bot provision token after generating certs", - "provision_token", provisionToken.GetSafeName(), + "provision_token", token.GetSafeName(), "error", err, ) } @@ -481,7 +482,7 @@ func (a *Server) GenerateBotCertsForJoin( // result of a cluster join attempt. func (a *Server) GenerateHostCertsForJoin( ctx context.Context, - provisionToken types.ProvisionToken, + token provision.Token, params *join.HostCertsParams, ) (*proto.Certs, error) { // instance certs include an additional field that specifies the list of @@ -489,7 +490,7 @@ func (a *Server) GenerateHostCertsForJoin( var systemRoles types.SystemRoles if params.SystemRole == types.RoleInstance { systemRolesSet := make(map[types.SystemRole]struct{}) - for _, r := range provisionToken.GetRoles() { + for _, r := range token.GetRoles() { if r.IsLocalService() { systemRolesSet[r] = struct{}{} } else { @@ -518,7 +519,7 @@ func (a *Server) GenerateHostCertsForJoin( RemoteAddr: params.RemoteAddr, DNSNames: params.DNSNames, SystemRoles: systemRoles, - }) + }, token.GetAssignedScope()) if err != nil { return nil, trace.Wrap(err) } @@ -548,9 +549,9 @@ func (a *Server) GenerateHostCertsForJoin( }, NodeName: params.HostName, Role: string(params.SystemRole), - Method: string(provisionToken.GetJoinMethod()), - TokenName: provisionToken.GetSafeName(), - TokenExpires: provisionToken.Expiry(), + Method: string(token.GetJoinMethod()), + TokenName: token.GetSafeName(), + TokenExpires: token.Expiry(), HostID: params.HostID, ConnectionMetadata: apievents.ConnectionMetadata{ RemoteAddr: params.RemoteAddr, diff --git a/lib/auth/register.go b/lib/auth/register.go index 43120e2c16428..fa4cd12d60489 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -70,7 +70,7 @@ func LocalRegister(id state.IdentityID, authServer *Server, additionalPrincipals PublicSSHKey: ssh.MarshalAuthorizedKey(sshPub), PublicTLSKey: tlsPub, SystemRoles: systemRoles, - }) + }, "") if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/scopes/joining/service_test.go b/lib/auth/scopes/joining/service_test.go index 526b18c7774a0..66302ed014576 100644 --- a/lib/auth/scopes/joining/service_test.go +++ b/lib/auth/scopes/joining/service_test.go @@ -30,7 +30,6 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" - "github.com/gravitational/teleport/api/defaults" headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" joiningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/scopes/joining/v1" scopesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/scopes/v1" @@ -63,8 +62,7 @@ func TestScopedJoiningService(t *testing.T) { Kind: types.KindScopedToken, Version: types.V1, Metadata: &headerv1.Metadata{ - Name: "testtoken", - Namespace: defaults.Namespace, + Name: "testtoken", }, Scope: "/test", Spec: &joiningv1.ScopedTokenSpec{ diff --git a/lib/auth/state/identity.go b/lib/auth/state/identity.go index 4c99fbe5520c9..6d06afcc549de 100644 --- a/lib/auth/state/identity.go +++ b/lib/auth/state/identity.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/trace" "golang.org/x/crypto/ssh" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" @@ -62,7 +63,7 @@ func (id *IdentityID) String() string { // Identity is collection of certificates and signers that represent server identity type Identity struct { - // ID specifies server unique ID, name and role + // ID specifies server unique ID, name, role, and scope ID IdentityID // KeyBytes is a PEM encoded private key KeyBytes []byte @@ -85,6 +86,8 @@ type Identity struct { ClusterName string // SystemRoles is a list of additional system roles. SystemRoles []string + // AgentScope is the scope an identity is constrained to. + AgentScope string } // HasSystemRole checks if this identity encompasses the supplied system role. @@ -330,6 +333,7 @@ func ReadSSHIdentityFromKeyPair(keyBytes, certBytes []byte) (*Identity, error) { return nil, trace.BadParameter("missing cert extension %v", utils.CertExtensionAuthority) } + agentScope := cert.Permissions.Extensions[teleport.CertExtensionAgentScope] return &Identity{ ID: IdentityID{HostUUID: cert.ValidPrincipals[0], Role: role}, ClusterName: clusterName, @@ -337,5 +341,6 @@ func ReadSSHIdentityFromKeyPair(keyBytes, certBytes []byte) (*Identity, error) { CertBytes: certBytes, KeySigner: certSigner, Cert: cert, + AgentScope: agentScope, }, nil } diff --git a/lib/decision/ssh_identity.go b/lib/decision/ssh_identity.go index 8ed4c6a0b7607..a6a4cf3bb4f51 100644 --- a/lib/decision/ssh_identity.go +++ b/lib/decision/ssh_identity.go @@ -67,6 +67,7 @@ func SSHIdentityToSSHCA(id *decisionpb.SSHIdentity) *sshca.Identity { DeviceCredentialID: id.DeviceCredentialId, GitHubUserID: id.GithubUserId, GitHubUsername: id.GithubUsername, + AgentScope: id.AgentScope, } } @@ -111,6 +112,7 @@ func SSHIdentityFromSSHCA(id *sshca.Identity) *decisionpb.SSHIdentity { GithubUserId: id.GitHubUserID, GithubUsername: id.GitHubUsername, JoinToken: id.JoinToken, + AgentScope: id.AgentScope, } } diff --git a/lib/decision/ssh_identity_test.go b/lib/decision/ssh_identity_test.go index 45686eb47fa8f..c16e0141be2f3 100644 --- a/lib/decision/ssh_identity_test.go +++ b/lib/decision/ssh_identity_test.go @@ -88,6 +88,7 @@ func TestSSHIdentityConversion(t *testing.T) { DeviceCredentialID: "cred", GitHubUserID: "github", GitHubUsername: "ghuser", + AgentScope: "/foo", } ignores := []string{ diff --git a/lib/join/boundkeypair/boundkeypair.go b/lib/join/boundkeypair/boundkeypair.go index dd9c8527429a3..f58c11d472849 100644 --- a/lib/join/boundkeypair/boundkeypair.go +++ b/lib/join/boundkeypair/boundkeypair.go @@ -40,6 +40,7 @@ import ( "github.com/gravitational/teleport/lib/join/internal/authz" "github.com/gravitational/teleport/lib/join/internal/diagnostic" "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/join/provision" "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/services/readonly" libsshutils "github.com/gravitational/teleport/lib/sshutils" @@ -529,7 +530,7 @@ type JoinParams struct { // Diag is the join attempt diagnostic. Diag *diagnostic.Diagnostic // ProvisionToken is the provision token used for the join attempt. - ProvisionToken types.ProvisionToken + ProvisionToken provision.Token // ClientInit is the ClientInit message sent by the joining client. ClientInit *messages.ClientInit // BoundKeypairInit is the BoundKeypairInit message sent by the joining client. diff --git a/lib/join/ec2join/ec2.go b/lib/join/ec2join/ec2.go index 854467e7b5e34..24cdae10a639a 100644 --- a/lib/join/ec2join/ec2.go +++ b/lib/join/ec2join/ec2.go @@ -41,6 +41,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/join/provision" "github.com/gravitational/teleport/lib/services" awsutils "github.com/gravitational/teleport/lib/utils/aws" "github.com/gravitational/teleport/lib/utils/aws/stsutils" @@ -157,7 +158,7 @@ func parseAndVerifyIID(iidBytes []byte) (*imds.InstanceIdentityDocument, error) return &iid, nil } -func checkPendingTime(iid *imds.InstanceIdentityDocument, provisionToken types.ProvisionToken, clock clockwork.Clock) error { +func checkPendingTime(iid *imds.InstanceIdentityDocument, provisionToken provision.Token, clock clockwork.Clock) error { timeSinceInstanceStart := clock.Since(iid.PendingTime) // Sanity check IID is not from the future. Allow 1 minute of clock drift. if timeSinceInstanceStart < -1*time.Minute { @@ -342,7 +343,7 @@ func tryToDetectIdentityReuse(ctx context.Context, params *CheckEC2RequestParams // CheckEC2RequestParams holds parameters for checking an EC2-method join request. type CheckEC2RequestParams struct { // ProvisionToken is the provision token being used. - ProvisionToken types.ProvisionToken + ProvisionToken provision.Token // Role is the system role being requested. Role types.SystemRole // Document is a signed EC2 Instance Identity Document. diff --git a/lib/join/iamjoin/iam.go b/lib/join/iamjoin/iam.go index c676042608439..9a8a7f1772ea2 100644 --- a/lib/join/iamjoin/iam.go +++ b/lib/join/iamjoin/iam.go @@ -35,6 +35,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/join/iam" "github.com/gravitational/teleport/lib/join/joinutils" + "github.com/gravitational/teleport/lib/join/provision" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/aws" ) @@ -265,7 +266,7 @@ type CheckIAMRequestParams struct { // Challenge is the challenge that was sent to the client. Challenge string // ProvisionToken is the provision token being used. - ProvisionToken types.ProvisionToken + ProvisionToken provision.Token // STSIdentityRequest is the signed sts:GetCallerIdentity request sent by // the client in response to the challenge. STSIdentityRequest []byte diff --git a/lib/join/internal/authz/authz.go b/lib/join/internal/authz/authz.go index 54fc4d92abfa1..b30a92894ceb1 100644 --- a/lib/join/internal/authz/authz.go +++ b/lib/join/internal/authz/authz.go @@ -39,4 +39,6 @@ type Context struct { BotGeneration uint64 // BotInstanceID is an authenticated Bot ID. BotInstanceID string + // Scope is the assigned scope of the authenticated client. + Scope string } diff --git a/lib/join/join_test.go b/lib/join/join_test.go index 2c1d071e2ab30..4503cc37b7781 100644 --- a/lib/join/join_test.go +++ b/lib/join/join_test.go @@ -32,10 +32,13 @@ import ( "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" "github.com/gravitational/teleport/api/constants" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" joinv1proto "github.com/gravitational/teleport/api/gen/proto/go/teleport/join/v1" + joiningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/scopes/joining/v1" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" @@ -52,7 +55,7 @@ import ( "github.com/gravitational/teleport/lib/utils/testutils" ) -// TestJoin tests the full cycle of proxy and node joining via the join service. +// TestJoinToken tests the full cycle of proxy and node joining via the join service. // // It first sets up a fake auth service running the gRPC join service. // @@ -62,7 +65,7 @@ import ( // // Finally, it tests various scenarios where a node attempts to join by // connecting to the proxy's gRPC join service. -func TestJoin(t *testing.T) { +func TestJoinToken(t *testing.T) { t.Parallel() token1, err := types.NewProvisionTokenFromSpec("token1", time.Now().Add(time.Minute), types.ProvisionTokenSpecV2{ @@ -84,10 +87,45 @@ func TestJoin(t *testing.T) { }) require.NoError(t, err) + scopedToken1 := &joiningv1.ScopedToken{ + Kind: types.KindScopedToken, + Version: types.V1, + Scope: "/aa", + Metadata: &headerv1.Metadata{ + Name: "scoped1", + }, + Spec: &joiningv1.ScopedTokenSpec{ + AssignedScope: "/aa/bb", + Roles: []string{types.RoleNode.String()}, + JoinMethod: string(types.JoinMethodToken), + }, + } + scopedToken2 := proto.CloneOf(scopedToken1) + scopedToken2.Metadata.Name = "scoped2" + scopedToken2.Spec.AssignedScope = "/aa/cc" + + scopedToken3 := proto.CloneOf(scopedToken1) + scopedToken3.Metadata.Name = "scoped3" + authService := newFakeAuthService(t) require.NoError(t, authService.Auth().UpsertToken(t.Context(), token1)) require.NoError(t, authService.Auth().UpsertToken(t.Context(), token2)) + _, err = authService.Auth().CreateScopedToken(t.Context(), &joiningv1.CreateScopedTokenRequest{ + Token: scopedToken1, + }) + require.NoError(t, err) + + _, err = authService.Auth().CreateScopedToken(t.Context(), &joiningv1.CreateScopedTokenRequest{ + Token: scopedToken2, + }) + require.NoError(t, err) + + _, err = authService.Auth().CreateScopedToken(t.Context(), &joiningv1.CreateScopedTokenRequest{ + Token: scopedToken3, + }) + require.NoError(t, err) + proxy := newFakeProxy(authService) proxy.join(t) proxyListener, err := net.Listen("tcp", "127.0.0.1:0") @@ -175,6 +213,89 @@ func TestJoin(t *testing.T) { require.ElementsMatch(t, expectedSystemRoles, newIdentity.SystemRoles) }) + t.Run("join and rejoin with scoped token", func(t *testing.T) { + // Node initially joins by connecting to the proxy's gRPC service. + identity, err := joinViaProxy( + t.Context(), + scopedToken1.GetMetadata().GetName(), + proxyListener.Addr(), + ) + require.NoError(t, err) + // Make sure the result contains a host ID and expected certificate roles. + require.NotEmpty(t, identity.ID.HostUUID) + require.Equal(t, types.RoleInstance, identity.ID.Role) + expectedSystemRoles := slices.DeleteFunc( + scopedToken1.GetSpec().GetRoles(), + func(s string) bool { return s == types.RoleInstance.String() }, + ) + require.ElementsMatch(t, expectedSystemRoles, identity.SystemRoles) + + require.Equal(t, scopedToken1.GetSpec().GetAssignedScope(), identity.AgentScope) + // Build an auth client with the new identity. + tlsConfig, err := identity.TLSConfig(nil /*cipherSuites*/) + require.NoError(t, err) + authClient, err := authService.TLS.NewClientWithCert(tlsConfig.Certificates[0]) + require.NoError(t, err) + + // Node can rejoin with a different token assigning the same scope + // by dialing the auth service with an auth client authenticated with + // its original credentials. + // + // It should get back its original host ID and the combined roles of + // its original certificate and the new token. + newIdentity, err := rejoinViaAuthClient( + t.Context(), + scopedToken3.GetMetadata().GetName(), + authClient, + ) + require.NoError(t, err) + require.Equal(t, identity.AgentScope, newIdentity.AgentScope) + require.Equal(t, identity.ID.HostUUID, newIdentity.ID.HostUUID) + require.Equal(t, identity.ID.NodeName, newIdentity.ID.NodeName) + require.Equal(t, identity.ID.Role, newIdentity.ID.Role) + expectedSystemRoles = slices.DeleteFunc( + apiutils.Deduplicate(slices.Concat( + scopedToken1.GetSpec().GetRoles(), + scopedToken3.GetSpec().GetRoles(), + )), + func(s string) bool { return s == types.RoleInstance.String() }, + ) + require.ElementsMatch(t, expectedSystemRoles, newIdentity.SystemRoles) + }) + + t.Run("join and rejoin with mismatched scoped tokens", func(t *testing.T) { + // Node initially joins by connecting to the proxy's gRPC service. + identity, err := joinViaProxy( + t.Context(), + scopedToken1.GetMetadata().GetName(), + proxyListener.Addr(), + ) + require.NoError(t, err) + // Make sure the result contains a host ID and expected certificate roles. + require.NotEmpty(t, identity.ID.HostUUID) + require.Equal(t, types.RoleInstance, identity.ID.Role) + expectedSystemRoles := slices.DeleteFunc( + scopedToken1.GetSpec().GetRoles(), + func(s string) bool { return s == types.RoleInstance.String() }, + ) + require.ElementsMatch(t, expectedSystemRoles, identity.SystemRoles) + + require.Equal(t, scopedToken1.GetSpec().GetAssignedScope(), identity.AgentScope) + // Build an auth client with the new identity. + tlsConfig, err := identity.TLSConfig(nil /*cipherSuites*/) + require.NoError(t, err) + authClient, err := authService.TLS.NewClientWithCert(tlsConfig.Certificates[0]) + require.NoError(t, err) + + // Node cannot rejoin with a different token assigning a different scope. + _, err = rejoinViaAuthClient( + t.Context(), + scopedToken2.GetMetadata().GetName(), + authClient, + ) + require.Error(t, err) + }) + t.Run("join and rejoin with bad token", func(t *testing.T) { // Node joins by connecting to the proxy's gRPC service. identity, err := joinViaProxy( diff --git a/lib/join/provision/token.go b/lib/join/provision/token.go new file mode 100644 index 0000000000000..59cf1823082d2 --- /dev/null +++ b/lib/join/provision/token.go @@ -0,0 +1,49 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package provision + +import ( + "time" + + "github.com/gravitational/teleport/api/types" +) + +// A Token is used in the join service to facilitate provisioning. +type Token interface { + // GetName returns the name of the token. + GetName() string + // GetSafeName returns the name of the token, sanitized appropriately for + // join methods where the name is secret. This should be used when logging + // the token name. + GetSafeName() string + // GetJoinMethod returns joining method that must be used with this token. + GetJoinMethod() types.JoinMethod + // GetRoles returns a list of teleport roles that will be granted to the + // resources provisioned with this token. + GetRoles() types.SystemRoles + // Expiry returns the token's expiration time. + Expiry() time.Time + // GetBotName returns the BotName field which must be set for joining bots. + GetBotName() string + // GetAssignedScope returns the scope that will be assigned to provisioned resources + // provisioned using the wrapped [joiningv1.ScopedToken]. + GetAssignedScope() string + // GetAllowRules returns the list of allow rules. + GetAllowRules() []*types.TokenRule + // GetAWSIIDTTL returns the TTL of EC2 IIDs + GetAWSIIDTTL() types.Duration +} diff --git a/lib/join/server.go b/lib/join/server.go index 3ad1e40edc52c..f1990d8b95675 100644 --- a/lib/join/server.go +++ b/lib/join/server.go @@ -26,6 +26,7 @@ import ( "log/slog" "slices" "strings" + "sync" "time" "github.com/gravitational/trace" @@ -49,6 +50,8 @@ import ( "github.com/gravitational/teleport/lib/join/internal/diagnostic" "github.com/gravitational/teleport/lib/join/internal/messages" "github.com/gravitational/teleport/lib/join/joinutils" + "github.com/gravitational/teleport/lib/join/provision" + "github.com/gravitational/teleport/lib/scopes/joining" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/utils" @@ -62,8 +65,8 @@ var log = logutils.NewPackageLogger(teleport.ComponentKey, "join") // JoinServer to implement joining. type AuthService interface { ValidateToken(ctx context.Context, tokenName string) (types.ProvisionToken, error) - GenerateHostCertsForJoin(ctx context.Context, provisionToken types.ProvisionToken, req *HostCertsParams) (*proto.Certs, error) - GenerateBotCertsForJoin(ctx context.Context, provisionToken types.ProvisionToken, req *BotCertsParams) (*proto.Certs, string, error) + GenerateHostCertsForJoin(ctx context.Context, token provision.Token, req *HostCertsParams) (*proto.Certs, error) + GenerateBotCertsForJoin(ctx context.Context, token provision.Token, req *BotCertsParams) (*proto.Certs, string, error) EmitAuditEvent(ctx context.Context, e apievents.AuditEvent) error GetAuthPreference(ctx context.Context) (types.AuthPreference, error) GetReadOnlyAuthPreference(context.Context) (readonly.AuthPreference, error) @@ -82,9 +85,10 @@ type AuthService interface { // ServerConfig holds configuration parameters for [Server]. type ServerConfig struct { - AuthService AuthService - Authorizer authz.Authorizer - FIPS bool + AuthService AuthService + Authorizer authz.Authorizer + FIPS bool + ScopedTokenService services.ScopedTokenService } // Server implements cluster joining for nodes and bots. @@ -99,6 +103,63 @@ func NewServer(cfg *ServerConfig) *Server { } } +// getProvisionToken attempts to resolve a name to a [provision.Token] by first attempting to +// fetch a [joiningv1.ScopedToken] and then falling back to a [types.ProvisionTokenV2] if a +// scoped token can not be found. +func (s *Server) getProvisionToken(ctx context.Context, name string) (provision.Token, error) { + var scoped provision.Token + var scopedErr error + + var classic provision.Token + var classicErr error + + wg := &sync.WaitGroup{} + wg.Go(func() { + tok, err := s.cfg.ScopedTokenService.UseScopedToken(ctx, name) + if err != nil { + scopedErr = err + return + } + + scoped, scopedErr = joining.NewToken(tok) + }) + wg.Go(func() { + // Fetch the provision token and validate that it is not expired. + classic, classicErr = s.cfg.AuthService.ValidateToken(ctx, name) + }) + wg.Wait() + + // we explicitly disallow a join if the provided token name returns both a scoped and classic provision token + if scoped != nil && classic != nil { + return nil, trace.AccessDenied("joining with an ambiguous token name is not permitted") + } + + if scoped != nil { + return scoped, nil + } + + if classic != nil { + return classic, nil + } + + // if both errors are [trace.NotFoundError], just return a single err + if trace.IsNotFound(scopedErr) && trace.IsNotFound(classicErr) { + return nil, trace.NotFound("token expired or not found") + } + + // prefer reporting errors other than [trace.NotFoundError] + if trace.IsNotFound(scopedErr) { + return nil, trace.Wrap(classicErr) + } + + if trace.IsNotFound(classicErr) { + return nil, trace.Wrap(scopedErr) + } + + // return both errors as an aggregate if we couldn't reasonably return one + return nil, trace.NewAggregate(scopedErr, classicErr) +} + // Join implements cluster joining for nodes and bots. // // It returns credentials for a node or bot to join the Teleport cluster using @@ -145,32 +206,36 @@ func (s *Server) Join(stream messages.ServerStream) (err error) { return trace.Wrap(err) } - // Fetch the provision token and validate that it is not expired. - provisionToken, err := s.cfg.AuthService.ValidateToken(ctx, clientInit.TokenName) + token, err := s.getProvisionToken(ctx, clientInit.TokenName) if err != nil { return trace.Wrap(err) } + // Set any diagnostic info we can get from the token. diag.Set(func(i *diagnostic.Info) { - i.SafeTokenName = provisionToken.GetSafeName() - i.TokenJoinMethod = string(configuredJoinMethod(provisionToken)) - i.TokenExpires = provisionToken.Expiry() - i.BotName = provisionToken.GetBotName() + i.SafeTokenName = token.GetSafeName() + i.TokenJoinMethod = string(configuredJoinMethod(token)) + i.TokenExpires = token.Expiry() + i.BotName = token.GetBotName() }) // Validate that the requested join method matches the join method // configured on the token, or that the client did not specify a specific // join method and allow the server to choose it from the token. - joinMethod, err := checkJoinMethod(provisionToken, clientInit.JoinMethod) + joinMethod, err := checkJoinMethod(token, clientInit.JoinMethod) if err != nil { return trace.Wrap(err) } // Assert that the provision token allows the requested system role. - if err := ProvisionTokenAllowsRole(provisionToken, types.SystemRole(clientInit.SystemRole)); err != nil { + if err := TokenAllowsRole(token, types.SystemRole(clientInit.SystemRole)); err != nil { return trace.Wrap(err) } + if authCtx.IsInstance && authCtx.Scope != token.GetAssignedScope() { + return trace.BadParameter("tried to re-join instance from scope %q into %q", authCtx.Scope, token.GetAssignedScope()) + } + authPref, err := s.cfg.AuthService.GetAuthPreference(ctx) if err != nil { return trace.Wrap(err, "getting cluster auth preference") @@ -187,7 +252,7 @@ func (s *Server) Join(stream messages.ServerStream) (err error) { } // Call out to the handler for the specific join method. - result, err := s.handleJoinMethod(stream, authCtx, clientInit, provisionToken, joinMethod) + result, err := s.handleJoinMethod(stream, authCtx, clientInit, token, joinMethod) if err != nil { return trace.Wrap(err) } @@ -200,20 +265,20 @@ func (s *Server) handleJoinMethod( stream messages.ServerStream, authCtx *joinauthz.Context, clientInit *messages.ClientInit, - provisionToken types.ProvisionToken, + token provision.Token, joinMethod types.JoinMethod, ) (messages.Response, error) { switch joinMethod { case types.JoinMethodToken: - return s.handleTokenJoin(stream, authCtx, clientInit, provisionToken) + return s.handleTokenJoin(stream, authCtx, clientInit, token) case types.JoinMethodBoundKeypair: - return s.handleBoundKeypairJoin(stream, authCtx, clientInit, provisionToken) + return s.handleBoundKeypairJoin(stream, authCtx, clientInit, token) case types.JoinMethodIAM: - return s.handleIAMJoin(stream, authCtx, clientInit, provisionToken) + return s.handleIAMJoin(stream, authCtx, clientInit, token) case types.JoinMethodEC2: - return s.handleEC2Join(stream, authCtx, clientInit, provisionToken) + return s.handleEC2Join(stream, authCtx, clientInit, token) case types.JoinMethodEnv0: - return s.handleOIDCJoin(stream, authCtx, clientInit, provisionToken, s.validateEnv0Token) + return s.handleOIDCJoin(stream, authCtx, clientInit, token, s.validateEnv0Token) default: // TODO(nklaassen): implement checks for all join methods. return nil, trace.NotImplemented("join method %s is not yet implemented by the new join service", joinMethod) @@ -290,11 +355,12 @@ func (s *Server) authenticate(ctx context.Context, diag *diagnostic.Diagnostic, HostID: hostID, BotInstanceID: botInstanceID, BotGeneration: botGeneration, + Scope: id.AgentScope, }, nil } -func checkJoinMethod(provisionToken types.ProvisionToken, requestedJoinMethod *string) (types.JoinMethod, error) { - tokenJoinMethod := configuredJoinMethod(provisionToken) +func checkJoinMethod(token provision.Token, requestedJoinMethod *string) (types.JoinMethod, error) { + tokenJoinMethod := configuredJoinMethod(token) if requestedJoinMethod == nil { // Auto join method mode, the client didn't specify so use whatever is on the token. return tokenJoinMethod, nil @@ -307,14 +373,14 @@ func checkJoinMethod(provisionToken types.ProvisionToken, requestedJoinMethod *s return tokenJoinMethod, nil } -// ProvisionTokenAllowsRole asserts that the given provision token allows the +// TokenAllowsRole asserts that the given provision token allows the // requested role, or else it returns an error. -func ProvisionTokenAllowsRole(provisionToken types.ProvisionToken, role types.SystemRole) error { +func TokenAllowsRole(token provision.Token, role types.SystemRole) error { // Instance certs can be requested if the provision token allows at least // one local service role (e.g. proxy, node, etc). if role == types.RoleInstance { hasLocalServiceRole := false - for _, role := range provisionToken.GetRoles() { + for _, role := range token.GetRoles() { if role.IsLocalService() { hasLocalServiceRole = true break @@ -326,7 +392,7 @@ func ProvisionTokenAllowsRole(provisionToken types.ProvisionToken, role types.Sy } // Make sure the caller is requesting a role allowed by the token. - if !provisionToken.GetRoles().Include(role) && role != types.RoleInstance { + if !token.GetRoles().Include(role) && role != types.RoleInstance { return trace.BadParameter("can not join the cluster, the token does not allow role %s", role) } @@ -339,15 +405,15 @@ func (s *Server) makeResult( authCtx *joinauthz.Context, clientInit *messages.ClientInit, clientParams *messages.ClientParams, - provisionToken types.ProvisionToken, + token provision.Token, rawClaims any, attrs *workloadidentityv1pb.JoinAttrs, ) (messages.Response, error) { switch types.SystemRole(clientInit.SystemRole) { case types.RoleInstance: - return s.makeHostResult(ctx, diag, authCtx, clientParams.HostParams, provisionToken, rawClaims) + return s.makeHostResult(ctx, diag, authCtx, clientParams.HostParams, token, rawClaims) case types.RoleBot: - result, _, err := s.makeBotResult(ctx, diag, authCtx, clientParams.BotParams, provisionToken, rawClaims, attrs) + result, _, err := s.makeBotResult(ctx, diag, authCtx, clientParams.BotParams, token, rawClaims, attrs) return result, trace.Wrap(err) default: return nil, trace.NotImplemented("new join service only supports Instance and Bot system roles, client requested %s", clientInit.SystemRole) @@ -359,14 +425,14 @@ func (s *Server) makeHostResult( diag *diagnostic.Diagnostic, authCtx *joinauthz.Context, hostParams *messages.HostParams, - provisionToken types.ProvisionToken, + token provision.Token, rawClaims any, ) (*messages.HostResult, error) { - certsParams, err := makeHostCertsParams(ctx, diag, authCtx, hostParams, configuredJoinMethod(provisionToken), rawClaims) + certsParams, err := makeHostCertsParams(ctx, diag, authCtx, hostParams, configuredJoinMethod(token), token.GetAssignedScope(), rawClaims) if err != nil { return nil, trace.Wrap(err) } - certs, err := s.cfg.AuthService.GenerateHostCertsForJoin(ctx, provisionToken, certsParams) + certs, err := s.cfg.AuthService.GenerateHostCertsForJoin(ctx, token, certsParams) if err != nil { return nil, trace.Wrap(err) } @@ -388,6 +454,7 @@ func makeHostCertsParams( authCtx *joinauthz.Context, hostParams *messages.HostParams, joinMethod types.JoinMethod, + scope string, rawClaims any, ) (*HostCertsParams, error) { // GenerateHostCertsForJoin requires the TLS key to be PEM-encoded. @@ -445,7 +512,7 @@ func (s *Server) makeBotResult( diag *diagnostic.Diagnostic, authCtx *joinauthz.Context, botParams *messages.BotParams, - provisionToken types.ProvisionToken, + token provision.Token, rawClaims any, attrs *workloadidentityv1pb.JoinAttrs, ) (*messages.BotResult, string, error) { @@ -453,7 +520,7 @@ func (s *Server) makeBotResult( if err != nil { return nil, "", trace.Wrap(err) } - certs, botInstanceID, err := s.cfg.AuthService.GenerateBotCertsForJoin(ctx, provisionToken, certsParams) + certs, botInstanceID, err := s.cfg.AuthService.GenerateBotCertsForJoin(ctx, token, certsParams) if err != nil { return nil, "", trace.Wrap(err) } @@ -639,7 +706,7 @@ func makeAuditEvent(info diagnostic.Info, attributesStruct *apievents.Struct) ap } } -func configuredJoinMethod(token types.ProvisionToken) types.JoinMethod { +func configuredJoinMethod(token provision.Token) types.JoinMethod { method := token.GetJoinMethod() if method == types.JoinMethodUnspecified { return types.JoinMethodToken diff --git a/lib/join/server_boundkeypair.go b/lib/join/server_boundkeypair.go index a7d28aff045eb..38d9f152ce993 100644 --- a/lib/join/server_boundkeypair.go +++ b/lib/join/server_boundkeypair.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/lib/join/internal/messages" "github.com/gravitational/teleport/lib/join/joinutils" "github.com/gravitational/teleport/lib/join/legacyjoin" + "github.com/gravitational/teleport/lib/join/provision" ) // handleBoundKeypairJoin handles join attempts for the bound keypair join @@ -58,7 +59,7 @@ func (s *Server) handleBoundKeypairJoin( stream messages.ServerStream, authCtx *authz.Context, clientInit *messages.ClientInit, - provisionToken types.ProvisionToken, + token provision.Token, ) (*messages.BotResult, error) { ctx := stream.Context() diag := stream.Diagnostic() @@ -67,6 +68,13 @@ func (s *Server) handleBoundKeypairJoin( if clientInit.SystemRole != types.RoleBot.String() { return nil, trace.BadParameter("bound keypair joining is only supported for bots") } + + // Scoped tokens currently validate against being created with the bot role, but just in case + // we'll check and return a more helpful error message if one happens to make it through. + if token.GetAssignedScope() != "" { + return nil, trace.BadParameter("bound keypair joining is not supported by scoped tokens") + } + boundKeypairInit, err := messages.RecvRequest[*messages.BoundKeypairInit](stream) if err != nil { return nil, trace.Wrap(err) @@ -97,7 +105,7 @@ func (s *Server) handleBoundKeypairJoin( return nil, "", trace.Wrap(err) } botCertsParams.PreviousBotInstanceID = previousBotInstanceID - protoCerts, botInstanceID, err := s.cfg.AuthService.GenerateBotCertsForJoin(ctx, provisionToken, botCertsParams) + protoCerts, botInstanceID, err := s.cfg.AuthService.GenerateBotCertsForJoin(ctx, token, botCertsParams) if err != nil { return nil, "", trace.Wrap(err) } @@ -111,7 +119,7 @@ func (s *Server) handleBoundKeypairJoin( AuthService: s.cfg.AuthService, AuthCtx: authCtx, Diag: diag, - ProvisionToken: provisionToken, + ProvisionToken: token, ClientInit: clientInit, BoundKeypairInit: boundKeypairInit, IssueChallenge: issueChallenge, @@ -183,7 +191,7 @@ func AdaptRegisterUsingBoundKeypairMethod( } // Assert that the provision token allows the requested system role. - if err := ProvisionTokenAllowsRole(provisionToken, req.JoinRequest.Role); err != nil { + if err := TokenAllowsRole(provisionToken, req.JoinRequest.Role); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/join/server_ec2.go b/lib/join/server_ec2.go index f1bd67a46061d..159cf4eb4b632 100644 --- a/lib/join/server_ec2.go +++ b/lib/join/server_ec2.go @@ -23,6 +23,7 @@ import ( "github.com/gravitational/teleport/lib/join/ec2join" "github.com/gravitational/teleport/lib/join/internal/authz" "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/join/provision" ) // handleEC2Join handles join attempts for the IAM join method. @@ -41,7 +42,7 @@ func (s *Server) handleEC2Join( stream messages.ServerStream, authCtx *authz.Context, clientInit *messages.ClientInit, - provisionToken types.ProvisionToken, + provisionToken provision.Token, ) (messages.Response, error) { // Receive the EC2Init message from the client. ec2Init, err := messages.RecvRequest[*messages.EC2Init](stream) diff --git a/lib/join/server_env0.go b/lib/join/server_env0.go index ed87e883e159e..838b7a6fc7c74 100644 --- a/lib/join/server_env0.go +++ b/lib/join/server_env0.go @@ -24,6 +24,7 @@ import ( workloadidentityv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/join/env0" + "github.com/gravitational/teleport/lib/join/provision" ) type Env0TokenValidator interface { @@ -37,7 +38,7 @@ type Env0TokenValidator interface { // suitable for use in `handleOIDCJoin` func (s *Server) validateEnv0Token( ctx context.Context, - provisionToken types.ProvisionToken, + provisionToken provision.Token, idToken []byte, ) (any, *workloadidentityv1.JoinAttrs, error) { verifiedIdentity, err := s.cfg.AuthService.GetEnv0IDTokenValidator().ValidateToken(ctx, idToken) diff --git a/lib/join/server_iam.go b/lib/join/server_iam.go index fc9e180b1ee57..75c3404b22d4c 100644 --- a/lib/join/server_iam.go +++ b/lib/join/server_iam.go @@ -20,11 +20,11 @@ import ( "github.com/gravitational/trace" workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" - "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/join/iamjoin" "github.com/gravitational/teleport/lib/join/internal/authz" "github.com/gravitational/teleport/lib/join/internal/diagnostic" "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/join/provision" ) // handleIAMJoin handles join attempts for the IAM join method. @@ -45,7 +45,7 @@ func (s *Server) handleIAMJoin( stream messages.ServerStream, authCtx *authz.Context, clientInit *messages.ClientInit, - provisionToken types.ProvisionToken, + token provision.Token, ) (messages.Response, error) { // Receive the IAMInit message from the client. iamInit, err := messages.RecvRequest[*messages.IAMInit](stream) @@ -76,7 +76,7 @@ func (s *Server) handleIAMJoin( // the verified identity matches allow rules in the provision token. verifiedIdentity, err := iamjoin.CheckIAMRequest(stream.Context(), &iamjoin.CheckIAMRequestParams{ Challenge: challenge, - ProvisionToken: provisionToken, + ProvisionToken: token, STSIdentityRequest: solution.STSIdentityRequest, HTTPClient: s.cfg.AuthService.GetHTTPClientForAWSSTS(), FIPS: s.cfg.FIPS, @@ -98,7 +98,7 @@ func (s *Server) handleIAMJoin( authCtx, clientInit, &iamInit.ClientParams, - provisionToken, + token, verifiedIdentity, &workloadidentityv1pb.JoinAttrs{ Iam: verifiedIdentity.JoinAttrs(), diff --git a/lib/join/server_oidc.go b/lib/join/server_oidc.go index 03d25078c6cf9..ca505b4467f43 100644 --- a/lib/join/server_oidc.go +++ b/lib/join/server_oidc.go @@ -24,17 +24,17 @@ import ( "github.com/gravitational/trace" workloadidentityv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" - "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/join/internal/authz" "github.com/gravitational/teleport/lib/join/internal/diagnostic" "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/join/provision" ) // oidcTokenValidator is a function type that validates an OIDC token and checks that // it matches an allow rule configured in the provision token. type oidcTokenValidator func( ctx context.Context, - provisionToken types.ProvisionToken, + provisionToken provision.Token, idToken []byte, ) (rawClaims any, joinAttrs *workloadidentityv1.JoinAttrs, err error) @@ -43,7 +43,7 @@ func (s *Server) handleOIDCJoin( stream messages.ServerStream, authCtx *authz.Context, clientInit *messages.ClientInit, - provisionToken types.ProvisionToken, + provisionToken provision.Token, validator oidcTokenValidator, ) (messages.Response, error) { // Receive the OIDCInit message from the client. diff --git a/lib/join/server_token.go b/lib/join/server_token.go index c8b11aa1d510b..fb62500863321 100644 --- a/lib/join/server_token.go +++ b/lib/join/server_token.go @@ -19,9 +19,9 @@ package join import ( "github.com/gravitational/trace" - "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/join/internal/authz" "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/join/provision" ) // handleTokenJoin handles join attempts for the token join method. @@ -29,7 +29,7 @@ func (s *Server) handleTokenJoin( stream messages.ServerStream, authCtx *authz.Context, clientInit *messages.ClientInit, - provisionToken types.ProvisionToken, + token provision.Token, ) (messages.Response, error) { // Receive the TokenInit message from the client. tokenInit, err := messages.RecvRequest[*messages.TokenInit](stream) @@ -47,7 +47,7 @@ func (s *Server) handleTokenJoin( authCtx, clientInit, &tokenInit.ClientParams, - provisionToken, + token, nil, /*rawClaims*/ nil, /*attrs*/ ) diff --git a/lib/join/token/scoped.go b/lib/join/token/scoped.go deleted file mode 100644 index b04669ceecbc5..0000000000000 --- a/lib/join/token/scoped.go +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Teleport - * Copyright (C) 2025 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package token - -import ( - "strings" - "time" - - "github.com/gravitational/trace" - - joiningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/scopes/joining/v1" - "github.com/gravitational/teleport/api/types" -) - -// Scoped wraps a [joiningv1.ScopedToken] such that it can be used to provision -// resources. -type Scoped struct { - token *joiningv1.ScopedToken - joinMethod types.JoinMethod - roles types.SystemRoles -} - -// NewScoped returns the wrapped version of the given [joiningv1.ScopedToken]. -// It will return an error if the configured join method is not a valid -// [types.JoinMethod] or if any of the configured roles are not a valid -// [types.SystemRole]. The validated join method and roles are cached on the -// [Scoped] wrapper itself so they can be read without repeating validation. -func NewScoped(token *joiningv1.ScopedToken) (*Scoped, error) { - joinMethod := types.JoinMethod(token.GetSpec().GetJoinMethod()) - if err := types.ValidateJoinMethod(joinMethod); err != nil { - return nil, trace.Wrap(err) - } - - roles, err := types.NewTeleportRoles(token.GetSpec().GetRoles()) - if err != nil { - return nil, trace.Wrap(err) - } - - return &Scoped{token: token, joinMethod: joinMethod, roles: roles}, nil -} - -// GetName returns the name of a [joiningv1.ScopedToken]. -func (s *Scoped) GetName() string { - return s.token.GetMetadata().GetName() -} - -// GetJoinMethod returns the cached [types.JoinMethod] generated when the -// [joiningv1.ScopedToken] was wrapped. -func (s *Scoped) GetJoinMethod() types.JoinMethod { - return s.joinMethod -} - -// GetRoles returns the cached [types.SystemRoles] generated when the -// [joiningv1.ScopedToken] was wrapped. -func (s *Scoped) GetRoles() types.SystemRoles { - return s.roles -} - -// GetSafeName returns the name of the scoped token, sanitized appropriately -// for join methods where the name is secret. This should be used when logging -// the token name. -func (s *Scoped) GetSafeName() string { - return GetSafeScopedTokenName(s.token) -} - -// Expiry returns the [time.Time] representing when the wrapped -// [joiningv1.ScopedToken] will expire. -func (s *Scoped) Expiry() time.Time { - return s.token.GetMetadata().GetExpires().AsTime() -} - -// GetSafeScopedTokenName returns the name of the scoped token, sanitized -// appropriately for join methods where the name is secret. This should be used -// when logging the token name. -func GetSafeScopedTokenName(token *joiningv1.ScopedToken) string { - name := token.GetMetadata().GetName() - if types.JoinMethod(token.GetSpec().GetJoinMethod()) != types.JoinMethodToken { - return name - } - - // If the token name is short, we just blank the whole thing. - if len(name) < 16 { - return strings.Repeat("*", len(name)) - } - - // If the token name is longer, we can show the last 25% of it to help - // the operator identify it. - hiddenBefore := int(0.75 * float64(len(name))) - name = name[hiddenBefore:] - name = strings.Repeat("*", hiddenBefore) + name - return name -} diff --git a/lib/scopes/joining/token.go b/lib/scopes/joining/token.go index 8b1880b7ee62d..3cc45021f3ad7 100644 --- a/lib/scopes/joining/token.go +++ b/lib/scopes/joining/token.go @@ -17,6 +17,9 @@ package joining import ( + "strings" + "time" + "github.com/gravitational/trace" joiningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/scopes/joining/v1" @@ -123,3 +126,140 @@ func WeakValidateToken(token *joiningv1.ScopedToken) error { return nil } + +// ValidateTokenForUse checks if a given scoped token can be used for +// provisioning. +func ValidateTokenForUse(token *joiningv1.ScopedToken) error { + if err := WeakValidateToken(token); err != nil { + return trace.Wrap(err) + } + + ttl := token.GetMetadata().GetExpires() + if ttl == nil || ttl.AsTime().IsZero() { + return nil + } + + now := time.Now().UTC() + if ttl.AsTime().Before(now) { + return trace.LimitExceeded("scoped token is expired") + } + + return nil +} + +// Token wraps a [joiningv1.ScopedToken] such that it can be used to provision +// resources. +type Token struct { + scoped *joiningv1.ScopedToken + joinMethod types.JoinMethod + roles types.SystemRoles +} + +// NewToken returns the wrapped version of the given [joiningv1.ScopedToken]. +// It will return an error if the configured join method is not a valid +// [types.JoinMethod] or if any of the configured roles are not a valid +// [types.SystemRole]. The validated join method and roles are cached on the +// [Scoped] wrapper itself so they can be read without repeating validation. +func NewToken(token *joiningv1.ScopedToken) (*Token, error) { + joinMethod := types.JoinMethod(token.GetSpec().GetJoinMethod()) + if err := types.ValidateJoinMethod(joinMethod); err != nil { + return nil, trace.Wrap(err) + } + + roles, err := types.NewTeleportRoles(token.GetSpec().GetRoles()) + if err != nil { + return nil, trace.Wrap(err) + } + + return &Token{scoped: token, joinMethod: joinMethod, roles: roles}, nil +} + +// GetName returns the name of a [joiningv1.ScopedToken]. +func (t *Token) GetName() string { + if t == nil { + return "" + } + + return t.scoped.GetMetadata().GetName() +} + +// GetJoinMethod returns the cached [types.JoinMethod] generated when the +// [joiningv1.ScopedToken] was wrapped. +func (t *Token) GetJoinMethod() types.JoinMethod { + if t == nil { + return types.JoinMethodUnspecified + } + + return t.joinMethod +} + +// GetRoles returns the cached [types.SystemRoles] generated when the +// [joiningv1.ScopedToken] was wrapped. +func (t *Token) GetRoles() types.SystemRoles { + if t == nil { + return nil + } + return t.roles +} + +// GetSafeName returns the name of the scoped token, sanitized appropriately +// for join methods where the name is secret. This should be used when logging +// the token name. +func (t *Token) GetSafeName() string { + return GetSafeScopedTokenName(t.scoped) +} + +// GetSafeScopedTokenName returns the name of the scoped token, sanitized +// appropriately for join methods where the name is secret. This should be used +// when logging the token name. +func GetSafeScopedTokenName(token *joiningv1.ScopedToken) string { + name := token.GetMetadata().GetName() + if types.JoinMethod(token.GetSpec().GetJoinMethod()) != types.JoinMethodToken { + return name + } + + // If the token name is short, we just blank the whole thing. + if len(name) < 16 { + return strings.Repeat("*", len(name)) + } + + // If the token name is longer, we can show the last 25% of it to help + // the operator identify it. + hiddenBefore := int(0.75 * float64(len(name))) + name = name[hiddenBefore:] + name = strings.Repeat("*", hiddenBefore) + name + return name +} + +// Expiry returns the [time.Time] representing when the wrapped +// [joiningv1.ScopedToken] will expire. +func (t *Token) Expiry() time.Time { + expiry := t.scoped.GetMetadata().GetExpires() + if expiry == nil { + return time.Time{} + } + + return expiry.AsTime() +} + +// GetBotName returns an empty string because scoped tokens do not currently +// support configuring a bot name. +func (t *Token) GetBotName() string { + return "" +} + +// GetAssignedScope returns the scope that will be assigned to resources +// provisioned using the wrapped [joiningv1.ScopedToken]. +func (t *Token) GetAssignedScope() string { + return t.scoped.GetSpec().GetAssignedScope() +} + +// GetAllowRules returns the list of allow rules. +func (t *Token) GetAllowRules() []*types.TokenRule { + return nil +} + +// GetAWSIIDTTL returns the TTL of EC2 IIDs +func (t *Token) GetAWSIIDTTL() types.Duration { + return types.NewDuration(0) +} diff --git a/lib/service/connect.go b/lib/service/connect.go index 94b002e3cc45c..33be9860f501d 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -438,6 +438,17 @@ func (process *TeleportProcess) getCertAuthority(conn *Connector, id types.CertA return conn.Client.GetCertAuthority(ctx, id, loadPrivateKeys) } +// localReRegister wraps an [*auth.Server] and removes the scope parameter from the signature of +// [auth.Server.GenerateHostCerts] +type localReRegister struct { + *auth.Server +} + +// GenerateHostCerts allows for generating host certs without providing a scope. +func (l localReRegister) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequest) (*proto.Certs, error) { + return l.Server.GenerateHostCerts(ctx, req, "") +} + // reRegister receives new identity credentials for proxy, node and auth. // In case if auth servers, the role is 'TeleportAdmin' and instead of using // TLS client this method uses the local auth server. @@ -453,7 +464,7 @@ func (process *TeleportProcess) reRegister(conn *Connector, additionalPrincipals var clt auth.ReRegisterClient = conn.Client var remoteAddr string if srv := process.getLocalAuth(); srv != nil { - clt = srv + clt = localReRegister{srv} // auth server typically extracts remote addr from conn. since we're using the local auth // directly we must supply a reasonable remote addr value. preferably the advertise IP, but // otherwise localhost. this behavior must be kept consistent with the equivalent behavior diff --git a/lib/services/local/scoped_tokens.go b/lib/services/local/scoped_tokens.go index 7cf0f12e16925..794219340fb17 100644 --- a/lib/services/local/scoped_tokens.go +++ b/lib/services/local/scoped_tokens.go @@ -84,6 +84,24 @@ func (s *ScopedTokenService) GetScopedToken(ctx context.Context, req *joiningv1. return &joiningv1.GetScopedTokenResponse{Token: token}, nil } +// UseScopedToken fetches a scoped join token by unique name and checks if it +// can be used for provisioning. Expired tokens will be deleted. +func (s *ScopedTokenService) UseScopedToken(ctx context.Context, name string) (*joiningv1.ScopedToken, error) { + token, err := s.svc.GetResource(ctx, name) + if err != nil { + return nil, trace.Wrap(err) + } + if err := joining.ValidateTokenForUse(token); err != nil { + if trace.IsLimitExceeded(err) { + if err := s.svc.DeleteResource(ctx, name); err != nil { + return nil, trace.LimitExceeded("cleaning up expired token: %v", err) + } + } + return nil, trace.Wrap(err) + } + return token, nil +} + func evalScopeFilter(filter *scopesv1.Filter, scope string) bool { if filter == nil { return true diff --git a/lib/services/local/scoped_tokens_test.go b/lib/services/local/scoped_tokens_test.go index b8bb5d4b338f7..49003fc6bcb45 100644 --- a/lib/services/local/scoped_tokens_test.go +++ b/lib/services/local/scoped_tokens_test.go @@ -20,14 +20,17 @@ import ( "cmp" "slices" "testing" + "time" gocmp "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/timestamppb" - "github.com/gravitational/teleport/api/defaults" headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" joiningv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/scopes/joining/v1" scopesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/scopes/v1" @@ -38,9 +41,14 @@ import ( ) func TestScopedTokenService(t *testing.T) { - bk, err := memory.New(memory.Config{}) + clock := clockwork.NewFakeClock() + clock.Advance(-30 * time.Hour) + bk, err := memory.New(memory.Config{ + Clock: clock, + }) require.NoError(t, err) - service, err := local.NewScopedTokenService(backend.NewSanitizer(bk)) + // service, err := local.NewScopedTokenService(backend.NewSanitizer(bk)) + service, err := local.NewScopedTokenService(bk) require.NoError(t, err) ctx := t.Context() @@ -49,8 +57,7 @@ func TestScopedTokenService(t *testing.T) { Kind: types.KindScopedToken, Version: types.V1, Metadata: &headerv1.Metadata{ - Name: "testtoken", - Namespace: defaults.Namespace, + Name: "testtoken", }, Scope: "/test", Spec: &joiningv1.ScopedTokenSpec{ @@ -75,6 +82,54 @@ func TestScopedTokenService(t *testing.T) { }) require.NoError(t, err) assert.Empty(t, gocmp.Diff(created.Token, fetched.Token, cmpOpts...)) + + _, err = service.DeleteScopedToken(ctx, &joiningv1.DeleteScopedTokenRequest{ + Name: fetched.Token.Metadata.Name, + }) + require.NoError(t, err) + _, err = service.GetScopedToken(ctx, &joiningv1.GetScopedTokenRequest{ + Name: fetched.Token.Metadata.Name, + }) + require.True(t, trace.IsNotFound(err)) + + expiredToken := proto.CloneOf(token) + expiredToken.Metadata.Name = "expiredtoken" + expiredToken.Metadata.Expires = timestamppb.New(time.Now().UTC().Add(-25 * time.Hour)) + + activeToken := proto.CloneOf(token) + activeToken.Metadata.Name = "activetoken" + activeToken.Metadata.Expires = timestamppb.New(time.Now().UTC().Add(25 * time.Hour)) + + expiredRes, err := service.CreateScopedToken(ctx, &joiningv1.CreateScopedTokenRequest{ + Token: expiredToken, + }) + require.NoError(t, err) + + activeRes, err := service.CreateScopedToken(ctx, &joiningv1.CreateScopedTokenRequest{ + Token: activeToken, + }) + require.NoError(t, err) + + // expired tokens should error and delete the token + expiredToken, err = service.UseScopedToken(ctx, expiredRes.Token.Metadata.Name) + require.True(t, trace.IsLimitExceeded(err)) + require.Nil(t, expiredToken) + + _, err = service.GetScopedToken(ctx, &joiningv1.GetScopedTokenRequest{ + Name: expiredRes.Token.Metadata.Name, + }) + require.True(t, trace.IsNotFound(err)) + + // active tokens should function like a get + activeToken, err = service.UseScopedToken(ctx, activeToken.Metadata.Name) + require.NoError(t, err) + assert.Empty(t, gocmp.Diff(activeToken, activeRes.Token, cmpOpts...)) + + fetchedActive, err := service.GetScopedToken(ctx, &joiningv1.GetScopedTokenRequest{ + Name: activeRes.Token.Metadata.Name, + }) + require.NoError(t, err) + assert.Empty(t, gocmp.Diff(activeRes.Token, fetchedActive.Token, cmpOpts...)) } func TestScopedTokenList(t *testing.T) { @@ -89,8 +144,7 @@ func TestScopedTokenList(t *testing.T) { Kind: types.KindScopedToken, Version: types.V1, Metadata: &headerv1.Metadata{ - Name: "test", - Namespace: defaults.Namespace, + Name: "test", }, Scope: "/test", Spec: &joiningv1.ScopedTokenSpec{ diff --git a/lib/services/scoped_tokens.go b/lib/services/scoped_tokens.go index 80b0234d9090a..7900d9dc37111 100644 --- a/lib/services/scoped_tokens.go +++ b/lib/services/scoped_tokens.go @@ -30,6 +30,10 @@ type ScopedTokenService interface { // GetScopedToken fetches a scoped join token by unique name GetScopedToken(ctx context.Context, req *joiningv1.GetScopedTokenRequest) (*joiningv1.GetScopedTokenResponse, error) + // UseScopedToken fetches a scoped join token by unique name and checks if it can + // be used for provisioning. If the token is expired, [UseScopedToken] should also + // delete it from the backend. + UseScopedToken(ctx context.Context, name string) (*joiningv1.ScopedToken, error) // ListScopedTokens retrieves a paginated list of scoped join tokens ListScopedTokens(ctx context.Context, req *joiningv1.ListScopedTokensRequest) (*joiningv1.ListScopedTokensResponse, error) diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index ba6c0fefb4043..9fc54ac3ae8c1 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -2421,7 +2421,7 @@ func newRawNode(t *testing.T, authSrv *auth.Server) *rawNode { DNSNames: []string{hostname}, PublicSSHKey: pub, PublicTLSKey: tlsPub, - }) + }, "") require.NoError(t, err) signer, err := sshutils.NewSigner(priv, certs.SSH) @@ -3305,7 +3305,7 @@ func newSigner(t testing.TB, ctx context.Context, testServer *authtest.Server) s Role: types.RoleNode, PublicSSHKey: pub, PublicTLSKey: tlsPub, - }) + }, "") require.NoError(t, err) // set up user CA and set up a user that has access to the server diff --git a/lib/sshca/identity.go b/lib/sshca/identity.go index 6de28b8facdb2..017516b01bce0 100644 --- a/lib/sshca/identity.go +++ b/lib/sshca/identity.go @@ -144,6 +144,8 @@ type Identity struct { // GitHubUsername indicates the GitHub username identified by the GitHub // connector. GitHubUsername string + // AgentScope is the scope this identity belongs to. + AgentScope string } // Encode encodes the identity into an ssh certificate. Note that the returned certificate is incomplete @@ -187,6 +189,10 @@ func (i *Identity) Encode(certFormat string) (*ssh.Certificate, error) { cert.Permissions.Extensions[utils.CertExtensionAuthority] = i.ClusterName } + if i.AgentScope != "" { + cert.Permissions.Extensions[teleport.CertExtensionAgentScope] = i.AgentScope + } + // --- user extensions --- if i.ScopePin != nil { @@ -411,6 +417,7 @@ func DecodeIdentity(cert *ssh.Certificate) (*Identity, error) { ident.ScopePin = &pin } + ident.AgentScope = takeValue(teleport.CertExtensionAgentScope) ident.PermitX11Forwarding = takeBool(teleport.CertExtensionPermitX11Forwarding) ident.PermitAgentForwarding = takeBool(teleport.CertExtensionPermitAgentForwarding) ident.PermitPortForwarding = takeBool(teleport.CertExtensionPermitPortForwarding) diff --git a/lib/sshca/identity_test.go b/lib/sshca/identity_test.go index 38d06bf48ae19..bbd45db43230f 100644 --- a/lib/sshca/identity_test.go +++ b/lib/sshca/identity_test.go @@ -90,6 +90,7 @@ func TestIdentityConversion(t *testing.T) { DeviceCredentialID: "cred", GitHubUserID: "github", GitHubUsername: "ghuser", + AgentScope: "/foo", } ignores := []string{ diff --git a/lib/sshca/sshca.go b/lib/sshca/sshca.go index 8ac14289823fa..dbc22643a3fa1 100644 --- a/lib/sshca/sshca.go +++ b/lib/sshca/sshca.go @@ -26,6 +26,7 @@ import ( "golang.org/x/crypto/ssh" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/lib/scopes" ) // Authority implements minimal key-management facility for generating OpenSSH @@ -78,6 +79,11 @@ func (r *HostCertificateRequest) Check() error { if err := r.Identity.SystemRole.Check(); err != nil { return trace.Wrap(err) } + if r.Identity.AgentScope != "" { + if err := scopes.StrongValidate(r.Identity.AgentScope); err != nil { + return trace.Wrap(err) + } + } return nil } diff --git a/lib/tlsca/ca.go b/lib/tlsca/ca.go index 8bd79ab750512..a0a2e6ee6ba65 100644 --- a/lib/tlsca/ca.go +++ b/lib/tlsca/ca.go @@ -124,6 +124,8 @@ type Identity struct { // ScopePin is an optional pin that ties the certificate to a specific scope and set of scoped roles. When // set, the Groups field must not be set. ScopePin *scopesv1.Pin + // AgentScope is the scope this identity belongs to. + AgentScope string // Impersonator is a username of a user impersonating this user Impersonator string // Groups is a list of groups (Teleport roles) encoded in the identity @@ -611,6 +613,9 @@ var ( // ScopePinASN1ExtensionOID is an extension OID that contains the scope pin // used to tie the certificate to a specific scope and set of scoped roles. ScopePinASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 2, 24} + // AgentScopeASN1ExtensionOID is an extension OID that contains the agent scope + // used to tie the certificate to a spec + AgentScopeASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 2, 25} ) // Device Trust OIDs. @@ -938,6 +943,14 @@ func (id *Identity) Subject() (pkix.Name, error) { }) } + if id.AgentScope != "" { + subject.ExtraNames = append(subject.ExtraNames, + pkix.AttributeTypeAndValue{ + Type: AgentScopeASN1ExtensionOID, + Value: id.AgentScope, + }) + } + if id.UserType != "" { subject.ExtraNames = append(subject.ExtraNames, pkix.AttributeTypeAndValue{ @@ -1239,6 +1252,9 @@ func FromSubject(subject pkix.Name, expires time.Time) (*Identity, error) { } id.ScopePin = &pin } + case attr.Type.Equal(AgentScopeASN1ExtensionOID): + id.AgentScope = attr.Value.(string) + case attr.Type.Equal(AllowedResourcesASN1ExtensionOID): allowedResourcesStr, ok := attr.Value.(string) if ok { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index abf45224feb97..cd1e8fdb5d229 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -331,7 +331,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { Role: types.RoleNode, PublicSSHKey: pub, PublicTLSKey: tlsPub, - }) + }, "") require.NoError(t, err) signer, err := sshutils.NewSigner(priv, certs.SSH) @@ -641,7 +641,7 @@ func (s *WebSuite) addNode(t *testing.T, uuid string, hostname string, address s Role: types.RoleNode, PublicSSHKey: pub, PublicTLSKey: tlsPub, - }) + }, "") require.NoError(t, err) signer, err := sshutils.NewSigner(priv, certs.SSH) @@ -8328,7 +8328,7 @@ func newWebPack(t *testing.T, numProxies int, opts ...webPackOptions) *webPack { Role: types.RoleNode, PublicSSHKey: pub, PublicTLSKey: tlsPub, - }) + }, "") require.NoError(t, err) signer, err := sshutils.NewSigner(priv, certs.SSH) diff --git a/tool/tctl/common/scoped_token_command.go b/tool/tctl/common/scoped_token_command.go index 8564574d0edfa..f731a9575d8f8 100644 --- a/tool/tctl/common/scoped_token_command.go +++ b/tool/tctl/common/scoped_token_command.go @@ -42,7 +42,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/itertools/stream" - "github.com/gravitational/teleport/lib/join/token" + "github.com/gravitational/teleport/lib/scopes/joining" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/utils" commonclient "github.com/gravitational/teleport/tool/tctl/common/client" @@ -271,7 +271,7 @@ func (c *ScopedTokensCommand) List(ctx context.Context, client *authclient.Clien if c.withSecrets { return tok.GetMetadata().GetName() } - return token.GetSafeScopedTokenName(tok) + return joining.GetSafeScopedTokenName(tok) } switch c.format { case teleport.JSON: