Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add support for CUDA Graphs. #343

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add support for CUDA Graphs.
gfokkema committed Jan 15, 2022

Verified

This commit was signed with the committer’s verified signature.
sparrc Cameron Sparr
commit 57ca7e8879e2d1ce73d57232d81af883722eea85
61 changes: 61 additions & 0 deletions src/cpp/cuda.hpp
Original file line number Diff line number Diff line change
@@ -990,6 +990,7 @@ namespace pycuda

// {{{ stream
class event;
class graph;

class stream : public boost::noncopyable, public context_dependent
{
@@ -1022,6 +1023,10 @@ namespace pycuda
#if CUDAPP_CUDA_VERSION >= 3020
void wait_for_event(const event &evt);
#endif
#if CUDAPP_CUDA_VERSION >= 10000
void begin_capture(CUstreamCaptureMode mode);
graph *end_capture();
#endif

bool is_done() const
{
@@ -1042,6 +1047,62 @@ namespace pycuda

// }}}

// {{{ graph
#if CUDAPP_CUDA_VERSION >= 10000
class graph_exec : public boost::noncopyable, public context_dependent
{
private:
CUgraphExec m_exec;

public:
graph_exec(CUgraphExec exec)
: m_exec(exec)
{ }

void launch(py::object stream_py)
{
PYCUDA_PARSE_STREAM_PY;
CUDAPP_CALL_GUARDED(cuGraphLaunch, (m_exec, s_handle))
}
};

class graph : public boost::noncopyable, public context_dependent
{
private:
CUgraph m_graph;

public:
graph(CUgraph graph)
: m_graph(graph)
{ }

graph_exec *instance()
{
CUgraphExec instance;
CUDAPP_CALL_GUARDED(cuGraphInstantiate, (&instance, m_graph, NULL, NULL, 0))
return new graph_exec(instance);
}

void debug_dot_print(std::string path)
{
CUDAPP_CALL_GUARDED(cuGraphDebugDotPrint, (m_graph, path.c_str(), 0))
}
};

inline void stream::begin_capture(CUstreamCaptureMode mode = CU_STREAM_CAPTURE_MODE_GLOBAL)
{
CUDAPP_CALL_GUARDED(cuStreamBeginCapture, (m_stream, mode));
}

inline graph *stream::end_capture()
{
CUgraph result;
CUDAPP_CALL_GUARDED(cuStreamEndCapture, (m_stream, &result))
return new graph(result);
}
#endif
// }}}

// {{{ array
class array : public boost::noncopyable, public context_dependent
{
35 changes: 35 additions & 0 deletions src/wrapper/wrap_cudadrv.cpp
Original file line number Diff line number Diff line change
@@ -1196,6 +1196,13 @@ BOOST_PYTHON_MODULE(_driver)
// }}}

// {{{ stream
#if CUDAPP_CUDA_VERSION >= 10000
py::enum_<CUstreamCaptureMode>("capture_mode")
.value("GLOBAL", CU_STREAM_CAPTURE_MODE_GLOBAL)
.value("THREAD_LOCAL", CU_STREAM_CAPTURE_MODE_THREAD_LOCAL)
.value("RELAXED", CU_STREAM_CAPTURE_MODE_RELAXED)
;
#endif
{
typedef stream cl;
py::class_<cl, boost::noncopyable, shared_ptr<cl> >
@@ -1204,12 +1211,40 @@ BOOST_PYTHON_MODULE(_driver)
.DEF_SIMPLE_METHOD(is_done)
#if CUDAPP_CUDA_VERSION >= 3020
.DEF_SIMPLE_METHOD(wait_for_event)
#endif
#if CUDAPP_CUDA_VERSION >= 10000
.def("begin_capture", &cl::begin_capture,
py::arg("capture_mode") = CU_STREAM_CAPTURE_MODE_GLOBAL)
.def("end_capture", &cl::end_capture,
py::return_value_policy<py::manage_new_object>())
#endif
.add_property("handle", &cl::handle_int)
;
}
// }}}

// {{{ graph
#if CUDAPP_CUDA_VERSION >= 10000
;
{
typedef graph_exec cl;
py::class_<cl, boost::noncopyable>("GraphExec", py::no_init)
.def("launch", &cl::launch,
py::arg("stream")=py::object())
;
}

{
typedef graph cl;
py::class_<cl, boost::noncopyable>("Graph", py::no_init)
.def("instance", &cl::instance,
py::return_value_policy<py::manage_new_object>())
.DEF_SIMPLE_METHOD(debug_dot_print)
;
}
#endif
// }}}

// {{{ module
{
typedef module cl;