Skip to content

Commit ff2d607

Browse files
committed
client: block RPCs early until the resolver has returned addresses
1 parent ab525e9 commit ff2d607

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

Diff for: clientconn.go

+26
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
129129
dopts: defaultDialOptions(),
130130
blockingpicker: newPickerWrapper(),
131131
czData: new(channelzData),
132+
resolvedAddrs: make(chan struct{}),
132133
}
133134
cc.retryThrottler.Store((*retryThrottler)(nil))
134135
cc.ctx, cc.cancel = context.WithCancel(context.Background())
@@ -399,6 +400,10 @@ type ClientConn struct {
399400
balancerWrapper *ccBalancerWrapper
400401
retryThrottler atomic.Value
401402

403+
resolvedAddrsOnce sync.Once
404+
resolvedAddrs chan struct{}
405+
hasResolvedAddrs int32
406+
402407
channelzID int64 // channelz unique identification number
403408
czData *channelzData
404409
}
@@ -444,6 +449,26 @@ func (cc *ClientConn) scWatcher() {
444449
}
445450
}
446451

452+
// waitForResolvedAddrs blocks until the resolver has provided addresses or the
453+
// context expires. Returns nil unless the context expires first; otherwise
454+
// returns a status error based on the context.
455+
func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error {
456+
// This is on the RPC path, so we use an atomic to avoid the need to do a
457+
// more-expensive "select" below after the resolver has returned once.
458+
if atomic.LoadInt32(&cc.hasResolvedAddrs) != 0 {
459+
return nil
460+
}
461+
select {
462+
case <-cc.resolvedAddrs:
463+
atomic.StoreInt32(&cc.hasResolvedAddrs, 1)
464+
return nil
465+
case <-ctx.Done():
466+
return status.FromContextError(ctx.Err()).Err()
467+
case <-cc.ctx.Done():
468+
return ErrClientConnClosing
469+
}
470+
}
471+
447472
func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) {
448473
cc.mu.Lock()
449474
defer cc.mu.Unlock()
@@ -457,6 +482,7 @@ func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) {
457482
}
458483

459484
cc.curAddresses = addrs
485+
cc.resolvedAddrsOnce.Do(func() { close(cc.resolvedAddrs) })
460486

461487
if cc.dopts.balancerBuilder == nil {
462488
// Only look at balancer types and switch balancer if balancer dial

Diff for: stream.go

+5
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
160160
}()
161161
}
162162
c := defaultCallInfo()
163+
// Provide an opportunity for the first RPC to see the first service config
164+
// provided by the resolver.
165+
if err := cc.waitForResolvedAddrs(ctx); err != nil {
166+
return err
167+
}
163168
mc := cc.GetMethodConfig(method)
164169
if mc.WaitForReady != nil {
165170
c.failFast = !*mc.WaitForReady

Diff for: test/end2end_test.go

+52
Original file line numberDiff line numberDiff line change
@@ -7103,3 +7103,55 @@ func (lis notifyingListener) Accept() (net.Conn, error) {
71037103
defer lis.connEstablished.Fire()
71047104
return lis.Listener.Accept()
71057105
}
7106+
7107+
func TestRPCWaitsForResolver(t *testing.T) {
7108+
te := testServiceConfigSetup(t, tcpClearRREnv)
7109+
te.startServer(&testServer{security: tcpClearRREnv.security})
7110+
defer te.tearDown()
7111+
r, rcleanup := manual.GenerateAndRegisterManualResolver()
7112+
defer rcleanup()
7113+
7114+
te.resolverScheme = r.Scheme()
7115+
te.nonBlockingDial = true
7116+
cc := te.clientConn()
7117+
tc := testpb.NewTestServiceClient(cc)
7118+
7119+
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
7120+
defer cancel()
7121+
// With no resolved addresses yet, this will timeout.
7122+
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded {
7123+
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
7124+
}
7125+
7126+
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
7127+
defer cancel()
7128+
go func() {
7129+
time.Sleep(time.Second)
7130+
r.NewServiceConfig(`{
7131+
"methodConfig": [
7132+
{
7133+
"name": [
7134+
{
7135+
"service": "grpc.testing.TestService",
7136+
"method": "UnaryCall"
7137+
}
7138+
],
7139+
"maxRequestMessageBytes": 0
7140+
}
7141+
]
7142+
}`)
7143+
r.NewAddress([]resolver.Address{{Addr: te.srvAddr}})
7144+
}()
7145+
// We wait a second before providing a service config and resolving
7146+
// addresses. So this will wait for that and then honor the
7147+
// maxRequestMessageBytes it contains.
7148+
if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{ResponseType: testpb.PayloadType_UNCOMPRESSABLE}); status.Code(err) != codes.ResourceExhausted {
7149+
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, nil", err)
7150+
}
7151+
if got := ctx.Err(); got != nil {
7152+
t.Fatalf("ctx.Err() = %v; want nil (deadline should be set short by service config)", got)
7153+
}
7154+
if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
7155+
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, nil", err)
7156+
}
7157+
}

0 commit comments

Comments
 (0)