diff --git a/optgen/cmd/support/agg_gen.go b/optgen/cmd/support/agg_gen.go index 4ab2c37c00..cbf97498f2 100644 --- a/optgen/cmd/support/agg_gen.go +++ b/optgen/cmd/support/agg_gen.go @@ -85,7 +85,7 @@ func (g *AggGen) genAggConstructor(define AggDef) { fmt.Fprintf(g.w, "func New%s(e sql.Expression) *%s {\n", define.Name, define.Name) fmt.Fprintf(g.w, " return &%s{\n", define.Name) fmt.Fprintf(g.w, " unaryAggBase{\n") - fmt.Fprintf(g.w, " UnaryExpression: expression.UnaryExpression{Child: e},\n") + fmt.Fprintf(g.w, " Child: e,\n") fmt.Fprintf(g.w, " functionName: \"%s\",\n", define.Name) fmt.Fprintf(g.w, " description: \"%s\",\n", define.Desc) fmt.Fprintf(g.w, " },\n") diff --git a/optgen/cmd/support/agg_gen_test.go b/optgen/cmd/support/agg_gen_test.go index ed99bd7b42..d1abb01b9f 100644 --- a/optgen/cmd/support/agg_gen_test.go +++ b/optgen/cmd/support/agg_gen_test.go @@ -41,7 +41,7 @@ func TestAggGen(t *testing.T) { func NewTest(e sql.Expression) *Test { return &Test{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, + Child: e, functionName: "Test", description: "Test description", }, diff --git a/sql/analyzer/replace_sort.go b/sql/analyzer/replace_sort.go index 1d27a4bfea..4d3d89e59a 100644 --- a/sql/analyzer/replace_sort.go +++ b/sql/analyzer/replace_sort.go @@ -373,7 +373,7 @@ func replaceAgg(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope, var sf sql.SortField switch agg := gb.SelectDeps[0].(type) { case *aggregation.Max: - gf, ok := agg.UnaryExpression.Child.(*expression.GetField) + gf, ok := agg.Child.(*expression.GetField) if !ok { return n, transform.SameTree, nil } @@ -382,7 +382,7 @@ func replaceAgg(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope, Order: sql.Descending, } case *aggregation.Min: - gf, ok := agg.UnaryExpression.Child.(*expression.GetField) + gf, ok := agg.Child.(*expression.GetField) if !ok { return n, transform.SameTree, nil } diff --git a/sql/analyzer/symbol_resolution.go b/sql/analyzer/symbol_resolution.go index 71827c3cd3..2c9ad68c24 100644 --- a/sql/analyzer/symbol_resolution.go +++ b/sql/analyzer/symbol_resolution.go @@ -178,12 +178,22 @@ func findSubqueryExpr(n sql.Node) *plan.Subquery { // hasMatchAgainstExpr searches for an *expression.MatchAgainst within the node's expressions func hasMatchAgainstExpr(node sql.Node) bool { var foundMatchAgainstExpr bool - transform.InspectExpressions(node, func(expr sql.Expression) bool { - _, isMatchAgainstExpr := expr.(*expression.MatchAgainst) - if isMatchAgainstExpr { - foundMatchAgainstExpr = true + transform.Inspect(node, func(n sql.Node) (cont bool) { + if ne, ok := n.(sql.Expressioner); ok { + for _, expr := range ne.Expressions() { + stop := transform.InspectExpr(expr, func(e sql.Expression) (stop bool) { + if _, isMatchAgainst := e.(*expression.MatchAgainst); isMatchAgainst { + foundMatchAgainstExpr = true + return true + } + return false + }) + if stop { + return false + } + } } - return !foundMatchAgainstExpr + return true }) return foundMatchAgainstExpr } diff --git a/sql/expression/alias.go b/sql/expression/alias.go index ea587555c9..1739bdb5f0 100644 --- a/sql/expression/alias.go +++ b/sql/expression/alias.go @@ -80,7 +80,7 @@ var _ sql.CollationCoercible = (*AliasReference)(nil) // Alias is a node that gives a name to an expression. type Alias struct { - UnaryExpression + UnaryExpressionStub name string unreferencable bool id sql.ColumnId @@ -92,7 +92,7 @@ var _ sql.CollationCoercible = (*Alias)(nil) // NewAlias returns a new Alias node. func NewAlias(name string, expr sql.Expression) *Alias { - return &Alias{UnaryExpression{expr}, name, false, 0} + return &Alias{UnaryExpressionStub{expr}, name, false, 0} } // AsUnreferencable marks the alias outside of scope referencing diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index dc42d6a51d..a64a94c8f3 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -681,7 +681,7 @@ func mult(lval, rval interface{}) (interface{}, error) { // UnaryMinus is an unary minus operator. type UnaryMinus struct { - UnaryExpression + UnaryExpressionStub } var _ sql.Expression = (*UnaryMinus)(nil) @@ -689,7 +689,7 @@ var _ sql.CollationCoercible = (*UnaryMinus)(nil) // NewUnaryMinus creates a new UnaryMinus expression node. func NewUnaryMinus(child sql.Expression) *UnaryMinus { - return &UnaryMinus{UnaryExpression{Child: child}} + return &UnaryMinus{UnaryExpressionStub{Child: child}} } // Eval implements the sql.Expression interface. diff --git a/sql/expression/auto_increment.go b/sql/expression/auto_increment.go index bd31a1651a..4afb947bd2 100644 --- a/sql/expression/auto_increment.go +++ b/sql/expression/auto_increment.go @@ -31,7 +31,7 @@ var ( // AutoIncrement implements AUTO_INCREMENT type AutoIncrement struct { - UnaryExpression + UnaryExpressionStub autoTbl sql.AutoIncrementTable autoCol *sql.Column } @@ -58,7 +58,7 @@ func NewAutoIncrement(ctx *sql.Context, table sql.Table, given sql.Expression) ( } return &AutoIncrement{ - UnaryExpression{Child: given}, + UnaryExpressionStub{Child: given}, autoTbl, autoCol, }, nil @@ -72,7 +72,7 @@ func NewAutoIncrementForColumn(ctx *sql.Context, table sql.Table, autoCol *sql.C } return &AutoIncrement{ - UnaryExpression{Child: given}, + UnaryExpressionStub{Child: given}, autoTbl, autoCol, }, nil @@ -159,7 +159,7 @@ func (i *AutoIncrement) WithChildren(children ...sql.Expression) (sql.Expression return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1) } return &AutoIncrement{ - UnaryExpression{Child: children[0]}, + UnaryExpressionStub{Child: children[0]}, i.autoTbl, i.autoCol, }, nil diff --git a/sql/expression/auto_uuid.go b/sql/expression/auto_uuid.go index cce8111cb9..836f8e76aa 100644 --- a/sql/expression/auto_uuid.go +++ b/sql/expression/auto_uuid.go @@ -26,7 +26,7 @@ import ( // AutoUuid is an expression that captures an automatically generated UUID value and stores it in the session for // later retrieval. AutoUuid is intended to only be used directly on top of a UUID function. type AutoUuid struct { - UnaryExpression + UnaryExpressionStub uuidCol *sql.Column foundUuid bool } @@ -38,8 +38,8 @@ var _ sql.CollationCoercible = (*AutoUuid)(nil) // because of package import cycles, we can't enforce that directly here. func NewAutoUuid(_ *sql.Context, col *sql.Column, child sql.Expression) *AutoUuid { return &AutoUuid{ - UnaryExpression: UnaryExpression{Child: child}, - uuidCol: col, + UnaryExpressionStub: UnaryExpressionStub{Child: child}, + uuidCol: col, } } @@ -94,9 +94,9 @@ func (au *AutoUuid) WithChildren(children ...sql.Expression) (sql.Expression, er return nil, sql.ErrInvalidChildrenNumber.New(au, len(children), 1) } return &AutoUuid{ - UnaryExpression: UnaryExpression{Child: children[0]}, - uuidCol: au.uuidCol, - foundUuid: au.foundUuid, + UnaryExpressionStub: UnaryExpressionStub{Child: children[0]}, + uuidCol: au.uuidCol, + foundUuid: au.foundUuid, }, nil } diff --git a/sql/expression/binary.go b/sql/expression/binary.go index a7ae8c0507..efdd2ffea2 100644 --- a/sql/expression/binary.go +++ b/sql/expression/binary.go @@ -28,14 +28,14 @@ import ( // // cc: https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#operator_binary type Binary struct { - UnaryExpression + UnaryExpressionStub } var _ sql.Expression = (*Binary)(nil) var _ sql.CollationCoercible = (*Binary)(nil) func NewBinary(e sql.Expression) sql.Expression { - return &Binary{UnaryExpression{Child: e}} + return &Binary{UnaryExpressionStub{Child: e}} } func (b *Binary) String() string { diff --git a/sql/expression/boolean.go b/sql/expression/boolean.go index 85b8a09d42..328ab805da 100644 --- a/sql/expression/boolean.go +++ b/sql/expression/boolean.go @@ -23,7 +23,7 @@ import ( // Not is a node that negates an expression. type Not struct { - UnaryExpression + UnaryExpressionStub } var _ sql.Expression = (*Not)(nil) @@ -31,7 +31,7 @@ var _ sql.CollationCoercible = (*Not)(nil) // NewNot returns a new Not node. func NewNot(child sql.Expression) *Not { - return &Not{UnaryExpression{child}} + return &Not{UnaryExpressionStub{child}} } // Type implements the Expression interface. diff --git a/sql/expression/common.go b/sql/expression/common.go index 180cfed5fa..cae38b42f9 100644 --- a/sql/expression/common.go +++ b/sql/expression/common.go @@ -32,32 +32,36 @@ func IsBinary(e sql.Expression) bool { return len(e.Children()) == 2 } -// UnaryExpression is an expression that has only one child. -type UnaryExpression struct { +type UnaryExpression interface { + sql.Expression + UnaryChild() sql.Expression +} + +// UnaryExpressionStub is an expression that has only one child. +type UnaryExpressionStub struct { Child sql.Expression } +// UnaryChild implements the UnaryExpression interface. +func (p *UnaryExpressionStub) UnaryChild() sql.Expression { + return p.Child +} + // Children implements the Expression interface. -func (p *UnaryExpression) Children() []sql.Expression { +func (p *UnaryExpressionStub) Children() []sql.Expression { return []sql.Expression{p.Child} } // Resolved implements the Expression interface. -func (p *UnaryExpression) Resolved() bool { +func (p *UnaryExpressionStub) Resolved() bool { return p.Child.Resolved() } // IsNullable returns whether the expression can be null. -func (p *UnaryExpression) IsNullable() bool { +func (p *UnaryExpressionStub) IsNullable() bool { return p.Child.IsNullable() } -// BinaryExpressionStub is an expression that has two children. -type BinaryExpressionStub struct { - LeftChild sql.Expression - RightChild sql.Expression -} - // BinaryExpression is an expression that has two children type BinaryExpression interface { sql.Expression @@ -65,6 +69,12 @@ type BinaryExpression interface { Right() sql.Expression } +// BinaryExpressionStub is an expression that has two children. +type BinaryExpressionStub struct { + LeftChild sql.Expression + RightChild sql.Expression +} + func (p *BinaryExpressionStub) Left() sql.Expression { return p.LeftChild } diff --git a/sql/expression/convert.go b/sql/expression/convert.go index cb1a663456..413f072d58 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -64,7 +64,7 @@ const ( // Convert represent a CAST(x AS T) or CONVERT(x, T) operation that casts x expression to type T. type Convert struct { - UnaryExpression + UnaryExpressionStub // cachedDecimalType is the cached Decimal type for this convert expression. Because new Decimal types // must be created with their specific scale and precision values, unlike other types, we cache the created @@ -88,8 +88,8 @@ var _ sql.CollationCoercible = (*Convert)(nil) func NewConvert(expr sql.Expression, castToType string) *Convert { disableRounding(expr) return &Convert{ - UnaryExpression: UnaryExpression{Child: expr}, - castToType: strings.ToLower(castToType), + UnaryExpressionStub: UnaryExpressionStub{Child: expr}, + castToType: strings.ToLower(castToType), } } @@ -99,10 +99,10 @@ func NewConvert(expr sql.Expression, castToType string) *Convert { func NewConvertWithLengthAndScale(expr sql.Expression, castToType string, typeLength, typeScale int) *Convert { disableRounding(expr) return &Convert{ - UnaryExpression: UnaryExpression{Child: expr}, - castToType: strings.ToLower(castToType), - typeLength: typeLength, - typeScale: typeScale, + UnaryExpressionStub: UnaryExpressionStub{Child: expr}, + castToType: strings.ToLower(castToType), + typeLength: typeLength, + typeScale: typeScale, } } diff --git a/sql/expression/convertusing.go b/sql/expression/convertusing.go index 213c2de440..f5258fa0f2 100644 --- a/sql/expression/convertusing.go +++ b/sql/expression/convertusing.go @@ -23,7 +23,7 @@ import ( // ConvertUsing represents a CONVERT(X USING T) operation that casts the expression X to the character set T. type ConvertUsing struct { - UnaryExpression + UnaryExpressionStub TargetCharSet sql.CharacterSetID } @@ -32,8 +32,8 @@ var _ sql.CollationCoercible = (*ConvertUsing)(nil) func NewConvertUsing(expr sql.Expression, targetCharSet sql.CharacterSetID) *ConvertUsing { return &ConvertUsing{ - UnaryExpression: UnaryExpression{Child: expr}, - TargetCharSet: targetCharSet, + UnaryExpressionStub: UnaryExpressionStub{Child: expr}, + TargetCharSet: targetCharSet, } } diff --git a/sql/expression/function/absval.go b/sql/expression/function/absval.go index 1e0caa712b..c848e059e3 100644 --- a/sql/expression/function/absval.go +++ b/sql/expression/function/absval.go @@ -25,7 +25,7 @@ import ( // AbsVal is a function that takes the absolute value of a number type AbsVal struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*AbsVal)(nil) @@ -33,7 +33,7 @@ var _ sql.CollationCoercible = (*AbsVal)(nil) // NewAbsVal creates a new AbsVal expression. func NewAbsVal(e sql.Expression) sql.Expression { - return &AbsVal{expression.UnaryExpression{Child: e}} + return &AbsVal{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/aggregation/common.go b/sql/expression/function/aggregation/common.go index f9c88af7f0..2765454db2 100644 --- a/sql/expression/function/aggregation/common.go +++ b/sql/expression/function/aggregation/common.go @@ -26,7 +26,7 @@ var ErrEvalUnsupportedOnAggregation = errors.NewKind("Unimplemented %s.Eval(). T // unaryAggBase is the generic embedded class optgen // uses to codegen single expression aggregate functions. type unaryAggBase struct { - expression.UnaryExpression + Child sql.Expression typ sql.Type window *sql.WindowDefinition functionName string @@ -76,6 +76,11 @@ func (a *unaryAggBase) WithId(id sql.ColumnId) sql.IdExpression { return &ret } +// IsNullable returns whether the expression can be null. +func (a *unaryAggBase) IsNullable() bool { + return a.Child.IsNullable() +} + // CollationCoercibility implements the interface sql.CollationCoercible. func (a *unaryAggBase) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, a.Child) @@ -96,7 +101,8 @@ func (a *unaryAggBase) Children() []sql.Expression { func (a *unaryAggBase) Resolved() bool { if _, ok := a.Child.(*expression.Star); ok { return true - } else if !a.Child.Resolved() { + } + if !a.Child.Resolved() { return false } if a.window == nil { @@ -112,7 +118,7 @@ func (a *unaryAggBase) WithChildren(children ...sql.Expression) (sql.Expression, } na := *a - na.UnaryExpression = expression.UnaryExpression{Child: children[0]} + na.Child = children[0] if len(children) > 1 && a.window != nil { w, err := a.window.FromExpressions(children[1:]) if err != nil { diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index 700ad0e47d..3c4f3654f0 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -22,9 +21,9 @@ var _ sql.WindowAdaptableExpression = (*AnyValue)(nil) func NewAnyValue(e sql.Expression) *AnyValue { return &AnyValue{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "AnyValue", - description: "returns any single value in the grouped rows", + Child: e, + functionName: "AnyValue", + description: "returns any single value in the grouped rows", }, } } @@ -101,9 +100,9 @@ var _ sql.WindowAdaptableExpression = (*Avg)(nil) func NewAvg(e sql.Expression) *Avg { return &Avg{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "Avg", - description: "returns the average value of expr in all rows.", + Child: e, + functionName: "Avg", + description: "returns the average value of expr in all rows.", }, } } @@ -180,9 +179,9 @@ var _ sql.WindowAdaptableExpression = (*BitAnd)(nil) func NewBitAnd(e sql.Expression) *BitAnd { return &BitAnd{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "BitAnd", - description: "returns the bitwise AND of all bits in expr.", + Child: e, + functionName: "BitAnd", + description: "returns the bitwise AND of all bits in expr.", }, } } @@ -259,9 +258,9 @@ var _ sql.WindowAdaptableExpression = (*BitOr)(nil) func NewBitOr(e sql.Expression) *BitOr { return &BitOr{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "BitOr", - description: "returns the bitwise OR of all bits in expr.", + Child: e, + functionName: "BitOr", + description: "returns the bitwise OR of all bits in expr.", }, } } @@ -338,9 +337,9 @@ var _ sql.WindowAdaptableExpression = (*BitXor)(nil) func NewBitXor(e sql.Expression) *BitXor { return &BitXor{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "BitXor", - description: "returns the bitwise XOR of all bits in expr.", + Child: e, + functionName: "BitXor", + description: "returns the bitwise XOR of all bits in expr.", }, } } @@ -417,9 +416,9 @@ var _ sql.WindowAdaptableExpression = (*Count)(nil) func NewCount(e sql.Expression) *Count { return &Count{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "Count", - description: "returns a count of the number of non-NULL values of expr in the rows retrieved by a SELECT statement.", + Child: e, + functionName: "Count", + description: "returns a count of the number of non-NULL values of expr in the rows retrieved by a SELECT statement.", }, } } @@ -496,9 +495,9 @@ var _ sql.WindowAdaptableExpression = (*First)(nil) func NewFirst(e sql.Expression) *First { return &First{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "First", - description: "returns the first value in a sequence of elements of an aggregation.", + Child: e, + functionName: "First", + description: "returns the first value in a sequence of elements of an aggregation.", }, } } @@ -575,9 +574,9 @@ var _ sql.WindowAdaptableExpression = (*JsonArray)(nil) func NewJsonArray(e sql.Expression) *JsonArray { return &JsonArray{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "JsonArray", - description: "returns result set as a single JSON array.", + Child: e, + functionName: "JsonArray", + description: "returns result set as a single JSON array.", }, } } @@ -654,9 +653,9 @@ var _ sql.WindowAdaptableExpression = (*Last)(nil) func NewLast(e sql.Expression) *Last { return &Last{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "Last", - description: "returns the last value in a sequence of elements of an aggregation.", + Child: e, + functionName: "Last", + description: "returns the last value in a sequence of elements of an aggregation.", }, } } @@ -733,9 +732,9 @@ var _ sql.WindowAdaptableExpression = (*Max)(nil) func NewMax(e sql.Expression) *Max { return &Max{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "Max", - description: "returns the maximum value of expr in all rows.", + Child: e, + functionName: "Max", + description: "returns the maximum value of expr in all rows.", }, } } @@ -812,9 +811,9 @@ var _ sql.WindowAdaptableExpression = (*Min)(nil) func NewMin(e sql.Expression) *Min { return &Min{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "Min", - description: "returns the minimum value of expr in all rows.", + Child: e, + functionName: "Min", + description: "returns the minimum value of expr in all rows.", }, } } @@ -891,9 +890,9 @@ var _ sql.WindowAdaptableExpression = (*Sum)(nil) func NewSum(e sql.Expression) *Sum { return &Sum{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "Sum", - description: "returns the sum of expr in all rows", + Child: e, + functionName: "Sum", + description: "returns the sum of expr in all rows", }, } } @@ -970,9 +969,9 @@ var _ sql.WindowAdaptableExpression = (*StdDevPop)(nil) func NewStdDevPop(e sql.Expression) *StdDevPop { return &StdDevPop{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "StdDevPop", - description: "returns the population standard deviation of expr", + Child: e, + functionName: "StdDevPop", + description: "returns the population standard deviation of expr", }, } } @@ -1049,9 +1048,9 @@ var _ sql.WindowAdaptableExpression = (*StdDevSamp)(nil) func NewStdDevSamp(e sql.Expression) *StdDevSamp { return &StdDevSamp{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "StdDevSamp", - description: "returns the sample standard deviation of expr", + Child: e, + functionName: "StdDevSamp", + description: "returns the sample standard deviation of expr", }, } } @@ -1128,9 +1127,9 @@ var _ sql.WindowAdaptableExpression = (*VarPop)(nil) func NewVarPop(e sql.Expression) *VarPop { return &VarPop{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "VarPop", - description: "returns the population variance of expr", + Child: e, + functionName: "VarPop", + description: "returns the population variance of expr", }, } } @@ -1207,9 +1206,9 @@ var _ sql.WindowAdaptableExpression = (*VarSamp)(nil) func NewVarSamp(e sql.Expression) *VarSamp { return &VarSamp{ unaryAggBase{ - UnaryExpression: expression.UnaryExpression{Child: e}, - functionName: "VarSamp", - description: "returns the sample variance of expr", + Child: e, + functionName: "VarSamp", + description: "returns the sample variance of expr", }, } } diff --git a/sql/expression/function/aggregation/window/first_value.go b/sql/expression/function/aggregation/window/first_value.go index efddda37c4..0aaee9306d 100644 --- a/sql/expression/function/aggregation/window/first_value.go +++ b/sql/expression/function/aggregation/window/first_value.go @@ -21,15 +21,14 @@ import ( "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" ) type FirstValue struct { + Child sql.Expression window *sql.WindowDefinition - expression.UnaryExpression - pos int - id sql.ColumnId + pos int + id sql.ColumnId } var _ sql.FunctionExpression = (*FirstValue)(nil) @@ -38,7 +37,7 @@ var _ sql.WindowAdaptableExpression = (*FirstValue)(nil) var _ sql.CollationCoercible = (*FirstValue)(nil) func NewFirstValue(e sql.Expression) sql.Expression { - return &FirstValue{UnaryExpression: expression.UnaryExpression{Child: e}} + return &FirstValue{Child: e} } // Id implements sql.IdExpression @@ -63,7 +62,7 @@ func (f *FirstValue) Window() *sql.WindowDefinition { return f.window } -// IsNullable implements sql.Expression +// Resolved implements sql.Expression func (f *FirstValue) Resolved() bool { return windowResolved(f.window) } diff --git a/sql/expression/function/aggregation/window/last_value.go b/sql/expression/function/aggregation/window/last_value.go index 89d0161eca..2350f80933 100644 --- a/sql/expression/function/aggregation/window/last_value.go +++ b/sql/expression/function/aggregation/window/last_value.go @@ -21,15 +21,14 @@ import ( "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" ) type LastValue struct { + Child sql.Expression window *sql.WindowDefinition - expression.UnaryExpression - pos int - id sql.ColumnId + pos int + id sql.ColumnId } var _ sql.FunctionExpression = (*LastValue)(nil) @@ -38,7 +37,7 @@ var _ sql.WindowAdaptableExpression = (*LastValue)(nil) var _ sql.CollationCoercible = (*LastValue)(nil) func NewLastValue(e sql.Expression) sql.Expression { - return &LastValue{window: nil, UnaryExpression: expression.UnaryExpression{Child: e}} + return &LastValue{window: nil, Child: e} } // Id implements sql.IdExpression @@ -63,7 +62,7 @@ func (f *LastValue) Window() *sql.WindowDefinition { return f.window } -// IsNullable implements sql.Expression +// Resolved implements sql.Expression func (f *LastValue) Resolved() bool { return windowResolved(f.window) } diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index b289dcfcb7..13b6fa2b9f 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -47,7 +47,7 @@ func numericRetType(inputType sql.Type) sql.Type { // Ceil returns the smallest integer value not less than X. type Ceil struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Ceil)(nil) @@ -55,7 +55,7 @@ var _ sql.CollationCoercible = (*Ceil)(nil) // NewCeil creates a new Ceil expression. func NewCeil(num sql.Expression) sql.Expression { - return &Ceil{expression.UnaryExpression{Child: num}} + return &Ceil{expression.UnaryExpressionStub{Child: num}} } // FunctionName implements sql.FunctionExpression @@ -130,7 +130,7 @@ func (c *Ceil) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // Floor returns the biggest integer value not less than X. type Floor struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Floor)(nil) @@ -138,7 +138,7 @@ var _ sql.CollationCoercible = (*Floor)(nil) // NewFloor returns a new Floor expression. func NewFloor(num sql.Expression) sql.Expression { - return &Floor{expression.UnaryExpression{Child: num}} + return &Floor{expression.UnaryExpressionStub{Child: num}} } // FunctionName implements sql.FunctionExpression @@ -216,7 +216,8 @@ func (f *Floor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // digits of it's integer part set to 0. If d is not specified or nil/null // it defaults to 0. type Round struct { - expression.BinaryExpressionStub + Num sql.Expression + Dec sql.Expression } var _ sql.FunctionExpression = (*Round)(nil) @@ -225,16 +226,19 @@ var _ sql.CollationCoercible = (*Round)(nil) // NewRound returns a new Round expression. func NewRound(args ...sql.Expression) (sql.Expression, error) { argLen := len(args) - if argLen == 0 || argLen > 2 { + switch argLen { + case 1: + return &Round{ + Num: args[0], + }, nil + case 2: + return &Round{ + Num: args[0], + Dec: args[1], + }, nil + default: return nil, sql.ErrInvalidArgumentNumber.New("ROUND", "1 or 2", argLen) } - - var right sql.Expression - if len(args) == 2 { - right = args[1] - } - - return &Round{expression.BinaryExpressionStub{LeftChild: args[0], RightChild: right}}, nil } // FunctionName implements sql.FunctionExpression @@ -249,16 +253,15 @@ func (r *Round) Description() string { // Children implements the Expression interface. func (r *Round) Children() []sql.Expression { - if r.RightChild == nil { - return []sql.Expression{r.LeftChild} + if r.Dec == nil { + return []sql.Expression{r.Num} } - - return r.BinaryExpressionStub.Children() + return []sql.Expression{r.Num, r.Dec} } // Eval implements the Expression interface. func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - val, err := r.LeftChild.Eval(ctx, row) + val, err := r.Num.Eval(ctx, row) if err != nil { return nil, err } @@ -272,9 +275,9 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } prec := int32(0) - if r.RightChild != nil { + if r.Dec != nil { var tmp any - tmp, err = r.RightChild.Eval(ctx, row) + tmp, err = r.Dec.Eval(ctx, row) if err != nil { return nil, err } @@ -301,7 +304,7 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { var res interface{} tmp := val.(decimal.Decimal).Round(prec) - lType := r.LeftChild.Type() + lType := r.Num.Type() if types.IsSigned(lType) { res, _, err = types.Int64.Convert(ctx, tmp) } else if types.IsUnsigned(lType) { @@ -322,25 +325,24 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // IsNullable implements the Expression interface. func (r *Round) IsNullable() bool { - return r.LeftChild.IsNullable() + return r.Num.IsNullable() && (r == nil || r.Dec.IsNullable()) } func (r *Round) String() string { - if r.RightChild == nil { - return fmt.Sprintf("%s(%s,0)", r.FunctionName(), r.LeftChild.String()) + if r.Dec == nil { + return fmt.Sprintf("%s(%s,0)", r.FunctionName(), r.Num.String()) } - - return fmt.Sprintf("%s(%s,%s)", r.FunctionName(), r.LeftChild.String(), r.RightChild.String()) + return fmt.Sprintf("%s(%s,%s)", r.FunctionName(), r.Num.String(), r.Dec.String()) } // Resolved implements the Expression interface. func (r *Round) Resolved() bool { - return r.LeftChild.Resolved() && (r.RightChild == nil || r.RightChild.Resolved()) + return r.Num.Resolved() && (r.Dec == nil || r.Dec.Resolved()) } // Type implements the Expression interface. func (r *Round) Type() sql.Type { - return numericRetType(r.LeftChild.Type()) + return numericRetType(r.Num.Type()) } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/function/collation.go b/sql/expression/function/collation.go index efcf8255d1..d6248d277b 100644 --- a/sql/expression/function/collation.go +++ b/sql/expression/function/collation.go @@ -25,7 +25,7 @@ import ( // Collation is a function that returns the collation of the inner expression. type Collation struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Collation)(nil) @@ -33,7 +33,7 @@ var _ sql.CollationCoercible = (*Collation)(nil) // NewCollation creates a new Collation expression. func NewCollation(e sql.Expression) sql.Expression { - return &Collation{expression.UnaryExpression{Child: e}} + return &Collation{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression @@ -86,7 +86,7 @@ func (*Collation) CollationCoercibility(ctx *sql.Context) (collation sql.Collati // Coercibility is a function that returns the coercibility of the inner expression. type Coercibility struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Coercibility)(nil) @@ -94,7 +94,7 @@ var _ sql.CollationCoercible = (*Coercibility)(nil) // NewCoercibility creates a new Coercibility expression. func NewCoercibility(e sql.Expression) sql.Expression { - return &Coercibility{expression.UnaryExpression{Child: e}} + return &Coercibility{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression @@ -147,7 +147,7 @@ func (*Coercibility) CollationCoercibility(ctx *sql.Context) (collation sql.Coll // Charset is a function that returns the character set of the inner expression. type Charset struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Charset)(nil) @@ -155,7 +155,7 @@ var _ sql.CollationCoercible = (*Charset)(nil) // NewCharset creates a new Charset expression. func NewCharset(e sql.Expression) sql.Expression { - return &Charset{expression.UnaryExpression{Child: e}} + return &Charset{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/days.go b/sql/expression/function/days.go index 5208b42dd7..fc633b4ad2 100644 --- a/sql/expression/function/days.go +++ b/sql/expression/function/days.go @@ -25,7 +25,7 @@ import ( // ToDays is a function that converts a date to a number of days since year 0. type ToDays struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*ToDays)(nil) @@ -33,7 +33,7 @@ var _ sql.CollationCoercible = (*ToDays)(nil) // NewToDays creates a new ToDays function. func NewToDays(date sql.Expression) sql.Expression { - return &ToDays{expression.UnaryExpression{Child: date}} + return &ToDays{expression.UnaryExpressionStub{Child: date}} } // CollationCoercibility implements sql.CollationCoercible @@ -114,7 +114,7 @@ func (t *ToDays) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // FromDays is a function that returns date for a given number of days since year 0. type FromDays struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*FromDays)(nil) @@ -122,7 +122,7 @@ var _ sql.CollationCoercible = (*FromDays)(nil) // NewFromDays creates a new FromDays function. func NewFromDays(days sql.Expression) sql.Expression { - return &FromDays{expression.UnaryExpression{Child: days}} + return &FromDays{expression.UnaryExpressionStub{Child: days}} } // CollationCoercibility implements sql.CollationCoercible @@ -238,7 +238,7 @@ func (f *FromDays) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // LastDay is a function that returns the date at the last day of the month. type LastDay struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*LastDay)(nil) @@ -246,7 +246,7 @@ var _ sql.CollationCoercible = (*LastDay)(nil) // NewLastDay creates a new LastDay function. func NewLastDay(date sql.Expression) sql.Expression { - return &LastDay{expression.UnaryExpression{Child: date}} + return &LastDay{expression.UnaryExpressionStub{Child: date}} } // CollationCoercibility implements sql.CollationCoercible diff --git a/sql/expression/function/function.go b/sql/expression/function/function.go index 80a55cce1d..a5f1a6fbe0 100644 --- a/sql/expression/function/function.go +++ b/sql/expression/function/function.go @@ -23,7 +23,7 @@ import ( ) type UnaryFunc struct { - expression.UnaryExpression + expression.UnaryExpressionStub // The type returned by the function RetType sql.Type // Name is the name of the function @@ -32,9 +32,9 @@ type UnaryFunc struct { func NewUnaryFunc(arg sql.Expression, name string, returnType sql.Type) *UnaryFunc { return &UnaryFunc{ - UnaryExpression: expression.UnaryExpression{Child: arg}, - Name: name, - RetType: returnType, + UnaryExpressionStub: expression.UnaryExpressionStub{Child: arg}, + Name: name, + RetType: returnType, } } diff --git a/sql/expression/function/inet_convert.go b/sql/expression/function/inet_convert.go index 5ec844383a..d78ca94358 100644 --- a/sql/expression/function/inet_convert.go +++ b/sql/expression/function/inet_convert.go @@ -28,14 +28,14 @@ import ( ) type InetAton struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*InetAton)(nil) var _ sql.CollationCoercible = (*InetAton)(nil) func NewInetAton(val sql.Expression) sql.Expression { - return &InetAton{expression.UnaryExpression{Child: val}} + return &InetAton{expression.UnaryExpressionStub{Child: val}} } // FunctionName implements sql.FunctionExpression @@ -108,14 +108,14 @@ func (i *InetAton) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } type Inet6Aton struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Inet6Aton)(nil) var _ sql.CollationCoercible = (*Inet6Aton)(nil) func NewInet6Aton(val sql.Expression) sql.Expression { - return &Inet6Aton{expression.UnaryExpression{Child: val}} + return &Inet6Aton{expression.UnaryExpressionStub{Child: val}} } // FunctionName implements sql.FunctionExpression @@ -188,14 +188,14 @@ func (i *Inet6Aton) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } type InetNtoa struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*InetNtoa)(nil) var _ sql.CollationCoercible = (*InetNtoa)(nil) func NewInetNtoa(val sql.Expression) sql.Expression { - return &InetNtoa{expression.UnaryExpression{Child: val}} + return &InetNtoa{expression.UnaryExpressionStub{Child: val}} } // FunctionName implements sql.FunctionExpression @@ -261,14 +261,14 @@ func (i *InetNtoa) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } type Inet6Ntoa struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Inet6Ntoa)(nil) var _ sql.CollationCoercible = (*Inet6Ntoa)(nil) func NewInet6Ntoa(val sql.Expression) sql.Expression { - return &Inet6Ntoa{expression.UnaryExpression{Child: val}} + return &Inet6Ntoa{expression.UnaryExpressionStub{Child: val}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/is_ip.go b/sql/expression/function/is_ip.go index 70fcd9c5e9..9171905cd2 100644 --- a/sql/expression/function/is_ip.go +++ b/sql/expression/function/is_ip.go @@ -26,14 +26,14 @@ import ( ) type IsIPv4 struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*IsIPv4)(nil) var _ sql.CollationCoercible = (*IsIPv4)(nil) func NewIsIPv4(val sql.Expression) sql.Expression { - return &IsIPv4{expression.UnaryExpression{Child: val}} + return &IsIPv4{expression.UnaryExpressionStub{Child: val}} } // FunctionName implements sql.FunctionExpression @@ -94,14 +94,14 @@ func (i *IsIPv4) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } type IsIPv6 struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*IsIPv6)(nil) var _ sql.CollationCoercible = (*IsIPv6)(nil) func NewIsIPv6(val sql.Expression) sql.Expression { - return &IsIPv6{expression.UnaryExpression{Child: val}} + return &IsIPv6{expression.UnaryExpressionStub{Child: val}} } // FunctionName implements sql.FunctionExpression @@ -162,14 +162,14 @@ func (i *IsIPv6) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } type IsIPv4Compat struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*IsIPv4Compat)(nil) var _ sql.CollationCoercible = (*IsIPv4Compat)(nil) func NewIsIPv4Compat(val sql.Expression) sql.Expression { - return &IsIPv4Compat{expression.UnaryExpression{Child: val}} + return &IsIPv4Compat{expression.UnaryExpressionStub{Child: val}} } // FunctionName implements sql.FunctionExpression @@ -234,14 +234,14 @@ func (i *IsIPv4Compat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } type IsIPv4Mapped struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*IsIPv4Mapped)(nil) var _ sql.CollationCoercible = (*IsIPv4Mapped)(nil) func NewIsIPv4Mapped(val sql.Expression) sql.Expression { - return &IsIPv4Mapped{expression.UnaryExpression{Child: val}} + return &IsIPv4Mapped{expression.UnaryExpressionStub{Child: val}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/isbinary.go b/sql/expression/function/isbinary.go index 70b3119b9f..b3af41f269 100644 --- a/sql/expression/function/isbinary.go +++ b/sql/expression/function/isbinary.go @@ -25,7 +25,7 @@ import ( // IsBinary is a function that returns whether a blob is binary or not. type IsBinary struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*IsBinary)(nil) @@ -33,7 +33,7 @@ var _ sql.CollationCoercible = (*IsBinary)(nil) // NewIsBinary creates a new IsBinary expression. func NewIsBinary(e sql.Expression) sql.Expression { - return &IsBinary{expression.UnaryExpression{Child: e}} + return &IsBinary{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/isnull.go b/sql/expression/function/isnull.go index 55ffa22040..51016e7d05 100644 --- a/sql/expression/function/isnull.go +++ b/sql/expression/function/isnull.go @@ -24,7 +24,7 @@ import ( // IsNull is a function that returns whether a value is null or not. type IsNull struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*IsNull)(nil) @@ -32,7 +32,7 @@ var _ sql.CollationCoercible = (*IsNull)(nil) // NewIsNull creates a new IsNull expression. func NewIsNull(e sql.Expression) sql.Expression { - return &IsNull{expression.UnaryExpression{Child: e}} + return &IsNull{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/json/json_pretty.go b/sql/expression/function/json/json_pretty.go index 627dde26ce..6eceb0c0b1 100644 --- a/sql/expression/function/json/json_pretty.go +++ b/sql/expression/function/json/json_pretty.go @@ -41,14 +41,14 @@ import ( // // https://dev.mysql.com/doc/refman/8.0/en/json-utility-functions.html#function_json-pretty type JSONPretty struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = &JSONPretty{} // NewJSONPretty creates a new JSONPretty function. func NewJSONPretty(arg sql.Expression) sql.Expression { - return &JSONPretty{expression.UnaryExpression{Child: arg}} + return &JSONPretty{expression.UnaryExpressionStub{Child: arg}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/json/json_quote.go b/sql/expression/function/json/json_quote.go index c3f193d2ae..e88443975c 100644 --- a/sql/expression/function/json/json_quote.go +++ b/sql/expression/function/json/json_quote.go @@ -35,7 +35,7 @@ import ( // // https://dev.mysql.com/doc/refman/8.0/en/json-creation-functions.html#function_json-quote type JSONQuote struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*JSONQuote)(nil) @@ -43,7 +43,7 @@ var _ sql.CollationCoercible = (*JSONQuote)(nil) // NewJSONQuote creates a new JSONQuote UDF. func NewJSONQuote(json sql.Expression) sql.Expression { - return &JSONQuote{expression.UnaryExpression{Child: json}} + return &JSONQuote{expression.UnaryExpressionStub{Child: json}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/json/json_unquote.go b/sql/expression/function/json/json_unquote.go index 32c619b664..f2663af0b7 100644 --- a/sql/expression/function/json/json_unquote.go +++ b/sql/expression/function/json/json_unquote.go @@ -28,7 +28,7 @@ import ( // Returns NULL if the argument is NULL. // An error occurs if the value starts and ends with double quotes but is not a valid JSON string literal. type JSONUnquote struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*JSONUnquote)(nil) @@ -36,7 +36,7 @@ var _ sql.CollationCoercible = (*JSONUnquote)(nil) // NewJSONUnquote creates a new JSONUnquote UDF. func NewJSONUnquote(json sql.Expression) sql.Expression { - return &JSONUnquote{expression.UnaryExpression{Child: json}} + return &JSONUnquote{expression.UnaryExpressionStub{Child: json}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/length.go b/sql/expression/function/length.go index 9c1d5833f9..b8b079ccf9 100644 --- a/sql/expression/function/length.go +++ b/sql/expression/function/length.go @@ -28,7 +28,7 @@ import ( // Length returns the length of a string or binary content, either in bytes // or characters. type Length struct { - expression.UnaryExpression + expression.UnaryExpressionStub CountType CountType } @@ -47,12 +47,12 @@ const ( // NewLength returns a new LENGTH function. func NewLength(e sql.Expression) sql.Expression { - return &Length{expression.UnaryExpression{Child: e}, NumBytes} + return &Length{expression.UnaryExpressionStub{Child: e}, NumBytes} } // NewCharLength returns a new CHAR_LENGTH function. func NewCharLength(e sql.Expression) sql.Expression { - return &Length{expression.UnaryExpression{Child: e}, NumChars} + return &Length{expression.UnaryExpressionStub{Child: e}, NumChars} } // FunctionName implements sql.FunctionExpression @@ -83,7 +83,7 @@ func (l *Length) WithChildren(children ...sql.Expression) (sql.Expression, error return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - return &Length{expression.UnaryExpression{Child: children[0]}, l.CountType}, nil + return &Length{expression.UnaryExpressionStub{Child: children[0]}, l.CountType}, nil } // Type implements the sql.Expression interface. diff --git a/sql/expression/function/locks.go b/sql/expression/function/locks.go index 8293191947..cf17cdc3d8 100644 --- a/sql/expression/function/locks.go +++ b/sql/expression/function/locks.go @@ -46,7 +46,7 @@ func (nl *NamedLockFunction) evalLockLogic(ctx *sql.Context, fn lockFuncLogic, r // NamedLockFunction is a sql function that takes just the name of a lock as an argument type NamedLockFunction struct { - expression.UnaryExpression + expression.UnaryExpressionStub retType sql.Type ls *sql.LockSubsystem funcName string @@ -128,10 +128,10 @@ func NewIsFreeLock(ls *sql.LockSubsystem) sql.CreateFunc1Args { return func(e sql.Expression) sql.Expression { return &IsFreeLock{ NamedLockFunction: NamedLockFunction{ - UnaryExpression: expression.UnaryExpression{e}, - ls: ls, - funcName: "is_free_lock", - retType: types.Int8, + UnaryExpressionStub: expression.UnaryExpressionStub{e}, + ls: ls, + funcName: "is_free_lock", + retType: types.Int8, }, } } @@ -170,10 +170,10 @@ func NewIsUsedLock(ls *sql.LockSubsystem) sql.CreateFunc1Args { return func(e sql.Expression) sql.Expression { return &IsUsedLock{ NamedLockFunction: NamedLockFunction{ - UnaryExpression: expression.UnaryExpression{e}, - ls: ls, - funcName: "is_used_lock", - retType: types.Uint32, + UnaryExpressionStub: expression.UnaryExpressionStub{e}, + ls: ls, + funcName: "is_used_lock", + retType: types.Uint32, }, } } @@ -212,10 +212,10 @@ func NewReleaseLock(ls *sql.LockSubsystem) sql.CreateFunc1Args { return func(e sql.Expression) sql.Expression { return &ReleaseLock{ NamedLockFunction: NamedLockFunction{ - UnaryExpression: expression.UnaryExpression{e}, - ls: ls, - funcName: "release_lock", - retType: types.Int8, + UnaryExpressionStub: expression.UnaryExpressionStub{e}, + ls: ls, + funcName: "release_lock", + retType: types.Int8, }, } } diff --git a/sql/expression/function/logarithm.go b/sql/expression/function/logarithm.go index 3943898ed8..aa6cabd92f 100644 --- a/sql/expression/function/logarithm.go +++ b/sql/expression/function/logarithm.go @@ -39,7 +39,7 @@ func NewLogBaseFunc(base float64) func(e sql.Expression) sql.Expression { // LogBase is a function that returns the logarithm of a value with a specific base. type LogBase struct { - expression.UnaryExpression + expression.UnaryExpressionStub base float64 } @@ -48,7 +48,7 @@ var _ sql.CollationCoercible = (*LogBase)(nil) // NewLogBase creates a new LogBase expression. func NewLogBase(base float64, e sql.Expression) sql.Expression { - return &LogBase{UnaryExpression: expression.UnaryExpression{Child: e}, base: base} + return &LogBase{UnaryExpressionStub: expression.UnaryExpressionStub{Child: e}, base: base} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/lower_upper.go b/sql/expression/function/lower_upper.go index fceb0cbae2..d78210e554 100644 --- a/sql/expression/function/lower_upper.go +++ b/sql/expression/function/lower_upper.go @@ -24,7 +24,7 @@ import ( // Lower is a function that returns the lowercase of the text provided. type Lower struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Lower)(nil) @@ -32,7 +32,7 @@ var _ sql.CollationCoercible = (*Lower)(nil) // NewLower creates a new Lower expression. func NewLower(e sql.Expression) sql.Expression { - return &Lower{expression.UnaryExpression{Child: e}} + return &Lower{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression @@ -90,7 +90,7 @@ func (l *Lower) CollationCoercibility(ctx *sql.Context) (collation sql.Collation // Upper is a function that returns the UPPERCASE of the text provided. type Upper struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Upper)(nil) @@ -98,7 +98,7 @@ var _ sql.CollationCoercible = (*Upper)(nil) // NewUpper creates a new Lower expression. func NewUpper(e sql.Expression) sql.Expression { - return &Upper{expression.UnaryExpression{Child: e}} + return &Upper{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/queryinfo.go b/sql/expression/function/queryinfo.go index 799a461bed..636dd8c4f5 100644 --- a/sql/expression/function/queryinfo.go +++ b/sql/expression/function/queryinfo.go @@ -20,7 +20,6 @@ import ( "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -101,7 +100,6 @@ func NewLastInsertUuid(children ...sql.Expression) (sql.Expression, error) { if len(children) > 0 { return nil, sql.ErrInvalidChildrenNumber.New((&LastInsertUuid{}).String(), len(children), 0) } - return &LastInsertUuid{}, nil } @@ -153,7 +151,7 @@ func (l *LastInsertUuid) Description() string { // LastInsertId implements the LAST_INSERT_ID() function // https://dev.mysql.com/doc/refman/8.0/en/information-functions.html#function_last-insert-id type LastInsertId struct { - expression.UnaryExpression + Child sql.Expression } func NewLastInsertId(children ...sql.Expression) (sql.Expression, error) { @@ -161,7 +159,7 @@ func NewLastInsertId(children ...sql.Expression) (sql.Expression, error) { case 0: return &LastInsertId{}, nil case 1: - return &LastInsertId{UnaryExpression: expression.UnaryExpression{Child: children[0]}}, nil + return &LastInsertId{Child: children[0]}, nil default: return nil, sql.ErrInvalidArgumentNumber.New("LastInsertId", len(children), 1) } diff --git a/sql/expression/function/random_bytes.go b/sql/expression/function/random_bytes.go index a11e9049c7..63113cccdf 100644 --- a/sql/expression/function/random_bytes.go +++ b/sql/expression/function/random_bytes.go @@ -29,7 +29,7 @@ const randomBytesMax = 1024 // RandomBytes returns a random binary string of the given length. type RandomBytes struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*RandomBytes)(nil) @@ -37,7 +37,7 @@ var _ sql.CollationCoercible = (*RandomBytes)(nil) // NewRandomBytes returns a new RANDOM_BYTES function. func NewRandomBytes(e sql.Expression) sql.Expression { - return &RandomBytes{expression.UnaryExpression{Child: e}} + return &RandomBytes{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/reverse_repeat_replace.go b/sql/expression/function/reverse_repeat_replace.go index b8287a2763..e5a5f4cad4 100644 --- a/sql/expression/function/reverse_repeat_replace.go +++ b/sql/expression/function/reverse_repeat_replace.go @@ -27,7 +27,7 @@ import ( // Reverse is a function that returns the reverse of the text provided. type Reverse struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Reverse)(nil) @@ -35,7 +35,7 @@ var _ sql.CollationCoercible = (*Reverse)(nil) // NewReverse creates a new Reverse expression. func NewReverse(e sql.Expression) sql.Expression { - return &Reverse{expression.UnaryExpression{Child: e}} + return &Reverse{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/sleep.go b/sql/expression/function/sleep.go index b1fcecbdc3..d397e1708e 100644 --- a/sql/expression/function/sleep.go +++ b/sql/expression/function/sleep.go @@ -28,7 +28,7 @@ import ( // and returns 0. // It can be useful to test timeouts or long queries. type Sleep struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Sleep)(nil) @@ -36,7 +36,7 @@ var _ sql.CollationCoercible = (*Sleep)(nil) // NewSleep creates a new Sleep expression. func NewSleep(e sql.Expression) sql.Expression { - return &Sleep{expression.UnaryExpression{Child: e}} + return &Sleep{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/soundex.go b/sql/expression/function/soundex.go index 1f7b34567c..1ec7fc182f 100644 --- a/sql/expression/function/soundex.go +++ b/sql/expression/function/soundex.go @@ -29,7 +29,7 @@ import ( // soundex string is four characters long, but the SOUNDEX() function returns // an arbitrarily long string. type Soundex struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Soundex)(nil) @@ -37,7 +37,7 @@ var _ sql.CollationCoercible = (*Soundex)(nil) // NewSoundex creates a new Soundex expression. func NewSoundex(e sql.Expression) sql.Expression { - return &Soundex{expression.UnaryExpression{Child: e}} + return &Soundex{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/spatial/st_area.go b/sql/expression/function/spatial/st_area.go index eb76e26f78..93a90f4364 100644 --- a/sql/expression/function/spatial/st_area.go +++ b/sql/expression/function/spatial/st_area.go @@ -24,7 +24,7 @@ import ( // Area is a function that returns the Area of a Polygon type Area struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Area)(nil) @@ -32,7 +32,7 @@ var _ sql.CollationCoercible = (*Area)(nil) // NewArea creates a new Area expression. func NewArea(arg sql.Expression) sql.Expression { - return &Area{expression.UnaryExpression{Child: arg}} + return &Area{expression.UnaryExpressionStub{Child: arg}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/spatial/st_dimension.go b/sql/expression/function/spatial/st_dimension.go index 90f7a9edfa..3bd5d45137 100644 --- a/sql/expression/function/spatial/st_dimension.go +++ b/sql/expression/function/spatial/st_dimension.go @@ -24,7 +24,7 @@ import ( // Dimension is a function that converts a spatial type into WKT format (alias for AsText) type Dimension struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Dimension)(nil) @@ -32,7 +32,7 @@ var _ sql.CollationCoercible = (*Dimension)(nil) // NewDimension creates a new point expression. func NewDimension(e sql.Expression) sql.Expression { - return &Dimension{expression.UnaryExpression{Child: e}} + return &Dimension{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/spatial/st_linestring.go b/sql/expression/function/spatial/st_linestring.go index d67b5a0dca..ed7b33c6b8 100644 --- a/sql/expression/function/spatial/st_linestring.go +++ b/sql/expression/function/spatial/st_linestring.go @@ -24,7 +24,7 @@ import ( // StartPoint is a function that returns the first point of a LineString type StartPoint struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*StartPoint)(nil) @@ -32,7 +32,7 @@ var _ sql.CollationCoercible = (*StartPoint)(nil) // NewStartPoint creates a new StartPoint expression. func NewStartPoint(arg sql.Expression) sql.Expression { - return &StartPoint{expression.UnaryExpression{Child: arg}} + return &StartPoint{expression.UnaryExpressionStub{Child: arg}} } // FunctionName implements sql.FunctionExpression @@ -96,7 +96,7 @@ func (s *StartPoint) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // EndPoint is a function that returns the last point of a LineString type EndPoint struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*EndPoint)(nil) @@ -104,7 +104,7 @@ var _ sql.CollationCoercible = (*EndPoint)(nil) // NewEndPoint creates a new EndPoint expression. func NewEndPoint(arg sql.Expression) sql.Expression { - return &EndPoint{expression.UnaryExpression{Child: arg}} + return &EndPoint{expression.UnaryExpressionStub{Child: arg}} } // FunctionName implements sql.FunctionExpression @@ -168,7 +168,7 @@ func (e *EndPoint) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // IsClosed is a function that checks if a LineString or MultiLineString is close type IsClosed struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*IsClosed)(nil) @@ -176,7 +176,7 @@ var _ sql.CollationCoercible = (*IsClosed)(nil) // NewIsClosed creates a new EndPoint expression. func NewIsClosed(arg sql.Expression) sql.Expression { - return &IsClosed{expression.UnaryExpression{Child: arg}} + return &IsClosed{expression.UnaryExpressionStub{Child: arg}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/spatial/st_swapxy.go b/sql/expression/function/spatial/st_swapxy.go index dc9ffa5c17..9161e77fdc 100644 --- a/sql/expression/function/spatial/st_swapxy.go +++ b/sql/expression/function/spatial/st_swapxy.go @@ -24,7 +24,7 @@ import ( // SwapXY is a function that returns a spatial type with their X and Y values swapped type SwapXY struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*SwapXY)(nil) @@ -32,7 +32,7 @@ var _ sql.CollationCoercible = (*SwapXY)(nil) // NewSwapXY creates a new point expression. func NewSwapXY(e sql.Expression) sql.Expression { - return &SwapXY{expression.UnaryExpression{Child: e}} + return &SwapXY{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/spatial/wkb.go b/sql/expression/function/spatial/wkb.go index 730bf5e1c5..722329827f 100644 --- a/sql/expression/function/spatial/wkb.go +++ b/sql/expression/function/spatial/wkb.go @@ -25,7 +25,7 @@ import ( // AsWKB is a function that converts a spatial type into WKB format (alias for AsBinary) type AsWKB struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*AsWKB)(nil) @@ -33,7 +33,7 @@ var _ sql.CollationCoercible = (*AsWKB)(nil) // NewAsWKB creates a new point expression. func NewAsWKB(e sql.Expression) sql.Expression { - return &AsWKB{expression.UnaryExpression{Child: e}} + return &AsWKB{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/spatial/wkt.go b/sql/expression/function/spatial/wkt.go index fbde555695..74358524a5 100644 --- a/sql/expression/function/spatial/wkt.go +++ b/sql/expression/function/spatial/wkt.go @@ -26,7 +26,7 @@ import ( // AsWKT is a function that converts a spatial type into WKT format (alias for AsText) type AsWKT struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*AsWKT)(nil) @@ -34,7 +34,7 @@ var _ sql.CollationCoercible = (*AsWKT)(nil) // NewAsWKT creates a new point expression. func NewAsWKT(e sql.Expression) sql.Expression { - return &AsWKT{expression.UnaryExpression{Child: e}} + return &AsWKT{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/sqrt_power.go b/sql/expression/function/sqrt_power.go index 818e228b25..eaf8bc98b2 100644 --- a/sql/expression/function/sqrt_power.go +++ b/sql/expression/function/sqrt_power.go @@ -27,7 +27,7 @@ import ( // Sqrt is a function that returns the square value of the number provided. type Sqrt struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Sqrt)(nil) @@ -35,7 +35,7 @@ var _ sql.CollationCoercible = (*Sqrt)(nil) // NewSqrt creates a new Sqrt expression. func NewSqrt(e sql.Expression) sql.Expression { - return &Sqrt{expression.UnaryExpression{Child: e}} + return &Sqrt{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/time.go b/sql/expression/function/time.go index 01f9065dde..0bde0f3643 100644 --- a/sql/expression/function/time.go +++ b/sql/expression/function/time.go @@ -50,7 +50,7 @@ func getDate(ctx *sql.Context, val interface{}) (interface{}, error) { } func getDatePart(ctx *sql.Context, - u expression.UnaryExpression, + u expression.UnaryExpressionStub, row sql.Row, f func(interface{}) interface{}) (interface{}, error) { val, err := u.Child.Eval(ctx, row) @@ -75,7 +75,7 @@ func getDatePart(ctx *sql.Context, // Year is a function that returns the year of a date. type Year struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Year)(nil) @@ -83,7 +83,7 @@ var _ sql.CollationCoercible = (*Year)(nil) // NewYear creates a new Year UDF. func NewYear(date sql.Expression) sql.Expression { - return &Year{expression.UnaryExpression{Child: date}} + return &Year{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -108,7 +108,7 @@ func (*Year) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, // Eval implements the Expression interface. func (y *Year) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, y.UnaryExpression, row, year) + return getDatePart(ctx, y.UnaryExpressionStub, row, year) } // WithChildren implements the Expression interface. @@ -120,7 +120,7 @@ func (y *Year) WithChildren(children ...sql.Expression) (sql.Expression, error) } type Quarter struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Quarter)(nil) @@ -128,7 +128,7 @@ var _ sql.CollationCoercible = (*Quarter)(nil) // NewQuarter creates a new Month UDF. func NewQuarter(date sql.Expression) sql.Expression { - return &Quarter{expression.UnaryExpression{Child: date}} + return &Quarter{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -158,7 +158,7 @@ func (q *Quarter) CollationCoercibility(ctx *sql.Context) (collation sql.Collati // Eval implements the Expression interface. func (q *Quarter) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, q.UnaryExpression, row, quarter) + return getDatePart(ctx, q.UnaryExpressionStub, row, quarter) } // WithChildren implements the Expression interface. @@ -171,7 +171,7 @@ func (q *Quarter) WithChildren(children ...sql.Expression) (sql.Expression, erro // Month is a function that returns the month of a date. type Month struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Month)(nil) @@ -179,7 +179,7 @@ var _ sql.CollationCoercible = (*Month)(nil) // NewMonth creates a new Month UDF. func NewMonth(date sql.Expression) sql.Expression { - return &Month{expression.UnaryExpression{Child: date}} + return &Month{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -209,7 +209,7 @@ func (*Month) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID // Eval implements the Expression interface. func (m *Month) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, m.UnaryExpression, row, month) + return getDatePart(ctx, m.UnaryExpressionStub, row, month) } // WithChildren implements the Expression interface. @@ -222,7 +222,7 @@ func (m *Month) WithChildren(children ...sql.Expression) (sql.Expression, error) // Day is a function that returns the day of a date. type Day struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Day)(nil) @@ -230,7 +230,7 @@ var _ sql.CollationCoercible = (*Day)(nil) // NewDay creates a new Day UDF. func NewDay(date sql.Expression) sql.Expression { - return &Day{expression.UnaryExpression{Child: date}} + return &Day{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -260,7 +260,7 @@ func (*Day) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, // Eval implements the Expression interface. func (d *Day) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, d.UnaryExpression, row, day) + return getDatePart(ctx, d.UnaryExpressionStub, row, day) } // WithChildren implements the Expression interface. @@ -274,7 +274,7 @@ func (d *Day) WithChildren(children ...sql.Expression) (sql.Expression, error) { // Weekday is a function that returns the weekday of a date where 0 = Monday, // ..., 6 = Sunday. type Weekday struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Weekday)(nil) @@ -282,7 +282,7 @@ var _ sql.CollationCoercible = (*Weekday)(nil) // NewWeekday creates a new Weekday UDF. func NewWeekday(date sql.Expression) sql.Expression { - return &Weekday{expression.UnaryExpression{Child: date}} + return &Weekday{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -312,7 +312,7 @@ func (*Weekday) CollationCoercibility(ctx *sql.Context) (collation sql.Collation // Eval implements the Expression interface. func (d *Weekday) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, d.UnaryExpression, row, weekday) + return getDatePart(ctx, d.UnaryExpressionStub, row, weekday) } // WithChildren implements the Expression interface. @@ -325,7 +325,7 @@ func (d *Weekday) WithChildren(children ...sql.Expression) (sql.Expression, erro // Hour is a function that returns the hour of a date. type Hour struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Hour)(nil) @@ -333,7 +333,7 @@ var _ sql.CollationCoercible = (*Hour)(nil) // NewHour creates a new Hour UDF. func NewHour(date sql.Expression) sql.Expression { - return &Hour{expression.UnaryExpression{Child: date}} + return &Hour{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -358,7 +358,7 @@ func (*Hour) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, // Eval implements the Expression interface. func (h *Hour) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, h.UnaryExpression, row, hour) + return getDatePart(ctx, h.UnaryExpressionStub, row, hour) } // WithChildren implements the Expression interface. @@ -371,7 +371,7 @@ func (h *Hour) WithChildren(children ...sql.Expression) (sql.Expression, error) // Minute is a function that returns the minute of a date. type Minute struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Minute)(nil) @@ -379,7 +379,7 @@ var _ sql.CollationCoercible = (*Minute)(nil) // NewMinute creates a new Minute UDF. func NewMinute(date sql.Expression) sql.Expression { - return &Minute{expression.UnaryExpression{Child: date}} + return &Minute{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -404,7 +404,7 @@ func (*Minute) CollationCoercibility(ctx *sql.Context) (collation sql.CollationI // Eval implements the Expression interface. func (m *Minute) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, m.UnaryExpression, row, minute) + return getDatePart(ctx, m.UnaryExpressionStub, row, minute) } // WithChildren implements the Expression interface. @@ -417,7 +417,7 @@ func (m *Minute) WithChildren(children ...sql.Expression) (sql.Expression, error // Second is a function that returns the second of a date. type Second struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Second)(nil) @@ -425,7 +425,7 @@ var _ sql.CollationCoercible = (*Second)(nil) // NewSecond creates a new Second UDF. func NewSecond(date sql.Expression) sql.Expression { - return &Second{expression.UnaryExpression{Child: date}} + return &Second{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -450,7 +450,7 @@ func (*Second) CollationCoercibility(ctx *sql.Context) (collation sql.CollationI // Eval implements the Expression interface. func (s *Second) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, s.UnaryExpression, row, second) + return getDatePart(ctx, s.UnaryExpressionStub, row, second) } // WithChildren implements the Expression interface. @@ -464,7 +464,7 @@ func (s *Second) WithChildren(children ...sql.Expression) (sql.Expression, error // DayOfWeek is a function that returns the day of the week from a date where // 1 = Sunday, ..., 7 = Saturday. type DayOfWeek struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*DayOfWeek)(nil) @@ -472,7 +472,7 @@ var _ sql.CollationCoercible = (*DayOfWeek)(nil) // NewDayOfWeek creates a new DayOfWeek UDF. func NewDayOfWeek(date sql.Expression) sql.Expression { - return &DayOfWeek{expression.UnaryExpression{Child: date}} + return &DayOfWeek{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -502,7 +502,7 @@ func (*DayOfWeek) CollationCoercibility(ctx *sql.Context) (collation sql.Collati // Eval implements the Expression interface. func (d *DayOfWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, d.UnaryExpression, row, dayOfWeek) + return getDatePart(ctx, d.UnaryExpressionStub, row, dayOfWeek) } // WithChildren implements the Expression interface. @@ -515,7 +515,7 @@ func (d *DayOfWeek) WithChildren(children ...sql.Expression) (sql.Expression, er // DayOfYear is a function that returns the day of the year from a date. type DayOfYear struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*DayOfYear)(nil) @@ -523,7 +523,7 @@ var _ sql.CollationCoercible = (*DayOfYear)(nil) // NewDayOfYear creates a new DayOfYear UDF. func NewDayOfYear(date sql.Expression) sql.Expression { - return &DayOfYear{expression.UnaryExpression{Child: date}} + return &DayOfYear{expression.UnaryExpressionStub{Child: date}} } // FunctionName implements sql.FunctionExpression @@ -553,7 +553,7 @@ func (*DayOfYear) CollationCoercibility(ctx *sql.Context) (collation sql.Collati // Eval implements the Expression interface. func (d *DayOfYear) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, d.UnaryExpression, row, dayOfYear) + return getDatePart(ctx, d.UnaryExpressionStub, row, dayOfYear) } // WithChildren implements the Expression interface. @@ -1292,7 +1292,7 @@ func (ut *UTCTimestamp) WithChildren(children ...sql.Expression) (sql.Expression // Date a function takes the DATE part out from a datetime expression. type Date struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Date)(nil) @@ -1310,7 +1310,7 @@ func (d *Date) Description() string { // NewDate returns a new Date node. func NewDate(date sql.Expression) sql.Expression { - return &Date{expression.UnaryExpression{Child: date}} + return &Date{expression.UnaryExpressionStub{Child: date}} } func (d *Date) String() string { return fmt.Sprintf("DATE(%s)", d.Child) } @@ -1365,7 +1365,7 @@ func (d *Date) WithChildren(children ...sql.Expression) (sql.Expression, error) // UnaryDatetimeFunc is a sql.Function which takes a single datetime argument type UnaryDatetimeFunc struct { - expression.UnaryExpression + expression.UnaryExpressionStub // SQLType is the return type of the function SQLType sql.Type // Name is the name of the function @@ -1373,7 +1373,7 @@ type UnaryDatetimeFunc struct { } func NewUnaryDatetimeFunc(arg sql.Expression, name string, sqlType sql.Type) *UnaryDatetimeFunc { - return &UnaryDatetimeFunc{UnaryExpression: expression.UnaryExpression{Child: arg}, Name: name, SQLType: sqlType} + return &UnaryDatetimeFunc{UnaryExpressionStub: expression.UnaryExpressionStub{Child: arg}, Name: name, SQLType: sqlType} } // FunctionName implements sql.FunctionExpression @@ -1500,7 +1500,7 @@ func NewMicrosecond(arg sql.Expression) sql.Expression { } func (m *Microsecond) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return getDatePart(ctx, m.UnaryExpression, row, microsecond) + return getDatePart(ctx, m.UnaryExpressionStub, row, microsecond) } func (m *Microsecond) WithChildren(children ...sql.Expression) (sql.Expression, error) { @@ -1757,7 +1757,7 @@ func (c *CurrTime) WithChildren(children ...sql.Expression) (sql.Expression, err // Time is a function takes the Time part out from a datetime expression. type Time struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*Time)(nil) @@ -1765,7 +1765,7 @@ var _ sql.CollationCoercible = (*Time)(nil) // NewTime returns a new Date node. func NewTime(time sql.Expression) sql.Expression { - return &Time{expression.UnaryExpression{Child: time}} + return &Time{expression.UnaryExpressionStub{Child: time}} } func (t *Time) FunctionName() string { @@ -1792,7 +1792,7 @@ func (*Time) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, // Eval implements the Expression interface. func (t *Time) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - v, err := t.UnaryExpression.Child.Eval(ctx, row) + v, err := t.UnaryExpressionStub.Child.Eval(ctx, row) if err != nil { return nil, err } diff --git a/sql/expression/function/tobase64_frombase64.go b/sql/expression/function/tobase64_frombase64.go index e740b5311e..f31ad774b9 100644 --- a/sql/expression/function/tobase64_frombase64.go +++ b/sql/expression/function/tobase64_frombase64.go @@ -30,7 +30,7 @@ import ( // ToBase64 is a function to encode a string to the Base64 format // using the same dialect that MySQL's TO_BASE64 uses type ToBase64 struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*ToBase64)(nil) @@ -38,7 +38,7 @@ var _ sql.CollationCoercible = (*ToBase64)(nil) // NewToBase64 creates a new ToBase64 expression. func NewToBase64(e sql.Expression) sql.Expression { - return &ToBase64{expression.UnaryExpression{Child: e}} + return &ToBase64{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression @@ -139,7 +139,7 @@ func (*ToBase64) CollationCoercibility(ctx *sql.Context) (collation sql.Collatio // FromBase64 is a function to decode a Base64-formatted string // using the same dialect that MySQL's FROM_BASE64 uses type FromBase64 struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.FunctionExpression = (*FromBase64)(nil) @@ -147,7 +147,7 @@ var _ sql.CollationCoercible = (*FromBase64)(nil) // NewFromBase64 creates a new FromBase64 expression. func NewFromBase64(e sql.Expression) sql.Expression { - return &FromBase64{expression.UnaryExpression{Child: e}} + return &FromBase64{expression.UnaryExpressionStub{Child: e}} } // FunctionName implements sql.FunctionExpression diff --git a/sql/expression/function/trim_ltrim_rtrim.go b/sql/expression/function/trim_ltrim_rtrim.go index 0222b96ded..bb90184b47 100644 --- a/sql/expression/function/trim_ltrim_rtrim.go +++ b/sql/expression/function/trim_ltrim_rtrim.go @@ -162,11 +162,11 @@ func (t Trim) WithChildren(children ...sql.Expression) (sql.Expression, error) { } type LeftTrim struct { - expression.UnaryExpression + expression.UnaryExpressionStub } func NewLeftTrim(str sql.Expression) sql.Expression { - return &LeftTrim{expression.UnaryExpression{Child: str}} + return &LeftTrim{expression.UnaryExpressionStub{Child: str}} } var _ sql.FunctionExpression = (*LeftTrim)(nil) @@ -231,11 +231,11 @@ func (t *LeftTrim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } type RightTrim struct { - expression.UnaryExpression + expression.UnaryExpressionStub } func NewRightTrim(str sql.Expression) sql.Expression { - return &RightTrim{expression.UnaryExpression{Child: str}} + return &RightTrim{expression.UnaryExpressionStub{Child: str}} } var _ sql.FunctionExpression = (*RightTrim)(nil) diff --git a/sql/expression/function/values.go b/sql/expression/function/values.go index af5f4e25a2..1c206ccfbc 100644 --- a/sql/expression/function/values.go +++ b/sql/expression/function/values.go @@ -26,7 +26,7 @@ import ( // INSERT INTO table (pk, v1, v2) VALUES (1, 3, 5), (2, 4, 6) ON DUPLICATE KEY UPDATE v2 = values(v1) * 10; // the values inserted into v2 would be 30 and 40. type Values struct { - expression.UnaryExpression + expression.UnaryExpressionStub Value interface{} } @@ -36,8 +36,8 @@ var _ sql.CollationCoercible = (*Values)(nil) // NewValues creates a new Values function. func NewValues(col sql.Expression) sql.Expression { return &Values{ - UnaryExpression: expression.UnaryExpression{Child: col}, - Value: nil, + UnaryExpressionStub: expression.UnaryExpressionStub{Child: col}, + Value: nil, } } diff --git a/sql/expression/function/vector/conversion.go b/sql/expression/function/vector/conversion.go index 0542cd454e..43c27637d4 100644 --- a/sql/expression/function/vector/conversion.go +++ b/sql/expression/function/vector/conversion.go @@ -24,7 +24,7 @@ import ( // StringToVector converts a JSON string representation to a vector type StringToVector struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.Expression = (*StringToVector)(nil) @@ -32,7 +32,7 @@ var _ sql.FunctionExpression = (*StringToVector)(nil) var _ sql.CollationCoercible = (*StringToVector)(nil) func NewStringToVector(e sql.Expression) sql.Expression { - return &StringToVector{UnaryExpression: expression.UnaryExpression{Child: e}} + return &StringToVector{UnaryExpressionStub: expression.UnaryExpressionStub{Child: e}} } func (s *StringToVector) FunctionName() string { @@ -82,7 +82,7 @@ func (s *StringToVector) Eval(ctx *sql.Context, row sql.Row) (interface{}, error // VectorToString converts a vector to a JSON string representation type VectorToString struct { - expression.UnaryExpression + expression.UnaryExpressionStub } var _ sql.Expression = (*VectorToString)(nil) @@ -90,7 +90,7 @@ var _ sql.FunctionExpression = (*VectorToString)(nil) var _ sql.CollationCoercible = (*VectorToString)(nil) func NewVectorToString(e sql.Expression) sql.Expression { - return &VectorToString{UnaryExpression: expression.UnaryExpression{Child: e}} + return &VectorToString{UnaryExpressionStub: expression.UnaryExpressionStub{Child: e}} } func (v *VectorToString) FunctionName() string { diff --git a/sql/expression/interval.go b/sql/expression/interval.go index 6f4dee4dca..a97aac2a43 100644 --- a/sql/expression/interval.go +++ b/sql/expression/interval.go @@ -30,7 +30,7 @@ import ( // Interval defines a time duration. type Interval struct { - UnaryExpression + UnaryExpressionStub Unit string } @@ -39,7 +39,7 @@ var _ sql.CollationCoercible = (*Interval)(nil) // NewInterval creates a new interval expression. func NewInterval(child sql.Expression, unit string) *Interval { - return &Interval{UnaryExpression{Child: child}, strings.ToUpper(unit)} + return &Interval{UnaryExpressionStub{Child: child}, strings.ToUpper(unit)} } // Type implements the sql.Expression interface. diff --git a/sql/expression/isnull.go b/sql/expression/isnull.go index 109e915b86..d5ace9166d 100644 --- a/sql/expression/isnull.go +++ b/sql/expression/isnull.go @@ -21,7 +21,7 @@ import ( // IsNull is an expression that checks if an expression is null. type IsNull struct { - UnaryExpression + UnaryExpressionStub } var _ sql.Expression = (*IsNull)(nil) @@ -30,7 +30,7 @@ var _ sql.IsNullExpression = (*IsNull)(nil) // NewIsNull creates a new IsNull expression. func NewIsNull(child sql.Expression) *IsNull { - return &IsNull{UnaryExpression{child}} + return &IsNull{UnaryExpressionStub{child}} } // IsNullExpression implements the sql.IsNullExpression interface. This function exsists primarily diff --git a/sql/expression/istrue.go b/sql/expression/istrue.go index 4e2df2021a..2f4cc2a89e 100644 --- a/sql/expression/istrue.go +++ b/sql/expression/istrue.go @@ -23,7 +23,7 @@ import ( // IsTrue is an expression that checks if an expression is true. type IsTrue struct { - UnaryExpression + UnaryExpressionStub invert bool } @@ -35,12 +35,12 @@ const IsFalseStr = "IS FALSE" // NewIsTrue creates a new IsTrue expression. func NewIsTrue(child sql.Expression) *IsTrue { - return &IsTrue{UnaryExpression: UnaryExpression{child}} + return &IsTrue{UnaryExpressionStub: UnaryExpressionStub{child}} } // NewIsFalse creates a new IsTrue expression with its boolean sense inverted (IsFalse, effectively). func NewIsFalse(child sql.Expression) *IsTrue { - return &IsTrue{UnaryExpression: UnaryExpression{child}, invert: true} + return &IsTrue{UnaryExpressionStub: UnaryExpressionStub{child}, invert: true} } // Type implements the Expression interface. diff --git a/sql/transform/expr.go b/sql/transform/expr.go index d7f6b96c94..7727194188 100644 --- a/sql/transform/expr.go +++ b/sql/transform/expr.go @@ -96,14 +96,31 @@ func Exprs(e []sql.Expression, f ExprFunc) ([]sql.Expression, TreeIdentity, erro return newExprs, NewTree, nil } -// InspectExpr traverses the given expression tree from the bottom up, breaking if -// stop = true. Returns a bool indicating whether traversal was interrupted. -func InspectExpr(expr sql.Expression, f func(sql.Expression) bool) bool { - children := expr.Children() - for _, child := range children { - if InspectExpr(child, f) { +// InspectExpr performs a post-order traversal of the sql.Expression tree; +// First, `f` is called on `expr.Children()` and if stop = false, then InspectExpr is recursively called on node's +// children. +// TODO: this conflicts with transform.Inspect which performs a pre-order traversal and stops when cont = false. +func InspectExpr(expr sql.Expression, f func(sql.Expression) bool) (stop bool) { + // Avoid allocating []sql.Expression + switch e := expr.(type) { + case expression.UnaryExpression: + if InspectExpr(e.UnaryChild(), f) { + return true + } + case expression.BinaryExpression: + if InspectExpr(e.Left(), f) { return true } + if InspectExpr(e.Right(), f) { + return true + } + default: + children := e.Children() + for _, child := range children { + if InspectExpr(child, f) { + return true + } + } } if f(expr) { return true diff --git a/sql/transform/walk.go b/sql/transform/walk.go index f9d1e64aa0..31d776c8c3 100644 --- a/sql/transform/walk.go +++ b/sql/transform/walk.go @@ -39,8 +39,6 @@ func Walk(v Visitor, node sql.Node) { for _, child := range node.Children() { Walk(v, child) } - - v.Visit(nil) } type inspector func(sql.Node) bool @@ -52,12 +50,27 @@ func (f inspector) Visit(node sql.Node) Visitor { return nil } -// Inspect traverses the plan in depth-first order: It starts by calling -// f(node); node must not be nil. If f returns true, Inspect invokes f -// recursively for each of the children of node, followed by a call of -// f(nil). -func Inspect(node sql.Node, f func(sql.Node) bool) { - Walk(inspector(f), node) +// Inspect performs a pre-order traversal of the sql.Node tree; +// First, it does f(node) and if cont = true, then Inspect is recursively called on node's children. +// TODO: this conflicts with transform.InspectExpr which performs a post-order traversal and stops when stop = true. +func Inspect(node sql.Node, f func(sql.Node) bool) (cont bool) { + if !f(node) { + return false + } + + // Avoid allocating []sql.Expression + switch n := node.(type) { + case sql.UnaryNode: + Inspect(n.Child(), f) + case sql.BinaryNode: + Inspect(n.Left(), f) + Inspect(n.Right(), f) + default: + for _, child := range n.Children() { + Inspect(child, f) + } + } + return true } // WalkExpressions traverses the plan and calls sql.Walk on any expression it finds. diff --git a/sql/transform/walk_test.go b/sql/transform/walk_test.go index 94b270c3f9..a96bbb2a56 100644 --- a/sql/transform/walk_test.go +++ b/sql/transform/walk_test.go @@ -39,7 +39,7 @@ func TestWalk(t *testing.T) { Walk(f, a3) require.Equal(t, - []sql.Node{a3, a2, c1, a1, nil, b1, nil, nil, nil, nil}, + []sql.Node{a3, a2, c1, a1, b1}, visited, ) @@ -55,7 +55,7 @@ func TestWalk(t *testing.T) { Walk(f, a3) require.Equal(t, - []sql.Node{a3, a2, c1, nil, nil}, + []sql.Node{a3, a2, c1}, visited, ) } @@ -83,7 +83,7 @@ func TestInspect(t *testing.T) { Inspect(a3, f) require.Equal(t, - []sql.Node{a3, a2, c1, a1, nil, b1, nil, nil, nil, nil}, + []sql.Node{a3, a2, c1, a1, b1}, visited, ) @@ -99,7 +99,7 @@ func TestInspect(t *testing.T) { Inspect(a3, f) require.Equal(t, - []sql.Node{a3, a2, c1, nil, nil}, + []sql.Node{a3, a2, c1}, visited, ) }