diff --git a/.gitignore b/.gitignore index 041443eeedd..ec0a429aa13 100644 --- a/.gitignore +++ b/.gitignore @@ -80,7 +80,7 @@ releases /dist/ /vthook/ /bin/ -/vtdataroot/ +*vtdataroot/ venv .scannerwork diff --git a/go.sum b/go.sum index 41e5864cac5..49128274505 100644 --- a/go.sum +++ b/go.sum @@ -693,7 +693,6 @@ github.com/spyzhov/ajson v0.4.2 h1:JMByd/jZApPKDvNsmO90X2WWGbmT2ahDFp73QhZbg3s= github.com/spyzhov/ajson v0.4.2/go.mod h1:63V+CGM6f1Bu/p4nLIN8885ojBdt88TbLoSFzyqMuVA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= diff --git a/go/test/endtoend/vtgate/system_schema_test.go b/go/test/endtoend/vtgate/system_schema_test.go index ad2fb0f5e93..96124020cc2 100644 --- a/go/test/endtoend/vtgate/system_schema_test.go +++ b/go/test/endtoend/vtgate/system_schema_test.go @@ -80,6 +80,11 @@ func TestInformationSchemaQuery(t *testing.T) { assertResultIsEmpty(t, conn, "table_schema = 'performance_schema' and table_name = 'foo'") assertSingleRowIsReturned(t, conn, "table_schema = 'vt_ks' and table_name = 't1'", "vt_ks") assertSingleRowIsReturned(t, conn, "table_schema = 'ks' and table_name = 't1'", "vt_ks") + // run end to end test for in statement. + assertSingleRowIsReturned(t, conn, "table_schema IN ('ks')", "vt_ks") + assertSingleRowIsReturned(t, conn, "table_schema IN ('vt_ks')", "vt_ks") + assertSingleRowIsReturned(t, conn, "table_schema IN ('ks') and table_name = 't1'", "vt_ks") + assertSingleRowIsReturned(t, conn, "table_schema IN ('ks') and table_name IN ('t1')", "vt_ks") } func assertResultIsEmpty(t *testing.T, conn *mysql.Conn, pre string) { diff --git a/go/vt/sqlparser/analyzer.go b/go/vt/sqlparser/analyzer.go index 29c62ca68ca..b4f61735568 100644 --- a/go/vt/sqlparser/analyzer.go +++ b/go/vt/sqlparser/analyzer.go @@ -120,13 +120,39 @@ func ASTToStatementType(stmt Statement) StatementType { //CanNormalize takes Statement and returns if the statement can be normalized. func CanNormalize(stmt Statement) bool { - switch stmt.(type) { - case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream: // TODO: we could merge this logic into ASTrewriter + switch s := stmt.(type) { + case *Select: + // Skip the system table normalization, normalization will cause schemaname replacement chaos. + return !containSystemTable(s.From) + case *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream: // TODO: we could merge this logic into ASTrewriter return true } return false } +func containSystemTable(tableExprs TableExprs) bool { + if len(tableExprs) == 1 { + return isSystemTable(tableExprs[0]) + } + return isSystemTable(tableExprs[0]) || containSystemTable(tableExprs[1:]) +} + +func isSystemTable(expr TableExpr) bool { + switch tableExpr := expr.(type) { + case *AliasedTableExpr: + switch table := tableExpr.Expr.(type) { + // Derived table possible, but we will leave it for now. + case TableName: + return SystemSchema(table.Qualifier.String()) + } + case *ParenTableExpr: + return containSystemTable(tableExpr.Exprs) + case *JoinTableExpr: + return isSystemTable(tableExpr.LeftExpr) || isSystemTable(tableExpr.RightExpr) + } + return false +} + // CachePlan takes Statement and returns true if the query plan should be cached func CachePlan(stmt Statement) bool { switch stmt.(type) { diff --git a/go/vt/sqlparser/analyzer_test.go b/go/vt/sqlparser/analyzer_test.go index cf755b82418..cc1bd205883 100644 --- a/go/vt/sqlparser/analyzer_test.go +++ b/go/vt/sqlparser/analyzer_test.go @@ -237,6 +237,67 @@ func TestGetTableName(t *testing.T) { } } +func TestCanNormalize(t *testing.T) { + testcases := []struct { + in string + out bool + }{{ + in: "select * from t", + out: true, + }, { + in: "select * from t.t", + out: true, + }, { + in: "select * from (select * from t) as tt", + out: true, + }, { + in: "describe t", + out: false, + }, { + in: "select * from information_schema.views", + out: false, + }, { + in: "select * from information_schema.tables", + out: false, + }, { + in: "select * from information_schema.columns", + out: false, + }, { + in: "select * from performance_schema.users", + out: false, + }, { + in: "select * from sys.version_patch", + out: false, + }, { + in: "select * from mysql.user", + out: false, + }, { + in: "select * from (select * from information_schema.columns) as tt", + out: true, + }, { + in: "select * from information_schema.columns as info join t on info.a = t.t", + out: false, + }, { + in: "select * from information_schema.columns as info join sys.version_patch", + out: false, + }, { + in: "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('commerce')", + out: false, + }} + + for _, tc := range testcases { + tree, err := Parse(tc.in) + if err != nil { + t.Error(err) + continue + } + out := CanNormalize(tree) + if out != tc.out { + t.Errorf("CanNormalize('%v'): %v, want %v", tc.in, out, tc.out) + } + } +} + func TestIsColName(t *testing.T) { testcases := []struct { in Expr diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go index e85fce2241d..1bf1c3202bb 100644 --- a/go/vt/sqlparser/expression_converter.go +++ b/go/vt/sqlparser/expression_converter.go @@ -71,7 +71,6 @@ func Convert(e Expr) (evalengine.Expr, error) { Left: left, Right: right, }, nil - } return nil, ErrExprNotSupported } diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 5260ff184c4..da3c1abc78a 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -93,11 +93,11 @@ func TestBuilder(query string, vschema ContextVSchema) (*engine.Plan, error) { if err != nil { return nil, err } + // TODO(Scott): the actual execution path go through prepareAST first to normalize the variables. which could be different than this one. result, err := sqlparser.RewriteAST(stmt, "") if err != nil { return nil, err } - reservedVars := sqlparser.NewReservedVars("vtg", reserved) return BuildFromStmt(query, result.AST, reservedVars, vschema, result.BindVarNeeds, true, true) } diff --git a/go/vt/vtgate/planbuilder/fuzz.go b/go/vt/vtgate/planbuilder/fuzz.go index ff8d4968316..1996d558b9e 100644 --- a/go/vt/vtgate/planbuilder/fuzz.go +++ b/go/vt/vtgate/planbuilder/fuzz.go @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -// +build gofuzz package planbuilder diff --git a/go/vt/vtgate/planbuilder/system_tables.go b/go/vt/vtgate/planbuilder/system_tables.go index 111d86bb41d..316091f71fa 100644 --- a/go/vt/vtgate/planbuilder/system_tables.go +++ b/go/vt/vtgate/planbuilder/system_tables.go @@ -18,25 +18,28 @@ package planbuilder import ( "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/evalengine" ) func (pb *primitiveBuilder) findSysInfoRoutingPredicates(expr sqlparser.Expr, rut *route) error { - isTableSchema, out, err := extractInfoSchemaRoutingPredicate(expr) + tableSchemas, tableNames, err := extractInfoSchemaRoutingPredicate(expr) if err != nil { return err } - if out == nil { + if len(tableSchemas) == 0 && len(tableNames) == 0 { // we didn't find a predicate to use for routing, so we just exit early return nil } - if isTableSchema { - rut.eroute.SysTableTableSchema = append(rut.eroute.SysTableTableSchema, out) - } else { - rut.eroute.SysTableTableName = append(rut.eroute.SysTableTableName, out) + if len(tableSchemas) > 0 { + rut.eroute.SysTableTableSchema = append(rut.eroute.SysTableTableSchema, tableSchemas...) + } + + if len(tableNames) > 0 { + rut.eroute.SysTableTableName = append(rut.eroute.SysTableTableName, tableNames...) } return nil @@ -73,7 +76,7 @@ func isTableNameCol(col *sqlparser.ColName) bool { return col.Name.EqualString("table_name") } -func extractInfoSchemaRoutingPredicate(in sqlparser.Expr) (bool, evalengine.Expr, error) { +func extractInfoSchemaRoutingPredicate(in sqlparser.Expr) ([]evalengine.Expr, []evalengine.Expr, error) { switch cmp := in.(type) { case *sqlparser.ComparisonExpr: if cmp.Operator == sqlparser.EqualOp { @@ -84,22 +87,141 @@ func extractInfoSchemaRoutingPredicate(in sqlparser.Expr) (bool, evalengine.Expr if err == sqlparser.ErrExprNotSupported { // This just means we can't rewrite this particular expression, // not that we have to exit altogether - return false, nil, nil + return nil, nil, nil } - return false, nil, err + return nil, nil, err } - var name string + exprs := []evalengine.Expr{evalExpr} if isSchemaName { + replaceOther(sqlparser.NewArgument(sqltypes.BvSchemaName)) + return exprs, nil, nil + } + replaceOther(sqlparser.NewArgument(engine.BvTableName)) + return nil, exprs, nil + } + } else if cmp.Operator == sqlparser.InOp || cmp.Operator == sqlparser.NotInOp { + // left side has to be the column, i.e (1, 2) IN column is not allowed. + // At least one column has to be DB name or table name. + colNames := checkAndSplitColumns(cmp.Left) + if colNames == nil { + return nil, nil, nil + } + valTuples := splitVals(cmp.Right, len(colNames)) + // check if the val tuples format is correct. + if valTuples == nil { + return nil, nil, nil + } + + sysTableSchemas := make([]evalengine.Expr, 0, len(valTuples)) + sysTableNames := make([]evalengine.Expr, 0, len(valTuples)) + for index, col := range colNames { + isSchema, isTable := isTableSchemaOrName(col) + var name string + if isSchema { name = sqltypes.BvSchemaName - } else { + } else if isTable { name = engine.BvTableName + } else { + // only need to rewrite the SysTable and SysSchema + continue + } + + for _, tuple := range valTuples { + expr := tuple[index] + if shouldRewrite(expr) { + tuple[index] = sqlparser.Argument(name) + evalExpr, err := sqlparser.Convert(expr) + if err != nil { + if err == sqlparser.ErrExprNotSupported { + // This just means we can't rewrite this particular expression, + // not that we have to exit altogether + return nil, nil, nil + } + return nil, nil, err + } + if isSchema { + sysTableSchemas = append(sysTableSchemas, evalExpr) + } else if isTable { + sysTableNames = append(sysTableNames, evalExpr) + } + } } - replaceOther(sqlparser.NewArgument(name)) - return isSchemaName, evalExpr, nil } + // construct right side, rows of tuples of __vtschemaname or database() + cmp.Right = populateValTuple(valTuples, len(colNames)) + return sysTableSchemas, sysTableNames, nil + } + } + return nil, nil, nil +} + +func populateValTuple(valTuples []sqlparser.ValTuple, numOfCol int) sqlparser.ValTuple { + var retValTuples sqlparser.ValTuple + retValTuples = make([]sqlparser.Expr, 0, len(valTuples)) + for _, tuple := range valTuples { + if numOfCol == 1 { + // only one col per row, of colName type. + retValTuples = append(retValTuples, tuple[0]) + } else { + retValTuples = append(retValTuples, tuple) + } + } + return retValTuples +} + +// Convert the right side of In ops to a list of rows. +func splitVals(e sqlparser.Expr, numOfCols int) []sqlparser.ValTuple { + // could either be (1, 2, 3) or ((1,2), (3,5)) + expressions, ok := e.(sqlparser.ValTuple) + if !ok { + log.Errorf("Unsupported type, expecting val tuple %v", e) + return nil + } + valTuples := make([]sqlparser.ValTuple, 0, len(expressions)) + + for _, tuple := range expressions { + if numOfCols == 1 { + // values could be literal, float or other types. + valTuple := []sqlparser.Expr{tuple} + valTuples = append(valTuples, valTuple) + } else { + valTuple, ok := tuple.(sqlparser.ValTuple) + if !ok { + log.Errorf("Unsupported type, expecting a list of val tuple %v", tuple) + return nil + } + valTuples = append(valTuples, valTuple) } + + } + return valTuples +} + +// Convert the left side of In ops to a list of columns. +func checkAndSplitColumns(e sqlparser.Expr) []sqlparser.Expr { + colNames := make([]sqlparser.Expr, 0) + switch cols := e.(type) { + case sqlparser.ValTuple: + colNames = cols + case *sqlparser.ColName: + colNames = append(colNames, cols) + default: + // unexpected left side of the in ops. + return nil + } + containSystemTable := false + for _, col := range colNames { + containsDB, containsTable := isTableSchemaOrName(col) + if containsDB || containsTable { + containSystemTable = true + break + } + } + if !containSystemTable { + log.Infof("left side of (not) in operator don't have a DB name or table name, don't need to rewrite. ") + return nil } - return false, nil, nil + return colNames } func shouldRewrite(e sqlparser.Expr) bool { diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt index 480314df077..c9395ae9a12 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt @@ -1803,6 +1803,154 @@ Gen4 plan same as above } } +# query trying to query keyspace in clause +"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('user')" +{ + "QueryType": "SELECT", + "Original": "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('user')", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select * from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select * from INFORMATION_SCHEMA.`TABLES` where TABLE_SCHEMA in (:__vtschemaname)", + "SysTableTableSchema": "[VARBINARY(\"user\")]" + } +} + +# query trying to query keyspace in clause with multiple values +"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('user','main')" +{ + "QueryType": "SELECT", + "Original": "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('user','main')", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select * from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select * from INFORMATION_SCHEMA.`TABLES` where TABLE_SCHEMA in (:__vtschemaname, :__vtschemaname)", + "SysTableTableSchema": "[VARBINARY(\"user\"), VARBINARY(\"main\")]" + } +} + +# query trying to query keyspace not in clause with multiple values +"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA NOT IN ('user','main')" +{ + "QueryType": "SELECT", + "Original": "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA NOT IN ('user','main')", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select * from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select * from INFORMATION_SCHEMA.`TABLES` where TABLE_SCHEMA not in (:__vtschemaname, :__vtschemaname)", + "SysTableTableSchema": "[VARBINARY(\"user\"), VARBINARY(\"main\")]" + } +} + +# query trying to query keyspace in clause with multiple values +"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('user','main',database())" +{ + "QueryType": "SELECT", + "Original": "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('user','main',database())", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select * from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select * from INFORMATION_SCHEMA.`TABLES` where TABLE_SCHEMA in (:__vtschemaname, :__vtschemaname, database())", + "SysTableTableSchema": "[VARBINARY(\"user\"), VARBINARY(\"main\")]" + } +} + +# And two in-clause of sys table. +"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('ks') AND TABLE_NAME IN ('route1')" +{ + "QueryType": "SELECT", + "Original": "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('ks') AND TABLE_NAME IN ('route1')", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select * from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select * from INFORMATION_SCHEMA.`TABLES` where TABLE_SCHEMA in (:__vtschemaname) and TABLE_NAME in (:__vttablename)", + "SysTableTableName": "[VARBINARY(\"route1\")]", + "SysTableTableSchema": "[VARBINARY(\"ks\")]" + } +} + +# schema name and table name of composite in clause. +"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA, TABLE_NAME) IN (('ks', 'route1'))" +{ + "QueryType": "SELECT", + "Original": "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA, TABLE_NAME) IN (('ks', 'route1'))", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select * from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select * from INFORMATION_SCHEMA.`TABLES` where (TABLE_SCHEMA, TABLE_NAME) in ((:__vtschemaname, :__vttablename))", + "SysTableTableName": "[VARBINARY(\"route1\")]", + "SysTableTableSchema": "[VARBINARY(\"ks\")]" + } +} + +# schema name and table name of composite NOT in clause. +"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA, TABLE_NAME) NOT IN (('ks', 'route1'),('ks1','route2'))" +{ + "QueryType": "SELECT", + "Original": "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA, TABLE_NAME) NOT IN (('ks', 'route1'),('ks1','route2'))", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select * from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select * from INFORMATION_SCHEMA.`TABLES` where (TABLE_SCHEMA, TABLE_NAME) not in ((:__vtschemaname, :__vttablename), (:__vtschemaname, :__vttablename))", + "SysTableTableName": "[VARBINARY(\"route1\"), VARBINARY(\"route2\")]", + "SysTableTableSchema": "[VARBINARY(\"ks\"), VARBINARY(\"ks1\")]" + } +} + +# schema name and table name of composite in clause. +"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA, TABLE_NAME, LOWER(COLUMN_NAME)) IN (('ks', 'route1', 'col'))" +{ + "QueryType": "SELECT", + "Original": "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA, TABLE_NAME, LOWER(COLUMN_NAME)) IN (('ks', 'route1', 'col'))", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select * from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select * from INFORMATION_SCHEMA.`TABLES` where (TABLE_SCHEMA, TABLE_NAME, LOWER(COLUMN_NAME)) in ((:__vtschemaname, :__vttablename, 'col'))", + "SysTableTableName": "[VARBINARY(\"route1\")]", + "SysTableTableSchema": "[VARBINARY(\"ks\")]" + } +} + # information_schema query using database() func "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = database()" {