Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions cmd/gce-pd-csi-driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,15 @@ func handle() {
if err != nil {
klog.Fatalf("Failed to set up metadata service: %v", err.Error())
}
isDataCacheEnabledNodePool, err := isDataCacheEnabledNodePool(ctx, *nodeName)
if err != nil {
klog.Fatalf("Failed to get node info from API server: %v", err.Error())
Comment thread
sunnylovestiramisu marked this conversation as resolved.
}
nsArgs := driver.NodeServerArgs{
EnableDeviceInUseCheck: *enableDeviceInUseCheck,
DeviceInUseTimeout: *deviceInUseTimeout,
EnableDataCache: *enableDataCacheFlag,
DataCacheEnabledNodePool: isDataCacheEnabledNodePool(ctx, *nodeName),
DataCacheEnabledNodePool: isDataCacheEnabledNodePool,
}
nodeServer = driver.NewNodeServer(gceDriver, mounter, deviceUtils, meta, statter, nsArgs)
if *maxConcurrentFormatAndMount > 0 {
Expand Down Expand Up @@ -347,14 +351,15 @@ func urlFlag(target **url.URL, name string, usage string) {
})
}

func isDataCacheEnabledNodePool(ctx context.Context, nodeName string) bool {
if nodeName != common.TestNode { // disregard logic below when E2E testing.
func isDataCacheEnabledNodePool(ctx context.Context, nodeName string) (bool, error) {
if !*enableDataCacheFlag {
return false, nil
}
if len(nodeName) > 0 && nodeName != common.TestNode { // disregard logic below when E2E testing.
dataCacheLSSDCount, err := driver.GetDataCacheCountFromNodeLabel(ctx, nodeName)
if err != nil || dataCacheLSSDCount == 0 {
return false
}
return dataCacheLSSDCount != 0, err
}
return true
return true, nil
}

func fetchLssdsForRaiding(lssdCount int) ([]string, error) {
Expand Down
33 changes: 28 additions & 5 deletions pkg/gce-pd-csi-driver/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ import (
"regexp"
"strconv"
"strings"
"time"

csi "github.com/container-storage-interface/spec/lib/go/csi"
fsnotify "github.com/fsnotify/fsnotify"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/klog/v2"
Expand Down Expand Up @@ -242,18 +245,15 @@ func ValidateDataCacheConfig(dataCacheMode string, dataCacheSize string, ctx con

func GetDataCacheCountFromNodeLabel(ctx context.Context, nodeName string) (int, error) {
cfg, err := rest.InClusterConfig()
// We want to capture API errors with node label fetching, so return -1
// in those cases instead of 0.
if err != nil {
return 0, err
}
kubeClient, err := kubernetes.NewForConfig(cfg)
if err != nil {
return 0, err
}
node, err := kubeClient.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{})
node, err := getNodeWithRetry(ctx, kubeClient, nodeName)
if err != nil {
// We could retry, but this error will also crashloop the driver which may be as good a way to retry as any.
return 0, err
}
if val, found := node.GetLabels()[fmt.Sprintf(common.NodeLabelPrefix, common.DataCacheLssdCountLabel)]; found {
Expand All @@ -264,10 +264,33 @@ func GetDataCacheCountFromNodeLabel(ctx context.Context, nodeName string) (int,
klog.V(4).Infof("Number of local SSDs requested for Data Cache: %v", dataCacheCount)
return dataCacheCount, nil
}
// This will be returned for a non-Data-Cache node pool
return 0, nil
}

func getNodeWithRetry(ctx context.Context, kubeClient *kubernetes.Clientset, nodeName string) (*v1.Node, error) {
var nodeObj *v1.Node
backoff := wait.Backoff{
Duration: 1 * time.Second,
Factor: 2.0,
Steps: 5,
}
err := wait.ExponentialBackoffWithContext(ctx, backoff, func() (bool, error) {
node, err := kubeClient.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{})
if err != nil {
klog.Warningf("Error getting node %s: %v, retrying...\n", nodeName, err)
return false, nil
}
nodeObj = node
klog.V(4).Infof("Successfully retrieved node info %s\n", nodeName)
return true, nil
})

if err != nil {
klog.Errorf("Failed to get node %s after retries: %v\n", nodeName, err)
}
return nodeObj, err
}

func FetchRaidedLssdCountForDatacache() (int, error) {
raidedPath, err := fetchRAIDedLocalSsdPath()
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ func GCEClientAndDriverSetup(instance *remote.InstanceInfo, driverConfig DriverC
fmt.Sprintf("--fallback-requisite-zones=%s", strings.Join(driverConfig.Zones, ",")),
}

extra_flags = append(extra_flags, fmt.Sprintf("--node-name=%s", utilcommon.TestNode))
if instance.GetLocalSSD() > 0 {
extra_flags = append(extra_flags, "--enable-data-cache")
extra_flags = append(extra_flags, fmt.Sprintf("--node-name=%s", utilcommon.TestNode))
}
extra_flags = append(extra_flags, fmt.Sprintf("--compute-endpoint=%s", driverConfig.ComputeEndpoint))
extra_flags = append(extra_flags, driverConfig.ExtraFlags...)
Expand Down