Skip to content

Commit b5f1dab

Browse files
d-smirnovGiuseppe Rossini
andauthored
[TIR] Tir constants integration into compilation pipeline (#8509)
* [TIR] Introduce tir.allocate_const to TIR This PR is adding non-scalar constant representation in TIR. This is used to express constants (i.e., parameters) in the TIR instead of bypassing the TIR as it's done until now. Change-Id: Id3afc4d7197260cb43ecde60f05ccbce3fc42430 Co-authored-by: Giuseppe Rossini <[email protected]> Change-Id: Id4a09a637c9c1fd7d49989c6c10f474a78569e18 * [TIR] Integrate tir constant nodes in compilation pipeline This PR integrates tir.allocate_const to the compilation pipeline to support --link-params. Change-Id: Ic8d0cb75d596299fcae7078b304598afbf0c5494 Co-authored-by: Giuseppe Rossini <[email protected]> Change-Id: Id98cc682bbfacfe75c4d8b260fd41658f1f196b2 * [TIR] tir.const extraction This commit tries to implement an amendment to tir.constant RFC with centralized storage of constant data within the IRModule Please note that data and irmod_storage_idx are not mutual exclisive further more the irmod_storage_idx is valid only immediatly after prim func addition to the mod or after update within the mod. If prim func is out of the the module scope then the index become meangless. irmod_storage_idx also is not used in calculation of hash function of the tir.constant node. Change-Id: I40742ed580468b0252ea3fec02184cba65e20871 * unit test fixed Change-Id: Ied2186554d4cbad44b2346216c8be92449e55732 * cmsis-nn codegen fix Now handled case when params of the functions came as constants Change-Id: I5874e182e34ef94e23048eaf3c61b01a56d91131 * Fixes for unittests Change-Id: I5b82ee3f80337155706b5470973f494a301b5d90 * Rebasing tests fixes Change-Id: I94ac87907081bab53c1dd1ab2db106ae057b4b19 * Linter: added method param description Change-Id: I2f8c4c8d244b74c794abaa6079c46cc593ffcbdb * Printing removal fix This patch removes forgotten print in fuse_ops Change-Id: I4bb5934f3b4cd5fde19d36a8e3319aae136bce8a * Bugfix Fixed concurrent map update bug here Change-Id: Ifec3bf5030086d9079b9e493096f17dfd82297ec * Reworked logic for not to introduce empty constant list to modue attrs Change-Id: I082c85b3b4b70c218f0d714f5613ef6e178bd020 * Added support for tir builtin::tvm_access_ptr This fixed unit tests for tests/python/integration/test_arm_mprofile_dsp.py Change-Id: I10919f301ef9ddc3fd87f0e1a8414e9a52fc7938 * Unit test fix Fixes unit tests in torch frontend Change-Id: I6c179834f93dd202605d1ce5a7f07d987b9dc469 * Addressed requested changes Addressed changes requested upstream Change-Id: I741e52b89eb285732c23b1ac7ff277e757a088c3 * Namespace usage changed to conform earlier C++ standard Change-Id: I1b29238cfe2a6bedb525f4f823a3a540f631d836 * Bugfix Change-Id: I57a44b714b307278a243817ec2864e53ad31366b * updated IRModuleNode::ExtractPrimFuncConstants Updated IRModuleNode::ExtractPrimFuncConstants as per request upstream. Change-Id: I35db0145fb5827efd0445ce665d0c99465274016 * Minor changes typo fixd renamed ExtractPrimFuncConstants to ExtractConstants removed getters/setters from FuseMutator and added parametrized constructor Change-Id: Ib2326805781779b88c963a8642ff683c8755956e * Moved LinkedParam/LinkedParamNode Moved LinkedParam/LinkedParamNode from tvm::tir namespace to tvm namespace Change-Id: Ie3f0303bd4f7890c6d680268c91f2051977bc7f4 * Addressed upstream comments Changed BindParams argument to Array<NDArray> Removed 'name' argument from te.const Switched to in-depth comparision of NDArrays in constant de-duplication Removed extra final comma from NDArrayToTIR Changed return type of ConstantAllocationSize to int64_t Made link_param a tvm.testing.parameter for test_fuse_take and test_fuse_gather_nd Change-Id: I4285099cc63756aa5ebe91a5bd207d4135499b41 * Removed unnecessary forward declaration +linter Change-Id: I2a6c0d1f97773aeb1ae3f458da252a22079ccdb1 * Constant extractor now is a separate pass Change-Id: Ia4adca9d3315b26fbdc006ef7c115900c081e303 * Added forgotten file + unit test fix Change-Id: Ice305f4fefd13fe95e97574e6d63ffeb664621df * Changed to IRModule pass Refactored ExtractPrimFuncConstants to IRModule pass. deDup -> DeDup Refactored logic of Applicator supplementary class Change-Id: I6c120d175eb6790ba90f176c4f856bde8f0c7c94 * bugfix after rebasing Change-Id: Ie3ee6ea2479476a30f486baef74f20070f117942 * -v -> -vv to have more debug information Change-Id: I12c63731663b9c9ea574b9ed5cb17311ba3cf701 Co-authored-by: Giuseppe Rossini <[email protected]>
1 parent 5956125 commit b5f1dab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+1221
-338
lines changed

include/tvm/ir/module.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,44 @@
4040
#include <vector>
4141

4242
namespace tvm {
43+
/*!
44+
* \brief Describes one parameter that should be linked into the generated module.
45+
*
46+
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
47+
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
48+
* use the information contained in this node to include the parameter data in the generated
49+
* module.
50+
*/
51+
class LinkedParamNode : public Object {
52+
public:
53+
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
54+
int64_t id;
55+
56+
/*! \brief Parameter data which should get linked into the final module. */
57+
::tvm::runtime::NDArray param;
58+
59+
void VisitAttrs(tvm::AttrVisitor* v) {
60+
v->Visit("id", &id);
61+
v->Visit("param", &param);
62+
}
63+
64+
static constexpr const char* _type_key = "tir.LinkedParam";
65+
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
66+
};
67+
68+
/*!
69+
* \brief Managed reference to LinkedParamNode.
70+
*/
71+
class LinkedParam : public ObjectRef {
72+
public:
73+
TVM_DLL LinkedParam(int64_t id, tvm::runtime::NDArray param);
74+
75+
TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
76+
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
77+
};
78+
4379
class IRModule;
80+
4481
/*!
4582
* \brief IRModule that holds functions and type definitions.
4683
*
@@ -504,6 +541,11 @@ constexpr const char* kRuntime = "runtime";
504541
*/
505542
constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";
506543

544+
/*
545+
* \brief Module attribute for tir constants
546+
*/
547+
constexpr const char* kConstantsArray = "Constants";
548+
507549
} // namespace attr
508550
} // namespace tvm
509551
#endif // TVM_IR_MODULE_H_

include/tvm/node/structural_hash.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include <tvm/node/functor.h>
2727
#include <tvm/runtime/data_type.h>
28+
#include <tvm/runtime/ndarray.h>
2829

2930
#include <functional>
3031
#include <string>
@@ -199,5 +200,13 @@ class SHashReducer {
199200
bool map_free_vars_;
200201
};
201202

203+
class SEqualReducer;
204+
struct NDArrayContainerTrait {
205+
static constexpr const std::nullptr_t VisitAttrs = nullptr;
206+
static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce);
207+
static bool SEqualReduce(const runtime::NDArray::Container* lhs,
208+
const runtime::NDArray::Container* rhs, SEqualReducer equal);
209+
};
210+
202211
} // namespace tvm
203212
#endif // TVM_NODE_STRUCTURAL_HASH_H_

include/tvm/relay/executor.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class ExecutorNode : public Object {
113113
}
114114

115115
static constexpr const char* _type_key = "Executor";
116+
static constexpr const bool _type_has_method_sequal_reduce = true;
117+
static constexpr const bool _type_has_method_shash_reduce = true;
116118
TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object);
117119
};
118120

@@ -122,8 +124,6 @@ class ExecutorNode : public Object {
122124
*/
123125
class Executor : public ObjectRef {
124126
public:
125-
Executor() = default;
126-
127127
/*!
128128
* \brief Create a new Executor object using the registry
129129
* \throws Error if name is not registered
@@ -147,7 +147,8 @@ class Executor : public ObjectRef {
147147
TVM_DLL static Map<String, String> ListExecutorOptions(const String& name);
148148

149149
/*! \brief specify container node */
150-
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode);
150+
TVM_DEFINE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode);
151+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecutorNode)
151152

152153
private:
153154
/*!

include/tvm/relay/interpreter.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,12 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
184184
* \param import_set Already imported external modules.
185185
* \param device The device on which all primitives will be executed.
186186
* \param target The compiler target flag for compiling primitives.
187+
* \param attrs Attributes for the expression to be evaluated with
187188
* @return The object representing the result.
188189
*/
189190
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
190-
std::unordered_set<String> import_set, Device device, Target target);
191+
std::unordered_set<String> import_set, Device device, Target target,
192+
Map<String, ObjectRef> attrs = {});
191193

192194
} // namespace relay
193195
} // namespace tvm

include/tvm/relay/runtime.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ class RuntimeNode : public Object {
105105
}
106106

107107
static constexpr const char* _type_key = "Runtime";
108+
static constexpr const bool _type_has_method_sequal_reduce = true;
109+
static constexpr const bool _type_has_method_shash_reduce = true;
108110
TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeNode, Object);
109111
};
110112

include/tvm/tir/function.h

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -151,42 +151,6 @@ class PrimFunc : public BaseFunc {
151151
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
152152
};
153153

154-
/*!
155-
* \brief Describes one parameter that should be linked into the generated module.
156-
*
157-
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
158-
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
159-
* use the information contained in this node to include the parameter data in the generated
160-
* module.
161-
*/
162-
class LinkedParamNode : public Object {
163-
public:
164-
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
165-
int64_t id;
166-
167-
/*! \brief Parameter data which should get linked into the final module. */
168-
::tvm::runtime::NDArray param;
169-
170-
void VisitAttrs(tvm::AttrVisitor* v) {
171-
v->Visit("id", &id);
172-
v->Visit("param", &param);
173-
}
174-
175-
static constexpr const char* _type_key = "tir.LinkedParam";
176-
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
177-
};
178-
179-
/*!
180-
* \brief Managed reference to LinkedParamNode.
181-
*/
182-
class LinkedParam : public ObjectRef {
183-
public:
184-
TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);
185-
186-
TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
187-
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
188-
};
189-
190154
/*!
191155
* \brief Tensor intrinsics for tensorization
192156
*/
@@ -239,7 +203,7 @@ class TensorIntrin : public ObjectRef {
239203
TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode)
240204
};
241205

242-
/*!
206+
/*
243207
* \brief Specialize parameters of PrimFunc.
244208
* \param func The PrimFunc to be specialized.
245209
* \param param_map The mapping from function params to the instance.

include/tvm/tir/stmt.h

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,16 +559,18 @@ class AllocateNode : public StmtNode {
559559
* Otherwise return 0.
560560
* \return The result.
561561
*/
562-
int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
562+
int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
563563
/*!
564564
* \brief If the buffer size is constant, return the size.
565565
* Otherwise return 0.
566566
* \param extents The extents of the buffer.
567567
* \return The result.
568568
*/
569-
TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
569+
TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
570570

571571
static constexpr const char* _type_key = "tir.Allocate";
572+
static constexpr const bool _type_has_method_sequal_reduce = true;
573+
static constexpr const bool _type_has_method_shash_reduce = true;
572574
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
573575
};
574576

@@ -585,6 +587,96 @@ class Allocate : public Stmt {
585587
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
586588
};
587589

590+
/*!
591+
* \brief Allocate a buffer that can be used in body.
592+
*/
593+
class AllocateConstNode : public StmtNode {
594+
public:
595+
/*! \brief The buffer variable. */
596+
Var buffer_var;
597+
/*! \brief The optional data associated to the constant.
598+
*/
599+
Optional<runtime::NDArray> data;
600+
/*! \brief If the PrimFunc containing the Stmt is added to IRModule,
601+
this is an optional index to indicate the index within
602+
"Constants" attribute, that is a Array<NDArray> of IRModule.
603+
*/
604+
Optional<Integer> irmod_storage_idx;
605+
/*! \brief The type of the buffer. */
606+
DataType dtype;
607+
/*! \brief The extents of the buffer. */
608+
Array<PrimExpr> extents;
609+
/*! \brief The body to be executed. */
610+
Stmt body;
611+
/*!
612+
* \brief Additional annotations about the allocation.
613+
*
614+
* These annotations can be used as auxiliary hint
615+
* to future transformations.
616+
*/
617+
Map<String, ObjectRef> annotations;
618+
619+
void VisitAttrs(AttrVisitor* v) {
620+
v->Visit("buffer_var", &buffer_var);
621+
v->Visit("data", &data);
622+
v->Visit("irmod_storage_idx", &irmod_storage_idx);
623+
v->Visit("dtype", &dtype);
624+
v->Visit("extents", &extents);
625+
v->Visit("body", &body);
626+
v->Visit("annotations", &annotations);
627+
v->Visit("span", &span);
628+
}
629+
630+
bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
631+
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
632+
equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) &&
633+
equal(annotations, other->annotations);
634+
}
635+
636+
void SHashReduce(SHashReducer hash_reduce) const {
637+
hash_reduce.DefHash(buffer_var);
638+
hash_reduce(dtype);
639+
hash_reduce(extents);
640+
hash_reduce(body);
641+
hash_reduce(annotations);
642+
hash_reduce(data);
643+
}
644+
645+
/*!
646+
* \brief If the buffer size is constant, return the size.
647+
* Otherwise return 0.
648+
* \return The result.
649+
*/
650+
int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); }
651+
/*!
652+
* \brief If the buffer size is constant, return the size.
653+
* Otherwise return 0.
654+
* \param extents The extents of the buffer.
655+
* \return The result.
656+
*/
657+
TVM_DLL static int64_t ConstantAllocationSize(const Array<PrimExpr>& extents);
658+
659+
static constexpr const char* _type_key = "tir.AllocateConst";
660+
static constexpr const bool _type_has_method_sequal_reduce = true;
661+
static constexpr const bool _type_has_method_shash_reduce = true;
662+
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode);
663+
};
664+
665+
/*!
666+
* \brief Managed reference to AllocateConstNode.
667+
* \sa AllocateConstNode
668+
*/
669+
class AllocateConst : public Stmt {
670+
public:
671+
/* The constructor to create a IRNode with constant data
672+
* depending on the type of ObjectRef, it will either
673+
* create AllocateConstNode with irmod_storage_idx or data
674+
*/
675+
TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
676+
ObjectRef data_or_idx, Stmt body, Span span = Span());
677+
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
678+
};
679+
588680
/*!
589681
* \brief The container of seq statement.
590682
* Represent a sequence of statements.

include/tvm/tir/stmt_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
8787
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
8888
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
8989
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
90+
virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9091
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9192
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9293
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -113,6 +114,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
113114
IR_STMT_FUNCTOR_DISPATCH(ForNode);
114115
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
115116
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
117+
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
116118
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
117119
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
118120
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
@@ -155,6 +157,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
155157
void VisitStmt_(const ForNode* op) override;
156158
void VisitStmt_(const WhileNode* op) override;
157159
void VisitStmt_(const AllocateNode* op) override;
160+
void VisitStmt_(const AllocateConstNode* op) override;
158161
void VisitStmt_(const StoreNode* op) override;
159162
void VisitStmt_(const BufferStoreNode* op) override;
160163
void VisitStmt_(const BufferRealizeNode* op) override;
@@ -255,6 +258,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
255258
Stmt VisitStmt_(const ForNode* op) override;
256259
Stmt VisitStmt_(const WhileNode* op) override;
257260
Stmt VisitStmt_(const AllocateNode* op) override;
261+
Stmt VisitStmt_(const AllocateConstNode* op) override;
258262
Stmt VisitStmt_(const StoreNode* op) override;
259263
Stmt VisitStmt_(const BufferStoreNode* op) override;
260264
Stmt VisitStmt_(const BufferRealizeNode* op) override;

include/tvm/tir/transform.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/tir/function.h>
3030

3131
#include <string>
32+
#include <vector>
3233

3334
namespace tvm {
3435
namespace tir {
@@ -601,6 +602,15 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
601602
*/
602603
TVM_DLL Pass InjectSoftwarePipeline();
603604

605+
TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
606+
607+
/*!
608+
* \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute.
609+
*
610+
* \return The pass.
611+
*/
612+
TVM_DLL Pass ExtractPrimFuncConstants();
613+
604614
} // namespace transform
605615
} // namespace tir
606616
} // namespace tvm

0 commit comments

Comments
 (0)