Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions go/test/endtoend/vtgate/gen4/gen4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,34 +132,28 @@ func TestDistinctAggregationFunc(t *testing.T) {
defer closer()

// insert some data.
utils.Exec(t, mcmp.VtConn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`)
mcmp.Exec(`insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`)

// count on primary vindex
utils.AssertMatches(t, mcmp.VtConn, `select tcol1, count(distinct id) from t2 group by tcol1`,
`[[VARCHAR("A") INT64(3)] [VARCHAR("B") INT64(3)] [VARCHAR("C") INT64(2)]]`)
mcmp.Exec(`select tcol1, count(distinct id) from t2 group by tcol1`)

// count on any column
utils.AssertMatches(t, mcmp.VtConn, `select tcol1, count(distinct tcol2) from t2 group by tcol1`,
`[[VARCHAR("A") INT64(2)] [VARCHAR("B") INT64(2)] [VARCHAR("C") INT64(1)]]`)
mcmp.Exec(`select tcol1, count(distinct tcol2) from t2 group by tcol1`)

// sum of columns
utils.AssertMatches(t, mcmp.VtConn, `select sum(id), sum(tcol1) from t2`,
`[[DECIMAL(36) FLOAT64(0)]]`)
mcmp.Exec(`select sum(id), sum(tcol1) from t2`)

// sum on primary vindex
utils.AssertMatches(t, mcmp.VtConn, `select tcol1, sum(distinct id) from t2 group by tcol1`,
`[[VARCHAR("A") DECIMAL(9)] [VARCHAR("B") DECIMAL(15)] [VARCHAR("C") DECIMAL(12)]]`)
mcmp.Exec(`select tcol1, sum(distinct id) from t2 group by tcol1`)

// sum on any column
utils.AssertMatches(t, mcmp.VtConn, `select tcol1, sum(distinct tcol2) from t2 group by tcol1`,
`[[VARCHAR("A") DECIMAL(0)] [VARCHAR("B") DECIMAL(0)] [VARCHAR("C") DECIMAL(0)]]`)
mcmp.Exec(`select tcol1, sum(distinct tcol2) from t2 group by tcol1`)

// insert more data to get values on sum
utils.Exec(t, mcmp.VtConn, `insert into t2(id, tcol1, tcol2) values (9, 'AA', null),(10, 'AA', '4'),(11, 'AA', '4'),(12, null, '5'),(13, null, '6'),(14, 'BB', '10'),(15, 'BB', '20'),(16, 'BB', 'X')`)
mcmp.Exec(`insert into t2(id, tcol1, tcol2) values (9, 'AA', null),(10, 'AA', '4'),(11, 'AA', '4'),(12, null, '5'),(13, null, '6'),(14, 'BB', '10'),(15, 'BB', '20'),(16, 'BB', 'X')`)

// multi distinct
utils.AssertMatches(t, mcmp.VtConn, `select tcol1, count(distinct tcol2), sum(distinct tcol2) from t2 group by tcol1`,
`[[NULL INT64(2) DECIMAL(11)] [VARCHAR("A") INT64(2) DECIMAL(0)] [VARCHAR("AA") INT64(1) DECIMAL(4)] [VARCHAR("B") INT64(2) DECIMAL(0)] [VARCHAR("BB") INT64(3) DECIMAL(30)] [VARCHAR("C") INT64(1) DECIMAL(0)]]`)
mcmp.Exec(`select tcol1, count(distinct tcol2), sum(distinct tcol2) from t2 group by tcol1`)
}

func TestDistinct(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,17 @@ func TestAggregateTypes(t *testing.T) {
mcmp.Exec("insert into aggr_test(id, val1, val2) values(6,'d',null), (7,'e',null), (8,'E',1)")
mcmp.AssertMatches("select val1, count(distinct val2), count(*) from aggr_test group by val1", `[[VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("d") INT64(0) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)]]`)
mcmp.AssertMatches("select val1, sum(distinct val2), sum(val2) from aggr_test group by val1", `[[VARCHAR("a") DECIMAL(1) DECIMAL(2)] [VARCHAR("b") DECIMAL(1) DECIMAL(1)] [VARCHAR("c") DECIMAL(7) DECIMAL(7)] [VARCHAR("d") NULL NULL] [VARCHAR("e") DECIMAL(1) DECIMAL(1)]]`)
mcmp.AssertMatches("select val1, count(distinct val2) k, count(*) from aggr_test group by val1 order by k desc, val1", `[[VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)] [VARCHAR("d") INT64(0) INT64(1)]]`)
mcmp.AssertMatches("select val1, count(distinct val2) k, count(*) from aggr_test group by val1 order by k desc, val1 limit 4", `[[VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)]]`)
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ val1, count(distinct val2) k, count(*) from aggr_test group by val1 order by k desc, val1", `[[VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)] [VARCHAR("d") INT64(0) INT64(1)]]`)
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ val1, count(distinct val2) k, count(*) from aggr_test group by val1 order by k desc, val1 limit 4", `[[VARCHAR("c") INT64(2) INT64(2)] [VARCHAR("a") INT64(1) INT64(2)] [VARCHAR("b") INT64(1) INT64(1)] [VARCHAR("e") INT64(1) INT64(2)]]`)

mcmp.AssertMatches("select ascii(val1) as a, count(*) from aggr_test group by a", `[[INT32(65) INT64(1)] [INT32(69) INT64(1)] [INT32(97) INT64(1)] [INT32(98) INT64(1)] [INT32(99) INT64(2)] [INT32(100) INT64(1)] [INT32(101) INT64(1)]]`)
mcmp.AssertMatches("select ascii(val1) as a, count(*) from aggr_test group by a order by a", `[[INT32(65) INT64(1)] [INT32(69) INT64(1)] [INT32(97) INT64(1)] [INT32(98) INT64(1)] [INT32(99) INT64(2)] [INT32(100) INT64(1)] [INT32(101) INT64(1)]]`)
mcmp.AssertMatches("select ascii(val1) as a, count(*) from aggr_test group by a order by 2, a", `[[INT32(65) INT64(1)] [INT32(69) INT64(1)] [INT32(97) INT64(1)] [INT32(98) INT64(1)] [INT32(100) INT64(1)] [INT32(101) INT64(1)] [INT32(99) INT64(2)]]`)
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ ascii(val1) as a, count(*) from aggr_test group by a order by a", `[[INT32(65) INT64(1)] [INT32(69) INT64(1)] [INT32(97) INT64(1)] [INT32(98) INT64(1)] [INT32(99) INT64(2)] [INT32(100) INT64(1)] [INT32(101) INT64(1)]]`)
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ ascii(val1) as a, count(*) from aggr_test group by a order by 2, a", `[[INT32(65) INT64(1)] [INT32(69) INT64(1)] [INT32(97) INT64(1)] [INT32(98) INT64(1)] [INT32(100) INT64(1)] [INT32(101) INT64(1)] [INT32(99) INT64(2)]]`)

mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ val1 as a, count(*) from aggr_test group by a order by a", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ val1 as a, count(*) from aggr_test group by a order by 2, a", `[[VARCHAR("b") INT64(1)] [VARCHAR("d") INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("c") INT64(2)] [VARCHAR("e") INT64(2)]]`)
mcmp.AssertMatches("select sum(val1) from aggr_test", `[[FLOAT64(0)]]`)
}

func TestGroupBy(t *testing.T) {
Expand Down Expand Up @@ -319,7 +320,7 @@ func TestAggOnTopOfLimit(t *testing.T) {
mcmp.AssertMatchesNoOrder(" select /*vt+ PLANNER=gen4 */ val1, count(val2) from (select val1, val2 from aggr_test limit 8) as x group by val1", `[[NULL INT64(1)] [VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)]]`)

// mysql returns FLOAT64(0), vitess returns DECIMAL(0)
mcmp.AssertMatchesNoCompare(" select /*vt+ PLANNER=gen4 */ count(*), sum(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0)]]", "[[INT64(2) DECIMAL(0)]]")
mcmp.AssertMatchesNoCompare(" select /*vt+ PLANNER=gen4 */ count(*), sum(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) FLOAT64(0)]]", "[[INT64(2) FLOAT64(0)]]")
mcmp.AssertMatches(" select /*vt+ PLANNER=gen4 */ count(val1), sum(id) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2) DECIMAL(7)]]")
mcmp.AssertMatches(" select /*vt+ PLANNER=gen4 */ count(*), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(2) DECIMAL(14)]]")
mcmp.AssertMatches(" select /*vt+ PLANNER=gen4 */ count(val1), sum(id) from (select id, val1 from aggr_test where val2 is null limit 2) as x", "[[INT64(1) DECIMAL(14)]]")
Expand Down
288 changes: 288 additions & 0 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
/*
Copyright 2023 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package engine

import (
"fmt"
"strconv"

"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/slices2"
"vitess.io/vitess/go/sqltypes"
binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
. "vitess.io/vitess/go/vt/vtgate/engine/opcode"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

// AggregateParams specify the parameters for each aggregation.
// It contains the opcode and input column number.
type AggregateParams struct {
Opcode AggregateOpcode
Col int

// These are used only for distinct opcodes.
KeyCol int
WCol int
WAssigned bool
CollationID collations.ID

Alias string `json:",omitempty"`
Expr sqlparser.Expr
Original *sqlparser.AliasedExpr

// This is based on the function passed in the select expression and
// not what we use to aggregate at the engine primitive level.
OrigOpcode AggregateOpcode
}

func (ap *AggregateParams) isDistinct() bool {
return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct
}

func (ap *AggregateParams) preProcess() bool {
switch ap.Opcode {
case AggregateCountDistinct, AggregateSumDistinct, AggregateGtid, AggregateCount, AggregateGroupConcat:
return true
default:
return false
}
}

func (ap *AggregateParams) String() string {
keyCol := strconv.Itoa(ap.Col)
if ap.WAssigned {
keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol)
}
if ap.CollationID != collations.Unknown {
keyCol += " COLLATE " + ap.CollationID.Get().Name()
}
dispOrigOp := ""
if ap.OrigOpcode != AggregateUnassigned && ap.OrigOpcode != ap.Opcode {
dispOrigOp = "_" + ap.OrigOpcode.String()
}
if ap.Alias != "" {
return fmt.Sprintf("%s%s(%s) AS %s", ap.Opcode.String(), dispOrigOp, keyCol, ap.Alias)
}
return fmt.Sprintf("%s%s(%s)", ap.Opcode.String(), dispOrigOp, keyCol)
}

func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
opCode := ap.Opcode
if ap.OrigOpcode != AggregateUnassigned {
opCode = ap.OrigOpcode
}
typ, _ := opCode.Type(&inputType)
return typ
}

func convertRow(
fields []*querypb.Field,
row []sqltypes.Value,
aggregates []*AggregateParams,
) (newRow []sqltypes.Value, curDistincts []sqltypes.Value) {
newRow = append(newRow, row...)
curDistincts = make([]sqltypes.Value, len(aggregates))
for index, aggr := range aggregates {
switch aggr.Opcode {
case AggregateCountStar:
newRow[aggr.Col] = countOne
case AggregateCount:
val := countOne
if row[aggr.Col].IsNull() {
val = countZero
}
newRow[aggr.Col] = val
case AggregateCountDistinct:
curDistincts[index] = findComparableCurrentDistinct(row, aggr)
// Type is int64. Ok to call MakeTrusted.
if row[aggr.KeyCol].IsNull() {
newRow[aggr.Col] = countZero
} else {
newRow[aggr.Col] = countOne
}
case AggregateSum:
if row[aggr.Col].IsNull() {
break
}
var err error
newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], fields[aggr.Col].Type)
if err != nil {
newRow[aggr.Col] = sumZero
}
case AggregateSumDistinct:
curDistincts[index] = findComparableCurrentDistinct(row, aggr)
var err error
newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], fields[aggr.Col].Type)
if err != nil {
newRow[aggr.Col] = sumZero
}
case AggregateGtid:
vgtid := &binlogdatapb.VGtid{}
vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{
Keyspace: row[aggr.Col-1].ToString(),
Shard: row[aggr.Col+1].ToString(),
Gtid: row[aggr.Col].ToString(),
})
data, _ := vgtid.MarshalVT()
val, _ := sqltypes.NewValue(sqltypes.VarBinary, data)
newRow[aggr.Col] = val
case AggregateGroupConcat:
if !row[aggr.Col].IsNull() {
newRow[aggr.Col] = sqltypes.MakeTrusted(fields[aggr.Col].Type, []byte(row[aggr.Col].ToString()))
}
}
}
return newRow, curDistincts
}

func merge(
fields []*querypb.Field,
row1, row2 []sqltypes.Value,
curDistincts []sqltypes.Value,
aggregates []*AggregateParams,
) ([]sqltypes.Value, []sqltypes.Value, error) {
result := sqltypes.CopyRow(row1)
for index, aggr := range aggregates {
if aggr.isDistinct() {
if row2[aggr.KeyCol].IsNull() {
continue
}
cmp, err := evalengine.NullsafeCompare(curDistincts[index], row2[aggr.KeyCol], aggr.CollationID)
if err != nil {
return nil, nil, err
}
if cmp == 0 {
continue
}
curDistincts[index] = findComparableCurrentDistinct(row2, aggr)
}

var err error
switch aggr.Opcode {
case AggregateCountStar:
result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], countOne, fields[aggr.Col].Type)
case AggregateCount:
val := countOne
if row2[aggr.Col].IsNull() {
val = countZero
}
result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], val, fields[aggr.Col].Type)
case AggregateSum:
value := row1[aggr.Col]
v2 := row2[aggr.Col]
if value.IsNull() && v2.IsNull() {
result[aggr.Col] = sqltypes.NULL
break
}
result[aggr.Col], err = evalengine.NullSafeAdd(value, v2, fields[aggr.Col].Type)
case AggregateMin:
result[aggr.Col], err = evalengine.Min(row1[aggr.Col], row2[aggr.Col], aggr.CollationID)
case AggregateMax:
result[aggr.Col], err = evalengine.Max(row1[aggr.Col], row2[aggr.Col], aggr.CollationID)
case AggregateCountDistinct:
result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], countOne, fields[aggr.Col].Type)
case AggregateSumDistinct:
result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], row2[aggr.Col], fields[aggr.Col].Type)
case AggregateGtid:
vgtid := &binlogdatapb.VGtid{}
rowBytes, err := row1[aggr.Col].ToBytes()
if err != nil {
return nil, nil, err
}
err = vgtid.UnmarshalVT(rowBytes)
if err != nil {
return nil, nil, err
}
vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{
Keyspace: row2[aggr.Col-1].ToString(),
Shard: row2[aggr.Col+1].ToString(),
Gtid: row2[aggr.Col].ToString(),
})
data, _ := vgtid.MarshalVT()
val, _ := sqltypes.NewValue(sqltypes.VarBinary, data)
result[aggr.Col] = val
case AggregateAnyValue:
// we just grab the first value per grouping. no need to do anything more complicated here
case AggregateGroupConcat:
if row2[aggr.Col].IsNull() {
break
}
if result[aggr.Col].IsNull() {
result[aggr.Col] = sqltypes.MakeTrusted(fields[aggr.Col].Type, []byte(row2[aggr.Col].ToString()))
break
}
concat := row1[aggr.Col].ToString() + "," + row2[aggr.Col].ToString()
result[aggr.Col] = sqltypes.MakeTrusted(fields[aggr.Col].Type, []byte(concat))
default:
return nil, nil, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode)
}
if err != nil {
return nil, nil, err
}
}
return result, curDistincts, nil
}

func convertFinal(current []sqltypes.Value, aggregates []*AggregateParams) ([]sqltypes.Value, error) {
result := sqltypes.CopyRow(current)
for _, aggr := range aggregates {
switch aggr.Opcode {
case AggregateGtid:
vgtid := &binlogdatapb.VGtid{}
currentBytes, err := current[aggr.Col].ToBytes()
if err != nil {
return nil, err
}
err = vgtid.UnmarshalVT(currentBytes)
if err != nil {
return nil, err
}
result[aggr.Col] = sqltypes.NewVarChar(vgtid.String())
}
}
return result, nil
}

func convertFields(fields []*querypb.Field, aggrs []*AggregateParams) []*querypb.Field {
fields = slices2.Map(fields, func(from *querypb.Field) *querypb.Field {
return proto.Clone(from).(*querypb.Field)
})
for _, aggr := range aggrs {
fields[aggr.Col].Type = aggr.typ(fields[aggr.Col].Type)
if aggr.Alias != "" {
fields[aggr.Col].Name = aggr.Alias
}
if aggr.isDistinct() {
// TODO: this should move to plan time
aggr.KeyCol = aggr.Col
}
}
return fields
}

func findComparableCurrentDistinct(row []sqltypes.Value, aggr *AggregateParams) sqltypes.Value {
curDistinct := row[aggr.KeyCol]
if aggr.WAssigned && !curDistinct.IsComparable() {
aggr.KeyCol = aggr.WCol
curDistinct = row[aggr.KeyCol]
}
return curDistinct
}
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading