Skip to content

Commit

Permalink
Make cuda-python 12.3.0 workaround specific to that version.
Browse files Browse the repository at this point in the history
It will crash in cuda-python 12.4.0.
  • Loading branch information
galv committed Feb 22, 2024
1 parent 118c01a commit dc3d1ff
Showing 1 changed file with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

import contextlib
from packaging.version import Version

import numpy as np
import torch

try:
from cuda import cuda, cudart, nvrtc
from cuda import __version__ as cuda_python_version

HAVE_CUDA_PYTHON = True
except ImportError:
Expand Down Expand Up @@ -144,13 +146,19 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
driver_params.conditional.handle = while_loop_conditional_handle
driver_params.conditional.type = cuda.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE
driver_params.conditional.size = 1
# Work around until https://github.com/NVIDIA/cuda-python/issues/55 is fixed
driver_params.conditional.phGraph_out = [cuda.CUgraph()]
if Version(cuda_python_version) == Version("12.3.0"):
# Work around for https://github.com/NVIDIA/cuda-python/issues/55
# Originally, cuda-python version 12.3.0 failed to allocate phGraph_out
# on its own.
# This bug is fixed in cuda-python version 12.4.0. In fact, we can
# no longer write to phGraph_out in cuda-python 12.4.0, so we must
# condition on the version number.
driver_params.conditional.phGraph_out = [cuda.CUgraph()]
(ctx,) = cu_call(cuda.cuCtxGetCurrent())
driver_params.conditional.ctx = ctx

# Use driver API here because of bug in cuda-python runtime API: https://github.com/NVIDIA/cuda-python/issues/55
# TODO: Change call to this after fix goes in:
# TODO: Change call to this after fix goes in (and we bump minimum cuda-python version to 12.4.0):
# node, = cu_call(cudart.cudaGraphAddNode(graph, dependencies, len(dependencies), driver_params))
(node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, len(dependencies), driver_params))
body_graph = driver_params.conditional.phGraph_out[0]
Expand Down

0 comments on commit dc3d1ff

Please sign in to comment.