diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 9e652727ec8..db4eb1df219 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -255,21 +255,21 @@ func (oa *OrderedAggregate) StreamExecute(vcursor VCursor, bindVars map[string]* return nil } -func (oa *OrderedAggregate) convertFields(fields []*querypb.Field) (newFields []*querypb.Field) { +func (oa *OrderedAggregate) convertFields(fields []*querypb.Field) []*querypb.Field { if !oa.HasDistinct { return fields } - newFields = append(newFields, fields...) + for _, aggr := range oa.Aggregates { if !aggr.isDistinct() { continue } - newFields[aggr.Col] = &querypb.Field{ + fields[aggr.Col] = &querypb.Field{ Name: aggr.Alias, Type: opcodeType[aggr.Opcode], } } - return newFields + return fields } func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes.Value, curDistinct sqltypes.Value) { @@ -375,10 +375,17 @@ func (oa *OrderedAggregate) createEmptyRow() ([]sqltypes.Value, error) { func createEmptyValueFor(opcode AggregateOpcode) (sqltypes.Value, error) { switch opcode { - case AggregateCountDistinct: + case + AggregateCountDistinct, + AggregateCount: return countZero, nil - case AggregateSumDistinct: + case + AggregateSumDistinct, + AggregateSum, + AggregateMin, + AggregateMax: return sqltypes.NULL, nil + } return sqltypes.NULL, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "unknown aggregation %v", opcode) } diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index d4201fad66c..9028b83d11a 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -666,42 +666,79 @@ func TestMerge(t *testing.T) { assert.Equal(want, merged) } -func TestNoInputAndNoGroupingKeys(t *testing.T) { - assert := assert.New(t) - fp := &fakePrimitive{ - results: []*sqltypes.Result{sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col1|col2", - "int64|int64", - ), - // Empty input table - )}, +func TestNoInputAndNoGroupingKeys(outer *testing.T) { + testCases := []struct { + name string + opcode AggregateOpcode + expectedVal string + expectedTyp string + }{{ + "count(distinct col1)", + AggregateCountDistinct, + "0", + "int64", + }, { + "col1", + AggregateCount, + "0", + "int64", + }, { + "sum(distinct col1)", + AggregateSumDistinct, + "null", + "decimal", + }, { + "col1", + AggregateSum, + "null", + "int64", + }, { + "col1", + AggregateMax, + "null", + "int64", + }, { + "col1", + AggregateMin, + "null", + "int64", + }} + + for _, test := range testCases { + outer.Run(test.name, func(t *testing.T) { + assert := assert.New(t) + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1", + "int64", + ), + // Empty input table + )}, + } + + oa := &OrderedAggregate{ + HasDistinct: true, + Aggregates: []AggregateParams{{ + Opcode: test.opcode, + Col: 0, + Alias: test.name, + }}, + Keys: []int{}, + Input: fp, + } + + result, err := oa.Execute(nil, nil, false) + assert.NoError(err) + + wantResult := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + test.name, + test.expectedTyp, + ), + test.expectedVal, + ) + assert.Equal(wantResult, result) + }) } - - oa := &OrderedAggregate{ - HasDistinct: true, - Aggregates: []AggregateParams{{ - Opcode: AggregateCountDistinct, - Col: 0, - Alias: "count(distinct col2)", - }, { - Opcode: AggregateSumDistinct, - Col: 1, - Alias: "sum(distinct col2)", - }}, - Keys: []int{}, - Input: fp, - } - - result, err := oa.Execute(nil, nil, false) - assert.NoError(err) - - wantResult := sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "count(distinct col2)|sum(distinct col2)", - "int64|decimal", - ), - "0|null", - ) - assert.Equal(wantResult, result) }