diff --git a/go/test/endtoend/vtgate/dbnameoverride/tablet_master_test.go b/go/test/endtoend/vtgate/dbnameoverride/tablet_master_test.go deleted file mode 100644 index d4bae5200da..00000000000 --- a/go/test/endtoend/vtgate/dbnameoverride/tablet_master_test.go +++ /dev/null @@ -1,134 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package master - -import ( - "context" - "flag" - "os" - "testing" - - "github.com/stretchr/testify/require" - "vitess.io/vitess/go/mysql" - - "vitess.io/vitess/go/test/endtoend/cluster" -) - -var ( - clusterInstance *cluster.LocalProcessCluster - vtParams mysql.ConnParams - hostname = "localhost" - keyspaceName = "ks" - cell = "zone1" - sqlSchema = ` - create table t1( - id bigint, - value varchar(16), - primary key(id) - ) Engine=InnoDB; -` - - vSchema = ` - { - "sharded": true, - "vindexes": { - "hash": { - "type": "hash" - } - }, - "tables": { - "t1": { - "column_vindexes": [ - { - "column": "id", - "name": "hash" - } - ] - } - } - }` -) - -const dbName = "myDbName" - -func TestMain(m *testing.M) { - defer cluster.PanicHandler(nil) - flag.Parse() - - exitCode := func() int { - clusterInstance = cluster.NewCluster(cell, hostname) - defer clusterInstance.Teardown() - - // Start topo server - err := clusterInstance.StartTopo() - if err != nil { - return 1 - } - - // Set extra tablet args for lock timeout - clusterInstance.VtTabletExtraArgs = []string{ - "-init_db_name_override", dbName, - } - - // Start keyspace - keyspace := &cluster.Keyspace{ - Name: keyspaceName, - SchemaSQL: sqlSchema, - VSchema: vSchema, - } - - if err = clusterInstance.StartUnshardedKeyspace(*keyspace, 1, false); err != nil { - return 1 - } - - if err = clusterInstance.StartVtgate(); err != nil { - return 1 - } - vtParams = mysql.ConnParams{ - Host: clusterInstance.Hostname, - Port: clusterInstance.VtgateMySQLPort, - } - - return m.Run() - }() - os.Exit(exitCode) -} - -func TestDbNameOverride(t *testing.T) { - defer cluster.PanicHandler(t) - ctx := context.Background() - conn, err := mysql.Connect(ctx, &vtParams) - require.Nil(t, err) - defer conn.Close() - qr, err := conn.ExecuteFetch("SELECT database() FROM information_schema.tables WHERE table_schema = database()", 1000, true) - - require.Nil(t, err) - require.Equal(t, 1, len(qr.Rows), "did not get enough rows back") - require.Equal(t, dbName, qr.Rows[0][0].ToString()) -} - -func TestInformationSchemaQuery(t *testing.T) { - defer cluster.PanicHandler(t) - ctx := context.Background() - conn, err := mysql.Connect(ctx, &vtParams) - require.Nil(t, err) - defer conn.Close() - qr, err := conn.ExecuteFetch("SELECT TABLE_NAME FROM information_schema.tables WHERE table_schema = 'ks'", 1000, true) - - require.Nil(t, err) - require.Equal(t, 1, len(qr.Rows), "did not get enough rows back") - require.Equal(t, "t1", qr.Rows[0][0].ToString()) -} diff --git a/go/test/endtoend/vtgate/misc_test.go b/go/test/endtoend/vtgate/misc_test.go index d69e06008b4..6d5f1593447 100644 --- a/go/test/endtoend/vtgate/misc_test.go +++ b/go/test/endtoend/vtgate/misc_test.go @@ -288,6 +288,37 @@ func TestShowTablesWithWhereClause(t *testing.T) { assertMatches(t, conn, "show tables from ks where Tables_in_ks='t3'", `[[VARCHAR("t3")]]`) } +func TestDbNameOverride(t *testing.T) { + defer cluster.PanicHandler(t) + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.Nil(t, err) + defer conn.Close() + qr, err := conn.ExecuteFetch("SELECT distinct database() FROM information_schema.tables WHERE table_schema = database()", 1000, true) + + require.Nil(t, err) + assert.Equal(t, 1, len(qr.Rows), "did not get enough rows back") + assert.Equal(t, "vt_ks", qr.Rows[0][0].ToString()) +} + +func TestInformationSchemaQuery(t *testing.T) { + defer cluster.PanicHandler(t) + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.Nil(t, err) + defer conn.Close() + + qr, err := conn.ExecuteFetch("SELECT distinct table_schema FROM information_schema.tables WHERE table_schema = 'ks'", 1000, true) + require.Nil(t, err) + assert.Equal(t, 1, len(qr.Rows), "did not get enough rows back") + assert.Equal(t, "vt_ks", qr.Rows[0][0].ToString()) + + qr, err = conn.ExecuteFetch("SELECT distinct table_schema FROM information_schema.tables WHERE table_schema = 'vt_ks'", 1000, true) + require.Nil(t, err) + assert.Equal(t, 1, len(qr.Rows), "did not get enough rows back") + assert.Equal(t, "vt_ks", qr.Rows[0][0].ToString()) +} + func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) { t.Helper() qr := exec(t, conn, query) diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 962665a6a21..a6843a28aa7 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -186,6 +186,8 @@ var routeName = map[RouteOpcode]string{ var ( partialSuccessScatterQueries = stats.NewCounter("PartialSuccessScatterQueries", "Count of partially successful scatter queries") + // BvSchemaName is bind variable to be sent down to vttablet for schema name. + BvSchemaName = "__vtschemaname" ) // MarshalJSON serializes the RouteOpcode as a JSON string. @@ -391,6 +393,7 @@ func (route *Route) paramsSystemQuery(vcursor VCursor, bindVars map[string]*quer } if keyspace == "" { keyspace = result.Value().ToString() + bindVars[BvSchemaName] = sqltypes.StringBindVariable(keyspace) } else if other := result.Value().ToString(); keyspace != other { return nil, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "can't use more than one keyspace per system table query - found both '%s' and '%s'", keyspace, other) } @@ -402,7 +405,11 @@ func (route *Route) paramsSystemQuery(vcursor VCursor, bindVars map[string]*quer destinations, _, err := vcursor.ResolveDestinations(keyspace, nil, []key.Destination{key.DestinationAnyShard{}}) if err != nil { - return nil, nil, vterrors.Wrapf(err, "failed to find information about keyspace `%s`", keyspace) + // Check with assigned route keyspace. + destinations, _, err = vcursor.ResolveDestinations(route.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}}) + if err != nil { + return nil, nil, vterrors.Wrapf(err, "failed to find information about keyspace `%s`", keyspace) + } } return destinations, []map[string]*querypb.BindVariable{bindVars}, nil } diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index a5c5fab208a..2a5f0c97933 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -257,7 +257,7 @@ func (r *rewriter) rewriteTableSchema(cursor *sqlparser.Cursor) bool { return false } r.tableNameExpressions = append(r.tableNameExpressions, evalExpr) - parent.Right = sqlparser.NewArgument([]byte(":__vtschemaname")) + parent.Right = sqlparser.NewArgument([]byte(":" + engine.BvSchemaName)) } } }