Skip to content

Commit 64ab31e

Browse files
authored
[UnitTest][Metal] Parametrize allreduce GPU tests (#15749)
* [UnitTest][Metal] Parametrize allreduce GPU tests As a first step to addressing the Metal codegen errors that required the reversion in #15725, parametrizing the unit tests for `allreduce`. While these tests are parametrized with `@tvm.testing.parametrize_targets("cuda", "metal")`, the automatic `tvm.testing.requires_metal` marker inserted for the metal parametrization will cause them to be skipped if the metal runtime is unavailable, which includes the current CI. * Updated filename, device used when testing on metal
1 parent 67df20f commit 64ab31e

File tree

1 file changed

+57
-52
lines changed

1 file changed

+57
-52
lines changed

tests/python/unittest/test_allreduce_cuda.py renamed to tests/python/unittest/test_allreduce.py

Lines changed: 57 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -46,61 +46,66 @@ def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32)
4646
B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl])
4747

4848

49-
@tvm.testing.requires_gpu
50-
@tvm.testing.requires_cuda
51-
def test_allreduce_cuda():
52-
def check_sum(d1: int, d2: int, d3: int):
53-
_, _, _d1, _d2, _d3 = reduce.params
54-
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
55-
sch = tvm.tir.Schedule(mod)
56-
blk = sch.get_block("reduce")
57-
i, j, k, l = sch.get_loops(blk)
58-
sch.bind(i, "blockIdx.x")
59-
sch.bind(j, "threadIdx.z")
60-
sch.bind(k, "threadIdx.y")
61-
sch.bind(l, "threadIdx.x")
62-
f = tvm.build(sch.mod["main"], target="cuda")
63-
64-
# prepare input and output array
65-
a_np = np.random.rand(1, d1, d2, d3).astype("float32")
66-
b_np = a_np.sum(axis=-1).astype("float32")
67-
a = tvm.nd.array(a_np, tvm.cuda(0))
68-
b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))
69-
70-
# launch kernel
71-
f(a, b)
72-
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
73-
74-
def check_max(d1: int, d2: int, d3: int):
75-
_, _, _d1, _d2, _d3 = reduce_max.params
76-
mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3})
77-
sch = tvm.tir.Schedule(mod)
78-
blk = sch.get_block("reduce")
79-
i, j, k, l = sch.get_loops(blk)
80-
sch.bind(i, "blockIdx.x")
81-
sch.bind(j, "threadIdx.z")
82-
sch.bind(k, "threadIdx.y")
83-
sch.bind(l, "threadIdx.x")
84-
f = tvm.build(sch.mod["main"], target="cuda")
85-
86-
# prepare input and output array
87-
a_np = -np.random.rand(1, d1, d2, d3).astype("float32")
88-
b_np = a_np.max(axis=-1).astype("float32")
89-
a = tvm.nd.array(a_np, tvm.cuda(0))
90-
b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))
91-
92-
# launch kernel
93-
f(a, b)
94-
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
95-
49+
def generate_param_sets():
9650
for d1 in range(1, 5):
9751
for d2 in range(1, 5):
9852
for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]:
99-
if d1 * d2 * d3 > 1024:
100-
continue
101-
check_sum(d1, d2, d3)
102-
check_max(d1, d2, d3)
53+
if d1 * d2 * d3 < 1024:
54+
yield (d1, d2, d3)
55+
56+
57+
dims = tvm.testing.parameter(*generate_param_sets())
58+
59+
60+
@tvm.testing.parametrize_targets("cuda", "metal")
61+
def test_allreduce_sum(dims, target, dev):
62+
d1, d2, d3 = dims
63+
_, _, _d1, _d2, _d3 = reduce.params
64+
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
65+
sch = tvm.tir.Schedule(mod)
66+
blk = sch.get_block("reduce")
67+
i, j, k, l = sch.get_loops(blk)
68+
sch.bind(i, "blockIdx.x")
69+
sch.bind(j, "threadIdx.z")
70+
sch.bind(k, "threadIdx.y")
71+
sch.bind(l, "threadIdx.x")
72+
f = tvm.build(sch.mod["main"], target=target)
73+
74+
# prepare input and output array
75+
a_np = np.random.rand(1, d1, d2, d3).astype("float32")
76+
b_np = a_np.sum(axis=-1).astype("float32")
77+
a = tvm.nd.array(a_np, dev)
78+
b = tvm.nd.array(np.zeros_like(b_np), dev)
79+
80+
# launch kernel
81+
f(a, b)
82+
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
83+
84+
85+
@tvm.testing.parametrize_targets("cuda", "metal")
86+
def test_allreduce_max(dims, target, dev):
87+
d1, d2, d3 = dims
88+
_, _, _d1, _d2, _d3 = reduce_max.params
89+
mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3})
90+
sch = tvm.tir.Schedule(mod)
91+
blk = sch.get_block("reduce")
92+
i, j, k, l = sch.get_loops(blk)
93+
sch.bind(i, "blockIdx.x")
94+
sch.bind(j, "threadIdx.z")
95+
sch.bind(k, "threadIdx.y")
96+
sch.bind(l, "threadIdx.x")
97+
f = tvm.build(sch.mod["main"], target=target)
98+
99+
# prepare input and output array
100+
a_np = -np.random.rand(1, d1, d2, d3).astype("float32")
101+
b_np = a_np.max(axis=-1).astype("float32")
102+
a = tvm.nd.array(a_np, dev)
103+
b = tvm.nd.array(np.zeros_like(b_np), dev)
104+
105+
# launch kernel
106+
f(a, b)
107+
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
103108

104109

105110
if __name__ == "__main__":
106-
test_allreduce_cuda()
111+
tvm.testing.main()

0 commit comments

Comments
 (0)