From 7eb2e3e916ffefb42fd9b0070386e67992f941ce Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Thu, 6 Nov 2025 16:32:15 -0800 Subject: [PATCH 1/3] Azure method support in new join service --- lib/auth/auth.go | 4 + lib/auth/bot_test.go | 92 ----- lib/auth/export_test.go | 31 -- lib/auth/join/join.go | 14 +- lib/auth/join_azure_legacy.go | 119 ++++++ .../join_azure.go => join/azurejoin/azure.go} | 274 ++++++------- lib/{auth => join/azurejoin}/azure_certs.go | 2 +- .../azurejoin}/azure_certs_test.go | 6 +- .../azurejoin}/join_azure_test.go | 382 ++++++++++++------ lib/join/joinclient/join.go | 3 + lib/join/joinclient/join_azure.go | 77 ++++ lib/join/server.go | 4 + lib/join/server_azure.go | 108 +++++ 13 files changed, 696 insertions(+), 420 deletions(-) create mode 100644 lib/auth/join_azure_legacy.go rename lib/{auth/join_azure.go => join/azurejoin/azure.go} (67%) rename lib/{auth => join/azurejoin}/azure_certs.go (99%) rename lib/{auth => join/azurejoin}/azure_certs_test.go (92%) rename lib/{auth => join/azurejoin}/join_azure_test.go (77%) create mode 100644 lib/join/joinclient/join_azure.go create mode 100644 lib/join/server_azure.go diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 74c44d6827cfe..92cba647e179c 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -114,6 +114,7 @@ import ( iterstream "github.com/gravitational/teleport/lib/itertools/stream" "github.com/gravitational/teleport/lib/join" "github.com/gravitational/teleport/lib/join/azuredevops" + "github.com/gravitational/teleport/lib/join/azurejoin" "github.com/gravitational/teleport/lib/join/bitbucket" joinboundkeypair "github.com/gravitational/teleport/lib/join/boundkeypair" "github.com/gravitational/teleport/lib/join/circleci" @@ -1339,6 +1340,9 @@ type Server struct { // override the implementation used in tests. env0IDTokenValidator join.Env0TokenValidator + // azureJoinConfig holds configuration for the Azure join method. + azureJoinConfig *azurejoin.AzureJoinConfig + // loadAllCAs tells tsh to load the host CAs for all clusters when trying to ssh into a node. loadAllCAs bool diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index dc6bbf3ccac0e..14160fb79aeb2 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -23,10 +23,6 @@ import ( "context" "crypto" "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/json" - "encoding/pem" "errors" "fmt" "io" @@ -36,10 +32,8 @@ import ( "text/template" "time" - "github.com/digitorus/pkcs7" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -58,16 +52,13 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/authtest" "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/testauthority" - "github.com/gravitational/teleport/lib/cloud/azure" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" - "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/join/iamjoin" "github.com/gravitational/teleport/lib/join/joinclient" "github.com/gravitational/teleport/lib/kube/token" @@ -741,89 +732,6 @@ func TestRegisterBot_RemoteAddr(t *testing.T) { require.NoError(t, err) checkCertLoginIP(t, certs.TLS, remoteAddr) }) - - t.Run("Azure method", func(t *testing.T) { - subID := uuid.NewString() - resourceGroup := "rg" - rsID := vmResourceID(subID, resourceGroup, "test-vm") - vmID := "vmID" - - accessToken, err := makeToken(rsID, "", a.GetClock().Now()) - require.NoError(t, err) - - // add token to auth server - azureTokenName := "azure-test-token" - azureToken, err := types.NewProvisionTokenFromSpec( - azureTokenName, - time.Now().Add(time.Minute), - types.ProvisionTokenSpecV2{ - Roles: []types.SystemRole{types.RoleBot}, - Azure: &types.ProvisionTokenSpecV2Azure{Allow: []*types.ProvisionTokenSpecV2Azure_Rule{{Subscription: subID}}}, - BotName: botName, - JoinMethod: types.JoinMethodAzure, - }) - require.NoError(t, err) - require.NoError(t, a.UpsertToken(ctx, azureToken)) - - vmClient := &mockAzureVMClient{ - vms: map[string]*azure.VirtualMachine{ - rsID: { - ID: rsID, - Name: "test-vm", - Subscription: subID, - ResourceGroup: resourceGroup, - VMID: vmID, - }, - }, - } - getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ - subID: vmClient, - }) - - tlsConfig, err := fixtures.LocalTLSConfig() - require.NoError(t, err) - - block, _ := pem.Decode(fixtures.LocalhostKey) - pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) - require.NoError(t, err) - - certs, err := a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { - ad := auth.AttestedData{ - Nonce: challenge, - SubscriptionID: subID, - ID: vmID, - } - adBytes, err := json.Marshal(&ad) - require.NoError(t, err) - s, err := pkcs7.NewSignedData(adBytes) - require.NoError(t, err) - require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{})) - signature, err := s.Finish() - require.NoError(t, err) - signedAD := auth.SignedAttestedData{ - Encoding: "pkcs7", - Signature: base64.StdEncoding.EncodeToString(signature), - } - signedADBytes, err := json.Marshal(&signedAD) - require.NoError(t, err) - - req := &proto.RegisterUsingAzureMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: azureTokenName, - HostID: "test-node", - Role: types.RoleBot, - PublicSSHKey: sshPubKey, - PublicTLSKey: tlsPubKey, - RemoteAddr: remoteAddr, - }, - AttestedData: signedADBytes, - AccessToken: accessToken, - } - return req, nil - }, auth.WithAzureCerts([]*x509.Certificate{tlsConfig.Certificate}), auth.WithAzureVerifyFunc(mockVerifyToken(nil)), auth.WithAzureVMClientGetter(getVMClient)) - require.NoError(t, err) - checkCertLoginIP(t, certs.TLS, remoteAddr) - }) } func responseFromAWSIdentity(id iamjoin.AWSIdentity) string { diff --git a/lib/auth/export_test.go b/lib/auth/export_test.go index a2eb46bc3a737..3d0dee5ce4e09 100644 --- a/lib/auth/export_test.go +++ b/lib/auth/export_test.go @@ -70,8 +70,6 @@ const ( MaxUserAgentLen = maxUserAgentLen ForwardedTag = forwardedTag - - AzureAccessTokenAudience = azureAccessTokenAudience ) var ( @@ -307,10 +305,6 @@ func TrimUserAgent(userAgent string) string { return trimUserAgent(userAgent) } -func IsAllowedDomain(cn string, domains []string) bool { - return isAllowedDomain(cn, domains) -} - func GetSnowflakeJWTParams(ctx context.Context, accountName, userName string, publicKey []byte) (string, string) { return getSnowflakeJWTParams(ctx, accountName, userName, publicKey) } @@ -336,31 +330,6 @@ func CheckHeaders(headers http.Header, challenge string, clock clockwork.Clock) } type GitHubManager = githubManager -type AttestedData = attestedData -type SignedAttestedData = signedAttestedData -type AzureRegisterOption = azureRegisterOption -type AzureRegisterConfig = azureRegisterConfig -type AzureVMClientGetter = vmClientGetter -type AzureVerifyTokenFunc = azureVerifyTokenFunc -type AccessTokenClaims = accessTokenClaims - -func WithAzureCerts(certs []*x509.Certificate) AzureRegisterOption { - return func(cfg *AzureRegisterConfig) { - cfg.certificateAuthorities = certs - } -} - -func WithAzureVerifyFunc(verify azureVerifyTokenFunc) AzureRegisterOption { - return func(cfg *AzureRegisterConfig) { - cfg.verify = verify - } -} - -func WithAzureVMClientGetter(getVMClient vmClientGetter) AzureRegisterOption { - return func(cfg *AzureRegisterConfig) { - cfg.getVMClient = getVMClient - } -} func (s *TLSServer) GRPCServer() *GRPCServer { return s.grpcServer diff --git a/lib/auth/join/join.go b/lib/auth/join/join.go index bd286b53928cb..91c7ad23d6c14 100644 --- a/lib/auth/join/join.go +++ b/lib/auth/join/join.go @@ -78,6 +78,15 @@ type AzureParams struct { // ClientID is the client ID of the managed identity for Teleport to assume // when authenticating a node. ClientID string + // IMDSClient overrides the client used to fetch data from Azure IMDS. + IMDSClient AzureIMDSClient +} + +// AzureIMDSClient is a client to Azure's IMDS. +type AzureIMDSClient interface { + IsAvailable(context.Context) bool + GetAttestedData(ctx context.Context, nonce string) ([]byte, error) + GetAccessToken(ctx context.Context, clientID string) (string, error) } // GitlabParams is the parameters specific to the gitlab join method. @@ -875,7 +884,10 @@ func registerUsingAzureMethod( ctx context.Context, client joinServiceClient, token string, hostKeys *newHostKeys, params RegisterParams, ) (*proto.Certs, error) { certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { - imds := azure.NewInstanceMetadataClient() + imds := params.AzureParams.IMDSClient + if imds == nil { + imds = azure.NewInstanceMetadataClient() + } if !imds.IsAvailable(ctx) { return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?") } diff --git a/lib/auth/join_azure_legacy.go b/lib/auth/join_azure_legacy.go new file mode 100644 index 0000000000000..9032f084e6aeb --- /dev/null +++ b/lib/auth/join_azure_legacy.go @@ -0,0 +1,119 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package auth + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/join/azurejoin" + "github.com/gravitational/teleport/lib/join/legacyjoin" +) + +// RegisterUsingAzureMethod registers the caller using the Azure join method +// and returns signed certs to join the cluster. +// +// The caller must provide a ChallengeResponseFunc which returns a +// *proto.RegisterUsingAzureMethodRequest with a signed attested data document +// including the challenge as a nonce. +// +// TODO(nklaassen): DELETE IN 20 when removing the legacy join service. +func (a *Server) RegisterUsingAzureMethod( + ctx context.Context, + challengeResponse client.RegisterAzureChallengeResponseFunc, +) (certs *proto.Certs, err error) { + var provisionToken types.ProvisionToken + var joinRequest *types.RegisterUsingTokenRequest + defer func() { + // Emit a log message and audit event on join failure. + if err != nil { + a.handleJoinFailure(ctx, err, provisionToken, nil, joinRequest) + } + }() + + if legacyjoin.Disabled() { + return nil, trace.Wrap(legacyjoin.ErrDisabled) + } + + challenge, err := azurejoin.GenerateAzureChallenge() + if err != nil { + return nil, trace.Wrap(err) + } + req, err := challengeResponse(challenge) + if err != nil { + return nil, trace.Wrap(err) + } + joinRequest = req.RegisterUsingTokenRequest + + if err := req.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest) + if err != nil { + return nil, trace.Wrap(err) + } + if provisionToken.GetJoinMethod() != types.JoinMethodAzure { + return nil, trace.AccessDenied("this token does not support the Azure join method") + } + + ptv2, ok := provisionToken.(*types.ProvisionTokenV2) + if !ok { + return nil, trace.Wrap(err, "Azure join method only supports ProvisionTokenV2, got %T", provisionToken) + } + + joinAttrs, err := azurejoin.CheckAzureRequest(ctx, azurejoin.CheckAzureRequestParams{ + AzureJoinConfig: a.GetAzureJoinConfig(), + Token: ptv2, + Challenge: challenge, + AttestedData: req.AttestedData, + AccessToken: req.AccessToken, + Logger: a.logger, + Clock: a.GetClock(), + }) + if err != nil { + return nil, trace.Wrap(err, "checking Azure challenge response") + } + + if req.RegisterUsingTokenRequest.Role == types.RoleBot { + params := makeBotCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/, &workloadidentityv1pb.JoinAttrs{ + Azure: joinAttrs, + }) + certs, _, err := a.GenerateBotCertsForJoin(ctx, provisionToken, params) + return certs, trace.Wrap(err) + } + params := makeHostCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/) + certs, err = a.GenerateHostCertsForJoin(ctx, provisionToken, params) + return certs, trace.Wrap(err) +} + +// GetAzureJoinConfig gets configuration options for azure joining. +func (a *Server) GetAzureJoinConfig() *azurejoin.AzureJoinConfig { + return a.azureJoinConfig +} + +// SetAzureJoinConfig sets configuration options for azure joining. +func (a *Server) SetAzureJoinConfig(c *azurejoin.AzureJoinConfig) { + a.azureJoinConfig = c +} diff --git a/lib/auth/join_azure.go b/lib/join/azurejoin/azure.go similarity index 67% rename from lib/auth/join_azure.go rename to lib/join/azurejoin/azure.go index ff54227baf252..bc5f22c224cfc 100644 --- a/lib/auth/join_azure.go +++ b/lib/join/azurejoin/azure.go @@ -1,22 +1,20 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package auth +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package azurejoin import ( "cmp" @@ -37,21 +35,19 @@ import ( "github.com/digitorus/pkcs7" "github.com/go-jose/go-jose/v3/jwt" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/zitadel/oidc/v3/pkg/oidc" - "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/client/proto" workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/join/joinutils" - "github.com/gravitational/teleport/lib/join/legacyjoin" liboidc "github.com/gravitational/teleport/lib/oidc" "github.com/gravitational/teleport/lib/utils" ) const ( - azureAccessTokenAudience = "https://management.azure.com/" + AzureAccessTokenAudience = "https://management.azure.com/" // azureUserAgent specifies the Azure User-Agent identification for telemetry. azureUserAgent = "teleport" @@ -64,7 +60,8 @@ const ( // Structs for unmarshaling attested data. Schema can be found at // https://learn.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service?tabs=linux#response-2 -type signedAttestedData struct { +// SignedAttestedData models the response from the attested data IMDS endpoint. +type SignedAttestedData struct { Encoding string `json:"encoding"` Signature string `json:"signature"` } @@ -80,7 +77,8 @@ type timestamp struct { ExpiresOn string `json:"expiresOn"` } -type attestedData struct { +// AttestedData models the decoded data returned from the attested data IMDS endpoint. +type AttestedData struct { LicenseType string `json:"licenseType"` Nonce string `json:"nonce"` Plan plan `json:"plan"` @@ -90,7 +88,8 @@ type attestedData struct { SKU string `json:"sku"` } -type accessTokenClaims struct { +// AccessTokenClaims models the claims in an Azure access token. +type AccessTokenClaims struct { oidc.TokenClaims TenantID string `json:"tid"` Version string `json:"ver"` @@ -111,7 +110,7 @@ type accessTokenClaims struct { AzureResourceID string `json:"xms_az_rid"` } -func (c *accessTokenClaims) AsJWTClaims() jwt.Claims { +func (c *AccessTokenClaims) asJWTClaims() jwt.Claims { return jwt.Claims{ Issuer: c.Issuer, Subject: c.Subject, @@ -123,24 +122,32 @@ func (c *accessTokenClaims) AsJWTClaims() jwt.Claims { } } -type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error) - -type vmClientGetter func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) - -type azureRegisterConfig struct { - certificateAuthorities []*x509.Certificate - verify azureVerifyTokenFunc - getVMClient vmClientGetter +// AzureVerifyTokenFunc is a function type that verifies an azure VM token. +type AzureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*AccessTokenClaims, error) + +// VMClientGetter is a function type that returns an Azure VM client for a +// given subscription authenticated with a given static token credential. +type VMClientGetter func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) + +// AzureJoinConfig holds configurable options for Azure joining. +type AzureJoinConfig struct { + // CertificateAuthorities, if set, overrides the root certificate + // authorities used to verify VM attested data. + CertificateAuthorities []*x509.Certificate + // Verify, if set, overrides the function used to verify azure VM tokens. + Verify AzureVerifyTokenFunc + // GetVMClient, if set, overrides the function used to get Azure VM clients. + GetVMClient VMClientGetter } -func azureVerifyFuncFromOIDCVerifier(clientID string) azureVerifyTokenFunc { - return func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error) { +func azureVerifyFuncFromOIDCVerifier(clientID string) AzureVerifyTokenFunc { + return func(ctx context.Context, rawIDToken string) (*AccessTokenClaims, error) { token, err := jwt.ParseSigned(rawIDToken) if err != nil { return nil, trace.Wrap(err) } // Need to get the tenant ID before we verify so we can construct the issuer URL. - var unverifiedClaims accessTokenClaims + var unverifiedClaims AccessTokenClaims if err := token.UnsafeClaimsWithoutVerification(&unverifiedClaims); err != nil { return nil, trace.Wrap(err) } @@ -148,24 +155,24 @@ func azureVerifyFuncFromOIDCVerifier(clientID string) azureVerifyTokenFunc { if err != nil { return nil, trace.Wrap(err) } - return liboidc.ValidateToken[*accessTokenClaims](ctx, issuer, clientID, rawIDToken) + return liboidc.ValidateToken[*AccessTokenClaims](ctx, issuer, clientID, rawIDToken) } } -func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error { - if cfg.verify == nil { - cfg.verify = azureVerifyFuncFromOIDCVerifier(azureAccessTokenAudience) +func (cfg *AzureJoinConfig) checkAndSetDefaults() error { + if cfg.Verify == nil { + cfg.Verify = azureVerifyFuncFromOIDCVerifier(AzureAccessTokenAudience) } - if cfg.certificateAuthorities == nil { + if cfg.CertificateAuthorities == nil { certs, err := getAzureRootCerts() if err != nil { return trace.Wrap(err) } - cfg.certificateAuthorities = certs + cfg.CertificateAuthorities = certs } - if cfg.getVMClient == nil { - cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) { + if cfg.GetVMClient == nil { + cfg.GetVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) { // The User-Agent is added for debugging purposes. It helps identify // and isolate teleport traffic. opts := &armpolicy.ClientOptions{ @@ -182,13 +189,11 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error { return nil } -type azureRegisterOption func(cfg *azureRegisterConfig) - // parseAndVeryAttestedData verifies that an attested data document was signed // by Azure. If verification is successful, it returns the ID of the VM that // produced the document. func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (subscriptionID, vmID string, err error) { - var signedAD signedAttestedData + var signedAD SignedAttestedData if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil { return "", "", trace.Wrap(err) } @@ -206,7 +211,7 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s if err != nil { return "", "", trace.Wrap(err) } - var ad attestedData + var ad AttestedData if err := utils.FastUnmarshal(p7.Content, &ad); err != nil { return "", "", trace.Wrap(err) } @@ -244,14 +249,14 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s // correct Azure VM. Returns the Azure join attributes func verifyVMIdentity( ctx context.Context, - cfg *azureRegisterConfig, + cfg *AzureJoinConfig, accessToken, subscriptionID, vmID string, requestStart time.Time, logger *slog.Logger, ) (joinAttrs *workloadidentityv1pb.JoinAttrsAzure, err error) { - tokenClaims, err := cfg.verify(ctx, accessToken) + tokenClaims, err := cfg.Verify(ctx, accessToken) if err != nil { return nil, trace.Wrap(err) } @@ -270,11 +275,11 @@ func verifyVMIdentity( expectedClaims := jwt.Expected{ Issuer: expectedIssuer, - Audience: jwt.Audience{azureAccessTokenAudience}, + Audience: jwt.Audience{AzureAccessTokenAudience}, Time: requestStart, } - if err := tokenClaims.AsJWTClaims().Validate(expectedClaims); err != nil { + if err := tokenClaims.asJWTClaims().Validate(expectedClaims); err != nil { return nil, trace.Wrap(err) } @@ -299,7 +304,7 @@ func verifyVMIdentity( Token: accessToken, ExpiresOn: tokenClaims.GetExpiration(), }) - vmClient, err := cfg.getVMClient(subscriptionID, tokenCredential) + vmClient, err := cfg.GetVMClient(subscriptionID, tokenCredential) if err != nil { return nil, trace.Wrap(err) } @@ -339,7 +344,7 @@ func verifyVMIdentity( } // claimsToIdentifiers returns the vm identifiers from the provided claims. -func claimsToIdentifiers(tokenClaims *accessTokenClaims) (subscriptionID, resourceGroupID string, err error) { +func claimsToIdentifiers(tokenClaims *AccessTokenClaims) (subscriptionID, resourceGroupID string, err error) { // xms_az_rid claim is omitted when the VM is assigned a System-Assigned Identity. // The xms_mirid claim should be used instead. rid := cmp.Or(tokenClaims.AzureResourceID, tokenClaims.ManangedIdentityResourceID) @@ -369,6 +374,7 @@ func checkAzureAllowRules(vmID string, attrs *workloadidentityv1pb.JoinAttrsAzur } return trace.AccessDenied("instance %v did not match any allow rules in token %v", vmID, token.GetName()) } + func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup string) bool { if len(allowedResourceGroups) == 0 { return true @@ -395,126 +401,78 @@ func azureJoinToAttrs(subscriptionID, resourceGroupID string) *workloadidentityv } } -func (a *Server) checkAzureRequest( - ctx context.Context, - challenge string, - req *proto.RegisterUsingAzureMethodRequest, - cfg *azureRegisterConfig, -) (*workloadidentityv1pb.JoinAttrsAzure, error) { - requestStart := a.clock.Now() - tokenName := req.RegisterUsingTokenRequest.Token - provisionToken, err := a.GetToken(ctx, tokenName) - if err != nil { - return nil, trace.Wrap(err) - } - if provisionToken.GetJoinMethod() != types.JoinMethodAzure { - return nil, trace.AccessDenied("this token does not support the Azure join method") - } - token, ok := provisionToken.(*types.ProvisionTokenV2) - if !ok { - return nil, trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken) - } - - subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities) - if err != nil { - return nil, trace.Wrap(err) - } - - attrs, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart, a.logger) - if err != nil { - return nil, trace.Wrap(err) - } - if err := checkAzureAllowRules(vmID, attrs, token); err != nil { - return attrs, trace.Wrap(err) - } - - return attrs, nil +// CheckAzureRequestParams holds all parameters for [CheckAzureRequest]. +type CheckAzureRequestParams struct { + // AzureJoinConfig holds configurable options for Azure joining. + AzureJoinConfig *AzureJoinConfig + // Token is the token used for the incoming request. + Token *types.ProvisionTokenV2 + // Challenge is the challenge that was issued. + Challenge string + // AttestedData is the Azure attested data that was returned by the joining + // client. It must include the challenge as a nonce. + AttestedData []byte + // AccessToken is the Azure access token that was returned by the joining client + AccessToken string + // Logger will be used for logging. + Logger *slog.Logger + // Clock overrides the system time. + Clock clockwork.Clock } -func generateAzureChallenge() (string, error) { - challenge, err := joinutils.GenerateChallenge(base64.RawURLEncoding, 24) - return challenge, trace.Wrap(err) +func (p *CheckAzureRequestParams) checkAndSetDefaults() error { + switch { + case p.AzureJoinConfig == nil: + p.AzureJoinConfig = &AzureJoinConfig{} + case p.Token == nil: + return trace.BadParameter("Token is required") + case len(p.Challenge) == 0: + return trace.BadParameter("Challenge is required") + case len(p.AttestedData) == 0: + return trace.BadParameter("AttestedData is required") + case len(p.AccessToken) == 0: + return trace.BadParameter("AccessToken is required") + case p.Logger == nil: + return trace.BadParameter("Logger is required") + case p.Clock == nil: + p.Clock = clockwork.NewRealClock() + } + return trace.Wrap(p.AzureJoinConfig.checkAndSetDefaults()) } -// RegisterUsingAzureMethodWithOpts registers the caller using the Azure join method -// and returns signed certs to join the cluster. -// -// The caller must provide a ChallengeResponseFunc which returns a -// *proto.RegisterUsingAzureMethodRequest with a signed attested data document -// including the challenge as a nonce. -func (a *Server) RegisterUsingAzureMethodWithOpts( - ctx context.Context, - challengeResponse client.RegisterAzureChallengeResponseFunc, - opts ...azureRegisterOption, -) (certs *proto.Certs, err error) { - var provisionToken types.ProvisionToken - var joinRequest *types.RegisterUsingTokenRequest - defer func() { - // Emit a log message and audit event on join failure. - if err != nil { - a.handleJoinFailure(ctx, err, provisionToken, nil, joinRequest) - } - }() - - if legacyjoin.Disabled() { - return nil, trace.Wrap(legacyjoin.ErrDisabled) - } - - cfg := &azureRegisterConfig{} - for _, opt := range opts { - opt(cfg) - } - if err := cfg.CheckAndSetDefaults(ctx); err != nil { +// CheckAzureRequest checks an azure join request by verifying the VMs claims +// and checking that they match an allow rule from the join token. +func CheckAzureRequest(ctx context.Context, params CheckAzureRequestParams) (*workloadidentityv1pb.JoinAttrsAzure, error) { + if err := params.checkAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } + requestStart := params.Clock.Now() - challenge, err := generateAzureChallenge() - if err != nil { - return nil, trace.Wrap(err) - } - req, err := challengeResponse(challenge) + subID, vmID, err := parseAndVerifyAttestedData( + ctx, + params.AttestedData, + params.Challenge, + params.AzureJoinConfig.CertificateAuthorities, + ) if err != nil { return nil, trace.Wrap(err) } - joinRequest = req.RegisterUsingTokenRequest - if err := req.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - - provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest) + attrs, err := verifyVMIdentity(ctx, params.AzureJoinConfig, params.AccessToken, subID, vmID, requestStart, params.Logger) if err != nil { return nil, trace.Wrap(err) } - - joinAttrs, err := a.checkAzureRequest(ctx, challenge, req, cfg) - if err != nil { - return nil, trace.Wrap(err) + if err := checkAzureAllowRules(vmID, attrs, params.Token); err != nil { + return attrs, trace.Wrap(err) } - if req.RegisterUsingTokenRequest.Role == types.RoleBot { - params := makeBotCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/, &workloadidentityv1pb.JoinAttrs{ - Azure: joinAttrs, - }) - certs, _, err := a.GenerateBotCertsForJoin(ctx, provisionToken, params) - return certs, trace.Wrap(err) - } - params := makeHostCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/) - certs, err = a.GenerateHostCertsForJoin(ctx, provisionToken, params) - return certs, trace.Wrap(err) + return attrs, nil } -// RegisterUsingAzureMethod registers the caller using the Azure join method -// and returns signed certs to join the cluster. -// -// The caller must provide a ChallengeResponseFunc which returns a -// *proto.RegisterUsingAzureMethodRequest with a signed attested data document -// including the challenge as a nonce. -func (a *Server) RegisterUsingAzureMethod( - ctx context.Context, - challengeResponse client.RegisterAzureChallengeResponseFunc, -) (certs *proto.Certs, err error) { - return a.RegisterUsingAzureMethodWithOpts(ctx, challengeResponse) +// GenerateAzureChallenge generates a challenge for the Azure join method. +func GenerateAzureChallenge() (string, error) { + challenge, err := joinutils.GenerateChallenge(base64.RawURLEncoding, 24) + return challenge, trace.Wrap(err) } // fixAzureSigningAlgorithm fixes a mismatch between the object IDs of the diff --git a/lib/auth/azure_certs.go b/lib/join/azurejoin/azure_certs.go similarity index 99% rename from lib/auth/azure_certs.go rename to lib/join/azurejoin/azure_certs.go index cea969aadc87b..a0893eed9c7e0 100644 --- a/lib/auth/azure_certs.go +++ b/lib/join/azurejoin/azure_certs.go @@ -16,7 +16,7 @@ * along with this program. If not, see . */ -package auth +package azurejoin import ( "context" diff --git a/lib/auth/azure_certs_test.go b/lib/join/azurejoin/azure_certs_test.go similarity index 92% rename from lib/auth/azure_certs_test.go rename to lib/join/azurejoin/azure_certs_test.go index abeb857f0edb7..7ebb1acccf004 100644 --- a/lib/auth/azure_certs_test.go +++ b/lib/join/azurejoin/azure_certs_test.go @@ -16,14 +16,12 @@ * along with this program. If not, see . */ -package auth_test +package azurejoin import ( "testing" "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/lib/auth" ) func TestIsAllowedDomain(t *testing.T) { @@ -71,7 +69,7 @@ func TestIsAllowedDomain(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - tc.assert(t, auth.IsAllowedDomain(tc.url, allowedDomains)) + tc.assert(t, isAllowedDomain(tc.url, allowedDomains)) }) } } diff --git a/lib/auth/join_azure_test.go b/lib/join/azurejoin/join_azure_test.go similarity index 77% rename from lib/auth/join_azure_test.go rename to lib/join/azurejoin/join_azure_test.go index deaa014dcc3fa..2133e93eb7b65 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/join/azurejoin/join_azure_test.go @@ -16,10 +16,11 @@ * along with this program. If not, see . */ -package auth_test +package azurejoin_test import ( "context" + "crypto" "crypto/x509" "encoding/base64" "encoding/json" @@ -36,13 +37,17 @@ import ( "github.com/stretchr/testify/require" "github.com/zitadel/oidc/v3/pkg/oidc" - "github.com/gravitational/teleport/api/client/proto" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authtest" - "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" + "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/join/azurejoin" + "github.com/gravitational/teleport/lib/join/joinclient" + "github.com/gravitational/teleport/lib/tlsca" ) type mockAzureVMClient struct { @@ -67,7 +72,7 @@ func (m *mockAzureVMClient) GetByVMID(_ context.Context, vmID string) (*azure.Vi return nil, trace.NotFound("no vm with id %q", vmID) } -func makeVMClientGetter(clients map[string]*mockAzureVMClient) auth.AzureVMClientGetter { +func makeVMClientGetter(clients map[string]*mockAzureVMClient) azurejoin.VMClientGetter { return func(subscriptionID string, _ *azure.StaticCredential) (azure.VirtualMachinesClient, error) { if client, ok := clients[subscriptionID]; ok { return client, nil @@ -76,18 +81,6 @@ func makeVMClientGetter(clients map[string]*mockAzureVMClient) auth.AzureVMClien } } -type azureChallengeResponseConfig struct { - Challenge string -} - -type azureChallengeResponseOption func(*azureChallengeResponseConfig) - -func withChallengeAzure(challenge string) azureChallengeResponseOption { - return func(cfg *azureChallengeResponseConfig) { - cfg.Challenge = challenge - } -} - func vmssResourceID(subscription, resourceGroup, name string) string { return resourceID("Microsoft.Compute/virtualMachineScaleSets", subscription, resourceGroup, name) } @@ -107,8 +100,8 @@ func resourceID(resourceType, subscription, resourceGroup, name string) string { ) } -func mockVerifyToken(err error) auth.AzureVerifyTokenFunc { - return func(_ context.Context, rawToken string) (*auth.AccessTokenClaims, error) { +func mockVerifyToken(err error) azurejoin.AzureVerifyTokenFunc { + return func(_ context.Context, rawToken string) (*azurejoin.AccessTokenClaims, error) { if err != nil { return nil, err } @@ -116,7 +109,7 @@ func mockVerifyToken(err error) auth.AzureVerifyTokenFunc { if err != nil { return nil, trace.Wrap(err) } - var claims auth.AccessTokenClaims + var claims azurejoin.AccessTokenClaims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, trace.Wrap(err) } @@ -132,10 +125,10 @@ func makeToken(managedIdentityResourceID, azureResourceID string, issueTime time if err != nil { return "", trace.Wrap(err) } - claims := auth.AccessTokenClaims{ + claims := azurejoin.AccessTokenClaims{ TokenClaims: oidc.TokenClaims{ Issuer: "https://sts.windows.net/test-tenant-id/", - Audience: []string{auth.AzureAccessTokenAudience}, + Audience: []string{azurejoin.AzureAccessTokenAudience}, Subject: "test", IssuedAt: oidc.FromTime(issueTime), NotBefore: oidc.FromTime(issueTime), @@ -154,14 +147,19 @@ func makeToken(managedIdentityResourceID, azureResourceID string, issueTime time return raw, nil } -func TestAuth_RegisterUsingAzureMethod(t *testing.T) { +func TestJoinAzure(t *testing.T) { t.Parallel() - ctx := t.Context() - p := newAuthSuite(t) - a := p.a - sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + server, err := authtest.NewTestServer(authtest.ServerConfig{ + Auth: authtest.AuthServerConfig{ + Dir: t.TempDir(), + }, + }) + require.NoError(t, err) + a := server.Auth() + + nopClient, err := server.NewClient(authtest.TestNop()) require.NoError(t, err) tlsConfig, err := fixtures.LocalTLSConfig() @@ -171,9 +169,6 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) require.NoError(t, err) - tlsPublicKey, err := authtest.PrivateKeyToPublicKeyTLS(sshPrivateKey) - require.NoError(t, err) - isAccessDenied := func(t require.TestingT, err error, _ ...any) { require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err) } @@ -197,10 +192,10 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { tokenVMID string requestTokenName string tokenSpec types.ProvisionTokenSpecV2 - challengeResponseOptions []azureChallengeResponseOption + overrideReturnedChallenge string challengeResponseErr error certs []*x509.Certificate - verify auth.AzureVerifyTokenFunc + verify azurejoin.AzureVerifyTokenFunc assertError require.ErrorAssertionFunc }{ { @@ -343,12 +338,10 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { }, JoinMethod: types.JoinMethodAzure, }, - challengeResponseOptions: []azureChallengeResponseOption{ - withChallengeAzure("wrong-challenge"), - }, - verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isAccessDenied, + overrideReturnedChallenge: "wrong-challenge", + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: isAccessDenied, }, { name: "invalid signature", @@ -458,6 +451,27 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + vmClient := &mockAzureVMClient{ + vms: map[string]*azure.VirtualMachine{ + defaultVMResourceID: { + ID: defaultVMResourceID, + Name: defaultVMName, + Subscription: defaultSubscription, + ResourceGroup: defaultResourceGroup, + VMID: defaultVMID, + }, + }, + } + getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ + defaultSubscription: vmClient, + }) + + a.SetAzureJoinConfig(&azurejoin.AzureJoinConfig{ + CertificateAuthorities: tc.certs, + Verify: tc.verify, + GetVMClient: getVMClient, + }) + token, err := types.NewProvisionTokenFromSpec( "test-token", time.Now().Add(time.Minute), @@ -476,74 +490,65 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { accessToken, err := makeToken(mirID, "", a.GetClock().Now()) require.NoError(t, err) - vmClient := &mockAzureVMClient{ - vms: map[string]*azure.VirtualMachine{ - defaultVMResourceID: { - ID: defaultVMResourceID, - Name: defaultVMName, - Subscription: defaultSubscription, - ResourceGroup: defaultResourceGroup, - VMID: defaultVMID, - }, - }, + imdsClient := &fakeIMDSClient{ + accessToken: accessToken, + accessTokenErr: tc.challengeResponseErr, + overrideChallenge: tc.overrideReturnedChallenge, + signingCert: tlsConfig.Certificate, + signingKey: pkey, + subscription: tc.tokenSubscription, + vmID: tc.tokenVMID, } - getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ - defaultSubscription: vmClient, - }) - - _, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { - cfg := &azureChallengeResponseConfig{Challenge: challenge} - for _, opt := range tc.challengeResponseOptions { - opt(cfg) - } - - ad := auth.AttestedData{ - Nonce: cfg.Challenge, - SubscriptionID: tc.tokenSubscription, - ID: tc.tokenVMID, - } - adBytes, err := json.Marshal(&ad) - require.NoError(t, err) - s, err := pkcs7.NewSignedData(adBytes) - require.NoError(t, err) - require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{})) - signature, err := s.Finish() - require.NoError(t, err) - signedAD := auth.SignedAttestedData{ - Encoding: "pkcs7", - Signature: base64.StdEncoding.EncodeToString(signature), - } - signedADBytes, err := json.Marshal(&signedAD) - require.NoError(t, err) - req := &proto.RegisterUsingAzureMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: tc.requestTokenName, - HostID: "test-node", - Role: types.RoleNode, - PublicSSHKey: sshPublicKey, - PublicTLSKey: tlsPublicKey, + t.Run("legacy", func(t *testing.T) { + _, err = joinclient.LegacyJoin(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + JoinMethod: types.JoinMethodAzure, + ID: state.IdentityID{ + Role: types.RoleInstance, + HostUUID: "testuuid", }, - AttestedData: signedADBytes, - AccessToken: accessToken, - } - return req, tc.challengeResponseErr - }, auth.WithAzureCerts(tc.certs), auth.WithAzureVerifyFunc(tc.verify), auth.WithAzureVMClientGetter(getVMClient)) - tc.assertError(t, err) + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + }, + }) + tc.assertError(t, err) + }) + t.Run("new", func(t *testing.T) { + _, err = joinclient.Join(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + ID: state.IdentityID{ + Role: types.RoleInstance, + }, + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + }, + }) + tc.assertError(t, err) + }) }) } } // TestAuth_RegisterUsingAzureClaims tests the Azure join method by verifying // joining VMs by the token claims rather than from the Azure VM API. -func TestAuth_RegisterUsingAzureClaims(t *testing.T) { +func TestJoinAzureClaims(t *testing.T) { t.Parallel() - ctx := t.Context() - p := newAuthSuite(t) - a := p.a - sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + server, err := authtest.NewTestServer(authtest.ServerConfig{ + Auth: authtest.AuthServerConfig{ + Dir: t.TempDir(), + }, + }) + require.NoError(t, err) + a := server.Auth() + + nopClient, err := server.NewClient(authtest.TestNop()) require.NoError(t, err) tlsConfig, err := fixtures.LocalTLSConfig() @@ -553,9 +558,6 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) require.NoError(t, err) - tlsPublicKey, err := authtest.PrivateKeyToPublicKeyTLS(sshPrivateKey) - require.NoError(t, err) - isAccessDenied := func(t require.TestingT, err error, _ ...any) { require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err) } @@ -565,6 +567,17 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { defaultIdentityName := "test-id" defaultVMID := "my-vm-id" + botName := "botty" + _, err = machineidv1.UpsertBot(ctx, a, &machineidv1pb.Bot{ + Kind: types.KindBot, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: botName, + }, + Spec: &machineidv1pb.BotSpec{}, + }, a.GetClock().Now(), "") + require.NoError(t, err) + tests := []struct { name string tokenManagedIdentityResourceID string @@ -573,10 +586,9 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { tokenVMID string requestTokenName string tokenSpec types.ProvisionTokenSpecV2 - challengeResponseOptions []azureChallengeResponseOption challengeResponseErr error certs []*x509.Certificate - verify auth.AzureVerifyTokenFunc + verify azurejoin.AzureVerifyTokenFunc assertError require.ErrorAssertionFunc }{ { @@ -807,45 +819,149 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { defaultSubscription: vmClient, }) - _, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { - cfg := &azureChallengeResponseConfig{Challenge: challenge} - for _, opt := range tc.challengeResponseOptions { - opt(cfg) - } + a.SetAzureJoinConfig(&azurejoin.AzureJoinConfig{ + CertificateAuthorities: tc.certs, + Verify: tc.verify, + GetVMClient: getVMClient, + }) - ad := auth.AttestedData{ - Nonce: cfg.Challenge, - SubscriptionID: tc.tokenSubscription, - ID: tc.tokenVMID, - } - adBytes, err := json.Marshal(&ad) - require.NoError(t, err) - s, err := pkcs7.NewSignedData(adBytes) - require.NoError(t, err) - require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{})) - signature, err := s.Finish() - require.NoError(t, err) - signedAD := auth.SignedAttestedData{ - Encoding: "pkcs7", - Signature: base64.StdEncoding.EncodeToString(signature), - } - signedADBytes, err := json.Marshal(&signedAD) + imdsClient := &fakeIMDSClient{ + accessToken: accessToken, + accessTokenErr: tc.challengeResponseErr, + signingCert: tlsConfig.Certificate, + signingKey: pkey, + subscription: tc.tokenSubscription, + vmID: tc.tokenVMID, + } + + t.Run("legacy", func(t *testing.T) { + // Try to join via the legacy join service. + _, err = joinclient.LegacyJoin(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + JoinMethod: types.JoinMethodAzure, + ID: state.IdentityID{ + Role: types.RoleInstance, + HostUUID: "testuuid", + }, + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + }, + }) + tc.assertError(t, err) + }) + t.Run("new", func(t *testing.T) { + // Try to join via the new join service. + _, err = joinclient.Join(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + ID: state.IdentityID{ + Role: types.RoleInstance, + }, + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + }, + }) + tc.assertError(t, err) + }) + t.Run("bot", func(t *testing.T) { + // Try to join as a bot. + tokenSpec := tc.tokenSpec + tokenSpec.BotName = botName + tokenSpec.Roles = types.SystemRoles{types.RoleBot} + token, err := types.NewProvisionTokenFromSpec( + "test-token", + time.Now().Add(time.Minute), + tokenSpec) require.NoError(t, err) + require.NoError(t, a.UpsertToken(ctx, token)) - req := &proto.RegisterUsingAzureMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: tc.requestTokenName, - HostID: "test-node", - Role: types.RoleNode, - PublicSSHKey: sshPublicKey, - PublicTLSKey: tlsPublicKey, + result, err := joinclient.Join(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + ID: state.IdentityID{ + Role: types.RoleBot, }, - AttestedData: signedADBytes, - AccessToken: accessToken, + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + }, + }) + tc.assertError(t, err) + if err != nil { + return } - return req, tc.challengeResponseErr - }, auth.WithAzureCerts(tc.certs), auth.WithAzureVerifyFunc(tc.verify), auth.WithAzureVMClientGetter(getVMClient)) - tc.assertError(t, err) + + cert, err := tlsca.ParseCertificatePEM(result.Certs.TLS) + require.NoError(t, err) + identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + + // Make sure the LoginIP was set on the identity. + require.NotEmpty(t, identity.LoginIP) + + // Make sure the JoinAttributes were set. + require.NotNil(t, identity.JoinAttributes) + require.NotNil(t, identity.JoinAttributes.Azure) + require.Equal(t, tc.tokenSubscription, identity.JoinAttributes.Azure.Subscription) + }) }) } } + +type fakeIMDSClient struct { + accessToken string + accessTokenErr error + + // overrideChallenge overrides the challenge/nonce included in attested data. + overrideChallenge string + signingCert *x509.Certificate + signingKey crypto.Signer + subscription string + vmID string +} + +func (c *fakeIMDSClient) IsAvailable(_ context.Context) bool { + return true +} + +func (c *fakeIMDSClient) GetAttestedData(_ context.Context, nonce string) ([]byte, error) { + ad := azurejoin.AttestedData{ + Nonce: nonce, + SubscriptionID: c.subscription, + ID: c.vmID, + } + if c.overrideChallenge != "" { + ad.Nonce = c.overrideChallenge + } + adBytes, err := json.Marshal(&ad) + if err != nil { + return nil, trace.Wrap(err) + } + s, err := pkcs7.NewSignedData(adBytes) + if err != nil { + return nil, trace.Wrap(err) + } + if err := s.AddSigner(c.signingCert, c.signingKey, pkcs7.SignerInfoConfig{}); err != nil { + return nil, trace.Wrap(err) + } + signature, err := s.Finish() + if err != nil { + return nil, trace.Wrap(err) + } + signedAD := azurejoin.SignedAttestedData{ + Encoding: "pkcs7", + Signature: base64.StdEncoding.EncodeToString(signature), + } + signedADBytes, err := json.Marshal(&signedAD) + if err != nil { + return nil, trace.Wrap(err) + } + return signedADBytes, nil +} + +func (c *fakeIMDSClient) GetAccessToken(_ context.Context, clientID string) (string, error) { + return c.accessToken, trace.Wrap(c.accessTokenErr) +} diff --git a/lib/join/joinclient/join.go b/lib/join/joinclient/join.go index 1374b9b5a030d..7dc1f80cace3b 100644 --- a/lib/join/joinclient/join.go +++ b/lib/join/joinclient/join.go @@ -204,6 +204,7 @@ func joinWithClient(ctx context.Context, params JoinParams, client *joinv1.Clien case types.JoinMethodUnspecified: // leave joinMethodPtr nil to let the server pick based on the token case types.JoinMethodToken, + types.JoinMethodAzure, types.JoinMethodAzureDevops, types.JoinMethodBitbucket, types.JoinMethodBoundKeypair, @@ -306,6 +307,8 @@ func joinWithMethod( switch types.JoinMethod(method) { case types.JoinMethodToken: return tokenJoin(stream, clientParams) + case types.JoinMethodAzure: + return azureJoin(ctx, stream, joinParams, clientParams) case types.JoinMethodAzureDevops: if joinParams.IDToken == "" { joinParams.IDToken, err = azuredevops.NewIDTokenSource(os.Getenv).GetIDToken(ctx) diff --git a/lib/join/joinclient/join_azure.go b/lib/join/joinclient/join_azure.go new file mode 100644 index 0000000000000..9030c5a465fc9 --- /dev/null +++ b/lib/join/joinclient/join_azure.go @@ -0,0 +1,77 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package joinclient + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/cloud/imds/azure" + "github.com/gravitational/teleport/lib/join/internal/messages" +) + +func azureJoin(ctx context.Context, stream messages.ClientStream, joinParams JoinParams, clientParams messages.ClientParams) (messages.Response, error) { + // The Azure join method involves the following messages: + // + // client->server ClientInit + // client<-server ServerInit + // client->server AzureInit + // client<-server AzureChallenge + // client->server AzureChallengeSolution + // client<-server Result + // + // At this point the ServerInit messages has already been received, what's + // left is to send the AzureInit message, handle the challenge-response, and + // receive and return the final result. + if err := stream.Send(&messages.AzureInit{ + ClientParams: clientParams, + }); err != nil { + return nil, trace.Wrap(err, "sending AzureInit") + } + + challenge, err := messages.RecvResponse[*messages.AzureChallenge](stream) + if err != nil { + return nil, trace.Wrap(err, "receiving AzureChallenge") + } + + imds := joinParams.AzureParams.IMDSClient + if imds == nil { + imds = azure.NewInstanceMetadataClient() + } + if !imds.IsAvailable(ctx) { + return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?") + } + ad, err := imds.GetAttestedData(ctx, challenge.Challenge) + if err != nil { + return nil, trace.Wrap(err) + } + accessToken, err := imds.GetAccessToken(ctx, joinParams.AzureParams.ClientID) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := stream.Send(&messages.AzureChallengeSolution{ + AttestedData: ad, + AccessToken: accessToken, + }); err != nil { + return nil, trace.Wrap(err, "sending AzureChallengeSolution") + } + + result, err := stream.Recv() + return result, trace.Wrap(err, "receiving join result") +} diff --git a/lib/join/server.go b/lib/join/server.go index 08b3d9701449b..5d297b03d6ec1 100644 --- a/lib/join/server.go +++ b/lib/join/server.go @@ -46,6 +46,7 @@ import ( "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/join/azuredevops" + "github.com/gravitational/teleport/lib/join/azurejoin" "github.com/gravitational/teleport/lib/join/bitbucket" "github.com/gravitational/teleport/lib/join/circleci" "github.com/gravitational/teleport/lib/join/ec2join" @@ -104,6 +105,7 @@ type AuthService interface { GetK8sOIDCValidator() *kubetoken.KubernetesOIDCTokenValidator GetSpaceliftIDTokenValidator() spacelift.Validator GetTerraformIDTokenValidator() terraformcloud.Validator + GetAzureJoinConfig() *azurejoin.AzureJoinConfig services.Presence } @@ -296,6 +298,8 @@ func (s *Server) handleJoinMethod( joinMethod types.JoinMethod, ) (messages.Response, error) { switch joinMethod { + case types.JoinMethodAzure: + return s.handleAzureJoin(stream, authCtx, clientInit, token) case types.JoinMethodAzureDevops: return s.handleOIDCJoin(stream, authCtx, clientInit, token, s.validateAzureDevopsToken) case types.JoinMethodBitbucket: diff --git a/lib/join/server_azure.go b/lib/join/server_azure.go new file mode 100644 index 0000000000000..db00379b5a37a --- /dev/null +++ b/lib/join/server_azure.go @@ -0,0 +1,108 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package join + +import ( + "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/join/azurejoin" + "github.com/gravitational/teleport/lib/join/internal/authz" + "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/join/provision" +) + +// handleAzureJoin handles join attempts for the Azure join method. +// +// The Azure join method involves the following messages: +// +// client->server ClientInit +// client<-server ServerInit +// client->server AzureInit +// client<-server AzureChallenge +// client->server AzureChallengeSolution +// client<-server Result +// +// At this point the ServerInit message has already been sent, what's left is +// to receive the AzureInit message, handle the challenge-response, and send the +// final result if everything checks out. +func (s *Server) handleAzureJoin( + stream messages.ServerStream, + authCtx *authz.Context, + clientInit *messages.ClientInit, + token provision.Token, +) (messages.Response, error) { + // Receive the AzureInit message from the client. + azureInit, err := messages.RecvRequest[*messages.AzureInit](stream) + if err != nil { + return nil, trace.Wrap(err, "receiving AzureInit message") + } + // Set any diagnostic info from the ClientParams. + setDiagnosticClientParams(stream.Diagnostic(), &azureInit.ClientParams) + + // Generate and send the challenge. + challenge, err := azurejoin.GenerateAzureChallenge() + if err != nil { + return nil, trace.Wrap(err, "generating challenge") + } + if err := stream.Send(&messages.AzureChallenge{ + Challenge: challenge, + }); err != nil { + return nil, trace.Wrap(err, "sending AzureChallenge") + } + + // Receive the solution from the client. + solution, err := messages.RecvRequest[*messages.AzureChallengeSolution](stream) + if err != nil { + return nil, trace.Wrap(err, "receiving AzureChallengeSolution") + } + + ptv2, ok := token.(*types.ProvisionTokenV2) + if !ok { + return nil, trace.BadParameter("Azure join method only supports ProvisionTokenV2, got %T", token) + } + + // Verify the client's idenitty and make sure it matches an allow rule in the provision token. + claims, err := azurejoin.CheckAzureRequest(stream.Context(), azurejoin.CheckAzureRequestParams{ + AzureJoinConfig: s.cfg.AuthService.GetAzureJoinConfig(), + Token: ptv2, + Challenge: challenge, + AttestedData: solution.AttestedData, + AccessToken: solution.AccessToken, + Logger: log, + Clock: s.cfg.AuthService.GetClock(), + }) + if err != nil { + return nil, trace.Wrap(err, "checking Azure challenge solution") + } + + // Make and return the final result message. + result, err := s.makeResult( + stream.Context(), + stream.Diagnostic(), + authCtx, + clientInit, + &azureInit.ClientParams, + token, + nil, // rawClaims + &workloadidentityv1pb.JoinAttrs{ + Azure: claims, + }, + ) + return result, trace.Wrap(err) +} From 4b78af1e23d73585c0c38eab95e704d8ce01c4b6 Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Wed, 19 Nov 2025 17:24:17 -0800 Subject: [PATCH 2/3] fix comment typo --- lib/join/server_azure.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/join/server_azure.go b/lib/join/server_azure.go index db00379b5a37a..744d70d0dd911 100644 --- a/lib/join/server_azure.go +++ b/lib/join/server_azure.go @@ -77,7 +77,7 @@ func (s *Server) handleAzureJoin( return nil, trace.BadParameter("Azure join method only supports ProvisionTokenV2, got %T", token) } - // Verify the client's idenitty and make sure it matches an allow rule in the provision token. + // Verify the client's identity and make sure it matches an allow rule in the provision token. claims, err := azurejoin.CheckAzureRequest(stream.Context(), azurejoin.CheckAzureRequestParams{ AzureJoinConfig: s.cfg.AuthService.GetAzureJoinConfig(), Token: ptv2, From 72dacf7ed1591f32d6f37e268103b592b4a5d6cc Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Wed, 19 Nov 2025 17:22:21 -0800 Subject: [PATCH 3/3] have client send intermediate CA --- lib/auth/join/join.go | 4 + lib/join/azurejoin/azure.go | 100 +++++++++----- lib/join/azurejoin/azure_certs.go | 8 +- lib/join/azurejoin/join_azure_test.go | 184 ++++++++++++++++++++------ lib/join/joinclient/join_azure.go | 62 ++++++++- lib/join/server_azure.go | 10 ++ 6 files changed, 286 insertions(+), 82 deletions(-) diff --git a/lib/auth/join/join.go b/lib/auth/join/join.go index 91c7ad23d6c14..d5fae28d72cf1 100644 --- a/lib/auth/join/join.go +++ b/lib/auth/join/join.go @@ -80,6 +80,10 @@ type AzureParams struct { ClientID string // IMDSClient overrides the client used to fetch data from Azure IMDS. IMDSClient AzureIMDSClient + // IssuerHTTPClient, if set, overrides the default HTTP client used to + // fetch the intermediate CA which issued the attested data document + // signing certificate. Only used when joining via the new join service. + IssuerHTTPClient utils.HTTPDoClient } // AzureIMDSClient is a client to Azure's IMDS. diff --git a/lib/join/azurejoin/azure.go b/lib/join/azurejoin/azure.go index bc5f22c224cfc..7d2f5e7957901 100644 --- a/lib/join/azurejoin/azure.go +++ b/lib/join/azurejoin/azure.go @@ -21,7 +21,6 @@ import ( "context" "crypto/x509" "encoding/base64" - "encoding/pem" "log/slog" "net/url" "slices" @@ -41,6 +40,7 @@ import ( workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/azure" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/join/joinutils" liboidc "github.com/gravitational/teleport/lib/oidc" "github.com/gravitational/teleport/lib/utils" @@ -138,6 +138,10 @@ type AzureJoinConfig struct { Verify AzureVerifyTokenFunc // GetVMClient, if set, overrides the function used to get Azure VM clients. GetVMClient VMClientGetter + // IssuerHTTPClient, if set, overrides the default HTTP client used to + // fetch the intermediate CA which issued the attested data document + // signing certificate. + IssuerHTTPClient utils.HTTPDoClient } func azureVerifyFuncFromOIDCVerifier(clientID string) AzureVerifyTokenFunc { @@ -186,55 +190,58 @@ func (cfg *AzureJoinConfig) checkAndSetDefaults() error { return client, trace.Wrap(err) } } + if cfg.IssuerHTTPClient == nil { + httpClient, err := defaults.HTTPClient() + if err != nil { + return trace.Wrap(err) + } + cfg.IssuerHTTPClient = httpClient + } return nil } // parseAndVeryAttestedData verifies that an attested data document was signed // by Azure. If verification is successful, it returns the ID of the VM that // produced the document. -func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (subscriptionID, vmID string, err error) { - var signedAD SignedAttestedData - if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil { - return "", "", trace.Wrap(err) - } - if signedAD.Encoding != "pkcs7" { - return "", "", trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding) - } - - sigPEM := "-----BEGIN PKCS7-----\n" + signedAD.Signature + "\n-----END PKCS7-----" - sigBER, _ := pem.Decode([]byte(sigPEM)) - if sigBER == nil { - return "", "", trace.AccessDenied("unable to decode attested data document") - } - - p7, err := pkcs7.Parse(sigBER.Bytes) +func parseAndVerifyAttestedData( + ctx context.Context, + cfg *AzureJoinConfig, + adBytes []byte, + intermediates []byte, + challenge string, +) (subscriptionID, vmID string, err error) { + ad, p7, err := ParseAttestedData(adBytes) if err != nil { return "", "", trace.Wrap(err) } - var ad AttestedData - if err := utils.FastUnmarshal(p7.Content, &ad); err != nil { - return "", "", trace.Wrap(err) - } if ad.Nonce != challenge { return "", "", trace.AccessDenied("challenge is missing or does not match") } - if len(p7.Certificates) == 0 { return "", "", trace.AccessDenied("no certificates for signature") } fixAzureSigningAlgorithm(p7) - // Azure only sends the leaf cert, so we have to fetch the intermediate. - intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0]) - if err != nil { - return "", "", trace.Wrap(err) - } - if intermediate != nil { - p7.Certificates = append(p7.Certificates, intermediate) + if len(intermediates) > 0 { + // Client explicitly sent intermediate CAs, included them. + intermediates, err := x509.ParseCertificates(intermediates) + if err != nil { + return "", "", trace.Wrap(err, "parsing intermediate certificates sent by client") + } + p7.Certificates = append(p7.Certificates, intermediates...) + } else { + // Client did not send intermediates, fetch them from Azure. + intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0], cfg.IssuerHTTPClient) + if err != nil { + return "", "", trace.Wrap(err) + } + if intermediate != nil { + p7.Certificates = append(p7.Certificates, intermediate) + } } pool := x509.NewCertPool() - for _, cert := range certs { + for _, cert := range cfg.CertificateAuthorities { pool.AddCert(cert) } @@ -245,6 +252,33 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s return ad.SubscriptionID, ad.ID, nil } +// ParseAttestedData returns the parsed VM attested data and a PKCS7 structure +// which can be used to verify the signature. +func ParseAttestedData(adBytes []byte) (*AttestedData, *pkcs7.PKCS7, error) { + var signedAD SignedAttestedData + if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil { + return nil, nil, trace.Wrap(err) + } + if signedAD.Encoding != "pkcs7" { + return nil, nil, trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding) + } + + sigDER, err := base64.StdEncoding.DecodeString(signedAD.Signature) + if err != nil { + return nil, nil, trace.Wrap(err, "decoding attested data document from base64") + } + + p7, err := pkcs7.Parse(sigDER) + if err != nil { + return nil, nil, trace.Wrap(err) + } + var ad AttestedData + if err := utils.FastUnmarshal(p7.Content, &ad); err != nil { + return nil, nil, trace.Wrap(err) + } + return &ad, p7, nil +} + // verifyVMIdentity verifies that the provided access token came from the // correct Azure VM. Returns the Azure join attributes func verifyVMIdentity( @@ -412,6 +446,9 @@ type CheckAzureRequestParams struct { // AttestedData is the Azure attested data that was returned by the joining // client. It must include the challenge as a nonce. AttestedData []byte + // Intermediate encodes the intermediate CAs that issued the leaf certificate + // used to sign the attested data document, in x509 DER format. + Intermediate []byte // AccessToken is the Azure access token that was returned by the joining client AccessToken string // Logger will be used for logging. @@ -450,9 +487,10 @@ func CheckAzureRequest(ctx context.Context, params CheckAzureRequestParams) (*wo subID, vmID, err := parseAndVerifyAttestedData( ctx, + params.AzureJoinConfig, params.AttestedData, + params.Intermediate, params.Challenge, - params.AzureJoinConfig.CertificateAuthorities, ) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/join/azurejoin/azure_certs.go b/lib/join/azurejoin/azure_certs.go index a0893eed9c7e0..431e7bcecb35b 100644 --- a/lib/join/azurejoin/azure_certs.go +++ b/lib/join/azurejoin/azure_certs.go @@ -28,7 +28,6 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" ) @@ -49,16 +48,11 @@ func isAllowedDomain(cn string, domains []string) bool { } // getAzureIssuerCert fetches a x509 certificate's issuing certificate. -func getAzureIssuerCert(ctx context.Context, cert *x509.Certificate) (*x509.Certificate, error) { +func getAzureIssuerCert(ctx context.Context, cert *x509.Certificate, httpClient utils.HTTPDoClient) (*x509.Certificate, error) { if len(cert.IssuingCertificateURL) == 0 { return nil, nil } - httpClient, err := defaults.HTTPClient() - if err != nil { - return nil, trace.Wrap(err) - } - // Azure sends only one issuing cert. issuerURL := cert.IssuingCertificateURL[0] commonName := cert.Subject.CommonName diff --git a/lib/join/azurejoin/join_azure_test.go b/lib/join/azurejoin/join_azure_test.go index 2133e93eb7b65..9752766d8d769 100644 --- a/lib/join/azurejoin/join_azure_test.go +++ b/lib/join/azurejoin/join_azure_test.go @@ -19,13 +19,19 @@ package azurejoin_test import ( + "bytes" "context" "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/x509" + "crypto/x509/pkix" "encoding/base64" "encoding/json" - "encoding/pem" "fmt" + "io" + "net/http" "testing" "time" @@ -44,7 +50,6 @@ import ( "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/cloud/azure" - "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/join/azurejoin" "github.com/gravitational/teleport/lib/join/joinclient" "github.com/gravitational/teleport/lib/tlsca" @@ -162,12 +167,14 @@ func TestJoinAzure(t *testing.T) { nopClient, err := server.NewClient(authtest.TestNop()) require.NoError(t, err) - tlsConfig, err := fixtures.LocalTLSConfig() - require.NoError(t, err) - - block, _ := pem.Decode(fixtures.LocalhostKey) - pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + caChain := newFakeAzureCAChain(t) + httpClient := newFakeAzureIssuerHTTPClient(caChain.intermediateCertDER) + instanceKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) + instanceCert := caChain.issueLeafCert(t, + instanceKey.Public(), + "instance.metadata.azure.com", + "http://www.microsoft.com/pkiops/certs/testcert.crt") isAccessDenied := func(t require.TestingT, err error, _ ...any) { require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err) @@ -216,7 +223,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -237,7 +244,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -257,7 +264,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -277,7 +284,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, challengeResponseErr: trace.BadParameter("test error"), assertError: isBadParameter, }, @@ -298,7 +305,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -319,7 +326,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -340,7 +347,7 @@ func TestJoinAzure(t *testing.T) { }, overrideReturnedChallenge: "wrong-challenge", verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -381,7 +388,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -402,7 +409,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -423,7 +430,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -444,7 +451,7 @@ func TestJoinAzure(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, } @@ -470,6 +477,7 @@ func TestJoinAzure(t *testing.T) { CertificateAuthorities: tc.certs, Verify: tc.verify, GetVMClient: getVMClient, + IssuerHTTPClient: httpClient, }) token, err := types.NewProvisionTokenFromSpec( @@ -494,8 +502,8 @@ func TestJoinAzure(t *testing.T) { accessToken: accessToken, accessTokenErr: tc.challengeResponseErr, overrideChallenge: tc.overrideReturnedChallenge, - signingCert: tlsConfig.Certificate, - signingKey: pkey, + signingCert: instanceCert, + signingKey: instanceKey, subscription: tc.tokenSubscription, vmID: tc.tokenVMID, } @@ -524,8 +532,9 @@ func TestJoinAzure(t *testing.T) { }, AuthClient: nopClient, AzureParams: joinclient.AzureParams{ - ClientID: tc.tokenVMID, - IMDSClient: imdsClient, + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + IssuerHTTPClient: httpClient, }, }) tc.assertError(t, err) @@ -551,12 +560,14 @@ func TestJoinAzureClaims(t *testing.T) { nopClient, err := server.NewClient(authtest.TestNop()) require.NoError(t, err) - tlsConfig, err := fixtures.LocalTLSConfig() - require.NoError(t, err) - - block, _ := pem.Decode(fixtures.LocalhostKey) - pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + caChain := newFakeAzureCAChain(t) + httpClient := newFakeAzureIssuerHTTPClient(caChain.intermediateCertDER) + instanceKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) + instanceCert := caChain.issueLeafCert(t, + instanceKey.Public(), + "instance.metadata.azure.com", + "http://www.microsoft.com/pkiops/certs/testcert.crt") isAccessDenied := func(t require.TestingT, err error, _ ...any) { require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err) @@ -610,7 +621,7 @@ func TestJoinAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -632,7 +643,7 @@ func TestJoinAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -654,7 +665,7 @@ func TestJoinAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -677,7 +688,7 @@ func TestJoinAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -700,7 +711,7 @@ func TestJoinAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -723,7 +734,7 @@ func TestJoinAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -746,7 +757,7 @@ func TestJoinAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -768,7 +779,7 @@ func TestJoinAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -790,7 +801,7 @@ func TestJoinAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, } @@ -823,13 +834,14 @@ func TestJoinAzureClaims(t *testing.T) { CertificateAuthorities: tc.certs, Verify: tc.verify, GetVMClient: getVMClient, + IssuerHTTPClient: httpClient, }) imdsClient := &fakeIMDSClient{ accessToken: accessToken, accessTokenErr: tc.challengeResponseErr, - signingCert: tlsConfig.Certificate, - signingKey: pkey, + signingCert: instanceCert, + signingKey: instanceKey, subscription: tc.tokenSubscription, vmID: tc.tokenVMID, } @@ -860,8 +872,9 @@ func TestJoinAzureClaims(t *testing.T) { }, AuthClient: nopClient, AzureParams: joinclient.AzureParams{ - ClientID: tc.tokenVMID, - IMDSClient: imdsClient, + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + IssuerHTTPClient: httpClient, }, }) tc.assertError(t, err) @@ -885,8 +898,9 @@ func TestJoinAzureClaims(t *testing.T) { }, AuthClient: nopClient, AzureParams: joinclient.AzureParams{ - ClientID: tc.tokenVMID, - IMDSClient: imdsClient, + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + IssuerHTTPClient: httpClient, }, }) tc.assertError(t, err) @@ -965,3 +979,89 @@ func (c *fakeIMDSClient) GetAttestedData(_ context.Context, nonce string) ([]byt func (c *fakeIMDSClient) GetAccessToken(_ context.Context, clientID string) (string, error) { return c.accessToken, trace.Wrap(c.accessTokenErr) } + +type fakeAzureCAChain struct { + intermediateKey crypto.Signer + intermediateCert *x509.Certificate + intermediateCertDER []byte + rootCert *x509.Certificate +} + +func newFakeAzureCAChain(t *testing.T) *fakeAzureCAChain { + rootKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + rootCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "test root CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign, + IsCA: true, + BasicConstraintsValid: true, + } + rootCertDER, err := x509.CreateCertificate(rand.Reader, rootCertTemplate, rootCertTemplate, rootKey.Public(), rootKey) + require.NoError(t, err) + rootCert, err := x509.ParseCertificate(rootCertDER) + require.NoError(t, err) + + intermediateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + intermediateCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "test intermediate CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign, + IsCA: true, + BasicConstraintsValid: true, + } + intermediateCertDER, err := x509.CreateCertificate(rand.Reader, intermediateCertTemplate, rootCert, intermediateKey.Public(), rootKey) + require.NoError(t, err) + intermediateCert, err := x509.ParseCertificate(intermediateCertDER) + require.NoError(t, err) + + return &fakeAzureCAChain{ + intermediateKey: intermediateKey, + intermediateCert: intermediateCert, + intermediateCertDER: intermediateCertDER, + rootCert: rootCert, + } +} + +func (c *fakeAzureCAChain) issueLeafCert(t *testing.T, pub crypto.PublicKey, commonName, issuerURL string) *x509.Certificate { + leafCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: commonName, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + IssuingCertificateURL: []string{issuerURL}, + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + leafCertDER, err := x509.CreateCertificate(rand.Reader, leafCertTemplate, c.intermediateCert, pub, c.intermediateKey) + require.NoError(t, err) + leafCert, err := x509.ParseCertificate(leafCertDER) + require.NoError(t, err) + return leafCert +} + +type fakeAzureIssuerHTTPClient struct { + issuerCertDER []byte + called int +} + +func newFakeAzureIssuerHTTPClient(issuerCertDER []byte) *fakeAzureIssuerHTTPClient { + return &fakeAzureIssuerHTTPClient{ + issuerCertDER: issuerCertDER, + } +} +func (c *fakeAzureIssuerHTTPClient) Do(req *http.Request) (*http.Response, error) { + c.called++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(c.issuerCertDER)), + }, nil +} diff --git a/lib/join/joinclient/join_azure.go b/lib/join/joinclient/join_azure.go index 9030c5a465fc9..a359d68e76f36 100644 --- a/lib/join/joinclient/join_azure.go +++ b/lib/join/joinclient/join_azure.go @@ -18,11 +18,17 @@ package joinclient import ( "context" + "crypto/x509" + "net/http" "github.com/gravitational/trace" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/cloud/imds/azure" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/join/azurejoin" "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/utils" ) func azureJoin(ctx context.Context, stream messages.ClientStream, joinParams JoinParams, clientParams messages.ClientParams) (messages.Response, error) { @@ -58,15 +64,20 @@ func azureJoin(ctx context.Context, stream messages.ClientStream, joinParams Joi } ad, err := imds.GetAttestedData(ctx, challenge.Challenge) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "getting attested data document") + } + intermediate, err := getIntermediate(ctx, joinParams.AzureParams.IssuerHTTPClient, ad) + if err != nil { + return nil, trace.Wrap(err, "getting intermediate CA for attested data") } accessToken, err := imds.GetAccessToken(ctx, joinParams.AzureParams.ClientID) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "getting access token") } if err := stream.Send(&messages.AzureChallengeSolution{ AttestedData: ad, + Intermediate: intermediate, AccessToken: accessToken, }); err != nil { return nil, trace.Wrap(err, "sending AzureChallengeSolution") @@ -75,3 +86,50 @@ func azureJoin(ctx context.Context, stream messages.ClientStream, joinParams Joi result, err := stream.Recv() return result, trace.Wrap(err, "receiving join result") } + +func getIntermediate(ctx context.Context, httpClient utils.HTTPDoClient, ad []byte) ([]byte, error) { + _, p7, err := azurejoin.ParseAttestedData(ad) + if err != nil { + return nil, trace.Wrap(err, "parsing attested data document") + } + if len(p7.Certificates) == 0 { + return nil, trace.Errorf("attested data signature has no certificates") + } + leafCert := p7.Certificates[0] + if len(leafCert.IssuingCertificateURL) == 0 { + return nil, trace.Errorf("attested data leaf certificate has no issuing certificate URL") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, leafCert.IssuingCertificateURL[0], nil /*body*/) + if err != nil { + return nil, trace.Wrap(err, "building HTTP request") + } + + if httpClient == nil { + httpClient, err = defaults.HTTPClient() + if err != nil { + return nil, trace.Wrap(err) + } + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, trace.Wrap(err, "fetching intermediate certificate") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, trace.Errorf("failed to fetch intermediate cert, got HTTP status code %d", resp.StatusCode) + } + + body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) + if err != nil { + return nil, trace.Wrap(err, "reading HTTP response body") + } + + if _, err := x509.ParseCertificates(body); err != nil { + return nil, trace.Wrap(err, "parsing intermediate certificate") + } + + return body, nil +} diff --git a/lib/join/server_azure.go b/lib/join/server_azure.go index 744d70d0dd911..c0687eb8753cd 100644 --- a/lib/join/server_azure.go +++ b/lib/join/server_azure.go @@ -72,6 +72,15 @@ func (s *Server) handleAzureJoin( return nil, trace.Wrap(err, "receiving AzureChallengeSolution") } + switch { + case len(solution.AttestedData) == 0: + return nil, trace.BadParameter("client did not send attested data") + case len(solution.Intermediate) == 0: + return nil, trace.BadParameter("client did not send intermediate CAs") + case len(solution.AccessToken) == 0: + return nil, trace.BadParameter("client did not send access token") + } + ptv2, ok := token.(*types.ProvisionTokenV2) if !ok { return nil, trace.BadParameter("Azure join method only supports ProvisionTokenV2, got %T", token) @@ -83,6 +92,7 @@ func (s *Server) handleAzureJoin( Token: ptv2, Challenge: challenge, AttestedData: solution.AttestedData, + Intermediate: solution.Intermediate, AccessToken: solution.AccessToken, Logger: log, Clock: s.cfg.AuthService.GetClock(),