Skip to content

Commit

Permalink
Merge pull request #87 from JuliaGPU/vc/fix_callback
Browse files Browse the repository at this point in the history
Implements  event_notify using safe operations
  • Loading branch information
vchuravy committed Nov 17, 2015
2 parents e05be14 + 267266e commit 712ba87
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 22 deletions.
55 changes: 42 additions & 13 deletions src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,27 @@ function Base.show(io::IO, ctx::Context)
print(io, "OpenCL.Context(@$ptr_address on $devs_str)")
end

immutable _CtxErr
handle :: Ptr{Void}
err_info :: Ptr{Cchar}
priv_info :: Ptr{Void}
cb :: Csize_t
end

function ctx_notify_err(err_info::Ptr{Cchar}, priv_info::Ptr{Void},
cb::Csize_t, julia_func::Ptr{Void})
err = bytestring(err_info)
private = bytestring(Base.unsafe_convert(Ptr{Cchar}, err_info))
callback = unsafe_pointer_to_objref(julia_func)::Function
callback(err, private)::Ptr{Void}
ptr = convert(Ptr{_CtxErr}, payload)
handle = unsafe_load(ptr, 1).handle

val = _CtxErr(handle, err_info, priv_info, cb)
unsafe_store!(ptr, val, 1)

ccall(:uv_async_send, Void, (Ptr{Void},), handle)
nothing
end


const ctx_callback_ptr = cfunction(ctx_notify_err, Ptr{Void},
const ctx_callback_ptr = cfunction(ctx_notify_err, Void,
(Ptr{Cchar}, Ptr{Void}, Csize_t, Ptr{Void}))

function raise_context_error(error_info, private_info)
Expand All @@ -48,7 +59,7 @@ end

function Context(devs::Vector{Device};
properties=nothing,
callback::Union{Void,Function}=nothing)
callback::Union{Function, Void} = nothing)
if isempty(devs)
ArgumentError("No devices specified for context")
end
Expand All @@ -57,22 +68,40 @@ function Context(devs::Vector{Device};
else
ctx_properties = C_NULL
end
if callback !== nothing
ctx_user_data = callback
else
ctx_user_data = raise_context_error
end

n_devices = length(devs)
device_ids = Array(CL_device_id, n_devices)
for (i, d) in enumerate(devs)
device_ids[i] = d.id
end
err_code = Array(CL_int, 1)

cond = Condition()
cb = Base.SingleAsyncWork(data -> notify(cond))
ctx_user_data = Ref(_CtxErr(cb.handle, 0, 0, 0))

err_code = Ref{CL_int}()
ctx_id = api.clCreateContext(ctx_properties, n_devices, device_ids,
ctx_callback_ptr, ctx_user_data, err_code)
if err_code[1] != CL_SUCCESS
if err_code[] != CL_SUCCESS
throw(CLError(err_code[1]))
end

true_callback = callback == nothing ? raise_context_error : callback :: Function

@async begin
try
Base.wait(cond)
err = ctx_user_data[]
error_info = bytestring(err.error_info)
private_info = bytestring(convert(Ptr{Cchar}, err.private_info))
true_callback(error_info, private_info)
catch
rethrow()
finally
Base.close(cb)
end
end

return Context(ctx_id)
end

Expand Down
51 changes: 42 additions & 9 deletions src/event.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,56 @@ Base.getindex(evt::CLEvent, evt_info::Symbol) = info(evt, evt_info)
end
end

function event_notify(evt_id::CL_event, status::CL_int, julia_func::Ptr{Void})
# Obtain the Function object from the opaque pointer
callback = unsafe_pointer_to_objref(julia_func)::Function
immutable _EventCB
handle :: Ptr{Void}
evt_id :: CL_event
status :: CL_int
end

function event_notify(evt_id::CL_event, status::CL_int, payload::Ptr{Void})
ptr = convert(Ptr{_EventCB}, payload)
handle = unsafe_load(ptr, 1).handle

# In order to callback into the Julia thread create an AsyncWork package.
cb_packaged = Base.SingleAsyncWork(data -> callback(evt_id, status))
val = _EventCB(handle, evt_id, status)
unsafe_store!(ptr, val, 1)

# Use uv_async_send to notify the main thread
ccall(:uv_async_send, Void, (Ptr{Void},), cb_packaged.handle)
ccall(:uv_async_send, Void, (Ptr{Void},), handle)
nothing
end

const event_notify_ptr = cfunction(event_notify, Void,
(CL_event, CL_int, Ptr{Void}))
function preserve_callback(evt :: CLEvent, cb, ptr)
evt._cbs[cb] = 0
push!(evt._memory, ptr)
end


function add_callback(evt::CLEvent, callback::Function)
@check api.clSetEventCallback(evt.id, CL_COMPLETE, event_notify_ptr, callback)
event_notify_ptr = cfunction(event_notify, Void,
(CL_event, CL_int, Ptr{Void}))

# The uv_callback is going to notify a task that,
# then executes the real callback.
cond = Condition()
cb = Base.SingleAsyncWork(data -> notify(cond))

# Storing the results of our c_callback needs to be
# isbits && isimmutable
r_ecb = Ref(_EventCB(cb.handle, 0, 0))

@check api.clSetEventCallback(evt.id, CL_COMPLETE, event_notify_ptr, r_ecb)

@async begin
try
Base.wait(cond)
ecb = r_ecb[]
callback(ecb.evt_id, ecb.status)
catch
rethrow()
finally
Base.close(cb)
end
end
end

function wait(evt::CLEvent)
Expand Down

0 comments on commit 712ba87

Please sign in to comment.