[MoE/b12x] Add cutlass_prefill_threshold hybrid dispatch to FlashInferB12xExperts#43687
Open
askliar wants to merge 2 commits into
Open
[MoE/b12x] Add cutlass_prefill_threshold hybrid dispatch to FlashInferB12xExperts#43687askliar wants to merge 2 commits into
askliar wants to merge 2 commits into
Conversation
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Contributor
|
Hi @askliar, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Contributor
|
Hi @askliar, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Stacked on top of vllm-project#43328. Plumbs FlashInfer's `cutlass_prefill_threshold` kwarg through `FlashInferB12xExperts`, enabling hybrid dispatch where the wrapper routes batches with `num_tokens >= threshold` through `cutlass_fused_moe` (prefill path) and small batches through the b12x kernels (decode path). Key changes: - New env var `VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD` (default 0, pure b12x dispatch). - When threshold > 0, `process_weights_after_loading` clones the CUTLASS-format swizzled FP8 SF + a/g alphas (with the `g_alphas/=a_gs` CUTLASS rescale) BEFORE the in-place B12x rewrite destroys them, and registers them as `nn.Parameter`s on the layer so EPLB rearranges them in lockstep with the live b12x scales. - `_ensure_wrapper` passes the kwarg, gated on `inspect.signature` to remain compatible with older FlashInfer builds (silent skip when threshold is 0, hard error when >0 and the kwarg is missing), and calls `register_cutlass_prefill_weights` once. The FP4 weight bytes are reusable between the b12x and CUTLASS paths — `prepare_nvfp4_moe_layer_for_fi_or_cutlass` produces the same `[w3, w1]` reorder + swizzled SF for both `FLASHINFER_CUTLASS` and `FLASHINFER_B12X` — so only the scales need to be cloned. This PR was prepared with AI assistance. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Andrii Skliar <askliar@nvidia.com>
580e98c to
f36ab87
Compare
Contributor
|
Hi @askliar, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Contributor
|
This pull request has merge conflicts that must be resolved before it can be |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stacked on top of #43328.
Plumbs FlashInfer's
cutlass_prefill_thresholdkwarg throughFlashInferB12xExperts, enabling hybrid dispatch where the wrapper routes batches withnum_tokens >= thresholdthroughcutlass_fused_moe(prefill path) and small batches through the b12x kernels (decode path).Key changes:
VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD(default0, pure b12x dispatch).process_weights_after_loadingclones the CUTLASS-format swizzled FP8 SF + a/g alphas (with theg_alphas /= a_gsCUTLASS rescale) before the in-place B12x rewrite destroys them, and registers them asnn.Parameters on the layer so EPLB rearranges them in lockstep with the live b12x scales._ensure_wrapperpasses the kwarg, gated oninspect.signatureto remain compatible with older FlashInfer builds (silent skip when threshold is0, hard error when>0and the kwarg is missing), and callsregister_cutlass_prefill_weightsonce.The FP4 weight bytes are reusable between the b12x and CUTLASS paths —
prepare_nvfp4_moe_layer_for_fi_or_cutlassproduces the same[w3, w1]reorder + swizzled SF for bothFLASHINFER_CUTLASSandFLASHINFER_B12X— so only the scales need to be cloned.Duplicate-work check
Not a duplicate. No open PR mentions
cutlass_prefill_thresholdor hybrid B12x/CUTLASS dispatch. #43332 (W4A16 supports-check) and #43334 (FP8 MoE backend env) touch the same area but are unrelated features.AI assistance disclosure
This PR was prepared with AI assistance (Claude Code). The submitting human reviewed every changed line.
Test plan
cutlass_prefill_threshold: run an NVFP4 MoE model withVLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD=0and confirm output matches the pre-PR baseline (silent kwarg skip path; pure b12x dispatch).VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD=128: confirm prefill batches route throughcutlass_fused_moeand decode batches through b12x kernels, and outputs match a reference run (FP-close, quality-neutral).=0: confirm normal operation (silent skip).>0: confirm the explicitRuntimeErrorfires at wrapper construction.w*_cutlass_*parameters are permuted in lockstep with the b12x scales.🤖 Generated with Claude Code