Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions python/tvm/contrib/debugger/debug_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,25 @@ def _update_graph_json(self):
"""update the nodes_list with name, shape and data type,
for temporarily storing the output.
"""
nodes_len = len(self._nodes_list)
for i in range(nodes_len):
node = self._nodes_list[i]
eid = 0
for node in self._nodes_list:
input_list = []
for input_node in node["inputs"]:
input_list.append(self._nodes_list[input_node[0]]["name"])
node["inputs"] = input_list
dtype = str("type: " + self._dtype_list[1][i])
if "attrs" not in node:
if node["op"] == "null":
node["attrs"] = {}
node["op"] = "param"
else:
num_outputs = 1
elif node["op"] == "tvm_op":
for input_node in node["inputs"]:
input_list.append(self._nodes_list[input_node[0]]["name"])
node["op"] = node["attrs"]["func_name"]
num_outputs = int(node["attrs"]["num_outputs"])
else:
raise ValueError("")
node["inputs"] = input_list
dtype = str("type: " + self._dtype_list[1][eid])
node["attrs"].update({"T": dtype})
node["shape"] = self._shapes_list[1][i]
node["shape"] = self._shapes_list[1][eid]
eid += num_outputs

def _cleanup_tensors(self):
"""Remove the tensor dump file (graph wont be removed)"""
Expand Down
26 changes: 25 additions & 1 deletion tests/python/unittest/test_runtime_graph_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tvm._ffi.base import TVMError
from tvm.contrib import utils
from tvm.contrib.debugger import debug_executor

from tvm import relay

# Constants for creating simple graphs, fixtures to avoid free globals
@pytest.fixture
Expand Down Expand Up @@ -275,5 +275,29 @@ def test_run_single_node(graph, n, A, myadd):
mod.run_individual_node(2)


@tvm.testing.requires_llvm
def test_multiple_output():
x = relay.var("x", shape=(1, 3, 48, 16), dtype="float32")
t = relay.split(x, [12, 16, 32], 2).astuple()
x0 = relay.TupleGetItem(t, 0)
x1 = relay.TupleGetItem(t, 1)
x2 = relay.TupleGetItem(t, 2)
x3 = relay.TupleGetItem(t, 3)
p0 = relay.const(np.random.uniform(-1, 1, (3, 3, 1, 1)).astype("float32"))
y = relay.nn.conv2d(x2, p0, kernel_size=(1, 1), kernel_layout="OIHW", out_dtype="float32") + x3

func = relay.Function([x], relay.Tuple([x0, x1, y]))
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
target = tvm.target.Target("llvm")
device = tvm.cpu()
lib = relay.build(mod, target=target)
m = debug_executor.GraphModuleDebug(
lib["debug_create"]("default", device), [device], lib.get_graph_json(), None
)
nodes = m.debug_datum.get_graph_nodes()
assert nodes[2]["shape"] == [3, 3, 1, 1]


if __name__ == "__main__":
tvm.testing.main()