diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 311fea19ca14e..71f47ab5e9273 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -257,6 +257,8 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken // If the token is from the system-assigned managed identity, the resource ID // is for the VM itself and we can use it to look up the VM. + // This will also match scale set VMs (VMSS), the vmClient is responsible + // for properly retrieving their information. if slices.Contains(resourceID.ResourceType.Types, "virtualMachines") { vm, err = vmClient.Get(ctx, tokenClaims.ResourceID) if err != nil { diff --git a/lib/cloud/azure/mocks.go b/lib/cloud/azure/mocks.go index 59133720d6f27..bd79d55969f85 100644 --- a/lib/cloud/azure/mocks.go +++ b/lib/cloud/azure/mocks.go @@ -538,6 +538,18 @@ func (m *ARMComputeMock) Get(_ context.Context, _ string, _ string, _ *armcomput }, m.GetErr } +// ARMComputeScaleSetMock mocks armcompute.VirtualMachineScaleSetVMsClient. +type ARMScaleSetMock struct { + GetResult armcompute.VirtualMachineScaleSetVM + GetErr error +} + +func (m *ARMScaleSetMock) Get(ctx context.Context, resourceGroupName string, vmScaleSetName string, instanceID string, options *armcompute.VirtualMachineScaleSetVMsClientGetOptions) (armcompute.VirtualMachineScaleSetVMsClientGetResponse, error) { + return armcompute.VirtualMachineScaleSetVMsClientGetResponse{ + VirtualMachineScaleSetVM: m.GetResult, + }, m.GetErr +} + // ARMSQLServerMock mocks armSQLServerClient type ARMSQLServerMock struct { NoAuth bool diff --git a/lib/cloud/azure/vm.go b/lib/cloud/azure/vm.go index 48572cbc2066e..021d47337619d 100644 --- a/lib/cloud/azure/vm.go +++ b/lib/cloud/azure/vm.go @@ -31,6 +31,10 @@ import ( "github.com/gravitational/teleport/api/types" ) +// virtualScaleSetUniformVMResourceType represents the resource type of uniform +// virtual scale set VMs. +const virtualScaleSetUniformVMResourceType = "virtualMachineScaleSets/virtualMachines" + // armCompute provides an interface for an Azure virtual machine client. type armCompute interface { // Get retrieves information about an Azure virtual machine. @@ -41,9 +45,16 @@ type armCompute interface { NewListAllPager(opts *armcompute.VirtualMachinesClientListAllOptions) *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse] } +// scaleSet provides an interfaces for an Azure VM scale set client. +type scaleSet interface { + // Get retrieves a virtual machine from a VM scale set. + Get(ctx context.Context, resourceGroupName string, vmScaleSetName string, instanceID string, options *armcompute.VirtualMachineScaleSetVMsClientGetOptions) (armcompute.VirtualMachineScaleSetVMsClientGetResponse, error) +} + // VirtualMachinesClient is a client for Azure virtual machines. type VirtualMachinesClient interface { - // Get returns the virtual machine for the given resource ID. + // Get returns the virtual machine (including scale set VMs) for the given + // resource ID. Get(ctx context.Context, resourceID string) (*VirtualMachine, error) // GetByVMID returns the virtual machine for a given VM ID. GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error) @@ -76,6 +87,8 @@ type Identity struct { type vmClient struct { // api is the Azure virtual machine client. api armCompute + // scaleSetAPI is the Azure VM scale set client. + scaleSetAPI scaleSet } // NewVirtualMachinesClient creates a new Azure virtual machines client by @@ -85,57 +98,98 @@ func NewVirtualMachinesClient(subscription string, cred azcore.TokenCredential, if err != nil { return nil, trace.Wrap(err) } + scaleSetAPI, err := armcompute.NewVirtualMachineScaleSetVMsClient(subscription, cred, options) + if err != nil { + return nil, trace.Wrap(err) + } - return NewVirtualMachinesClientByAPI(computeAPI), nil + return NewVirtualMachinesClientByAPI(computeAPI, scaleSetAPI), nil } // NewVirtualMachinesClientByAPI creates a new Azure virtual machines client by // ARM API client. -func NewVirtualMachinesClientByAPI(api armCompute) VirtualMachinesClient { +func NewVirtualMachinesClientByAPI(api armCompute, scaleSetAPI scaleSet) VirtualMachinesClient { return &vmClient{ - api: api, + api: api, + scaleSetAPI: scaleSetAPI, } } -func parseVirtualMachine(vm *armcompute.VirtualMachine) (*VirtualMachine, error) { - resourceID, err := arm.ParseResourceID(*vm.ID) +type vmTypes interface { + *armcompute.VirtualMachine | *armcompute.VirtualMachineScaleSetVM +} + +func parseVirtualMachine[T vmTypes](vm T) (*VirtualMachine, error) { + var ( + id string + name string + identity *armcompute.VirtualMachineIdentity + vmID *string + ) + + switch v := any(vm).(type) { + case *armcompute.VirtualMachine: + id = *v.ID + name = *v.Name + identity = v.Identity + if v.Properties != nil { + vmID = v.Properties.VMID + } + + case *armcompute.VirtualMachineScaleSetVM: + id = *v.ID + name = *v.Name + identity = v.Identity + if v.Properties != nil { + vmID = v.Properties.VMID + } + } + + resourceID, err := arm.ParseResourceID(id) if err != nil { return nil, trace.Wrap(err) } var identities []Identity - if vm.Identity != nil { - if systemAssigned := StringVal(vm.Identity.PrincipalID); systemAssigned != "" { + if identity != nil { + if systemAssigned := StringVal(identity.PrincipalID); systemAssigned != "" { identities = append(identities, Identity{ResourceID: systemAssigned}) } - for identityID := range vm.Identity.UserAssignedIdentities { + for identityID := range identity.UserAssignedIdentities { identities = append(identities, Identity{ResourceID: identityID}) } } - var vmID string - if vm.Properties != nil { - vmID = *vm.Properties.VMID - } - return &VirtualMachine{ - ID: *vm.ID, - Name: *vm.Name, + ID: id, + Name: name, Subscription: resourceID.SubscriptionID, ResourceGroup: resourceID.ResourceGroupName, - VMID: vmID, + VMID: StringVal(vmID), Identities: identities, }, nil } -// Get returns the virtual machine for the given resource ID. +// Get returns the virtual machine (including scale set VMs) for the given +// resource ID. +// +// The virtual machine scale set (VMSS) supports two types of orchestration +// modes: uniform and flexible. Both have different resource ID format from the +// instance metadata API. A VM from a uniform VMSS has a different resource ID +// and requires a different API to retrieve its information. Flexible VMSS VMs +// use the same resource ID format as regular VMs and don't require special +// handling. func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine, error) { parsedResourceID, err := arm.ParseResourceID(resourceID) if err != nil { return nil, trace.Wrap(err) } + if parsedResourceID.ResourceType.Type == virtualScaleSetUniformVMResourceType { + return c.getScaleSetVM(ctx, parsedResourceID) + } + resp, err := c.api.Get(ctx, parsedResourceID.ResourceGroupName, parsedResourceID.Name, nil) if err != nil { return nil, trace.Wrap(err) @@ -164,6 +218,20 @@ func (c *vmClient) GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, return nil, trace.NotFound("no VM with ID %q", vmID) } +func (c *vmClient) getScaleSetVM(ctx context.Context, resourceID *arm.ResourceID) (*VirtualMachine, error) { + if resourceID.Parent == nil { + return nil, trace.BadParameter("expected resource ID to include scale set as parent resource") + } + + resp, err := c.scaleSetAPI.Get(ctx, resourceID.ResourceGroupName, resourceID.Parent.Name, resourceID.Name, nil) + if err != nil { + return nil, trace.Wrap(err) + } + + result, err := parseVirtualMachine(&resp.VirtualMachineScaleSetVM) + return result, trace.Wrap(err) +} + type vmPager struct { more func() bool nextPage func(context.Context) ([]*armcompute.VirtualMachine, error) diff --git a/lib/cloud/azure/vm_test.go b/lib/cloud/azure/vm_test.go index 5de2ea7cdfc96..76939d1b522f5 100644 --- a/lib/cloud/azure/vm_test.go +++ b/lib/cloud/azure/vm_test.go @@ -137,7 +137,116 @@ func TestGetVirtualMachine(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - vmClient := NewVirtualMachinesClientByAPI(tc.client) + vmClient := NewVirtualMachinesClientByAPI(tc.client, nil /* scaleSetAPI */) + + vm, err := vmClient.Get(ctx, tc.resourceID) + tc.assertError(t, err) + tc.assertVM(t, vm) + }) + } +} + +func TestGetScaleSetVirtualMachine(t *testing.T) { + ctx := context.Background() + validResourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0" + + for _, tc := range []struct { + desc string + resourceID string + client *ARMScaleSetMock + assertError require.ErrorAssertionFunc + assertVM require.ValueAssertionFunc + }{ + { + desc: "vm with valid user identities", + resourceID: validResourceID, + client: &ARMScaleSetMock{ + GetResult: armcompute.VirtualMachineScaleSetVM{ + ID: to.Ptr(validResourceID), + Name: to.Ptr("name"), + Identity: &armcompute.VirtualMachineIdentity{ + PrincipalID: to.Ptr("system assigned"), + Type: to.Ptr(armcompute.ResourceIdentityTypeSystemAssigned), + UserAssignedIdentities: map[string]*armcompute.UserAssignedIdentitiesValue{ + "identity1": {}, + "identity2": {}, + }, + }, + }, + }, + assertError: require.NoError, + assertVM: func(t require.TestingT, val interface{}, _ ...interface{}) { + require.NotNil(t, val) + vm, ok := val.(*VirtualMachine) + require.Truef(t, ok, "expected *VirtualMachine, got %T", val) + require.Equal(t, vm.ID, validResourceID) + require.Equal(t, "name", vm.Name) + require.ElementsMatch(t, []Identity{ + {ResourceID: "system assigned"}, + {ResourceID: "identity1"}, + {ResourceID: "identity2"}, + }, vm.Identities) + }, + }, + { + desc: "vm without identity", + resourceID: validResourceID, + client: &ARMScaleSetMock{ + GetResult: armcompute.VirtualMachineScaleSetVM{ + ID: to.Ptr(validResourceID), + Name: to.Ptr("name"), + }, + }, + assertError: require.NoError, + assertVM: func(t require.TestingT, val interface{}, _ ...interface{}) { + require.NotNil(t, val) + vm, ok := val.(*VirtualMachine) + require.Truef(t, ok, "expected *VirtualMachine, got %T", val) + require.Equal(t, vm.ID, validResourceID) + require.Equal(t, "name", vm.Name) + require.Empty(t, vm.Identities) + }, + }, + { + desc: "vm with only user managed identities", + resourceID: validResourceID, + client: &ARMScaleSetMock{ + GetResult: armcompute.VirtualMachineScaleSetVM{ + ID: to.Ptr(validResourceID), + Name: to.Ptr("name"), + Identity: &armcompute.VirtualMachineIdentity{ + UserAssignedIdentities: map[string]*armcompute.UserAssignedIdentitiesValue{ + "identity1": {}, + "identity2": {}, + }, + }, + }, + }, + assertError: require.NoError, + assertVM: func(t require.TestingT, val interface{}, _ ...interface{}) { + require.NotNil(t, val) + vm, ok := val.(*VirtualMachine) + require.Truef(t, ok, "expected *VirtualMachine, got %T", val) + require.Equal(t, vm.ID, validResourceID) + require.Equal(t, "name", vm.Name) + require.ElementsMatch(t, []Identity{ + {ResourceID: "identity1"}, + {ResourceID: "identity2"}, + }, vm.Identities) + }, + }, + { + desc: "client error", + resourceID: validResourceID, + client: &ARMScaleSetMock{ + GetErr: fmt.Errorf("client error"), + }, + assertError: require.Error, + assertVM: require.Nil, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + vmClient := NewVirtualMachinesClientByAPI(nil /* api */, tc.client) vm, err := vmClient.Get(ctx, tc.resourceID) tc.assertError(t, err) @@ -184,7 +293,7 @@ func TestListVirtualMachines(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - client := NewVirtualMachinesClientByAPI(mockAPI) + client := NewVirtualMachinesClientByAPI(mockAPI, nil /* scaleSetAPI */) vms, err := client.ListVirtualMachines(context.Background(), tc.resourceGroup) require.NoError(t, err) diff --git a/lib/cloud/mocks/azure.go b/lib/cloud/mocks/azure.go index adad1487f2e18..bab528c4dfd09 100644 --- a/lib/cloud/mocks/azure.go +++ b/lib/cloud/mocks/azure.go @@ -22,6 +22,8 @@ import ( "context" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "k8s.io/client-go/rest" @@ -57,3 +59,37 @@ func (a *AKSMock) ClusterCredentials(ctx context.Context, cfg azure.ClusterCrede } return nil, time.Now(), trace.NotFound("cluster not found") } + +// AzureVM generates Azure VM resource. +func AzureVM(identities []string) armcompute.VirtualMachine { + identitiesMap := make(map[string]*armcompute.UserAssignedIdentitiesValue) + for _, identity := range identities { + identitiesMap[identity] = &armcompute.UserAssignedIdentitiesValue{} + } + + return armcompute.VirtualMachine{ + ID: to.Ptr("/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.compute/virtualmachines/vm"), + Name: to.Ptr("vm"), + Identity: &armcompute.VirtualMachineIdentity{ + PrincipalID: to.Ptr("00000000-0000-0000-0000-000000000000"), + UserAssignedIdentities: identitiesMap, + }, + } +} + +// AzureScaleSetVM generates Azure scale set VM resource. +func AzureScaleSetVM(identities []string) armcompute.VirtualMachineScaleSetVM { + identitiesMap := make(map[string]*armcompute.UserAssignedIdentitiesValue) + for _, identity := range identities { + identitiesMap[identity] = &armcompute.UserAssignedIdentitiesValue{} + } + + return armcompute.VirtualMachineScaleSetVM{ + ID: to.Ptr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0"), + Name: to.Ptr("vm"), + Identity: &armcompute.VirtualMachineIdentity{ + PrincipalID: to.Ptr("00000000-0000-0000-0000-000000000000"), + UserAssignedIdentities: identitiesMap, + }, + } +} diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index 4fe788a68b0dd..370b5afe06542 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -28,8 +28,6 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -278,8 +276,8 @@ func TestGetAzureIdentityResourceID(t *testing.T) { instanceType: types.InstanceMetadataTypeAzure, }, AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI(&libcloudazure.ARMComputeMock{ - GetResult: generateAzureVM(t, []string{identityResourceID(t, "identity")}), - }), + GetResult: mocks.AzureVM([]string{identityResourceID(t, "identity")}), + }, nil /* scaleSetAPI */), }, errAssertion: require.NoError, resourceIDAssertion: func(requireT require.TestingT, value interface{}, _ ...interface{}) { @@ -295,8 +293,8 @@ func TestGetAzureIdentityResourceID(t *testing.T) { instanceType: types.InstanceMetadataTypeAzure, }, AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI(&libcloudazure.ARMComputeMock{ - GetResult: generateAzureVM(t, []string{identityResourceID(t, "identity")}), - }), + GetResult: mocks.AzureVM([]string{identityResourceID(t, "identity")}), + }, nil /* scaleSetAPI */), }, errAssertion: require.Error, resourceIDAssertion: require.Empty, @@ -310,8 +308,8 @@ func TestGetAzureIdentityResourceID(t *testing.T) { instanceType: types.InstanceMetadataTypeAzure, }, AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI(&libcloudazure.ARMComputeMock{ - GetResult: generateAzureVM(t, []string{"identity"}), - }), + GetResult: mocks.AzureVM([]string{"identity"}), + }, nil /* scaleSetAPI */), }, errAssertion: require.Error, resourceIDAssertion: require.Empty, @@ -338,7 +336,81 @@ func TestGetAzureIdentityResourceID(t *testing.T) { }, AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI(&libcloudazure.ARMComputeMock{ GetErr: errors.New("failed to get VM"), - }), + }, nil /* scaleSetAPI */), + }, + errAssertion: require.Error, + resourceIDAssertion: require.Empty, + }, + { + desc: "scale set vm running on Azure and identity is attached", + identityName: "identity", + clients: &cloud.TestCloudClients{ + InstanceMetadata: &imdsMock{ + id: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0", + instanceType: types.InstanceMetadataTypeAzure, + }, + AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI( + nil, /* api */ + &libcloudazure.ARMScaleSetMock{ + GetResult: mocks.AzureScaleSetVM([]string{identityResourceID(t, "identity")}), + }, + ), + }, + errAssertion: require.NoError, + resourceIDAssertion: func(requireT require.TestingT, value interface{}, _ ...interface{}) { + require.Equal(requireT, identityResourceID(t, "identity"), value) + }, + }, + { + desc: "scale set vm running on Azure without the identity", + identityName: "random-identity-not-attached", + clients: &cloud.TestCloudClients{ + InstanceMetadata: &imdsMock{ + id: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0", + instanceType: types.InstanceMetadataTypeAzure, + }, + AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI( + nil, /* api */ + &libcloudazure.ARMScaleSetMock{ + GetResult: mocks.AzureScaleSetVM([]string{identityResourceID(t, "identity")}), + }, + ), + }, + errAssertion: require.Error, + resourceIDAssertion: require.Empty, + }, + { + desc: "scale set vm running on Azure wrong format identity", + identityName: "random-identity-not-attached", + clients: &cloud.TestCloudClients{ + InstanceMetadata: &imdsMock{ + id: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0", + instanceType: types.InstanceMetadataTypeAzure, + }, + AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI( + nil, /* api */ + &libcloudazure.ARMScaleSetMock{ + GetResult: mocks.AzureScaleSetVM([]string{"identity"}), + }, + ), + }, + errAssertion: require.Error, + resourceIDAssertion: require.Empty, + }, + { + desc: "scale set vm running but failed to get VM", + identityName: "identity", + clients: &cloud.TestCloudClients{ + InstanceMetadata: &imdsMock{ + id: "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0", + instanceType: types.InstanceMetadataTypeAzure, + }, + AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI( + nil, /* api */ + &libcloudazure.ARMScaleSetMock{ + GetErr: trace.NotFound("vm not found"), + }, + ), }, errAssertion: require.Error, resourceIDAssertion: require.Empty, @@ -375,7 +447,7 @@ func TestGetAzureIdentityResourceIDCache(t *testing.T) { id: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.compute/virtualmachines/vm", instanceType: types.InstanceMetadataTypeAzure, }, - AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI(virtualMachinesMock), + AzureVirtualMachines: libcloudazure.NewVirtualMachinesClientByAPI(virtualMachinesMock, nil /* scaleSetAPI */), }, }) require.NoError(t, err) @@ -387,7 +459,7 @@ func TestGetAzureIdentityResourceIDCache(t *testing.T) { // Change mock to return the VM. virtualMachinesMock.GetErr = nil - virtualMachinesMock.GetResult = generateAzureVM(t, []string{identityResourceID(t, "identity")}) + virtualMachinesMock.GetResult = mocks.AzureVM([]string{identityResourceID(t, "identity")}) // Advance the clock to force cache expiration. clock.Advance(azureVirtualMachineCacheTTL + time.Second) @@ -927,25 +999,6 @@ func identityResourceID(t *testing.T, identityName string) string { return fmt.Sprintf("/subscriptions/sub-id/resourceGroups/group-name/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", identityName) } -// generateAzureVM generates Azure VM resource. -func generateAzureVM(t *testing.T, identities []string) armcompute.VirtualMachine { - t.Helper() - - identitiesMap := make(map[string]*armcompute.UserAssignedIdentitiesValue) - for _, identity := range identities { - identitiesMap[identity] = &armcompute.UserAssignedIdentitiesValue{} - } - - return armcompute.VirtualMachine{ - ID: to.Ptr("/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.compute/virtualmachines/vm"), - Name: to.Ptr("vm"), - Identity: &armcompute.VirtualMachineIdentity{ - PrincipalID: to.Ptr("00000000-0000-0000-0000-000000000000"), - UserAssignedIdentities: identitiesMap, - }, - } -} - // authClientMock is a mock that implements AuthClient interface. type authClientMock struct { } diff --git a/lib/srv/server/azure_watcher_test.go b/lib/srv/server/azure_watcher_test.go index 6c3989bc91a75..875162622b5d0 100644 --- a/lib/srv/server/azure_watcher_test.go +++ b/lib/srv/server/azure_watcher_test.go @@ -76,7 +76,7 @@ func TestAzureWatcher(t *testing.T) { }, }, }, - }), + }, nil /* scaleSetAPI */), } tests := []struct {