Skip to content

Commit df17c27

Browse files
committed
Address review comments
1 parent eaca5d5 commit df17c27

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

python/tvm/driver/tvmc/transform.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,12 @@ def convert_graph_layout(mod, desired_layouts, ops=None):
125125
if ops is None:
126126
ops = ["nn.conv2d", "nn.conv2d_transpose", "qnn.conv2d"]
127127

128-
assert isinstance(desired_layouts, list) and len(desired_layouts) > 0
128+
if not isinstance(desired_layouts, list):
129+
# For backwards compatibility
130+
assert isinstance(desired_layouts, str)
131+
desired_layouts = [desired_layouts]
132+
133+
assert len(desired_layouts) > 0
129134

130135
if len(desired_layouts) != len(ops):
131136
if len(desired_layouts) != 1:
@@ -231,6 +236,9 @@ def generate_transform_args(parser):
231236
"--desired-layout",
232237
nargs="+",
233238
help="Change the data/kernel layout of the graph. (i.e. NCHW or NHWC:HWIO)",
239+
"This option can be provided multiple times to specify per-operator layouts, "
240+
"e.g. '--desired-layout NHWC:HWIO' (Apply same layout for every operator)."
241+
"e.g. '--desired-layout-ops nn.conv2d nn.avg_pool2d --desired-layout NCHW NHWC'."
234242
)
235243
parser.add_argument(
236244
"--desired-layout-ops",

tests/python/driver/tvmc/test_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_layout_transform_convert_kernel_layout_pass_args(relay_conv2d, monkeypa
8080
a non-default kernel layout is provided.
8181
"""
8282
desired_layout = "NHWC:HWIO"
83-
desired_layout_ops = ["nn.nonv2d"]
83+
desired_layout_ops = ["nn.conv2d"]
8484

8585
mock_convert_layout = MagicMock()
8686
mock_convert_layout.return_value = relay.transform.ConvertLayout({})

0 commit comments

Comments
 (0)