@@ -60,8 +60,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
6060 your operation.
6161 """
6262
63- batch_size = prod (data .shape [:- 1 ])
64- scan_axis_size = data .shape [- 1 ]
63+ batch_size = cast ( prod (data .shape [:- 1 ]), "int32" )
64+ scan_axis_size = cast ( data .shape [- 1 ], "int32" )
6565
6666 ib = tvm .tir .ir_builder .create ()
6767
@@ -105,7 +105,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
105105 # Up Sweep of exclusive scan
106106 lim = ceil_log2 (scan_axis_size )
107107
108- with ib .for_range (0 , cast (lim , "int64 " ), dtype = "int64 " ) as l2_width :
108+ with ib .for_range (0 , cast (lim , "int32 " ), dtype = "int32 " ) as l2_width :
109109 width = 2 << l2_width
110110
111111 with ib .new_scope ():
@@ -121,9 +121,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
121121
122122 by = te .thread_axis ("blockIdx.y" )
123123 ib .scope_attr (by , "thread_extent" , nthread_by )
124- start = ib .allocate ("int64 " , (1 ,), name = "start" , scope = "local" )
125- middle = ib .allocate ("int64 " , (1 ,), name = "middle" , scope = "local" )
126- end = ib .allocate ("int64 " , (1 ,), name = "end" , scope = "local" )
124+ start = ib .allocate ("int32 " , (1 ,), name = "start" , scope = "local" )
125+ middle = ib .allocate ("int32 " , (1 ,), name = "middle" , scope = "local" )
126+ end = ib .allocate ("int32 " , (1 ,), name = "end" , scope = "local" )
127127 start [0 ] = width * tid
128128 with ib .if_scope (start [0 ] < scan_axis_size ):
129129 middle [0 ] = start [0 ] + tvm .tir .indexdiv (width , 2 )
@@ -143,7 +143,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
143143 reduction [bx ] = output [(bx + 1 ) * scan_axis_size - 1 ]
144144 output [(bx + 1 ) * scan_axis_size - 1 ] = cast (identity_value , out_dtype )
145145
146- with ib .for_range (0 , cast (lim , "int64 " ), dtype = "int64 " ) as l2_width :
146+ with ib .for_range (0 , cast (lim , "int32 " ), dtype = "int32 " ) as l2_width :
147147 width = 2 << (lim - l2_width - 1 )
148148
149149 with ib .new_scope ():
@@ -159,9 +159,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
159159
160160 by = te .thread_axis ("blockIdx.y" )
161161 ib .scope_attr (by , "thread_extent" , nthread_by )
162- start = ib .allocate ("int64 " , (1 ,), name = "start" , scope = "local" )
163- middle = ib .allocate ("int64 " , (1 ,), name = "middle" , scope = "local" )
164- end = ib .allocate ("int64 " , (1 ,), name = "end" , scope = "local" )
162+ start = ib .allocate ("int32 " , (1 ,), name = "start" , scope = "local" )
163+ middle = ib .allocate ("int32 " , (1 ,), name = "middle" , scope = "local" )
164+ end = ib .allocate ("int32 " , (1 ,), name = "end" , scope = "local" )
165165 tmp = ib .allocate (out_dtype , (1 ,), name = "end" , scope = "local" )
166166 start [0 ] = width * tid
167167 with ib .if_scope (tvm .tir .all (start [0 ] < scan_axis_size )):
@@ -206,8 +206,8 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generi
206206 ex_scan_output = expand_dims (ex_scan_output , axis = 0 )
207207
208208 def ir (data , data_ex_scan , reduction ):
209- batch_size = prod (data .shape [:- 1 ])
210- scan_axis_size = data .shape [- 1 ]
209+ batch_size = cast ( prod (data .shape [:- 1 ]), "int32" )
210+ scan_axis_size = cast ( data .shape [- 1 ], "int32" )
211211
212212 ib = tvm .tir .ir_builder .create ()
213213
0 commit comments