diff --git a/balancer_wrapper.go b/balancer_wrapper.go index a1e56a3893cb..66e3ad1de8df 100644 --- a/balancer_wrapper.go +++ b/balancer_wrapper.go @@ -281,6 +281,10 @@ type acBalancerWrapper struct { // dropped or updated. This is required as closures can't be compared for // equality. healthData *healthData + + shutdownMu sync.Mutex + shutdownCh chan struct{} + activeGofuncs sync.WaitGroup } // healthData holds data related to health state reporting. @@ -343,16 +347,45 @@ func (acbw *acBalancerWrapper) String() string { } func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { - acbw.ac.updateAddrs(addrs) + acbw.goFunc(func(shutdown <-chan struct{}) { + acbw.ac.updateAddrs(shutdown, addrs) + }) } func (acbw *acBalancerWrapper) Connect() { - go acbw.ac.connect() + acbw.goFunc(acbw.ac.connect) +} + +func (acbw *acBalancerWrapper) goFunc(fn func(shutdown <-chan struct{})) { + acbw.shutdownMu.Lock() + defer acbw.shutdownMu.Unlock() + + shutdown := acbw.shutdownCh + if shutdown == nil { + shutdown = make(chan struct{}) + acbw.shutdownCh = shutdown + } + + acbw.activeGofuncs.Add(1) + go func() { + defer acbw.activeGofuncs.Done() + fn(shutdown) + }() } func (acbw *acBalancerWrapper) Shutdown() { acbw.closeProducers() acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain) + + acbw.shutdownMu.Lock() + defer acbw.shutdownMu.Unlock() + + shutdown := acbw.shutdownCh + acbw.shutdownCh = nil + if shutdown != nil { + close(shutdown) + acbw.activeGofuncs.Wait() + } } // NewStream begins a streaming RPC on the addrConn. If the addrConn is not diff --git a/clientconn.go b/clientconn.go index 5dec2dacc0ba..cc6eda763510 100644 --- a/clientconn.go +++ b/clientconn.go @@ -977,7 +977,7 @@ func (cc *ClientConn) incrCallsFailed() { // connect starts creating a transport. // It does nothing if the ac is not IDLE. // TODO(bar) Move this to the addrConn section. -func (ac *addrConn) connect() { +func (ac *addrConn) connect(abort <-chan struct{}) { ac.mu.Lock() if ac.state == connectivity.Shutdown { if logger.V(2) { @@ -994,7 +994,7 @@ func (ac *addrConn) connect() { return } - ac.resetTransportAndUnlock() + ac.resetTransportAndUnlock(abort) } // equalAddressIgnoringBalAttributes returns true is a and b are considered equal. @@ -1013,7 +1013,7 @@ func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool { // updateAddrs updates ac.addrs with the new addresses list and handles active // connections or connection attempts. -func (ac *addrConn) updateAddrs(addrs []resolver.Address) { +func (ac *addrConn) updateAddrs(abort <-chan struct{}, addrs []resolver.Address) { addrs = copyAddresses(addrs) limit := len(addrs) if limit > 5 { @@ -1069,7 +1069,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) { // Since we were connecting/connected, we should start a new connection // attempt. - go ac.resetTransportAndUnlock() + ac.resetTransportAndUnlock(abort) } // getServerName determines the serverName to be used in the connection @@ -1315,9 +1315,17 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) { // resetTransportAndUnlock unconditionally connects the addrConn. // // ac.mu must be held by the caller, and this function will guarantee it is released. -func (ac *addrConn) resetTransportAndUnlock() { - acCtx := ac.ctx - if acCtx.Err() != nil { +func (ac *addrConn) resetTransportAndUnlock(abort <-chan struct{}) { + ctx, cancel := context.WithCancel(ac.ctx) + go func() { + select { + case <-abort: + cancel() + case <-ctx.Done(): + } + }() + + if ctx.Err() != nil { ac.mu.Unlock() return } @@ -1345,7 +1353,7 @@ func (ac *addrConn) resetTransportAndUnlock() { ac.updateConnectivityState(connectivity.Connecting, nil) ac.mu.Unlock() - if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil { + if err := ac.tryAllAddrs(ctx, addrs, connectDeadline); err != nil { if !errors.Is(err, context.Canceled) { connectionAttemptsFailedMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel) } else { @@ -1359,7 +1367,7 @@ func (ac *addrConn) resetTransportAndUnlock() { // to ensure one resolution request per pass instead of per subconn failure. ac.cc.resolveNow(resolver.ResolveNowOptions{}) ac.mu.Lock() - if acCtx.Err() != nil { + if ctx.Err() != nil { // addrConn was torn down. ac.mu.Unlock() return @@ -1380,13 +1388,13 @@ func (ac *addrConn) resetTransportAndUnlock() { ac.mu.Unlock() case <-b: timer.Stop() - case <-acCtx.Done(): + case <-ctx.Done(): timer.Stop() return } ac.mu.Lock() - if acCtx.Err() == nil { + if ctx.Err() == nil { ac.updateConnectivityState(connectivity.Idle, err) } ac.mu.Unlock() @@ -1481,11 +1489,33 @@ func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, c // new transport. func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error { addr.ServerName = ac.cc.getServerName(addr) + + var healthCheckDoneCh <-chan struct{} hctx, hcancel := context.WithCancel(ctx) + defer func() { + // If healthCheckDoneCh is nil, then a health check has not been + // started. Therefore, the health check context can be canceled because + // it is not in use. + if healthCheckDoneCh == nil { + hcancel() + } + }() onClose := func(r transport.GoAwayReason) { + var healthCheckCompleteCh <-chan struct{} + ac.mu.Lock() - defer ac.mu.Unlock() + defer func() { + ac.mu.Unlock() + // If healthCheckCompleteCh is not nil, then hcancel() has been + // called and healthCheckCompleteCh is a copy of healthCheckDoneCh, + // as it was when ac.mu was held. Now wait for the health check to + // complete. + if healthCheckCompleteCh != nil { + <-healthCheckCompleteCh + } + }() + // adjust params based on GoAwayReason ac.adjustParams(r) if ctx.Err() != nil { @@ -1496,6 +1526,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, return } hcancel() + healthCheckCompleteCh = healthCheckDoneCh if ac.transport == nil { // We're still connecting to this address, which could error. Do // not update the connectivity state or resolve; these will happen @@ -1521,7 +1552,6 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, logger.Infof("Creating new client transport to %q: %v", addr, err) } // newTr is either nil, or closed. - hcancel() channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err) return err } @@ -1556,13 +1586,17 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, } ac.curAddr = addr ac.transport = newTr - ac.startHealthCheck(hctx) // Will set state to READY if appropriate. + healthCheckDoneCh = ac.startHealthCheck(hctx) // Will set state to READY if appropriate. return nil } // startHealthCheck starts the health checking stream (RPC) to watch the health // stats of this connection if health checking is requested and configured. // +// A channel is returned that will be closed once the health check goroutine +// exits after ctx has been canceled, or nil if the health check requirements +// aren't met and no goroutine has been started. +// // LB channel health checking is enabled when all requirements below are met: // 1. it is not disabled by the user with the WithDisableHealthCheck DialOption // 2. internal.HealthCheckFunc is set by importing the grpc/health package @@ -1572,7 +1606,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, // It sets addrConn to READY if the health checking stream is not started. // // Caller must hold ac.mu. -func (ac *addrConn) startHealthCheck(ctx context.Context) { +func (ac *addrConn) startHealthCheck(ctx context.Context) <-chan struct{} { var healthcheckManagingState bool defer func() { if !healthcheckManagingState { @@ -1581,14 +1615,14 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) { }() if ac.cc.dopts.disableHealthCheck { - return + return nil } healthCheckConfig := ac.cc.healthCheckConfig() if healthCheckConfig == nil { - return + return nil } if !ac.scopts.HealthCheckEnabled { - return + return nil } healthCheckFunc := internal.HealthCheckFunc if healthCheckFunc == nil { @@ -1596,7 +1630,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) { // // TODO: add a link to the health check doc in the error message. channelz.Error(logger, ac.channelz, "Health check is requested but health check function is not set.") - return + return nil } healthcheckManagingState = true @@ -1621,7 +1655,9 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) { ac.updateConnectivityState(s, lastErr) } // Start the health checking stream. + done := make(chan struct{}) go func() { + defer close(done) err := healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName) if err != nil { if status.Code(err) == codes.Unimplemented { @@ -1631,6 +1667,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) { } } }() + return done } func (ac *addrConn) resetConnectBackoff() { diff --git a/clientconn_test.go b/clientconn_test.go index f1e977162d09..b2bdfdd9ccc1 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -614,11 +614,11 @@ func (s) TestResetConnectBackoff(t *testing.T) { default: } }() - dialer := func(string, time.Duration) (net.Conn, error) { + dialer := func(context.Context, string) (net.Conn, error) { dials <- struct{}{} return nil, errors.New("failed to fake dial") } - cc, err := NewClient("passthrough:///", WithTransportCredentials(insecure.NewCredentials()), WithDialer(dialer), withBackoff(backoffForever{})) + cc, err := NewClient("passthrough:///", WithTransportCredentials(insecure.NewCredentials()), WithContextDialer(dialer), withBackoff(backoffForever{})) if err != nil { t.Fatalf("grpc.NewClient() failed with error: %v, want: nil", err) } @@ -647,7 +647,7 @@ func (s) TestResetConnectBackoff(t *testing.T) { func (s) TestBackoffCancel(t *testing.T) { dialStrCh := make(chan string) - cc, err := NewClient("passthrough:///", WithTransportCredentials(insecure.NewCredentials()), WithDialer(func(t string, _ time.Duration) (net.Conn, error) { + cc, err := NewClient("passthrough:///", WithTransportCredentials(insecure.NewCredentials()), WithContextDialer(func(_ context.Context, t string) (net.Conn, error) { dialStrCh <- t return nil, fmt.Errorf("test dialer, always error") })) diff --git a/dial_test.go b/dial_test.go index 7b74aac79fa6..7ebd23a8d19b 100644 --- a/dial_test.go +++ b/dial_test.go @@ -218,11 +218,11 @@ func (s) TestDialContextFailFast(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() failErr := failFastError{} - dialer := func(string, time.Duration) (net.Conn, error) { + dialer := func(context.Context, string) (net.Conn, error) { return nil, failErr } - _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithTransportCredentials(insecure.NewCredentials()), WithDialer(dialer), FailOnNonTempDialError(true)) + _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithTransportCredentials(insecure.NewCredentials()), WithContextDialer(dialer), FailOnNonTempDialError(true)) if terr, ok := err.(transport.ConnectionError); !ok || terr.Origin() != failErr { t.Fatalf("DialContext() = _, %v, want _, %v", err, failErr) } diff --git a/internal/testutils/pipe_listener.go b/internal/testutils/pipe_listener.go index 6bd3bc0bea12..8142e17ae197 100644 --- a/internal/testutils/pipe_listener.go +++ b/internal/testutils/pipe_listener.go @@ -20,9 +20,9 @@ package testutils import ( + "context" "errors" "net" - "time" ) var errClosed = errors.New("closed") @@ -78,14 +78,16 @@ func (p *PipeListener) Addr() net.Addr { return pipeAddr{} } -// Dialer dials a connection. -func (p *PipeListener) Dialer() func(string, time.Duration) (net.Conn, error) { - return func(string, time.Duration) (net.Conn, error) { +// ContextDialer dials a connection using a context. +func (p *PipeListener) ContextDialer() func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, _ string) (net.Conn, error) { connChan := make(chan net.Conn) select { case p.c <- connChan: case <-p.done: return nil, errClosed + case <-ctx.Done(): + return nil, ctx.Err() } conn, ok := <-connChan if !ok { diff --git a/internal/testutils/pipe_listener_test.go b/internal/testutils/pipe_listener_test.go index 45cc27e97868..15a86246abed 100644 --- a/internal/testutils/pipe_listener_test.go +++ b/internal/testutils/pipe_listener_test.go @@ -19,6 +19,7 @@ package testutils_test import ( + "context" "testing" "time" @@ -26,6 +27,8 @@ import ( "google.golang.org/grpc/internal/testutils" ) +const defaultTestTimeout = 10 * time.Second + type s struct { grpctest.Tester } @@ -53,8 +56,10 @@ func (s) TestPipeListener(t *testing.T) { recvdBytes <- read }() - dl := pl.Dialer() - conn, err := dl("", time.Duration(0)) + dl := pl.ContextDialer() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := dl(ctx, "") if err != nil { t.Fatal(err) } @@ -85,8 +90,10 @@ func (s) TestUnblocking(t *testing.T) { { desc: "Accept unblocks Dial", blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error { - dl := pl.Dialer() - _, err := dl("", time.Duration(0)) + dl := pl.ContextDialer() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err := dl(ctx, "") close(done) return err }, @@ -99,8 +106,10 @@ func (s) TestUnblocking(t *testing.T) { desc: "Close unblocks Dial", blockFuncShouldError: true, // because pl.Close will be called blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error { - dl := pl.Dialer() - _, err := dl("", time.Duration(0)) + dl := pl.ContextDialer() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err := dl(ctx, "") close(done) return err }, @@ -116,8 +125,10 @@ func (s) TestUnblocking(t *testing.T) { return err }, unblockFunc: func(pl *testutils.PipeListener) error { - dl := pl.Dialer() - _, err := dl("", time.Duration(0)) + dl := pl.ContextDialer() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err := dl(ctx, "") return err }, }, diff --git a/test/clientconn_state_transition_test.go b/test/clientconn_state_transition_test.go index 3d365bd8b378..0c9eb17a5f1e 100644 --- a/test/clientconn_state_transition_test.go +++ b/test/clientconn_state_transition_test.go @@ -188,7 +188,7 @@ func testStateTransitionSingleAddress(t *testing.T, wantStates []connectivity.St dopts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDialer(pl.Dialer()), + grpc.WithContextDialer(pl.ContextDialer()), grpc.WithConnectParams(grpc.ConnectParams{ Backoff: backoff.Config{}, MinConnectTimeout: 100 * time.Millisecond, diff --git a/test/end2end_test.go b/test/end2end_test.go index 3859881f3f56..818f707aa2b9 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -408,7 +408,7 @@ type env struct { security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS balancer string // One of "round_robin", "pick_first", or "". - customDialer func(string, string, time.Duration) (net.Conn, error) + customDialer func(context.Context, string, string) (net.Conn, error) } func (e env) runnable() bool { @@ -418,11 +418,12 @@ func (e env) runnable() bool { return true } -func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { +func (e env) dialer(ctx context.Context, addr string) (net.Conn, error) { if e.customDialer != nil { - return e.customDialer(e.network, addr, timeout) + return e.customDialer(ctx, e.network, addr) } - return net.DialTimeout(e.network, addr, timeout) + d := net.Dialer{} + return d.DialContext(ctx, e.network, addr) } var ( @@ -759,7 +760,7 @@ func (d *nopDecompressor) Type() string { } func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) { - opts = append(opts, grpc.WithDialer(te.e.dialer), grpc.WithUserAgent(te.userAgent)) + opts = append(opts, grpc.WithContextDialer(te.e.dialer), grpc.WithUserAgent(te.userAgent)) if te.clientCompression { opts = append(opts, @@ -839,7 +840,7 @@ func (te *test) clientConnWithConnControl() (*grpc.ClientConn, *dialerWrapper) { opts, scheme := te.configDial() dw := &dialerWrapper{} // overwrite the dialer before - opts = append(opts, grpc.WithDialer(dw.dialer)) + opts = append(opts, grpc.WithContextDialer(dw.dialer)) var err error te.cc, err = grpc.NewClient(scheme+te.srvAddr, opts...) if err != nil { @@ -868,7 +869,9 @@ func (te *test) declareLogNoise(phrases ...string) { } func (te *test) withServerTester(fn func(st *serverTester)) { - c, err := te.e.dialer(te.srvAddr, 10*time.Second) + ctx, cancel := context.WithTimeout(te.ctx, 10*time.Second) + defer cancel() + c, err := te.e.dialer(ctx, te.srvAddr) if err != nil { te.t.Fatal(err) } @@ -925,8 +928,9 @@ func (l *lazyConn) Write(b []byte) (int, error) { func (s) TestContextDeadlineNotIgnored(t *testing.T) { e := noBalancerEnv var lc *lazyConn - e.customDialer = func(network, addr string, timeout time.Duration) (net.Conn, error) { - conn, err := net.DialTimeout(network, addr, timeout) + e.customDialer = func(ctx context.Context, network, addr string) (net.Conn, error) { + d := net.Dialer{} + conn, err := d.DialContext(ctx, network, addr) if err != nil { return nil, err } @@ -6184,7 +6188,7 @@ func (s) TestNetPipeConn(t *testing.T) { go s.Serve(pl) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc, err := grpc.NewClient("passthrough:///", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDialer(pl.Dialer())) + cc, err := grpc.NewClient("passthrough:///", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(pl.ContextDialer())) if err != nil { t.Fatalf("Error creating client: %v", err) } diff --git a/test/local_creds_test.go b/test/local_creds_test.go index 13dd4ae53808..300a54c5f59e 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -148,9 +148,10 @@ func (l *lisWrapper) Accept() (net.Conn, error) { return connWrapper{c, l.remote}, nil } -func spoofDialer(addr net.Addr) func(target string, t time.Duration) (net.Conn, error) { - return func(t string, d time.Duration) (net.Conn, error) { - c, err := net.DialTimeout("tcp", t, d) +func spoofDialer(addr net.Addr) func(ctx context.Context, target string) (net.Conn, error) { + return func(ctx context.Context, t string) (net.Conn, error) { + d := net.Dialer{} + c, err := d.DialContext(ctx, "tcp", t) if err != nil { return nil, err } @@ -182,7 +183,7 @@ func testLocalCredsE2EFail(t *testing.T, dopts []grpc.DialOption) error { stubserver.StartTestService(t, ss) defer ss.S.Stop() - cc, err := grpc.NewClient(lis.Addr().String(), append(dopts, grpc.WithDialer(spoofDialer(fakeServerAddr)))...) + cc, err := grpc.NewClient(lis.Addr().String(), append(dopts, grpc.WithContextDialer(spoofDialer(fakeServerAddr)))...) if err != nil { return fmt.Errorf("Failed to dial server: %v, %v", err, lis.Addr().String()) } diff --git a/test/rawConnWrapper.go b/test/rawConnWrapper.go index 4928056ebe53..317754ece80c 100644 --- a/test/rawConnWrapper.go +++ b/test/rawConnWrapper.go @@ -18,12 +18,12 @@ package test import ( "bytes" + "context" "fmt" "io" "net" "strings" "sync" - "time" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -67,8 +67,9 @@ type dialerWrapper struct { rcw *rawConnWrapper } -func (d *dialerWrapper) dialer(target string, t time.Duration) (net.Conn, error) { - c, err := net.DialTimeout("tcp", target, t) +func (d *dialerWrapper) dialer(ctx context.Context, target string) (net.Conn, error) { + dialer := net.Dialer{} + c, err := dialer.DialContext(ctx, "tcp", target) d.c = c d.rcw = newRawConnWrapperFromConn(c) return c, err