3131
3232#include < fstream>
3333#include < numeric>
34+ #include < regex>
3435#include < sstream>
3536
3637#include " ../../utils.h"
@@ -439,6 +440,23 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
439440 using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
440441 using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
441442
443+ std::map<std::string, std::string> op_map{
444+ {" bias" , " add" },
445+ {" relu" , " nn.relu" },
446+ {" tanh" , " tanh" },
447+ {" sigmoid" , " sigmoid" },
448+ };
449+
450+ std::vector<std::string> ParsingOpList (std::string op, std::string pattern_name) {
451+ std::vector<std::string> op_list = {" nn." + op};
452+ for (auto & t : op_map) {
453+ if (pattern_name.find (t.first ) != std::string::npos) {
454+ op_list.push_back (t.second );
455+ }
456+ }
457+ return op_list;
458+ }
459+
442460 public:
443461 DNNLJSONSerializer (const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {}
444462
@@ -453,28 +471,29 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
453471 ICHECK (comp.defined ()) << " DNNL JSON runtime only supports composite functions." ;
454472 name = comp.value ();
455473
456- if (name == " dnnl.conv2d_bias_relu" ) {
457- call = GetRootCall (fn->body .as <CallNode>(), 2 , {" nn.conv2d" , " add" , " nn.relu" });
458- } else if (name == " dnnl.conv2d_bias_tanh" ) {
459- call = GetRootCall (fn->body .as <CallNode>(), 2 , {" nn.conv2d" , " add" , " tanh" });
460- ICHECK (call->op .as <OpNode>()) << " Not op node" ;
461- } else if (name == " dnnl.conv2d_bias_sigmoid" ) {
462- call = GetRootCall (fn->body .as <CallNode>(), 2 , {" nn.conv2d" , " add" , " sigmoid" });
474+ if (name.find (" dnnl.conv2d_transpose" ) != std::string::npos) {
475+ std::vector<std::string> op_list = ParsingOpList (" conv2d_transpose" , name);
476+ call = GetRootCall (fn->body .as <CallNode>(), op_list.size () - 1 , op_list);
463477 ICHECK (call->op .as <OpNode>()) << " Not op node" ;
464- } else if (name == " dnnl.conv2d_bias" ) {
465- call = GetRootCall (fn->body .as <CallNode>(), 1 , {" nn.conv2d" , " add" });
478+ } else if (name.find (" dnnl.conv3d_transpose" ) != std::string::npos) {
479+ std::vector<std::string> op_list = ParsingOpList (" conv3d_transpose" , name);
480+ call = GetRootCall (fn->body .as <CallNode>(), op_list.size () - 1 , op_list);
466481 ICHECK (call->op .as <OpNode>()) << " Not op node" ;
467- } else if (name == " dnnl.conv2d_relu" ) {
468- call = GetRootCall (fn->body .as <CallNode>(), 1 , {" nn.conv2d" , " nn.relu" });
482+ } else if (name.find (" dnnl.conv1d" ) != std::string::npos) {
483+ std::vector<std::string> op_list = ParsingOpList (" conv1d" , name);
484+ call = GetRootCall (fn->body .as <CallNode>(), op_list.size () - 1 , op_list);
469485 ICHECK (call->op .as <OpNode>()) << " Not op node" ;
470- } else if (name == " dnnl.conv2d_tanh" ) {
471- call = GetRootCall (fn->body .as <CallNode>(), 1 , {" nn.conv2d" , " tanh" });
486+ } else if (name.find (" dnnl.conv2d" ) != std::string::npos) {
487+ std::vector<std::string> op_list = ParsingOpList (" conv2d" , name);
488+ call = GetRootCall (fn->body .as <CallNode>(), op_list.size () - 1 , op_list);
472489 ICHECK (call->op .as <OpNode>()) << " Not op node" ;
473- } else if (name == " dnnl.conv2d_sigmoid" ) {
474- call = GetRootCall (fn->body .as <CallNode>(), 1 , {" nn.conv2d" , " sigmoid" });
490+ } else if (name.find (" dnnl.conv3d" ) != std::string::npos) {
491+ std::vector<std::string> op_list = ParsingOpList (" conv3d" , name);
492+ call = GetRootCall (fn->body .as <CallNode>(), op_list.size () - 1 , op_list);
475493 ICHECK (call->op .as <OpNode>()) << " Not op node" ;
476- } else if (name == " dnnl.dense_bias" ) {
477- call = GetRootCall (fn->body .as <CallNode>(), 1 , {" nn.dense" , " add" });
494+ } else if (name.find (" dnnl.dense" ) != std::string::npos) {
495+ std::vector<std::string> op_list = ParsingOpList (" dense" , name);
496+ call = GetRootCall (fn->body .as <CallNode>(), op_list.size () - 1 , op_list);
478497 ICHECK (call->op .as <OpNode>()) << " Not op node" ;
479498 } else {
480499 LOG (FATAL) << " Unrecognized DNNL pattern: " << name;
0 commit comments