From 88153717566c7d7ab88bd26e351550568f61b11c Mon Sep 17 00:00:00 2001 From: Seth Jennings Date: Wed, 28 Sep 2022 12:34:15 -0500 Subject: [PATCH] allow certain EndpointAccess transitions --- .../hostedcontrolplane_controller.go | 29 +++- .../hostedcluster/hostedcluster_webhook.go | 20 +++ .../hostedcluster_webhook_test.go | 130 ++++++++++++++++++ 3 files changed, 177 insertions(+), 2 deletions(-) diff --git a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go index f1155d49706..ea842f6187a 100644 --- a/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go +++ b/control-plane-operator/controllers/hostedcontrolplane/hostedcontrolplane_controller.go @@ -860,8 +860,8 @@ func (r *HostedControlPlaneReconciler) reconcileAPIServerService(ctx context.Con } if serviceStrategy.Type == hyperv1.Route { + externalRoute := manifests.KubeAPIServerExternalRoute(hcp.Namespace) if util.IsPublicHCP(hcp) { - externalRoute := manifests.KubeAPIServerExternalRoute(hcp.Namespace) if _, err := createOrUpdate(ctx, r.Client, externalRoute, func() error { kas.ReconcileRoute(externalRoute, serviceStrategy.Route.Hostname) if externalRoute.Annotations == nil { @@ -872,6 +872,18 @@ func (r *HostedControlPlaneReconciler) reconcileAPIServerService(ctx context.Con }); err != nil { return fmt.Errorf("failed to reconcile apiserver external route: %w", err) } + } else { + // Remove the external route if it exists + err := r.Get(ctx, client.ObjectKeyFromObject(externalRoute), externalRoute) + if err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to check whether kube-apiserver external route exists: %w", err) + } + } else { + if err := r.Delete(ctx, externalRoute); err != nil { + return fmt.Errorf("failed to delete kube-apiserver external route: %w", err) + } + } } // We do not need to enumerate all possible addresses, because we use the KAS as default backend through a custom @@ -2569,6 +2581,7 @@ func (r *HostedControlPlaneReconciler) reconcileRouter(ctx context.Context, hcp } var canonicalHostname string + pubSvc := manifests.RouterPublicService(hcp.Namespace) if util.IsPrivateHCP(hcp) { svc := manifests.PrivateRouterService(hcp.Namespace) if _, err := createOrUpdate(ctx, r.Client, svc, func() error { @@ -2579,10 +2592,22 @@ func (r *HostedControlPlaneReconciler) reconcileRouter(ctx context.Context, hcp if (!util.IsPublicHCP(hcp) || !exposeAPIThroughRouter) && len(svc.Status.LoadBalancer.Ingress) > 0 { canonicalHostname = svc.Status.LoadBalancer.Ingress[0].Hostname } + if !util.IsPublicHCP(hcp) { + // Remove the public router Service if it exists + err := r.Get(ctx, client.ObjectKeyFromObject(pubSvc), pubSvc) + if err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("failed to check whether public router service exists: %w", err) + } + } else { + if err := r.Delete(ctx, pubSvc); err != nil { + return fmt.Errorf("failed to delete public router service: %w", err) + } + } + } } if util.IsPublicHCP(hcp) && exposeAPIThroughRouter { - pubSvc := manifests.RouterPublicService(hcp.Namespace) if _, err := createOrUpdate(ctx, r.Client, pubSvc, func() error { return ingress.ReconcileRouterService(pubSvc, config.OwnerRefFrom(hcp), util.APIPortWithDefault(hcp, config.DefaultAPIServerPort), false) }); err != nil { diff --git a/hypershift-operator/controllers/hostedcluster/hostedcluster_webhook.go b/hypershift-operator/controllers/hostedcluster/hostedcluster_webhook.go index d1ad7294b5f..3cf27f41778 100644 --- a/hypershift-operator/controllers/hostedcluster/hostedcluster_webhook.go +++ b/hypershift-operator/controllers/hostedcluster/hostedcluster_webhook.go @@ -227,6 +227,22 @@ func validateStructDeepEqual(x reflect.Value, y reflect.Value, path *field.Path, return errs } +func validateEndpointAccess(new *hyperv1.PlatformSpec, old *hyperv1.PlatformSpec) error { + if old.Type != hyperv1.AWSPlatform || new.Type != hyperv1.AWSPlatform || old.AWS == nil || new.AWS == nil { + return nil + } + if old.AWS.EndpointAccess == new.AWS.EndpointAccess { + return nil + } + if old.AWS.EndpointAccess == hyperv1.Public || new.AWS.EndpointAccess == hyperv1.Public { + return fmt.Errorf("transitioning from EndpointAccess %s to %s is not allowed", old.AWS.EndpointAccess, new.AWS.EndpointAccess) + } + // Clear EndpointAccess for further validation + old.AWS.EndpointAccess = "" + new.AWS.EndpointAccess = "" + return nil +} + // validateStructEqual uses introspection to walk through the fields of a struct and check // for differences. Any differences are flagged as an invalid change to an immutable field. func validateStructEqual(x any, y any, path *field.Path) field.ErrorList { @@ -269,6 +285,10 @@ func validateHostedClusterUpdate(new *hyperv1.HostedCluster, old *hyperv1.Hosted old.Spec.Networking.APIServer.Port = new.Spec.Networking.APIServer.Port } + if err := validateEndpointAccess(&new.Spec.Platform, &old.Spec.Platform); err != nil { + return err + } + errs := validateStructEqual(new.Spec, old.Spec, field.NewPath("HostedCluster.spec")) return errs.ToAggregate() diff --git a/hypershift-operator/controllers/hostedcluster/hostedcluster_webhook_test.go b/hypershift-operator/controllers/hostedcluster/hostedcluster_webhook_test.go index 66693ea4d42..36135ef0d5e 100644 --- a/hypershift-operator/controllers/hostedcluster/hostedcluster_webhook_test.go +++ b/hypershift-operator/controllers/hostedcluster/hostedcluster_webhook_test.go @@ -452,3 +452,133 @@ func TestValidateHostedClusterCreate(t *testing.T) { }) } } + +func Test_validateEndpointAccess(t *testing.T) { + type args struct { + new *hyperv1.PlatformSpec + old *hyperv1.PlatformSpec + } + tests := []struct { + name string + args args + wantErr bool + }{ + + { + name: "non-AWS should pass", + args: args{ + new: &hyperv1.PlatformSpec{ + Type: hyperv1.AgentPlatform, + }, + old: &hyperv1.PlatformSpec{ + Type: hyperv1.AgentPlatform, + }, + }, + }, + { + name: "nil AWS platform passes", + args: args{ + new: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + }, + old: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + }, + }, + }, + { + name: "no EndpointAccess changes passes", + args: args{ + new: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.Public, + }, + }, + old: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.Public, + }, + }, + }, + }, + { + name: "Private to PublicAndPrivate passes", + args: args{ + new: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.PublicAndPrivate, + }, + }, + old: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.Private, + }, + }, + }, + }, + { + name: "PublicAndPrivate to Private passes", + args: args{ + new: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.Private, + }, + }, + old: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.PublicAndPrivate, + }, + }, + }, + }, + { + name: "Public to Private fails", + args: args{ + new: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.Private, + }, + }, + old: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.Public, + }, + }, + }, + wantErr: true, + }, + { + name: "Public to PublicAndPrivate fails", + args: args{ + new: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.PublicAndPrivate, + }, + }, + old: &hyperv1.PlatformSpec{ + Type: hyperv1.AWSPlatform, + AWS: &hyperv1.AWSPlatformSpec{ + EndpointAccess: hyperv1.Public, + }, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := validateEndpointAccess(tt.args.new, tt.args.old); (err != nil) != tt.wantErr { + t.Errorf("validateEndpointAccess() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}