diff --git a/azure/scope/cluster.go b/azure/scope/cluster.go index 0eb7c6725a2..b323a343511 100644 --- a/azure/scope/cluster.go +++ b/azure/scope/cluster.go @@ -35,6 +35,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure/services/loadbalancers" "sigs.k8s.io/cluster-api-provider-azure/azure/services/natgateways" "sigs.k8s.io/cluster-api-provider-azure/azure/services/routetables" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/securitygroups" "sigs.k8s.io/cluster-api-provider-azure/azure/services/subnets" "sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualnetworks" "sigs.k8s.io/cluster-api-provider-azure/azure/services/vnetpeerings" @@ -277,12 +278,14 @@ func (s *ClusterScope) NatGatewaySpecs() []azure.ResourceSpecGetter { } // NSGSpecs returns the security group specs. -func (s *ClusterScope) NSGSpecs() []azure.NSGSpec { - nsgspecs := make([]azure.NSGSpec, len(s.AzureCluster.Spec.NetworkSpec.Subnets)) +func (s *ClusterScope) NSGSpecs() []azure.ResourceSpecGetter { + nsgspecs := make([]azure.ResourceSpecGetter, len(s.AzureCluster.Spec.NetworkSpec.Subnets)) for i, subnet := range s.AzureCluster.Spec.NetworkSpec.Subnets { - nsgspecs[i] = azure.NSGSpec{ + nsgspecs[i] = &securitygroups.NSGSpec{ Name: subnet.SecurityGroup.Name, SecurityRules: subnet.SecurityGroup.SecurityRules, + ResourceGroup: s.ResourceGroup(), + Location: s.Location(), } } @@ -696,6 +699,7 @@ func (s *ClusterScope) PatchObject(ctx context.Context) error { infrav1.BastionHostReadyCondition, infrav1.VNetReadyCondition, infrav1.SubnetsReadyCondition, + infrav1.SecurityGroupsReadyCondition, }}) } diff --git a/azure/services/securitygroups/client.go b/azure/services/securitygroups/client.go index 9864b5aca47..ae7e3bc11ce 100644 --- a/azure/services/securitygroups/client.go +++ b/azure/services/securitygroups/client.go @@ -18,27 +18,23 @@ package securitygroups import ( "context" + "encoding/json" "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" "github.com/Azure/go-autorest/autorest" + azureautorest "github.com/Azure/go-autorest/autorest/azure" + "github.com/pkg/errors" + infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" + "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) -// client wraps go-sdk. -type client interface { - Get(context.Context, string, string) (network.SecurityGroup, error) - CreateOrUpdate(context.Context, string, string, network.SecurityGroup) error - Delete(context.Context, string, string) error -} - // azureClient contains the Azure go-sdk Client. type azureClient struct { securitygroups network.SecurityGroupsClient } -var _ client = (*azureClient)(nil) - // newClient creates a new VM client from subscription ID. func newClient(auth azure.Authorizer) *azureClient { c := newSecurityGroupsClient(auth.SubscriptionID(), auth.BaseURI(), auth.Authorizer()) @@ -53,58 +49,126 @@ func newSecurityGroupsClient(subscriptionID string, baseURI string, authorizer a } // Get gets the specified network security group. -func (ac *azureClient) Get(ctx context.Context, resourceGroupName, sgName string) (network.SecurityGroup, error) { - ctx, _, done := tele.StartSpanWithLogger(ctx, "securitygroups.AzureClient.Get") +func (ac *azureClient) Get(ctx context.Context, spec azure.ResourceSpecGetter) (result interface{}, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "securitygroups.azureClient.Get") defer done() - return ac.securitygroups.Get(ctx, resourceGroupName, sgName, "") + return ac.securitygroups.Get(ctx, spec.ResourceGroupName(), spec.ResourceName(), "") } -// CreateOrUpdate creates or updates a network security group in the specified resource group. -func (ac *azureClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, sgName string, sg network.SecurityGroup) error { - ctx, _, done := tele.StartSpanWithLogger(ctx, "securitygroups.AzureClient.CreateOrUpdate") +// CreateOrUpdateAsync creates or updates a network security group in the specified resource group. +// It sends a PUT request to Azure and if accepted without error, the func will return a Future which can be used to track the ongoing +// progress of the operation. +func (ac *azureClient) CreateOrUpdateAsync(ctx context.Context, spec azure.ResourceSpecGetter, parameters interface{}) (result interface{}, future azureautorest.FutureAPI, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "securitygroups.azureClient.CreateOrUpdate") defer done() + sg, ok := parameters.(network.SecurityGroup) + if !ok { + return nil, nil, errors.Errorf("%T is not a network.SecurityGroup", parameters) + } + var etag string if sg.Etag != nil { etag = *sg.Etag } - req, err := ac.securitygroups.CreateOrUpdatePreparer(ctx, resourceGroupName, sgName, sg) + req, err := ac.securitygroups.CreateOrUpdatePreparer(ctx, spec.ResourceGroupName(), spec.ResourceName(), sg) if err != nil { err = autorest.NewErrorWithError(err, "network.SecurityGroupsClient", "CreateOrUpdate", nil, "Failure preparing request") - return err + return nil, nil, err } if etag != "" { req.Header.Add("If-Match", etag) } - future, err := ac.securitygroups.CreateOrUpdateSender(req) + createFuture, err := ac.securitygroups.CreateOrUpdateSender(req) if err != nil { - err = autorest.NewErrorWithError(err, "network.SecurityGroupsClient", "CreateOrUpdate", future.Response(), "Failure sending request") - return err + err = autorest.NewErrorWithError(err, "network.SecurityGroupsClient", "CreateOrUpdate", createFuture.Response(), "Failure sending request") + return nil, nil, err } - err = future.WaitForCompletionRef(ctx, ac.securitygroups.Client) + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) + defer cancel() + + err = createFuture.WaitForCompletionRef(ctx, ac.securitygroups.Client) if err != nil { - return err + // if an error occurs, return the future. + // this means the long-running operation didn't finish in the specified timeout. + return nil, &createFuture, err } - _, err = future.Result(ac.securitygroups) - return err + result, err = createFuture.Result(ac.securitygroups) + // if the operation completed, return a nil future. + return result, nil, err } -// Delete deletes the specified network security group. -func (ac *azureClient) Delete(ctx context.Context, resourceGroupName, sgName string) error { - ctx, _, done := tele.StartSpanWithLogger(ctx, "securitygroups.AzureClient.Delete") +// Delete deletes the specified network security group. DeleteAsync sends a DELETE +// request to Azure and if accepted without error, the func will return a Future which can be used to track the ongoing +// progress of the operation. +func (ac *azureClient) DeleteAsync(ctx context.Context, spec azure.ResourceSpecGetter) (future azureautorest.FutureAPI, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "securitygroups.azureClient.Delete") defer done() - future, err := ac.securitygroups.Delete(ctx, resourceGroupName, sgName) + deleteFuture, err := ac.securitygroups.Delete(ctx, spec.ResourceGroupName(), spec.ResourceName()) if err != nil { - return err + return nil, err } - err = future.WaitForCompletionRef(ctx, ac.securitygroups.Client) + + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) + defer cancel() + + err = deleteFuture.WaitForCompletionRef(ctx, ac.securitygroups.Client) + if err != nil { + // if an error occurs, return the future. + // this means the long-running operation didn't finish in the specified timeout. + return &deleteFuture, err + } + _, err = deleteFuture.Result(ac.securitygroups) + // if the operation completed, return a nil future. + return nil, err +} + +// IsDone returns true if the long-running operation has completed. +func (ac *azureClient) IsDone(ctx context.Context, future azureautorest.FutureAPI) (isDone bool, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "securitygroups.azureClient.IsDone") + defer done() + + isDone, err = future.DoneWithContext(ctx, ac.securitygroups) if err != nil { - return err + return false, errors.Wrap(err, "failed checking if the operation was complete") + } + + return isDone, nil +} + +// Result fetches the result of a long-running operation future. +func (ac *azureClient) Result(ctx context.Context, future azureautorest.FutureAPI, futureType string) (result interface{}, err error) { + _, _, done := tele.StartSpanWithLogger(ctx, "securitygroups.azureClient.Result") + defer done() + + if future == nil { + return nil, errors.Errorf("cannot get result from nil future") + } + + switch futureType { + case infrav1.PutFuture: + // Marshal and Unmarshal the future to put it into the correct future type so we can access the Result function. + // Unfortunately the FutureAPI can't be casted directly to SecurityGroupsCreateOrUpdateFuture because it is a azureautorest.Future, which doesn't implement the Result function. See PR #1686 for discussion on alternatives. + // It was converted back to a generic azureautorest.Future from the CAPZ infrav1.Future type stored in Status: https://github.com/kubernetes-sigs/cluster-api-provider-azure/blob/main/azure/converters/futures.go#L49. + var createFuture *network.SecurityGroupsCreateOrUpdateFuture + jsonData, err := future.MarshalJSON() + if err != nil { + return nil, errors.Wrap(err, "failed to marshal future") + } + if err := json.Unmarshal(jsonData, &createFuture); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal future data") + } + return createFuture.Result(ac.securitygroups) + + case infrav1.DeleteFuture: + // Delete does not return a result security group. + return nil, nil + + default: + return nil, errors.Errorf("unknown future type %q", futureType) } - _, err = future.Result(ac.securitygroups) - return err } diff --git a/azure/services/securitygroups/mock_securitygroups/client_mock.go b/azure/services/securitygroups/mock_securitygroups/client_mock.go deleted file mode 100644 index d3612742306..00000000000 --- a/azure/services/securitygroups/mock_securitygroups/client_mock.go +++ /dev/null @@ -1,95 +0,0 @@ -/* -Copyright The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Code generated by MockGen. DO NOT EDIT. -// Source: ../client.go - -// Package mock_securitygroups is a generated GoMock package. -package mock_securitygroups - -import ( - context "context" - reflect "reflect" - - network "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" - gomock "github.com/golang/mock/gomock" -) - -// Mockclient is a mock of client interface. -type Mockclient struct { - ctrl *gomock.Controller - recorder *MockclientMockRecorder -} - -// MockclientMockRecorder is the mock recorder for Mockclient. -type MockclientMockRecorder struct { - mock *Mockclient -} - -// NewMockclient creates a new mock instance. -func NewMockclient(ctrl *gomock.Controller) *Mockclient { - mock := &Mockclient{ctrl: ctrl} - mock.recorder = &MockclientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *Mockclient) EXPECT() *MockclientMockRecorder { - return m.recorder -} - -// CreateOrUpdate mocks base method. -func (m *Mockclient) CreateOrUpdate(arg0 context.Context, arg1, arg2 string, arg3 network.SecurityGroup) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateOrUpdate", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// CreateOrUpdate indicates an expected call of CreateOrUpdate. -func (mr *MockclientMockRecorder) CreateOrUpdate(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdate", reflect.TypeOf((*Mockclient)(nil).CreateOrUpdate), arg0, arg1, arg2, arg3) -} - -// Delete mocks base method. -func (m *Mockclient) Delete(arg0 context.Context, arg1, arg2 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Delete", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// Delete indicates an expected call of Delete. -func (mr *MockclientMockRecorder) Delete(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*Mockclient)(nil).Delete), arg0, arg1, arg2) -} - -// Get mocks base method. -func (m *Mockclient) Get(arg0 context.Context, arg1, arg2 string) (network.SecurityGroup, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2) - ret0, _ := ret[0].(network.SecurityGroup) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockclientMockRecorder) Get(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*Mockclient)(nil).Get), arg0, arg1, arg2) -} diff --git a/azure/services/securitygroups/mock_securitygroups/doc.go b/azure/services/securitygroups/mock_securitygroups/doc.go index 921787947d9..c7be1215c92 100644 --- a/azure/services/securitygroups/mock_securitygroups/doc.go +++ b/azure/services/securitygroups/mock_securitygroups/doc.go @@ -15,8 +15,6 @@ limitations under the License. */ // Run go generate to regenerate this mock. -//go:generate ../../../../hack/tools/bin/mockgen -destination client_mock.go -package mock_securitygroups -source ../client.go Client //go:generate ../../../../hack/tools/bin/mockgen -destination securitygroups_mock.go -package mock_securitygroups -source ../securitygroups.go NSGScope -//go:generate /usr/bin/env bash -c "cat ../../../../hack/boilerplate/boilerplate.generatego.txt client_mock.go > _client_mock.go && mv _client_mock.go client_mock.go" //go:generate /usr/bin/env bash -c "cat ../../../../hack/boilerplate/boilerplate.generatego.txt securitygroups_mock.go > _securitygroups_mock.go && mv _securitygroups_mock.go securitygroups_mock.go" package mock_securitygroups //nolint diff --git a/azure/services/securitygroups/mock_securitygroups/securitygroups_mock.go b/azure/services/securitygroups/mock_securitygroups/securitygroups_mock.go index d4c78c48579..f35a663d83f 100644 --- a/azure/services/securitygroups/mock_securitygroups/securitygroups_mock.go +++ b/azure/services/securitygroups/mock_securitygroups/securitygroups_mock.go @@ -27,6 +27,7 @@ import ( gomock "github.com/golang/mock/gomock" v1beta1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" azure "sigs.k8s.io/cluster-api-provider-azure/azure" + v1beta10 "sigs.k8s.io/cluster-api/api/v1beta1" ) // MockNSGScope is a mock of NSGScope interface. @@ -52,62 +53,6 @@ func (m *MockNSGScope) EXPECT() *MockNSGScopeMockRecorder { return m.recorder } -// APIServerLB mocks base method. -func (m *MockNSGScope) APIServerLB() *v1beta1.LoadBalancerSpec { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "APIServerLB") - ret0, _ := ret[0].(*v1beta1.LoadBalancerSpec) - return ret0 -} - -// APIServerLB indicates an expected call of APIServerLB. -func (mr *MockNSGScopeMockRecorder) APIServerLB() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLB", reflect.TypeOf((*MockNSGScope)(nil).APIServerLB)) -} - -// APIServerLBName mocks base method. -func (m *MockNSGScope) APIServerLBName() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "APIServerLBName") - ret0, _ := ret[0].(string) - return ret0 -} - -// APIServerLBName indicates an expected call of APIServerLBName. -func (mr *MockNSGScopeMockRecorder) APIServerLBName() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLBName", reflect.TypeOf((*MockNSGScope)(nil).APIServerLBName)) -} - -// APIServerLBPoolName mocks base method. -func (m *MockNSGScope) APIServerLBPoolName(arg0 string) string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "APIServerLBPoolName", arg0) - ret0, _ := ret[0].(string) - return ret0 -} - -// APIServerLBPoolName indicates an expected call of APIServerLBPoolName. -func (mr *MockNSGScopeMockRecorder) APIServerLBPoolName(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLBPoolName", reflect.TypeOf((*MockNSGScope)(nil).APIServerLBPoolName), arg0) -} - -// AdditionalTags mocks base method. -func (m *MockNSGScope) AdditionalTags() v1beta1.Tags { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AdditionalTags") - ret0, _ := ret[0].(v1beta1.Tags) - return ret0 -} - -// AdditionalTags indicates an expected call of AdditionalTags. -func (mr *MockNSGScopeMockRecorder) AdditionalTags() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AdditionalTags", reflect.TypeOf((*MockNSGScope)(nil).AdditionalTags)) -} - // Authorizer mocks base method. func (m *MockNSGScope) Authorizer() autorest.Authorizer { m.ctrl.T.Helper() @@ -122,20 +67,6 @@ func (mr *MockNSGScopeMockRecorder) Authorizer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authorizer", reflect.TypeOf((*MockNSGScope)(nil).Authorizer)) } -// AvailabilitySetEnabled mocks base method. -func (m *MockNSGScope) AvailabilitySetEnabled() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AvailabilitySetEnabled") - ret0, _ := ret[0].(bool) - return ret0 -} - -// AvailabilitySetEnabled indicates an expected call of AvailabilitySetEnabled. -func (mr *MockNSGScopeMockRecorder) AvailabilitySetEnabled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AvailabilitySetEnabled", reflect.TypeOf((*MockNSGScope)(nil).AvailabilitySetEnabled)) -} - // BaseURI mocks base method. func (m *MockNSGScope) BaseURI() string { m.ctrl.T.Helper() @@ -192,88 +123,30 @@ func (mr *MockNSGScopeMockRecorder) CloudEnvironment() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloudEnvironment", reflect.TypeOf((*MockNSGScope)(nil).CloudEnvironment)) } -// CloudProviderConfigOverrides mocks base method. -func (m *MockNSGScope) CloudProviderConfigOverrides() *v1beta1.CloudProviderConfigOverrides { +// DeleteLongRunningOperationState mocks base method. +func (m *MockNSGScope) DeleteLongRunningOperationState(arg0, arg1 string) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CloudProviderConfigOverrides") - ret0, _ := ret[0].(*v1beta1.CloudProviderConfigOverrides) - return ret0 -} - -// CloudProviderConfigOverrides indicates an expected call of CloudProviderConfigOverrides. -func (mr *MockNSGScopeMockRecorder) CloudProviderConfigOverrides() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloudProviderConfigOverrides", reflect.TypeOf((*MockNSGScope)(nil).CloudProviderConfigOverrides)) -} - -// ClusterName mocks base method. -func (m *MockNSGScope) ClusterName() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClusterName") - ret0, _ := ret[0].(string) - return ret0 -} - -// ClusterName indicates an expected call of ClusterName. -func (mr *MockNSGScopeMockRecorder) ClusterName() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterName", reflect.TypeOf((*MockNSGScope)(nil).ClusterName)) -} - -// ControlPlaneRouteTable mocks base method. -func (m *MockNSGScope) ControlPlaneRouteTable() v1beta1.RouteTable { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ControlPlaneRouteTable") - ret0, _ := ret[0].(v1beta1.RouteTable) - return ret0 -} - -// ControlPlaneRouteTable indicates an expected call of ControlPlaneRouteTable. -func (mr *MockNSGScopeMockRecorder) ControlPlaneRouteTable() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControlPlaneRouteTable", reflect.TypeOf((*MockNSGScope)(nil).ControlPlaneRouteTable)) + m.ctrl.Call(m, "DeleteLongRunningOperationState", arg0, arg1) } -// ControlPlaneSubnet mocks base method. -func (m *MockNSGScope) ControlPlaneSubnet() v1beta1.SubnetSpec { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ControlPlaneSubnet") - ret0, _ := ret[0].(v1beta1.SubnetSpec) - return ret0 -} - -// ControlPlaneSubnet indicates an expected call of ControlPlaneSubnet. -func (mr *MockNSGScopeMockRecorder) ControlPlaneSubnet() *gomock.Call { +// DeleteLongRunningOperationState indicates an expected call of DeleteLongRunningOperationState. +func (mr *MockNSGScopeMockRecorder) DeleteLongRunningOperationState(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControlPlaneSubnet", reflect.TypeOf((*MockNSGScope)(nil).ControlPlaneSubnet)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLongRunningOperationState", reflect.TypeOf((*MockNSGScope)(nil).DeleteLongRunningOperationState), arg0, arg1) } -// FailureDomains mocks base method. -func (m *MockNSGScope) FailureDomains() []string { +// GetLongRunningOperationState mocks base method. +func (m *MockNSGScope) GetLongRunningOperationState(arg0, arg1 string) *v1beta1.Future { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FailureDomains") - ret0, _ := ret[0].([]string) + ret := m.ctrl.Call(m, "GetLongRunningOperationState", arg0, arg1) + ret0, _ := ret[0].(*v1beta1.Future) return ret0 } -// FailureDomains indicates an expected call of FailureDomains. -func (mr *MockNSGScopeMockRecorder) FailureDomains() *gomock.Call { +// GetLongRunningOperationState indicates an expected call of GetLongRunningOperationState. +func (mr *MockNSGScopeMockRecorder) GetLongRunningOperationState(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FailureDomains", reflect.TypeOf((*MockNSGScope)(nil).FailureDomains)) -} - -// GetPrivateDNSZoneName mocks base method. -func (m *MockNSGScope) GetPrivateDNSZoneName() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPrivateDNSZoneName") - ret0, _ := ret[0].(string) - return ret0 -} - -// GetPrivateDNSZoneName indicates an expected call of GetPrivateDNSZoneName. -func (mr *MockNSGScopeMockRecorder) GetPrivateDNSZoneName() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateDNSZoneName", reflect.TypeOf((*MockNSGScope)(nil).GetPrivateDNSZoneName)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLongRunningOperationState", reflect.TypeOf((*MockNSGScope)(nil).GetLongRunningOperationState), arg0, arg1) } // HashKey mocks base method. @@ -290,34 +163,6 @@ func (mr *MockNSGScopeMockRecorder) HashKey() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HashKey", reflect.TypeOf((*MockNSGScope)(nil).HashKey)) } -// IsAPIServerPrivate mocks base method. -func (m *MockNSGScope) IsAPIServerPrivate() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsAPIServerPrivate") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsAPIServerPrivate indicates an expected call of IsAPIServerPrivate. -func (mr *MockNSGScopeMockRecorder) IsAPIServerPrivate() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsAPIServerPrivate", reflect.TypeOf((*MockNSGScope)(nil).IsAPIServerPrivate)) -} - -// IsIPv6Enabled mocks base method. -func (m *MockNSGScope) IsIPv6Enabled() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsIPv6Enabled") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsIPv6Enabled indicates an expected call of IsIPv6Enabled. -func (mr *MockNSGScopeMockRecorder) IsIPv6Enabled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsIPv6Enabled", reflect.TypeOf((*MockNSGScope)(nil).IsIPv6Enabled)) -} - // IsVnetManaged mocks base method. func (m *MockNSGScope) IsVnetManaged() bool { m.ctrl.T.Helper() @@ -332,25 +177,11 @@ func (mr *MockNSGScopeMockRecorder) IsVnetManaged() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsVnetManaged", reflect.TypeOf((*MockNSGScope)(nil).IsVnetManaged)) } -// Location mocks base method. -func (m *MockNSGScope) Location() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Location") - ret0, _ := ret[0].(string) - return ret0 -} - -// Location indicates an expected call of Location. -func (mr *MockNSGScopeMockRecorder) Location() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Location", reflect.TypeOf((*MockNSGScope)(nil).Location)) -} - // NSGSpecs mocks base method. -func (m *MockNSGScope) NSGSpecs() []azure.NSGSpec { +func (m *MockNSGScope) NSGSpecs() []azure.ResourceSpecGetter { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "NSGSpecs") - ret0, _ := ret[0].([]azure.NSGSpec) + ret0, _ := ret[0].([]azure.ResourceSpecGetter) return ret0 } @@ -360,100 +191,16 @@ func (mr *MockNSGScopeMockRecorder) NSGSpecs() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NSGSpecs", reflect.TypeOf((*MockNSGScope)(nil).NSGSpecs)) } -// NodeSubnets mocks base method. -func (m *MockNSGScope) NodeSubnets() []v1beta1.SubnetSpec { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NodeSubnets") - ret0, _ := ret[0].([]v1beta1.SubnetSpec) - return ret0 -} - -// NodeSubnets indicates an expected call of NodeSubnets. -func (mr *MockNSGScopeMockRecorder) NodeSubnets() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeSubnets", reflect.TypeOf((*MockNSGScope)(nil).NodeSubnets)) -} - -// OutboundLBName mocks base method. -func (m *MockNSGScope) OutboundLBName(arg0 string) string { +// SetLongRunningOperationState mocks base method. +func (m *MockNSGScope) SetLongRunningOperationState(arg0 *v1beta1.Future) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OutboundLBName", arg0) - ret0, _ := ret[0].(string) - return ret0 + m.ctrl.Call(m, "SetLongRunningOperationState", arg0) } -// OutboundLBName indicates an expected call of OutboundLBName. -func (mr *MockNSGScopeMockRecorder) OutboundLBName(arg0 interface{}) *gomock.Call { +// SetLongRunningOperationState indicates an expected call of SetLongRunningOperationState. +func (mr *MockNSGScopeMockRecorder) SetLongRunningOperationState(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutboundLBName", reflect.TypeOf((*MockNSGScope)(nil).OutboundLBName), arg0) -} - -// OutboundPoolName mocks base method. -func (m *MockNSGScope) OutboundPoolName(arg0 string) string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OutboundPoolName", arg0) - ret0, _ := ret[0].(string) - return ret0 -} - -// OutboundPoolName indicates an expected call of OutboundPoolName. -func (mr *MockNSGScopeMockRecorder) OutboundPoolName(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutboundPoolName", reflect.TypeOf((*MockNSGScope)(nil).OutboundPoolName), arg0) -} - -// ResourceGroup mocks base method. -func (m *MockNSGScope) ResourceGroup() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ResourceGroup") - ret0, _ := ret[0].(string) - return ret0 -} - -// ResourceGroup indicates an expected call of ResourceGroup. -func (mr *MockNSGScopeMockRecorder) ResourceGroup() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResourceGroup", reflect.TypeOf((*MockNSGScope)(nil).ResourceGroup)) -} - -// SetSubnet mocks base method. -func (m *MockNSGScope) SetSubnet(arg0 v1beta1.SubnetSpec) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetSubnet", arg0) -} - -// SetSubnet indicates an expected call of SetSubnet. -func (mr *MockNSGScopeMockRecorder) SetSubnet(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSubnet", reflect.TypeOf((*MockNSGScope)(nil).SetSubnet), arg0) -} - -// Subnet mocks base method. -func (m *MockNSGScope) Subnet(arg0 string) v1beta1.SubnetSpec { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Subnet", arg0) - ret0, _ := ret[0].(v1beta1.SubnetSpec) - return ret0 -} - -// Subnet indicates an expected call of Subnet. -func (mr *MockNSGScopeMockRecorder) Subnet(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subnet", reflect.TypeOf((*MockNSGScope)(nil).Subnet), arg0) -} - -// Subnets mocks base method. -func (m *MockNSGScope) Subnets() v1beta1.Subnets { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Subnets") - ret0, _ := ret[0].(v1beta1.Subnets) - return ret0 -} - -// Subnets indicates an expected call of Subnets. -func (mr *MockNSGScopeMockRecorder) Subnets() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subnets", reflect.TypeOf((*MockNSGScope)(nil).Subnets)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLongRunningOperationState", reflect.TypeOf((*MockNSGScope)(nil).SetLongRunningOperationState), arg0) } // SubscriptionID mocks base method. @@ -484,16 +231,38 @@ func (mr *MockNSGScopeMockRecorder) TenantID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TenantID", reflect.TypeOf((*MockNSGScope)(nil).TenantID)) } -// Vnet mocks base method. -func (m *MockNSGScope) Vnet() *v1beta1.VnetSpec { +// UpdateDeleteStatus mocks base method. +func (m *MockNSGScope) UpdateDeleteStatus(arg0 v1beta10.ConditionType, arg1 string, arg2 error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Vnet") - ret0, _ := ret[0].(*v1beta1.VnetSpec) - return ret0 + m.ctrl.Call(m, "UpdateDeleteStatus", arg0, arg1, arg2) +} + +// UpdateDeleteStatus indicates an expected call of UpdateDeleteStatus. +func (mr *MockNSGScopeMockRecorder) UpdateDeleteStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDeleteStatus", reflect.TypeOf((*MockNSGScope)(nil).UpdateDeleteStatus), arg0, arg1, arg2) +} + +// UpdatePatchStatus mocks base method. +func (m *MockNSGScope) UpdatePatchStatus(arg0 v1beta10.ConditionType, arg1 string, arg2 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatePatchStatus", arg0, arg1, arg2) +} + +// UpdatePatchStatus indicates an expected call of UpdatePatchStatus. +func (mr *MockNSGScopeMockRecorder) UpdatePatchStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePatchStatus", reflect.TypeOf((*MockNSGScope)(nil).UpdatePatchStatus), arg0, arg1, arg2) +} + +// UpdatePutStatus mocks base method. +func (m *MockNSGScope) UpdatePutStatus(arg0 v1beta10.ConditionType, arg1 string, arg2 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatePutStatus", arg0, arg1, arg2) } -// Vnet indicates an expected call of Vnet. -func (mr *MockNSGScopeMockRecorder) Vnet() *gomock.Call { +// UpdatePutStatus indicates an expected call of UpdatePutStatus. +func (mr *MockNSGScopeMockRecorder) UpdatePutStatus(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Vnet", reflect.TypeOf((*MockNSGScope)(nil).Vnet)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePutStatus", reflect.TypeOf((*MockNSGScope)(nil).UpdatePutStatus), arg0, arg1, arg2) } diff --git a/azure/services/securitygroups/securitygroups.go b/azure/services/securitygroups/securitygroups.go index 1ce0c0a1f8f..5a53d476817 100644 --- a/azure/services/securitygroups/securitygroups.go +++ b/azure/services/securitygroups/securitygroups.go @@ -18,142 +18,107 @@ package securitygroups import ( "context" - "strings" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" - "github.com/Azure/go-autorest/autorest/to" - "github.com/pkg/errors" + infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" - "sigs.k8s.io/cluster-api-provider-azure/azure/converters" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/async" + "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) +const serviceName = "securitygroups" + // NSGScope defines the scope interface for a security groups service. type NSGScope interface { - azure.ClusterDescriber - azure.NetworkDescriber - NSGSpecs() []azure.NSGSpec + azure.Authorizer + azure.AsyncStatusUpdater + NSGSpecs() []azure.ResourceSpecGetter + IsVnetManaged() bool } // Service provides operations on Azure resources. type Service struct { Scope NSGScope - client + async.Reconciler } // New creates a new service. func New(scope NSGScope) *Service { + client := newClient(scope) return &Service{ - Scope: scope, - client: newClient(scope), + Scope: scope, + Reconciler: async.New(scope, client, client), } } -// Reconcile gets/creates/updates a network security group. +// Reconcile gets/creates/updates network security groups. func (s *Service) Reconcile(ctx context.Context) error { ctx, log, done := tele.StartSpanWithLogger(ctx, "securitygroups.Service.Reconcile") defer done() + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureServiceReconcileTimeout) + defer cancel() + + // Only create the NSGs if their lifecycle is managed by this controller. if !s.Scope.IsVnetManaged() { - log.V(4).Info("Skipping network security group reconcile in custom VNet mode") + log.V(4).Info("Skipping network security groups reconcile in custom VNet mode") return nil } - for _, nsgSpec := range s.Scope.NSGSpecs() { - securityRules := make([]network.SecurityRule, 0) - var etag *string - - existingNSG, err := s.client.Get(ctx, s.Scope.ResourceGroup(), nsgSpec.Name) - switch { - case err != nil && !azure.ResourceNotFound(err): - return errors.Wrapf(err, "failed to get NSG %s in %s", nsgSpec.Name, s.Scope.ResourceGroup()) - case err == nil: - // security group already exists - // We append the existing NSG etag to the header to ensure we only apply the updates if the NSG has not been modified. - etag = existingNSG.Etag - // Check if the expected rules are present - update := false - securityRules = *existingNSG.SecurityRules - for _, rule := range nsgSpec.SecurityRules { - sdkRule := converters.SecurityRuleToSDK(rule) - if !ruleExists(securityRules, sdkRule) { - update = true - securityRules = append(securityRules, sdkRule) - } - } - if !update { - // Skip update for NSG as the required default rules are present - log.V(2).Info("security group exists and no default rules are missing, skipping update", "security group", nsgSpec.Name) - continue - } - default: - log.V(2).Info("creating security group", "security group", nsgSpec.Name) - for _, rule := range nsgSpec.SecurityRules { - securityRules = append(securityRules, converters.SecurityRuleToSDK(rule)) - } - } - sg := network.SecurityGroup{ - Location: to.StringPtr(s.Scope.Location()), - SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ - SecurityRules: &securityRules, - }, - Etag: etag, - } - err = s.client.CreateOrUpdate(ctx, s.Scope.ResourceGroup(), nsgSpec.Name, sg) - if err != nil { - return errors.Wrapf(err, "failed to create or update security group %s in resource group %s", nsgSpec.Name, s.Scope.ResourceGroup()) - } - - log.V(2).Info("successfully created or updated security group", "security group", nsgSpec.Name) + specs := s.Scope.NSGSpecs() + if len(specs) == 0 { + return nil } - return nil -} -func ruleExists(rules []network.SecurityRule, rule network.SecurityRule) bool { - for _, existingRule := range rules { - if !strings.EqualFold(to.String(existingRule.Name), to.String(rule.Name)) { - continue - } - if !strings.EqualFold(to.String(existingRule.DestinationPortRange), to.String(rule.DestinationPortRange)) { - continue - } - if existingRule.Protocol != network.SecurityRuleProtocolTCP && - existingRule.Access != network.SecurityRuleAccessAllow && - existingRule.Direction != network.SecurityRuleDirectionInbound { - continue - } - if !strings.EqualFold(to.String(existingRule.SourcePortRange), "*") && - !strings.EqualFold(to.String(existingRule.SourceAddressPrefix), "*") && - !strings.EqualFold(to.String(existingRule.DestinationAddressPrefix), "*") { - continue + var resErr error + + // We go through the list of security groups to reconcile each one, independently of the result of the previous one. + // If multiple errors occur, we return the most pressing one. + // Order of precedence (highest -> lowest) is: error that is not an operationNotDoneError (i.e. error creating) -> operationNotDoneError (i.e. creating in progress) -> no error (i.e. created) + for _, nsgSpec := range specs { + if _, err := s.CreateResource(ctx, nsgSpec, serviceName); err != nil { + if !azure.IsOperationNotDoneError(err) || resErr == nil { + resErr = err + } } - return true } - return false + + s.Scope.UpdatePutStatus(infrav1.SecurityGroupsReadyCondition, serviceName, resErr) + return resErr } -// Delete deletes the network security group with the provided name. +// Delete deletes network security groups. func (s *Service) Delete(ctx context.Context) error { ctx, log, done := tele.StartSpanWithLogger(ctx, "securitygroups.Service.Delete") defer done() + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureServiceReconcileTimeout) + defer cancel() + + // Only delete the NSG if its lifecycle is managed by this controller. if !s.Scope.IsVnetManaged() { - log.V(4).Info("Skipping network security group delete in custom VNet mode") + log.V(4).Info("Skipping network security groups delete in custom VNet mode") return nil } - for _, nsgSpec := range s.Scope.NSGSpecs() { - log.V(2).Info("deleting security group", "security group", nsgSpec.Name) - err := s.client.Delete(ctx, s.Scope.ResourceGroup(), nsgSpec.Name) - if err != nil && azure.ResourceNotFound(err) { - // already deleted - continue - } - if err != nil { - return errors.Wrapf(err, "failed to delete security group %s in resource group %s", nsgSpec.Name, s.Scope.ResourceGroup()) - } + specs := s.Scope.NSGSpecs() + if len(specs) == 0 { + return nil + } + + var result error - log.V(2).Info("successfully deleted security group", "security group", nsgSpec.Name) + // We go through the list of security groups to delete each one, independently of the result of the previous one. + // If multiple errors occur, we return the most pressing one. + // Order of precedence (highest -> lowest) is: error that is not an operationNotDoneError (i.e. error deleting) -> operationNotDoneError (i.e. deleting in progress) -> no error (i.e. deleted) + for _, nsgSpec := range specs { + if err := s.DeleteResource(ctx, nsgSpec, serviceName); err != nil { + if !azure.IsOperationNotDoneError(err) || result == nil { + result = err + } + } } - return nil + + s.Scope.UpdateDeleteStatus(infrav1.SecurityGroupsReadyCondition, serviceName, result) + return result } diff --git a/azure/services/securitygroups/securitygroups_test.go b/azure/services/securitygroups/securitygroups_test.go index 7cfaa874fd8..75ac2b18955 100644 --- a/azure/services/securitygroups/securitygroups_test.go +++ b/azure/services/securitygroups/securitygroups_test.go @@ -18,211 +18,113 @@ package securitygroups import ( "context" - "net/http" "testing" "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" - "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/to" "github.com/golang/mock/gomock" . "github.com/onsi/gomega" + "github.com/pkg/errors" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/async/mock_async" "sigs.k8s.io/cluster-api-provider-azure/azure/services/securitygroups/mock_securitygroups" gomockinternal "sigs.k8s.io/cluster-api-provider-azure/internal/test/matchers/gomock" ) +var ( + fakeNSG = NSGSpec{ + Name: "test-nsg", + Location: "test-location", + SecurityRules: infrav1.SecurityRules{ + { + Name: "allow_ssh", + Description: "Allow SSH", + Priority: 2200, + Protocol: infrav1.SecurityGroupProtocolTCP, + Direction: infrav1.SecurityRuleDirectionInbound, + Source: to.StringPtr("*"), + SourcePorts: to.StringPtr("*"), + Destination: to.StringPtr("*"), + DestinationPorts: to.StringPtr("22"), + }, + { + Name: "other_rule", + Description: "Test Rule", + Priority: 500, + Protocol: infrav1.SecurityGroupProtocolTCP, + Direction: infrav1.SecurityRuleDirectionInbound, + Source: to.StringPtr("*"), + SourcePorts: to.StringPtr("*"), + Destination: to.StringPtr("*"), + DestinationPorts: to.StringPtr("80"), + }, + }, + ResourceGroup: "test-group", + } + fakeNSG2 = NSGSpec{ + Name: "test-nsg-2", + Location: "test-location", + SecurityRules: infrav1.SecurityRules{}, + ResourceGroup: "test-group", + } + errFake = errors.New("this is an error") + notDoneError = azure.NewOperationNotDoneError(&infrav1.Future{}) +) + func TestReconcileSecurityGroups(t *testing.T) { testcases := []struct { - name string - expect func(s *mock_securitygroups.MockNSGScopeMockRecorder, m *mock_securitygroups.MockclientMockRecorder) + name string + expectedError string + expect func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) }{ { - name: "security groups do not exist", - expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, m *mock_securitygroups.MockclientMockRecorder) { - s.NSGSpecs().Return([]azure.NSGSpec{ - { - Name: "nsg-one", - SecurityRules: infrav1.SecurityRules{ - { - Name: "first-rule", - Description: "a test rule", - Protocol: infrav1.SecurityGroupProtocolAll, - Priority: 400, - SourcePorts: to.StringPtr("*"), - DestinationPorts: to.StringPtr("*"), - Source: to.StringPtr("*"), - Destination: to.StringPtr("*"), - Direction: infrav1.SecurityRuleDirectionInbound, - }, - { - Name: "second-rule", - Description: "another test rule", - Protocol: infrav1.SecurityGroupProtocolAll, - Priority: 450, - SourcePorts: to.StringPtr("*"), - DestinationPorts: to.StringPtr("*"), - Source: to.StringPtr("*"), - Destination: to.StringPtr("*"), - Direction: infrav1.SecurityRuleDirectionInbound, - }, - }, - }, - { - Name: "nsg-two", - SecurityRules: infrav1.SecurityRules{}, - }, - }) + name: "create multiple security groups succeeds, should return no error", + expectedError: "", + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { s.IsVnetManaged().Return(true) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "nsg-one").Return(network.SecurityGroup{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found")) - m.CreateOrUpdate(gomockinternal.AContext(), "my-rg", "nsg-one", gomockinternal.DiffEq(network.SecurityGroup{ - SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ - SecurityRules: &[]network.SecurityRule{ - { - SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ - Description: to.StringPtr("a test rule"), - SourcePortRange: to.StringPtr("*"), - DestinationPortRange: to.StringPtr("*"), - SourceAddressPrefix: to.StringPtr("*"), - DestinationAddressPrefix: to.StringPtr("*"), - Protocol: "*", - Direction: "Inbound", - Access: "Allow", - Priority: to.Int32Ptr(400), - }, - Name: to.StringPtr("first-rule"), - }, - { - SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ - Description: to.StringPtr("another test rule"), - SourcePortRange: to.StringPtr("*"), - DestinationPortRange: to.StringPtr("*"), - SourceAddressPrefix: to.StringPtr("*"), - DestinationAddressPrefix: to.StringPtr("*"), - Protocol: "*", - Direction: "Inbound", - Access: "Allow", - Priority: to.Int32Ptr(450), - }, - Name: to.StringPtr("second-rule"), - }, - }, - }, - Etag: nil, - Location: to.StringPtr("test-location"), - })) - m.Get(gomockinternal.AContext(), "my-rg", "nsg-two").Return(network.SecurityGroup{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found")) - m.CreateOrUpdate(gomockinternal.AContext(), "my-rg", "nsg-two", gomockinternal.DiffEq(network.SecurityGroup{ - SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ - SecurityRules: &[]network.SecurityRule{}, - }, - Etag: nil, - Location: to.StringPtr("test-location"), - })) + s.NSGSpecs().Return([]azure.ResourceSpecGetter{&fakeNSG, &fakeNSG2}) + r.CreateResource(gomockinternal.AContext(), &fakeNSG, serviceName).Return(nil, nil) + r.CreateResource(gomockinternal.AContext(), &fakeNSG2, serviceName).Return(nil, nil) + s.UpdatePutStatus(infrav1.SecurityGroupsReadyCondition, serviceName, nil) }, - }, { - name: "security group exists", - expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, m *mock_securitygroups.MockclientMockRecorder) { - s.NSGSpecs().Return([]azure.NSGSpec{ - { - Name: "nsg-one", - SecurityRules: infrav1.SecurityRules{ - { - Name: "first-rule", - Description: "a test rule", - Protocol: "*", - Priority: 400, - SourcePorts: to.StringPtr("*"), - DestinationPorts: to.StringPtr("*"), - Source: to.StringPtr("*"), - Destination: to.StringPtr("*"), - Direction: infrav1.SecurityRuleDirectionOutbound, - }, - }, - }, - { - Name: "nsg-two", - SecurityRules: infrav1.SecurityRules{}, - }, - }) - s.IsVnetManaged().AnyTimes().Return(true) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "nsg-one").Return(network.SecurityGroup{ - Response: autorest.Response{}, - SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ - SecurityRules: &[]network.SecurityRule{ - { - SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ - Description: to.StringPtr("a different rule"), - Protocol: "*", - SourcePortRange: to.StringPtr("*"), - DestinationPortRange: to.StringPtr("*"), - SourceAddressPrefix: to.StringPtr("*"), - DestinationAddressPrefix: to.StringPtr("*"), - Priority: to.Int32Ptr(300), - Access: network.SecurityRuleAccessDeny, - Direction: network.SecurityRuleDirectionOutbound, - }, - Name: to.StringPtr("foo-rule"), - }, - }, - }, - Etag: to.StringPtr("test-etag"), - ID: to.StringPtr("fake/nsg/id"), - Name: to.StringPtr("nsg-one"), - }, nil) - m.CreateOrUpdate(gomockinternal.AContext(), "my-rg", "nsg-one", gomockinternal.DiffEq(network.SecurityGroup{ - SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ - SecurityRules: &[]network.SecurityRule{ - { - SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ - Description: to.StringPtr("a different rule"), - SourcePortRange: to.StringPtr("*"), - DestinationPortRange: to.StringPtr("*"), - SourceAddressPrefix: to.StringPtr("*"), - DestinationAddressPrefix: to.StringPtr("*"), - Protocol: "*", - Direction: "Outbound", - Access: "Deny", - Priority: to.Int32Ptr(300), - }, - Name: to.StringPtr("foo-rule"), - }, - { - SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ - Description: to.StringPtr("a test rule"), - SourcePortRange: to.StringPtr("*"), - DestinationPortRange: to.StringPtr("*"), - SourceAddressPrefix: to.StringPtr("*"), - DestinationAddressPrefix: to.StringPtr("*"), - Protocol: "*", - Direction: "Outbound", - Access: "Allow", - Priority: to.Int32Ptr(400), - }, - Name: to.StringPtr("first-rule"), - }, - }, - }, - Etag: to.StringPtr("test-etag"), - Location: to.StringPtr("test-location"), - })) - m.Get(gomockinternal.AContext(), "my-rg", "nsg-two").Return(network.SecurityGroup{ - Response: autorest.Response{}, - SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ - SecurityRules: &[]network.SecurityRule{}, - }, - Etag: to.StringPtr("test-etag"), - ID: to.StringPtr("fake/nsg/two/id"), - Name: to.StringPtr("nsg-two"), - }, nil) + }, + { + name: "first security groups create fails, should return error", + expectedError: errFake.Error(), + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.IsVnetManaged().Return(true) + s.NSGSpecs().Return([]azure.ResourceSpecGetter{&fakeNSG, &fakeNSG2}) + r.CreateResource(gomockinternal.AContext(), &fakeNSG, serviceName).Return(nil, errFake) + r.CreateResource(gomockinternal.AContext(), &fakeNSG2, serviceName).Return(nil, nil) + s.UpdatePutStatus(infrav1.SecurityGroupsReadyCondition, serviceName, errFake) }, - }, { - name: "skipping network security group reconcile in custom VNet mode", - expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, m *mock_securitygroups.MockclientMockRecorder) { + }, + { + name: "first sg create fails, second sg create not done, should return create error", + expectedError: errFake.Error(), + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.IsVnetManaged().Return(true) + s.NSGSpecs().Return([]azure.ResourceSpecGetter{&fakeNSG, &fakeNSG2}) + r.CreateResource(gomockinternal.AContext(), &fakeNSG, serviceName).Return(nil, errFake) + r.CreateResource(gomockinternal.AContext(), &fakeNSG2, serviceName).Return(nil, notDoneError) + s.UpdatePutStatus(infrav1.SecurityGroupsReadyCondition, serviceName, errFake) + }, + }, + { + name: "security groups create not done, should return not done error", + expectedError: notDoneError.Error(), + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.IsVnetManaged().Return(true) + s.NSGSpecs().Return([]azure.ResourceSpecGetter{&fakeNSG}) + r.CreateResource(gomockinternal.AContext(), &fakeNSG, serviceName).Return(nil, notDoneError) + s.UpdatePutStatus(infrav1.SecurityGroupsReadyCondition, serviceName, notDoneError) + }, + }, + { + name: "vnet is not managed, should skip reconcile", + expectedError: "", + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { s.IsVnetManaged().Return(false) }, }, @@ -236,79 +138,79 @@ func TestReconcileSecurityGroups(t *testing.T) { defer mockCtrl.Finish() scopeMock := mock_securitygroups.NewMockNSGScope(mockCtrl) - clientMock := mock_securitygroups.NewMockclient(mockCtrl) + reconcilerMock := mock_async.NewMockReconciler(mockCtrl) - tc.expect(scopeMock.EXPECT(), clientMock.EXPECT()) + tc.expect(scopeMock.EXPECT(), reconcilerMock.EXPECT()) s := &Service{ - Scope: scopeMock, - client: clientMock, + Scope: scopeMock, + Reconciler: reconcilerMock, } - g.Expect(s.Reconcile(context.TODO())).To(Succeed()) + err := s.Reconcile(context.TODO()) + if tc.expectedError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err).To(MatchError(tc.expectedError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } }) } } func TestDeleteSecurityGroups(t *testing.T) { testcases := []struct { - name string - expect func(s *mock_securitygroups.MockNSGScopeMockRecorder, m *mock_securitygroups.MockclientMockRecorder) + name string + expectedError string + expect func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) }{ { - name: "security groups exist", - expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, m *mock_securitygroups.MockclientMockRecorder) { - s.NSGSpecs().Return([]azure.NSGSpec{ - { - Name: "nsg-one", - SecurityRules: infrav1.SecurityRules{ - { - Name: "first-rule", - Description: "a test rule", - Protocol: "all", - Priority: 400, - SourcePorts: to.StringPtr("*"), - DestinationPorts: to.StringPtr("*"), - Source: to.StringPtr("*"), - Destination: to.StringPtr("*"), - Direction: infrav1.SecurityRuleDirectionInbound, - }, - }, - }, - { - Name: "nsg-two", - SecurityRules: infrav1.SecurityRules{}, - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") + name: "delete multiple security groups succeeds, should return no error", + expectedError: "", + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.IsVnetManaged().Return(true) + s.NSGSpecs().Return([]azure.ResourceSpecGetter{&fakeNSG, &fakeNSG2}) + r.DeleteResource(gomockinternal.AContext(), &fakeNSG, serviceName).Return(nil) + r.DeleteResource(gomockinternal.AContext(), &fakeNSG2, serviceName).Return(nil) + s.UpdateDeleteStatus(infrav1.SecurityGroupsReadyCondition, serviceName, nil) + }, + }, + { + name: "first security groups delete fails, should return an error", + expectedError: errFake.Error(), + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.IsVnetManaged().Return(true) + s.NSGSpecs().Return([]azure.ResourceSpecGetter{&fakeNSG, &fakeNSG2}) + r.DeleteResource(gomockinternal.AContext(), &fakeNSG, serviceName).Return(errFake) + r.DeleteResource(gomockinternal.AContext(), &fakeNSG2, serviceName).Return(nil) + s.UpdateDeleteStatus(infrav1.SecurityGroupsReadyCondition, serviceName, errFake) + }, + }, + { + name: "first security groups delete fails and second security groups create not done, should return an error", + expectedError: errFake.Error(), + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { s.IsVnetManaged().Return(true) - m.Delete(gomockinternal.AContext(), "my-rg", "nsg-one") - m.Delete(gomockinternal.AContext(), "my-rg", "nsg-two") + s.NSGSpecs().Return([]azure.ResourceSpecGetter{&fakeNSG, &fakeNSG2}) + r.DeleteResource(gomockinternal.AContext(), &fakeNSG, serviceName).Return(errFake) + r.DeleteResource(gomockinternal.AContext(), &fakeNSG2, serviceName).Return(notDoneError) + s.UpdateDeleteStatus(infrav1.SecurityGroupsReadyCondition, serviceName, errFake) }, }, { - name: "security group already deleted", - expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, m *mock_securitygroups.MockclientMockRecorder) { - s.NSGSpecs().Return([]azure.NSGSpec{ - { - Name: "nsg-one", - SecurityRules: infrav1.SecurityRules{}, - }, - { - Name: "nsg-two", - SecurityRules: infrav1.SecurityRules{}, - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") + name: "security groups delete not done, should return not done error", + expectedError: notDoneError.Error(), + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { s.IsVnetManaged().Return(true) - m.Delete(gomockinternal.AContext(), "my-rg", "nsg-one"). - Return(autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found")) - m.Delete(gomockinternal.AContext(), "my-rg", "nsg-two") + s.NSGSpecs().Return([]azure.ResourceSpecGetter{&fakeNSG}) + r.DeleteResource(gomockinternal.AContext(), &fakeNSG, serviceName).Return(notDoneError) + s.UpdateDeleteStatus(infrav1.SecurityGroupsReadyCondition, serviceName, notDoneError) }, }, { - name: "skipping network security group delete in custom VNet mode", - expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, m *mock_securitygroups.MockclientMockRecorder) { + name: "vnet is not managed, should skip delete", + expectedError: "", + expect: func(s *mock_securitygroups.MockNSGScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { s.IsVnetManaged().Return(false) }, }, @@ -322,16 +224,64 @@ func TestDeleteSecurityGroups(t *testing.T) { defer mockCtrl.Finish() scopeMock := mock_securitygroups.NewMockNSGScope(mockCtrl) - clientMock := mock_securitygroups.NewMockclient(mockCtrl) + reconcilerMock := mock_async.NewMockReconciler(mockCtrl) - tc.expect(scopeMock.EXPECT(), clientMock.EXPECT()) + tc.expect(scopeMock.EXPECT(), reconcilerMock.EXPECT()) s := &Service{ - Scope: scopeMock, - client: clientMock, + Scope: scopeMock, + Reconciler: reconcilerMock, } - g.Expect(s.Delete(context.TODO())).To(Succeed()) + err := s.Delete(context.TODO()) + if tc.expectedError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err).To(MatchError(tc.expectedError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } }) } } + +var ( + ruleA = network.SecurityRule{ + Name: to.StringPtr("A"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Description: to.StringPtr("this is rule A"), + Protocol: network.SecurityRuleProtocolTCP, + DestinationPortRange: to.StringPtr("*"), + SourcePortRange: to.StringPtr("*"), + DestinationAddressPrefix: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("*"), + Priority: to.Int32Ptr(100), + Direction: network.SecurityRuleDirectionInbound, + }, + } + ruleB = network.SecurityRule{ + Name: to.StringPtr("B"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Description: to.StringPtr("this is rule B"), + Protocol: network.SecurityRuleProtocolTCP, + DestinationPortRange: to.StringPtr("*"), + SourcePortRange: to.StringPtr("*"), + DestinationAddressPrefix: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("*"), + Priority: to.Int32Ptr(100), + Direction: network.SecurityRuleDirectionOutbound, + }, + } + ruleBModified = network.SecurityRule{ + Name: to.StringPtr("B"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Description: to.StringPtr("this is rule B"), + Protocol: network.SecurityRuleProtocolTCP, + DestinationPortRange: to.StringPtr("80"), + SourcePortRange: to.StringPtr("*"), + DestinationAddressPrefix: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("*"), + Priority: to.Int32Ptr(100), + Direction: network.SecurityRuleDirectionOutbound, + }, + } +) diff --git a/azure/services/securitygroups/spec.go b/azure/services/securitygroups/spec.go new file mode 100644 index 00000000000..857f2f1605f --- /dev/null +++ b/azure/services/securitygroups/spec.go @@ -0,0 +1,117 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package securitygroups + +import ( + "strings" + + "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" + "github.com/Azure/go-autorest/autorest/to" + "github.com/pkg/errors" + infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" + "sigs.k8s.io/cluster-api-provider-azure/azure/converters" +) + +// NSGSpec defines the specification for a security group. +type NSGSpec struct { + Name string + SecurityRules infrav1.SecurityRules + Location string + ResourceGroup string +} + +// ResourceName returns the name of the security group. +func (s *NSGSpec) ResourceName() string { + return s.Name +} + +// ResourceGroupName returns the name of the resource group. +func (s *NSGSpec) ResourceGroupName() string { + return s.ResourceGroup +} + +// OwnerResourceName is a no-op for security groups. +func (s *NSGSpec) OwnerResourceName() string { + return "" +} + +// Parameters returns the parameters for the security group. +func (s *NSGSpec) Parameters(existing interface{}) (interface{}, error) { + securityRules := make([]network.SecurityRule, 0) + var etag *string + + if existing != nil { + existingNSG, ok := existing.(network.SecurityGroup) + if !ok { + return nil, errors.Errorf("%T is not a network.SecurityGroup", existing) + } + // security group already exists + // We append the existing NSG etag to the header to ensure we only apply the updates if the NSG has not been modified. + etag = existingNSG.Etag + // Check if the expected rules are present + update := false + securityRules = *existingNSG.SecurityRules + for _, rule := range s.SecurityRules { + sdkRule := converters.SecurityRuleToSDK(rule) + if !ruleExists(securityRules, sdkRule) { + update = true + securityRules = append(securityRules, sdkRule) + } + } + if !update { + // Skip update for NSG as the required default rules are present + return nil, nil + } + } else { + // new security group + for _, rule := range s.SecurityRules { + securityRules = append(securityRules, converters.SecurityRuleToSDK(rule)) + } + } + + return network.SecurityGroup{ + Location: to.StringPtr(s.Location), + SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ + SecurityRules: &securityRules, + }, + Etag: etag, + }, nil +} + +// TODO: review this logic and make sure it is what we want. It seems incorrect to skip rules that don't have a certain protocol, etc. +func ruleExists(rules []network.SecurityRule, rule network.SecurityRule) bool { + for _, existingRule := range rules { + if !strings.EqualFold(to.String(existingRule.Name), to.String(rule.Name)) { + continue + } + if !strings.EqualFold(to.String(existingRule.DestinationPortRange), to.String(rule.DestinationPortRange)) { + continue + } + if existingRule.Protocol != network.SecurityRuleProtocolTCP && + existingRule.Access != network.SecurityRuleAccessAllow && + existingRule.Direction != network.SecurityRuleDirectionInbound { + continue + } + if !strings.EqualFold(to.String(existingRule.SourcePortRange), "*") && + !strings.EqualFold(to.String(existingRule.SourceAddressPrefix), "*") && + !strings.EqualFold(to.String(existingRule.DestinationAddressPrefix), "*") { + continue + } + return true + } + return false +} diff --git a/azure/services/securitygroups/spec_test.go b/azure/services/securitygroups/spec_test.go new file mode 100644 index 00000000000..a7bb7b8cc43 --- /dev/null +++ b/azure/services/securitygroups/spec_test.go @@ -0,0 +1,214 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package securitygroups + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" + "github.com/Azure/go-autorest/autorest/to" + . "github.com/onsi/gomega" + infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" + "sigs.k8s.io/cluster-api-provider-azure/azure/converters" +) + +var ( + sshRule = infrav1.SecurityRule{ + Name: "allow_ssh", + Description: "Allow SSH", + Priority: 2200, + Protocol: infrav1.SecurityGroupProtocolTCP, + Direction: infrav1.SecurityRuleDirectionInbound, + Source: to.StringPtr("*"), + SourcePorts: to.StringPtr("*"), + Destination: to.StringPtr("*"), + DestinationPorts: to.StringPtr("22"), + } + otherRule = infrav1.SecurityRule{ + Name: "other_rule", + Description: "Test Rule", + Priority: 500, + Protocol: infrav1.SecurityGroupProtocolTCP, + Direction: infrav1.SecurityRuleDirectionInbound, + Source: to.StringPtr("*"), + SourcePorts: to.StringPtr("*"), + Destination: to.StringPtr("*"), + DestinationPorts: to.StringPtr("80"), + } + customRule = infrav1.SecurityRule{ + Name: "custom_rule", + Description: "Test Rule", + Priority: 501, + Protocol: infrav1.SecurityGroupProtocolTCP, + Direction: infrav1.SecurityRuleDirectionOutbound, + Source: to.StringPtr("*"), + SourcePorts: to.StringPtr("*"), + Destination: to.StringPtr("*"), + DestinationPorts: to.StringPtr("80"), + } +) + +func TestParameters(t *testing.T) { + testcases := []struct { + name string + spec *NSGSpec + existing interface{} + expect func(g *WithT, result interface{}) + expectedError string + }{ + { + name: "NSG already exists with all rules present", + spec: &NSGSpec{ + Name: "test-nsg", + Location: "test-location", + SecurityRules: infrav1.SecurityRules{ + sshRule, + otherRule, + }, + ResourceGroup: "test-group", + }, + existing: network.SecurityGroup{ + Name: to.StringPtr("test-nsg"), + SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ + SecurityRules: &[]network.SecurityRule{ + converters.SecurityRuleToSDK(sshRule), + converters.SecurityRuleToSDK(otherRule), + }, + }, + }, + expect: func(g *WithT, result interface{}) { + g.Expect(result).To(BeNil()) + }, + }, + { + name: "NSG already exists but missing a rule", + spec: &NSGSpec{ + Name: "test-nsg", + Location: "test-location", + SecurityRules: infrav1.SecurityRules{ + sshRule, + otherRule, + }, + ResourceGroup: "test-group", + }, + existing: network.SecurityGroup{ + Name: to.StringPtr("test-nsg"), + Location: to.StringPtr("test-location"), + Etag: to.StringPtr("fake-etag"), + SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ + SecurityRules: &[]network.SecurityRule{ + converters.SecurityRuleToSDK(sshRule), + converters.SecurityRuleToSDK(customRule), + }, + }, + }, + expect: func(g *WithT, result interface{}) { + g.Expect(result).To(BeAssignableToTypeOf(network.SecurityGroup{})) + g.Expect(result).To(Equal(network.SecurityGroup{ + Location: to.StringPtr("test-location"), + Etag: to.StringPtr("fake-etag"), + SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ + SecurityRules: &[]network.SecurityRule{ + converters.SecurityRuleToSDK(sshRule), + converters.SecurityRuleToSDK(customRule), + converters.SecurityRuleToSDK(otherRule), + }, + }, + })) + }, + }, + { + name: "NSG does not exist", + spec: &NSGSpec{ + Name: "test-nsg", + Location: "test-location", + SecurityRules: infrav1.SecurityRules{ + sshRule, + otherRule, + }, + ResourceGroup: "test-group", + }, + existing: nil, + expect: func(g *WithT, result interface{}) { + g.Expect(result).To(BeAssignableToTypeOf(network.SecurityGroup{})) + g.Expect(result).To(Equal(network.SecurityGroup{ + SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ + SecurityRules: &[]network.SecurityRule{ + converters.SecurityRuleToSDK(sshRule), + converters.SecurityRuleToSDK(otherRule), + }, + }, + Location: to.StringPtr("test-location"), + })) + }, + }, + } + + for _, tc := range testcases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + t.Parallel() + + result, err := tc.spec.Parameters(tc.existing) + if tc.expectedError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err).To(MatchError(tc.expectedError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + tc.expect(g, result) + }) + } +} + +func TestRuleExists(t *testing.T) { + testcases := []struct { + name string + rules []network.SecurityRule + rule network.SecurityRule + expected bool + }{ + { + name: "rule doesn't exitst", + rules: []network.SecurityRule{ruleA}, + rule: ruleB, + expected: false, + }, + { + name: "rule exists", + rules: []network.SecurityRule{ruleA, ruleB}, + rule: ruleB, + expected: true, + }, + { + name: "rule exists but has been modified", + rules: []network.SecurityRule{ruleA, ruleB}, + rule: ruleBModified, + expected: false, + }, + } + for _, tc := range testcases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + t.Parallel() + result := ruleExists(tc.rules, tc.rule) + g.Expect(result).To(Equal(tc.expected)) + }) + } +} diff --git a/azure/types.go b/azure/types.go index b44ba210f28..6146eee2e39 100644 --- a/azure/types.go +++ b/azure/types.go @@ -50,12 +50,6 @@ const ( VirtualMachineScaleSet = "VirtualMachineScaleSet" ) -// NSGSpec defines the specification for a Security Group. -type NSGSpec struct { - Name string - SecurityRules infrav1.SecurityRules -} - // ScaleSetSpec defines the specification for a Scale Set. type ScaleSetSpec struct { Name string