diff --git a/setup.py b/setup.py index 0aa666e7b679..2a34555a361c 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,11 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] # Validate the NVCC CUDA version. +if CUDA_HOME is None: + raise ValueError( + "Could not find CUDA_HOME variable, " + "please check your CUDA installation." + ) nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if nvcc_cuda_version < Version("11.0"): raise RuntimeError("CUDA 11.0 or higher is required to build the package.")