Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions go/test/endtoend/vtgate/system_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ func TestInformationSchemaQuery(t *testing.T) {
require.NoError(t, err)
defer conn.Close()

assertSingleRowIsReturned(t, conn, "table_schema = 'ks'")
assertSingleRowIsReturned(t, conn, "table_schema = 'vt_ks'")
assertSingleRowIsReturned(t, conn, "table_schema = 'ks'", "vt_ks")
assertSingleRowIsReturned(t, conn, "table_schema = 'vt_ks'", "vt_ks")
assertResultIsEmpty(t, conn, "table_schema = 'NONE'")
assertSingleRowIsReturned(t, conn, "table_schema = 'performance_schema'", "performance_schema")
assertResultIsEmpty(t, conn, "table_schema = 'PERFORMANCE_SCHEMA'")
assertSingleRowIsReturned(t, conn, "table_schema = 'performance_schema' and table_name = 'users'", "performance_schema")
assertResultIsEmpty(t, conn, "table_schema = 'performance_schema' and table_name = 'foo'")
}

func assertResultIsEmpty(t *testing.T, conn *mysql.Conn, pre string) {
Expand All @@ -84,12 +88,12 @@ func assertResultIsEmpty(t *testing.T, conn *mysql.Conn, pre string) {
})
}

func assertSingleRowIsReturned(t *testing.T, conn *mysql.Conn, predicate string) {
func assertSingleRowIsReturned(t *testing.T, conn *mysql.Conn, predicate string, expectedKs string) {
t.Run(predicate, func(t *testing.T) {
qr, err := conn.ExecuteFetch("SELECT distinct table_schema FROM information_schema.tables WHERE "+predicate, 1000, true)
require.NoError(t, err)
assert.Equal(t, 1, len(qr.Rows), "did not get enough rows back")
assert.Equal(t, "vt_ks", qr.Rows[0][0].ToString())
assert.Equal(t, expectedKs, qr.Rows[0][0].ToString())
})
}

Expand Down
69 changes: 40 additions & 29 deletions go/vt/vtgate/engine/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,53 +417,59 @@ func (route *Route) routeInfoSchemaQuery(vcursor VCursor, bindVars map[string]*q
Row: []sqltypes.Value{},
}

var specifiedKS string
if route.SysTableTableSchema != nil {
result, err := route.SysTableTableSchema.Evaluate(env)
if err != nil {
return nil, err
}
specifiedKS = result.Value().ToString()
bindVars[sqltypes.BvSchemaName] = sqltypes.StringBindVariable(specifiedKS)
}

var tableName string
if route.SysTableTableName != nil {
// the use has specified a table_name - let's check if it's a routed table
rss, err := route.paramsRoutedTable(vcursor, env, bindVars)
val, err := route.SysTableTableName.Evaluate(env)
if err != nil {
return nil, err
}
tableName = val.Value().ToString()
bindVars[BvTableName] = sqltypes.StringBindVariable(tableName)
}

// if the table_schema is system system, route to default keyspace.
if sqlparser.SystemSchema(specifiedKS) {
return defaultRoute()
}

// the use has specified a table_name - let's check if it's a routed table
if tableName != "" {
rss, err := route.paramsRoutedTable(vcursor, bindVars, specifiedKS, tableName)
if err != nil {
return nil, err
}
if rss != nil {
return rss, nil
}
// it was not a routed table, and we dont have a schema name to look up. give up
if route.SysTableTableSchema == nil {
return defaultRoute()
}
}

// we only have table_schema to work with
result, err := route.SysTableTableSchema.Evaluate(env)
if err != nil {
return nil, err
// it was not a routed table, and we dont have a schema name to look up. give up
if specifiedKS == "" {
return defaultRoute()
}
specifiedKS := result.Value().ToString()

// we only have table_schema to work with
destinations, _, err := vcursor.ResolveDestinations(specifiedKS, nil, []key.Destination{key.DestinationAnyShard{}})
if err != nil {
log.Errorf("failed to route information_schema query to keyspace [%s]", specifiedKS)
bindVars[sqltypes.BvSchemaName] = sqltypes.StringBindVariable(specifiedKS)
return defaultRoute()
}
bindVars[sqltypes.BvReplaceSchemaName] = sqltypes.Int64BindVariable(1)
setReplaceSchemaName(bindVars)
return destinations, nil
}

func (route *Route) paramsRoutedTable(vcursor VCursor, env evalengine.ExpressionEnv, bindVars map[string]*querypb.BindVariable) ([]*srvtopo.ResolvedShard, error) {
val, err := route.SysTableTableName.Evaluate(env)
if err != nil {
return nil, err
}
tableName := val.Value().ToString()

var tableSchema string
if route.SysTableTableSchema != nil {
val, err := route.SysTableTableSchema.Evaluate(env)
if err != nil {
return nil, err
}
tableSchema = val.Value().ToString()
}

func (route *Route) paramsRoutedTable(vcursor VCursor, bindVars map[string]*querypb.BindVariable, tableSchema string, tableName string) ([]*srvtopo.ResolvedShard, error) {
tbl := sqlparser.TableName{
Name: sqlparser.NewTableIdent(tableName),
Qualifier: sqlparser.NewTableIdent(tableSchema),
Expand All @@ -478,7 +484,7 @@ func (route *Route) paramsRoutedTable(vcursor VCursor, env evalengine.Expression
shards, _, err := vcursor.ResolveDestinations(destination.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}})
bindVars[BvTableName] = sqltypes.StringBindVariable(destination.Name.String())
if tableSchema != "" {
bindVars[sqltypes.BvReplaceSchemaName] = sqltypes.Int64BindVariable(1)
setReplaceSchemaName(bindVars)
}
return shards, err
}
Expand All @@ -488,6 +494,11 @@ func (route *Route) paramsRoutedTable(vcursor VCursor, env evalengine.Expression
return nil, nil
}

func setReplaceSchemaName(bindVars map[string]*querypb.BindVariable) {
delete(bindVars, sqltypes.BvSchemaName)
bindVars[sqltypes.BvReplaceSchemaName] = sqltypes.Int64BindVariable(1)
}

func (route *Route) paramsAnyShard(vcursor VCursor, bindVars map[string]*querypb.BindVariable) ([]*srvtopo.ResolvedShard, []map[string]*querypb.BindVariable, error) {
rss, _, err := vcursor.ResolveDestinations(route.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}})
if err != nil {
Expand Down
Loading