Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions go/vt/vtgate/engine/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ package engine
import (
"fmt"

"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"

"vitess.io/vitess/go/sqltypes"

querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -179,6 +182,17 @@ func (oa *OrderedAggregate) execute(vcursor VCursor, bindVars map[string]*queryp
out.Rows = append(out.Rows, current)
current, curDistinct = oa.convertRow(row)
}

if len(result.Rows) == 0 && len(oa.Keys) == 0 {
// When doing aggregation without grouping keys, we need to produce a single row containing zero-value for the
// different aggregation functions
row, err := oa.createEmptyRow()
if err != nil {
return nil, err
}
out.Rows = append(out.Rows, row)
}

if current != nil {
out.Rows = append(out.Rows, current)
}
Expand Down Expand Up @@ -345,3 +359,26 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes
}
return result, curDistinct, nil
}

// creates the empty row for the case when we are missing grouping keys and have empty input table
func (oa *OrderedAggregate) createEmptyRow() ([]sqltypes.Value, error) {
out := make([]sqltypes.Value, len(oa.Aggregates))
for i, aggr := range oa.Aggregates {
value, err := createEmptyValueFor(aggr.Opcode)
if err != nil {
return nil, err
}
out[i] = value
}
return out, nil
}

func createEmptyValueFor(opcode AggregateOpcode) (sqltypes.Value, error) {
switch opcode {
case AggregateCountDistinct:
return sqltypes.NULL, nil
Copy link
Copy Markdown
Collaborator

@mpawliszyn mpawliszyn Aug 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crap @systay I think these are reversed no? We want count to always return non null.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a PR to fix it:
#5121

case AggregateSumDistinct:
return sumZero, nil
}
return sqltypes.NULL, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "unknown aggregation %v", opcode)
}
149 changes: 76 additions & 73 deletions go/vt/vtgate/engine/ordered_aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ package engine

import (
"errors"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
"vitess.io/vitess/go/sqltypes"
)

func TestOrderedAggregateExecute(t *testing.T) {
assert := assert.New(t)
fields := sqltypes.MakeTestFields(
"col|count(*)",
"varbinary|decimal",
Expand All @@ -50,22 +51,19 @@ func TestOrderedAggregateExecute(t *testing.T) {
}

result, err := oa.Execute(nil, nil, false)
if err != nil {
t.Error(err)
}
assert.NoError(err)

wantResult := sqltypes.MakeTestResult(
fields,
"a|2",
"b|2",
"c|7",
)
if !reflect.DeepEqual(result, wantResult) {
t.Errorf("oa.Execute:\n%v, want\n%v", result, wantResult)
}
assert.Equal(wantResult, result)
}

func TestOrderedAggregateExecuteTruncate(t *testing.T) {
assert := assert.New(t)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand All @@ -91,9 +89,7 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) {
}

result, err := oa.Execute(nil, nil, false)
if err != nil {
t.Error(err)
}
assert.NoError(err)

wantResult := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand All @@ -104,12 +100,11 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) {
"b|2",
"C|7",
)
if !reflect.DeepEqual(result, wantResult) {
t.Errorf("oa.Execute:\n%v, want\n%v", result, wantResult)
}
assert.Equal(wantResult, result)
}

func TestOrderedAggregateStreamExecute(t *testing.T) {
assert := assert.New(t)
fields := sqltypes.MakeTestFields(
"col|count(*)",
"varbinary|decimal",
Expand Down Expand Up @@ -139,9 +134,7 @@ func TestOrderedAggregateStreamExecute(t *testing.T) {
results = append(results, qr)
return nil
})
if err != nil {
t.Error(err)
}
assert.NoError(err)

wantResults := sqltypes.MakeTestStreamingResults(
fields,
Expand All @@ -151,12 +144,11 @@ func TestOrderedAggregateStreamExecute(t *testing.T) {
"---",
"c|7",
)
if !reflect.DeepEqual(results, wantResults) {
t.Errorf("oa.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
}
assert.Equal(wantResults, results)
}

func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) {
assert := assert.New(t)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand Down Expand Up @@ -186,9 +178,7 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) {
results = append(results, qr)
return nil
})
if err != nil {
t.Error(err)
}
assert.NoError(err)

wantResults := sqltypes.MakeTestStreamingResults(
sqltypes.MakeTestFields(
Expand All @@ -201,32 +191,28 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) {
"---",
"C|7",
)
if !reflect.DeepEqual(results, wantResults) {
t.Errorf("oa.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults))
}
assert.Equal(wantResults, results)
}

func TestOrderedAggregateGetFields(t *testing.T) {
result := sqltypes.MakeTestResult(
assert := assert.New(t)
input := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col|count(*)",
"varbinary|decimal",
),
)
fp := &fakePrimitive{results: []*sqltypes.Result{result}}
fp := &fakePrimitive{results: []*sqltypes.Result{input}}

oa := &OrderedAggregate{Input: fp}

got, err := oa.GetFields(nil, nil)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(got, result) {
t.Errorf("oa.GetFields:\n%v, want\n%v", got, result)
}
assert.NoError(err)
assert.Equal(got, input)
}

func TestOrderedAggregateGetFieldsTruncate(t *testing.T) {
assert := assert.New(t)
result := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col|count(*)|weight_string(col)",
Expand All @@ -241,18 +227,14 @@ func TestOrderedAggregateGetFieldsTruncate(t *testing.T) {
}

got, err := oa.GetFields(nil, nil)
if err != nil {
t.Error(err)
}
assert.NoError(err)
wantResult := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col|count(*)",
"varchar|decimal",
),
)
if !reflect.DeepEqual(got, wantResult) {
t.Errorf("oa.GetFields:\n%v, want\n%v", got, wantResult)
}
assert.Equal(wantResult, got)
}

func TestOrderedAggregateInputFail(t *testing.T) {
Expand All @@ -277,6 +259,7 @@ func TestOrderedAggregateInputFail(t *testing.T) {
}

func TestOrderedAggregateExecuteCountDistinct(t *testing.T) {
assert := assert.New(t)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand Down Expand Up @@ -331,9 +314,7 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) {
}

result, err := oa.Execute(nil, nil, false)
if err != nil {
t.Error(err)
}
assert.NoError(err)

wantResult := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand All @@ -350,12 +331,11 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) {
"h|3|4",
"i|2|2",
)
if !reflect.DeepEqual(result, wantResult) {
t.Errorf("oa.Execute:\n%v, want\n%v", result, wantResult)
}
assert.Equal(wantResult, result)
}

func TestOrderedAggregateStreamCountDistinct(t *testing.T) {
assert := assert.New(t)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand Down Expand Up @@ -414,9 +394,7 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) {
results = append(results, qr)
return nil
})
if err != nil {
t.Error(err)
}
assert.NoError(err)

wantResults := sqltypes.MakeTestStreamingResults(
sqltypes.MakeTestFields(
Expand All @@ -441,12 +419,11 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) {
"-----",
"i|2|2",
)
if !reflect.DeepEqual(results, wantResults) {
t.Errorf("oa.Execute:\n%v, want\n%v", results, wantResults)
}
assert.Equal(wantResults, results)
}

func TestOrderedAggregateSumDistinctGood(t *testing.T) {
assert := assert.New(t)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand Down Expand Up @@ -501,9 +478,7 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) {
}

result, err := oa.Execute(nil, nil, false)
if err != nil {
t.Error(err)
}
assert.NoError(err)

wantResult := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand All @@ -520,12 +495,11 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) {
"h|6|4",
"i|7|2",
)
if !reflect.DeepEqual(result, wantResult) {
t.Errorf("oa.Execute:\n%v, want\n%v", result, wantResult)
}
assert.Equal(wantResult, result)
}

func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) {
assert := assert.New(t)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand All @@ -550,9 +524,7 @@ func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) {
}

result, err := oa.Execute(nil, nil, false)
if err != nil {
t.Error(err)
}
assert.NoError(err)

wantResult := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
Expand All @@ -561,9 +533,7 @@ func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) {
),
"a|1",
)
if !reflect.DeepEqual(result, wantResult) {
t.Errorf("oa.Execute:\n%v, want\n%v", result, wantResult)
}
assert.Equal(wantResult, result)
}

func TestOrderedAggregateKeysFail(t *testing.T) {
Expand Down Expand Up @@ -633,6 +603,7 @@ func TestOrderedAggregateMergeFail(t *testing.T) {
}

func TestMerge(t *testing.T) {
assert := assert.New(t)
oa := &OrderedAggregate{
Aggregates: []AggregateParams{{
Opcode: AggregateCount,
Expand All @@ -658,20 +629,52 @@ func TestMerge(t *testing.T) {
)

merged, _, err := oa.merge(fields, r.Rows[0], r.Rows[1], sqltypes.NULL)
if err != nil {
t.Error(err)
}
assert.NoError(err)
want := sqltypes.MakeTestResult(fields, "1|5|6|2|bc").Rows[0]
if !reflect.DeepEqual(merged, want) {
t.Errorf("oa.merge(row1, row2): %v, want %v", merged, want)
}
assert.Equal(want, merged)

// swap and retry
merged, _, err = oa.merge(fields, r.Rows[1], r.Rows[0], sqltypes.NULL)
if err != nil {
t.Error(err)
assert.NoError(err)
assert.Equal(want, merged)
}

func TestNoInputAndNoGroupingKeys(t *testing.T) {
assert := assert.New(t)
fp := &fakePrimitive{
results: []*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2",
"int64|int64",
),
// Empty input table
)},
}
if !reflect.DeepEqual(merged, want) {
t.Errorf("oa.merge(row1, row2): %v, want %v", merged, want)

oa := &OrderedAggregate{
HasDistinct: true,
Aggregates: []AggregateParams{{
Opcode: AggregateCountDistinct,
Col: 0,
Alias: "count(distinct col2)",
}, {
Opcode: AggregateSumDistinct,
Col: 1,
Alias: "sum(distinct col2)",
}},
Keys: []int{},
Input: fp,
}

result, err := oa.Execute(nil, nil, false)
assert.NoError(err)

wantResult := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"count(distinct col2)|sum(distinct col2)",
"int64|decimal",
),
"null|0",
)
assert.Equal(wantResult, result)
}
Loading