Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions go/test/endtoend/vtgate/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,27 @@ func TestFilterAfterLeftJoin(t *testing.T) {
utils.AssertMatches(t, conn, query, `[[INT64(1) INT64(10)]]`)
}

func TestFilterWithINAfterLeftJoin(t *testing.T) {
conn, closer := start(t)
defer closer()

utils.Exec(t, conn, "insert into t1 (id1,id2) values (1, 10)")
utils.Exec(t, conn, "insert into t1 (id1,id2) values (2, 3)")
utils.Exec(t, conn, "insert into t1 (id1,id2) values (3, 2)")
utils.Exec(t, conn, "insert into t1 (id1,id2) values (4, 5)")

query := "select a.id1, b.id3 from t1 as a left outer join t2 as b on a.id2 = b.id4 WHERE a.id2 = 10 AND (b.id3 IS NULL OR b.id3 IN (1))"
utils.AssertMatches(t, conn, query, `[[INT64(1) NULL]]`)

utils.Exec(t, conn, "insert into t2 (id3,id4) values (1, 10)")

query = "select a.id1, b.id3 from t1 as a left outer join t2 as b on a.id2 = b.id4 WHERE a.id2 = 10 AND (b.id3 IS NULL OR b.id3 IN (1))"
utils.AssertMatches(t, conn, query, `[[INT64(1) INT64(1)]]`)

query = "select a.id1, b.id3 from t1 as a left outer join t2 as b on a.id2 = b.id4 WHERE a.id2 = 10 AND (b.id3 IS NULL OR (b.id3, b.id4) IN ((1, 10)))"
utils.AssertMatches(t, conn, query, `[[INT64(1) INT64(1)]]`)
}

func TestDescribeVindex(t *testing.T) {
conn, closer := start(t)
defer closer()
Expand Down
22 changes: 22 additions & 0 deletions go/vt/vtgate/evalengine/compiler_asm_push.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,28 @@ func (asm *assembler) PushBVar_u(key string) {
}, "PUSH UINT64(:%q)", key)
}

func push_tuple(env *ExpressionEnv, values []*querypb.Value) int {
env.vm.stack[env.vm.sp], env.vm.err = newEvalTuple(values, env.collationEnv.DefaultConnectionCharset())
if env.vm.err != nil {
return 0
}
env.vm.sp++
return 1
}

func (asm *assembler) PushBVar_tuple(key string) {
asm.adjustStack(1)

asm.emit(func(env *ExpressionEnv) int {
var bvar *querypb.BindVariable
bvar, env.vm.err = env.lookupBindVar(key)
if env.vm.err != nil {
return 0
}
return push_tuple(env, bvar.Values)
}, "PUSH TUPLE(:%q)", key)
}

func (asm *assembler) PushLiteral(lit eval) error {
asm.adjustStack(1)

Expand Down
18 changes: 18 additions & 0 deletions go/vt/vtgate/evalengine/eval_tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.
package evalengine

import (
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

type evalTuple struct {
Expand All @@ -26,6 +28,22 @@ type evalTuple struct {

var _ eval = (*evalTuple)(nil)

func newEvalTuple(values []*querypb.Value, collation collations.ID) (*evalTuple, error) {
evals := make([]eval, 0, len(values))

for _, value := range values {
val := sqltypes.ProtoToValue(value)

e, err := valueToEval(val, typedCoercionCollation(val.Type(), collations.CollationForType(val.Type(), collation)), nil)
if err != nil {
return nil, err
}
evals = append(evals, e)
}

return &evalTuple{t: evals}, nil
}

func (e *evalTuple) ToRawBytes() []byte {
var vals []sqltypes.Value
for _, e2 := range e.t {
Expand Down
2 changes: 2 additions & 0 deletions go/vt/vtgate/evalengine/expr_bvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) {
c.asm.PushBVar_time(bvar.Key)
case tt == sqltypes.Vector:
c.asm.PushBVar_vector(bvar.Key)
case tt == sqltypes.Tuple:
c.asm.PushBVar_tuple(bvar.Key)
default:
return ctype{}, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Type is not supported: %s", tt)
}
Expand Down
13 changes: 12 additions & 1 deletion go/vt/vtgate/evalengine/expr_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,18 @@ func (expr *InExpr) compile(c *compiler) (ctype, error) {

return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil
case *BindVariable:
return ctype{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "rhs of an In operation should be a tuple")

if rhs.Type != sqltypes.Tuple {
return ctype{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "rhs of an In operation should be a tuple")
}

rt, err := rhs.compile(c)
if err != nil {
return ctype{}, err
}

c.asm.In_slow(c.env.CollationEnv(), expr.Negate)
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil
default:
panic("unreachable")
}
Expand Down
17 changes: 17 additions & 0 deletions go/vt/vtgate/evalengine/translate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,15 @@ func TestEvaluate(t *testing.T) {
}, {
expression: "(1,2) in ((1,null), (2,3))",
expected: NULL,
}, {
expression: "1 IN ::tuple_bind_variable",
expected: True,
}, {
expression: "3 IN ::tuple_bind_variable",
expected: True,
}, {
expression: "4 IN ::tuple_bind_variable",
expected: False,
}, {
expression: "(1,(1,2,3),(1,(1,2),4),2) = (1,(1,2,3),(1,(1,2),4),2)",
expected: True,
Expand Down Expand Up @@ -319,6 +328,14 @@ func TestEvaluate(t *testing.T) {
"uint32_bind_variable": sqltypes.Uint32BindVariable(21),
"uint64_bind_variable": sqltypes.Uint64BindVariable(22),
"float_bind_variable": sqltypes.Float64BindVariable(2.2),
"tuple_bind_variable": {
Type: sqltypes.Tuple,
Values: []*querypb.Value{
{Type: sqltypes.Int64, Value: []byte("1")},
{Type: sqltypes.Int64, Value: []byte("2")},
{Type: sqltypes.Int64, Value: []byte("3")},
},
},
}, NewEmptyVCursor(venv, time.Local))

// When
Expand Down
Loading