Skip to content

[MoE] Add VLLM_FP8_MOE_BACKEND to override FP8-expert routing indepently of --moe-backend#43334

Open
ECMGit wants to merge 1 commit into
vllm-project:mainfrom
ECMGit:fp8-moe-backend-env
Open

[MoE] Add VLLM_FP8_MOE_BACKEND to override FP8-expert routing indepently of --moe-backend#43334
ECMGit wants to merge 1 commit into
vllm-project:mainfrom
ECMGit:fp8-moe-backend-env

Conversation

@ECMGit
Copy link
Copy Markdown
Contributor

@ECMGit ECMGit commented May 21, 2026

Purpose

--moe-backend is consumed by both the NVFP4 dispatcher (oracle/nvfp4.py:map_nvfp4_backend) and the FP8 dispatcher (oracle/fp8.py:map_fp8_backend). For FP4-only kernels — e.g. flashinfer_b12x, registered in the NVFP4 oracle by PR #40082 — the FP8 dispatcher has no corresponding backend, so on any mixed-precision checkpoint (NVFP4 experts + FP8 experts in the same model) map_fp8_backend raises at engine init:

ValueError: moe_backend='flashinfer_b12x' is not supported for FP8 MoE.
  Expected one of ['triton', 'deep_gemm', 'cutlass', 'flashinfer_trtllm',
                   'flashinfer_cutlass', 'marlin', 'aiter'].

This blocks every mixed-precision checkpoint that pairs NVFP4 expert layers with FP8 ones — e.g. nvidia/Qwen3.6-35B-A3B-2.06GB-per-token on DGX Spark (GB10 / sm_121a). Such checkpoints are the norm for the 2.06GB-per-token family.

Change:

Introduce VLLM_FP8_MOE_BACKEND so callers can route FP8 experts to a different backend while keeping --moe-backend=flashinfer_b12x for the NVFP4 majority.

  • vllm/envs.py: declare VLLM_FP8_MOE_BACKEND (literal of valid FP8 backend strings, default None).
  • vllm/model_executor/layers/fused_moe/oracle/fp8.py: map_fp8_backend reads envs.VLLM_FP8_MOE_BACKEND as an override before consulting runner_backend; error message identifies the source.

Unset behavior is unchanged — only callers that explicitly set the env var see the new dispatch.

Alternatives considered:

  1. Extend b12x to handle FP8 experts — not feasible; b12x is FP4-only by construction (class docstring: "Only NVFP4 (kNvfp4Static/kNvfp4Dynamic) quantization is supported").
  2. Auto-fall-back inside map_fp8_backend — opaque to callers and surprising; an explicit env-var opt-in is safer until the FP4/FP8 split is formalized in --moe-backend.
  3. Per-layer-class backend selection in the model card — larger refactor; this env var is the minimal unblocker.

Risk: Low. Adds one env var (default unset → unchanged behavior). map_fp8_backend only branches when the env var is set; the existing error path is preserved for misconfigured values.

Test Plan

Hardware: DGX Spark (GB10, sm_121a)
Container: vllm/vllm-openai:nightly-aarch64 + this PR + the b12x-W4A16-supports companion PR
Model: nvidia/Qwen3.6-35B-A3B-2.06GB-per-token (modelopt-native, mixed NVFP4 + FP8 experts)
Tooling: FlashInfer 0.6.11.post3, cutlass-dsl trio 4.5.1

Recipe:

export VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x
export VLLM_USE_FLASHINFER_MOE_FP4=1
export VLLM_FP8_MOE_BACKEND=flashinfer_cutlass    # THIS PR — routes FP8 experts via FI-cutlass
export FLASHINFER_DISABLE_VERSION_CHECK=1
export CUTE_DSL_ARCH=sm_121a

vllm serve nvidia/Qwen3.6-35B-A3B-2.06GB-per-token \
    --tensor-parallel-size 1 --trust-remote-code --dtype auto \
    --kv-cache-dtype fp8 --attention-backend FLASHINFER \
    --gpu-memory-utilization 0.85 --max-model-len 40960 \
    --max-num-seqs 4 --max-num-batched-tokens 8192 \
    --enable-chunked-prefill --async-scheduling --enable-prefix-caching \
    --moe-backend=flashinfer_b12x --quantization=modelopt \
    --compilation-config '{"pass_config":{"fuse_norm_quant":true,"fuse_act_quant":true,"fuse_attn_quant":false}}' \
    --speculative-config '{"method":"mtp","num_speculative_tokens":3,"rejection_sample_method":"synthetic","synthetic_acceptance_length":3.12,"moe_backend":"triton"}'

Benchmark (aiperf): K=3 AL=3.12, BS=1 / concurrency=1, ISL=2048 + 32K user-context (total ISL=34,831), OSL=1024, 10 warmup + 60 measured requests, --use-server-token-count --streaming.

Negative test (without VLLM_FP8_MOE_BACKEND set): verify the error path still triggers and the message correctly identifies the source.

Test Result

Without this PR (current behavior on mixed-precision checkpoint):

ValueError: moe_backend='flashinfer_b12x' is not supported for FP8 MoE.
  Expected one of ['triton', 'deep_gemm', 'cutlass', 'flashinfer_trtllm',
                   'flashinfer_cutlass', 'marlin', 'aiter'].

Serve dies at engine init; no throughput measurement possible.

With this PR + VLLM_FP8_MOE_BACKEND=flashinfer_cutlass:

Metric Value
Output Token Throughput 91.00 tok/s
Output Token Throughput / user 97.42 tok/s/user
TTFT 746.81 ms
ITL 10.27 ms
Request Latency 11,249.37 ms
MTP acceptance length 3.15 (target 3.12) ✓
Requests / errors 60 / 0

With this PR + env var unset (regression sanity): same error as "Without this PR" above is raised, now annotated with (from --moe-backend) to disambiguate the source.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

Agent Guidelines

IMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban.

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the VLLM_FP8_MOE_BACKEND environment variable, allowing users to override the FP8 MoE backend independently of the primary --moe-backend flag. This change addresses scenarios where a specific backend (like flashinfer_b12x for NVFP4) lacks an FP8 equivalent, preventing potential runtime errors in mixed-precision configurations. The map_fp8_backend function has been updated to prioritize this new environment variable and includes improved error reporting to clarify the source of the backend configuration. I have no feedback to provide as there were no review comments to evaluate.

@ECMGit
Copy link
Copy Markdown
Contributor Author

ECMGit commented May 21, 2026

add @meena-at-work as reviewer

@ECMGit ECMGit marked this pull request as ready for review May 21, 2026 17:43
@ECMGit ECMGit force-pushed the fp8-moe-backend-env branch from 909f5e7 to 36b146b Compare May 25, 2026 09:51
@xinli-sw
Copy link
Copy Markdown
Contributor

does #43148 solve your issue?

@xinli-sw
Copy link
Copy Markdown
Contributor

if not can you rebase the fix onto that approach?

…dently of --moe-backend

`--moe-backend` is consumed by both the NVFP4 dispatcher
(`oracle/nvfp4.py:map_nvfp4_backend`) and the FP8 dispatcher
(`oracle/fp8.py:map_fp8_backend`). For FP4-only kernels -- e.g.
`flashinfer_b12x` registered in NVFP4 oracle by PR vllm-project#40082 -- the FP8
dispatcher has no corresponding backend, so on any mixed-precision
checkpoint (NVFP4 experts + FP8 experts in the same model)
`map_fp8_backend` raises at engine init:

  ValueError: moe_backend='flashinfer_b12x' is not supported for FP8 MoE.
    Expected one of ['triton', 'deep_gemm', 'cutlass',
                     'flashinfer_trtllm', 'flashinfer_cutlass',
                     'marlin', 'aiter'].

This blocks every mixed-precision checkpoint that pairs NVFP4 expert
layers with FP8 ones -- e.g. `nvidia/Qwen3.6-35B-A3B-2.06GB-per-token`
on DGX Spark (GB10 / sm_121a). Such checkpoints are the norm for the
2.06GB-per-token family.

Introduce `VLLM_FP8_MOE_BACKEND` so callers can route FP8 experts to a
different backend (e.g. `flashinfer_cutlass`) while keeping
`--moe-backend=flashinfer_b12x` for the NVFP4 majority. Unset behavior
is unchanged -- only callers that explicitly set the env var see the
new dispatch.

PR vllm-project#42546 ("Support Qwen3.5/3.6 VLM quantized prefix mapping") is
unrelated scope (only touches quantization/modelopt.py). There is no
other in-review PR proposing VLLM_FP8_MOE_BACKEND; this is a clean
net-new env knob.

Alternatives considered:
1. Extend b12x to handle FP8 experts -- not feasible; b12x is FP4-only
   by construction (class docstring: "Only NVFP4 (kNvfp4Static/
   kNvfp4Dynamic) quantization is supported").
2. Auto-fall-back inside map_fp8_backend -- opaque and surprising; an
   explicit env-var opt-in is safer until the FP4/FP8 split is
   formalized in --moe-backend.
3. Per-layer-class backend selection in the model card -- larger
   refactor; this env var is the minimal unblocker.

Tested on DGX Spark (GB10, sm_121a) with vllm/vllm-openai:nightly-aarch64
+ this PR + the b12x-W4A16-supports companion PR.
Model: nvidia/Qwen3.6-35B-A3B-2.06GB-per-token (modelopt-native,
mixed NVFP4 + FP8 experts).
Recipe:
  export VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x
  export VLLM_USE_FLASHINFER_MOE_FP4=1
  export VLLM_FP8_MOE_BACKEND=flashinfer_cutlass   # THIS PR
  vllm serve nvidia/Qwen3.6-35B-A3B-2.06GB-per-token \
      --moe-backend=flashinfer_b12x --quantization=modelopt ...

aiperf K=3 AL=3.12, BS=1, ISL=2048+32K prefix=34,831, OSL=1024,
60 measured + 10 warmup, 0 errors:
  Output Token Throughput        : 91.00 tok/s
  Output Token Throughput / user : 97.42 tok/s/user
  TTFT                           : 746.81 ms
  ITL                            : 10.27 ms
  MTP acceptance length          : 3.15 (target 3.12)

Without this change, serve dies at engine init with the ValueError
above before any throughput can be measured.

Signed-off-by: Junhao Shen <junshen@nvidia.com>
@ECMGit ECMGit force-pushed the fp8-moe-backend-env branch from 36b146b to 7cc70c6 Compare May 26, 2026 11:42
@ECMGit
Copy link
Copy Markdown
Contributor Author

ECMGit commented May 26, 2026

if not can you rebase the fix onto that approach?

I've done the rebase, please check it out.

@mgoin
Copy link
Copy Markdown
Member

mgoin commented May 26, 2026

@ECMGit @xinli-sw could we solve this by introducing granularity to --moe-backend instead of adding new environment variables? For instance --fp8-moe-backend would be fine with me. Or expanding --moe-backend into a config so you could do --moe-backend.fp8

I'm curious why you need to specify --moe-backend in this case at all though, since I assume the b12x backend will be chosen by default as the most performance one on spark. If not, we should fix the heuristics rather than rely on user's to specify the right (increasingly complex) args

@ECMGit
Copy link
Copy Markdown
Contributor Author

ECMGit commented May 26, 2026

@ECMGit @xinli-sw could we solve this by introducing granularity to --moe-backend instead of adding new environment variables? For instance --fp8-moe-backend would be fine with me. Or expanding --moe-backend into a config so you could do --moe-backend.fp8

I'm curious why you need to specify --moe-backend in this case at all though, since I assume the b12x backend will be chosen by default as the most performance one on spark. If not, we should fix the heuristics rather than rely on user's to specify the right (increasingly complex) args

Hi @mgoin thanks for reviewing it.

let me explain the intention for this fix: this is for mixed-precision checkpoints like interleave NVFP4 and FP8 experts in the same model. if b12x as default --moe-backend is consumed by both the NVFP4 and FP8 dispatchers, so picking flashinfer_b12x for the FP4 majority crashes map_fp8_backend and b12x is FP4-only by construction. A per-dtype override could be the minimal fix.

I can switch to --fp8-moe-backend as a optional arg if that is more surgical.

@mgoin
Copy link
Copy Markdown
Member

mgoin commented May 26, 2026

@ECMGit I think you missed my second point - why do you need to set --moe-backend at all? If you don't explicitly set that argument, vLLM can already select different backends through the default dispatch within each oracle. The best case we aim for is to select the best kernel by default, so I think you should make it so b12x is selected by default on your system for your case if it is the best

@ECMGit
Copy link
Copy Markdown
Contributor Author

ECMGit commented May 28, 2026

Hi @mgoin , I think you are right. I've ran and re-test it on my end:
aiperf headline on nightly + only PR #43332 applied (no VLLM_FP8_MOE_BACKEND, no --moe-backend):
verified:

  • FP4 oracle auto-picks FLASHINFER_B12X; FP8 Linear auto-binds FlashInferFP8ScaledMMLinearKernel.
  • 60/60 requests, OTT 93.99 tok/s (within 3% of the 2026-05-21 baseline of 91.00), TTFT 767ms, MTP acceptance length 3.10–3.13.

PR #43334 is not needed for the Qwen3.6-35B mixed-precision checkpoint.
however PR #43332 is needed for running this checkpoint
I am Happy to close #43334 or rescope to just port to use auto dispatch even when moe_backend is explictly assigned

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants