diff --git a/src/nativeops.jl b/src/nativeops.jl index 709692a33..3948832f7 100644 --- a/src/nativeops.jl +++ b/src/nativeops.jl @@ -53,8 +53,20 @@ need_top_grad(:: Operator) = true #=doc .. function:: declare_backward_dependency(op :: Operator, out_grad, in_data, out_data) + + Declare dependencies of this operator for backward pass. + + Return value needs to be an integer array. =# -function declare_backward_dependency(:: Operator, out_grad, in_data, out_data) +function declare_backward_dependency(op :: Operator, out_grad, in_data, out_data) + deps = Int[] + if need_top_grad(op) + append!(deps, out_grad) + end + append!(deps, in_data) + append!(deps, out_data) + + return deps end ### @@ -191,14 +203,24 @@ function _wrapper_list_outputs(data :: Ptr{Ptr{Cstring}}, _op :: Ptr{Void}) return true end -function _wrapper_declare_backward_dependency(out_grad :: Ptr{Cint}, - in_data :: Ptr{Cint}, - out_data :: Ptr{Cint}, +function _wrapper_declare_backward_dependency(_out_grad :: Ptr{Cint}, + _in_data :: Ptr{Cint}, + _out_data :: Ptr{Cint}, num_dep :: Ptr{Cint}, deps :: Ptr{Ptr{Cint}}, _op :: Ptr{Void}) try op = unsafe_pointer_to_objref(_op) :: Operator + + out_grad = pointer_to_array(_out_grad, length(list_outputs(op)), false) + in_data = pointer_to_array(_in_data, length(list_arguments(op)), false) + out_data = pointer_to_array(_out_data, length(list_outputs(op)), false) + + rdeps = convert(Array{Cint}, declare_backward_dependency(out_grad, in_data, out_data)) + + unsafe_store!(num_dep, length(rdeps), 1) + r_rdeps = Ref(rdeps) # Lifetime? + unsafe_store!(deps, convert(Ptr{Cint}, r_rdeps), 1) catch return false end