Skip to content

[MoE][GPT-OSS] Add L40S/SM89 Marlin block-size policy#38054

Draft
will-deines wants to merge 1 commit intovllm-project:mainfrom
will-deines:feature/gpt-oss-l40s-moe-policy
Draft

[MoE][GPT-OSS] Add L40S/SM89 Marlin block-size policy#38054
will-deines wants to merge 1 commit intovllm-project:mainfrom
will-deines:feature/gpt-oss-l40s-moe-policy

Conversation

@will-deines
Copy link
Copy Markdown

Purpose

This draft follows up the GPT-OSS L40S attention-policy work with a narrow
Marlin MoE runtime policy for the GPT-OSS 20B shape on SM89 / L40S.

The change keeps the existing generic Marlin block_size_m heuristic as the
default, but adds a model- and device-specific override for the observed
GPT-OSS MXFP4 MoE shape on L40S:

  • choose block_size_m=64 for tiny-M decode-like calls
  • choose block_size_m=32 for larger-M prefill-like calls

The policy is intentionally narrow:

  • DeviceCapability(8, 9) only
  • GPT-OSS 20B MoE shape only (hidden_size=2880, num_experts=32, top_k=4)
  • MXFP4 MoE path only

Everything else keeps the existing generic Marlin auto policy unchanged.

Why This Is Not Duplicating Existing Open PRs

Motivation

Deployed L40S benchmark experiments on GPT-OSS 20B showed that the current
generic Marlin block_size_m heuristic is not the best fit for the GPT-OSS
MoE shape on L40S.

In local deploy sweeps on Modal L40S endpoints:

  • block_size_m=48 consistently regressed and was rejected
  • block_size_m=32 was strong for long-prefill control cases
  • block_size_m=64 was best on the decode-heavy case

This draft encodes that result as a narrow runtime selector instead of
requiring separate deployment variants.

Test Plan

source ~/.venvs/modal-test/bin/activate
pytest tests/kernels/moe/test_marlin_block_size_policy.py \
  tests/v1/attention/test_cuda_attention_backend_policy.py \
  tests/kernels/attention/test_triton_unified_attention_tile_policy.py -q

UV_CACHE_DIR=/tmp/uv-cache UV_TOOL_DIR=/tmp/uv-tools \
  uvx pre-commit run --files \
  vllm/model_executor/layers/fused_moe/fused_marlin_moe.py \
  tests/kernels/moe/test_marlin_block_size_policy.py

Test Result

  • pytest ... -q -> 17 passed
  • file-scoped pre-commit -> passed

Manual validation on deployed Modal L40S GPT-OSS 20B endpoints:

  • block_size_m=48 regressed both the decode-heavy case and the long-prefill
    control versus baseline
  • block_size_m=32 improved the decode-heavy case modestly and preserved a
    strong long-prefill control result
  • block_size_m=64 improved the decode-heavy case more than 32 while
    remaining strongly better than baseline on the long-prefill control

Representative deployed results versus the same baseline:

  • decode-heavy case, concurrency 8
    • b32: -9.89% median per-request total
    • b64: -13.91% median per-request total
  • long-prefill control, concurrency 1
    • b32: -40.08% median per-request total
    • b64: -45.04% median per-request total

This draft uses 64 for tiny-M decode-like calls and 32 for larger-M
prefill-like calls to reflect that deployed sweep.

AI Assistance

AI assistance was used to help implement the selector, write the focused
tests, and analyze the L40S benchmark results. I reviewed every changed line
and ran the commands above.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces model-aware attention backend and Marlin MoE block size policies. It adds logic to prioritize attention backends (Triton, FlashAttention, FlashInfer) for GPT-OSS models with attention sinks based on CUDA device capability (SM8x, SM9x, SM100+). Additionally, it implements a new policy for Marlin MoE block size selection, optimizing for GPT-OSS models on SM89, and introduces a mechanism to use smaller Triton unified attention sink tiles on SM8x GPUs with less shared memory. The review comments suggest improving readability and maintainability by defining magic numbers as named constants in both the Marlin MoE and Triton attention tile size selection logic.

Comment on lines +50 to +69
GPT_OSS_SM89_MOE_BLOCK_SIZE_M_SMALL_M = 64
GPT_OSS_SM89_MOE_BLOCK_SIZE_M_LARGE_M = 32
GPT_OSS_SM89_MOE_SMALL_M_THRESHOLD = 128


def _use_gpt_oss_sm89_marlin_block_size_policy(
*,
num_experts: int,
topk: int,
hidden_size: int,
quant_type: ScalarType,
device_capability: DeviceCapability | None,
) -> bool:
return (
device_capability == DeviceCapability(8, 9)
and num_experts == 32
and topk == 4
and hidden_size == 2880
and quant_type == scalar_types.float4_e2m1f
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better readability and maintainability, especially for such a narrow, hardware-specific performance policy, it's good practice to define the magic numbers for the GPT-OSS 20B MoE shape as constants. This makes the code easier to understand and modify in the future.

GPT_OSS_SM89_MOE_BLOCK_SIZE_M_SMALL_M = 64
GPT_OSS_SM89_MOE_BLOCK_SIZE_M_LARGE_M = 32
GPT_OSS_SM89_MOE_SMALL_M_THRESHOLD = 128


# GPT-OSS 20B MoE shape constants
GPT_OSS_20B_MOE_NUM_EXPERTS = 32
GPT_OSS_20B_MOE_TOP_K = 4
GPT_OSS_20B_MOE_HIDDEN_SIZE = 2880


def _use_gpt_oss_sm89_marlin_block_size_policy(
    *,
    num_experts: int,
    topk: int,
    hidden_size: int,
    quant_type: ScalarType,
    device_capability: DeviceCapability | None,
) -> bool:
    return (
        device_capability == DeviceCapability(8, 9)
        and num_experts == GPT_OSS_20B_MOE_NUM_EXPERTS
        and topk == GPT_OSS_20B_MOE_TOP_K
        and hidden_size == GPT_OSS_20B_MOE_HIDDEN_SIZE
        and quant_type == scalar_types.float4_e2m1f
    )

Comment on lines +863 to +877
def _use_small_sm8x_sink_tiles(
device_capability: DeviceCapability | None,
has_sinks: bool,
) -> bool:
"""Prefer smaller sink tiles on Ada/GA10x-class SM8x GPUs.

SM86/SM89 parts have materially less shared memory per SM than SM80,
so the sink-capable unified Triton path benefits from a smaller tile.
"""
return (
has_sinks
and device_capability is not None
and device_capability.major == 8
and device_capability.minor in (6, 9)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To improve readability and maintainability, it's better to define the magic numbers for device minor versions as a named constant. This makes the code's intent clearer and easier to update if more device types are added in the future.

Suggested change
def _use_small_sm8x_sink_tiles(
device_capability: DeviceCapability | None,
has_sinks: bool,
) -> bool:
"""Prefer smaller sink tiles on Ada/GA10x-class SM8x GPUs.
SM86/SM89 parts have materially less shared memory per SM than SM80,
so the sink-capable unified Triton path benefits from a smaller tile.
"""
return (
has_sinks
and device_capability is not None
and device_capability.major == 8
and device_capability.minor in (6, 9)
)
_SM8X_DEVICES_WITH_LESS_SHARED_MEM = (6, 9)
def _use_small_sm8x_sink_tiles(
device_capability: DeviceCapability | None,
has_sinks: bool,
) -> bool:
"""Prefer smaller sink tiles on Ada/GA10x-class SM8x GPUs.
SM86/SM89 parts have materially less shared memory per SM than SM80,
so the sink-capable unified Triton path benefits from a smaller tile.
"""
return (
has_sinks
and device_capability is not None
and device_capability.major == 8
and device_capability.minor in _SM8X_DEVICES_WITH_LESS_SHARED_MEM
)

Co-authored-by: OpenAI Codex <noreply@openai.com>

Signed-off-by: Will Deines <will@garr.io>
(cherry picked from commit b43bcfd)
Signed-off-by: Will Deines <will@garr.io>
@will-deines will-deines force-pushed the feature/gpt-oss-l40s-moe-policy branch from b43bcfd to d761c0b Compare March 26, 2026 20:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models nvidia v1

Projects

Status: No status
Status: To Triage

Development

Successfully merging this pull request may close these issues.

2 participants