Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions balancer_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ type acBalancerWrapper struct {
// dropped or updated. This is required as closures can't be compared for
// equality.
healthData *healthData

shutdownMu sync.Mutex
shutdownCh chan struct{}
activeGofuncs sync.WaitGroup
}

// healthData holds data related to health state reporting.
Expand Down Expand Up @@ -343,16 +347,45 @@ func (acbw *acBalancerWrapper) String() string {
}

func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
acbw.ac.updateAddrs(addrs)
acbw.goFunc(func(shutdown <-chan struct{}) {
acbw.ac.updateAddrs(shutdown, addrs)
})
}

func (acbw *acBalancerWrapper) Connect() {
go acbw.ac.connect()
acbw.goFunc(acbw.ac.connect)
}

func (acbw *acBalancerWrapper) goFunc(fn func(shutdown <-chan struct{})) {
acbw.shutdownMu.Lock()
defer acbw.shutdownMu.Unlock()

shutdown := acbw.shutdownCh
if shutdown == nil {
shutdown = make(chan struct{})
acbw.shutdownCh = shutdown
}

acbw.activeGofuncs.Add(1)
go func() {
defer acbw.activeGofuncs.Done()
fn(shutdown)
}()
}

func (acbw *acBalancerWrapper) Shutdown() {
acbw.closeProducers()
acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain)

acbw.shutdownMu.Lock()
defer acbw.shutdownMu.Unlock()

shutdown := acbw.shutdownCh
acbw.shutdownCh = nil
if shutdown != nil {
close(shutdown)
acbw.activeGofuncs.Wait()
}
}

// NewStream begins a streaming RPC on the addrConn. If the addrConn is not
Expand Down
75 changes: 56 additions & 19 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ func (cc *ClientConn) incrCallsFailed() {
// connect starts creating a transport.
// It does nothing if the ac is not IDLE.
// TODO(bar) Move this to the addrConn section.
func (ac *addrConn) connect() {
func (ac *addrConn) connect(abort <-chan struct{}) {
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
if logger.V(2) {
Expand All @@ -994,7 +994,7 @@ func (ac *addrConn) connect() {
return
}

ac.resetTransportAndUnlock()
ac.resetTransportAndUnlock(abort)
}

// equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
Expand All @@ -1013,7 +1013,7 @@ func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool {

// updateAddrs updates ac.addrs with the new addresses list and handles active
// connections or connection attempts.
func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
func (ac *addrConn) updateAddrs(abort <-chan struct{}, addrs []resolver.Address) {
addrs = copyAddresses(addrs)
limit := len(addrs)
if limit > 5 {
Expand Down Expand Up @@ -1069,7 +1069,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {

// Since we were connecting/connected, we should start a new connection
// attempt.
go ac.resetTransportAndUnlock()
ac.resetTransportAndUnlock(abort)
}

// getServerName determines the serverName to be used in the connection
Expand Down Expand Up @@ -1315,9 +1315,17 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
// resetTransportAndUnlock unconditionally connects the addrConn.
//
// ac.mu must be held by the caller, and this function will guarantee it is released.
func (ac *addrConn) resetTransportAndUnlock() {
acCtx := ac.ctx
if acCtx.Err() != nil {
func (ac *addrConn) resetTransportAndUnlock(abort <-chan struct{}) {
ctx, cancel := context.WithCancel(ac.ctx)
go func() {
select {
case <-abort:
cancel()
case <-ctx.Done():
}
}()

if ctx.Err() != nil {
ac.mu.Unlock()
return
}
Expand Down Expand Up @@ -1345,7 +1353,7 @@ func (ac *addrConn) resetTransportAndUnlock() {
ac.updateConnectivityState(connectivity.Connecting, nil)
ac.mu.Unlock()

if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil {
if err := ac.tryAllAddrs(ctx, addrs, connectDeadline); err != nil {
if !errors.Is(err, context.Canceled) {
connectionAttemptsFailedMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel)
} else {
Expand All @@ -1359,7 +1367,7 @@ func (ac *addrConn) resetTransportAndUnlock() {
// to ensure one resolution request per pass instead of per subconn failure.
ac.cc.resolveNow(resolver.ResolveNowOptions{})
ac.mu.Lock()
if acCtx.Err() != nil {
if ctx.Err() != nil {
// addrConn was torn down.
ac.mu.Unlock()
return
Expand All @@ -1380,13 +1388,13 @@ func (ac *addrConn) resetTransportAndUnlock() {
ac.mu.Unlock()
case <-b:
timer.Stop()
case <-acCtx.Done():
case <-ctx.Done():
timer.Stop()
return
}

ac.mu.Lock()
if acCtx.Err() == nil {
if ctx.Err() == nil {
ac.updateConnectivityState(connectivity.Idle, err)
}
ac.mu.Unlock()
Expand Down Expand Up @@ -1481,11 +1489,33 @@ func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, c
// new transport.
func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
addr.ServerName = ac.cc.getServerName(addr)

var healthCheckDoneCh <-chan struct{}
hctx, hcancel := context.WithCancel(ctx)
defer func() {
// If healthCheckDoneCh is nil, then a health check has not been
// started. Therefore, the health check context can be canceled because
// it is not in use.
if healthCheckDoneCh == nil {
hcancel()
}
}()

onClose := func(r transport.GoAwayReason) {
var healthCheckCompleteCh <-chan struct{}

ac.mu.Lock()
defer ac.mu.Unlock()
defer func() {
ac.mu.Unlock()
// If healthCheckCompleteCh is not nil, then hcancel() has been
// called and healthCheckCompleteCh is a copy of healthCheckDoneCh,
// as it was when ac.mu was held. Now wait for the health check to
// complete.
if healthCheckCompleteCh != nil {
<-healthCheckCompleteCh
}
}()

// adjust params based on GoAwayReason
ac.adjustParams(r)
if ctx.Err() != nil {
Expand All @@ -1496,6 +1526,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
return
}
hcancel()
healthCheckCompleteCh = healthCheckDoneCh
if ac.transport == nil {
// We're still connecting to this address, which could error. Do
// not update the connectivity state or resolve; these will happen
Expand All @@ -1521,7 +1552,6 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
logger.Infof("Creating new client transport to %q: %v", addr, err)
}
// newTr is either nil, or closed.
hcancel()
channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err)
return err
}
Expand Down Expand Up @@ -1556,13 +1586,17 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
}
ac.curAddr = addr
ac.transport = newTr
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
healthCheckDoneCh = ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
return nil
}

// startHealthCheck starts the health checking stream (RPC) to watch the health
// stats of this connection if health checking is requested and configured.
//
// A channel is returned that will be closed once the health check goroutine
// exits after ctx has been canceled, or nil if the health check requirements
// aren't met and no goroutine has been started.
//
// LB channel health checking is enabled when all requirements below are met:
// 1. it is not disabled by the user with the WithDisableHealthCheck DialOption
// 2. internal.HealthCheckFunc is set by importing the grpc/health package
Expand All @@ -1572,7 +1606,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
// It sets addrConn to READY if the health checking stream is not started.
//
// Caller must hold ac.mu.
func (ac *addrConn) startHealthCheck(ctx context.Context) {
func (ac *addrConn) startHealthCheck(ctx context.Context) <-chan struct{} {
var healthcheckManagingState bool
defer func() {
if !healthcheckManagingState {
Expand All @@ -1581,22 +1615,22 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
}()

if ac.cc.dopts.disableHealthCheck {
return
return nil
}
healthCheckConfig := ac.cc.healthCheckConfig()
if healthCheckConfig == nil {
return
return nil
}
if !ac.scopts.HealthCheckEnabled {
return
return nil
}
healthCheckFunc := internal.HealthCheckFunc
if healthCheckFunc == nil {
// The health package is not imported to set health check function.
//
// TODO: add a link to the health check doc in the error message.
channelz.Error(logger, ac.channelz, "Health check is requested but health check function is not set.")
return
return nil
}

healthcheckManagingState = true
Expand All @@ -1621,7 +1655,9 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
ac.updateConnectivityState(s, lastErr)
}
// Start the health checking stream.
done := make(chan struct{})
go func() {
defer close(done)
err := healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
if err != nil {
if status.Code(err) == codes.Unimplemented {
Expand All @@ -1631,6 +1667,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
}
}
}()
return done
}

func (ac *addrConn) resetConnectBackoff() {
Expand Down
6 changes: 3 additions & 3 deletions clientconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,11 +614,11 @@ func (s) TestResetConnectBackoff(t *testing.T) {
default:
}
}()
dialer := func(string, time.Duration) (net.Conn, error) {
dialer := func(context.Context, string) (net.Conn, error) {
dials <- struct{}{}
return nil, errors.New("failed to fake dial")
}
cc, err := NewClient("passthrough:///", WithTransportCredentials(insecure.NewCredentials()), WithDialer(dialer), withBackoff(backoffForever{}))
cc, err := NewClient("passthrough:///", WithTransportCredentials(insecure.NewCredentials()), WithContextDialer(dialer), withBackoff(backoffForever{}))
if err != nil {
t.Fatalf("grpc.NewClient() failed with error: %v, want: nil", err)
}
Expand Down Expand Up @@ -647,7 +647,7 @@ func (s) TestResetConnectBackoff(t *testing.T) {

func (s) TestBackoffCancel(t *testing.T) {
dialStrCh := make(chan string)
cc, err := NewClient("passthrough:///", WithTransportCredentials(insecure.NewCredentials()), WithDialer(func(t string, _ time.Duration) (net.Conn, error) {
cc, err := NewClient("passthrough:///", WithTransportCredentials(insecure.NewCredentials()), WithContextDialer(func(_ context.Context, t string) (net.Conn, error) {
dialStrCh <- t
return nil, fmt.Errorf("test dialer, always error")
}))
Expand Down
4 changes: 2 additions & 2 deletions dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ func (s) TestDialContextFailFast(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
failErr := failFastError{}
dialer := func(string, time.Duration) (net.Conn, error) {
dialer := func(context.Context, string) (net.Conn, error) {
return nil, failErr
}

_, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithTransportCredentials(insecure.NewCredentials()), WithDialer(dialer), FailOnNonTempDialError(true))
_, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithTransportCredentials(insecure.NewCredentials()), WithContextDialer(dialer), FailOnNonTempDialError(true))
if terr, ok := err.(transport.ConnectionError); !ok || terr.Origin() != failErr {
t.Fatalf("DialContext() = _, %v, want _, %v", err, failErr)
}
Expand Down
10 changes: 6 additions & 4 deletions internal/testutils/pipe_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
package testutils

import (
"context"
"errors"
"net"
"time"
)

var errClosed = errors.New("closed")
Expand Down Expand Up @@ -78,14 +78,16 @@ func (p *PipeListener) Addr() net.Addr {
return pipeAddr{}
}

// Dialer dials a connection.
func (p *PipeListener) Dialer() func(string, time.Duration) (net.Conn, error) {
return func(string, time.Duration) (net.Conn, error) {
// ContextDialer dials a connection using a context.
func (p *PipeListener) ContextDialer() func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, _ string) (net.Conn, error) {
connChan := make(chan net.Conn)
select {
case p.c <- connChan:
case <-p.done:
return nil, errClosed
case <-ctx.Done():
return nil, ctx.Err()
}
conn, ok := <-connChan
if !ok {
Expand Down
Loading