diff --git a/op_builder/builder.py b/op_builder/builder.py index 384129a9a941..acdc721a3022 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -110,6 +110,7 @@ class OpBuilder(ABC): def __init__(self, name): self.name = name self.jit_mode = False + self.error_log = None @abstractmethod def absolute_name(self): @@ -439,6 +440,7 @@ def command_exists(self, cmd): return valid def warning(self, msg): + self.error_log = f"{msg}" print(f"{WARNING} {msg}") def deepspeed_src_path(self, code_path): @@ -471,7 +473,7 @@ def load(self, verbose=True): def jit_load(self, verbose=True): if not self.is_compatible(verbose): raise RuntimeError( - f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue." + f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}" ) try: import ninja # noqa: F401 diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 2f05230dbada..e9df633174f3 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -1,4 +1,6 @@ -from .builder import CUDAOpBuilder +import torch +from packaging import version +from .builder import CUDAOpBuilder, installed_cuda_version class InferenceBuilder(CUDAOpBuilder): @@ -12,6 +14,19 @@ def __init__(self, name=None): def absolute_name(self): return f'deepspeed.ops.transformer.inference.{self.NAME}_op' + def is_compatible(self, verbose=True): + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = version.parse(torch.version.cuda).major + cuda_capability = torch.cuda.get_device_properties(0).major + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning( + "On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + def sources(self): return [ 'csrc/transformer/inference/csrc/pt_binding.cpp',