diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 742aa9cc072..90c9da600c6 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -775,6 +775,32 @@ func TestCompilerSingle(t *testing.T) { // exercise the path to push sets onto the stack. result: `FLOAT64(1)`, }, + { + expression: `GREATEST(NULL, '2023-10-24')`, + result: `NULL`, + }, + { + expression: `GREATEST(NULL, 1)`, + result: `NULL`, + }, + { + expression: `GREATEST(NULL, 1.0)`, + result: `NULL`, + }, + { + expression: `GREATEST(NULL, 1.0e0)`, + result: `NULL`, + }, + { + expression: `GREATEST(column0, 1.0e0)`, + values: []sqltypes.Value{sqltypes.MakeTrusted(sqltypes.Enum, []byte("foo"))}, + // Enum and set are treated as strings in this case. + result: `VARCHAR("foo")`, + }, + { + expression: `GREATEST(JSON_OBJECT(), JSON_ARRAY())`, + result: `VARCHAR("{}")`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index 1084a240bd8..2a6505bbd68 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -108,6 +108,7 @@ func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiCom timestamp int date int time int + json int ) /* @@ -147,7 +148,15 @@ func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiCom if !arg.isHexOrBitLiteral() { call.prec = max(call.prec, datetime2.DefaultPrecision) } + case sqltypes.Geometry: + return func(_ *ExpressionEnv, _ []eval, _, _ int) (eval, error) { + return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "unsupported argument of geometry type for GREATEST/LEAST") + } } + case *evalSet: + text++ + case *evalEnum: + text++ case *evalTemporal: temporal++ call.prec = max(call.prec, int(arg.prec)) @@ -161,6 +170,8 @@ func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiCom case sqltypes.Time: time++ } + case *evalJSON: + json++ } } @@ -236,6 +247,9 @@ func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiCom if decimals > 0 { return compareAllDecimal } + if json > 0 { + return compareAllText + } } panic("unexpected argument type") } @@ -444,7 +458,7 @@ func (call *builtinMultiComparison) eval(env *ExpressionEnv) (eval, error) { return call.getMultiComparisonFunc(args)(env, args, call.cmp, call.prec) } -func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, error) { +func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype, jumps []*jump) (ctype, error) { var ca collationAggregation var f typeFlag for _, arg := range args { @@ -456,10 +470,11 @@ func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, tc := ca.result() c.asm.Fn_MULTICMP_c(len(args), call.cmp < 0, tc) + c.asm.jumpDestination(jumps...) return ctype{Type: sqltypes.VarChar, Flag: f, Col: tc}, nil } -func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, error) { +func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype, jumps []*jump) (ctype, error) { var f typeFlag var size int32 var scale int32 @@ -470,6 +485,7 @@ func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, c.compileToDecimal(tt, len(args)-i) } c.asm.Fn_MULTICMP_d(len(args), call.cmp < 0) + c.asm.jumpDestination(jumps...) return ctype{Type: sqltypes.Decimal, Flag: f, Col: collationNumeric, Size: size, Scale: scale}, nil } @@ -486,6 +502,7 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { time int text int binary int + json int args []ctype nullable bool prec int @@ -500,12 +517,17 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { In all other cases, the arguments are compared as binary strings. */ - for _, expr := range call.Arguments { + jumps := make([]*jump, 0, len(call.Arguments)) + for i, expr := range call.Arguments { tt, err := expr.compile(c) if err != nil { return ctype{}, err } + if tt.nullable() { + jumps = append(jumps, c.compileNullCheckArg(tt, i)) + } + args = append(args, tt) nullable = nullable || tt.nullable() @@ -544,6 +566,12 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { temporal++ time++ prec = max(prec, int(tt.Size)) + case sqltypes.Set, sqltypes.Enum: + text++ + case sqltypes.TypeJSON: + json++ + case sqltypes.Geometry: + return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unsupported argument of geometry type for GREATEST/LEAST") case sqltypes.Null: nullable = true default: @@ -579,6 +607,7 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { } } c.asm.Fn_MULTICMP_temporal(len(args), call.cmp < 0) + c.asm.jumpDestination(jumps...) return ctype{Type: typ, Flag: f, Col: collationBinary}, nil } else if temporal > 0 { var ca collationAggregation @@ -608,24 +637,28 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { case time > 0: c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(nil), len(args), call.cmp, prec) } + c.asm.jumpDestination(jumps...) return ctype{Type: sqltypes.VarChar, Flag: f, Col: tc}, nil } if signed+unsigned == len(args) { if signed == len(args) { c.asm.Fn_MULTICMP_i(len(args), call.cmp < 0) + c.asm.jumpDestination(jumps...) return ctype{Type: sqltypes.Int64, Flag: f, Col: collationNumeric}, nil } if unsigned == len(args) { c.asm.Fn_MULTICMP_u(len(args), call.cmp < 0) + c.asm.jumpDestination(jumps...) return ctype{Type: sqltypes.Uint64, Flag: f, Col: collationNumeric}, nil } - return call.compile_d(c, args) + return call.compile_d(c, args, jumps) } if binary > 0 || text > 0 { if text > 0 { - return call.compile_c(c, args) + return call.compile_c(c, args, jumps) } c.asm.Fn_MULTICMP_b(len(args), call.cmp < 0) + c.asm.jumpDestination(jumps...) return ctype{Type: sqltypes.VarBinary, Flag: f, Col: collationBinary}, nil } else { if floats > 0 { @@ -633,10 +666,25 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { c.compileToFloat(tt, len(args)-i) } c.asm.Fn_MULTICMP_f(len(args), call.cmp < 0) + c.asm.jumpDestination(jumps...) return ctype{Type: sqltypes.Float64, Flag: f, Col: collationNumeric}, nil } if decimals > 0 { - return call.compile_d(c, args) + return call.compile_d(c, args, jumps) + } + if json > 0 { + c.asm.Fn_MULTICMP_c(len(args), call.cmp < 0, collationJSON) + c.asm.jumpDestination(jumps...) + return ctype{Type: sqltypes.Text, Flag: f, Col: collationJSON}, nil + } + + // The next case only gets hit if we already know at least one of the inputs + // is a static NULL typed value. That means we already have removed all items + // from the stack at this point and the top is a NULL to return. + if nullable { + c.asm.adjustStack(-len(args) + 1) + c.asm.jumpDestination(jumps...) + return ctype{Type: sqltypes.Null, Flag: f, Col: collationBinary}, nil } } return ctype{}, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected argument for GREATEST/LEAST") diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 5469873b10e..1f34e6bfe0c 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -1156,7 +1156,7 @@ func StrcmpComparison(yield Query) { func MultiComparisons(yield Query) { var numbers = []string{ - `0`, `-1`, `1`, `0.0`, `1.0`, `-1.0`, `1.0E0`, `-1.0E0`, `0.0E0`, + `NULL`, `0`, `-1`, `1`, `0.0`, `1.0`, `-1.0`, `1.0E0`, `-1.0E0`, `0.0E0`, strconv.FormatUint(math.MaxUint64, 10), strconv.FormatUint(math.MaxInt64, 10), strconv.FormatInt(math.MinInt64, 10),