@@ -84,7 +84,7 @@ def test_gpu_softmax_mn():
8484 "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])" ,
8585 "l5, l6 = sch.split(loop=l3, factors=[None, v4])" ,
8686 'sch.bind(loop=l6, thread_axis="threadIdx.x")' ,
87- "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1 )" ,
87+ "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True )" ,
8888 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")' ,
8989 "l7, l8, l9 = sch.get_loops(block=b0)" ,
9090 "l10, l11 = sch.split(loop=l9, factors=[None, v4])" ,
@@ -97,7 +97,7 @@ def test_gpu_softmax_mn():
9797 "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])" ,
9898 "l5, l6 = sch.split(loop=l3, factors=[None, v4])" ,
9999 'sch.bind(loop=l6, thread_axis="threadIdx.x")' ,
100- "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1 )" ,
100+ "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True )" ,
101101 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")' ,
102102 "l7, l8, l9 = sch.get_loops(block=b0)" ,
103103 "l10, l11 = sch.split(loop=l9, factors=[None, v4])" ,
@@ -111,7 +111,7 @@ def test_gpu_softmax_mn():
111111 "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])" ,
112112 "l6, l7 = sch.split(loop=l4, factors=[None, v5])" ,
113113 'sch.bind(loop=l7, thread_axis="threadIdx.x")' ,
114- "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1 )" ,
114+ "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True )" ,
115115 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")' ,
116116 "l8, l9, l10 = sch.get_loops(block=b1)" ,
117117 "l11, l12 = sch.split(loop=l10, factors=[None, v5])" ,
@@ -121,7 +121,7 @@ def test_gpu_softmax_mn():
121121 "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])" ,
122122 "l17, l18 = sch.split(loop=l15, factors=[None, v16])" ,
123123 'sch.bind(loop=l18, thread_axis="threadIdx.x")' ,
124- "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=1 )" ,
124+ "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True )" ,
125125 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")' ,
126126 "l19, l20, l21 = sch.get_loops(block=b0)" ,
127127 "l22, l23 = sch.split(loop=l21, factors=[None, v16])" ,
@@ -161,7 +161,7 @@ def test_gpu_softmax_mn_after_inline():
161161 "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])" ,
162162 "l5, l6 = sch.split(loop=l3, factors=[None, v4])" ,
163163 'sch.bind(loop=l6, thread_axis="threadIdx.x")' ,
164- "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1 )" ,
164+ "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True )" ,
165165 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")' ,
166166 "l7, l8, l9 = sch.get_loops(block=b0)" ,
167167 "l10, l11 = sch.split(loop=l9, factors=[None, v4])" ,
@@ -175,14 +175,14 @@ def test_gpu_softmax_mn_after_inline():
175175 "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])" ,
176176 "l6, l7 = sch.split(loop=l4, factors=[None, v5])" ,
177177 'sch.bind(loop=l7, thread_axis="threadIdx.x")' ,
178- "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1 )" ,
178+ "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True )" ,
179179 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")' ,
180180 "l8, l9, l10 = sch.get_loops(block=b1)" ,
181181 "l11, l12 = sch.split(loop=l10, factors=[None, v5])" ,
182182 'sch.bind(loop=l12, thread_axis="threadIdx.x")' ,
183183 "b13, b14 = sch.get_consumers(block=b0)" ,
184184 "l15, l16, l17, l18 = sch.get_loops(block=b13)" ,
185- "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=1 )" ,
185+ "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True )" ,
186186 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")' ,
187187 "l19, l20, l21 = sch.get_loops(block=b0)" ,
188188 "l22, l23 = sch.split(loop=l21, factors=[None, v5])" ,
@@ -210,7 +210,7 @@ def test_gpu_batch_norm_bmn():
210210 "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])" ,
211211 "l4, l5 = sch.split(loop=l2, factors=[None, v3])" ,
212212 'sch.bind(loop=l5, thread_axis="threadIdx.x")' ,
213- "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=1 )" ,
213+ "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True )" ,
214214 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")' ,
215215 "l6, l7, l8, l9 = sch.get_loops(block=b0)" ,
216216 "l10 = sch.fuse(l8, l9)" ,
0 commit comments