diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 0f3b14dc9d8f1..c1d0da6be1c5f 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -54,6 +54,8 @@ const ( azureUserAgent = "teleport" // azureVirtualMachine specifies the Azure virtual machine resource type. azureVirtualMachine = "virtualMachines" + // azureVirtualMachineScaleSet specifies the Azure virtual machine scale set resource type. + azureVirtualMachineScaleSet = "virtualMachineScaleSets" ) // Structs for unmarshaling attested data. Schema can be found at @@ -348,10 +350,14 @@ func claimsToIdentifiers(tokenClaims *accessTokenClaims) (subscriptionID, resour if err != nil { return "", "", trace.Wrap(err, "failed to parse resource id from claims") } - if !slices.Contains(resourceID.ResourceType.Types, azureVirtualMachine) { - return "", "", trace.BadParameter("unexpected resource type: %q", resourceID.ResourceType.Type) + + for _, resourceType := range resourceID.ResourceType.Types { + switch resourceType { + case azureVirtualMachine, azureVirtualMachineScaleSet: + return resourceID.SubscriptionID, resourceID.ResourceGroupName, nil + } } - return resourceID.SubscriptionID, resourceID.ResourceGroupName, nil + return "", "", trace.BadParameter("unexpected resource type: %q", resourceID.ResourceType.Type) } func checkAzureAllowRules(vmID string, attrs *workloadidentityv1pb.JoinAttrsAzure, token *types.ProvisionTokenV2) error { diff --git a/lib/auth/join_azure_test.go b/lib/auth/join_azure_test.go index 29b4d0e483db9..1214c7ac22074 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/auth/join_azure_test.go @@ -103,6 +103,10 @@ func withChallengeAzure(challenge string) azureChallengeResponseOption { } } +func vmssResourceID(subscription, resourceGroup, name string) string { + return resourceID("Microsoft.Compute/virtualMachineScaleSets", subscription, resourceGroup, name) +} + func vmResourceID(subscription, resourceGroup, name string) string { return resourceID("Microsoft.Compute/virtualMachines", subscription, resourceGroup, name) } @@ -776,6 +780,28 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { certs: []*x509.Certificate{tlsConfig.Certificate}, assertError: isAccessDenied, }, + { + name: "vmss resource type", + requestTokenName: "test-token", + tokenSubscription: "token-subscription", + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: vmssResourceID("token-subscription", defaultResourceGroup, defaultVMName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: "token-subscription", + ResourceGroups: []string{defaultResourceGroup}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: require.NoError, + }, } for _, tc := range tests {