diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 71274420c342..6723854c3e92 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -162,9 +162,12 @@ def replacement(input: torch.Tensor, mat2: torch.Tensor, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + 0, # scatter_dim + self.tp.device_group.group_name, # group_name + None, # bias_node + None, # result_scale_node + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs @@ -274,9 +277,12 @@ def replacement(input: torch.Tensor, mat2: torch.Tensor, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + 0, # scatter_dim + self.tp.device_group.group_name, # group_name + None, # bias_node + None, # result_scale_node + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs