3333aten = torch .ops .aten
3434patterns = PatternMatcherPass ()
3535
36- _micro_pipeline_tp_ag_transpose_mm_enabled = True
36+ # Configs:
37+ _ag_transpose_mm_enabled = False
38+ _ag_mm_last_dim_enabled = True
39+ _ag_mm_last_dim_no_splitcat_use = False
40+ _mm_rs_last_dim_enabled = True
3741
38- # Check performance if overhead of decomposition outweights pipeline wins
39- _micro_pipeline_tp_ag_mm_last_dim_enabled = True
40- _micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled = True
4142
42- _micro_pipeline_tp_mm_rs_last_dim_enabled = True
43+ def _is_last_dim (t : torch .Tensor , dim : int ) -> bool :
44+ return dim == t .ndim - 1 or dim == - 1
4345
4446
4547def _is_backward (graph : torch .fx .Graph ) -> bool :
@@ -617,7 +619,7 @@ def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]:
617619 matmul = _ScaledMatmul .from_match ([user ])
618620 matmuls .append (matmul )
619621 elif (
620- _micro_pipeline_tp_ag_transpose_mm_enabled
622+ _ag_transpose_mm_enabled
621623 and user .target == aten .permute .default
622624 and (user .args [1 ] == [1 , 0 ] or user .args [1 ] == [0 , 1 ])
623625 ):
@@ -762,11 +764,12 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
762764 if not is_symm_mem_enabled_for_group (group_name ):
763765 return
764766
765- if (
766- not _micro_pipeline_tp_ag_mm_last_dim_enabled
767- and gather_dim == _get_tensor (shard_node ).ndim - 1
768- ):
769- return
767+ if _is_last_dim (_get_tensor (shard_node ), gather_dim ):
768+ if not _ag_mm_last_dim_enabled :
769+ return
770+
771+ if _get_tensor (shard_node ).shape [- 1 ] < 1024 :
772+ return
770773
771774 # Find consumer matmuls
772775 matmuls = _find_consumer_matmuls (ag_res_node )
@@ -784,13 +787,12 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
784787 return
785788
786789 if (
787- _micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled
788- and gather_dim == _get_tensor (shard_node ). ndim - 1
790+ _ag_mm_last_dim_no_splitcat_use
791+ and _is_last_dim ( _get_tensor (shard_node ), gather_dim )
789792 and len (all_gather .res_node .users ) > len (matmuls )
790793 ):
791794 # The result of ag-split-cat is used not only in matmuls.
792795 # Then it has to be materialized, which can have overhead.
793- # TODO: find out conditions of strideness when there is no overhead.
794796 log_strs .append (
795797 f"fuse_agmm lastdim ag-split-cat { len (all_gather .res_node .users )} used more than num matmuls"
796798 )
@@ -837,15 +839,16 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
837839 matmul .replace_with (new_out_node )
838840 matmul .erase ()
839841 else :
840- if "val" in shard_node .meta :
841- restrided = restride_A_shard_for_fused_all_gather_matmul (
842- _get_tensor (shard_node ),
843- gather_dim ,
844- )
845- shard_node = graph .call_function (
846- inductor_prims .force_stride_order ,
847- args = (shard_node , restrided .stride ()),
848- )
842+ if not _is_last_dim (_get_tensor (shard_node ), gather_dim ):
843+ if "val" in shard_node .meta :
844+ restrided = restride_A_shard_for_fused_all_gather_matmul (
845+ _get_tensor (shard_node ),
846+ gather_dim ,
847+ )
848+ shard_node = graph .call_function (
849+ inductor_prims .force_stride_order ,
850+ args = (shard_node , restrided .stride ()),
851+ )
849852 fused_node = _insert_fused_all_gather_matmul (
850853 graph , matmuls , shard_node , gather_dim , group_name
851854 )
@@ -1055,11 +1058,14 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) ->
10551058 log_strs .append ("fuse_mmrs not symm mem group" )
10561059 return
10571060
1058- if (
1059- not _micro_pipeline_tp_mm_rs_last_dim_enabled
1060- and orig_scatter_dim == _get_tensor (input_node ).ndim - 1
1061- ):
1062- return
1061+ if _is_last_dim (_get_tensor (input_node ), orig_scatter_dim ):
1062+ if not _mm_rs_last_dim_enabled :
1063+ return
1064+
1065+ group = torch ._C ._distributed_c10d ._resolve_process_group (group_name )
1066+ group_size = group .size ()
1067+ if _get_tensor (input_node ).shape [- 1 ] // group_size < 1024 :
1068+ return
10631069
10641070 # Currently fused_matmul_reduce_scatter doesn't return the matmul result,
10651071 # so we can't apply the fusion if the matmul result is used by multiple
@@ -1113,16 +1119,17 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) ->
11131119
11141120 graph = rs_wait_tensor_node .graph
11151121 with graph .inserting_before (rs_wait_tensor_node ):
1116- # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
1117- if "val" in matmul .A_node .meta :
1118- restrided = restride_A_for_fused_matmul_reduce_scatter (
1119- _get_tensor (matmul .A_node ),
1120- scatter_dim_after_maybe_reshape ,
1121- )
1122- matmul .A_node = graph .call_function (
1123- inductor_prims .force_stride_order ,
1124- args = (matmul .A_node , restrided .stride ()),
1125- )
1122+ if not _is_last_dim (_get_tensor (input_node ), orig_scatter_dim ):
1123+ # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
1124+ if "val" in matmul .A_node .meta :
1125+ restrided = restride_A_for_fused_matmul_reduce_scatter (
1126+ _get_tensor (matmul .A_node ),
1127+ scatter_dim_after_maybe_reshape ,
1128+ )
1129+ matmul .A_node = graph .call_function (
1130+ inductor_prims .force_stride_order ,
1131+ args = (matmul .A_node , restrided .stride ()),
1132+ )
11261133
11271134 # Replace matched subgraph with fused matmul reduce scatter node
11281135 fused_node = _insert_fused_matmul_reduce_scatter (
@@ -1295,11 +1302,12 @@ def micro_pipeline_tp_pass(
12951302 "async TP found no matching all-gather/reduce-scatter patterns for fusion"
12961303 )
12971304
1305+ for all_gather in all_gathers :
1306+ fuse_all_gather_matmul (all_gather , log_strs )
1307+
12981308 for reduce_scatter in reduce_scatters :
12991309 fuse_matmul_reduce_scatter (reduce_scatter , log_strs )
13001310
1301- for all_gather in all_gathers :
1302- fuse_all_gather_matmul (all_gather , log_strs )
13031311 trace_structured (
13041312 "artifact" ,
13051313 metadata_fn = lambda : {
0 commit comments