diff --git a/clientconn.go b/clientconn.go index a1e1a98006a7..93b0aa2c06ad 100644 --- a/clientconn.go +++ b/clientconn.go @@ -68,6 +68,9 @@ var ( errConnClosing = errors.New("grpc: the connection is closing") // errBalancerClosed indicates that the balancer is closed. errBalancerClosed = errors.New("grpc: balancer is closed") + // invalidDefaultServiceConfigErrPrefix is used to prefix the json parsing error for the default + // service config. + invalidDefaultServiceConfigErrPrefix = "grpc: the provided default service config is invalid" ) // The following errors are returned from Dial and DialContext @@ -173,6 +176,13 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } } + if cc.dopts.defaultServiceConfigRawJSON != nil { + sc, err := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON) + if err != nil { + return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, err) + } + cc.dopts.defaultServiceConfig = sc + } cc.mkp = cc.dopts.copts.KeepaliveParams if cc.dopts.copts.Dialer == nil { @@ -214,7 +224,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * select { case sc, ok := <-cc.dopts.scChan: if ok { - cc.sc = sc + cc.sc = &sc scSet = true } default: @@ -260,7 +270,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * select { case sc, ok := <-cc.dopts.scChan: if ok { - cc.sc = sc + cc.sc = &sc } case <-ctx.Done(): return nil, ctx.Err() @@ -382,8 +392,7 @@ type ClientConn struct { mu sync.RWMutex resolverWrapper *ccResolverWrapper - sc ServiceConfig - scRaw string + sc *ServiceConfig conns map[*addrConn]struct{} // Keepalive parameter can be updated if a GoAway is received. mkp keepalive.ClientParameters @@ -429,8 +438,7 @@ func (cc *ClientConn) scWatcher() { cc.mu.Lock() // TODO: load balance policy runtime change is ignored. // We may revisit this decision in the future. - cc.sc = sc - cc.scRaw = "" + cc.sc = &sc cc.mu.Unlock() case <-cc.ctx.Done(): return @@ -457,6 +465,24 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error { } } +// gRPC should resort to default service config when: +// * resolver service config is disabled +// * or, resolver does not return a service config or returns an invalid one. +func (cc *ClientConn) fallbackToDefaultServiceConfig(sc string) bool { + if cc.dopts.disableServiceConfig { + return true + } + // The logic below is temporary, will be removed once we change the resolver.State ServiceConfig field type. + // Right now, we assume that empty service config string means resolver does not return a config. + if sc == "" { + return true + } + // TODO: the logic below is temporary. Once we finish the logic to validate service config + // in resolver, we will replace the logic below. + _, err := parseServiceConfig(sc) + return err != nil +} + func (cc *ClientConn) updateResolverState(s resolver.State) error { cc.mu.Lock() defer cc.mu.Unlock() @@ -467,29 +493,26 @@ func (cc *ClientConn) updateResolverState(s resolver.State) error { return nil } - if !cc.dopts.disableServiceConfig && cc.scRaw != s.ServiceConfig { - // New service config; apply it. + if cc.fallbackToDefaultServiceConfig(s.ServiceConfig) { + if cc.dopts.defaultServiceConfig != nil && cc.sc == nil { + cc.applyServiceConfig(cc.dopts.defaultServiceConfig) + } + } else { + // TODO: the parsing logic below will be moved inside resolver. sc, err := parseServiceConfig(s.ServiceConfig) if err != nil { - fmt.Println("error parsing config: ", err) return err } - cc.scRaw = s.ServiceConfig - cc.sc = sc - - if cc.sc.retryThrottling != nil { - newThrottler := &retryThrottler{ - tokens: cc.sc.retryThrottling.MaxTokens, - max: cc.sc.retryThrottling.MaxTokens, - thresh: cc.sc.retryThrottling.MaxTokens / 2, - ratio: cc.sc.retryThrottling.TokenRatio, - } - cc.retryThrottler.Store(newThrottler) - } else { - cc.retryThrottler.Store((*retryThrottler)(nil)) + if cc.sc == nil || cc.sc.rawJSONString != s.ServiceConfig { + cc.applyServiceConfig(sc) } } + // update the service config that will be sent to balancer. + if cc.sc != nil { + s.ServiceConfig = cc.sc.rawJSONString + } + if cc.dopts.balancerBuilder == nil { // Only look at balancer types and switch balancer if balancer dial // option is not set. @@ -504,7 +527,7 @@ func (cc *ClientConn) updateResolverState(s resolver.State) error { // TODO: use new loadBalancerConfig field with appropriate priority. if isGRPCLB { newBalancerName = grpclbName - } else if cc.sc.LB != nil { + } else if cc.sc != nil && cc.sc.LB != nil { newBalancerName = *cc.sc.LB } else { newBalancerName = PickFirstBalancerName @@ -724,6 +747,9 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { // TODO: Avoid the locking here. cc.mu.RLock() defer cc.mu.RUnlock() + if cc.sc == nil { + return MethodConfig{} + } m, ok := cc.sc.Methods[method] if !ok { i := strings.LastIndex(method, "/") @@ -735,6 +761,9 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { func (cc *ClientConn) healthCheckConfig() *healthCheckConfig { cc.mu.RLock() defer cc.mu.RUnlock() + if cc.sc == nil { + return nil + } return cc.sc.healthCheckConfig } @@ -748,6 +777,28 @@ func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method st return t, done, nil } +func (cc *ClientConn) applyServiceConfig(sc *ServiceConfig) error { + if sc == nil { + // should never reach here. + return fmt.Errorf("got nil pointer for service config") + } + cc.sc = sc + + if cc.sc.retryThrottling != nil { + newThrottler := &retryThrottler{ + tokens: cc.sc.retryThrottling.MaxTokens, + max: cc.sc.retryThrottling.MaxTokens, + thresh: cc.sc.retryThrottling.MaxTokens / 2, + ratio: cc.sc.retryThrottling.TokenRatio, + } + cc.retryThrottler.Store(newThrottler) + } else { + cc.retryThrottler.Store((*retryThrottler)(nil)) + } + + return nil +} + func (cc *ClientConn) resolveNow(o resolver.ResolveNowOption) { cc.mu.RLock() r := cc.resolverWrapper diff --git a/clientconn_test.go b/clientconn_test.go index 9b5ac03547b6..ab1db5664049 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -24,6 +24,7 @@ import ( "fmt" "math" "net" + "strings" "sync/atomic" "testing" "time" @@ -1232,3 +1233,90 @@ func (s) TestUpdateAddresses_RetryFromFirstAddr(t *testing.T) { t.Fatal("timed out waiting for any server to be contacted after tryUpdateAddrs") } } + +func (s) TestDefaultServiceConfig(t *testing.T) { + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + addr := r.Scheme() + ":///non.existent" + js := `{ + "methodConfig": [ + { + "name": [ + { + "service": "foo", + "method": "bar" + } + ], + "waitForReady": true + } + ] +}` + testInvalidDefaultServiceConfig(t) + testDefaultServiceConfigWhenResolverServiceConfigDisabled(t, r, addr, js) + testDefaultServiceConfigWhenResolverDoesNotReturnServiceConfig(t, r, addr, js) + testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t, r, addr, js) +} + +func verifyWaitForReadyEqualsTrue(cc *ClientConn) bool { + var i int + for i = 0; i < 10; i++ { + mc := cc.GetMethodConfig("/foo/bar") + if mc.WaitForReady != nil && *mc.WaitForReady == true { + break + } + time.Sleep(100 * time.Millisecond) + } + return i != 10 +} + +func testInvalidDefaultServiceConfig(t *testing.T) { + _, err := Dial("fake.com", WithInsecure(), WithDefaultServiceConfig("")) + if !strings.Contains(err.Error(), invalidDefaultServiceConfigErrPrefix) { + t.Fatalf("Dial got err: %v, want err contains: %v", err, invalidDefaultServiceConfigErrPrefix) + } +} + +func testDefaultServiceConfigWhenResolverServiceConfigDisabled(t *testing.T, r resolver.Resolver, addr string, js string) { + cc, err := Dial(addr, WithInsecure(), WithDisableServiceConfig(), WithDefaultServiceConfig(js)) + if err != nil { + t.Fatalf("Dial(%s, _) = _, %v, want _, ", addr, err) + } + defer cc.Close() + // Resolver service config gets ignored since resolver service config is disabled. + r.(*manual.Resolver).UpdateState(resolver.State{ + Addresses: []resolver.Address{{Addr: addr}}, + ServiceConfig: "{}", + }) + if !verifyWaitForReadyEqualsTrue(cc) { + t.Fatal("default service config failed to be applied after 1s") + } +} + +func testDefaultServiceConfigWhenResolverDoesNotReturnServiceConfig(t *testing.T, r resolver.Resolver, addr string, js string) { + cc, err := Dial(addr, WithInsecure(), WithDefaultServiceConfig(js)) + if err != nil { + t.Fatalf("Dial(%s, _) = _, %v, want _, ", addr, err) + } + defer cc.Close() + r.(*manual.Resolver).UpdateState(resolver.State{ + Addresses: []resolver.Address{{Addr: addr}}, + }) + if !verifyWaitForReadyEqualsTrue(cc) { + t.Fatal("default service config failed to be applied after 1s") + } +} + +func testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t *testing.T, r resolver.Resolver, addr string, js string) { + cc, err := Dial(addr, WithInsecure(), WithDefaultServiceConfig(js)) + if err != nil { + t.Fatalf("Dial(%s, _) = _, %v, want _, ", addr, err) + } + defer cc.Close() + r.(*manual.Resolver).UpdateState(resolver.State{ + Addresses: []resolver.Address{{Addr: addr}}, + ServiceConfig: "{something wrong,}", + }) + if !verifyWaitForReadyEqualsTrue(cc) { + t.Fatal("default service config failed to be applied after 1s") + } +} diff --git a/dialoptions.go b/dialoptions.go index a0743a9e75fa..e114fecbb7b4 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -55,14 +55,16 @@ type dialOptions struct { // balancer, and also by WithBalancerName dial option. balancerBuilder balancer.Builder // This is to support grpclb. - resolverBuilder resolver.Builder - reqHandshake envconfig.RequireHandshakeSetting - channelzParentID int64 - disableServiceConfig bool - disableRetry bool - disableHealthCheck bool - healthCheckFunc internal.HealthChecker - minConnectTimeout func() time.Duration + resolverBuilder resolver.Builder + reqHandshake envconfig.RequireHandshakeSetting + channelzParentID int64 + disableServiceConfig bool + disableRetry bool + disableHealthCheck bool + healthCheckFunc internal.HealthChecker + minConnectTimeout func() time.Duration + defaultServiceConfig *ServiceConfig // defaultServiceConfig is parsed from defaultServiceConfigRawJSON. + defaultServiceConfigRawJSON *string } // DialOption configures how we set up the connection. @@ -441,12 +443,27 @@ func WithChannelzParentID(id int64) DialOption { // WithDisableServiceConfig returns a DialOption that causes grpc to ignore any // service config provided by the resolver and provides a hint to the resolver // to not fetch service configs. +// +// Note that, this dial option only disables service config from resolver. If +// default service config is provided, grpc will use the default service config. func WithDisableServiceConfig() DialOption { return newFuncDialOption(func(o *dialOptions) { o.disableServiceConfig = true }) } +// WithDefaultServiceConfig returns a DialOption that configures the default +// service config, which will be used in cases where: +// 1. WithDisableServiceConfig is called. +// 2. Resolver does not return service config or if the resolver gets and invalid config. +// +// This API is EXPERIMENTAL. +func WithDefaultServiceConfig(s string) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.defaultServiceConfigRawJSON = &s + }) +} + // WithDisableRetry returns a DialOption that disables retries, even if the // service config enables them. This does not impact transparent retries, which // will happen automatically if no data is written to the wire or if the RPC is diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go index 176da7bd3c48..e9cef3a92b55 100644 --- a/resolver_conn_wrapper.go +++ b/resolver_conn_wrapper.go @@ -118,7 +118,7 @@ func (ccr *ccResolverWrapper) UpdateState(s resolver.State) { ccr.curState = s } -// NewAddress is called by the resolver implemenetion to send addresses to gRPC. +// NewAddress is called by the resolver implementation to send addresses to gRPC. func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { if ccr.isDone() { return @@ -131,7 +131,7 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { ccr.cc.updateResolverState(ccr.curState) } -// NewServiceConfig is called by the resolver implemenetion to send service +// NewServiceConfig is called by the resolver implementation to send service // configs to gRPC. func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { if ccr.isDone() { diff --git a/service_config.go b/service_config.go index 982a3bf21e65..1c5227426f49 100644 --- a/service_config.go +++ b/service_config.go @@ -99,6 +99,9 @@ type ServiceConfig struct { // healthCheckConfig must be set as one of the requirement to enable LB channel // health check. healthCheckConfig *healthCheckConfig + // rawJSONString stores service config json string that get parsed into + // this service config struct. + rawJSONString string } // healthCheckConfig defines the go-native version of the LB channel health check config. @@ -238,24 +241,22 @@ type jsonSC struct { HealthCheckConfig *healthCheckConfig } -func parseServiceConfig(js string) (ServiceConfig, error) { - if len(js) == 0 { - return ServiceConfig{}, fmt.Errorf("no JSON service config provided") - } +func parseServiceConfig(js string) (*ServiceConfig, error) { var rsc jsonSC err := json.Unmarshal([]byte(js), &rsc) if err != nil { grpclog.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) - return ServiceConfig{}, err + return nil, err } sc := ServiceConfig{ LB: rsc.LoadBalancingPolicy, Methods: make(map[string]MethodConfig), retryThrottling: rsc.RetryThrottling, healthCheckConfig: rsc.HealthCheckConfig, + rawJSONString: js, } if rsc.MethodConfig == nil { - return sc, nil + return &sc, nil } for _, m := range *rsc.MethodConfig { @@ -265,7 +266,7 @@ func parseServiceConfig(js string) (ServiceConfig, error) { d, err := parseDuration(m.Timeout) if err != nil { grpclog.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) - return ServiceConfig{}, err + return nil, err } mc := MethodConfig{ @@ -274,7 +275,7 @@ func parseServiceConfig(js string) (ServiceConfig, error) { } if mc.retryPolicy, err = convertRetryPolicy(m.RetryPolicy); err != nil { grpclog.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) - return ServiceConfig{}, err + return nil, err } if m.MaxRequestMessageBytes != nil { if *m.MaxRequestMessageBytes > int64(maxInt) { @@ -305,7 +306,7 @@ func parseServiceConfig(js string) (ServiceConfig, error) { sc.retryThrottling = nil } } - return sc, nil + return &sc, nil } func convertRetryPolicy(jrp *jsonRetryPolicy) (p *retryPolicy, err error) { diff --git a/service_config_test.go b/service_config_test.go index 020b643f89c8..a21416303c75 100644 --- a/service_config_test.go +++ b/service_config_test.go @@ -29,7 +29,7 @@ import ( func (s) TestParseLoadBalancer(t *testing.T) { testcases := []struct { scjs string - wantSC ServiceConfig + wantSC *ServiceConfig wantErr bool }{ { @@ -47,7 +47,7 @@ func (s) TestParseLoadBalancer(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ LB: newString("round_robin"), Methods: map[string]MethodConfig{ "/foo/Bar": { @@ -72,14 +72,14 @@ func (s) TestParseLoadBalancer(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, } for _, c := range testcases { sc, err := parseServiceConfig(c.scjs) - if c.wantErr != (err != nil) || !reflect.DeepEqual(sc, c.wantSC) { + if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(sc, c.wantSC) { t.Fatalf("parseServiceConfig(%s) = %+v, %v, want %+v, %v", c.scjs, sc, err, c.wantSC, c.wantErr) } } @@ -88,7 +88,7 @@ func (s) TestParseLoadBalancer(t *testing.T) { func (s) TestParseWaitForReady(t *testing.T) { testcases := []struct { scjs string - wantSC ServiceConfig + wantSC *ServiceConfig wantErr bool }{ { @@ -105,7 +105,7 @@ func (s) TestParseWaitForReady(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ Methods: map[string]MethodConfig{ "/foo/Bar": { WaitForReady: newBool(true), @@ -128,7 +128,7 @@ func (s) TestParseWaitForReady(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ Methods: map[string]MethodConfig{ "/foo/Bar": { WaitForReady: newBool(false), @@ -160,14 +160,14 @@ func (s) TestParseWaitForReady(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, } for _, c := range testcases { sc, err := parseServiceConfig(c.scjs) - if c.wantErr != (err != nil) || !reflect.DeepEqual(sc, c.wantSC) { + if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(sc, c.wantSC) { t.Fatalf("parseServiceConfig(%s) = %+v, %v, want %+v, %v", c.scjs, sc, err, c.wantSC, c.wantErr) } } @@ -176,7 +176,7 @@ func (s) TestParseWaitForReady(t *testing.T) { func (s) TestPraseTimeOut(t *testing.T) { testcases := []struct { scjs string - wantSC ServiceConfig + wantSC *ServiceConfig wantErr bool }{ { @@ -193,7 +193,7 @@ func (s) TestPraseTimeOut(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ Methods: map[string]MethodConfig{ "/foo/Bar": { Timeout: newDuration(time.Second), @@ -216,7 +216,7 @@ func (s) TestPraseTimeOut(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, { @@ -242,14 +242,14 @@ func (s) TestPraseTimeOut(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, } for _, c := range testcases { sc, err := parseServiceConfig(c.scjs) - if c.wantErr != (err != nil) || !reflect.DeepEqual(sc, c.wantSC) { + if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(sc, c.wantSC) { t.Fatalf("parseServiceConfig(%s) = %+v, %v, want %+v, %v", c.scjs, sc, err, c.wantSC, c.wantErr) } } @@ -258,7 +258,7 @@ func (s) TestPraseTimeOut(t *testing.T) { func (s) TestPraseMsgSize(t *testing.T) { testcases := []struct { scjs string - wantSC ServiceConfig + wantSC *ServiceConfig wantErr bool }{ { @@ -276,7 +276,7 @@ func (s) TestPraseMsgSize(t *testing.T) { } ] }`, - ServiceConfig{ + &ServiceConfig{ Methods: map[string]MethodConfig{ "/foo/Bar": { MaxReqSize: newInt(1024), @@ -311,14 +311,14 @@ func (s) TestPraseMsgSize(t *testing.T) { } ] }`, - ServiceConfig{}, + nil, true, }, } for _, c := range testcases { sc, err := parseServiceConfig(c.scjs) - if c.wantErr != (err != nil) || !reflect.DeepEqual(sc, c.wantSC) { + if c.wantErr != (err != nil) || !scCompareWithRawJSONSkipped(sc, c.wantSC) { t.Fatalf("parseServiceConfig(%s) = %+v, %v, want %+v, %v", c.scjs, sc, err, c.wantSC, c.wantErr) } } @@ -384,3 +384,15 @@ func newDuration(b time.Duration) *time.Duration { func newString(b string) *string { return &b } + +func scCompareWithRawJSONSkipped(s1, s2 *ServiceConfig) bool { + if s1 == nil && s2 == nil { + return true + } + if (s1 == nil) != (s2 == nil) { + return false + } + s1.rawJSONString = "" + s2.rawJSONString = "" + return reflect.DeepEqual(s1, s2) +}