diff --git a/cmd/gce-pd-csi-driver/main.go b/cmd/gce-pd-csi-driver/main.go index 9f3f1554d..b1ac5757f 100644 --- a/cmd/gce-pd-csi-driver/main.go +++ b/cmd/gce-pd-csi-driver/main.go @@ -98,6 +98,8 @@ var ( diskTopology = flag.Bool("disk-topology", false, "If set to true, the driver will add a disk-type.gke.io/[disk-type] topology label when the StorageClass has the use-allowed-disk-topology parameter set to true. That topology label is included in the Topologies returned in CreateVolumeResponse. This flag is disabled by default.") + dynamicVolumes = flag.Bool("dynamic-volumes", false, "If set to true, the CSI driver will automatically select a compatible disk type based on the presence of the dynamic-volume parameter and disk types defined in the StorageClass. Disabled by default.") + diskCacheSyncPeriod = flag.Duration("disk-cache-sync-period", 10*time.Minute, "Period for the disk cache to check the /dev/disk/by-id/ directory and evaluate the symlinks") enableDiskSizeValidation = flag.Bool("enable-disk-size-validation", false, "If set to true, the driver will validate that the requested disk size is matches the physical disk size. This flag is disabled by default.") @@ -255,6 +257,7 @@ func handle() { args := &driver.GCEControllerServerArgs{ EnableDiskTopology: *diskTopology, EnableDiskSizeValidation: *enableDiskSizeValidation, + EnableDynamicVolumes: *dynamicVolumes, } controllerServer = driver.NewControllerServer(gceDriver, cloudProvider, initialBackoffDuration, maxBackoffDuration, fallbackRequisiteZones, *enableStoragePoolsFlag, *enableDataCacheFlag, multiZoneVolumeHandleConfig, listVolumesConfig, provisionableDisksConfig, *enableHdHAFlag, args) @@ -297,6 +300,8 @@ func handle() { SysfsPath: "/sys", MetricsManager: metricsManager, DeviceCache: deviceCache, + EnableDynamicVolumes: *dynamicVolumes, + NodeName: *nodeName, } nodeServer = driver.NewNodeServer(gceDriver, mounter, deviceUtils, meta, statter, nsArgs) diff --git a/pkg/common/utils.go b/pkg/common/utils.go index b7448cd51..e86373881 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -799,6 +799,11 @@ func MapNumber(vCPUs int64, limitMap []MachineHyperdiskLimit) int64 { return 15 } +// HasDiskTypeLabelKeyPrefix checks if the label key starts with the DiskTypeKeyPrefix. +func HasDiskTypeLabelKeyPrefix(labelKey string) bool { + return strings.HasPrefix(labelKey, DiskTypeKeyPrefix) +} + func DiskTypeLabelKey(diskType string) string { return fmt.Sprintf("%s/%s", DiskTypeKeyPrefix, diskType) } diff --git a/pkg/gce-pd-csi-driver/controller.go b/pkg/gce-pd-csi-driver/controller.go index e7b04b012..9855098b5 100644 --- a/pkg/gce-pd-csi-driver/controller.go +++ b/pkg/gce-pd-csi-driver/controller.go @@ -129,6 +129,7 @@ type GCEControllerServer struct { type GCEControllerServerArgs struct { EnableDiskTopology bool EnableDiskSizeValidation bool + EnableDynamicVolumes bool } type MultiZoneVolumeHandleConfig struct { diff --git a/pkg/gce-pd-csi-driver/gce-pd-driver.go b/pkg/gce-pd-csi-driver/gce-pd-driver.go index c1740df80..a0eefe960 100644 --- a/pkg/gce-pd-csi-driver/gce-pd-driver.go +++ b/pkg/gce-pd-csi-driver/gce-pd-driver.go @@ -160,6 +160,8 @@ func NewNodeServer(gceDriver *GCEDriver, mounter *mount.SafeFormatAndMount, devi SysfsPath: args.SysfsPath, metricsManager: args.MetricsManager, DeviceCache: args.DeviceCache, + EnableDynamicVolumes: args.EnableDynamicVolumes, + nodeName: args.NodeName, } } diff --git a/pkg/gce-pd-csi-driver/node.go b/pkg/gce-pd-csi-driver/node.go index ce8e6afbb..4460405e8 100644 --- a/pkg/gce-pd-csi-driver/node.go +++ b/pkg/gce-pd-csi-driver/node.go @@ -32,9 +32,9 @@ import ( csi "github.com/container-storage-interface/spec/lib/go/csi" + corev1 "k8s.io/api/core/v1" "k8s.io/klog/v2" "k8s.io/mount-utils" - "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils" metadataservice "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/gce-cloud-provider/metadata" @@ -54,6 +54,7 @@ type GCENodeServer struct { EnableDataCache bool DataCacheEnabledNodePool bool SysfsPath string + nodeName string // A map storing all volumes with ongoing operations so that additional operations // for that same volume (as defined by VolumeID) return an Aborted error @@ -82,6 +83,8 @@ type GCENodeServer struct { metricsManager *metrics.MetricsManager // A cache of the device paths for the volumes that are attached to the node. DeviceCache *linkcache.DeviceCache + + EnableDynamicVolumes bool } type NodeServerArgs struct { @@ -98,8 +101,12 @@ type NodeServerArgs struct { // SysfsPath defaults to "/sys", except if it's a unit test. SysfsPath string + NodeName string + MetricsManager *metrics.MetricsManager DeviceCache *linkcache.DeviceCache + + EnableDynamicVolumes bool } var _ csi.NodeServer = &GCENodeServer{} @@ -166,6 +173,15 @@ func (ns *GCENodeServer) WithSerializedFormatAndMount(timeout time.Duration, max return ns } +// GetNodeName returns the node name, prioritizing the override value (from Downward API) +// over the metadata service if available. +func (ns *GCENodeServer) GetNodeName() string { + if ns.nodeName != "" { + return ns.nodeName + } + return ns.MetadataService.GetName() +} + func (ns *GCENodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { // Validate Arguments targetPath := req.GetTargetPath() @@ -686,9 +702,25 @@ func (ns *GCENodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRe Segments: map[string]string{common.TopologyKeyZone: ns.MetadataService.GetZone()}, } + node, err := k8sclient.GetNodeWithRetry(ctx, ns.GetNodeName()) + if err != nil { + klog.Errorf("Failed to get node %s: %v. The error is ignored so that the driver can register", ns.GetNodeName(), err.Error()) + } + + if ns.EnableDynamicVolumes { + labels, err := ns.getDiskTypeLabels(node) + if err != nil { + klog.Errorf("Failed to fetch disk type topology labels: %v", err) + } + + for k, v := range labels { + top.Segments[k] = v + } + } + nodeID := common.CreateNodeID(ns.MetadataService.GetProject(), ns.MetadataService.GetZone(), ns.MetadataService.GetName()) - volumeLimits, err := ns.GetVolumeLimits(ctx) + volumeLimits, err := ns.getVolumeLimits(ctx, node) if err != nil { klog.Errorf("GetVolumeLimits failed: %v. The error is ignored so that the driver can register", err.Error()) // No error should be returned from NodeGetInfo, otherwise the driver will not register @@ -850,7 +882,7 @@ func (ns *GCENodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpa }, nil } -func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) { +func (ns *GCENodeServer) getVolumeLimits(ctx context.Context, node *corev1.Node) (int64, error) { // Machine-type format: n1-type-CPUS or custom-CPUS-RAM or f1/g1-type machineType := ns.MetadataService.GetMachineType() @@ -862,7 +894,7 @@ func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) { } // Get attach limit override from label - attachLimitOverride, err := GetAttachLimitsOverrideFromNodeLabel(ctx, ns.MetadataService.GetName()) + attachLimitOverride, err := getAttachLimitsOverrideFromNodeLabel(node) if err == nil && attachLimitOverride > 0 && attachLimitOverride < 128 { return attachLimitOverride, nil } else { @@ -924,10 +956,10 @@ func (ns *GCENodeServer) GetVolumeLimits(ctx context.Context) (int64, error) { return volumeLimitBig, nil } -func GetAttachLimitsOverrideFromNodeLabel(ctx context.Context, nodeName string) (int64, error) { - node, err := k8sclient.GetNodeWithRetry(ctx, nodeName) - if err != nil { - return 0, err +func getAttachLimitsOverrideFromNodeLabel(node *corev1.Node) (int64, error) { + // If then node is nil, return 0 which means there is no override + if node == nil { + return 0, fmt.Errorf("node is nil") } if val, found := node.GetLabels()[fmt.Sprintf(common.NodeRestrictionLabelPrefix, common.AttachLimitOverrideLabel)]; found { attachLimitOverrideForNode, err := strconv.ParseInt(val, 10, 64) @@ -939,3 +971,17 @@ func GetAttachLimitsOverrideFromNodeLabel(ctx context.Context, nodeName string) } return 0, nil } + +func (ns *GCENodeServer) getDiskTypeLabels(node *corev1.Node) (map[string]string, error) { + if node == nil { + return nil, fmt.Errorf("node is nil") + } + lbls := make(map[string]string) + for k, v := range node.GetLabels() { + if common.HasDiskTypeLabelKeyPrefix(k) { + lbls[k] = v + } + } + + return lbls, nil +} diff --git a/pkg/gce-pd-csi-driver/node_test.go b/pkg/gce-pd-csi-driver/node_test.go index e7f4d927e..13e065044 100644 --- a/pkg/gce-pd-csi-driver/node_test.go +++ b/pkg/gce-pd-csi-driver/node_test.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "path/filepath" + "sort" "strings" "testing" "time" @@ -31,6 +32,9 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/testing/protocmp" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/mount-utils" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/common" "sigs.k8s.io/gcp-compute-persistent-disk-csi-driver/pkg/deviceutils" @@ -236,11 +240,11 @@ func TestNodeGetVolumeStats(t *testing.T) { func TestNodeGetVolumeLimits(t *testing.T) { gceDriver := getTestGCEDriver(t) ns := gceDriver.ns - req := &csi.NodeGetInfoRequest{} testCases := []struct { name string machineType string + node *corev1.Node expVolumeLimit int64 expectError bool }{ @@ -430,6 +434,43 @@ func TestNodeGetVolumeLimits(t *testing.T) { name: "a4x-medgpu-nolssd", // does not exist, testing edge case machineType: "a4x-medgpu-nolssd", expVolumeLimit: volumeLimitBig, + expectError: true, + }, + { + name: "attach limit override", + machineType: "n1-standard-1", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + fmt.Sprintf(common.NodeRestrictionLabelPrefix, common.AttachLimitOverrideLabel): "63", + }, + }, + }, + expVolumeLimit: 63, + }, + { + name: "invalid attach limit override", + machineType: "n1-standard-1", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + fmt.Sprintf(common.NodeRestrictionLabelPrefix, common.AttachLimitOverrideLabel): "invalid", + }, + }, + }, + expVolumeLimit: volumeLimitBig, + }, + { + name: "attach limit override out of bounds", + machineType: "n1-standard-1", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + fmt.Sprintf(common.NodeRestrictionLabelPrefix, common.AttachLimitOverrideLabel): "9999", + }, + }, + }, + expVolumeLimit: volumeLimitBig, }, } @@ -437,18 +478,15 @@ func TestNodeGetVolumeLimits(t *testing.T) { t.Logf("Test case: %s", tc.name) metadataservice.SetMachineType(tc.machineType) - res, err := ns.NodeGetInfo(context.Background(), req) + volumeLimit, err := ns.getVolumeLimits(context.Background(), tc.node) if err != nil && !tc.expectError { t.Fatalf("Failed to get node info: %v", err) } - volumeLimit := res.GetMaxVolumesPerNode() if volumeLimit != tc.expVolumeLimit { t.Fatalf("Expected volume limit: %v, got %v, for machine-type: %v", tc.expVolumeLimit, volumeLimit, tc.machineType) } - - t.Logf("Get node info: %v", res) } } @@ -1679,3 +1717,177 @@ func TestBlockingFormatAndMount(t *testing.T) { gceDriver := getTestBlockingFormatAndMountGCEDriver(t, readyToExecute) runBlockingFormatAndMount(t, gceDriver, readyToExecute) } + +func TestGetDiskTypeLabels(t *testing.T) { + const ( + nodeName = "test-node" + diskA = common.DiskTypeKeyPrefix + "/disk-a" + diskB = common.DiskTypeKeyPrefix + "/disk-b" + ) + + testCases := []struct { + desc string + node *corev1.Node + want []string + wantError bool + }{ + { + desc: "no topology labels", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + Labels: map[string]string{"foo": "bar"}, + }, + }, + want: nil, + }, + { + desc: "multiple topology labels", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + Labels: map[string]string{ + diskA: "true", + diskB: "true", + }, + }, + }, + want: []string{diskA, diskB}, + }, + { + desc: "node not found", + node: nil, + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + gceDriver := getTestGCEDriverWithCustomMounter(t, mountmanager.NewFakeSafeMounter(), &NodeServerArgs{}) + ns := gceDriver.ns + + lbls, err := ns.getDiskTypeLabels(tc.node) + if tc.wantError { + if err == nil { + t.Fatalf("expected error but got none") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var got []string + for key := range lbls { + got = append(got, key) + } + sort.Strings(got) + + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("unexpected topology labels (-want +got):\n%s", diff) + } + }) + } +} + +func TestNodeGetInfo(t *testing.T) { + const ( + machineType = "n1-standard-4" + zone = "us-central1-b" + name = "test-node" + ) + tests := []struct { + desc string + nodeNameOverride string + want *csi.NodeGetInfoResponse + }{ + { + desc: "success", + want: &csi.NodeGetInfoResponse{ + NodeId: fmt.Sprintf("projects/test-project/zones/%s/instances/%s", zone, name), + MaxVolumesPerNode: volumeLimitBig, + AccessibleTopology: &csi.Topology{ + Segments: map[string]string{ + common.TopologyKeyZone: zone, + }, + }, + }, + }, + { + desc: "success with nodeNameOverride", + nodeNameOverride: "override-node-name", + want: &csi.NodeGetInfoResponse{ + NodeId: fmt.Sprintf("projects/test-project/zones/%s/instances/%s", zone, name), + MaxVolumesPerNode: volumeLimitBig, + AccessibleTopology: &csi.Topology{ + Segments: map[string]string{ + common.TopologyKeyZone: zone, + }, + }, + }, + }, + } + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + var ns *GCENodeServer + if tc.nodeNameOverride != "" { + args := &NodeServerArgs{ + DeviceCache: linkcache.NewTestDeviceCache(1*time.Minute, linkcache.NewTestNodeWithVolumes([]string{defaultVolumeID})), + NodeName: tc.nodeNameOverride, + } + gceDriver := getTestGCEDriverWithCustomMounter(t, mountmanager.NewFakeSafeMounter(), args) + ns = gceDriver.ns + } else { + gceDriver := getTestGCEDriver(t) + ns = gceDriver.ns + } + req := &csi.NodeGetInfoRequest{} + metadataservice.SetMachineType(machineType) + metadataservice.SetZone(zone) + metadataservice.SetName(node) + + got, err := ns.NodeGetInfo(context.Background(), req) + if err != nil { + t.Fatalf("Failed to get node info: %v", err) + } + + if diff := cmp.Diff(tc.want, got, protocmp.Transform()); diff != "" { + t.Fatalf("NodeGetInfo() returned unexpected diff (-want +got):\n%s", diff) + } + }) + } +} + +func TestGetNodeName(t *testing.T) { + tests := []struct { + desc string + nodeNameOverride string + metadataName string + want string + }{ + { + desc: "returns metadata service name when override is empty", + metadataName: "metadata-node", + want: "metadata-node", + }, + { + desc: "returns override when set", + nodeNameOverride: "override-node", + metadataName: "metadata-node", + want: "override-node", + }, + } + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + args := &NodeServerArgs{NodeName: tc.nodeNameOverride} + gceDriver := getTestGCEDriverWithCustomMounter(t, mountmanager.NewFakeSafeMounter(), args) + ns := gceDriver.ns + metadataservice.SetName(tc.metadataName) + + got := ns.GetNodeName() + if got != tc.want { + t.Errorf("GetNodeName() = %q, want %q", got, tc.want) + } + }) + } +}