Skip to content

Commit

Permalink
allow disabling the default golang database retry behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
dvilaverde committed Jul 22, 2024
1 parent 8d0b3e3 commit 9fc4fd7
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 9 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,18 @@ golang's [ParseDuration](https://pkg.go.dev/time#ParseDuration) format.
| --------- | --------- | ----------------------------------------------- |
| duration | 0 | user:pass@localhost/mydb?writeTimeout=1m30s |

#### `retries`

Allows disabling the golang `database/sql` default behavior to retry errors
when `ErrBadConn` is returned by the driver. When retries are disabled
this driver will not return `ErrBadConn` from the `database/sql` package.

Valid values are `on` (default) and `off`.

| Type | Default | Example |
| --------- | --------- | ----------------------------------------------- |
| string | on | user:pass@localhost/mydb?retries=off |

### Custom Driver Options

The driver package exposes the function `SetDSNOptions`, allowing for modification of the
Expand Down
49 changes: 40 additions & 9 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ func parseDSN(dsn string) (connInfo, error) {
// Open takes a supplied DSN string and opens a connection
// See ParseDSN for more information on the form of the DSN
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
var c *client.Conn
var (
c *client.Conn
// by default database/sql driver retries will be enabled
retries = true
)

ci, err := parseDSN(dsn)

Expand Down Expand Up @@ -134,6 +138,10 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
if timeout, err = time.ParseDuration(value[0]); err != nil {
return nil, errors.Wrap(err, "invalid duration value for timeout option")
}
} else if key == "retries" && len(value) > 0 {
// by default keep the golang database/sql retry behavior enabled unless
// the retries driver option is explicitly set to 'false'
retries = !strings.EqualFold(value[0], "off")
} else {
if option, ok := options[key]; ok {
opt := func(o DriverOption, v string) client.Option {
Expand Down Expand Up @@ -161,15 +169,27 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
return nil, err
}

return &conn{c}, nil
// if retries is true then return sqldriver.ErrBadConn which will trigger up to 3
// retries by the database/sql package. If retries are disabled then we'll return
// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
// In this case the sqldriver.Validator interface is implemented and will return
// false for IsValid() signaling that the connection is bad and should be discarded.
return &conn{Conn: c, state: &state{valid: true, useDriverErrors: !retries}}, nil
}

type CheckNamedValueFunc func(*sqldriver.NamedValue) error

var _ sqldriver.NamedValueChecker = &conn{}
var _ sqldriver.Validator = &conn{}

type state struct {
valid bool
useDriverErrors bool
}

type conn struct {
*client.Conn
state *state
}

func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
Expand All @@ -190,13 +210,17 @@ func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
return sqldriver.ErrSkip
}

func (c *conn) IsValid() bool {
return c.state.valid
}

func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
st, err := c.Conn.Prepare(query)
if err != nil {
return nil, errors.Trace(err)
}

return &stmt{st}, nil
return &stmt{Stmt: st, connectionState: c.state}, nil
}

func (c *conn) Close() error {
Expand All @@ -222,10 +246,16 @@ func buildArgs(args []sqldriver.Value) []interface{} {
return a
}

func replyError(err error) error {
if mysql.ErrorEqual(err, mysql.ErrBadConn) {
func (st *state) replyError(err error) error {
isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn)

if !st.useDriverErrors && isBadConnection {
return sqldriver.ErrBadConn
} else {
// if we have a bad connection, this mark the state of this connection as not valid
// do the database/sql package can discard it instead of placing it back in the
// sql.DB pool.
st.valid = !isBadConnection
return errors.Trace(err)
}
}
Expand All @@ -234,7 +264,7 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
a := buildArgs(args)
r, err := c.Conn.Execute(query, a...)
if err != nil {
return nil, replyError(err)
return nil, c.state.replyError(err)
}
return &result{r}, nil
}
Expand All @@ -243,13 +273,14 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
a := buildArgs(args)
r, err := c.Conn.Execute(query, a...)
if err != nil {
return nil, replyError(err)
return nil, c.state.replyError(err)
}
return newRows(r.Resultset)
}

type stmt struct {
*client.Stmt
connectionState *state
}

func (s *stmt) Close() error {
Expand All @@ -264,7 +295,7 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
a := buildArgs(args)
r, err := s.Stmt.Execute(a...)
if err != nil {
return nil, replyError(err)
return nil, s.connectionState.replyError(err)
}
return &result{r}, nil
}
Expand All @@ -273,7 +304,7 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
a := buildArgs(args)
r, err := s.Stmt.Execute(a...)
if err != nil {
return nil, replyError(err)
return nil, s.connectionState.replyError(err)
}
return newRows(r.Resultset)
}
Expand Down
62 changes: 62 additions & 0 deletions driver/driver_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,51 @@ type testServer struct {
}

type mockHandler struct {
// the number of times a query executed
queryCount int
}

func TestDriverOptions_SetRetriesOn(t *testing.T) {
log.SetLevel(log.LevelDebug)
srv := CreateMockServer(t)
defer srv.Stop()

conn, err := sql.Open("mysql", "[email protected]:3307/test?readTimeout=1s")
defer func() {
_ = conn.Close()
}()
require.NoError(t, err)

rows, err := conn.QueryContext(context.TODO(), "select * from slow;")
require.Nil(t, rows)

// we want to get a golang database/sql/driver ErrBadConn
require.ErrorIs(t, err, sqlDriver.ErrBadConn)

// here we issue assert that even though we only issued 1 query, that the retries
// remained on and there were 3 calls to the DB.
require.Equal(t, 3, srv.handler.queryCount)
}

func TestDriverOptions_SetRetriesOff(t *testing.T) {
log.SetLevel(log.LevelDebug)
srv := CreateMockServer(t)
defer srv.Stop()

conn, err := sql.Open("mysql", "[email protected]:3307/test?readTimeout=1s&retries=off")
defer func() {
_ = conn.Close()
}()
require.NoError(t, err)

rows, err := conn.QueryContext(context.TODO(), "select * from slow;")
require.Nil(t, rows)
// we want the native error from this driver implementation
require.ErrorIs(t, err, mysql.ErrBadConn)

// here we issue assert that even though we only issued 1 query, that the retries
// remained on and there were 3 calls to the DB.
require.Equal(t, 1, srv.handler.queryCount)
}

func TestDriverOptions_SetCollation(t *testing.T) {
Expand Down Expand Up @@ -65,6 +110,9 @@ func TestDriverOptions_ConnectTimeout(t *testing.T) {
defer srv.Stop()

conn, err := sql.Open("mysql", "[email protected]:3307/test?timeout=1s")
defer func() {
_ = conn.Close()
}()
require.NoError(t, err)

rows, err := conn.QueryContext(context.TODO(), "select * from table;")
Expand All @@ -88,6 +136,9 @@ func TestDriverOptions_BufferSize(t *testing.T) {
})

conn, err := sql.Open("mysql", "[email protected]:3307/test?bufferSize=4096")
defer func() {
_ = conn.Close()
}()
require.NoError(t, err)

rows, err := conn.QueryContext(context.TODO(), "select * from table;")
Expand All @@ -103,6 +154,9 @@ func TestDriverOptions_ReadTimeout(t *testing.T) {
defer srv.Stop()

conn, err := sql.Open("mysql", "[email protected]:3307/test?readTimeout=1s")
defer func() {
_ = conn.Close()
}()
require.NoError(t, err)

rows, err := conn.QueryContext(context.TODO(), "select * from slow;")
Expand Down Expand Up @@ -134,11 +188,15 @@ func TestDriverOptions_writeTimeout(t *testing.T) {
require.Contains(t, err.Error(), "missing unit in duration")
require.Error(t, err)
require.Nil(t, result)
require.NoError(t, conn.Close())

// use an almost zero (1ns) writeTimeout to ensure the insert statement
// can't write before the timeout. Just want to make sure ExecContext()
// will throw an error.
conn, err = sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1ns")
defer func() {
_ = conn.Close()
}()
require.NoError(t, err)

// ExecContext() should fail due to the write timeout of 1ns
Expand All @@ -165,6 +223,9 @@ func TestDriverOptions_namedValueChecker(t *testing.T) {
srv := CreateMockServer(t)
defer srv.Stop()
conn, err := sql.Open("mysql", "[email protected]:3307/test?writeTimeout=1s")
defer func() {
_ = conn.Close()
}()
require.NoError(t, err)
defer conn.Close()

Expand Down Expand Up @@ -248,6 +309,7 @@ func (h *mockHandler) UseDB(dbName string) error {
}

func (h *mockHandler) handleQuery(query string, binary bool, args []interface{}) (*mysql.Result, error) {
h.queryCount++
ss := strings.Split(query, " ")
switch strings.ToLower(ss[0]) {
case "select":
Expand Down

0 comments on commit 9fc4fd7

Please sign in to comment.