diff --git a/go/sqltypes/arithmetic.go b/go/sqltypes/arithmetic.go index 3dc65188881..d264415994b 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/sqltypes/arithmetic.go @@ -109,6 +109,37 @@ func Multiply(v1, v2 Value) (Value, error) { return castFromNumeric(lresult, lresult.typ), nil } +// Float Division for MySQL. Replicates behavior of "/" operator +func Divide(v1, v2 Value) (Value, error) { + if v1.IsNull() || v2.IsNull() { + return NULL, nil + } + + lv2AsFloat, err := ToFloat64(v2) + divisorIsZero := lv2AsFloat == 0 + + if divisorIsZero || err != nil { + return NULL, err + } + + lv1, err := newNumeric(v1) + if err != nil { + return NULL, err + } + + lv2, err := newNumeric(v2) + if err != nil { + return NULL, err + } + + lresult, err := divideNumericWithError(lv1, lv2) + if err != nil { + return NULL, err + } + + return castFromNumeric(lresult, lresult.typ), nil +} + // NullsafeAdd adds two Values in a null-safe manner. A null value // is treated as 0. If both values are null, then a null is returned. // If both values are not null, a numeric value is built @@ -470,6 +501,20 @@ func multiplyNumericWithError(v1, v2 numeric) (numeric, error) { panic("unreachable") } +func divideNumericWithError(v1, v2 numeric) (numeric, error) { + switch v1.typ { + case Int64: + return floatDivideAnyWithError(float64(v1.ival), v2) + + case Uint64: + return floatDivideAnyWithError(float64(v1.uval), v2) + + case Float64: + return floatDivideAnyWithError(v1.fval, v2) + } + panic("unreachable") +} + // prioritize reorders the input parameters // to be Float64, Uint64, Int64. func prioritize(v1, v2 numeric) (altv1, altv2 numeric) { @@ -523,6 +568,7 @@ func intTimesIntWithError(v1, v2 int64) (numeric, error) { return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2) } return numeric{typ: Int64, ival: result}, nil + } func intMinusUintWithError(v1 int64, v2 uint64) (numeric, error) { @@ -626,6 +672,24 @@ func floatTimesAny(v1 float64, v2 numeric) numeric { return numeric{typ: Float64, fval: v1 * v2.fval} } +func floatDivideAnyWithError(v1 float64, v2 numeric) (numeric, error) { + switch v2.typ { + case Int64: + v2.fval = float64(v2.ival) + case Uint64: + v2.fval = float64(v2.uval) + } + result := v1 / v2.fval + divisorLessThanOne := v2.fval < 1 + resultMismatch := (v2.fval*result != v1) + + if divisorLessThanOne && resultMismatch { + return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in %v / %v", v1, v2.fval) + } + + return numeric{typ: Float64, fval: v1 / v2.fval}, nil +} + func anyMinusFloat(v1 numeric, v2 float64) numeric { switch v1.typ { case Int64: diff --git a/go/sqltypes/arithmetic_test.go b/go/sqltypes/arithmetic_test.go index fa5a93ddb96..dfda557ef19 100644 --- a/go/sqltypes/arithmetic_test.go +++ b/go/sqltypes/arithmetic_test.go @@ -29,6 +29,112 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) +func TestDivide(t *testing.T) { + tcases := []struct { + v1, v2 Value + out Value + err error + }{{ + //All Nulls + v1: NULL, + v2: NULL, + out: NULL, + }, { + // First value null. + v1: NULL, + v2: NewInt32(1), + out: NULL, + }, { + // Second value null. + v1: NewInt32(1), + v2: NULL, + out: NULL, + }, { + // Second arg 0 + v1: NewInt32(5), + v2: NewInt32(0), + out: NULL, + }, { + // Both arguments zero + v1: NewInt32(0), + v2: NewInt32(0), + out: NULL, + }, { + // case with negative value + v1: NewInt64(-1), + v2: NewInt64(-2), + out: NewFloat64(0.5000), + }, { + // float64 division by zero + v1: NewFloat64(2), + v2: NewFloat64(0), + out: NULL, + }, { + // Lower bound for int64 + v1: NewInt64(math.MinInt64), + v2: NewInt64(1), + out: NewFloat64(math.MinInt64), + }, { + // upper bound for uint64 + v1: NewUint64(math.MaxUint64), + v2: NewUint64(1), + out: NewFloat64(math.MaxUint64), + }, { + // testing for error in types + v1: TestValue(Int64, "1.2"), + v2: NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for error in types + v1: NewInt64(2), + v2: TestValue(Int64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for uint/int + v1: NewUint64(4), + v2: NewInt64(5), + out: NewFloat64(0.8), + }, { + // testing for uint/uint + v1: NewUint64(1), + v2: NewUint64(2), + out: NewFloat64(0.5), + }, { + // testing for float64/int64 + v1: TestValue(Float64, "1.2"), + v2: NewInt64(-2), + out: NewFloat64(-0.6), + }, { + // testing for float64/uint64 + v1: TestValue(Float64, "1.2"), + v2: NewUint64(2), + out: NewFloat64(0.6), + }, { + // testing for overflow of float64 + v1: NewFloat64(math.MaxFloat64), + v2: NewFloat64(0.5), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in 1.7976931348623157e+308 / 0.5"), + }} + + for _, tcase := range tcases { + got, err := Divide(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("%v %v %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err)) + t.Errorf("Divide(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Divide(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + func TestMultiply(t *testing.T) { tcases := []struct { v1, v2 Value