Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ releases
/dist/
/vthook/
/bin/
/vtdataroot/
*vtdataroot/
venv

.scannerwork
Expand Down
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,6 @@ github.com/spyzhov/ajson v0.4.2 h1:JMByd/jZApPKDvNsmO90X2WWGbmT2ahDFp73QhZbg3s=
github.com/spyzhov/ajson v0.4.2/go.mod h1:63V+CGM6f1Bu/p4nLIN8885ojBdt88TbLoSFzyqMuVA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
Expand Down
5 changes: 5 additions & 0 deletions go/test/endtoend/vtgate/system_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ func TestInformationSchemaQuery(t *testing.T) {
assertResultIsEmpty(t, conn, "table_schema = 'performance_schema' and table_name = 'foo'")
assertSingleRowIsReturned(t, conn, "table_schema = 'vt_ks' and table_name = 't1'", "vt_ks")
assertSingleRowIsReturned(t, conn, "table_schema = 'ks' and table_name = 't1'", "vt_ks")
// run end to end test for in statement.
assertSingleRowIsReturned(t, conn, "table_schema IN ('ks')", "vt_ks")
assertSingleRowIsReturned(t, conn, "table_schema IN ('vt_ks')", "vt_ks")
assertSingleRowIsReturned(t, conn, "table_schema IN ('ks') and table_name = 't1'", "vt_ks")
assertSingleRowIsReturned(t, conn, "table_schema IN ('ks') and table_name IN ('t1')", "vt_ks")
}

func assertResultIsEmpty(t *testing.T, conn *mysql.Conn, pre string) {
Expand Down
30 changes: 28 additions & 2 deletions go/vt/sqlparser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,39 @@ func ASTToStatementType(stmt Statement) StatementType {

//CanNormalize takes Statement and returns if the statement can be normalized.
func CanNormalize(stmt Statement) bool {
switch stmt.(type) {
case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream: // TODO: we could merge this logic into ASTrewriter
switch s := stmt.(type) {
case *Select:
// Skip the system table normalization, normalization will cause schemaname replacement chaos.
return !containSystemTable(s.From)
case *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream: // TODO: we could merge this logic into ASTrewriter
return true
}
return false
}

func containSystemTable(tableExprs TableExprs) bool {
if len(tableExprs) == 1 {
return isSystemTable(tableExprs[0])
}
return isSystemTable(tableExprs[0]) || containSystemTable(tableExprs[1:])
}

func isSystemTable(expr TableExpr) bool {
switch tableExpr := expr.(type) {
case *AliasedTableExpr:
switch table := tableExpr.Expr.(type) {
// Derived table possible, but we will leave it for now.
case TableName:
return SystemSchema(table.Qualifier.String())
}
case *ParenTableExpr:
return containSystemTable(tableExpr.Exprs)
case *JoinTableExpr:
return isSystemTable(tableExpr.LeftExpr) || isSystemTable(tableExpr.RightExpr)
}
return false
}

// CachePlan takes Statement and returns true if the query plan should be cached
func CachePlan(stmt Statement) bool {
switch stmt.(type) {
Expand Down
61 changes: 61 additions & 0 deletions go/vt/sqlparser/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,67 @@ func TestGetTableName(t *testing.T) {
}
}

func TestCanNormalize(t *testing.T) {
testcases := []struct {
in string
out bool
}{{
in: "select * from t",
out: true,
}, {
in: "select * from t.t",
out: true,
}, {
in: "select * from (select * from t) as tt",
out: true,
}, {
in: "describe t",
out: false,
}, {
in: "select * from information_schema.views",
out: false,
}, {
in: "select * from information_schema.tables",
out: false,
}, {
in: "select * from information_schema.columns",
out: false,
}, {
in: "select * from performance_schema.users",
out: false,
}, {
in: "select * from sys.version_patch",
out: false,
}, {
in: "select * from mysql.user",
out: false,
}, {
in: "select * from (select * from information_schema.columns) as tt",
out: true,
}, {
in: "select * from information_schema.columns as info join t on info.a = t.t",
out: false,
}, {
in: "select * from information_schema.columns as info join sys.version_patch",
out: false,
}, {
in: "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN ('commerce')",
out: false,
}}

for _, tc := range testcases {
tree, err := Parse(tc.in)
if err != nil {
t.Error(err)
continue
}
out := CanNormalize(tree)
if out != tc.out {
t.Errorf("CanNormalize('%v'): %v, want %v", tc.in, out, tc.out)
}
}
}

func TestIsColName(t *testing.T) {
testcases := []struct {
in Expr
Expand Down
1 change: 0 additions & 1 deletion go/vt/sqlparser/expression_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ func Convert(e Expr) (evalengine.Expr, error) {
Left: left,
Right: right,
}, nil

}
return nil, ErrExprNotSupported
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ func TestBuilder(query string, vschema ContextVSchema) (*engine.Plan, error) {
if err != nil {
return nil, err
}
// TODO(Scott): the actual execution path go through prepareAST first to normalize the variables. which could be different than this one.
result, err := sqlparser.RewriteAST(stmt, "")
if err != nil {
return nil, err
}

reservedVars := sqlparser.NewReservedVars("vtg", reserved)
return BuildFromStmt(query, result.AST, reservedVars, vschema, result.BindVarNeeds, true, true)
}
Expand Down
1 change: 0 additions & 1 deletion go/vt/vtgate/planbuilder/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// +build gofuzz

package planbuilder

Expand Down
150 changes: 136 additions & 14 deletions go/vt/vtgate/planbuilder/system_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,28 @@ package planbuilder

import (
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/engine"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

func (pb *primitiveBuilder) findSysInfoRoutingPredicates(expr sqlparser.Expr, rut *route) error {
isTableSchema, out, err := extractInfoSchemaRoutingPredicate(expr)
tableSchemas, tableNames, err := extractInfoSchemaRoutingPredicate(expr)
if err != nil {
return err
}
if out == nil {
if len(tableSchemas) == 0 && len(tableNames) == 0 {
// we didn't find a predicate to use for routing, so we just exit early
return nil
}

if isTableSchema {
rut.eroute.SysTableTableSchema = append(rut.eroute.SysTableTableSchema, out)
} else {
rut.eroute.SysTableTableName = append(rut.eroute.SysTableTableName, out)
if len(tableSchemas) > 0 {
rut.eroute.SysTableTableSchema = append(rut.eroute.SysTableTableSchema, tableSchemas...)
}

if len(tableNames) > 0 {
rut.eroute.SysTableTableName = append(rut.eroute.SysTableTableName, tableNames...)
}

return nil
Expand Down Expand Up @@ -73,7 +76,7 @@ func isTableNameCol(col *sqlparser.ColName) bool {
return col.Name.EqualString("table_name")
}

func extractInfoSchemaRoutingPredicate(in sqlparser.Expr) (bool, evalengine.Expr, error) {
func extractInfoSchemaRoutingPredicate(in sqlparser.Expr) ([]evalengine.Expr, []evalengine.Expr, error) {
switch cmp := in.(type) {
case *sqlparser.ComparisonExpr:
if cmp.Operator == sqlparser.EqualOp {
Expand All @@ -84,22 +87,141 @@ func extractInfoSchemaRoutingPredicate(in sqlparser.Expr) (bool, evalengine.Expr
if err == sqlparser.ErrExprNotSupported {
// This just means we can't rewrite this particular expression,
// not that we have to exit altogether
return false, nil, nil
return nil, nil, nil
}
return false, nil, err
return nil, nil, err
}
var name string
exprs := []evalengine.Expr{evalExpr}
if isSchemaName {
replaceOther(sqlparser.NewArgument(sqltypes.BvSchemaName))
return exprs, nil, nil
}
replaceOther(sqlparser.NewArgument(engine.BvTableName))
return nil, exprs, nil
}
} else if cmp.Operator == sqlparser.InOp || cmp.Operator == sqlparser.NotInOp {
// left side has to be the column, i.e (1, 2) IN column is not allowed.
// At least one column has to be DB name or table name.
colNames := checkAndSplitColumns(cmp.Left)
if colNames == nil {
return nil, nil, nil
}
valTuples := splitVals(cmp.Right, len(colNames))
// check if the val tuples format is correct.
if valTuples == nil {
return nil, nil, nil
}

sysTableSchemas := make([]evalengine.Expr, 0, len(valTuples))
sysTableNames := make([]evalengine.Expr, 0, len(valTuples))
for index, col := range colNames {
isSchema, isTable := isTableSchemaOrName(col)
var name string
if isSchema {
name = sqltypes.BvSchemaName
} else {
} else if isTable {
name = engine.BvTableName
} else {
// only need to rewrite the SysTable and SysSchema
continue
}

for _, tuple := range valTuples {
expr := tuple[index]
if shouldRewrite(expr) {
tuple[index] = sqlparser.Argument(name)
evalExpr, err := sqlparser.Convert(expr)
if err != nil {
if err == sqlparser.ErrExprNotSupported {
// This just means we can't rewrite this particular expression,
// not that we have to exit altogether
return nil, nil, nil
}
return nil, nil, err
}
if isSchema {
sysTableSchemas = append(sysTableSchemas, evalExpr)
} else if isTable {
sysTableNames = append(sysTableNames, evalExpr)
}
}
}
replaceOther(sqlparser.NewArgument(name))
return isSchemaName, evalExpr, nil
}
// construct right side, rows of tuples of __vtschemaname or database()
cmp.Right = populateValTuple(valTuples, len(colNames))
return sysTableSchemas, sysTableNames, nil
}
}
return nil, nil, nil
}

func populateValTuple(valTuples []sqlparser.ValTuple, numOfCol int) sqlparser.ValTuple {
var retValTuples sqlparser.ValTuple
retValTuples = make([]sqlparser.Expr, 0, len(valTuples))
for _, tuple := range valTuples {
if numOfCol == 1 {
// only one col per row, of colName type.
retValTuples = append(retValTuples, tuple[0])
} else {
retValTuples = append(retValTuples, tuple)
}
}
return retValTuples
}

// Convert the right side of In ops to a list of rows.
func splitVals(e sqlparser.Expr, numOfCols int) []sqlparser.ValTuple {
// could either be (1, 2, 3) or ((1,2), (3,5))
expressions, ok := e.(sqlparser.ValTuple)
if !ok {
log.Errorf("Unsupported type, expecting val tuple %v", e)
return nil
}
valTuples := make([]sqlparser.ValTuple, 0, len(expressions))

for _, tuple := range expressions {
if numOfCols == 1 {
// values could be literal, float or other types.
valTuple := []sqlparser.Expr{tuple}
valTuples = append(valTuples, valTuple)
} else {
valTuple, ok := tuple.(sqlparser.ValTuple)
if !ok {
log.Errorf("Unsupported type, expecting a list of val tuple %v", tuple)
return nil
}
valTuples = append(valTuples, valTuple)
}

}
return valTuples
}

// Convert the left side of In ops to a list of columns.
func checkAndSplitColumns(e sqlparser.Expr) []sqlparser.Expr {
colNames := make([]sqlparser.Expr, 0)
switch cols := e.(type) {
case sqlparser.ValTuple:
colNames = cols
case *sqlparser.ColName:
colNames = append(colNames, cols)
default:
// unexpected left side of the in ops.
return nil
}
containSystemTable := false
for _, col := range colNames {
containsDB, containsTable := isTableSchemaOrName(col)
if containsDB || containsTable {
containSystemTable = true
break
}
}
if !containSystemTable {
log.Infof("left side of (not) in operator don't have a DB name or table name, don't need to rewrite. ")
return nil
}
return false, nil, nil
return colNames
}

func shouldRewrite(e sqlparser.Expr) bool {
Expand Down
Loading