diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index e429d3f4bb7..d978f037463 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -80,6 +80,9 @@ func TestAggrWithLimit(t *testing.T) { mcmp.Exec(fmt.Sprintf("insert into aggr_test(id, val1, val2) values(%d, 'a', %d)", i, r)) } mcmp.Exec("select val2, count(*) from aggr_test group by val2 order by count(*), val2 limit 10") + if utils.BinaryIsAtLeastAtVersion(23, "vtgate") { + mcmp.Exec("SELECT 1 AS `id`, COUNT(*) FROM (SELECT `id` FROM aggr_test WHERE val1 = 1 LIMIT 100) `t`") + } } func TestAggregateTypes(t *testing.T) { diff --git a/go/vt/sqlparser/analyzer.go b/go/vt/sqlparser/analyzer.go index c3a9b744c17..4315b146faa 100644 --- a/go/vt/sqlparser/analyzer.go +++ b/go/vt/sqlparser/analyzer.go @@ -19,6 +19,7 @@ package sqlparser // analyzer.go contains utility analysis functions. import ( + "errors" "fmt" "strings" "unicode" @@ -379,6 +380,20 @@ func IsColName(node Expr) bool { return ok } +var errNotStatic = errors.New("not static") + +// IsConstant returns true if the Expr can be evaluated without input or access to tables. +func IsConstant(node Expr) bool { + err := Walk(func(node SQLNode) (kontinue bool, err error) { + switch node.(type) { + case *ColName, *Subquery: + return false, errNotStatic + } + return true, nil + }, node) + return err == nil +} + // IsValue returns true if the Expr is a string, integral or value arg. // NULL is not considered to be a value. func IsValue(node Expr) bool { diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index fa0342155a8..f003d647b72 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -35,7 +35,12 @@ import ( // It contains the opcode and input column number. type AggregateParams struct { Opcode opcode.AggregateOpcode - Col int + + // Input source specification - exactly one of these should be set: + // Col: Column index for simple column references (e.g., SUM(column_name)) + // EExpr: Evaluated expression for literals, parameters + Col int + EExpr evalengine.Expr // These are used only for distinct opcodes. KeyCol int @@ -53,15 +58,26 @@ type AggregateParams struct { CollationEnv *collations.Environment } -func NewAggregateParam(opcode opcode.AggregateOpcode, col int, alias string, collationEnv *collations.Environment) *AggregateParams { +// NewAggregateParam creates a new aggregate param +func NewAggregateParam( + oc opcode.AggregateOpcode, + col int, + expr evalengine.Expr, + alias string, + collationEnv *collations.Environment, +) *AggregateParams { + if expr != nil && oc != opcode.AggregateConstant { + panic(vterrors.VT13001("expr should be nil")) + } out := &AggregateParams{ - Opcode: opcode, + Opcode: oc, Col: col, + EExpr: expr, Alias: alias, WCol: -1, CollationEnv: collationEnv, } - if opcode.NeedsComparableValues() { + if oc.NeedsComparableValues() { out.KeyCol = col } return out @@ -73,6 +89,9 @@ func (ap *AggregateParams) WAssigned() bool { func (ap *AggregateParams) String() string { keyCol := strconv.Itoa(ap.Col) + if ap.EExpr != nil { + keyCol = sqlparser.String(ap.EExpr) + } if ap.WAssigned() { keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol) } @@ -89,7 +108,14 @@ func (ap *AggregateParams) String() string { return fmt.Sprintf("%s%s(%s)", ap.Opcode.String(), dispOrigOp, keyCol) } -func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type { +func (ap *AggregateParams) typ(inputType querypb.Type, env *evalengine.ExpressionEnv, collID collations.ID) querypb.Type { + if ap.EExpr != nil { + value, err := eval(env, ap.EExpr, collID) + if err != nil { + return sqltypes.Unknown + } + return value.Type() + } if ap.OrigOpcode != opcode.AggregateUnassigned { return ap.OrigOpcode.SQLType(inputType) } @@ -98,7 +124,7 @@ func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type { type aggregator interface { add(row []sqltypes.Value) error - finish() sqltypes.Value + finish(env *evalengine.ExpressionEnv, coll collations.ID) (sqltypes.Value, error) reset() } @@ -151,8 +177,8 @@ func (a *aggregatorCount) add(row []sqltypes.Value) error { return nil } -func (a *aggregatorCount) finish() sqltypes.Value { - return sqltypes.NewInt64(a.n) +func (a *aggregatorCount) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { + return sqltypes.NewInt64(a.n), nil } func (a *aggregatorCount) reset() { @@ -164,13 +190,13 @@ type aggregatorCountStar struct { n int64 } -func (a *aggregatorCountStar) add(_ []sqltypes.Value) error { +func (a *aggregatorCountStar) add([]sqltypes.Value) error { a.n++ return nil } -func (a *aggregatorCountStar) finish() sqltypes.Value { - return sqltypes.NewInt64(a.n) +func (a *aggregatorCountStar) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { + return sqltypes.NewInt64(a.n), nil } func (a *aggregatorCountStar) reset() { @@ -198,8 +224,8 @@ func (a *aggregatorMax) add(row []sqltypes.Value) (err error) { return a.minmax.Max(row[a.from]) } -func (a *aggregatorMinMax) finish() sqltypes.Value { - return a.minmax.Result() +func (a *aggregatorMinMax) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { + return a.minmax.Result(), nil } func (a *aggregatorMinMax) reset() { @@ -222,8 +248,8 @@ func (a *aggregatorSum) add(row []sqltypes.Value) error { return a.sum.Add(row[a.from]) } -func (a *aggregatorSum) finish() sqltypes.Value { - return a.sum.Result() +func (a *aggregatorSum) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { + return a.sum.Result(), nil } func (a *aggregatorSum) reset() { @@ -232,28 +258,51 @@ func (a *aggregatorSum) reset() { } type aggregatorScalar struct { - from int - current sqltypes.Value - init bool + from int + current sqltypes.Value + hasValue bool } func (a *aggregatorScalar) add(row []sqltypes.Value) error { - if !a.init { + if !a.hasValue { a.current = row[a.from] - a.init = true + a.hasValue = true } return nil } -func (a *aggregatorScalar) finish() sqltypes.Value { - return a.current +func (a *aggregatorScalar) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { + return a.current, nil } func (a *aggregatorScalar) reset() { a.current = sqltypes.NULL - a.init = false + a.hasValue = false +} + +type aggregatorConstant struct { + expr evalengine.Expr +} + +func (*aggregatorConstant) add([]sqltypes.Value) error { + return nil } +func (a *aggregatorConstant) finish(env *evalengine.ExpressionEnv, coll collations.ID) (sqltypes.Value, error) { + return eval(env, a.expr, coll) +} + +func eval(env *evalengine.ExpressionEnv, eexpr evalengine.Expr, coll collations.ID) (sqltypes.Value, error) { + v, err := env.Evaluate(eexpr) + if err != nil { + return sqltypes.Value{}, err + } + + return v.Value(coll), nil +} + +func (*aggregatorConstant) reset() {} + type aggregatorGroupConcat struct { from int type_ sqltypes.Type @@ -275,11 +324,11 @@ func (a *aggregatorGroupConcat) add(row []sqltypes.Value) error { return nil } -func (a *aggregatorGroupConcat) finish() sqltypes.Value { +func (a *aggregatorGroupConcat) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { if a.n == 0 { - return sqltypes.NULL + return sqltypes.NULL, nil } - return sqltypes.MakeTrusted(a.type_, a.concat) + return sqltypes.MakeTrusted(a.type_, a.concat), nil } func (a *aggregatorGroupConcat) reset() { @@ -301,19 +350,23 @@ func (a *aggregatorGtid) add(row []sqltypes.Value) error { return nil } -func (a *aggregatorGtid) finish() sqltypes.Value { +func (a *aggregatorGtid) finish(*evalengine.ExpressionEnv, collations.ID) (sqltypes.Value, error) { gtid := binlogdatapb.VGtid{ShardGtids: a.shards} - return sqltypes.NewVarChar(gtid.String()) + return sqltypes.NewVarChar(gtid.String()), nil } func (a *aggregatorGtid) reset() { a.shards = a.shards[:0] // safe to reuse because only the serialized form of a.shards is returned } -type aggregationState []aggregator +type aggregationState struct { + env *evalengine.ExpressionEnv + aggregators []aggregator + coll collations.ID +} -func (a aggregationState) add(row []sqltypes.Value) error { - for _, st := range a { +func (a *aggregationState) add(row []sqltypes.Value) error { + for _, st := range a.aggregators { if err := st.add(row); err != nil { return err } @@ -321,16 +374,20 @@ func (a aggregationState) add(row []sqltypes.Value) error { return nil } -func (a aggregationState) finish() (row []sqltypes.Value) { - row = make([]sqltypes.Value, 0, len(a)) - for _, st := range a { - row = append(row, st.finish()) +func (a *aggregationState) finish() ([]sqltypes.Value, error) { + row := make([]sqltypes.Value, 0, len(a.aggregators)) + for _, st := range a.aggregators { + v, err := st.finish(a.env, a.coll) + if err != nil { + return nil, err + } + row = append(row, v) } - return + return row, nil } -func (a aggregationState) reset() { - for _, st := range a { +func (a *aggregationState) reset() { + for _, st := range a.aggregators { st.reset() } } @@ -354,13 +411,16 @@ func isComparable(typ sqltypes.Type) bool { return false } -func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (aggregationState, []*querypb.Field, error) { +func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams, env *evalengine.ExpressionEnv, collation collations.ID) (*aggregationState, []*querypb.Field, error) { fields = slice.Map(fields, func(from *querypb.Field) *querypb.Field { return from.CloneVT() }) - agstate := make([]aggregator, len(fields)) + aggregators := make([]aggregator, len(fields)) for _, aggr := range aggregates { - sourceType := fields[aggr.Col].Type - targetType := aggr.typ(sourceType) + var sourceType querypb.Type + if aggr.Col < len(fields) { + sourceType = fields[aggr.Col].Type + } + targetType := aggr.typ(sourceType, env, collation) var ag aggregator var distinct = -1 @@ -444,22 +504,25 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg separator: separator, } + case opcode.AggregateConstant: + ag = &aggregatorConstant{expr: aggr.EExpr} + default: panic("BUG: unexpected Aggregation opcode") } - agstate[aggr.Col] = ag + aggregators[aggr.Col] = ag fields[aggr.Col].Type = targetType if aggr.Alias != "" { fields[aggr.Col].Name = aggr.Alias } } - for i, a := range agstate { + for i, a := range aggregators { if a == nil { - agstate[i] = &aggregatorScalar{from: i} + aggregators[i] = &aggregatorScalar{from: i} } } - return agstate, fields, nil + return &aggregationState{aggregators: aggregators, env: env, coll: collation}, fields, nil } diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 840f221b1fd..19f4d5a0519 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -29,7 +29,11 @@ func (cached *AggregateParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(112) + size += int64(128) + } + // field EExpr vitess.io/vitess/go/vt/vtgate/evalengine.Expr + if cc, ok := cached.EExpr.(cachedObject); ok { + size += cc.CachedSize(true) } // field Type vitess.io/vitess/go/vt/vtgate/evalengine.Type size += cached.Type.CachedSize(false) diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 28c09de0fd6..511d85a3d98 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -77,8 +77,9 @@ const ( AggregateCountStar AggregateGroupConcat AggregateAvg - AggregateUDF // This is an opcode used to represent UDFs - _NumOfOpCodes // This line must be last of the opcodes! + AggregateUDF // This is an opcode used to represent UDFs + AggregateConstant // This is an opcode used to represent constants that are not grouped + _NumOfOpCodes // This line must be last of the opcodes! ) // SupportedAggregates maps the list of supported aggregate @@ -97,6 +98,7 @@ var SupportedAggregates = map[string]AggregateOpcode{ "count_star": AggregateCountStar, "any_value": AggregateAnyValue, "group_concat": AggregateGroupConcat, + "constant_aggr": AggregateGroupConcat, } var AggregateName = map[AggregateOpcode]string{ @@ -111,6 +113,7 @@ var AggregateName = map[AggregateOpcode]string{ AggregateGroupConcat: "group_concat", AggregateAnyValue: "any_value", AggregateAvg: "avg", + AggregateConstant: "constant_aggr", } func (code AggregateOpcode) String() string { @@ -127,7 +130,7 @@ func (code AggregateOpcode) MarshalJSON() ([]byte, error) { return ([]byte)(fmt.Sprintf("\"%s\"", code.String())), nil } -// Type returns the opcode return sql type, and a bool telling is we are sure about this type or not +// SQLType returns the opcode return sql type, and a bool telling is we are sure about this type or not func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type { switch code { case AggregateUnassigned: @@ -154,7 +157,7 @@ func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type { return sqltypes.Int64 case AggregateGtid: return sqltypes.VarChar - case AggregateUDF: + case AggregateUDF, AggregateConstant: // TODO: we can probably figure out the type here return sqltypes.Unknown default: panic(code.String()) // we have a unit test checking we never reach here diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 324e531c4dd..2c2ffc290e6 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -139,6 +139,7 @@ func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVa bindVars, true, /*wantFields - we need the input fields types to correctly calculate the output types*/ ) + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) if err != nil { return nil, err } @@ -146,7 +147,7 @@ func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVa return oa.executeGroupBy(result) } - agg, fields, err := newAggregation(result.Fields, oa.Aggregates) + agg, fields, err := newAggregation(result.Fields, oa.Aggregates, env, vcursor.ConnCollation()) if err != nil { return nil, err } @@ -166,7 +167,11 @@ func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVa } if nextGroup { - out.Rows = append(out.Rows, agg.finish()) + values, err := agg.finish() + if err != nil { + return nil, err + } + out.Rows = append(out.Rows, values) agg.reset() } @@ -176,7 +181,11 @@ func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVa } if currentKey != nil { - out.Rows = append(out.Rows, agg.finish()) + values, err := agg.finish() + if err != nil { + return nil, err + } + out.Rows = append(out.Rows, values) } return out, nil @@ -238,12 +247,13 @@ func (oa *OrderedAggregate) TryStreamExecute(ctx context.Context, vcursor VCurso if len(oa.Aggregates) == 0 { return oa.executeStreamGroupBy(ctx, vcursor, bindVars, callback) } + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) cb := func(qr *sqltypes.Result) error { return callback(qr.Truncate(oa.TruncateColumnCount)) } - var agg aggregationState + var agg *aggregationState var fields []*querypb.Field var currentKey []sqltypes.Value @@ -251,7 +261,7 @@ func (oa *OrderedAggregate) TryStreamExecute(ctx context.Context, vcursor VCurso var err error if agg == nil && len(qr.Fields) != 0 { - agg, fields, err = newAggregation(qr.Fields, oa.Aggregates) + agg, fields, err = newAggregation(qr.Fields, oa.Aggregates, env, vcursor.ConnCollation()) if err != nil { return err } @@ -271,7 +281,11 @@ func (oa *OrderedAggregate) TryStreamExecute(ctx context.Context, vcursor VCurso if nextGroup { // this is a new grouping. let's yield the old one, and start a new - if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{agg.finish()}}); err != nil { + values, err := agg.finish() + if err != nil { + return err + } + if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{values}}); err != nil { return err } @@ -292,7 +306,11 @@ func (oa *OrderedAggregate) TryStreamExecute(ctx context.Context, vcursor VCurso } if currentKey != nil { - if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{agg.finish()}}); err != nil { + values, err := agg.finish() + if err != nil { + return err + } + if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{values}}); err != nil { return err } } @@ -305,8 +323,8 @@ func (oa *OrderedAggregate) GetFields(ctx context.Context, vcursor VCursor, bind if err != nil { return nil, err } - - _, fields, err := newAggregation(qr.Fields, oa.Aggregates) + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) + _, fields, err := newAggregation(qr.Fields, oa.Aggregates, env, vcursor.ConnCollation()) if err != nil { return nil, err } diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index c601654bced..30713496327 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -53,7 +53,7 @@ func TestOrderedAggregateExecute(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "", collations.MySQL8())}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, nil, "", collations.MySQL8())}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -85,7 +85,7 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) { )}, } - aggr := NewAggregateParam(AggregateSum, 1, "", collations.MySQL8()) + aggr := NewAggregateParam(AggregateSum, 1, nil, "", collations.MySQL8()) aggr.OrigOpcode = AggregateCountStar oa := &OrderedAggregate{ @@ -125,7 +125,7 @@ func TestMinMaxFailsCorrectly(t *testing.T) { )}, } - aggr := NewAggregateParam(AggregateMax, 0, "", collations.MySQL8()) + aggr := NewAggregateParam(AggregateMax, 0, nil, "", collations.MySQL8()) aggr.WCol = 1 oa := &ScalarAggregate{ Aggregates: []*AggregateParams{aggr}, @@ -154,7 +154,7 @@ func TestOrderedAggregateStreamExecute(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "", collations.MySQL8())}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, nil, "", collations.MySQL8())}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -193,7 +193,7 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "", collations.MySQL8())}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, nil, "", collations.MySQL8())}, GroupByKeys: []*GroupByParams{{KeyCol: 2}}, TruncateColumnCount: 2, Input: fp, @@ -231,7 +231,7 @@ func TestOrderedAggregateGetFields(t *testing.T) { oa := &OrderedAggregate{Input: fp} - got, err := oa.GetFields(context.Background(), nil, nil) + got, err := oa.GetFields(context.Background(), &noopVCursor{}, nil) assert.NoError(t, err) assert.Equal(t, got, input) } @@ -296,8 +296,8 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { )}, } - aggr1 := NewAggregateParam(AggregateCountDistinct, 1, "count(distinct col2)", collations.MySQL8()) - aggr2 := NewAggregateParam(AggregateSum, 2, "", collations.MySQL8()) + aggr1 := NewAggregateParam(AggregateCountDistinct, 1, nil, "count(distinct col2)", collations.MySQL8()) + aggr2 := NewAggregateParam(AggregateSum, 2, nil, "", collations.MySQL8()) aggr2.OrigOpcode = AggregateCountStar oa := &OrderedAggregate{ Aggregates: []*AggregateParams{aggr1, aggr2}, @@ -365,12 +365,12 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) { )}, } - aggr2 := NewAggregateParam(AggregateSum, 2, "", collations.MySQL8()) + aggr2 := NewAggregateParam(AggregateSum, 2, nil, "", collations.MySQL8()) aggr2.OrigOpcode = AggregateCountDistinct oa := &OrderedAggregate{ Aggregates: []*AggregateParams{ - NewAggregateParam(AggregateCountDistinct, 1, "count(distinct col2)", collations.MySQL8()), + NewAggregateParam(AggregateCountDistinct, 1, nil, "count(distinct col2)", collations.MySQL8()), aggr2}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, @@ -451,8 +451,8 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) { oa := &OrderedAggregate{ Aggregates: []*AggregateParams{ - NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct col2)", collations.MySQL8()), - NewAggregateParam(AggregateSum, 2, "", collations.MySQL8()), + NewAggregateParam(AggregateSumDistinct, 1, nil, "sum(distinct col2)", collations.MySQL8()), + NewAggregateParam(AggregateSum, 2, nil, "", collations.MySQL8()), }, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, @@ -495,7 +495,7 @@ func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct col2)", collations.MySQL8())}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSumDistinct, 1, nil, "sum(distinct col2)", collations.MySQL8())}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -527,7 +527,7 @@ func TestOrderedAggregateKeysFail(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "", collations.MySQL8())}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, nil, "", collations.MySQL8())}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -557,7 +557,7 @@ func TestOrderedAggregateMergeFail(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "", collations.MySQL8())}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, nil, "", collations.MySQL8())}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -618,7 +618,7 @@ func TestOrderedAggregateExecuteGtid(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateGtid, 1, "vgtid", collations.MySQL8())}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateGtid, 1, nil, "vgtid", collations.MySQL8())}, TruncateColumnCount: 2, Input: fp, } @@ -651,7 +651,7 @@ func TestCountDistinctOnVarchar(t *testing.T) { )}, } - aggr := NewAggregateParam(AggregateCountDistinct, 1, "count(distinct c2)", collations.MySQL8()) + aggr := NewAggregateParam(AggregateCountDistinct, 1, nil, "count(distinct c2)", collations.MySQL8()) aggr.WCol = 2 oa := &OrderedAggregate{ Aggregates: []*AggregateParams{aggr}, @@ -711,7 +711,7 @@ func TestCountDistinctOnVarcharWithNulls(t *testing.T) { )}, } - aggr := NewAggregateParam(AggregateCountDistinct, 1, "count(distinct c2)", collations.MySQL8()) + aggr := NewAggregateParam(AggregateCountDistinct, 1, nil, "count(distinct c2)", collations.MySQL8()) aggr.WCol = 2 oa := &OrderedAggregate{ Aggregates: []*AggregateParams{aggr}, @@ -773,7 +773,7 @@ func TestSumDistinctOnVarcharWithNulls(t *testing.T) { )}, } - aggr := NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct c2)", collations.MySQL8()) + aggr := NewAggregateParam(AggregateSumDistinct, 1, nil, "sum(distinct c2)", collations.MySQL8()) aggr.WCol = 2 oa := &OrderedAggregate{ Aggregates: []*AggregateParams{aggr}, @@ -839,8 +839,8 @@ func TestMultiDistinct(t *testing.T) { oa := &OrderedAggregate{ Aggregates: []*AggregateParams{ - NewAggregateParam(AggregateCountDistinct, 1, "count(distinct c2)", collations.MySQL8()), - NewAggregateParam(AggregateSumDistinct, 2, "sum(distinct c3)", collations.MySQL8()), + NewAggregateParam(AggregateCountDistinct, 1, nil, "count(distinct c2)", collations.MySQL8()), + NewAggregateParam(AggregateSumDistinct, 2, nil, "sum(distinct c3)", collations.MySQL8()), }, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, @@ -898,7 +898,7 @@ func TestOrderedAggregateCollate(t *testing.T) { collationEnv := collations.MySQL8() collationID, _ := collationEnv.LookupID("utf8mb4_0900_ai_ci") oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "", collationEnv)}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, nil, "", collationEnv)}, GroupByKeys: []*GroupByParams{{KeyCol: 0, Type: evalengine.NewType(sqltypes.Unknown, collationID)}}, Input: fp, } @@ -937,7 +937,7 @@ func TestOrderedAggregateCollateAS(t *testing.T) { collationEnv := collations.MySQL8() collationID, _ := collationEnv.LookupID("utf8mb4_0900_as_ci") oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "", collationEnv)}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, nil, "", collationEnv)}, GroupByKeys: []*GroupByParams{{KeyCol: 0, Type: evalengine.NewType(sqltypes.Unknown, collationID)}}, Input: fp, } @@ -978,7 +978,7 @@ func TestOrderedAggregateCollateKS(t *testing.T) { collationEnv := collations.MySQL8() collationID, _ := collationEnv.LookupID("utf8mb4_ja_0900_as_cs_ks") oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "", collationEnv)}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, nil, "", collationEnv)}, GroupByKeys: []*GroupByParams{{KeyCol: 0, Type: evalengine.NewType(sqltypes.Unknown, collationID)}}, Input: fp, } @@ -1059,7 +1059,7 @@ func TestGroupConcatWithAggrOnEngine(t *testing.T) { for _, tcase := range tcases { t.Run(tcase.name, func(t *testing.T) { fp := &fakePrimitive{results: []*sqltypes.Result{tcase.inputResult}} - agp := NewAggregateParam(AggregateGroupConcat, 1, "group_concat(c2)", collations.MySQL8()) + agp := NewAggregateParam(AggregateGroupConcat, 1, nil, "group_concat(c2)", collations.MySQL8()) agp.Func = &sqlparser.GroupConcatExpr{Separator: ","} oa := &OrderedAggregate{ Aggregates: []*AggregateParams{agp}, @@ -1140,7 +1140,7 @@ func TestGroupConcat(t *testing.T) { for _, tcase := range tcases { t.Run(tcase.name, func(t *testing.T) { fp := &fakePrimitive{results: []*sqltypes.Result{tcase.inputResult}} - agp := NewAggregateParam(AggregateGroupConcat, 1, "", collations.MySQL8()) + agp := NewAggregateParam(AggregateGroupConcat, 1, nil, "", collations.MySQL8()) agp.Func = &sqlparser.GroupConcatExpr{Separator: ","} oa := &OrderedAggregate{ Aggregates: []*AggregateParams{agp}, diff --git a/go/vt/vtgate/engine/scalar_aggregation.go b/go/vt/vtgate/engine/scalar_aggregation.go index e33204f5c58..a3270f33e8e 100644 --- a/go/vt/vtgate/engine/scalar_aggregation.go +++ b/go/vt/vtgate/engine/scalar_aggregation.go @@ -20,6 +20,8 @@ import ( "context" "sync" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -63,8 +65,9 @@ func (sa *ScalarAggregate) GetFields(ctx context.Context, vcursor VCursor, bindV if err != nil { return nil, err } + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) - _, fields, err := newAggregation(qr.Fields, sa.Aggregates) + _, fields, err := newAggregation(qr.Fields, sa.Aggregates, env, vcursor.ConnCollation()) if err != nil { return nil, err } @@ -84,8 +87,9 @@ func (sa *ScalarAggregate) TryExecute(ctx context.Context, vcursor VCursor, bind if err != nil { return nil, err } + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) - agg, fields, err := newAggregation(result.Fields, sa.Aggregates) + agg, fields, err := newAggregation(result.Fields, sa.Aggregates, env, vcursor.ConnCollation()) if err != nil { return nil, err } @@ -96,9 +100,13 @@ func (sa *ScalarAggregate) TryExecute(ctx context.Context, vcursor VCursor, bind } } + values, err := agg.finish() + if err != nil { + return nil, err + } out := &sqltypes.Result{ Fields: fields, - Rows: [][]sqltypes.Value{agg.finish()}, + Rows: [][]sqltypes.Value{values}, } return out.Truncate(sa.TruncateColumnCount), nil } @@ -108,9 +116,10 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor cb := func(qr *sqltypes.Result) error { return callback(qr.Truncate(sa.TruncateColumnCount)) } + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) var mu sync.Mutex - var agg aggregationState + var agg *aggregationState var fields []*querypb.Field fieldsSent := !wantfields @@ -123,7 +132,7 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor if agg == nil && len(result.Fields) != 0 { var err error - agg, fields, err = newAggregation(result.Fields, sa.Aggregates) + agg, fields, err = newAggregation(result.Fields, sa.Aggregates, env, vcursor.ConnCollation()) if err != nil { return err } @@ -146,7 +155,11 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor return err } - return cb(&sqltypes.Result{Rows: [][]sqltypes.Value{agg.finish()}}) + values, err := agg.finish() + if err != nil { + return err + } + return cb(&sqltypes.Result{Rows: [][]sqltypes.Value{values}}) } // Inputs implements the Primitive interface diff --git a/go/vt/vtgate/engine/scalar_aggregation_test.go b/go/vt/vtgate/engine/scalar_aggregation_test.go index 6fa0c8aecb8..777a2f628b4 100644 --- a/go/vt/vtgate/engine/scalar_aggregation_test.go +++ b/go/vt/vtgate/engine/scalar_aggregation_test.go @@ -276,8 +276,8 @@ func TestScalarDistinctAggrOnEngine(t *testing.T) { oa := &ScalarAggregate{ Aggregates: []*AggregateParams{ - NewAggregateParam(AggregateCountDistinct, 0, "count(distinct value)", collations.MySQL8()), - NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct value)", collations.MySQL8()), + NewAggregateParam(AggregateCountDistinct, 0, nil, "count(distinct value)", collations.MySQL8()), + NewAggregateParam(AggregateSumDistinct, 1, nil, "sum(distinct value)", collations.MySQL8()), }, Input: fp, } @@ -314,9 +314,9 @@ func TestScalarDistinctPushedDown(t *testing.T) { "8|90", )}} - countAggr := NewAggregateParam(AggregateSum, 0, "count(distinct value)", collations.MySQL8()) + countAggr := NewAggregateParam(AggregateSum, 0, nil, "count(distinct value)", collations.MySQL8()) countAggr.OrigOpcode = AggregateCountDistinct - sumAggr := NewAggregateParam(AggregateSum, 1, "sum(distinct value)", collations.MySQL8()) + sumAggr := NewAggregateParam(AggregateSum, 1, nil, "sum(distinct value)", collations.MySQL8()) sumAggr.OrigOpcode = AggregateSumDistinct oa := &ScalarAggregate{ Aggregates: []*AggregateParams{ diff --git a/go/vt/vtgate/evalengine/eval_result.go b/go/vt/vtgate/evalengine/eval_result.go index 5c1973d8eb1..161c4ac5ca9 100644 --- a/go/vt/vtgate/evalengine/eval_result.go +++ b/go/vt/vtgate/evalengine/eval_result.go @@ -41,8 +41,13 @@ func (er EvalResult) Value(id collations.ID) sqltypes.Value { return evalToSQLValue(er.v) } - dst, err := charset.Convert(nil, colldata.Lookup(id).Charset(), str.bytes, colldata.Lookup(str.col.Collation).Charset()) - if err != nil { + lookup := colldata.Lookup(id) + var err error + var dst []byte + if lookup != nil { + dst, err = charset.Convert(nil, lookup.Charset(), str.bytes, colldata.Lookup(str.col.Collation).Charset()) + } + if lookup == nil || err != nil { // If we can't convert, we just return what we have, but it's going // to be invalidly encoded. Should normally never happen as only utf8mb4 // is really supported for the connection character set anyway and all diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 2a26f114a1d..b86df7ca97b 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -342,9 +342,23 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega case opcode.AggregateUDF: message := fmt.Sprintf("Aggregate UDF '%s' must be pushed down to MySQL", sqlparser.String(aggr.Original.Expr)) return nil, vterrors.VT12001(message) + case opcode.AggregateConstant: + // For AnyValue aggregations (literals, parameters), translate to evalengine + // This allows evaluation even when no input rows are present (empty result sets) + cfg := &evalengine.Config{ + Collation: ctx.VSchema.ConnCollation(), + Environment: ctx.VSchema.Environment(), + ResolveColumn: func(name *sqlparser.ColName) (int, error) { return aggr.ColOffset, nil }, + } + expr, err := evalengine.Translate(aggr.Original.Expr, cfg) + if err != nil { + return nil, err + } + aggregates = append(aggregates, engine.NewAggregateParam(aggr.OpCode, aggr.ColOffset, expr, aggr.Alias, ctx.VSchema.Environment().CollationEnv())) + continue } - aggrParam := engine.NewAggregateParam(aggr.OpCode, aggr.ColOffset, aggr.Alias, ctx.VSchema.Environment().CollationEnv()) + aggrParam := engine.NewAggregateParam(aggr.OpCode, aggr.ColOffset, nil, aggr.Alias, ctx.VSchema.Environment().CollationEnv()) aggrParam.Func = aggr.Func if gcFunc, isGc := aggrParam.Func.(*sqlparser.GroupConcatExpr); isGc && gcFunc.Separator == "" { gcFunc.Separator = sqlparser.GroupConcatDefaultSeparator diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing_helper.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing_helper.go index 2fbf5c32311..fbd829f16de 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing_helper.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing_helper.go @@ -130,7 +130,7 @@ func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) er return nil case opcode.AggregateCount, opcode.AggregateSum: return ab.handleAggrWithCountStarMultiplier(ctx, aggr) - case opcode.AggregateMax, opcode.AggregateMin, opcode.AggregateAnyValue: + case opcode.AggregateMax, opcode.AggregateMin, opcode.AggregateAnyValue, opcode.AggregateConstant: return ab.handlePushThroughAggregation(ctx, aggr) case opcode.AggregateGroupConcat: f := aggr.Func.(*sqlparser.GroupConcatExpr) diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 1a96e81b66e..f0e0b5bab27 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -79,6 +79,16 @@ func (a *Aggregator) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser return newFilter(a, expr) } +// createNonGroupingAggr creates the appropriate aggregation for a non-grouping, non-aggregation column +// If the expression is constant, it returns AggregateConstant, otherwise AggregateAnyValue +func createNonGroupingAggr(expr *sqlparser.AliasedExpr) Aggr { + if sqlparser.IsConstant(expr.Expr) { + return NewAggr(opcode.AggregateConstant, nil, expr, expr.ColumnName()) + } else { + return NewAggr(opcode.AggregateAnyValue, nil, expr, expr.ColumnName()) + } +} + func (a *Aggregator) addColumnWithoutPushing(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int { offset := len(a.Columns) a.Columns = append(a.Columns, expr) @@ -94,12 +104,12 @@ func (a *Aggregator) addColumnWithoutPushing(ctx *plancontext.PlanningContext, e aggr = createAggrFromAggrFunc(e, expr) case *sqlparser.FuncExpr: if ctx.IsAggr(e) { - aggr = NewAggr(opcode.AggregateUDF, nil, expr, expr.As.String()) - } else { - aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String()) + aggr = NewAggr(opcode.AggregateUDF, nil, expr, expr.ColumnName()) } - default: - aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String()) + } + + if aggr.Alias == "" { + aggr = createNonGroupingAggr(expr) } aggr.ColOffset = offset a.Aggregations = append(a.Aggregations, aggr) @@ -176,7 +186,7 @@ func (a *Aggregator) AddColumn(ctx *plancontext.PlanningContext, reuse bool, gro } if !groupBy { - aggr := NewAggr(opcode.AggregateAnyValue, nil, ae, ae.As.String()) + aggr := createNonGroupingAggr(ae) aggr.ColOffset = len(a.Columns) a.Aggregations = append(a.Aggregations, aggr) } @@ -417,7 +427,7 @@ func (aggr Aggr) setPushColumn(exprs []sqlparser.Expr) { func (aggr Aggr) getPushColumn() sqlparser.Expr { switch aggr.OpCode { - case opcode.AggregateAnyValue: + case opcode.AggregateAnyValue, opcode.AggregateConstant: return aggr.Original.Expr case opcode.AggregateCountStar: return sqlparser.NewIntLiteral("1") @@ -436,14 +446,14 @@ func (aggr Aggr) getPushColumn() sqlparser.Expr { func (aggr Aggr) getPushColumnExprs() []sqlparser.Expr { switch aggr.OpCode { - case opcode.AggregateAnyValue: + case opcode.AggregateAnyValue, opcode.AggregateConstant: return []sqlparser.Expr{aggr.Original.Expr} case opcode.AggregateCountStar: return []sqlparser.Expr{sqlparser.NewIntLiteral("1")} - case opcode.AggregateUDF: - // AggregateUDFs can't be evaluated on the vtgate. So either we are able to push everything down, or we will have to fail the query. - return nil default: + if aggr.Func == nil { + return nil + } return aggr.Func.GetArgs() } } diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 073e320fff1..961eaeeb21e 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -404,19 +404,19 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte // Here we go over the expressions we are returning. Since we know we are aggregating, // all expressions have to be either grouping expressions or aggregate expressions. // If we find an expression that is neither, we treat is as a special aggregation function AggrRandom - for _, expr := range qp.SelectExprs { - aliasedExpr, err := expr.GetAliasedExpr() + for _, selectExpr := range qp.SelectExprs { + aliasedExpr, err := selectExpr.GetAliasedExpr() if err != nil { panic(err) } - if !ctx.ContainsAggr(expr.Col) { - getExpr, err := expr.GetExpr() + if !ctx.ContainsAggr(selectExpr.Col) { + getExpr, err := selectExpr.GetExpr() if err != nil { panic(err) } if !qp.isExprInGroupByExprs(ctx, getExpr) { - aggr := NewAggr(opcode.AggregateAnyValue, nil, aliasedExpr, aliasedExpr.ColumnName()) + aggr := createNonGroupingAggr(aliasedExpr) out = append(out, aggr) } continue @@ -462,7 +462,7 @@ func (qp *QueryProjection) extractAggr( return true } if !qp.isExprInGroupByExprs(ctx, ex) { - aggr := NewAggr(opcode.AggregateAnyValue, nil, aeWrap(ex), "") + aggr := createNonGroupingAggr(aeWrap(ex)) addAggr(aggr) } return false diff --git a/go/vt/vtgate/planbuilder/show.go b/go/vt/vtgate/planbuilder/show.go index 6195cb16a90..1adcfeae7e4 100644 --- a/go/vt/vtgate/planbuilder/show.go +++ b/go/vt/vtgate/planbuilder/show.go @@ -571,7 +571,7 @@ func buildShowVGtidPlan(show *sqlparser.ShowBasic, vschema plancontext.VSchema) } return &engine.OrderedAggregate{ Aggregates: []*engine.AggregateParams{ - engine.NewAggregateParam(popcode.AggregateGtid, 1, "global vgtid_executed", vschema.Environment().CollationEnv()), + engine.NewAggregateParam(popcode.AggregateGtid, 1, nil, "global vgtid_executed", vschema.Environment().CollationEnv()), }, TruncateColumnCount: 2, Input: send, diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index a1b53981128..911749a932d 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -2432,7 +2432,7 @@ { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "any_value(0) AS 1, sum_count(1) AS count(id)", + "Aggregates": "constant_aggr(1) AS 1, sum_count(1) AS count(id)", "Inputs": [ { "OperatorType": "VindexLookup", @@ -2556,7 +2556,7 @@ { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "any_value(0) AS 1, sum_count(1) AS count(id)", + "Aggregates": "constant_aggr(1) AS 1, sum_count(1) AS count(id)", "Inputs": [ { "OperatorType": "Route", @@ -3972,7 +3972,7 @@ { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "any_value(0) AS id, sum_count_star(1) AS count(*), any_value(2)", + "Aggregates": "any_value(0) AS id, sum_count_star(1) AS count(*), constant_aggr(1) AS 1", "Inputs": [ { "OperatorType": "Route", @@ -5684,7 +5684,7 @@ { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "any_value(0), sum_count_star(1) AS count(*)", + "Aggregates": "constant_aggr(1) AS 1, sum_count_star(1) AS count(*)", "Inputs": [ { "OperatorType": "Route", diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index c2fb8885f51..1885d2be7c9 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -409,7 +409,7 @@ { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "any_value(0) AS id, sum_count_star(1) AS count(*), any_value(2)", + "Aggregates": "any_value(0) AS id, sum_count_star(1) AS count(*), constant_aggr(1) AS 1", "Inputs": [ { "OperatorType": "Route", @@ -2068,7 +2068,7 @@ { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "any_value(0), sum(1) AS sum(num)", + "Aggregates": "constant_aggr(1000) AS 1000, sum(1) AS sum(num)", "Inputs": [ { "OperatorType": "Projection", diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 34aa538fafb..f56596ea424 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -753,7 +753,7 @@ "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", - "Aggregates": "any_value(0) AS 1", + "Aggregates": "constant_aggr(1) AS 1", "GroupBy": "1, 4", "ResultColumns": 1, "Inputs": [ diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index 60e6857d53d..900697420f0 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -1573,7 +1573,7 @@ { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "sum(0) AS sum(ps_supplycost * ps_availqty), any_value(1)", + "Aggregates": "sum(0) AS sum(ps_supplycost * ps_availqty), constant_aggr(0.00001000000) AS 0.00001000000", "Inputs": [ { "OperatorType": "Projection", @@ -1992,7 +1992,7 @@ { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "any_value(0), sum(1) AS sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end), sum(2) AS sum(l_extendedprice * (1 - l_discount))", + "Aggregates": "constant_aggr(100.00) AS 100.00, sum(1) AS sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end), sum(2) AS sum(l_extendedprice * (1 - l_discount))", "Inputs": [ { "OperatorType": "Projection", diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index 60105355fea..0cf8defd671 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -1969,7 +1969,7 @@ { "OperatorType": "Aggregate", "Variant": "Ordered", - "Aggregates": "any_value(0) AS type, any_value(1) AS id", + "Aggregates": "constant_aggr('a') AS type, constant_aggr(0) AS id", "GroupBy": "2 COLLATE utf8mb4_0900_ai_ci", "ResultColumns": 2, "Inputs": [ diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_plan.go b/go/vt/vttablet/tabletmanager/vdiff/table_plan.go index 49f2713b58a..880cd010cb1 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_plan.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_plan.go @@ -122,6 +122,7 @@ func (td *tableDiffer) buildTablePlan(dbClient binlogplayer.DBClient, dbName str aggregates = append(aggregates, engine.NewAggregateParam( /*opcode*/ opcode.AggregateSum, /*offset*/ sourceSelect.GetColumnCount()-1, + nil, /*alias*/ "", collationEnv), ) } diff --git a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ_test.go b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ_test.go index 97dd4406b3b..5b0b35f9fc0 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ_test.go +++ b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ_test.go @@ -686,8 +686,8 @@ func TestBuildPlanSuccess(t *testing.T) { Direction: sqlparser.AscOrder, }}, aggregates: []*engine.AggregateParams{ - engine.NewAggregateParam(opcode.AggregateSum, 2, "", collations.MySQL8()), - engine.NewAggregateParam(opcode.AggregateSum, 3, "", collations.MySQL8()), + engine.NewAggregateParam(opcode.AggregateSum, 2, nil, "", collations.MySQL8()), + engine.NewAggregateParam(opcode.AggregateSum, 3, nil, "", collations.MySQL8()), }, }, }, { diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 95fd12ec6f4..22a806564a4 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -27,6 +27,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/mysql/config" + "google.golang.org/protobuf/encoding/prototext" "vitess.io/vitess/go/mysql/replication" @@ -147,8 +149,7 @@ type tableDiffer struct { sourcePrimitive engine.Primitive targetPrimitive engine.Primitive - collationEnv *collations.Environment - parser *sqlparser.Parser + env *vtenv.Environment } // shardStreamer streams rows from one shard. This works for @@ -667,9 +668,8 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer return nil, fmt.Errorf("unexpected: %v", sqlparser.String(statement)) } td := &tableDiffer{ - targetTable: table.Name, - collationEnv: df.env.CollationEnv(), - parser: df.env.Parser(), + targetTable: table.Name, + env: df.env, } sourceSelect := &sqlparser.Select{} targetSelect := &sqlparser.Select{} @@ -710,6 +710,7 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer aggregates = append(aggregates, engine.NewAggregateParam( /*opcode*/ opcode.AggregateSum, /*offset*/ sourceSelect.GetColumnCount()-1, + nil, /*alias*/ "", df.env.CollationEnv())) } } @@ -772,7 +773,7 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer if len(aggregates) != 0 { td.sourcePrimitive = &engine.OrderedAggregate{ Aggregates: aggregates, - GroupByKeys: pkColsToGroupByParams(td.pkCols, td.collationEnv), + GroupByKeys: pkColsToGroupByParams(td.pkCols, td.env.CollationEnv()), Input: td.sourcePrimitive, } } @@ -1093,12 +1094,12 @@ type primitiveExecutor struct { err error } -func newPrimitiveExecutor(ctx context.Context, prim engine.Primitive) *primitiveExecutor { +func newPrimitiveExecutor(ctx context.Context, prim engine.Primitive, env *vtenv.Environment) *primitiveExecutor { pe := &primitiveExecutor{ prim: prim, resultch: make(chan *sqltypes.Result, 1), } - vcursor := &contextVCursor{} + vcursor := newVCursor(env) go func() { defer close(pe.resultch) pe.err = vcursor.StreamExecutePrimitive(ctx, pe.prim, make(map[string]*querypb.BindVariable), true, func(qr *sqltypes.Result) error { @@ -1180,8 +1181,8 @@ func humanInt(n int64) string { // nolint // tableDiffer func (td *tableDiffer) diff(ctx context.Context, rowsToCompare *int64, debug, onlyPks bool, maxExtraRowsToCompare int) (*DiffReport, error) { - sourceExecutor := newPrimitiveExecutor(ctx, td.sourcePrimitive) - targetExecutor := newPrimitiveExecutor(ctx, td.targetPrimitive) + sourceExecutor := newPrimitiveExecutor(ctx, td.sourcePrimitive, td.env) + targetExecutor := newPrimitiveExecutor(ctx, td.targetPrimitive, td.env) dr := &DiffReport{} var sourceRow, targetRow []sqltypes.Value var err error @@ -1321,7 +1322,7 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []com if col.collation == collations.Unknown { collationID = collations.CollationBinaryID } - c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.collationEnv, collationID, col.values) + c, err = evalengine.NullsafeCompare(sourceRow[compareIndex], targetRow[compareIndex], td.env.CollationEnv(), collationID, col.values) if err != nil { return 0, err } @@ -1335,7 +1336,7 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []com func (td *tableDiffer) genRowDiff(queryStmt string, row []sqltypes.Value, debug, onlyPks bool) (*RowDiff, error) { drp := &RowDiff{} drp.Row = make(map[string]sqltypes.Value) - statement, err := td.parser.Parse(queryStmt) + statement, err := td.env.Parser().Parse(queryStmt) if err != nil { return nil, err } @@ -1404,6 +1405,11 @@ func (td *tableDiffer) genDebugQueryDiff(sel *sqlparser.Select, row []sqltypes.V // contextVCursor satisfies VCursor interface type contextVCursor struct { engine.VCursor + env *vtenv.Environment +} + +func newVCursor(env *vtenv.Environment) *contextVCursor { + return &contextVCursor{env: env} } func (vc *contextVCursor) ConnCollation() collations.ID { @@ -1418,6 +1424,18 @@ func (vc *contextVCursor) StreamExecutePrimitive(ctx context.Context, primitive return primitive.TryStreamExecute(ctx, vc, bindVars, wantfields, callback) } +func (vc *contextVCursor) TimeZone() *time.Location { + return time.Local +} + +func (vc *contextVCursor) SQLMode() string { + return config.DefaultSQLMode +} + +func (vc *contextVCursor) Environment() *vtenv.Environment { + return vc.env +} + // ----------------------------------------------------------------- // Utility functions diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index 4bbaed35c3f..1c041584c3f 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -38,8 +38,8 @@ import ( ) func TestVDiffPlanSuccess(t *testing.T) { + env := vtenv.NewTestEnv() collationEnv := collations.MySQL8() - parser := sqlparser.NewTestParser() schm := &tabletmanagerdatapb.SchemaDefinition{ TableDefinitions: []*tabletmanagerdatapb.TableDefinition{{ Name: "t1", @@ -100,8 +100,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { input: &binlogdatapb.Rule{ @@ -119,8 +118,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { input: &binlogdatapb.Rule{ @@ -138,8 +136,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { input: &binlogdatapb.Rule{ @@ -157,8 +154,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{1}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { input: &binlogdatapb.Rule{ @@ -176,8 +172,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // non-pk text column. @@ -196,8 +191,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // non-pk text column, different order. @@ -216,8 +210,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{1}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // pk text column. @@ -236,8 +229,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // pk text column, different order. @@ -256,8 +248,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{1}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // text column as expression. @@ -276,8 +267,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{1}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{1, collationEnv.DefaultConnectionCharset(), nil, false}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { input: &binlogdatapb.Rule{ @@ -294,8 +284,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0, 1}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}, {1, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // in_keyrange @@ -314,8 +303,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // in_keyrange on RHS of AND. @@ -335,8 +323,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // in_keyrange on LHS of AND. @@ -356,8 +343,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // in_keyrange on cascaded AND expression @@ -377,8 +363,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // in_keyrange parenthesized @@ -398,8 +383,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // group by @@ -418,8 +402,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { // aggregations @@ -438,15 +421,14 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: &engine.OrderedAggregate{ Aggregates: []*engine.AggregateParams{ - engine.NewAggregateParam(opcode.AggregateSum, 2, "", collationEnv), - engine.NewAggregateParam(opcode.AggregateSum, 3, "", collationEnv), + engine.NewAggregateParam(opcode.AggregateSum, 2, nil, "", collationEnv), + engine.NewAggregateParam(opcode.AggregateSum, 3, nil, "", collationEnv), }, GroupByKeys: []*engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1, CollationEnv: collations.MySQL8()}}, Input: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), }, targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }, { input: &binlogdatapb.Rule{ @@ -464,8 +446,7 @@ func TestVDiffPlanSuccess(t *testing.T) { selectPks: []int{0}, sourcePrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, collations.Unknown, nil, true}}, collationEnv), - collationEnv: collationEnv, - parser: parser, + env: env, }, }}