Enable B12x backend for non-gated MoEs (like Nemotron) #43328
Conversation
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request updates the FlashInfer B12x MoE implementation to use the B12xMoEWrapper and adds support for the ReLU2 activation function. Key changes include lazy wrapper initialization, updated weight processing in tests, and support for non-gated MoE configurations. Review feedback identifies a missing Any import that would cause a runtime error and suggests optimizing the apply method by passing the output buffer directly to the wrapper's run call to avoid unnecessary tensor allocations.
| @@ -20,7 +20,6 @@ | |||
| ) | |||
| from vllm.platforms import current_platform | |||
There was a problem hiding this comment.
| result = self._wrapper.run( | ||
| x=hidden_states, | ||
| token_selected_experts=topk_ids.to(torch.int32), | ||
| token_final_scales=topk_weights, | ||
| w1_weight=w1, | ||
| w1_weight_sf=self.w1_sf_mma, | ||
| w1_alpha=self.g1_alphas, | ||
| fc2_input_scale=self.a2_gscale, | ||
| w2_weight=w2, | ||
| w2_weight_sf=self.w2_sf_mma, | ||
| w2_alpha=self.g2_alphas, | ||
| num_experts=global_num_experts, | ||
| top_k=top_k, | ||
| num_local_experts=self.num_local_experts, | ||
| output_dtype=self.out_dtype, | ||
| output=output, | ||
| token_selected_experts=topk_ids.to(torch.int32), | ||
| token_final_scales=topk_weights, | ||
| ) | ||
| output.copy_(result) |
There was a problem hiding this comment.
The B12xMoEWrapper.run method supports an out parameter, which allows the kernel to write results directly into the provided buffer. Using out=output avoids an extra tensor allocation inside run and a subsequent copy_ operation, which is significantly more efficient for the inference hot path.
| result = self._wrapper.run( | |
| x=hidden_states, | |
| token_selected_experts=topk_ids.to(torch.int32), | |
| token_final_scales=topk_weights, | |
| w1_weight=w1, | |
| w1_weight_sf=self.w1_sf_mma, | |
| w1_alpha=self.g1_alphas, | |
| fc2_input_scale=self.a2_gscale, | |
| w2_weight=w2, | |
| w2_weight_sf=self.w2_sf_mma, | |
| w2_alpha=self.g2_alphas, | |
| num_experts=global_num_experts, | |
| top_k=top_k, | |
| num_local_experts=self.num_local_experts, | |
| output_dtype=self.out_dtype, | |
| output=output, | |
| token_selected_experts=topk_ids.to(torch.int32), | |
| token_final_scales=topk_weights, | |
| ) | |
| output.copy_(result) | |
| self._wrapper.run( | |
| x=hidden_states, | |
| w1_weight=w1, | |
| w1_weight_sf=self.w1_sf_mma, | |
| w1_alpha=self.g1_alphas, | |
| fc2_input_scale=self.a2_gscale, | |
| w2_weight=w2, | |
| w2_weight_sf=self.w2_sf_mma, | |
| w2_alpha=self.g2_alphas, | |
| token_selected_experts=topk_ids.to(torch.int32), | |
| token_final_scales=topk_weights, | |
| out=output, | |
| ) |
|
Hi @askliar, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Stacked on top of vllm-project#43328. Plumbs FlashInfer's `cutlass_prefill_threshold` kwarg through `FlashInferB12xExperts`, enabling hybrid dispatch where the wrapper routes batches with `num_tokens >= threshold` through `cutlass_fused_moe` (prefill path) and small batches through the b12x kernels (decode path). Key changes: - New env var `VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD` (default 0, pure b12x dispatch). - When threshold > 0, `process_weights_after_loading` clones the CUTLASS-format swizzled FP8 SF + a/g alphas (with the `g_alphas/=a_gs` CUTLASS rescale) BEFORE the in-place B12x rewrite destroys them, and registers them as `nn.Parameter`s on the layer so EPLB rearranges them in lockstep with the live b12x scales. - `_ensure_wrapper` passes the kwarg, gated on `inspect.signature` to remain compatible with older FlashInfer builds (silent skip when threshold is 0, hard error when >0 and the kwarg is missing), and calls `register_cutlass_prefill_weights` once. The FP4 weight bytes are reusable between the b12x and CUTLASS paths — `prepare_nvfp4_moe_layer_for_fi_or_cutlass` produces the same `[w3, w1]` reorder + swizzled SF for both `FLASHINFER_CUTLASS` and `FLASHINFER_B12X` — so only the scales need to be cloned. This PR was prepared with AI assistance. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Andrii Skliar <askliar@nvidia.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Summary
Stacked on top of #40082.
This PR refines the FlashInfer B12x MoE integration by switching the SM12x MoE path to FlashInfer's
B12xMoEWrapperAPI and adding ReLU2 / non-gated MoE coverage.Key changes:
B12xMoEWrapperforFlashInferB12xExpertsis_act_and_mulDuplicate-work check
This is intentionally not a duplicate of #40082: it is an incremental stacked change on top of #40082.