Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken

// If the token is from the system-assigned managed identity, the resource ID
// is for the VM itself and we can use it to look up the VM.
// This will also match scale set VMs (VMSS), the vmClient is responsible
// for properly retrieving their information.
if slices.Contains(resourceID.ResourceType.Types, "virtualMachines") {
vm, err = vmClient.Get(ctx, tokenClaims.ResourceID)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions lib/cloud/azure/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,18 @@ func (m *ARMComputeMock) Get(_ context.Context, _ string, _ string, _ *armcomput
}, m.GetErr
}

// ARMComputeScaleSetMock mocks armcompute.VirtualMachineScaleSetVMsClient.
type ARMScaleSetMock struct {
GetResult armcompute.VirtualMachineScaleSetVM
GetErr error
}

func (m *ARMScaleSetMock) Get(ctx context.Context, resourceGroupName string, vmScaleSetName string, instanceID string, options *armcompute.VirtualMachineScaleSetVMsClientGetOptions) (armcompute.VirtualMachineScaleSetVMsClientGetResponse, error) {
return armcompute.VirtualMachineScaleSetVMsClientGetResponse{
VirtualMachineScaleSetVM: m.GetResult,
}, m.GetErr
}

// ARMSQLServerMock mocks armSQLServerClient
type ARMSQLServerMock struct {
NoAuth bool
Expand Down
104 changes: 86 additions & 18 deletions lib/cloud/azure/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ import (
"github.com/gravitational/teleport/api/types"
)

// virtualScaleSetUniformVMResourceType represents the resource type of uniform
// virtual scale set VMs.
const virtualScaleSetUniformVMResourceType = "virtualMachineScaleSets/virtualMachines"

// armCompute provides an interface for an Azure virtual machine client.
type armCompute interface {
// Get retrieves information about an Azure virtual machine.
Expand All @@ -41,9 +45,16 @@ type armCompute interface {
NewListAllPager(opts *armcompute.VirtualMachinesClientListAllOptions) *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse]
}

// scaleSet provides an interfaces for an Azure VM scale set client.
type scaleSet interface {
// Get retrieves a virtual machine from a VM scale set.
Get(ctx context.Context, resourceGroupName string, vmScaleSetName string, instanceID string, options *armcompute.VirtualMachineScaleSetVMsClientGetOptions) (armcompute.VirtualMachineScaleSetVMsClientGetResponse, error)
}

// VirtualMachinesClient is a client for Azure virtual machines.
type VirtualMachinesClient interface {
// Get returns the virtual machine for the given resource ID.
// Get returns the virtual machine (including scale set VMs) for the given
// resource ID.
Get(ctx context.Context, resourceID string) (*VirtualMachine, error)
// GetByVMID returns the virtual machine for a given VM ID.
GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error)
Expand Down Expand Up @@ -76,6 +87,8 @@ type Identity struct {
type vmClient struct {
// api is the Azure virtual machine client.
api armCompute
// scaleSetAPI is the Azure VM scale set client.
scaleSetAPI scaleSet
}

// NewVirtualMachinesClient creates a new Azure virtual machines client by
Expand All @@ -85,57 +98,98 @@ func NewVirtualMachinesClient(subscription string, cred azcore.TokenCredential,
if err != nil {
return nil, trace.Wrap(err)
}
scaleSetAPI, err := armcompute.NewVirtualMachineScaleSetVMsClient(subscription, cred, options)
if err != nil {
return nil, trace.Wrap(err)
}

return NewVirtualMachinesClientByAPI(computeAPI), nil
return NewVirtualMachinesClientByAPI(computeAPI, scaleSetAPI), nil
}

// NewVirtualMachinesClientByAPI creates a new Azure virtual machines client by
// ARM API client.
func NewVirtualMachinesClientByAPI(api armCompute) VirtualMachinesClient {
func NewVirtualMachinesClientByAPI(api armCompute, scaleSetAPI scaleSet) VirtualMachinesClient {
return &vmClient{
api: api,
api: api,
scaleSetAPI: scaleSetAPI,
}
}

func parseVirtualMachine(vm *armcompute.VirtualMachine) (*VirtualMachine, error) {
resourceID, err := arm.ParseResourceID(*vm.ID)
type vmTypes interface {
*armcompute.VirtualMachine | *armcompute.VirtualMachineScaleSetVM
}

func parseVirtualMachine[T vmTypes](vm T) (*VirtualMachine, error) {
var (
id string
name string
identity *armcompute.VirtualMachineIdentity
vmID *string
)

switch v := any(vm).(type) {
case *armcompute.VirtualMachine:
id = *v.ID
name = *v.Name
identity = v.Identity
if v.Properties != nil {
vmID = v.Properties.VMID
}

case *armcompute.VirtualMachineScaleSetVM:
id = *v.ID
name = *v.Name
identity = v.Identity
if v.Properties != nil {
vmID = v.Properties.VMID
}
}

resourceID, err := arm.ParseResourceID(id)
if err != nil {
return nil, trace.Wrap(err)
}

var identities []Identity
if vm.Identity != nil {
if systemAssigned := StringVal(vm.Identity.PrincipalID); systemAssigned != "" {
if identity != nil {
if systemAssigned := StringVal(identity.PrincipalID); systemAssigned != "" {
identities = append(identities, Identity{ResourceID: systemAssigned})
}

for identityID := range vm.Identity.UserAssignedIdentities {
for identityID := range identity.UserAssignedIdentities {
identities = append(identities, Identity{ResourceID: identityID})
}
}

var vmID string
if vm.Properties != nil {
vmID = *vm.Properties.VMID
}

return &VirtualMachine{
ID: *vm.ID,
Name: *vm.Name,
ID: id,
Name: name,
Subscription: resourceID.SubscriptionID,
ResourceGroup: resourceID.ResourceGroupName,
VMID: vmID,
VMID: StringVal(vmID),
Identities: identities,
}, nil
}

// Get returns the virtual machine for the given resource ID.
// Get returns the virtual machine (including scale set VMs) for the given
// resource ID.
//
// The virtual machine scale set (VMSS) supports two types of orchestration
// modes: uniform and flexible. Both have different resource ID format from the
// instance metadata API. A VM from a uniform VMSS has a different resource ID
// and requires a different API to retrieve its information. Flexible VMSS VMs
// use the same resource ID format as regular VMs and don't require special
// handling.
func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine, error) {
parsedResourceID, err := arm.ParseResourceID(resourceID)
if err != nil {
return nil, trace.Wrap(err)
}

if parsedResourceID.ResourceType.Type == virtualScaleSetUniformVMResourceType {
return c.getScaleSetVM(ctx, parsedResourceID)
}

resp, err := c.api.Get(ctx, parsedResourceID.ResourceGroupName, parsedResourceID.Name, nil)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -164,6 +218,20 @@ func (c *vmClient) GetByVMID(ctx context.Context, vmID string) (*VirtualMachine,
return nil, trace.NotFound("no VM with ID %q", vmID)
}

func (c *vmClient) getScaleSetVM(ctx context.Context, resourceID *arm.ResourceID) (*VirtualMachine, error) {
if resourceID.Parent == nil {
return nil, trace.BadParameter("expected resource ID to include scale set as parent resource")
}

resp, err := c.scaleSetAPI.Get(ctx, resourceID.ResourceGroupName, resourceID.Parent.Name, resourceID.Name, nil)
if err != nil {
return nil, trace.Wrap(err)
}

result, err := parseVirtualMachine(&resp.VirtualMachineScaleSetVM)
return result, trace.Wrap(err)
}

type vmPager struct {
more func() bool
nextPage func(context.Context) ([]*armcompute.VirtualMachine, error)
Expand Down
113 changes: 111 additions & 2 deletions lib/cloud/azure/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,116 @@ func TestGetVirtualMachine(t *testing.T) {
},
} {
t.Run(tc.desc, func(t *testing.T) {
vmClient := NewVirtualMachinesClientByAPI(tc.client)
vmClient := NewVirtualMachinesClientByAPI(tc.client, nil /* scaleSetAPI */)

vm, err := vmClient.Get(ctx, tc.resourceID)
tc.assertError(t, err)
tc.assertVM(t, vm)
})
}
}

func TestGetScaleSetVirtualMachine(t *testing.T) {
ctx := context.Background()
validResourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0"

for _, tc := range []struct {
desc string
resourceID string
client *ARMScaleSetMock
assertError require.ErrorAssertionFunc
assertVM require.ValueAssertionFunc
}{
{
desc: "vm with valid user identities",
resourceID: validResourceID,
client: &ARMScaleSetMock{
GetResult: armcompute.VirtualMachineScaleSetVM{
ID: to.Ptr(validResourceID),
Name: to.Ptr("name"),
Identity: &armcompute.VirtualMachineIdentity{
PrincipalID: to.Ptr("system assigned"),
Type: to.Ptr(armcompute.ResourceIdentityTypeSystemAssigned),
UserAssignedIdentities: map[string]*armcompute.UserAssignedIdentitiesValue{
"identity1": {},
"identity2": {},
},
},
},
},
assertError: require.NoError,
assertVM: func(t require.TestingT, val interface{}, _ ...interface{}) {
require.NotNil(t, val)
vm, ok := val.(*VirtualMachine)
require.Truef(t, ok, "expected *VirtualMachine, got %T", val)
require.Equal(t, vm.ID, validResourceID)
require.Equal(t, "name", vm.Name)
require.ElementsMatch(t, []Identity{
{ResourceID: "system assigned"},
{ResourceID: "identity1"},
{ResourceID: "identity2"},
}, vm.Identities)
},
},
{
desc: "vm without identity",
resourceID: validResourceID,
client: &ARMScaleSetMock{
GetResult: armcompute.VirtualMachineScaleSetVM{
ID: to.Ptr(validResourceID),
Name: to.Ptr("name"),
},
},
assertError: require.NoError,
assertVM: func(t require.TestingT, val interface{}, _ ...interface{}) {
require.NotNil(t, val)
vm, ok := val.(*VirtualMachine)
require.Truef(t, ok, "expected *VirtualMachine, got %T", val)
require.Equal(t, vm.ID, validResourceID)
require.Equal(t, "name", vm.Name)
require.Empty(t, vm.Identities)
},
},
{
desc: "vm with only user managed identities",
resourceID: validResourceID,
client: &ARMScaleSetMock{
GetResult: armcompute.VirtualMachineScaleSetVM{
ID: to.Ptr(validResourceID),
Name: to.Ptr("name"),
Identity: &armcompute.VirtualMachineIdentity{
UserAssignedIdentities: map[string]*armcompute.UserAssignedIdentitiesValue{
"identity1": {},
"identity2": {},
},
},
},
},
assertError: require.NoError,
assertVM: func(t require.TestingT, val interface{}, _ ...interface{}) {
require.NotNil(t, val)
vm, ok := val.(*VirtualMachine)
require.Truef(t, ok, "expected *VirtualMachine, got %T", val)
require.Equal(t, vm.ID, validResourceID)
require.Equal(t, "name", vm.Name)
require.ElementsMatch(t, []Identity{
{ResourceID: "identity1"},
{ResourceID: "identity2"},
}, vm.Identities)
},
},
{
desc: "client error",
resourceID: validResourceID,
client: &ARMScaleSetMock{
GetErr: fmt.Errorf("client error"),
},
assertError: require.Error,
assertVM: require.Nil,
},
} {
t.Run(tc.desc, func(t *testing.T) {
vmClient := NewVirtualMachinesClientByAPI(nil /* api */, tc.client)

vm, err := vmClient.Get(ctx, tc.resourceID)
tc.assertError(t, err)
Expand Down Expand Up @@ -184,7 +293,7 @@ func TestListVirtualMachines(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
client := NewVirtualMachinesClientByAPI(mockAPI)
client := NewVirtualMachinesClientByAPI(mockAPI, nil /* scaleSetAPI */)

vms, err := client.ListVirtualMachines(context.Background(), tc.resourceGroup)
require.NoError(t, err)
Expand Down
36 changes: 36 additions & 0 deletions lib/cloud/mocks/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"context"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"k8s.io/client-go/rest"
Expand Down Expand Up @@ -57,3 +59,37 @@ func (a *AKSMock) ClusterCredentials(ctx context.Context, cfg azure.ClusterCrede
}
return nil, time.Now(), trace.NotFound("cluster not found")
}

// AzureVM generates Azure VM resource.
func AzureVM(identities []string) armcompute.VirtualMachine {
identitiesMap := make(map[string]*armcompute.UserAssignedIdentitiesValue)
for _, identity := range identities {
identitiesMap[identity] = &armcompute.UserAssignedIdentitiesValue{}
}

return armcompute.VirtualMachine{
ID: to.Ptr("/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.compute/virtualmachines/vm"),
Name: to.Ptr("vm"),
Identity: &armcompute.VirtualMachineIdentity{
PrincipalID: to.Ptr("00000000-0000-0000-0000-000000000000"),
UserAssignedIdentities: identitiesMap,
},
}
}

// AzureScaleSetVM generates Azure scale set VM resource.
func AzureScaleSetVM(identities []string) armcompute.VirtualMachineScaleSetVM {
identitiesMap := make(map[string]*armcompute.UserAssignedIdentitiesValue)
for _, identity := range identities {
identitiesMap[identity] = &armcompute.UserAssignedIdentitiesValue{}
}

return armcompute.VirtualMachineScaleSetVM{
ID: to.Ptr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0"),
Name: to.Ptr("vm"),
Identity: &armcompute.VirtualMachineIdentity{
PrincipalID: to.Ptr("00000000-0000-0000-0000-000000000000"),
UserAssignedIdentities: identitiesMap,
},
}
}
Loading