diff --git a/conda/environments/all_cuda-129_arch-aarch64.yaml b/conda/environments/all_cuda-129_arch-aarch64.yaml index d94acac0a..7f054feab 100644 --- a/conda/environments/all_cuda-129_arch-aarch64.yaml +++ b/conda/environments/all_cuda-129_arch-aarch64.yaml @@ -23,7 +23,7 @@ dependencies: - librmm==25.8.*,>=0.0.0a0 - libtool - ninja -- numba>=0.59.1,<0.62.0a0 +- numba-cuda>=0.16.0,<0.17.0a0 - numpy>=1.23,<3.0a0 - pip - pkg-config diff --git a/conda/environments/all_cuda-129_arch-x86_64.yaml b/conda/environments/all_cuda-129_arch-x86_64.yaml index 492ba0adb..596256f2a 100644 --- a/conda/environments/all_cuda-129_arch-x86_64.yaml +++ b/conda/environments/all_cuda-129_arch-x86_64.yaml @@ -23,7 +23,7 @@ dependencies: - librmm==25.8.*,>=0.0.0a0 - libtool - ninja -- numba>=0.59.1,<0.62.0a0 +- numba-cuda>=0.16.0,<0.17.0a0 - numpy>=1.23,<3.0a0 - pip - pkg-config diff --git a/dependencies.yaml b/dependencies.yaml index fabc069ac..fa937eb32 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -289,7 +289,7 @@ dependencies: common: - output_types: [conda, requirements, pyproject] packages: - - &numba numba>=0.59.1,<0.62.0a0 + - &numba-cuda numba-cuda>=0.16.0,<0.17.0a0 - rapids-dask-dependency==25.8.*,>=0.0.0a0 test_cpp: common: @@ -301,7 +301,7 @@ dependencies: - output_types: [conda, requirements, pyproject] packages: - cloudpickle - - *numba + - *numba-cuda - pytest==7.* - pytest-asyncio - pytest-rerunfailures diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 723fe8483..92a72795a 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -18,6 +18,7 @@ import struct import weakref from collections.abc import Awaitable, Callable, Collection +from enum import Enum from typing import TYPE_CHECKING, Any from unittest.mock import patch @@ -86,13 +87,18 @@ def _warn_cuda_context_wrong_device( ) -def synchronize_stream(stream=0): +class CudaStream(Enum): + Default = 0 + + +def synchronize_stream(stream: CudaStream = CudaStream.Default): import numba.cuda - ctx = numba.cuda.current_context() - cu_stream = numba.cuda.driver.drvapi.cu_stream(stream) - stream = numba.cuda.driver.Stream(ctx, cu_stream, None) - stream.synchronize() + if stream == CudaStream.Default: + numba_stream = numba.cuda.default_stream() + else: + raise ValueError("Unsupported stream") + numba_stream.synchronize() class gc_disabled: @@ -467,7 +473,7 @@ async def write( try: if multi_buffer is True: if any(hasattr(f, "__cuda_array_interface__") for f in frames): - synchronize_stream(0) + synchronize_stream(CudaStream.Default) close = [struct.pack("?", False)] await self.ep.send_multi(close + frames) @@ -502,7 +508,7 @@ async def write( # non-blocking CUDA streams. Note this is only sufficient if the memory # being sent is not currently in use on non-blocking CUDA streams. if any(cuda_send_frames): - synchronize_stream(0) + synchronize_stream(CudaStream.Default) for each_frame in send_frames: await self.ep.send(each_frame) @@ -581,7 +587,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): # synchronize the default stream before starting receiving to ensure # buffers have been allocated if any(cuda_recv_frames): - synchronize_stream(0) + synchronize_stream(CudaStream.Default) try: for each_frame in recv_frames: diff --git a/python/distributed-ucxx/pyproject.toml b/python/distributed-ucxx/pyproject.toml index 0f96d0bd0..84e86b6c6 100644 --- a/python/distributed-ucxx/pyproject.toml +++ b/python/distributed-ucxx/pyproject.toml @@ -18,7 +18,7 @@ authors = [ license = { text = "BSD-3-Clause" } requires-python = ">=3.10" dependencies = [ - "numba>=0.59.1,<0.62.0a0", + "numba-cuda>=0.16.0,<0.17.0a0", "rapids-dask-dependency==25.8.*,>=0.0.0a0", "ucxx==0.45.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. diff --git a/python/ucxx/pyproject.toml b/python/ucxx/pyproject.toml index 3a5c92986..f570c853f 100644 --- a/python/ucxx/pyproject.toml +++ b/python/ucxx/pyproject.toml @@ -41,7 +41,7 @@ test = [ "cloudpickle", "cudf==25.8.*,>=0.0.0a0", "cupy-cuda12x>=12.0.0", - "numba>=0.59.1,<0.62.0a0", + "numba-cuda>=0.16.0,<0.17.0a0", "pytest-asyncio", "pytest-rerunfailures", "pytest==7.*",