diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 01fd9f9a1c8e..87651c3fd05e 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -19,7 +19,7 @@ ) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass @@ -169,16 +169,37 @@ def replacement( scale_a: torch.Tensor, scale_b: torch.Tensor, ) -> torch.Tensor: - gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( - input, - mat2, - scale_a, - scale_b, - "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, - ) + if is_torch_equal_or_newer("2.8.0.dev"): + # TODO: This fails in the dynamic shapes case because the shapes + # get specialized + output_shape = ( + torch.ops.aten.sym_size.int(input, 0), + torch.ops.aten.sym_size.int(mat2, 1), + ) + gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + input, + mat2, + scale_a, + scale_b, + "avg", + orig_scatter_dim=0, + scatter_dim_after_maybe_reshape=0, + output_shape=output_shape, + out_dtype=self.dtype, + group_name=self.tp.device_group.group_name, + ) + else: + # For older versions, use the old signature + gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + input, + mat2, + scale_a, + scale_b, + "avg", + scatter_dim=0, + out_dtype=self.dtype, + group_name=self.tp.device_group.group_name, + ) return gemm_rs @@ -296,16 +317,38 @@ def replacement( scale_b: torch.Tensor, cutlass_mm_output: torch.Tensor, ) -> torch.Tensor: - gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( - input, - mat2, - scale_a, - scale_b, - "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, - ) + if is_torch_equal_or_newer("2.8.0.dev"): + # TODO: This fails in the dynamic shapes case because the shapes + # get specialized + output_shape = ( + torch.ops.aten.sym_size.int(input, 0), + torch.ops.aten.sym_size.int(mat2, 1), + ) + + gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + input, + mat2, + scale_a, + scale_b, + "avg", + orig_scatter_dim=0, + scatter_dim_after_maybe_reshape=0, + output_shape=output_shape, + out_dtype=self.dtype, + group_name=self.tp.device_group.group_name, + ) + else: + # For older versions, use the old signature + gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + input, + mat2, + scale_a, + scale_b, + "avg", + scatter_dim=0, + out_dtype=self.dtype, + group_name=self.tp.device_group.group_name, + ) return gemm_rs