@@ -43,37 +43,38 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
4343 explicit AsyncDMALowerer (bool dma_bypass_cache, arith::Analyzer* analyzer)
4444 : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {}
4545
46- Stmt VisitStmt_ (const ForNode* loop) final {
47- // if for loop is not within async_commit_queue_scope
48- if (!async_queue_id_.has_value ()) {
49- return arith::IRMutatorWithAnalyzer::VisitStmt_ (loop);
50- }
51-
52- // if for loop is not a memcpy of a contiguous region
53- std::optional<tvm::tir::MemCpyDetails> mem_copy = IdentifyMemCpy (GetRef<For>(loop), analyzer_);
54- if (!mem_copy.has_value () || mem_copy->dest ->region .size () != 1 ||
55- mem_copy->source ->region .size () != 1 ) {
56- LOG (FATAL) << " Unable to lower async dma due to non contiguous memory access" ;
57- }
58-
59- // now that we are about to perform the `copy` transform
60- // save queue ID for inspection in `wait` transform
61- // and, increment the number of DMA copies in the group
62- queue_ids_.insert (async_queue_id_.value ());
63- dmas_in_group_++;
64-
65- tvm::PrimExpr src_min = mem_copy->source ->region [0 ]->min ;
66- tvm::PrimExpr dst_min = mem_copy->dest ->region [0 ]->min ;
67- tvm::PrimExpr dst_extent = mem_copy->dest ->region [0 ]->extent ;
68-
69- auto src = BufferLoad (mem_copy->source ->buffer , {src_min});
70- auto dst = BufferLoad (mem_copy->dest ->buffer , {dst_min});
71- return Evaluate (
72- Call (DataType::Int (32 ), builtin::dma_copy (),
73- {async_queue_id_.value (), Call (DataType::Handle (), builtin::address_of (), {dst}),
74- Call (DataType::Handle (), builtin::address_of (), {src}),
75- dst_extent * src->dtype .bytes (), dma_bypass_cache_}));
76- }
46+ // TODO: split lower async DMA support for CUDA and Hexagon Backend
47+ // Stmt VisitStmt_(const ForNode* loop) final {
48+ // // if for loop is not within async_commit_queue_scope
49+ // if (!async_queue_id_.has_value()) {
50+ // return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
51+ // }
52+
53+ // // if for loop is not a memcpy of a contiguous region
54+ // std::optional<tvm::tir::MemCpyDetails> mem_copy = IdentifyMemCpy(GetRef<For>(loop), analyzer_);
55+ // if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 ||
56+ // mem_copy->source->region.size() != 1) {
57+ // LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access";
58+ // }
59+
60+ // // now that we are about to perform the `copy` transform
61+ // // save queue ID for inspection in `wait` transform
62+ // // and, increment the number of DMA copies in the group
63+ // queue_ids_.insert(async_queue_id_.value());
64+ // dmas_in_group_++;
65+
66+ // tvm::PrimExpr src_min = mem_copy->source->region[0]->min;
67+ // tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min;
68+ // tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent;
69+
70+ // auto src = BufferLoad(mem_copy->source->buffer, {src_min});
71+ // auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min});
72+ // return Evaluate(
73+ // Call(DataType::Int(32), builtin::dma_copy(),
74+ // {async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}),
75+ // Call(DataType::Handle(), builtin::address_of(), {src}),
76+ // dst_extent * src->dtype.bytes(), dma_bypass_cache_}));
77+ // }
7778
7879 Stmt VisitStmt_ (const AttrStmtNode* op) final {
7980 // populate analyzer knowledge of loop iterators
0 commit comments