Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestValidCert(t *testing.T) {
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -118,7 +118,7 @@ func TestNoCert(t *testing.T) {
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
10 changes: 5 additions & 5 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestTLSClientDisabled(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -223,7 +223,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -296,7 +296,7 @@ func TestTLSClientRequired(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -343,7 +343,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -426,7 +426,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
require.NoError(t, err)
defer l.Close()

Expand Down
21 changes: 4 additions & 17 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ import (
"sync/atomic"
"time"

"github.com/spf13/pflag"

"vitess.io/vitess/go/bucketpool"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/sqlerror"
Expand All @@ -40,7 +38,6 @@ import (
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vterrors"
)

Expand Down Expand Up @@ -70,19 +67,6 @@ const (
ephemeralRead
)

var (
mysqlMultiQuery = false
)

func registerConnFlags(fs *pflag.FlagSet) {
fs.BoolVar(&mysqlMultiQuery, "mysql-server-multi-query-protocol", mysqlMultiQuery, "If set, the server will use the new implementation of handling queries where-in multiple queries are sent together.")
}

func init() {
servenv.OnParseFor("vtgate", registerConnFlags)
servenv.OnParseFor("vtcombo", registerConnFlags)
}

// A Getter has a Get()
type Getter interface {
Get() *querypb.VTGateCallerID
Expand Down Expand Up @@ -222,6 +206,8 @@ type Conn struct {
// This is currently used for testing.
keepAliveOn bool

multiQuery bool

// mu protects the fields below
mu sync.Mutex
// cancel keep the cancel function for the current executing query.
Expand Down Expand Up @@ -298,6 +284,7 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn {
keepAliveOn: enabledKeepAlive,
flushDelay: listener.flushDelay,
truncateErrLen: listener.truncateErrLen,
multiQuery: listener.multiQuery,
}

if listener.connReadBufferSize > 0 {
Expand Down Expand Up @@ -930,7 +917,7 @@ func (c *Conn) handleNextCommand(handler Handler) bool {
res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false)
return res != connErr
case ComQuery:
if mysqlMultiQuery {
if c.multiQuery {
return c.handleComQueryMulti(handler, data)
}
return c.handleComQuery(handler, data)
Expand Down
2 changes: 2 additions & 0 deletions go/mysql/conn_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ type ConnParams struct {
FlushDelay time.Duration

TruncateErrLen int

MultiQuery bool
}

// EnableSSL will set the right flag on the parameters.
Expand Down
29 changes: 4 additions & 25 deletions go/mysql/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,14 +804,10 @@ func TestIsEOFPacket(t *testing.T) {
}

func TestMultiStatementStopsOnError(t *testing.T) {
origMysqlMultiQuery := mysqlMultiQuery
defer func() {
mysqlMultiQuery = origMysqlMultiQuery
}()
for _, b := range []bool{true, false} {
t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) {
mysqlMultiQuery = b
listener, sConn, cConn := createSocketPair(t)
sConn.multiQuery = b
sConn.Capabilities |= CapabilityClientMultiStatements
defer func() {
listener.Close()
Expand Down Expand Up @@ -839,14 +835,10 @@ func TestMultiStatementStopsOnError(t *testing.T) {
}

func TestEmptyQuery(t *testing.T) {
origMysqlMultiQuery := mysqlMultiQuery
defer func() {
mysqlMultiQuery = origMysqlMultiQuery
}()
for _, b := range []bool{true, false} {
t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) {
mysqlMultiQuery = b
listener, sConn, cConn := createSocketPair(t)
sConn.multiQuery = b
sConn.Capabilities |= CapabilityClientMultiStatements
defer func() {
listener.Close()
Expand All @@ -873,14 +865,10 @@ func TestEmptyQuery(t *testing.T) {
}

func TestMultiStatement(t *testing.T) {
origMysqlMultiQuery := mysqlMultiQuery
defer func() {
mysqlMultiQuery = origMysqlMultiQuery
}()
for _, b := range []bool{true, false} {
t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) {
mysqlMultiQuery = b
listener, sConn, cConn := createSocketPair(t)
sConn.multiQuery = b
sConn.Capabilities |= CapabilityClientMultiStatements
defer func() {
listener.Close()
Expand Down Expand Up @@ -930,14 +918,10 @@ func TestMultiStatement(t *testing.T) {
}

func TestMultiStatementOnSplitError(t *testing.T) {
origMysqlMultiQuery := mysqlMultiQuery
defer func() {
mysqlMultiQuery = origMysqlMultiQuery
}()
for _, b := range []bool{true, false} {
t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) {
mysqlMultiQuery = b
listener, sConn, cConn := createSocketPair(t)
sConn.multiQuery = b
sConn.Capabilities |= CapabilityClientMultiStatements
defer func() {
listener.Close()
Expand Down Expand Up @@ -990,13 +974,8 @@ func TestInitDbAgainstWrongDbDoesNotDropConnection(t *testing.T) {
}

func TestConnectionErrorWhileWritingComQuery(t *testing.T) {
origMysqlMultiQuery := mysqlMultiQuery
defer func() {
mysqlMultiQuery = origMysqlMultiQuery
}()
for _, b := range []bool{true, false} {
t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) {
mysqlMultiQuery = b
// Set the conn for the server connection to the simulated connection which always returns an error on writing
sConn := newConn(testConn{
writeToPass: []bool{false, true},
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func NewWithEnv(t testing.TB, env *vtenv.Environment) *DB {
authServer := mysql.NewAuthServerNone()

// Start listening.
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0, 0)
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestClearTextClientAuth(t *testing.T) {
defer authServer.close()

// Create the listener.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -100,7 +100,7 @@ func TestSSLConnection(t *testing.T) {
defer authServer.close()

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false)
require.NoError(t, err, "NewListener failed: %v", err)
host := l.Addr().(*net.TCPAddr).IP.String()
port := l.Addr().(*net.TCPAddr).Port
Expand Down
6 changes: 3 additions & 3 deletions go/mysql/mysql_fuzzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func createFuzzingSocketPair() (net.Listener, *Conn, *Conn) {
}

// Create a Conn on both sides.
cConn := newConn(clientConn, DefaultFlushDelay)
sConn := newConn(serverConn, DefaultFlushDelay)
cConn := newConn(clientConn, DefaultFlushDelay, 0, false)
sConn := newConn(serverConn, DefaultFlushDelay, 0, false)

return listener, sConn, cConn
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func FuzzHandleNextCommand(data []byte) int {
writeToPass: []bool{false},
pos: -1,
queryPacket: data,
}, DefaultFlushDelay)
}, DefaultFlushDelay, 0, false)
sConn.PrepareData = map[uint32]*PrepareData{}

handler := &fuzztestRun{}
Expand Down
11 changes: 9 additions & 2 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ type Listener struct {
// connBufferPooling configures if vtgate server pools connection buffers
connBufferPooling bool

multiQuery bool

// connKeepAlivePeriod is period between tcp keep-alives.
connKeepAlivePeriod time.Duration

Expand Down Expand Up @@ -236,6 +238,7 @@ func NewFromListener(
connBufferPooling bool,
keepAlivePeriod time.Duration,
flushDelay time.Duration,
multiQuery bool,
) (*Listener, error) {
cfg := ListenerConfig{
Listener: l,
Expand All @@ -247,6 +250,7 @@ func NewFromListener(
ConnBufferPooling: connBufferPooling,
ConnKeepAlivePeriod: keepAlivePeriod,
FlushDelay: flushDelay,
MultiQuery: multiQuery,
}
return NewListenerWithConfig(cfg)
}
Expand All @@ -262,17 +266,18 @@ func NewListener(
connBufferPooling bool,
keepAlivePeriod time.Duration,
flushDelay time.Duration,
multiQuery bool,
) (*Listener, error) {
listener, err := net.Listen(protocol, address)
if err != nil {
return nil, err
}
if proxyProtocol {
proxyListener := &proxyproto.Listener{Listener: listener}
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay)
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, multiQuery)
}

return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay)
return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, multiQuery)
}

// ListenerConfig should be used with NewListenerWithConfig to specify listener parameters.
Expand All @@ -289,6 +294,7 @@ type ListenerConfig struct {
ConnBufferPooling bool
ConnKeepAlivePeriod time.Duration
FlushDelay time.Duration
MultiQuery bool
}

// NewListenerWithConfig creates new listener using provided config. There are
Expand Down Expand Up @@ -317,6 +323,7 @@ func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) {
connBufferPooling: cfg.ConnBufferPooling,
connKeepAlivePeriod: cfg.ConnKeepAlivePeriod,
flushDelay: cfg.FlushDelay,
multiQuery: cfg.MultiQuery,
truncateErrLen: cfg.Handler.Env().TruncateErrLen(),
charset: cfg.Handler.Env().CollationEnv().DefaultConnectionCharset(),
}, nil
Expand Down
Loading
Loading