Add NVFP4 all-gather GEMM fusion for AsyncTP#41882
Conversation
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for NVFP4 (Blackwell) fusions, specifically for AsyncTP and sequence parallelism, by adding new custom operators and fusion patterns. Feedback highlights a performance optimization opportunity to use symmetric memory for intermediate buffers in the AsyncTP path and suggests simplifying redundant tensor view logic in the fusion patterns.
| A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1]) | ||
| A_scale = A_scale_shard.new_empty( |
There was a problem hiding this comment.
The intermediate buffers A and A_scale are allocated using new_empty on every call to this custom op. For AsyncTP to be effective, these buffers should ideally be allocated in symmetric memory to avoid unnecessary copies during the all-gather operation. Furthermore, constant allocation of large buffers in the hot path can lead to significant performance overhead. Consider using torch.ops.symm_mem.empty_symm_mem or a similar mechanism to ensure these buffers are symmetric and potentially cached.
| if self.a_scale_view in ("float8", "float8_uint8"): | ||
| a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn) | ||
| if self.a_scale_view in ("uint8", "float8_uint8"): | ||
| a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8) |
There was a problem hiding this comment.
The double view logic for float8_uint8 is redundant. If the goal is to obtain a uint8 tensor, viewing directly as uint8 is sufficient regardless of whether it was previously viewed as float8. This simplifies the pattern and avoids unnecessary operations in the graph.
| if self.a_scale_view in ("float8", "float8_uint8"): | |
| a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn) | |
| if self.a_scale_view in ("uint8", "float8_uint8"): | |
| a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8) | |
| if self.a_scale_view in ("float8", "float8_uint8"): | |
| a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn) | |
| elif self.a_scale_view == "uint8": | |
| a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8) |
ProExpertProg
left a comment
There was a problem hiding this comment.
Nice work, just a few nits! Also, can we add some test cases to SP and AsyncTP correctness CI jobs (should be in e2e_correctness CI)
| inductor_graph_partition: bool, | ||
| run_e2e_fusion_test, | ||
| ): | ||
| # NVFP4 currently wires the all-gather + GEMM path only. The generic |
There was a problem hiding this comment.
Let's set this on llama-fp4 model directly?
| a_scale_view=a_scale_view, | ||
| ) | ||
| ) | ||
| # NVFP4 activation scales are block/group scales, not FP8 |
There was a problem hiding this comment.
Wait, thinking about this again, isn't reduce scatter trivial? Inputs are already column-parallel across ranks, so each rank has the appropriate scales and inputs only. Output is full size but it's activations only (and partial numerically), so reduction is needed but only on the output, no scale comms need to be involved.
Am I missing something?
| logger = init_logger(__name__) | ||
|
|
||
| if hasattr(torch.ops._C, "scaled_fp4_quant"): | ||
| STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out |
There was a problem hiding this comment.
Lol I don't think this is static vs dynamic, these are just the overloads
|
no_sp: on: https://paste.ubuntu.com/p/t6Z7wccr65/ |
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
all passed |
Signed-off-by: roG0d <baonudesifeizhai@gmail.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Signed-off-by: roG0d <baonudesifeizhai@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Purpose
#27893
wires the NVFP4 FlashInfer all-gather + GEMM path into AsyncTP.
It adds NVFP4 coverage for SP + AsyncTP by fusing:
all_gather(fp4 activation) + all_gather(group scales) + flashinfer_mm_fp4into:
fused_all_gather_flashinfer_fp4_matmulThe reduce-scatter side is intentionally not enabled for NVFP4 in this PR.
PyTorch Gap
PyTorch does not currently provide an NVFP4-aware fused GEMM + reduce-scatter path.
The existing symmetric-memory helpers are designed around generic matmul / FP8-style scaling and do not handle NVFP4 block/group scales, scale swizzling, or layout-aware sharding. Reusing the FP8 helper would require incorrectly slicing NVFP4 scales, so this PR keeps scope to the all-gather path only.
Test Plan
CUDA_VISIBLE_DEVICES=4,5 .venv/bin/python -m pytest \
tests/compile/fusions_e2e/test_tp2_async_tp.py::test_tp2_async_tp_nvfp4_fusions
-v -s--passed
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.