@@ -46,61 +46,66 @@ def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32)
4646 B [vi , vj , vk ] = T .max (B [vi , vj , vk ], A [vi , vj , vk , vl ])
4747
4848
49- @tvm .testing .requires_gpu
50- @tvm .testing .requires_cuda
51- def test_allreduce_cuda ():
52- def check_sum (d1 : int , d2 : int , d3 : int ):
53- _ , _ , _d1 , _d2 , _d3 = reduce .params
54- mod = reduce .specialize ({_d1 : d1 , _d2 : d2 , _d3 : d3 })
55- sch = tvm .tir .Schedule (mod )
56- blk = sch .get_block ("reduce" )
57- i , j , k , l = sch .get_loops (blk )
58- sch .bind (i , "blockIdx.x" )
59- sch .bind (j , "threadIdx.z" )
60- sch .bind (k , "threadIdx.y" )
61- sch .bind (l , "threadIdx.x" )
62- f = tvm .build (sch .mod ["main" ], target = "cuda" )
63-
64- # prepare input and output array
65- a_np = np .random .rand (1 , d1 , d2 , d3 ).astype ("float32" )
66- b_np = a_np .sum (axis = - 1 ).astype ("float32" )
67- a = tvm .nd .array (a_np , tvm .cuda (0 ))
68- b = tvm .nd .array (np .zeros_like (b_np ), tvm .cuda (0 ))
69-
70- # launch kernel
71- f (a , b )
72- tvm .testing .assert_allclose (b .numpy (), b_np , rtol = 1e-6 , atol = 1e-6 )
73-
74- def check_max (d1 : int , d2 : int , d3 : int ):
75- _ , _ , _d1 , _d2 , _d3 = reduce_max .params
76- mod = reduce_max .specialize ({_d1 : d1 , _d2 : d2 , _d3 : d3 })
77- sch = tvm .tir .Schedule (mod )
78- blk = sch .get_block ("reduce" )
79- i , j , k , l = sch .get_loops (blk )
80- sch .bind (i , "blockIdx.x" )
81- sch .bind (j , "threadIdx.z" )
82- sch .bind (k , "threadIdx.y" )
83- sch .bind (l , "threadIdx.x" )
84- f = tvm .build (sch .mod ["main" ], target = "cuda" )
85-
86- # prepare input and output array
87- a_np = - np .random .rand (1 , d1 , d2 , d3 ).astype ("float32" )
88- b_np = a_np .max (axis = - 1 ).astype ("float32" )
89- a = tvm .nd .array (a_np , tvm .cuda (0 ))
90- b = tvm .nd .array (np .zeros_like (b_np ), tvm .cuda (0 ))
91-
92- # launch kernel
93- f (a , b )
94- tvm .testing .assert_allclose (b .numpy (), b_np , rtol = 1e-6 , atol = 1e-6 )
95-
49+ def generate_param_sets ():
9650 for d1 in range (1 , 5 ):
9751 for d2 in range (1 , 5 ):
9852 for d3 in [2 , 4 , 8 , 12 , 16 , 32 , 48 , 64 , 100 , 128 , 201 , 256 , 512 , 1024 ]:
99- if d1 * d2 * d3 > 1024 :
100- continue
101- check_sum (d1 , d2 , d3 )
102- check_max (d1 , d2 , d3 )
53+ if d1 * d2 * d3 < 1024 :
54+ yield (d1 , d2 , d3 )
55+
56+
57+ dims = tvm .testing .parameter (* generate_param_sets ())
58+
59+
60+ @tvm .testing .parametrize_targets ("cuda" , "metal" )
61+ def test_allreduce_sum (dims , target , dev ):
62+ d1 , d2 , d3 = dims
63+ _ , _ , _d1 , _d2 , _d3 = reduce .params
64+ mod = reduce .specialize ({_d1 : d1 , _d2 : d2 , _d3 : d3 })
65+ sch = tvm .tir .Schedule (mod )
66+ blk = sch .get_block ("reduce" )
67+ i , j , k , l = sch .get_loops (blk )
68+ sch .bind (i , "blockIdx.x" )
69+ sch .bind (j , "threadIdx.z" )
70+ sch .bind (k , "threadIdx.y" )
71+ sch .bind (l , "threadIdx.x" )
72+ f = tvm .build (sch .mod ["main" ], target = target )
73+
74+ # prepare input and output array
75+ a_np = np .random .rand (1 , d1 , d2 , d3 ).astype ("float32" )
76+ b_np = a_np .sum (axis = - 1 ).astype ("float32" )
77+ a = tvm .nd .array (a_np , dev )
78+ b = tvm .nd .array (np .zeros_like (b_np ), dev )
79+
80+ # launch kernel
81+ f (a , b )
82+ tvm .testing .assert_allclose (b .numpy (), b_np , rtol = 1e-6 , atol = 1e-6 )
83+
84+
85+ @tvm .testing .parametrize_targets ("cuda" , "metal" )
86+ def test_allreduce_max (dims , target , dev ):
87+ d1 , d2 , d3 = dims
88+ _ , _ , _d1 , _d2 , _d3 = reduce_max .params
89+ mod = reduce_max .specialize ({_d1 : d1 , _d2 : d2 , _d3 : d3 })
90+ sch = tvm .tir .Schedule (mod )
91+ blk = sch .get_block ("reduce" )
92+ i , j , k , l = sch .get_loops (blk )
93+ sch .bind (i , "blockIdx.x" )
94+ sch .bind (j , "threadIdx.z" )
95+ sch .bind (k , "threadIdx.y" )
96+ sch .bind (l , "threadIdx.x" )
97+ f = tvm .build (sch .mod ["main" ], target = target )
98+
99+ # prepare input and output array
100+ a_np = - np .random .rand (1 , d1 , d2 , d3 ).astype ("float32" )
101+ b_np = a_np .max (axis = - 1 ).astype ("float32" )
102+ a = tvm .nd .array (a_np , dev )
103+ b = tvm .nd .array (np .zeros_like (b_np ), dev )
104+
105+ # launch kernel
106+ f (a , b )
107+ tvm .testing .assert_allclose (b .numpy (), b_np , rtol = 1e-6 , atol = 1e-6 )
103108
104109
105110if __name__ == "__main__" :
106- test_allreduce_cuda ()
111+ tvm . testing . main ()
0 commit comments