Skip to content

Commit ac83f3f

Browse files
authored
Implement new DBInterface.executemultiple interface function (#155)
This allows returning multiple resultsets from a single query call; this requires setting `multi_statements` to `true` by default, but I couldn't think of any downside to doing this (vs. the alternative in which a user finds the function, tries it and immediately fails and has to restart their connection by setting the flag). Prepared statements don't support multiple resultsets, and an error is thrown appropriately (from mysql itself). This PR requires JuliaDatabases/DBInterface.jl#26 to be merged and tagged first.
1 parent 2bdf85f commit ac83f3f

File tree

3 files changed

+69
-14
lines changed

3 files changed

+69
-14
lines changed

src/MySQL.jl

+11-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,16 @@ end
4747

4848
function clear!(conn, result::API.MYSQL_RES)
4949
if result.ptr != C_NULL
50-
while API.fetchrow(conn.mysql, result) != C_NULL || API.nextresult(conn.mysql) !== nothing
50+
while true
51+
if API.fetchrow(conn.mysql, result) == C_NULL
52+
if API.moreresults(conn.mysql)
53+
finalize(result)
54+
@assert API.nextresult(conn.mysql) !== nothing
55+
result = API.useresult(conn.mysql)
56+
else
57+
break
58+
end
59+
end
5160
end
5261
finalize(result)
5362
end
@@ -68,7 +77,7 @@ function clientflags(;
6877
compress::Bool=false,
6978
ignore_space::Bool=false,
7079
local_files::Bool=false,
71-
multi_statements::Bool=false,
80+
multi_statements::Bool=true,
7281
multi_results::Bool=false,
7382
kw...
7483
)

src/execute.jl

+39-12
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,22 @@ mutable struct TextCursor{buffered} <: DBInterface.Cursor
99
types::Vector{Type}
1010
lookup::Dict{Symbol, Int}
1111
current_rownumber::Int
12+
current_resultsetnumber::Int
1213
end
1314

1415
struct TextRow{buffered} <: Tables.AbstractRow
1516
cursor::TextCursor{buffered}
1617
row::Ptr{Ptr{UInt8}}
1718
lengths::Vector{Culong}
1819
rownumber::Int
20+
resultsetnumber::Int
1921
end
2022

2123
getcursor(r::TextRow) = getfield(r, :cursor)
2224
getrow(r::TextRow) = getfield(r, :row)
2325
getlengths(r::TextRow) = getfield(r, :lengths)
2426
getrownumber(r::TextRow) = getfield(r, :rownumber)
27+
getresultsetnumber(r::TextRow) = getfield(r, :resultsetnumber)
2528

2629
Tables.columnnames(r::TextRow) = getcursor(r).names
2730

@@ -71,7 +74,7 @@ end
7174
@noinline wrongrow(i) = throw(ArgumentError("row $i is no longer valid; mysql results are forward-only iterators where each row is only valid when iterated"))
7275

7376
function Tables.getcolumn(r::TextRow, ::Type{T}, i::Int, nm::Symbol) where {T}
74-
getrownumber(r) == getcursor(r).current_rownumber || wrongrow(getrownumber(r))
77+
(getrownumber(r) == getcursor(r).current_rownumber && getresultsetnumber(r) == getcursor(r).current_resultsetnumber) || wrongrow(getrownumber(r))
7578
return cast(T, unsafe_load(getrow(r), i), getlengths(r)[i])
7679
end
7780

@@ -91,18 +94,11 @@ function Base.iterate(cursor::TextCursor{buffered}, i=1) where {buffered}
9194
rowptr = API.fetchrow(cursor.conn.mysql, cursor.result)
9295
if rowptr == C_NULL
9396
!buffered && API.errno(cursor.conn.mysql) != 0 && throw(API.Error(cursor.conn.mysql))
94-
if API.nextresult(cursor.conn.mysql) === nothing
95-
finalize(cursor.result)
96-
return nothing
97-
else
98-
# we ***ignore*** additional resultsets for now
99-
clear!(cursor.conn)
100-
return nothing
101-
end
97+
return nothing
10298
end
10399
lengths = API.fetchlengths(cursor.result, cursor.nfields)
104100
cursor.current_rownumber = i
105-
return TextRow(cursor, rowptr, lengths, i), i + 1
101+
return TextRow(cursor, rowptr, lengths, i, cursor.current_resultsetnumber), i + 1
106102
end
107103

108104
"""
@@ -167,5 +163,36 @@ function DBInterface.execute(conn::Connection, sql::AbstractString, params=(); m
167163
error("error with mysql resultset columns")
168164
end
169165
lookup = Dict(x => i for (i, x) in enumerate(names))
170-
return TextCursor{buffered}(conn, sql, nfields, nrows, Core.bitcast(Int64, rows_affected), result, names, types, lookup, 0)
171-
end
166+
return TextCursor{buffered}(conn, sql, nfields, nrows, Core.bitcast(Int64, rows_affected), result, names, types, lookup, 0, 1)
167+
end
168+
169+
struct TextCursors{T}
170+
cursor::TextCursor{T}
171+
end
172+
173+
Base.eltype(c::TextCursors{T}) where {T} = TextCursor{T}
174+
Base.IteratorSize(::Type{<:TextCursors}) = Base.SizeUnknown()
175+
176+
function Base.iterate(cursor::TextCursors{buffered}, first=true) where {buffered}
177+
cursor.cursor.result.ptr == C_NULL && return nothing
178+
if !first
179+
finalize(cursor.cursor.result)
180+
if API.moreresults(cursor.cursor.conn.mysql)
181+
@assert API.nextresult(cursor.cursor.conn.mysql) !== nothing
182+
cursor.cursor.result = buffered ? API.storeresult(cursor.cursor.conn.mysql) : API.useresult(cursor.cursor.conn.mysql)
183+
if buffered
184+
cursor.cursor.nrows = API.numrows(cursor.cursor.result)
185+
end
186+
cursor.cursor.nfields = API.numfields(cursor.cursor.result)
187+
fields = API.fetchfields(cursor.cursor.result, cursor.cursor.nfields)
188+
cursor.cursor.names = [ccall(:jl_symbol_n, Ref{Symbol}, (Ptr{UInt8}, Csize_t), x.name, x.name_length) for x in fields]
189+
cursor.cursor.types = [juliatype(x.field_type, API.notnullable(x), API.isunsigned(x), API.isbinary(x)) for x in fields]
190+
else
191+
return nothing
192+
end
193+
end
194+
return cursor.cursor, false
195+
end
196+
197+
DBInterface.executemultiple(conn::Connection, sql::AbstractString, params=(); kw...) =
198+
TextCursors(DBInterface.execute(conn, sql, params; kw...))

test/runtests.jl

+19
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,22 @@ res = DBInterface.execute(stmt) |> columntable
249249
@test length(res[1]) == 5
250250
res = DBInterface.execute(stmt)
251251
res = DBInterface.execute(stmt)
252+
253+
results = DBInterface.executemultiple(conn, "select * from Employee; select DeptNo, OfficeNo from Employee where OfficeNo IS NOT NULL")
254+
state = iterate(results)
255+
@test state !== nothing
256+
res, st = state
257+
@test !st
258+
@test length(res) == 5
259+
ret = columntable(res)
260+
@test length(ret[1]) == 5
261+
state = iterate(results, st)
262+
@test state !== nothing
263+
res, st = state
264+
@test !st
265+
@test length(res) == 4
266+
ret = columntable(res)
267+
@test length(ret[1]) == 4
268+
269+
# multiple-queries not supported by mysql w/ prepared statements
270+
@test_throws MySQL.API.StmtError DBInterface.prepare(conn, "select * from Employee; select DeptNo, OfficeNo from Employee where OfficeNo IS NOT NULL")

0 commit comments

Comments
 (0)