-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Fix scalar aggregation with literals in empty result sets #18477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2eb8f48
32a2621
1c80255
c59f31f
386ae98
e944d81
7b9c3c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,12 @@ import ( | |
| // It contains the opcode and input column number. | ||
| type AggregateParams struct { | ||
| Opcode opcode.AggregateOpcode | ||
| Col int | ||
|
|
||
| // Input source specification - exactly one of these should be set: | ||
| // Col: Column index for simple column references (e.g., SUM(column_name)) | ||
| // EExpr: Evaluated expression for literals, parameters | ||
| Col int | ||
| EExpr evalengine.Expr | ||
|
|
||
| // These are used only for distinct opcodes. | ||
| KeyCol int | ||
|
|
@@ -53,15 +58,26 @@ type AggregateParams struct { | |
| CollationEnv *collations.Environment | ||
| } | ||
|
|
||
| func NewAggregateParam(opcode opcode.AggregateOpcode, col int, alias string, collationEnv *collations.Environment) *AggregateParams { | ||
| // NewAggregateParam creates a new aggregate param | ||
| func NewAggregateParam( | ||
| oc opcode.AggregateOpcode, | ||
| col int, | ||
| expr evalengine.Expr, | ||
| alias string, | ||
| collationEnv *collations.Environment, | ||
| ) *AggregateParams { | ||
| if expr != nil && oc != opcode.AggregateConstant { | ||
| panic(vterrors.VT13001("expr should be nil")) | ||
|
Comment on lines
+69
to
+70
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can have better error message, defining why.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is just defensive coding. If this ever happens, we messed up somewhere else. |
||
| } | ||
| out := &AggregateParams{ | ||
| Opcode: opcode, | ||
| Opcode: oc, | ||
| Col: col, | ||
| EExpr: expr, | ||
| Alias: alias, | ||
| WCol: -1, | ||
| CollationEnv: collationEnv, | ||
| } | ||
| if opcode.NeedsComparableValues() { | ||
| if oc.NeedsComparableValues() { | ||
| out.KeyCol = col | ||
| } | ||
| return out | ||
|
|
@@ -73,6 +89,9 @@ func (ap *AggregateParams) WAssigned() bool { | |
|
|
||
| func (ap *AggregateParams) String() string { | ||
| keyCol := strconv.Itoa(ap.Col) | ||
| if ap.EExpr != nil { | ||
| keyCol = sqlparser.String(ap.EExpr) | ||
| } | ||
| if ap.WAssigned() { | ||
| keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol) | ||
| } | ||
|
|
@@ -89,7 +108,14 @@ func (ap *AggregateParams) String() string { | |
| return fmt.Sprintf("%s%s(%s)", ap.Opcode.String(), dispOrigOp, keyCol) | ||
| } | ||
|
|
||
| func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type { | ||
| func (ap *AggregateParams) typ(inputType querypb.Type, env *evalengine.ExpressionEnv, collID collations.ID) querypb.Type { | ||
| if ap.EExpr != nil { | ||
| value, err := eval(env, ap.EExpr, collID) | ||
| if err != nil { | ||
| return sqltypes.Unknown | ||
| } | ||
| return value.Type() | ||
| } | ||
| if ap.OrigOpcode != opcode.AggregateUnassigned { | ||
| return ap.OrigOpcode.SQLType(inputType) | ||
| } | ||
|
|
@@ -98,7 +124,7 @@ func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type { | |
|
|
||
| type aggregator interface { | ||
| add(row []sqltypes.Value) error | ||
| finish() sqltypes.Value | ||
| finish(env *evalengine.ExpressionEnv, coll collations.ID) (sqltypes.Value, error) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There new params are used by only aggregateConstant. Can it be made part of the struct than interface? |
||
| reset() | ||
| } | ||
|
|
||
|
|
@@ -151,8 +177,8 @@ func (a *aggregatorCount) add(row []sqltypes.Value) error { | |
| return nil | ||
| } | ||
|
|
||
| func (a *aggregatorCount) finish() sqltypes.Value { | ||
| return sqltypes.NewInt64(a.n) | ||
| func (a *aggregatorCount) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { | ||
| return sqltypes.NewInt64(a.n), nil | ||
| } | ||
|
|
||
| func (a *aggregatorCount) reset() { | ||
|
|
@@ -164,13 +190,13 @@ type aggregatorCountStar struct { | |
| n int64 | ||
| } | ||
|
|
||
| func (a *aggregatorCountStar) add(_ []sqltypes.Value) error { | ||
| func (a *aggregatorCountStar) add([]sqltypes.Value) error { | ||
| a.n++ | ||
| return nil | ||
| } | ||
|
|
||
| func (a *aggregatorCountStar) finish() sqltypes.Value { | ||
| return sqltypes.NewInt64(a.n) | ||
| func (a *aggregatorCountStar) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { | ||
| return sqltypes.NewInt64(a.n), nil | ||
| } | ||
|
|
||
| func (a *aggregatorCountStar) reset() { | ||
|
|
@@ -198,8 +224,8 @@ func (a *aggregatorMax) add(row []sqltypes.Value) (err error) { | |
| return a.minmax.Max(row[a.from]) | ||
| } | ||
|
|
||
| func (a *aggregatorMinMax) finish() sqltypes.Value { | ||
| return a.minmax.Result() | ||
| func (a *aggregatorMinMax) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { | ||
| return a.minmax.Result(), nil | ||
| } | ||
|
|
||
| func (a *aggregatorMinMax) reset() { | ||
|
|
@@ -222,8 +248,8 @@ func (a *aggregatorSum) add(row []sqltypes.Value) error { | |
| return a.sum.Add(row[a.from]) | ||
| } | ||
|
|
||
| func (a *aggregatorSum) finish() sqltypes.Value { | ||
| return a.sum.Result() | ||
| func (a *aggregatorSum) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { | ||
| return a.sum.Result(), nil | ||
| } | ||
|
|
||
| func (a *aggregatorSum) reset() { | ||
|
|
@@ -232,28 +258,51 @@ func (a *aggregatorSum) reset() { | |
| } | ||
|
|
||
| type aggregatorScalar struct { | ||
| from int | ||
| current sqltypes.Value | ||
| init bool | ||
| from int | ||
| current sqltypes.Value | ||
| hasValue bool | ||
| } | ||
|
|
||
| func (a *aggregatorScalar) add(row []sqltypes.Value) error { | ||
| if !a.init { | ||
| if !a.hasValue { | ||
| a.current = row[a.from] | ||
| a.init = true | ||
| a.hasValue = true | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| func (a *aggregatorScalar) finish() sqltypes.Value { | ||
| return a.current | ||
| func (a *aggregatorScalar) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { | ||
| return a.current, nil | ||
| } | ||
|
|
||
| func (a *aggregatorScalar) reset() { | ||
| a.current = sqltypes.NULL | ||
| a.init = false | ||
| a.hasValue = false | ||
| } | ||
|
|
||
| type aggregatorConstant struct { | ||
| expr evalengine.Expr | ||
| } | ||
|
|
||
| func (*aggregatorConstant) add([]sqltypes.Value) error { | ||
| return nil | ||
| } | ||
|
|
||
| func (a *aggregatorConstant) finish(env *evalengine.ExpressionEnv, coll collations.ID) (sqltypes.Value, error) { | ||
| return eval(env, a.expr, coll) | ||
| } | ||
|
|
||
| func eval(env *evalengine.ExpressionEnv, eexpr evalengine.Expr, coll collations.ID) (sqltypes.Value, error) { | ||
| v, err := env.Evaluate(eexpr) | ||
| if err != nil { | ||
| return sqltypes.Value{}, err | ||
| } | ||
|
|
||
| return v.Value(coll), nil | ||
| } | ||
|
|
||
| func (*aggregatorConstant) reset() {} | ||
|
|
||
| type aggregatorGroupConcat struct { | ||
| from int | ||
| type_ sqltypes.Type | ||
|
|
@@ -275,11 +324,11 @@ func (a *aggregatorGroupConcat) add(row []sqltypes.Value) error { | |
| return nil | ||
| } | ||
|
|
||
| func (a *aggregatorGroupConcat) finish() sqltypes.Value { | ||
| func (a *aggregatorGroupConcat) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { | ||
| if a.n == 0 { | ||
| return sqltypes.NULL | ||
| return sqltypes.NULL, nil | ||
| } | ||
| return sqltypes.MakeTrusted(a.type_, a.concat) | ||
| return sqltypes.MakeTrusted(a.type_, a.concat), nil | ||
| } | ||
|
|
||
| func (a *aggregatorGroupConcat) reset() { | ||
|
|
@@ -301,36 +350,44 @@ func (a *aggregatorGtid) add(row []sqltypes.Value) error { | |
| return nil | ||
| } | ||
|
|
||
| func (a *aggregatorGtid) finish() sqltypes.Value { | ||
| func (a *aggregatorGtid) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { | ||
| gtid := binlogdatapb.VGtid{ShardGtids: a.shards} | ||
| return sqltypes.NewVarChar(gtid.String()) | ||
| return sqltypes.NewVarChar(gtid.String()), nil | ||
| } | ||
|
|
||
| func (a *aggregatorGtid) reset() { | ||
| a.shards = a.shards[:0] // safe to reuse because only the serialized form of a.shards is returned | ||
| } | ||
|
|
||
| type aggregationState []aggregator | ||
| type aggregationState struct { | ||
| env *evalengine.ExpressionEnv | ||
| aggregators []aggregator | ||
| coll collations.ID | ||
| } | ||
|
|
||
| func (a aggregationState) add(row []sqltypes.Value) error { | ||
| for _, st := range a { | ||
| func (a *aggregationState) add(row []sqltypes.Value) error { | ||
| for _, st := range a.aggregators { | ||
| if err := st.add(row); err != nil { | ||
| return err | ||
| } | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| func (a aggregationState) finish() (row []sqltypes.Value) { | ||
| row = make([]sqltypes.Value, 0, len(a)) | ||
| for _, st := range a { | ||
| row = append(row, st.finish()) | ||
| func (a *aggregationState) finish() ([]sqltypes.Value, error) { | ||
| row := make([]sqltypes.Value, 0, len(a.aggregators)) | ||
| for _, st := range a.aggregators { | ||
| v, err := st.finish(a.env, a.coll) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| row = append(row, v) | ||
| } | ||
| return | ||
| return row, nil | ||
| } | ||
|
|
||
| func (a aggregationState) reset() { | ||
| for _, st := range a { | ||
| func (a *aggregationState) reset() { | ||
| for _, st := range a.aggregators { | ||
| st.reset() | ||
| } | ||
| } | ||
|
|
@@ -354,13 +411,16 @@ func isComparable(typ sqltypes.Type) bool { | |
| return false | ||
| } | ||
|
|
||
| func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (aggregationState, []*querypb.Field, error) { | ||
| func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams, env *evalengine.ExpressionEnv, collation collations.ID) (*aggregationState, []*querypb.Field, error) { | ||
| fields = slice.Map(fields, func(from *querypb.Field) *querypb.Field { return from.CloneVT() }) | ||
|
|
||
| agstate := make([]aggregator, len(fields)) | ||
| aggregators := make([]aggregator, len(fields)) | ||
| for _, aggr := range aggregates { | ||
| sourceType := fields[aggr.Col].Type | ||
| targetType := aggr.typ(sourceType) | ||
| var sourceType querypb.Type | ||
| if aggr.Col < len(fields) { | ||
| sourceType = fields[aggr.Col].Type | ||
| } | ||
| targetType := aggr.typ(sourceType, env, collation) | ||
|
|
||
| var ag aggregator | ||
| var distinct = -1 | ||
|
|
@@ -444,22 +504,25 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg | |
| separator: separator, | ||
| } | ||
|
|
||
| case opcode.AggregateConstant: | ||
| ag = &aggregatorConstant{expr: aggr.EExpr} | ||
|
|
||
| default: | ||
| panic("BUG: unexpected Aggregation opcode") | ||
| } | ||
|
|
||
| agstate[aggr.Col] = ag | ||
| aggregators[aggr.Col] = ag | ||
| fields[aggr.Col].Type = targetType | ||
| if aggr.Alias != "" { | ||
| fields[aggr.Col].Name = aggr.Alias | ||
| } | ||
| } | ||
|
|
||
| for i, a := range agstate { | ||
| for i, a := range aggregators { | ||
| if a == nil { | ||
| agstate[i] = &aggregatorScalar{from: i} | ||
| aggregators[i] = &aggregatorScalar{from: i} | ||
| } | ||
| } | ||
|
|
||
| return agstate, fields, nil | ||
| return &aggregationState{aggregators: aggregators, env: env, coll: collation}, fields, nil | ||
| } | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess if
EExpris set, the value inColis ignored? I'm asking because it says "exactly one should be set", but0(the default / empty value froCol) is a valid value and I was confused for a moment what this means then.