Skip to content

Commit

Permalink
Make the provisioner name optional
Browse files Browse the repository at this point in the history
Driver name is used as provisioner when no provisioner is specified by the
user.
  • Loading branch information
jsafrane committed Oct 1, 2018
1 parent 5cfa08e commit 55d7c0e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
12 changes: 11 additions & 1 deletion cmd/csi-provisioner/csi-provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
)

var (
provisioner = flag.String("provisioner", "", "Name of the provisioner. The provisioner will only provision volumes for claims that request a StorageClass with a provisioner field set equal to this name.")
provisioner = flag.String("provisioner", "", "Name of the provisioner. The provisioner will only provision volumes for claims that request a StorageClass with a provisioner field set equal to this name. If omitted, CSI driver name is used.")
master = flag.String("master", "", "Master URL to build a client config from. Either this or kubeconfig needs to be set if the provisioner is being run out of cluster.")
kubeconfig = flag.String("kubeconfig", "", "Absolute path to the kubeconfig file. Either this or master needs to be set if the provisioner is being run out of cluster.")
csiEndpoint = flag.String("csi-address", "/run/csi/socket", "The gRPC endpoint for Target CSI Volume")
Expand Down Expand Up @@ -121,6 +121,16 @@ func init() {
}
time.Sleep(10 * time.Second)
}

// Autodetect provisioner name if necessary
if *provisioner == "" {
*provisioner, err = ctrl.GetDriverName(grpcClient, *connectionTimeout)
if err != nil {
glog.Fatalf("Error getting CSI driver name: %s", err)
}
glog.V(2).Infof("Detected CSI driver %q", *provisioner)
}

// Create the provisioner: it implements the Provisioner interface expected by
// the controller
csiProvisioner := ctrl.NewCSIProvisioner(clientset, csiAPIClient, *csiEndpoint, *connectionTimeout, identity, *volumeNamePrefix, *volumeNameUUIDLength, grpcClient, snapClient)
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func Connect(address string, timeout time.Duration) (*grpc.ClientConn, error) {
}
}

func getDriverName(conn *grpc.ClientConn, timeout time.Duration) (string, error) {
func GetDriverName(conn *grpc.ClientConn, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

Expand Down Expand Up @@ -304,7 +304,7 @@ func checkDriverState(grpcClient *grpc.ClientConn, timeout time.Duration, needSn
}
}

driverName, err := getDriverName(grpcClient, timeout)
driverName, err := GetDriverName(grpcClient, timeout)
if err != nil {
return nil, fmt.Errorf("failed to get driver info: %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/controller/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func TestGetPluginName(t *testing.T) {
out := test.output[0]

identityServer.EXPECT().GetPluginInfo(gomock.Any(), in).Return(out, nil).Times(1)
oldName, err := getDriverName(csiConn.conn, timeout)
oldName, err := GetDriverName(csiConn.conn, timeout)
if err != nil {
t.Errorf("test %q: Failed to get driver's name", test.name)
}
Expand All @@ -132,7 +132,7 @@ func TestGetPluginName(t *testing.T) {

out = test.output[1]
identityServer.EXPECT().GetPluginInfo(gomock.Any(), in).Return(out, nil).Times(1)
newName, err := getDriverName(csiConn.conn, timeout)
newName, err := GetDriverName(csiConn.conn, timeout)
if err != nil {
t.Errorf("test %s: Failed to get driver's name", test.name)
}
Expand Down Expand Up @@ -360,7 +360,7 @@ func TestGetDriverName(t *testing.T) {
// Setup expectation
identityServer.EXPECT().GetPluginInfo(gomock.Any(), in).Return(out, injectedErr).Times(1)

name, err := getDriverName(csiConn.conn, timeout)
name, err := GetDriverName(csiConn.conn, timeout)
if test.expectError && err == nil {
t.Errorf("test %q: Expected error, got none", test.name)
}
Expand Down

0 comments on commit 55d7c0e

Please sign in to comment.