Skip to content

Commit

Permalink
evalengine: More cleanup (#12573)
Browse files Browse the repository at this point in the history
* evalengine: fix multi-comparisons

Signed-off-by: Vicent Marti <[email protected]>

* evalengine: add a limit for REPEAT

Signed-off-by: Vicent Marti <[email protected]>

* evalengine: Additional fixes for edge cases

This solves edge cases around decimal promotion which follows a
conversion to int64 (not uint64) and then is casted to uint64 for all
bitwise operations.

Also adds a bunch of tests for the hex() function and implemented
conversions for non string like types.

It also found more bugs in MySQL CASE handling, yay.

Signed-off-by: Dirkjan Bussink <[email protected]>

---------

Signed-off-by: Vicent Marti <[email protected]>
Signed-off-by: Dirkjan Bussink <[email protected]>
Co-authored-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
vmg and dbussink authored Mar 9, 2023
1 parent a24e22c commit 365458e
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 55 deletions.
13 changes: 11 additions & 2 deletions go/vt/vtgate/evalengine/eval_numeric.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,13 @@ func (e *evalDecimal) negate() evalNumeric {

func (e *evalDecimal) toInt64() *evalInt64 {
dec := e.dec.Round(0)
i, _ := dec.Int64()
i, valid := dec.Int64()
if !valid {
if dec.Sign() < 0 {
return newEvalInt64(math.MinInt64)
}
return newEvalInt64(math.MaxInt64)
}
return newEvalInt64(i)
}

Expand All @@ -333,6 +339,9 @@ func (e *evalDecimal) toUint64() *evalUint64 {
return newEvalUint64(uint64(i))
}

u, _ := dec.Uint64()
u, valid := dec.Uint64()
if !valid {
return newEvalUint64(math.MaxUint64)
}
return newEvalUint64(u)
}
20 changes: 10 additions & 10 deletions go/vt/vtgate/evalengine/expr_bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ func (b *BitwiseNotExpr) eval(env *ExpressionEnv) (eval, error) {
return newEvalBinary(out), nil
}

eu := evalToNumeric(e).toUint64()
return newEvalUint64(^eu.u), nil
eu := evalToNumeric(e).toInt64()
return newEvalUint64(^uint64(eu.i)), nil
}

func (b *BitwiseNotExpr) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) {
Expand Down Expand Up @@ -201,9 +201,9 @@ func (bit *BitwiseExpr) eval(env *ExpressionEnv) (eval, error) {
}
}

lu := evalToNumeric(l).toUint64()
ru := evalToNumeric(r).toUint64()
return newEvalUint64(op.numeric(lu.u, ru.u)), nil
lu := evalToNumeric(l).toInt64()
ru := evalToNumeric(r).toInt64()
return newEvalUint64(op.numeric(uint64(lu.i), uint64(ru.i))), nil

case opBitShift:
/*
Expand All @@ -213,12 +213,12 @@ func (bit *BitwiseExpr) eval(env *ExpressionEnv) (eval, error) {
unsigned 64-bit integer as necessary.
*/
if l, ok := l.(*evalBytes); ok && l.isBinary() && !l.isHexOrBitLiteral() {
ru := evalToNumeric(r).toUint64()
return newEvalBinary(op.binary(l.bytes, ru.u)), nil
ru := evalToNumeric(r).toInt64()
return newEvalBinary(op.binary(l.bytes, uint64(ru.i))), nil
}
lu := evalToNumeric(l).toUint64()
ru := evalToNumeric(r).toUint64()
return newEvalUint64(op.numeric(lu.u, ru.u)), nil
lu := evalToNumeric(l).toInt64()
ru := evalToNumeric(r).toInt64()
return newEvalUint64(op.numeric(uint64(lu.i), uint64(ru.i))), nil

default:
panic("unexpected bit operation")
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/fn_bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func (call *builtinBitCount) eval(env *ExpressionEnv) (eval, error) {
count += bits.OnesCount8(b)
}
} else {
u := evalToNumeric(arg).toUint64()
count = bits.OnesCount64(u.u)
u := evalToNumeric(arg).toInt64()
count = bits.OnesCount64(uint64(u.i))
}
return newEvalInt64(int64(count)), nil
}
Expand Down
84 changes: 50 additions & 34 deletions go/vt/vtgate/evalengine/fn_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package evalengine

import (
"bytes"
"math"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/collations/charset"
Expand Down Expand Up @@ -65,11 +64,12 @@ func (b *builtinCoalesce) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) {

func getMultiComparisonFunc(args []eval) multiComparisonFunc {
var (
integers int
floats int
decimals int
text int
binary int
integersI int
integersU int
floats int
decimals int
text int
binary int
)

/*
Expand All @@ -90,13 +90,9 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc {

switch arg := arg.(type) {
case *evalInt64:
integers++
integersI++
case *evalUint64:
if arg.u > math.MaxInt64 {
decimals++
} else {
integers++
}
integersU++
case *evalFloat:
floats++
case *evalDecimal:
Expand All @@ -111,8 +107,14 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc {
}
}

if integers == len(args) {
return compareAllInteger
if integersI+integersU == len(args) {
if integersI == len(args) {
return compareAllInteger_i
}
if integersU == len(args) {
return compareAllInteger_u
}
return compareAllDecimal
}
if binary > 0 || text > 0 {
if text > 0 {
Expand All @@ -132,15 +134,26 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc {
panic("unexpected argument type")
}

func compareAllInteger(args []eval, cmp int) (eval, error) {
var candidateI = args[0].(*evalInt64).i
func compareAllInteger_u(args []eval, cmp int) (eval, error) {
x := args[0].(*evalUint64)
for _, arg := range args[1:] {
thisI := arg.(*evalInt64).i
if (cmp < 0) == (thisI < candidateI) {
candidateI = thisI
y := arg.(*evalUint64)
if (cmp < 0) == (y.u < x.u) {
x = y
}
}
return &evalInt64{candidateI}, nil
return x, nil
}

func compareAllInteger_i(args []eval, cmp int) (eval, error) {
x := args[0].(*evalInt64)
for _, arg := range args[1:] {
y := arg.(*evalInt64)
if (cmp < 0) == (y.i < x.i) {
x = y
}
}
return x, nil
}

func compareAllFloat(args []eval, cmp int) (eval, error) {
Expand Down Expand Up @@ -243,12 +256,13 @@ func (call *builtinMultiComparison) eval(env *ExpressionEnv) (eval, error) {

func (call *builtinMultiComparison) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) {
var (
integers int
floats int
decimals int
text int
binary int
flags typeFlag
integersI int
integersU int
floats int
decimals int
text int
binary int
flags typeFlag
)

for _, expr := range call.Arguments {
Expand All @@ -257,13 +271,9 @@ func (call *builtinMultiComparison) typeof(env *ExpressionEnv) (sqltypes.Type, t

switch tt {
case sqltypes.Int8, sqltypes.Int16, sqltypes.Int32, sqltypes.Int64:
integers++
integersI++
case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64:
if f&flagIntegerOvf != 0 {
decimals++
} else {
integers++
}
integersU++
case sqltypes.Float32, sqltypes.Float64:
floats++
case sqltypes.Decimal:
Expand All @@ -278,8 +288,14 @@ func (call *builtinMultiComparison) typeof(env *ExpressionEnv) (sqltypes.Type, t
if flags&flagNull != 0 {
return sqltypes.Null, flags
}
if integers == len(call.Arguments) {
return sqltypes.Int64, flags
if integersI+integersU == len(call.Arguments) {
if integersI == len(call.Arguments) {
return sqltypes.Int64, flags
}
if integersU == len(call.Arguments) {
return sqltypes.Uint64, flags
}
return sqltypes.Decimal, flags
}
if binary > 0 || text > 0 {
if text > 0 {
Expand Down
15 changes: 9 additions & 6 deletions go/vt/vtgate/evalengine/fn_hex.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
"math/bits"

"vitess.io/vitess/go/sqltypes"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

type builtinHex struct {
Expand All @@ -44,16 +42,21 @@ func (call *builtinHex) eval(env *ExpressionEnv) (eval, error) {
case *evalBytes:
encoded = hexEncodeBytes(arg.bytes)
case evalNumeric:
encoded = hexEncodeUint(arg.toUint64().u)
encoded = hexEncodeUint(uint64(arg.toInt64().i))
default:
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported HEX argument: %s", arg.SQLType())
encoded = hexEncodeBytes(arg.ToRawBytes())
}
if arg.SQLType() == sqltypes.Blob || arg.SQLType() == sqltypes.TypeJSON {
return newEvalRaw(sqltypes.Text, encoded, env.collation()), nil
}

return newEvalText(encoded, env.collation()), nil
}

func (call *builtinHex) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) {
_, f := call.Arguments[0].typeof(env)
tt, f := call.Arguments[0].typeof(env)
if tt == sqltypes.Blob || tt == sqltypes.TypeJSON {
return sqltypes.Text, f
}
return sqltypes.VarChar, f
}

Expand Down
29 changes: 29 additions & 0 deletions go/vt/vtgate/evalengine/fn_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,20 @@ func (call *builtinASCII) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) {
return sqltypes.Int64, f
}

// maxRepeatLength is the maximum number of times a string can be repeated.
// This is based on how MySQL behaves here. The maximum value in MySQL is
// actually based on `max_allowed_packet`. The value here is the maximum
// for `max_allowed_packet` with 1GB.
// See https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_max_allowed_packet
//
// Practically though, this is really whacky anyway.
// There's 3 possible states:
// - `<= max_allowed_packet` and actual packet generated is `<= max_allowed_packet`. It works
// - `<= max_allowed_packet` but the actual packet generated is `> max_allowed_packet` so it fails with an
// error: `ERROR 2020 (HY000): Got packet bigger than 'max_allowed_packet' bytes` and the client gets disconnected.
// - `> max_allowed_packet`, no error and returns `NULL`.
const maxRepeatLength = 1073741824

type builtinRepeat struct {
CallExpr
}
Expand All @@ -211,9 +225,24 @@ func (call *builtinRepeat) eval(env *ExpressionEnv) (eval, error) {
if repeat < 0 {
repeat = 0
}
if !checkMaxLength(int64(len(text.bytes)), repeat) {
return nil, nil
}

return newEvalText(bytes.Repeat(text.bytes, int(repeat)), text.col), nil
}

func checkMaxLength(len, repeat int64) bool {
if repeat <= 0 {
return true
}
if len*repeat/repeat != len {
// we have an overflow, can't be a valid length.
return false
}
return len*repeat <= maxRepeatLength
}

func (call *builtinRepeat) typeof(env *ExpressionEnv) (sqltypes.Type, typeFlag) {
_, f1 := call.Arguments[0].typeof(env)
// typecheck the right-hand argument but ignore its flags
Expand Down
28 changes: 27 additions & 1 deletion go/vt/vtgate/evalengine/testcases/cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ type FnLength struct{ defaultEnv }
type FnBitLength struct{ defaultEnv }
type FnAscii struct{ defaultEnv }
type FnRepeat struct{ defaultEnv }
type FnHex struct{ defaultEnv }

var Cases = []TestCase{
JSONExtract{},
Expand Down Expand Up @@ -114,6 +115,7 @@ var Cases = []TestCase{
FnBitLength{},
FnAscii{},
FnRepeat{},
FnHex{},
}

func (JSONPathOperations) Test(yield Iterator) {
Expand Down Expand Up @@ -248,6 +250,12 @@ func comparisonSkip(a, b string) bool {
if b == "-1" && a == "18446744073709551615" {
return true
}
if a == "9223372036854775808" && b == "-9223372036854775808" {
return true
}
if a == "-9223372036854775808" && b == "9223372036854775808" {
return true
}
return false
}

Expand Down Expand Up @@ -550,6 +558,10 @@ func (MultiComparisons) Test(yield Iterator) {
strconv.FormatUint(math.MaxUint64, 10),
strconv.FormatUint(math.MaxInt64, 10),
strconv.FormatInt(math.MinInt64, 10),
`CAST(0 AS UNSIGNED)`,
`CAST(1 AS UNSIGNED)`,
`CAST(2 AS UNSIGNED)`,
`CAST(420 AS UNSIGNED)`,
`'foobar'`, `'FOOBAR'`,
`"0"`, `"-1"`, `"1"`,
`_utf8mb4 'foobar'`, `_utf8mb4 'FOOBAR'`,
Expand Down Expand Up @@ -730,10 +742,24 @@ func (FnAscii) Test(yield Iterator) {
}

func (FnRepeat) Test(yield Iterator) {
counts := []string{"-1", "1.2", "3"}
counts := []string{"-1", "1.2", "3", "1073741825"}
for _, str := range inputStrings {
for _, cnt := range counts {
yield(fmt.Sprintf("repeat(%s, %s)", str, cnt), nil)
}
}
}

func (FnHex) Test(yield Iterator) {
for _, str := range inputStrings {
yield(fmt.Sprintf("hex(%s)", str), nil)
}

for _, str := range inputConversions {
yield(fmt.Sprintf("hex(%s)", str), nil)
}

for _, str := range inputBitwise {
yield(fmt.Sprintf("hex(%s)", str), nil)
}
}
3 changes: 3 additions & 0 deletions go/vt/vtgate/evalengine/testcases/inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ var inputBitwise = []string{
"0.0e0", "1.0e0", "255.0", "1.5e0", "-1.5e0", "1.1e0", "-1e0", "-255e0", "7e0", "9e0", "13e0",
strconv.FormatUint(math.MaxUint64, 10),
strconv.FormatUint(math.MaxInt64, 10),
strconv.FormatUint(math.MaxInt64+1, 10),
strconv.FormatInt(math.MinInt64, 10),
"18446744073709551616",
"-9223372036854775809",
`"foobar"`, `"foobar1234"`, `"0"`, "0x1", "-0x1", "X'ff'", "X'00'",
`"1abcd"`, "NULL", `_binary "foobar"`, `_binary "foobar1234"`,
"64", "'64'", "_binary '64'", "X'40'", "_binary X'40'",
Expand Down

0 comments on commit 365458e

Please sign in to comment.