diff --git a/cmd/gce-pd-csi-driver/main.go b/cmd/gce-pd-csi-driver/main.go index 726a69506..f470947c9 100644 --- a/cmd/gce-pd-csi-driver/main.go +++ b/cmd/gce-pd-csi-driver/main.go @@ -303,6 +303,7 @@ func handle() { MetricsManager: metricsManager, DeviceCache: deviceCache, EnableDynamicVolumes: *dynamicVolumes, + NodeName: *nodeName, } nodeServer = driver.NewNodeServer(gceDriver, mounter, deviceUtils, meta, statter, nsArgs) diff --git a/pkg/gce-pd-csi-driver/gce-pd-driver.go b/pkg/gce-pd-csi-driver/gce-pd-driver.go index 3f7988f71..cf7d2a0f4 100644 --- a/pkg/gce-pd-csi-driver/gce-pd-driver.go +++ b/pkg/gce-pd-csi-driver/gce-pd-driver.go @@ -161,6 +161,7 @@ func NewNodeServer(gceDriver *GCEDriver, mounter *mount.SafeFormatAndMount, devi metricsManager: args.MetricsManager, DeviceCache: args.DeviceCache, EnableDynamicVolumes: args.EnableDynamicVolumes, + nodeName: args.NodeName, } } diff --git a/pkg/gce-pd-csi-driver/node.go b/pkg/gce-pd-csi-driver/node.go index 19c25f59d..8ba465a9e 100644 --- a/pkg/gce-pd-csi-driver/node.go +++ b/pkg/gce-pd-csi-driver/node.go @@ -57,6 +57,7 @@ type GCENodeServer struct { EnableDataCache bool DataCacheEnabledNodePool bool SysfsPath string + nodeName string // A map storing all volumes with ongoing operations so that additional operations // for that same volume (as defined by VolumeID) return an Aborted error @@ -103,6 +104,8 @@ type NodeServerArgs struct { // SysfsPath defaults to "/sys", except if it's a unit test. SysfsPath string + NodeName string + MetricsManager *metrics.MetricsManager DeviceCache *linkcache.DeviceCache @@ -187,6 +190,15 @@ func (ns *GCENodeServer) WithSerializedFormatAndMount(timeout time.Duration, max return ns } +// GetNodeName returns the node name, prioritizing the override value (from Downward API) +// over the metadata service if available. +func (ns *GCENodeServer) GetNodeName() string { + if ns.nodeName != "" { + return ns.nodeName + } + return ns.MetadataService.GetName() +} + func (ns *GCENodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { // Validate Arguments targetPath := req.GetTargetPath() @@ -731,9 +743,9 @@ func (ns *GCENodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRe Segments: map[string]string{constants.TopologyKeyZone: ns.MetadataService.GetZone()}, } - node, err := k8sclient.GetNodeWithRetry(ctx, ns.MetadataService.GetName()) + node, err := k8sclient.GetNodeWithRetry(ctx, ns.GetNodeName()) if err != nil { - klog.Errorf("Failed to get node %s: %v. The error is ignored so that the driver can register", ns.MetadataService.GetName(), err.Error()) + klog.Errorf("Failed to get node %s: %v. The error is ignored so that the driver can register", ns.GetNodeName(), err.Error()) } if ns.EnableDynamicVolumes { @@ -747,7 +759,7 @@ func (ns *GCENodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRe } } - nodeID := common.CreateNodeID(ns.MetadataService.GetProject(), ns.MetadataService.GetZone(), ns.MetadataService.GetName()) + nodeID := common.CreateNodeID(ns.MetadataService.GetProject(), ns.MetadataService.GetZone(), ns.GetNodeName()) volumeLimits, err := ns.getVolumeLimits(ctx, node) if err != nil { diff --git a/pkg/gce-pd-csi-driver/node_test.go b/pkg/gce-pd-csi-driver/node_test.go index e5220f9fb..7cad050bd 100644 --- a/pkg/gce-pd-csi-driver/node_test.go +++ b/pkg/gce-pd-csi-driver/node_test.go @@ -1855,8 +1855,9 @@ func TestNodeGetInfo(t *testing.T) { name = "test-node" ) tests := []struct { - desc string - want *csi.NodeGetInfoResponse + desc string + nodeNameOverride string + want *csi.NodeGetInfoResponse }{ { desc: "success", @@ -1870,11 +1871,34 @@ func TestNodeGetInfo(t *testing.T) { }, }, }, + { + desc: "success with nodeNameOverride", + nodeNameOverride: "override-node-name", + want: &csi.NodeGetInfoResponse{ + NodeId: fmt.Sprintf("projects/test-project/zones/%s/instances/%s", zone, "override-node-name"), + MaxVolumesPerNode: volumeLimitBig, + AccessibleTopology: &csi.Topology{ + Segments: map[string]string{ + constants.TopologyKeyZone: zone, + }, + }, + }, + }, } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - gceDriver := getTestGCEDriver(t) - ns := gceDriver.ns + var ns *GCENodeServer + if tc.nodeNameOverride != "" { + args := &NodeServerArgs{ + DeviceCache: linkcache.NewTestDeviceCache(1*time.Minute, linkcache.NewTestNodeWithVolumes([]string{defaultVolumeID})), + NodeName: tc.nodeNameOverride, + } + gceDriver := getTestGCEDriverWithCustomMounter(t, mountmanager.NewFakeSafeMounter(), args) + ns = gceDriver.ns + } else { + gceDriver := getTestGCEDriver(t) + ns = gceDriver.ns + } req := &csi.NodeGetInfoRequest{} metadataservice.SetMachineType(machineType) metadataservice.SetZone(zone) @@ -1891,3 +1915,37 @@ func TestNodeGetInfo(t *testing.T) { }) } } + +func TestGetNodeName(t *testing.T) { + tests := []struct { + desc string + nodeNameOverride string + metadataName string + want string + }{ + { + desc: "returns metadata service name when override is empty", + metadataName: "metadata-node", + want: "metadata-node", + }, + { + desc: "returns override when set", + nodeNameOverride: "override-node", + metadataName: "metadata-node", + want: "override-node", + }, + } + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + args := &NodeServerArgs{NodeName: tc.nodeNameOverride} + gceDriver := getTestGCEDriverWithCustomMounter(t, mountmanager.NewFakeSafeMounter(), args) + ns := gceDriver.ns + metadataservice.SetName(tc.metadataName) + + got := ns.GetNodeName() + if got != tc.want { + t.Errorf("GetNodeName() = %q, want %q", got, tc.want) + } + }) + } +}