@@ -625,11 +625,6 @@ def _fused_all_gather_matmul_last_gather_dim_impl(
625625 def unflatten (t : torch .Tensor ) -> torch .Tensor :
626626 return t .view (* leading_dims , - 1 )
627627
628- A_out_leading_dims = list (A_shard .shape [:- 1 ])
629-
630- def unflatten_A_out (t : torch .Tensor ) -> torch .Tensor :
631- return t .view (* A_out_leading_dims , - 1 )
632-
633628 A_flat_out = A_shard_flat .new_empty (
634629 A_shard_flat .shape [0 ] * group .size (),
635630 A_shard_flat .shape [1 ],
@@ -645,19 +640,17 @@ def unflatten_A_out(t: torch.Tensor) -> torch.Tensor:
645640 for B , out_dtype in zip (Bs , out_dtypes )
646641 ]
647642
648- # Additional allocation for partials output,
649- # That will be reduced into output.
650- output_partials = [torch .empty_like (out ) for out in outputs ]
651-
652643 first = True
653644
654645 def default_consumer (shard : torch .Tensor , rank : int ) -> None :
655646 nonlocal first
656647 for idx , (B , kwargs ) in enumerate (zip (Bs , kwargs_list )):
657- out = outputs [idx ] if first else output_partials [idx ]
658- mm_out_op (shard , B_shards [idx ][rank ], ** kwargs , out = out )
659- if not first :
660- outputs [idx ] += output_partials [idx ]
648+ out = outputs [idx ]
649+ if first :
650+ torch .ops .aten .mm .out (shard , B_shards [idx ][rank ], ** kwargs , out = out )
651+ else :
652+ out .addmm_ (shard , B_shards [idx ][rank ])
653+
661654 first = False
662655
663656 _pipelined_all_gather_and_consume_last_dim (
@@ -672,7 +665,7 @@ def default_consumer(shard: torch.Tensor, rank: int) -> None:
672665 # This path is inefficient and will be filtered out at passes stage
673666 # Added only for completness.
674667 A_split_cat_out_flat = torch .cat (A_flat_out .chunk (group_size ), dim = - 1 )
675- ret_A = unflatten_A_out (A_split_cat_out_flat )
668+ ret_A = unflatten (A_split_cat_out_flat )
676669
677670 return ret_A , [unflatten (output ) for output in outputs ]
678671
0 commit comments