Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,32 @@ func TestCompilerSingle(t *testing.T) {
// exercise the path to push sets onto the stack.
result: `FLOAT64(1)`,
},
{
expression: `GREATEST(NULL, '2023-10-24')`,
result: `NULL`,
},
{
expression: `GREATEST(NULL, 1)`,
result: `NULL`,
},
{
expression: `GREATEST(NULL, 1.0)`,
result: `NULL`,
},
{
expression: `GREATEST(NULL, 1.0e0)`,
result: `NULL`,
},
{
expression: `GREATEST(column0, 1.0e0)`,
values: []sqltypes.Value{sqltypes.MakeTrusted(sqltypes.Enum, []byte("foo"))},
// Enum and set are treated as strings in this case.
result: `VARCHAR("foo")`,
},
{
expression: `GREATEST(JSON_OBJECT(), JSON_ARRAY())`,
result: `VARCHAR("{}")`,
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down
60 changes: 54 additions & 6 deletions go/vt/vtgate/evalengine/fn_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiCom
timestamp int
date int
time int
json int
)

/*
Expand Down Expand Up @@ -147,7 +148,15 @@ func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiCom
if !arg.isHexOrBitLiteral() {
call.prec = max(call.prec, datetime2.DefaultPrecision)
}
case sqltypes.Geometry:
return func(_ *ExpressionEnv, _ []eval, _, _ int) (eval, error) {
return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "unsupported argument of geometry type for GREATEST/LEAST")
}
}
case *evalSet:
text++
case *evalEnum:
text++
case *evalTemporal:
temporal++
call.prec = max(call.prec, int(arg.prec))
Expand All @@ -161,6 +170,8 @@ func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiCom
case sqltypes.Time:
time++
}
case *evalJSON:
json++
}
}

Expand Down Expand Up @@ -236,6 +247,9 @@ func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiCom
if decimals > 0 {
return compareAllDecimal
}
if json > 0 {
return compareAllText
}
}
panic("unexpected argument type")
}
Expand Down Expand Up @@ -444,7 +458,7 @@ func (call *builtinMultiComparison) eval(env *ExpressionEnv) (eval, error) {
return call.getMultiComparisonFunc(args)(env, args, call.cmp, call.prec)
}

func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, error) {
func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype, jumps []*jump) (ctype, error) {
var ca collationAggregation
var f typeFlag
for _, arg := range args {
Expand All @@ -456,10 +470,11 @@ func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype,

tc := ca.result()
c.asm.Fn_MULTICMP_c(len(args), call.cmp < 0, tc)
c.asm.jumpDestination(jumps...)
return ctype{Type: sqltypes.VarChar, Flag: f, Col: tc}, nil
}

func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, error) {
func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype, jumps []*jump) (ctype, error) {
var f typeFlag
var size int32
var scale int32
Expand All @@ -470,6 +485,7 @@ func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype,
c.compileToDecimal(tt, len(args)-i)
}
c.asm.Fn_MULTICMP_d(len(args), call.cmp < 0)
c.asm.jumpDestination(jumps...)
return ctype{Type: sqltypes.Decimal, Flag: f, Col: collationNumeric, Size: size, Scale: scale}, nil
}

Expand All @@ -486,6 +502,7 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) {
time int
text int
binary int
json int
args []ctype
nullable bool
prec int
Expand All @@ -500,12 +517,17 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) {
In all other cases, the arguments are compared as binary strings.
*/

for _, expr := range call.Arguments {
jumps := make([]*jump, 0, len(call.Arguments))
for i, expr := range call.Arguments {
tt, err := expr.compile(c)
if err != nil {
return ctype{}, err
}

if tt.nullable() {
jumps = append(jumps, c.compileNullCheckArg(tt, i))
}

args = append(args, tt)

nullable = nullable || tt.nullable()
Expand Down Expand Up @@ -544,6 +566,12 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) {
temporal++
time++
prec = max(prec, int(tt.Size))
case sqltypes.Set, sqltypes.Enum:
text++
case sqltypes.TypeJSON:
json++
case sqltypes.Geometry:
return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unsupported argument of geometry type for GREATEST/LEAST")
case sqltypes.Null:
nullable = true
default:
Expand Down Expand Up @@ -579,6 +607,7 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) {
}
}
c.asm.Fn_MULTICMP_temporal(len(args), call.cmp < 0)
c.asm.jumpDestination(jumps...)
return ctype{Type: typ, Flag: f, Col: collationBinary}, nil
} else if temporal > 0 {
var ca collationAggregation
Expand Down Expand Up @@ -608,35 +637,54 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) {
case time > 0:
c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(nil), len(args), call.cmp, prec)
}
c.asm.jumpDestination(jumps...)
return ctype{Type: sqltypes.VarChar, Flag: f, Col: tc}, nil
}
if signed+unsigned == len(args) {
if signed == len(args) {
c.asm.Fn_MULTICMP_i(len(args), call.cmp < 0)
c.asm.jumpDestination(jumps...)
return ctype{Type: sqltypes.Int64, Flag: f, Col: collationNumeric}, nil
}
if unsigned == len(args) {
c.asm.Fn_MULTICMP_u(len(args), call.cmp < 0)
c.asm.jumpDestination(jumps...)
return ctype{Type: sqltypes.Uint64, Flag: f, Col: collationNumeric}, nil
}
return call.compile_d(c, args)
return call.compile_d(c, args, jumps)
}
if binary > 0 || text > 0 {
if text > 0 {
return call.compile_c(c, args)
return call.compile_c(c, args, jumps)
}
c.asm.Fn_MULTICMP_b(len(args), call.cmp < 0)
c.asm.jumpDestination(jumps...)
return ctype{Type: sqltypes.VarBinary, Flag: f, Col: collationBinary}, nil
} else {
if floats > 0 {
for i, tt := range args {
c.compileToFloat(tt, len(args)-i)
}
c.asm.Fn_MULTICMP_f(len(args), call.cmp < 0)
c.asm.jumpDestination(jumps...)
return ctype{Type: sqltypes.Float64, Flag: f, Col: collationNumeric}, nil
}
if decimals > 0 {
return call.compile_d(c, args)
return call.compile_d(c, args, jumps)
}
if json > 0 {
c.asm.Fn_MULTICMP_c(len(args), call.cmp < 0, collationJSON)
c.asm.jumpDestination(jumps...)
return ctype{Type: sqltypes.Text, Flag: f, Col: collationJSON}, nil
}

// The next case only gets hit if we already know at least one of the inputs
// is a static NULL typed value. That means we already have removed all items
// from the stack at this point and the top is a NULL to return.
if nullable {
c.asm.adjustStack(-len(args) + 1)
c.asm.jumpDestination(jumps...)
return ctype{Type: sqltypes.Null, Flag: f, Col: collationBinary}, nil
}
}
return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected argument for GREATEST/LEAST")
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/testcases/cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ func StrcmpComparison(yield Query) {

func MultiComparisons(yield Query) {
var numbers = []string{
`0`, `-1`, `1`, `0.0`, `1.0`, `-1.0`, `1.0E0`, `-1.0E0`, `0.0E0`,
`NULL`, `0`, `-1`, `1`, `0.0`, `1.0`, `-1.0`, `1.0E0`, `-1.0E0`, `0.0E0`,
strconv.FormatUint(math.MaxUint64, 10),
strconv.FormatUint(math.MaxInt64, 10),
strconv.FormatInt(math.MinInt64, 10),
Expand Down
Loading