diff --git a/go/test/endtoend/vtgate/system_schema_test.go b/go/test/endtoend/vtgate/system_schema_test.go index e5a100bd935..5954ee5af76 100644 --- a/go/test/endtoend/vtgate/system_schema_test.go +++ b/go/test/endtoend/vtgate/system_schema_test.go @@ -71,9 +71,13 @@ func TestInformationSchemaQuery(t *testing.T) { require.NoError(t, err) defer conn.Close() - assertSingleRowIsReturned(t, conn, "table_schema = 'ks'") - assertSingleRowIsReturned(t, conn, "table_schema = 'vt_ks'") + assertSingleRowIsReturned(t, conn, "table_schema = 'ks'", "vt_ks") + assertSingleRowIsReturned(t, conn, "table_schema = 'vt_ks'", "vt_ks") assertResultIsEmpty(t, conn, "table_schema = 'NONE'") + assertSingleRowIsReturned(t, conn, "table_schema = 'performance_schema'", "performance_schema") + assertResultIsEmpty(t, conn, "table_schema = 'PERFORMANCE_SCHEMA'") + assertSingleRowIsReturned(t, conn, "table_schema = 'performance_schema' and table_name = 'users'", "performance_schema") + assertResultIsEmpty(t, conn, "table_schema = 'performance_schema' and table_name = 'foo'") } func assertResultIsEmpty(t *testing.T, conn *mysql.Conn, pre string) { @@ -84,12 +88,12 @@ func assertResultIsEmpty(t *testing.T, conn *mysql.Conn, pre string) { }) } -func assertSingleRowIsReturned(t *testing.T, conn *mysql.Conn, predicate string) { +func assertSingleRowIsReturned(t *testing.T, conn *mysql.Conn, predicate string, expectedKs string) { t.Run(predicate, func(t *testing.T) { qr, err := conn.ExecuteFetch("SELECT distinct table_schema FROM information_schema.tables WHERE "+predicate, 1000, true) require.NoError(t, err) assert.Equal(t, 1, len(qr.Rows), "did not get enough rows back") - assert.Equal(t, "vt_ks", qr.Rows[0][0].ToString()) + assert.Equal(t, expectedKs, qr.Rows[0][0].ToString()) }) } diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 71067fa51a9..61407b556f2 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -417,53 +417,59 @@ func (route *Route) routeInfoSchemaQuery(vcursor VCursor, bindVars map[string]*q Row: []sqltypes.Value{}, } + var specifiedKS string + if route.SysTableTableSchema != nil { + result, err := route.SysTableTableSchema.Evaluate(env) + if err != nil { + return nil, err + } + specifiedKS = result.Value().ToString() + bindVars[sqltypes.BvSchemaName] = sqltypes.StringBindVariable(specifiedKS) + } + + var tableName string if route.SysTableTableName != nil { - // the use has specified a table_name - let's check if it's a routed table - rss, err := route.paramsRoutedTable(vcursor, env, bindVars) + val, err := route.SysTableTableName.Evaluate(env) + if err != nil { + return nil, err + } + tableName = val.Value().ToString() + bindVars[BvTableName] = sqltypes.StringBindVariable(tableName) + } + + // if the table_schema is system system, route to default keyspace. + if sqlparser.SystemSchema(specifiedKS) { + return defaultRoute() + } + + // the use has specified a table_name - let's check if it's a routed table + if tableName != "" { + rss, err := route.paramsRoutedTable(vcursor, bindVars, specifiedKS, tableName) if err != nil { return nil, err } if rss != nil { return rss, nil } - // it was not a routed table, and we dont have a schema name to look up. give up - if route.SysTableTableSchema == nil { - return defaultRoute() - } } - // we only have table_schema to work with - result, err := route.SysTableTableSchema.Evaluate(env) - if err != nil { - return nil, err + // it was not a routed table, and we dont have a schema name to look up. give up + if specifiedKS == "" { + return defaultRoute() } - specifiedKS := result.Value().ToString() + + // we only have table_schema to work with destinations, _, err := vcursor.ResolveDestinations(specifiedKS, nil, []key.Destination{key.DestinationAnyShard{}}) if err != nil { log.Errorf("failed to route information_schema query to keyspace [%s]", specifiedKS) bindVars[sqltypes.BvSchemaName] = sqltypes.StringBindVariable(specifiedKS) return defaultRoute() } - bindVars[sqltypes.BvReplaceSchemaName] = sqltypes.Int64BindVariable(1) + setReplaceSchemaName(bindVars) return destinations, nil } -func (route *Route) paramsRoutedTable(vcursor VCursor, env evalengine.ExpressionEnv, bindVars map[string]*querypb.BindVariable) ([]*srvtopo.ResolvedShard, error) { - val, err := route.SysTableTableName.Evaluate(env) - if err != nil { - return nil, err - } - tableName := val.Value().ToString() - - var tableSchema string - if route.SysTableTableSchema != nil { - val, err := route.SysTableTableSchema.Evaluate(env) - if err != nil { - return nil, err - } - tableSchema = val.Value().ToString() - } - +func (route *Route) paramsRoutedTable(vcursor VCursor, bindVars map[string]*querypb.BindVariable, tableSchema string, tableName string) ([]*srvtopo.ResolvedShard, error) { tbl := sqlparser.TableName{ Name: sqlparser.NewTableIdent(tableName), Qualifier: sqlparser.NewTableIdent(tableSchema), @@ -478,7 +484,7 @@ func (route *Route) paramsRoutedTable(vcursor VCursor, env evalengine.Expression shards, _, err := vcursor.ResolveDestinations(destination.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}}) bindVars[BvTableName] = sqltypes.StringBindVariable(destination.Name.String()) if tableSchema != "" { - bindVars[sqltypes.BvReplaceSchemaName] = sqltypes.Int64BindVariable(1) + setReplaceSchemaName(bindVars) } return shards, err } @@ -488,6 +494,11 @@ func (route *Route) paramsRoutedTable(vcursor VCursor, env evalengine.Expression return nil, nil } +func setReplaceSchemaName(bindVars map[string]*querypb.BindVariable) { + delete(bindVars, sqltypes.BvSchemaName) + bindVars[sqltypes.BvReplaceSchemaName] = sqltypes.Int64BindVariable(1) +} + func (route *Route) paramsAnyShard(vcursor VCursor, bindVars map[string]*querypb.BindVariable) ([]*srvtopo.ResolvedShard, []map[string]*querypb.BindVariable, error) { rss, _, err := vcursor.ResolveDestinations(route.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}}) if err != nil { diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 79d92a4f50a..27e4e96acc0 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -18,7 +18,6 @@ package vtgate import ( "fmt" - "reflect" "strings" "testing" @@ -54,30 +53,35 @@ func TestSelectNext(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{"n": sqltypes.Int64BindVariable(2)}, }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries:\n%v, want\n%v\n", sbclookup.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbclookup.Queries) } func TestSelectDBA(t *testing.T) { executor, sbc1, _, _ := createLegacyExecutorEnv() query := "select * from INFORMATION_SCHEMA.foo" - _, err := executor.Execute( - context.Background(), - "TestSelectDBA", + _, err := executor.Execute(context.Background(), "TestSelectDBA", NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), - query, - map[string]*querypb.BindVariable{}, + query, map[string]*querypb.BindVariable{}, ) require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: query, - BindVariables: map[string]*querypb.BindVariable{}, - }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + wantQueries := []*querypb.BoundQuery{{Sql: query, BindVariables: map[string]*querypb.BindVariable{}}} + utils.MustMatch(t, wantQueries, sbc1.Queries) + sbc1.Queries = nil + + query = "SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = 'performance_schema' AND table_name = 'foo'" + _, err = executor.Execute(context.Background(), "TestSelectDBA", + NewSafeSession(&vtgatepb.Session{TargetString: "TestExecutor"}), + query, map[string]*querypb.BindVariable{}, + ) + require.NoError(t, err) + wantQueries = []*querypb.BoundQuery{{Sql: "select COUNT(*) from INFORMATION_SCHEMA.`TABLES` where table_schema = :__vtschemaname and table_name = :__vttablename", + BindVariables: map[string]*querypb.BindVariable{ + "__vtschemaname": sqltypes.StringBindVariable("performance_schema"), + "__vttablename": sqltypes.StringBindVariable("foo"), + }}} + utils.MustMatch(t, wantQueries, sbc1.Queries) + } func TestUnsharded(t *testing.T) { @@ -89,9 +93,7 @@ func TestUnsharded(t *testing.T) { Sql: "select id from music_user_map where id = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbclookup.Queries) } func TestUnshardedComments(t *testing.T) { @@ -103,9 +105,7 @@ func TestUnshardedComments(t *testing.T) { Sql: "/* leading */ select id from music_user_map where id = 1 /* trailing */", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbclookup.Queries) _, err = executorExec(executor, "update music_user_map set id = 1 /* trailing */", nil) require.NoError(t, err) @@ -116,9 +116,7 @@ func TestUnshardedComments(t *testing.T) { Sql: "update music_user_map set id = 1 /* trailing */", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbclookup.Queries) sbclookup.Queries = nil _, err = executorExec(executor, "delete from music_user_map /* trailing */", nil) @@ -127,9 +125,7 @@ func TestUnshardedComments(t *testing.T) { Sql: "delete from music_user_map /* trailing */", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbclookup.Queries) sbclookup.Queries = nil _, err = executorExec(executor, "insert into music_user_map values (1) /* trailing */", nil) @@ -138,9 +134,7 @@ func TestUnshardedComments(t *testing.T) { Sql: "insert into music_user_map values (1) /* trailing */", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbclookup.Queries) } func TestStreamUnsharded(t *testing.T) { @@ -214,12 +208,7 @@ func TestStreamBuffering(t *testing.T) { for r := range results { gotResults = append(gotResults, r) } - if !reflect.DeepEqual(gotResults, wantResults) { - t.Logf("len: %d", len(gotResults)) - for i := range gotResults { - t.Errorf("Buffered streaming:\n%v, want\n%v", gotResults[i], wantResults[i]) - } - } + utils.MustMatch(t, wantResults, gotResults) } func TestStreamLimitOffset(t *testing.T) { @@ -713,9 +702,7 @@ func TestSelectEqual(t *testing.T) { Sql: "select id from user where id = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } @@ -727,9 +714,7 @@ func TestSelectEqual(t *testing.T) { Sql: "select id from user where id = 3", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc2.Queries, wantQueries) { - t.Errorf("sbc2.Queries: %+v, want %+v\n", sbc2.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc2.Queries) if execCount := sbc1.ExecCount.Get(); execCount != 1 { t.Errorf("sbc1.ExecCount: %v, want 1\n", execCount) } @@ -744,9 +729,7 @@ func TestSelectEqual(t *testing.T) { Sql: "select id from user where id = '3'", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc2.Queries, wantQueries) { - t.Errorf("sbc2.Queries: %+v, want %+v\n", sbc2.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc2.Queries) if execCount := sbc1.ExecCount.Get(); execCount != 1 { t.Errorf("sbc1.ExecCount: %v, want 1\n", execCount) } @@ -765,9 +748,7 @@ func TestSelectEqual(t *testing.T) { Sql: "select id from user where `name` = 'foo'", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) vars, err := sqltypes.BuildBindVariable([]interface{}{sqltypes.NewVarBinary("foo")}) require.NoError(t, err) wantQueries = []*querypb.BoundQuery{{ @@ -776,9 +757,7 @@ func TestSelectEqual(t *testing.T) { "name": vars, }, }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbclookup.Queries) } func TestSelectDual(t *testing.T) { @@ -790,15 +769,11 @@ func TestSelectDual(t *testing.T) { Sql: "select @@aa.bb from dual", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) _, err = executorExec(executor, "select @@aa.bb from TestUnsharded.dual", nil) require.NoError(t, err) - if !reflect.DeepEqual(lookup.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, lookup.Queries) } func TestSelectComments(t *testing.T) { @@ -810,9 +785,7 @@ func TestSelectComments(t *testing.T) { Sql: "/* leading */ select id from user where id = 1 /* trailing */", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } @@ -831,9 +804,7 @@ func TestSelectNormalize(t *testing.T) { "vtg1": sqltypes.TestBindVariable(int64(1)), }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } @@ -864,9 +835,7 @@ func TestSelectCaseSensitivity(t *testing.T) { Sql: "select Id from user where iD = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } @@ -894,9 +863,7 @@ func TestSelectKeyRange(t *testing.T) { Sql: "select krcol_unique, krcol from keyrange_table where krcol = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } @@ -912,9 +879,7 @@ func TestSelectKeyRangeUnique(t *testing.T) { Sql: "select krcol_unique, krcol from keyrange_table where krcol_unique = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } @@ -933,9 +898,7 @@ func TestSelectIN(t *testing.T) { "__vals": sqltypes.TestBindVariable([]interface{}{int64(1)}), }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } @@ -952,18 +915,14 @@ func TestSelectIN(t *testing.T) { "__vals": sqltypes.TestBindVariable([]interface{}{int64(1)}), }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantQueries = []*querypb.BoundQuery{{ Sql: "select id from user where id in ::__vals", BindVariables: map[string]*querypb.BindVariable{ "__vals": sqltypes.TestBindVariable([]interface{}{int64(3)}), }, }} - if !reflect.DeepEqual(sbc2.Queries, wantQueries) { - t.Errorf("sbc2.Queries: %+v, want %+v\n", sbc2.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc2.Queries) // In is a bind variable list, that will end up on two shards. // This is using an []interface{} for the bind variable list. @@ -980,9 +939,7 @@ func TestSelectIN(t *testing.T) { "vals": sqltypes.TestBindVariable([]interface{}{int64(1), int64(3)}), }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantQueries = []*querypb.BoundQuery{{ Sql: "select id from user where id in ::__vals", BindVariables: map[string]*querypb.BindVariable{ @@ -990,9 +947,7 @@ func TestSelectIN(t *testing.T) { "vals": sqltypes.TestBindVariable([]interface{}{int64(1), int64(3)}), }, }} - if !reflect.DeepEqual(sbc2.Queries, wantQueries) { - t.Errorf("sbc2.Queries: %+v, want %+v\n", sbc2.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc2.Queries) // Convert a non-list bind variable. sbc1.Queries = nil @@ -1007,9 +962,7 @@ func TestSelectIN(t *testing.T) { Sql: "select id from user where `name` = 'foo'", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) vars, err := sqltypes.BuildBindVariable([]interface{}{sqltypes.NewVarBinary("foo")}) require.NoError(t, err) wantQueries = []*querypb.BoundQuery{{ @@ -1018,9 +971,7 @@ func TestSelectIN(t *testing.T) { "name": vars, }, }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbclookup.Queries) } func TestStreamSelectIN(t *testing.T) { @@ -1065,9 +1016,7 @@ func TestStreamSelectIN(t *testing.T) { "name": vars, }, }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbclookup.Queries) } func TestSelectScatter(t *testing.T) { @@ -1096,9 +1045,7 @@ func TestSelectScatter(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { - if !reflect.DeepEqual(conn.Queries, wantQueries) { - t.Errorf("conn.Queries = %#v, want %#v", conn.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, conn.Queries) } testQueryLog(t, logChan, "TestExecute", "SELECT", wantQueries[0].Sql, 8) } @@ -1241,9 +1188,7 @@ func TestSelectScatterOrderBy(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { - if !reflect.DeepEqual(conn.Queries, wantQueries) { - t.Errorf("conn.Queries = %#v, want %#v", conn.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, conn.Queries) } wantResult := &sqltypes.Result{ @@ -1264,9 +1209,7 @@ func TestSelectScatterOrderBy(t *testing.T) { wantResult.Rows = append(wantResult.Rows, row) } } - if !reflect.DeepEqual(gotResult, wantResult) { - t.Errorf("scatter order by:\n%v, want\n%v", gotResult, wantResult) - } + utils.MustMatch(t, wantResult, gotResult) } // TestSelectScatterOrderByVarChar will run an ORDER BY query that will scatter out to 8 shards and return the 8 rows (one per shard) sorted. @@ -1312,9 +1255,7 @@ func TestSelectScatterOrderByVarChar(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { - if !reflect.DeepEqual(conn.Queries, wantQueries) { - t.Errorf("conn.Queries = %#v, want %#v", conn.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, conn.Queries) } wantResult := &sqltypes.Result{ @@ -1335,9 +1276,7 @@ func TestSelectScatterOrderByVarChar(t *testing.T) { wantResult.Rows = append(wantResult.Rows, row) } } - if !reflect.DeepEqual(gotResult, wantResult) { - t.Errorf("scatter order by:\n%v, want\n%v", gotResult, wantResult) - } + utils.MustMatch(t, wantResult, gotResult) } func TestStreamSelectScatterOrderBy(t *testing.T) { @@ -1378,9 +1317,7 @@ func TestStreamSelectScatterOrderBy(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { - if !reflect.DeepEqual(conn.Queries, wantQueries) { - t.Errorf("conn.Queries = %#v, want %#v", conn.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, conn.Queries) } wantResult := &sqltypes.Result{ @@ -1396,9 +1333,7 @@ func TestStreamSelectScatterOrderBy(t *testing.T) { } wantResult.Rows = append(wantResult.Rows, row, row) } - if !reflect.DeepEqual(gotResult, wantResult) { - t.Errorf("scatter order by:\n%v, want\n%v", gotResult, wantResult) - } + utils.MustMatch(t, wantResult, gotResult) } func TestStreamSelectScatterOrderByVarChar(t *testing.T) { @@ -1440,9 +1375,7 @@ func TestStreamSelectScatterOrderByVarChar(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { - if !reflect.DeepEqual(conn.Queries, wantQueries) { - t.Errorf("conn.Queries = %#v, want %#v", conn.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, conn.Queries) } wantResult := &sqltypes.Result{ @@ -1458,9 +1391,7 @@ func TestStreamSelectScatterOrderByVarChar(t *testing.T) { } wantResult.Rows = append(wantResult.Rows, row, row) } - if !reflect.DeepEqual(gotResult, wantResult) { - t.Errorf("scatter order by:\n%v, want\n%v", gotResult, wantResult) - } + utils.MustMatch(t, wantResult, gotResult) } // TestSelectScatterAggregate will run an aggregate query that will scatter out to 8 shards and return 4 aggregated rows. @@ -1502,9 +1433,7 @@ func TestSelectScatterAggregate(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { - if !reflect.DeepEqual(conn.Queries, wantQueries) { - t.Errorf("conn.Queries = %#v, want %#v", conn.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, conn.Queries) } wantResult := &sqltypes.Result{ @@ -1522,9 +1451,7 @@ func TestSelectScatterAggregate(t *testing.T) { } wantResult.Rows = append(wantResult.Rows, row) } - if !reflect.DeepEqual(gotResult, wantResult) { - t.Errorf("scatter order by:\n%v, want\n%v", gotResult, wantResult) - } + utils.MustMatch(t, wantResult, gotResult) } func TestStreamSelectScatterAggregate(t *testing.T) { @@ -1565,9 +1492,7 @@ func TestStreamSelectScatterAggregate(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { - if !reflect.DeepEqual(conn.Queries, wantQueries) { - t.Errorf("conn.Queries = %#v, want %#v", conn.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, conn.Queries) } wantResult := &sqltypes.Result{ @@ -1583,9 +1508,7 @@ func TestStreamSelectScatterAggregate(t *testing.T) { } wantResult.Rows = append(wantResult.Rows, row) } - if !reflect.DeepEqual(gotResult, wantResult) { - t.Errorf("scatter order by:\n%v, want\n%v", gotResult, wantResult) - } + utils.MustMatch(t, wantResult, gotResult) } // TestSelectScatterLimit will run a limit query (ordered for consistency) against @@ -1628,9 +1551,7 @@ func TestSelectScatterLimit(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}, }} for _, conn := range conns { - if !reflect.DeepEqual(conn.Queries, wantQueries) { - t.Errorf("got: conn.Queries = %v, want: %v", conn.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, conn.Queries) } wantResult := &sqltypes.Result{ @@ -1655,9 +1576,7 @@ func TestSelectScatterLimit(t *testing.T) { sqltypes.NewInt32(2), }) - if !reflect.DeepEqual(gotResult, wantResult) { - t.Errorf("scatter order by:\n%v, want\n%v", gotResult, wantResult) - } + utils.MustMatch(t, wantResult, gotResult) } // TestStreamSelectScatterLimit will run a streaming limit query (ordered for consistency) against @@ -1700,9 +1619,7 @@ func TestStreamSelectScatterLimit(t *testing.T) { BindVariables: map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}, }} for _, conn := range conns { - if !reflect.DeepEqual(conn.Queries, wantQueries) { - t.Errorf("got: conn.Queries = %v, want: %v", conn.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, conn.Queries) } wantResult := &sqltypes.Result{ @@ -1725,9 +1642,7 @@ func TestStreamSelectScatterLimit(t *testing.T) { sqltypes.NewInt32(2), }) - if !reflect.DeepEqual(gotResult, wantResult) { - t.Errorf("scatter order by:\n%v, want\n%v", gotResult, wantResult) - } + utils.MustMatch(t, wantResult, gotResult) } // TODO(sougou): stream and non-stream testing are very similar. @@ -1744,16 +1659,12 @@ func TestSimpleJoin(t *testing.T) { Sql: "select u1.id from user as u1 where u1.id = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantQueries = []*querypb.BoundQuery{{ Sql: "select u2.id from user as u2 where u2.id = 3", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc2.Queries, wantQueries) { - t.Errorf("sbc2.Queries: %+v, want %+v\n", sbc2.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc2.Queries) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ sandboxconn.SingleRowResult.Fields[0], @@ -1786,16 +1697,12 @@ func TestJoinComments(t *testing.T) { Sql: "select u1.id from user as u1 where u1.id = 1 /* trailing */", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantQueries = []*querypb.BoundQuery{{ Sql: "select u2.id from user as u2 where u2.id = 3 /* trailing */", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc2.Queries, wantQueries) { - t.Errorf("sbc2.Queries: %+v, want %+v\n", sbc2.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc2.Queries) testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 2) } @@ -1812,16 +1719,12 @@ func TestSimpleJoinStream(t *testing.T) { Sql: "select u1.id from user as u1 where u1.id = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantQueries = []*querypb.BoundQuery{{ Sql: "select u2.id from user as u2 where u2.id = 3", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc2.Queries, wantQueries) { - t.Errorf("sbc2.Queries: %+v, want %+v\n", sbc2.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc2.Queries) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ sandboxconn.SingleRowResult.Fields[0], @@ -1867,9 +1770,7 @@ func TestVarJoin(t *testing.T) { Sql: "select u1.id, u1.col from user as u1 where u1.id = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) // We have to use string representation because bindvars type is too complex. got := fmt.Sprintf("%+v", sbc2.Queries) want := `[sql:"select u2.id from user as u2 where u2.id = :u1_col" bind_variables: > ]` @@ -1905,9 +1806,7 @@ func TestVarJoinStream(t *testing.T) { Sql: "select u1.id, u1.col from user as u1 where u1.id = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) // We have to use string representation because bindvars type is too complex. got := fmt.Sprintf("%+v", sbc2.Queries) want := `[sql:"select u2.id from user as u2 where u2.id = :u1_col" bind_variables: > ]` @@ -2030,9 +1929,7 @@ func TestEmptyJoin(t *testing.T) { "u1_col": sqltypes.NullBindVariable, }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries:\n%v, want\n%v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ {Name: "id", Type: sqltypes.Int32}, @@ -2068,9 +1965,7 @@ func TestEmptyJoinStream(t *testing.T) { "u1_col": sqltypes.NullBindVariable, }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ {Name: "id", Type: sqltypes.Int32}, @@ -2113,9 +2008,7 @@ func TestEmptyJoinRecursive(t *testing.T) { "u2_col": sqltypes.NullBindVariable, }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries:\n%+v, want\n%+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ {Name: "id", Type: sqltypes.Int32}, @@ -2159,9 +2052,7 @@ func TestEmptyJoinRecursiveStream(t *testing.T) { "u2_col": sqltypes.NullBindVariable, }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ {Name: "id", Type: sqltypes.Int32}, @@ -2195,9 +2086,7 @@ func TestCrossShardSubquery(t *testing.T) { Sql: "select u1.id as id1, u1.col from user as u1 where u1.id = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) // We have to use string representation because bindvars type is too complex. got := fmt.Sprintf("%+v", sbc2.Queries) want := `[sql:"select u2.id from user as u2 where u2.id = :u1_col" bind_variables: > ]` @@ -2240,9 +2129,7 @@ func TestCrossShardSubqueryStream(t *testing.T) { Sql: "select u1.id as id1, u1.col from user as u1 where u1.id = 1", BindVariables: map[string]*querypb.BindVariable{}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries:\n%+v, want\n%+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) // We have to use string representation because bindvars type is too complex. got := fmt.Sprintf("%+v", sbc2.Queries) want := `[sql:"select u2.id from user as u2 where u2.id = :u1_col" bind_variables: > ]` @@ -2288,9 +2175,7 @@ func TestCrossShardSubqueryGetFields(t *testing.T) { "u1_col": sqltypes.NullBindVariable, }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries:\n%+v, want\n%+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ @@ -2318,9 +2203,7 @@ func TestSelectBindvarswithPrepare(t *testing.T) { Sql: "select id from user where 1 != 1", BindVariables: map[string]*querypb.BindVariable{"id": sqltypes.Int64BindVariable(1)}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index 7c601940575..4be60942a61 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -2766,3 +2766,22 @@ Gen4 plan same as above "SysTableTableSchema": "VARBINARY(\"a\")" } } + +# system schema in where clause of information_schema query +"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = 'performance_schema' AND table_name = 'foo'" +{ + "QueryType": "SELECT", + "Original": "SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = 'performance_schema' AND table_name = 'foo'", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select COUNT(*) from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select COUNT(*) from INFORMATION_SCHEMA.`TABLES` where table_schema = :__vtschemaname and table_name = :__vttablename", + "SysTableTableName": "VARBINARY(\"foo\")", + "SysTableTableSchema": "VARBINARY(\"performance_schema\")" + } +}