Skip to content

Conversation

@angelayi
Copy link
Contributor

@angelayi angelayi commented Oct 9, 2025

Purpose

The signature of symm_mem::fused_scaled_matmul_reduce_scatter was changed in pytorch/pytorch#149247 to additionally have the arguments output_shape, orig_scatter_dim, scatter_dim_after_maybe_reshape. vllm's async_tp pass that has a replacement graph with this op but uses the wrong signature (https://github.com/vllm-project/vllm/blob/main/vllm/compilation/collective_fusion.py#L172), causing an error.

Although there is a test case, looks like it's not actually run on CI?

image

Test Plan

pytest tests/compile/test_async_tp.py -k test_async_tp_pass_replace locally passes

cc @ProExpertProg @cascade812

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a breaking change in the fused_scaled_matmul_reduce_scatter op signature from a newer PyTorch version by adding version-conditional logic. My main feedback is to refactor the duplicated code blocks into a shared helper function to improve maintainability.

Comment on lines +172 to 204
if is_torch_equal_or_newer("2.8.0.dev"):
# TODO: This fails in the dynamic shapes case because the shapes
# get specialized
output_shape = (
torch.ops.aten.sym_size.int(input, 0),
torch.ops.aten.sym_size.int(mat2, 1),
)
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
orig_scatter_dim=0,
scatter_dim_after_maybe_reshape=0,
output_shape=output_shape,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
)
else:
# For older versions, use the old signature
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
)

return gemm_rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This version-checking logic for fused_scaled_matmul_reduce_scatter is duplicated in CutlassScaledMMReduceScatterPattern.replacement (lines 320-353). To improve maintainability and reduce redundancy, consider extracting this logic into a shared helper function or method. This would centralize the call to the op, making future updates easier.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines +172 to +188
if is_torch_equal_or_newer("2.8.0.dev"):
# TODO: This fails in the dynamic shapes case because the shapes
# get specialized
output_shape = (
torch.ops.aten.sym_size.int(input, 0),
torch.ops.aten.sym_size.int(mat2, 1),
)
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
orig_scatter_dim=0,
scatter_dim_after_maybe_reshape=0,
output_shape=output_shape,
out_dtype=self.dtype,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Compute reduce-scatter output_shape per shard

When targeting the new fused_scaled_matmul_reduce_scatter signature, output_shape is derived directly from input and mat2 without accounting for the tensor-parallel world size. In the unfused graph this matmul output is immediately reduced along dim‑0, so each rank ultimately sees a first dimension of scaled_mm.size(0) // tp_world_size. Passing the pre‑scatter size (input.shape[0]) will request the wrong shape from the fused op and either misallocate or fail once torch 2.8 executes this branch. output_shape[0] should reflect the reduce-scatter result (divide by self.tp_size).

Useful? React with 👍 / 👎.

Comment on lines +320 to +337
if is_torch_equal_or_newer("2.8.0.dev"):
# TODO: This fails in the dynamic shapes case because the shapes
# get specialized
output_shape = (
torch.ops.aten.sym_size.int(input, 0),
torch.ops.aten.sym_size.int(mat2, 1),
)

gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
orig_scatter_dim=0,
scatter_dim_after_maybe_reshape=0,
output_shape=output_shape,
out_dtype=self.dtype,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Cutlass fused path also uses full matmul size for output_shape

The cutlass variant builds output_shape with the original matmul dimensions even though the fused operator returns the reduce-scatter shard. On a torch version ≥2.8 this means output_shape[0] stays the full input length instead of the per‑rank length (input.shape[0] // tp_world_size), leading to shape mismatches or execution failures when the new signature is exercised. The output shape passed to the op must match the size after scattering, not before.

Useful? React with 👍 / 👎.

@angelayi
Copy link
Contributor Author

angelayi commented Oct 9, 2025

closing as #26038 already fixes it

@angelayi angelayi closed this Oct 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant