Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,61 +46,66 @@ def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32)
B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl])


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_allreduce_cuda():
def check_sum(d1: int, d2: int, d3: int):
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
f = tvm.build(sch.mod["main"], target="cuda")

# prepare input and output array
a_np = np.random.rand(1, d1, d2, d3).astype("float32")
b_np = a_np.sum(axis=-1).astype("float32")
a = tvm.nd.array(a_np, tvm.cuda(0))
b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))

# launch kernel
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)

def check_max(d1: int, d2: int, d3: int):
_, _, _d1, _d2, _d3 = reduce_max.params
mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
f = tvm.build(sch.mod["main"], target="cuda")

# prepare input and output array
a_np = -np.random.rand(1, d1, d2, d3).astype("float32")
b_np = a_np.max(axis=-1).astype("float32")
a = tvm.nd.array(a_np, tvm.cuda(0))
b = tvm.nd.array(np.zeros_like(b_np), tvm.cuda(0))

# launch kernel
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)

def generate_param_sets():
for d1 in range(1, 5):
for d2 in range(1, 5):
for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]:
if d1 * d2 * d3 > 1024:
continue
check_sum(d1, d2, d3)
check_max(d1, d2, d3)
if d1 * d2 * d3 < 1024:
yield (d1, d2, d3)


dims = tvm.testing.parameter(*generate_param_sets())


@tvm.testing.parametrize_targets("cuda", "metal")
def test_allreduce_sum(dims, target, dev):
d1, d2, d3 = dims
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
f = tvm.build(sch.mod["main"], target=target)

# prepare input and output array
a_np = np.random.rand(1, d1, d2, d3).astype("float32")
b_np = a_np.sum(axis=-1).astype("float32")
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros_like(b_np), dev)

# launch kernel
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)


@tvm.testing.parametrize_targets("cuda", "metal")
def test_allreduce_max(dims, target, dev):
d1, d2, d3 = dims
_, _, _d1, _d2, _d3 = reduce_max.params
mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
f = tvm.build(sch.mod["main"], target=target)

# prepare input and output array
a_np = -np.random.rand(1, d1, d2, d3).astype("float32")
b_np = a_np.max(axis=-1).astype("float32")
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros_like(b_np), dev)

# launch kernel
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)


if __name__ == "__main__":
test_allreduce_cuda()
tvm.testing.main()