Skip to content

Commit b729dd4

Browse files
committed
use tasks to handle forward and backward
1 parent 22c1ecb commit b729dd4

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

src/nativeops.jl

+40-3
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,45 @@ immutable NativeOpInfo
4848
c_wrapper_infer = cfunction(_wrapper_infer, Void, (Cint, Ptr{Cint}, Ptr{Ptr{Cuint}}, Ptr{Void}))
4949
const c_wrapper_list = cfunction(_wrapper_list, Void, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void}))
5050

51-
p_f = pointer_from_objref(forward)
52-
p_b = pointer_from_objref(backwards)
51+
cond_forward = Condition()
52+
cond_backward = Condition()
53+
cb_f = Base.SingleAsyncWork(data -> notify(cond_forward))
54+
cb_b = Base.SingleAsyncWork(data -> notify(cond_backward))
55+
56+
r_forward = Ref(_FB(cb_f.handle))
57+
r_backward = Ref(_FB(cb_f.handle))
58+
59+
p_f = convert(Ptr{Void}, r_forward)
60+
p_f = convert(Ptr{Void}, r_backward)
61+
62+
@schedule begin
63+
try
64+
while true
65+
wait(cond_forward)
66+
cond_forward = Condition()
67+
_entry_forward(r_forward[])
68+
end
69+
catch
70+
rethrow()
71+
finally
72+
Base.close(cb_f)
73+
end
74+
end
75+
76+
@schedule begin
77+
try
78+
while true
79+
wait(cond_backward)
80+
cond_backward = Condition()
81+
_entry_backward(r_backward[])
82+
end
83+
catch
84+
rethrow()
85+
finally
86+
Base.close(cb_f)
87+
end
88+
end
89+
5390
new(c_wrapper_fb, c_wrapper_fb, c_wrapper_infer, c_wrapper_list,
5491
c_wrapper_list, p_f, p_b, p_is, p_lo, p_la)
5592
end
@@ -105,7 +142,7 @@ immutable _FB
105142
shapes :: Ptr{Ptr{Cuint}}
106143
tags :: Ptr{Cint}
107144
end
108-
145+
_FB(handle :: Ptr{Void}) = _FP(handle, 0, 0, 0, 0, 0)
109146
@assert isbits(_FB)
110147

111148
function _wrapper_fb(size :: Cint, data :: Ptr{Ptr{Cfloat}}, ndims :: Ptr{Cint}, shapes :: Ptr{Ptr{Cuint}}, tags :: Ptr{Cint}, payload :: Ptr{Void})

0 commit comments

Comments
 (0)