-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
Fix fused_scaled_matmul_reduce_scatter callsite #26506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: angelayi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a breaking change in the fused_scaled_matmul_reduce_scatter op signature from a newer PyTorch version by adding version-conditional logic. My main feedback is to refactor the duplicated code blocks into a shared helper function to improve maintainability.
| if is_torch_equal_or_newer("2.8.0.dev"): | ||
| # TODO: This fails in the dynamic shapes case because the shapes | ||
| # get specialized | ||
| output_shape = ( | ||
| torch.ops.aten.sym_size.int(input, 0), | ||
| torch.ops.aten.sym_size.int(mat2, 1), | ||
| ) | ||
| gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( | ||
| input, | ||
| mat2, | ||
| scale_a, | ||
| scale_b, | ||
| "avg", | ||
| orig_scatter_dim=0, | ||
| scatter_dim_after_maybe_reshape=0, | ||
| output_shape=output_shape, | ||
| out_dtype=self.dtype, | ||
| group_name=self.tp.device_group.group_name, | ||
| ) | ||
| else: | ||
| # For older versions, use the old signature | ||
| gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( | ||
| input, | ||
| mat2, | ||
| scale_a, | ||
| scale_b, | ||
| "avg", | ||
| scatter_dim=0, | ||
| out_dtype=self.dtype, | ||
| group_name=self.tp.device_group.group_name, | ||
| ) | ||
|
|
||
| return gemm_rs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This version-checking logic for fused_scaled_matmul_reduce_scatter is duplicated in CutlassScaledMMReduceScatterPattern.replacement (lines 320-353). To improve maintainability and reduce redundancy, consider extracting this logic into a shared helper function or method. This would centralize the call to the op, making future updates easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
| if is_torch_equal_or_newer("2.8.0.dev"): | ||
| # TODO: This fails in the dynamic shapes case because the shapes | ||
| # get specialized | ||
| output_shape = ( | ||
| torch.ops.aten.sym_size.int(input, 0), | ||
| torch.ops.aten.sym_size.int(mat2, 1), | ||
| ) | ||
| gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( | ||
| input, | ||
| mat2, | ||
| scale_a, | ||
| scale_b, | ||
| "avg", | ||
| orig_scatter_dim=0, | ||
| scatter_dim_after_maybe_reshape=0, | ||
| output_shape=output_shape, | ||
| out_dtype=self.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compute reduce-scatter output_shape per shard
When targeting the new fused_scaled_matmul_reduce_scatter signature, output_shape is derived directly from input and mat2 without accounting for the tensor-parallel world size. In the unfused graph this matmul output is immediately reduced along dim‑0, so each rank ultimately sees a first dimension of scaled_mm.size(0) // tp_world_size. Passing the pre‑scatter size (input.shape[0]) will request the wrong shape from the fused op and either misallocate or fail once torch 2.8 executes this branch. output_shape[0] should reflect the reduce-scatter result (divide by self.tp_size).
Useful? React with 👍 / 👎.
| if is_torch_equal_or_newer("2.8.0.dev"): | ||
| # TODO: This fails in the dynamic shapes case because the shapes | ||
| # get specialized | ||
| output_shape = ( | ||
| torch.ops.aten.sym_size.int(input, 0), | ||
| torch.ops.aten.sym_size.int(mat2, 1), | ||
| ) | ||
|
|
||
| gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( | ||
| input, | ||
| mat2, | ||
| scale_a, | ||
| scale_b, | ||
| "avg", | ||
| orig_scatter_dim=0, | ||
| scatter_dim_after_maybe_reshape=0, | ||
| output_shape=output_shape, | ||
| out_dtype=self.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cutlass fused path also uses full matmul size for output_shape
The cutlass variant builds output_shape with the original matmul dimensions even though the fused operator returns the reduce-scatter shard. On a torch version ≥2.8 this means output_shape[0] stays the full input length instead of the per‑rank length (input.shape[0] // tp_world_size), leading to shape mismatches or execution failures when the new signature is exercised. The output shape passed to the op must match the size after scattering, not before.
Useful? React with 👍 / 👎.
|
closing as #26038 already fixes it |
Purpose
The signature of
symm_mem::fused_scaled_matmul_reduce_scatterwas changed in pytorch/pytorch#149247 to additionally have the argumentsoutput_shape, orig_scatter_dim, scatter_dim_after_maybe_reshape. vllm's async_tp pass that has a replacement graph with this op but uses the wrong signature (https://github.com/vllm-project/vllm/blob/main/vllm/compilation/collective_fusion.py#L172), causing an error.Although there is a test case, looks like it's not actually run on CI?
Test Plan
pytest tests/compile/test_async_tp.py -k test_async_tp_pass_replacelocally passescc @ProExpertProg @cascade812