diff --git a/connection.go b/connection.go index 992b494..a7a4875 100644 --- a/connection.go +++ b/connection.go @@ -102,10 +102,10 @@ type ConnectionHandler interface { OnConnect(context.Context, *Connection, GenericClient, *Server) error // OnConnectError is called whenever there is an error during connection. - OnConnectError(err error, reconnectThrottleDuration time.Duration) + OnConnectError(ctx context.Context, err error, reconnectThrottleDuration time.Duration) // OnDoCommandError is called whenever there is an error during DoCommand - OnDoCommandError(err error, nextTime time.Duration) + OnDoCommandError(ctx context.Context, err error, nextTime time.Duration) // OnDisconnected is called whenever the connection notices it // is disconnected. @@ -115,12 +115,12 @@ type ConnectionHandler interface { // an RPC function passed to Connection.DoCommand(), and // should return whether or not that error signifies that that // RPC should retried (with backoff) - ShouldRetry(name string, err error) bool + ShouldRetry(ctx context.Context, name string, err error) bool // ShouldRetryOnConnect is called whenever an error is returned // during connection establishment, and should return whether or // not the connection should be established again. - ShouldRetryOnConnect(err error) bool + ShouldRetryOnConnect(ctx context.Context, err error) bool // HandlerName returns a string representing the type of the connection // handler. @@ -405,12 +405,14 @@ func (c *Connection) DoCommand(ctx context.Context, name string, // immediately when ctx is canceled. will // retry connectivity errors w/backoff. throttleErr := rpcFunc(rawClient) - if throttleErr != nil && c.handler.ShouldRetry(name, throttleErr) { + if throttleErr != nil && c.handler.ShouldRetry(ctx, name, throttleErr) { return throttleErr } rpcErr = throttleErr return nil - }, c.doCommandBackoff, c.handler.OnDoCommandError) + }, c.doCommandBackoff, func(err error, nextTime time.Duration) { + c.handler.OnDoCommandError(ctx, err, nextTime) + }) // RetryNotify gave up. if throttleErr != nil { @@ -506,7 +508,7 @@ func (c *Connection) doReconnect(ctx context.Context, disconnectStatus Disconnec return nil default: } - if !c.handler.ShouldRetryOnConnect(err) { + if !c.handler.ShouldRetryOnConnect(ctx, err) { // A fatal error happened. *reconnectErrPtr = err // short-circuit Retry @@ -515,7 +517,9 @@ func (c *Connection) doReconnect(ctx context.Context, disconnectStatus Disconnec return err }, c.reconnectBackoff, // give the caller a chance to log any other error or adjust state - c.handler.OnConnectError) + func(err error, reconnectThrottleDuration time.Duration) { + c.handler.OnConnectError(ctx, err, reconnectThrottleDuration) + }) if err != nil { // this shouldn't happen, but just in case. diff --git a/connection_test.go b/connection_test.go index 658b9bf..ce7fa27 100644 --- a/connection_test.go +++ b/connection_test.go @@ -32,12 +32,12 @@ func (ut *unitTester) OnConnect(context.Context, *Connection, GenericClient, *Se } // OnConnectError implements the ConnectionHandler interface. -func (ut *unitTester) OnConnectError(error, time.Duration) { +func (ut *unitTester) OnConnectError(context.Context, error, time.Duration) { ut.numConnectErrors++ } // OnDoCommandError implements the ConnectionHandler interace -func (ut *unitTester) OnDoCommandError(error, time.Duration) { +func (ut *unitTester) OnDoCommandError(context.Context, error, time.Duration) { } // OnDisconnected implements the ConnectionHandler interface. @@ -46,7 +46,7 @@ func (ut *unitTester) OnDisconnected(context.Context, DisconnectStatus) { } // ShouldRetry implements the ConnectionHandler interface. -func (ut *unitTester) ShouldRetry(name string, err error) bool { +func (ut *unitTester) ShouldRetry(ctx context.Context, name string, err error) bool { _, isThrottle := err.(throttleError) return isThrottle } @@ -54,7 +54,7 @@ func (ut *unitTester) ShouldRetry(name string, err error) bool { var errCanceled = errors.New("Canceled!") // ShouldRetryOnConnect implements the ConnectionHandler interface. -func (ut *unitTester) ShouldRetryOnConnect(err error) bool { +func (ut *unitTester) ShouldRetryOnConnect(cxt context.Context, err error) bool { return err != errCanceled } diff --git a/connection_test_util.go b/connection_test_util.go index c222e50..9e244c8 100644 --- a/connection_test_util.go +++ b/connection_test_util.go @@ -18,20 +18,20 @@ func (testConnectionHandler) OnConnect(context.Context, *Connection, GenericClie return nil } -func (testConnectionHandler) OnConnectError(err error, reconnectThrottleDuration time.Duration) { +func (testConnectionHandler) OnConnectError(ctx context.Context, err error, reconnectThrottleDuration time.Duration) { } -func (testConnectionHandler) OnDoCommandError(err error, nextTime time.Duration) { +func (testConnectionHandler) OnDoCommandError(ctx context.Context, err error, nextTime time.Duration) { } func (testConnectionHandler) OnDisconnected(ctx context.Context, status DisconnectStatus) { } -func (testConnectionHandler) ShouldRetry(name string, err error) bool { +func (testConnectionHandler) ShouldRetry(ctx context.Context, name string, err error) bool { return false } -func (testConnectionHandler) ShouldRetryOnConnect(err error) bool { +func (testConnectionHandler) ShouldRetryOnConnect(ctx context.Context, err error) bool { return false }