Skip to content

Commit

Permalink
Use Connect() and Probe() from csi-lib-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jsafrane committed Feb 22, 2019
1 parent c3ca24d commit a9543a0
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 71 deletions.
28 changes: 15 additions & 13 deletions cmd/csi-provisioner/csi-provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"k8s.io/klog"

flag "github.com/spf13/pflag"
"google.golang.org/grpc"

ctrl "github.com/kubernetes-csi/external-provisioner/pkg/controller"
snapclientset "github.com/kubernetes-csi/external-snapshotter/pkg/client/clientset/versioned"
Expand All @@ -49,7 +48,7 @@ var (
master = flag.String("master", "", "Master URL to build a client config from. Either this or kubeconfig needs to be set if the provisioner is being run out of cluster.")
kubeconfig = flag.String("kubeconfig", "", "Absolute path to the kubeconfig file. Either this or master needs to be set if the provisioner is being run out of cluster.")
csiEndpoint = flag.String("csi-address", "/run/csi/socket", "The gRPC endpoint for Target CSI Volume.")
connectionTimeout = flag.Duration("connection-timeout", 10*time.Second, "Timeout for waiting for CSI driver socket.")
connectionTimeout = flag.Duration("connection-timeout", 0, "This option is deprecated.")
volumeNamePrefix = flag.String("volume-name-prefix", "pvc", "Prefix to apply to the name of a created volume.")
volumeNameUUIDLength = flag.Int("volume-name-uuid-length", -1, "Truncates generated UUID of a created volume to this length. Defaults behavior is to NOT truncate.")
showVersion = flag.Bool("version", false, "Show version.")
Expand Down Expand Up @@ -78,6 +77,10 @@ func init() {
flag.Set("logtostderr", "true")
flag.Parse()

if *connectionTimeout != 0 {
klog.Warningf("Warning: option -connection-timeout is deprecated and has no effect")
}

if err := utilfeature.DefaultFeatureGate.SetFromMap(featureGates); err != nil {
klog.Fatal(err)
}
Expand Down Expand Up @@ -127,17 +130,16 @@ func init() {
klog.Fatalf("Error getting server version: %v", err)
}

// Provisioner will stay in Init until driver opens csi socket, once it's done
// controller will exit this loop and proceed normally.
socketDown := true
grpcClient := &grpc.ClientConn{}
for socketDown {
grpcClient, err = ctrl.Connect(*csiEndpoint, *connectionTimeout)
if err == nil {
socketDown = false
continue
}
time.Sleep(10 * time.Second)
grpcClient, err := ctrl.Connect(*csiEndpoint)
if err != nil {
klog.Error(err.Error())
os.Exit(1)
}

err = ctrl.Probe(grpcClient, *operationTimeout)
if err != nil {
klog.Error(err.Error())
os.Exit(1)
}

// Autodetect provisioner name
Expand Down
52 changes: 7 additions & 45 deletions pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"
"fmt"
"math"
"net"
"os"
"strings"
"time"
Expand All @@ -33,6 +32,7 @@ import (
"github.com/kubernetes-sigs/sig-storage-lib-external-provisioner/controller"
"github.com/kubernetes-sigs/sig-storage-lib-external-provisioner/util"

"github.com/kubernetes-csi/csi-lib-utils/connection"
"github.com/kubernetes-csi/csi-lib-utils/protosanitizer"
snapapi "github.com/kubernetes-csi/external-snapshotter/pkg/apis/volumesnapshot/v1alpha1"
snapclientset "github.com/kubernetes-csi/external-snapshotter/pkg/client/clientset/versioned"
Expand All @@ -44,7 +44,6 @@ import (
"k8s.io/client-go/rest"

"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"

"github.com/container-storage-interface/spec/lib/go/csi"
csiclientset "k8s.io/csi-api/pkg/client/clientset/versioned"
Expand Down Expand Up @@ -185,55 +184,18 @@ func logGRPC(ctx context.Context, method string, req, reply interface{}, cc *grp
return err
}

func Connect(address string, timeout time.Duration) (*grpc.ClientConn, error) {
klog.V(2).Infof("Connecting to %s", address)
dialOptions := []grpc.DialOption{
grpc.WithInsecure(),
grpc.WithBackoffMaxDelay(time.Second),
grpc.WithUnaryInterceptor(logGRPC),
}
if strings.HasPrefix(address, "/") {
dialOptions = append(dialOptions, grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}))
}
conn, err := grpc.Dial(address, dialOptions...)
func Connect(address string) (*grpc.ClientConn, error) {
return connection.Connect(address, connection.OnConnectionLoss(connection.ExitOnConnectionLoss()))
}

if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for {
if !conn.WaitForStateChange(ctx, conn.GetState()) {
klog.V(4).Infof("Connection timed out")
return conn, fmt.Errorf("Connection timed out")
}
if conn.GetState() == connectivity.Ready {
klog.V(3).Infof("Connected")
return conn, nil
}
klog.V(4).Infof("Still trying, connection is %s", conn.GetState())
}
func Probe(conn *grpc.ClientConn, singleCallTimeout time.Duration) error {
return connection.ProbeForever(conn, singleCallTimeout)
}

func GetDriverName(conn *grpc.ClientConn, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

client := csi.NewIdentityClient(conn)

req := csi.GetPluginInfoRequest{}

rsp, err := client.GetPluginInfo(ctx, &req)
if err != nil {
return "", err
}
name := rsp.GetName()
if name == "" {
return "", fmt.Errorf("name is empty")
}
return name, nil
return connection.GetDriverName(ctx, conn)
}

func getDriverCapabilities(conn *grpc.ClientConn, timeout time.Duration) (sets.Int, error) {
Expand Down
57 changes: 44 additions & 13 deletions pkg/controller/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"strconv"
"testing"
"time"

"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/golang/mock/gomock"
"github.com/kubernetes-csi/csi-lib-utils/connection"
"github.com/kubernetes-csi/csi-test/driver"
"github.com/kubernetes-csi/external-provisioner/pkg/features"
crdv1 "github.com/kubernetes-csi/external-snapshotter/pkg/apis/volumesnapshot/v1alpha1"
Expand Down Expand Up @@ -62,8 +66,8 @@ type csiConnection struct {
conn *grpc.ClientConn
}

func New(address string, timeout time.Duration) (csiConnection, error) {
conn, err := Connect(address, timeout)
func New(address string) (csiConnection, error) {
conn, err := connection.Connect(address)
if err != nil {
return csiConnection{}, err
}
Expand All @@ -72,7 +76,7 @@ func New(address string, timeout time.Duration) (csiConnection, error) {
}, nil
}

func createMockServer(t *testing.T) (*gomock.Controller,
func createMockServer(t *testing.T, tmpdir string) (*gomock.Controller,
*driver.MockCSIDriver,
*driver.MockIdentityServer,
*driver.MockControllerServer,
Expand All @@ -85,18 +89,26 @@ func createMockServer(t *testing.T) (*gomock.Controller,
Identity: identityServer,
Controller: controllerServer,
})
drv.Start()
drv.StartOnAddress("unix", filepath.Join(tmpdir, "csi.sock"))

// Create a client connection to it
addr := drv.Address()
csiConn, err := New(addr, timeout)
csiConn, err := New(addr)
if err != nil {
return nil, nil, nil, nil, csiConnection{}, err
}

return mockController, drv, identityServer, controllerServer, csiConn, nil
}

func tempDir(t *testing.T) string {
dir, err := ioutil.TempDir("", "external-attacher-test-")
if err != nil {
t.Fatalf("Cannot create temporary directory: %s", err)
}
return dir
}

func TestGetPluginName(t *testing.T) {
test := struct {
name string
Expand All @@ -121,7 +133,9 @@ func TestGetPluginName(t *testing.T) {
},
}

mockController, driver, identityServer, _, csiConn, err := createMockServer(t)
tmpdir := tempDir(t)
defer os.RemoveAll(tmpdir)
mockController, driver, identityServer, _, csiConn, err := createMockServer(t, tmpdir)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -326,7 +340,9 @@ func TestGetDriverCapabilities(t *testing.T) {
},
}...)

mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t)
tmpdir := tempDir(t)
defer os.RemoveAll(tmpdir)
mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -448,7 +464,9 @@ func TestGetDriverName(t *testing.T) {
},
}

mockController, driver, identityServer, _, csiConn, err := createMockServer(t)
tmpdir := tempDir(t)
defer os.RemoveAll(tmpdir)
mockController, driver, identityServer, _, csiConn, err := createMockServer(t, tmpdir)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -527,7 +545,10 @@ func TestBytesToQuantity(t *testing.T) {
func TestCreateDriverReturnsInvalidCapacityDuringProvision(t *testing.T) {
// Set up mocks
var requestedBytes int64 = 100
mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t)

tmpdir := tempDir(t)
defer os.RemoveAll(tmpdir)
mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1349,7 +1370,9 @@ func newSnapshot(name, className, boundToContent, snapshotUID, claimName string,
func runProvisionTest(t *testing.T, k string, tc provisioningTestcase, requestedBytes int64) {
t.Logf("Running test: %v", k)

mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t)
tmpdir := tempDir(t)
defer os.RemoveAll(tmpdir)
mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1712,7 +1735,9 @@ func TestProvisionFromSnapshot(t *testing.T) {
},
}

mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t)
tmpdir := tempDir(t)
defer os.RemoveAll(tmpdir)
mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1820,7 +1845,10 @@ func TestProvisionWithTopology(t *testing.T) {
}

const requestBytes = 100
mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t)

tmpdir := tempDir(t)
defer os.RemoveAll(tmpdir)
mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1858,7 +1886,10 @@ func TestProvisionWithTopology(t *testing.T) {
func TestProvisionWithMountOptions(t *testing.T) {
expectedOptions := []string{"foo=bar", "baz=qux"}
const requestBytes = 100
mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t)

tmpdir := tempDir(t)
defer os.RemoveAll(tmpdir)
mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit a9543a0

Please sign in to comment.