Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
Expand Down Expand Up @@ -247,7 +248,6 @@ def __init__(
self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
self.is_blackwell = current_platform.is_device_capability_family(100)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()

# Get the correct blockscale mul and input quant operations.
Expand Down Expand Up @@ -303,7 +303,7 @@ def _run_deepgemm(
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
if self.use_deep_gemm_e8m0 and self.is_blackwell:
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
Expand Down
34 changes: 27 additions & 7 deletions vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,35 @@ class DeepGemmQuantScaleFMT(Enum):
# element contains 4 scale values.
UE8M0 = 2

@staticmethod
def from_oracle() -> "DeepGemmQuantScaleFMT":
if not is_deep_gemm_e8m0_used():
return DeepGemmQuantScaleFMT.FLOAT32
return (
DeepGemmQuantScaleFMT.UE8M0
@classmethod
def init_oracle_cache(cls) -> None:
"""Initialize the oracle decision and store it in the class cache"""
cached = getattr(cls, "_oracle_cache", None)
if cached is not None:
return

use_e8m0 = (
envs.VLLM_USE_DEEP_GEMM_E8M0
and is_deep_gemm_supported()
and (_fp8_gemm_nt_impl is not None)
)
if not use_e8m0:
cls._oracle_cache = cls.FLOAT32 # type: ignore
return

cls._oracle_cache = ( # type: ignore
cls.UE8M0
if current_platform.is_device_capability_family(100)
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
else cls.FLOAT32_CEIL_UE8M0
)

@classmethod
def from_oracle(cls) -> "DeepGemmQuantScaleFMT":
"""Return the pre-initialized oracle decision"""
cached = getattr(cls, "_oracle_cache", None)
assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized"
return cached


@functools.cache
def is_deep_gemm_supported() -> bool:
Expand Down Expand Up @@ -149,6 +168,7 @@ def _lazy_init() -> None:
_transform_sf_into_required_layout_impl = getattr(
_dg, "transform_sf_into_required_layout", None
)
DeepGemmQuantScaleFMT.init_oracle_cache()


def get_num_sms() -> int:
Expand Down