Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open a new connection for each probe call #69

Merged
merged 5 commits into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 88 additions & 21 deletions cmd/livenessprobe/livenessprobe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ limitations under the License.
package main

import (
"flag"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"

csi "github.com/container-storage-interface/spec/lib/go/csi"
"github.com/golang/mock/gomock"
connlib "github.com/kubernetes-csi/csi-lib-utils/connection"
"github.com/kubernetes-csi/csi-test/driver"
"google.golang.org/grpc"
)

const (
Expand All @@ -38,8 +40,7 @@ func createMockServer(t *testing.T) (
*driver.MockIdentityServer,
*driver.MockControllerServer,
*driver.MockNodeServer,
*grpc.ClientConn,
error) {
func()) {
// Start the mock server
mockController := gomock.NewController(t)
identityServer := driver.NewMockIdentityServer(mockController)
Expand All @@ -50,37 +51,77 @@ func createMockServer(t *testing.T) (
Controller: controllerServer,
Node: nodeServer,
})
drv.Start()

// Create a client connection to it
addr := drv.Address()
csiConn, err := connlib.Connect(addr)
tmpDir, err := ioutil.TempDir("", "livenessprobe_test.*")
if err != nil {
return nil, nil, nil, nil, nil, nil, err
t.Errorf("failed to create a temporary socket file name: %v", err)
}

return mockController, drv, identityServer, controllerServer, nodeServer, csiConn, nil
csiEndpoint := fmt.Sprintf("%s/csi.sock", tmpDir)
err = drv.StartOnAddress("unix", csiEndpoint)
if err != nil {
t.Errorf("failed to start the csi driver at %s: %v", csiEndpoint, err)
}

return mockController, drv, identityServer, controllerServer, nodeServer, func() {
mockController.Finish()
drv.Stop()
os.RemoveAll(csiEndpoint)
}
}

func TestProbe(t *testing.T) {
mockController, driver, idServer, _, _, csiConn, err := createMockServer(t)
_, driver, idServer, _, _, cleanUpFunc := createMockServer(t)
defer cleanUpFunc()

flag.Set("csi-address", driver.Address())
flag.Parse()

var injectedErr error

inProbe := &csi.ProbeRequest{}
outProbe := &csi.ProbeResponse{}
idServer.EXPECT().Probe(gomock.Any(), inProbe).Return(outProbe, injectedErr).Times(1)

hp := &healthProbe{driverName: driverName}

server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.String() == "/healthz" {
hp.checkProbe(rw, req)
}
}))
defer server.Close()

httpreq, err := http.NewRequest("GET", fmt.Sprintf("%s/healthz", server.URL), nil)
if err != nil {
t.Fatal(err)
t.Fatalf("failed to build test request for health check: %v", err)
}
defer mockController.Finish()
defer driver.Stop()
defer csiConn.Close()

httpresp, err := http.DefaultClient.Do(httpreq)
if err != nil {
t.Errorf("failed to check probe: %v", err)
}

expectedStatusCode := http.StatusOK
if httpresp.StatusCode != expectedStatusCode {
t.Errorf("expected status code %d but got %d", expectedStatusCode, httpresp.StatusCode)
}
}

func TestProbe_issue68(t *testing.T) {
_, driver, idServer, _, _, cleanUpFunc := createMockServer(t)
defer cleanUpFunc()

flag.Set("csi-address", driver.Address())
flag.Parse()

var injectedErr error

inProbe := &csi.ProbeRequest{}
outProbe := &csi.ProbeResponse{}
idServer.EXPECT().Probe(gomock.Any(), inProbe).Return(outProbe, injectedErr).Times(1)

hp := &healthProbe{
conn: csiConn,
driverName: driverName,
}
hp := &healthProbe{driverName: driverName}

server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.String() == "/healthz" {
Expand All @@ -89,12 +130,38 @@ func TestProbe(t *testing.T) {
}))
defer server.Close()

httpreq, err := http.NewRequest("GET", server.URL+"/healthz", nil)
httpreq, err := http.NewRequest("GET", fmt.Sprintf("%s/healthz", server.URL), nil)
if err != nil {
t.Fatalf("failed to build test request for health check: %v", err)
}
_, err = http.DefaultClient.Do(httpreq)

httpresp, err := http.DefaultClient.Do(httpreq)
if err != nil {
t.Errorf("failed to check probe: %v", err)
}

expectedStatusCode := http.StatusOK
if httpresp.StatusCode != expectedStatusCode {
t.Errorf("expected status code %d but got %d", expectedStatusCode, httpresp.StatusCode)
}

err = os.Remove(driver.Address())
if err != nil {
t.Errorf("failed to remove the csi driver socket file: %v", err)
}

httpreq, err = http.NewRequest("GET", fmt.Sprintf("%s/healthz", server.URL), nil)
if err != nil {
t.Fatalf("failed to build test request for health check: %v", err)
}

httpresp, err = http.DefaultClient.Do(httpreq)
if err != nil {
t.Errorf("failed to check probe: %v", err)
}

expectedStatusCode = http.StatusInternalServerError
if httpresp.StatusCode != expectedStatusCode {
t.Errorf("expected status code %d but got %d", expectedStatusCode, httpresp.StatusCode)
}
}
47 changes: 43 additions & 4 deletions cmd/livenessprobe/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"flag"
"net"
"net/http"
"sync"
"time"

"k8s.io/klog"
Expand All @@ -39,16 +40,24 @@ var (
)

type healthProbe struct {
conn *grpc.ClientConn
driverName string
}

func (h *healthProbe) checkProbe(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), *probeTimeout)
defer cancel()

conn, err := acquireConnection(ctx)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
klog.Errorf("failed to establish connection to CSI driver: %v", err)
return
}
defer conn.Close()

klog.V(5).Infof("Sending probe request to CSI driver %q", h.driverName)
ready, err := rpc.Probe(ctx, h.conn)
ready, err := rpc.Probe(ctx, conn)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
Expand All @@ -68,12 +77,42 @@ func (h *healthProbe) checkProbe(w http.ResponseWriter, req *http.Request) {
klog.V(5).Infof("Health check succeeded")
}

// acquireConnection wraps the connlib.Connect but adding support to context
// cancelation.
func acquireConnection(ctx context.Context) (conn *grpc.ClientConn, err error) {
var m sync.Mutex
var canceled bool
ready := make(chan bool)
go func() {
conn, err = connlib.Connect(*csiAddress)

m.Lock()
defer m.Unlock()
if err != nil && canceled {
conn.Close()
}

close(ready)
}()

select {
case <-ctx.Done():
m.Lock()
defer m.Unlock()
canceled = true
return nil, ctx.Err()

case <-ready:
return conn, err
}
}

func main() {
klog.InitFlags(nil)
flag.Set("logtostderr", "true")
flag.Parse()

csiConn, err := connlib.Connect(*csiAddress)
csiConn, err := acquireConnection(context.Background())
if err != nil {
// connlib should retry forever so a returned error should mean
// the grpc client is misconfigured rather than an error on the network
Expand All @@ -82,13 +121,13 @@ func main() {

klog.Infof("calling CSI driver to discover driver name")
csiDriverName, err := rpc.GetDriverName(context.Background(), csiConn)
csiConn.Close()
if err != nil {
klog.Fatalf("failed to get CSI driver name: %v", err)
}
klog.Infof("CSI driver name: %q", csiDriverName)

hp := &healthProbe{
conn: csiConn,
driverName: csiDriverName,
}

Expand Down