|
16 | 16 | * specific language governing permissions and limitations |
17 | 17 | * under the License. |
18 | 18 | */ |
| 19 | +#include <tvm/tir/expr.h> |
| 20 | + |
19 | 21 | #include "../utils.h" |
20 | 22 |
|
21 | 23 | namespace tvm { |
@@ -297,6 +299,93 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
297 | 299 | self->Replace(alloc_site_sref, new_block, block_reuse_map); |
298 | 300 | } |
299 | 301 |
|
| 302 | +/*! |
| 303 | + * \brief A helper mutator which recursively mutates the old buffer's data type, inserts data type |
| 304 | + * conversions, and collecte the block sref reuse information for the following replacement. |
| 305 | + */ |
| 306 | +class DTypeMutator : private ReplaceBufferMutator { |
| 307 | + public: |
| 308 | + /*! |
| 309 | + * \param allocate_site The block where `old_buffer` was allocated. |
| 310 | + * \param old_buffer The old buffer |
| 311 | + * \param target_dtype The data type to be set |
| 312 | + * \param block_sref_reuse The block sref reuse map to be updated |
| 313 | + * \return The new block after the mutation |
| 314 | + */ |
| 315 | + static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, const DataType& dtype, |
| 316 | + Map<Block, Block>* block_sref_reuse) { |
| 317 | + Buffer new_buffer = WithDType(old_buffer, dtype); |
| 318 | + DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse); |
| 319 | + Stmt new_block = mutator.VisitStmt(allocate_site); |
| 320 | + return Downcast<Block>(new_block); |
| 321 | + } |
| 322 | + |
| 323 | + private: |
| 324 | + DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& dtype, |
| 325 | + Map<Block, Block>* block_sref_reuse) |
| 326 | + : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse), |
| 327 | + src_dtype_(old_buffer->dtype), |
| 328 | + tgt_dtype_(dtype) {} |
| 329 | + |
| 330 | + MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { |
| 331 | + auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get()); |
| 332 | + if (it != buffer_var_map_.end()) { |
| 333 | + Buffer new_target_buffer = WithDType(match_buffer->buffer, it->second->dtype); |
| 334 | + buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer; |
| 335 | + return MatchBufferRegion(new_target_buffer, |
| 336 | + BufferRegion(it->second, match_buffer->source->region)); |
| 337 | + } else { |
| 338 | + return match_buffer; |
| 339 | + } |
| 340 | + } |
| 341 | + |
| 342 | + Stmt VisitStmt_(const BufferStoreNode* op) final { |
| 343 | + BufferStore node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
| 344 | + auto it = buffer_var_map_.find(node->buffer->data.get()); |
| 345 | + if (it != buffer_var_map_.end()) { |
| 346 | + node.CopyOnWrite()->buffer = it->second; |
| 347 | + node.CopyOnWrite()->value = Cast(tgt_dtype_, node->value); |
| 348 | + } |
| 349 | + return node; |
| 350 | + } |
| 351 | + |
| 352 | + PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
| 353 | + BufferLoad node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
| 354 | + auto it = buffer_var_map_.find(node->buffer->data.get()); |
| 355 | + if (it != buffer_var_map_.end()) { |
| 356 | + return Cast(src_dtype_, BufferLoad(it->second, node->indices)); |
| 357 | + } |
| 358 | + return node; |
| 359 | + } |
| 360 | + |
| 361 | + DataType src_dtype_, tgt_dtype_; |
| 362 | +}; |
| 363 | + |
| 364 | +void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
| 365 | + const String& dtype) { |
| 366 | + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
| 367 | + Buffer buffer = |
| 368 | + GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kWrite); |
| 369 | + DataType target_dtype(runtime::String2DLDataType(dtype)); |
| 370 | + |
| 371 | + // Step 1. If `dtype` equals the original data type, just return. |
| 372 | + if (buffer->dtype == target_dtype) { |
| 373 | + return; |
| 374 | + } |
| 375 | + |
| 376 | + // Step 2. Get the allocation site of the target buffer. |
| 377 | + StmtSRef alloc_site_sref = |
| 378 | + NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); |
| 379 | + const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref); |
| 380 | + |
| 381 | + // Step 3. Recursively replace old buffer to a new buffer, where the new buffer has the given |
| 382 | + // dtype, and insert data type conversions. |
| 383 | + Map<Block, Block> block_reuse_map; |
| 384 | + Block new_block = |
| 385 | + DTypeMutator::Mutate(GetRef<Block>(alloc_site), buffer, target_dtype, &block_reuse_map); |
| 386 | + self->Replace(alloc_site_sref, new_block, block_reuse_map); |
| 387 | +} |
| 388 | + |
300 | 389 | /******** InstructionKind Registration ********/ |
301 | 390 |
|
302 | 391 | struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> { |
@@ -356,8 +445,36 @@ struct SetScopeTraits : public UnpackedInstTraits<SetScopeTraits> { |
356 | 445 | friend struct ::tvm::tir::UnpackedInstTraits; |
357 | 446 | }; |
358 | 447 |
|
| 448 | +struct UnsafeSetDTypeTraits : public UnpackedInstTraits<UnsafeSetDTypeTraits> { |
| 449 | + static constexpr const char* kName = "UnsafeSetDType"; |
| 450 | + static constexpr bool kIsPure = false; |
| 451 | + |
| 452 | + private: |
| 453 | + static constexpr size_t kNumInputs = 1; |
| 454 | + static constexpr size_t kNumAttrs = 2; |
| 455 | + static constexpr size_t kNumDecisions = 0; |
| 456 | + |
| 457 | + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, |
| 458 | + String dtype) { |
| 459 | + return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype); |
| 460 | + } |
| 461 | + |
| 462 | + static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index, |
| 463 | + String dtype) { |
| 464 | + PythonAPICall py("unsafe_set_dtype"); |
| 465 | + py.Input("block", block_rv); |
| 466 | + py.Input("buffer_index", buffer_index); |
| 467 | + py.Input("dtype", dtype); |
| 468 | + return py.Str(); |
| 469 | + } |
| 470 | + |
| 471 | + template <typename> |
| 472 | + friend struct ::tvm::tir::UnpackedInstTraits; |
| 473 | +}; |
| 474 | + |
359 | 475 | TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); |
360 | 476 | TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits); |
| 477 | +TVM_REGISTER_INST_KIND_TRAITS(UnsafeSetDTypeTraits); |
361 | 478 |
|
362 | 479 | } // namespace tir |
363 | 480 | } // namespace tvm |
0 commit comments