From 20ece66eda32c0f957b09acfced3a43deaebf3e9 Mon Sep 17 00:00:00 2001 From: Marek Siarkowicz Date: Thu, 2 Mar 2023 12:31:13 +0100 Subject: [PATCH 1/6] test: Test etcd watch stream starvation under high read response load when sharing the same connection Signed-off-by: Marek Siarkowicz --- pkg/stringutil/rand.go | 6 +- tests/common/grpc_test.go | 227 +++++++++++++++++++++ tests/framework/config/cluster.go | 15 +- tests/framework/e2e/e2e.go | 1 + tests/framework/integration/cluster.go | 54 +++-- tests/framework/integration/integration.go | 19 +- 6 files changed, 283 insertions(+), 39 deletions(-) create mode 100644 tests/common/grpc_test.go diff --git a/pkg/stringutil/rand.go b/pkg/stringutil/rand.go index a15b0de0c08f..96d9df311cfb 100644 --- a/pkg/stringutil/rand.go +++ b/pkg/stringutil/rand.go @@ -24,7 +24,7 @@ func UniqueStrings(slen uint, n int) (ss []string) { exist := make(map[string]struct{}) ss = make([]string, 0, n) for len(ss) < n { - s := randString(slen) + s := RandString(slen) if _, ok := exist[s]; !ok { ss = append(ss, s) exist[s] = struct{}{} @@ -37,14 +37,14 @@ func UniqueStrings(slen uint, n int) (ss []string) { func RandomStrings(slen uint, n int) (ss []string) { ss = make([]string, 0, n) for i := 0; i < n; i++ { - ss = append(ss, randString(slen)) + ss = append(ss, RandString(slen)) } return ss } const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -func randString(l uint) string { +func RandString(l uint) string { rand.Seed(time.Now().UnixNano()) s := make([]byte, l) for i := 0; i < int(l); i++ { diff --git a/tests/common/grpc_test.go b/tests/common/grpc_test.go new file mode 100644 index 000000000000..50af5f6dee77 --- /dev/null +++ b/tests/common/grpc_test.go @@ -0,0 +1,227 @@ +// Copyright 2023 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/pkg/v3/stringutil" + "go.etcd.io/etcd/tests/v3/framework/config" + "go.etcd.io/etcd/tests/v3/framework/integration" + "go.etcd.io/etcd/tests/v3/framework/interfaces" + "golang.org/x/sync/errgroup" +) + +const ( + watchResponsePeriod = 100 * time.Millisecond + watchTestDuration = 5 * time.Second + maxWatchDelay = 100 * time.Millisecond + // Configuration for continuous read load on etcd + numberOfPreexistingKeys = 100 + sizeOfPreexistingValues = 10000 + readLoadConcurrency = 10 +) + +func TestWatchDelayForPeriodicProgressNotification(t *testing.T) { + testRunner.BeforeTest(t) + for _, tc := range clusterTestCases() { + tc := tc + tc.config.WatchProgressNotifyInterval = watchResponsePeriod + t.Run(tc.name, func(t *testing.T) { + clus := testRunner.NewCluster(context.Background(), t, config.WithClusterConfig(tc.config)) + defer clus.Close() + c := client(t, clus, tc.config) + require.NoError(t, fillEtcdWithData(context.Background(), c, numberOfPreexistingKeys, sizeOfPreexistingValues)) + + ctx, cancel := context.WithTimeout(context.Background(), watchTestDuration) + defer cancel() + g := errgroup.Group{} + continuouslyExecuteGetAll(ctx, t, &g, c) + validateWatchDelay(t, c.Watch(ctx, "fake-key", clientv3.WithProgressNotify())) + require.NoError(t, g.Wait()) + }) + } +} + +func TestWatchDelayForManualProgressNotification(t *testing.T) { + testRunner.BeforeTest(t) + for _, tc := range clusterTestCases() { + t.Run(tc.name, func(t *testing.T) { + clus := testRunner.NewCluster(context.Background(), t, config.WithClusterConfig(tc.config)) + defer clus.Close() + c := client(t, clus, tc.config) + require.NoError(t, fillEtcdWithData(context.Background(), c, numberOfPreexistingKeys, sizeOfPreexistingValues)) + + ctx, cancel := context.WithTimeout(context.Background(), watchTestDuration) + defer cancel() + g := errgroup.Group{} + continuouslyExecuteGetAll(ctx, t, &g, c) + g.Go(func() error { + for { + err := c.RequestProgress(ctx) + if err != nil { + // Cannot use error.Is(err, context.DeadlineExceeded) as GRPC can wrap it like status.New(status.Unknown, context.DeadlineExceeded) + if strings.Contains(err.Error(), "context deadline exceeded") { + return nil + } else { + return err + } + } + time.Sleep(watchResponsePeriod) + } + }) + validateWatchDelay(t, c.Watch(ctx, "fake-key")) + require.NoError(t, g.Wait()) + }) + } +} + +func TestWatchDelayForEvent(t *testing.T) { + testRunner.BeforeTest(t) + for _, tc := range clusterTestCases() { + t.Run(tc.name, func(t *testing.T) { + clus := testRunner.NewCluster(context.Background(), t, config.WithClusterConfig(tc.config)) + defer clus.Close() + c := client(t, clus, tc.config) + require.NoError(t, fillEtcdWithData(context.Background(), c, numberOfPreexistingKeys, sizeOfPreexistingValues)) + + ctx, cancel := context.WithTimeout(context.Background(), watchTestDuration) + defer cancel() + g := errgroup.Group{} + g.Go(func() error { + i := 0 + for { + _, err := c.Put(ctx, "key", fmt.Sprintf("%d", i)) + if err != nil { + // Cannot use error.Is(err, context.DeadlineExceeded) as GRPC can wrap it like status.New(status.Unknown, context.DeadlineExceeded) + if strings.Contains(err.Error(), "context deadline exceeded") { + return nil + } else { + return err + } + } + time.Sleep(watchResponsePeriod) + } + }) + continuouslyExecuteGetAll(ctx, t, &g, c) + validateWatchDelay(t, c.Watch(ctx, "key")) + require.NoError(t, g.Wait()) + }) + } +} + +func validateWatchDelay(t *testing.T, watch clientv3.WatchChan) { + start := time.Now() + var maxDelay time.Duration + for range watch { + sinceLast := time.Since(start) + if sinceLast > watchResponsePeriod+maxWatchDelay { + t.Errorf("Unexpected watch response delayed over allowed threshold %s, delay: %s", maxWatchDelay, sinceLast-watchResponsePeriod) + } else { + t.Logf("Got watch response, since last: %s", sinceLast) + } + if sinceLast > maxDelay { + maxDelay = sinceLast + } + start = time.Now() + } + sinceLast := time.Since(start) + if sinceLast > maxDelay && sinceLast > watchResponsePeriod+maxWatchDelay { + t.Errorf("Unexpected watch response delayed over allowed threshold %s, delay: unknown", maxWatchDelay) + t.Errorf("Test finished while in middle of delayed response, measured delay: %s", sinceLast-watchResponsePeriod) + t.Logf("Please increase the test duration to measure delay") + } else { + t.Logf("Max delay: %s", maxDelay-watchResponsePeriod) + } +} + +func fillEtcdWithData(ctx context.Context, c *clientv3.Client, keyCount int, valueSize uint) error { + g := errgroup.Group{} + concurrency := 10 + keysPerRoutine := keyCount / concurrency + for i := 0; i < concurrency; i++ { + i := i + g.Go(func() error { + for j := 0; j < keysPerRoutine; j++ { + _, err := c.Put(ctx, fmt.Sprintf("%d", i*keysPerRoutine+j), stringutil.RandString(valueSize)) + if err != nil { + return err + } + } + return nil + }) + } + return g.Wait() +} + +func continuouslyExecuteGetAll(ctx context.Context, t *testing.T, g *errgroup.Group, c *clientv3.Client) { + mux := sync.RWMutex{} + size := 0 + for i := 0; i < readLoadConcurrency; i++ { + g.Go(func() error { + for { + _, err := c.Get(ctx, "", clientv3.WithPrefix()) + if err != nil { + // Cannot use error.Is(err, context.DeadlineExceeded) as GRPC can wrap it like status.New(status.Unknown, context.DeadlineExceeded) + if strings.Contains(err.Error(), "context deadline exceeded") { + return nil + } else { + return err + } + } + mux.Lock() + size += numberOfPreexistingKeys * sizeOfPreexistingValues + mux.Unlock() + } + }) + } + g.Go(func() error { + lastSize := size + for range time.Tick(time.Second) { + select { + case <-ctx.Done(): + return nil + default: + } + mux.RLock() + t.Logf("Generating read load around %.1f MB/s", float64(size-lastSize)/1000/1000) + lastSize = size + mux.RUnlock() + } + return nil + }) +} + +func client(t *testing.T, clus interfaces.Cluster, cfg config.ClusterConfig) *clientv3.Client { + tls, err := integration.TlsInfo(t, cfg.ClientTLS) + if err != nil { + t.Fatal(err) + } + c, err := integration.Client(clus.Endpoints(), tls) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + c.Close() + }) + return c +} diff --git a/tests/framework/config/cluster.go b/tests/framework/config/cluster.go index a8c23d7082cc..facbfce1f5a6 100644 --- a/tests/framework/config/cluster.go +++ b/tests/framework/config/cluster.go @@ -29,13 +29,14 @@ const ( ) type ClusterConfig struct { - ClusterSize int - PeerTLS TLSConfig - ClientTLS TLSConfig - QuotaBackendBytes int64 - StrictReconfigCheck bool - AuthToken string - SnapshotCount int + ClusterSize int + PeerTLS TLSConfig + ClientTLS TLSConfig + QuotaBackendBytes int64 + StrictReconfigCheck bool + AuthToken string + SnapshotCount int + WatchProgressNotifyInterval time.Duration // ClusterContext is used by "e2e" or "integration" to extend the // ClusterConfig. The common test cases shouldn't care about what diff --git a/tests/framework/e2e/e2e.go b/tests/framework/e2e/e2e.go index 1c3b1830335c..44cce9498b4d 100644 --- a/tests/framework/e2e/e2e.go +++ b/tests/framework/e2e/e2e.go @@ -51,6 +51,7 @@ func (e e2eRunner) NewCluster(ctx context.Context, t testing.TB, opts ...config. WithStrictReconfigCheck(cfg.StrictReconfigCheck), WithAuthTokenOpts(cfg.AuthToken), WithSnapshotCount(cfg.SnapshotCount), + WithWatchProcessNotifyInterval(cfg.WatchProgressNotifyInterval), ) if cfg.ClusterContext != nil { diff --git a/tests/framework/integration/cluster.go b/tests/framework/integration/cluster.go index 1c26d0700747..580c5e533325 100644 --- a/tests/framework/integration/cluster.go +++ b/tests/framework/integration/cluster.go @@ -1439,14 +1439,13 @@ func (c *Cluster) Endpoints() []string { } func (c *Cluster) ClusterClient(t testing.TB, opts ...framecfg.ClientOption) (client *clientv3.Client, err error) { - cfg, err := c.newClientCfg() - if err != nil { - return nil, err + if c.Cfg.ClientMaxCallSendMsgSize != 0 { + opts = append(opts, WithClientMaxCallSendMsgSize(c.Cfg.ClientMaxCallSendMsgSize)) } - for _, opt := range opts { - opt(cfg) + if c.Cfg.ClientMaxCallRecvMsgSize != 0 { + opts = append(opts, WithClientMaxCallRecvMsgSize(c.Cfg.ClientMaxCallRecvMsgSize)) } - client, err = newClientV3(*cfg) + client, err = Client(c.Endpoints(), c.Cfg.ClientTLS, opts...) if err != nil { return nil, err } @@ -1456,6 +1455,25 @@ func (c *Cluster) ClusterClient(t testing.TB, opts ...framecfg.ClientOption) (cl return client, nil } +func Client(endpoints []string, tlscfg *transport.TLSInfo, opts ...framecfg.ClientOption) (client *clientv3.Client, err error) { + cfg := &clientv3.Config{ + Endpoints: endpoints, + DialTimeout: 5 * time.Second, + DialOptions: []grpc.DialOption{grpc.WithBlock()}, + } + if tlscfg != nil { + tls, err := tlscfg.ClientConfig() + if err != nil { + return nil, err + } + cfg.TLS = tls + } + for _, opt := range opts { + opt(cfg) + } + return newClientV3(*cfg) +} + func WithAuth(userName, password string) framecfg.ClientOption { return func(c any) { cfg := c.(*clientv3.Config) @@ -1464,22 +1482,18 @@ func WithAuth(userName, password string) framecfg.ClientOption { } } -func (c *Cluster) newClientCfg() (*clientv3.Config, error) { - cfg := &clientv3.Config{ - Endpoints: c.Endpoints(), - DialTimeout: 5 * time.Second, - DialOptions: []grpc.DialOption{grpc.WithBlock()}, - MaxCallSendMsgSize: c.Cfg.ClientMaxCallSendMsgSize, - MaxCallRecvMsgSize: c.Cfg.ClientMaxCallRecvMsgSize, +func WithClientMaxCallSendMsgSize(value int) framecfg.ClientOption { + return func(c any) { + cfg := c.(*clientv3.Config) + cfg.MaxCallSendMsgSize = value } - if c.Cfg.ClientTLS != nil { - tls, err := c.Cfg.ClientTLS.ClientConfig() - if err != nil { - return nil, err - } - cfg.TLS = tls +} + +func WithClientMaxCallRecvMsgSize(value int) framecfg.ClientOption { + return func(c any) { + cfg := c.(*clientv3.Config) + cfg.MaxCallRecvMsgSize = value } - return cfg, nil } // NewClientV3 creates a new grpc client connection to the member diff --git a/tests/framework/integration/integration.go b/tests/framework/integration/integration.go index 8d5f786e7177..330311c7bab3 100644 --- a/tests/framework/integration/integration.go +++ b/tests/framework/integration/integration.go @@ -51,17 +51,18 @@ func (e integrationRunner) NewCluster(ctx context.Context, t testing.TB, opts .. var err error cfg := config.NewClusterConfig(opts...) integrationCfg := ClusterConfig{ - Size: cfg.ClusterSize, - QuotaBackendBytes: cfg.QuotaBackendBytes, - DisableStrictReconfigCheck: !cfg.StrictReconfigCheck, - AuthToken: cfg.AuthToken, - SnapshotCount: uint64(cfg.SnapshotCount), - } - integrationCfg.ClientTLS, err = tlsInfo(t, cfg.ClientTLS) + Size: cfg.ClusterSize, + QuotaBackendBytes: cfg.QuotaBackendBytes, + DisableStrictReconfigCheck: !cfg.StrictReconfigCheck, + AuthToken: cfg.AuthToken, + SnapshotCount: uint64(cfg.SnapshotCount), + WatchProgressNotifyInterval: cfg.WatchProgressNotifyInterval, + } + integrationCfg.ClientTLS, err = TlsInfo(t, cfg.ClientTLS) if err != nil { t.Fatalf("ClientTLS: %s", err) } - integrationCfg.PeerTLS, err = tlsInfo(t, cfg.PeerTLS) + integrationCfg.PeerTLS, err = TlsInfo(t, cfg.PeerTLS) if err != nil { t.Fatalf("PeerTLS: %s", err) } @@ -72,7 +73,7 @@ func (e integrationRunner) NewCluster(ctx context.Context, t testing.TB, opts .. } } -func tlsInfo(t testing.TB, cfg config.TLSConfig) (*transport.TLSInfo, error) { +func TlsInfo(t testing.TB, cfg config.TLSConfig) (*transport.TLSInfo, error) { switch cfg { case config.NoTLS: return nil, nil From 9c922eaa696621b22e59a54d0fe173da615a030c Mon Sep 17 00:00:00 2001 From: Marek Siarkowicz Date: Fri, 3 Mar 2023 14:32:29 +0100 Subject: [PATCH 2/6] server: Refactor common parts between secure and insecure serving Signed-off-by: Marek Siarkowicz --- server/embed/serve.go | 71 +++++++++++++++++++------------------------ 1 file changed, 32 insertions(+), 39 deletions(-) diff --git a/server/embed/serve.go b/server/embed/serve.go index 7fff618a687c..dbd19946edb0 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -16,6 +16,7 @@ package embed import ( "context" + "crypto/tls" "fmt" "io" defaultLog "log" @@ -114,25 +115,30 @@ func (sctx *serveCtx) serve( servElection := v3election.NewElectionServer(v3c) servLock := v3lock.NewLockServer(v3c) - var gs *grpc.Server - defer func() { - if err != nil && gs != nil { - sctx.lg.Warn("stopping grpc server due to error", zap.Error(err)) - gs.Stop() - sctx.lg.Warn("stopped grpc server due to error", zap.Error(err)) - } - }() + var tlscfg *tls.Config // Make sure serversC is closed even if we prematurely exit the function. defer close(sctx.serversC) + if sctx.secure { + tlscfg, err = tlsinfo.ServerConfig() + if err != nil { + return err + } + } + gs := v3rpc.Server(s, tlscfg, nil, gopts...) + defer func() { + sctx.lg.Warn("stopping grpc server due to error", zap.Error(err)) + gs.Stop() + sctx.lg.Warn("stopped grpc server due to error", zap.Error(err)) + }() + v3electionpb.RegisterElectionServer(gs, servElection) + v3lockpb.RegisterLockServer(gs, servLock) + if sctx.serviceRegister != nil { + sctx.serviceRegister(gs) + } + var srv *http.Server if sctx.insecure { - gs = v3rpc.Server(s, nil, nil, gopts...) - v3electionpb.RegisterElectionServer(gs, servElection) - v3lockpb.RegisterLockServer(gs, servLock) - if sctx.serviceRegister != nil { - sctx.serviceRegister(gs) - } grpcl := m.Match(cmux.HTTP2()) go func() { errHandler(gs.Serve(grpcl)) }() @@ -147,35 +153,19 @@ func (sctx *serveCtx) serve( httpmux := sctx.createMux(gwmux, handler) - srvhttp := &http.Server{ + srv = &http.Server{ Handler: createAccessController(sctx.lg, s, httpmux), ErrorLog: logger, // do not log user error } - if err := configureHttpServer(srvhttp, s.Cfg); err != nil { + if err := configureHttpServer(srv, s.Cfg); err != nil { sctx.lg.Error("Configure http server failed", zap.Error(err)) return err } httpl := m.Match(cmux.HTTP1()) - go func() { errHandler(srvhttp.Serve(httpl)) }() - - sctx.serversC <- &servers{grpc: gs, http: srvhttp} - sctx.lg.Info( - "serving client traffic insecurely; this is strongly discouraged!", - zap.String("address", sctx.l.Addr().String()), - ) + go func() { errHandler(srv.Serve(httpl)) }() } if sctx.secure { - tlscfg, tlsErr := tlsinfo.ServerConfig() - if tlsErr != nil { - return tlsErr - } - gs = v3rpc.Server(s, tlscfg, nil, gopts...) - v3electionpb.RegisterElectionServer(gs, servElection) - v3lockpb.RegisterLockServer(gs, servLock) - if sctx.serviceRegister != nil { - sctx.serviceRegister(gs) - } handler = grpcHandlerFunc(gs, handler) var gwmux *gw.ServeMux @@ -199,7 +189,7 @@ func (sctx *serveCtx) serve( // TODO: add debug flag; enable logging when debug flag is set httpmux := sctx.createMux(gwmux, handler) - srv := &http.Server{ + srv = &http.Server{ Handler: createAccessController(sctx.lg, s, httpmux), TLSConfig: tlscfg, ErrorLog: logger, // do not log user error @@ -210,13 +200,16 @@ func (sctx *serveCtx) serve( } go func() { errHandler(srv.Serve(tlsl)) }() - sctx.serversC <- &servers{secure: true, grpc: gs, http: srv} - sctx.lg.Info( - "serving client traffic securely", - zap.String("address", sctx.l.Addr().String()), - ) } + sctx.serversC <- &servers{secure: sctx.secure, grpc: gs, http: srv} + + msg := "serving client traffic securely" + if sctx.insecure { + msg = "serving client traffic insecurely; this is strongly discouraged!" + } + sctx.lg.Info(msg, zap.String("address", sctx.l.Addr().String())) + return m.Serve() } From 6bdb77cd181a0c0d0bb699320146640e673ab232 Mon Sep 17 00:00:00 2001 From: Marek Siarkowicz Date: Fri, 3 Mar 2023 14:43:47 +0100 Subject: [PATCH 3/6] server: Refactor common grpc gateway code Signed-off-by: Marek Siarkowicz --- server/embed/serve.go | 46 +++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/server/embed/serve.go b/server/embed/serve.go index dbd19946edb0..e93aa20a4d5b 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -28,7 +28,7 @@ import ( etcdservergw "go.etcd.io/etcd/api/v3/etcdserverpb/gw" "go.etcd.io/etcd/client/pkg/v3/transport" - "go.etcd.io/etcd/client/v3/credentials" + clientcreds "go.etcd.io/etcd/client/v3/credentials" "go.etcd.io/etcd/pkg/v3/debugutil" "go.etcd.io/etcd/pkg/v3/httputil" "go.etcd.io/etcd/server/v3/config" @@ -49,6 +49,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/trace" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" ) @@ -137,20 +138,32 @@ func (sctx *serveCtx) serve( sctx.serviceRegister(gs) } var srv *http.Server + var gwmux *gw.ServeMux + if s.Cfg.EnableGRPCGateway { + var creds credentials.TransportCredentials + if sctx.insecure { + creds = insecure.NewCredentials() + } + if sctx.secure { + if s.Cfg.EnableGRPCGateway { + dtls := tlscfg.Clone() + // trust local server + dtls.InsecureSkipVerify = true + bundle := clientcreds.NewBundle(clientcreds.Config{TLSConfig: dtls}) + creds = bundle.TransportCredentials() + } + } + gwmux, err = sctx.registerGateway([]grpc.DialOption{grpc.WithTransportCredentials(creds)}) + if err != nil { + sctx.lg.Error("registerGateway failed", zap.Error(err)) + return err + } + } if sctx.insecure { grpcl := m.Match(cmux.HTTP2()) go func() { errHandler(gs.Serve(grpcl)) }() - var gwmux *gw.ServeMux - if s.Cfg.EnableGRPCGateway { - gwmux, err = sctx.registerGateway([]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}) - if err != nil { - sctx.lg.Error("registerGateway failed", zap.Error(err)) - return err - } - } - httpmux := sctx.createMux(gwmux, handler) srv = &http.Server{ @@ -168,19 +181,6 @@ func (sctx *serveCtx) serve( if sctx.secure { handler = grpcHandlerFunc(gs, handler) - var gwmux *gw.ServeMux - if s.Cfg.EnableGRPCGateway { - dtls := tlscfg.Clone() - // trust local server - dtls.InsecureSkipVerify = true - bundle := credentials.NewBundle(credentials.Config{TLSConfig: dtls}) - opts := []grpc.DialOption{grpc.WithTransportCredentials(bundle.TransportCredentials())} - gwmux, err = sctx.registerGateway(opts) - if err != nil { - return err - } - } - var tlsl net.Listener tlsl, err = transport.NewTLSListener(m.Match(cmux.Any()), tlsinfo) if err != nil { From 72a0c2965585059da6b364cdc75fe771a6ab9fae Mon Sep 17 00:00:00 2001 From: Marek Siarkowicz Date: Fri, 3 Mar 2023 14:55:32 +0100 Subject: [PATCH 4/6] server: Refactor common part of secure and insecure serve Signed-off-by: Marek Siarkowicz --- server/embed/serve.go | 54 +++++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/server/embed/serve.go b/server/embed/serve.go index e93aa20a4d5b..7a0c44021470 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -111,7 +111,6 @@ func (sctx *serveCtx) serve( sctx.lg.Info("ready to serve client requests") - m := cmux.New(sctx.l) v3c := v3client.New(s) servElection := v3election.NewElectionServer(v3c) servLock := v3lock.NewLockServer(v3c) @@ -159,47 +158,36 @@ func (sctx *serveCtx) serve( return err } } - + listener := sctx.l + if sctx.secure { + handler = grpcHandlerFunc(gs, handler) + listener, err = transport.NewTLSListener(listener, tlsinfo) + if err != nil { + return err + } + } + m := cmux.New(listener) + // TODO: add debug flag; enable logging when debug flag is set + httpmux := sctx.createMux(gwmux, handler) + srv = &http.Server{ + Handler: createAccessController(sctx.lg, s, httpmux), + TLSConfig: tlscfg, + ErrorLog: logger, // do not log user error + } + if err := configureHttpServer(srv, s.Cfg); err != nil { + sctx.lg.Error("Configure https server failed", zap.Error(err)) + return err + } if sctx.insecure { grpcl := m.Match(cmux.HTTP2()) go func() { errHandler(gs.Serve(grpcl)) }() - httpmux := sctx.createMux(gwmux, handler) - - srv = &http.Server{ - Handler: createAccessController(sctx.lg, s, httpmux), - ErrorLog: logger, // do not log user error - } - if err := configureHttpServer(srv, s.Cfg); err != nil { - sctx.lg.Error("Configure http server failed", zap.Error(err)) - return err - } httpl := m.Match(cmux.HTTP1()) go func() { errHandler(srv.Serve(httpl)) }() } if sctx.secure { - handler = grpcHandlerFunc(gs, handler) - - var tlsl net.Listener - tlsl, err = transport.NewTLSListener(m.Match(cmux.Any()), tlsinfo) - if err != nil { - return err - } - // TODO: add debug flag; enable logging when debug flag is set - httpmux := sctx.createMux(gwmux, handler) - - srv = &http.Server{ - Handler: createAccessController(sctx.lg, s, httpmux), - TLSConfig: tlscfg, - ErrorLog: logger, // do not log user error - } - if err := configureHttpServer(srv, s.Cfg); err != nil { - sctx.lg.Error("Configure https server failed", zap.Error(err)) - return err - } - go func() { errHandler(srv.Serve(tlsl)) }() - + go func() { errHandler(srv.Serve(listener)) }() } sctx.serversC <- &servers{secure: sctx.secure, grpc: gs, http: srv} From cc7640403a301c8efbbb6791df92d98f81f86132 Mon Sep 17 00:00:00 2001 From: Marek Siarkowicz Date: Tue, 7 Mar 2023 16:07:05 +0100 Subject: [PATCH 5/6] server: Fix cmux shutdown --- server/embed/etcd.go | 1 + server/embed/serve.go | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/server/embed/etcd.go b/server/embed/etcd.go index a03b4f1c9fe9..1e14eed12466 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -455,6 +455,7 @@ func (e *Etcd) Close() { } func stopServers(ctx context.Context, ss *servers) { + ss.cmux.Close() // first, close the http.Server ss.http.Shutdown(ctx) // do not grpc.Server.GracefulStop with TLS enabled etcd server diff --git a/server/embed/serve.go b/server/embed/serve.go index 7a0c44021470..ad725cd8bc40 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -73,6 +73,7 @@ type servers struct { secure bool grpc *grpc.Server http *http.Server + cmux cmux.CMux } func newServeCtx(lg *zap.Logger) *serveCtx { @@ -118,7 +119,12 @@ func (sctx *serveCtx) serve( var tlscfg *tls.Config // Make sure serversC is closed even if we prematurely exit the function. - defer close(sctx.serversC) + var closed bool + defer func() { + if !closed { + close(sctx.serversC) + } + }() if sctx.secure { tlscfg, err = tlsinfo.ServerConfig() if err != nil { @@ -190,7 +196,9 @@ func (sctx *serveCtx) serve( go func() { errHandler(srv.Serve(listener)) }() } - sctx.serversC <- &servers{secure: sctx.secure, grpc: gs, http: srv} + sctx.serversC <- &servers{secure: sctx.secure, grpc: gs, http: srv, cmux: m} + close(sctx.serversC) + closed = true msg := "serving client traffic securely" if sctx.insecure { From 17a73ea4f788b845bfb7a230ab81fd033b131228 Mon Sep 17 00:00:00 2001 From: Marek Siarkowicz Date: Fri, 3 Mar 2023 15:02:38 +0100 Subject: [PATCH 6/6] Fix #15402 by moving grpc server from under http server Signed-off-by: Marek Siarkowicz --- server/embed/serve.go | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/server/embed/serve.go b/server/embed/serve.go index ad725cd8bc40..d8022e3da944 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -131,7 +131,7 @@ func (sctx *serveCtx) serve( return err } } - gs := v3rpc.Server(s, tlscfg, nil, gopts...) + gs := v3rpc.Server(s, nil, nil, gopts...) defer func() { sctx.lg.Warn("stopping grpc server due to error", zap.Error(err)) gs.Stop() @@ -166,7 +166,6 @@ func (sctx *serveCtx) serve( } listener := sctx.l if sctx.secure { - handler = grpcHandlerFunc(gs, handler) listener, err = transport.NewTLSListener(listener, tlsinfo) if err != nil { return err @@ -193,7 +192,11 @@ func (sctx *serveCtx) serve( } if sctx.secure { - go func() { errHandler(srv.Serve(listener)) }() + grpcl := m.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc")) + go func() { errHandler(gs.Serve(grpcl)) }() + + httpl := m.Match(cmux.Any()) + go func() { errHandler(srv.Serve(httpl)) }() } sctx.serversC <- &servers{secure: sctx.secure, grpc: gs, http: srv, cmux: m} @@ -216,23 +219,6 @@ func configureHttpServer(srv *http.Server, cfg config.ServerConfig) error { }) } -// grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC -// connections or otherHandler otherwise. Given in gRPC docs. -func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler { - if otherHandler == nil { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - grpcServer.ServeHTTP(w, r) - }) - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { - grpcServer.ServeHTTP(w, r) - } else { - otherHandler.ServeHTTP(w, r) - } - }) -} - type registerHandlerFunc func(context.Context, *gw.ServeMux, *grpc.ClientConn) error func (sctx *serveCtx) registerGateway(opts []grpc.DialOption) (*gw.ServeMux, error) {