diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 9ee244b6dd4..adf44dac897 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -29,6 +29,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/sqlescape" + "vitess.io/vitess/go/bucketpool" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/sync2" @@ -774,7 +776,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { case ComInitDB: db := c.parseComInitDB(data) c.recycleReadPacket() - if err := c.execQuery(fmt.Sprintf("use `%s`", db), handler, false); err != nil { + if err := c.execQuery("use "+sqlescape.EscapeID(db), handler, false); err != nil { return err } case ComQuery: diff --git a/go/mysql/server.go b/go/mysql/server.go index bc6e1e93266..bc3ecc8dea7 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -18,13 +18,14 @@ package mysql import ( "crypto/tls" - "fmt" "io" "net" "strings" "sync/atomic" "time" + "vitess.io/vitess/go/sqlescape" + proxyproto "github.com/pires/go-proxyproto" "vitess.io/vitess/go/netutil" "vitess.io/vitess/go/sqltypes" @@ -443,7 +444,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti // Set initial db name. if c.schemaName != "" { - err = l.handler.ComQuery(c, fmt.Sprintf("use `%s`", c.schemaName), func(result *sqltypes.Result) error { + err = l.handler.ComQuery(c, "use "+sqlescape.EscapeID(c.schemaName), func(result *sqltypes.Result) error { return nil }) if err != nil { diff --git a/go/vt/vttablet/tabletmanager/rpc_lock_tables.go b/go/vt/vttablet/tabletmanager/rpc_lock_tables.go index 149991bb3d8..71163e70bcb 100644 --- a/go/vt/vttablet/tabletmanager/rpc_lock_tables.go +++ b/go/vt/vttablet/tabletmanager/rpc_lock_tables.go @@ -110,7 +110,7 @@ func (tm *TabletManager) lockTablesUsingLockTables(conn *dbconnpool.DBConnection tableNames = append(tableNames, fmt.Sprintf("%s READ", sqlescape.EscapeID(name))) } lockStatement := fmt.Sprintf("LOCK TABLES %v", strings.Join(tableNames, ", ")) - _, err := conn.ExecuteFetch(fmt.Sprintf("USE %s", tm.DBConfigs.DBName), 0, false) + _, err := conn.ExecuteFetch("USE "+sqlescape.EscapeID(tm.DBConfigs.DBName), 0, false) if err != nil { return err } diff --git a/go/vt/vttablet/tabletmanager/rpc_query.go b/go/vt/vttablet/tabletmanager/rpc_query.go index 1c919764ff8..70775bee7c6 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query.go +++ b/go/vt/vttablet/tabletmanager/rpc_query.go @@ -18,6 +18,7 @@ package tabletmanager import ( "golang.org/x/net/context" + "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" @@ -44,7 +45,7 @@ func (tm *TabletManager) ExecuteFetchAsDba(ctx context.Context, query []byte, db if dbName != "" { // This execute might fail if db does not exist. // Error is ignored because given query might create this database. - conn.ExecuteFetch("USE "+dbName, 1, false) + conn.ExecuteFetch("USE "+sqlescape.EscapeID(dbName), 1, false) } // run the query @@ -81,7 +82,7 @@ func (tm *TabletManager) ExecuteFetchAsAllPrivs(ctx context.Context, query []byt if dbName != "" { // This execute might fail if db does not exist. // Error is ignored because given query might create this database. - conn.ExecuteFetch("USE "+dbName, 1, false) + conn.ExecuteFetch("USE "+sqlescape.EscapeID(dbName), 1, false) } // run the query diff --git a/go/vt/vttablet/tabletmanager/rpc_query_test.go b/go/vt/vttablet/tabletmanager/rpc_query_test.go new file mode 100644 index 00000000000..d5b386ce3ba --- /dev/null +++ b/go/vt/vttablet/tabletmanager/rpc_query_test.go @@ -0,0 +1,52 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tabletmanager + +import ( + "context" + "testing" + + "vitess.io/vitess/go/sqltypes" + + "vitess.io/vitess/go/mysql/fakesqldb" + + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/vt/dbconfigs" + "vitess.io/vitess/go/vt/mysqlctl/fakemysqldaemon" + "vitess.io/vitess/go/vt/vttablet/tabletservermock" + + "github.com/stretchr/testify/require" +) + +func TestTabletManager_ExecuteFetchAsDba(t *testing.T) { + ctx := context.Background() + cp := mysql.ConnParams{} + db := fakesqldb.New(t) + db.AddQueryPattern(".*", &sqltypes.Result{}) + daemon := fakemysqldaemon.NewFakeMysqlDaemon(db) + + dbName := " escap`e me " + tm := &TabletManager{ + MysqlDaemon: daemon, + DBConfigs: dbconfigs.NewTestDBConfigs(cp, cp, dbName), + QueryServiceControl: tabletservermock.NewController(), + } + + _, err := tm.ExecuteFetchAsDba(ctx, []byte("select 42"), dbName, 10, false, false) + require.NoError(t, err) + require.Equal(t, "use ` escap``e me `;select 42", db.QueryLog()) +} diff --git a/go/vt/wrangler/testlib/apply_schema_flaky_test.go b/go/vt/wrangler/testlib/apply_schema_flaky_test.go index 534d2fff8e2..e2a0b438160 100644 --- a/go/vt/wrangler/testlib/apply_schema_flaky_test.go +++ b/go/vt/wrangler/testlib/apply_schema_flaky_test.go @@ -95,7 +95,7 @@ func TestApplySchema_AllowLongUnavailability(t *testing.T) { ft.FakeMysqlDaemon.PreflightSchemaChangeResult = preflightSchemaChanges } - changeToDb := "USE vt_ks" + changeToDb := "USE `vt_ks`" addColumn := "ALTER TABLE table1 ADD COLUMN new_id bigint(20)" db.AddQuery(changeToDb, &sqltypes.Result{}) db.AddQuery(addColumn, &sqltypes.Result{}) diff --git a/go/vt/wrangler/testlib/copy_schema_shard_test.go b/go/vt/wrangler/testlib/copy_schema_shard_test.go index 35523a7dd26..bd758a019d4 100644 --- a/go/vt/wrangler/testlib/copy_schema_shard_test.go +++ b/go/vt/wrangler/testlib/copy_schema_shard_test.go @@ -98,7 +98,7 @@ func copySchema(t *testing.T, useShardAsSource bool) { sourceMaster.FakeMysqlDaemon.Schema = schema sourceRdonly.FakeMysqlDaemon.Schema = schema - changeToDb := "USE vt_ks" + changeToDb := "USE `vt_ks`" createDb := "CREATE DATABASE `vt_ks` /*!40100 DEFAULT CHARACTER SET utf8 */" createTable := "CREATE TABLE `vt_ks`.`table1` (\n" + " `id` bigint(20) NOT NULL AUTO_INCREMENT,\n" +