Skip to content

Commit

Permalink
planner: unify the behavior of prepare/execute limit to mysql (#40360)
Browse files Browse the repository at this point in the history
ref #40219
  • Loading branch information
fzzf678 authored Jan 10, 2023
1 parent b912237 commit e2a14ce
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 10 deletions.
26 changes: 23 additions & 3 deletions executor/seqtest/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ func TestPreparedLimitOffset(t *testing.T) {
r.Check(testkit.Rows("2"))

tk.MustExec(`set @a=1.1`)
r = tk.MustQuery(`execute stmt_test_1 using @a, @b;`)
r.Check(testkit.Rows("2"))
_, err := tk.Exec(`execute stmt_test_1 using @a, @b;`)
require.True(t, plannercore.ErrWrongArguments.Equal(err))

tk.MustExec(`set @c="-1"`)
_, err := tk.Exec("execute stmt_test_1 using @c, @c")
_, err = tk.Exec("execute stmt_test_1 using @c, @c")
require.True(t, plannercore.ErrWrongArguments.Equal(err))

stmtID, _, _, err := tk.Session().PrepareStmt("select id from prepare_test limit ?")
Expand Down Expand Up @@ -767,3 +767,23 @@ func TestPreparedIssue17419(t *testing.T) {
// _, ok := tk1.Session().ShowProcess().Plan.(*plannercore.Execute)
// require.True(t, ok)
}

func TestLimitUnsupportedCase(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, key(a))")
tk.MustExec("prepare stmt from 'select * from t limit ?'")

tk.MustExec("set @a = 1.2")
tk.MustGetErrMsg("execute stmt using @a", "[planner:1210]Incorrect arguments to LIMIT")
tk.MustExec("set @a = 1.")
tk.MustGetErrMsg("execute stmt using @a", "[planner:1210]Incorrect arguments to LIMIT")
tk.MustExec("set @a = '0'")
tk.MustGetErrMsg("execute stmt using @a", "[planner:1210]Incorrect arguments to LIMIT")
tk.MustExec("set @a = '1'")
tk.MustGetErrMsg("execute stmt using @a", "[planner:1210]Incorrect arguments to LIMIT")
tk.MustExec("set @a = 1_2")
tk.MustGetErrMsg("execute stmt using @a", "[planner:1210]Incorrect arguments to LIMIT")
}
34 changes: 27 additions & 7 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (a *aggOrderByResolver) Enter(inNode ast.Node) (ast.Node, bool) {
a.exprDepth++
if n, ok := inNode.(*driver.ParamMarkerExpr); ok {
if a.exprDepth == 1 {
_, isNull, isExpectedType := getUintFromNode(a.ctx, n)
_, isNull, isExpectedType := getUintFromNode(a.ctx, n, false)
// For constant uint expression in top level, it should be treated as position expression.
if !isNull && isExpectedType {
return expression.ConstructPositionExpr(n), true
Expand Down Expand Up @@ -2005,7 +2005,7 @@ CheckReferenced:
// getUintFromNode gets uint64 value from ast.Node.
// For ordinary statement, node should be uint64 constant value.
// For prepared statement, node is string. We should convert it to uint64.
func getUintFromNode(ctx sessionctx.Context, n ast.Node) (uVal uint64, isNull bool, isExpectedType bool) {
func getUintFromNode(ctx sessionctx.Context, n ast.Node, mustInt64orUint64 bool) (uVal uint64, isNull bool, isExpectedType bool) {
var val interface{}
switch v := n.(type) {
case *driver.ValueExpr:
Expand All @@ -2014,6 +2014,11 @@ func getUintFromNode(ctx sessionctx.Context, n ast.Node) (uVal uint64, isNull bo
if !v.InExecute {
return 0, false, true
}
if mustInt64orUint64 {
if expected := checkParamTypeInt64orUint64(v); !expected {
return 0, false, false
}
}
param, err := expression.ParamMarkerExpression(ctx, v, false)
if err != nil {
return 0, false, false
Expand Down Expand Up @@ -2047,17 +2052,32 @@ func getUintFromNode(ctx sessionctx.Context, n ast.Node) (uVal uint64, isNull bo
return 0, false, false
}

// check param type for plan cache limit, only allow int64 and uint64 now
// eg: set @a = 1;
func checkParamTypeInt64orUint64(param *driver.ParamMarkerExpr) bool {
val := param.GetValue()
switch v := val.(type) {
case int64:
if v >= 0 {
return true
}
case uint64:
return true
}
return false
}

func extractLimitCountOffset(ctx sessionctx.Context, limit *ast.Limit) (count uint64,
offset uint64, err error) {
var isExpectedType bool
if limit.Count != nil {
count, _, isExpectedType = getUintFromNode(ctx, limit.Count)
count, _, isExpectedType = getUintFromNode(ctx, limit.Count, true)
if !isExpectedType {
return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT")
}
}
if limit.Offset != nil {
offset, _, isExpectedType = getUintFromNode(ctx, limit.Offset)
offset, _, isExpectedType = getUintFromNode(ctx, limit.Offset, true)
if !isExpectedType {
return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT")
}
Expand Down Expand Up @@ -2838,7 +2858,7 @@ func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) {
case *driver.ParamMarkerExpr:
g.isParam = true
if g.exprDepth == 1 {
_, isNull, isExpectedType := getUintFromNode(g.ctx, n)
_, isNull, isExpectedType := getUintFromNode(g.ctx, n, false)
// For constant uint expression in top level, it should be treated as position expression.
if !isNull && isExpectedType {
return expression.ConstructPositionExpr(n), true
Expand Down Expand Up @@ -6203,7 +6223,7 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast
if bound.Type == ast.CurrentRow {
return bound, nil
}
numRows, _, _ := getUintFromNode(b.ctx, boundClause.Expr)
numRows, _, _ := getUintFromNode(b.ctx, boundClause.Expr, false)
bound.Num = numRows
return bound, nil
}
Expand Down Expand Up @@ -6519,7 +6539,7 @@ func (b *PlanBuilder) checkOriginWindowFrameBound(bound *ast.FrameBound, spec *a
if bound.Unit != ast.TimeUnitInvalid {
return ErrWindowRowsIntervalUse.GenWithStackByArgs(getWindowName(spec.Name.O))
}
_, isNull, isExpectedType := getUintFromNode(b.ctx, bound.Expr)
_, isNull, isExpectedType := getUintFromNode(b.ctx, bound.Expr, false)
if isNull || !isExpectedType {
return ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O))
}
Expand Down

0 comments on commit e2a14ce

Please sign in to comment.