diff --git a/go/test/endtoend/vtgate/misc_test.go b/go/test/endtoend/vtgate/misc_test.go index df71069e50c..b68732e7008 100644 --- a/go/test/endtoend/vtgate/misc_test.go +++ b/go/test/endtoend/vtgate/misc_test.go @@ -746,6 +746,27 @@ func TestFilterAfterLeftJoin(t *testing.T) { utils.AssertMatches(t, conn, query, `[[INT64(1) INT64(10)]]`) } +func TestFilterWithINAfterLeftJoin(t *testing.T) { + conn, closer := start(t) + defer closer() + + utils.Exec(t, conn, "insert into t1 (id1,id2) values (1, 10)") + utils.Exec(t, conn, "insert into t1 (id1,id2) values (2, 3)") + utils.Exec(t, conn, "insert into t1 (id1,id2) values (3, 2)") + utils.Exec(t, conn, "insert into t1 (id1,id2) values (4, 5)") + + query := "select a.id1, b.id3 from t1 as a left outer join t2 as b on a.id2 = b.id4 WHERE a.id2 = 10 AND (b.id3 IS NULL OR b.id3 IN (1))" + utils.AssertMatches(t, conn, query, `[[INT64(1) NULL]]`) + + utils.Exec(t, conn, "insert into t2 (id3,id4) values (1, 10)") + + query = "select a.id1, b.id3 from t1 as a left outer join t2 as b on a.id2 = b.id4 WHERE a.id2 = 10 AND (b.id3 IS NULL OR b.id3 IN (1))" + utils.AssertMatches(t, conn, query, `[[INT64(1) INT64(1)]]`) + + query = "select a.id1, b.id3 from t1 as a left outer join t2 as b on a.id2 = b.id4 WHERE a.id2 = 10 AND (b.id3 IS NULL OR (b.id3, b.id4) IN ((1, 10)))" + utils.AssertMatches(t, conn, query, `[[INT64(1) INT64(1)]]`) +} + func TestDescribeVindex(t *testing.T) { conn, closer := start(t) defer closer() diff --git a/go/vt/vtgate/evalengine/compiler_asm_push.go b/go/vt/vtgate/evalengine/compiler_asm_push.go index 404c8870f87..aad8879c765 100644 --- a/go/vt/vtgate/evalengine/compiler_asm_push.go +++ b/go/vt/vtgate/evalengine/compiler_asm_push.go @@ -525,6 +525,28 @@ func (asm *assembler) PushBVar_u(key string) { }, "PUSH UINT64(:%q)", key) } +func push_tuple(env *ExpressionEnv, values []*querypb.Value) int { + env.vm.stack[env.vm.sp], env.vm.err = newEvalTuple(values, env.collationEnv.DefaultConnectionCharset()) + if env.vm.err != nil { + return 0 + } + env.vm.sp++ + return 1 +} + +func (asm *assembler) PushBVar_tuple(key string) { + asm.adjustStack(1) + + asm.emit(func(env *ExpressionEnv) int { + var bvar *querypb.BindVariable + bvar, env.vm.err = env.lookupBindVar(key) + if env.vm.err != nil { + return 0 + } + return push_tuple(env, bvar.Values) + }, "PUSH TUPLE(:%q)", key) +} + func (asm *assembler) PushLiteral(lit eval) error { asm.adjustStack(1) diff --git a/go/vt/vtgate/evalengine/eval_tuple.go b/go/vt/vtgate/evalengine/eval_tuple.go index 1faff68e155..7d665fda36a 100644 --- a/go/vt/vtgate/evalengine/eval_tuple.go +++ b/go/vt/vtgate/evalengine/eval_tuple.go @@ -17,7 +17,9 @@ limitations under the License. package evalengine import ( + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" ) type evalTuple struct { @@ -26,6 +28,22 @@ type evalTuple struct { var _ eval = (*evalTuple)(nil) +func newEvalTuple(values []*querypb.Value, collation collations.ID) (*evalTuple, error) { + evals := make([]eval, 0, len(values)) + + for _, value := range values { + val := sqltypes.ProtoToValue(value) + + e, err := valueToEval(val, typedCoercionCollation(val.Type(), collations.CollationForType(val.Type(), collation)), nil) + if err != nil { + return nil, err + } + evals = append(evals, e) + } + + return &evalTuple{t: evals}, nil +} + func (e *evalTuple) ToRawBytes() []byte { var vals []sqltypes.Value for _, e2 := range e.t { diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 81dc3432d76..23e55f77308 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -170,6 +170,8 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) { c.asm.PushBVar_time(bvar.Key) case tt == sqltypes.Vector: c.asm.PushBVar_vector(bvar.Key) + case tt == sqltypes.Tuple: + c.asm.PushBVar_tuple(bvar.Key) default: return ctype{}, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Type is not supported: %s", tt) } diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index 6e6c888ecf6..0250306c2ff 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -580,7 +580,18 @@ func (expr *InExpr) compile(c *compiler) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil case *BindVariable: - return ctype{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "rhs of an In operation should be a tuple") + + if rhs.Type != sqltypes.Tuple { + return ctype{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "rhs of an In operation should be a tuple") + } + + rt, err := rhs.compile(c) + if err != nil { + return ctype{}, err + } + + c.asm.In_slow(c.env.CollationEnv(), expr.Negate) + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil default: panic("unreachable") } diff --git a/go/vt/vtgate/evalengine/translate_test.go b/go/vt/vtgate/evalengine/translate_test.go index 52b717fe2c6..c14ee4fa8b0 100644 --- a/go/vt/vtgate/evalengine/translate_test.go +++ b/go/vt/vtgate/evalengine/translate_test.go @@ -237,6 +237,15 @@ func TestEvaluate(t *testing.T) { }, { expression: "(1,2) in ((1,null), (2,3))", expected: NULL, + }, { + expression: "1 IN ::tuple_bind_variable", + expected: True, + }, { + expression: "3 IN ::tuple_bind_variable", + expected: True, + }, { + expression: "4 IN ::tuple_bind_variable", + expected: False, }, { expression: "(1,(1,2,3),(1,(1,2),4),2) = (1,(1,2,3),(1,(1,2),4),2)", expected: True, @@ -319,6 +328,14 @@ func TestEvaluate(t *testing.T) { "uint32_bind_variable": sqltypes.Uint32BindVariable(21), "uint64_bind_variable": sqltypes.Uint64BindVariable(22), "float_bind_variable": sqltypes.Float64BindVariable(2.2), + "tuple_bind_variable": { + Type: sqltypes.Tuple, + Values: []*querypb.Value{ + {Type: sqltypes.Int64, Value: []byte("1")}, + {Type: sqltypes.Int64, Value: []byte("2")}, + {Type: sqltypes.Int64, Value: []byte("3")}, + }, + }, }, NewEmptyVCursor(venv, time.Local)) // When