Skip to content

Commit 3cc3dc8

Browse files
support represent ramp as index slice in tvmscript
1 parent fa834f6 commit 3cc3dc8

File tree

8 files changed

+157
-98
lines changed

8 files changed

+157
-98
lines changed

python/tvm/script/parser.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,14 @@ def transform_SubscriptAssign(self, node):
631631
f"cannot be indexed by {len(indexes)}-dimensional indices.",
632632
node.params[1].span,
633633
)
634+
635+
def __convert_index(x):
636+
if isinstance(x, Slice):
637+
return x.as_index_expr(self.report_error)
638+
return x
639+
634640
# BufferStore
641+
indexes = [__convert_index(x) for x in indexes]
635642
return tvm.tir.BufferStore(
636643
symbol,
637644
tvm.runtime.convert(rhs, span=rhs_span),
@@ -948,11 +955,18 @@ def f():
948955
)
949956

950957
def transform_Slice(self, node):
958+
"""Index slice visitor."""
951959
start = self.transform(node.start)
952960
end = self.transform(node.end)
953-
if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
954-
self.report_error("Only step size 1 is supported for slices.", node.step.span)
955-
return Slice(start, end)
961+
if not (
962+
isinstance(node.step, ast.Constant)
963+
and isinstance(node.step.value, int)
964+
and node.step.value > 0
965+
):
966+
self.report_error(
967+
"Only positive integer step size is supported for slices.", node.step.span
968+
)
969+
return Slice(start, end, node.step.value, tvm_span_from_synr(node.span))
956970

957971
def transform_Subscript(self, node):
958972
"""Array access visitor.

python/tvm/script/tir/node.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
from typing import Optional, Union, List, Callable
2121
import synr
22-
22+
from tvm.arith import Analyzer
2323
from tvm.runtime import ObjectGeneric, convert
24-
from tvm.tir import PrimExpr, Buffer, BufferLoad
25-
from tvm.ir import Span
24+
from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
25+
from tvm.ir import Span, Range
2626

2727

2828
class Slice:
@@ -36,24 +36,50 @@ class Slice:
3636
stop : Optional[Union[PrimExpr, int]]
3737
The stop index, None means the Slice is an element-wise index
3838
39+
step : int
40+
The slice step
41+
3942
span : Optional[Span]
4043
The location of the slice in the source.
4144
"""
4245

4346
start: Union[PrimExpr, int]
4447
stop: Optional[Union[PrimExpr, int]]
48+
step: int
4549
span: Optional[Span]
4650

4751
def __init__(
4852
self,
4953
start: Union[PrimExpr, int],
5054
stop: Optional[Union[PrimExpr, int]] = None,
55+
step: int = 1,
5156
span: Optional[Span] = None,
5257
):
5358
self.start = start
5459
self.stop = stop
60+
self.step = step
5561
self.span = span
5662

63+
def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
64+
"""Helper to create index PrimExpr from slice object
65+
Parameters
66+
----------
67+
report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
68+
The error report func
69+
"""
70+
if self.stop is None:
71+
# scalar index
72+
return self.start
73+
extent = Analyzer().simplify(self.stop - self.start)
74+
if not isinstance(extent, (int, IntImm)):
75+
report_error("Slice's extent should be constant for buffer indices", self.span)
76+
if self.step < 1:
77+
report_error("Slice's step should be positive integer", self.span)
78+
lanes = (int(extent) + self.step - 1) // self.step
79+
if lanes == 1:
80+
return self.start
81+
return Ramp(self.start, self.step, lanes, self.span)
82+
5783

5884
class BufferSlice(ObjectGeneric):
5985
"""A generic object for representing general buffer access. Following cases are supported:
@@ -148,13 +174,35 @@ def __str__(self):
148174

149175
def asobject(self) -> BufferLoad:
150176
"""Convert object."""
151-
for s in self.slices:
152-
if s.stop is not None:
153-
self.report_error("BufferLoad only accepts elementwise access", self.span)
154-
155-
indices = [s.start for s in self.slices]
177+
indices = [s.as_index_expr(self.report_error) for s in self.slices]
156178
return BufferLoad(self.buffer, indices, span=self.span)
157179

180+
def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion:
181+
"""Construct BufferRegion from BufferSlice
182+
183+
Parameters
184+
----------
185+
analyzer : Optional[tvm.arith.Analyzer]
186+
The analyzer for simplifying. If not provided, the method will construct a new one
187+
188+
Returns
189+
-------
190+
buffer_region : BufferRegion
191+
The constructed BufferRegion.
192+
"""
193+
region: List[Range] = []
194+
for s in self.slices:
195+
start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start)
196+
extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start
197+
if not analyzer:
198+
analyzer = Analyzer()
199+
if isinstance(extent, PrimExpr):
200+
extent = analyzer.simplify(extent)
201+
if s.step != 1:
202+
self.report_error("BufferRegion do not support non-trivial stride", s.span)
203+
region.append(Range.from_min_extent(start, extent, span=s.span))
204+
return BufferRegion(self.buffer, region)
205+
158206
def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
159207
return self.asobject().astype(dtype, span)
160208

python/tvm/script/tir/scope_handler.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
2727

2828
from .node import BufferSlice
29-
from .utils import buffer_slice_to_region
3029

3130
from ..context_maintainer import ContextMaintainer
3231
from ..registry import register
@@ -327,12 +326,10 @@ def block(name_hint: str = "", span: Optional[Span] = None):
327326

328327
# create block read/write regions
329328
reads: List[BufferRegion] = (
330-
[buffer_slice_to_region(read) for read in block_info.reads]
331-
if block_info.reads
332-
else []
329+
[read.as_buffer_region() for read in block_info.reads] if block_info.reads else []
333330
)
334331
writes: List[BufferRegion] = (
335-
[buffer_slice_to_region(write) for write in block_info.writes]
332+
[write.as_buffer_region() for write in block_info.writes]
336333
if block_info.writes
337334
else []
338335
)

python/tvm/script/tir/special_stmt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from tvm.tir import IntImm, IterVar, Var
3131

3232
from .node import BufferSlice
33-
from .utils import buffer_slice_to_region
3433

3534
from ..context_maintainer import BlockInfo, ContextMaintainer
3635
from ..registry import register
@@ -168,7 +167,7 @@ def match_buffer(
168167
)
169168
self.context.func_buffer_map[param] = buffer
170169
elif isinstance(param, BufferSlice):
171-
buffer_region = buffer_slice_to_region(param)
170+
buffer_region = param.as_buffer_region()
172171
self.context.current_block_scope().match_buffers.append(
173172
tvm.tir.MatchBufferRegion(buffer, buffer_region)
174173
)

python/tvm/script/tir/utils.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

src/printer/tvmscript_printer.cc

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
265265
Doc PrintRange(const RangeNode* op);
266266
Doc PrintArray(const ArrayNode* op);
267267
Doc PrintBuffer(const BufferNode* op);
268+
Doc PrintBufferIndices(const Array<PrimExpr>& indices);
268269
Doc PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers);
269270
Doc AllocBufferDeclaration(const Buffer& buf);
270271
Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
@@ -834,7 +835,7 @@ Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_p
834835
if (op->indices.size() == 0) {
835836
doc << Print(op->buffer) << "[()]";
836837
} else {
837-
doc << Print(op->buffer) << Print(op->indices);
838+
doc << Print(op->buffer) << PrintBufferIndices(op->indices);
838839
}
839840
return doc;
840841
}
@@ -1260,7 +1261,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) {
12601261
if (op->indices.size() == 0) {
12611262
doc << Print(op->buffer) << "[()] = " << Print(op->value);
12621263
} else {
1263-
doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
1264+
doc << Print(op->buffer) << PrintBufferIndices(op->indices) << " = " << Print(op->value);
12641265
}
12651266
return doc;
12661267
}
@@ -1678,6 +1679,30 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
16781679
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
16791680
}
16801681

1682+
Doc TVMScriptPrinter::PrintBufferIndices(const Array<PrimExpr>& indices) {
1683+
Doc doc;
1684+
doc << '[';
1685+
for (size_t i = 0; i < indices.size(); ++i) {
1686+
if (i != 0) {
1687+
doc << ", ";
1688+
}
1689+
PrimExpr index = indices[i];
1690+
if (const RampNode* ramp = index.as<RampNode>()) {
1691+
// specify ramp printing as python index slice
1692+
if (auto* stride_imm = ramp->stride.as<IntImmNode>()) {
1693+
doc << Print(ramp->base) << ":" << Print(ramp->base + ramp->lanes * ramp->stride);
1694+
if (stride_imm->value != 1) {
1695+
doc << ":" << Print(ramp->stride);
1696+
}
1697+
continue;
1698+
}
1699+
}
1700+
doc << Print(index);
1701+
}
1702+
doc << ']';
1703+
return doc;
1704+
}
1705+
16811706
Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers) {
16821707
Doc decls;
16831708
for (const auto& buf_usage : aliasing_buffers) {

tests/python/unittest/test_tvmscript_error_report.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -372,29 +372,6 @@ def test_error_index_type():
372372
check_error(error_bufferslice_index_type, 8)
373373

374374

375-
def error_index_with_stop() -> None:
376-
A = T.alloc_buffer((128, 128), "float32")
377-
for i, j in T.grid(128, 128):
378-
with T.block():
379-
vi, vj = T.axis.remap("SS", [i, j])
380-
A[vi, vj] = A[vi, 1:10] + 1 # error
381-
382-
383-
def error_bufferslice_index_with_stop() -> None:
384-
A = T.alloc_buffer((1,), "int32")
385-
B = T.alloc_buffer((16, 16), "float32")
386-
C = T.alloc_buffer((16, 16), "float32")
387-
for i, j in T.grid(16, 16):
388-
with T.block():
389-
vi, vj = T.axis.remap("SS", [i, j])
390-
C[vi, vj] = B[vi, A[0:1]] # error
391-
392-
393-
def test_error_index_with_stop_slice():
394-
check_error(error_index_with_stop, 6)
395-
check_error(error_bufferslice_index_with_stop, 8)
396-
397-
398375
def special_stmt_except() -> None:
399376
A = T.alloc_buffer("(128, 128)", "float32") # error
400377
T.evaluate(1.0)
@@ -658,5 +635,42 @@ def test_preflattened_buffer_map_offset_factor():
658635
check_error(preflattened_buffer_map_offset_factor_nonint, 3)
659636

660637

638+
def strided_buffer_region(A: T.handle):
639+
# do not allow stride in buffer region
640+
A = T.match_buffer((128, 128), "int32")
641+
with T.block():
642+
T.reads([])
643+
T.writes([A[0:128:2, 0:128:3]]) # error
644+
T.evaluate(T.call_extern("strided_compute", dtype=""))
645+
646+
647+
def strided_buffer_region(A: T.handle):
648+
# do not allow stride in buffer region
649+
A = T.match_buffer((128, 128), "int32")
650+
with T.block():
651+
T.reads([])
652+
T.writes([A[0:128:2, 0:128:3]]) # error
653+
T.evaluate(T.call_extern("strided_compute", dtype=""))
654+
655+
656+
def access_reversed_slice(A: T.handle):
657+
# do not allow reversed slice step
658+
A = T.match_buffer((128,), "int32")
659+
A[0:128:-1] = T.broadcast(1, 128)
660+
661+
662+
def access_non_const_slice(A: T.handle):
663+
# do not allow reversed slice step
664+
A = T.match_buffer((128,), "int32")
665+
for i in range(4):
666+
T.evaluate(A[0:i:1])
667+
668+
669+
def test_illegal_buffer_slice():
670+
check_error(strided_buffer_region, 3)
671+
check_error(access_reversed_slice, 3)
672+
check_error(access_non_const_slice, 3)
673+
674+
661675
if __name__ == "__main__":
662676
sys.exit(pytest.main([__file__] + sys.argv[1:]))

tests/python/unittest/test_tvmscript_roundtrip.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3272,6 +3272,22 @@ def element_wise(a: T.handle, c: T.handle) -> None:
32723272
return element_wise
32733273

32743274

3275+
def buffer_ramp_access_as_slice_index():
3276+
@T.prim_func
3277+
def buffer_ramp_access(a: T.handle, b: T.handle, c: T.handle) -> None:
3278+
A = T.match_buffer(a, (128,), "float32")
3279+
B = T.match_buffer(b, (128,), "float32")
3280+
C = T.match_buffer(c, (128,), "float32")
3281+
for i in range(128):
3282+
A[i : i + 1 : 1] = i
3283+
for i in range(4):
3284+
B[i * 32 : i * 32 + 32] = A[i * 32 : i * 32 + 32 : 1] + T.broadcast(1.0, 32)
3285+
for i in range(4):
3286+
C[i : i + 128 : 4] = B[i : i + 128 : 4] + T.broadcast(1.0, 32)
3287+
3288+
return buffer_ramp_access
3289+
3290+
32753291
ir_generator = tvm.testing.parameter(
32763292
opt_gemm_normalize,
32773293
opt_gemm_lower,
@@ -3308,6 +3324,7 @@ def element_wise(a: T.handle, c: T.handle) -> None:
33083324
string_annotation_escaping,
33093325
pointer_type,
33103326
buffer_axis_separator,
3327+
buffer_ramp_access_as_slice_index,
33113328
)
33123329

33133330

0 commit comments

Comments
 (0)