diff --git a/cmd/livenessprobe/livenessprobe_test.go b/cmd/livenessprobe/livenessprobe_test.go index 38435338..7fc3ea2b 100644 --- a/cmd/livenessprobe/livenessprobe_test.go +++ b/cmd/livenessprobe/livenessprobe_test.go @@ -17,15 +17,17 @@ limitations under the License. package main import ( + "flag" + "fmt" + "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" csi "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" - connlib "github.com/kubernetes-csi/csi-lib-utils/connection" "github.com/kubernetes-csi/csi-test/driver" - "google.golang.org/grpc" ) const ( @@ -38,8 +40,7 @@ func createMockServer(t *testing.T) ( *driver.MockIdentityServer, *driver.MockControllerServer, *driver.MockNodeServer, - *grpc.ClientConn, - error) { + func()) { // Start the mock server mockController := gomock.NewController(t) identityServer := driver.NewMockIdentityServer(mockController) @@ -50,26 +51,69 @@ func createMockServer(t *testing.T) ( Controller: controllerServer, Node: nodeServer, }) - drv.Start() - // Create a client connection to it - addr := drv.Address() - csiConn, err := connlib.Connect(addr) + tmpDir, err := ioutil.TempDir("", "livenessprobe_test.*") if err != nil { - return nil, nil, nil, nil, nil, nil, err + t.Errorf("failed to create a temporary socket file name: %v", err) } - return mockController, drv, identityServer, controllerServer, nodeServer, csiConn, nil + csiEndpoint := fmt.Sprintf("%s/csi.sock", tmpDir) + err = drv.StartOnAddress("unix", csiEndpoint) + if err != nil { + t.Errorf("failed to start the csi driver at %s: %v", csiEndpoint, err) + } + + return mockController, drv, identityServer, controllerServer, nodeServer, func() { + mockController.Finish() + drv.Stop() + os.RemoveAll(csiEndpoint) + } } func TestProbe(t *testing.T) { - mockController, driver, idServer, _, _, csiConn, err := createMockServer(t) + _, driver, idServer, _, _, cleanUpFunc := createMockServer(t) + defer cleanUpFunc() + + flag.Set("csi-address", driver.Address()) + flag.Parse() + + var injectedErr error + + inProbe := &csi.ProbeRequest{} + outProbe := &csi.ProbeResponse{} + idServer.EXPECT().Probe(gomock.Any(), inProbe).Return(outProbe, injectedErr).Times(1) + + hp := &healthProbe{driverName: driverName} + + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.URL.String() == "/healthz" { + hp.checkProbe(rw, req) + } + })) + defer server.Close() + + httpreq, err := http.NewRequest("GET", fmt.Sprintf("%s/healthz", server.URL), nil) if err != nil { - t.Fatal(err) + t.Fatalf("failed to build test request for health check: %v", err) } - defer mockController.Finish() - defer driver.Stop() - defer csiConn.Close() + + httpresp, err := http.DefaultClient.Do(httpreq) + if err != nil { + t.Errorf("failed to check probe: %v", err) + } + + expectedStatusCode := http.StatusOK + if httpresp.StatusCode != expectedStatusCode { + t.Errorf("expected status code %d but got %d", expectedStatusCode, httpresp.StatusCode) + } +} + +func TestProbe_issue68(t *testing.T) { + _, driver, idServer, _, _, cleanUpFunc := createMockServer(t) + defer cleanUpFunc() + + flag.Set("csi-address", driver.Address()) + flag.Parse() var injectedErr error @@ -77,10 +121,7 @@ func TestProbe(t *testing.T) { outProbe := &csi.ProbeResponse{} idServer.EXPECT().Probe(gomock.Any(), inProbe).Return(outProbe, injectedErr).Times(1) - hp := &healthProbe{ - conn: csiConn, - driverName: driverName, - } + hp := &healthProbe{driverName: driverName} server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { if req.URL.String() == "/healthz" { @@ -89,12 +130,38 @@ func TestProbe(t *testing.T) { })) defer server.Close() - httpreq, err := http.NewRequest("GET", server.URL+"/healthz", nil) + httpreq, err := http.NewRequest("GET", fmt.Sprintf("%s/healthz", server.URL), nil) if err != nil { t.Fatalf("failed to build test request for health check: %v", err) } - _, err = http.DefaultClient.Do(httpreq) + + httpresp, err := http.DefaultClient.Do(httpreq) if err != nil { t.Errorf("failed to check probe: %v", err) } + + expectedStatusCode := http.StatusOK + if httpresp.StatusCode != expectedStatusCode { + t.Errorf("expected status code %d but got %d", expectedStatusCode, httpresp.StatusCode) + } + + err = os.Remove(driver.Address()) + if err != nil { + t.Errorf("failed to remove the csi driver socket file: %v", err) + } + + httpreq, err = http.NewRequest("GET", fmt.Sprintf("%s/healthz", server.URL), nil) + if err != nil { + t.Fatalf("failed to build test request for health check: %v", err) + } + + httpresp, err = http.DefaultClient.Do(httpreq) + if err != nil { + t.Errorf("failed to check probe: %v", err) + } + + expectedStatusCode = http.StatusInternalServerError + if httpresp.StatusCode != expectedStatusCode { + t.Errorf("expected status code %d but got %d", expectedStatusCode, httpresp.StatusCode) + } } diff --git a/cmd/livenessprobe/main.go b/cmd/livenessprobe/main.go index e0e310b5..1ad7a6d4 100644 --- a/cmd/livenessprobe/main.go +++ b/cmd/livenessprobe/main.go @@ -21,6 +21,7 @@ import ( "flag" "net" "net/http" + "sync" "time" "k8s.io/klog" @@ -39,7 +40,6 @@ var ( ) type healthProbe struct { - conn *grpc.ClientConn driverName string } @@ -47,8 +47,17 @@ func (h *healthProbe) checkProbe(w http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(req.Context(), *probeTimeout) defer cancel() + conn, err := acquireConnection(ctx) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + klog.Errorf("failed to establish connection to CSI driver: %v", err) + return + } + defer conn.Close() + klog.V(5).Infof("Sending probe request to CSI driver %q", h.driverName) - ready, err := rpc.Probe(ctx, h.conn) + ready, err := rpc.Probe(ctx, conn) if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(err.Error())) @@ -68,12 +77,42 @@ func (h *healthProbe) checkProbe(w http.ResponseWriter, req *http.Request) { klog.V(5).Infof("Health check succeeded") } +// acquireConnection wraps the connlib.Connect but adding support to context +// cancelation. +func acquireConnection(ctx context.Context) (conn *grpc.ClientConn, err error) { + var m sync.Mutex + var canceled bool + ready := make(chan bool) + go func() { + conn, err = connlib.Connect(*csiAddress) + + m.Lock() + defer m.Unlock() + if err != nil && canceled { + conn.Close() + } + + close(ready) + }() + + select { + case <-ctx.Done(): + m.Lock() + defer m.Unlock() + canceled = true + return nil, ctx.Err() + + case <-ready: + return conn, err + } +} + func main() { klog.InitFlags(nil) flag.Set("logtostderr", "true") flag.Parse() - csiConn, err := connlib.Connect(*csiAddress) + csiConn, err := acquireConnection(context.Background()) if err != nil { // connlib should retry forever so a returned error should mean // the grpc client is misconfigured rather than an error on the network @@ -82,13 +121,13 @@ func main() { klog.Infof("calling CSI driver to discover driver name") csiDriverName, err := rpc.GetDriverName(context.Background(), csiConn) + csiConn.Close() if err != nil { klog.Fatalf("failed to get CSI driver name: %v", err) } klog.Infof("CSI driver name: %q", csiDriverName) hp := &healthProbe{ - conn: csiConn, driverName: csiDriverName, }