diff --git a/tests/kernels/moe/test_marlin_block_size_policy.py b/tests/kernels/moe/test_marlin_block_size_policy.py new file mode 100644 index 000000000000..22cb8f3173c6 --- /dev/null +++ b/tests/kernels/moe/test_marlin_block_size_policy.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + _choose_marlin_block_size_m, +) +from vllm.platforms.interface import DeviceCapability +from vllm.scalar_type import scalar_types + + +def test_gpt_oss_sm89_small_m_uses_decode_like_block_size() -> None: + block_size_m, policy = _choose_marlin_block_size_m( + num_tokens=1, + num_experts=32, + topk=4, + hidden_size=2880, + quant_type=scalar_types.float4_e2m1f, + input_dtype=None, + device_capability=DeviceCapability(8, 9), + ) + + assert (block_size_m, policy) == (64, "gpt_oss_sm89_decode_like") + + +def test_gpt_oss_sm89_large_m_uses_prefill_block_size() -> None: + block_size_m, policy = _choose_marlin_block_size_m( + num_tokens=1024, + num_experts=32, + topk=4, + hidden_size=2880, + quant_type=scalar_types.float4_e2m1f, + input_dtype=None, + device_capability=DeviceCapability(8, 9), + ) + + assert (block_size_m, policy) == (32, "gpt_oss_sm89_prefill_like") + + +def test_non_sm89_gpt_oss_shape_uses_generic_policy() -> None: + block_size_m, policy = _choose_marlin_block_size_m( + num_tokens=1, + num_experts=32, + topk=4, + hidden_size=2880, + quant_type=scalar_types.float4_e2m1f, + input_dtype=None, + device_capability=DeviceCapability(9, 0), + ) + + assert (block_size_m, policy) == (8, "auto") + + +def test_generic_auto_policy_keeps_int8_floor() -> None: + block_size_m, policy = _choose_marlin_block_size_m( + num_tokens=16, + num_experts=32, + topk=1, + hidden_size=4096, + quant_type=scalar_types.uint4, + input_dtype=torch.int8, + device_capability=DeviceCapability(8, 9), + ) + + assert (block_size_m, policy) == (16, "auto") diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 136a8188d6a0..d1300467025d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -44,8 +44,70 @@ kNvfp4Static, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.scalar_type import ScalarType, scalar_types +GPT_OSS_SM89_MOE_BLOCK_SIZE_M_SMALL_M = 64 +GPT_OSS_SM89_MOE_BLOCK_SIZE_M_LARGE_M = 32 +GPT_OSS_SM89_MOE_SMALL_M_THRESHOLD = 128 + + +def _use_gpt_oss_sm89_marlin_block_size_policy( + *, + num_experts: int, + topk: int, + hidden_size: int, + quant_type: ScalarType, + device_capability: DeviceCapability | None, +) -> bool: + return ( + device_capability == DeviceCapability(8, 9) + and num_experts == 32 + and topk == 4 + and hidden_size == 2880 + and quant_type == scalar_types.float4_e2m1f + ) + + +def _choose_marlin_block_size_m( + *, + num_tokens: int, + num_experts: int, + topk: int, + hidden_size: int, + quant_type: ScalarType, + input_dtype: torch.dtype | None, + device_capability: DeviceCapability | None, +) -> tuple[int, str]: + # GPT-OSS on SM89/L40S benefits from a larger block during tiny-M + # decode-like calls and a smaller block during prefill-like calls. Keep + # this narrow to the observed GPT-OSS MXFP4 MoE problem shape. + if _use_gpt_oss_sm89_marlin_block_size_policy( + num_experts=num_experts, + topk=topk, + hidden_size=hidden_size, + quant_type=quant_type, + device_capability=device_capability, + ): + if num_tokens <= GPT_OSS_SM89_MOE_SMALL_M_THRESHOLD: + return ( + GPT_OSS_SM89_MOE_BLOCK_SIZE_M_SMALL_M, + "gpt_oss_sm89_decode_like", + ) + return ( + GPT_OSS_SM89_MOE_BLOCK_SIZE_M_LARGE_M, + "gpt_oss_sm89_prefill_like", + ) + + for block_size_m in [8, 16, 32, 48, 64]: + if num_tokens * topk / num_experts / block_size_m < 0.9: + break + + if input_dtype is not None and input_dtype.itemsize == 1: + block_size_m = max(block_size_m, 16) + + return block_size_m, "auto" + def _fused_marlin_moe( hidden_states: torch.Tensor, @@ -304,14 +366,19 @@ def fused_marlin_moe( assert num_bits in [4, 8] assert topk_weights.dtype == torch.float32 - # M block size selection logic - # TODO: tune this further for specific models - for block_size_m in [8, 16, 32, 48, 64]: - if M * topk / E / block_size_m < 0.9: - break - - if input_dtype is not None and input_dtype.itemsize == 1: - block_size_m = max(block_size_m, 16) + block_size_m, _ = _choose_marlin_block_size_m( + num_tokens=M, + num_experts=E, + topk=topk, + hidden_size=K, + quant_type=quant_type, + input_dtype=input_dtype, + device_capability=( + current_platform.get_device_capability() + if current_platform.is_cuda() + else None + ), + ) if global_num_experts == -1: global_num_experts = E