Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYU
github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72/go.mod h1:ZWUdY4iszqRQ8OcoXClkxiAVAoWoK3cq0Hvv4ddGRuM=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20240110233415-e46007d964c0 h1:P8wb4dR5krirPa0swEJbEObc/I7GaAM/01nOnuQrl0c=
github.com/dolthub/vitess v0.0.0-20240110233415-e46007d964c0/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20240117061527-f9260279b3d3 h1:nEwq2/8gTI2jm/4APIMTrWNDDRCn8AWJjrCbH+d7CJc=
github.com/dolthub/vitess v0.0.0-20240117061527-f9260279b3d3/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20240117195812-420942cccb48 h1:Bdsy71WXx4yvK71IFwIqQ2duL5a/y15EuKEhVN51bSE=
github.com/dolthub/vitess v0.0.0-20240117195812-420942cccb48/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20240117220136-123ca09b8929 h1:6SExRtdwbcNPl7q09SXxtnwk+pVdhrsd0ap1DVfphEg=
github.com/dolthub/vitess v0.0.0-20240117220136-123ca09b8929/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462 h1:So1KO202cb047yWg5X27xRso6tkSYmU0Yu96JIVsaEU=
github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/assign_update_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
transform.InspectExpressions(node, func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.SetField:
gf := e.Left.(*expression.GetField)
gf := e.LeftChild.(*expression.GetField)
ret[gf.Table()] = struct{}{}
return false
}
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/fix_exec_indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ func (s *idxScope) visitSelf(n sql.Node) error {
}
// left uses destination schema
// right uses |rightSchema|
newLeft := fixExprToScope(set.Left, dstScope)
newRight := fixExprToScope(set.Right, rightScope)
newLeft := fixExprToScope(set.LeftChild, dstScope)
newRight := fixExprToScope(set.RightChild, rightScope)
s.expressions = append(s.expressions, expression.NewSetField(newLeft, newRight))
}
for _, c := range n.Checks() {
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/hoist_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func recurseSubqueryForOuterFilters(n sql.Node, a *Analyzer, corr sql.ColSet) (s
var sq *plan.Subquery
switch e := e.(type) {
case *plan.InSubquery:
sq, _ = e.Right.(*plan.Subquery)
sq, _ = e.RightChild.(*plan.Subquery)
case *plan.ExistsSubquery:
sq = e.Query
default:
Expand Down
48 changes: 24 additions & 24 deletions sql/analyzer/optimization_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,50 +244,50 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S
expression.NewLessThanOrEqual(e.Val, e.Upper),
), transform.NewTree, nil
case *expression.Or:
if isTrue(e.Left) {
return e.Left, transform.NewTree, nil
if isTrue(e.LeftChild) {
return e.LeftChild, transform.NewTree, nil
}

if isTrue(e.Right) {
return e.Right, transform.NewTree, nil
if isTrue(e.RightChild) {
return e.RightChild, transform.NewTree, nil
}

if isFalse(e.Left) {
return e.Right, transform.NewTree, nil
if isFalse(e.LeftChild) {
return e.RightChild, transform.NewTree, nil
}

if isFalse(e.Right) {
return e.Left, transform.NewTree, nil
if isFalse(e.RightChild) {
return e.LeftChild, transform.NewTree, nil
}

return e, transform.SameTree, nil
case *expression.And:
if isFalse(e.Left) {
return e.Left, transform.NewTree, nil
if isFalse(e.LeftChild) {
return e.LeftChild, transform.NewTree, nil
}

if isFalse(e.Right) {
return e.Right, transform.NewTree, nil
if isFalse(e.RightChild) {
return e.RightChild, transform.NewTree, nil
}

if isTrue(e.Left) {
return e.Right, transform.NewTree, nil
if isTrue(e.LeftChild) {
return e.RightChild, transform.NewTree, nil
}

if isTrue(e.Right) {
return e.Left, transform.NewTree, nil
if isTrue(e.RightChild) {
return e.LeftChild, transform.NewTree, nil
}

return e, transform.SameTree, nil
case *expression.Like:
// if the charset is not utf8mb4, the last character used in optimization rule does not work
coll, _ := sql.GetCoercibility(ctx, e.Left)
coll, _ := sql.GetCoercibility(ctx, e.LeftChild)
charset := coll.CharacterSet()
if charset != sql.CharacterSet_utf8mb4 {
return e, transform.SameTree, nil
}
// TODO: maybe more cases to simplify
r, ok := e.Right.(*expression.Literal)
r, ok := e.RightChild.(*expression.Literal)
if !ok {
return e, transform.SameTree, nil
}
Expand All @@ -310,7 +310,7 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S
// if there are also no multiple character wildcards, this is just a plain equals
numWild := strings.Count(valStr, "%") - strings.Count(valStr, "\\%")
if numWild == 0 {
return expression.NewEquals(e.Left, e.Right), transform.NewTree, nil
return expression.NewEquals(e.LeftChild, e.RightChild), transform.NewTree, nil
}
// if there are many multiple character wildcards, don't simplify
if numWild != 1 {
Expand All @@ -328,10 +328,10 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S
return e, transform.SameTree, nil
}
valStr = valStr[:len(valStr)-1]
newRightLower := expression.NewLiteral(valStr, e.Right.Type())
newRightLower := expression.NewLiteral(valStr, e.RightChild.Type())
valStr += string(byte(255)) // append largest possible character as upper bound
newRightUpper := expression.NewLiteral(valStr, e.Right.Type())
newExpr := expression.NewAnd(expression.NewGreaterThanOrEqual(e.Left, newRightLower), expression.NewLessThanOrEqual(e.Left, newRightUpper))
newRightUpper := expression.NewLiteral(valStr, e.RightChild.Type())
newExpr := expression.NewAnd(expression.NewGreaterThanOrEqual(e.LeftChild, newRightLower), expression.NewLessThanOrEqual(e.LeftChild, newRightUpper))
return newExpr, transform.NewTree, nil
case *expression.Literal, expression.Tuple, *expression.Interval, *expression.CollatedExpression, *expression.MatchAgainst:
return e, transform.SameTree, nil
Expand Down Expand Up @@ -435,14 +435,14 @@ func pushNotFiltersHelper(e sql.Expression) (sql.Expression, error) {
// NOT(AND(left,right))=>OR(NOT(left), NOT(right))
if not, _ := e.(*expression.Not); not != nil {
if f, _ := not.Child.(*expression.And); f != nil {
return pushNotFiltersHelper(expression.NewOr(expression.NewNot(f.Left), expression.NewNot(f.Right)))
return pushNotFiltersHelper(expression.NewOr(expression.NewNot(f.LeftChild), expression.NewNot(f.RightChild)))
}
}

// NOT(OR(left,right))=>AND(NOT(left), NOT(right))
if not, _ := e.(*expression.Not); not != nil {
if f, _ := not.Child.(*expression.Or); f != nil {
return pushNotFiltersHelper(expression.NewAnd(expression.NewNot(f.Left), expression.NewNot(f.Right)))
return pushNotFiltersHelper(expression.NewAnd(expression.NewNot(f.LeftChild), expression.NewNot(f.RightChild)))
}
}

Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func validateCreateTrigger(ctx *sql.Context, a *Analyzer, node sql.Node, scope *

switch e := e.(type) {
case *expression.SetField:
switch left := e.Left.(type) {
switch left := e.LeftChild.(type) {
case column:
if strings.ToLower(left.Table()) == "old" {
err = sql.ErrInvalidUpdateOfOldRow.New()
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/unnest_insubqueries.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ func unnestInSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
var max1 bool
switch e := candE.(type) {
case *plan.InSubquery:
sq, _ = e.Right.(*plan.Subquery)
l = e.Left
sq, _ = e.RightChild.(*plan.Subquery)
l = e.LeftChild

joinF = expression.NewEquals(nil, nil)
case expression.Comparer:
Expand Down
10 changes: 5 additions & 5 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -922,23 +922,23 @@ func validateExprSem(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
func validateSem(e sql.Expression) error {
switch e := e.(type) {
case *expression.And:
if err := logicalSem(e.BinaryExpression); err != nil {
if err := logicalSem(e.BinaryExpressionStub); err != nil {
return err
}
case *expression.Or:
if err := logicalSem(e.BinaryExpression); err != nil {
if err := logicalSem(e.BinaryExpressionStub); err != nil {
return err
}
default:
}
return nil
}

func logicalSem(e expression.BinaryExpression) error {
if lc := fds(e.Left); lc != 1 {
func logicalSem(e expression.BinaryExpressionStub) error {
if lc := fds(e.LeftChild); lc != 1 {
return sql.ErrInvalidOperandColumns.New(1, lc)
}
if rc := fds(e.Right); rc != 1 {
if rc := fds(e.RightChild); rc != 1 {
return sql.ErrInvalidOperandColumns.New(1, rc)
}
return nil
Expand Down
49 changes: 20 additions & 29 deletions sql/expression/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ func arithmeticWarning(ctx *sql.Context, errCode int, errMsg string) {
// Arithmetic expression in the future.
type ArithmeticOp interface {
sql.Expression
LeftChild() sql.Expression
RightChild() sql.Expression
BinaryExpression
SetOpCount(int32)
Operator() string
}
Expand All @@ -68,14 +67,14 @@ var _ sql.CollationCoercible = (*Arithmetic)(nil)

// Arithmetic expressions include plus, minus and multiplication (+, -, *) operations.
type Arithmetic struct {
BinaryExpression
BinaryExpressionStub
Op string
ops int32
}

// NewArithmetic creates a new Arithmetic sql.Expression.
func NewArithmetic(left, right sql.Expression, op string) *Arithmetic {
a := &Arithmetic{BinaryExpression{Left: left, Right: right}, op, 0}
a := &Arithmetic{BinaryExpressionStub{LeftChild: left, RightChild: right}, op, 0}
ops := countArithmeticOps(a)
setArithmeticOps(a, ops)
return a
Expand All @@ -96,14 +95,6 @@ func NewMult(left, right sql.Expression) *Arithmetic {
return NewArithmetic(left, right, sqlparser.MultStr)
}

func (a *Arithmetic) LeftChild() sql.Expression {
return a.Left
}

func (a *Arithmetic) RightChild() sql.Expression {
return a.Right
}

func (a *Arithmetic) Operator() string {
return a.Op
}
Expand All @@ -113,11 +104,11 @@ func (a *Arithmetic) SetOpCount(i int32) {
}

func (a *Arithmetic) String() string {
return fmt.Sprintf("(%s %s %s)", a.Left, a.Op, a.Right)
return fmt.Sprintf("(%s %s %s)", a.LeftChild, a.Op, a.RightChild)
}

func (a *Arithmetic) DebugString() string {
return fmt.Sprintf("(%s %s %s)", sql.DebugString(a.Left), a.Op, sql.DebugString(a.Right))
return fmt.Sprintf("(%s %s %s)", sql.DebugString(a.LeftChild), a.Op, sql.DebugString(a.RightChild))
}

// IsNullable implements the sql.Expression interface.
Expand All @@ -126,23 +117,23 @@ func (a *Arithmetic) IsNullable() bool {
return true
}

return a.BinaryExpression.IsNullable()
return a.BinaryExpressionStub.IsNullable()
}

// Type returns the greatest type for given operation.
func (a *Arithmetic) Type() sql.Type {
//TODO: what if both BindVars? should be constant folded
rTyp := a.Right.Type()
rTyp := a.RightChild.Type()
if types.IsDeferredType(rTyp) {
return rTyp
}
lTyp := a.Left.Type()
lTyp := a.LeftChild.Type()
if types.IsDeferredType(lTyp) {
return lTyp
}

// applies for + and - ops
if isInterval(a.Left) || isInterval(a.Right) {
if isInterval(a.LeftChild) || isInterval(a.RightChild) {
// TODO: we might need to truncate precision here
return types.DatetimeMaxPrecision
}
Expand All @@ -167,7 +158,7 @@ func (a *Arithmetic) Type() sql.Type {
}

if a.Op == sqlparser.MultStr {
return floatOrDecimalTypeForMult(a.Left, a.Right)
return floatOrDecimalTypeForMult(a.LeftChild, a.RightChild)
} else {
return getFloatOrMaxDecimalType(a, false)
}
Expand Down Expand Up @@ -225,25 +216,25 @@ func (a *Arithmetic) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{},
var lval, rval interface{}
var err error

if i, ok := a.Left.(*Interval); ok {
if i, ok := a.LeftChild.(*Interval); ok {
lval, err = i.EvalDelta(ctx, row)
if err != nil {
return nil, nil, err
}
} else {
lval, err = a.Left.Eval(ctx, row)
lval, err = a.LeftChild.Eval(ctx, row)
if err != nil {
return nil, nil, err
}
}

if i, ok := a.Right.(*Interval); ok {
if i, ok := a.RightChild.(*Interval); ok {
rval, err = i.EvalDelta(ctx, row)
if err != nil {
return nil, nil, err
}
} else {
rval, err = a.Right.Eval(ctx, row)
rval, err = a.RightChild.Eval(ctx, row)
if err != nil {
return nil, nil, err
}
Expand All @@ -255,8 +246,8 @@ func (a *Arithmetic) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{},
func (a *Arithmetic) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}, error) {
typ := a.Type()

lIsTimeType := types.IsTime(a.Left.Type())
rIsTimeType := types.IsTime(a.Right.Type())
lIsTimeType := types.IsTime(a.LeftChild.Type())
rIsTimeType := types.IsTime(a.RightChild.Type())

if i, ok := left.(*TimeDelta); ok {
left = i
Expand Down Expand Up @@ -296,7 +287,7 @@ func countArithmeticOps(e sql.Expression) int32 {
}

if a, ok := e.(ArithmeticOp); ok {
return countDivs(a.LeftChild()) + 1
return countDivs(a.Left()) + 1
}

return 0
Expand All @@ -311,8 +302,8 @@ func setArithmeticOps(e sql.Expression, opScale int32) {

if a, ok := e.(ArithmeticOp); ok {
a.SetOpCount(opScale)
setDivs(a.LeftChild(), opScale)
setDivs(a.RightChild(), opScale)
setDivs(a.Left(), opScale)
setDivs(a.Right(), opScale)
}

return
Expand All @@ -330,7 +321,7 @@ func isOutermostArithmeticOp(e sql.Expression, d, dScale int32) bool {
if d == dScale {
return true
} else {
return isOutermostDiv(a.LeftChild(), d, dScale)
return isOutermostDiv(a.Left(), d, dScale)
}
}

Expand Down
Loading