Skip to content

Commit 37f6aa0

Browse files
authored
[MetaSchedule] Fix tensorcore winograd task extraction (#13625)
* [MetaSchedule] Fix tensorcore winograd task extraction * add test * fixed target
1 parent 7674ea8 commit 37f6aa0

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

python/tvm/relay/op/strategy/cuda.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
261261
)
262262
if (
263263
target.kind.name == "cuda"
264+
and not is_auto_scheduler_enabled()
265+
and not is_meta_schedule_enabled()
264266
and nvcc.have_tensorcore(target=target)
265267
and (
266268
(N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)

tests/python/unittest/test_meta_schedule_relay_integration.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,24 @@ def test_meta_schedule_integration_extract_from_resnet():
108108
assert t.task_name in expected_task_names, t.task_name
109109

110110

111+
@requires_torch
112+
def test_task_extraction_winograd_tensorcore():
113+
mod, params, _ = get_network(name="resnet_50", input_shape=[16, 3, 224, 224])
114+
seq = tvm.transform.Sequential(
115+
[
116+
relay.transform.ToMixedPrecision("float16"),
117+
relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "HWIO"]}),
118+
]
119+
)
120+
with tvm.transform.PassContext(opt_level=3):
121+
mod = seq(mod)
122+
123+
target = tvm.target.Target("nvidia/geforce-rtx-3070")
124+
extracted_tasks = ms.relay_integration.extract_tasks(mod, target=target, params=params)
125+
126+
assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4
127+
128+
111129
@requires_torch
112130
def test_task_extraction_anchor_block():
113131
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])

0 commit comments

Comments
 (0)