From 9fab3e0969d85d2b24ac1f1279c891d91b95dea7 Mon Sep 17 00:00:00 2001 From: Yuxuan Li Date: Mon, 25 Mar 2019 14:33:16 -0700 Subject: [PATCH] fix reviews again --- clientconn.go | 53 ++++++++++++++++++++++++------------------ clientconn_test.go | 24 +++++++++++++++++++ dialoptions.go | 30 +++++++++++------------- service_config.go | 12 +++------- service_config_test.go | 46 ++++++++++++++++++++---------------- 5 files changed, 96 insertions(+), 69 deletions(-) diff --git a/clientconn.go b/clientconn.go index bea3efe5435a..0c14116049c2 100644 --- a/clientconn.go +++ b/clientconn.go @@ -86,6 +86,9 @@ var ( // errCredentialsConflict indicates that grpc.WithTransportCredentials() // and grpc.WithInsecure() are both called for a connection. errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)") + // errInvalidDefaultServiceConfig indicates that grpc.WithDefaultServiceConfig(string) provides + // an invalid service config string. + errInvalidDefaultServiceConfig = errors.New("grpc: the provided default service config is invalid") ) const ( @@ -173,6 +176,13 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } } + if cc.dopts.defaultServiceConfigRawJSON != nil { + sc, err := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON) + if err != nil { + return nil, errInvalidDefaultServiceConfig + } + cc.dopts.defaultServiceConfig = sc + } cc.mkp = cc.dopts.copts.KeepaliveParams if cc.dopts.copts.Dialer == nil { @@ -457,26 +467,25 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error { } } -// Apply default service config when default service config is configured and: +// gRPC should resort to default service config when: // * resolver service config is disabled // * or, resolver does not return a service config or returns an invalid one. -func (cc *ClientConn) shouldApplyDefaultServiceConfig(sc string) bool { - if cc.dopts.defaultServiceConfig != nil { - if cc.dopts.disableServiceConfig { - return true - } - // The logic below is temporary, will be removed once we change the resolver.State ServiceConfig field type. - // Right now, we assume that empty service config string means resolver does not return a config. - if sc == "" { - return false - } - // TODO: the logic below is temporary. Once we finish the logic to validate service config - // in resolver, we will replace the logic below. - _, err := parseServiceConfig(sc) - if err != nil { - return true - } +func (cc *ClientConn) fallbackToDefaultServiceConfig(sc string) bool { + if cc.dopts.disableServiceConfig { + return true } + // The logic below is temporary, will be removed once we change the resolver.State ServiceConfig field type. + // Right now, we assume that empty service config string means resolver does not return a config. + if sc == "" { + return true + } + // TODO: the logic below is temporary. Once we finish the logic to validate service config + // in resolver, we will replace the logic below. + _, err := parseServiceConfig(sc) + if err != nil { + return true + } + return false } @@ -490,14 +499,14 @@ func (cc *ClientConn) updateResolverState(s resolver.State) error { return nil } - if cc.shouldApplyDefaultServiceConfig(s.ServiceConfig) { - if cc.sc.rawJSONString == nil { + if cc.fallbackToDefaultServiceConfig(s.ServiceConfig) { + if cc.dopts.defaultServiceConfig != nil && cc.sc.rawJSONString == nil { cc.applyServiceConfig(cc.dopts.defaultServiceConfig) } } else { // TODO: the parsing logic below will be moved inside resolver. sc, err := parseServiceConfig(s.ServiceConfig) - if err != nil && s.ServiceConfig != "" { // s.ServiceConfig != "" is a temporary special case. + if err != nil { return err } if cc.sc.rawJSONString == nil || *cc.sc.rawJSONString != s.ServiceConfig { @@ -763,11 +772,9 @@ func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method st return t, done, nil } -// Parse and apply the service config. If sc is passed as a non-nil pointer, which indicates we have -// a parsed service config, we will skip the parsing. It will also skip the whole processing if -// the new service config is the same as the old one. func (cc *ClientConn) applyServiceConfig(sc *ServiceConfig) error { if sc == nil { + // should never reach here. return fmt.Errorf("got nil pointer for service config") } cc.sc = *sc diff --git a/clientconn_test.go b/clientconn_test.go index 84f28f63b3f4..ba6550fe882b 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -1250,8 +1250,10 @@ func (s) TestDefaultServiceConfig(t *testing.T) { } ] }` + testInvalidDefaultServiceConfig(t) testDefaultServiceConfigWhenResolverServiceConfigDisabled(t, r, addr, js) testDefaultServiceConfigWhenResolverDoesNotReturnServiceConfig(t, r, addr, js) + testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t, r, addr, js) } func verifyWaitForReadyEqualsTrue(cc *ClientConn) bool { @@ -1266,6 +1268,13 @@ func verifyWaitForReadyEqualsTrue(cc *ClientConn) bool { return i != 10 } +func testInvalidDefaultServiceConfig(t *testing.T) { + _, err := Dial("fake.com", WithInsecure(), WithDefaultServiceConfig("")) + if err != errInvalidDefaultServiceConfig { + t.Fatalf("Dial got err: %v, want: %v", err, errInvalidDefaultServiceConfig) + } +} + func testDefaultServiceConfigWhenResolverServiceConfigDisabled(t *testing.T, r resolver.Resolver, addr string, js string) { cc, err := Dial(addr, WithInsecure(), WithDisableServiceConfig(), WithDefaultServiceConfig(js)) if err != nil { @@ -1295,3 +1304,18 @@ func testDefaultServiceConfigWhenResolverDoesNotReturnServiceConfig(t *testing.T t.Fatal("default service config failed to be applied after 1s") } } + +func testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t *testing.T, r resolver.Resolver, addr string, js string) { + cc, err := Dial(addr, WithInsecure(), WithDefaultServiceConfig(js)) + if err != nil { + t.Fatalf("Dial(%s, _) = _, %v, want _, ", addr, err) + } + defer cc.Close() + r.(*manual.Resolver).UpdateState(resolver.State{ + Addresses: []resolver.Address{{Addr: addr}}, + ServiceConfig: "{something wrong,}", + }) + if !verifyWaitForReadyEqualsTrue(cc) { + t.Fatal("default service config failed to be applied after 1s") + } +} diff --git a/dialoptions.go b/dialoptions.go index 51a58a9e0404..a9ccf496c2ff 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -55,15 +55,16 @@ type dialOptions struct { // balancer, and also by WithBalancerName dial option. balancerBuilder balancer.Builder // This is to support grpclb. - resolverBuilder resolver.Builder - reqHandshake envconfig.RequireHandshakeSetting - channelzParentID int64 - disableServiceConfig bool - disableRetry bool - disableHealthCheck bool - healthCheckFunc internal.HealthChecker - minConnectTimeout func() time.Duration - defaultServiceConfig *ServiceConfig + resolverBuilder resolver.Builder + reqHandshake envconfig.RequireHandshakeSetting + channelzParentID int64 + disableServiceConfig bool + disableRetry bool + disableHealthCheck bool + healthCheckFunc internal.HealthChecker + minConnectTimeout func() time.Duration + defaultServiceConfig *ServiceConfig + defaultServiceConfigRawJSON *string } // DialOption configures how we set up the connection. @@ -452,16 +453,11 @@ func WithDisableServiceConfig() DialOption { // will be used in cases where: // 1. WithDisableServiceConfig is called. // 2. Resolver does not return service config or if the resolver gets and invalid config. -// It is strongly recommended that caller of this function verifies the validity of the input string -// by using the grpc.ValidateServiceConfig function. +// +// This API is EXPERIMENTAL. func WithDefaultServiceConfig(s string) DialOption { return newFuncDialOption(func(o *dialOptions) { - sc, err := parseServiceConfig(s) - if err != nil { - grpclog.Warningf("the provided service config is invalid, err: %v", err) - return - } - o.defaultServiceConfig = sc + o.defaultServiceConfigRawJSON = &s }) } diff --git a/service_config.go b/service_config.go index 931108778a4a..ff401e1c4579 100644 --- a/service_config.go +++ b/service_config.go @@ -247,7 +247,7 @@ func parseServiceConfig(js string) (*ServiceConfig, error) { err := json.Unmarshal([]byte(js), &rsc) if err != nil { grpclog.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) - return &ServiceConfig{}, err + return nil, err } sc := ServiceConfig{ LB: rsc.LoadBalancingPolicy, @@ -267,7 +267,7 @@ func parseServiceConfig(js string) (*ServiceConfig, error) { d, err := parseDuration(m.Timeout) if err != nil { grpclog.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) - return &ServiceConfig{}, err + return nil, err } mc := MethodConfig{ @@ -276,7 +276,7 @@ func parseServiceConfig(js string) (*ServiceConfig, error) { } if mc.retryPolicy, err = convertRetryPolicy(m.RetryPolicy); err != nil { grpclog.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) - return &ServiceConfig{}, err + return nil, err } if m.MaxRequestMessageBytes != nil { if *m.MaxRequestMessageBytes > int64(maxInt) { @@ -372,9 +372,3 @@ func getMaxSize(mcMax, doptMax *int, defaultVal int) *int { func newInt(b int) *int { return &b } - -// ValidateServiceConfig validates the input service config json string and returns the error. -func ValidateServiceConfig(js string) error { - _, err := parseServiceConfig(js) - return err -} diff --git a/service_config_test.go b/service_config_test.go index 18c81eeb121e..afa02a990d04 100644 --- a/service_config_test.go +++ b/service_config_test.go @@ -29,7 +29,7 @@ import ( func (s) TestParseLoadBalancer(t *testing.T) { testcases := []struct { scjs string - wantSC ServiceConfig + wantSC *ServiceConfig wantErr bool }{ { @@ -47,7 +47,7 @@ func (s) TestParseLoadBalancer(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ LB: newString("round_robin"), Methods: map[string]MethodConfig{ "/foo/Bar": { @@ -72,15 +72,15 @@ func (s) TestParseLoadBalancer(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, } for _, c := range testcases { sc, err := parseServiceConfig(c.scjs) - if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(*sc, c.wantSC) { - t.Fatalf("parseServiceConfig(%s) = %+v, %v, want %+v, %v", c.scjs, *sc, err, c.wantSC, c.wantErr) + if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(sc, c.wantSC) { + t.Fatalf("parseServiceConfig(%s) = %+v, %v, want %+v, %v", c.scjs, sc, err, c.wantSC, c.wantErr) } } } @@ -88,7 +88,7 @@ func (s) TestParseLoadBalancer(t *testing.T) { func (s) TestParseWaitForReady(t *testing.T) { testcases := []struct { scjs string - wantSC ServiceConfig + wantSC *ServiceConfig wantErr bool }{ { @@ -105,7 +105,7 @@ func (s) TestParseWaitForReady(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ Methods: map[string]MethodConfig{ "/foo/Bar": { WaitForReady: newBool(true), @@ -128,7 +128,7 @@ func (s) TestParseWaitForReady(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ Methods: map[string]MethodConfig{ "/foo/Bar": { WaitForReady: newBool(false), @@ -160,14 +160,14 @@ func (s) TestParseWaitForReady(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, } for _, c := range testcases { sc, err := parseServiceConfig(c.scjs) - if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(*sc, c.wantSC) { + if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(sc, c.wantSC) { t.Fatalf("parseServiceConfig(%s) = %+v, %v, want %+v, %v", c.scjs, sc, err, c.wantSC, c.wantErr) } } @@ -176,7 +176,7 @@ func (s) TestParseWaitForReady(t *testing.T) { func (s) TestPraseTimeOut(t *testing.T) { testcases := []struct { scjs string - wantSC ServiceConfig + wantSC *ServiceConfig wantErr bool }{ { @@ -193,7 +193,7 @@ func (s) TestPraseTimeOut(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ Methods: map[string]MethodConfig{ "/foo/Bar": { Timeout: newDuration(time.Second), @@ -216,7 +216,7 @@ func (s) TestPraseTimeOut(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, { @@ -242,14 +242,14 @@ func (s) TestPraseTimeOut(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, } for _, c := range testcases { sc, err := parseServiceConfig(c.scjs) - if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(*sc, c.wantSC) { + if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(sc, c.wantSC) { t.Fatalf("parseServiceConfig(%s) = %+v, %v, want %+v, %v", c.scjs, sc, err, c.wantSC, c.wantErr) } } @@ -258,7 +258,7 @@ func (s) TestPraseTimeOut(t *testing.T) { func (s) TestPraseMsgSize(t *testing.T) { testcases := []struct { scjs string - wantSC ServiceConfig + wantSC *ServiceConfig wantErr bool }{ { @@ -276,7 +276,7 @@ func (s) TestPraseMsgSize(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ Methods: map[string]MethodConfig{ "/foo/Bar": { MaxReqSize: newInt(1024), @@ -311,14 +311,14 @@ func (s) TestPraseMsgSize(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, } for _, c := range testcases { sc, err := parseServiceConfig(c.scjs) - if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(*sc, c.wantSC) { + if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(sc, c.wantSC) { t.Fatalf("parseServiceConfig(%s) = %+v, %v, want %+v, %v", c.scjs, sc, err, c.wantSC, c.wantErr) } } @@ -385,7 +385,13 @@ func newString(b string) *string { return &b } -func scCompareWithRawJSONSkipped(s1, s2 ServiceConfig) bool { +func scCompareWithRawJSONSkipped(s1, s2 *ServiceConfig) bool { + if s1 == nil && s2 == nil { + return true + } + if (s1 == nil) != (s2 == nil) { + return false + } s1.rawJSONString = nil s2.rawJSONString = nil return reflect.DeepEqual(s1, s2)