Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 98 additions & 30 deletions go/sqltypes/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package sqltypes
import (
"bytes"
"fmt"
"math"

"strconv"

querypb "vitess.io/vitess/go/vt/proto/query"
Expand All @@ -40,6 +42,25 @@ type numeric struct {

var zeroBytes = []byte("0")

// Add adds two values together
// if v1 or v2 is null, then it returns null
func Add(v1, v2 Value) (Value, error) {
if v1.IsNull() || v2.IsNull() {
return NULL, nil
}

lv1, err := newNumeric(v1)

lv2, err := newNumeric(v2)

lresult, err := addNumericWithError(lv1, lv2)
if err != nil {
return NULL, err
}

return castFromNumeric(lresult, lresult.typ), nil
}

// NullsafeAdd adds two Values in a null-safe manner. A null value
// is treated as 0. If both values are null, then a null is returned.
// If both values are not null, a numeric value is built
Expand All @@ -51,7 +72,7 @@ var zeroBytes = []byte("0")
// addition, if one of the input types was Decimal, then
// a Decimal is built. Otherwise, the final type of the
// result is preserved.
func NullsafeAdd(v1, v2 Value, resultType querypb.Type) (Value, error) {
func NullsafeAdd(v1, v2 Value, resultType querypb.Type) Value {
if v1.IsNull() {
v1 = MakeTrusted(resultType, zeroBytes)
}
Expand All @@ -61,16 +82,14 @@ func NullsafeAdd(v1, v2 Value, resultType querypb.Type) (Value, error) {

lv1, err := newNumeric(v1)
if err != nil {
return NULL, err
return NULL
}
lv2, err := newNumeric(v2)
if err != nil {
return NULL, err
}
lresult, err := addNumeric(lv1, lv2)
if err != nil {
return NULL, err
return NULL
}
lresult := addNumeric(lv1, lv2)

return castFromNumeric(lresult, resultType)
}

Expand Down Expand Up @@ -224,10 +243,7 @@ func ToInt64(v Value) (int64, error) {

// ToFloat64 converts Value to float64.
func ToFloat64(v Value) (float64, error) {
num, err := newNumeric(v)
if err != nil {
return 0, err
}
num, _ := newNumeric(v)
switch num.typ {
case Int64:
return float64(num.ival), nil
Expand Down Expand Up @@ -292,7 +308,7 @@ func newNumeric(v Value) (numeric, error) {
if fval, err := strconv.ParseFloat(str, 64); err == nil {
return numeric{fval: fval, typ: Float64}, nil
}
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str)
return numeric{ival: 0, typ: Int64}, nil
}

// newIntegralNumeric parses a value and produces an Int64 or Uint64.
Expand Down Expand Up @@ -323,22 +339,41 @@ func newIntegralNumeric(v Value) (numeric, error) {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str)
}

func addNumeric(v1, v2 numeric) (numeric, error) {
func addNumeric(v1, v2 numeric) numeric {
v1, v2 = prioritize(v1, v2)
switch v1.typ {
case Int64:
return intPlusInt(v1.ival, v2.ival), nil
return intPlusInt(v1.ival, v2.ival)
case Uint64:
switch v2.typ {
case Int64:
return uintPlusInt(v1.uval, v2.ival)
case Uint64:
return uintPlusUint(v1.uval, v2.uval), nil
return uintPlusUint(v1.uval, v2.uval)
}
case Float64:
return floatPlusAny(v1.fval, v2)
}
panic("unreachable")
}

func addNumericWithError(v1, v2 numeric) (numeric, error) {
v1, v2 = prioritize(v1, v2)
switch v1.typ {
case Int64:
return intPlusIntWithError(v1.ival, v2.ival)
case Uint64:
switch v2.typ {
case Int64:
return uintPlusIntWithError(v1.uval, v2.ival)
case Uint64:
return uintPlusUintWithError(v1.uval, v2.uval)
}
case Float64:
return floatPlusAny(v1.fval, v2), nil
}
panic("unreachable")

}

// prioritize reorders the input parameters
Expand All @@ -353,6 +388,7 @@ func prioritize(v1, v2 numeric) (altv1, altv2 numeric) {
if v2.typ == Float64 {
return v2, v1
}

}
return v1, v2
}
Expand All @@ -371,21 +407,47 @@ overflow:
return numeric{typ: Float64, fval: float64(v1) + float64(v2)}
}

func uintPlusInt(v1 uint64, v2 int64) (numeric, error) {
if v2 < 0 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "cannot add a negative number to an unsigned integer: %d, %d", v1, v2)
func intPlusIntWithError(v1, v2 int64) (numeric, error) {
result := v1 + v2
if (result > v1) != (v2 > 0) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a general practice, we should prefer clarify over compactness. A clearer way to write this if would have been:
if (v1 > 0 && v2 > 0 && result < 0) || (v1 < 0 && v2 < 0 && result > 0)

return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2)
}
return uintPlusUint(v1, uint64(v2)), nil
return numeric{typ: Int64, ival: result}, nil
}

func uintPlusInt(v1 uint64, v2 int64) numeric {
return uintPlusUint(v1, uint64(v2))
}

func uintPlusIntWithError(v1 uint64, v2 int64) (numeric, error) {
if v2 >= math.MaxInt64 && v1 > 0 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2)
}

//convert to int -> uint is because for numeric operators (such as + or -)
//where one of the operands is an unsigned integer, the result is unsigned by default.
return uintPlusUintWithError(v1, uint64(v2))
}

func uintPlusUint(v1, v2 uint64) numeric {
result := v1 + v2
if result < v2 {
return numeric{typ: Float64, fval: float64(v1) + float64(v2)}

}
return numeric{typ: Uint64, uval: result}
}

func uintPlusUintWithError(v1, v2 uint64) (numeric, error) {
result := v1 + v2

if result < v2 {
return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2)
}

return numeric{typ: Uint64, uval: result}, nil
}

func floatPlusAny(v1 float64, v2 numeric) numeric {
switch v2.typ {
case Int64:
Expand All @@ -396,37 +458,43 @@ func floatPlusAny(v1 float64, v2 numeric) numeric {
return numeric{typ: Float64, fval: v1 + v2.fval}
}

func castFromNumeric(v numeric, resultType querypb.Type) (Value, error) {
func castFromNumeric(v numeric, resultType querypb.Type) Value {
switch {
case IsSigned(resultType):
switch v.typ {
case Int64:
return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)), nil
case Uint64, Float64:
return NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion: %v to %v", v.typ, resultType)
return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10))
case Uint64:
return MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.uval), 10))
case Float64:
return MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.fval), 10))

}
case IsUnsigned(resultType):
switch v.typ {
case Uint64:
return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)), nil
case Int64, Float64:
return NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion: %v to %v", v.typ, resultType)
return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10))
case Int64:
return MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.ival), 10))
case Float64:
return MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.fval), 10))

}
case IsFloat(resultType) || resultType == Decimal:
switch v.typ {
case Int64:
return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)), nil
return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10))
case Uint64:
return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)), nil
return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10))
case Float64:
format := byte('g')
if resultType == Decimal {
format = 'f'
}
return MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64)), nil
return MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64))
}
}
return NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion to non-numeric: %v", resultType)
return NULL
}

func compareNumeric(v1, v2 numeric) int {
Expand Down
Loading