Skip to content

Enable B12x backend for non-gated MoEs (like Nemotron) #43328

Open
askliar wants to merge 1 commit into
vllm-project:mainfrom
askliar:askliar/b12x-wrapper-on-pr40082
Open

Enable B12x backend for non-gated MoEs (like Nemotron) #43328
askliar wants to merge 1 commit into
vllm-project:mainfrom
askliar:askliar/b12x-wrapper-on-pr40082

Conversation

@askliar

@askliar askliar commented May 21, 2026

Copy link
Copy Markdown
Contributor

Summary

Stacked on top of #40082.

This PR refines the FlashInfer B12x MoE integration by switching the SM12x MoE path to FlashInfer's B12xMoEWrapper API and adding ReLU2 / non-gated MoE coverage.

Key changes:

  • Use B12xMoEWrapper for FlashInferB12xExperts
  • Keep BF16 hidden states as unquantized inputs; B12x handles FP4 activation quantization internally
  • Support both SiLU gated MoE and ReLU2 non-gated MoE
  • Add ReLU2 test coverage
  • Update the test helper to allow non-default activation and is_act_and_mul

Duplicate-work check

This is intentionally not a duplicate of #40082: it is an incremental stacked change on top of #40082.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the FlashInfer B12x MoE implementation to use the B12xMoEWrapper and adds support for the ReLU2 activation function. Key changes include lazy wrapper initialization, updated weight processing in tests, and support for non-gated MoE configurations. Review feedback identifies a missing Any import that would cause a runtime error and suggests optimizing the apply method by passing the output buffer directly to the wrapper's run call to avoid unnecessary tensor allocations.

@@ -20,7 +20,6 @@
)
from vllm.platforms import current_platform

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Any type hint is used in the __init__ method (line 82), but it has not been imported from the typing module. This will cause a NameError at runtime when the class is instantiated. Please add the missing import.

from typing import Any
from vllm.platforms import current_platform

Comment on lines +254 to +266
result = self._wrapper.run(
x=hidden_states,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
num_experts=global_num_experts,
top_k=top_k,
num_local_experts=self.num_local_experts,
output_dtype=self.out_dtype,
output=output,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
)
output.copy_(result)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The B12xMoEWrapper.run method supports an out parameter, which allows the kernel to write results directly into the provided buffer. Using out=output avoids an extra tensor allocation inside run and a subsequent copy_ operation, which is significantly more efficient for the inference hot path.

Suggested change
result = self._wrapper.run(
x=hidden_states,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
num_experts=global_num_experts,
top_k=top_k,
num_local_experts=self.num_local_experts,
output_dtype=self.out_dtype,
output=output,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
)
output.copy_(result)
self._wrapper.run(
x=hidden_states,
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
out=output,
)

@mergify

mergify Bot commented May 21, 2026

Copy link
Copy Markdown
Contributor

Hi @askliar, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

askliar pushed a commit to askliar/vllm that referenced this pull request May 26, 2026
Stacked on top of vllm-project#43328.

Plumbs FlashInfer's `cutlass_prefill_threshold` kwarg through
`FlashInferB12xExperts`, enabling hybrid dispatch where the wrapper
routes batches with `num_tokens >= threshold` through `cutlass_fused_moe`
(prefill path) and small batches through the b12x kernels (decode path).

Key changes:
- New env var `VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD` (default 0,
  pure b12x dispatch).
- When threshold > 0, `process_weights_after_loading` clones the
  CUTLASS-format swizzled FP8 SF + a/g alphas (with the `g_alphas/=a_gs`
  CUTLASS rescale) BEFORE the in-place B12x rewrite destroys them, and
  registers them as `nn.Parameter`s on the layer so EPLB rearranges them
  in lockstep with the live b12x scales.
- `_ensure_wrapper` passes the kwarg, gated on `inspect.signature` to
  remain compatible with older FlashInfer builds (silent skip when
  threshold is 0, hard error when >0 and the kwarg is missing), and
  calls `register_cutlass_prefill_weights` once.

The FP4 weight bytes are reusable between the b12x and CUTLASS paths —
`prepare_nvfp4_moe_layer_for_fi_or_cutlass` produces the same `[w3, w1]`
reorder + swizzled SF for both `FLASHINFER_CUTLASS` and `FLASHINFER_B12X`
— so only the scales need to be cloned.

This PR was prepared with AI assistance.

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@mergify

mergify Bot commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @askliar.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant