Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cudax/include/cuda/experimental/__stf/graph/graph_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,11 @@ public:
size_t nnodes;

cuda_safe_call(cudaGraphGetNodes(*g, nullptr, &nnodes));
#if _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphGetEdges(*g, nullptr, nullptr, nullptr, &nedges));
#else // _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphGetEdges(*g, nullptr, nullptr, &nedges));
#endif // _CCCL_CTK_AT_LEAST(13, 0)

auto& state = this->state();

Expand Down Expand Up @@ -499,7 +503,11 @@ public:
cuda_safe_call(cudaGraphGetNodes(g, nullptr, &numNodes));

size_t numEdges;
#if _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphGetEdges(g, nullptr, nullptr, nullptr, &numEdges));
#else // _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphGetEdges(g, nullptr, nullptr, &numEdges));
#endif // _CCCL_CTK_AT_LEAST(13, 0)

cuuint64_t mem_attr;
cuda_safe_call(cudaDeviceGetGraphMemAttribute(0, cudaGraphMemAttrUsedMemHigh, &mem_attr));
Expand Down Expand Up @@ -617,7 +625,11 @@ private:
size_t nnodes;

cuda_safe_call(cudaGraphGetNodes(g, nullptr, &nnodes));
#if _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphGetEdges(g, nullptr, nullptr, nullptr, &nedges));
#else // _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphGetEdges(g, nullptr, nullptr, &nedges));
#endif // _CCCL_CTK_AT_LEAST(13, 0)

cudaGraphExec_t local_exec_graph = nullptr;

Expand Down
19 changes: 19 additions & 0 deletions cudax/include/cuda/experimental/__stf/graph/graph_task.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,34 @@ public:
#ifndef NDEBUG
// Ensure the node does not have dependencies yet
size_t num_deps;
# if _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphNodeGetDependencies(node, nullptr, nullptr, &num_deps));
# else // _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphNodeGetDependencies(node, nullptr, &num_deps));
# endif // _CCCL_CTK_AT_LEAST(13, 0)
assert(num_deps == 0);

// Ensure there are no output dependencies either (or we could not
// add input dependencies later)
size_t num_deps_out;
# if _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphNodeGetDependentNodes(node, nullptr, nullptr, &num_deps_out));
# else // _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphNodeGetDependentNodes(node, nullptr, &num_deps_out));
# endif // _CCCL_CTK_AT_LEAST(13, 0)
assert(num_deps_out == 0);
#endif

// Repeat node as many times as there are input dependencies
::std::vector<cudaGraphNode_t> out_array(ready_dependencies.size(), node);
#if _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphAddDependencies(
ctx_graph, ready_dependencies.data(), out_array.data(), nullptr, ready_dependencies.size()));
#else // _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphAddDependencies(
ctx_graph, ready_dependencies.data(), out_array.data(), ready_dependencies.size()));
#endif // _CCCL_CTK_AT_LEAST(13, 0)

auto gnp = reserved::graph_event(node, stage, ctx_graph);
gnp->set_symbol(ctx, "done " + get_symbol());
/* This node is now the output dependency of the task */
Expand All @@ -161,8 +175,13 @@ public:
{
// First node depends on ready_dependencies
::std::vector<cudaGraphNode_t> out_array(ready_dependencies.size(), chained_task_nodes[0]);
#if _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphAddDependencies(
ctx_graph, ready_dependencies.data(), out_array.data(), nullptr, ready_dependencies.size()));
#else // _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(
cudaGraphAddDependencies(ctx_graph, ready_dependencies.data(), out_array.data(), ready_dependencies.size()));
#endif // _CCCL_CTK_AT_LEAST(13, 0)

// Overall the task depends on the completion of the last node
auto gnp = reserved::graph_event(chained_task_nodes.back(), stage, ctx_graph);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,11 @@ private:
kernel_descs[i].launch_in_graph(chain[i], g);
if (i > 0)
{
#if _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphAddDependencies(g, &chain[i - 1], &chain[i], nullptr, 1));
#else // _CCCL_CTK_AT_LEAST(13, 0)
cuda_safe_call(cudaGraphAddDependencies(g, &chain[i - 1], &chain[i], 1));
#endif // _CCCL_CTK_AT_LEAST(13, 0)
}
}
}
Expand Down