@@ -93,6 +93,15 @@ def _parse_seed(seed: Optional[int]) -> int:
9393 return seed
9494
9595
96+ def _get_block_default_dtype (block : Block ) -> str :
97+ for i in block .iter_vars :
98+ return i .var .dtype
99+ for buffer_region in list (block .reads ) + list (block .writes ):
100+ for dom in buffer_region .region :
101+ return dom .min .dtype
102+ return "int64"
103+
104+
96105@_register_object ("tir.Schedule" )
97106class Schedule (Object ):
98107 """The user-facing schedule class
@@ -1492,7 +1501,10 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None:
14921501 block = self ._normalize_block_arg (block )
14931502
14941503 if callable (index_map ):
1495- index_map = IndexMap .from_func (index_map )
1504+ index_map = IndexMap .from_func (
1505+ index_map ,
1506+ index_dtype = _get_block_default_dtype (self .get (block )),
1507+ )
14961508 return _ffi_api .ScheduleReindexCacheRead ( # type: ignore # pylint: disable=no-member
14971509 self , block , read_buffer_index , storage_scope , index_map
14981510 )
@@ -1589,7 +1601,10 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
15891601 block = self ._normalize_block_arg (block )
15901602
15911603 if callable (index_map ):
1592- index_map = IndexMap .from_func (index_map )
1604+ index_map = IndexMap .from_func (
1605+ index_map ,
1606+ index_dtype = _get_block_default_dtype (self .get (block )),
1607+ )
15931608 return _ffi_api .ScheduleReindexCacheWrite ( # type: ignore # pylint: disable=no-member
15941609 self , block , write_buffer_index , storage_scope , index_map
15951610 )
@@ -3246,14 +3261,22 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
32463261
32473262 ndim = len (buffer_obj .shape )
32483263 if callable (index_map ):
3249- index_map , axis_separators = IndexMap .from_func_with_separators (index_map , ndim = ndim )
3264+ index_map , axis_separators = IndexMap .from_func_with_separators (
3265+ index_map ,
3266+ ndim = ndim ,
3267+ index_dtype = _get_block_default_dtype (self .get (block )),
3268+ )
32503269 else :
32513270 axis_separators = []
32523271
32533272 if pad_value is None :
32543273 pass
32553274 elif callable (pad_value ):
3256- pad_value = IndexMap .from_func (pad_value , ndim = len (index_map .final_indices ))
3275+ pad_value = IndexMap .from_func (
3276+ pad_value ,
3277+ ndim = len (index_map .final_indices ),
3278+ index_dtype = _get_block_default_dtype (self .get (block )),
3279+ )
32573280 elif not isinstance (pad_value , IndexMap ):
32583281 # Explicitly convert python int/float arguments to the
32593282 # buffer's type. If the default `tvm.runtime.convert`
@@ -3264,7 +3287,9 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
32643287 elif "float" in buffer_obj .dtype and isinstance (pad_value , float ):
32653288 pad_value = FloatImm (buffer_obj .dtype , pad_value )
32663289 pad_value = IndexMap .from_func (
3267- lambda * indices : pad_value , ndim = len (index_map .final_indices )
3290+ lambda * indices : pad_value ,
3291+ ndim = len (index_map .final_indices ),
3292+ index_dtype = _get_block_default_dtype (self .get (block )),
32683293 )
32693294
32703295 buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
@@ -3337,7 +3362,10 @@ def after_transform_block_layout(
33373362 """
33383363 block = self ._normalize_block_arg (block )
33393364 if callable (index_map ):
3340- index_map = IndexMap .from_func (index_map )
3365+ index_map = IndexMap .from_func (
3366+ index_map ,
3367+ index_dtype = _get_block_default_dtype (self .get (block )),
3368+ )
33413369 _ffi_api .ScheduleTransformBlockLayout ( # type: ignore # pylint: disable=no-member
33423370 self , block , index_map
33433371 )
0 commit comments