diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index 72a1ecce87c..28577c7c5b3 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -50,7 +50,7 @@ func TestValidCert(t *testing.T) { authServer := newAuthServerClientCert(string(MysqlClearPassword)) // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -118,7 +118,7 @@ func TestNoCert(t *testing.T) { authServer := newAuthServerClientCert(string(MysqlClearPassword)) // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() diff --git a/go/mysql/client_test.go b/go/mysql/client_test.go index da577e338b2..703ac56a1ed 100644 --- a/go/mysql/client_test.go +++ b/go/mysql/client_test.go @@ -151,7 +151,7 @@ func TestTLSClientDisabled(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) defer l.Close() @@ -223,7 +223,7 @@ func TestTLSClientPreferredDefault(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) defer l.Close() @@ -296,7 +296,7 @@ func TestTLSClientRequired(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) defer l.Close() @@ -343,7 +343,7 @@ func TestTLSClientVerifyCA(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) defer l.Close() @@ -426,7 +426,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) defer l.Close() diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 07960ec4146..c16890c8b2e 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -30,8 +30,6 @@ import ( "sync/atomic" "time" - "github.com/spf13/pflag" - "vitess.io/vitess/go/bucketpool" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/sqlerror" @@ -40,7 +38,6 @@ import ( "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" ) @@ -70,19 +67,6 @@ const ( ephemeralRead ) -var ( - mysqlMultiQuery = false -) - -func registerConnFlags(fs *pflag.FlagSet) { - fs.BoolVar(&mysqlMultiQuery, "mysql-server-multi-query-protocol", mysqlMultiQuery, "If set, the server will use the new implementation of handling queries where-in multiple queries are sent together.") -} - -func init() { - servenv.OnParseFor("vtgate", registerConnFlags) - servenv.OnParseFor("vtcombo", registerConnFlags) -} - // A Getter has a Get() type Getter interface { Get() *querypb.VTGateCallerID @@ -222,6 +206,8 @@ type Conn struct { // This is currently used for testing. keepAliveOn bool + multiQuery bool + // mu protects the fields below mu sync.Mutex // cancel keep the cancel function for the current executing query. @@ -298,6 +284,7 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn { keepAliveOn: enabledKeepAlive, flushDelay: listener.flushDelay, truncateErrLen: listener.truncateErrLen, + multiQuery: listener.multiQuery, } if listener.connReadBufferSize > 0 { @@ -930,7 +917,7 @@ func (c *Conn) handleNextCommand(handler Handler) bool { res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false) return res != connErr case ComQuery: - if mysqlMultiQuery { + if c.multiQuery { return c.handleComQueryMulti(handler, data) } return c.handleComQuery(handler, data) diff --git a/go/mysql/conn_params.go b/go/mysql/conn_params.go index 46e733f6021..1a16a409b61 100644 --- a/go/mysql/conn_params.go +++ b/go/mysql/conn_params.go @@ -65,6 +65,8 @@ type ConnParams struct { FlushDelay time.Duration TruncateErrLen int + + MultiQuery bool } // EnableSSL will set the right flag on the parameters. diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index 96f707eec5e..ccf5614498a 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -804,14 +804,10 @@ func TestIsEOFPacket(t *testing.T) { } func TestMultiStatementStopsOnError(t *testing.T) { - origMysqlMultiQuery := mysqlMultiQuery - defer func() { - mysqlMultiQuery = origMysqlMultiQuery - }() for _, b := range []bool{true, false} { t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) { - mysqlMultiQuery = b listener, sConn, cConn := createSocketPair(t) + sConn.multiQuery = b sConn.Capabilities |= CapabilityClientMultiStatements defer func() { listener.Close() @@ -839,14 +835,10 @@ func TestMultiStatementStopsOnError(t *testing.T) { } func TestEmptyQuery(t *testing.T) { - origMysqlMultiQuery := mysqlMultiQuery - defer func() { - mysqlMultiQuery = origMysqlMultiQuery - }() for _, b := range []bool{true, false} { t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) { - mysqlMultiQuery = b listener, sConn, cConn := createSocketPair(t) + sConn.multiQuery = b sConn.Capabilities |= CapabilityClientMultiStatements defer func() { listener.Close() @@ -873,14 +865,10 @@ func TestEmptyQuery(t *testing.T) { } func TestMultiStatement(t *testing.T) { - origMysqlMultiQuery := mysqlMultiQuery - defer func() { - mysqlMultiQuery = origMysqlMultiQuery - }() for _, b := range []bool{true, false} { t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) { - mysqlMultiQuery = b listener, sConn, cConn := createSocketPair(t) + sConn.multiQuery = b sConn.Capabilities |= CapabilityClientMultiStatements defer func() { listener.Close() @@ -930,14 +918,10 @@ func TestMultiStatement(t *testing.T) { } func TestMultiStatementOnSplitError(t *testing.T) { - origMysqlMultiQuery := mysqlMultiQuery - defer func() { - mysqlMultiQuery = origMysqlMultiQuery - }() for _, b := range []bool{true, false} { t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) { - mysqlMultiQuery = b listener, sConn, cConn := createSocketPair(t) + sConn.multiQuery = b sConn.Capabilities |= CapabilityClientMultiStatements defer func() { listener.Close() @@ -990,13 +974,8 @@ func TestInitDbAgainstWrongDbDoesNotDropConnection(t *testing.T) { } func TestConnectionErrorWhileWritingComQuery(t *testing.T) { - origMysqlMultiQuery := mysqlMultiQuery - defer func() { - mysqlMultiQuery = origMysqlMultiQuery - }() for _, b := range []bool{true, false} { t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) { - mysqlMultiQuery = b // Set the conn for the server connection to the simulated connection which always returns an error on writing sConn := newConn(testConn{ writeToPass: []bool{false, true}, diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 2d9ace82b52..738646591bd 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -198,7 +198,7 @@ func NewWithEnv(t testing.TB, env *vtenv.Environment) *DB { authServer := mysql.NewAuthServerNone() // Start listening. - db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0, 0) + db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0, 0, false) if err != nil { t.Fatalf("NewListener failed: %v", err) } diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index 13ed1099e58..82309fb9fe1 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -47,7 +47,7 @@ func TestClearTextClientAuth(t *testing.T) { defer authServer.close() // Create the listener. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -100,7 +100,7 @@ func TestSSLConnection(t *testing.T) { defer authServer.close() // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed: %v", err) host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port diff --git a/go/mysql/mysql_fuzzer.go b/go/mysql/mysql_fuzzer.go index 3ec302151f2..5ce82cd56c0 100644 --- a/go/mysql/mysql_fuzzer.go +++ b/go/mysql/mysql_fuzzer.go @@ -77,8 +77,8 @@ func createFuzzingSocketPair() (net.Listener, *Conn, *Conn) { } // Create a Conn on both sides. - cConn := newConn(clientConn, DefaultFlushDelay) - sConn := newConn(serverConn, DefaultFlushDelay) + cConn := newConn(clientConn, DefaultFlushDelay, 0, false) + sConn := newConn(serverConn, DefaultFlushDelay, 0, false) return listener, sConn, cConn } @@ -197,7 +197,7 @@ func FuzzHandleNextCommand(data []byte) int { writeToPass: []bool{false}, pos: -1, queryPacket: data, - }, DefaultFlushDelay) + }, DefaultFlushDelay, 0, false) sConn.PrepareData = map[uint32]*PrepareData{} handler := &fuzztestRun{} diff --git a/go/mysql/server.go b/go/mysql/server.go index 46984f0e336..e6a274f7782 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -201,6 +201,8 @@ type Listener struct { // connBufferPooling configures if vtgate server pools connection buffers connBufferPooling bool + multiQuery bool + // connKeepAlivePeriod is period between tcp keep-alives. connKeepAlivePeriod time.Duration @@ -236,6 +238,7 @@ func NewFromListener( connBufferPooling bool, keepAlivePeriod time.Duration, flushDelay time.Duration, + multiQuery bool, ) (*Listener, error) { cfg := ListenerConfig{ Listener: l, @@ -247,6 +250,7 @@ func NewFromListener( ConnBufferPooling: connBufferPooling, ConnKeepAlivePeriod: keepAlivePeriod, FlushDelay: flushDelay, + MultiQuery: multiQuery, } return NewListenerWithConfig(cfg) } @@ -262,6 +266,7 @@ func NewListener( connBufferPooling bool, keepAlivePeriod time.Duration, flushDelay time.Duration, + multiQuery bool, ) (*Listener, error) { listener, err := net.Listen(protocol, address) if err != nil { @@ -269,10 +274,10 @@ func NewListener( } if proxyProtocol { proxyListener := &proxyproto.Listener{Listener: listener} - return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay) + return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, multiQuery) } - return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay) + return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod, flushDelay, multiQuery) } // ListenerConfig should be used with NewListenerWithConfig to specify listener parameters. @@ -289,6 +294,7 @@ type ListenerConfig struct { ConnBufferPooling bool ConnKeepAlivePeriod time.Duration FlushDelay time.Duration + MultiQuery bool } // NewListenerWithConfig creates new listener using provided config. There are @@ -317,6 +323,7 @@ func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) { connBufferPooling: cfg.ConnBufferPooling, connKeepAlivePeriod: cfg.ConnKeepAlivePeriod, flushDelay: cfg.FlushDelay, + multiQuery: cfg.MultiQuery, truncateErrLen: cfg.Handler.Env().TruncateErrLen(), charset: cfg.Handler.Env().CollationEnv().DefaultConnectionCharset(), }, nil diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index b1c86576175..781c142e7eb 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -300,7 +300,7 @@ func TestConnectionFromListener(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:") require.NoError(t, err, "net.Listener failed") - l, err := NewFromListener(listener, authServer, th, 0, 0, false, 0, 0) + l, err := NewFromListener(listener, authServer, th, 0, 0, false, 0, 0, false) require.NoError(t, err, "NewListener failed") host, port := getHostPort(t, l.Addr()) fmt.Printf("host: %s, port: %d\n", host, port) @@ -330,7 +330,7 @@ func TestConnectionWithoutSourceHost(t *testing.T) { }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed") host, port := getHostPort(t, l.Addr()) // Setup the right parameters. @@ -362,7 +362,7 @@ func TestConnectionWithSourceHost(t *testing.T) { } defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed") host, port := getHostPort(t, l.Addr()) // Setup the right parameters. @@ -394,7 +394,7 @@ func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { } defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed") host, port := getHostPort(t, l.Addr()) // Setup the right parameters. @@ -431,7 +431,7 @@ func TestConnectionUnixSocket(t *testing.T) { os.Remove(unixSocket.Name()) - l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed") // Setup the right parameters. params := &ConnParams{ @@ -458,7 +458,7 @@ func TestClientFoundRows(t *testing.T) { }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed") host, port := getHostPort(t, l.Addr()) // Setup the right parameters. @@ -502,7 +502,7 @@ func TestConnCounts(t *testing.T) { }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed") host, port := getHostPort(t, l.Addr()) // Test with one new connection. @@ -556,7 +556,7 @@ func TestServer(t *testing.T) { }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) host, port := getHostPort(t, l.Addr()) // Setup the right parameters. @@ -656,7 +656,7 @@ func TestServerStats(t *testing.T) { }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) host, port := getHostPort(t, l.Addr()) // Setup the right parameters. @@ -744,7 +744,7 @@ func TestClearTextServer(t *testing.T) { }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) host, port := getHostPort(t, l.Addr()) // Setup the right parameters. @@ -817,7 +817,7 @@ func TestDialogServer(t *testing.T) { }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) l.AllowClearTextWithoutTLS.Store(true) host, port := getHostPort(t, l.Addr()) @@ -866,7 +866,7 @@ func TestTLSServer(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port @@ -974,7 +974,7 @@ func TestTLSRequired(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err = NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err = NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port @@ -1047,7 +1047,7 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed: %v", err) host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port @@ -1133,7 +1133,7 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed: %v", err) host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port @@ -1190,7 +1190,7 @@ func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { defer authServer.close() // Create the listener. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err, "NewListener failed: %v", err) host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port @@ -1230,7 +1230,7 @@ func TestErrorCodes(t *testing.T) { }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) host, port := getHostPort(t, l.Addr()) // Setup the right parameters. @@ -1407,7 +1407,7 @@ func TestListenerShutdown(t *testing.T) { }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) host, port := getHostPort(t, l.Addr()) // Setup the right parameters. @@ -1480,7 +1480,7 @@ func TestServerFlush(t *testing.T) { mysqlServerFlushDelay := 10 * time.Millisecond th := &testHandler{} - l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, mysqlServerFlushDelay) + l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, mysqlServerFlushDelay, false) require.NoError(t, err) host, port := getHostPort(t, l.Addr()) params := &ConnParams{ @@ -1527,7 +1527,7 @@ func TestTcpKeepAlive(t *testing.T) { ctx := utils.LeakCheckContext(t) th := &testHandler{} - l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, 0) + l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) host, port := getHostPort(t, l.Addr()) params := &ConnParams{ diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 7e285f157a4..265d8b7c9cb 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -78,6 +78,7 @@ var ( mysqlDrainOnTerm bool mysqlServerFlushDelay = 100 * time.Millisecond + mysqlServerMultiQuery = false ) func registerPluginFlags(fs *pflag.FlagSet) { @@ -104,6 +105,7 @@ func registerPluginFlags(fs *pflag.FlagSet) { fs.DurationVar(&mysqlServerFlushDelay, "mysql_server_flush_delay", mysqlServerFlushDelay, "Delay after which buffered response will be flushed to the client.") fs.StringVar(&mysqlDefaultWorkloadName, "mysql_default_workload", mysqlDefaultWorkloadName, "Default session workload (OLTP, OLAP, DBA)") fs.BoolVar(&mysqlDrainOnTerm, "mysql-server-drain-onterm", mysqlDrainOnTerm, "If set, the server waits for --onterm_timeout for already connected clients to complete their in flight work") + fs.BoolVar(&mysqlServerMultiQuery, "mysql-server-multi-query-protocol", mysqlServerMultiQuery, "If set, the server will use the new implementation of handling queries where-in multiple queries are sent together.") } // vtgateHandler implements the Listener interface. @@ -620,6 +622,7 @@ func initMySQLProtocol(vtgate *VTGate) *mysqlServer { mysqlConnBufferPooling, mysqlKeepAlivePeriod, mysqlServerFlushDelay, + mysqlServerMultiQuery, ) if err != nil { log.Exitf("mysql.NewListener failed: %v", err) @@ -665,6 +668,7 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys mysqlConnBufferPooling, mysqlKeepAlivePeriod, mysqlServerFlushDelay, + mysqlServerMultiQuery, ) switch err := err.(type) { @@ -698,6 +702,7 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys mysqlConnBufferPooling, mysqlKeepAlivePeriod, mysqlServerFlushDelay, + mysqlServerMultiQuery, ) return listener, listenerErr default: diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index a311790d771..a77d66d32b0 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -784,7 +784,7 @@ func TestComQueryMulti(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) th := &testHandler{} - listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0, 0) + listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) defer listener.Close() @@ -825,7 +825,7 @@ func TestGracefulShutdown(t *testing.T) { vh := newVtgateHandler(&VTGate{executor: executor, timings: timings, rowsReturned: rowsReturned, rowsAffected: rowsAffected, queryTextCharsProcessed: queryTextCharsProcessed}) th := &testHandler{} - listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0, 0) + listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) defer listener.Close() @@ -863,7 +863,7 @@ func TestGracefulShutdownWithTransaction(t *testing.T) { vh := newVtgateHandler(&VTGate{executor: executor, timings: timings, rowsReturned: rowsReturned, rowsAffected: rowsAffected, queryTextCharsProcessed: queryTextCharsProcessed}) th := &testHandler{} - listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0, 0) + listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0, 0, false) require.NoError(t, err) defer listener.Close()