Skip to content

Commit 94d645b

Browse files
YuchenJinmikepapadim
authored andcommitted
Relax IRBuilder (#4)
* Add initial IRBuilder. * Add function output to irbuilder; update based on new AST. * Add call method; clean up bindings * Add test. * Add multifuction test * Move implementation to C++; infer shape and type * update op python hook * More tests and bug fix * Add comments. * Update shape/type inference. * Restructure code; add python type hint. * Cleanup code. * Rebase; address comments. * Add call intrinsic. * nits. * Remove call op. * Migrate scope to C++ using tvm::With. * Address naming. * Add GetBlocks API. * Unify EmitOutput APIs; add more comments. * Remove shape and type deduction code. * Also remove the shape/type attr interface. * Address comments. * Differentiate global and local function. * Reset counter after building func/block. * Rebase. * Remove shape infer builtin. * Return from void function as empty tuple. Co-authored-by: Michalis Papadimitriou <[email protected]>
1 parent 83e5d23 commit 94d645b

File tree

15 files changed

+789
-11
lines changed

15 files changed

+789
-11
lines changed

include/tvm/ir/type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ class TupleType : public Type {
396396
inline Type VoidType() { return TupleType::Empty(); }
397397

398398
/*!
399-
* \brief Check whether the tyep represents void.
399+
* \brief Check whether the type represents void.
400400
* \return The check result.
401401
*/
402402
inline bool IsVoidType(const Type& type) {

include/tvm/relax/ir_builder.h

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/ir_builder.h
22+
* \brief The utility for constructing Relax AST.
23+
*/
24+
#ifndef TVM_RELAX_IR_BUILDER_H_
25+
#define TVM_RELAX_IR_BUILDER_H_
26+
27+
#include <tvm/ir/expr.h>
28+
#include <tvm/relax/expr.h>
29+
#include <tvm/relay/expr.h>
30+
#include <tvm/runtime/object.h>
31+
#include <tvm/runtime/registry.h>
32+
#include <tvm/support/with.h>
33+
34+
namespace tvm {
35+
namespace relax {
36+
37+
using relay::Call;
38+
39+
class IRBuilder;
40+
41+
/*!
42+
* \brief The state of Relax function node being built.
43+
*/
44+
struct RelaxFunction {
45+
/*! \brief The function name. */
46+
Optional<GlobalVar> func_name = NullOpt;
47+
/*! \brief The function parameters. */
48+
Array<Var> params;
49+
/*! \brief The bindings in the function. */
50+
std::vector<Binding> bindings;
51+
/*! \brief The binding blocks in the function. */
52+
std::vector<BindingBlock> binding_blocks;
53+
/*! \brief The return of the function. */
54+
Expr ret = Tuple();
55+
/*! \brief The FunctionNode being built. */
56+
Function func;
57+
};
58+
59+
/*!
60+
* \brief A builder that provides APIs to build Relax AST.
61+
*/
62+
class IRBuilderNode : public Object {
63+
public:
64+
/*!
65+
* \brief Fill the function name and parameters.
66+
*/
67+
void FillFuncNameParam(const Array<Var>& params, const std::string& func_name);
68+
/*!
69+
* \brief Build a function node.
70+
*/
71+
void BuildFunction();
72+
/*!
73+
* \brief Build a binding block.
74+
*/
75+
void BuildBlock();
76+
/*!
77+
* \brief Emit a call node.
78+
* \param call The CallNode to be emitted.
79+
* \return The variable being created and binded to \p call.
80+
*/
81+
Var Emit(const Call& call);
82+
/*!
83+
* \brief Generate an output for the current dataflow block or function.
84+
* \param output The output variable of the block/function.
85+
* \return The variable being binded to \p ouput.
86+
*/
87+
Var EmitOutput(const Expr& output);
88+
/*!
89+
* \brief Get the function being built.
90+
*/
91+
Function Get();
92+
/*!
93+
* \brief Get binding blocks being built.
94+
*/
95+
std::vector<BindingBlock> GetBlocks();
96+
/*!
97+
* \brief Create a IRBuilder.
98+
* \return The created IRBuilder.
99+
*/
100+
TVM_DLL static IRBuilder Create();
101+
102+
void VisitAttrs(AttrVisitor* v) {}
103+
104+
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
105+
static constexpr const char* _type_key = "relax.IRBuilder";
106+
TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, Object);
107+
108+
private:
109+
/*! \brief The state of the function currently being built. */
110+
RelaxFunction func;
111+
/*! \brief A flag tracking if currently inside a dataflow block or not. */
112+
bool is_dataflow = false;
113+
/*! \brief A global variable counter for naming global variables. */
114+
int global_var_counter = 0;
115+
/*! \brief A dataflow variable counter for naming dataflow variables. */
116+
int dataflow_var_counter = 0;
117+
};
118+
119+
class IRBuilder : public ObjectRef {
120+
public:
121+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
122+
};
123+
124+
/*! \brief Auxiliary scope for building Relax function node,
125+
* similar to python's with syntax.
126+
*
127+
* \code
128+
* {
129+
* With<FunctionScope> scope(ir_builder);
130+
* // build function node.
131+
* }
132+
*/
133+
class FunctionScopeNode : public Object {
134+
public:
135+
IRBuilder ir_builder;
136+
void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); }
137+
138+
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
139+
static constexpr const char* _type_key = "relax.FunctionScope";
140+
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionScopeNode, Object);
141+
};
142+
143+
class FunctionScope : public ObjectRef {
144+
public:
145+
TVM_DLL FunctionScope(IRBuilder ib);
146+
TVM_DEFINE_OBJECT_REF_METHODS(FunctionScope, ObjectRef, FunctionScopeNode);
147+
class Internal;
148+
149+
private:
150+
// Classes to get the Python `with` like syntax.
151+
friend class Internal;
152+
friend class With<FunctionScope>;
153+
// The entry of a function scope.
154+
TVM_DLL void EnterWithScope();
155+
// The exit of a function scope.
156+
TVM_DLL void ExitWithScope();
157+
};
158+
159+
/*! \brief Auxiliary scope for building Relax dataflow block,
160+
* similar to python's with syntax.
161+
*
162+
* \code
163+
* {
164+
* With<DataflowScope> scope(ir_builder);
165+
* // build dataflow block.
166+
* }
167+
*/
168+
class DataflowScopeNode : public Object {
169+
public:
170+
IRBuilder ir_builder;
171+
void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); }
172+
173+
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
174+
static constexpr const char* _type_key = "relax.DataflowScope";
175+
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowScopeNode, Object);
176+
};
177+
178+
class DataflowScope : public ObjectRef {
179+
public:
180+
TVM_DLL DataflowScope(IRBuilder ib);
181+
TVM_DEFINE_OBJECT_REF_METHODS(DataflowScope, ObjectRef, DataflowScopeNode);
182+
class Internal;
183+
184+
private:
185+
// Classes to get the Python `with` like syntax.
186+
friend class Internal;
187+
friend class With<DataflowScope>;
188+
// The entry of a dataflow scope.
189+
TVM_DLL void EnterWithScope();
190+
// The exit of a dataflow scope.
191+
TVM_DLL void ExitWithScope();
192+
};
193+
194+
} // namespace relax
195+
} // namespace tvm
196+
197+
#endif // TVM_RELAX_IR_BUILDER_H_

include/tvm/relax/type.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ namespace relax {
3939

4040
class ShapeTypeNode : public TypeNode {
4141
public:
42-
43-
void VisitAttrs(tvm::AttrVisitor* v) {
44-
}
42+
void VisitAttrs(tvm::AttrVisitor* v) {}
4543

4644
bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
4745
return true;
@@ -64,10 +62,9 @@ class ShapeType : public Type {
6462
const ShapeTypeNode* get() const {
6563
return operator->();
6664
}
67-
using ContainerType = ShapeTypeNode;
65+
using ContainerType = ShapeTypeNode;
6866
};
6967

70-
7168
class DynTensorTypeNode : public BaseTensorTypeNode {
7269
public:
7370
/*!
@@ -92,6 +89,10 @@ class DynTensorTypeNode : public BaseTensorTypeNode {
9289
hash_reduce(dtype);
9390
}
9491

92+
inline bool IsUnknownRank() const { return rank == -1; }
93+
94+
inline bool IsUnknownDtype() const { return dtype.is_void(); }
95+
9596
static constexpr const char* _type_key = "relax.DynTensorType";
9697
TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode);
9798
};

python/tvm/relax/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from . import ty
2121
from . import vm
2222
from . import op
23+
from . import ir_builder
24+
from . import op
2325

2426

2527
# Expr
@@ -56,3 +58,6 @@
5658

5759
# Operator
5860
from .op.base import call_dps
61+
62+
# IRBuilder
63+
IRBuilder = ir_builder.IRBuilder

0 commit comments

Comments
 (0)