Skip to content

Commit

Permalink
Ensure GRPC service shuts down on context cancel.
Browse files Browse the repository at this point in the history
  • Loading branch information
mcdee committed Oct 21, 2020
1 parent 4a00eb0 commit 002a547
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Development
- Ensure GRPC service shuts down on context cancel
- Add `--version` flag to print software version

# Version 0.9.0
Expand Down
7 changes: 7 additions & 0 deletions services/api/grpc/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) {
if err != nil {
return nil, errors.Wrap(err, "failed to start API server")
}

// Cancel service on context done.
go func() {
<-ctx.Done()
s.grpcServer.GracefulStop()
}()

return s, nil
}

Expand Down
33 changes: 32 additions & 1 deletion testing/daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,45 @@ package daemon_test

import (
"context"
"fmt"
"math/rand"
"net"
"os"
"testing"
"time"

"github.com/attestantio/dirk/testing/daemon"
"github.com/stretchr/testify/require"
)

func TestDaemon(t *testing.T) {
ctx := context.Background()
_, _, err := daemon.New(ctx, "", 1, 12345, map[uint64]string{1: "server-test01:12345"})
// #nosec G404
port := uint32(12000 + rand.Intn(4000))
_, path, err := daemon.New(ctx, "", 1, port, map[uint64]string{1: fmt.Sprintf("signer-test01:%d", port)})
require.NoError(t, err)
os.RemoveAll(path)
}

func TestCancelDaemon(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
// #nosec G404
port := uint32(12000 + rand.Intn(4000))
_, path, err := daemon.New(ctx, "", 1, port, map[uint64]string{1: fmt.Sprintf("signer-test01:%d", port)})
require.NoError(t, err)
defer os.RemoveAll(path)
require.True(t, endpointAlive(fmt.Sprintf("signer-test01:%d", port)))
cancel()
// Sleep for a second to allow graceful stop of the daemon.
time.Sleep(time.Second)
require.False(t, endpointAlive(fmt.Sprintf("signer-test01:%d", port)))
}

func endpointAlive(address string) bool {
conn, err := net.DialTimeout("tcp", address, 5*time.Second)
if err == nil {
conn.Close()
return true
}
return false
}

0 comments on commit 002a547

Please sign in to comment.