-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Add NVFP4 all-gather GEMM fusion for AsyncTP #41882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
cead7b6
2d922c5
7699df0
916a281
b1d8bfc
a612c3e
ffda91d
a47fa75
4df4172
b3fdf4f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -74,6 +74,36 @@ def _flashinfer_scaled_mm_out( | |||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _flashinfer_fp4_mm_out( | ||||||||||||||||||
| A: torch.Tensor, | ||||||||||||||||||
| B: torch.Tensor, | ||||||||||||||||||
| *, | ||||||||||||||||||
| scale_a: torch.Tensor, | ||||||||||||||||||
| scale_b: torch.Tensor, | ||||||||||||||||||
| out: torch.Tensor, | ||||||||||||||||||
| alpha: torch.Tensor, | ||||||||||||||||||
| out_dtype: torch.dtype | None = None, | ||||||||||||||||||
| use_8x4_sf_layout: bool = False, | ||||||||||||||||||
| backend: str = "cutlass", | ||||||||||||||||||
| ) -> None: | ||||||||||||||||||
| from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm_out | ||||||||||||||||||
|
|
||||||||||||||||||
| assert A.ndim == 2 and B.ndim == 2 and out.ndim == 2, ( | ||||||||||||||||||
| "FlashInfer FP4 symm_mem adapter expects 2D inputs and output" | ||||||||||||||||||
| ) | ||||||||||||||||||
| flashinfer_scaled_fp4_mm_out( | ||||||||||||||||||
| A, | ||||||||||||||||||
| B, | ||||||||||||||||||
| scale_a, | ||||||||||||||||||
| scale_b, | ||||||||||||||||||
| alpha, | ||||||||||||||||||
| out=out, | ||||||||||||||||||
| out_dtype=out_dtype or out.dtype, | ||||||||||||||||||
| use_8x4_sf_layout=use_8x4_sf_layout, | ||||||||||||||||||
| backend=backend, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def fused_flashinfer_scaled_matmul_reduce_scatter_fake( | ||||||||||||||||||
| A: torch.Tensor, | ||||||||||||||||||
| B: torch.Tensor, | ||||||||||||||||||
|
|
@@ -197,6 +227,90 @@ def fused_all_gather_flashinfer_scaled_matmul( | |||||||||||||||||
| return outputs[0] | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def fused_all_gather_flashinfer_fp4_matmul_fake( | ||||||||||||||||||
| A_shard: torch.Tensor, | ||||||||||||||||||
| B: torch.Tensor, | ||||||||||||||||||
| A_scale_shard: torch.Tensor, | ||||||||||||||||||
| B_scale: torch.Tensor, | ||||||||||||||||||
| alpha: torch.Tensor, | ||||||||||||||||||
| gather_dim: int, | ||||||||||||||||||
| group_name: str, | ||||||||||||||||||
| out_dtype: torch.dtype | None = None, | ||||||||||||||||||
| view_a_scale_as_fp8: bool = False, | ||||||||||||||||||
| use_8x4_sf_layout: bool = False, | ||||||||||||||||||
| backend: str = "cutlass", | ||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||
| world_size = c10d._resolve_process_group(group_name).size() | ||||||||||||||||||
| output_shape = list(A_shard.shape) | ||||||||||||||||||
| output_shape[gather_dim] *= world_size | ||||||||||||||||||
| output_shape[-1] = B.shape[1] | ||||||||||||||||||
| return torch.empty( | ||||||||||||||||||
| output_shape, | ||||||||||||||||||
| dtype=out_dtype or torch.bfloat16, | ||||||||||||||||||
| device=A_shard.device, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def fused_all_gather_flashinfer_fp4_matmul( | ||||||||||||||||||
| A_shard: torch.Tensor, | ||||||||||||||||||
| B: torch.Tensor, | ||||||||||||||||||
| A_scale_shard: torch.Tensor, | ||||||||||||||||||
| B_scale: torch.Tensor, | ||||||||||||||||||
| alpha: torch.Tensor, | ||||||||||||||||||
| gather_dim: int, | ||||||||||||||||||
| group_name: str, | ||||||||||||||||||
| out_dtype: torch.dtype | None = None, | ||||||||||||||||||
| view_a_scale_as_fp8: bool = False, | ||||||||||||||||||
| use_8x4_sf_layout: bool = False, | ||||||||||||||||||
| backend: str = "cutlass", | ||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||
| assert gather_dim == 0, ( | ||||||||||||||||||
| "FlashInfer FP4 symm_mem adapter currently only supports gather_dim=0" | ||||||||||||||||||
| ) | ||||||||||||||||||
| assert A_shard.ndim == 2 and A_scale_shard.ndim == 2 and B.ndim == 2, ( | ||||||||||||||||||
| "FlashInfer FP4 symm_mem adapter expects 2D inputs" | ||||||||||||||||||
| ) | ||||||||||||||||||
| if view_a_scale_as_fp8: | ||||||||||||||||||
| A_scale_shard = A_scale_shard.view(torch.float8_e4m3fn) | ||||||||||||||||||
|
|
||||||||||||||||||
| group = c10d._resolve_process_group(group_name) | ||||||||||||||||||
| world_size = group.size() | ||||||||||||||||||
| output = A_shard.new_empty( | ||||||||||||||||||
| A_shard.shape[0] * world_size, | ||||||||||||||||||
| B.shape[1], | ||||||||||||||||||
| dtype=out_dtype or torch.bfloat16, | ||||||||||||||||||
| ) | ||||||||||||||||||
| output_shards = output.chunk(world_size) | ||||||||||||||||||
|
|
||||||||||||||||||
| A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1]) | ||||||||||||||||||
| A_scale = A_scale_shard.new_empty( | ||||||||||||||||||
|
Comment on lines
+285
to
+286
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The intermediate buffers |
||||||||||||||||||
| A_scale_shard.shape[0] * world_size, | ||||||||||||||||||
| A_scale_shard.shape[1], | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| def fp4_shard_consumer(shards: list[torch.Tensor], rank: int) -> None: | ||||||||||||||||||
| _flashinfer_fp4_mm_out( | ||||||||||||||||||
| shards[0], | ||||||||||||||||||
| B, | ||||||||||||||||||
| scale_a=shards[1], | ||||||||||||||||||
| scale_b=B_scale, | ||||||||||||||||||
| alpha=alpha, | ||||||||||||||||||
| out=output_shards[rank], | ||||||||||||||||||
| out_dtype=out_dtype, | ||||||||||||||||||
| use_8x4_sf_layout=use_8x4_sf_layout, | ||||||||||||||||||
| backend=backend, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| torch.distributed._symmetric_memory._pipelined_multi_all_gather_and_consume( | ||||||||||||||||||
| [A_shard, A_scale_shard], | ||||||||||||||||||
| fp4_shard_consumer, | ||||||||||||||||||
| [A, A_scale], | ||||||||||||||||||
| group_name, | ||||||||||||||||||
| False, | ||||||||||||||||||
| ) | ||||||||||||||||||
| return output | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| direct_register_custom_op( | ||||||||||||||||||
| op_name="fused_flashinfer_scaled_matmul_reduce_scatter", | ||||||||||||||||||
| op_func=fused_flashinfer_scaled_matmul_reduce_scatter, | ||||||||||||||||||
|
|
@@ -209,6 +323,12 @@ def fused_all_gather_flashinfer_scaled_matmul( | |||||||||||||||||
| fake_impl=fused_all_gather_flashinfer_scaled_matmul_fake, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| direct_register_custom_op( | ||||||||||||||||||
| op_name="fused_all_gather_flashinfer_fp4_matmul", | ||||||||||||||||||
| op_func=fused_all_gather_flashinfer_fp4_matmul, | ||||||||||||||||||
| fake_impl=fused_all_gather_flashinfer_fp4_matmul_fake, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class BasePattern: | ||||||||||||||||||
| def __init__(self, dtype: torch.dtype, device: str | None) -> None: | ||||||||||||||||||
|
|
@@ -682,6 +802,101 @@ def _replacement( | |||||||||||||||||
| return _replacement | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class FlashInferAllGatherFP4Pattern( | ||||||||||||||||||
| BasePattern, VllmPatternReplacement[..., torch.Tensor] | ||||||||||||||||||
| ): | ||||||||||||||||||
| def __init__( | ||||||||||||||||||
| self, | ||||||||||||||||||
| dtype: torch.dtype, | ||||||||||||||||||
| device: str | None, | ||||||||||||||||||
| backend: str, | ||||||||||||||||||
| use_8x4_sf_layout: bool, | ||||||||||||||||||
| a_scale_view: str, | ||||||||||||||||||
| ) -> None: | ||||||||||||||||||
| super().__init__(dtype, device) | ||||||||||||||||||
| self.backend = backend | ||||||||||||||||||
| self.use_8x4_sf_layout = use_8x4_sf_layout | ||||||||||||||||||
| self.a_scale_view = a_scale_view | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_inputs(self) -> list[torch.Tensor]: | ||||||||||||||||||
| a_shard_2d = torch.empty([8, 8], device=self.device, dtype=torch.uint8) | ||||||||||||||||||
| b_2d = torch.empty([8, 16], device=self.device, dtype=torch.uint8) | ||||||||||||||||||
| a_scale_shard = torch.empty([128, 4], device=self.device, dtype=torch.int32) | ||||||||||||||||||
| b_scale = torch.empty([4, 128], device=self.device, dtype=torch.uint8) | ||||||||||||||||||
| alpha = torch.empty([], device=self.device, dtype=torch.float32) | ||||||||||||||||||
| return [ | ||||||||||||||||||
| a_shard_2d, | ||||||||||||||||||
| b_2d, | ||||||||||||||||||
| a_scale_shard, | ||||||||||||||||||
| b_scale, | ||||||||||||||||||
| alpha, | ||||||||||||||||||
| ] | ||||||||||||||||||
|
|
||||||||||||||||||
| @property | ||||||||||||||||||
| def pattern(self) -> Callable[..., torch.Tensor]: | ||||||||||||||||||
| def _pattern( | ||||||||||||||||||
| a_shard_2d: torch.Tensor, | ||||||||||||||||||
| b_2d: torch.Tensor, | ||||||||||||||||||
| a_scale_shard: torch.Tensor, | ||||||||||||||||||
| b_scale: torch.Tensor, | ||||||||||||||||||
| alpha: torch.Tensor, | ||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||
| all_gather_a = torch.ops.vllm.all_gather.default( | ||||||||||||||||||
| a_shard_2d, | ||||||||||||||||||
| dim=0, | ||||||||||||||||||
| world_size=self.tp_size, | ||||||||||||||||||
| group_name=self.tp.unique_name, | ||||||||||||||||||
| ) | ||||||||||||||||||
| all_gather_a_scale = torch.ops.vllm.all_gather.default( | ||||||||||||||||||
| a_scale_shard, | ||||||||||||||||||
| dim=0, | ||||||||||||||||||
| world_size=self.tp_size, | ||||||||||||||||||
| group_name=self.tp.unique_name, | ||||||||||||||||||
| ) | ||||||||||||||||||
| a_scale = all_gather_a_scale | ||||||||||||||||||
| if self.a_scale_view in ("float8", "float8_uint8"): | ||||||||||||||||||
| a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn) | ||||||||||||||||||
| if self.a_scale_view in ("uint8", "float8_uint8"): | ||||||||||||||||||
| a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8) | ||||||||||||||||||
|
Comment on lines
+857
to
+860
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The double view logic for
Suggested change
|
||||||||||||||||||
| return torch.ops.vllm.flashinfer_mm_fp4.default( | ||||||||||||||||||
| all_gather_a, | ||||||||||||||||||
| b_2d, | ||||||||||||||||||
| a_scale, | ||||||||||||||||||
| b_scale, | ||||||||||||||||||
| alpha, | ||||||||||||||||||
| self.dtype, | ||||||||||||||||||
| self.use_8x4_sf_layout, | ||||||||||||||||||
| self.backend, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| return _pattern | ||||||||||||||||||
|
|
||||||||||||||||||
| @property | ||||||||||||||||||
| def replacement(self) -> Callable[..., torch.Tensor]: | ||||||||||||||||||
| def _replacement( | ||||||||||||||||||
| a_shard_2d: torch.Tensor, | ||||||||||||||||||
| b_2d: torch.Tensor, | ||||||||||||||||||
| a_scale_shard: torch.Tensor, | ||||||||||||||||||
| b_scale: torch.Tensor, | ||||||||||||||||||
| alpha: torch.Tensor, | ||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||
| return torch.ops.vllm.fused_all_gather_flashinfer_fp4_matmul.default( | ||||||||||||||||||
| a_shard_2d, | ||||||||||||||||||
| b_2d, | ||||||||||||||||||
| a_scale_shard, | ||||||||||||||||||
| b_scale, | ||||||||||||||||||
| alpha, | ||||||||||||||||||
| 0, | ||||||||||||||||||
| self.tp.device_group.group_name, | ||||||||||||||||||
| self.dtype, | ||||||||||||||||||
| self.a_scale_view in ("float8", "float8_uint8"), | ||||||||||||||||||
| self.use_8x4_sf_layout, | ||||||||||||||||||
| self.backend, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| return _replacement | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class AsyncTPPass(VllmFusionPatternMatcherPass): | ||||||||||||||||||
| @enable_fake_mode | ||||||||||||||||||
| def __init__(self, config: VllmConfig) -> None: | ||||||||||||||||||
|
|
@@ -718,6 +933,33 @@ def __init__(self, config: VllmConfig) -> None: | |||||||||||||||||
| self.register( | ||||||||||||||||||
| FlashInferBMMFP8ReduceScatterPattern(self.model_dtype, self.device) | ||||||||||||||||||
| ) | ||||||||||||||||||
| if hasattr(torch.ops.vllm, "flashinfer_mm_fp4"): | ||||||||||||||||||
| for backend in ("cutlass", "cudnn"): | ||||||||||||||||||
| for a_scale_view in ("float8_uint8", "uint8"): | ||||||||||||||||||
| self.register( | ||||||||||||||||||
| FlashInferAllGatherFP4Pattern( | ||||||||||||||||||
| self.model_dtype, | ||||||||||||||||||
| self.device, | ||||||||||||||||||
| backend, | ||||||||||||||||||
| use_8x4_sf_layout=False, | ||||||||||||||||||
| a_scale_view=a_scale_view, | ||||||||||||||||||
| ) | ||||||||||||||||||
| ) | ||||||||||||||||||
| for use_8x4_sf_layout in (False, True): | ||||||||||||||||||
| for a_scale_view in ("float8",): | ||||||||||||||||||
| self.register( | ||||||||||||||||||
| FlashInferAllGatherFP4Pattern( | ||||||||||||||||||
| self.model_dtype, | ||||||||||||||||||
| self.device, | ||||||||||||||||||
| "trtllm", | ||||||||||||||||||
| use_8x4_sf_layout=use_8x4_sf_layout, | ||||||||||||||||||
| a_scale_view=a_scale_view, | ||||||||||||||||||
| ) | ||||||||||||||||||
| ) | ||||||||||||||||||
| # NVFP4 activation scales are block/group scales, not FP8 | ||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, thinking about this again, isn't reduce scatter trivial? Inputs are already column-parallel across ranks, so each rank has the appropriate scales and inputs only. Output is full size but it's activations only (and partial numerically), so reduction is needed but only on the output, no scale comms need to be involved. Am I missing something? |
||||||||||||||||||
| # row-wise scales. Register only the all-gather path until the | ||||||||||||||||||
| # reduce-scatter side has a dedicated NVFP4 scale-sharding | ||||||||||||||||||
| # implementation. | ||||||||||||||||||
|
|
||||||||||||||||||
| self.dump_patterns(config, self.pm_pass) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's set this on llama-fp4 model directly?