diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go index df9b0af9b38..885ce9e9609 100644 --- a/go/vt/sqlparser/ast_rewriting.go +++ b/go/vt/sqlparser/ast_rewriting.go @@ -52,13 +52,14 @@ func RewriteAST(in Statement, keyspace string) (*RewriteASTResult, error) { if err != nil { return nil, err } + if setRewriter.err != nil { + return nil, setRewriter.err + } + out, ok := result.(Statement) if !ok { return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "statement rewriting returned a non statement: %s", String(out)) } - if setRewriter.err != nil { - return nil, setRewriter.err - } r := &RewriteASTResult{ AST: out, @@ -342,3 +343,192 @@ func SystemSchema(schema string) bool { strings.EqualFold(schema, "sys") || strings.EqualFold(schema, "mysql") } + +// RewriteToCNF walks the input AST and rewrites any boolean logic into CNF +// Note: In order to re-plan, we need to empty the accumulated metadata in the AST, +// so ColName.Metadata will be nil:ed out as part of this rewrite +func RewriteToCNF(ast SQLNode) (SQLNode, error) { + var err error + for { + finishedRewrite := true + ast, err = Rewrite(ast, func(cursor *Cursor) bool { + if e, isExpr := cursor.node.(Expr); isExpr { + rewritten, didRewrite := rewriteToCNFExpr(e) + if didRewrite { + finishedRewrite = false + cursor.Replace(rewritten) + } + } + if col, isCol := cursor.node.(*ColName); isCol { + col.Metadata = nil + } + return true + }, nil) + if err != nil { + return nil, err + } + + if finishedRewrite { + return ast, nil + } + } +} + +func distinctOr(in *OrExpr) (Expr, bool) { + todo := []*OrExpr{in} + var leaves []Expr + for len(todo) > 0 { + curr := todo[0] + todo = todo[1:] + addAnd := func(in Expr) { + and, ok := in.(*OrExpr) + if ok { + todo = append(todo, and) + } else { + leaves = append(leaves, in) + } + } + addAnd(curr.Left) + addAnd(curr.Right) + } + original := len(leaves) + var predicates []Expr + +outer1: + for len(leaves) > 0 { + curr := leaves[0] + leaves = leaves[1:] + for _, alreadyIn := range predicates { + if EqualsExpr(alreadyIn, curr) { + continue outer1 + } + } + predicates = append(predicates, curr) + } + if original == len(predicates) { + return in, false + } + var result Expr + for i, curr := range predicates { + if i == 0 { + result = curr + continue + } + result = &OrExpr{Left: result, Right: curr} + } + return result, true +} +func distinctAnd(in *AndExpr) (Expr, bool) { + todo := []*AndExpr{in} + var leaves []Expr + for len(todo) > 0 { + curr := todo[0] + todo = todo[1:] + addAnd := func(in Expr) { + and, ok := in.(*AndExpr) + if ok { + todo = append(todo, and) + } else { + leaves = append(leaves, in) + } + } + addAnd(curr.Left) + addAnd(curr.Right) + } + original := len(leaves) + var predicates []Expr + +outer1: + for len(leaves) > 0 { + curr := leaves[0] + leaves = leaves[1:] + for _, alreadyIn := range predicates { + if EqualsExpr(alreadyIn, curr) { + continue outer1 + } + } + predicates = append(predicates, curr) + } + if original == len(predicates) { + return in, false + } + var result Expr + for i, curr := range predicates { + if i == 0 { + result = curr + continue + } + result = &AndExpr{Left: result, Right: curr} + } + return result, true +} + +func rewriteToCNFExpr(expr Expr) (Expr, bool) { + switch expr := expr.(type) { + case *NotExpr: + switch child := expr.Expr.(type) { + case *NotExpr: + // NOT NOT A => A + return child.Expr, true + case *OrExpr: + // DeMorgan Rewriter + // NOT (A OR B) => NOT A AND NOT B + return &AndExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true + case *AndExpr: + // DeMorgan Rewriter + // NOT (A AND B) => NOT A OR NOT B + return &OrExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true + } + case *OrExpr: + or := expr + if and, ok := or.Left.(*AndExpr); ok { + // Simplification + // (A AND B) OR A => A + if EqualsExpr(or.Right, and.Left) || EqualsExpr(or.Right, and.Right) { + return or.Right, true + } + // Distribution Law + // (A AND B) OR C => (A OR C) AND (B OR C) + return &AndExpr{Left: &OrExpr{Left: and.Left, Right: or.Right}, Right: &OrExpr{Left: and.Right, Right: or.Right}}, true + } + if and, ok := or.Right.(*AndExpr); ok { + // Simplification + // A OR (A AND B) => A + if EqualsExpr(or.Left, and.Left) || EqualsExpr(or.Left, and.Right) { + return or.Left, true + } + // Distribution Law + // C OR (A AND B) => (C OR A) AND (C OR B) + return &AndExpr{Left: &OrExpr{Left: or.Left, Right: and.Left}, Right: &OrExpr{Left: or.Left, Right: and.Right}}, true + } + // Try to make distinct + return distinctOr(expr) + + case *XorExpr: + // DeMorgan Rewriter + // (A XOR B) => (A OR B) AND NOT (A AND B) + return &AndExpr{Left: &OrExpr{Left: expr.Left, Right: expr.Right}, Right: &NotExpr{Expr: &AndExpr{Left: expr.Left, Right: expr.Right}}}, true + case *AndExpr: + res, rewritten := distinctAnd(expr) + if rewritten { + return res, rewritten + } + and := expr + if or, ok := and.Left.(*OrExpr); ok { + // Simplification + // (A OR B) AND A => A + if EqualsExpr(or.Left, and.Right) || EqualsExpr(or.Right, and.Right) { + return and.Right, true + } + } + if or, ok := and.Right.(*OrExpr); ok { + // Simplification + // A OR (A AND B) => A + if EqualsExpr(or.Left, and.Left) || EqualsExpr(or.Right, and.Left) { + return or.Left, true + } + } + + } + return expr, false +} diff --git a/go/vt/sqlparser/ast_rewriting_test.go b/go/vt/sqlparser/ast_rewriting_test.go index c2368d81514..1d2b746755b 100644 --- a/go/vt/sqlparser/ast_rewriting_test.go +++ b/go/vt/sqlparser/ast_rewriting_test.go @@ -311,3 +311,99 @@ func TestRewritesWithDefaultKeyspace(in *testing.T) { }) } } + +func TestRewriteToCNF(in *testing.T) { + tests := []struct { + in string + expected string + }{{ + in: "not (not A = 3)", + expected: "A = 3", + }, { + in: "not (A = 3 and B = 2)", + expected: "not A = 3 or not B = 2", + }, { + in: "not (A = 3 or B = 2)", + expected: "not A = 3 and not B = 2", + }, { + in: "A xor B", + expected: "(A or B) and not (A and B)", + }, { + in: "(A and B) or C", + expected: "(A or C) and (B or C)", + }, { + in: "C or (A and B)", + expected: "(C or A) and (C or B)", + }, { + in: "A and A", + expected: "A", + }, { + in: "A OR A", + expected: "A", + }, { + in: "A OR (A AND B)", + expected: "A", + }, { + in: "A OR (B AND A)", + expected: "A", + }, { + in: "(A AND B) OR A", + expected: "A", + }, { + in: "(B AND A) OR A", + expected: "A", + }, { + in: "(A and B) and (B and A)", + expected: "A and B", + }, { + in: "(A or B) and A", + expected: "A", + }, { + in: "A and (A or B)", + expected: "A", + }} + + for _, tc := range tests { + in.Run(tc.in, func(t *testing.T) { + stmt, err := Parse("SELECT * FROM T WHERE " + tc.in) + require.NoError(t, err) + + expr := stmt.(*Select).Where.Expr + expr, didRewrite := rewriteToCNFExpr(expr) + assert.True(t, didRewrite) + assert.Equal(t, tc.expected, String(expr)) + }) + } +} + +func TestFixedPointRewriteToCNF(in *testing.T) { + tests := []struct { + in string + expected string + }{{ + in: "A xor B", + expected: "(A or B) and (not A or not B)", + }, { + in: "(A and B) and (B and A) and (B and A) and (A and B)", + expected: "A and B", + }, { + in: "((A and B) OR (A and C) OR (A and D)) and E and F", + expected: "A and ((A or B) and (B or C or A)) and ((A or D) and ((B or A or D) and (B or C or D))) and E and F", + }, { + in: "(A and B) OR (A and C)", + expected: "A and ((B or A) and (B or C))", + }} + + for _, tc := range tests { + in.Run(tc.in, func(t *testing.T) { + require := require.New(t) + stmt, err := Parse("SELECT * FROM T WHERE " + tc.in) + require.NoError(err) + + expr := stmt.(*Select).Where.Expr + output, err := RewriteToCNF(expr) + require.NoError(err) + assert.Equal(t, tc.expected, String(output)) + }) + } +} diff --git a/go/vt/sqlparser/precedence_test.go b/go/vt/sqlparser/precedence_test.go index 801a8faa1d2..7b917f8e698 100644 --- a/go/vt/sqlparser/precedence_test.go +++ b/go/vt/sqlparser/precedence_test.go @@ -132,6 +132,7 @@ func TestParens(t *testing.T) { {in: "(a & b) | c", expected: "a & b | c"}, {in: "not (a=b and c=d)", expected: "not (a = b and c = d)"}, {in: "not (a=b) and c=d", expected: "not a = b and c = d"}, + {in: "(not (a=b)) and c=d", expected: "not a = b and c = d"}, {in: "-(12)", expected: "-12"}, {in: "-(12 + 12)", expected: "-(12 + 12)"}, {in: "(1 > 2) and (1 = b)", expected: "1 > 2 and 1 = b"}, diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index c3b5e682a1d..846c920ab19 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -599,7 +599,7 @@ func locateFile(name string) string { } func BenchmarkPlanner(b *testing.B) { - filenames := []string{"from_cases.txt", "filter_cases.txt", "large_cases.txt", "aggr_cases.txt", "memory_sort_cases.txt", "select_cases.txt", "union_cases.txt", "wireup_cases.txt"} + filenames := []string{"from_cases.txt", "filter_cases.txt", "large_cases.txt", "aggr_cases.txt", "select_cases.txt", "union_cases.txt"} vschema := &vschemaWrapper{ v: loadSchema(b, "schema_test.json"), sysVarEnabled: true, diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 2ce06046ca2..110e4e2c5cf 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" + "vitess.io/vitess/go/vt/orchestrator/external/golib/log" + "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/key" @@ -45,17 +47,61 @@ func buildSelectPlan(query string) func(sqlparser.Statement, sqlparser.BindVars, return p, nil } - pb := newPrimitiveBuilder(vschema, newJointab(reservedVars)) - if err := pb.processSelect(sel, reservedVars, nil, query); err != nil { - return nil, err + getPlan := func(sel *sqlparser.Select) (logicalPlan, error) { + pb := newPrimitiveBuilder(vschema, newJointab(reservedVars)) + if err := pb.processSelect(sel, reservedVars, nil, query); err != nil { + return nil, err + } + if err := pb.plan.Wireup(pb.plan, pb.jt); err != nil { + return nil, err + } + return pb.plan, nil } - if err := pb.plan.Wireup(pb.plan, pb.jt); err != nil { + + plan, err := getPlan(sel) + if err != nil { return nil, err } - return pb.plan.Primitive(), nil + + if shouldRetryWithCNFRewriting(plan) { + // by transforming the predicates to CNF, the planner will sometimes find better plans + primitive := rewriteToCNFAndReplan(stmt, getPlan) + if primitive != nil { + return primitive, nil + } + } + return plan.Primitive(), nil } } +func rewriteToCNFAndReplan(stmt sqlparser.Statement, getPlan func(sel *sqlparser.Select) (logicalPlan, error)) engine.Primitive { + rewritten, err := sqlparser.RewriteToCNF(stmt) + if err == nil { + sel2, isSelect := rewritten.(*sqlparser.Select) + if isSelect { + log.Infof("retrying plan after cnf: %s", sqlparser.String(sel2)) + plan2, err := getPlan(sel2) + if err == nil && !shouldRetryWithCNFRewriting(plan2) { + // we only use this new plan if it's better than the old one we got + return plan2.Primitive() + } + } + } + return nil +} + +func shouldRetryWithCNFRewriting(plan logicalPlan) bool { + routePlan, isRoute := plan.(*route) + if !isRoute { + return false + } + // if we have a I_S query, but have not found table_schema or table_name, let's try CNF + return routePlan.eroute.Opcode == engine.SelectDBA && + routePlan.eroute.SysTableTableName == nil && + routePlan.eroute.SysTableTableSchema == nil + +} + func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *semantics.SemTable) (firstOffset int, err error) { switch node := plan.(type) { case *route: diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt index 23b246e4858..125be6f87d0 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt @@ -1886,3 +1886,21 @@ Gen4 plan same as above ] } } + +# able to isolate table_schema value even when hidden inside of ORs +"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'ks' and other_column = 42) OR (TABLE_SCHEMA = 'ks' and foobar = 'value')" +{ + "QueryType": "SELECT", + "Original": "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE (TABLE_SCHEMA = 'ks' and other_column = 42) OR (TABLE_SCHEMA = 'ks' and foobar = 'value')", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select * from INFORMATION_SCHEMA.`TABLES` where 1 != 1", + "Query": "select * from INFORMATION_SCHEMA.`TABLES` where TABLE_SCHEMA = :__vtschemaname and (other_column = 42 or TABLE_SCHEMA = 'ks') and (other_column = 42 or foobar = 'value')", + "SysTableTableSchema": "VARBINARY(\"ks\")" + } +}