Skip to content

Commit 3ad8d9c

Browse files
committed
Squashed commit: AutoTIR
1 parent 8d76075 commit 3ad8d9c

File tree

98 files changed

+7709
-521
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+7709
-521
lines changed

include/tvm/arith/int_set.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>&
169169
* \return An integer set that can cover all the possible values of e.
170170
*/
171171
IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
172+
/*!
173+
* \brief Same as EvalSet, but takes Map<Var, IntSet>
174+
*
175+
* \param e The expression to be evaluated.
176+
* \param dom_map The domain of each variable.
177+
* \return An integer set that can cover all the possible values of e.
178+
*/
179+
IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map);
172180
/*!
173181
* \brief Same as EvalSet, but takes unordered_map
174182
*
@@ -177,6 +185,15 @@ IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
177185
* \return An integer set that can cover all the possible values of e.
178186
*/
179187
IntSet EvalSet(PrimExpr e, const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
188+
/*!
189+
* \brief Same as EvalSet, but takes Array<PrimExpr>
190+
*
191+
* \param exprs The expressions to be evaluated.
192+
* \param dom_map The domain of each variable.
193+
* \return An array of integer sets that can cover all the possible values.
194+
*/
195+
Array<IntSet> EvalSet(const Array<PrimExpr>& exprs, const Map<Var, IntSet>& dom_map);
196+
180197
/*!
181198
* \brief Find an symbolic integer set that contains is union over
182199
* all the possible conditional values in dom_map.

include/tvm/meta_schedule/builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class BuilderInputNode : public runtime::Object {
3232
IRModule mod;
3333
/*! \brief The target to be built for. */
3434
Target target;
35-
/*! \brief The optional parameters used for build */
35+
/*! \brief Parameters for Relay build module. */
3636
Optional<Map<String, runtime::NDArray>> params;
3737

3838
void VisitAttrs(tvm::AttrVisitor* v) {
@@ -55,7 +55,7 @@ class BuilderInput : public runtime::ObjectRef {
5555
* \brief Constructor of BuilderInput.
5656
* \param mod The IRModule to be built.
5757
* \param target The target to be built for.
58-
* \param params The optional parameters used for build
58+
* \param params Parameters for Relay build module.
5959
*/
6060
TVM_DLL explicit BuilderInput(IRModule mod, Target target,
6161
Optional<Map<String, runtime::NDArray>> params = NullOpt);

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class ScheduleRule : public runtime::ObjectRef {
137137
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
138138
* - NullOpt on CPU
139139
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
140+
* \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation
140141
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
141142
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
142143
* NullOpt means disable vectorization
@@ -146,6 +147,7 @@ class ScheduleRule : public runtime::ObjectRef {
146147
*/
147148
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
148149
Optional<Array<String>> tile_binds, //
150+
bool use_tensor_core, //
149151
Optional<Integer> max_innermost_factor, //
150152
Optional<Array<Integer>> vector_load_lens, //
151153
Optional<Map<String, ObjectRef>> reuse_read, //

include/tvm/meta_schedule/tune_context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class TuneContextNode : public runtime::Object {
8282
v->Visit("rand_state", &rand_state);
8383
v->Visit("num_threads", &num_threads);
8484
v->Visit("is_stopped", &is_stopped);
85+
v->Visit("builder_results", &builder_results);
86+
v->Visit("runner_futures", &runner_futures);
8587
v->Visit("measure_candidates", &measure_candidates);
8688
}
8789

include/tvm/tir/function.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,65 @@ class LinkedParam : public ObjectRef {
187187
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
188188
};
189189

190+
/*! \brief A mapping from multi-dimensional indices to another set of multi-dimensional indices */
191+
class IndexMapNode : public Object {
192+
public:
193+
/*! \brief The source indices */
194+
Array<Var> src_iters;
195+
/*! \brief The target indices */
196+
Array<PrimExpr> tgt_iters;
197+
198+
void VisitAttrs(tvm::AttrVisitor* v) {
199+
v->Visit("src_iters", &src_iters);
200+
v->Visit("tgt_iters", &tgt_iters);
201+
}
202+
203+
/*!
204+
* \brief Take `inputs` as the source indices and return the corresponding target indices.
205+
* \param inputs The source indices.
206+
* \return The target indices.
207+
*/
208+
Array<PrimExpr> Apply(const Array<PrimExpr>& inputs) const;
209+
210+
/*!
211+
* \brief Map a shape to the output space
212+
* \param shape The shape in the source space
213+
* \return The shape in the target space
214+
*/
215+
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape) const;
216+
217+
/*!
218+
* \brief Convert to string representation in Python.
219+
* \return The stringified lambda expression in Python.
220+
*/
221+
String ToPythonString() const;
222+
223+
static constexpr const char* _type_key = "tir.IndexMap";
224+
TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object);
225+
};
226+
227+
/*!
228+
* \brief Managed reference to IndexMapNode.
229+
* \sa IndexMapNode
230+
*/
231+
class IndexMap : public ObjectRef {
232+
public:
233+
/*!
234+
* \brief Constructor.
235+
* \param src_iters The source indices.
236+
* \param tgt_iters The target indices.
237+
*/
238+
explicit IndexMap(Array<Var> src_iters, Array<PrimExpr> tgt_iters);
239+
/*!
240+
* \brief Create an index map from a packed function
241+
* \param ndim The number of dimensions
242+
* \param func The function to be applied
243+
* \return The created index map
244+
*/
245+
static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func);
246+
TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
247+
};
248+
190249
/*!
191250
* \brief Tensor intrinsics for tensorization
192251
*/

include/tvm/tir/schedule/schedule.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@ class ScheduleNode : public runtime::Object {
355355
*/
356356
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
357357
const String& storage_scope) = 0;
358+
/******** Schedule: Data movement ********/
359+
virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
360+
const String& storage_scope) = 0;
361+
virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index,
362+
const String& storage_scope) = 0;
358363
/******** Schedule: Compute location ********/
359364
/*!
360365
* \brief Move a producer block under the specific loop, and regenerate the
@@ -521,6 +526,21 @@ class ScheduleNode : public runtime::Object {
521526
*/
522527
virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;
523528

529+
/******** Schedule: Layout transformation ********/
530+
/*!
531+
* \brief Apply a transformation represented by IndexMap to buffer
532+
* \details The indices and the access region to the target buffer is transformed by the given
533+
* index_map. The index_map is used to infer the new shape of the buffer. Buffer must be either
534+
* a function parameter, or allocated in a block (it cannot be a buffer subregion created via
535+
* 'match_buffer').
536+
* \param block_rv The block that accesses the target buffer.
537+
* \param buffer_index The index of the buffer in block's read or write region.
538+
* \param is_write_index Whether the buffer_index is the index of the block's write region.
539+
* \param index_map The transformation to apply.
540+
*/
541+
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, bool is_write_index,
542+
const IndexMap& index_map) = 0;
543+
524544
/******** Schedule: Misc ********/
525545
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
526546
virtual void EnterPostproc() = 0;

include/tvm/tir/stmt.h

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,7 @@ class BlockRealize : public Stmt {
12241224
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
12251225
};
12261226

1227-
/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
1227+
/*! \brief namespace of possible attributes in AttrStmt.attr_key */
12281228
namespace attr {
12291229
// The above attr does not pass to ir stage.
12301230
/*! \brief Mark launching extent of thread, used by device API. */
@@ -1361,12 +1361,6 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
13611361
*/
13621362
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
13631363

1364-
/*! \brief Mark the stage of a statement in the software pipeline */
1365-
constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1366-
1367-
/*! \brief Mark the order of a statement in the software pipeline */
1368-
constexpr const char* software_pipeline_order = "software_pipeline_order";
1369-
13701364
/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
13711365
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
13721366

@@ -1400,6 +1394,54 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
14001394
/*! \brief Mark auto-unroll setting on the block. */
14011395
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
14021396

1397+
/*!
1398+
* \brief Mark that the block need to add predicate for block var bounds during lowering
1399+
*/
1400+
constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";
1401+
1402+
/*!
1403+
* \brief Mark that the block should be further rewritten using tensorization.
1404+
*/
1405+
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1406+
1407+
/*! \brief Mark that tensor core is enabled in the PrimExpr */
1408+
constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";
1409+
1410+
/*!
1411+
* \brief Mark a block as generated by cache_read or cache_write block.
1412+
* 0 means cache_read; 1 means cache_write.
1413+
* \sa meta_schedule_cache_type_read
1414+
* \sa meta_schedule_cache_type_write
1415+
*/
1416+
constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
1417+
1418+
/*! \sa meta_schedule_cache_type */
1419+
constexpr const int meta_schedule_cache_type_read = 0;
1420+
1421+
/*! \sa meta_schedule_cache_type */
1422+
constexpr const int meta_schedule_cache_type_write = 1;
1423+
1424+
/*! \brief Mark the scope of the software pipeline */
1425+
constexpr const char* software_pipeline_scope = "software_pipeline_scope";
1426+
1427+
/*! \brief Mark the stage of a statement in the software pipeline */
1428+
constexpr const char* software_pipeline_stage = "software_pipeline_stage";
1429+
1430+
/*! \brief Mark the order of a statement in the software pipeline */
1431+
constexpr const char* software_pipeline_order = "software_pipeline_order";
1432+
1433+
/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify
1434+
* the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the
1435+
* prologue, the body, and the epilogue of the software pipeline.
1436+
*/
1437+
constexpr const char* nested_software_pipeline_stage = "nested_software_pipeline_stage";
1438+
1439+
/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify
1440+
* the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the
1441+
* prologue, the body, and the epilogue of the software pipeline.
1442+
*/
1443+
constexpr const char* nested_software_pipeline_order = "nested_software_pipeline_order";
1444+
14031445
/*!
14041446
* \brief Check if attr_key is a pragma key extension
14051447
* \param attr_key The attr key to be compared

include/tvm/tir/stmt_functor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,12 @@ inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>&
406406
* \param stmt_or_expr The ir to be visited.
407407
* \param fvisit The visitor function to be applied. If fvisit returns false, it won't visit the
408408
* children of the node
409+
* \param visit_init_block Whether or not to visit the init block
410+
* children of the node
409411
*/
410412
TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
411-
const std::function<bool(const ObjectRef&)>& fvisit);
413+
const std::function<bool(const ObjectRef&)>& fvisit,
414+
bool visit_init_block = true);
412415
} // namespace tir
413416
} // namespace tvm
414417

include/tvm/tir/transform.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,25 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
601601
*/
602602
TVM_DLL Pass InjectSoftwarePipeline();
603603

604+
/*!
605+
* \brief Automatically do memory optimizations for auto copy blocks
606+
* \return The pass.
607+
*/
608+
TVM_DLL Pass LowerAutoCopy();
609+
610+
/*!
611+
* \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
612+
* \return The pass.
613+
*/
614+
TVM_DLL Pass RenormalizeSplitPattern();
615+
616+
/*!
617+
* \brief Narrow the extents of some loops by checking whether some constraints in the block iter
618+
* bound predicates can be directly applied on the loops.
619+
* \return The pass.
620+
*/
621+
TVM_DLL Pass ApplyBlockBoundPredicate();
622+
604623
} // namespace transform
605624
} // namespace tir
606625
} // namespace tvm

include/tvm/tir/var.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ class IterVarNode : public Object {
255255
IterVarType iter_type;
256256
/*!
257257
* \brief additional tag on the iteration variable,
258-
* set this if this is binded already to a known thread tag.
258+
* set this if this is bound already to a known thread tag.
259259
*/
260260
String thread_tag;
261261
/*!

0 commit comments

Comments
 (0)