Skip to content

Commit

Permalink
Use Connect() and Probe() from csi-lib-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jsafrane committed Feb 22, 2019
1 parent c3ca24d commit 41d9159
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 90 deletions.
28 changes: 15 additions & 13 deletions cmd/csi-provisioner/csi-provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"k8s.io/klog"

flag "github.com/spf13/pflag"
"google.golang.org/grpc"

ctrl "github.com/kubernetes-csi/external-provisioner/pkg/controller"
snapclientset "github.com/kubernetes-csi/external-snapshotter/pkg/client/clientset/versioned"
Expand All @@ -49,7 +48,7 @@ var (
master = flag.String("master", "", "Master URL to build a client config from. Either this or kubeconfig needs to be set if the provisioner is being run out of cluster.")
kubeconfig = flag.String("kubeconfig", "", "Absolute path to the kubeconfig file. Either this or master needs to be set if the provisioner is being run out of cluster.")
csiEndpoint = flag.String("csi-address", "/run/csi/socket", "The gRPC endpoint for Target CSI Volume.")
connectionTimeout = flag.Duration("connection-timeout", 10*time.Second, "Timeout for waiting for CSI driver socket.")
connectionTimeout = flag.Duration("connection-timeout", 0, "This option is deprecated.")
volumeNamePrefix = flag.String("volume-name-prefix", "pvc", "Prefix to apply to the name of a created volume.")
volumeNameUUIDLength = flag.Int("volume-name-uuid-length", -1, "Truncates generated UUID of a created volume to this length. Defaults behavior is to NOT truncate.")
showVersion = flag.Bool("version", false, "Show version.")
Expand Down Expand Up @@ -78,6 +77,10 @@ func init() {
flag.Set("logtostderr", "true")
flag.Parse()

if *connectionTimeout != 0 {
klog.Warningf("Warning: option -connection-timeout is deprecated and has no effect")
}

if err := utilfeature.DefaultFeatureGate.SetFromMap(featureGates); err != nil {
klog.Fatal(err)
}
Expand Down Expand Up @@ -127,17 +130,16 @@ func init() {
klog.Fatalf("Error getting server version: %v", err)
}

// Provisioner will stay in Init until driver opens csi socket, once it's done
// controller will exit this loop and proceed normally.
socketDown := true
grpcClient := &grpc.ClientConn{}
for socketDown {
grpcClient, err = ctrl.Connect(*csiEndpoint, *connectionTimeout)
if err == nil {
socketDown = false
continue
}
time.Sleep(10 * time.Second)
grpcClient, err := ctrl.Connect(*csiEndpoint)
if err != nil {
klog.Error(err.Error())
os.Exit(1)
}

err = ctrl.Probe(grpcClient, *operationTimeout)
if err != nil {
klog.Error(err.Error())
os.Exit(1)
}

// Autodetect provisioner name
Expand Down
60 changes: 11 additions & 49 deletions pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"
"fmt"
"math"
"net"
"os"
"strings"
"time"
Expand All @@ -33,6 +32,7 @@ import (
"github.com/kubernetes-sigs/sig-storage-lib-external-provisioner/controller"
"github.com/kubernetes-sigs/sig-storage-lib-external-provisioner/util"

"github.com/kubernetes-csi/csi-lib-utils/connection"
"github.com/kubernetes-csi/csi-lib-utils/protosanitizer"
snapapi "github.com/kubernetes-csi/external-snapshotter/pkg/apis/volumesnapshot/v1alpha1"
snapclientset "github.com/kubernetes-csi/external-snapshotter/pkg/client/clientset/versioned"
Expand All @@ -44,7 +44,6 @@ import (
"k8s.io/client-go/rest"

"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"

"github.com/container-storage-interface/spec/lib/go/csi"
csiclientset "k8s.io/csi-api/pkg/client/clientset/versioned"
Expand Down Expand Up @@ -111,31 +110,31 @@ const (

var (
provisionerSecretParams = deprecatedSecretParamsMap{
name: "Provisioner",
name: "Provisioner",
deprecatedSecretNameKey: provisionerSecretNameKey,
deprecatedSecretNamespaceKey: provisionerSecretNamespaceKey,
secretNameKey: prefixedProvisionerSecretNameKey,
secretNamespaceKey: prefixedProvisionerSecretNamespaceKey,
}

nodePublishSecretParams = deprecatedSecretParamsMap{
name: "NodePublish",
name: "NodePublish",
deprecatedSecretNameKey: nodePublishSecretNameKey,
deprecatedSecretNamespaceKey: nodePublishSecretNamespaceKey,
secretNameKey: prefixedNodePublishSecretNameKey,
secretNamespaceKey: prefixedNodePublishSecretNamespaceKey,
}

controllerPublishSecretParams = deprecatedSecretParamsMap{
name: "ControllerPublish",
name: "ControllerPublish",
deprecatedSecretNameKey: controllerPublishSecretNameKey,
deprecatedSecretNamespaceKey: controllerPublishSecretNamespaceKey,
secretNameKey: prefixedControllerPublishSecretNameKey,
secretNamespaceKey: prefixedControllerPublishSecretNamespaceKey,
}

nodeStageSecretParams = deprecatedSecretParamsMap{
name: "NodeStage",
name: "NodeStage",
deprecatedSecretNameKey: nodeStageSecretNameKey,
deprecatedSecretNamespaceKey: nodeStageSecretNamespaceKey,
secretNameKey: prefixedNodeStageSecretNameKey,
Expand Down Expand Up @@ -185,55 +184,18 @@ func logGRPC(ctx context.Context, method string, req, reply interface{}, cc *grp
return err
}

func Connect(address string, timeout time.Duration) (*grpc.ClientConn, error) {
klog.V(2).Infof("Connecting to %s", address)
dialOptions := []grpc.DialOption{
grpc.WithInsecure(),
grpc.WithBackoffMaxDelay(time.Second),
grpc.WithUnaryInterceptor(logGRPC),
}
if strings.HasPrefix(address, "/") {
dialOptions = append(dialOptions, grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}))
}
conn, err := grpc.Dial(address, dialOptions...)
func Connect(address string) (*grpc.ClientConn, error) {
return connection.Connect(address, connection.OnConnectionLoss(connection.ExitOnConnectionLoss()))
}

if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for {
if !conn.WaitForStateChange(ctx, conn.GetState()) {
klog.V(4).Infof("Connection timed out")
return conn, fmt.Errorf("Connection timed out")
}
if conn.GetState() == connectivity.Ready {
klog.V(3).Infof("Connected")
return conn, nil
}
klog.V(4).Infof("Still trying, connection is %s", conn.GetState())
}
func Probe(conn *grpc.ClientConn, singleCallTimeout time.Duration) error {
return connection.ProbeForever(conn, singleCallTimeout)
}

func GetDriverName(conn *grpc.ClientConn, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

client := csi.NewIdentityClient(conn)

req := csi.GetPluginInfoRequest{}

rsp, err := client.GetPluginInfo(ctx, &req)
if err != nil {
return "", err
}
name := rsp.GetName()
if name == "" {
return "", fmt.Errorf("name is empty")
}
return name, nil
return connection.GetDriverName(ctx, conn)
}

func getDriverCapabilities(conn *grpc.ClientConn, timeout time.Duration) (sets.Int, error) {
Expand Down
Loading

0 comments on commit 41d9159

Please sign in to comment.