Skip to content

Commit

Permalink
fix parse aten::xxx bug and graphviz bug #5
Browse files Browse the repository at this point in the history
  • Loading branch information
xiayouran committed Aug 8, 2023
1 parent 1099d09 commit d164582
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions visu_tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def codegen(self):
graph.render(filename=self.save_name, format='svg', cleanup=True)

def get_node_args(self, output_node, body_node):
pattern1 = re.compile(r'(%[a-zA-Z]*([\d._a-z]*\d*)*|meta\[relay\.Constant]\[\d*])')
pattern1 = re.compile(r'(%[a-zA-Z]*:*([\d._a-z]*\d*)*|meta\[relay\.Constant]\[\d*])')
pattern2 = re.compile(r'(%[a-z]?\d+\.\d+)')

if '(%' not in body_node:
Expand Down Expand Up @@ -159,10 +159,21 @@ def parse_node(self):
if not args_list and not index:
continue

for i in range(len(args_list)):
arg_str = args_list[i]
if '::' in arg_str:
# fix the 'colons in node identifiers' bug
# https://github.com/xflr6/graphviz/issues/53
args_list[i] = arg_str.replace('::', '--')

self.nodes[info[0]] = IRNode(name=info[0], label=info[1][:index], inputs=args_list)
for n in args_list:
if not self.nodes.get(n, ''):
self.nodes[n] = IRNode(name=n, label=n, color='white')
if '--' in n:
n_ = n.replace('--', '::')
self.nodes[n] = IRNode(name=n_, label=n_, color='white')
else:
self.nodes[n] = IRNode(name=n, label=n, color='white')

# 在图的末尾添加一个空节点
# TODO Multi-output bug
Expand Down Expand Up @@ -499,4 +510,4 @@ def split_fn_op(self):
pnodes[output_str if output_str else '}'] = PNode(name=output_str if output_str else '}', type='op',
inputs=args, body=match_op[0])

return pnodes
return pnodes

0 comments on commit d164582

Please sign in to comment.