Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add support for CUDA Graphs. #343

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/demo_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Sample source code from the Tutorial Introduction in the documentation.
import pycuda.driver as cuda
import pycuda.autoinit # noqa
from pycuda.compiler import SourceModule

mod = SourceModule("""
__global__ void plus(float *a, int num)
{
int idx = threadIdx.x + threadIdx.y*4;
a[idx] += num;
}

__global__ void times(float *a, float *b)
{
int idx = threadIdx.x + threadIdx.y*4;
a[idx] *= b[idx];
}
""")
func_plus = mod.get_function("plus")
func_times = mod.get_function("times")

import numpy
a = numpy.zeros((4, 4)).astype(numpy.float32)
a_gpu = cuda.mem_alloc_like(a)
b = numpy.zeros((4, 4)).astype(numpy.float32)
b_gpu = cuda.mem_alloc_like(b)
result = numpy.zeros_like(b)

# begin graph capture, pull stream_2 into it as a dependency
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cross-stream-dependencies
stream_1 = cuda.Stream()
stream_2 = cuda.Stream()
stream_1.begin_capture()
event_init = cuda.Event()
event_a = cuda.Event()
event_b = cuda.Event()

event_init.record(stream_1)
stream_2.wait_for_event(event_init)

cuda.memcpy_htod_async(a_gpu, a, stream_1)
func_plus(a_gpu, numpy.int32(2), block=(4, 4, 1), stream=stream_1)
event_a.record(stream_1)

cuda.memcpy_htod_async(b_gpu, b, stream_2)
func_plus(b_gpu, numpy.int32(3), block=(4, 4, 1), stream=stream_2)
event_b.record(stream_2)

stream_1.wait_for_event(event_a)
stream_1.wait_for_event(event_b)
func_times(a_gpu, b_gpu, block=(4, 4, 1), stream=stream_1)
cuda.memcpy_dtoh_async(result, a_gpu, stream_1)

graph = stream_1.end_capture()
graph.debug_dot_print("test.dot") # print dotfile of graph
instance = graph.instance()

# using a separate graph stream to launch, this is not strictly necessary
stream_graph = cuda.Stream()
instance.launch(stream_graph)

print("original arrays:")
print(a)
print(b)
print("(0+2)x(0+3) = 6, using a kernel graph of 3 kernels:")
print(result)
65 changes: 63 additions & 2 deletions src/cpp/cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ typedef Py_ssize_t PYCUDA_BUFFER_SIZE_T;
<< std::endl; \
}
#define CUDAPP_CATCH_CLEANUP_ON_DEAD_CONTEXT(TYPE) \
catch (pycuda::cannot_activate_out_of_thread_context) \
catch (pycuda::cannot_activate_out_of_thread_context const&) \
{ } \
catch (pycuda::cannot_activate_dead_context) \
catch (pycuda::cannot_activate_dead_context const&) \
{ \
/* PyErr_Warn( \
PyExc_UserWarning, #TYPE " in dead context was implicitly cleaned up");*/ \
Expand Down Expand Up @@ -990,6 +990,7 @@ namespace pycuda

// {{{ stream
class event;
class graph;

class stream : public boost::noncopyable, public context_dependent
{
Expand Down Expand Up @@ -1022,6 +1023,10 @@ namespace pycuda
#if CUDAPP_CUDA_VERSION >= 3020
void wait_for_event(const event &evt);
#endif
#if CUDAPP_CUDA_VERSION >= 10000
void begin_capture(CUstreamCaptureMode mode);
graph *end_capture();
#endif

bool is_done() const
{
Expand All @@ -1042,6 +1047,62 @@ namespace pycuda

// }}}

// {{{ graph
#if CUDAPP_CUDA_VERSION >= 10000
class graph_exec : public boost::noncopyable, public context_dependent
{
private:
CUgraphExec m_exec;

public:
graph_exec(CUgraphExec exec)
: m_exec(exec)
{ }

void launch(py::object stream_py)
{
PYCUDA_PARSE_STREAM_PY;
CUDAPP_CALL_GUARDED(cuGraphLaunch, (m_exec, s_handle))
}
};

class graph : public boost::noncopyable, public context_dependent
{
private:
CUgraph m_graph;

public:
graph(CUgraph graph)
: m_graph(graph)
{ }

graph_exec *instance()
{
CUgraphExec instance;
CUDAPP_CALL_GUARDED(cuGraphInstantiate, (&instance, m_graph, NULL, NULL, 0))
return new graph_exec(instance);
}

void debug_dot_print(std::string path)
{
CUDAPP_CALL_GUARDED(cuGraphDebugDotPrint, (m_graph, path.c_str(), 0))
}
};

inline void stream::begin_capture(CUstreamCaptureMode mode = CU_STREAM_CAPTURE_MODE_GLOBAL)
{
CUDAPP_CALL_GUARDED(cuStreamBeginCapture, (m_stream, mode));
}

inline graph *stream::end_capture()
{
CUgraph result;
CUDAPP_CALL_GUARDED(cuStreamEndCapture, (m_stream, &result))
return new graph(result);
}
#endif
// }}}

// {{{ array
class array : public boost::noncopyable, public context_dependent
{
Expand Down
35 changes: 35 additions & 0 deletions src/wrapper/wrap_cudadrv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,13 @@ BOOST_PYTHON_MODULE(_driver)
// }}}

// {{{ stream
#if CUDAPP_CUDA_VERSION >= 10000
py::enum_<CUstreamCaptureMode>("capture_mode")
.value("GLOBAL", CU_STREAM_CAPTURE_MODE_GLOBAL)
.value("THREAD_LOCAL", CU_STREAM_CAPTURE_MODE_THREAD_LOCAL)
.value("RELAXED", CU_STREAM_CAPTURE_MODE_RELAXED)
;
#endif
{
typedef stream cl;
py::class_<cl, boost::noncopyable, shared_ptr<cl> >
Expand All @@ -1204,12 +1211,40 @@ BOOST_PYTHON_MODULE(_driver)
.DEF_SIMPLE_METHOD(is_done)
#if CUDAPP_CUDA_VERSION >= 3020
.DEF_SIMPLE_METHOD(wait_for_event)
#endif
#if CUDAPP_CUDA_VERSION >= 10000
.def("begin_capture", &cl::begin_capture,
py::arg("capture_mode") = CU_STREAM_CAPTURE_MODE_GLOBAL)
.def("end_capture", &cl::end_capture,
py::return_value_policy<py::manage_new_object>())
#endif
.add_property("handle", &cl::handle_int)
;
}
// }}}

// {{{ graph
#if CUDAPP_CUDA_VERSION >= 10000
;
{
typedef graph_exec cl;
py::class_<cl, boost::noncopyable>("GraphExec", py::no_init)
.def("launch", &cl::launch,
py::arg("stream")=py::object())
;
}

{
typedef graph cl;
py::class_<cl, boost::noncopyable>("Graph", py::no_init)
.def("instance", &cl::instance,
py::return_value_policy<py::manage_new_object>())
.DEF_SIMPLE_METHOD(debug_dot_print)
;
}
#endif
// }}}

// {{{ module
{
typedef module cl;
Expand Down