Skip to content

Commit

Permalink
adapt to NDArrayOp
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Jan 7, 2016
1 parent 15539c7 commit e492b57
Showing 1 changed file with 101 additions and 33 deletions.
134 changes: 101 additions & 33 deletions src/nativeops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,14 @@ end
=#
need_top_grad(:: Operator) = true

#=doc
.. function:: declare_backward_dependency(op :: Operator, out_grad, in_data, out_data)
=#
function declare_backward_dependency(:: Operator, out_grad, in_data, out_data)
end

###
# NativeOpInfo mirrors the struct in include/mxnet/c_api.h and consists of five function
# NDArrayOpInfo mirrors the struct in include/mxnet/c_api.h and consists of five function
# pointers that work as callbacks. Each p_... ia a opaque pointer that contains the
# necessary information to call the right function.
#
Expand All @@ -61,27 +67,31 @@ need_top_grad(:: Operator) = true
#
# Todo: Cleanup tasks.
###
immutable NativeOpInfo
immutable NDArrayOpInfo
forward :: Ptr{Void}
backward :: Ptr{Void}
infer_shape :: Ptr{Void}
list_outputs :: Ptr{Void}
list_arguments :: Ptr{Void}
declare_backward_dependecy :: Ptr{Void}

p_forward :: Ptr{Void}
p_backward :: Ptr{Void}
p_infer_shape :: Ptr{Void}
p_list_outputs :: Ptr{Void}
p_list_arguments :: Ptr{Void}
p_declare_backward_dependency :: Ptr{Void}

function NativeOpInfo(op :: Operator)
# infer_shape, list_args, list_outputs are called directly and use dynamic dispatch,
# for finding the correct operator.
p_is, p_la, p_la = pointer_from_objref(op)
p_is, p_la, p_la p_dbd = pointer_from_objref(op)

c_wrapper_fb = cfunction(_wrapper_fb, Void, (Cint, Ptr{Ptr{Cfloat}}, Ptr{Cint}, Ptr{Ptr{Cuint}}, Ptr{Cint}, Ptr{Void}))
c_wrapper_fb = cfunction(_wrapper_fb, Void, (Cint, Ptr{Ptr{Void}}, Ptr{Cint}, Ptr{Void}))
c_wrapper_infer = cfunction(_wrapper_infer, Void, (Cint, Ptr{Cint}, Ptr{Ptr{Cuint}}, Ptr{Void}))
c_wrapper_list_outputs = cfunction(_wrapper_list_outputs, Void, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void}))
c_wrapper_list_arguments = cfunction(_wrapper_list_arguments, Void, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void}))
c_wrapper_declare_backward_dependency = cfunction(_wrapper_declare_backward_dependency, Cbool, Bool, (Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Ptr{Cint}}, Ptr{Void}))

# Setting up for handling backward/forward. Each function has a condition for a task
# to wait on and a libuv callback that notifies that condition, to handle the call.
Expand Down Expand Up @@ -129,43 +139,70 @@ immutable NativeOpInfo
end

new(c_wrapper_fb, c_wrapper_fb, c_wrapper_infer, c_wrapper_list,
c_wrapper_list, p_f, p_b, p_is, p_lo, p_la)
c_wrapper_list, c_wrapper_declare_backward_dependency,
p_f, p_b, p_is, p_lo, p_la, p_dpd)
end
end

###
# Infer and list are called in sync.
###
function _wrapper_infer(size :: Cint, ndims :: Ptr{Cint}, shapes :: Ptr{Ptr{Cuint}}, _op :: Ptr{Void})
op = unsafe_pointer_to_objref(_op) :: Operator
try
op = unsafe_pointer_to_objref(_op) :: Operator

n_in = length(list_arguments(op))
n_out = length(list_outputs(op))
@assert size == n_in + n_out
n_in = length(list_arguments(op))
n_out = length(list_outputs(op))
@assert size == n_in + n_out

shapes = [[tensor_shapes[i][j] for j in 1:tensor_dims[i]] for i in 1:n_in]]
shapes = [[tensor_shapes[i][j] for j in 1:tensor_dims[i]] for i in 1:n_in]

ishape, oshape = infer_shape(op, shapes)
@assert length(ishape) == n_in
@assert length(oshape) == n_out
ishape, oshape = infer_shape(op, shapes)
@assert length(ishape) == n_in
@assert length(oshape) == n_out

rshape = cat(ishape, oshape)
unsafe_store!(shapes, rshapes)
return nothing
rshape = cat(ishape, oshape)
unsafe_store!(shapes, rshapes)
catch
return false
end
return true
end

function _wrapper_list_arguments(data :: Ptr{Ptr{Cstring}}, _op :: Ptr{Void})
op = unsafe_pointer_to_objref(_op) :: Operator
arguments = list_arguments(op)
unsafe_store!(data, arguments)
return nothing
try
op = unsafe_pointer_to_objref(_op) :: Operator
arguments = list_arguments(op)
unsafe_store!(data, arguments)
catch
return false
end
return true
end

function _wrapper_list_outputs(data :: Ptr{Ptr{Cstring}}, _op :: Ptr{Void})
op = unsafe_pointer_to_objref(_op) :: Operator
outputs = list_outputs(op)
unsafe_store!(data, outputs)
return nothing
try
op = unsafe_pointer_to_objref(_op) :: Operator
outputs = list_outputs(op)
unsafe_store!(data, outputs)
catch
return false
end
return true
end

function _wrapper_declare_backward_dependency(out_grad :: Ptr{Cint},
in_data :: Ptr{Cint},
out_data :: Ptr{Cint},
num_dep :: Ptr{Cint},
deps :: Ptr{Ptr{Cint}},
_op :: Ptr{Void})
try
op = unsafe_pointer_to_objref(_op) :: Operator
catch
return false
end
return true
end

##
Expand All @@ -178,36 +215,67 @@ end
immutable _FB
handle :: Ptr{Void}
size :: Cint
data :: Ptr{Ptr{Cfloat}}
ndims :: Ptr{Cint}
shapes :: Ptr{Ptr{Cuint}}
data :: Ptr{Ptr{Void}}
tags :: Ptr{Cint}
end
_FB(handle :: Ptr{Void}) = _FP(handle, 0, 0, 0, 0, 0)
_FB(handle :: Ptr{Void}) = _FP(handle, 0, 0, 0)
@assert isbits(_FB)

# This function is called async and because the Julia runtime is not thread safe, we are
# very limited in the things we can do. Using a immutable that is a bitstype we can pass,
# return values to the handling tasks.
function _wrapper_fb(size :: Cint, data :: Ptr{Ptr{Cfloat}}, ndims :: Ptr{Cint}, shapes :: Ptr{Ptr{Cuint}}, tags :: Ptr{Cint}, payload :: Ptr{Void})
function _wrapper_fb(size :: Cint, data :: Ptr{Ptr{Void}}, tags :: Ptr{Cint}, payload :: Ptr{Void})
# Load the libuv async handle
ptr = convert(Ptr{_FB}, payload)
handle = unsafe_load(ptr, 1).handle

# Create result
val = _FB(handle, size, data, ndims, shapes, tags)
val = _FB(handle, size, data, tags)
unsafe_store!(ptr, val, 1)

ccall(:uv_async_send, Void, (Ptr{Void},), handle)
nothing
return true # Better solution?
end

# Todo: handle the c callback and call the correct function
function _entry_forward(:: Operator, payload :: _FB)
function _entry_forward(op :: Operator, payload :: _FB)
num_ndarray = payload.size
ndarraies = pointer_to_array(payload.data, num_ndarray, false)
tags = pointer_to_array(payload.tags, num_ndarray, false)

tensors = [[] for i in 1:4]

# Tags are zero-based
for i in 1:num_ndarray
if tags[i] == 1
#tensors[tags[i]+1].append(NDArray(cast(ndarraies[i], NDArrayHandle),
# writable=True))
else
#tensors[tags[i]+1].append(NDArray(cast(ndarraies[i], NDArrayHandle),
# writable=False))
end
end
forward(op, tensors[1], tensors[2])
end

# Todo: handle the c callback and call the correct function
function _entry_backward(:: Operator, payload :: _FB)
function _entry_backward(op :: Operator, payload :: _FB)
num_ndarray = payload.size
ndarraies = pointer_to_array(payload.data, num_ndarray, false)
tags = pointer_to_array(payload.tags, num_ndarray, false)

tensors = [[] for i in 1:4]

for i in 1:num_ndarray
if tags[i] == 2
#tensors[tags[i]+1].append(NDArray(cast(ndarraies[i], NDArrayHandle),
# writable=True))
else
#tensors[tags[i]+1].append(NDArray(cast(ndarraies[i], NDArrayHandle),
# writable=False))
end
end
backward(op, tensors[1], tensors[2], tensors[3], tensors[4])
end

# pstring = bytestring("0x", hex(reinterpret(UInt, pointer_from_objref(info))))
Expand Down

0 comments on commit e492b57

Please sign in to comment.