Skip to content

Commit cceeb34

Browse files
committed
client: block RPCs early until the resolver has returned addresses
1 parent 59a2cfb commit cceeb34

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

Diff for: clientconn.go

+26
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
130130
dopts: defaultDialOptions(),
131131
blockingpicker: newPickerWrapper(),
132132
czData: new(channelzData),
133+
resolvedAddrs: make(chan struct{}),
133134
}
134135
cc.retryThrottler.Store((*retryThrottler)(nil))
135136
cc.ctx, cc.cancel = context.WithCancel(context.Background())
@@ -402,6 +403,10 @@ type ClientConn struct {
402403
balancerWrapper *ccBalancerWrapper
403404
retryThrottler atomic.Value
404405

406+
resolvedAddrsOnce sync.Once
407+
resolvedAddrs chan struct{}
408+
hasResolvedAddrs int32
409+
405410
channelzID int64 // channelz unique identification number
406411
czData *channelzData
407412
}
@@ -447,6 +452,26 @@ func (cc *ClientConn) scWatcher() {
447452
}
448453
}
449454

455+
// waitForResolvedAddrs blocks until the resolver has provided addresses or the
456+
// context expires. Returns nil unless the context expires first; otherwise
457+
// returns a status error based on the context.
458+
func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error {
459+
// This is on the RPC path, so we use an atomic to avoid the need to do a
460+
// more-expensive "select" below after the resolver has returned once.
461+
if atomic.LoadInt32(&cc.hasResolvedAddrs) != 0 {
462+
return nil
463+
}
464+
select {
465+
case <-cc.resolvedAddrs:
466+
atomic.StoreInt32(&cc.hasResolvedAddrs, 1)
467+
return nil
468+
case <-ctx.Done():
469+
return status.FromContextError(ctx.Err()).Err()
470+
case <-cc.ctx.Done():
471+
return ErrClientConnClosing
472+
}
473+
}
474+
450475
func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) {
451476
cc.mu.Lock()
452477
defer cc.mu.Unlock()
@@ -460,6 +485,7 @@ func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) {
460485
}
461486

462487
cc.curAddresses = addrs
488+
cc.resolvedAddrsOnce.Do(func() { close(cc.resolvedAddrs) })
463489

464490
if cc.dopts.balancerBuilder == nil {
465491
// Only look at balancer types and switch balancer if balancer dial

Diff for: stream.go

+5
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
166166
}()
167167
}
168168
c := defaultCallInfo()
169+
// Provide an opportunity for the first RPC to see the first service config
170+
// provided by the resolver.
171+
if err := cc.waitForResolvedAddrs(ctx); err != nil {
172+
return nil, err
173+
}
169174
mc := cc.GetMethodConfig(method)
170175
if mc.WaitForReady != nil {
171176
c.failFast = !*mc.WaitForReady

Diff for: test/end2end_test.go

+52
Original file line numberDiff line numberDiff line change
@@ -7147,3 +7147,55 @@ func (lis notifyingListener) Accept() (net.Conn, error) {
71477147
defer lis.connEstablished.Fire()
71487148
return lis.Listener.Accept()
71497149
}
7150+
7151+
func TestRPCWaitsForResolver(t *testing.T) {
7152+
te := testServiceConfigSetup(t, tcpClearRREnv)
7153+
te.startServer(&testServer{security: tcpClearRREnv.security})
7154+
defer te.tearDown()
7155+
r, rcleanup := manual.GenerateAndRegisterManualResolver()
7156+
defer rcleanup()
7157+
7158+
te.resolverScheme = r.Scheme()
7159+
te.nonBlockingDial = true
7160+
cc := te.clientConn()
7161+
tc := testpb.NewTestServiceClient(cc)
7162+
7163+
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
7164+
defer cancel()
7165+
// With no resolved addresses yet, this will timeout.
7166+
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded {
7167+
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
7168+
}
7169+
7170+
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
7171+
defer cancel()
7172+
go func() {
7173+
time.Sleep(time.Second)
7174+
r.NewServiceConfig(`{
7175+
"methodConfig": [
7176+
{
7177+
"name": [
7178+
{
7179+
"service": "grpc.testing.TestService",
7180+
"method": "UnaryCall"
7181+
}
7182+
],
7183+
"maxRequestMessageBytes": 0
7184+
}
7185+
]
7186+
}`)
7187+
r.NewAddress([]resolver.Address{{Addr: te.srvAddr}})
7188+
}()
7189+
// We wait a second before providing a service config and resolving
7190+
// addresses. So this will wait for that and then honor the
7191+
// maxRequestMessageBytes it contains.
7192+
if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{ResponseType: testpb.PayloadType_UNCOMPRESSABLE}); status.Code(err) != codes.ResourceExhausted {
7193+
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, nil", err)
7194+
}
7195+
if got := ctx.Err(); got != nil {
7196+
t.Fatalf("ctx.Err() = %v; want nil (deadline should be set short by service config)", got)
7197+
}
7198+
if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
7199+
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, nil", err)
7200+
}
7201+
}

0 commit comments

Comments
 (0)