@@ -92,17 +92,11 @@ class TransformLayoutPlanner : private StmtExprVisitor {
9292 std::variant<ProloguePlan, ReplacementPlan, EpiloguePlan, NoPaddingRequired>;
9393
9494 static TransformPlan Plan (Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map,
95+ IndexMap inverse, PrimExpr padding_predicate,
9596 Optional<IndexMap> pad_value) {
9697 ICHECK (!pad_value.defined () || pad_value.value ()->final_indices .size () == 1 )
9798 << " Internal error: Should be caught by ScheduleError checks prior to this point" ;
9899 TransformLayoutPlanner visitor (old_buffer);
99- auto [inverse, padding_predicate] = [&]() {
100- Array<Range> region;
101- for (const auto & dim : old_buffer->shape ) {
102- region.push_back (Range::FromMinExtent (make_zero (dim.dtype ()), dim));
103- }
104- return index_map.NonSurjectiveInverse (region);
105- }();
106100 visitor (block);
107101 return visitor.Finalize (new_buffer, index_map, inverse, padding_predicate, pad_value);
108102 }
@@ -754,18 +748,17 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
754748 * \param old_buffer The target buffer before transformation
755749 * \param new_buffer The new buffer after transformation
756750 * \param index_map The transformation applied to the buffer
757- * \param pad_value The value to be used for padding
758751 * \return The new AST rooting at the original parent scope and the map from the old block to the
759752 * new block
760753 */
761- static std::pair<Stmt, Map<Block, Block>> Rewrite (const Block& scope_stmt,
762- const Buffer& old_buffer ,
763- const Buffer& new_buffer ,
764- const IndexMap& index_map,
765- const Optional<IndexMap>& pad_value) {
766- auto plan = pad_value. defined () ? TransformLayoutPlanner::Plan ( scope_stmt, old_buffer,
767- new_buffer, index_map , pad_value)
768- : TransformLayoutPlanner::NoPaddingRequired{} ;
754+ static std::pair<Stmt, Map<Block, Block>> Rewrite (
755+ const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer ,
756+ const IndexMap& index_map, const Optional<IndexMap>& opt_inverse ,
757+ const PrimExpr& padding_predicate, const Optional< IndexMap>& pad_value) {
758+ auto plan = pad_value. defined () ? TransformLayoutPlanner::Plan (
759+ scope_stmt, old_buffer, new_buffer, index_map ,
760+ opt_inverse. value (), padding_predicate , pad_value)
761+ : TransformLayoutPlanner::NoPaddingRequired () ;
769762
770763 arith::Analyzer analyzer;
771764 TransformLayoutRewriter rewriter (old_buffer, new_buffer, index_map, plan, &analyzer);
@@ -1058,6 +1051,40 @@ class TransformationPaddingExpressionError : public ScheduleError {
10581051 BufferLoad illegal_load_;
10591052};
10601053
1054+ class TransformationIntroducesPaddingError : public ScheduleError {
1055+ public:
1056+ TransformationIntroducesPaddingError (IRModule mod, Buffer buffer, IndexMap index_map,
1057+ PrimExpr padding_predicate)
1058+ : mod_(std::move(mod)),
1059+ buffer_ (std::move(buffer)),
1060+ index_map_(std::move(index_map)),
1061+ padding_predicate_(std::move(padding_predicate)) {}
1062+
1063+ String FastErrorString () const final {
1064+ std::ostringstream ss;
1065+ ss << " ScheduleError: Transformation would introduce padding at " << padding_predicate_ << " ." ;
1066+ return ss.str ();
1067+ }
1068+
1069+ String DetailRenderTemplate () const final {
1070+ auto new_shape = index_map_->MapShape (buffer_->shape );
1071+ std::ostringstream os;
1072+ os << " The transformation " << index_map_ << " applied on buffer " << buffer_->name
1073+ << " of shape " << buffer_->shape << " would result in shape " << new_shape
1074+ << " . However, this would introduce padding wherever " << padding_predicate_ << " is true." ;
1075+ return os.str ();
1076+ }
1077+
1078+ IRModule mod () const final { return mod_; }
1079+ Array<ObjectRef> LocationsOfInterest () const final { return {}; }
1080+
1081+ private:
1082+ IRModule mod_;
1083+ Buffer buffer_;
1084+ IndexMap index_map_;
1085+ PrimExpr padding_predicate_;
1086+ };
1087+
10611088// Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid
10621089// dtype-mismatch issues later.
10631090IndexMap LegalizeIndexMapDType (const IndexMap& index_map, const Array<PrimExpr>& args) {
@@ -1094,7 +1121,7 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>&
10941121
10951122void TransformLayout (ScheduleState self, const StmtSRef& block_sref, int buffer_index,
10961123 BufferIndexType buffer_index_type, const IndexMap& index_map_orig,
1097- const Optional<IndexMap>& pad_value) {
1124+ const Optional<IndexMap>& pad_value, bool assume_injective_transform ) {
10981125 // Step 1: Input handling and error checking
10991126 const BlockNode* block_ptr = TVM_SREF_TO_BLOCK (block_sref);
11001127 Buffer old_buffer =
@@ -1122,14 +1149,32 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
11221149 : GetScopeRoot (self, block_sref, /* require_stage_pipeline=*/ false );
11231150 const BlockNode* scope_block = TVM_SREF_TO_BLOCK (scope_sref);
11241151
1152+ Optional<IndexMap> opt_inverse = NullOpt;
1153+ PrimExpr padding_predicate = Bool (true );
1154+ if (!assume_injective_transform) {
1155+ std::tie (opt_inverse, padding_predicate) = [&]() {
1156+ Array<Range> region;
1157+ for (const auto & dim : old_buffer->shape ) {
1158+ region.push_back (Range::FromMinExtent (make_zero (dim.dtype ()), dim));
1159+ }
1160+ return index_map.NonSurjectiveInverse (region);
1161+ }();
1162+ }
1163+
1164+ bool has_padding = !is_zero (padding_predicate);
1165+ if (has_padding && !pad_value.defined ()) {
1166+ throw TransformationIntroducesPaddingError (self->mod , old_buffer, index_map, padding_predicate);
1167+ }
1168+
11251169 // Step 2: Infer the shape of the new buffer
11261170 Buffer new_buffer = old_buffer;
11271171 new_buffer.CopyOnWrite ()->shape = index_map->MapShape (old_buffer->shape );
11281172
11291173 // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block
11301174 // alloc_buffers.
1131- auto [new_stmt, block_sref_reuse] = TransformLayoutRewriter::Rewrite (
1132- GetRef<Block>(scope_block), old_buffer, new_buffer, index_map, pad_value);
1175+ auto [new_stmt, block_sref_reuse] =
1176+ TransformLayoutRewriter::Rewrite (GetRef<Block>(scope_block), old_buffer, new_buffer,
1177+ index_map, opt_inverse, padding_predicate, pad_value);
11331178 Block new_scope_block = Downcast<Block>(new_stmt);
11341179
11351180 // Step 4: Rewrite buffer_map of the PrimFunc if necessary.
@@ -1472,20 +1517,21 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
14721517
14731518 private:
14741519 static constexpr size_t kNumInputs = 2 ;
1475- static constexpr size_t kNumAttrs = 3 ;
1520+ static constexpr size_t kNumAttrs = 4 ;
14761521 static constexpr size_t kNumDecisions = 0 ;
14771522
14781523 static void UnpackedApplyToSchedule (Schedule sch, BlockRV block_rv, IndexMap index_map,
14791524 Integer buffer_index, Integer buffer_index_type,
1480- Optional<IndexMap> pad_value) {
1525+ Optional<IndexMap> pad_value,
1526+ Bool assume_injective_transform) {
14811527 return sch->TransformLayout (block_rv, buffer_index.IntValue (),
14821528 static_cast <BufferIndexType>(buffer_index_type->value ), index_map,
1483- pad_value);
1529+ pad_value, assume_injective_transform. operator bool () );
14841530 }
14851531
14861532 static String UnpackedAsPython (Array<String> outputs, String block_rv, IndexMap index_map,
14871533 Integer buffer_index, Integer buffer_index_type,
1488- Optional<IndexMap> pad_value) {
1534+ Optional<IndexMap> pad_value, Bool assume_injective_transform ) {
14891535 PythonAPICall py (" transform_layout" );
14901536 py.Input (" block" , block_rv);
14911537
@@ -1495,6 +1541,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
14951541 py.Input (" buffer" , os.str ());
14961542 py.Input (" index_map" , index_map->ToPythonString ());
14971543 py.Input (" pad_value" , pad_value ? pad_value.value ()->ToPythonString () : " None" );
1544+ py.Input (" assume_injective_transform" , assume_injective_transform.operator bool ());
14981545
14991546 return py.Str ();
15001547 }
@@ -1510,6 +1557,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
15101557 } else {
15111558 attrs_record.push_back (attrs[2 ]);
15121559 }
1560+ attrs_record.push_back (attrs[3 ]);
15131561 return std::move (attrs_record);
15141562 }
15151563
@@ -1523,6 +1571,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
15231571 } else {
15241572 attrs.push_back (attrs_record[2 ]);
15251573 }
1574+ attrs.push_back (attrs_record[3 ]);
15261575 return attrs;
15271576 }
15281577
0 commit comments