Skip to content

[MoE/b12x] Add cutlass_prefill_threshold hybrid dispatch to FlashInferB12xExperts#43687

Open
askliar wants to merge 2 commits into
vllm-project:mainfrom
askliar:askliar/b12x-cutlass-prefill-threshold
Open

[MoE/b12x] Add cutlass_prefill_threshold hybrid dispatch to FlashInferB12xExperts#43687
askliar wants to merge 2 commits into
vllm-project:mainfrom
askliar:askliar/b12x-cutlass-prefill-threshold

Conversation

@askliar
Copy link
Copy Markdown
Contributor

@askliar askliar commented May 26, 2026

Summary

Stacked on top of #43328.

Plumbs FlashInfer's cutlass_prefill_threshold kwarg through FlashInferB12xExperts, enabling hybrid dispatch where the wrapper routes batches with num_tokens >= threshold through cutlass_fused_moe (prefill path) and small batches through the b12x kernels (decode path).

Key changes:

  • New env var VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD (default 0, pure b12x dispatch).
  • When threshold > 0, process_weights_after_loading clones the CUTLASS-format swizzled FP8 SF + a/g alphas (with the g_alphas /= a_gs CUTLASS rescale) before the in-place B12x rewrite destroys them, and registers them as nn.Parameters on the layer so EPLB rearranges them in lockstep with the live b12x scales.
  • _ensure_wrapper passes the kwarg, gated on inspect.signature to remain compatible with older FlashInfer builds (silent skip when threshold is 0, hard error when >0 and the kwarg is missing), and calls register_cutlass_prefill_weights once.

The FP4 weight bytes are reusable between the b12x and CUTLASS paths — prepare_nvfp4_moe_layer_for_fi_or_cutlass produces the same [w3, w1] reorder + swizzled SF for both FLASHINFER_CUTLASS and FLASHINFER_B12X — so only the scales need to be cloned.

Duplicate-work check

Not a duplicate. No open PR mentions cutlass_prefill_threshold or hybrid B12x/CUTLASS dispatch. #43332 (W4A16 supports-check) and #43334 (FP8 MoE backend env) touch the same area but are unrelated features.

AI assistance disclosure

This PR was prepared with AI assistance (Claude Code). The submitting human reviewed every changed line.

Test plan

  • On SM120/121 with a FlashInfer build that exposes cutlass_prefill_threshold: run an NVFP4 MoE model with VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD=0 and confirm output matches the pre-PR baseline (silent kwarg skip path; pure b12x dispatch).
  • Same setup with VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD=128: confirm prefill batches route through cutlass_fused_moe and decode batches through b12x kernels, and outputs match a reference run (FP-close, quality-neutral).
  • With an older FlashInfer that lacks the kwarg and threshold =0: confirm normal operation (silent skip).
  • With an older FlashInfer that lacks the kwarg and threshold >0: confirm the explicit RuntimeError fires at wrapper construction.
  • EPLB rearrangement test: confirm the registered w*_cutlass_* parameters are permuted in lockstep with the b12x scales.

🤖 Generated with Claude Code

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 26, 2026

Hi @askliar, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 26, 2026

Hi @askliar, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Stacked on top of vllm-project#43328.

Plumbs FlashInfer's `cutlass_prefill_threshold` kwarg through
`FlashInferB12xExperts`, enabling hybrid dispatch where the wrapper
routes batches with `num_tokens >= threshold` through `cutlass_fused_moe`
(prefill path) and small batches through the b12x kernels (decode path).

Key changes:
- New env var `VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD` (default 0,
  pure b12x dispatch).
- When threshold > 0, `process_weights_after_loading` clones the
  CUTLASS-format swizzled FP8 SF + a/g alphas (with the `g_alphas/=a_gs`
  CUTLASS rescale) BEFORE the in-place B12x rewrite destroys them, and
  registers them as `nn.Parameter`s on the layer so EPLB rearranges them
  in lockstep with the live b12x scales.
- `_ensure_wrapper` passes the kwarg, gated on `inspect.signature` to
  remain compatible with older FlashInfer builds (silent skip when
  threshold is 0, hard error when >0 and the kwarg is missing), and
  calls `register_cutlass_prefill_weights` once.

The FP4 weight bytes are reusable between the b12x and CUTLASS paths —
`prepare_nvfp4_moe_layer_for_fi_or_cutlass` produces the same `[w3, w1]`
reorder + swizzled SF for both `FLASHINFER_CUTLASS` and `FLASHINFER_B12X`
— so only the scales need to be cloned.

This PR was prepared with AI assistance.

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@askliar askliar force-pushed the askliar/b12x-cutlass-prefill-threshold branch from 580e98c to f36ab87 Compare May 26, 2026 19:44
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 26, 2026

Hi @askliar, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jun 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @askliar.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant