Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from nemo.utils.enum import PrettyStrEnum

try:
from cuda import cudart
from cuda.bindings import runtime as cudart

HAVE_CUDA_PYTHON = True
except ImportError:
Expand Down Expand Up @@ -462,9 +462,9 @@ def _create_process_batch_kernel(cls):
"""
kernel_string = r"""\
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;

extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);

extern "C" __global__
void loop_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any)
{
Expand Down Expand Up @@ -573,7 +573,7 @@ def _full_graph_compile(self):
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_process_batch()
capture_status, _, graph, _, _ = cu_call(
capture_status, _, graph, _, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
)

Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from nemo.utils.enum import PrettyStrEnum

try:
from cuda import cudart
from cuda.bindings import runtime as cudart

HAVE_CUDA_PYTHON = True
except ImportError:
Expand Down Expand Up @@ -852,7 +852,7 @@ def _graph_reinitialize(self, logits, logits_len):
):
self._before_loop()

capture_status, _, graph, _, _ = cu_call(
capture_status, _, graph, _, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
)
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

try:
from cuda import cudart
from cuda.bindings import runtime as cudart

HAVE_CUDA_PYTHON = True
except ImportError:
Expand Down Expand Up @@ -205,7 +205,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
# Get max sequence length
self.max_out_len_t = self.encoder_output_length.max()

capture_status, _, graph, _, _ = cu_call(
capture_status, _, graph, _, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.device).cuda_stream)
)
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from nemo.utils.enum import PrettyStrEnum

try:
from cuda import cudart
from cuda.bindings import runtime as cudart

HAVE_CUDA_PYTHON = True
except ImportError:
Expand Down Expand Up @@ -726,9 +726,9 @@ def _create_loop_body_kernel(cls):
"""
kernel_string = r"""\
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;

extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);

extern "C" __global__
void loop_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any)
{
Expand Down Expand Up @@ -893,7 +893,7 @@ def _full_graph_compile(self):
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_loop()
capture_status, _, graph, _, _ = cu_call(
capture_status, _, graph, _, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from nemo.utils.enum import PrettyStrEnum

try:
from cuda import cudart
from cuda.bindings import runtime as cudart

HAVE_CUDA_PYTHON = True
except ImportError:
Expand Down Expand Up @@ -805,9 +805,9 @@ def _create_loop_body_kernel(cls):
"""
kernel_string = r"""\
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;

extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);

extern "C" __global__
void loop_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any)
{
Expand Down Expand Up @@ -974,7 +974,7 @@ def _full_graph_compile(self):
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_loop()
capture_status, _, graph, _, _ = cu_call(
capture_status, _, graph, _, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from nemo.core.utils.cuda_python_utils import cu_call, run_nvrtc, with_conditional_node

try:
from cuda import cudart
from cuda.bindings import runtime as cudart

HAVE_CUDA_PYTHON = True
except ImportError:
Expand Down Expand Up @@ -767,9 +767,9 @@ def _create_outer_while_loop_kernel(cls):
"""
kernel_string = r"""\
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;

extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);

extern "C" __global__
void outer_label_looping_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any)
{
Expand All @@ -786,9 +786,9 @@ def _create_inner_while_loop_kernel(cls):
"""
kernel_string = r"""\
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;

extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);

extern "C" __global__
void inner_find_non_blank_conditional(cudaGraphConditionalHandle handle, const bool *advance_mask_any)
{
Expand Down Expand Up @@ -943,7 +943,7 @@ def _full_graph_compile(self):
):
self._before_outer_loop()

capture_status, _, graph, _, _ = cu_call(
capture_status, _, graph, _, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
)
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from nemo.core.utils.cuda_python_utils import cu_call, run_nvrtc, with_conditional_node

try:
from cuda import cudart
from cuda.bindings import runtime as cudart

HAVE_CUDA_PYTHON = True
except ImportError:
Expand Down Expand Up @@ -847,9 +847,9 @@ def _create_outer_while_loop_kernel(cls):
"""
kernel_string = r"""\
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;

extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);

extern "C" __global__
void outer_label_looping_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any)
{
Expand All @@ -866,9 +866,9 @@ def _create_inner_while_loop_kernel(cls):
"""
kernel_string = r"""\
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;

extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);

extern "C" __global__
void inner_find_non_blank_conditional(cudaGraphConditionalHandle handle, const bool *advance_mask_any)
{
Expand Down Expand Up @@ -1026,7 +1026,7 @@ def _full_graph_compile(self):
):
self._before_outer_loop()

capture_status, _, graph, _, _ = cu_call(
capture_status, _, graph, _, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
)
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive
Expand Down
25 changes: 15 additions & 10 deletions nemo/core/utils/cuda_python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def check_cuda_python_cuda_graphs_conditional_nodes_supported():
raise EnvironmentError("CUDA is not available")

try:
from cuda import cuda
from cuda.bindings import driver as cuda
except ImportError:
raise ModuleNotFoundError("No `cuda-python` module. Please do `pip install cuda-python>=12.3`")

from cuda import __version__ as cuda_python_version
from cuda.bindings import __version__ as cuda_python_version

if Version(cuda_python_version) < Version("12.3.0"):
raise ImportError(f"Found cuda-python {cuda_python_version}, but at least version 12.3.0 is needed.")
Expand Down Expand Up @@ -72,7 +72,9 @@ def assert_drv(err):
"""
Throws an exception if the return value of a cuda-python call is not success.
"""
from cuda import cuda, cudart, nvrtc
from cuda.bindings import driver as cuda
from cuda.bindings import nvrtc
from cuda.bindings import runtime as cudart

if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
Expand All @@ -92,7 +94,7 @@ def cu_call(f_call_out):
Makes calls to cuda-python's functions inside cuda.cuda more python by throwing an exception
if they return a status which is not cudaSuccess
"""
from cuda import cudart
from cuda.bindings import runtime as cudart

error, *others = f_call_out
if error != cudart.cudaError_t.cudaSuccess:
Expand All @@ -111,10 +113,11 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
to decide both whether to enter the loop, and also whether to
execute the next iteration of the loop).
"""
from cuda import __version__ as cuda_python_version
from cuda import cuda, cudart
from cuda.bindings import __version__ as cuda_python_version
from cuda.bindings import driver as cuda
from cuda.bindings import runtime as cudart

capture_status, _, graph, _, _ = cu_call(
capture_status, _, graph, _, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream)
)
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive
Expand All @@ -133,7 +136,7 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
0,
)

capture_status, _, graph, dependencies, _ = cu_call(
capture_status, _, graph, dependencies, _, _ = cu_call(
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream)
)
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive
Expand All @@ -157,13 +160,14 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
# Use driver API here because of bug in cuda-python runtime API: https://github.com/NVIDIA/cuda-python/issues/55
# TODO: Change call to this after fix goes in (and we bump minimum cuda-python version to 12.4.0):
# node, = cu_call(cudart.cudaGraphAddNode(graph, dependencies, len(dependencies), driver_params))
(node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, len(dependencies), driver_params))
(node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, None, len(dependencies), driver_params))
body_graph = driver_params.conditional.phGraph_out[0]

cu_call(
cudart.cudaStreamUpdateCaptureDependencies(
torch.cuda.current_stream(device=device).cuda_stream,
[node],
None,
1,
cudart.cudaStreamUpdateCaptureDependenciesFlags.cudaStreamSetCaptureDependencies,
)
Expand Down Expand Up @@ -194,7 +198,8 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi


def run_nvrtc(kernel_string: str, kernel_name: bytes, program_name: bytes):
from cuda import cuda, nvrtc
from cuda.bindings import driver as cuda
from cuda.bindings import nvrtc

err, prog = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), program_name, 0, [], [])
assert_drv(err)
Expand Down
37 changes: 16 additions & 21 deletions nemo/core/utils/numba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def numba_cuda_is_supported(min_version: str) -> bool:
"""
Tests if an appropriate version of numba is installed, and if it is,
if cuda is supported properly within it.

Args:
min_version: The minimum version of numba that is required.

Expand All @@ -127,25 +127,20 @@ def numba_cuda_is_supported(min_version: str) -> bool:
if module_available is True:
from numba import cuda

# this method first arrived in 0.53, and that's the minimum version required
if hasattr(cuda, 'is_supported_version'):
try:
cuda_available = cuda.is_available()
if cuda_available:
cuda_compatible = cuda.is_supported_version()
else:
cuda_compatible = False

if is_numba_compat_strict():
return cuda_available and cuda_compatible
else:
return cuda_available

except OSError:
# dlopen(libcudart.dylib) might fail if CUDA was never installed in the first place.
return False
else:
# assume cuda is supported, but it may fail due to CUDA incompatibility
try:
cuda_available = cuda.is_available()
if cuda_available:
cuda_compatible = cuda.cudadrv.runtime.get_version()[0] == 13
else:
cuda_compatible = False

if is_numba_compat_strict():
return cuda_available and cuda_compatible
else:
return cuda_available

except Exception:
# dlopen(libcudart.dylib) might fail if CUDA was never installed in the first place.
return False

else:
Expand Down Expand Up @@ -188,7 +183,7 @@ def is_numba_cuda_fp16_supported(return_reason: bool = False) -> Union[bool, Tup
def skip_numba_cuda_test_if_unsupported(min_version: str):
"""
Helper method to skip pytest test case if numba cuda is not supported.

Args:
min_version: The minimum version of numba that is required.
"""
Expand Down
5 changes: 3 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
fsspec==2024.12.0
huggingface_hub>=0.24
numba
numba ; platform_system == 'Darwin'
numba-cuda[cu13]>=0.20.0 ; platform_system != 'Darwin'
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
numpy>=1.22
onnx>=1.7.0
Expand All @@ -15,4 +16,4 @@ text-unidecode
torch
tqdm>=4.41.0
wget
wrapt
wrapt
Loading