Skip to content

Commit

Permalink
Custom reimplementation of Base.SingleAsyncWork
Browse files Browse the repository at this point in the history
The custom version of SingelAsyncWork is needed to pass the wiating
condition around + the data for the callback without an anonymous
function.
  • Loading branch information
vchuravy committed Nov 6, 2015
1 parent 53970ba commit 00e24c6
Showing 1 changed file with 94 additions and 14 deletions.
108 changes: 94 additions & 14 deletions src/nativeops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,40 +34,120 @@ end
# Each _wrapper_... is the entry point for the c function call that converts the opaque
# pointer into the right julia function.
###
abstract _MXNET_DATA

immutable _Async{T <: _MXNET_DATA}
data :: T

handle :: Ptr{Void}
cb :: Function
cond :: Condition

function _Async(data :: T, cb::Function)
this = new(data, Libc.malloc(Base._sizeof_uv_async), cb, Condition())
Base.associate_julia_struct(this.handle, this)
Base.preserve_handle(this)

async_cb = cfunction(cb, Void, (Ptr{Void},))

err = ccall(:uv_async_init,Cint,(Ptr{Void},Ptr{Void},Ptr{Void}),Base.eventloop(),this.handle,async_cb::Ptr{Void})

this
end
end

Base._uv_hook_close(t::_Async) = (uv.handle = C_NULL; unpreserve_handle(uv); nothing)

immutable _FB <: _MXNET_DATA
size :: Cint
data :: Ptr{Ptr{Cfloat}}
ndims :: Ptr{Cint}
shapes :: Ptr{Ptr{Cuint}}
tags :: Ptr{Cint}
end

function _wrapper_fb(size :: Cint, data :: Ptr{Ptr{Cfloat}}, ndims :: Ptr{Cint}, shapes :: Ptr{Ptr{Cuint}}, tags :: Ptr{Cint}, jf :: Ptr{Void})
julia_function = unsafe_pointer_to_objref(jf) :: Function
julia_function(Int(size), data, ndims, shapes, tags)
entry = unsafe_pointer_to_objref(jf) :: Function
cb_data = _FB(size, data, ndims, shapes, tags)
work = _Async{_FB}(cb_data, entry)
ccall(:uv_async_send, Void, (Ptr{Void},), work.handle)
wait(work.cond)
return nothing
end

immutable _INFER <: _MXNET_DATA
size :: Cint
ndims :: Ptr{Cint}
shapes :: Ptr{Ptr{Cuint}}
end

function _wrapper_infer(size :: Cint, ndims :: Ptr{Cint}, shapes :: Ptr{Ptr{Cuint}}, jf :: Ptr{Void})
julia_function = unsafe_pointer_to_objref(jf) :: Function
julia_function(Int(size), ndims, shapes)
entry = unsafe_pointer_to_objref(jf) :: Function
cb_data = _INFER(size, ndims, shapes)
work = _Async{_INFER}(cb_data, entry)
ccall(:uv_async_send, Void, (Ptr{Void},), work.handle)
wait(work.cond)
return nothing
end

immutable _LIST <: _MXNET_DATA
result :: Ptr{Ptr{Ptr{Cchar}}}
end

function _wrapper_list(data :: Ptr{Ptr{Ptr{Cchar}}}, jf :: Ptr{Void})
julia_function = unsafe_pointer_to_objref(jf) :: Function
julia_function(data)
entry = unsafe_pointer_to_objref(jf) :: Function
cb_data = _LIST(data)
work = _Async{_LIST}(cb_data, entry)
ccall(:uv_async_send, Void, (Ptr{Void},), work.handle)
wait(work.cond)
return nothing
end

###
# Test entry functions
# Entry functions
# These functions are now executed in the main Julia thread.
###

function list_entry(a :: Ptr{Ptr{Ptr{Cchar}}})
data = ByteString[""]
ref = Base.cconvert(Ptr{Ptr{Cchar}}, data)
ptr = Base.unsafe_convert(Ptr{Ptr{Cchar}}, ref)
unsafe_store!(a, ptr,1)
function list_entry(handle :: Ptr{Void})
work = Base.@handle_as handle _Async{_LIST}
try
data = ByteString[""]
ref = Base.cconvert(Ptr{Ptr{Cchar}}, data)
ptr = Base.unsafe_convert(Ptr{Ptr{Cchar}}, ref)
unsafe_store!(work.data.result, ptr,1)
catch
finally
notify(work.cond)
end
nothing
end

function fb_entry(num_tensor :: Cint, in :: Ptr{Ptr{Cfloat}}, x :: Ptr{Cint}, y :: Ptr{Cuint}, z :: Ptr{Cint})
function fb_entry(handle :: Ptr{Void})
work = Base.@handle_as handle _Async{_FB}
try
data = work.data
# do data conversion
# call appropriate julia function
# store result
catch
finally
notify(work.cond)
end
nothing
end

function infer_entry(num_tensor :: Cint, tensor_dims :: Ptr{Cint}, tensor_shapes :: Ptr{Ptr{Cuint}})
function infer_entry(handle :: Ptr{Void})
work = Base.@handle_as handle _Async{_INFER}
try
data = work.data
# do data conversion
# call appropriate julia function
# store result
catch
finally
notify(work.cond)
end
nothing
end

create_info() = NativeOpInfo(fb_entry, fb_entry, infer_entry, list_entry, list_entry)
Expand Down

0 comments on commit 00e24c6

Please sign in to comment.