Skip to content

[B12x] W4A16 NVFP4 support + Nemotron-3.5 / Qwen3.5 fixes#43333

Open
askliar wants to merge 4 commits into
vllm-project:mainfrom
askliar:feat/add_b12x_w4a16_support
Open

[B12x] W4A16 NVFP4 support + Nemotron-3.5 / Qwen3.5 fixes#43333
askliar wants to merge 4 commits into
vllm-project:mainfrom
askliar:feat/add_b12x_w4a16_support

Conversation

@askliar

@askliar askliar commented May 21, 2026

Copy link
Copy Markdown
Contributor

Summary

To be merged after #43328

Adds W4A16 (NVFP4) to the SM12x FlashInfer B12x MoE path, plus model-side fixes for running Nemotron-H 3.5 and Qwen3.5-MoE under modelopt and compressed-tensors checkpoints. Follow-up to #40082 (W4A4 SM12x). Bumps FlashInfer to 0.6.11.post3.

Changes

FlashInferB12xExperts

  • W4A16 path via activation_precision (auto-detected from quant_config.a1_gscale); accepts compressed-tensors NVFP4 key shape.
  • source_format forwarded to B12xMoEWrapper (via call-stack inspection — TODO: plumb through FusedMoEQuantConfig).
  • MMA-layout views cached in process_weights_after_loading.

Modelopt / LM-head wiring

  • ModelOptMixedPrecisionConfig.get_quant_method: handles Qwen VLM nested-prefix LM heads and language_model.model.model.language_model. swap; routes ParallelLMHead through FP8/NVFP4 methods.
  • Nemotron-H / Qwen3.5 *ForCausalLM + MTP pass quant_config to ParallelLMHead.
  • prepare_fp4_layer_for_marlin falls back to torch.get_default_dtype() when params_dtype is absent.
  • VocabParallelEmbedding.weight_loader reshapes scalar FP4 scales instead of asserting.

MTP × compressed-tensors

  • Extend compressed_tensors_config.ignore with per-expert MTP linears (BF16 in released checkpoints).

Misc

  • prepare_nvfp4_moe_layer_for_fi_or_cutlass updates intermediate_size_per_partition after padding.

Test plan

  • pytest tests/kernels/moe/test_flashinfer_b12x_moe.py -v on SM120 / SM121.
  • E2E: Nemotron-H 3.5 (W4A4 + W4A16), Qwen3.5-MoE VLM (modelopt).

Andrii Skliar added 3 commits May 21, 2026 17:40

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates FlashInfer to version 0.6.11.post3 and refactors the B12x MoE implementation to utilize the new B12xMoEWrapper, introducing support for ReLU2 activation and W4A16 quantization schemes. It also enhances ModelOpt quantization for Qwen and Nemotron models by improving LM head handling and excluding specific MTP layers from compression. Review feedback identifies several critical issues: the VLLM_FLASHINFER_B12X_ACTIVATION_PRECISION environment variable is currently ignored, the use of inspect.stack() for source format detection is considered fragile, and the copy_ operation in the MoE apply method introduces a memory overhead and performance regression. Furthermore, the reliance on torch.get_default_dtype() in Marlin utilities could cause assertion failures if the default is set to float32.

Comment on lines +83 to +85
self.activation_precision = (
"fp4" if quant_config.a1_gscale is not None else "bf16"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The environment variable VLLM_FLASHINFER_B12X_ACTIVATION_PRECISION (defined in vllm/envs.py) is currently ignored in this implementation. The activation_precision is hardcoded based on the presence of a1_gscale, which prevents users from overriding this setting for debugging or performance tuning. The logic should be updated to respect the environment variable while falling back to safe defaults.

Comment on lines +96 to +116
def _detect_source_format() -> str:
"""Walk the constructor's call stack to find the parent quant-method
class and map it to a FlashInfer ``source_format`` string.

``make_nvfp4_moe_kernel`` instantiates the experts class from the
parent method's ``create_weights`` (compressed-tensors) or equivalent
(modelopt) — so the parent ``self`` is reachable in an outer frame.
Fall back to "modelopt" if no recognized parent is found.
"""
import inspect

for frame_info in inspect.stack():
parent = frame_info.frame.f_locals.get("self")
if parent is None:
continue
cls_name = type(parent).__name__
if "CompressedTensors" in cls_name:
return "compressed_tensors"
if "ModelOpt" in cls_name:
return "modelopt"
return "modelopt"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using inspect.stack() to determine the source format is fragile and introduces significant maintainability overhead. It relies on the internal call stack structure, which can easily break if the instantiation logic is refactored or wrapped. This information should be passed explicitly through the constructor or via the quant_config. Given the existing TODO in the PR description, this should be prioritized to avoid technical debt.

token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
)
output.copy_(result)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This copy_ operation introduces an extra memory copy and a new tensor allocation for every forward pass, which is a performance regression compared to the previous functional API that accepted an output buffer directly. If B12xMoEWrapper.run does not support an out parameter, it is highly recommended to update the FlashInfer wrapper to support in-place operations to maintain optimal performance for MoE models.

Comment on lines +222 to +223
if param_dtype is None:
param_dtype = torch.get_default_dtype()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Relying on torch.get_default_dtype() as a fallback is risky here. If the default dtype is torch.float32 (the standard PyTorch default), the subsequent call to nvfp4_marlin_process_global_scale will trigger an assertion failure (line 137), as it only supports half and bfloat16. It is safer to attempt to retrieve the dtype from the layer's weights or use a more appropriate fallback that aligns with the supported dtypes of the Marlin kernel.

@mergify

mergify Bot commented May 21, 2026

Copy link
Copy Markdown
Contributor

Hi @askliar, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify

mergify Bot commented May 26, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @askliar.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant