diff --git a/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py b/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py index 3df27cb03cdf..cdcbd4e418f7 100644 --- a/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py @@ -33,7 +33,7 @@ from nemo.utils.enum import PrettyStrEnum try: - from cuda import cudart + from cuda.bindings import runtime as cudart HAVE_CUDA_PYTHON = True except ImportError: @@ -462,9 +462,9 @@ def _create_process_batch_kernel(cls): """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; - + extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); - + extern "C" __global__ void loop_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any) { @@ -573,7 +573,7 @@ def _full_graph_compile(self): torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_process_batch() - capture_status, _, graph, _, _ = cu_call( + capture_status, _, graph, _, _, _ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index 051cc8876919..0ed10bfe6ce1 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -38,7 +38,7 @@ from nemo.utils.enum import PrettyStrEnum try: - from cuda import cudart + from cuda.bindings import runtime as cudart HAVE_CUDA_PYTHON = True except ImportError: @@ -852,7 +852,7 @@ def _graph_reinitialize(self, logits, logits_len): ): self._before_loop() - capture_status, _, graph, _, _ = cu_call( + capture_status, _, graph, _, _, _ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive diff --git a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py index d4ed2c0bbec4..dff6c2fafd37 100644 --- a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py @@ -17,7 +17,7 @@ import torch try: - from cuda import cudart + from cuda.bindings import runtime as cudart HAVE_CUDA_PYTHON = True except ImportError: @@ -205,7 +205,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len # Get max sequence length self.max_out_len_t = self.encoder_output_length.max() - capture_status, _, graph, _, _ = cu_call( + capture_status, _, graph, _, _, _ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.device).cuda_stream) ) assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index b9df336e88a3..9faa83cde7e8 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -38,7 +38,7 @@ from nemo.utils.enum import PrettyStrEnum try: - from cuda import cudart + from cuda.bindings import runtime as cudart HAVE_CUDA_PYTHON = True except ImportError: @@ -726,9 +726,9 @@ def _create_loop_body_kernel(cls): """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; - + extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); - + extern "C" __global__ void loop_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any) { @@ -893,7 +893,7 @@ def _full_graph_compile(self): torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_loop() - capture_status, _, graph, _, _ = cu_call( + capture_status, _, graph, _, _, _ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) diff --git a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py index e4338c2420df..c30d80853e81 100644 --- a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py @@ -38,7 +38,7 @@ from nemo.utils.enum import PrettyStrEnum try: - from cuda import cudart + from cuda.bindings import runtime as cudart HAVE_CUDA_PYTHON = True except ImportError: @@ -805,9 +805,9 @@ def _create_loop_body_kernel(cls): """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; - + extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); - + extern "C" __global__ void loop_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any) { @@ -974,7 +974,7 @@ def _full_graph_compile(self): torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_loop() - capture_status, _, graph, _, _ = cu_call( + capture_status, _, graph, _, _, _ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py index a933a802563a..31677e16e2cf 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py @@ -32,7 +32,7 @@ from nemo.core.utils.cuda_python_utils import cu_call, run_nvrtc, with_conditional_node try: - from cuda import cudart + from cuda.bindings import runtime as cudart HAVE_CUDA_PYTHON = True except ImportError: @@ -767,9 +767,9 @@ def _create_outer_while_loop_kernel(cls): """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; - + extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); - + extern "C" __global__ void outer_label_looping_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any) { @@ -786,9 +786,9 @@ def _create_inner_while_loop_kernel(cls): """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; - + extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); - + extern "C" __global__ void inner_find_non_blank_conditional(cudaGraphConditionalHandle handle, const bool *advance_mask_any) { @@ -943,7 +943,7 @@ def _full_graph_compile(self): ): self._before_outer_loop() - capture_status, _, graph, _, _ = cu_call( + capture_status, _, graph, _, _, _ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py index 12e9ec9d98c2..7b48e3137df8 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py @@ -32,7 +32,7 @@ from nemo.core.utils.cuda_python_utils import cu_call, run_nvrtc, with_conditional_node try: - from cuda import cudart + from cuda.bindings import runtime as cudart HAVE_CUDA_PYTHON = True except ImportError: @@ -847,9 +847,9 @@ def _create_outer_while_loop_kernel(cls): """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; - + extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); - + extern "C" __global__ void outer_label_looping_conditional(cudaGraphConditionalHandle handle, const bool *active_mask_any) { @@ -866,9 +866,9 @@ def _create_inner_while_loop_kernel(cls): """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; - + extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); - + extern "C" __global__ void inner_find_non_blank_conditional(cudaGraphConditionalHandle handle, const bool *advance_mask_any) { @@ -1026,7 +1026,7 @@ def _full_graph_compile(self): ): self._before_outer_loop() - capture_status, _, graph, _, _ = cu_call( + capture_status, _, graph, _, _, _ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive diff --git a/nemo/core/utils/cuda_python_utils.py b/nemo/core/utils/cuda_python_utils.py index 0ce86a91872a..c732642b60e1 100644 --- a/nemo/core/utils/cuda_python_utils.py +++ b/nemo/core/utils/cuda_python_utils.py @@ -27,11 +27,11 @@ def check_cuda_python_cuda_graphs_conditional_nodes_supported(): raise EnvironmentError("CUDA is not available") try: - from cuda import cuda + from cuda.bindings import driver as cuda except ImportError: raise ModuleNotFoundError("No `cuda-python` module. Please do `pip install cuda-python>=12.3`") - from cuda import __version__ as cuda_python_version + from cuda.bindings import __version__ as cuda_python_version if Version(cuda_python_version) < Version("12.3.0"): raise ImportError(f"Found cuda-python {cuda_python_version}, but at least version 12.3.0 is needed.") @@ -72,7 +72,9 @@ def assert_drv(err): """ Throws an exception if the return value of a cuda-python call is not success. """ - from cuda import cuda, cudart, nvrtc + from cuda.bindings import driver as cuda + from cuda.bindings import nvrtc + from cuda.bindings import runtime as cudart if isinstance(err, cuda.CUresult): if err != cuda.CUresult.CUDA_SUCCESS: @@ -92,7 +94,7 @@ def cu_call(f_call_out): Makes calls to cuda-python's functions inside cuda.cuda more python by throwing an exception if they return a status which is not cudaSuccess """ - from cuda import cudart + from cuda.bindings import runtime as cudart error, *others = f_call_out if error != cudart.cudaError_t.cudaSuccess: @@ -111,10 +113,11 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi to decide both whether to enter the loop, and also whether to execute the next iteration of the loop). """ - from cuda import __version__ as cuda_python_version - from cuda import cuda, cudart + from cuda.bindings import __version__ as cuda_python_version + from cuda.bindings import driver as cuda + from cuda.bindings import runtime as cudart - capture_status, _, graph, _, _ = cu_call( + capture_status, _, graph, _, _, _ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream) ) assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive @@ -133,7 +136,7 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi 0, ) - capture_status, _, graph, dependencies, _ = cu_call( + capture_status, _, graph, dependencies, _, _ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream) ) assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive @@ -157,13 +160,14 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi # Use driver API here because of bug in cuda-python runtime API: https://github.com/NVIDIA/cuda-python/issues/55 # TODO: Change call to this after fix goes in (and we bump minimum cuda-python version to 12.4.0): # node, = cu_call(cudart.cudaGraphAddNode(graph, dependencies, len(dependencies), driver_params)) - (node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, len(dependencies), driver_params)) + (node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, None, len(dependencies), driver_params)) body_graph = driver_params.conditional.phGraph_out[0] cu_call( cudart.cudaStreamUpdateCaptureDependencies( torch.cuda.current_stream(device=device).cuda_stream, [node], + None, 1, cudart.cudaStreamUpdateCaptureDependenciesFlags.cudaStreamSetCaptureDependencies, ) @@ -194,7 +198,8 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi def run_nvrtc(kernel_string: str, kernel_name: bytes, program_name: bytes): - from cuda import cuda, nvrtc + from cuda.bindings import driver as cuda + from cuda.bindings import nvrtc err, prog = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), program_name, 0, [], []) assert_drv(err) diff --git a/nemo/core/utils/numba_utils.py b/nemo/core/utils/numba_utils.py index 9117b2ea1010..077f2f9f6c51 100644 --- a/nemo/core/utils/numba_utils.py +++ b/nemo/core/utils/numba_utils.py @@ -110,7 +110,7 @@ def numba_cuda_is_supported(min_version: str) -> bool: """ Tests if an appropriate version of numba is installed, and if it is, if cuda is supported properly within it. - + Args: min_version: The minimum version of numba that is required. @@ -127,25 +127,20 @@ def numba_cuda_is_supported(min_version: str) -> bool: if module_available is True: from numba import cuda - # this method first arrived in 0.53, and that's the minimum version required - if hasattr(cuda, 'is_supported_version'): - try: - cuda_available = cuda.is_available() - if cuda_available: - cuda_compatible = cuda.is_supported_version() - else: - cuda_compatible = False - - if is_numba_compat_strict(): - return cuda_available and cuda_compatible - else: - return cuda_available - - except OSError: - # dlopen(libcudart.dylib) might fail if CUDA was never installed in the first place. - return False - else: - # assume cuda is supported, but it may fail due to CUDA incompatibility + try: + cuda_available = cuda.is_available() + if cuda_available: + cuda_compatible = cuda.cudadrv.runtime.get_version()[0] == 13 + else: + cuda_compatible = False + + if is_numba_compat_strict(): + return cuda_available and cuda_compatible + else: + return cuda_available + + except Exception: + # dlopen(libcudart.dylib) might fail if CUDA was never installed in the first place. return False else: @@ -188,7 +183,7 @@ def is_numba_cuda_fp16_supported(return_reason: bool = False) -> Union[bool, Tup def skip_numba_cuda_test_if_unsupported(min_version: str): """ Helper method to skip pytest test case if numba cuda is not supported. - + Args: min_version: The minimum version of numba that is required. """ diff --git a/requirements/requirements.txt b/requirements/requirements.txt index bdf67e7691c9..d2f8aae72344 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,7 @@ fsspec==2024.12.0 huggingface_hub>=0.24 -numba +numba ; platform_system == 'Darwin' +numba-cuda[cu13]>=0.20.0 ; platform_system != 'Darwin' numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing numpy>=1.22 onnx>=1.7.0 @@ -15,4 +16,4 @@ text-unidecode torch tqdm>=4.41.0 wget -wrapt \ No newline at end of file +wrapt