diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 1e7daa6571..c9a3aa47c9 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -90,18 +90,41 @@ func (m *sumBuffer) PerformSum(ctx *sql.Context, v interface{}) { v = val } } - switch n := v.(type) { + case float64: + if m.isnil { + m.sum = float64(0) + m.isnil = false + } + switch sum := m.sum.(type) { + case float64: + case decimal.Decimal: + m.sum, _ = sum.Float64() + default: + var err error + m.sum, _, err = types.Float64.Convert(ctx, sum) + if err != nil { + m.sum = float64(0) + } + } + m.sum = m.sum.(float64) + n case decimal.Decimal: if m.isnil { m.sum = decimal.NewFromInt(0) m.isnil = false } - if sum, ok := m.sum.(decimal.Decimal); ok { - m.sum = sum.Add(n) - } else { - m.sum = decimal.NewFromFloat(m.sum.(float64)).Add(n) + switch sum := m.sum.(type) { + case decimal.Decimal: + case float64: + m.sum = decimal.NewFromFloat(sum) + default: + var err error + m.sum, _, err = types.InternalDecimalType.Convert(ctx, sum) + if err != nil { + m.sum = decimal.NewFromInt(0) + } } + m.sum = m.sum.(decimal.Decimal).Add(n) default: val, _, err := types.Float64.Convert(ctx, n) if err != nil { @@ -111,11 +134,17 @@ func (m *sumBuffer) PerformSum(ctx *sql.Context, v interface{}) { m.sum = float64(0) m.isnil = false } - sum, _, err := types.Float64.Convert(ctx, m.sum) - if err != nil { - sum = float64(0) + switch sum := m.sum.(type) { + case float64: + case decimal.Decimal: + m.sum, _ = sum.Float64() + default: + sum, _, err = types.Float64.Convert(ctx, sum) + if err != nil { + sum = float64(0) + } } - m.sum = sum.(float64) + val.(float64) + m.sum = m.sum.(float64) + val.(float64) } }