Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add update_capture_dependencies flags
Browse files Browse the repository at this point in the history
nimlgen committed Oct 20, 2023

Verified

This commit was signed with the committer’s verified signature.
sparrc Cameron Sparr
1 parent 66d7e36 commit f232a64
Showing 3 changed files with 14 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/demo_graph.py
Original file line number Diff line number Diff line change
@@ -34,11 +34,11 @@
func_plus(a_gpu, numpy.int32(2), block=(4, 4, 1), stream=stream_1)
_, _, graph, deps = stream_1.get_capture_info_v2()
first_node = graph.add_kernel_node(b_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
stream_1.update_capture_dependencies([first_node], 1)
stream_1.update_capture_dependencies([first_node], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)

_, _, graph, deps = stream_1.get_capture_info_v2()
second_node = graph.add_kernel_node(a_gpu, b_gpu, block=(4, 4, 1), func=func_times, dependencies=deps)
stream_1.update_capture_dependencies([second_node], 1)
stream_1.update_capture_dependencies([second_node], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
cuda.memcpy_dtoh_async(result, a_gpu, stream_1)

graph = stream_1.end_capture()
10 changes: 9 additions & 1 deletion src/wrapper/wrap_cudadrv.cpp
Original file line number Diff line number Diff line change
@@ -1277,6 +1277,12 @@ BOOST_PYTHON_MODULE(_driver)
.value("ACTIVE", CU_STREAM_CAPTURE_STATUS_ACTIVE)
.value("INVALIDATED", CU_STREAM_CAPTURE_STATUS_INVALIDATED)
;
#endif
#if CUDAPP_CUDA_VERSION >= 11030
py::enum_<CUstreamUpdateCaptureDependencies_flags>("update_capture_dependencies_flags")
.value("ADD_CAPTURE_DEPENDENCIES", CU_STREAM_ADD_CAPTURE_DEPENDENCIES)
.value("SET_CAPTURE_DEPENDENCIES", CU_STREAM_SET_CAPTURE_DEPENDENCIES)
;
#endif
{
typedef stream cl;
@@ -1294,7 +1300,9 @@ BOOST_PYTHON_MODULE(_driver)
py::return_value_policy<py::manage_new_object>())
.def("get_capture_info_v2", &cl::get_capture_info_v2)
#if CUDAPP_CUDA_VERSION >= 11030
.def("update_capture_dependencies", &cl::update_capture_dependencies)
.def("update_capture_dependencies", &cl::update_capture_dependencies,
(py::arg("dependencies"),
py::arg("flags") = CU_STREAM_ADD_CAPTURE_DEPENDENCIES))
#endif
#endif
.add_property("handle", &cl::handle_int)
6 changes: 3 additions & 3 deletions test/test_graph.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ def test_dynamic_params(self):
assert stat == drv.capture_status.ACTIVE, "Capture should be active"
assert len(deps) == 0, "Nothing on deps"
newnode = x_graph.add_kernel_node(a_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
stream_1.update_capture_dependencies([newnode], 1)
stream_1.update_capture_dependencies([newnode], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
drv.memcpy_dtoh_async(result, a_gpu, stream_1) # Capture a copy as well.
graph = stream_1.end_capture()
assert graph == x_graph, "Should be the same"
@@ -110,11 +110,11 @@ def test_many_dynamic_params(self):
assert stat == drv.capture_status.ACTIVE, "Capture should be active"
assert len(deps) == 0, "Nothing on deps"
newnode = x_graph.add_kernel_node(a_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
stream_1.update_capture_dependencies([newnode], 1)
stream_1.update_capture_dependencies([newnode], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
_, _, x_graph, deps = stream_1.get_capture_info_v2()
assert deps == [newnode], "Call to update_capture_dependencies should set newnode as the only dep"
newnode2 = x_graph.add_kernel_node(b_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
stream_1.update_capture_dependencies([newnode2], 1)
stream_1.update_capture_dependencies([newnode2], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)

# Static capture
func_times(a_gpu, b_gpu, block=(4, 4, 1), stream=stream_1)

0 comments on commit f232a64

Please sign in to comment.