diff --git a/numba_cuda/numba/cuda/__init__.py b/numba_cuda/numba/cuda/__init__.py index 37a4475c8..a17b5b186 100644 --- a/numba_cuda/numba/cuda/__init__.py +++ b/numba_cuda/numba/cuda/__init__.py @@ -3,23 +3,21 @@ from numba.core import config from .utils import _readenv -# Enable pynvjitlink if the environment variables NUMBA_CUDA_ENABLE_PYNVJITLINK -# or CUDA_ENABLE_PYNVJITLINK are set, or if the pynvjitlink module is found. If -# explicitly disabled, do not use pynvjitlink, even if present in the env. -_pynvjitlink_enabled_in_env = _readenv( - "NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, None -) -_pynvjitlink_enabled_in_cfg = getattr(config, "CUDA_ENABLE_PYNVJITLINK", None) - -if _pynvjitlink_enabled_in_env is not None: - ENABLE_PYNVJITLINK = _pynvjitlink_enabled_in_env -elif _pynvjitlink_enabled_in_cfg is not None: - ENABLE_PYNVJITLINK = _pynvjitlink_enabled_in_cfg -else: - ENABLE_PYNVJITLINK = importlib.util.find_spec("pynvjitlink") is not None - -if not hasattr(config, "CUDA_ENABLE_PYNVJITLINK"): - config.CUDA_ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK +# Enable pynvjitlink based on the following precedence: +# 1. Config setting "CUDA_ENABLE_PYNVJITLINK" (highest priority) +# 2. Environment variable "NUMBA_CUDA_ENABLE_PYNVJITLINK" +# 3. Auto-detection of pynvjitlink module (lowest priority) +if getattr(config, "CUDA_ENABLE_PYNVJITLINK", None) is None: + if ( + _pynvjitlink_enabled_in_env := _readenv( + "NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, None + ) + ) is not None: + config.CUDA_ENABLE_PYNVJITLINK = _pynvjitlink_enabled_in_env + else: + config.CUDA_ENABLE_PYNVJITLINK = ( + importlib.util.find_spec("pynvjitlink") is not None + ) # Upstream numba sets CUDA_USE_NVIDIA_BINDING to 0 by default, so it always # exists. Override, but not if explicitly set to 0 in the envioronment.