diff --git a/lib/ecto/query.ex b/lib/ecto/query.ex index 75c95d5b1e..eb1ebf7fa6 100644 --- a/lib/ecto/query.ex +++ b/lib/ecto/query.ex @@ -770,8 +770,8 @@ defmodule Ecto.Query do end @from_join_opts [:as, :prefix, :hints] - @no_binds [:lock, :union, :union_all, :except, :except_all, :intersect, :intersect_all] - @binds [:where, :or_where, :select, :distinct, :order_by, :group_by, :windows] ++ + @no_binds [:union, :union_all, :except, :except_all, :intersect, :intersect_all] + @binds [:lock, :where, :or_where, :select, :distinct, :order_by, :group_by, :windows] ++ [:having, :or_having, :limit, :offset, :preload, :update, :select_merge, :with_ctes] defp from([{type, expr}|t], env, count_bind, quoted, binds) when type in @binds do @@ -1659,8 +1659,8 @@ defmodule Ecto.Query do User |> where(u.id == ^current_user) |> lock("FOR SHARE NOWAIT") """ - defmacro lock(query, expr) do - Builder.Lock.build(query, expr, __CALLER__) + defmacro lock(query, binding \\ [], expr) do + Builder.Lock.build(query, binding, expr, __CALLER__) end @doc ~S""" diff --git a/lib/ecto/query/builder/lock.ex b/lib/ecto/query/builder/lock.ex index 7833474d37..0a934120a1 100644 --- a/lib/ecto/query/builder/lock.ex +++ b/lib/ecto/query/builder/lock.ex @@ -8,16 +8,28 @@ defmodule Ecto.Query.Builder.Lock do @doc """ Escapes the lock code. - iex> escape(quote do: "FOO") + iex> escape(quote(do: "FOO"), [], __ENV__) "FOO" """ - @spec escape(Macro.t) :: Macro.t - def escape(lock) when is_binary(lock), do: lock + @spec escape(Macro.t(), Keyword.t, Macro.Env.t) :: Macro.t() + def escape(lock, _vars, _env) when is_binary(lock), do: lock - def escape(other) do - Builder.error! "`#{Macro.to_string(other)}` is not a valid lock. " <> - "For security reasons, a lock must always be a literal string" + def escape({:fragment, _, [_ | _]} = expr, vars, env) do + {expr, {params, :acc}} = Builder.escape(expr, :any, {[], :acc}, vars, env) + + if params != [] do + Builder.error!("value interpolation is not allowed in :lock") + end + + expr + end + + def escape(other, _, _) do + Builder.error!( + "`#{Macro.to_string(other)}` is not a valid lock. " <> + "For security reasons, a lock must always be a literal string or a fragment" + ) end @doc """ @@ -27,18 +39,20 @@ defmodule Ecto.Query.Builder.Lock do If possible, it does all calculations at compile time to avoid runtime work. """ - @spec build(Macro.t, Macro.t, Macro.Env.t) :: Macro.t - def build(query, expr, env) do - Builder.apply_query(query, __MODULE__, [escape(expr)], env) + @spec build(Macro.t(), Macro.t(), Macro.t(), Macro.Env.t()) :: Macro.t() + def build(query, binding, expr, env) do + {query, binding} = Builder.escape_binding(query, binding, env) + Builder.apply_query(query, __MODULE__, [escape(expr, binding, env)], env) end @doc """ The callback applied by `build/4` to build the query. """ - @spec apply(Ecto.Queryable.t, term) :: Ecto.Query.t + @spec apply(Ecto.Queryable.t(), term) :: Ecto.Query.t() def apply(%Ecto.Query{} = query, value) do %{query | lock: value} end + def apply(query, value) do apply(Ecto.Queryable.to_query(query), value) end diff --git a/test/ecto/query/builder/lock_test.exs b/test/ecto/query/builder/lock_test.exs index 9c680ce0e1..4b9f48a3a9 100644 --- a/test/ecto/query/builder/lock_test.exs +++ b/test/ecto/query/builder/lock_test.exs @@ -15,6 +15,16 @@ defmodule Ecto.Query.Builder.LockTest do end end + test "lock with string" do + query = %Ecto.Query{} |> lock("FOO") + assert query.lock == "FOO" + end + + test "lock with fragment" do + query = "posts" |> lock([p], fragment("update on ?", p)) + assert query.lock == {:fragment, [], [raw: "update on ", expr: {:&, [], [0]}, raw: ""]} + end + test "overrides on duplicated lock" do query = %Ecto.Query{} |> lock("FOO") |> lock("BAR") assert query.lock == "BAR" diff --git a/test/ecto/uuid_test.exs b/test/ecto/uuid_test.exs index e585281816..cd745e325b 100644 --- a/test/ecto/uuid_test.exs +++ b/test/ecto/uuid_test.exs @@ -38,6 +38,6 @@ defmodule Ecto.UUIDTest do end test "generate" do - assert << _::64, ?-, _::32, ?-, _::32, ?-, _::32, ?-, _::96 >> = Ecto.UUID.generate + assert << _::64, ?-, _::32, ?-, _::32, ?-, _::32, ?-, _::96 >> = Ecto.UUID.generate() end end