From b9de75e275040c25784e2ee95ae27f6a627d0b67 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Mon, 23 Apr 2018 02:47:24 +0900 Subject: [PATCH] Support UNIX domain socket in the example servers (#609) * Lets example servers gracefully shutdown * Support UNIX domain sockets in the example servers --- Makefile | 1 + examples/BUILD.bazel | 5 +- examples/gateway/BUILD.bazel | 13 ++++ examples/gateway/gateway.go | 71 +++++++++++++++++ examples/integration_test.go | 5 ++ examples/main.go | 71 ++++++++--------- examples/main_test.go | 15 +++- examples/proto_error_test.go | 88 ++++++++++++++++------ examples/server/BUILD.bazel | 1 + examples/server/cmd/example-server/main.go | 9 ++- examples/server/main.go | 21 +++++- 11 files changed, 231 insertions(+), 69 deletions(-) create mode 100644 examples/gateway/BUILD.bazel create mode 100644 examples/gateway/gateway.go diff --git a/Makefile b/Makefile index cc880094592..abac26e6609 100644 --- a/Makefile +++ b/Makefile @@ -133,6 +133,7 @@ $(ABE_EXAMPLE_SRCS): $(ABE_EXAMPLE_SPEC) examples: $(EXAMPLE_SVCSRCS) $(EXAMPLE_GWSRCS) $(EXAMPLE_DEPSRCS) $(EXAMPLE_SWAGGERSRCS) $(EXAMPLE_CLIENT_SRCS) test: examples go test -race $(PKG)/... + go test -race $(PKG)/examples -args -network=unix -endpoint=test.sock lint: golint --set_exit_status $(PKG)/runtime diff --git a/examples/BUILD.bazel b/examples/BUILD.bazel index a6964471274..5ca5fa4ae2e 100644 --- a/examples/BUILD.bazel +++ b/examples/BUILD.bazel @@ -7,10 +7,9 @@ go_library( srcs = ["main.go"], importpath = "github.com/grpc-ecosystem/grpc-gateway/examples", deps = [ - "//examples/examplepb:go_default_library", + "//examples/gateway:go_default_library", "//runtime:go_default_library", "@com_github_golang_glog//:go_default_library", - "@org_golang_google_grpc//:go_default_library", ], ) @@ -36,10 +35,12 @@ go_test( "//examples/server:go_default_library", "//examples/sub:go_default_library", "//runtime:go_default_library", + "@com_github_golang_glog//:go_default_library", "@com_github_golang_protobuf//jsonpb:go_default_library", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_golang_protobuf//ptypes/empty:go_default_library", "@org_golang_google_genproto//googleapis/rpc/status:go_default_library", "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_x_net//context:go_default_library", ], ) diff --git a/examples/gateway/BUILD.bazel b/examples/gateway/BUILD.bazel new file mode 100644 index 00000000000..6aefb52a02b --- /dev/null +++ b/examples/gateway/BUILD.bazel @@ -0,0 +1,13 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = ["gateway.go"], + importpath = "github.com/grpc-ecosystem/grpc-gateway/examples/gateway", + visibility = ["//visibility:public"], + deps = [ + "//examples/examplepb:go_default_library", + "//runtime:go_default_library", + "@org_golang_google_grpc//:go_default_library", + ], +) diff --git a/examples/gateway/gateway.go b/examples/gateway/gateway.go new file mode 100644 index 00000000000..cc503dd2814 --- /dev/null +++ b/examples/gateway/gateway.go @@ -0,0 +1,71 @@ +package gateway + +import ( + "context" + "net" + "net/http" + "time" + + "github.com/grpc-ecosystem/grpc-gateway/examples/examplepb" + gwruntime "github.com/grpc-ecosystem/grpc-gateway/runtime" + "google.golang.org/grpc" +) + +type optSet struct { + mux []gwruntime.ServeMuxOption + dial []grpc.DialOption + + echoEndpoint, abeEndpoint, flowEndpoint string +} + +// newGateway returns a new gateway server which translates HTTP into gRPC. +func newGateway(ctx context.Context, opts optSet) (http.Handler, error) { + mux := gwruntime.NewServeMux(opts.mux...) + + err := examplepb.RegisterEchoServiceHandlerFromEndpoint(ctx, mux, opts.echoEndpoint, opts.dial) + if err != nil { + return nil, err + } + err = examplepb.RegisterStreamServiceHandlerFromEndpoint(ctx, mux, opts.abeEndpoint, opts.dial) + if err != nil { + return nil, err + } + err = examplepb.RegisterABitOfEverythingServiceHandlerFromEndpoint(ctx, mux, opts.abeEndpoint, opts.dial) + if err != nil { + return nil, err + } + err = examplepb.RegisterFlowCombinationHandlerFromEndpoint(ctx, mux, opts.flowEndpoint, opts.dial) + if err != nil { + return nil, err + } + return mux, nil +} + +// NewTCPGateway returns a new gateway server which connect to the gRPC service with TCP. +// "addr" must be a valid TCP address with a port number. +func NewTCPGateway(ctx context.Context, addr string, opts ...gwruntime.ServeMuxOption) (http.Handler, error) { + return newGateway(ctx, optSet{ + mux: opts, + dial: []grpc.DialOption{grpc.WithInsecure()}, + echoEndpoint: addr, + abeEndpoint: addr, + flowEndpoint: addr, + }) +} + +// NewUnixGatway returns a new gateway server which connect to the gRPC service with a unix domain socket. +// "addr" must be a valid path to the socket. +func NewUnixGateway(ctx context.Context, addr string, opts ...gwruntime.ServeMuxOption) (http.Handler, error) { + return newGateway(ctx, optSet{ + mux: opts, + dial: []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", addr, timeout) + }), + }, + echoEndpoint: addr, + abeEndpoint: addr, + flowEndpoint: addr, + }) +} diff --git a/examples/integration_test.go b/examples/integration_test.go index 5cc2ed70e50..a00379795a7 100644 --- a/examples/integration_test.go +++ b/examples/integration_test.go @@ -42,8 +42,13 @@ func TestEcho(t *testing.T) { } func TestForwardResponseOption(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { if err := Run( + ctx, ":8081", runtime.WithForwardResponseOption( func(_ context.Context, w http.ResponseWriter, _ proto.Message) error { diff --git a/examples/main.go b/examples/main.go index 18942b3276c..bb2b55614da 100644 --- a/examples/main.go +++ b/examples/main.go @@ -1,49 +1,24 @@ package main import ( + "context" "flag" + "fmt" "net/http" "path" "strings" - "context" "github.com/golang/glog" - "github.com/grpc-ecosystem/grpc-gateway/examples/examplepb" + "github.com/grpc-ecosystem/grpc-gateway/examples/gateway" "github.com/grpc-ecosystem/grpc-gateway/runtime" - "google.golang.org/grpc" ) var ( - echoEndpoint = flag.String("echo_endpoint", "localhost:9090", "endpoint of EchoService") - abeEndpoint = flag.String("more_endpoint", "localhost:9090", "endpoint of ABitOfEverythingService") - flowEndpoint = flag.String("flow_endpoint", "localhost:9090", "endpoint of FlowCombination") - + endpoint = flag.String("endpoint", "localhost:9090", "endpoint of the gRPC service") + network = flag.String("network", "tcp", `one of "tcp" or "unix". Must be consistent to -endpoint`) swaggerDir = flag.String("swagger_dir", "examples/examplepb", "path to the directory which contains swagger definitions") ) -// newGateway returns a new gateway server which translates HTTP into gRPC. -func newGateway(ctx context.Context, opts ...runtime.ServeMuxOption) (http.Handler, error) { - mux := runtime.NewServeMux(opts...) - dialOpts := []grpc.DialOption{grpc.WithInsecure()} - err := examplepb.RegisterEchoServiceHandlerFromEndpoint(ctx, mux, *echoEndpoint, dialOpts) - if err != nil { - return nil, err - } - err = examplepb.RegisterStreamServiceHandlerFromEndpoint(ctx, mux, *abeEndpoint, dialOpts) - if err != nil { - return nil, err - } - err = examplepb.RegisterABitOfEverythingServiceHandlerFromEndpoint(ctx, mux, *abeEndpoint, dialOpts) - if err != nil { - return nil, err - } - err = examplepb.RegisterFlowCombinationHandlerFromEndpoint(ctx, mux, *flowEndpoint, dialOpts) - if err != nil { - return nil, err - } - return mux, nil -} - func serveSwagger(w http.ResponseWriter, r *http.Request) { if !strings.HasSuffix(r.URL.Path, ".swagger.json") { glog.Errorf("Not Found: %s", r.URL.Path) @@ -81,9 +56,19 @@ func preflightHandler(w http.ResponseWriter, r *http.Request) { return } +func newGateway(ctx context.Context, opts ...runtime.ServeMuxOption) (http.Handler, error) { + switch *network { + case "tcp": + return gateway.NewTCPGateway(ctx, *endpoint, opts...) + case "unix": + return gateway.NewUnixGateway(ctx, *endpoint, opts...) + default: + return nil, fmt.Errorf("unsupported network type %q:", *network) + } +} + // Run starts a HTTP server and blocks forever if successful. -func Run(address string, opts ...runtime.ServeMuxOption) error { - ctx := context.Background() +func Run(ctx context.Context, address string, opts ...runtime.ServeMuxOption) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -96,14 +81,32 @@ func Run(address string, opts ...runtime.ServeMuxOption) error { } mux.Handle("/", gw) - return http.ListenAndServe(address, allowCORS(mux)) + s := &http.Server{ + Addr: address, + Handler: allowCORS(mux), + } + go func() { + <-ctx.Done() + glog.Infof("Shutting down the http server") + if err := s.Shutdown(context.Background()); err != nil { + glog.Errorf("Failed to shutdown http server: %v", err) + } + }() + + glog.Infof("Starting listening at %s", address) + if err := s.ListenAndServe(); err != http.ErrServerClosed { + glog.Errorf("Failed to listen and serve: %v", err) + return err + } + return nil } func main() { flag.Parse() defer glog.Flush() - if err := Run(":8080"); err != nil { + ctx := context.Background() + if err := Run(ctx, ":8080"); err != nil { glog.Fatal(err) } } diff --git a/examples/main_test.go b/examples/main_test.go index 2742c385bb2..9c98c6f7a4e 100644 --- a/examples/main_test.go +++ b/examples/main_test.go @@ -7,18 +7,20 @@ import ( "testing" "time" + "github.com/golang/glog" server "github.com/grpc-ecosystem/grpc-gateway/examples/server" + "golang.org/x/net/context" ) -func runServers() <-chan error { +func runServers(ctx context.Context) <-chan error { ch := make(chan error, 2) go func() { - if err := server.Run(); err != nil { + if err := server.Run(ctx, *network, *endpoint); err != nil { ch <- fmt.Errorf("cannot run grpc service: %v", err) } }() go func() { - if err := Run(":8080"); err != nil { + if err := Run(ctx, ":8080"); err != nil { ch <- fmt.Errorf("cannot run gateway service: %v", err) } }() @@ -27,7 +29,11 @@ func runServers() <-chan error { func TestMain(m *testing.M) { flag.Parse() - errCh := runServers() + defer glog.Flush() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := runServers(ctx) ch := make(chan int, 1) go func() { @@ -40,6 +46,7 @@ func TestMain(m *testing.M) { fmt.Fprintln(os.Stderr, err) os.Exit(1) case status := <-ch: + cancel() os.Exit(status) } } diff --git a/examples/proto_error_test.go b/examples/proto_error_test.go index de3b638f736..df0a99487b8 100644 --- a/examples/proto_error_test.go +++ b/examples/proto_error_test.go @@ -10,24 +10,32 @@ import ( "github.com/golang/protobuf/jsonpb" "github.com/grpc-ecosystem/grpc-gateway/runtime" + "golang.org/x/net/context" spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" ) +func runServer(ctx context.Context, t *testing.T, port uint16) { + opt := runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler) + if err := Run(ctx, fmt.Sprintf(":%d", port), opt); err != nil { + t.Errorf("gw.Run() failed with %v; want success", err) + } +} + func TestWithProtoErrorHandler(t *testing.T) { - go func() { - if err := Run( - ":8082", - runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler), - ); err != nil { - t.Errorf("gw.Run() failed with %v; want success", err) - return - } - }() + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + const port = 8082 + go runServer(ctx, t, port) + // Waiting for the server's getting available. + // TODO(yugui) find a better way to wait time.Sleep(100 * time.Millisecond) - testEcho(t, 8082, "application/json") - testEchoBody(t, 8082) + + testEcho(t, port, "application/json") + testEchoBody(t, port) } func TestABEWithProtoErrorHandler(t *testing.T) { @@ -36,19 +44,29 @@ func TestABEWithProtoErrorHandler(t *testing.T) { return } - testABECreate(t, 8082) - testABECreateBody(t, 8082) - testABEBulkCreate(t, 8082) - testABELookup(t, 8082) - testABELookupNotFoundWithProtoError(t) - testABEList(t, 8082) - testABEBulkEcho(t, 8082) - testABEBulkEchoZeroLength(t, 8082) - testAdditionalBindings(t, 8082) + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + const port = 8083 + go runServer(ctx, t, port) + // Waiting for the server's getting available. + // TODO(yugui) find a better way to wait + time.Sleep(100 * time.Millisecond) + + testABECreate(t, port) + testABECreateBody(t, port) + testABEBulkCreate(t, port) + testABELookup(t, port) + testABELookupNotFoundWithProtoError(t, port) + testABEList(t, port) + testABEBulkEcho(t, port) + testABEBulkEchoZeroLength(t, port) + testAdditionalBindings(t, port) } -func testABELookupNotFoundWithProtoError(t *testing.T) { - url := "http://localhost:8082/v1/example/a_bit_of_everything" +func testABELookupNotFoundWithProtoError(t *testing.T, port uint16) { + url := fmt.Sprintf("http://localhost:%d/v1/example/a_bit_of_everything", port) uuid := "not_exist" url = fmt.Sprintf("%s/%s", url, uuid) resp, err := http.Get(url) @@ -98,7 +116,18 @@ func testABELookupNotFoundWithProtoError(t *testing.T) { } func TestUnknownPathWithProtoError(t *testing.T) { - url := "http://localhost:8082" + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + const port = 8084 + go runServer(ctx, t, port) + + // Waiting for the server's getting available. + // TODO(yugui) find a better way to wait + time.Sleep(100 * time.Millisecond) + + url := fmt.Sprintf("http://localhost:%d", port) resp, err := http.Post(url, "application/json", strings.NewReader("{}")) if err != nil { t.Errorf("http.Post(%q) failed with %v; want success", url, err) @@ -134,7 +163,18 @@ func TestUnknownPathWithProtoError(t *testing.T) { } func TestMethodNotAllowedWithProtoError(t *testing.T) { - url := "http://localhost:8082/v1/example/echo/myid" + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + const port = 8085 + go runServer(ctx, t, port) + + // Waiting for the server's getting available. + // TODO(yugui) find a better way to wait + time.Sleep(100 * time.Millisecond) + + url := fmt.Sprintf("http://localhost:%d/v1/example/echo/myid", port) resp, err := http.Get(url) if err != nil { t.Errorf("http.Post(%q) failed with %v; want success", url, err) diff --git a/examples/server/BUILD.bazel b/examples/server/BUILD.bazel index 0982db9636f..8c34fc6d6c1 100644 --- a/examples/server/BUILD.bazel +++ b/examples/server/BUILD.bazel @@ -25,5 +25,6 @@ go_library( "@org_golang_google_grpc//codes:go_default_library", "@org_golang_google_grpc//metadata:go_default_library", "@org_golang_google_grpc//status:go_default_library", + "@org_golang_x_net//context:go_default_library", ], ) diff --git a/examples/server/cmd/example-server/main.go b/examples/server/cmd/example-server/main.go index 34b319ab4ed..31e182c4b12 100644 --- a/examples/server/cmd/example-server/main.go +++ b/examples/server/cmd/example-server/main.go @@ -1,17 +1,24 @@ package main import ( + "context" "flag" "github.com/golang/glog" "github.com/grpc-ecosystem/grpc-gateway/examples/server" ) +var ( + addr = flag.String("addr", ":9090", "endpoint of the gRPC service") + network = flag.String("network", "tcp", "a valid network type which is consistent to -addr") +) + func main() { flag.Parse() defer glog.Flush() - if err := server.Run(); err != nil { + ctx := context.Background() + if err := server.Run(ctx, *network, *addr); err != nil { glog.Fatal(err) } } diff --git a/examples/server/main.go b/examples/server/main.go index c5e6cb6f97f..a24f16b7cab 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -3,15 +3,25 @@ package server import ( "net" + "github.com/golang/glog" examples "github.com/grpc-ecosystem/grpc-gateway/examples/examplepb" + "golang.org/x/net/context" "google.golang.org/grpc" ) -func Run() error { - l, err := net.Listen("tcp", ":9090") +// Run starts the example gRPC service. +// "network" and "address" are passed to net.Listen. +func Run(ctx context.Context, network, address string) error { + l, err := net.Listen(network, address) if err != nil { return err } + defer func() { + if err := l.Close(); err != nil { + glog.Errorf("Failed to close %s %s: %v", network, address, err) + } + }() + s := grpc.NewServer() examples.RegisterEchoServiceServer(s, newEchoServer()) examples.RegisterFlowCombinationServer(s, newFlowCombinationServer()) @@ -20,6 +30,9 @@ func Run() error { examples.RegisterABitOfEverythingServiceServer(s, abe) examples.RegisterStreamServiceServer(s, abe) - s.Serve(l) - return nil + go func() { + defer s.GracefulStop() + <-ctx.Done() + }() + return s.Serve(l) }