Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 193 additions & 3 deletions go/vt/sqlparser/ast_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ func RewriteAST(in Statement, keyspace string) (*RewriteASTResult, error) {
if err != nil {
return nil, err
}
if setRewriter.err != nil {
return nil, setRewriter.err
}

out, ok := result.(Statement)
if !ok {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "statement rewriting returned a non statement: %s", String(out))
}
if setRewriter.err != nil {
return nil, setRewriter.err
}

r := &RewriteASTResult{
AST: out,
Expand Down Expand Up @@ -342,3 +343,192 @@ func SystemSchema(schema string) bool {
strings.EqualFold(schema, "sys") ||
strings.EqualFold(schema, "mysql")
}

// RewriteToCNF walks the input AST and rewrites any boolean logic into CNF
// Note: In order to re-plan, we need to empty the accumulated metadata in the AST,
// so ColName.Metadata will be nil:ed out as part of this rewrite
func RewriteToCNF(ast SQLNode) (SQLNode, error) {
var err error
for {
finishedRewrite := true
ast, err = Rewrite(ast, func(cursor *Cursor) bool {
if e, isExpr := cursor.node.(Expr); isExpr {
rewritten, didRewrite := rewriteToCNFExpr(e)
if didRewrite {
finishedRewrite = false
cursor.Replace(rewritten)
}
}
if col, isCol := cursor.node.(*ColName); isCol {
col.Metadata = nil
}
return true
}, nil)
if err != nil {
return nil, err
}

if finishedRewrite {
return ast, nil
}
}
}

func distinctOr(in *OrExpr) (Expr, bool) {
todo := []*OrExpr{in}
var leaves []Expr
for len(todo) > 0 {
curr := todo[0]
todo = todo[1:]
addAnd := func(in Expr) {
and, ok := in.(*OrExpr)
if ok {
todo = append(todo, and)
} else {
leaves = append(leaves, in)
}
}
addAnd(curr.Left)
addAnd(curr.Right)
}
original := len(leaves)
var predicates []Expr

outer1:
for len(leaves) > 0 {
curr := leaves[0]
leaves = leaves[1:]
for _, alreadyIn := range predicates {
if EqualsExpr(alreadyIn, curr) {
continue outer1
}
}
predicates = append(predicates, curr)
}
if original == len(predicates) {
return in, false
}
var result Expr
for i, curr := range predicates {
if i == 0 {
result = curr
continue
}
result = &OrExpr{Left: result, Right: curr}
}
return result, true
}
func distinctAnd(in *AndExpr) (Expr, bool) {
todo := []*AndExpr{in}
var leaves []Expr
for len(todo) > 0 {
curr := todo[0]
todo = todo[1:]
addAnd := func(in Expr) {
and, ok := in.(*AndExpr)
if ok {
todo = append(todo, and)
} else {
leaves = append(leaves, in)
}
}
addAnd(curr.Left)
addAnd(curr.Right)
}
original := len(leaves)
var predicates []Expr

outer1:
for len(leaves) > 0 {
curr := leaves[0]
leaves = leaves[1:]
for _, alreadyIn := range predicates {
if EqualsExpr(alreadyIn, curr) {
continue outer1
}
}
predicates = append(predicates, curr)
}
if original == len(predicates) {
return in, false
}
var result Expr
for i, curr := range predicates {
if i == 0 {
result = curr
continue
}
result = &AndExpr{Left: result, Right: curr}
}
return result, true
}

func rewriteToCNFExpr(expr Expr) (Expr, bool) {
switch expr := expr.(type) {
case *NotExpr:
switch child := expr.Expr.(type) {
case *NotExpr:
// NOT NOT A => A
return child.Expr, true
case *OrExpr:
// DeMorgan Rewriter
// NOT (A OR B) => NOT A AND NOT B
return &AndExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true
case *AndExpr:
// DeMorgan Rewriter
// NOT (A AND B) => NOT A OR NOT B
return &OrExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true
}
case *OrExpr:
or := expr
if and, ok := or.Left.(*AndExpr); ok {
// Simplification
// (A AND B) OR A => A
if EqualsExpr(or.Right, and.Left) || EqualsExpr(or.Right, and.Right) {
return or.Right, true
}
// Distribution Law
// (A AND B) OR C => (A OR C) AND (B OR C)
return &AndExpr{Left: &OrExpr{Left: and.Left, Right: or.Right}, Right: &OrExpr{Left: and.Right, Right: or.Right}}, true
}
if and, ok := or.Right.(*AndExpr); ok {
// Simplification
// A OR (A AND B) => A
if EqualsExpr(or.Left, and.Left) || EqualsExpr(or.Left, and.Right) {
return or.Left, true
}
// Distribution Law
// C OR (A AND B) => (C OR A) AND (C OR B)
return &AndExpr{Left: &OrExpr{Left: or.Left, Right: and.Left}, Right: &OrExpr{Left: or.Left, Right: and.Right}}, true
}
// Try to make distinct
return distinctOr(expr)

case *XorExpr:
// DeMorgan Rewriter
// (A XOR B) => (A OR B) AND NOT (A AND B)
return &AndExpr{Left: &OrExpr{Left: expr.Left, Right: expr.Right}, Right: &NotExpr{Expr: &AndExpr{Left: expr.Left, Right: expr.Right}}}, true
case *AndExpr:
res, rewritten := distinctAnd(expr)
if rewritten {
return res, rewritten
}
and := expr
if or, ok := and.Left.(*OrExpr); ok {
// Simplification
// (A OR B) AND A => A
if EqualsExpr(or.Left, and.Right) || EqualsExpr(or.Right, and.Right) {
return and.Right, true
}
}
if or, ok := and.Right.(*OrExpr); ok {
// Simplification
// A OR (A AND B) => A
if EqualsExpr(or.Left, and.Left) || EqualsExpr(or.Right, and.Left) {
return or.Left, true
}
}

}
return expr, false
}
96 changes: 96 additions & 0 deletions go/vt/sqlparser/ast_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,99 @@ func TestRewritesWithDefaultKeyspace(in *testing.T) {
})
}
}

func TestRewriteToCNF(in *testing.T) {
tests := []struct {
in string
expected string
}{{
in: "not (not A = 3)",
expected: "A = 3",
}, {
in: "not (A = 3 and B = 2)",
expected: "not A = 3 or not B = 2",
}, {
in: "not (A = 3 or B = 2)",
expected: "not A = 3 and not B = 2",
}, {
in: "A xor B",
expected: "(A or B) and not (A and B)",
}, {
in: "(A and B) or C",
expected: "(A or C) and (B or C)",
}, {
in: "C or (A and B)",
expected: "(C or A) and (C or B)",
}, {
in: "A and A",
expected: "A",
}, {
in: "A OR A",
expected: "A",
}, {
in: "A OR (A AND B)",
expected: "A",
}, {
in: "A OR (B AND A)",
expected: "A",
}, {
in: "(A AND B) OR A",
expected: "A",
}, {
in: "(B AND A) OR A",
expected: "A",
}, {
in: "(A and B) and (B and A)",
expected: "A and B",
}, {
in: "(A or B) and A",
expected: "A",
}, {
in: "A and (A or B)",
expected: "A",
}}

for _, tc := range tests {
in.Run(tc.in, func(t *testing.T) {
stmt, err := Parse("SELECT * FROM T WHERE " + tc.in)
require.NoError(t, err)

expr := stmt.(*Select).Where.Expr
expr, didRewrite := rewriteToCNFExpr(expr)
assert.True(t, didRewrite)
assert.Equal(t, tc.expected, String(expr))
})
}
}

func TestFixedPointRewriteToCNF(in *testing.T) {
tests := []struct {
in string
expected string
}{{
in: "A xor B",
expected: "(A or B) and (not A or not B)",
}, {
in: "(A and B) and (B and A) and (B and A) and (A and B)",
expected: "A and B",
}, {
in: "((A and B) OR (A and C) OR (A and D)) and E and F",
expected: "A and ((A or B) and (B or C or A)) and ((A or D) and ((B or A or D) and (B or C or D))) and E and F",
}, {
in: "(A and B) OR (A and C)",
expected: "A and ((B or A) and (B or C))",
}}

for _, tc := range tests {
in.Run(tc.in, func(t *testing.T) {
require := require.New(t)
stmt, err := Parse("SELECT * FROM T WHERE " + tc.in)
require.NoError(err)

expr := stmt.(*Select).Where.Expr
output, err := RewriteToCNF(expr)
require.NoError(err)
assert.Equal(t, tc.expected, String(output))
})
}
}
1 change: 1 addition & 0 deletions go/vt/sqlparser/precedence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ func TestParens(t *testing.T) {
{in: "(a & b) | c", expected: "a & b | c"},
{in: "not (a=b and c=d)", expected: "not (a = b and c = d)"},
{in: "not (a=b) and c=d", expected: "not a = b and c = d"},
{in: "(not (a=b)) and c=d", expected: "not a = b and c = d"},
{in: "-(12)", expected: "-12"},
{in: "-(12 + 12)", expected: "-(12 + 12)"},
{in: "(1 > 2) and (1 = b)", expected: "1 > 2 and 1 = b"},
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ func locateFile(name string) string {
}

func BenchmarkPlanner(b *testing.B) {
filenames := []string{"from_cases.txt", "filter_cases.txt", "large_cases.txt", "aggr_cases.txt", "memory_sort_cases.txt", "select_cases.txt", "union_cases.txt", "wireup_cases.txt"}
filenames := []string{"from_cases.txt", "filter_cases.txt", "large_cases.txt", "aggr_cases.txt", "select_cases.txt", "union_cases.txt"}
vschema := &vschemaWrapper{
v: loadSchema(b, "schema_test.json"),
sysVarEnabled: true,
Expand Down
Loading