diff --git a/go/sqltypes/arithmetic.go b/go/sqltypes/arithmetic.go index cde014116eb..3dc65188881 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/sqltypes/arithmetic.go @@ -87,6 +87,28 @@ func Subtract(v1, v2 Value) (Value, error) { return castFromNumeric(lresult, lresult.typ), nil } +// Multiply takes two values and multiplies it together +func Multiply(v1, v2 Value) (Value, error) { + if v1.IsNull() || v2.IsNull() { + return NULL, nil + } + + lv1, err := newNumeric(v1) + if err != nil { + return NULL, err + } + lv2, err := newNumeric(v2) + if err != nil { + return NULL, err + } + lresult, err := multiplyNumericWithError(lv1, lv2) + if err != nil { + return NULL, err + } + + return castFromNumeric(lresult, lresult.typ), nil +} + // NullsafeAdd adds two Values in a null-safe manner. A null value // is treated as 0. If both values are null, then a null is returned. // If both values are not null, a numeric value is built @@ -430,6 +452,24 @@ func subtractNumericWithError(v1, v2 numeric) (numeric, error) { panic("unreachable") } +func multiplyNumericWithError(v1, v2 numeric) (numeric, error) { + v1, v2 = prioritize(v1, v2) + switch v1.typ { + case Int64: + return intTimesIntWithError(v1.ival, v2.ival) + case Uint64: + switch v2.typ { + case Int64: + return uintTimesIntWithError(v1.uval, v2.ival) + case Uint64: + return uintTimesUintWithError(v1.uval, v2.uval) + } + case Float64: + return floatTimesAny(v1.fval, v2), nil + } + panic("unreachable") +} + // prioritize reorders the input parameters // to be Float64, Uint64, Int64. func prioritize(v1, v2 numeric) (altv1, altv2 numeric) { @@ -477,6 +517,14 @@ func intMinusIntWithError(v1, v2 int64) (numeric, error) { return numeric{typ: Int64, ival: result}, nil } +func intTimesIntWithError(v1, v2 int64) (numeric, error) { + result := v1 * v2 + if v1 != 0 && result/v1 != v2 { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2) + } + return numeric{typ: Int64, ival: result}, nil +} + func intMinusUintWithError(v1 int64, v2 uint64) (numeric, error) { if v1 < 0 || v1 < int64(v2) { return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) @@ -508,6 +556,13 @@ func uintMinusIntWithError(v1 uint64, v2 int64) (numeric, error) { return uintMinusUintWithError(v1, uint64(v2)) } +func uintTimesIntWithError(v1 uint64, v2 int64) (numeric, error) { + if v2 < 0 || int64(v1) < 0 { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) + } + return uintTimesUintWithError(v1, uint64(v2)) +} + func uintPlusUint(v1, v2 uint64) numeric { result := v1 + v2 if result < v2 { @@ -529,6 +584,15 @@ func uintMinusUintWithError(v1, v2 uint64) (numeric, error) { if v2 > v1 { return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } + + return numeric{typ: Uint64, uval: result}, nil +} + +func uintTimesUintWithError(v1, v2 uint64) (numeric, error) { + result := v1 * v2 + if result < v2 || result < v1 { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) + } return numeric{typ: Uint64, uval: result}, nil } @@ -552,6 +616,16 @@ func floatMinusAny(v1 float64, v2 numeric) numeric { return numeric{typ: Float64, fval: v1 - v2.fval} } +func floatTimesAny(v1 float64, v2 numeric) numeric { + switch v2.typ { + case Int64: + v2.fval = float64(v2.ival) + case Uint64: + v2.fval = float64(v2.uval) + } + return numeric{typ: Float64, fval: v1 * v2.fval} +} + func anyMinusFloat(v1 numeric, v2 float64) numeric { switch v1.typ { case Int64: diff --git a/go/sqltypes/arithmetic_test.go b/go/sqltypes/arithmetic_test.go index 227b20c7d4c..fa5a93ddb96 100644 --- a/go/sqltypes/arithmetic_test.go +++ b/go/sqltypes/arithmetic_test.go @@ -29,6 +29,105 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) +func TestMultiply(t *testing.T) { + tcases := []struct { + v1, v2 Value + out Value + err error + }{{ + //All Nulls + v1: NULL, + v2: NULL, + out: NULL, + }, { + // First value null. + v1: NewInt32(1), + v2: NULL, + out: NULL, + }, { + // Second value null. + v1: NULL, + v2: NewInt32(1), + out: NULL, + }, { + // case with negative value + v1: NewInt64(-1), + v2: NewInt64(-2), + out: NewInt64(2), + }, { + // testing for int64 overflow with min negative value + v1: NewInt64(math.MinInt64), + v2: NewInt64(1), + out: NewInt64(math.MinInt64), + }, { + // testing for error in types + v1: TestValue(Int64, "1.2"), + v2: NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for error in types + v1: NewInt64(2), + v2: TestValue(Int64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for uint*int + v1: NewUint64(4), + v2: NewInt64(5), + out: NewUint64(20), + }, { + // testing for uint*uint + v1: NewUint64(1), + v2: NewUint64(2), + out: NewUint64(2), + }, { + // testing for float64*int64 + v1: TestValue(Float64, "1.2"), + v2: NewInt64(-2), + out: NewFloat64(-2.4), + }, { + // testing for float64*uint64 + v1: TestValue(Float64, "1.2"), + v2: NewUint64(2), + out: NewFloat64(2.4), + }, { + // testing for overflow of int64 + v1: NewInt64(math.MaxInt64), + v2: NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in 9223372036854775807 * 2"), + }, { + // testing for underflow of uint64*max.uint64 + v1: NewInt64(2), + v2: NewUint64(math.MaxUint64), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 * 2"), + }, { + v1: NewUint64(math.MaxUint64), + v2: NewUint64(1), + out: NewUint64(math.MaxUint64), + }, { + //Checking whether maxInt value can be passed as uint value + v1: NewUint64(math.MaxInt64), + v2: NewInt64(3), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 9223372036854775807 * 3"), + }} + + for _, tcase := range tcases { + + got, err := Multiply(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Multiply(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Multiply(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + func TestSubtract(t *testing.T) { tcases := []struct { v1, v2 Value