diff --git a/src/transform/make_packed_api.cc b/src/transform/make_packed_api.cc index a72286133..942c652fd 100644 --- a/src/transform/make_packed_api.cc +++ b/src/transform/make_packed_api.cc @@ -324,8 +324,11 @@ PrimFunc MakePackedAPI(PrimFunc func) { record_shape_vars(buf->elem_offset); } - // A visitor that marks a buffer as used when its underlying data var is - // referenced (e.g. BufferLoad/BufferStore or any direct var usage). + // A visitor that records + // - which parameter buffers are used via their data var (load/store/direct), + // - which shape/stride/offset symbols are referenced in the body. + // Shape symbols are not immediately attributed to all carrier buffers here; + // a minimal carrier set is selected after visiting. struct UsedBufferDetector : public StmtExprVisitor { UsedBufferDetector( const std::unordered_map &data2param, @@ -335,26 +338,25 @@ PrimFunc MakePackedAPI(PrimFunc func) { void VisitExpr_(const VarNode *op) override { auto it = data2param.find(op); if (it != data2param.end()) { - used_params.insert(it->second); + used_params_by_data.insert(it->second); } auto it2 = shape2params.find(op); if (it2 != shape2params.end()) { - for (const VarNode *p : it2->second) - used_params.insert(p); + used_shape_vars.insert(op); } StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode *op) override { auto it = data2param.find(op->buffer->data.get()); if (it != data2param.end()) { - used_params.insert(it->second); + used_params_by_data.insert(it->second); } StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const BufferLoadNode *op) override { auto it = data2param.find(op->buffer->data.get()); if (it != data2param.end()) { - used_params.insert(it->second); + used_params_by_data.insert(it->second); } StmtExprVisitor::VisitExpr_(op); } @@ -362,7 +364,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { const std::unordered_map &data2param; const std::unordered_map> &shape2params; - std::unordered_set used_params; + std::unordered_set used_params_by_data; + std::unordered_set used_shape_vars; }; UsedBufferDetector detector(data_var2param, shape_var2params); @@ -371,7 +374,30 @@ PrimFunc MakePackedAPI(PrimFunc func) { // Build the packed argument handling. While doing so, keep track of whether // each parameter buffer is actually used. Unused input buffers can be // nullable and do not require DLTensor field dereferences. - std::unordered_set used_param_buffers = detector.used_params; + // + // Start from buffers used via data-var (definitely non-NULL), then for each + // referenced shape symbol pick a minimal "carrier" buffer that provides the + // symbol. Prefer carriers that are already used-by-data; otherwise pick one + // arbitrary carrier to ensure the symbol is bound. + std::unordered_set used_param_buffers = + detector.used_params_by_data; + for (const VarNode *sym : detector.used_shape_vars) { + auto it = shape_var2params.find(sym); + if (it == shape_var2params.end()) + continue; + const auto &carriers = it->second; + bool has_used_carrier = false; + for (const VarNode *p : carriers) { + if (used_param_buffers.count(p)) { + has_used_carrier = true; + break; + } + } + if (!has_used_carrier && !carriers.empty()) { + // Choose the first carrier to anchor this symbol. + used_param_buffers.insert(carriers.front()); + } + } for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i];