Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optgen/cmd/support/agg_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (g *AggGen) genAggConstructor(define AggDef) {
fmt.Fprintf(g.w, "func New%s(e sql.Expression) *%s {\n", define.Name, define.Name)
fmt.Fprintf(g.w, " return &%s{\n", define.Name)
fmt.Fprintf(g.w, " unaryAggBase{\n")
fmt.Fprintf(g.w, " UnaryExpression: expression.UnaryExpression{Child: e},\n")
fmt.Fprintf(g.w, " Child: e,\n")
fmt.Fprintf(g.w, " functionName: \"%s\",\n", define.Name)
fmt.Fprintf(g.w, " description: \"%s\",\n", define.Desc)
fmt.Fprintf(g.w, " },\n")
Expand Down
2 changes: 1 addition & 1 deletion optgen/cmd/support/agg_gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestAggGen(t *testing.T) {
func NewTest(e sql.Expression) *Test {
return &Test{
unaryAggBase{
UnaryExpression: expression.UnaryExpression{Child: e},
Child: e,
functionName: "Test",
description: "Test description",
},
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/replace_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ func replaceAgg(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope,
var sf sql.SortField
switch agg := gb.SelectDeps[0].(type) {
case *aggregation.Max:
gf, ok := agg.UnaryExpression.Child.(*expression.GetField)
gf, ok := agg.Child.(*expression.GetField)
if !ok {
return n, transform.SameTree, nil
}
Expand All @@ -382,7 +382,7 @@ func replaceAgg(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope,
Order: sql.Descending,
}
case *aggregation.Min:
gf, ok := agg.UnaryExpression.Child.(*expression.GetField)
gf, ok := agg.Child.(*expression.GetField)
if !ok {
return n, transform.SameTree, nil
}
Expand Down
20 changes: 15 additions & 5 deletions sql/analyzer/symbol_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,22 @@ func findSubqueryExpr(n sql.Node) *plan.Subquery {
// hasMatchAgainstExpr searches for an *expression.MatchAgainst within the node's expressions
func hasMatchAgainstExpr(node sql.Node) bool {
var foundMatchAgainstExpr bool
transform.InspectExpressions(node, func(expr sql.Expression) bool {
_, isMatchAgainstExpr := expr.(*expression.MatchAgainst)
if isMatchAgainstExpr {
foundMatchAgainstExpr = true
transform.Inspect(node, func(n sql.Node) (cont bool) {
if ne, ok := n.(sql.Expressioner); ok {
for _, expr := range ne.Expressions() {
stop := transform.InspectExpr(expr, func(e sql.Expression) (stop bool) {
if _, isMatchAgainst := e.(*expression.MatchAgainst); isMatchAgainst {
foundMatchAgainstExpr = true
return true
}
return false
})
if stop {
return false
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this backwards? Don't you want to stop as soon as you find a foundMatchAgainst expr?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does stop once foundMatchAgainst expr is found. The logic is backwards for transform.Inspect vs transform.InspectExpr. I was initially going to swap the logic in this PR, but the changes quickly got unwieldy, so I left the TODOs for now.

}
}
}
return !foundMatchAgainstExpr
return true
})
return foundMatchAgainstExpr
}
Expand Down
4 changes: 2 additions & 2 deletions sql/expression/alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ var _ sql.CollationCoercible = (*AliasReference)(nil)

// Alias is a node that gives a name to an expression.
type Alias struct {
UnaryExpression
UnaryExpressionStub
name string
unreferencable bool
id sql.ColumnId
Expand All @@ -92,7 +92,7 @@ var _ sql.CollationCoercible = (*Alias)(nil)

// NewAlias returns a new Alias node.
func NewAlias(name string, expr sql.Expression) *Alias {
return &Alias{UnaryExpression{expr}, name, false, 0}
return &Alias{UnaryExpressionStub{expr}, name, false, 0}
}

// AsUnreferencable marks the alias outside of scope referencing
Expand Down
4 changes: 2 additions & 2 deletions sql/expression/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -681,15 +681,15 @@ func mult(lval, rval interface{}) (interface{}, error) {

// UnaryMinus is an unary minus operator.
type UnaryMinus struct {
UnaryExpression
UnaryExpressionStub
}

var _ sql.Expression = (*UnaryMinus)(nil)
var _ sql.CollationCoercible = (*UnaryMinus)(nil)

// NewUnaryMinus creates a new UnaryMinus expression node.
func NewUnaryMinus(child sql.Expression) *UnaryMinus {
return &UnaryMinus{UnaryExpression{Child: child}}
return &UnaryMinus{UnaryExpressionStub{Child: child}}
}

// Eval implements the sql.Expression interface.
Expand Down
8 changes: 4 additions & 4 deletions sql/expression/auto_increment.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var (

// AutoIncrement implements AUTO_INCREMENT
type AutoIncrement struct {
UnaryExpression
UnaryExpressionStub
autoTbl sql.AutoIncrementTable
autoCol *sql.Column
}
Expand All @@ -58,7 +58,7 @@ func NewAutoIncrement(ctx *sql.Context, table sql.Table, given sql.Expression) (
}

return &AutoIncrement{
UnaryExpression{Child: given},
UnaryExpressionStub{Child: given},
autoTbl,
autoCol,
}, nil
Expand All @@ -72,7 +72,7 @@ func NewAutoIncrementForColumn(ctx *sql.Context, table sql.Table, autoCol *sql.C
}

return &AutoIncrement{
UnaryExpression{Child: given},
UnaryExpressionStub{Child: given},
autoTbl,
autoCol,
}, nil
Expand Down Expand Up @@ -159,7 +159,7 @@ func (i *AutoIncrement) WithChildren(children ...sql.Expression) (sql.Expression
return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1)
}
return &AutoIncrement{
UnaryExpression{Child: children[0]},
UnaryExpressionStub{Child: children[0]},
i.autoTbl,
i.autoCol,
}, nil
Expand Down
12 changes: 6 additions & 6 deletions sql/expression/auto_uuid.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
// AutoUuid is an expression that captures an automatically generated UUID value and stores it in the session for
// later retrieval. AutoUuid is intended to only be used directly on top of a UUID function.
type AutoUuid struct {
UnaryExpression
UnaryExpressionStub
uuidCol *sql.Column
foundUuid bool
}
Expand All @@ -38,8 +38,8 @@ var _ sql.CollationCoercible = (*AutoUuid)(nil)
// because of package import cycles, we can't enforce that directly here.
func NewAutoUuid(_ *sql.Context, col *sql.Column, child sql.Expression) *AutoUuid {
return &AutoUuid{
UnaryExpression: UnaryExpression{Child: child},
uuidCol: col,
UnaryExpressionStub: UnaryExpressionStub{Child: child},
uuidCol: col,
}
}

Expand Down Expand Up @@ -94,9 +94,9 @@ func (au *AutoUuid) WithChildren(children ...sql.Expression) (sql.Expression, er
return nil, sql.ErrInvalidChildrenNumber.New(au, len(children), 1)
}
return &AutoUuid{
UnaryExpression: UnaryExpression{Child: children[0]},
uuidCol: au.uuidCol,
foundUuid: au.foundUuid,
UnaryExpressionStub: UnaryExpressionStub{Child: children[0]},
uuidCol: au.uuidCol,
foundUuid: au.foundUuid,
}, nil
}

Expand Down
4 changes: 2 additions & 2 deletions sql/expression/binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import (
//
// cc: https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#operator_binary
type Binary struct {
UnaryExpression
UnaryExpressionStub
}

var _ sql.Expression = (*Binary)(nil)
var _ sql.CollationCoercible = (*Binary)(nil)

func NewBinary(e sql.Expression) sql.Expression {
return &Binary{UnaryExpression{Child: e}}
return &Binary{UnaryExpressionStub{Child: e}}
}

func (b *Binary) String() string {
Expand Down
4 changes: 2 additions & 2 deletions sql/expression/boolean.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ import (

// Not is a node that negates an expression.
type Not struct {
UnaryExpression
UnaryExpressionStub
}

var _ sql.Expression = (*Not)(nil)
var _ sql.CollationCoercible = (*Not)(nil)

// NewNot returns a new Not node.
func NewNot(child sql.Expression) *Not {
return &Not{UnaryExpression{child}}
return &Not{UnaryExpressionStub{child}}
}

// Type implements the Expression interface.
Expand Down
32 changes: 21 additions & 11 deletions sql/expression/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,49 @@ func IsBinary(e sql.Expression) bool {
return len(e.Children()) == 2
}

// UnaryExpression is an expression that has only one child.
type UnaryExpression struct {
type UnaryExpression interface {
sql.Expression
UnaryChild() sql.Expression
}

// UnaryExpressionStub is an expression that has only one child.
type UnaryExpressionStub struct {
Child sql.Expression
}

// UnaryChild implements the UnaryExpression interface.
func (p *UnaryExpressionStub) UnaryChild() sql.Expression {
return p.Child
}

// Children implements the Expression interface.
func (p *UnaryExpression) Children() []sql.Expression {
func (p *UnaryExpressionStub) Children() []sql.Expression {
return []sql.Expression{p.Child}
}

// Resolved implements the Expression interface.
func (p *UnaryExpression) Resolved() bool {
func (p *UnaryExpressionStub) Resolved() bool {
return p.Child.Resolved()
}

// IsNullable returns whether the expression can be null.
func (p *UnaryExpression) IsNullable() bool {
func (p *UnaryExpressionStub) IsNullable() bool {
return p.Child.IsNullable()
}

// BinaryExpressionStub is an expression that has two children.
type BinaryExpressionStub struct {
LeftChild sql.Expression
RightChild sql.Expression
}

// BinaryExpression is an expression that has two children
type BinaryExpression interface {
sql.Expression
Left() sql.Expression
Right() sql.Expression
}

// BinaryExpressionStub is an expression that has two children.
type BinaryExpressionStub struct {
LeftChild sql.Expression
RightChild sql.Expression
}

func (p *BinaryExpressionStub) Left() sql.Expression {
return p.LeftChild
}
Expand Down
14 changes: 7 additions & 7 deletions sql/expression/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ const (

// Convert represent a CAST(x AS T) or CONVERT(x, T) operation that casts x expression to type T.
type Convert struct {
UnaryExpression
UnaryExpressionStub

// cachedDecimalType is the cached Decimal type for this convert expression. Because new Decimal types
// must be created with their specific scale and precision values, unlike other types, we cache the created
Expand All @@ -88,8 +88,8 @@ var _ sql.CollationCoercible = (*Convert)(nil)
func NewConvert(expr sql.Expression, castToType string) *Convert {
disableRounding(expr)
return &Convert{
UnaryExpression: UnaryExpression{Child: expr},
castToType: strings.ToLower(castToType),
UnaryExpressionStub: UnaryExpressionStub{Child: expr},
castToType: strings.ToLower(castToType),
}
}

Expand All @@ -99,10 +99,10 @@ func NewConvert(expr sql.Expression, castToType string) *Convert {
func NewConvertWithLengthAndScale(expr sql.Expression, castToType string, typeLength, typeScale int) *Convert {
disableRounding(expr)
return &Convert{
UnaryExpression: UnaryExpression{Child: expr},
castToType: strings.ToLower(castToType),
typeLength: typeLength,
typeScale: typeScale,
UnaryExpressionStub: UnaryExpressionStub{Child: expr},
castToType: strings.ToLower(castToType),
typeLength: typeLength,
typeScale: typeScale,
}
}

Expand Down
6 changes: 3 additions & 3 deletions sql/expression/convertusing.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

// ConvertUsing represents a CONVERT(X USING T) operation that casts the expression X to the character set T.
type ConvertUsing struct {
UnaryExpression
UnaryExpressionStub
TargetCharSet sql.CharacterSetID
}

Expand All @@ -32,8 +32,8 @@ var _ sql.CollationCoercible = (*ConvertUsing)(nil)

func NewConvertUsing(expr sql.Expression, targetCharSet sql.CharacterSetID) *ConvertUsing {
return &ConvertUsing{
UnaryExpression: UnaryExpression{Child: expr},
TargetCharSet: targetCharSet,
UnaryExpressionStub: UnaryExpressionStub{Child: expr},
TargetCharSet: targetCharSet,
}
}

Expand Down
4 changes: 2 additions & 2 deletions sql/expression/function/absval.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ import (

// AbsVal is a function that takes the absolute value of a number
type AbsVal struct {
expression.UnaryExpression
expression.UnaryExpressionStub
}

var _ sql.FunctionExpression = (*AbsVal)(nil)
var _ sql.CollationCoercible = (*AbsVal)(nil)

// NewAbsVal creates a new AbsVal expression.
func NewAbsVal(e sql.Expression) sql.Expression {
return &AbsVal{expression.UnaryExpression{Child: e}}
return &AbsVal{expression.UnaryExpressionStub{Child: e}}
}

// FunctionName implements sql.FunctionExpression
Expand Down
12 changes: 9 additions & 3 deletions sql/expression/function/aggregation/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ var ErrEvalUnsupportedOnAggregation = errors.NewKind("Unimplemented %s.Eval(). T
// unaryAggBase is the generic embedded class optgen
// uses to codegen single expression aggregate functions.
type unaryAggBase struct {
expression.UnaryExpression
Child sql.Expression
typ sql.Type
window *sql.WindowDefinition
functionName string
Expand Down Expand Up @@ -76,6 +76,11 @@ func (a *unaryAggBase) WithId(id sql.ColumnId) sql.IdExpression {
return &ret
}

// IsNullable returns whether the expression can be null.
func (a *unaryAggBase) IsNullable() bool {
return a.Child.IsNullable()
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (a *unaryAggBase) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.GetCoercibility(ctx, a.Child)
Expand All @@ -96,7 +101,8 @@ func (a *unaryAggBase) Children() []sql.Expression {
func (a *unaryAggBase) Resolved() bool {
if _, ok := a.Child.(*expression.Star); ok {
return true
} else if !a.Child.Resolved() {
}
if !a.Child.Resolved() {
return false
}
if a.window == nil {
Expand All @@ -112,7 +118,7 @@ func (a *unaryAggBase) WithChildren(children ...sql.Expression) (sql.Expression,
}

na := *a
na.UnaryExpression = expression.UnaryExpression{Child: children[0]}
na.Child = children[0]
if len(children) > 1 && a.window != nil {
w, err := a.window.FromExpressions(children[1:])
if err != nil {
Expand Down
Loading
Loading