Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class OpBuilder(ABC):
def __init__(self, name):
self.name = name
self.jit_mode = False
self.error_log = None

@abstractmethod
def absolute_name(self):
Expand Down Expand Up @@ -439,6 +440,7 @@ def command_exists(self, cmd):
return valid

def warning(self, msg):
self.error_log = f"{msg}"
print(f"{WARNING} {msg}")

def deepspeed_src_path(self, code_path):
Expand Down Expand Up @@ -471,7 +473,7 @@ def load(self, verbose=True):
def jit_load(self, verbose=True):
if not self.is_compatible(verbose):
raise RuntimeError(
f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue."
f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}"
)
try:
import ninja # noqa: F401
Expand Down
17 changes: 16 additions & 1 deletion op_builder/transformer_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .builder import CUDAOpBuilder
import torch
from packaging import version
from .builder import CUDAOpBuilder, installed_cuda_version


class InferenceBuilder(CUDAOpBuilder):
Expand All @@ -12,6 +14,19 @@ def __init__(self, name=None):
def absolute_name(self):
return f'deepspeed.ops.transformer.inference.{self.NAME}_op'

def is_compatible(self, verbose=True):
cuda_okay = True
if not self.is_rocm_pytorch() and torch.cuda.is_available():
sys_cuda_major, _ = installed_cuda_version()
torch_cuda_major = version.parse(torch.version.cuda).major
cuda_capability = torch.cuda.get_device_properties(0).major
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
self.warning(
"On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay

def sources(self):
return [
'csrc/transformer/inference/csrc/pt_binding.cpp',
Expand Down