Skip to content

Add NVFP4 all-gather GEMM fusion for AsyncTP#41882

Merged
ProExpertProg merged 10 commits into
vllm-project:mainfrom
baonudesifeizhai:fixnvfp4asynctp
May 10, 2026
Merged

Add NVFP4 all-gather GEMM fusion for AsyncTP#41882
ProExpertProg merged 10 commits into
vllm-project:mainfrom
baonudesifeizhai:fixnvfp4asynctp

Conversation

@baonudesifeizhai

@baonudesifeizhai baonudesifeizhai commented May 7, 2026

Copy link
Copy Markdown
Contributor

Purpose

#27893

wires the NVFP4 FlashInfer all-gather + GEMM path into AsyncTP.

It adds NVFP4 coverage for SP + AsyncTP by fusing:

all_gather(fp4 activation) + all_gather(group scales) + flashinfer_mm_fp4

into:

fused_all_gather_flashinfer_fp4_matmul

The reduce-scatter side is intentionally not enabled for NVFP4 in this PR.

PyTorch Gap

PyTorch does not currently provide an NVFP4-aware fused GEMM + reduce-scatter path.

The existing symmetric-memory helpers are designed around generic matmul / FP8-style scaling and do not handle NVFP4 block/group scales, scale swizzling, or layout-aware sharding. Reusing the FP8 helper would require incorrectly slicing NVFP4 scales, so this PR keeps scope to the all-gather path only.

Test Plan

CUDA_VISIBLE_DEVICES=4,5 .venv/bin/python -m pytest \
tests/compile/fusions_e2e/test_tp2_async_tp.py::test_tp2_async_tp_nvfp4_fusions
-v -s--passed

Test Result

tp=4  

Input Length | Output Throughput no_sp (tok/s) | Output Throughput on (tok/s) | Output Gain | Total Throughput no_sp (tok/s) | Total Throughput on (tok/s) | Total Gain
-- | -- | -- | -- | -- | -- | --
1024 | 823.22 | 830.51 | +0.89% | 7408.96 | 7474.58 | +0.89%
2048 | 747.13 | 762.30 | +2.03% | 12701.14 | 12959.09 | +2.03%
4096 | 614.05 | 643.85 | +4.85% | 20263.75 | 21246.99 | +4.85%
8192 | 440.99 | 482.72 | +9.46% | 28664.43 | 31376.75 | +9.46%
16384 | 273.97 | 307.62 | +12.28% | 35342.53 | 39682.62 | +12.28%
32768 | 146.12 | 165.90 | +13.54% | 37551.79 | 42636.69 | +13.54%
---
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVFP4 (Blackwell) fusions, specifically for AsyncTP and sequence parallelism, by adding new custom operators and fusion patterns. Feedback highlights a performance optimization opportunity to use symmetric memory for intermediate buffers in the AsyncTP path and suggests simplifying redundant tensor view logic in the fusion patterns.

Comment on lines +285 to +286
A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1])
A_scale = A_scale_shard.new_empty(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The intermediate buffers A and A_scale are allocated using new_empty on every call to this custom op. For AsyncTP to be effective, these buffers should ideally be allocated in symmetric memory to avoid unnecessary copies during the all-gather operation. Furthermore, constant allocation of large buffers in the hot path can lead to significant performance overhead. Consider using torch.ops.symm_mem.empty_symm_mem or a similar mechanism to ensure these buffers are symmetric and potentially cached.

Comment on lines +857 to +860
if self.a_scale_view in ("float8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn)
if self.a_scale_view in ("uint8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The double view logic for float8_uint8 is redundant. If the goal is to obtain a uint8 tensor, viewing directly as uint8 is sufficient regardless of whether it was previously viewed as float8. This simplifies the pattern and avoids unnecessary operations in the graph.

Suggested change
if self.a_scale_view in ("float8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn)
if self.a_scale_view in ("uint8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8)
if self.a_scale_view in ("float8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn)
elif self.a_scale_view == "uint8":
a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8)

@ProExpertProg ProExpertProg left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, just a few nits! Also, can we add some test cases to SP and AsyncTP correctness CI jobs (should be in e2e_correctness CI)

inductor_graph_partition: bool,
run_e2e_fusion_test,
):
# NVFP4 currently wires the all-gather + GEMM path only. The generic

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's set this on llama-fp4 model directly?

a_scale_view=a_scale_view,
)
)
# NVFP4 activation scales are block/group scales, not FP8

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, thinking about this again, isn't reduce scatter trivial? Inputs are already column-parallel across ranks, so each rank has the appropriate scales and inputs only. Output is full size but it's activations only (and partial numerically), so reduction is needed but only on the output, no scale comms need to be involved.

Am I missing something?

logger = init_logger(__name__)

if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol I don't think this is static vs dynamic, these are just the overloads

@baonudesifeizhai

Copy link
Copy Markdown
Contributor Author

no_sp:

ns', 'num_concurrent': 8, 'tokenized_requests': False}), gen_kwargs: ({'temperature': 0}), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.92|±  |0.0273|
|     |       |strict-match    |     5|exact_match|↑  | 0.64|±  |0.0482|

on: https://paste.ubuntu.com/p/t6Z7wccr65/

 
local-completions ({'model': '/root/zdj/models/Llama-3.3-70B-Instruct-NVFP4', 'base_url': 'http://127.0.0.1:24081/v1/completions', 'num_concurrent': 8, 'tokenized_requests': False}), gen_kwargs: ({'temperature': 0}), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.94|±  |0.0239|
|     |       |strict-match    |     5|exact_match|↑  | 0.65|±  |0.0479|

Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
@baonudesifeizhai

Copy link
Copy Markdown
Contributor Author
CUDA_VISIBLE_DEVICES=4,5 \
VLLM_NVFP4_GEMM_BACKEND=flashinfer-cutlass \
.venv/bin/python -m pytest -s -vv \
  tests/compile/fusions_e2e/test_tp2_async_tp.py::test_tp2_async_tp_nvfp4_fusions \
  --tb=short
CUDA_VISIBLE_DEVICES=4,5 \
.venv/bin/python -m pytest -s -vv \
  tests/compile/correctness_e2e/test_async_tp.py::test_async_tp_pass_nvfp4_correctness \
  --tb=short
CUDA_VISIBLE_DEVICES=4,5 \
.venv/bin/python -m pytest -s -vv \
  tests/compile/correctness_e2e/test_sequence_parallel.py::test_tp_sp_nvfp4_generation \
  --tb=short

all passed

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label May 9, 2026
@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA May 9, 2026
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
@mergify mergify Bot added the llama Related to Llama models label May 9, 2026
@ProExpertProg ProExpertProg enabled auto-merge (squash) May 9, 2026 23:09
@ProExpertProg ProExpertProg merged commit bc5fdc1 into vllm-project:main May 10, 2026
73 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA May 10, 2026
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request May 11, 2026
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants