@@ -73,21 +73,25 @@ def _update_graph_json(self):
7373 """update the nodes_list with name, shape and data type,
7474 for temporarily storing the output.
7575 """
76- nodes_len = len (self ._nodes_list )
77- for i in range (nodes_len ):
78- node = self ._nodes_list [i ]
76+ eid = 0
77+ for node in self ._nodes_list :
7978 input_list = []
80- for input_node in node ["inputs" ]:
81- input_list .append (self ._nodes_list [input_node [0 ]]["name" ])
82- node ["inputs" ] = input_list
83- dtype = str ("type: " + self ._dtype_list [1 ][i ])
84- if "attrs" not in node :
79+ if node ["op" ] == "null" :
8580 node ["attrs" ] = {}
8681 node ["op" ] = "param"
87- else :
82+ num_outputs = 1
83+ elif node ["op" ] == "tvm_op" :
84+ for input_node in node ["inputs" ]:
85+ input_list .append (self ._nodes_list [input_node [0 ]]["name" ])
8886 node ["op" ] = node ["attrs" ]["func_name" ]
87+ num_outputs = int (node ["attrs" ]["num_outputs" ])
88+ else :
89+ raise ValueError ("" )
90+ node ["inputs" ] = input_list
91+ dtype = str ("type: " + self ._dtype_list [1 ][eid ])
8992 node ["attrs" ].update ({"T" : dtype })
90- node ["shape" ] = self ._shapes_list [1 ][i ]
93+ node ["shape" ] = self ._shapes_list [1 ][eid ]
94+ eid += num_outputs
9195
9296 def _cleanup_tensors (self ):
9397 """Remove the tensor dump file (graph wont be removed)"""
0 commit comments