Skip to content
7 changes: 7 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8592,6 +8592,13 @@ from typestable`,
Query: "flush engine logs",
Expected: []sql.Row{},
},
// TODO: this is the largest scale decimal we support, but it's not the largest MySQL supports
{
Query: "select round(5e29, -30)",
Expected: []sql.Row{
{1e30},
},
},
}

var KeylessQueries = []QueryTest{
Expand Down
136 changes: 41 additions & 95 deletions sql/expression/function/ceil_round_floor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
package function

import (
"encoding/hex"
"fmt"
"math"
"strconv"

"github.com/shopspring/decimal"

Expand Down Expand Up @@ -242,119 +240,67 @@ func (r *Round) Children() []sql.Expression {

// Eval implements the Expression interface.
func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
xVal, err := r.Left.Eval(ctx, row)
val, err := r.Left.Eval(ctx, row)
if err != nil {
return nil, err
}

if xVal == nil {
if val == nil {
return nil, nil
}

dVal := float64(0)
decType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)
val, _, err = decType.Convert(val)
if err != nil {
// TODO: truncate
return nil, err
}

prec := int32(0)
if r.Right != nil {
var dTemp interface{}
dTemp, err = r.Right.Eval(ctx, row)
var tmp interface{}
tmp, err = r.Right.Eval(ctx, row)
if err != nil {
return nil, err
}

if dTemp != nil {
switch dNum := dTemp.(type) {
case float64:
dVal = float64(int64(dNum))
case float32:
dVal = float64(int64(dNum))
case int64:
dVal = float64(dNum)
case int32:
dVal = float64(dNum)
case int16:
dVal = float64(dNum)
case int8:
dVal = float64(dNum)
case uint64:
dVal = float64(dNum)
case uint32:
dVal = float64(dNum)
case uint16:
dVal = float64(dNum)
case uint8:
dVal = float64(dNum)
case int:
dVal = float64(dNum)
case []byte:
val, err := strconv.ParseUint(hex.EncodeToString(dNum), 16, 64)
if err != nil {
return nil, err
}
dVal = float64(val)
default:
dTemp, _, err = types.Float64.Convert(dTemp)
if err == nil {
dVal = dTemp.(float64)
}
if tmp == nil {
return nil, nil
}

if tmp != nil {
tmp, _, err = types.Int32.Convert(tmp)
if err != nil {
// TODO: truncate
return nil, err
}
if dVal > 30 { // MySQL cuts off at 30 for larger values
dVal = 30
prec = tmp.(int32)
// MySQL cuts off at 30 for larger values
// TODO: these limits are fine only because we can't handle decimals larger than this
if prec > types.DecimalTypeMaxPrecision {
prec = types.DecimalTypeMaxPrecision
}
if prec < -types.DecimalTypeMaxScale {
prec = -types.DecimalTypeMaxScale
}
}
}

if types.IsText(r.Left.Type()) {
xVal, _, err = types.Float64.Convert(xVal)
if err != nil {
return int32(0), nil
}
} else if !types.IsNumber(r.Left.Type()) {
xVal, _, err = types.Float64.Convert(xVal)
if err != nil {
return int32(0), nil
}

xNum := xVal.(float64)
return int32(math.Round(xNum*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
var res interface{}
tmp := val.(decimal.Decimal).Round(prec)
if types.IsSigned(r.Left.Type()) {
res, _, err = types.Int64.Convert(tmp)
} else if types.IsUnsigned(r.Left.Type()) {
res, _, err = types.Uint64.Convert(tmp)
} else if types.IsFloat(r.Left.Type()) {
res, _, err = types.Float64.Convert(tmp)
} else if types.IsDecimal(r.Left.Type()) {
res = tmp
} else if types.IsTextBlob(r.Left.Type()) {
res, _, err = types.Float64.Convert(tmp)
}

// One way to round to a decimal place is to shift the number up by the desired decimal position, round to the
// nearest integer, and then shift back down.
// For example, we have 5.855 and want to round to 2 decimal places.
// In this case, xNum = 5.855 and dVal = 2
// round(xNum * 10^dVal) / 10^dVal
// round(5.855 * 10^2) / 10^2
// round(5.855 * 100) / 100
// round(585.5) / 100
// 586 / 100
// 5.86
switch xNum := xVal.(type) {
case float64:
return math.Round(xNum*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal), nil
case float32:
return float32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case int64:
return int64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case int32:
return int32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case int16:
return int16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case int8:
return int8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case uint64:
return uint64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case uint32:
return uint32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case uint16:
return uint16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case uint8:
return uint8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case int:
return int(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil
case decimal.Decimal:
return xNum.Round(int32(dVal)), nil
default:
return nil, sql.ErrInvalidType.New(r.Left.Type().String())
}
return res, err
}

// IsNullable implements the Expression interface.
Expand Down
Loading