[MoE/b12x] Accept W4A16 (kNvfp4Static, None) in FlashInferB12xExperts supports check#43332
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. Agent GuidelinesIMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request updates the flashinfer_b12x_moe expert to support W4A16 NVFP4 checkpoints by allowing None as an activation quantization key in the _supports_quant_scheme method. While this enables the dispatcher to select the backend, feedback points out that the apply method contains a strict assertion requiring an activation scale (a2_gscale), which will likely cause an AssertionError for these models. It is recommended to update the apply method to provide a default scale tensor instead of asserting.
|
add @meena-at-work, @juhi10071998 as reviewer |
|
Crossing reference with #43341 (@askliar) which also touches
So this PR covers the modelopt-native W4A16 family (e.g. The only mechanical conflict is in the body of One important note for landing orderThe This matters because compressed-tensors So even if #43341 lands first, this PR's Change 2 is still needed end-to-end for the compressed-tensors W4A16 path to actually serve requests. cc @askliar — happy to coordinate landing order, or if you'd prefer I can rebase Change 2 onto your |
… supports check `FlashInferB12xExperts._supports_quant_scheme` currently requires the activation key to be `kNvfp4Dynamic`, which makes the dispatcher reject every W4A16 NVFP4 checkpoint (activation_key == None) -- e.g. `nvidia/Qwen3.6-35B-A3B-2.06GB-per-token`. This forces such checkpoints onto Marlin even though the b12x kernel itself is W4A16-compatible. Per the class docstring: "Input quantization (BF16->FP4) is performed inside the kernel so BF16 hidden states are passed directly." -- i.e. the kernel already handles the BF16-activation case correctly. This change only loosens the metadata gate; no kernel-side changes. PR vllm-project#42566 ("W4A16 NVFP4 fused MoE + mixed-precision dispatch") only touches quantization/modelopt.py and acknowledges the gap in its own commit message: "their _supports_quant_scheme requires (kNvfp4Static, kNvfp4Dynamic) exactly... only Marlin survives." That PR deliberately routes W4A16 to Marlin as a workaround; this PR is the fix on the b12x side. The two are complementary and can land independently -- once both land, W4A16 NVFP4 prefers b12x (fast path). Failure mode without this PR: ValueError: NvFp4 MoE backend 'FLASHINFER_B12X' does not support the deployment configuration since kernel does not support quantization scheme QuantKey(u8, scale(f8e4m3fn, static, GroupShape(row=1, col=16)), scale2(f32, static, per_tensor), symmetric) x None. Tested on DGX Spark (GB10, sm_121a) with vllm/vllm-openai:nightly-aarch64 + this PR + the FP8-backend-env companion PR. Model: nvidia/Qwen3.6-35B-A3B-2.06GB-per-token (modelopt-native, mixed NVFP4 + FP8 experts). aiperf K=3 AL=3.12, BS=1, ISL=2048+32K prefix=34,831, OSL=1024, 60 measured + 10 warmup, 0 errors: Output Token Throughput : 91.00 tok/s Output Token Throughput / user : 97.42 tok/s/user TTFT : 746.81 ms ITL : 10.27 ms Request Latency : 11,249.37 ms MTP acceptance length : 3.15 (target 3.12) For reference on the same workload: Marlin (current W4A16 fallback) : OTT 92.26, TTFT 798.72 b12x on dgx-fork (matched cubin): OTT 95.15, TTFT 758.90 Without this change b12x rejects the checkpoint at engine init; with it b12x runs and matches/exceeds Marlin on the b12x-fast path. Signed-off-by: Junhao Shen <junshen@nvidia.com>
…hout activation quant metadata Addresses review feedback on the preceding commit (supports-check loosening for W4A16). `FlashInferB12xExperts.apply` previously asserted `self.a2_gscale is not None` unconditionally. For W4A16 NVFP4 checkpoints lacking static activation-quant metadata (e.g. compressed-tensors W4A16-CT layouts), `a2_gscale` is legitimately None and the assert fires at the first forward pass -- strictly worse than the engine-init rejection we just removed at the dispatcher gate. `process_weights_after_loading` already tolerates `a2_gscale is None` (the `if self.a2_gscale is not None: ...` guard at the top of this file), so the assert is the inconsistency. The b12x kernel performs dynamic per-block FC2-input quantization internally, so a uniform 1.0 scale per expert is semantically equivalent to the bake-in done for static-quant checkpoints. Construct the default in apply() instead of asserting. Signed-off-by: Junhao Shen <junshen@nvidia.com>
8b7c0a0 to
bcc299b
Compare
| # quant checkpoints. | ||
| fc2_input_scale = self.a2_gscale | ||
| if fc2_input_scale is None: | ||
| fc2_input_scale = torch.ones( |
There was a problem hiding this comment.
This is apply and hot path. I think it is better do any allocation somewhere early.
There was a problem hiding this comment.
okay updated, route to early stage
|
@claude review pls |
|
Hi @ECMGit, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
384f8f1 to
5d23ea5
Compare
|
Hi @ECMGit, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
adding self._fc2_input_scale since the self.a2_gscale is a read-only attribute Signed-off-by: Junhao Shen <junshen@nvidia.com>
5d23ea5 to
0114e1d
Compare
|
Hi @vadiklyutiy , the pre-commit has passed. let me know if further change is needed. |
… supports check (vllm-project#43332) Signed-off-by: Junhao Shen <junshen@nvidia.com> Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
… supports check (vllm-project#43332) Signed-off-by: Junhao Shen <junshen@nvidia.com> Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
… supports check (vllm-project#43332) Signed-off-by: Junhao Shen <junshen@nvidia.com> Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
… supports check (vllm-project#43332) Signed-off-by: Junhao Shen <junshen@nvidia.com> Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Signed-off-by: JisoLya <523420504@qq.com>
Purpose
FlashInferB12xExperts._supports_quant_scheme(introduced by PR #40082) currently requires the activation key to bekNvfp4Dynamic, which makes the dispatcher reject every W4A16 NVFP4 checkpoint (activation_key == None) — e.g.nvidia/Qwen3.6-35B-A3B-2.06GB-per-token. This forces such checkpoints onto Marlin, even though the b12x kernel itself is W4A16-compatible.PR #42566 ("W4A16 NVFP4 fused MoE + mixed-precision dispatch") acknowledges this exact gap in its own commit message — "their
_supports_quant_schemerequires(kNvfp4Static, kNvfp4Dynamic)exactly... only Marlin survives" — and deliberately routes W4A16 to Marlin as a workaround. This PR is the actual fix on the b12x side, complementary to #42566 and independent of its merge.Per the b12x class docstring: "Input quantization (BF16→FP4) is performed inside the kernel so BF16 hidden states are passed directly." — i.e. the kernel already handles the BF16-activation case correctly. This PR only loosens the metadata gate; no kernel-side changes needed.
Change:
Failure mode without this PR:
Risk: None — gate-only loosening. No kernel changes. No new dependencies. The existing
(kNvfp4Static, kNvfp4Dynamic)path is preserved verbatim.Test Plan
Hardware: DGX Spark (GB10, sm_121a)
Container:
vllm/vllm-openai:nightly-aarch64+ this PR + the FP8-backend-env companion PRModel:
nvidia/Qwen3.6-35B-A3B-2.06GB-per-token(modelopt-native, mixed NVFP4 + FP8 experts)Tooling: FlashInfer
0.6.11.post3, cutlass-dsl trio4.5.1Serve command:
Benchmark (aiperf): K=3 AL=3.12, BS=1 / concurrency=1, ISL=2048 + 32K user-context (total ISL=34,831), OSL=1024, 10 warmup + 60 measured requests,
--use-server-token-count --streaming.Test Result
With this PR (b12x dispatcher accepts W4A16):
Without this PR: serve dies at engine init with the
ValueErrorshown above before any throughput can be measured.Reference on the same hardware/workload (sanity comparison):
With this PR, b12x runs cleanly on W4A16 NVFP4 checkpoints and reaches parity with Marlin on the public wheel stack (the small ~4% delta vs the dgx-fork b12x number is attributable to a known
flashinfer-cubin 0.6.11.post2/flashinfer 0.6.11.post3version-skew in the public PyPI wheel — unrelated to this PR).Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.