Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions go/sqltypes/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,37 @@ func Multiply(v1, v2 Value) (Value, error) {
return castFromNumeric(lresult, lresult.typ), nil
}

// Float Division for MySQL. Replicates behavior of "/" operator
func Divide(v1, v2 Value) (Value, error) {
if v1.IsNull() || v2.IsNull() {
return NULL, nil
}

lv2AsFloat, err := ToFloat64(v2)
divisorIsZero := lv2AsFloat == 0

if divisorIsZero || err != nil {
return NULL, err
}

lv1, err := newNumeric(v1)
if err != nil {
return NULL, err
}

lv2, err := newNumeric(v2)
if err != nil {
return NULL, err
}

lresult, err := divideNumericWithError(lv1, lv2)
if err != nil {
return NULL, err
}

return castFromNumeric(lresult, lresult.typ), nil
}

// NullsafeAdd adds two Values in a null-safe manner. A null value
// is treated as 0. If both values are null, then a null is returned.
// If both values are not null, a numeric value is built
Expand Down Expand Up @@ -470,6 +501,20 @@ func multiplyNumericWithError(v1, v2 numeric) (numeric, error) {
panic("unreachable")
}

func divideNumericWithError(v1, v2 numeric) (numeric, error) {
switch v1.typ {
case Int64:
return floatDivideAnyWithError(float64(v1.ival), v2)

case Uint64:
return floatDivideAnyWithError(float64(v1.uval), v2)

case Float64:
return floatDivideAnyWithError(v1.fval, v2)
}
panic("unreachable")
}

// prioritize reorders the input parameters
// to be Float64, Uint64, Int64.
func prioritize(v1, v2 numeric) (altv1, altv2 numeric) {
Expand Down Expand Up @@ -523,6 +568,7 @@ func intTimesIntWithError(v1, v2 int64) (numeric, error) {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2)
}
return numeric{typ: Int64, ival: result}, nil

}

func intMinusUintWithError(v1 int64, v2 uint64) (numeric, error) {
Expand Down Expand Up @@ -626,6 +672,24 @@ func floatTimesAny(v1 float64, v2 numeric) numeric {
return numeric{typ: Float64, fval: v1 * v2.fval}
}

func floatDivideAnyWithError(v1 float64, v2 numeric) (numeric, error) {
switch v2.typ {
case Int64:
v2.fval = float64(v2.ival)
case Uint64:
v2.fval = float64(v2.uval)
}
result := v1 / v2.fval
divisorLessThanOne := v2.fval < 1
resultMismatch := (v2.fval*result != v1)

if divisorLessThanOne && resultMismatch {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in %v / %v", v1, v2.fval)
}

return numeric{typ: Float64, fval: v1 / v2.fval}, nil
}

func anyMinusFloat(v1 numeric, v2 float64) numeric {
switch v1.typ {
case Int64:
Expand Down
106 changes: 106 additions & 0 deletions go/sqltypes/arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,112 @@ import (
"vitess.io/vitess/go/vt/vterrors"
)

func TestDivide(t *testing.T) {
tcases := []struct {
v1, v2 Value
out Value
err error
}{{
//All Nulls
v1: NULL,
v2: NULL,
out: NULL,
}, {
// First value null.
v1: NULL,
v2: NewInt32(1),
out: NULL,
}, {
// Second value null.
v1: NewInt32(1),
v2: NULL,
out: NULL,
}, {
// Second arg 0
v1: NewInt32(5),
v2: NewInt32(0),
out: NULL,
}, {
// Both arguments zero
v1: NewInt32(0),
v2: NewInt32(0),
out: NULL,
}, {
// case with negative value
v1: NewInt64(-1),
v2: NewInt64(-2),
out: NewFloat64(0.5000),
}, {
// float64 division by zero
v1: NewFloat64(2),
v2: NewFloat64(0),
out: NULL,
}, {
// Lower bound for int64
v1: NewInt64(math.MinInt64),
v2: NewInt64(1),
out: NewFloat64(math.MinInt64),
}, {
// upper bound for uint64
v1: NewUint64(math.MaxUint64),
v2: NewUint64(1),
out: NewFloat64(math.MaxUint64),
}, {
// testing for error in types
v1: TestValue(Int64, "1.2"),
v2: NewInt64(2),
err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"),
}, {
// testing for error in types
v1: NewInt64(2),
v2: TestValue(Int64, "1.2"),
err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"),
}, {
// testing for uint/int
v1: NewUint64(4),
v2: NewInt64(5),
out: NewFloat64(0.8),
}, {
// testing for uint/uint
v1: NewUint64(1),
v2: NewUint64(2),
out: NewFloat64(0.5),
}, {
// testing for float64/int64
v1: TestValue(Float64, "1.2"),
v2: NewInt64(-2),
out: NewFloat64(-0.6),
}, {
// testing for float64/uint64
v1: TestValue(Float64, "1.2"),
v2: NewUint64(2),
out: NewFloat64(0.6),
}, {
// testing for overflow of float64
v1: NewFloat64(math.MaxFloat64),
v2: NewFloat64(0.5),
err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in 1.7976931348623157e+308 / 0.5"),
}}

for _, tcase := range tcases {
got, err := Divide(tcase.v1, tcase.v2)

if !vterrors.Equals(err, tcase.err) {
t.Errorf("%v %v %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err))
t.Errorf("Divide(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err))
}

if tcase.err != nil {
continue
}

if !reflect.DeepEqual(got, tcase.out) {
t.Errorf("Divide(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out))
}
}

}

func TestMultiply(t *testing.T) {
tcases := []struct {
v1, v2 Value
Expand Down