Skip to content

Commit

Permalink
handle declare backward dependecy
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Jan 7, 2016
1 parent 34f8ea4 commit f223c96
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions src/nativeops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,20 @@ need_top_grad(:: Operator) = true

#=doc
.. function:: declare_backward_dependency(op :: Operator, out_grad, in_data, out_data)
Declare dependencies of this operator for backward pass.
Return value needs to be an integer array.
=#
function declare_backward_dependency(:: Operator, out_grad, in_data, out_data)
function declare_backward_dependency(op :: Operator, out_grad, in_data, out_data)
deps = Int[]
if need_top_grad(op)
append!(deps, out_grad)
end
append!(deps, in_data)
append!(deps, out_data)

return deps
end

###
Expand Down Expand Up @@ -191,14 +203,24 @@ function _wrapper_list_outputs(data :: Ptr{Ptr{Cstring}}, _op :: Ptr{Void})
return true
end

function _wrapper_declare_backward_dependency(out_grad :: Ptr{Cint},
in_data :: Ptr{Cint},
out_data :: Ptr{Cint},
function _wrapper_declare_backward_dependency(_out_grad :: Ptr{Cint},
_in_data :: Ptr{Cint},
_out_data :: Ptr{Cint},
num_dep :: Ptr{Cint},
deps :: Ptr{Ptr{Cint}},
_op :: Ptr{Void})
try
op = unsafe_pointer_to_objref(_op) :: Operator

out_grad = pointer_to_array(_out_grad, length(list_outputs(op)), false)
in_data = pointer_to_array(_in_data, length(list_arguments(op)), false)
out_data = pointer_to_array(_out_data, length(list_outputs(op)), false)

rdeps = convert(Array{Cint}, declare_backward_dependency(out_grad, in_data, out_data))

unsafe_store!(num_dep, length(rdeps), 1)
r_rdeps = Ref(rdeps) # Lifetime?
unsafe_store!(deps, convert(Ptr{Cint}, r_rdeps), 1)
catch
return false
end
Expand Down

0 comments on commit f223c96

Please sign in to comment.