Skip to content

Commit f480eb2

Browse files
cyx-6junrushao
authored andcommitted
stmt methods (apache#47)
* `stmt` methods 0 * `stmt` methods 1 * `stmt` methods 2 * `stmt` methods 3 * `stmt` methods 4 * add `T.while` method * `stmt` methods without `with` * `IfFrame`, `ThenFrame`, `ElseFrame` as replacement * apply code review suggestions 0 * apply code review suggestions 1 * apply code review suggestions 2 * apply code review suggestions
1 parent 0cbb858 commit f480eb2

File tree

7 files changed

+828
-9
lines changed

7 files changed

+828
-9
lines changed

python/tvm/script/builder/tir/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,20 @@
4040
prim_func,
4141
)
4242
from .var import Buffer
43+
from .stmt import (
44+
Assert,
45+
let,
46+
allocate,
47+
allocate_const,
48+
launch_thread,
49+
realize,
50+
attr,
51+
while_,
52+
if_,
53+
then_,
54+
else_,
55+
env_thread,
56+
buffer_store,
57+
prefetch,
58+
evaluate,
59+
)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""TVM Script TIR For Frame"""
18+
import numpy as np
19+
from typing import List, Union
20+
21+
from tvm._ffi import register_object as _register_object
22+
from tvm.tir import Buffer, IterVar, PrimExpr, Var, BufferRegion, Stmt, StringImm
23+
from tvm.ir import Type, Range
24+
from tvm.runtime import ndarray as nd, Object
25+
26+
from . import _ffi_api
27+
from .. import _ffi_api as _base_ffi_api
28+
from .base import TIRFrame
29+
30+
31+
@_register_object("script.builder.tir.AssertFrame")
32+
class AssertFrame(TIRFrame):
33+
...
34+
35+
36+
@_register_object("script.builder.tir.LetFrame")
37+
class LetFrame(TIRFrame):
38+
...
39+
40+
41+
@_register_object("script.builder.tir.AllocateFrame")
42+
class AllocateFrame(TIRFrame):
43+
def __enter__(self) -> Buffer:
44+
_base_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore
45+
return self.buffer
46+
47+
48+
@_register_object("script.builder.tir.AllocateConstFrame")
49+
class AllocateConstFrame(TIRFrame):
50+
def __enter__(self) -> Buffer:
51+
_base_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore
52+
return self.buffer
53+
54+
55+
@_register_object("script.builder.tir.LaunchThreadFrame")
56+
class LaunchThreadFrame(TIRFrame):
57+
...
58+
59+
60+
@_register_object("script.builder.tir.RealizeFrame")
61+
class RealizeFrame(TIRFrame):
62+
...
63+
64+
65+
@_register_object("script.builder.tir.AttrFrame")
66+
class AttrFrame(TIRFrame):
67+
...
68+
69+
70+
@_register_object("script.builder.tir.WhileFrame")
71+
class WhileFrame(TIRFrame):
72+
...
73+
74+
75+
@_register_object("script.builder.tir.IfFrame")
76+
class IfFrame(TIRFrame):
77+
...
78+
79+
80+
@_register_object("script.builder.tir.ThenFrame")
81+
class ThenFrame(TIRFrame):
82+
...
83+
84+
85+
@_register_object("script.builder.tir.ElseFrame")
86+
class ElseFrame(TIRFrame):
87+
...
88+
89+
90+
def Assert(condition: PrimExpr, message: str) -> AssertFrame:
91+
return _ffi_api.AssertFrame(condition, message) # pylint: disable=no-member # type: ignore
92+
93+
94+
def let(var: Var, value: PrimExpr) -> LetFrame:
95+
return _ffi_api.LetFrame(var, value) # pylint: disable=no-member # type: ignore
96+
97+
98+
def allocate(
99+
extents: List[PrimExpr],
100+
dtype: str,
101+
storage_scope: str = "",
102+
condition: PrimExpr = None,
103+
annotations=None,
104+
) -> AllocateFrame:
105+
return _ffi_api.AllocateFrame(
106+
extents, dtype, storage_scope, condition, annotations
107+
) # pylint: disable=no-member # type: ignore
108+
109+
110+
def allocate_const(data: List[PrimExpr], dtype: str, extents: List[PrimExpr]) -> AllocateConstFrame:
111+
return _ffi_api.AllocateConstFrame(
112+
nd.array(np.asarray(data, dtype)), dtype, extents
113+
) # pylint: disable=no-member # type: ignore
114+
115+
116+
def launch_thread(iter_var: IterVar, extent: PrimExpr) -> LaunchThreadFrame:
117+
return _ffi_api.LaunchThreadFrame(iter_var, extent) # pylint: disable=no-member # type: ignore
118+
119+
120+
def realize(
121+
buffer_slice: BufferRegion, storage_scope: str, condition: PrimExpr = True
122+
) -> RealizeFrame:
123+
return _ffi_api.RealizeFrame(
124+
buffer_slice, storage_scope, condition
125+
) # pylint: disable=no-member # type: ignore
126+
127+
128+
def attr(node: Object, attr_key: str, value: Union[PrimExpr, str]) -> AttrFrame:
129+
if isinstance(value, str):
130+
value = StringImm(value)
131+
return _ffi_api.AttrFrame(node, attr_key, value) # pylint: disable=no-member # type: ignore
132+
133+
134+
def while_(condition: PrimExpr) -> WhileFrame:
135+
return _ffi_api.WhileFrame(condition) # pylint: disable=no-member # type: ignore
136+
137+
138+
def if_(condition: PrimExpr) -> IfFrame:
139+
return _ffi_api.IfFrame(condition) # pylint: disable=no-member # type: ignore
140+
141+
142+
def then_() -> ThenFrame:
143+
return _ffi_api.ThenFrame() # pylint: disable=no-member # type: ignore
144+
145+
146+
def else_() -> ElseFrame:
147+
return _ffi_api.ElseFrame() # pylint: disable=no-member # type: ignore
148+
149+
150+
def env_thread(thread_tag: str) -> IterVar:
151+
return _ffi_api.EnvThread(thread_tag) # pylint: disable=no-member # type: ignore
152+
153+
154+
def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[PrimExpr]) -> None:
155+
return _ffi_api.BufferStore(buffer, value, indices) # pylint: disable=no-member # type: ignore
156+
157+
158+
def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
159+
return _ffi_api.Prefetch(buffer, indices) # pylint: disable=no-member # type: ignore
160+
161+
162+
def evaluate(value: PrimExpr) -> None:
163+
return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore

src/script/builder/tir/block_frame.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ void BlockFrameNode::ExitWithScope() {
5252
Block block = Block(iter_vars, reads, writes, name, AsStmt(stmts), init, alloc_buffers,
5353
match_buffers, annotations);
5454
if (no_realize) {
55-
CHECK(iter_values.empty()) << "ValueError: Block bindings are not allowed when `no_realize=True`";
55+
CHECK(iter_values.empty())
56+
<< "ValueError: Block bindings are not allowed when `no_realize=True`";
5657
CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
5758
AddToParent(block);
5859
} else {
@@ -68,7 +69,7 @@ BlockInitFrame Init() {
6869
void BlockInitFrameNode::EnterWithScope() {
6970
BlockFrame frame = FindBlockFrame("T.init");
7071
if (frame->init.defined()) {
71-
LOG(FATAL) << "Duplicate block init declaration";
72+
LOG(FATAL) << "ValueError: Duplicate block init declaration";
7273
}
7374
TIRFrameNode::EnterWithScope();
7475
}
@@ -92,7 +93,7 @@ BlockFrame FindBlockFrame(const String& method) {
9293
void Where(PrimExpr predicate) {
9394
BlockFrame frame = FindBlockFrame("T.where");
9495
if (frame->predicate.defined()) {
95-
LOG(FATAL) << "Duplicate block predicate declaration, previous one is "
96+
LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is "
9697
<< frame->predicate.value();
9798
}
9899
frame->predicate = predicate;
@@ -102,7 +103,7 @@ void Reads(Array<ObjectRef> buffer_slices) {
102103
using namespace tvm::tir;
103104
BlockFrame frame = FindBlockFrame("T.reads");
104105
if (!frame->reads.empty()) {
105-
LOG(FATAL) << "Duplicate read region declaration, previous one is " << frame->reads;
106+
LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads;
106107
}
107108
for (const ObjectRef& obj : buffer_slices) {
108109
if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
@@ -119,7 +120,8 @@ void Writes(Array<ObjectRef> buffer_slices) {
119120
using namespace tvm::tir;
120121
BlockFrame frame = FindBlockFrame("T.writes");
121122
if (!frame->writes.empty()) {
122-
LOG(FATAL) << "Duplicate write region declaration, previous one is " << frame->writes;
123+
LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is "
124+
<< frame->writes;
123125
}
124126
for (const ObjectRef& obj : buffer_slices) {
125127
if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
@@ -135,7 +137,7 @@ void Writes(Array<ObjectRef> buffer_slices) {
135137
void BlockAttrs(Map<String, ObjectRef> attrs) {
136138
BlockFrame frame = FindBlockFrame("T.block_attr");
137139
if (!frame->annotations.empty()) {
138-
LOG(FATAL) << "Duplicate block annotations, previous one is " << frame->annotations;
140+
LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
139141
}
140142
frame->annotations = attrs;
141143
}

src/script/builder/tir/prim_func_frame.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) {
116116
void FuncName(String name) {
117117
PrimFuncFrame frame = FindPrimFuncFrame("T.func_name");
118118
if (frame->name.defined()) {
119-
LOG(FATAL) << "Duplicate prim func name, previous one is " << frame->name.value();
119+
LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value();
120120
}
121121
frame->name = name;
122122
}
@@ -125,15 +125,16 @@ void FuncAttrs(Map<String, ObjectRef> attrs) {
125125
using namespace tvm::tir;
126126
PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr");
127127
if (!frame->attrs.empty()) {
128-
LOG(FATAL) << "Duplicate prim func annotations, previous one is " << frame->attrs;
128+
LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs;
129129
}
130130
frame->attrs = attrs;
131131
}
132132

133133
tvm::Type FuncRet(tvm::Type ret_type) {
134134
PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type");
135135
if (frame->ret_type.defined()) {
136-
LOG(FATAL) << "Duplicate prim func return type, previous one is " << frame->ret_type.value();
136+
LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is "
137+
<< frame->ret_type.value();
137138
}
138139
frame->ret_type = ret_type;
139140
return ret_type;

0 commit comments

Comments
 (0)