@@ -46,10 +46,12 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
4646 auto opt_mod = seq (std::move (mod));
4747
4848 Array<ExtractedTask> tasks;
49- PostOrderVisit (opt_mod->Lookup (" main" ), [target, &tasks](const Expr& exp) {
49+ std::unordered_set<tec::CCacheKey> cache_;
50+ PostOrderVisit (opt_mod->Lookup (" main" ), [target, &tasks, &cache_](const Expr& exp) {
5051 if (exp->IsInstance <FunctionNode>()) {
5152 Function relay_func = Downcast<Function>(exp);
52- if (relay_func->HasNonzeroAttr (attr::kPrimitive )) {
53+ tec::CCacheKey cache_key (relay_func, target);
54+ if (relay_func->HasNonzeroAttr (attr::kPrimitive ) && cache_.find (cache_key) == cache_.end ()) {
5355 Array<te::Tensor> outputs;
5456 std::string fused_name;
5557 std::tie (outputs, fused_name) =
@@ -59,6 +61,7 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
5961 auto relay_mod = IRModule ({{prim_fn_var, relay_func}});
6062 auto tir_mod = IRModule ({{prim_fn_var, prim_func}});
6163 tasks.push_back (ExtractedTask (prim_fn_var->name_hint , relay_mod, target, {tir_mod}));
64+ cache_.insert (cache_key);
6265 }
6366 }
6467 });
0 commit comments