diff --git a/lib/auth/auth.go b/lib/auth/auth.go
index 74c44d6827cfe..92cba647e179c 100644
--- a/lib/auth/auth.go
+++ b/lib/auth/auth.go
@@ -114,6 +114,7 @@ import (
iterstream "github.com/gravitational/teleport/lib/itertools/stream"
"github.com/gravitational/teleport/lib/join"
"github.com/gravitational/teleport/lib/join/azuredevops"
+ "github.com/gravitational/teleport/lib/join/azurejoin"
"github.com/gravitational/teleport/lib/join/bitbucket"
joinboundkeypair "github.com/gravitational/teleport/lib/join/boundkeypair"
"github.com/gravitational/teleport/lib/join/circleci"
@@ -1339,6 +1340,9 @@ type Server struct {
// override the implementation used in tests.
env0IDTokenValidator join.Env0TokenValidator
+ // azureJoinConfig holds configuration for the Azure join method.
+ azureJoinConfig *azurejoin.AzureJoinConfig
+
// loadAllCAs tells tsh to load the host CAs for all clusters when trying to ssh into a node.
loadAllCAs bool
diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go
index dc6bbf3ccac0e..14160fb79aeb2 100644
--- a/lib/auth/bot_test.go
+++ b/lib/auth/bot_test.go
@@ -23,10 +23,6 @@ import (
"context"
"crypto"
"crypto/tls"
- "crypto/x509"
- "encoding/base64"
- "encoding/json"
- "encoding/pem"
"errors"
"fmt"
"io"
@@ -36,10 +32,8 @@ import (
"text/template"
"time"
- "github.com/digitorus/pkcs7"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
@@ -58,16 +52,13 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/utils/keys"
- "github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/authtest"
"github.com/gravitational/teleport/lib/auth/machineid/machineidv1"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/auth/testauthority"
- "github.com/gravitational/teleport/lib/cloud/azure"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
- "github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/join/iamjoin"
"github.com/gravitational/teleport/lib/join/joinclient"
"github.com/gravitational/teleport/lib/kube/token"
@@ -741,89 +732,6 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
require.NoError(t, err)
checkCertLoginIP(t, certs.TLS, remoteAddr)
})
-
- t.Run("Azure method", func(t *testing.T) {
- subID := uuid.NewString()
- resourceGroup := "rg"
- rsID := vmResourceID(subID, resourceGroup, "test-vm")
- vmID := "vmID"
-
- accessToken, err := makeToken(rsID, "", a.GetClock().Now())
- require.NoError(t, err)
-
- // add token to auth server
- azureTokenName := "azure-test-token"
- azureToken, err := types.NewProvisionTokenFromSpec(
- azureTokenName,
- time.Now().Add(time.Minute),
- types.ProvisionTokenSpecV2{
- Roles: []types.SystemRole{types.RoleBot},
- Azure: &types.ProvisionTokenSpecV2Azure{Allow: []*types.ProvisionTokenSpecV2Azure_Rule{{Subscription: subID}}},
- BotName: botName,
- JoinMethod: types.JoinMethodAzure,
- })
- require.NoError(t, err)
- require.NoError(t, a.UpsertToken(ctx, azureToken))
-
- vmClient := &mockAzureVMClient{
- vms: map[string]*azure.VirtualMachine{
- rsID: {
- ID: rsID,
- Name: "test-vm",
- Subscription: subID,
- ResourceGroup: resourceGroup,
- VMID: vmID,
- },
- },
- }
- getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{
- subID: vmClient,
- })
-
- tlsConfig, err := fixtures.LocalTLSConfig()
- require.NoError(t, err)
-
- block, _ := pem.Decode(fixtures.LocalhostKey)
- pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
- require.NoError(t, err)
-
- certs, err := a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
- ad := auth.AttestedData{
- Nonce: challenge,
- SubscriptionID: subID,
- ID: vmID,
- }
- adBytes, err := json.Marshal(&ad)
- require.NoError(t, err)
- s, err := pkcs7.NewSignedData(adBytes)
- require.NoError(t, err)
- require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{}))
- signature, err := s.Finish()
- require.NoError(t, err)
- signedAD := auth.SignedAttestedData{
- Encoding: "pkcs7",
- Signature: base64.StdEncoding.EncodeToString(signature),
- }
- signedADBytes, err := json.Marshal(&signedAD)
- require.NoError(t, err)
-
- req := &proto.RegisterUsingAzureMethodRequest{
- RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{
- Token: azureTokenName,
- HostID: "test-node",
- Role: types.RoleBot,
- PublicSSHKey: sshPubKey,
- PublicTLSKey: tlsPubKey,
- RemoteAddr: remoteAddr,
- },
- AttestedData: signedADBytes,
- AccessToken: accessToken,
- }
- return req, nil
- }, auth.WithAzureCerts([]*x509.Certificate{tlsConfig.Certificate}), auth.WithAzureVerifyFunc(mockVerifyToken(nil)), auth.WithAzureVMClientGetter(getVMClient))
- require.NoError(t, err)
- checkCertLoginIP(t, certs.TLS, remoteAddr)
- })
}
func responseFromAWSIdentity(id iamjoin.AWSIdentity) string {
diff --git a/lib/auth/export_test.go b/lib/auth/export_test.go
index a2eb46bc3a737..3d0dee5ce4e09 100644
--- a/lib/auth/export_test.go
+++ b/lib/auth/export_test.go
@@ -70,8 +70,6 @@ const (
MaxUserAgentLen = maxUserAgentLen
ForwardedTag = forwardedTag
-
- AzureAccessTokenAudience = azureAccessTokenAudience
)
var (
@@ -307,10 +305,6 @@ func TrimUserAgent(userAgent string) string {
return trimUserAgent(userAgent)
}
-func IsAllowedDomain(cn string, domains []string) bool {
- return isAllowedDomain(cn, domains)
-}
-
func GetSnowflakeJWTParams(ctx context.Context, accountName, userName string, publicKey []byte) (string, string) {
return getSnowflakeJWTParams(ctx, accountName, userName, publicKey)
}
@@ -336,31 +330,6 @@ func CheckHeaders(headers http.Header, challenge string, clock clockwork.Clock)
}
type GitHubManager = githubManager
-type AttestedData = attestedData
-type SignedAttestedData = signedAttestedData
-type AzureRegisterOption = azureRegisterOption
-type AzureRegisterConfig = azureRegisterConfig
-type AzureVMClientGetter = vmClientGetter
-type AzureVerifyTokenFunc = azureVerifyTokenFunc
-type AccessTokenClaims = accessTokenClaims
-
-func WithAzureCerts(certs []*x509.Certificate) AzureRegisterOption {
- return func(cfg *AzureRegisterConfig) {
- cfg.certificateAuthorities = certs
- }
-}
-
-func WithAzureVerifyFunc(verify azureVerifyTokenFunc) AzureRegisterOption {
- return func(cfg *AzureRegisterConfig) {
- cfg.verify = verify
- }
-}
-
-func WithAzureVMClientGetter(getVMClient vmClientGetter) AzureRegisterOption {
- return func(cfg *AzureRegisterConfig) {
- cfg.getVMClient = getVMClient
- }
-}
func (s *TLSServer) GRPCServer() *GRPCServer {
return s.grpcServer
diff --git a/lib/auth/join/join.go b/lib/auth/join/join.go
index bd286b53928cb..d5fae28d72cf1 100644
--- a/lib/auth/join/join.go
+++ b/lib/auth/join/join.go
@@ -78,6 +78,19 @@ type AzureParams struct {
// ClientID is the client ID of the managed identity for Teleport to assume
// when authenticating a node.
ClientID string
+ // IMDSClient overrides the client used to fetch data from Azure IMDS.
+ IMDSClient AzureIMDSClient
+ // IssuerHTTPClient, if set, overrides the default HTTP client used to
+ // fetch the intermediate CA which issued the attested data document
+ // signing certificate. Only used when joining via the new join service.
+ IssuerHTTPClient utils.HTTPDoClient
+}
+
+// AzureIMDSClient is a client to Azure's IMDS.
+type AzureIMDSClient interface {
+ IsAvailable(context.Context) bool
+ GetAttestedData(ctx context.Context, nonce string) ([]byte, error)
+ GetAccessToken(ctx context.Context, clientID string) (string, error)
}
// GitlabParams is the parameters specific to the gitlab join method.
@@ -875,7 +888,10 @@ func registerUsingAzureMethod(
ctx context.Context, client joinServiceClient, token string, hostKeys *newHostKeys, params RegisterParams,
) (*proto.Certs, error) {
certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
- imds := azure.NewInstanceMetadataClient()
+ imds := params.AzureParams.IMDSClient
+ if imds == nil {
+ imds = azure.NewInstanceMetadataClient()
+ }
if !imds.IsAvailable(ctx) {
return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?")
}
diff --git a/lib/auth/join_azure_legacy.go b/lib/auth/join_azure_legacy.go
new file mode 100644
index 0000000000000..9032f084e6aeb
--- /dev/null
+++ b/lib/auth/join_azure_legacy.go
@@ -0,0 +1,119 @@
+/*
+ * Teleport
+ * Copyright (C) 2023 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package auth
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/api/client/proto"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/join/azurejoin"
+ "github.com/gravitational/teleport/lib/join/legacyjoin"
+)
+
+// RegisterUsingAzureMethod registers the caller using the Azure join method
+// and returns signed certs to join the cluster.
+//
+// The caller must provide a ChallengeResponseFunc which returns a
+// *proto.RegisterUsingAzureMethodRequest with a signed attested data document
+// including the challenge as a nonce.
+//
+// TODO(nklaassen): DELETE IN 20 when removing the legacy join service.
+func (a *Server) RegisterUsingAzureMethod(
+ ctx context.Context,
+ challengeResponse client.RegisterAzureChallengeResponseFunc,
+) (certs *proto.Certs, err error) {
+ var provisionToken types.ProvisionToken
+ var joinRequest *types.RegisterUsingTokenRequest
+ defer func() {
+ // Emit a log message and audit event on join failure.
+ if err != nil {
+ a.handleJoinFailure(ctx, err, provisionToken, nil, joinRequest)
+ }
+ }()
+
+ if legacyjoin.Disabled() {
+ return nil, trace.Wrap(legacyjoin.ErrDisabled)
+ }
+
+ challenge, err := azurejoin.GenerateAzureChallenge()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ req, err := challengeResponse(challenge)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ joinRequest = req.RegisterUsingTokenRequest
+
+ if err := req.CheckAndSetDefaults(); err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ if provisionToken.GetJoinMethod() != types.JoinMethodAzure {
+ return nil, trace.AccessDenied("this token does not support the Azure join method")
+ }
+
+ ptv2, ok := provisionToken.(*types.ProvisionTokenV2)
+ if !ok {
+ return nil, trace.Wrap(err, "Azure join method only supports ProvisionTokenV2, got %T", provisionToken)
+ }
+
+ joinAttrs, err := azurejoin.CheckAzureRequest(ctx, azurejoin.CheckAzureRequestParams{
+ AzureJoinConfig: a.GetAzureJoinConfig(),
+ Token: ptv2,
+ Challenge: challenge,
+ AttestedData: req.AttestedData,
+ AccessToken: req.AccessToken,
+ Logger: a.logger,
+ Clock: a.GetClock(),
+ })
+ if err != nil {
+ return nil, trace.Wrap(err, "checking Azure challenge response")
+ }
+
+ if req.RegisterUsingTokenRequest.Role == types.RoleBot {
+ params := makeBotCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/, &workloadidentityv1pb.JoinAttrs{
+ Azure: joinAttrs,
+ })
+ certs, _, err := a.GenerateBotCertsForJoin(ctx, provisionToken, params)
+ return certs, trace.Wrap(err)
+ }
+ params := makeHostCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/)
+ certs, err = a.GenerateHostCertsForJoin(ctx, provisionToken, params)
+ return certs, trace.Wrap(err)
+}
+
+// GetAzureJoinConfig gets configuration options for azure joining.
+func (a *Server) GetAzureJoinConfig() *azurejoin.AzureJoinConfig {
+ return a.azureJoinConfig
+}
+
+// SetAzureJoinConfig sets configuration options for azure joining.
+func (a *Server) SetAzureJoinConfig(c *azurejoin.AzureJoinConfig) {
+ a.azureJoinConfig = c
+}
diff --git a/lib/auth/join_azure.go b/lib/join/azurejoin/azure.go
similarity index 62%
rename from lib/auth/join_azure.go
rename to lib/join/azurejoin/azure.go
index ff54227baf252..7d2f5e7957901 100644
--- a/lib/auth/join_azure.go
+++ b/lib/join/azurejoin/azure.go
@@ -1,29 +1,26 @@
-/*
- * Teleport
- * Copyright (C) 2023 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package auth
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package azurejoin
import (
"cmp"
"context"
"crypto/x509"
"encoding/base64"
- "encoding/pem"
"log/slog"
"net/url"
"slices"
@@ -37,21 +34,20 @@ import (
"github.com/digitorus/pkcs7"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
"github.com/zitadel/oidc/v3/pkg/oidc"
- "github.com/gravitational/teleport/api/client"
- "github.com/gravitational/teleport/api/client/proto"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud/azure"
+ "github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/join/joinutils"
- "github.com/gravitational/teleport/lib/join/legacyjoin"
liboidc "github.com/gravitational/teleport/lib/oidc"
"github.com/gravitational/teleport/lib/utils"
)
const (
- azureAccessTokenAudience = "https://management.azure.com/"
+ AzureAccessTokenAudience = "https://management.azure.com/"
// azureUserAgent specifies the Azure User-Agent identification for telemetry.
azureUserAgent = "teleport"
@@ -64,7 +60,8 @@ const (
// Structs for unmarshaling attested data. Schema can be found at
// https://learn.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service?tabs=linux#response-2
-type signedAttestedData struct {
+// SignedAttestedData models the response from the attested data IMDS endpoint.
+type SignedAttestedData struct {
Encoding string `json:"encoding"`
Signature string `json:"signature"`
}
@@ -80,7 +77,8 @@ type timestamp struct {
ExpiresOn string `json:"expiresOn"`
}
-type attestedData struct {
+// AttestedData models the decoded data returned from the attested data IMDS endpoint.
+type AttestedData struct {
LicenseType string `json:"licenseType"`
Nonce string `json:"nonce"`
Plan plan `json:"plan"`
@@ -90,7 +88,8 @@ type attestedData struct {
SKU string `json:"sku"`
}
-type accessTokenClaims struct {
+// AccessTokenClaims models the claims in an Azure access token.
+type AccessTokenClaims struct {
oidc.TokenClaims
TenantID string `json:"tid"`
Version string `json:"ver"`
@@ -111,7 +110,7 @@ type accessTokenClaims struct {
AzureResourceID string `json:"xms_az_rid"`
}
-func (c *accessTokenClaims) AsJWTClaims() jwt.Claims {
+func (c *AccessTokenClaims) asJWTClaims() jwt.Claims {
return jwt.Claims{
Issuer: c.Issuer,
Subject: c.Subject,
@@ -123,24 +122,36 @@ func (c *accessTokenClaims) AsJWTClaims() jwt.Claims {
}
}
-type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error)
-
-type vmClientGetter func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error)
-
-type azureRegisterConfig struct {
- certificateAuthorities []*x509.Certificate
- verify azureVerifyTokenFunc
- getVMClient vmClientGetter
+// AzureVerifyTokenFunc is a function type that verifies an azure VM token.
+type AzureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*AccessTokenClaims, error)
+
+// VMClientGetter is a function type that returns an Azure VM client for a
+// given subscription authenticated with a given static token credential.
+type VMClientGetter func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error)
+
+// AzureJoinConfig holds configurable options for Azure joining.
+type AzureJoinConfig struct {
+ // CertificateAuthorities, if set, overrides the root certificate
+ // authorities used to verify VM attested data.
+ CertificateAuthorities []*x509.Certificate
+ // Verify, if set, overrides the function used to verify azure VM tokens.
+ Verify AzureVerifyTokenFunc
+ // GetVMClient, if set, overrides the function used to get Azure VM clients.
+ GetVMClient VMClientGetter
+ // IssuerHTTPClient, if set, overrides the default HTTP client used to
+ // fetch the intermediate CA which issued the attested data document
+ // signing certificate.
+ IssuerHTTPClient utils.HTTPDoClient
}
-func azureVerifyFuncFromOIDCVerifier(clientID string) azureVerifyTokenFunc {
- return func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error) {
+func azureVerifyFuncFromOIDCVerifier(clientID string) AzureVerifyTokenFunc {
+ return func(ctx context.Context, rawIDToken string) (*AccessTokenClaims, error) {
token, err := jwt.ParseSigned(rawIDToken)
if err != nil {
return nil, trace.Wrap(err)
}
// Need to get the tenant ID before we verify so we can construct the issuer URL.
- var unverifiedClaims accessTokenClaims
+ var unverifiedClaims AccessTokenClaims
if err := token.UnsafeClaimsWithoutVerification(&unverifiedClaims); err != nil {
return nil, trace.Wrap(err)
}
@@ -148,24 +159,24 @@ func azureVerifyFuncFromOIDCVerifier(clientID string) azureVerifyTokenFunc {
if err != nil {
return nil, trace.Wrap(err)
}
- return liboidc.ValidateToken[*accessTokenClaims](ctx, issuer, clientID, rawIDToken)
+ return liboidc.ValidateToken[*AccessTokenClaims](ctx, issuer, clientID, rawIDToken)
}
}
-func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error {
- if cfg.verify == nil {
- cfg.verify = azureVerifyFuncFromOIDCVerifier(azureAccessTokenAudience)
+func (cfg *AzureJoinConfig) checkAndSetDefaults() error {
+ if cfg.Verify == nil {
+ cfg.Verify = azureVerifyFuncFromOIDCVerifier(AzureAccessTokenAudience)
}
- if cfg.certificateAuthorities == nil {
+ if cfg.CertificateAuthorities == nil {
certs, err := getAzureRootCerts()
if err != nil {
return trace.Wrap(err)
}
- cfg.certificateAuthorities = certs
+ cfg.CertificateAuthorities = certs
}
- if cfg.getVMClient == nil {
- cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) {
+ if cfg.GetVMClient == nil {
+ cfg.GetVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) {
// The User-Agent is added for debugging purposes. It helps identify
// and isolate teleport traffic.
opts := &armpolicy.ClientOptions{
@@ -179,57 +190,58 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error {
return client, trace.Wrap(err)
}
}
+ if cfg.IssuerHTTPClient == nil {
+ httpClient, err := defaults.HTTPClient()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ cfg.IssuerHTTPClient = httpClient
+ }
return nil
}
-type azureRegisterOption func(cfg *azureRegisterConfig)
-
// parseAndVeryAttestedData verifies that an attested data document was signed
// by Azure. If verification is successful, it returns the ID of the VM that
// produced the document.
-func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (subscriptionID, vmID string, err error) {
- var signedAD signedAttestedData
- if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil {
- return "", "", trace.Wrap(err)
- }
- if signedAD.Encoding != "pkcs7" {
- return "", "", trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding)
- }
-
- sigPEM := "-----BEGIN PKCS7-----\n" + signedAD.Signature + "\n-----END PKCS7-----"
- sigBER, _ := pem.Decode([]byte(sigPEM))
- if sigBER == nil {
- return "", "", trace.AccessDenied("unable to decode attested data document")
- }
-
- p7, err := pkcs7.Parse(sigBER.Bytes)
+func parseAndVerifyAttestedData(
+ ctx context.Context,
+ cfg *AzureJoinConfig,
+ adBytes []byte,
+ intermediates []byte,
+ challenge string,
+) (subscriptionID, vmID string, err error) {
+ ad, p7, err := ParseAttestedData(adBytes)
if err != nil {
return "", "", trace.Wrap(err)
}
- var ad attestedData
- if err := utils.FastUnmarshal(p7.Content, &ad); err != nil {
- return "", "", trace.Wrap(err)
- }
if ad.Nonce != challenge {
return "", "", trace.AccessDenied("challenge is missing or does not match")
}
-
if len(p7.Certificates) == 0 {
return "", "", trace.AccessDenied("no certificates for signature")
}
fixAzureSigningAlgorithm(p7)
- // Azure only sends the leaf cert, so we have to fetch the intermediate.
- intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0])
- if err != nil {
- return "", "", trace.Wrap(err)
- }
- if intermediate != nil {
- p7.Certificates = append(p7.Certificates, intermediate)
+ if len(intermediates) > 0 {
+ // Client explicitly sent intermediate CAs, included them.
+ intermediates, err := x509.ParseCertificates(intermediates)
+ if err != nil {
+ return "", "", trace.Wrap(err, "parsing intermediate certificates sent by client")
+ }
+ p7.Certificates = append(p7.Certificates, intermediates...)
+ } else {
+ // Client did not send intermediates, fetch them from Azure.
+ intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0], cfg.IssuerHTTPClient)
+ if err != nil {
+ return "", "", trace.Wrap(err)
+ }
+ if intermediate != nil {
+ p7.Certificates = append(p7.Certificates, intermediate)
+ }
}
pool := x509.NewCertPool()
- for _, cert := range certs {
+ for _, cert := range cfg.CertificateAuthorities {
pool.AddCert(cert)
}
@@ -240,18 +252,45 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s
return ad.SubscriptionID, ad.ID, nil
}
+// ParseAttestedData returns the parsed VM attested data and a PKCS7 structure
+// which can be used to verify the signature.
+func ParseAttestedData(adBytes []byte) (*AttestedData, *pkcs7.PKCS7, error) {
+ var signedAD SignedAttestedData
+ if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil {
+ return nil, nil, trace.Wrap(err)
+ }
+ if signedAD.Encoding != "pkcs7" {
+ return nil, nil, trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding)
+ }
+
+ sigDER, err := base64.StdEncoding.DecodeString(signedAD.Signature)
+ if err != nil {
+ return nil, nil, trace.Wrap(err, "decoding attested data document from base64")
+ }
+
+ p7, err := pkcs7.Parse(sigDER)
+ if err != nil {
+ return nil, nil, trace.Wrap(err)
+ }
+ var ad AttestedData
+ if err := utils.FastUnmarshal(p7.Content, &ad); err != nil {
+ return nil, nil, trace.Wrap(err)
+ }
+ return &ad, p7, nil
+}
+
// verifyVMIdentity verifies that the provided access token came from the
// correct Azure VM. Returns the Azure join attributes
func verifyVMIdentity(
ctx context.Context,
- cfg *azureRegisterConfig,
+ cfg *AzureJoinConfig,
accessToken,
subscriptionID,
vmID string,
requestStart time.Time,
logger *slog.Logger,
) (joinAttrs *workloadidentityv1pb.JoinAttrsAzure, err error) {
- tokenClaims, err := cfg.verify(ctx, accessToken)
+ tokenClaims, err := cfg.Verify(ctx, accessToken)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -270,11 +309,11 @@ func verifyVMIdentity(
expectedClaims := jwt.Expected{
Issuer: expectedIssuer,
- Audience: jwt.Audience{azureAccessTokenAudience},
+ Audience: jwt.Audience{AzureAccessTokenAudience},
Time: requestStart,
}
- if err := tokenClaims.AsJWTClaims().Validate(expectedClaims); err != nil {
+ if err := tokenClaims.asJWTClaims().Validate(expectedClaims); err != nil {
return nil, trace.Wrap(err)
}
@@ -299,7 +338,7 @@ func verifyVMIdentity(
Token: accessToken,
ExpiresOn: tokenClaims.GetExpiration(),
})
- vmClient, err := cfg.getVMClient(subscriptionID, tokenCredential)
+ vmClient, err := cfg.GetVMClient(subscriptionID, tokenCredential)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -339,7 +378,7 @@ func verifyVMIdentity(
}
// claimsToIdentifiers returns the vm identifiers from the provided claims.
-func claimsToIdentifiers(tokenClaims *accessTokenClaims) (subscriptionID, resourceGroupID string, err error) {
+func claimsToIdentifiers(tokenClaims *AccessTokenClaims) (subscriptionID, resourceGroupID string, err error) {
// xms_az_rid claim is omitted when the VM is assigned a System-Assigned Identity.
// The xms_mirid claim should be used instead.
rid := cmp.Or(tokenClaims.AzureResourceID, tokenClaims.ManangedIdentityResourceID)
@@ -369,6 +408,7 @@ func checkAzureAllowRules(vmID string, attrs *workloadidentityv1pb.JoinAttrsAzur
}
return trace.AccessDenied("instance %v did not match any allow rules in token %v", vmID, token.GetName())
}
+
func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup string) bool {
if len(allowedResourceGroups) == 0 {
return true
@@ -395,126 +435,82 @@ func azureJoinToAttrs(subscriptionID, resourceGroupID string) *workloadidentityv
}
}
-func (a *Server) checkAzureRequest(
- ctx context.Context,
- challenge string,
- req *proto.RegisterUsingAzureMethodRequest,
- cfg *azureRegisterConfig,
-) (*workloadidentityv1pb.JoinAttrsAzure, error) {
- requestStart := a.clock.Now()
- tokenName := req.RegisterUsingTokenRequest.Token
- provisionToken, err := a.GetToken(ctx, tokenName)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- if provisionToken.GetJoinMethod() != types.JoinMethodAzure {
- return nil, trace.AccessDenied("this token does not support the Azure join method")
- }
- token, ok := provisionToken.(*types.ProvisionTokenV2)
- if !ok {
- return nil, trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken)
- }
-
- subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- attrs, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart, a.logger)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- if err := checkAzureAllowRules(vmID, attrs, token); err != nil {
- return attrs, trace.Wrap(err)
- }
-
- return attrs, nil
+// CheckAzureRequestParams holds all parameters for [CheckAzureRequest].
+type CheckAzureRequestParams struct {
+ // AzureJoinConfig holds configurable options for Azure joining.
+ AzureJoinConfig *AzureJoinConfig
+ // Token is the token used for the incoming request.
+ Token *types.ProvisionTokenV2
+ // Challenge is the challenge that was issued.
+ Challenge string
+ // AttestedData is the Azure attested data that was returned by the joining
+ // client. It must include the challenge as a nonce.
+ AttestedData []byte
+ // Intermediate encodes the intermediate CAs that issued the leaf certificate
+ // used to sign the attested data document, in x509 DER format.
+ Intermediate []byte
+ // AccessToken is the Azure access token that was returned by the joining client
+ AccessToken string
+ // Logger will be used for logging.
+ Logger *slog.Logger
+ // Clock overrides the system time.
+ Clock clockwork.Clock
}
-func generateAzureChallenge() (string, error) {
- challenge, err := joinutils.GenerateChallenge(base64.RawURLEncoding, 24)
- return challenge, trace.Wrap(err)
+func (p *CheckAzureRequestParams) checkAndSetDefaults() error {
+ switch {
+ case p.AzureJoinConfig == nil:
+ p.AzureJoinConfig = &AzureJoinConfig{}
+ case p.Token == nil:
+ return trace.BadParameter("Token is required")
+ case len(p.Challenge) == 0:
+ return trace.BadParameter("Challenge is required")
+ case len(p.AttestedData) == 0:
+ return trace.BadParameter("AttestedData is required")
+ case len(p.AccessToken) == 0:
+ return trace.BadParameter("AccessToken is required")
+ case p.Logger == nil:
+ return trace.BadParameter("Logger is required")
+ case p.Clock == nil:
+ p.Clock = clockwork.NewRealClock()
+ }
+ return trace.Wrap(p.AzureJoinConfig.checkAndSetDefaults())
}
-// RegisterUsingAzureMethodWithOpts registers the caller using the Azure join method
-// and returns signed certs to join the cluster.
-//
-// The caller must provide a ChallengeResponseFunc which returns a
-// *proto.RegisterUsingAzureMethodRequest with a signed attested data document
-// including the challenge as a nonce.
-func (a *Server) RegisterUsingAzureMethodWithOpts(
- ctx context.Context,
- challengeResponse client.RegisterAzureChallengeResponseFunc,
- opts ...azureRegisterOption,
-) (certs *proto.Certs, err error) {
- var provisionToken types.ProvisionToken
- var joinRequest *types.RegisterUsingTokenRequest
- defer func() {
- // Emit a log message and audit event on join failure.
- if err != nil {
- a.handleJoinFailure(ctx, err, provisionToken, nil, joinRequest)
- }
- }()
-
- if legacyjoin.Disabled() {
- return nil, trace.Wrap(legacyjoin.ErrDisabled)
- }
-
- cfg := &azureRegisterConfig{}
- for _, opt := range opts {
- opt(cfg)
- }
- if err := cfg.CheckAndSetDefaults(ctx); err != nil {
+// CheckAzureRequest checks an azure join request by verifying the VMs claims
+// and checking that they match an allow rule from the join token.
+func CheckAzureRequest(ctx context.Context, params CheckAzureRequestParams) (*workloadidentityv1pb.JoinAttrsAzure, error) {
+ if err := params.checkAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
+ requestStart := params.Clock.Now()
- challenge, err := generateAzureChallenge()
- if err != nil {
- return nil, trace.Wrap(err)
- }
- req, err := challengeResponse(challenge)
+ subID, vmID, err := parseAndVerifyAttestedData(
+ ctx,
+ params.AzureJoinConfig,
+ params.AttestedData,
+ params.Intermediate,
+ params.Challenge,
+ )
if err != nil {
return nil, trace.Wrap(err)
}
- joinRequest = req.RegisterUsingTokenRequest
- if err := req.CheckAndSetDefaults(); err != nil {
- return nil, trace.Wrap(err)
- }
-
- provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest)
+ attrs, err := verifyVMIdentity(ctx, params.AzureJoinConfig, params.AccessToken, subID, vmID, requestStart, params.Logger)
if err != nil {
return nil, trace.Wrap(err)
}
-
- joinAttrs, err := a.checkAzureRequest(ctx, challenge, req, cfg)
- if err != nil {
- return nil, trace.Wrap(err)
+ if err := checkAzureAllowRules(vmID, attrs, params.Token); err != nil {
+ return attrs, trace.Wrap(err)
}
- if req.RegisterUsingTokenRequest.Role == types.RoleBot {
- params := makeBotCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/, &workloadidentityv1pb.JoinAttrs{
- Azure: joinAttrs,
- })
- certs, _, err := a.GenerateBotCertsForJoin(ctx, provisionToken, params)
- return certs, trace.Wrap(err)
- }
- params := makeHostCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/)
- certs, err = a.GenerateHostCertsForJoin(ctx, provisionToken, params)
- return certs, trace.Wrap(err)
+ return attrs, nil
}
-// RegisterUsingAzureMethod registers the caller using the Azure join method
-// and returns signed certs to join the cluster.
-//
-// The caller must provide a ChallengeResponseFunc which returns a
-// *proto.RegisterUsingAzureMethodRequest with a signed attested data document
-// including the challenge as a nonce.
-func (a *Server) RegisterUsingAzureMethod(
- ctx context.Context,
- challengeResponse client.RegisterAzureChallengeResponseFunc,
-) (certs *proto.Certs, err error) {
- return a.RegisterUsingAzureMethodWithOpts(ctx, challengeResponse)
+// GenerateAzureChallenge generates a challenge for the Azure join method.
+func GenerateAzureChallenge() (string, error) {
+ challenge, err := joinutils.GenerateChallenge(base64.RawURLEncoding, 24)
+ return challenge, trace.Wrap(err)
}
// fixAzureSigningAlgorithm fixes a mismatch between the object IDs of the
diff --git a/lib/auth/azure_certs.go b/lib/join/azurejoin/azure_certs.go
similarity index 98%
rename from lib/auth/azure_certs.go
rename to lib/join/azurejoin/azure_certs.go
index cea969aadc87b..431e7bcecb35b 100644
--- a/lib/auth/azure_certs.go
+++ b/lib/join/azurejoin/azure_certs.go
@@ -16,7 +16,7 @@
* along with this program. If not, see .
*/
-package auth
+package azurejoin
import (
"context"
@@ -28,7 +28,6 @@ import (
"github.com/gravitational/trace"
"github.com/gravitational/teleport"
- "github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/utils"
)
@@ -49,16 +48,11 @@ func isAllowedDomain(cn string, domains []string) bool {
}
// getAzureIssuerCert fetches a x509 certificate's issuing certificate.
-func getAzureIssuerCert(ctx context.Context, cert *x509.Certificate) (*x509.Certificate, error) {
+func getAzureIssuerCert(ctx context.Context, cert *x509.Certificate, httpClient utils.HTTPDoClient) (*x509.Certificate, error) {
if len(cert.IssuingCertificateURL) == 0 {
return nil, nil
}
- httpClient, err := defaults.HTTPClient()
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
// Azure sends only one issuing cert.
issuerURL := cert.IssuingCertificateURL[0]
commonName := cert.Subject.CommonName
diff --git a/lib/auth/azure_certs_test.go b/lib/join/azurejoin/azure_certs_test.go
similarity index 92%
rename from lib/auth/azure_certs_test.go
rename to lib/join/azurejoin/azure_certs_test.go
index abeb857f0edb7..7ebb1acccf004 100644
--- a/lib/auth/azure_certs_test.go
+++ b/lib/join/azurejoin/azure_certs_test.go
@@ -16,14 +16,12 @@
* along with this program. If not, see .
*/
-package auth_test
+package azurejoin
import (
"testing"
"github.com/stretchr/testify/require"
-
- "github.com/gravitational/teleport/lib/auth"
)
func TestIsAllowedDomain(t *testing.T) {
@@ -71,7 +69,7 @@ func TestIsAllowedDomain(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
- tc.assert(t, auth.IsAllowedDomain(tc.url, allowedDomains))
+ tc.assert(t, isAllowedDomain(tc.url, allowedDomains))
})
}
}
diff --git a/lib/auth/join_azure_test.go b/lib/join/azurejoin/join_azure_test.go
similarity index 65%
rename from lib/auth/join_azure_test.go
rename to lib/join/azurejoin/join_azure_test.go
index deaa014dcc3fa..9752766d8d769 100644
--- a/lib/auth/join_azure_test.go
+++ b/lib/join/azurejoin/join_azure_test.go
@@ -16,15 +16,22 @@
* along with this program. If not, see .
*/
-package auth_test
+package azurejoin_test
import (
+ "bytes"
"context"
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
"crypto/x509"
+ "crypto/x509/pkix"
"encoding/base64"
"encoding/json"
- "encoding/pem"
"fmt"
+ "io"
+ "net/http"
"testing"
"time"
@@ -36,13 +43,16 @@ import (
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
- "github.com/gravitational/teleport/api/client/proto"
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authtest"
- "github.com/gravitational/teleport/lib/auth/testauthority"
+ "github.com/gravitational/teleport/lib/auth/machineid/machineidv1"
+ "github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/cloud/azure"
- "github.com/gravitational/teleport/lib/fixtures"
+ "github.com/gravitational/teleport/lib/join/azurejoin"
+ "github.com/gravitational/teleport/lib/join/joinclient"
+ "github.com/gravitational/teleport/lib/tlsca"
)
type mockAzureVMClient struct {
@@ -67,7 +77,7 @@ func (m *mockAzureVMClient) GetByVMID(_ context.Context, vmID string) (*azure.Vi
return nil, trace.NotFound("no vm with id %q", vmID)
}
-func makeVMClientGetter(clients map[string]*mockAzureVMClient) auth.AzureVMClientGetter {
+func makeVMClientGetter(clients map[string]*mockAzureVMClient) azurejoin.VMClientGetter {
return func(subscriptionID string, _ *azure.StaticCredential) (azure.VirtualMachinesClient, error) {
if client, ok := clients[subscriptionID]; ok {
return client, nil
@@ -76,18 +86,6 @@ func makeVMClientGetter(clients map[string]*mockAzureVMClient) auth.AzureVMClien
}
}
-type azureChallengeResponseConfig struct {
- Challenge string
-}
-
-type azureChallengeResponseOption func(*azureChallengeResponseConfig)
-
-func withChallengeAzure(challenge string) azureChallengeResponseOption {
- return func(cfg *azureChallengeResponseConfig) {
- cfg.Challenge = challenge
- }
-}
-
func vmssResourceID(subscription, resourceGroup, name string) string {
return resourceID("Microsoft.Compute/virtualMachineScaleSets", subscription, resourceGroup, name)
}
@@ -107,8 +105,8 @@ func resourceID(resourceType, subscription, resourceGroup, name string) string {
)
}
-func mockVerifyToken(err error) auth.AzureVerifyTokenFunc {
- return func(_ context.Context, rawToken string) (*auth.AccessTokenClaims, error) {
+func mockVerifyToken(err error) azurejoin.AzureVerifyTokenFunc {
+ return func(_ context.Context, rawToken string) (*azurejoin.AccessTokenClaims, error) {
if err != nil {
return nil, err
}
@@ -116,7 +114,7 @@ func mockVerifyToken(err error) auth.AzureVerifyTokenFunc {
if err != nil {
return nil, trace.Wrap(err)
}
- var claims auth.AccessTokenClaims
+ var claims azurejoin.AccessTokenClaims
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, trace.Wrap(err)
}
@@ -132,10 +130,10 @@ func makeToken(managedIdentityResourceID, azureResourceID string, issueTime time
if err != nil {
return "", trace.Wrap(err)
}
- claims := auth.AccessTokenClaims{
+ claims := azurejoin.AccessTokenClaims{
TokenClaims: oidc.TokenClaims{
Issuer: "https://sts.windows.net/test-tenant-id/",
- Audience: []string{auth.AzureAccessTokenAudience},
+ Audience: []string{azurejoin.AzureAccessTokenAudience},
Subject: "test",
IssuedAt: oidc.FromTime(issueTime),
NotBefore: oidc.FromTime(issueTime),
@@ -154,25 +152,29 @@ func makeToken(managedIdentityResourceID, azureResourceID string, issueTime time
return raw, nil
}
-func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
+func TestJoinAzure(t *testing.T) {
t.Parallel()
-
ctx := t.Context()
- p := newAuthSuite(t)
- a := p.a
-
- sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair()
- require.NoError(t, err)
- tlsConfig, err := fixtures.LocalTLSConfig()
+ server, err := authtest.NewTestServer(authtest.ServerConfig{
+ Auth: authtest.AuthServerConfig{
+ Dir: t.TempDir(),
+ },
+ })
require.NoError(t, err)
+ a := server.Auth()
- block, _ := pem.Decode(fixtures.LocalhostKey)
- pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ nopClient, err := server.NewClient(authtest.TestNop())
require.NoError(t, err)
- tlsPublicKey, err := authtest.PrivateKeyToPublicKeyTLS(sshPrivateKey)
+ caChain := newFakeAzureCAChain(t)
+ httpClient := newFakeAzureIssuerHTTPClient(caChain.intermediateCertDER)
+ instanceKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
+ instanceCert := caChain.issueLeafCert(t,
+ instanceKey.Public(),
+ "instance.metadata.azure.com",
+ "http://www.microsoft.com/pkiops/certs/testcert.crt")
isAccessDenied := func(t require.TestingT, err error, _ ...any) {
require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err)
@@ -197,10 +199,10 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
tokenVMID string
requestTokenName string
tokenSpec types.ProvisionTokenSpecV2
- challengeResponseOptions []azureChallengeResponseOption
+ overrideReturnedChallenge string
challengeResponseErr error
certs []*x509.Certificate
- verify auth.AzureVerifyTokenFunc
+ verify azurejoin.AzureVerifyTokenFunc
assertError require.ErrorAssertionFunc
}{
{
@@ -221,7 +223,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: require.NoError,
},
{
@@ -242,7 +244,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: require.NoError,
},
{
@@ -262,7 +264,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -282,7 +284,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
challengeResponseErr: trace.BadParameter("test error"),
assertError: isBadParameter,
},
@@ -303,7 +305,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -324,7 +326,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -343,12 +345,10 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
},
JoinMethod: types.JoinMethodAzure,
},
- challengeResponseOptions: []azureChallengeResponseOption{
- withChallengeAzure("wrong-challenge"),
- },
- verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
- assertError: isAccessDenied,
+ overrideReturnedChallenge: "wrong-challenge",
+ verify: mockVerifyToken(nil),
+ certs: []*x509.Certificate{caChain.rootCert},
+ assertError: isAccessDenied,
},
{
name: "invalid signature",
@@ -388,7 +388,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -409,7 +409,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -430,7 +430,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: require.NoError,
},
{
@@ -451,13 +451,35 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: require.NoError,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
+ vmClient := &mockAzureVMClient{
+ vms: map[string]*azure.VirtualMachine{
+ defaultVMResourceID: {
+ ID: defaultVMResourceID,
+ Name: defaultVMName,
+ Subscription: defaultSubscription,
+ ResourceGroup: defaultResourceGroup,
+ VMID: defaultVMID,
+ },
+ },
+ }
+ getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{
+ defaultSubscription: vmClient,
+ })
+
+ a.SetAzureJoinConfig(&azurejoin.AzureJoinConfig{
+ CertificateAuthorities: tc.certs,
+ Verify: tc.verify,
+ GetVMClient: getVMClient,
+ IssuerHTTPClient: httpClient,
+ })
+
token, err := types.NewProvisionTokenFromSpec(
"test-token",
time.Now().Add(time.Minute),
@@ -476,85 +498,76 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
accessToken, err := makeToken(mirID, "", a.GetClock().Now())
require.NoError(t, err)
- vmClient := &mockAzureVMClient{
- vms: map[string]*azure.VirtualMachine{
- defaultVMResourceID: {
- ID: defaultVMResourceID,
- Name: defaultVMName,
- Subscription: defaultSubscription,
- ResourceGroup: defaultResourceGroup,
- VMID: defaultVMID,
- },
- },
+ imdsClient := &fakeIMDSClient{
+ accessToken: accessToken,
+ accessTokenErr: tc.challengeResponseErr,
+ overrideChallenge: tc.overrideReturnedChallenge,
+ signingCert: instanceCert,
+ signingKey: instanceKey,
+ subscription: tc.tokenSubscription,
+ vmID: tc.tokenVMID,
}
- getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{
- defaultSubscription: vmClient,
- })
-
- _, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
- cfg := &azureChallengeResponseConfig{Challenge: challenge}
- for _, opt := range tc.challengeResponseOptions {
- opt(cfg)
- }
-
- ad := auth.AttestedData{
- Nonce: cfg.Challenge,
- SubscriptionID: tc.tokenSubscription,
- ID: tc.tokenVMID,
- }
- adBytes, err := json.Marshal(&ad)
- require.NoError(t, err)
- s, err := pkcs7.NewSignedData(adBytes)
- require.NoError(t, err)
- require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{}))
- signature, err := s.Finish()
- require.NoError(t, err)
- signedAD := auth.SignedAttestedData{
- Encoding: "pkcs7",
- Signature: base64.StdEncoding.EncodeToString(signature),
- }
- signedADBytes, err := json.Marshal(&signedAD)
- require.NoError(t, err)
- req := &proto.RegisterUsingAzureMethodRequest{
- RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{
- Token: tc.requestTokenName,
- HostID: "test-node",
- Role: types.RoleNode,
- PublicSSHKey: sshPublicKey,
- PublicTLSKey: tlsPublicKey,
+ t.Run("legacy", func(t *testing.T) {
+ _, err = joinclient.LegacyJoin(ctx, joinclient.JoinParams{
+ Token: tc.requestTokenName,
+ JoinMethod: types.JoinMethodAzure,
+ ID: state.IdentityID{
+ Role: types.RoleInstance,
+ HostUUID: "testuuid",
},
- AttestedData: signedADBytes,
- AccessToken: accessToken,
- }
- return req, tc.challengeResponseErr
- }, auth.WithAzureCerts(tc.certs), auth.WithAzureVerifyFunc(tc.verify), auth.WithAzureVMClientGetter(getVMClient))
- tc.assertError(t, err)
+ AuthClient: nopClient,
+ AzureParams: joinclient.AzureParams{
+ ClientID: tc.tokenVMID,
+ IMDSClient: imdsClient,
+ },
+ })
+ tc.assertError(t, err)
+ })
+ t.Run("new", func(t *testing.T) {
+ _, err = joinclient.Join(ctx, joinclient.JoinParams{
+ Token: tc.requestTokenName,
+ ID: state.IdentityID{
+ Role: types.RoleInstance,
+ },
+ AuthClient: nopClient,
+ AzureParams: joinclient.AzureParams{
+ ClientID: tc.tokenVMID,
+ IMDSClient: imdsClient,
+ IssuerHTTPClient: httpClient,
+ },
+ })
+ tc.assertError(t, err)
+ })
})
}
}
// TestAuth_RegisterUsingAzureClaims tests the Azure join method by verifying
// joining VMs by the token claims rather than from the Azure VM API.
-func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
+func TestJoinAzureClaims(t *testing.T) {
t.Parallel()
-
ctx := t.Context()
- p := newAuthSuite(t)
- a := p.a
-
- sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair()
- require.NoError(t, err)
- tlsConfig, err := fixtures.LocalTLSConfig()
+ server, err := authtest.NewTestServer(authtest.ServerConfig{
+ Auth: authtest.AuthServerConfig{
+ Dir: t.TempDir(),
+ },
+ })
require.NoError(t, err)
+ a := server.Auth()
- block, _ := pem.Decode(fixtures.LocalhostKey)
- pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ nopClient, err := server.NewClient(authtest.TestNop())
require.NoError(t, err)
- tlsPublicKey, err := authtest.PrivateKeyToPublicKeyTLS(sshPrivateKey)
+ caChain := newFakeAzureCAChain(t)
+ httpClient := newFakeAzureIssuerHTTPClient(caChain.intermediateCertDER)
+ instanceKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
+ instanceCert := caChain.issueLeafCert(t,
+ instanceKey.Public(),
+ "instance.metadata.azure.com",
+ "http://www.microsoft.com/pkiops/certs/testcert.crt")
isAccessDenied := func(t require.TestingT, err error, _ ...any) {
require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err)
@@ -565,6 +578,17 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
defaultIdentityName := "test-id"
defaultVMID := "my-vm-id"
+ botName := "botty"
+ _, err = machineidv1.UpsertBot(ctx, a, &machineidv1pb.Bot{
+ Kind: types.KindBot,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: botName,
+ },
+ Spec: &machineidv1pb.BotSpec{},
+ }, a.GetClock().Now(), "")
+ require.NoError(t, err)
+
tests := []struct {
name string
tokenManagedIdentityResourceID string
@@ -573,10 +597,9 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
tokenVMID string
requestTokenName string
tokenSpec types.ProvisionTokenSpecV2
- challengeResponseOptions []azureChallengeResponseOption
challengeResponseErr error
certs []*x509.Certificate
- verify auth.AzureVerifyTokenFunc
+ verify azurejoin.AzureVerifyTokenFunc
assertError require.ErrorAssertionFunc
}{
{
@@ -598,7 +621,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: require.NoError,
},
{
@@ -620,7 +643,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -642,7 +665,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -665,7 +688,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: require.NoError,
},
{
@@ -688,7 +711,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -711,7 +734,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -734,7 +757,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: require.NoError,
},
{
@@ -756,7 +779,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: isAccessDenied,
},
{
@@ -778,7 +801,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
- certs: []*x509.Certificate{tlsConfig.Certificate},
+ certs: []*x509.Certificate{caChain.rootCert},
assertError: require.NoError,
},
}
@@ -807,45 +830,238 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) {
defaultSubscription: vmClient,
})
- _, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
- cfg := &azureChallengeResponseConfig{Challenge: challenge}
- for _, opt := range tc.challengeResponseOptions {
- opt(cfg)
- }
+ a.SetAzureJoinConfig(&azurejoin.AzureJoinConfig{
+ CertificateAuthorities: tc.certs,
+ Verify: tc.verify,
+ GetVMClient: getVMClient,
+ IssuerHTTPClient: httpClient,
+ })
- ad := auth.AttestedData{
- Nonce: cfg.Challenge,
- SubscriptionID: tc.tokenSubscription,
- ID: tc.tokenVMID,
- }
- adBytes, err := json.Marshal(&ad)
- require.NoError(t, err)
- s, err := pkcs7.NewSignedData(adBytes)
- require.NoError(t, err)
- require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{}))
- signature, err := s.Finish()
- require.NoError(t, err)
- signedAD := auth.SignedAttestedData{
- Encoding: "pkcs7",
- Signature: base64.StdEncoding.EncodeToString(signature),
- }
- signedADBytes, err := json.Marshal(&signedAD)
+ imdsClient := &fakeIMDSClient{
+ accessToken: accessToken,
+ accessTokenErr: tc.challengeResponseErr,
+ signingCert: instanceCert,
+ signingKey: instanceKey,
+ subscription: tc.tokenSubscription,
+ vmID: tc.tokenVMID,
+ }
+
+ t.Run("legacy", func(t *testing.T) {
+ // Try to join via the legacy join service.
+ _, err = joinclient.LegacyJoin(ctx, joinclient.JoinParams{
+ Token: tc.requestTokenName,
+ JoinMethod: types.JoinMethodAzure,
+ ID: state.IdentityID{
+ Role: types.RoleInstance,
+ HostUUID: "testuuid",
+ },
+ AuthClient: nopClient,
+ AzureParams: joinclient.AzureParams{
+ ClientID: tc.tokenVMID,
+ IMDSClient: imdsClient,
+ },
+ })
+ tc.assertError(t, err)
+ })
+ t.Run("new", func(t *testing.T) {
+ // Try to join via the new join service.
+ _, err = joinclient.Join(ctx, joinclient.JoinParams{
+ Token: tc.requestTokenName,
+ ID: state.IdentityID{
+ Role: types.RoleInstance,
+ },
+ AuthClient: nopClient,
+ AzureParams: joinclient.AzureParams{
+ ClientID: tc.tokenVMID,
+ IMDSClient: imdsClient,
+ IssuerHTTPClient: httpClient,
+ },
+ })
+ tc.assertError(t, err)
+ })
+ t.Run("bot", func(t *testing.T) {
+ // Try to join as a bot.
+ tokenSpec := tc.tokenSpec
+ tokenSpec.BotName = botName
+ tokenSpec.Roles = types.SystemRoles{types.RoleBot}
+ token, err := types.NewProvisionTokenFromSpec(
+ "test-token",
+ time.Now().Add(time.Minute),
+ tokenSpec)
require.NoError(t, err)
+ require.NoError(t, a.UpsertToken(ctx, token))
- req := &proto.RegisterUsingAzureMethodRequest{
- RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{
- Token: tc.requestTokenName,
- HostID: "test-node",
- Role: types.RoleNode,
- PublicSSHKey: sshPublicKey,
- PublicTLSKey: tlsPublicKey,
+ result, err := joinclient.Join(ctx, joinclient.JoinParams{
+ Token: tc.requestTokenName,
+ ID: state.IdentityID{
+ Role: types.RoleBot,
},
- AttestedData: signedADBytes,
- AccessToken: accessToken,
+ AuthClient: nopClient,
+ AzureParams: joinclient.AzureParams{
+ ClientID: tc.tokenVMID,
+ IMDSClient: imdsClient,
+ IssuerHTTPClient: httpClient,
+ },
+ })
+ tc.assertError(t, err)
+ if err != nil {
+ return
}
- return req, tc.challengeResponseErr
- }, auth.WithAzureCerts(tc.certs), auth.WithAzureVerifyFunc(tc.verify), auth.WithAzureVMClientGetter(getVMClient))
- tc.assertError(t, err)
+
+ cert, err := tlsca.ParseCertificatePEM(result.Certs.TLS)
+ require.NoError(t, err)
+ identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter)
+ require.NoError(t, err)
+
+ // Make sure the LoginIP was set on the identity.
+ require.NotEmpty(t, identity.LoginIP)
+
+ // Make sure the JoinAttributes were set.
+ require.NotNil(t, identity.JoinAttributes)
+ require.NotNil(t, identity.JoinAttributes.Azure)
+ require.Equal(t, tc.tokenSubscription, identity.JoinAttributes.Azure.Subscription)
+ })
})
}
}
+
+type fakeIMDSClient struct {
+ accessToken string
+ accessTokenErr error
+
+ // overrideChallenge overrides the challenge/nonce included in attested data.
+ overrideChallenge string
+ signingCert *x509.Certificate
+ signingKey crypto.Signer
+ subscription string
+ vmID string
+}
+
+func (c *fakeIMDSClient) IsAvailable(_ context.Context) bool {
+ return true
+}
+
+func (c *fakeIMDSClient) GetAttestedData(_ context.Context, nonce string) ([]byte, error) {
+ ad := azurejoin.AttestedData{
+ Nonce: nonce,
+ SubscriptionID: c.subscription,
+ ID: c.vmID,
+ }
+ if c.overrideChallenge != "" {
+ ad.Nonce = c.overrideChallenge
+ }
+ adBytes, err := json.Marshal(&ad)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ s, err := pkcs7.NewSignedData(adBytes)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ if err := s.AddSigner(c.signingCert, c.signingKey, pkcs7.SignerInfoConfig{}); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ signature, err := s.Finish()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ signedAD := azurejoin.SignedAttestedData{
+ Encoding: "pkcs7",
+ Signature: base64.StdEncoding.EncodeToString(signature),
+ }
+ signedADBytes, err := json.Marshal(&signedAD)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return signedADBytes, nil
+}
+
+func (c *fakeIMDSClient) GetAccessToken(_ context.Context, clientID string) (string, error) {
+ return c.accessToken, trace.Wrap(c.accessTokenErr)
+}
+
+type fakeAzureCAChain struct {
+ intermediateKey crypto.Signer
+ intermediateCert *x509.Certificate
+ intermediateCertDER []byte
+ rootCert *x509.Certificate
+}
+
+func newFakeAzureCAChain(t *testing.T) *fakeAzureCAChain {
+ rootKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ require.NoError(t, err)
+ rootCertTemplate := &x509.Certificate{
+ Subject: pkix.Name{
+ CommonName: "test root CA",
+ },
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(time.Hour),
+ KeyUsage: x509.KeyUsageCertSign,
+ IsCA: true,
+ BasicConstraintsValid: true,
+ }
+ rootCertDER, err := x509.CreateCertificate(rand.Reader, rootCertTemplate, rootCertTemplate, rootKey.Public(), rootKey)
+ require.NoError(t, err)
+ rootCert, err := x509.ParseCertificate(rootCertDER)
+ require.NoError(t, err)
+
+ intermediateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ require.NoError(t, err)
+ intermediateCertTemplate := &x509.Certificate{
+ Subject: pkix.Name{
+ CommonName: "test intermediate CA",
+ },
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(time.Hour),
+ KeyUsage: x509.KeyUsageCertSign,
+ IsCA: true,
+ BasicConstraintsValid: true,
+ }
+ intermediateCertDER, err := x509.CreateCertificate(rand.Reader, intermediateCertTemplate, rootCert, intermediateKey.Public(), rootKey)
+ require.NoError(t, err)
+ intermediateCert, err := x509.ParseCertificate(intermediateCertDER)
+ require.NoError(t, err)
+
+ return &fakeAzureCAChain{
+ intermediateKey: intermediateKey,
+ intermediateCert: intermediateCert,
+ intermediateCertDER: intermediateCertDER,
+ rootCert: rootCert,
+ }
+}
+
+func (c *fakeAzureCAChain) issueLeafCert(t *testing.T, pub crypto.PublicKey, commonName, issuerURL string) *x509.Certificate {
+ leafCertTemplate := &x509.Certificate{
+ Subject: pkix.Name{
+ CommonName: commonName,
+ },
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(time.Hour),
+ IssuingCertificateURL: []string{issuerURL},
+ KeyUsage: x509.KeyUsageDigitalSignature,
+ BasicConstraintsValid: true,
+ }
+ leafCertDER, err := x509.CreateCertificate(rand.Reader, leafCertTemplate, c.intermediateCert, pub, c.intermediateKey)
+ require.NoError(t, err)
+ leafCert, err := x509.ParseCertificate(leafCertDER)
+ require.NoError(t, err)
+ return leafCert
+}
+
+type fakeAzureIssuerHTTPClient struct {
+ issuerCertDER []byte
+ called int
+}
+
+func newFakeAzureIssuerHTTPClient(issuerCertDER []byte) *fakeAzureIssuerHTTPClient {
+ return &fakeAzureIssuerHTTPClient{
+ issuerCertDER: issuerCertDER,
+ }
+}
+func (c *fakeAzureIssuerHTTPClient) Do(req *http.Request) (*http.Response, error) {
+ c.called++
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(bytes.NewReader(c.issuerCertDER)),
+ }, nil
+}
diff --git a/lib/join/joinclient/join.go b/lib/join/joinclient/join.go
index 1374b9b5a030d..7dc1f80cace3b 100644
--- a/lib/join/joinclient/join.go
+++ b/lib/join/joinclient/join.go
@@ -204,6 +204,7 @@ func joinWithClient(ctx context.Context, params JoinParams, client *joinv1.Clien
case types.JoinMethodUnspecified:
// leave joinMethodPtr nil to let the server pick based on the token
case types.JoinMethodToken,
+ types.JoinMethodAzure,
types.JoinMethodAzureDevops,
types.JoinMethodBitbucket,
types.JoinMethodBoundKeypair,
@@ -306,6 +307,8 @@ func joinWithMethod(
switch types.JoinMethod(method) {
case types.JoinMethodToken:
return tokenJoin(stream, clientParams)
+ case types.JoinMethodAzure:
+ return azureJoin(ctx, stream, joinParams, clientParams)
case types.JoinMethodAzureDevops:
if joinParams.IDToken == "" {
joinParams.IDToken, err = azuredevops.NewIDTokenSource(os.Getenv).GetIDToken(ctx)
diff --git a/lib/join/joinclient/join_azure.go b/lib/join/joinclient/join_azure.go
new file mode 100644
index 0000000000000..a359d68e76f36
--- /dev/null
+++ b/lib/join/joinclient/join_azure.go
@@ -0,0 +1,135 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package joinclient
+
+import (
+ "context"
+ "crypto/x509"
+ "net/http"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/lib/cloud/imds/azure"
+ "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/join/azurejoin"
+ "github.com/gravitational/teleport/lib/join/internal/messages"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+func azureJoin(ctx context.Context, stream messages.ClientStream, joinParams JoinParams, clientParams messages.ClientParams) (messages.Response, error) {
+ // The Azure join method involves the following messages:
+ //
+ // client->server ClientInit
+ // client<-server ServerInit
+ // client->server AzureInit
+ // client<-server AzureChallenge
+ // client->server AzureChallengeSolution
+ // client<-server Result
+ //
+ // At this point the ServerInit messages has already been received, what's
+ // left is to send the AzureInit message, handle the challenge-response, and
+ // receive and return the final result.
+ if err := stream.Send(&messages.AzureInit{
+ ClientParams: clientParams,
+ }); err != nil {
+ return nil, trace.Wrap(err, "sending AzureInit")
+ }
+
+ challenge, err := messages.RecvResponse[*messages.AzureChallenge](stream)
+ if err != nil {
+ return nil, trace.Wrap(err, "receiving AzureChallenge")
+ }
+
+ imds := joinParams.AzureParams.IMDSClient
+ if imds == nil {
+ imds = azure.NewInstanceMetadataClient()
+ }
+ if !imds.IsAvailable(ctx) {
+ return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?")
+ }
+ ad, err := imds.GetAttestedData(ctx, challenge.Challenge)
+ if err != nil {
+ return nil, trace.Wrap(err, "getting attested data document")
+ }
+ intermediate, err := getIntermediate(ctx, joinParams.AzureParams.IssuerHTTPClient, ad)
+ if err != nil {
+ return nil, trace.Wrap(err, "getting intermediate CA for attested data")
+ }
+ accessToken, err := imds.GetAccessToken(ctx, joinParams.AzureParams.ClientID)
+ if err != nil {
+ return nil, trace.Wrap(err, "getting access token")
+ }
+
+ if err := stream.Send(&messages.AzureChallengeSolution{
+ AttestedData: ad,
+ Intermediate: intermediate,
+ AccessToken: accessToken,
+ }); err != nil {
+ return nil, trace.Wrap(err, "sending AzureChallengeSolution")
+ }
+
+ result, err := stream.Recv()
+ return result, trace.Wrap(err, "receiving join result")
+}
+
+func getIntermediate(ctx context.Context, httpClient utils.HTTPDoClient, ad []byte) ([]byte, error) {
+ _, p7, err := azurejoin.ParseAttestedData(ad)
+ if err != nil {
+ return nil, trace.Wrap(err, "parsing attested data document")
+ }
+ if len(p7.Certificates) == 0 {
+ return nil, trace.Errorf("attested data signature has no certificates")
+ }
+ leafCert := p7.Certificates[0]
+ if len(leafCert.IssuingCertificateURL) == 0 {
+ return nil, trace.Errorf("attested data leaf certificate has no issuing certificate URL")
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, leafCert.IssuingCertificateURL[0], nil /*body*/)
+ if err != nil {
+ return nil, trace.Wrap(err, "building HTTP request")
+ }
+
+ if httpClient == nil {
+ httpClient, err = defaults.HTTPClient()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ }
+
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ return nil, trace.Wrap(err, "fetching intermediate certificate")
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, trace.Errorf("failed to fetch intermediate cert, got HTTP status code %d", resp.StatusCode)
+ }
+
+ body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize)
+ if err != nil {
+ return nil, trace.Wrap(err, "reading HTTP response body")
+ }
+
+ if _, err := x509.ParseCertificates(body); err != nil {
+ return nil, trace.Wrap(err, "parsing intermediate certificate")
+ }
+
+ return body, nil
+}
diff --git a/lib/join/server.go b/lib/join/server.go
index 08b3d9701449b..5d297b03d6ec1 100644
--- a/lib/join/server.go
+++ b/lib/join/server.go
@@ -46,6 +46,7 @@ import (
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/join/azuredevops"
+ "github.com/gravitational/teleport/lib/join/azurejoin"
"github.com/gravitational/teleport/lib/join/bitbucket"
"github.com/gravitational/teleport/lib/join/circleci"
"github.com/gravitational/teleport/lib/join/ec2join"
@@ -104,6 +105,7 @@ type AuthService interface {
GetK8sOIDCValidator() *kubetoken.KubernetesOIDCTokenValidator
GetSpaceliftIDTokenValidator() spacelift.Validator
GetTerraformIDTokenValidator() terraformcloud.Validator
+ GetAzureJoinConfig() *azurejoin.AzureJoinConfig
services.Presence
}
@@ -296,6 +298,8 @@ func (s *Server) handleJoinMethod(
joinMethod types.JoinMethod,
) (messages.Response, error) {
switch joinMethod {
+ case types.JoinMethodAzure:
+ return s.handleAzureJoin(stream, authCtx, clientInit, token)
case types.JoinMethodAzureDevops:
return s.handleOIDCJoin(stream, authCtx, clientInit, token, s.validateAzureDevopsToken)
case types.JoinMethodBitbucket:
diff --git a/lib/join/server_azure.go b/lib/join/server_azure.go
new file mode 100644
index 0000000000000..c0687eb8753cd
--- /dev/null
+++ b/lib/join/server_azure.go
@@ -0,0 +1,118 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package join
+
+import (
+ "github.com/gravitational/trace"
+
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/join/azurejoin"
+ "github.com/gravitational/teleport/lib/join/internal/authz"
+ "github.com/gravitational/teleport/lib/join/internal/messages"
+ "github.com/gravitational/teleport/lib/join/provision"
+)
+
+// handleAzureJoin handles join attempts for the Azure join method.
+//
+// The Azure join method involves the following messages:
+//
+// client->server ClientInit
+// client<-server ServerInit
+// client->server AzureInit
+// client<-server AzureChallenge
+// client->server AzureChallengeSolution
+// client<-server Result
+//
+// At this point the ServerInit message has already been sent, what's left is
+// to receive the AzureInit message, handle the challenge-response, and send the
+// final result if everything checks out.
+func (s *Server) handleAzureJoin(
+ stream messages.ServerStream,
+ authCtx *authz.Context,
+ clientInit *messages.ClientInit,
+ token provision.Token,
+) (messages.Response, error) {
+ // Receive the AzureInit message from the client.
+ azureInit, err := messages.RecvRequest[*messages.AzureInit](stream)
+ if err != nil {
+ return nil, trace.Wrap(err, "receiving AzureInit message")
+ }
+ // Set any diagnostic info from the ClientParams.
+ setDiagnosticClientParams(stream.Diagnostic(), &azureInit.ClientParams)
+
+ // Generate and send the challenge.
+ challenge, err := azurejoin.GenerateAzureChallenge()
+ if err != nil {
+ return nil, trace.Wrap(err, "generating challenge")
+ }
+ if err := stream.Send(&messages.AzureChallenge{
+ Challenge: challenge,
+ }); err != nil {
+ return nil, trace.Wrap(err, "sending AzureChallenge")
+ }
+
+ // Receive the solution from the client.
+ solution, err := messages.RecvRequest[*messages.AzureChallengeSolution](stream)
+ if err != nil {
+ return nil, trace.Wrap(err, "receiving AzureChallengeSolution")
+ }
+
+ switch {
+ case len(solution.AttestedData) == 0:
+ return nil, trace.BadParameter("client did not send attested data")
+ case len(solution.Intermediate) == 0:
+ return nil, trace.BadParameter("client did not send intermediate CAs")
+ case len(solution.AccessToken) == 0:
+ return nil, trace.BadParameter("client did not send access token")
+ }
+
+ ptv2, ok := token.(*types.ProvisionTokenV2)
+ if !ok {
+ return nil, trace.BadParameter("Azure join method only supports ProvisionTokenV2, got %T", token)
+ }
+
+ // Verify the client's identity and make sure it matches an allow rule in the provision token.
+ claims, err := azurejoin.CheckAzureRequest(stream.Context(), azurejoin.CheckAzureRequestParams{
+ AzureJoinConfig: s.cfg.AuthService.GetAzureJoinConfig(),
+ Token: ptv2,
+ Challenge: challenge,
+ AttestedData: solution.AttestedData,
+ Intermediate: solution.Intermediate,
+ AccessToken: solution.AccessToken,
+ Logger: log,
+ Clock: s.cfg.AuthService.GetClock(),
+ })
+ if err != nil {
+ return nil, trace.Wrap(err, "checking Azure challenge solution")
+ }
+
+ // Make and return the final result message.
+ result, err := s.makeResult(
+ stream.Context(),
+ stream.Diagnostic(),
+ authCtx,
+ clientInit,
+ &azureInit.ClientParams,
+ token,
+ nil, // rawClaims
+ &workloadidentityv1pb.JoinAttrs{
+ Azure: claims,
+ },
+ )
+ return result, trace.Wrap(err)
+}