From 002a5473a8917523faaae63b56d2bfcf60b857d5 Mon Sep 17 00:00:00 2001 From: Jim McDonald Date: Wed, 21 Oct 2020 19:35:10 +0100 Subject: [PATCH] Ensure GRPC service shuts down on context cancel. --- CHANGELOG.md | 1 + services/api/grpc/service.go | 7 +++++++ testing/daemon/daemon_test.go | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bccaf32..0205d94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ # Development + - Ensure GRPC service shuts down on context cancel - Add `--version` flag to print software version # Version 0.9.0 diff --git a/services/api/grpc/service.go b/services/api/grpc/service.go index 1128cc3..165500e 100644 --- a/services/api/grpc/service.go +++ b/services/api/grpc/service.go @@ -120,6 +120,13 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) { if err != nil { return nil, errors.Wrap(err, "failed to start API server") } + + // Cancel service on context done. + go func() { + <-ctx.Done() + s.grpcServer.GracefulStop() + }() + return s, nil } diff --git a/testing/daemon/daemon_test.go b/testing/daemon/daemon_test.go index 877323a..ac98a0f 100644 --- a/testing/daemon/daemon_test.go +++ b/testing/daemon/daemon_test.go @@ -15,7 +15,12 @@ package daemon_test import ( "context" + "fmt" + "math/rand" + "net" + "os" "testing" + "time" "github.com/attestantio/dirk/testing/daemon" "github.com/stretchr/testify/require" @@ -23,6 +28,32 @@ import ( func TestDaemon(t *testing.T) { ctx := context.Background() - _, _, err := daemon.New(ctx, "", 1, 12345, map[uint64]string{1: "server-test01:12345"}) + // #nosec G404 + port := uint32(12000 + rand.Intn(4000)) + _, path, err := daemon.New(ctx, "", 1, port, map[uint64]string{1: fmt.Sprintf("signer-test01:%d", port)}) require.NoError(t, err) + os.RemoveAll(path) +} + +func TestCancelDaemon(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + // #nosec G404 + port := uint32(12000 + rand.Intn(4000)) + _, path, err := daemon.New(ctx, "", 1, port, map[uint64]string{1: fmt.Sprintf("signer-test01:%d", port)}) + require.NoError(t, err) + defer os.RemoveAll(path) + require.True(t, endpointAlive(fmt.Sprintf("signer-test01:%d", port))) + cancel() + // Sleep for a second to allow graceful stop of the daemon. + time.Sleep(time.Second) + require.False(t, endpointAlive(fmt.Sprintf("signer-test01:%d", port))) +} + +func endpointAlive(address string) bool { + conn, err := net.DialTimeout("tcp", address, 5*time.Second) + if err == nil { + conn.Close() + return true + } + return false }