From 936b8644f15651f700d6f6a4cf1b0ceae3df8f14 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 26 May 2022 14:13:23 -0700 Subject: [PATCH 1/2] [TIR] Add schedule primitive TransformBlockLayout --- include/tvm/tir/schedule/schedule.h | 11 + python/tvm/tir/schedule/schedule.py | 61 ++++ src/tir/schedule/analysis.h | 10 + src/tir/schedule/analysis/analysis.cc | 27 ++ src/tir/schedule/concrete_schedule.cc | 8 + src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/primitive.h | 12 + .../primitive/layout_transformation.cc | 303 ++++++++++++++++++ .../schedule/primitive/loop_transformation.cc | 29 +- src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 10 + src/tir/schedule/traced_schedule.h | 1 + src/tir/schedule/transform.cc | 31 ++ src/tir/schedule/transform.h | 39 +++ .../test_tir_schedule_transform_layout.py | 113 +++++++ 15 files changed, 632 insertions(+), 26 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 18e15d1670f1..767dcae49dae 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -545,6 +545,17 @@ class ScheduleNode : public runtime::Object { virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) = 0; + /*! + * \brief Apply a transformation represented by IndexMap to block + * \details The block iters and the block body are transformed by the given index_map. + * Outer loops corresponding to each new block iter are regenerated. + * The index_map is required to be bijective affine since we need its inverse mapping. + * \param self The state of the schedule + * \param block_sref The block sref that refers to the block to be transformed + * \param affine_index_map The transformation to apply. + */ + virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0; + /*! * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read * or write index diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index dc687b1eaef1..6c8326390d50 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2286,6 +2286,67 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> self, block, buffer_index, buffer_index_type_enum, axis_separators ) + @type_checked + def transform_block_layout( + self, + block: BlockRV, + index_map: Union[IndexMap, Callable], + ) -> None: + """Apply a transformation represented by IndexMap to block + + Parameters + ---------- + block_rv : BlockRV + The block to be transformed + + index_map : Union[IndexMap, Callable] + The transformation to apply. + + Examples + -------- + + Before transform_block_layout, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_transform_block_layout( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"] + ) -> None: + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do transform_block_layout: + + .. code-block:: python + + sch = tir.Schedule(before_transform_block_layout) + sch.transform_block_layout(sch.get_block("B"), lambda i, j: (i * 16 + j,)) + print(sch.mod["main"].script()) + + After applying transform_block_layout, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_transform_block_layout( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"] + ) -> None: + for i in range(256): + with T.block("B"): + vi, = T.axis.remap("S", [i]) + B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0 + """ + if callable(index_map): + index_map = IndexMap.from_func(index_map) + _ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member + self, block, index_map + ) + @type_checked def set_axis_separator( self, diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index c9c3d72ae0b5..7d84763ff85c 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -277,6 +277,16 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, std::unordered_set* data_par_vars, std::unordered_set* reduce_vars); +/******** Loop properties ********/ +/*! + * \brief Check the loop starts with zero. + * \param self The schedule state + * \param loop_sref The StmtSRef that points to the loop to be checked + * \param analyzer The arithmetic analyzer + * \throw ScheduleError If the loop doesn't starts with zero. + */ +void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer); + /******** Block-loop relation ********/ /*! diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4777ee2657b3..0659b3531894 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -686,6 +686,33 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, return has_block_vars_of_other_types; } +/******** Loop properties ********/ + +void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer) { + class LoopNotStartWithZeroError : public ScheduleError { + public: + explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive only supports loop starting with 0"; + } + + String DetailRenderTemplate() const final { + return "The loop {0} does not start with 0, which is not supported"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + }; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (!analyzer->CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } +} + /******** Block-loop relation ********/ Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 7b953220f22c..8066d85a8e7d 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -693,6 +693,14 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_i TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_); } +void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv, + const IndexMap& index_map) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("transform_block_layout", this->error_render_level_); +} + void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const Array& axis_separators) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 9293aa349300..8e83aac2ce82 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -134,6 +134,7 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) override; + void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const Array& axis_separators) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index d55b89693421..0926fbe0e220 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -442,6 +442,18 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map); +/*! + * \brief Apply a transformation represented by IndexMap to block + * \details The block iters and the block body are transformed by the given index_map. + * Outer loops corresponding to each new block iter are regenerated. + * The index_map is required to be bijective affine since we need its inverse mapping. + * \param self The state of the schedule + * \param block_sref The block sref that refers to the block to be transformed + * \param affine_index_map The transformation to apply. + */ +TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, + const IndexMap& index_map); + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index cf95665ee828..be93207f7a20 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -192,6 +192,268 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ self->Replace(scope_sref, new_scope_block, block_sref_reuse); } +/*! + * \brief Detect the block iter type assoicated with the expression + * + * This function collects block iters in the expression and check if the block iters have the same + * iter type. The detected iter type is the iter type of the block iters in the expression + * if they have the same iter type, otherwise the detected iter type will be kOpaque. + * + * \param expr The expression + * \param block_iter_type_map The mapping from block iter to iter type + * \return The detected block iter type + */ +IterVarType DetectNewBlockIterType( + const PrimExpr& expr, + const std::unordered_map& block_iter_type_map) { + IterVarType result{kOpaque}; + bool found = false; + PostOrderVisit(expr, [&](const ObjectRef& obj) { + if (const VarNode* var = obj.as()) { + auto it = block_iter_type_map.find(var); + if (it != block_iter_type_map.end()) { + if (!found) { + found = true; + result = it->second; + } else if (result != it->second) { + result = kOpaque; + return false; + } + } + } + return true; + }); + return result; +} + +class NotBijectiveAffineIndexMapError : public ScheduleError { + public: + NotBijectiveAffineIndexMapError(IRModule mod, IndexMap index_map) + : mod_(std::move(mod)), index_map_(std::move(index_map)) {} + String FastErrorString() const final { + return "ScheduleError: The index map is not bijective affine."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The index map " << index_map_->ToPythonString() << " is not bijective affine."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + IndexMap index_map_; +}; + +class IndexMapNotApplicableToBlockIterError : public ScheduleError { + public: + static void Check(const IRModule mod, const Block& block, const IndexMap& index_map) { + if (index_map->initial_indices.size() != block->iter_vars.size()) { + throw IndexMapNotApplicableToBlockIterError(mod, block, index_map); + } + } + explicit IndexMapNotApplicableToBlockIterError(IRModule mod, Block block, IndexMap index_map) + : mod_(std::move(mod)), block_(std::move(block)), index_map_(std::move(index_map)) {} + + String FastErrorString() const final { + return "ScheduleError: The index map can't be applied to block iters because the number of " + "parameters mismatch."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The index map " << index_map_->ToPythonString() + << " can't be applied to block iters of {0} because the number of parameters mismatch. " + "Expected: " + << index_map_->initial_indices.size() << ", actual: " << block_->iter_vars.size(); + return os.str(); + } + + IRModule mod() const final { return mod_; } + + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + Block block_; + IndexMap index_map_; + +}; + +class NotTrivialBindingError : public ScheduleError { + public: + explicit NotTrivialBindingError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + static void CheckBlockHasTrivialBinding(const IRModule& mod, const BlockRealize& block_realize, + std::unordered_set outer_loop_vars) { + // Step 2: Check all the binding values are loops vars + for (const PrimExpr& iter_value : block_realize->iter_values) { + const VarNode* loop_var = iter_value.as(); + if (!loop_var || !outer_loop_vars.count(loop_var)) { + throw NotTrivialBindingError(mod, block_realize->block); + } + } + } + + String FastErrorString() const final { + return "ScheduleError: The binding values of the block are not variables of outer loops."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The binding values of the {0} are not variables of outer loops."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; +}; + +class OpaqueNewIterTypeError : public ScheduleError { + public: + explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr iter_value) + : mod_(std::move(mod)), block_(std::move(block)), iter_value_(std::move(iter_value)) {} + + String FastErrorString() const final { + return "ScheduleError: Cannot detect the new block iter type because it contains more than one " + "type of original iter vars."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "Cannot detect the block iter type for new iter value " << PrettyPrint(iter_value_) + << " in {0} because it contains more than one type of original iter vars."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + PrimExpr iter_value_; +}; + +void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, + const IndexMap& index_map) { + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + const Block& block = GetRef(block_ptr); + arith::Analyzer analyzer; + + // Step 1: Collect outer loops and loop vars + Array loops = GetLoops(block_sref); // outer loops of the block + std::unordered_set loop_vars; // loop vars of the outer loops + for (const StmtSRef& loop_sref : loops) { + CheckLoopStartsWithZero(self, loop_sref, &analyzer); + loop_vars.emplace(loop_sref->StmtAs()->loop_var.get()); + } + + // Step 2: Check the all outer loops have a single child and the block bindings are trivial (all + // binding values are loop vars) + StmtSRef scope_sref{nullptr}; // the scope statement for replacement + if (!loops.empty()) { + scope_sref = loops.front(); + CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front()); + } else { + scope_sref = block_sref; + } + + BlockRealize block_realize = GetBlockRealize(self, block_sref); + NotTrivialBindingError::CheckBlockHasTrivialBinding(self->mod, block_realize, loop_vars); + + // Step 3: Collect information of block iter vars + Array block_vars; // iter_var->var of each block iter + Map block_iter_dom; // domain of block iter + std::unordered_map block_iter_type; // iter type of block iter + + Array + block_iter_range_array; // array of block iter extents in the same order as block iters + for (const auto& iter_var : block->iter_vars) { + block_vars.push_back(iter_var->var); + block_iter_dom.Set(iter_var->var, iter_var->dom); + block_iter_type[iter_var->var.get()] = iter_var->iter_type; + ICHECK(is_zero(iter_var->dom->min)); + block_iter_range_array.push_back(iter_var->dom->extent); + } + + // Step 4: Apply the IndexMap to block iters. + IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map); + Array transformed_block_iters = index_map->MapIndices(block_vars); + Array new_block_iter_range = index_map->MapShape(block_iter_range_array); + + auto iter_map = arith::DetectIterMap( + /*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom, /*predicate=*/Bool(true), + /*require_bijective=*/true, &analyzer, /*simplify_trivial_iterators=*/true); + if (iter_map.empty()) { + throw NotBijectiveAffineIndexMapError(self->mod, index_map); + } + + // Step 5: Create the new block after transformation. + + // Step 5.1: Create new block iters. After applying the IndexMap f to block iters ax_0, ..., ax_n, + // create block iter each expression in f(ax_0, ..., ax_n). + Array new_block_iters; // new block iters + Array new_block_vars; // iter_var->var of new block iters + for (size_t i = 0; i < index_map->final_indices.size(); ++i) { + Var new_block_var{"bv" + std::to_string(i), DataType::Int(32)}; + new_block_vars.push_back(new_block_var); + IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type); + if (iter_type == kOpaque) { + throw OpaqueNewIterTypeError(self->mod, GetRef(block_ptr), transformed_block_iters[i]); + } + new_block_iters.push_back(IterVar(/*dom=*/Range::FromMinExtent(0, new_block_iter_range[i]), + /*var=*/std::move(new_block_var), /*iter_type=*/iter_type)); + } + + // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters + // in the body. + + auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars); + // Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant zero. + for (const auto& iter_var : block_ptr->iter_vars) { + if (inverse_map.find(iter_var->var) == inverse_map.end()) { + ICHECK(is_one(iter_var->dom->extent)); + inverse_map.Set(iter_var->var, 0); + } + } + + Block new_block = Downcast(Substitute(GetRef(block_ptr), inverse_map)); + new_block.CopyOnWrite()->iter_vars = new_block_iters; + new_block = Downcast(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer)); + + // Step 5.3: Create outer loops for each new block iter. + + // Make new loop vars + Array new_loop_vars; + for (int i = 0; i < static_cast(new_block_iters.size()); ++i) { + new_loop_vars.push_back(Var("ax" + std::to_string(i), DataType::Int(32))); + } + + // Make new block realize + BlockRealizeNode* new_block_realize = block_realize.CopyOnWrite(); + new_block_realize->iter_values = new_loop_vars; + new_block_realize->block = new_block; + + // Generate outer loops + Stmt body = GetRef(new_block_realize); + for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; --i) { + body = For(Downcast(new_loop_vars[i]), 0, new_block_iter_range[i], ForKind::kSerial, std::move(body)); + } + + // Step 6: Do the actual replacement + self->Replace(scope_sref, body, {{block, new_block}}); +} + class BufferAxisSeparatorMutator : private ReplaceBufferMutator { public: static Block Mutate(const Block& scope_block, const Buffer& old_buffer, Buffer new_buffer, @@ -270,6 +532,7 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer // Step 4: Replace the scope block with the new block self->Replace(scope_sref, new_scope_block, block_sref_reuse); } + /******** InstructionKind Registration ********/ struct TransformLayoutTraits : public UnpackedInstTraits { @@ -324,6 +587,45 @@ struct TransformLayoutTraits : public UnpackedInstTraits friend struct ::tvm::tir::UnpackedInstTraits; }; +struct TransformBlockLayoutTraits : public UnpackedInstTraits { + static constexpr const char* kName = "TransformBlockLayout"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map) { + return sch->TransformBlockLayout(block_rv, index_map); + } + + static String UnpackedAsPython(Array outputs, String block_rv, IndexMap index_map) { + PythonAPICall py("transform_block_layout"); + py.Input("block", block_rv); + py.Input("index_map", index_map->ToPythonString()); + return py.Str(); + } + + public: + static ObjectRef AttrsAsJSON(const Array& attrs) { + Array attrs_record; + attrs_record.reserve(kNumAttrs); + attrs_record.push_back(String(::tvm::SaveJSON(attrs[0]))); + return std::move(attrs_record); + } + + static Array AttrsFromJSON(const ObjectRef& attrs_record_) { + Array attrs_record = Downcast>(attrs_record_); + Array attrs; + attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[0]))); + return attrs; + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + struct SetAxisSeparatorTraits : public UnpackedInstTraits { static constexpr const char* kName = "SetAxisSeparator"; static constexpr bool kIsPure = false; @@ -359,6 +661,7 @@ struct SetAxisSeparatorTraits : public UnpackedInstTraits LocationsOfInterest() const final { return {loop_}; } - - IRModule mod_; - For loop_; -}; - class NotSingleInferFactorError : public ScheduleError { public: explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} @@ -407,10 +388,8 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, } // Currently, loops not starting with 0 are not supported arith::Analyzer analyzer; - if (!analyzer.CanProve(loop->min == 0)) { - throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); - } - // Step 2. Replace all occurrences of the original loop var with new variables + CheckLoopStartsWithZero(self, loop_sref, &analyzer); + int n = factors.size(); PrimExpr substitute_value = 0; std::vector new_loop_vars; @@ -482,9 +461,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { } outer_loop_sref = sref; outer_loop = loop; - if (!analyzer.CanProve(loop->min == 0)) { - throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); - } + CheckLoopStartsWithZero(self, sref, &analyzer); const VarNode* used_var = nullptr; auto f_contain = [&outer_loop_vars, &used_var](const VarNode* var) { if (outer_loop_vars.count(var)) { diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 8dc0c52111cc..fb884ce77f7b 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -233,6 +233,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") return self->TransformLayout(block_rv, buffer_index, static_cast(buffer_index_type), index_map); }); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout") + .set_body_method(&ScheduleNode::TransformBlockLayout); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, const Array& axis_separators) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 865b6f378468..8156480a4516 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -442,6 +442,16 @@ void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_ind /*outputs=*/{})); } +void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) { + ConcreteScheduleNode::TransformBlockLayout(block_rv, index_map); + static const InstructionKind& kind = InstructionKind::Get("TransformBlockLayout"); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{index_map}, + /*outputs=*/{})); +} + void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const Array& axis_separators) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 12c076d886cd..d1860be9512d 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -98,6 +98,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) override; + void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const Array& axis_separators) final; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 6c4f3e1b7af0..79802ecd65db 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -280,5 +280,36 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin); +/******** BlockBufferAccessSimplifier ********/ +void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_access_regions) { + auto fmutate = [this](const BufferRegion& buffer_region) { + std::vector new_buffer_region; + for (const auto& range : buffer_region->region) { + new_buffer_region.push_back(Range::FromMinExtent(analyzer_->Simplify(range->min), + analyzer_->Simplify(range->extent))); + } + return BufferRegion(buffer_region->buffer, new_buffer_region); + }; + (*old_access_regions).MutateByApply(fmutate); +} + +Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) { + Block block = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); + auto* n = block.CopyOnWrite(); + SimplifyAccessRegion(&n->reads); + SimplifyAccessRegion(&n->writes); + return std::move(block); +} + +Stmt BlockBufferAccessSimplifier::VisitStmt_(const BufferStoreNode* op) { + auto node = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); +} + +PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) { + auto node = Downcast(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 52e27350d466..192d44d9e9ad 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -26,6 +26,7 @@ #include #include +#include "../../arith/ir_mutator_with_analyzer.h" #include "../ir/functor_common.h" namespace tvm { @@ -172,6 +173,44 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, const String& intrin_name); +/******** Block mutation ********/ + +/*! + * \brief Simplifier for indices of buffer access and block buffer access regions. + */ +class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer { + public: + /*! + * \brief Simplify indices of buffer access and block buffer access regions in the statement + * \param stmt The statement to be simplified + * \param analyzer The arithmetic analyzer + * \return The simplified statement + */ + static Stmt Simplify(const Stmt& stmt, arith::Analyzer* analyzer) { + BlockBufferAccessSimplifier simplifier(analyzer); + return simplifier(stmt); + } + + private: + explicit BlockBufferAccessSimplifier(arith::Analyzer* analyzer) + : IRMutatorWithAnalyzer(analyzer) {} + + using IRMutatorWithAnalyzer::VisitExpr_; + using IRMutatorWithAnalyzer::VisitStmt_; + + void SimplifyAccessRegion(Array* old_access_regions); + Stmt VisitStmt_(const BlockNode* op) final; + Stmt VisitStmt_(const BufferStoreNode* op) final; + PrimExpr VisitExpr_(const BufferLoadNode* op) final; + + template + Node VisitBufferAccess(Node node) { + node.CopyOnWrite()->indices.MutateByApply( + [this](const PrimExpr& expr) { return analyzer_->Simplify(expr); }); + return node; + } +}; + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 699eaf1236ac..e184bc3f627c 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -91,6 +91,83 @@ def two_elementwise_transformed_output_buffer( C[vi // 16, vj // 16, vi % 16, vj % 16] = B[vi, vj] + 1.0 +@T.prim_func +def elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + +@T.prim_func +def elementwise_transformed(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]) -> None: + for i in range(16384): + with T.block("B"): + vi, = T.axis.remap("S", [i]) + B[vi // 128, vi % 128] = A[vi // 128, vi % 128] * 2.0 + + +@T.prim_func +def conv2d_nhwc( + Input: T.Buffer[(1, 224, 224, 3), "float32"], + Weight: T.Buffer[(7, 7, 3, 64), "float32"], + Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], +) -> None: + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), + Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1], + T.float32(0), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): + with T.block("conv2d_nhwc"): + n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + Conv2d_nhwc[n, h, w, co] = T.float32(0) + Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( + PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ((T.floordiv(co, 64) * 3) + rc)] + * Weight[rh, rw, rc, co] + ) + + +@T.prim_func +def conv2d_nhwc_transformed( + Input: T.Buffer[(1, 224, 224, 3), "float32"], + Weight: T.Buffer[(7, 7, 3, 64), "float32"], + Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], +) -> None: + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227, + Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1], + T.float32(0), + dtype="float32", + ) + for ax0, ax_1, ax_2 in T.grid(12544, 64, 147): + with T.block("conv2d_nhwc"): + bv0, bv1, bv2 = T.axis.remap("SSR", [ax0, ax_1, ax_2]) + T.reads( + PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3], + Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1], + ) + T.writes(Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1]) + with T.init(): + Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = T.float32(0) + Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = ( + Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] + + PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3] + * Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1] + ) + # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on @@ -218,5 +295,41 @@ def summation_3d_split( tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"]) +def test_transform_block_layout_basic(): + sch = tir.Schedule(elementwise, debug_mask="all") + block = sch.get_block("B") + sch.transform_block_layout(block, lambda i, j: (i * 128 + j,)) + tvm.ir.assert_structural_equal(elementwise_transformed, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_transform_block_layout_conv2d_nhwc(): + sch = tir.Schedule(conv2d_nhwc, debug_mask="all") + block = sch.get_block("conv2d_nhwc") + sch.transform_block_layout( + block, + lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co, rh * 7 * 3 + rw * 3 + rc), + ) + tvm.ir.assert_structural_equal(conv2d_nhwc_transformed, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc) + + +def test_transform_block_layout_fail_non_affine(): + sch = tir.Schedule(elementwise, debug_mask="all") + block = sch.get_block("B") + with pytest.raises(tir.ScheduleError): + sch.transform_block_layout(block, lambda i, j: (i + j,)) + + +def test_transform_block_layout_fail_mixed_iter_type(): + sch = tir.Schedule(conv2d_nhwc, debug_mask="all") + block = sch.get_block("conv2d_nhwc") + with pytest.raises(tir.ScheduleError): + sch.transform_block_layout( + block, + lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co * 7 + rh, rw * 3 + rc), + ) + + if __name__ == "__main__": tvm.testing.main() From a82836df39dd271cd19cd99abda096653eb9f779 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 27 May 2022 11:14:39 -0700 Subject: [PATCH 2/2] fixup! [TIR] Add schedule primitive TransformBlockLayout Fix doc --- include/tvm/tir/schedule/schedule.h | 5 ++--- python/tvm/tir/schedule/schedule.py | 2 +- src/tir/schedule/analysis.h | 3 ++- src/tir/schedule/analysis/analysis.cc | 8 +++++--- src/tir/schedule/primitive.h | 2 +- src/tir/schedule/primitive/layout_transformation.cc | 9 +++++---- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 767dcae49dae..48014280a558 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -550,9 +550,8 @@ class ScheduleNode : public runtime::Object { * \details The block iters and the block body are transformed by the given index_map. * Outer loops corresponding to each new block iter are regenerated. * The index_map is required to be bijective affine since we need its inverse mapping. - * \param self The state of the schedule - * \param block_sref The block sref that refers to the block to be transformed - * \param affine_index_map The transformation to apply. + * \param block_rv The block to be transformed + * \param index_map The transformation to apply. */ virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0; diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 6c8326390d50..f86228848b9d 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2296,7 +2296,7 @@ def transform_block_layout( Parameters ---------- - block_rv : BlockRV + block : BlockRV The block to be transformed index_map : Union[IndexMap, Callable] diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 7d84763ff85c..0574cfefadb6 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -285,7 +285,8 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, * \param analyzer The arithmetic analyzer * \throw ScheduleError If the loop doesn't starts with zero. */ -void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer); +void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, + arith::Analyzer* analyzer); /******** Block-loop relation ********/ diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 0659b3531894..c4719015daa4 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -688,10 +688,12 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, /******** Loop properties ********/ -void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer) { +void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, + arith::Analyzer* analyzer) { class LoopNotStartWithZeroError : public ScheduleError { - public: - explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + public: + explicit LoopNotStartWithZeroError(IRModule mod, For loop) + : mod_(mod), loop_(std::move(loop)) {} String FastErrorString() const final { return "ScheduleError: The primitive only supports loop starting with 0"; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0926fbe0e220..50dedf71ff52 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -449,7 +449,7 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int * The index_map is required to be bijective affine since we need its inverse mapping. * \param self The state of the schedule * \param block_sref The block sref that refers to the block to be transformed - * \param affine_index_map The transformation to apply. + * \param index_map The transformation to apply. */ TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, const IndexMap& index_map); diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index be93207f7a20..6da796fc955f 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -281,7 +281,6 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError { IRModule mod_; Block block_; IndexMap index_map_; - }; class NotTrivialBindingError : public ScheduleError { @@ -405,7 +404,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, Array new_block_iters; // new block iters Array new_block_vars; // iter_var->var of new block iters for (size_t i = 0; i < index_map->final_indices.size(); ++i) { - Var new_block_var{"bv" + std::to_string(i), DataType::Int(32)}; + Var new_block_var{"v" + std::to_string(i), DataType::Int(32)}; new_block_vars.push_back(new_block_var); IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type); if (iter_type == kOpaque) { @@ -419,7 +418,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // in the body. auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars); - // Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant zero. + // Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant + // zero. for (const auto& iter_var : block_ptr->iter_vars) { if (inverse_map.find(iter_var->var) == inverse_map.end()) { ICHECK(is_one(iter_var->dom->extent)); @@ -447,7 +447,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Generate outer loops Stmt body = GetRef(new_block_realize); for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; --i) { - body = For(Downcast(new_loop_vars[i]), 0, new_block_iter_range[i], ForKind::kSerial, std::move(body)); + body = For(Downcast(new_loop_vars[i]), 0, new_block_iter_range[i], ForKind::kSerial, + std::move(body)); } // Step 6: Do the actual replacement