Skip to content

Commit f7e7528

Browse files
committed
[Fix] Buffer slicing using idnex dtype as extent
1 parent fe01c5a commit f7e7528

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

python/tvm/tir/buffer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def offset_of(self, indices):
179179

180180
def __getitem__(self, indices):
181181
from ..arith import Analyzer # pylint: disable=import-outside-toplevel
182-
from .expr import BufferLoad, Ramp # pylint: disable=import-outside-toplevel
182+
from .expr import BufferLoad, Ramp, const # pylint: disable=import-outside-toplevel
183183
from .stmt import BufferRegion # pylint: disable=import-outside-toplevel
184184

185185
if not isinstance(indices, (tuple, list)):
@@ -195,7 +195,11 @@ def __getitem__(self, indices):
195195
stop = self.shape[i] if index.stop is None else index.stop
196196
region.append(Range.from_min_extent(start, analyzer.simplify(stop - start)))
197197
else:
198-
region.append(Range.from_min_extent(index, 1))
198+
region.append(
199+
Range.from_min_extent(
200+
index, const(1, index.dtype) if isinstance(index, PrimExpr) else 1
201+
)
202+
)
199203
return BufferRegion(self, region)
200204
else:
201205
expr_indices = []

tests/python/unittest/test_tvmscript_regression.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy
1818

1919
import tvm
20+
import tvm.testing
2021
from tvm.script import tir as T
2122

2223

@@ -73,9 +74,18 @@ def func_ref():
7374
tvm.ir.assert_structural_equal(test_case, func_ref)
7475

7576

77+
78+
def test_tir_buffer_region_extent_correct_dtype():
79+
@T.prim_func
80+
def func(A: T.Buffer[(T.int64(16), T.int64(1)), "float32"]):
81+
for i in T.grid(T.int64(16)):
82+
with T.block("block"):
83+
vi = T.axis.remap("S", [i])
84+
T.reads(A[vi, T.int64(0) : T.int64(1)])
85+
T.evaluate(0)
86+
87+
assert func.body.block.body.body.block.reads[0].region[0].extent.dtype == "int64"
88+
89+
7690
if __name__ == "__main__":
77-
a = numpy.zeros((10, 10), dtype="int8")
78-
test_multi_element_array_in_outmost_namespace()
79-
test_different_dtype_assignment_to_var()
80-
b = 1
81-
test_var_capturing_order()
91+
tvm.testing.main()

0 commit comments

Comments
 (0)