Skip to content

Commit 440e2ed

Browse files
committed
[BugFix][Runtime] Fix Incorrect node information
1 parent 49ed544 commit 440e2ed

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

python/tvm/contrib/debugger/debug_result.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,25 @@ def _update_graph_json(self):
7373
"""update the nodes_list with name, shape and data type,
7474
for temporarily storing the output.
7575
"""
76-
nodes_len = len(self._nodes_list)
77-
for i in range(nodes_len):
78-
node = self._nodes_list[i]
76+
eid = 0
77+
for node in self._nodes_list:
7978
input_list = []
80-
for input_node in node["inputs"]:
81-
input_list.append(self._nodes_list[input_node[0]]["name"])
82-
node["inputs"] = input_list
83-
dtype = str("type: " + self._dtype_list[1][i])
84-
if "attrs" not in node:
79+
if node["op"] == "null":
8580
node["attrs"] = {}
8681
node["op"] = "param"
87-
else:
82+
num_outputs = 1
83+
elif node["op"] == "tvm_op":
84+
for input_node in node["inputs"]:
85+
input_list.append(self._nodes_list[input_node[0]]["name"])
8886
node["op"] = node["attrs"]["func_name"]
87+
num_outputs = int(node["attrs"]["num_outputs"])
88+
else:
89+
raise ValueError("")
90+
node["inputs"] = input_list
91+
dtype = str("type: " + self._dtype_list[1][eid])
8992
node["attrs"].update({"T": dtype})
90-
node["shape"] = self._shapes_list[1][i]
93+
node["shape"] = self._shapes_list[1][eid]
94+
eid += num_outputs
9195

9296
def _cleanup_tensors(self):
9397
"""Remove the tensor dump file (graph wont be removed)"""

0 commit comments

Comments
 (0)