diff --git a/op_builder/sparse_attn.py b/op_builder/sparse_attn.py index 4c716f859970..c3fa5624b25e 100644 --- a/op_builder/sparse_attn.py +++ b/op_builder/sparse_attn.py @@ -25,6 +25,18 @@ def is_compatible(self): command_status = list(map(self.command_exists, required_commands)) deps_compatible = all(command_status) + # torch-cpu will not have a cuda version + if torch.version.cuda is None: + cuda_compatible = False + self.warning(f"{self.NAME} cuda is not available from torch") + else: + major, minor = torch.version.cuda.split('.')[:2] + cuda_compatible = int(major) == 10 and int(minor) >= 1 + if not cuda_compatible: + self.warning( + f"{self.NAME} requires CUDA version 10.1+, does not currently support >=11 or <10.1" + ) + TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5 @@ -33,4 +45,5 @@ def is_compatible(self): f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}' ) - return super().is_compatible() and deps_compatible and torch_compatible + return super().is_compatible( + ) and deps_compatible and torch_compatible and cuda_compatible