diff --git a/client/auth.go b/client/auth.go index fc65332e0..cb479f3ba 100644 --- a/client/auth.go +++ b/client/auth.go @@ -304,7 +304,7 @@ func (c *Conn) writeAuthHandshake() error { } currentSequence := c.Sequence - c.Conn = packet.NewConn(tlsConn) + c.Conn = packet.NewBufferedConn(tlsConn, c.BufferSize) c.Sequence = currentSequence } diff --git a/client/conn.go b/client/conn.go index c7be06b85..1358d5f52 100644 --- a/client/conn.go +++ b/client/conn.go @@ -18,6 +18,8 @@ import ( "github.com/go-mysql-org/go-mysql/utils" ) +const defaultBufferSize = 65536 // 64kb + type Option func(*Conn) error type Conn struct { @@ -33,6 +35,9 @@ type Conn struct { ReadTimeout time.Duration WriteTimeout time.Duration + // The buffer size to use in the packet connection + BufferSize int + serverVersion string // server capabilities capability uint32 @@ -94,6 +99,7 @@ type Dialer func(ctx context.Context, network, address string) (net.Conn, error) func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) { c := new(Conn) + c.BufferSize = defaultBufferSize c.attributes = map[string]string{ "_client_name": "go-mysql", // "_client_version": "0.1", @@ -129,7 +135,7 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam } } - c.Conn = packet.NewConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout) + c.Conn = packet.NewConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout, c.BufferSize) if c.tlsConfig != nil { seq := c.Conn.Sequence c.Conn = packet.NewTLSConnWithTimeout(conn, c.ReadTimeout, c.WriteTimeout) diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go index b7be7672f..e0a9820a8 100644 --- a/driver/driver_options_test.go +++ b/driver/driver_options_test.go @@ -8,6 +8,7 @@ import ( "math" "net" "reflect" + "strconv" "strings" "testing" "time" @@ -73,6 +74,29 @@ func TestDriverOptions_ConnectTimeout(t *testing.T) { conn.Close() } +func TestDriverOptions_BufferSize(t *testing.T) { + log.SetLevel(log.LevelDebug) + srv := CreateMockServer(t) + defer srv.Stop() + + SetDSNOptions(map[string]DriverOption{ + "bufferSize": func(c *client.Conn, value string) error { + var err error + c.BufferSize, err = strconv.Atoi(value) + return err + }, + }) + + conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?bufferSize=4096") + require.NoError(t, err) + + rows, err := conn.QueryContext(context.TODO(), "select * from table;") + require.NotNil(t, rows) + require.NoError(t, err) + + conn.Close() +} + func TestDriverOptions_ReadTimeout(t *testing.T) { log.SetLevel(log.LevelDebug) srv := CreateMockServer(t) diff --git a/packet/conn.go b/packet/conn.go index 9901e34be..d6a8f36b5 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -53,10 +53,14 @@ type Conn struct { } func NewConn(conn net.Conn) *Conn { + return NewBufferedConn(conn, 65536) // 64kb +} + +func NewBufferedConn(conn net.Conn, bufferSize int) *Conn { c := new(Conn) c.Conn = conn - c.br = bufio.NewReaderSize(c, 65536) // 64kb + c.br = bufio.NewReaderSize(c, bufferSize) c.reader = c.br c.copyNBuf = make([]byte, DefaultBufferSize) @@ -64,8 +68,8 @@ func NewConn(conn net.Conn) *Conn { return c } -func NewConnWithTimeout(conn net.Conn, readTimeout, writeTimeout time.Duration) *Conn { - c := NewConn(conn) +func NewConnWithTimeout(conn net.Conn, readTimeout, writeTimeout time.Duration, bufferSize int) *Conn { + c := NewBufferedConn(conn, bufferSize) c.readTimeout = readTimeout c.writeTimeout = writeTimeout return c