Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions src/transform/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,11 @@ PrimFunc MakePackedAPI(PrimFunc func) {
record_shape_vars(buf->elem_offset);
}

// A visitor that marks a buffer as used when its underlying data var is
// referenced (e.g. BufferLoad/BufferStore or any direct var usage).
// A visitor that records
// - which parameter buffers are used via their data var (load/store/direct),
// - which shape/stride/offset symbols are referenced in the body.
// Shape symbols are not immediately attributed to all carrier buffers here;
// a minimal carrier set is selected after visiting.
struct UsedBufferDetector : public StmtExprVisitor {
UsedBufferDetector(
const std::unordered_map<const VarNode *, const VarNode *> &data2param,
Expand All @@ -335,34 +338,34 @@ PrimFunc MakePackedAPI(PrimFunc func) {
void VisitExpr_(const VarNode *op) override {
auto it = data2param.find(op);
if (it != data2param.end()) {
used_params.insert(it->second);
used_params_by_data.insert(it->second);
}
auto it2 = shape2params.find(op);
if (it2 != shape2params.end()) {
for (const VarNode *p : it2->second)
used_params.insert(p);
used_shape_vars.insert(op);
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) override {
auto it = data2param.find(op->buffer->data.get());
if (it != data2param.end()) {
used_params.insert(it->second);
used_params_by_data.insert(it->second);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode *op) override {
auto it = data2param.find(op->buffer->data.get());
if (it != data2param.end()) {
used_params.insert(it->second);
used_params_by_data.insert(it->second);
}
StmtExprVisitor::VisitExpr_(op);
}

const std::unordered_map<const VarNode *, const VarNode *> &data2param;
const std::unordered_map<const VarNode *, std::vector<const VarNode *>>
&shape2params;
std::unordered_set<const VarNode *> used_params;
std::unordered_set<const VarNode *> used_params_by_data;
std::unordered_set<const VarNode *> used_shape_vars;
};

UsedBufferDetector detector(data_var2param, shape_var2params);
Expand All @@ -371,7 +374,30 @@ PrimFunc MakePackedAPI(PrimFunc func) {
// Build the packed argument handling. While doing so, keep track of whether
// each parameter buffer is actually used. Unused input buffers can be
// nullable and do not require DLTensor field dereferences.
std::unordered_set<const VarNode *> used_param_buffers = detector.used_params;
//
// Start from buffers used via data-var (definitely non-NULL), then for each
// referenced shape symbol pick a minimal "carrier" buffer that provides the
// symbol. Prefer carriers that are already used-by-data; otherwise pick one
// arbitrary carrier to ensure the symbol is bound.
std::unordered_set<const VarNode *> used_param_buffers =
detector.used_params_by_data;
for (const VarNode *sym : detector.used_shape_vars) {
auto it = shape_var2params.find(sym);
if (it == shape_var2params.end())
continue;
const auto &carriers = it->second;
bool has_used_carrier = false;
for (const VarNode *p : carriers) {
if (used_param_buffers.count(p)) {
has_used_carrier = true;
break;
}
}
if (!has_used_carrier && !carriers.empty()) {
// Choose the first carrier to anchor this symbol.
used_param_buffers.insert(carriers.front());
}
}

for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
Expand Down
Loading