Skip to content

Commit dfaf496

Browse files
committed
dedup tasks
1 parent e49d500 commit dfaf496

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/relay/backend/task_extraction.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
4646
auto opt_mod = seq(std::move(mod));
4747

4848
Array<ExtractedTask> tasks;
49-
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) {
49+
std::unordered_set<tec::CCacheKey> cache_;
50+
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_](const Expr& exp) {
5051
if (exp->IsInstance<FunctionNode>()) {
5152
Function relay_func = Downcast<Function>(exp);
52-
if (relay_func->HasNonzeroAttr(attr::kPrimitive)) {
53+
tec::CCacheKey cache_key(relay_func, target);
54+
if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache_.find(cache_key) == cache_.end()) {
5355
Array<te::Tensor> outputs;
5456
std::string fused_name;
5557
std::tie(outputs, fused_name) =
@@ -59,6 +61,7 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
5961
auto relay_mod = IRModule({{prim_fn_var, relay_func}});
6062
auto tir_mod = IRModule({{prim_fn_var, prim_func}});
6163
tasks.push_back(ExtractedTask(prim_fn_var->name_hint, relay_mod, target, {tir_mod}));
64+
cache_.insert(cache_key);
6265
}
6366
}
6467
});

src/relay/backend/te_compiler_cache.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
757757
std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function& source_func, Target target,
758758
bool return_inputs) {
759759
LowerToTECompute lower_te_compute(target);
760-
auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; });
760+
auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name;});
761761
// Following ScheduleBuilder, remove placeholder ops from outputs.
762762
tvm::Array<te::Tensor> tensor_outs;
763763
for (const auto& tensor : outputs) {

0 commit comments

Comments
 (0)