diff --git a/go.sum b/go.sum index 631c908d4b..aaee9a0082 100644 --- a/go.sum +++ b/go.sum @@ -58,14 +58,6 @@ github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYU github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72/go.mod h1:ZWUdY4iszqRQ8OcoXClkxiAVAoWoK3cq0Hvv4ddGRuM= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20240110233415-e46007d964c0 h1:P8wb4dR5krirPa0swEJbEObc/I7GaAM/01nOnuQrl0c= -github.com/dolthub/vitess v0.0.0-20240110233415-e46007d964c0/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= -github.com/dolthub/vitess v0.0.0-20240117061527-f9260279b3d3 h1:nEwq2/8gTI2jm/4APIMTrWNDDRCn8AWJjrCbH+d7CJc= -github.com/dolthub/vitess v0.0.0-20240117061527-f9260279b3d3/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= -github.com/dolthub/vitess v0.0.0-20240117195812-420942cccb48 h1:Bdsy71WXx4yvK71IFwIqQ2duL5a/y15EuKEhVN51bSE= -github.com/dolthub/vitess v0.0.0-20240117195812-420942cccb48/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= -github.com/dolthub/vitess v0.0.0-20240117220136-123ca09b8929 h1:6SExRtdwbcNPl7q09SXxtnwk+pVdhrsd0ap1DVfphEg= -github.com/dolthub/vitess v0.0.0-20240117220136-123ca09b8929/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462 h1:So1KO202cb047yWg5X27xRso6tkSYmU0Yu96JIVsaEU= github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= diff --git a/sql/analyzer/assign_update_join.go b/sql/analyzer/assign_update_join.go index 2aa333205e..c2184aaa05 100644 --- a/sql/analyzer/assign_update_join.go +++ b/sql/analyzer/assign_update_join.go @@ -87,7 +87,7 @@ func getTablesToBeUpdated(node sql.Node) map[string]struct{} { transform.InspectExpressions(node, func(e sql.Expression) bool { switch e := e.(type) { case *expression.SetField: - gf := e.Left.(*expression.GetField) + gf := e.LeftChild.(*expression.GetField) ret[gf.Table()] = struct{}{} return false } diff --git a/sql/analyzer/fix_exec_indexes.go b/sql/analyzer/fix_exec_indexes.go index 20c5185370..43ab41d4e1 100644 --- a/sql/analyzer/fix_exec_indexes.go +++ b/sql/analyzer/fix_exec_indexes.go @@ -343,8 +343,8 @@ func (s *idxScope) visitSelf(n sql.Node) error { } // left uses destination schema // right uses |rightSchema| - newLeft := fixExprToScope(set.Left, dstScope) - newRight := fixExprToScope(set.Right, rightScope) + newLeft := fixExprToScope(set.LeftChild, dstScope) + newRight := fixExprToScope(set.RightChild, rightScope) s.expressions = append(s.expressions, expression.NewSetField(newLeft, newRight)) } for _, c := range n.Checks() { diff --git a/sql/analyzer/hoist_filters.go b/sql/analyzer/hoist_filters.go index ba1642740b..7609577c4b 100644 --- a/sql/analyzer/hoist_filters.go +++ b/sql/analyzer/hoist_filters.go @@ -85,7 +85,7 @@ func recurseSubqueryForOuterFilters(n sql.Node, a *Analyzer, corr sql.ColSet) (s var sq *plan.Subquery switch e := e.(type) { case *plan.InSubquery: - sq, _ = e.Right.(*plan.Subquery) + sq, _ = e.RightChild.(*plan.Subquery) case *plan.ExistsSubquery: sq = e.Query default: diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index d776efec51..a1788b3c70 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -244,50 +244,50 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S expression.NewLessThanOrEqual(e.Val, e.Upper), ), transform.NewTree, nil case *expression.Or: - if isTrue(e.Left) { - return e.Left, transform.NewTree, nil + if isTrue(e.LeftChild) { + return e.LeftChild, transform.NewTree, nil } - if isTrue(e.Right) { - return e.Right, transform.NewTree, nil + if isTrue(e.RightChild) { + return e.RightChild, transform.NewTree, nil } - if isFalse(e.Left) { - return e.Right, transform.NewTree, nil + if isFalse(e.LeftChild) { + return e.RightChild, transform.NewTree, nil } - if isFalse(e.Right) { - return e.Left, transform.NewTree, nil + if isFalse(e.RightChild) { + return e.LeftChild, transform.NewTree, nil } return e, transform.SameTree, nil case *expression.And: - if isFalse(e.Left) { - return e.Left, transform.NewTree, nil + if isFalse(e.LeftChild) { + return e.LeftChild, transform.NewTree, nil } - if isFalse(e.Right) { - return e.Right, transform.NewTree, nil + if isFalse(e.RightChild) { + return e.RightChild, transform.NewTree, nil } - if isTrue(e.Left) { - return e.Right, transform.NewTree, nil + if isTrue(e.LeftChild) { + return e.RightChild, transform.NewTree, nil } - if isTrue(e.Right) { - return e.Left, transform.NewTree, nil + if isTrue(e.RightChild) { + return e.LeftChild, transform.NewTree, nil } return e, transform.SameTree, nil case *expression.Like: // if the charset is not utf8mb4, the last character used in optimization rule does not work - coll, _ := sql.GetCoercibility(ctx, e.Left) + coll, _ := sql.GetCoercibility(ctx, e.LeftChild) charset := coll.CharacterSet() if charset != sql.CharacterSet_utf8mb4 { return e, transform.SameTree, nil } // TODO: maybe more cases to simplify - r, ok := e.Right.(*expression.Literal) + r, ok := e.RightChild.(*expression.Literal) if !ok { return e, transform.SameTree, nil } @@ -310,7 +310,7 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S // if there are also no multiple character wildcards, this is just a plain equals numWild := strings.Count(valStr, "%") - strings.Count(valStr, "\\%") if numWild == 0 { - return expression.NewEquals(e.Left, e.Right), transform.NewTree, nil + return expression.NewEquals(e.LeftChild, e.RightChild), transform.NewTree, nil } // if there are many multiple character wildcards, don't simplify if numWild != 1 { @@ -328,10 +328,10 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S return e, transform.SameTree, nil } valStr = valStr[:len(valStr)-1] - newRightLower := expression.NewLiteral(valStr, e.Right.Type()) + newRightLower := expression.NewLiteral(valStr, e.RightChild.Type()) valStr += string(byte(255)) // append largest possible character as upper bound - newRightUpper := expression.NewLiteral(valStr, e.Right.Type()) - newExpr := expression.NewAnd(expression.NewGreaterThanOrEqual(e.Left, newRightLower), expression.NewLessThanOrEqual(e.Left, newRightUpper)) + newRightUpper := expression.NewLiteral(valStr, e.RightChild.Type()) + newExpr := expression.NewAnd(expression.NewGreaterThanOrEqual(e.LeftChild, newRightLower), expression.NewLessThanOrEqual(e.LeftChild, newRightUpper)) return newExpr, transform.NewTree, nil case *expression.Literal, expression.Tuple, *expression.Interval, *expression.CollatedExpression, *expression.MatchAgainst: return e, transform.SameTree, nil @@ -435,14 +435,14 @@ func pushNotFiltersHelper(e sql.Expression) (sql.Expression, error) { // NOT(AND(left,right))=>OR(NOT(left), NOT(right)) if not, _ := e.(*expression.Not); not != nil { if f, _ := not.Child.(*expression.And); f != nil { - return pushNotFiltersHelper(expression.NewOr(expression.NewNot(f.Left), expression.NewNot(f.Right))) + return pushNotFiltersHelper(expression.NewOr(expression.NewNot(f.LeftChild), expression.NewNot(f.RightChild))) } } // NOT(OR(left,right))=>AND(NOT(left), NOT(right)) if not, _ := e.(*expression.Not); not != nil { if f, _ := not.Child.(*expression.Or); f != nil { - return pushNotFiltersHelper(expression.NewAnd(expression.NewNot(f.Left), expression.NewNot(f.Right))) + return pushNotFiltersHelper(expression.NewAnd(expression.NewNot(f.LeftChild), expression.NewNot(f.RightChild))) } } diff --git a/sql/analyzer/triggers.go b/sql/analyzer/triggers.go index 05f006832b..e5baaa4eb1 100644 --- a/sql/analyzer/triggers.go +++ b/sql/analyzer/triggers.go @@ -69,7 +69,7 @@ func validateCreateTrigger(ctx *sql.Context, a *Analyzer, node sql.Node, scope * switch e := e.(type) { case *expression.SetField: - switch left := e.Left.(type) { + switch left := e.LeftChild.(type) { case column: if strings.ToLower(left.Table()) == "old" { err = sql.ErrInvalidUpdateOfOldRow.New() diff --git a/sql/analyzer/unnest_insubqueries.go b/sql/analyzer/unnest_insubqueries.go index bbe5838ae1..367ca8a3b2 100644 --- a/sql/analyzer/unnest_insubqueries.go +++ b/sql/analyzer/unnest_insubqueries.go @@ -97,8 +97,8 @@ func unnestInSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S var max1 bool switch e := candE.(type) { case *plan.InSubquery: - sq, _ = e.Right.(*plan.Subquery) - l = e.Left + sq, _ = e.RightChild.(*plan.Subquery) + l = e.LeftChild joinF = expression.NewEquals(nil, nil) case expression.Comparer: diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 90244075a0..f116877e47 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -922,11 +922,11 @@ func validateExprSem(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop func validateSem(e sql.Expression) error { switch e := e.(type) { case *expression.And: - if err := logicalSem(e.BinaryExpression); err != nil { + if err := logicalSem(e.BinaryExpressionStub); err != nil { return err } case *expression.Or: - if err := logicalSem(e.BinaryExpression); err != nil { + if err := logicalSem(e.BinaryExpressionStub); err != nil { return err } default: @@ -934,11 +934,11 @@ func validateSem(e sql.Expression) error { return nil } -func logicalSem(e expression.BinaryExpression) error { - if lc := fds(e.Left); lc != 1 { +func logicalSem(e expression.BinaryExpressionStub) error { + if lc := fds(e.LeftChild); lc != 1 { return sql.ErrInvalidOperandColumns.New(1, lc) } - if rc := fds(e.Right); rc != 1 { + if rc := fds(e.RightChild); rc != 1 { return sql.ErrInvalidOperandColumns.New(1, rc) } return nil diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 6817b17437..eeaeed96b6 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -57,8 +57,7 @@ func arithmeticWarning(ctx *sql.Context, errCode int, errMsg string) { // Arithmetic expression in the future. type ArithmeticOp interface { sql.Expression - LeftChild() sql.Expression - RightChild() sql.Expression + BinaryExpression SetOpCount(int32) Operator() string } @@ -68,14 +67,14 @@ var _ sql.CollationCoercible = (*Arithmetic)(nil) // Arithmetic expressions include plus, minus and multiplication (+, -, *) operations. type Arithmetic struct { - BinaryExpression + BinaryExpressionStub Op string ops int32 } // NewArithmetic creates a new Arithmetic sql.Expression. func NewArithmetic(left, right sql.Expression, op string) *Arithmetic { - a := &Arithmetic{BinaryExpression{Left: left, Right: right}, op, 0} + a := &Arithmetic{BinaryExpressionStub{LeftChild: left, RightChild: right}, op, 0} ops := countArithmeticOps(a) setArithmeticOps(a, ops) return a @@ -96,14 +95,6 @@ func NewMult(left, right sql.Expression) *Arithmetic { return NewArithmetic(left, right, sqlparser.MultStr) } -func (a *Arithmetic) LeftChild() sql.Expression { - return a.Left -} - -func (a *Arithmetic) RightChild() sql.Expression { - return a.Right -} - func (a *Arithmetic) Operator() string { return a.Op } @@ -113,11 +104,11 @@ func (a *Arithmetic) SetOpCount(i int32) { } func (a *Arithmetic) String() string { - return fmt.Sprintf("(%s %s %s)", a.Left, a.Op, a.Right) + return fmt.Sprintf("(%s %s %s)", a.LeftChild, a.Op, a.RightChild) } func (a *Arithmetic) DebugString() string { - return fmt.Sprintf("(%s %s %s)", sql.DebugString(a.Left), a.Op, sql.DebugString(a.Right)) + return fmt.Sprintf("(%s %s %s)", sql.DebugString(a.LeftChild), a.Op, sql.DebugString(a.RightChild)) } // IsNullable implements the sql.Expression interface. @@ -126,23 +117,23 @@ func (a *Arithmetic) IsNullable() bool { return true } - return a.BinaryExpression.IsNullable() + return a.BinaryExpressionStub.IsNullable() } // Type returns the greatest type for given operation. func (a *Arithmetic) Type() sql.Type { //TODO: what if both BindVars? should be constant folded - rTyp := a.Right.Type() + rTyp := a.RightChild.Type() if types.IsDeferredType(rTyp) { return rTyp } - lTyp := a.Left.Type() + lTyp := a.LeftChild.Type() if types.IsDeferredType(lTyp) { return lTyp } // applies for + and - ops - if isInterval(a.Left) || isInterval(a.Right) { + if isInterval(a.LeftChild) || isInterval(a.RightChild) { // TODO: we might need to truncate precision here return types.DatetimeMaxPrecision } @@ -167,7 +158,7 @@ func (a *Arithmetic) Type() sql.Type { } if a.Op == sqlparser.MultStr { - return floatOrDecimalTypeForMult(a.Left, a.Right) + return floatOrDecimalTypeForMult(a.LeftChild, a.RightChild) } else { return getFloatOrMaxDecimalType(a, false) } @@ -225,25 +216,25 @@ func (a *Arithmetic) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, var lval, rval interface{} var err error - if i, ok := a.Left.(*Interval); ok { + if i, ok := a.LeftChild.(*Interval); ok { lval, err = i.EvalDelta(ctx, row) if err != nil { return nil, nil, err } } else { - lval, err = a.Left.Eval(ctx, row) + lval, err = a.LeftChild.Eval(ctx, row) if err != nil { return nil, nil, err } } - if i, ok := a.Right.(*Interval); ok { + if i, ok := a.RightChild.(*Interval); ok { rval, err = i.EvalDelta(ctx, row) if err != nil { return nil, nil, err } } else { - rval, err = a.Right.Eval(ctx, row) + rval, err = a.RightChild.Eval(ctx, row) if err != nil { return nil, nil, err } @@ -255,8 +246,8 @@ func (a *Arithmetic) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, func (a *Arithmetic) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}, error) { typ := a.Type() - lIsTimeType := types.IsTime(a.Left.Type()) - rIsTimeType := types.IsTime(a.Right.Type()) + lIsTimeType := types.IsTime(a.LeftChild.Type()) + rIsTimeType := types.IsTime(a.RightChild.Type()) if i, ok := left.(*TimeDelta); ok { left = i @@ -296,7 +287,7 @@ func countArithmeticOps(e sql.Expression) int32 { } if a, ok := e.(ArithmeticOp); ok { - return countDivs(a.LeftChild()) + 1 + return countDivs(a.Left()) + 1 } return 0 @@ -311,8 +302,8 @@ func setArithmeticOps(e sql.Expression, opScale int32) { if a, ok := e.(ArithmeticOp); ok { a.SetOpCount(opScale) - setDivs(a.LeftChild(), opScale) - setDivs(a.RightChild(), opScale) + setDivs(a.Left(), opScale) + setDivs(a.Right(), opScale) } return @@ -330,7 +321,7 @@ func isOutermostArithmeticOp(e sql.Expression, d, dScale int32) bool { if d == dScale { return true } else { - return isOutermostDiv(a.LeftChild(), d, dScale) + return isOutermostDiv(a.Left(), d, dScale) } } diff --git a/sql/expression/bit_ops.go b/sql/expression/bit_ops.go index 528994898f..ba1f24bd20 100644 --- a/sql/expression/bit_ops.go +++ b/sql/expression/bit_ops.go @@ -30,7 +30,7 @@ import ( // BitOp expressions include BIT -AND, -OR and -XOR (&, | and ^) operations // https://dev.mysql.com/doc/refman/8.0/en/bit-functions.html type BitOp struct { - BinaryExpression + BinaryExpressionStub Op string } @@ -39,7 +39,7 @@ var _ sql.CollationCoercible = (*BitOp)(nil) // NewBitOp creates a new BitOp sql.Expression. func NewBitOp(left, right sql.Expression, op string) *BitOp { - return &BitOp{BinaryExpression{Left: left, Right: right}, op} + return &BitOp{BinaryExpressionStub{LeftChild: left, RightChild: right}, op} } // NewBitAnd creates a new BitOp & sql.Expression. @@ -68,25 +68,25 @@ func NewShiftRight(left, right sql.Expression) *BitOp { } func (b *BitOp) String() string { - return fmt.Sprintf("(%s %s %s)", b.Left, b.Op, b.Right) + return fmt.Sprintf("(%s %s %s)", b.LeftChild, b.Op, b.RightChild) } func (b *BitOp) DebugString() string { - return fmt.Sprintf("(%s %s %s)", sql.DebugString(b.Left), b.Op, sql.DebugString(b.Right)) + return fmt.Sprintf("(%s %s %s)", sql.DebugString(b.LeftChild), b.Op, sql.DebugString(b.RightChild)) } // IsNullable implements the sql.Expression interface. func (b *BitOp) IsNullable() bool { - return b.BinaryExpression.IsNullable() + return b.BinaryExpressionStub.IsNullable() } // Type returns the greatest type for given operation. func (b *BitOp) Type() sql.Type { - rTyp := b.Right.Type() + rTyp := b.RightChild.Type() if types.IsDeferredType(rTyp) { return rTyp } - lTyp := b.Left.Type() + lTyp := b.LeftChild.Type() if types.IsDeferredType(lTyp) { return lTyp } @@ -154,12 +154,12 @@ func (b *BitOp) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, inter var err error // bit ops used with Interval error is caught at parsing the query - lval, err = b.Left.Eval(ctx, row) + lval, err = b.LeftChild.Eval(ctx, row) if err != nil { return nil, nil, err } - rval, err = b.Right.Eval(ctx, row) + rval, err = b.RightChild.Eval(ctx, row) if err != nil { return nil, nil, err } @@ -170,8 +170,8 @@ func (b *BitOp) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, inter func (b *BitOp) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}, error) { typ := b.Type() - left = convertValueToType(ctx, typ, left, types.IsTime(b.Left.Type())) - right = convertValueToType(ctx, typ, right, types.IsTime(b.Right.Type())) + left = convertValueToType(ctx, typ, left, types.IsTime(b.LeftChild.Type())) + right = convertValueToType(ctx, typ, right, types.IsTime(b.RightChild.Type())) return left, right, nil } diff --git a/sql/expression/common.go b/sql/expression/common.go index 4cdbd4a80b..180cfed5fa 100644 --- a/sql/expression/common.go +++ b/sql/expression/common.go @@ -32,7 +32,7 @@ func IsBinary(e sql.Expression) bool { return len(e.Children()) == 2 } -// UnaryExpression is an expression that has only one children. +// UnaryExpression is an expression that has only one child. type UnaryExpression struct { Child sql.Expression } @@ -52,25 +52,40 @@ func (p *UnaryExpression) IsNullable() bool { return p.Child.IsNullable() } -// BinaryExpression is an expression that has two children. -type BinaryExpression struct { - Left sql.Expression - Right sql.Expression +// BinaryExpressionStub is an expression that has two children. +type BinaryExpressionStub struct { + LeftChild sql.Expression + RightChild sql.Expression +} + +// BinaryExpression is an expression that has two children +type BinaryExpression interface { + sql.Expression + Left() sql.Expression + Right() sql.Expression +} + +func (p *BinaryExpressionStub) Left() sql.Expression { + return p.LeftChild +} + +func (p *BinaryExpressionStub) Right() sql.Expression { + return p.RightChild } // Children implements the Expression interface. -func (p *BinaryExpression) Children() []sql.Expression { - return []sql.Expression{p.Left, p.Right} +func (p *BinaryExpressionStub) Children() []sql.Expression { + return []sql.Expression{p.LeftChild, p.RightChild} } // Resolved implements the Expression interface. -func (p *BinaryExpression) Resolved() bool { - return p.Left.Resolved() && p.Right.Resolved() +func (p *BinaryExpressionStub) Resolved() bool { + return p.LeftChild.Resolved() && p.RightChild.Resolved() } // IsNullable returns whether the expression can be null. -func (p *BinaryExpression) IsNullable() bool { - return p.Left.IsNullable() || p.Right.IsNullable() +func (p *BinaryExpressionStub) IsNullable() bool { + return p.LeftChild.IsNullable() || p.RightChild.IsNullable() } type NaryExpression struct { diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index f159ddb22a..1023fc43f1 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -64,11 +64,11 @@ func PreciseComparison(e sql.Expression) bool { } type comparison struct { - BinaryExpression + BinaryExpressionStub } func newComparison(left, right sql.Expression) comparison { - return comparison{BinaryExpression{left, right}} + return comparison{BinaryExpressionStub{left, right}} } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -260,10 +260,10 @@ func (*comparison) Type() sql.Type { } // Left implements Comparer interface -func (c *comparison) Left() sql.Expression { return c.BinaryExpression.Left } +func (c *comparison) Left() sql.Expression { return c.BinaryExpressionStub.LeftChild } // Right implements Comparer interface -func (c *comparison) Right() sql.Expression { return c.BinaryExpression.Right } +func (c *comparison) Right() sql.Expression { return c.BinaryExpressionStub.RightChild } // Equals is a comparison that checks an expression is equal to another. type Equals struct { diff --git a/sql/expression/div.go b/sql/expression/div.go index 3a6d698667..3494f6450e 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -46,7 +46,7 @@ var _ sql.CollationCoercible = (*Div)(nil) // Div expression represents "/" arithmetic operation type Div struct { - BinaryExpression + BinaryExpressionStub ops int32 // divScale is number of continuous division operations; this value will be available of all layers divScale int32 @@ -59,7 +59,7 @@ type Div struct { // NewDiv creates a new Div / sql.Expression. func NewDiv(left, right sql.Expression) *Div { a := &Div{ - BinaryExpression: BinaryExpression{Left: left, Right: right}, + BinaryExpressionStub: BinaryExpressionStub{LeftChild: left, RightChild: right}, curIntermediatePrecisionInc: 0, } a.leftmostScale.Store(0) @@ -70,14 +70,6 @@ func NewDiv(left, right sql.Expression) *Div { return a } -func (d *Div) LeftChild() sql.Expression { - return d.Left -} - -func (d *Div) RightChild() sql.Expression { - return d.Right -} - func (d *Div) Operator() string { return sqlparser.DivStr } @@ -87,16 +79,16 @@ func (d *Div) SetOpCount(i int32) { } func (d *Div) String() string { - return fmt.Sprintf("(%s / %s)", d.Left, d.Right) + return fmt.Sprintf("(%s / %s)", d.LeftChild, d.RightChild) } func (d *Div) DebugString() string { - return fmt.Sprintf("(%s / %s)", sql.DebugString(d.Left), sql.DebugString(d.Right)) + return fmt.Sprintf("(%s / %s)", sql.DebugString(d.LeftChild), sql.DebugString(d.RightChild)) } // IsNullable implements the sql.Expression interface. func (d *Div) IsNullable() bool { - return d.BinaryExpression.IsNullable() + return d.BinaryExpressionStub.IsNullable() } // Type returns the result type for this division expression. For nested division expressions, we prefer sending @@ -181,7 +173,7 @@ func (d *Div) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interfa var err error // division used with Interval error is caught at parsing the query - lval, err = d.Left.Eval(ctx, row) + lval, err = d.LeftChild.Eval(ctx, row) if err != nil { return nil, nil, err } @@ -189,7 +181,7 @@ func (d *Div) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interfa // this operation is only done on the left value as the scale/fraction part of the leftmost value // is used to calculate the scale of the final result. If the value is GetField of decimal type column // the decimal value evaluated does not always match the scale of column type definition - if dt, ok := d.Left.Type().(sql.DecimalType); ok { + if dt, ok := d.LeftChild.Type().(sql.DecimalType); ok { if dVal, ok := lval.(decimal.Decimal); ok { ts := int32(dt.Scale()) if ts > dVal.Exponent()*-1 { @@ -201,7 +193,7 @@ func (d *Div) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interfa } } - rval, err = d.Right.Eval(ctx, row) + rval, err = d.RightChild.Eval(ctx, row) if err != nil { return nil, nil, err } @@ -218,8 +210,8 @@ func (d *Div) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interfa // should be preserved. func (d *Div) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) { typ := d.internalType() - lIsTimeType := types.IsTime(d.Left.Type()) - rIsTimeType := types.IsTime(d.Right.Type()) + lIsTimeType := types.IsTime(d.LeftChild.Type()) + rIsTimeType := types.IsTime(d.RightChild.Type()) if types.IsFloat(typ) { left = convertValueToType(ctx, typ, left, lIsTimeType) @@ -303,11 +295,11 @@ func (d *Div) div(ctx *sql.Context, lval, rval interface{}) (interface{}, error) // in order to match MySQL's behavior. func (d *Div) determineResultType(outermostResult bool) sql.Type { //TODO: what if both BindVars? should be constant folded - rTyp := d.Right.Type() + rTyp := d.RightChild.Type() if types.IsDeferredType(rTyp) { return rTyp } - lTyp := d.Left.Type() + lTyp := d.LeftChild.Type() if types.IsDeferredType(lTyp) { return lTyp } @@ -482,11 +474,11 @@ func countDivs(e sql.Expression) int32 { } if a, ok := e.(*Div); ok { - return countDivs(a.Left) + 1 + return countDivs(a.LeftChild) + 1 } if a, ok := e.(ArithmeticOp); ok { - return countDivs(a.LeftChild()) + return countDivs(a.Left()) } return 0 @@ -502,13 +494,13 @@ func setDivs(e sql.Expression, dScale int32) { if a, ok := e.(*Div); ok { a.divScale = dScale - setDivs(a.Left, dScale) - setDivs(a.Right, dScale) + setDivs(a.LeftChild, dScale) + setDivs(a.RightChild, dScale) } if a, ok := e.(ArithmeticOp); ok { - setDivs(a.LeftChild(), dScale) - setDivs(a.RightChild(), dScale) + setDivs(a.Left(), dScale) + setDivs(a.Right(), dScale) } return @@ -524,14 +516,14 @@ func getScaleOfLeftmostValue(ctx *sql.Context, row sql.Row, e sql.Expression, d, if a, ok := e.(*Div); ok { d = d + 1 if d == dScale { - lval, err := a.Left.Eval(ctx, row) + lval, err := a.LeftChild.Eval(ctx, row) if err != nil { return 0 } _, s := GetPrecisionAndScale(lval) // the leftmost value can be row value of decimal type column // the evaluated value does not always match the scale of column type definition - typ := a.Left.Type() + typ := a.LeftChild.Type() if dt, dok := typ.(sql.DecimalType); dok { ts := dt.Scale() if ts > s { @@ -540,7 +532,7 @@ func getScaleOfLeftmostValue(ctx *sql.Context, row sql.Row, e sql.Expression, d, } return int32(s) } else { - return getScaleOfLeftmostValue(ctx, row, a.Left, d, dScale) + return getScaleOfLeftmostValue(ctx, row, a.LeftChild, d, dScale) } } @@ -568,10 +560,10 @@ func isOutermostDiv(e sql.Expression, d, dScale int32) bool { if d == dScale { return true } else { - return isOutermostDiv(a.Left, d, dScale) + return isOutermostDiv(a.LeftChild, d, dScale) } } else if a, ok := e.(ArithmeticOp); ok { - return isOutermostDiv(a.LeftChild(), d, dScale) + return isOutermostDiv(a.Left(), d, dScale) } return false @@ -637,21 +629,21 @@ func getPrecInc(e sql.Expression, cur int) int { if d.curIntermediatePrecisionInc > cur { return d.curIntermediatePrecisionInc } - l := getPrecInc(d.Left, cur) + l := getPrecInc(d.LeftChild, cur) if l > cur { cur = l } - r := getPrecInc(d.Right, cur) + r := getPrecInc(d.RightChild, cur) if r > cur { cur = r } return cur } else if d, ok := e.(ArithmeticOp); ok { - l := getPrecInc(d.LeftChild(), cur) + l := getPrecInc(d.Left(), cur) if l > cur { cur = l } - r := getPrecInc(d.RightChild(), cur) + r := getPrecInc(d.Right(), cur) if r > cur { cur = r } @@ -666,26 +658,18 @@ var _ sql.CollationCoercible = (*IntDiv)(nil) // IntDiv expression represents integer "div" arithmetic operation type IntDiv struct { - BinaryExpression + BinaryExpressionStub ops int32 } // NewIntDiv creates a new IntDiv 'div' sql.Expression. func NewIntDiv(left, right sql.Expression) *IntDiv { - a := &IntDiv{BinaryExpression{Left: left, Right: right}, 0} + a := &IntDiv{BinaryExpressionStub{LeftChild: left, RightChild: right}, 0} ops := countArithmeticOps(a) setArithmeticOps(a, ops) return a } -func (i *IntDiv) LeftChild() sql.Expression { - return i.Left -} - -func (i *IntDiv) RightChild() sql.Expression { - return i.Right -} - func (i *IntDiv) Operator() string { return sqlparser.IntDivStr } @@ -695,22 +679,22 @@ func (i *IntDiv) SetOpCount(i2 int32) { } func (i *IntDiv) String() string { - return fmt.Sprintf("(%s div %s)", i.Left, i.Right) + return fmt.Sprintf("(%s div %s)", i.LeftChild, i.RightChild) } func (i *IntDiv) DebugString() string { - return fmt.Sprintf("(%s div %s)", sql.DebugString(i.Left), sql.DebugString(i.Right)) + return fmt.Sprintf("(%s div %s)", sql.DebugString(i.LeftChild), sql.DebugString(i.RightChild)) } // IsNullable implements the sql.Expression interface. func (i *IntDiv) IsNullable() bool { - return i.BinaryExpression.IsNullable() + return i.BinaryExpressionStub.IsNullable() } // Type returns the greatest type for given operation. func (i *IntDiv) Type() sql.Type { - lTyp := i.Left.Type() - rTyp := i.Right.Type() + lTyp := i.LeftChild.Type() + rTyp := i.RightChild.Type() if types.IsUnsigned(lTyp) || types.IsUnsigned(rTyp) { return types.Uint64 @@ -753,12 +737,12 @@ func (i *IntDiv) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, inte var err error // int division used with Interval error is caught at parsing the query - lval, err = i.Left.Eval(ctx, row) + lval, err = i.LeftChild.Eval(ctx, row) if err != nil { return nil, nil, err } - rval, err = i.Right.Eval(ctx, row) + rval, err = i.RightChild.Eval(ctx, row) if err != nil { return nil, nil, err } @@ -774,7 +758,7 @@ func (i *IntDiv) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, inte // should be preserved. func (i *IntDiv) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) { var typ sql.Type - lTyp, rTyp := i.Left.Type(), i.Right.Type() + lTyp, rTyp := i.LeftChild.Type(), i.RightChild.Type() lIsTimeType := types.IsTime(lTyp) rIsTimeType := types.IsTime(rTyp) diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index 8114bd00c6..2d5a18c6f7 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -198,7 +198,7 @@ func (f *Floor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // digits of it's integer part set to 0. If d is not specified or nil/null // it defaults to 0. type Round struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*Round)(nil) @@ -216,7 +216,7 @@ func NewRound(args ...sql.Expression) (sql.Expression, error) { right = args[1] } - return &Round{expression.BinaryExpression{Left: args[0], Right: right}}, nil + return &Round{expression.BinaryExpressionStub{LeftChild: args[0], RightChild: right}}, nil } // FunctionName implements sql.FunctionExpression @@ -231,16 +231,16 @@ func (r *Round) Description() string { // Children implements the Expression interface. func (r *Round) Children() []sql.Expression { - if r.Right == nil { - return []sql.Expression{r.Left} + if r.RightChild == nil { + return []sql.Expression{r.LeftChild} } - return r.BinaryExpression.Children() + return r.BinaryExpressionStub.Children() } // Eval implements the Expression interface. func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - val, err := r.Left.Eval(ctx, row) + val, err := r.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -257,9 +257,9 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } prec := int32(0) - if r.Right != nil { + if r.RightChild != nil { var tmp interface{} - tmp, err = r.Right.Eval(ctx, row) + tmp, err = r.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -288,15 +288,15 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { var res interface{} tmp := val.(decimal.Decimal).Round(prec) - if types.IsSigned(r.Left.Type()) { + if types.IsSigned(r.LeftChild.Type()) { res, _, err = types.Int64.Convert(tmp) - } else if types.IsUnsigned(r.Left.Type()) { + } else if types.IsUnsigned(r.LeftChild.Type()) { res, _, err = types.Uint64.Convert(tmp) - } else if types.IsFloat(r.Left.Type()) { + } else if types.IsFloat(r.LeftChild.Type()) { res, _, err = types.Float64.Convert(tmp) - } else if types.IsDecimal(r.Left.Type()) { + } else if types.IsDecimal(r.LeftChild.Type()) { res = tmp - } else if types.IsTextBlob(r.Left.Type()) { + } else if types.IsTextBlob(r.LeftChild.Type()) { res, _, err = types.Float64.Convert(tmp) } @@ -305,25 +305,25 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // IsNullable implements the Expression interface. func (r *Round) IsNullable() bool { - return r.Left.IsNullable() + return r.LeftChild.IsNullable() } func (r *Round) String() string { - if r.Right == nil { - return fmt.Sprintf("%s(%s,0)", r.FunctionName(), r.Left.String()) + if r.RightChild == nil { + return fmt.Sprintf("%s(%s,0)", r.FunctionName(), r.LeftChild.String()) } - return fmt.Sprintf("%s(%s,%s)", r.FunctionName(), r.Left.String(), r.Right.String()) + return fmt.Sprintf("%s(%s,%s)", r.FunctionName(), r.LeftChild.String(), r.RightChild.String()) } // Resolved implements the Expression interface. func (r *Round) Resolved() bool { - return r.Left.Resolved() && (r.Right == nil || r.Right.Resolved()) + return r.LeftChild.Resolved() && (r.RightChild == nil || r.RightChild.Resolved()) } // Type implements the Expression interface. func (r *Round) Type() sql.Type { - leftChildType := r.Left.Type() + leftChildType := r.LeftChild.Type() if types.IsNumber(leftChildType) { return leftChildType } diff --git a/sql/expression/function/date_format.go b/sql/expression/function/date_format.go index 3c805a931e..07eb718d18 100644 --- a/sql/expression/function/date_format.go +++ b/sql/expression/function/date_format.go @@ -251,7 +251,7 @@ func formatDate(format string, t time.Time) (string, error) { // DateFormat function returns a string representation of the date specified in the format specified type DateFormat struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*DateFormat)(nil) @@ -270,20 +270,20 @@ func (f *DateFormat) Description() string { // NewDateFormat returns a new DateFormat UDF func NewDateFormat(ex, value sql.Expression) sql.Expression { return &DateFormat{ - expression.BinaryExpression{ - Left: ex, - Right: value, + expression.BinaryExpressionStub{ + LeftChild: ex, + RightChild: value, }, } } // Eval implements the Expression interface. func (f *DateFormat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - if f.Left == nil || f.Right == nil { + if f.LeftChild == nil || f.RightChild == nil { return nil, nil } - left, err := f.Left.Eval(ctx, row) + left, err := f.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -300,7 +300,7 @@ func (f *DateFormat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { t := timeVal.(time.Time) - right, err := f.Right.Eval(ctx, row) + right, err := f.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -330,17 +330,17 @@ func (*DateFormat) CollationCoercibility(ctx *sql.Context) (collation sql.Collat // IsNullable implements the Expression interface. func (f *DateFormat) IsNullable() bool { - if types.IsNull(f.Left) { - if types.IsNull(f.Right) { + if types.IsNull(f.LeftChild) { + if types.IsNull(f.RightChild) { return true } - return f.Right.IsNullable() + return f.RightChild.IsNullable() } - return f.Left.IsNullable() + return f.LeftChild.IsNullable() } func (f *DateFormat) String() string { - return fmt.Sprintf("date_format(%s, %s)", f.Left, f.Right) + return fmt.Sprintf("date_format(%s, %s)", f.LeftChild, f.RightChild) } // WithChildren implements the Expression interface. diff --git a/sql/expression/function/extract.go b/sql/expression/function/extract.go index 7d28f1cb99..26d587235d 100644 --- a/sql/expression/function/extract.go +++ b/sql/expression/function/extract.go @@ -25,7 +25,7 @@ import ( // Extract takes out the specified unit(s) from the time expression. type Extract struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*Extract)(nil) @@ -34,9 +34,9 @@ var _ sql.CollationCoercible = (*Extract)(nil) // NewExtract creates a new Extract expression. func NewExtract(e1, e2 sql.Expression) sql.Expression { return &Extract{ - expression.BinaryExpression{ - Left: e1, - Right: e2, + expression.BinaryExpressionStub{ + LeftChild: e1, + RightChild: e2, }, } } @@ -60,7 +60,7 @@ func (*Extract) CollationCoercibility(ctx *sql.Context) (collation sql.Collation } func (td *Extract) String() string { - return fmt.Sprintf("%s(%s from %s)", td.FunctionName(), td.Left, td.Right) + return fmt.Sprintf("%s(%s from %s)", td.FunctionName(), td.LeftChild, td.RightChild) } // WithChildren implements the Expression interface. @@ -73,11 +73,11 @@ func (td *Extract) WithChildren(children ...sql.Expression) (sql.Expression, err // Eval implements the Expression interface. func (td *Extract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - if td.Left == nil || td.Right == nil { + if td.LeftChild == nil || td.RightChild == nil { return nil, nil } - left, err := td.Left.Eval(ctx, row) + left, err := td.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -91,7 +91,7 @@ func (td *Extract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, fmt.Errorf("unit is not string type") } - right, err := td.Right.Eval(ctx, row) + right, err := td.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -128,7 +128,7 @@ func (td *Extract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { case "MONTH": return int(dateTime.Month()), nil case "WEEK": - date, err := getDate(ctx, expression.UnaryExpression{Child: td.Right}, row) + date, err := getDate(ctx, expression.UnaryExpression{Child: td.RightChild}, row) if err != nil { return nil, err } diff --git a/sql/expression/function/find_in_set.go b/sql/expression/function/find_in_set.go index 80699a9355..7bc60f8cdd 100644 --- a/sql/expression/function/find_in_set.go +++ b/sql/expression/function/find_in_set.go @@ -25,7 +25,7 @@ import ( // FindInSet takes out the specified unit(s) from the time expression. type FindInSet struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*FindInSet)(nil) @@ -34,9 +34,9 @@ var _ sql.CollationCoercible = (*FindInSet)(nil) // NewFindInSet creates a new FindInSet expression. func NewFindInSet(e1, e2 sql.Expression) sql.Expression { return &FindInSet{ - expression.BinaryExpression{ - Left: e1, - Right: e2, + expression.BinaryExpressionStub{ + LeftChild: e1, + RightChild: e2, }, } } @@ -60,7 +60,7 @@ func (*FindInSet) CollationCoercibility(ctx *sql.Context) (collation sql.Collati } func (f *FindInSet) String() string { - return fmt.Sprintf("%s(%s from %s)", f.FunctionName(), f.Left, f.Right) + return fmt.Sprintf("%s(%s from %s)", f.FunctionName(), f.LeftChild, f.RightChild) } // WithChildren implements the Expression interface. @@ -73,16 +73,16 @@ func (f *FindInSet) WithChildren(children ...sql.Expression) (sql.Expression, er // Eval implements the Expression interface. func (f *FindInSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - if f.Left == nil || f.Right == nil { + if f.LeftChild == nil || f.RightChild == nil { return nil, nil } - left, err := f.Left.Eval(ctx, row) + left, err := f.LeftChild.Eval(ctx, row) if err != nil { return nil, err } - right, err := f.Right.Eval(ctx, row) + right, err := f.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -103,7 +103,7 @@ func (f *FindInSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } var r string - rType := f.Right.Type() + rType := f.RightChild.Type() if setType, ok := rType.(types.SetType); ok { // TODO: set type should take advantage of bit arithmetic r, err = setType.BitsToString(right.(uint64)) @@ -124,8 +124,8 @@ func (f *FindInSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { r = rVal.(string) } - leftColl, leftCoer := sql.GetCoercibility(ctx, f.Left) - rightColl, rightCoer := sql.GetCoercibility(ctx, f.Right) + leftColl, leftCoer := sql.GetCoercibility(ctx, f.LeftChild) + rightColl, rightCoer := sql.GetCoercibility(ctx, f.RightChild) collPref, _ := sql.ResolveCoercibility(leftColl, leftCoer, rightColl, rightCoer) strType := types.CreateLongText(collPref) diff --git a/sql/expression/function/hash.go b/sql/expression/function/hash.go index 0be1305cea..03e5ee9cb0 100644 --- a/sql/expression/function/hash.go +++ b/sql/expression/function/hash.go @@ -142,7 +142,7 @@ func (f *SHA1) WithChildren(children ...sql.Expression) (sql.Expression, error) // SHA2 function returns the SHA-224/256/384/512 hash of the input. // https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_sha2 type SHA2 struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*SHA2)(nil) @@ -150,7 +150,7 @@ var _ sql.CollationCoercible = (*SHA2)(nil) // NewSHA2 returns a new SHA2 function expression func NewSHA2(arg, count sql.Expression) sql.Expression { - return &SHA2{expression.BinaryExpression{Left: arg, Right: count}} + return &SHA2{expression.BinaryExpressionStub{LeftChild: arg, RightChild: count}} } // Description implements sql.FunctionExpression @@ -165,14 +165,14 @@ func (*SHA2) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, // Eval implements sql.Expression func (f *SHA2) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - arg, err := f.Left.Eval(ctx, row) + arg, err := f.LeftChild.Eval(ctx, row) if err != nil { return nil, err } if arg == nil { return nil, nil } - countArg, err := f.Right.Eval(ctx, row) + countArg, err := f.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -217,7 +217,7 @@ func (f *SHA2) FunctionName() string { // String implements sql.Expression func (f *SHA2) String() string { - return fmt.Sprintf("%s(%s,%s)", f.FunctionName(), f.Left, f.Right) + return fmt.Sprintf("%s(%s,%s)", f.FunctionName(), f.LeftChild, f.RightChild) } // Type implements sql.Expression diff --git a/sql/expression/function/ifnull.go b/sql/expression/function/ifnull.go index 845eff0d6e..9f5e4f8709 100644 --- a/sql/expression/function/ifnull.go +++ b/sql/expression/function/ifnull.go @@ -24,7 +24,7 @@ import ( // IfNull function returns the specified value IF the expression is NULL, otherwise return the expression. type IfNull struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*IfNull)(nil) @@ -33,9 +33,9 @@ var _ sql.CollationCoercible = (*IfNull)(nil) // NewIfNull returns a new IFNULL UDF func NewIfNull(ex, value sql.Expression) sql.Expression { return &IfNull{ - expression.BinaryExpression{ - Left: ex, - Right: value, + expression.BinaryExpressionStub{ + LeftChild: ex, + RightChild: value, }, } } @@ -52,7 +52,7 @@ func (f *IfNull) Description() string { // Eval implements the Expression interface. func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - left, err := f.Left.Eval(ctx, row) + left, err := f.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -60,7 +60,7 @@ func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return left, nil } - right, err := f.Right.Eval(ctx, row) + right, err := f.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -69,39 +69,39 @@ func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // Type implements the Expression interface. func (f *IfNull) Type() sql.Type { - if types.IsNull(f.Left) { - if types.IsNull(f.Right) { + if types.IsNull(f.LeftChild) { + if types.IsNull(f.RightChild) { return types.Null } - return f.Right.Type() + return f.RightChild.Type() } - return f.Left.Type() + return f.LeftChild.Type() } // CollationCoercibility implements the interface sql.CollationCoercible. func (f *IfNull) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - if types.IsNull(f.Left) { - if types.IsNull(f.Right) { + if types.IsNull(f.LeftChild) { + if types.IsNull(f.RightChild) { return sql.Collation_binary, 6 } - return sql.GetCoercibility(ctx, f.Right) + return sql.GetCoercibility(ctx, f.RightChild) } - return sql.GetCoercibility(ctx, f.Left) + return sql.GetCoercibility(ctx, f.LeftChild) } // IsNullable implements the Expression interface. func (f *IfNull) IsNullable() bool { - if types.IsNull(f.Left) { - if types.IsNull(f.Right) { + if types.IsNull(f.LeftChild) { + if types.IsNull(f.RightChild) { return true } - return f.Right.IsNullable() + return f.RightChild.IsNullable() } - return f.Left.IsNullable() + return f.LeftChild.IsNullable() } func (f *IfNull) String() string { - return fmt.Sprintf("ifnull(%s, %s)", f.Left, f.Right) + return fmt.Sprintf("ifnull(%s, %s)", f.LeftChild, f.RightChild) } // WithChildren implements the Expression interface. diff --git a/sql/expression/function/locks.go b/sql/expression/function/locks.go index eb7214e25b..dff050d9bd 100644 --- a/sql/expression/function/locks.go +++ b/sql/expression/function/locks.go @@ -269,7 +269,7 @@ func IsUsedLockFunc(ctx *sql.Context, ls *sql.LockSubsystem, lockName string) (i // GetLock is a SQL function implementing get_lock type GetLock struct { - expression.BinaryExpression + expression.BinaryExpressionStub ls *sql.LockSubsystem } @@ -279,7 +279,7 @@ var _ sql.CollationCoercible = (*GetLock)(nil) // CreateNewGetLock returns a new GetLock object func CreateNewGetLock(ls *sql.LockSubsystem) func(e1, e2 sql.Expression) sql.Expression { return func(e1, e2 sql.Expression) sql.Expression { - return &GetLock{expression.BinaryExpression{e1, e2}, ls} + return &GetLock{expression.BinaryExpressionStub{e1, e2}, ls} } } @@ -295,11 +295,11 @@ func (gl *GetLock) Description() string { // Eval implements the Expression interface. func (gl *GetLock) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - if gl.Left == nil { + if gl.LeftChild == nil { return nil, nil } - leftVal, err := gl.Left.Eval(ctx, row) + leftVal, err := gl.LeftChild.Eval(ctx, row) if err != nil { return nil, err @@ -309,11 +309,11 @@ func (gl *GetLock) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - if gl.Right == nil { + if gl.RightChild == nil { return nil, nil } - rightVal, err := gl.Right.Eval(ctx, row) + rightVal, err := gl.RightChild.Eval(ctx, row) if err != nil { return nil, err @@ -323,14 +323,14 @@ func (gl *GetLock) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - s, ok := gl.Left.Type().(sql.StringType) + s, ok := gl.LeftChild.Type().(sql.StringType) if !ok { - return nil, ErrIllegalLockNameArgType.New(gl.Left.Type().String(), gl.FunctionName()) + return nil, ErrIllegalLockNameArgType.New(gl.LeftChild.Type().String(), gl.FunctionName()) } lockName, err := types.ConvertToString(leftVal, s) if err != nil { - return nil, fmt.Errorf("%w; %s", ErrIllegalLockNameArgType.New(gl.Left.Type().String(), gl.FunctionName()), err) + return nil, fmt.Errorf("%w; %s", ErrIllegalLockNameArgType.New(gl.LeftChild.Type().String(), gl.FunctionName()), err) } timeout, _, err := types.Int64.Convert(rightVal) @@ -354,7 +354,7 @@ func (gl *GetLock) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // String implements the fmt.Stringer interface. func (gl *GetLock) String() string { - return fmt.Sprintf("get_lock(%s, %s)", gl.Left.String(), gl.Right.String()) + return fmt.Sprintf("get_lock(%s, %s)", gl.LeftChild.String(), gl.RightChild.String()) } // IsNullable implements the Expression interface. @@ -368,7 +368,7 @@ func (gl *GetLock) WithChildren(children ...sql.Expression) (sql.Expression, err return nil, sql.ErrInvalidChildrenNumber.New(gl, len(children), 1) } - return &GetLock{expression.BinaryExpression{Left: children[0], Right: children[1]}, gl.ls}, nil + return &GetLock{expression.BinaryExpressionStub{LeftChild: children[0], RightChild: children[1]}, gl.ls}, nil } // Type implements the Expression interface. diff --git a/sql/expression/function/logarithm.go b/sql/expression/function/logarithm.go index 20a154aa4f..7846ffb111 100644 --- a/sql/expression/function/logarithm.go +++ b/sql/expression/function/logarithm.go @@ -138,7 +138,7 @@ func (l *LogBase) Eval( // Log is a function that returns the natural logarithm of a value. type Log struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*Log)(nil) @@ -152,9 +152,9 @@ func NewLog(args ...sql.Expression) (sql.Expression, error) { } if argLen == 1 { - return &Log{expression.BinaryExpression{Left: expression.NewLiteral(math.E, types.Float64), Right: args[0]}}, nil + return &Log{expression.BinaryExpressionStub{LeftChild: expression.NewLiteral(math.E, types.Float64), RightChild: args[0]}}, nil } else { - return &Log{expression.BinaryExpression{Left: args[0], Right: args[1]}}, nil + return &Log{expression.BinaryExpressionStub{LeftChild: args[0], RightChild: args[1]}}, nil } } @@ -169,7 +169,7 @@ func (l *Log) Description() string { } func (l *Log) String() string { - return fmt.Sprintf("%s(%s,%s)", l.FunctionName(), l.Left, l.Right) + return fmt.Sprintf("%s(%s,%s)", l.FunctionName(), l.LeftChild, l.RightChild) } // WithChildren implements the Expression interface. @@ -179,7 +179,7 @@ func (l *Log) WithChildren(children ...sql.Expression) (sql.Expression, error) { // Children implements the Expression interface. func (l *Log) Children() []sql.Expression { - return []sql.Expression{l.Left, l.Right} + return []sql.Expression{l.LeftChild, l.RightChild} } // Type returns the resultant type of the function. @@ -194,7 +194,7 @@ func (*Log) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, // IsNullable implements the Expression interface. func (l *Log) IsNullable() bool { - return l.Left.IsNullable() || l.Right.IsNullable() + return l.LeftChild.IsNullable() || l.RightChild.IsNullable() } // Eval implements the Expression interface. @@ -202,7 +202,7 @@ func (l *Log) Eval( ctx *sql.Context, row sql.Row, ) (interface{}, error) { - left, err := l.Left.Eval(ctx, row) + left, err := l.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -216,7 +216,7 @@ func (l *Log) Eval( return nil, sql.ErrInvalidType.New(reflect.TypeOf(left)) } - right, err := l.Right.Eval(ctx, row) + right, err := l.RightChild.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/expression/function/nullif.go b/sql/expression/function/nullif.go index 97386485c6..063fdef226 100644 --- a/sql/expression/function/nullif.go +++ b/sql/expression/function/nullif.go @@ -24,7 +24,7 @@ import ( // NullIf function compares two expressions and returns NULL if they are equal. Otherwise, the first expression is returned. type NullIf struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*NullIf)(nil) @@ -33,9 +33,9 @@ var _ sql.CollationCoercible = (*NullIf)(nil) // NewNullIf returns a new NULLIF UDF func NewNullIf(ex1, ex2 sql.Expression) sql.Expression { return &NullIf{ - expression.BinaryExpression{ - Left: ex1, - Right: ex2, + expression.BinaryExpressionStub{ + LeftChild: ex1, + RightChild: ex2, }, } } @@ -52,11 +52,11 @@ func (f *NullIf) Description() string { // Eval implements the Expression interface. func (f *NullIf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - if types.IsNull(f.Left) && types.IsNull(f.Right) { + if types.IsNull(f.LeftChild) && types.IsNull(f.RightChild) { return nil, nil } - val, err := expression.NewEquals(f.Left, f.Right).Eval(ctx, row) + val, err := expression.NewEquals(f.LeftChild, f.RightChild).Eval(ctx, row) if err != nil { return nil, err } @@ -64,24 +64,24 @@ func (f *NullIf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - return f.Left.Eval(ctx, row) + return f.LeftChild.Eval(ctx, row) } // Type implements the Expression interface. func (f *NullIf) Type() sql.Type { - if types.IsNull(f.Left) { + if types.IsNull(f.LeftChild) { return types.Null } - return f.Left.Type() + return f.LeftChild.Type() } // CollationCoercibility implements the interface sql.CollationCoercible. func (f *NullIf) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - if types.IsNull(f.Left) { + if types.IsNull(f.LeftChild) { return sql.Collation_binary, 6 } - return sql.GetCoercibility(ctx, f.Left) + return sql.GetCoercibility(ctx, f.LeftChild) } // IsNullable implements the Expression interface. @@ -90,7 +90,7 @@ func (f *NullIf) IsNullable() bool { } func (f *NullIf) String() string { - return fmt.Sprintf("%s(%s,%s)", f.FunctionName(), f.Left, f.Right) + return fmt.Sprintf("%s(%s,%s)", f.FunctionName(), f.LeftChild, f.RightChild) } // WithChildren implements the Expression interface. diff --git a/sql/expression/function/reverse_repeat_replace.go b/sql/expression/function/reverse_repeat_replace.go index 2f04c2e5de..d23a7f0a30 100644 --- a/sql/expression/function/reverse_repeat_replace.go +++ b/sql/expression/function/reverse_repeat_replace.go @@ -101,7 +101,7 @@ var ErrNegativeRepeatCount = errors.NewKind("negative Repeat count: %v") // Repeat is a function that returns the string repeated n times. type Repeat struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*Repeat)(nil) @@ -109,7 +109,7 @@ var _ sql.CollationCoercible = (*Repeat)(nil) // NewRepeat creates a new Repeat expression. func NewRepeat(str sql.Expression, count sql.Expression) sql.Expression { - return &Repeat{expression.BinaryExpression{Left: str, Right: count}} + return &Repeat{expression.BinaryExpressionStub{LeftChild: str, RightChild: count}} } // FunctionName implements sql.FunctionExpression @@ -123,7 +123,7 @@ func (r *Repeat) Description() string { } func (r *Repeat) String() string { - return fmt.Sprintf("repeat(%s, %s)", r.Left, r.Right) + return fmt.Sprintf("repeat(%s, %s)", r.LeftChild, r.RightChild) } // Type implements the Expression interface. @@ -133,8 +133,8 @@ func (r *Repeat) Type() sql.Type { // CollationCoercibility implements the interface sql.CollationCoercible. func (r *Repeat) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - leftCollation, leftCoercibility := sql.GetCoercibility(ctx, r.Left) - rightCollation, rightCoercibility := sql.GetCoercibility(ctx, r.Right) + leftCollation, leftCoercibility := sql.GetCoercibility(ctx, r.LeftChild) + rightCollation, rightCoercibility := sql.GetCoercibility(ctx, r.RightChild) return sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility) } @@ -152,7 +152,7 @@ func (r *Repeat) Eval( row sql.Row, ) (interface{}, error) { //TODO: handle collations - str, err := r.Left.Eval(ctx, row) + str, err := r.LeftChild.Eval(ctx, row) if str == nil || err != nil { return nil, err } @@ -162,7 +162,7 @@ func (r *Repeat) Eval( return nil, err } - count, err := r.Right.Eval(ctx, row) + count, err := r.RightChild.Eval(ctx, row) if count == nil || err != nil { return nil, err } diff --git a/sql/expression/function/spatial/st_equals.go b/sql/expression/function/spatial/st_equals.go index f7e3e5c27d..61b4cc17ca 100644 --- a/sql/expression/function/spatial/st_equals.go +++ b/sql/expression/function/spatial/st_equals.go @@ -24,7 +24,7 @@ import ( // STEquals is a function that returns the STEquals of a LineString type STEquals struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*STEquals)(nil) @@ -32,9 +32,9 @@ var _ sql.FunctionExpression = (*STEquals)(nil) // NewSTEquals creates a new STEquals expression. func NewSTEquals(g1, g2 sql.Expression) sql.Expression { return &STEquals{ - expression.BinaryExpression{ - Left: g1, - Right: g2, + expression.BinaryExpressionStub{ + LeftChild: g1, + RightChild: g2, }, } } @@ -55,7 +55,7 @@ func (s *STEquals) Type() sql.Type { } func (s *STEquals) String() string { - return fmt.Sprintf("ST_EQUALS(%s, %s)", s.Left, s.Right) + return fmt.Sprintf("ST_EQUALS(%s, %s)", s.LeftChild, s.RightChild) } // WithChildren implements the Expression interface. @@ -74,11 +74,11 @@ func isEqual(g1 types.GeometryValue, g2 types.GeometryValue) bool { // Eval implements the sql.Expression interface. func (s *STEquals) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - geom1, err := s.Left.Eval(ctx, row) + geom1, err := s.LeftChild.Eval(ctx, row) if err != nil { return nil, err } - geom2, err := s.Right.Eval(ctx, row) + geom2, err := s.RightChild.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/expression/function/spatial/st_intersects.go b/sql/expression/function/spatial/st_intersects.go index 3fddb0011d..44f72b6a1c 100644 --- a/sql/expression/function/spatial/st_intersects.go +++ b/sql/expression/function/spatial/st_intersects.go @@ -25,7 +25,7 @@ import ( // Intersects is a function that returns true if the two geometries intersect type Intersects struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*Intersects)(nil) @@ -34,9 +34,9 @@ var _ sql.CollationCoercible = (*Intersects)(nil) // NewIntersects creates a new Intersects expression. func NewIntersects(g1, g2 sql.Expression) sql.Expression { return &Intersects{ - expression.BinaryExpression{ - Left: g1, - Right: g2, + expression.BinaryExpressionStub{ + LeftChild: g1, + RightChild: g2, }, } } @@ -62,11 +62,11 @@ func (*Intersects) CollationCoercibility(ctx *sql.Context) (collation sql.Collat } func (i *Intersects) String() string { - return fmt.Sprintf("%s(%s,%s)", i.FunctionName(), i.Left, i.Right) + return fmt.Sprintf("%s(%s,%s)", i.FunctionName(), i.LeftChild, i.RightChild) } func (i *Intersects) DebugString() string { - return fmt.Sprintf("%s(%s,%s)", i.FunctionName(), sql.DebugString(i.Left), sql.DebugString(i.Right)) + return fmt.Sprintf("%s(%s,%s)", i.FunctionName(), sql.DebugString(i.LeftChild), sql.DebugString(i.RightChild)) } // WithChildren implements the Expression interface. @@ -355,11 +355,11 @@ func validateGeomComp(geom1, geom2 interface{}, funcName string) (types.Geometry // Eval implements the sql.Expression interface. func (i *Intersects) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - geom1, err := i.Left.Eval(ctx, row) + geom1, err := i.LeftChild.Eval(ctx, row) if err != nil { return nil, err } - geom2, err := i.Right.Eval(ctx, row) + geom2, err := i.RightChild.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/expression/function/spatial/st_within.go b/sql/expression/function/spatial/st_within.go index ae728105bd..3b451f85ca 100644 --- a/sql/expression/function/spatial/st_within.go +++ b/sql/expression/function/spatial/st_within.go @@ -25,7 +25,7 @@ import ( // Within is a function that true if left is spatially within right type Within struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*Within)(nil) @@ -34,9 +34,9 @@ var _ sql.CollationCoercible = (*Within)(nil) // NewWithin creates a new Within expression. func NewWithin(g1, g2 sql.Expression) sql.Expression { return &Within{ - expression.BinaryExpression{ - Left: g1, - Right: g2, + expression.BinaryExpressionStub{ + LeftChild: g1, + RightChild: g2, }, } } @@ -62,7 +62,7 @@ func (*Within) CollationCoercibility(ctx *sql.Context) (collation sql.CollationI } func (w *Within) String() string { - return fmt.Sprintf("%s(%s,%s)", w.FunctionName(), w.Left, w.Right) + return fmt.Sprintf("%s(%s,%s)", w.FunctionName(), w.LeftChild, w.RightChild) } // WithChildren implements the Expression interface. @@ -226,11 +226,11 @@ func isWithin(g1, g2 types.GeometryValue) bool { // Eval implements the sql.Expression interface. func (w *Within) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - geom1, err := w.Left.Eval(ctx, row) + geom1, err := w.LeftChild.Eval(ctx, row) if err != nil { return nil, err } - geom2, err := w.Right.Eval(ctx, row) + geom2, err := w.RightChild.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/expression/function/sqrt_power.go b/sql/expression/function/sqrt_power.go index dd252d7ad6..5ad5b3ab13 100644 --- a/sql/expression/function/sqrt_power.go +++ b/sql/expression/function/sqrt_power.go @@ -100,7 +100,7 @@ func (s *Sqrt) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // Power is a function that returns value of X raised to the power of Y. type Power struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*Power)(nil) @@ -109,9 +109,9 @@ var _ sql.CollationCoercible = (*Power)(nil) // NewPower creates a new Power expression. func NewPower(e1, e2 sql.Expression) sql.Expression { return &Power{ - expression.BinaryExpression{ - Left: e1, - Right: e2, + expression.BinaryExpressionStub{ + LeftChild: e1, + RightChild: e2, }, } } @@ -135,10 +135,10 @@ func (*Power) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID } // IsNullable implements the Expression interface. -func (p *Power) IsNullable() bool { return p.Left.IsNullable() || p.Right.IsNullable() } +func (p *Power) IsNullable() bool { return p.LeftChild.IsNullable() || p.RightChild.IsNullable() } func (p *Power) String() string { - return fmt.Sprintf("power(%s, %s)", p.Left, p.Right) + return fmt.Sprintf("power(%s, %s)", p.LeftChild, p.RightChild) } // WithChildren implements the Expression interface. @@ -151,7 +151,7 @@ func (p *Power) WithChildren(children ...sql.Expression) (sql.Expression, error) // Eval implements the Expression interface. func (p *Power) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - left, err := p.Left.Eval(ctx, row) + left, err := p.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -165,7 +165,7 @@ func (p *Power) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - right, err := p.Right.Eval(ctx, row) + right, err := p.RightChild.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/expression/function/strcmp.go b/sql/expression/function/strcmp.go index a4f262ad57..b1c4586eda 100644 --- a/sql/expression/function/strcmp.go +++ b/sql/expression/function/strcmp.go @@ -24,7 +24,7 @@ import ( // StrCmp compares two strings type StrCmp struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*StrCmp)(nil) @@ -33,9 +33,9 @@ var _ sql.CollationCoercible = (*StrCmp)(nil) // NewStrCmp creates a new NewStrCmp UDF. func NewStrCmp(e1, e2 sql.Expression) sql.Expression { return &StrCmp{ - expression.BinaryExpression{ - Left: e1, - Right: e2, + expression.BinaryExpressionStub{ + LeftChild: e1, + RightChild: e2, }, } } @@ -57,13 +57,13 @@ func (s *StrCmp) Type() sql.Type { // CollationCoercibility implements the interface sql.CollationCoercible. func (s *StrCmp) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - leftCollation, leftCoercibility := sql.GetCoercibility(ctx, s.Left) - rightCollation, rightCoercibility := sql.GetCoercibility(ctx, s.Right) + leftCollation, leftCoercibility := sql.GetCoercibility(ctx, s.LeftChild) + rightCollation, rightCoercibility := sql.GetCoercibility(ctx, s.RightChild) return sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility) } func (s *StrCmp) String() string { - return fmt.Sprintf("%s(%s,%s)", s.FunctionName(), s.Left, s.Right) + return fmt.Sprintf("%s(%s,%s)", s.FunctionName(), s.LeftChild, s.RightChild) } // WithChildren implements the Expression interface. @@ -75,11 +75,11 @@ func (s *StrCmp) WithChildren(children ...sql.Expression) (sql.Expression, error } func (s *StrCmp) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - if s.Left == nil || s.Right == nil { + if s.LeftChild == nil || s.RightChild == nil { return nil, nil } - expr1, err := s.Left.Eval(ctx, row) + expr1, err := s.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -87,7 +87,7 @@ func (s *StrCmp) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - expr2, err := s.Right.Eval(ctx, row) + expr2, err := s.RightChild.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/expression/function/time_format.go b/sql/expression/function/time_format.go index bc69b528d4..78070d96ea 100644 --- a/sql/expression/function/time_format.go +++ b/sql/expression/function/time_format.go @@ -73,7 +73,7 @@ func formatTime(format string, t time.Time) (string, error) { // TimeFormat function returns a string representation of the date specified in the format specified type TimeFormat struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*TimeFormat)(nil) @@ -92,20 +92,20 @@ func (f *TimeFormat) Description() string { // NewTimeFormat returns a new TimeFormat UDF func NewTimeFormat(ex, value sql.Expression) sql.Expression { return &TimeFormat{ - expression.BinaryExpression{ - Left: ex, - Right: value, + expression.BinaryExpressionStub{ + LeftChild: ex, + RightChild: value, }, } } // Eval implements the Expression interface. func (f *TimeFormat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - if f.Left == nil || f.Right == nil { + if f.LeftChild == nil || f.RightChild == nil { return nil, nil } - left, err := f.Left.Eval(ctx, row) + left, err := f.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -119,7 +119,7 @@ func (f *TimeFormat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - right, err := f.Right.Eval(ctx, row) + right, err := f.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -151,17 +151,17 @@ func (*TimeFormat) CollationCoercibility(ctx *sql.Context) (collation sql.Collat // IsNullable implements the Expression interface. func (f *TimeFormat) IsNullable() bool { - if types.IsNull(f.Left) { - if types.IsNull(f.Right) { + if types.IsNull(f.LeftChild) { + if types.IsNull(f.RightChild) { return true } - return f.Right.IsNullable() + return f.RightChild.IsNullable() } - return f.Left.IsNullable() + return f.LeftChild.IsNullable() } func (f *TimeFormat) String() string { - return fmt.Sprintf("%s(%s,%s)", f.FunctionName(), f.Left, f.Right) + return fmt.Sprintf("%s(%s,%s)", f.FunctionName(), f.LeftChild, f.RightChild) } // WithChildren implements the Expression interface. diff --git a/sql/expression/function/timediff.go b/sql/expression/function/timediff.go index 67e21f751f..62b52f3ac7 100644 --- a/sql/expression/function/timediff.go +++ b/sql/expression/function/timediff.go @@ -29,7 +29,7 @@ import ( // TimeDiff subtracts the second argument from the first expressed as a time value. type TimeDiff struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*TimeDiff)(nil) @@ -38,9 +38,9 @@ var _ sql.CollationCoercible = (*TimeDiff)(nil) // NewTimeDiff creates a new NewTimeDiff expression. func NewTimeDiff(e1, e2 sql.Expression) sql.Expression { return &TimeDiff{ - expression.BinaryExpression{ - Left: e1, - Right: e2, + expression.BinaryExpressionStub{ + LeftChild: e1, + RightChild: e2, }, } } @@ -64,7 +64,7 @@ func (*TimeDiff) CollationCoercibility(ctx *sql.Context) (collation sql.Collatio } func (td *TimeDiff) String() string { - return fmt.Sprintf("%s(%s,%s)", td.FunctionName(), td.Left, td.Right) + return fmt.Sprintf("%s(%s,%s)", td.FunctionName(), td.LeftChild, td.RightChild) } // WithChildren implements the Expression interface. @@ -89,16 +89,16 @@ func convToDateOrTime(val interface{}) (interface{}, error) { // Eval implements the Expression interface. func (td *TimeDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - if td.Left == nil || td.Right == nil { + if td.LeftChild == nil || td.RightChild == nil { return nil, nil } - left, err := td.Left.Eval(ctx, row) + left, err := td.LeftChild.Eval(ctx, row) if err != nil { return nil, err } - right, err := td.Right.Eval(ctx, row) + right, err := td.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -149,7 +149,7 @@ func (td *TimeDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // DateDiff returns expr1 − expr2 expressed as a value in days from one date to the other. type DateDiff struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.FunctionExpression = (*DateDiff)(nil) @@ -158,9 +158,9 @@ var _ sql.CollationCoercible = (*DateDiff)(nil) // NewDateDiff creates a new DATEDIFF() function. func NewDateDiff(expr1, expr2 sql.Expression) sql.Expression { return &DateDiff{ - expression.BinaryExpression{ - Left: expr1, - Right: expr2, + expression.BinaryExpressionStub{ + LeftChild: expr1, + RightChild: expr2, }, } } @@ -193,11 +193,11 @@ func (d *DateDiff) WithChildren(children ...sql.Expression) (sql.Expression, err // Eval implements the sql.Expression interface. func (d *DateDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - if d.Left == nil || d.Right == nil { + if d.LeftChild == nil || d.RightChild == nil { return nil, nil } - expr1, err := d.Left.Eval(ctx, row) + expr1, err := d.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -213,7 +213,7 @@ func (d *DateDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { expr1str := expr1.(time.Time).String()[:10] expr1, _, _ = types.DatetimeMaxPrecision.Convert(expr1str) - expr2, err := d.Right.Eval(ctx, row) + expr2, err := d.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -238,7 +238,7 @@ func (d *DateDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } func (d *DateDiff) String() string { - return fmt.Sprintf("DATEDIFF(%s, %s)", d.Left, d.Right) + return fmt.Sprintf("DATEDIFF(%s, %s)", d.LeftChild, d.RightChild) } // TimestampDiff returns expr1 − expr2 expressed as a value in unit specified. diff --git a/sql/expression/in.go b/sql/expression/in.go index 3c3c46ca2b..cca8cb92c9 100644 --- a/sql/expression/in.go +++ b/sql/expression/in.go @@ -24,7 +24,7 @@ import ( // InTuple is an expression that checks an expression is inside a list of expressions. type InTuple struct { - BinaryExpression + BinaryExpressionStub } // We implement Comparer because we have a Left() and a Right(), but we can't be Compare()d @@ -45,16 +45,16 @@ func (*InTuple) CollationCoercibility(ctx *sql.Context) (collation sql.Collation } func (in *InTuple) Left() sql.Expression { - return in.BinaryExpression.Left + return in.BinaryExpressionStub.LeftChild } func (in *InTuple) Right() sql.Expression { - return in.BinaryExpression.Right + return in.BinaryExpressionStub.RightChild } // NewInTuple creates an InTuple expression. func NewInTuple(left sql.Expression, right sql.Expression) *InTuple { - return &InTuple{BinaryExpression{left, right}} + return &InTuple{BinaryExpressionStub{left, right}} } // Eval implements the Expression interface. diff --git a/sql/expression/like.go b/sql/expression/like.go index d49cb1364f..4cd89020d0 100644 --- a/sql/expression/like.go +++ b/sql/expression/like.go @@ -33,7 +33,7 @@ func newDefaultLikeMatcher(likeStr string) (regex.DisposableMatcher, error) { // Like performs pattern matching against two strings. type Like struct { - BinaryExpression + BinaryExpressionStub Escape sql.Expression pool *sync.Pool once sync.Once @@ -59,11 +59,11 @@ func NewLike(left, right, escape sql.Expression) sql.Expression { }) return &Like{ - BinaryExpression: BinaryExpression{left, right}, - Escape: escape, - pool: nil, - once: sync.Once{}, - cached: cached, + BinaryExpressionStub: BinaryExpressionStub{left, right}, + Escape: escape, + pool: nil, + once: sync.Once{}, + cached: cached, } } @@ -72,8 +72,8 @@ func (l *Like) Type() sql.Type { return types.Boolean } // CollationCoercibility implements the interface sql.CollationCoercible. func (l *Like) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - leftCollation, leftCoercibility := sql.GetCoercibility(ctx, l.Left) - rightCollation, rightCoercibility := sql.GetCoercibility(ctx, l.Right) + leftCollation, leftCoercibility := sql.GetCoercibility(ctx, l.LeftChild) + rightCollation, rightCoercibility := sql.GetCoercibility(ctx, l.RightChild) return sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility) } @@ -82,7 +82,7 @@ func (l *Like) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { span, ctx := ctx.Span("expression.Like") defer span.End() - left, err := l.Left.Eval(ctx, row) + left, err := l.LeftChild.Eval(ctx, row) if err != nil || left == nil { return nil, err } @@ -137,7 +137,7 @@ func (l *Like) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } func (l *Like) evalRight(ctx *sql.Context, row sql.Row) (right *string, escape rune, err error) { - rightVal, err := l.Right.Eval(ctx, row) + rightVal, err := l.RightChild.Eval(ctx, row) if err != nil || rightVal == nil { return nil, 0, err } @@ -175,7 +175,7 @@ func (l *Like) evalRight(ctx *sql.Context, row sql.Row) (right *string, escape r } func (l *Like) String() string { - return fmt.Sprintf("%s LIKE %s", l.Left, l.Right) + return fmt.Sprintf("%s LIKE %s", l.LeftChild, l.RightChild) } // WithChildren implements the Expression interface. diff --git a/sql/expression/logic.go b/sql/expression/logic.go index e3385d4338..a244f55c7f 100644 --- a/sql/expression/logic.go +++ b/sql/expression/logic.go @@ -23,7 +23,7 @@ import ( // And checks whether two expressions are true. type And struct { - BinaryExpression + BinaryExpressionStub } var _ sql.Expression = (*And)(nil) @@ -31,7 +31,7 @@ var _ sql.CollationCoercible = (*And)(nil) // NewAnd creates a new And expression. func NewAnd(left, right sql.Expression) sql.Expression { - return &And{BinaryExpression{Left: left, Right: right}} + return &And{BinaryExpressionStub{LeftChild: left, RightChild: right}} } // JoinAnd joins several expressions with And. @@ -66,8 +66,8 @@ func SplitConjunction(expr sql.Expression) []sql.Expression { } return append( - SplitConjunction(and.Left), - SplitConjunction(and.Right)..., + SplitConjunction(and.LeftChild), + SplitConjunction(and.RightChild)..., ) } @@ -82,19 +82,19 @@ func SplitDisjunction(expr sql.Expression) []sql.Expression { } return append( - SplitDisjunction(and.Left), - SplitDisjunction(and.Right)..., + SplitDisjunction(and.LeftChild), + SplitDisjunction(and.RightChild)..., ) } func (a *And) String() string { - return fmt.Sprintf("(%s AND %s)", a.Left, a.Right) + return fmt.Sprintf("(%s AND %s)", a.LeftChild, a.RightChild) } func (a *And) DebugString() string { pr := sql.NewTreePrinter() _ = pr.WriteNode("AND") - children := []string{sql.DebugString(a.Left), sql.DebugString(a.Right)} + children := []string{sql.DebugString(a.LeftChild), sql.DebugString(a.RightChild)} _ = pr.WriteChildren(children...) return pr.String() } @@ -111,7 +111,7 @@ func (*And) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, // Eval implements the Expression interface. func (a *And) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - lval, err := a.Left.Eval(ctx, row) + lval, err := a.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -122,7 +122,7 @@ func (a *And) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } - rval, err := a.Right.Eval(ctx, row) + rval, err := a.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -150,7 +150,7 @@ func (a *And) WithChildren(children ...sql.Expression) (sql.Expression, error) { // Or checks whether one of the two given expressions is true. type Or struct { - BinaryExpression + BinaryExpressionStub } var _ sql.Expression = (*Or)(nil) @@ -158,7 +158,7 @@ var _ sql.CollationCoercible = (*Or)(nil) // NewOr creates a new Or expression. func NewOr(left, right sql.Expression) sql.Expression { - return &Or{BinaryExpression{Left: left, Right: right}} + return &Or{BinaryExpressionStub{LeftChild: left, RightChild: right}} } // JoinOr joins several expressions with Or. @@ -183,13 +183,13 @@ func JoinOr(exprs ...sql.Expression) sql.Expression { } func (o *Or) String() string { - return fmt.Sprintf("(%s OR %s)", o.Left, o.Right) + return fmt.Sprintf("(%s OR %s)", o.LeftChild, o.RightChild) } func (o *Or) DebugString() string { pr := sql.NewTreePrinter() _ = pr.WriteNode("Or") - children := []string{sql.DebugString(o.Left), sql.DebugString(o.Right)} + children := []string{sql.DebugString(o.LeftChild), sql.DebugString(o.RightChild)} _ = pr.WriteChildren(children...) return pr.String() } @@ -206,7 +206,7 @@ func (*Or) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, c // Eval implements the Expression interface. func (o *Or) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - lval, err := o.Left.Eval(ctx, row) + lval, err := o.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -217,7 +217,7 @@ func (o *Or) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } - rval, err := o.Right.Eval(ctx, row) + rval, err := o.RightChild.Eval(ctx, row) if err != nil { return nil, err } @@ -247,7 +247,7 @@ func (o *Or) WithChildren(children ...sql.Expression) (sql.Expression, error) { // Xor checks whether only one of the two given expressions is true. type Xor struct { - BinaryExpression + BinaryExpressionStub } var _ sql.Expression = (*Xor)(nil) @@ -255,15 +255,15 @@ var _ sql.CollationCoercible = (*Xor)(nil) // NewXor creates a new Xor expression. func NewXor(left, right sql.Expression) sql.Expression { - return &Xor{BinaryExpression{Left: left, Right: right}} + return &Xor{BinaryExpressionStub{LeftChild: left, RightChild: right}} } func (x *Xor) String() string { - return fmt.Sprintf("(%s XOR %s)", x.Left, x.Right) + return fmt.Sprintf("(%s XOR %s)", x.LeftChild, x.RightChild) } func (x *Xor) DebugString() string { - return fmt.Sprintf("%s XOR %s", sql.DebugString(x.Left), sql.DebugString(x.Right)) + return fmt.Sprintf("%s XOR %s", sql.DebugString(x.LeftChild), sql.DebugString(x.RightChild)) } // Type implements the Expression interface. @@ -278,7 +278,7 @@ func (*Xor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, // Eval implements the Expression interface. func (x *Xor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - lval, err := x.Left.Eval(ctx, row) + lval, err := x.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -290,7 +290,7 @@ func (x *Xor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - rval, err := x.Right.Eval(ctx, row) + rval, err := x.RightChild.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/expression/mod.go b/sql/expression/mod.go index 856263c36f..9fabb6e535 100644 --- a/sql/expression/mod.go +++ b/sql/expression/mod.go @@ -30,7 +30,7 @@ var _ sql.CollationCoercible = (*Mod)(nil) // Mod expression represents "%" arithmetic operation type Mod struct { - BinaryExpression + BinaryExpressionStub ops int32 } @@ -39,7 +39,7 @@ var _ sql.CollationCoercible = (*Mod)(nil) // NewMod creates a new Mod sql.Expression. func NewMod(left, right sql.Expression) *Mod { - a := &Mod{BinaryExpression{Left: left, Right: right}, 0} + a := &Mod{BinaryExpressionStub{LeftChild: left, RightChild: right}, 0} ops := countArithmeticOps(a) setArithmeticOps(a, ops) return a @@ -53,14 +53,6 @@ func (m *Mod) Description() string { return "returns the remainder of the first argument divided by the second argument" } -func (m *Mod) LeftChild() sql.Expression { - return m.Left -} - -func (m *Mod) RightChild() sql.Expression { - return m.Right -} - func (m *Mod) Operator() string { return sqlparser.ModStr } @@ -70,26 +62,26 @@ func (m *Mod) SetOpCount(i int32) { } func (m *Mod) String() string { - return fmt.Sprintf("(%s %% %s)", m.Left, m.Right) + return fmt.Sprintf("(%s %% %s)", m.LeftChild, m.RightChild) } func (m *Mod) DebugString() string { - return fmt.Sprintf("(%s %% %s)", sql.DebugString(m.Left), sql.DebugString(m.Right)) + return fmt.Sprintf("(%s %% %s)", sql.DebugString(m.LeftChild), sql.DebugString(m.RightChild)) } // IsNullable implements the sql.Expression interface. func (m *Mod) IsNullable() bool { - return m.BinaryExpression.IsNullable() + return m.BinaryExpressionStub.IsNullable() } // Type returns the greatest type for given operation. func (m *Mod) Type() sql.Type { //TODO: what if both BindVars? should be constant folded - rTyp := m.Right.Type() + rTyp := m.RightChild.Type() if types.IsDeferredType(rTyp) { return rTyp } - lTyp := m.Left.Type() + lTyp := m.LeftChild.Type() if types.IsDeferredType(lTyp) { return lTyp } @@ -137,12 +129,12 @@ func (m *Mod) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interfa var err error // mod used with Interval error is caught at parsing the query - lval, err = m.Left.Eval(ctx, row) + lval, err = m.LeftChild.Eval(ctx, row) if err != nil { return nil, nil, err } - rval, err = m.Right.Eval(ctx, row) + rval, err = m.RightChild.Eval(ctx, row) if err != nil { return nil, nil, err } @@ -152,8 +144,8 @@ func (m *Mod) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interfa func (m *Mod) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) { typ := m.Type() - lIsTimeType := types.IsTime(m.Left.Type()) - rIsTimeType := types.IsTime(m.Right.Type()) + lIsTimeType := types.IsTime(m.LeftChild.Type()) + rIsTimeType := types.IsTime(m.RightChild.Type()) if types.IsFloat(typ) { left = convertValueToType(ctx, typ, left, lIsTimeType) diff --git a/sql/expression/set.go b/sql/expression/set.go index cf76eb7561..e7e72e95f2 100644 --- a/sql/expression/set.go +++ b/sql/expression/set.go @@ -27,7 +27,7 @@ var errCannotSetField = errors.NewKind("Expected GetField expression on left but // SetField updates the value of a field or a system variable type SetField struct { - BinaryExpression + BinaryExpressionStub } var _ sql.Expression = (*SetField)(nil) @@ -35,39 +35,39 @@ var _ sql.CollationCoercible = (*SetField)(nil) // NewSetField creates a new SetField expression. func NewSetField(left, expr sql.Expression) sql.Expression { - return &SetField{BinaryExpression{Left: left, Right: expr}} + return &SetField{BinaryExpressionStub{LeftChild: left, RightChild: expr}} } func (s *SetField) String() string { - return fmt.Sprintf("SET %s = %s", s.Left, s.Right) + return fmt.Sprintf("SET %s = %s", s.LeftChild, s.RightChild) } func (s *SetField) DebugString() string { - return fmt.Sprintf("SET %s = %s", sql.DebugString(s.Left), sql.DebugString(s.Right)) + return fmt.Sprintf("SET %s = %s", sql.DebugString(s.LeftChild), sql.DebugString(s.RightChild)) } // Type implements the Expression interface. func (s *SetField) Type() sql.Type { - return s.Left.Type() + return s.LeftChild.Type() } // CollationCoercibility implements the interface sql.CollationCoercible. func (s *SetField) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - return sql.GetCoercibility(ctx, s.Left) + return sql.GetCoercibility(ctx, s.LeftChild) } // Eval implements the Expression interface. // Returns a copy of the given row with an updated value. func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - getField, ok := s.Left.(*GetField) + getField, ok := s.LeftChild.(*GetField) if !ok { - return nil, errCannotSetField.New(s.Left) + return nil, errCannotSetField.New(s.LeftChild) } if getField.fieldIndex < 0 || getField.fieldIndex >= len(row) { return nil, ErrIndexOutOfBounds.New(getField.fieldIndex, len(row)) } - val, err := s.Right.Eval(ctx, row) + val, err := s.RightChild.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/plan/insubquery.go b/sql/plan/insubquery.go index 57b948f109..6a9c2e30c9 100644 --- a/sql/plan/insubquery.go +++ b/sql/plan/insubquery.go @@ -26,7 +26,7 @@ import ( // instead of the expression package, because Subquery is itself in the plan package (because it functions more like a // plan node than an expression in its evaluation). type InSubquery struct { - expression.BinaryExpression + expression.BinaryExpressionStub } var _ sql.Expression = (*InSubquery)(nil) @@ -44,15 +44,15 @@ func (*InSubquery) CollationCoercibility(ctx *sql.Context) (collation sql.Collat // NewInSubquery creates an InSubquery expression. func NewInSubquery(left sql.Expression, right sql.Expression) *InSubquery { - return &InSubquery{expression.BinaryExpression{Left: left, Right: right}} + return &InSubquery{expression.BinaryExpressionStub{LeftChild: left, RightChild: right}} } var nilKey, _ = sql.HashOf(sql.NewRow(nil)) // Eval implements the Expression interface. func (in *InSubquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - typ := in.Left.Type().Promote() - left, err := in.Left.Eval(ctx, row) + typ := in.LeftChild.Type().Promote() + left, err := in.LeftChild.Eval(ctx, row) if err != nil { return nil, err } @@ -69,7 +69,7 @@ func (in *InSubquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - switch right := in.Right.(type) { + switch right := in.RightChild.(type) { case *Subquery: if types.NumColumns(typ) != types.NumColumns(right.Type()) { return nil, sql.ErrInvalidOperandColumns.New(types.NumColumns(typ), types.NumColumns(right.Type())) @@ -138,8 +138,8 @@ func (in *InSubquery) WithChildren(children ...sql.Expression) (sql.Expression, func (in *InSubquery) Describe(options sql.DescribeOptions) string { pr := sql.NewTreePrinter() _ = pr.WriteNode("InSubquery") - children := []string{fmt.Sprintf("left: %s", sql.Describe(in.Left, options)), - fmt.Sprintf("right: %s", sql.Describe(in.Right, options))} + children := []string{fmt.Sprintf("left: %s", sql.Describe(in.Left(), options)), + fmt.Sprintf("right: %s", sql.Describe(in.Right(), options))} _ = pr.WriteChildren(children...) return pr.String() } @@ -164,12 +164,12 @@ func (in *InSubquery) DebugString() string { // Children implements the Expression interface. func (in *InSubquery) Children() []sql.Expression { - return []sql.Expression{in.Left, in.Right} + return []sql.Expression{in.LeftChild, in.RightChild} } // Dispose implements sql.Disposable func (in *InSubquery) Dispose() { - if sq, ok := in.Right.(*Subquery); ok { + if sq, ok := in.RightChild.(*Subquery); ok { sq.Dispose() } } diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index d641a6a32e..69f5c0f795 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -316,7 +316,7 @@ func isColumnUpdated(col *sql.Column, updateExprs []sql.Expression) bool { if !ok { continue } - gf, ok := sf.Left.(*expression.GetField) + gf, ok := sf.LeftChild.(*expression.GetField) if !ok { continue } @@ -610,7 +610,7 @@ func getTablesToBeUpdated(node sql.Node) map[string]struct{} { transform.InspectExpressions(node, func(e sql.Expression) bool { switch e := e.(type) { case *expression.SetField: - gf := e.Left.(*expression.GetField) + gf := e.LeftChild.(*expression.GetField) ret[gf.Table()] = struct{}{} return false } diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 185c7c3564..89cd9eaaf3 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -41,7 +41,30 @@ func (b *Builder) buildWhere(inScope *scope, where *ast.Where) { inScope.node = filterNode } -func (b *Builder) buildScalar(inScope *scope, e ast.Expr) sql.Expression { +func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { + defer func() { + if !(b.bindCtx == nil || b.bindCtx.resolveOnly) { + return + } + + if be, ok := ex.(expression.BinaryExpression); ok { + left := be.Left() + right := be.Right() + if leftBindVar, ok := left.(*expression.BindVar); ok { + if typ, ok := hasColumnType(right); ok { + leftBindVar.Typ = typ + left = leftBindVar + } + } else if rightBindVar, ok := right.(*expression.BindVar); ok { + if typ, ok := hasColumnType(left); ok { + rightBindVar.Typ = typ + right = rightBindVar + } + } + ex, _ = be.WithChildren(left, right) + } + }() + switch v := e.(type) { case *ast.Default: return expression.WrapExpression(expression.NewDefaultColumn(v.ColName)) @@ -472,8 +495,6 @@ func (b *Builder) buildComparison(inScope *scope, c *ast.ComparisonExpr) sql.Exp left := b.buildScalar(inScope, c.Left) right := b.buildScalar(inScope, c.Right) - left, right = b.annotateBindvarsWithTypeInfo(c, left, right) - var escape sql.Expression = nil if c.Escape != nil { escape = b.buildScalar(inScope, c.Escape) @@ -531,26 +552,6 @@ func (b *Builder) buildComparison(inScope *scope, c *ast.ComparisonExpr) sql.Exp return nil } -// annotateBindvarsWithTypeInfo assigns the type of the column expression the bindvar on left and right, if possible. -// This only works if one side of the comparison is a bindvar and the other is a column expression. -// Otherwise, |left| and |right| are returned unchanged. -func (b *Builder) annotateBindvarsWithTypeInfo(c *ast.ComparisonExpr, left sql.Expression, right sql.Expression) (sql.Expression, sql.Expression) { - if leftBind, ok := c.Left.(*ast.SQLVal); ok && b.shouldAssignBindvarType(leftBind) { - leftBindVar := left.(*expression.BindVar) - if typ, ok := hasColumnType(right); ok { - leftBindVar.Typ = typ - left = leftBindVar - } - } else if rightBind, ok := c.Right.(*ast.SQLVal); ok && b.shouldAssignBindvarType(rightBind) { - rightBindVar := right.(*expression.BindVar) - if typ, ok := hasColumnType(left); ok { - rightBindVar.Typ = typ - right = rightBindVar - } - } - return left, right -} - func hasColumnType(e sql.Expression) (sql.Type, bool) { var typ sql.Type sql.Inspect(e, func(e sql.Expression) bool { diff --git a/sql/rowexec/dml_iters.go b/sql/rowexec/dml_iters.go index 94f8586dea..19c974e1e4 100644 --- a/sql/rowexec/dml_iters.go +++ b/sql/rowexec/dml_iters.go @@ -133,7 +133,7 @@ func shouldUseTriggerStatementForReturnRow(stmt sql.Node) bool { switch logic := n.(type) { case *plan.Set: for _, expr := range logic.Exprs { - sql.Inspect(expr.(*expression.SetField).Left, func(e sql.Expression) bool { + sql.Inspect(expr.(*expression.SetField).LeftChild, func(e sql.Expression) bool { if _, ok := e.(*expression.GetField); ok { hasSetField = true return false @@ -239,7 +239,7 @@ func shouldUseLogicResult(logic sql.Node, row sql.Row) (bool, sql.Row) { case *plan.Set: hasSetField := false for _, expr := range logic.Exprs { - sql.Inspect(expr.(*expression.SetField).Left, func(e sql.Expression) bool { + sql.Inspect(expr.(*expression.SetField).LeftChild, func(e sql.Expression) bool { if _, ok := e.(*expression.GetField); ok { hasSetField = true return false @@ -256,7 +256,7 @@ func shouldUseLogicResult(logic sql.Node, row sql.Row) (bool, sql.Row) { return true } for _, expr := range set.Exprs { - sql.Inspect(expr.(*expression.SetField).Left, func(e sql.Expression) bool { + sql.Inspect(expr.(*expression.SetField).LeftChild, func(e sql.Expression) bool { if _, ok := e.(*expression.GetField); ok { hasSetField = true return false diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index fe79001fce..700bb0b4b2 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -222,7 +222,7 @@ func getFieldIndexFromUpdateExpr(updateExpr sql.Expression) (int, bool) { return 0, false } - getField, ok := setField.Left.(*expression.GetField) + getField, ok := setField.LeftChild.(*expression.GetField) if !ok { return 0, false } diff --git a/sql/rowexec/rel.go b/sql/rowexec/rel.go index d435d48491..056e56caf7 100644 --- a/sql/rowexec/rel.go +++ b/sql/rowexec/rel.go @@ -353,23 +353,23 @@ func (b *BaseBuilder) buildSet(ctx *sql.Context, n *plan.Set, row sql.Row) (sql. return nil, fmt.Errorf("unsupported type for set: %T", v) } - switch left := setField.Left.(type) { + switch left := setField.LeftChild.(type) { case *expression.SystemVar: - err := setSystemVar(ctx, left, setField.Right, row) + err := setSystemVar(ctx, left, setField.RightChild, row) if err != nil { return nil, err } case *expression.UserVar: - err := setUserVar(ctx, left, setField.Right, row) + err := setUserVar(ctx, left, setField.RightChild, row) if err != nil { return nil, err } case *expression.ProcedureParam: - value, err := setField.Right.Eval(ctx, row) + value, err := setField.RightChild.Eval(ctx, row) if err != nil { return nil, err } - err = left.Set(value, setField.Right.Type()) + err = left.Set(value, setField.RightChild.Type()) if err != nil { return nil, err } diff --git a/sql/rowexec/rel_iters.go b/sql/rowexec/rel_iters.go index 48ca66b968..2b5392b3d9 100644 --- a/sql/rowexec/rel_iters.go +++ b/sql/rowexec/rel_iters.go @@ -571,7 +571,7 @@ func defaultValFromProjectExpr(e sql.Expression) (*sql.ColumnDefaultValue, bool) func defaultValFromSetExpression(e sql.Expression) (*sql.ColumnDefaultValue, bool) { if sf, ok := e.(*expression.SetField); ok { - return defaultValFromProjectExpr(sf.Right) + return defaultValFromProjectExpr(sf.RightChild) } return nil, false }