Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/kernels/moe/test_marlin_block_size_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
_choose_marlin_block_size_m,
)
from vllm.platforms.interface import DeviceCapability
from vllm.scalar_type import scalar_types


def test_gpt_oss_sm89_small_m_uses_decode_like_block_size() -> None:
block_size_m, policy = _choose_marlin_block_size_m(
num_tokens=1,
num_experts=32,
topk=4,
hidden_size=2880,
quant_type=scalar_types.float4_e2m1f,
input_dtype=None,
device_capability=DeviceCapability(8, 9),
)

assert (block_size_m, policy) == (64, "gpt_oss_sm89_decode_like")


def test_gpt_oss_sm89_large_m_uses_prefill_block_size() -> None:
block_size_m, policy = _choose_marlin_block_size_m(
num_tokens=1024,
num_experts=32,
topk=4,
hidden_size=2880,
quant_type=scalar_types.float4_e2m1f,
input_dtype=None,
device_capability=DeviceCapability(8, 9),
)

assert (block_size_m, policy) == (32, "gpt_oss_sm89_prefill_like")


def test_non_sm89_gpt_oss_shape_uses_generic_policy() -> None:
block_size_m, policy = _choose_marlin_block_size_m(
num_tokens=1,
num_experts=32,
topk=4,
hidden_size=2880,
quant_type=scalar_types.float4_e2m1f,
input_dtype=None,
device_capability=DeviceCapability(9, 0),
)

assert (block_size_m, policy) == (8, "auto")


def test_generic_auto_policy_keeps_int8_floor() -> None:
block_size_m, policy = _choose_marlin_block_size_m(
num_tokens=16,
num_experts=32,
topk=1,
hidden_size=4096,
quant_type=scalar_types.uint4,
input_dtype=torch.int8,
device_capability=DeviceCapability(8, 9),
)

assert (block_size_m, policy) == (16, "auto")
83 changes: 75 additions & 8 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,70 @@
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.scalar_type import ScalarType, scalar_types

GPT_OSS_SM89_MOE_BLOCK_SIZE_M_SMALL_M = 64
GPT_OSS_SM89_MOE_BLOCK_SIZE_M_LARGE_M = 32
GPT_OSS_SM89_MOE_SMALL_M_THRESHOLD = 128


def _use_gpt_oss_sm89_marlin_block_size_policy(
*,
num_experts: int,
topk: int,
hidden_size: int,
quant_type: ScalarType,
device_capability: DeviceCapability | None,
) -> bool:
return (
device_capability == DeviceCapability(8, 9)
and num_experts == 32
and topk == 4
and hidden_size == 2880
and quant_type == scalar_types.float4_e2m1f
)
Comment on lines +50 to +69
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better readability and maintainability, especially for such a narrow, hardware-specific performance policy, it's good practice to define the magic numbers for the GPT-OSS 20B MoE shape as constants. This makes the code easier to understand and modify in the future.

GPT_OSS_SM89_MOE_BLOCK_SIZE_M_SMALL_M = 64
GPT_OSS_SM89_MOE_BLOCK_SIZE_M_LARGE_M = 32
GPT_OSS_SM89_MOE_SMALL_M_THRESHOLD = 128


# GPT-OSS 20B MoE shape constants
GPT_OSS_20B_MOE_NUM_EXPERTS = 32
GPT_OSS_20B_MOE_TOP_K = 4
GPT_OSS_20B_MOE_HIDDEN_SIZE = 2880


def _use_gpt_oss_sm89_marlin_block_size_policy(
    *,
    num_experts: int,
    topk: int,
    hidden_size: int,
    quant_type: ScalarType,
    device_capability: DeviceCapability | None,
) -> bool:
    return (
        device_capability == DeviceCapability(8, 9)
        and num_experts == GPT_OSS_20B_MOE_NUM_EXPERTS
        and topk == GPT_OSS_20B_MOE_TOP_K
        and hidden_size == GPT_OSS_20B_MOE_HIDDEN_SIZE
        and quant_type == scalar_types.float4_e2m1f
    )



def _choose_marlin_block_size_m(
*,
num_tokens: int,
num_experts: int,
topk: int,
hidden_size: int,
quant_type: ScalarType,
input_dtype: torch.dtype | None,
device_capability: DeviceCapability | None,
) -> tuple[int, str]:
# GPT-OSS on SM89/L40S benefits from a larger block during tiny-M
# decode-like calls and a smaller block during prefill-like calls. Keep
# this narrow to the observed GPT-OSS MXFP4 MoE problem shape.
if _use_gpt_oss_sm89_marlin_block_size_policy(
num_experts=num_experts,
topk=topk,
hidden_size=hidden_size,
quant_type=quant_type,
device_capability=device_capability,
):
if num_tokens <= GPT_OSS_SM89_MOE_SMALL_M_THRESHOLD:
return (
GPT_OSS_SM89_MOE_BLOCK_SIZE_M_SMALL_M,
"gpt_oss_sm89_decode_like",
)
return (
GPT_OSS_SM89_MOE_BLOCK_SIZE_M_LARGE_M,
"gpt_oss_sm89_prefill_like",
)

for block_size_m in [8, 16, 32, 48, 64]:
if num_tokens * topk / num_experts / block_size_m < 0.9:
break

if input_dtype is not None and input_dtype.itemsize == 1:
block_size_m = max(block_size_m, 16)

return block_size_m, "auto"


def _fused_marlin_moe(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -304,14 +366,19 @@ def fused_marlin_moe(
assert num_bits in [4, 8]
assert topk_weights.dtype == torch.float32

# M block size selection logic
# TODO: tune this further for specific models
for block_size_m in [8, 16, 32, 48, 64]:
if M * topk / E / block_size_m < 0.9:
break

if input_dtype is not None and input_dtype.itemsize == 1:
block_size_m = max(block_size_m, 16)
block_size_m, _ = _choose_marlin_block_size_m(
num_tokens=M,
num_experts=E,
topk=topk,
hidden_size=K,
quant_type=quant_type,
input_dtype=input_dtype,
device_capability=(
current_platform.get_device_capability()
if current_platform.is_cuda()
else None
),
)

if global_num_experts == -1:
global_num_experts = E
Expand Down
Loading