Skip to content

Commit 357dd7e

Browse files
committed
[asynctp] Optimize agmm lastdim via addmm_
stack-info: PR: #190, branch: IvanKobzarev/stack/7
1 parent 0df823d commit 357dd7e

File tree

2 files changed

+56
-58
lines changed

2 files changed

+56
-58
lines changed

autoparallel/asynctp.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@
3333
aten = torch.ops.aten
3434
patterns = PatternMatcherPass()
3535

36-
_micro_pipeline_tp_ag_transpose_mm_enabled = True
36+
# Configs:
37+
_ag_transpose_mm_enabled = False
38+
_ag_mm_last_dim_enabled = True
39+
_ag_mm_last_dim_no_splitcat_use = False
40+
_mm_rs_last_dim_enabled = True
3741

38-
# Check performance if overhead of decomposition outweights pipeline wins
39-
_micro_pipeline_tp_ag_mm_last_dim_enabled = True
40-
_micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled = True
4142

42-
_micro_pipeline_tp_mm_rs_last_dim_enabled = True
43+
def _is_last_dim(t: torch.Tensor, dim: int) -> bool:
44+
return dim == t.ndim - 1 or dim == -1
4345

4446

4547
def _is_backward(graph: torch.fx.Graph) -> bool:
@@ -617,7 +619,7 @@ def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]:
617619
matmul = _ScaledMatmul.from_match([user])
618620
matmuls.append(matmul)
619621
elif (
620-
_micro_pipeline_tp_ag_transpose_mm_enabled
622+
_ag_transpose_mm_enabled
621623
and user.target == aten.permute.default
622624
and (user.args[1] == [1, 0] or user.args[1] == [0, 1])
623625
):
@@ -762,11 +764,12 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
762764
if not is_symm_mem_enabled_for_group(group_name):
763765
return
764766

765-
if (
766-
not _micro_pipeline_tp_ag_mm_last_dim_enabled
767-
and gather_dim == _get_tensor(shard_node).ndim - 1
768-
):
769-
return
767+
if _is_last_dim(_get_tensor(shard_node), gather_dim):
768+
if not _ag_mm_last_dim_enabled:
769+
return
770+
771+
if _get_tensor(shard_node).shape[-1] < 1024:
772+
return
770773

771774
# Find consumer matmuls
772775
matmuls = _find_consumer_matmuls(ag_res_node)
@@ -784,13 +787,12 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
784787
return
785788

786789
if (
787-
_micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled
788-
and gather_dim == _get_tensor(shard_node).ndim - 1
790+
_ag_mm_last_dim_no_splitcat_use
791+
and _is_last_dim(_get_tensor(shard_node), gather_dim)
789792
and len(all_gather.res_node.users) > len(matmuls)
790793
):
791794
# The result of ag-split-cat is used not only in matmuls.
792795
# Then it has to be materialized, which can have overhead.
793-
# TODO: find out conditions of strideness when there is no overhead.
794796
log_strs.append(
795797
f"fuse_agmm lastdim ag-split-cat {len(all_gather.res_node.users)} used more than num matmuls"
796798
)
@@ -837,15 +839,16 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
837839
matmul.replace_with(new_out_node)
838840
matmul.erase()
839841
else:
840-
if "val" in shard_node.meta:
841-
restrided = restride_A_shard_for_fused_all_gather_matmul(
842-
_get_tensor(shard_node),
843-
gather_dim,
844-
)
845-
shard_node = graph.call_function(
846-
inductor_prims.force_stride_order,
847-
args=(shard_node, restrided.stride()),
848-
)
842+
if not _is_last_dim(_get_tensor(shard_node), gather_dim):
843+
if "val" in shard_node.meta:
844+
restrided = restride_A_shard_for_fused_all_gather_matmul(
845+
_get_tensor(shard_node),
846+
gather_dim,
847+
)
848+
shard_node = graph.call_function(
849+
inductor_prims.force_stride_order,
850+
args=(shard_node, restrided.stride()),
851+
)
849852
fused_node = _insert_fused_all_gather_matmul(
850853
graph, matmuls, shard_node, gather_dim, group_name
851854
)
@@ -1055,11 +1058,14 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) ->
10551058
log_strs.append("fuse_mmrs not symm mem group")
10561059
return
10571060

1058-
if (
1059-
not _micro_pipeline_tp_mm_rs_last_dim_enabled
1060-
and orig_scatter_dim == _get_tensor(input_node).ndim - 1
1061-
):
1062-
return
1061+
if _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
1062+
if not _mm_rs_last_dim_enabled:
1063+
return
1064+
1065+
group = torch._C._distributed_c10d._resolve_process_group(group_name)
1066+
group_size = group.size()
1067+
if _get_tensor(input_node).shape[-1] // group_size < 1024:
1068+
return
10631069

10641070
# Currently fused_matmul_reduce_scatter doesn't return the matmul result,
10651071
# so we can't apply the fusion if the matmul result is used by multiple
@@ -1113,16 +1119,17 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) ->
11131119

11141120
graph = rs_wait_tensor_node.graph
11151121
with graph.inserting_before(rs_wait_tensor_node):
1116-
# Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
1117-
if "val" in matmul.A_node.meta:
1118-
restrided = restride_A_for_fused_matmul_reduce_scatter(
1119-
_get_tensor(matmul.A_node),
1120-
scatter_dim_after_maybe_reshape,
1121-
)
1122-
matmul.A_node = graph.call_function(
1123-
inductor_prims.force_stride_order,
1124-
args=(matmul.A_node, restrided.stride()),
1125-
)
1122+
if not _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
1123+
# Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
1124+
if "val" in matmul.A_node.meta:
1125+
restrided = restride_A_for_fused_matmul_reduce_scatter(
1126+
_get_tensor(matmul.A_node),
1127+
scatter_dim_after_maybe_reshape,
1128+
)
1129+
matmul.A_node = graph.call_function(
1130+
inductor_prims.force_stride_order,
1131+
args=(matmul.A_node, restrided.stride()),
1132+
)
11261133

11271134
# Replace matched subgraph with fused matmul reduce scatter node
11281135
fused_node = _insert_fused_matmul_reduce_scatter(
@@ -1295,11 +1302,12 @@ def micro_pipeline_tp_pass(
12951302
"async TP found no matching all-gather/reduce-scatter patterns for fusion"
12961303
)
12971304

1305+
for all_gather in all_gathers:
1306+
fuse_all_gather_matmul(all_gather, log_strs)
1307+
12981308
for reduce_scatter in reduce_scatters:
12991309
fuse_matmul_reduce_scatter(reduce_scatter, log_strs)
13001310

1301-
for all_gather in all_gathers:
1302-
fuse_all_gather_matmul(all_gather, log_strs)
13031311
trace_structured(
13041312
"artifact",
13051313
metadata_fn=lambda: {

autoparallel/asynctp_ops.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,6 @@ def _fused_all_gather_matmul_impl(
418418
group = c10d._resolve_process_group(group_name)
419419

420420
if gather_dim == A_shard.ndim - 1:
421-
# Implementation for gathering on last dimension of matmul (N)
422-
# A_shard splitted column wise
423-
# A_shard: [A0, A1, ... , Ags]
424421
return _fused_all_gather_matmul_last_gather_dim_impl(
425422
mm_out_op,
426423
A_shard,
@@ -625,11 +622,6 @@ def _fused_all_gather_matmul_last_gather_dim_impl(
625622
def unflatten(t: torch.Tensor) -> torch.Tensor:
626623
return t.view(*leading_dims, -1)
627624

628-
A_out_leading_dims = list(A_shard.shape[:-1])
629-
630-
def unflatten_A_out(t: torch.Tensor) -> torch.Tensor:
631-
return t.view(*A_out_leading_dims, -1)
632-
633625
A_flat_out = A_shard_flat.new_empty(
634626
A_shard_flat.shape[0] * group.size(),
635627
A_shard_flat.shape[1],
@@ -645,19 +637,17 @@ def unflatten_A_out(t: torch.Tensor) -> torch.Tensor:
645637
for B, out_dtype in zip(Bs, out_dtypes)
646638
]
647639

648-
# Additional allocation for partials output,
649-
# That will be reduced into output.
650-
output_partials = [torch.empty_like(out) for out in outputs]
651-
652640
first = True
653641

654642
def default_consumer(shard: torch.Tensor, rank: int) -> None:
655643
nonlocal first
656644
for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)):
657-
out = outputs[idx] if first else output_partials[idx]
658-
mm_out_op(shard, B_shards[idx][rank], **kwargs, out=out)
659-
if not first:
660-
outputs[idx] += output_partials[idx]
645+
out = outputs[idx]
646+
if first:
647+
torch.ops.aten.mm.out(shard, B_shards[idx][rank], **kwargs, out=out)
648+
else:
649+
out.addmm_(shard, B_shards[idx][rank])
650+
661651
first = False
662652

663653
_pipelined_all_gather_and_consume_last_dim(
@@ -672,7 +662,7 @@ def default_consumer(shard: torch.Tensor, rank: int) -> None:
672662
# This path is inefficient and will be filtered out at passes stage
673663
# Added only for completness.
674664
A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1)
675-
ret_A = unflatten_A_out(A_split_cat_out_flat)
665+
ret_A = unflatten(A_split_cat_out_flat)
676666

677667
return ret_A, [unflatten(output) for output in outputs]
678668

@@ -1134,7 +1124,7 @@ def _fused_matmul_reduce_scatter_impl(
11341124
out_shape = [*A.shape[:-1], B.shape[1]]
11351125
out_shape[scatter_dim] //= group.size()
11361126

1137-
if scatter_dim == A.ndim - 1:
1127+
if scatter_dim == A.ndim - 1 or scatter_dim == -1:
11381128
B_shards = B.chunk(group.size(), dim=B.ndim - 1)
11391129
A_flat = A.flatten(0, -2)
11401130

0 commit comments

Comments
 (0)