[B12x] W4A16 NVFP4 support + Nemotron-3.5 / Qwen3.5 fixes#43333
[B12x] W4A16 NVFP4 support + Nemotron-3.5 / Qwen3.5 fixes#43333askliar wants to merge 4 commits into
Conversation
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request updates FlashInfer to version 0.6.11.post3 and refactors the B12x MoE implementation to utilize the new B12xMoEWrapper, introducing support for ReLU2 activation and W4A16 quantization schemes. It also enhances ModelOpt quantization for Qwen and Nemotron models by improving LM head handling and excluding specific MTP layers from compression. Review feedback identifies several critical issues: the VLLM_FLASHINFER_B12X_ACTIVATION_PRECISION environment variable is currently ignored, the use of inspect.stack() for source format detection is considered fragile, and the copy_ operation in the MoE apply method introduces a memory overhead and performance regression. Furthermore, the reliance on torch.get_default_dtype() in Marlin utilities could cause assertion failures if the default is set to float32.
| self.activation_precision = ( | ||
| "fp4" if quant_config.a1_gscale is not None else "bf16" | ||
| ) |
There was a problem hiding this comment.
The environment variable VLLM_FLASHINFER_B12X_ACTIVATION_PRECISION (defined in vllm/envs.py) is currently ignored in this implementation. The activation_precision is hardcoded based on the presence of a1_gscale, which prevents users from overriding this setting for debugging or performance tuning. The logic should be updated to respect the environment variable while falling back to safe defaults.
| def _detect_source_format() -> str: | ||
| """Walk the constructor's call stack to find the parent quant-method | ||
| class and map it to a FlashInfer ``source_format`` string. | ||
|
|
||
| ``make_nvfp4_moe_kernel`` instantiates the experts class from the | ||
| parent method's ``create_weights`` (compressed-tensors) or equivalent | ||
| (modelopt) — so the parent ``self`` is reachable in an outer frame. | ||
| Fall back to "modelopt" if no recognized parent is found. | ||
| """ | ||
| import inspect | ||
|
|
||
| for frame_info in inspect.stack(): | ||
| parent = frame_info.frame.f_locals.get("self") | ||
| if parent is None: | ||
| continue | ||
| cls_name = type(parent).__name__ | ||
| if "CompressedTensors" in cls_name: | ||
| return "compressed_tensors" | ||
| if "ModelOpt" in cls_name: | ||
| return "modelopt" | ||
| return "modelopt" |
There was a problem hiding this comment.
Using inspect.stack() to determine the source format is fragile and introduces significant maintainability overhead. It relies on the internal call stack structure, which can easily break if the instantiation logic is refactored or wrapped. This information should be passed explicitly through the constructor or via the quant_config. Given the existing TODO in the PR description, this should be prioritized to avoid technical debt.
| token_selected_experts=topk_ids.to(torch.int32), | ||
| token_final_scales=topk_weights, | ||
| ) | ||
| output.copy_(result) |
There was a problem hiding this comment.
This copy_ operation introduces an extra memory copy and a new tensor allocation for every forward pass, which is a performance regression compared to the previous functional API that accepted an output buffer directly. If B12xMoEWrapper.run does not support an out parameter, it is highly recommended to update the FlashInfer wrapper to support in-place operations to maintain optimal performance for MoE models.
| if param_dtype is None: | ||
| param_dtype = torch.get_default_dtype() |
There was a problem hiding this comment.
Relying on torch.get_default_dtype() as a fallback is risky here. If the default dtype is torch.float32 (the standard PyTorch default), the subsequent call to nvfp4_marlin_process_global_scale will trigger an assertion failure (line 137), as it only supports half and bfloat16. It is safer to attempt to retrieve the dtype from the layer's weights or use a more appropriate fallback that aligns with the supported dtypes of the Marlin kernel.
…bles in `envs.py`
|
Hi @askliar, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
Summary
To be merged after #43328
Adds W4A16 (NVFP4) to the SM12x FlashInfer B12x MoE path, plus model-side fixes for running Nemotron-H 3.5 and Qwen3.5-MoE under modelopt and compressed-tensors checkpoints. Follow-up to #40082 (W4A4 SM12x). Bumps FlashInfer to
0.6.11.post3.Changes
FlashInferB12xExpertsactivation_precision(auto-detected fromquant_config.a1_gscale); accepts compressed-tensors NVFP4 key shape.source_formatforwarded toB12xMoEWrapper(via call-stack inspection — TODO: plumb throughFusedMoEQuantConfig).process_weights_after_loading.Modelopt / LM-head wiring
ModelOptMixedPrecisionConfig.get_quant_method: handles Qwen VLM nested-prefix LM heads andlanguage_model.model.↔model.language_model.swap; routesParallelLMHeadthrough FP8/NVFP4 methods.*ForCausalLM+ MTP passquant_configtoParallelLMHead.prepare_fp4_layer_for_marlinfalls back totorch.get_default_dtype()whenparams_dtypeis absent.VocabParallelEmbedding.weight_loaderreshapes scalar FP4 scales instead of asserting.MTP × compressed-tensors
compressed_tensors_config.ignorewith per-expert MTP linears (BF16 in released checkpoints).Misc
prepare_nvfp4_moe_layer_for_fi_or_cutlassupdatesintermediate_size_per_partitionafter padding.Test plan
pytest tests/kernels/moe/test_flashinfer_b12x_moe.py -von SM120 / SM121.