Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 117 additions & 13 deletions go/sqltypes/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package sqltypes
import (
"bytes"
"fmt"
"math"

"strconv"

Expand All @@ -28,9 +27,6 @@ import (
"vitess.io/vitess/go/vt/vterrors"
)

// TODO(sougou): change these functions to be more permissive.
// Most string to number conversions should quietly convert to 0.

// numeric represents a numeric value extracted from
// a Value, used for arithmetic operations.
type numeric struct {
Expand All @@ -50,8 +46,14 @@ func Add(v1, v2 Value) (Value, error) {
}

lv1, err := newNumeric(v1)
if err != nil {
return NULL, err
}

lv2, err := newNumeric(v2)
if err != nil {
return NULL, err
}

lresult, err := addNumericWithError(lv1, lv2)
if err != nil {
Expand All @@ -61,6 +63,30 @@ func Add(v1, v2 Value) (Value, error) {
return castFromNumeric(lresult, lresult.typ), nil
}

// Subtract takes two values and subtracts them
func Subtract(v1, v2 Value) (Value, error) {
if v1.IsNull() || v2.IsNull() {
return NULL, nil
}

lv1, err := newNumeric(v1)
if err != nil {
return NULL, err
}

lv2, err := newNumeric(v2)
if err != nil {
return NULL, err
}

lresult, err := subtractNumericWithError(lv1, lv2)
if err != nil {
return NULL, err
}

return castFromNumeric(lresult, lresult.typ), nil
}

// NullsafeAdd adds two Values in a null-safe manner. A null value
// is treated as 0. If both values are null, then a null is returned.
// If both values are not null, a numeric value is built
Expand Down Expand Up @@ -243,7 +269,10 @@ func ToInt64(v Value) (int64, error) {

// ToFloat64 converts Value to float64.
func ToFloat64(v Value) (float64, error) {
num, _ := newNumeric(v)
num, err := newNumeric(v)
if err != nil {
return 0, err
}
switch num.typ {
case Int64:
return float64(num.ival), nil
Expand Down Expand Up @@ -373,7 +402,32 @@ func addNumericWithError(v1, v2 numeric) (numeric, error) {
return floatPlusAny(v1.fval, v2), nil
}
panic("unreachable")
}

func subtractNumericWithError(v1, v2 numeric) (numeric, error) {
switch v1.typ {
case Int64:
switch v2.typ {
case Int64:
return intMinusIntWithError(v1.ival, v2.ival)
case Uint64:
return intMinusUintWithError(v1.ival, v2.uval)
case Float64:
return anyMinusFloat(v1, v2.fval), nil
}
case Uint64:
switch v2.typ {
case Int64:
return uintMinusIntWithError(v1.uval, v2.ival)
case Uint64:
return uintMinusUintWithError(v1.uval, v2.uval)
case Float64:
return anyMinusFloat(v1, v2.fval), nil
}
case Float64:
return floatMinusAny(v1.fval, v2), nil
}
panic("unreachable")
}

// prioritize reorders the input parameters
Expand All @@ -388,7 +442,6 @@ func prioritize(v1, v2 numeric) (altv1, altv2 numeric) {
if v2.typ == Float64 {
return v2, v1
}

}
return v1, v2
}
Expand All @@ -415,36 +468,67 @@ func intPlusIntWithError(v1, v2 int64) (numeric, error) {
return numeric{typ: Int64, ival: result}, nil
}

func intMinusIntWithError(v1, v2 int64) (numeric, error) {
result := v1 - v2

if (result < v1) != (v2 > 0) {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v - %v", v1, v2)
}
return numeric{typ: Int64, ival: result}, nil
}

func intMinusUintWithError(v1 int64, v2 uint64) (numeric, error) {
if v1 < 0 || v1 < int64(v2) {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2)
}
return uintMinusUintWithError(uint64(v1), v2)
}

func uintPlusInt(v1 uint64, v2 int64) numeric {
return uintPlusUint(v1, uint64(v2))
}

func uintPlusIntWithError(v1 uint64, v2 int64) (numeric, error) {
if v2 >= math.MaxInt64 && v1 > 0 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2)
if v2 < 0 && v1 < uint64(v2) {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2)
}

//convert to int -> uint is because for numeric operators (such as + or -)
//where one of the operands is an unsigned integer, the result is unsigned by default.
// convert to int -> uint is because for numeric operators (such as + or -)
// where one of the operands is an unsigned integer, the result is unsigned by default.
return uintPlusUintWithError(v1, uint64(v2))
}

func uintMinusIntWithError(v1 uint64, v2 int64) (numeric, error) {
if int64(v1) < v2 && v2 > 0 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2)
}
// uint - (- int) = uint + int
if v2 < 0 {
return uintPlusIntWithError(v1, -v2)
}
return uintMinusUintWithError(v1, uint64(v2))
}

func uintPlusUint(v1, v2 uint64) numeric {
result := v1 + v2
if result < v2 {
return numeric{typ: Float64, fval: float64(v1) + float64(v2)}

}
return numeric{typ: Uint64, uval: result}
}

func uintPlusUintWithError(v1, v2 uint64) (numeric, error) {
result := v1 + v2

if result < v2 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2)
}
return numeric{typ: Uint64, uval: result}, nil
}

func uintMinusUintWithError(v1, v2 uint64) (numeric, error) {
result := v1 - v2
if v2 > v1 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2)
}
return numeric{typ: Uint64, uval: result}, nil
}

Expand All @@ -458,6 +542,26 @@ func floatPlusAny(v1 float64, v2 numeric) numeric {
return numeric{typ: Float64, fval: v1 + v2.fval}
}

func floatMinusAny(v1 float64, v2 numeric) numeric {
switch v2.typ {
case Int64:
v2.fval = float64(v2.ival)
case Uint64:
v2.fval = float64(v2.uval)
}
return numeric{typ: Float64, fval: v1 - v2.fval}
}

func anyMinusFloat(v1 numeric, v2 float64) numeric {
switch v1.typ {
case Int64:
v1.fval = float64(v1.ival)
case Uint64:
v1.fval = float64(v1.uval)
}
return numeric{typ: Float64, fval: v1.fval - v2}
}

func castFromNumeric(v numeric, resultType querypb.Type) Value {
switch {
case IsSigned(resultType):
Expand Down
Loading