Skip to content

Commit

Permalink
Allow Enumerable for query args instead of Array (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
lwakefield committed Jul 3, 2024
1 parent 3eaac85 commit 532ae07
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
6 changes: 3 additions & 3 deletions src/db/pool_statement.cr
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module DB
end

# See `QueryMethods#exec`
def exec(*args_, args : Array? = nil) : ExecResult
def exec(*args_, args : Enumerable? = nil) : ExecResult
statement_with_retry &.exec(*args_, args: args)
end

Expand All @@ -25,12 +25,12 @@ module DB
end

# See `QueryMethods#query`
def query(*args_, args : Array? = nil) : ResultSet
def query(*args_, args : Enumerable? = nil) : ResultSet
statement_with_retry &.query(*args_, args: args)
end

# See `QueryMethods#scalar`
def scalar(*args_, args : Array? = nil)
def scalar(*args_, args : Enumerable? = nil)
statement_with_retry &.scalar(*args_, args: args)
end

Expand Down
34 changes: 17 additions & 17 deletions src/db/query_methods.cr
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ module DB
# result = db.query "select name from contacts where id = ?", args: [10]
# ```
#
def query(query, *args_, args : Array? = nil)
def query(query, *args_, args : Enumerable? = nil)
build(query).query(*args_, args: args)
end

Expand All @@ -56,7 +56,7 @@ module DB
# end
# end
# ```
def query(query, *args_, args : Array? = nil)
def query(query, *args_, args : Enumerable? = nil)
# CHECK build(query).query(*args, &block)
rs = query(query, *args_, args: args)
yield rs ensure rs.close
Expand All @@ -73,7 +73,7 @@ module DB
# ```
# name = db.query_one "select name from contacts where id = ?", 18, &.read(String)
# ```
def query_one(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U forall U
def query_one(query, *args_, args : Enumerable? = nil, &block : ResultSet -> U) : U forall U
query(query, *args_, args: args) do |rs|
raise DB::NoResultsError.new("no results") unless rs.move_next

Expand All @@ -92,7 +92,7 @@ module DB
# ```
# db.query_one "select name, age from contacts where id = ?", 1, as: {String, Int32}
# ```
def query_one(query, *args_, args : Array? = nil, as types : Tuple)
def query_one(query, *args_, args : Enumerable? = nil, as types : Tuple)
query_one(query, *args_, args: args) do |rs|
rs.read(*types)
end
Expand All @@ -108,7 +108,7 @@ module DB
# ```
# db.query_one "select name, age from contacts where id = ?", 1, as: {name: String, age: Int32}
# ```
def query_one(query, *args_, args : Array? = nil, as types : NamedTuple)
def query_one(query, *args_, args : Enumerable? = nil, as types : NamedTuple)
query_one(query, *args_, args: args) do |rs|
rs.read(**types)
end
Expand All @@ -123,7 +123,7 @@ module DB
# ```
# db.query_one "select name from contacts where id = ?", 1, as: String
# ```
def query_one(query, *args_, args : Array? = nil, as type : Class)
def query_one(query, *args_, args : Enumerable? = nil, as type : Class)
query_one(query, *args_, args: args) do |rs|
rs.read(type)
end
Expand All @@ -141,7 +141,7 @@ module DB
# name = db.query_one? "select name from contacts where id = ?", 18, &.read(String)
# typeof(name) # => String | Nil
# ```
def query_one?(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U? forall U
def query_one?(query, *args_, args : Enumerable? = nil, &block : ResultSet -> U) : U? forall U
query(query, *args_, args: args) do |rs|
return nil unless rs.move_next

Expand All @@ -162,7 +162,7 @@ module DB
# result = db.query_one? "select name, age from contacts where id = ?", 1, as: {String, Int32}
# typeof(result) # => Tuple(String, Int32) | Nil
# ```
def query_one?(query, *args_, args : Array? = nil, as types : Tuple)
def query_one?(query, *args_, args : Enumerable? = nil, as types : Tuple)
query_one?(query, *args_, args: args) do |rs|
rs.read(*types)
end
Expand All @@ -180,7 +180,7 @@ module DB
# result = db.query_one? "select name, age from contacts where id = ?", 1, as: {age: String, name: Int32}
# typeof(result) # => NamedTuple(age: String, name: Int32) | Nil
# ```
def query_one?(query, *args_, args : Array? = nil, as types : NamedTuple)
def query_one?(query, *args_, args : Enumerable? = nil, as types : NamedTuple)
query_one?(query, *args_, args: args) do |rs|
rs.read(**types)
end
Expand All @@ -197,7 +197,7 @@ module DB
# name = db.query_one? "select name from contacts where id = ?", 1, as: String
# typeof(name) # => String?
# ```
def query_one?(query, *args_, args : Array? = nil, as type : Class)
def query_one?(query, *args_, args : Enumerable? = nil, as type : Class)
query_one?(query, *args_, args: args) do |rs|
rs.read(type)
end
Expand All @@ -209,7 +209,7 @@ module DB
# ```
# names = db.query_all "select name from contacts", &.read(String)
# ```
def query_all(query, *args_, args : Array? = nil, &block : ResultSet -> U) : Array(U) forall U
def query_all(query, *args_, args : Enumerable? = nil, &block : ResultSet -> U) : Array(U) forall U
ary = [] of U
query_each(query, *args_, args: args) do |rs|
ary.push(yield rs)
Expand All @@ -223,7 +223,7 @@ module DB
# ```
# contacts = db.query_all "select name, age from contacts", as: {String, Int32}
# ```
def query_all(query, *args_, args : Array? = nil, as types : Tuple)
def query_all(query, *args_, args : Enumerable? = nil, as types : Tuple)
query_all(query, *args_, args: args) do |rs|
rs.read(*types)
end
Expand All @@ -236,7 +236,7 @@ module DB
# ```
# contacts = db.query_all "select name, age from contacts", as: {name: String, age: Int32}
# ```
def query_all(query, *args_, args : Array? = nil, as types : NamedTuple)
def query_all(query, *args_, args : Enumerable? = nil, as types : NamedTuple)
query_all(query, *args_, args: args) do |rs|
rs.read(**types)
end
Expand All @@ -248,7 +248,7 @@ module DB
# ```
# names = db.query_all "select name from contacts", as: String
# ```
def query_all(query, *args_, args : Array? = nil, as type : Class)
def query_all(query, *args_, args : Enumerable? = nil, as type : Class)
query_all(query, *args_, args: args) do |rs|
rs.read(type)
end
Expand All @@ -262,7 +262,7 @@ module DB
# puts rs.read(String)
# end
# ```
def query_each(query, *args_, args : Array? = nil)
def query_each(query, *args_, args : Enumerable? = nil)
query(query, *args_, args: args) do |rs|
rs.each do
yield rs
Expand All @@ -271,7 +271,7 @@ module DB
end

# Performs the `query` and returns an `ExecResult`
def exec(query, *args_, args : Array? = nil)
def exec(query, *args_, args : Enumerable? = nil)
build(query).exec(*args_, args: args)
end

Expand All @@ -280,7 +280,7 @@ module DB
# ```
# puts db.scalar("SELECT MAX(name)").as(String) # => (a String)
# ```
def scalar(query, *args_, args : Array? = nil)
def scalar(query, *args_, args : Enumerable? = nil)
build(query).scalar(*args_, args: args)
end
end
Expand Down
12 changes: 6 additions & 6 deletions src/db/statement.cr
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DB
# and for connection pool statements.
module StatementMethods
# See `QueryMethods#scalar`
def scalar(*args_, args : Array? = nil)
def scalar(*args_, args : Enumerable? = nil)
query(*args_, args: args) do |rs|
rs.each do
return rs.read
Expand All @@ -14,20 +14,20 @@ module DB
end

# See `QueryMethods#query`
def query(*args_, args : Array? = nil)
def query(*args_, args : Enumerable? = nil)
rs = query(*args_, args: args)
yield rs ensure rs.close
end

# See `QueryMethods#exec`
abstract def exec : ExecResult
# See `QueryMethods#exec`
abstract def exec(*args_, args : Array? = nil) : ExecResult
abstract def exec(*args_, args : Enumerable? = nil) : ExecResult

# See `QueryMethods#query`
abstract def query : ResultSet
# See `QueryMethods#query`
abstract def query(*args_, args : Array? = nil) : ResultSet
abstract def query(*args_, args : Enumerable? = nil) : ResultSet
end

# Represents a query in a `Connection`.
Expand Down Expand Up @@ -74,7 +74,7 @@ module DB
end

# See `QueryMethods#exec`
def exec(*args_, args : Array? = nil) : DB::ExecResult
def exec(*args_, args : Enumerable? = nil) : DB::ExecResult
perform_exec_and_release(EnumerableConcat.build(args_, args))
end

Expand All @@ -84,7 +84,7 @@ module DB
end

# See `QueryMethods#query`
def query(*args_, args : Array? = nil) : DB::ResultSet
def query(*args_, args : Enumerable? = nil) : DB::ResultSet
perform_query_with_rescue(EnumerableConcat.build(args_, args))
end

Expand Down

0 comments on commit 532ae07

Please sign in to comment.