Skip to content

Commit 16f8822

Browse files
authored
[Transform][Relax] Handle is_group argument in IPC AllReduce (#17201)
* [Transform][Relax] Handle `is_group` argument in IPC AllReduce The `relax.transform.IPCAllReduceRewrite` pass rewrites calls to `"runtime.disco.allreduce"` to instead call an optimized `"runtime.disco.cuda_ipc.custom_allreduce"` version. When the legalization of `R.ccl.allreduce` was updated in #17180 to provide an `in_group` argument, the `IPCAllReduceRewrite` pass was not updated. This commit updates the `IPCAllReduceRewrite` to be handle the additional `in_group` argument. * lint fix * lint fix
1 parent 9e88018 commit 16f8822

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

python/tvm/relax/transform/ipc_allreduce_rewrite.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re
9797
# Return if the call is not a summation all-reduce.
9898
return
9999

100-
assert len(call.args) == 3
101-
allreduce_input = call.args[0]
100+
assert len(call.args) == 4
101+
allreduce_input, _strategy, _ingroup, allreduce_output = call.args
102102
alloc_tensor = self.alloc_map.get(allreduce_input, None)
103103
if alloc_tensor is None or alloc_tensor.args[3].value != "global":
104104
# Return if the allocation of all-reduce input is not recorded,
@@ -113,9 +113,13 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re
113113
alloc_tensor.args[2],
114114
relax.StringImm("ipc_memory"),
115115
)
116+
116117
self.binding_replacement_map[call] = relax.Call(
117118
relax.ExternFunc("runtime.disco.cuda_ipc.custom_allreduce"),
118-
args=[call.args[0], relax.PrimValue(self.allreduce_strategy), call.args[2]],
119+
# The "cuda_ipc.custom_allreduce" implementation does not
120+
# yet support num_groups>1, and therefore does not use the
121+
# `in_group` argument.
122+
[allreduce_input, relax.PrimValue(self.allreduce_strategy), allreduce_output],
119123
)
120124

121125

tests/python/relax/test_transform_ipc_allreduce_rewrite.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore
3737
alloc1: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
3838
R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global")
3939
)
40-
_: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([0]), alloc1)
40+
_: R.Object = R.call_packed(
41+
"runtime.disco.allreduce", lv1, R.shape([0]), R.prim_value(True), alloc1
42+
)
4143
return alloc1
4244

4345
@I.ir_module
@@ -85,7 +87,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore
8587
alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
8688
R.shape([m * n]), R.dtype("float16"), R.prim_value(0), R.str("global")
8789
)
88-
_: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([0]), alloc1)
90+
_: R.Object = R.call_packed(
91+
"runtime.disco.allreduce", lv1, R.shape([0]), R.prim_value(False), alloc1
92+
)
8993
return alloc1
9094

9195
@I.ir_module
@@ -137,7 +141,9 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore
137141
alloc1: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
138142
R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global")
139143
)
140-
_: R.Object = R.call_packed("runtime.disco.allreduce", lv1, R.shape([1]), alloc1)
144+
_: R.Object = R.call_packed(
145+
"runtime.disco.allreduce", lv1, R.shape([1]), R.prim_value(True), alloc1
146+
)
141147
return alloc1
142148

143149
allreduce_strategy = 1
@@ -146,6 +152,4 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore
146152

147153

148154
if __name__ == "__main__":
149-
test_ipc_allreduce_rewrite()
150-
test_ipc_allreduce_spread_along_reshape()
151-
test_ipc_allreduce_skip_reducer_other_than_sum()
155+
tvm.testing.main()

0 commit comments

Comments
 (0)