Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/transformers/integrations/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

from ..utils import logging
from ..utils.generic import GeneralInterface
from ..utils.import_utils import is_torch_available, is_torch_less_or_equal, is_torchdynamo_compiling
from ..utils.import_utils import (
is_torch_available,
is_torch_greater_or_equal,
is_torch_less_or_equal,
is_torchdynamo_compiling,
)


if is_torch_available():
Expand Down Expand Up @@ -275,6 +280,20 @@ def _can_use_grouped_mm(input: torch.Tensor, weight: torch.Tensor, offs: torch.T
# issue: https://github.com/pytorch/pytorch/issues/172440
return False

# On CUDA, `grouped_mm` availability also depends on GPU compute capability:
# `torch.nn.functional.grouped_mm` in torch>=2.10 and `torch._grouped_mm` in torch>=2.9 support SM80+
# but older `torch._grouped_mm` requires SM90+.
if weight.device.type == "cuda":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a small comment here for clarification

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

if hasattr(torch.nn.functional, "grouped_mm"):
return torch.cuda.get_device_capability(weight.device) >= (8, 0)
if hasattr(torch, "_grouped_mm"):
if is_torch_greater_or_equal("2.9", accept_dev=True):
return torch.cuda.get_device_capability(weight.device) >= (8, 0)
else:
return torch.cuda.get_device_capability(weight.device) >= (9, 0)

return False

return hasattr(torch.nn.functional, "grouped_mm") or hasattr(torch, "_grouped_mm")


Expand Down
Loading