Skip to content

Commit 630cdd5

Browse files
committed
partly comment the AsyncDMALowerer pass. This pass is not friendly to cuda backend.
1 parent 08b366d commit 630cdd5

File tree

1 file changed

+32
-31
lines changed

1 file changed

+32
-31
lines changed

src/tir/transforms/lower_async_dma.cc

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,37 +43,38 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
4343
explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer)
4444
: IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {}
4545

46-
Stmt VisitStmt_(const ForNode* loop) final {
47-
// if for loop is not within async_commit_queue_scope
48-
if (!async_queue_id_.has_value()) {
49-
return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
50-
}
51-
52-
// if for loop is not a memcpy of a contiguous region
53-
std::optional<tvm::tir::MemCpyDetails> mem_copy = IdentifyMemCpy(GetRef<For>(loop), analyzer_);
54-
if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 ||
55-
mem_copy->source->region.size() != 1) {
56-
LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access";
57-
}
58-
59-
// now that we are about to perform the `copy` transform
60-
// save queue ID for inspection in `wait` transform
61-
// and, increment the number of DMA copies in the group
62-
queue_ids_.insert(async_queue_id_.value());
63-
dmas_in_group_++;
64-
65-
tvm::PrimExpr src_min = mem_copy->source->region[0]->min;
66-
tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min;
67-
tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent;
68-
69-
auto src = BufferLoad(mem_copy->source->buffer, {src_min});
70-
auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min});
71-
return Evaluate(
72-
Call(DataType::Int(32), builtin::dma_copy(),
73-
{async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}),
74-
Call(DataType::Handle(), builtin::address_of(), {src}),
75-
dst_extent * src->dtype.bytes(), dma_bypass_cache_}));
76-
}
46+
// TODO: split lower async DMA support for CUDA and Hexagon Backend
47+
// Stmt VisitStmt_(const ForNode* loop) final {
48+
// // if for loop is not within async_commit_queue_scope
49+
// if (!async_queue_id_.has_value()) {
50+
// return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
51+
// }
52+
53+
// // if for loop is not a memcpy of a contiguous region
54+
// std::optional<tvm::tir::MemCpyDetails> mem_copy = IdentifyMemCpy(GetRef<For>(loop), analyzer_);
55+
// if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 ||
56+
// mem_copy->source->region.size() != 1) {
57+
// LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access";
58+
// }
59+
60+
// // now that we are about to perform the `copy` transform
61+
// // save queue ID for inspection in `wait` transform
62+
// // and, increment the number of DMA copies in the group
63+
// queue_ids_.insert(async_queue_id_.value());
64+
// dmas_in_group_++;
65+
66+
// tvm::PrimExpr src_min = mem_copy->source->region[0]->min;
67+
// tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min;
68+
// tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent;
69+
70+
// auto src = BufferLoad(mem_copy->source->buffer, {src_min});
71+
// auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min});
72+
// return Evaluate(
73+
// Call(DataType::Int(32), builtin::dma_copy(),
74+
// {async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}),
75+
// Call(DataType::Handle(), builtin::address_of(), {src}),
76+
// dst_extent * src->dtype.bytes(), dma_bypass_cache_}));
77+
// }
7778

7879
Stmt VisitStmt_(const AttrStmtNode* op) final {
7980
// populate analyzer knowledge of loop iterators

0 commit comments

Comments
 (0)