diff --git a/go/sqltypes/arithmetic.go b/go/sqltypes/arithmetic.go index 516c2c9f220..f5c18dd43a1 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/sqltypes/arithmetic.go @@ -19,6 +19,8 @@ package sqltypes import ( "bytes" "fmt" + "math" + "strconv" querypb "vitess.io/vitess/go/vt/proto/query" @@ -40,6 +42,25 @@ type numeric struct { var zeroBytes = []byte("0") +// Add adds two values together +// if v1 or v2 is null, then it returns null +func Add(v1, v2 Value) (Value, error) { + if v1.IsNull() || v2.IsNull() { + return NULL, nil + } + + lv1, err := newNumeric(v1) + + lv2, err := newNumeric(v2) + + lresult, err := addNumericWithError(lv1, lv2) + if err != nil { + return NULL, err + } + + return castFromNumeric(lresult, lresult.typ), nil +} + // NullsafeAdd adds two Values in a null-safe manner. A null value // is treated as 0. If both values are null, then a null is returned. // If both values are not null, a numeric value is built @@ -51,7 +72,7 @@ var zeroBytes = []byte("0") // addition, if one of the input types was Decimal, then // a Decimal is built. Otherwise, the final type of the // result is preserved. -func NullsafeAdd(v1, v2 Value, resultType querypb.Type) (Value, error) { +func NullsafeAdd(v1, v2 Value, resultType querypb.Type) Value { if v1.IsNull() { v1 = MakeTrusted(resultType, zeroBytes) } @@ -61,16 +82,14 @@ func NullsafeAdd(v1, v2 Value, resultType querypb.Type) (Value, error) { lv1, err := newNumeric(v1) if err != nil { - return NULL, err + return NULL } lv2, err := newNumeric(v2) if err != nil { - return NULL, err - } - lresult, err := addNumeric(lv1, lv2) - if err != nil { - return NULL, err + return NULL } + lresult := addNumeric(lv1, lv2) + return castFromNumeric(lresult, resultType) } @@ -224,10 +243,7 @@ func ToInt64(v Value) (int64, error) { // ToFloat64 converts Value to float64. func ToFloat64(v Value) (float64, error) { - num, err := newNumeric(v) - if err != nil { - return 0, err - } + num, _ := newNumeric(v) switch num.typ { case Int64: return float64(num.ival), nil @@ -292,7 +308,7 @@ func newNumeric(v Value) (numeric, error) { if fval, err := strconv.ParseFloat(str, 64); err == nil { return numeric{fval: fval, typ: Float64}, nil } - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) + return numeric{ival: 0, typ: Int64}, nil } // newIntegralNumeric parses a value and produces an Int64 or Uint64. @@ -323,22 +339,41 @@ func newIntegralNumeric(v Value) (numeric, error) { return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) } -func addNumeric(v1, v2 numeric) (numeric, error) { +func addNumeric(v1, v2 numeric) numeric { v1, v2 = prioritize(v1, v2) switch v1.typ { case Int64: - return intPlusInt(v1.ival, v2.ival), nil + return intPlusInt(v1.ival, v2.ival) case Uint64: switch v2.typ { case Int64: return uintPlusInt(v1.uval, v2.ival) case Uint64: - return uintPlusUint(v1.uval, v2.uval), nil + return uintPlusUint(v1.uval, v2.uval) + } + case Float64: + return floatPlusAny(v1.fval, v2) + } + panic("unreachable") +} + +func addNumericWithError(v1, v2 numeric) (numeric, error) { + v1, v2 = prioritize(v1, v2) + switch v1.typ { + case Int64: + return intPlusIntWithError(v1.ival, v2.ival) + case Uint64: + switch v2.typ { + case Int64: + return uintPlusIntWithError(v1.uval, v2.ival) + case Uint64: + return uintPlusUintWithError(v1.uval, v2.uval) } case Float64: return floatPlusAny(v1.fval, v2), nil } panic("unreachable") + } // prioritize reorders the input parameters @@ -353,6 +388,7 @@ func prioritize(v1, v2 numeric) (altv1, altv2 numeric) { if v2.typ == Float64 { return v2, v1 } + } return v1, v2 } @@ -371,21 +407,47 @@ overflow: return numeric{typ: Float64, fval: float64(v1) + float64(v2)} } -func uintPlusInt(v1 uint64, v2 int64) (numeric, error) { - if v2 < 0 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "cannot add a negative number to an unsigned integer: %d, %d", v1, v2) +func intPlusIntWithError(v1, v2 int64) (numeric, error) { + result := v1 + v2 + if (result > v1) != (v2 > 0) { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) } - return uintPlusUint(v1, uint64(v2)), nil + return numeric{typ: Int64, ival: result}, nil +} + +func uintPlusInt(v1 uint64, v2 int64) numeric { + return uintPlusUint(v1, uint64(v2)) +} + +func uintPlusIntWithError(v1 uint64, v2 int64) (numeric, error) { + if v2 >= math.MaxInt64 && v1 > 0 { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) + } + + //convert to int -> uint is because for numeric operators (such as + or -) + //where one of the operands is an unsigned integer, the result is unsigned by default. + return uintPlusUintWithError(v1, uint64(v2)) } func uintPlusUint(v1, v2 uint64) numeric { result := v1 + v2 if result < v2 { return numeric{typ: Float64, fval: float64(v1) + float64(v2)} + } return numeric{typ: Uint64, uval: result} } +func uintPlusUintWithError(v1, v2 uint64) (numeric, error) { + result := v1 + v2 + + if result < v2 { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) + } + + return numeric{typ: Uint64, uval: result}, nil +} + func floatPlusAny(v1 float64, v2 numeric) numeric { switch v2.typ { case Int64: @@ -396,37 +458,43 @@ func floatPlusAny(v1 float64, v2 numeric) numeric { return numeric{typ: Float64, fval: v1 + v2.fval} } -func castFromNumeric(v numeric, resultType querypb.Type) (Value, error) { +func castFromNumeric(v numeric, resultType querypb.Type) Value { switch { case IsSigned(resultType): switch v.typ { case Int64: - return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)), nil - case Uint64, Float64: - return NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion: %v to %v", v.typ, resultType) + return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) + case Uint64: + return MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.uval), 10)) + case Float64: + return MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.fval), 10)) + } case IsUnsigned(resultType): switch v.typ { case Uint64: - return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)), nil - case Int64, Float64: - return NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion: %v to %v", v.typ, resultType) + return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) + case Int64: + return MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.ival), 10)) + case Float64: + return MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.fval), 10)) + } case IsFloat(resultType) || resultType == Decimal: switch v.typ { case Int64: - return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)), nil + return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) case Uint64: - return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)), nil + return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) case Float64: format := byte('g') if resultType == Decimal { format = 'f' } - return MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64)), nil + return MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64)) } } - return NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion to non-numeric: %v", resultType) + return NULL } func compareNumeric(v1, v2 numeric) int { diff --git a/go/sqltypes/arithmetic_test.go b/go/sqltypes/arithmetic_test.go index 458a5944dd9..81aeba0cf30 100644 --- a/go/sqltypes/arithmetic_test.go +++ b/go/sqltypes/arithmetic_test.go @@ -19,6 +19,7 @@ package sqltypes import ( "encoding/binary" "fmt" + "math" "reflect" "strconv" "testing" @@ -29,6 +30,111 @@ import ( ) func TestAdd(t *testing.T) { + tcases := []struct { + v1, v2 Value + out Value + err error + }{{ + + //All Nulls + v1: NULL, + v2: NULL, + out: NULL, + }, { + // First value null. + v1: NewInt32(1), + v2: NULL, + out: NULL, + }, { + // Second value null. + v1: NULL, + v2: NewInt32(1), + out: NULL, + }, { + + // case with negatives + v1: NewInt64(-1), + v2: NewInt64(-2), + out: NewInt64(-3), + }, { + + // testing for overflow int64 + v1: NewInt64(math.MaxInt64), + v2: NewUint64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in 2 + 9223372036854775807"), + }, { + + v1: NewInt64(-2), + v2: NewUint64(1), + out: NewUint64(math.MaxUint64), + }, { + + v1: NewInt64(math.MaxInt64), + v2: NewInt64(-2), + out: NewInt64(9223372036854775805), + }, { + // Normal case + v1: NewUint64(1), + v2: NewUint64(2), + out: NewUint64(3), + }, { + // testing for overflow uint64 + v1: NewUint64(math.MaxUint64), + v2: NewUint64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), + }, { + + // int64 underflow + v1: NewInt64(math.MinInt64), + v2: NewInt64(-2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 + -2"), + }, { + + // checking int64 max value can be returned + v1: NewInt64(math.MaxInt64), + v2: NewUint64(0), + out: NewUint64(9223372036854775807), + }, { + + // testing whether uint64 max value can be returned + v1: NewUint64(math.MaxUint64), + v2: NewInt64(0), + out: NewUint64(math.MaxUint64), + }, { + + v1: NewUint64(math.MaxInt64), + v2: NewInt64(1), + out: NewUint64(9223372036854775808), + }, { + + v1: NewUint64(1), + v2: TestValue(VarChar, "c"), + out: NewUint64(1), + }, { + v1: NewUint64(1), + v2: TestValue(VarChar, "1.2"), + out: NewFloat64(2.2), + }} + + for _, tcase := range tcases { + + got, err := Add(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Add(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Addition(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestNullsafeAdd(t *testing.T) { tcases := []struct { v1, v2 Value out Value @@ -67,21 +173,15 @@ func TestAdd(t *testing.T) { // Make sure underlying error is returned while adding. v1: NewInt64(-1), v2: NewUint64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "cannot add a negative number to an unsigned integer: 2, -1"), + out: NewInt64(-9223372036854775808), }, { // Make sure underlying error is returned while converting. v1: NewFloat64(1), v2: NewFloat64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion: FLOAT64 to INT64"), + out: NewInt64(3), }} for _, tcase := range tcases { - got, err := NullsafeAdd(tcase.v1, tcase.v2, Int64) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Add(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } + got := NullsafeAdd(tcase.v1, tcase.v2, Int64) if !reflect.DeepEqual(got, tcase.out) { t.Errorf("Add(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) @@ -346,7 +446,7 @@ func TestToFloat64(t *testing.T) { err error }{{ v: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), + out: 0, }, { v: NewInt64(1), out: 1, @@ -515,7 +615,7 @@ func TestNewNumeric(t *testing.T) { err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseFloat: parsing \"abcd\": invalid syntax"), }, { v: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), + out: numeric{typ: Float64, fval: 0}, }} for _, tcase := range tcases { got, err := newNumeric(tcase.v) @@ -623,7 +723,7 @@ func TestAddNumeric(t *testing.T) { }, { v1: numeric{typ: Int64, ival: -1}, v2: numeric{typ: Uint64, uval: 2}, - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "cannot add a negative number to an unsigned integer: 2, -1"), + out: numeric{typ: Float64, fval: 18446744073709551617}, }, { // Uint64 overflow. v1: numeric{typ: Uint64, uval: 18446744073709551615}, @@ -631,13 +731,7 @@ func TestAddNumeric(t *testing.T) { out: numeric{typ: Float64, fval: 18446744073709551617}, }} for _, tcase := range tcases { - got, err := addNumeric(tcase.v1, tcase.v2) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("addNumeric(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } + got := addNumeric(tcase.v1, tcase.v2) if got != tcase.out { t.Errorf("addNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) @@ -705,15 +799,15 @@ func TestCastFromNumeric(t *testing.T) { }, { typ: Int64, v: numeric{typ: Uint64, uval: 1}, - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion: UINT64 to INT64"), + out: NewInt64(1), }, { typ: Int64, v: numeric{typ: Float64, fval: 1.2e-16}, - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion: FLOAT64 to INT64"), + out: NewInt64(0), }, { typ: Uint64, v: numeric{typ: Int64, ival: 1}, - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion: INT64 to UINT64"), + out: NewUint64(1), }, { typ: Uint64, v: numeric{typ: Uint64, uval: 1}, @@ -721,7 +815,7 @@ func TestCastFromNumeric(t *testing.T) { }, { typ: Uint64, v: numeric{typ: Float64, fval: 1.2e-16}, - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion: FLOAT64 to UINT64"), + out: NewUint64(0), }, { typ: Float64, v: numeric{typ: Int64, ival: 1}, @@ -753,13 +847,7 @@ func TestCastFromNumeric(t *testing.T) { err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion to non-numeric: VARBINARY"), }} for _, tcase := range tcases { - got, err := castFromNumeric(tcase.v, tcase.typ) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("castFromNumeric(%v, %v) error: %v, want %v", tcase.v, tcase.typ, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } + got := castFromNumeric(tcase.v, tcase.typ) if !reflect.DeepEqual(got, tcase.out) { t.Errorf("castFromNumeric(%v, %v): %v, want %v", tcase.v, tcase.typ, printValue(got), printValue(tcase.out)) @@ -1021,7 +1109,7 @@ func BenchmarkAddActual(b *testing.B) { v1 := MakeTrusted(Int64, []byte("1")) v2 := MakeTrusted(Int64, []byte("12")) for i := 0; i < b.N; i++ { - v1, _ = NullsafeAdd(v1, v2, Int64) + v1 = NullsafeAdd(v1, v2, Int64) } } diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index 39c91287e38..e53419154b5 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -89,9 +89,11 @@ func NewValue(typ querypb.Type, val []byte) (v Value, err error) { // comments. Other packages can also use the function to create // VarBinary or VarChar values. func MakeTrusted(typ querypb.Type, val []byte) Value { + if typ == Null { return NULL } + return Value{typ: typ, val: val} } diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index fa541903cbb..e44f552d5b4 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -341,15 +341,15 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes var err error switch aggr.Opcode { case AggregateCount, AggregateSum: - result[aggr.Col], err = sqltypes.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], fields[aggr.Col].Type) + result[aggr.Col] = sqltypes.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], fields[aggr.Col].Type) case AggregateMin: result[aggr.Col], err = sqltypes.Min(row1[aggr.Col], row2[aggr.Col]) case AggregateMax: result[aggr.Col], err = sqltypes.Max(row1[aggr.Col], row2[aggr.Col]) case AggregateCountDistinct: - result[aggr.Col], err = sqltypes.NullsafeAdd(row1[aggr.Col], countOne, opcodeType[aggr.Opcode]) + result[aggr.Col] = sqltypes.NullsafeAdd(row1[aggr.Col], countOne, opcodeType[aggr.Opcode]) case AggregateSumDistinct: - result[aggr.Col], err = sqltypes.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], opcodeType[aggr.Opcode]) + result[aggr.Col] = sqltypes.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], opcodeType[aggr.Opcode]) default: return nil, sqltypes.NULL, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode) } diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index 5706761f031..547504a3b2a 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -18,10 +18,12 @@ package engine import ( "errors" + "reflect" "testing" "github.com/stretchr/testify/assert" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" ) func TestOrderedAggregateExecute(t *testing.T) { @@ -591,15 +593,40 @@ func TestOrderedAggregateMergeFail(t *testing.T) { Input: fp, } - want := "could not parse value: 'b'" - if _, err := oa.Execute(nil, nil, false); err == nil || err.Error() != want { - t.Errorf("oa.Execute(): %v, want %s", err, want) + result := &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "col", + Type: querypb.Type_VARBINARY, + }, + { + Name: "count(*)", + Type: querypb.Type_DECIMAL, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.MakeTrusted(querypb.Type_VARBINARY, []byte("a")), + sqltypes.MakeTrusted(querypb.Type_DECIMAL, []byte("1")), + }, + }, + RowsAffected: 1, + } + + res, err := oa.Execute(nil, nil, false) + if err != nil { + t.Errorf("oa.Execute() failed: %v", err) + } + + if !reflect.DeepEqual(res, result) { + t.Fatalf("Found mismatched values: want %v, got %v", result, res) } fp.rewind() - if err := oa.StreamExecute(nil, nil, false, func(_ *sqltypes.Result) error { return nil }); err == nil || err.Error() != want { - t.Errorf("oa.StreamExecute(): %v, want %s", err, want) + if err := oa.StreamExecute(nil, nil, false, func(_ *sqltypes.Result) error { return nil }); err != nil { + t.Errorf("oa.StreamExecute(): %v", err) } + } func TestMerge(t *testing.T) {