Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,34 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
return None


if not has_deep_gemm():
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None
else:
_dg = importlib.import_module("deep_gemm") # type: ignore

_fp8_gemm_nt_impl = _resolve_symbol(
_dg,
"fp8_gemm_nt",
"gemm_fp8_fp8_bf16_nt",
)
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None


def _lazy_init() -> None:
"""Import deep_gemm and resolve symbols on first use."""
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl

# fast path
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
or _grouped_masked_impl is not None):
return

if not has_deep_gemm():
return

_dg = importlib.import_module("deep_gemm")

_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
"gemm_fp8_fp8_bf16_nt")
_grouped_impl = _resolve_symbol(
_dg,
"m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
)
_dg, "m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
_grouped_masked_impl = _resolve_symbol(
_dg,
"fp8_m_grouped_gemm_nt_masked",
"m_grouped_gemm_fp8_fp8_bf16_nt_masked",
)

_dg, "fp8_m_grouped_gemm_nt_masked",
"m_grouped_gemm_fp8_fp8_bf16_nt_masked")
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
try:
_math_mod = importlib.import_module(
Expand All @@ -80,24 +84,28 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:


def fp8_gemm_nt(*args, **kwargs):
_lazy_init()
if _fp8_gemm_nt_impl is None:
return _missing(*args, **kwargs)
return _fp8_gemm_nt_impl(*args, **kwargs)


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _grouped_impl is None:
return _missing(*args, **kwargs)
return _grouped_impl(*args, **kwargs)


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
_lazy_init()
if _grouped_masked_impl is None:
return _missing(*args, **kwargs)
return _grouped_masked_impl(*args, **kwargs)


def per_block_cast_to_fp8(x, *args, **kwargs):
_lazy_init()
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
return _per_block_cast_impl(x, use_ue8m0=True)
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
Expand Down