Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 64 additions & 21 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer

from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
Expand Down Expand Up @@ -169,16 +169,37 @@ def replacement(
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
)
if is_torch_equal_or_newer("2.8.0.dev"):
# TODO: This fails in the dynamic shapes case because the shapes
# get specialized
output_shape = (
torch.ops.aten.sym_size.int(input, 0),
torch.ops.aten.sym_size.int(mat2, 1),
)
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
orig_scatter_dim=0,
scatter_dim_after_maybe_reshape=0,
output_shape=output_shape,
out_dtype=self.dtype,
Comment on lines +172 to +188

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Compute reduce-scatter output_shape per shard

When targeting the new fused_scaled_matmul_reduce_scatter signature, output_shape is derived directly from input and mat2 without accounting for the tensor-parallel world size. In the unfused graph this matmul output is immediately reduced along dim‑0, so each rank ultimately sees a first dimension of scaled_mm.size(0) // tp_world_size. Passing the pre‑scatter size (input.shape[0]) will request the wrong shape from the fused op and either misallocate or fail once torch 2.8 executes this branch. output_shape[0] should reflect the reduce-scatter result (divide by self.tp_size).

Useful? React with 👍 / 👎.

group_name=self.tp.device_group.group_name,
)
else:
# For older versions, use the old signature
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
)

return gemm_rs
Comment on lines +172 to 204
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This version-checking logic for fused_scaled_matmul_reduce_scatter is duplicated in CutlassScaledMMReduceScatterPattern.replacement (lines 320-353). To improve maintainability and reduce redundancy, consider extracting this logic into a shared helper function or method. This would centralize the call to the op, making future updates easier.


Expand Down Expand Up @@ -296,16 +317,38 @@ def replacement(
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
)
if is_torch_equal_or_newer("2.8.0.dev"):
# TODO: This fails in the dynamic shapes case because the shapes
# get specialized
output_shape = (
torch.ops.aten.sym_size.int(input, 0),
torch.ops.aten.sym_size.int(mat2, 1),
)

gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
orig_scatter_dim=0,
scatter_dim_after_maybe_reshape=0,
output_shape=output_shape,
out_dtype=self.dtype,
Comment on lines +320 to +337

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Cutlass fused path also uses full matmul size for output_shape

The cutlass variant builds output_shape with the original matmul dimensions even though the fused operator returns the reduce-scatter shard. On a torch version ≥2.8 this means output_shape[0] stays the full input length instead of the per‑rank length (input.shape[0] // tp_world_size), leading to shape mismatches or execution failures when the new signature is exercised. The output shape passed to the op must match the size after scattering, not before.

Useful? React with 👍 / 👎.

group_name=self.tp.device_group.group_name,
)
else:
# For older versions, use the old signature
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
)

return gemm_rs

Expand Down