Skip to content

Commit 1d7dbf2

Browse files
zxybazhpfk-beta
authored andcommitted
[MetaSchedule] Bug Fix for Relay Integration (apache#10534)
* Bug fix. * Fix tune relay script. * Remove debug info. * Retest CI. * Add regression test. * Remove comments.
1 parent 75a38e8 commit 1d7dbf2

File tree

4 files changed

+249
-49
lines changed

4 files changed

+249
-49
lines changed

src/meta_schedule/integration.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,11 @@ Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
136136
if (database->HasWorkload(prim_mod)) {
137137
Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
138138
if (records.size() == 1) {
139-
LOG(INFO) << "Applied history best for: " << task_name;
140139
tir::Schedule sch =
141140
tir::Schedule::Traced(records[0]->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
142141
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
143142
records[0]->trace->ApplyToSchedule(sch, false);
144143
tir::PrimFunc func = GetOnlyOneFunction<tir::PrimFunc>(sch->mod()).value();
145-
LOG(INFO) << "\n" << tir::AsTVMScript(func);
146144
return func;
147145
}
148146
}

src/relay/backend/te_compiler.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,13 @@ class TECompilerImpl : public TECompilerNode {
322322
});
323323

324324
if (value->cached_func->prim_func.defined()) {
325-
VLOG(1) << "already have PrimFunc";
326-
value->cached_func->funcs->Add(value->cached_func->prim_fn_var,
327-
value->cached_func->prim_func.value());
325+
VLOG(1) << "Lowering PrimFunc";
326+
IRModule lowered = tvm::LowerPrimFunc(value->cached_func->prim_func.value(),
327+
value->cached_func->prim_fn_var->name_hint, false);
328+
ICHECK_EQ(lowered->functions.size(), 1);
329+
for (const auto& kv : lowered->functions) {
330+
value->cached_func->funcs->Add(value->cached_func->prim_fn_var, kv.second);
331+
}
328332
} else {
329333
// NOTE: array will copy on write.
330334
Array<te::Tensor> all_args = Array<te::Tensor>(value->cached_func->inputs);

src/relay/backend/te_compiler_cache.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
180180
}
181181
}
182182
if (use_meta_schedule_) {
183-
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
183+
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
184184
Optional<ObjectRef> opt_mod_or_base_func =
185185
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
186186
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,

tests/python/unittest/test_meta_schedule_tune_relay.py

Lines changed: 241 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,63 +16,98 @@
1616
# under the License.
1717
# pylint: disable=missing-docstring
1818
import logging
19+
from multiprocessing.sharedctypes import Value
1920
import tempfile
2021
from typing import List
21-
22+
from os import path as osp
2223
import numpy as np
2324
import pytest
2425
import tvm
2526
from tvm import relay
2627
from tvm.contrib import graph_executor
2728
from tvm.ir import IRModule
29+
from tvm.tir.schedule.schedule import Schedule
30+
from tvm.tir.schedule.trace import Trace
2831
from tvm.meta_schedule import ReplayTraceConfig
29-
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
32+
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload, JSONDatabase
33+
from tvm.meta_schedule.integration import ApplyHistoryBest
3034
from tvm.meta_schedule.testing.relay_workload import get_network
3135
from tvm.meta_schedule.tune import tune_relay
3236
from tvm.meta_schedule.utils import derived_object
3337
from tvm.target.target import Target
38+
from tvm.script import tir as T
3439

3540
logging.basicConfig()
3641
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
3742

43+
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
44+
# fmt: off
45+
@tvm.script.ir_module
46+
class tvmgen_default_fused_layout_transform:
47+
@T.prim_func
48+
def main(
49+
placeholder: T.Buffer[(1, 3, 16, 16), "float32"],
50+
T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"],
51+
) -> None:
52+
# function attr dict
53+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
54+
# body
55+
# with T.block("root")
56+
for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3):
57+
with T.block("T_layout_trans"):
58+
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
59+
T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3])
60+
T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
61+
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(
62+
ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16,
63+
placeholder[ax0, ax1 * 3 + ax4, ax2, ax3],
64+
T.float32(0),
65+
dtype="float32",
66+
)
3867

39-
@derived_object
40-
class DummyDatabase(PyDatabase):
41-
def __init__(self):
42-
super().__init__()
43-
self.records = []
44-
self.workload_reg = []
45-
46-
def has_workload(self, mod: IRModule) -> Workload:
47-
for workload in self.workload_reg:
48-
if tvm.ir.structural_equal(workload.mod, mod):
49-
return True
50-
return False
51-
52-
def commit_tuning_record(self, record: TuningRecord) -> None:
53-
self.records.append(record)
54-
55-
def commit_workload(self, mod: IRModule) -> Workload:
56-
for workload in self.workload_reg:
57-
if tvm.ir.structural_equal(workload.mod, mod):
58-
return workload
59-
workload = Workload(mod)
60-
self.workload_reg.append(workload)
61-
return workload
62-
63-
def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
64-
return list(
65-
filter(
66-
lambda x: x.workload == workload,
67-
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
68-
)
69-
)[: int(top_k)]
7068

71-
def __len__(self) -> int:
72-
return len(self.records)
69+
@tvm.script.ir_module
70+
class tvmgen_default_fused_nn_contrib_conv2d_NCHWc:
71+
@T.prim_func
72+
def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None:
73+
# function attr dict
74+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
75+
# body
76+
# with T.block("root")
77+
data_pad = T.alloc_buffer([1, 1, 20, 20, 3], dtype="float32")
78+
for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3):
79+
with T.block("data_pad"):
80+
i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
81+
T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1])
82+
T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1])
83+
data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32")
84+
for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5):
85+
with T.block("conv2d_NCHWc"):
86+
n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7])
87+
T.reads(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block], data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block])
88+
T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block])
89+
T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]})
90+
with T.init():
91+
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0)
92+
conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]
93+
94+
@tvm.script.ir_module
95+
class tvmgen_default_fused_layout_transform_1:
96+
@T.prim_func
97+
def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None:
98+
# function attr dict
99+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
100+
# body
101+
# with T.block("root")
102+
for i0, i1, i2, i3 in T.grid(1, 8, 16, 16):
103+
with T.block("T_layout_trans"):
104+
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
105+
T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4])
106+
T.writes(T_layout_trans[ax0, ax1, ax2, ax3])
107+
T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32")
73108

74-
def print_results(self) -> None:
75-
print("\n".join([str(r) for r in self.records]))
109+
# fmt: on
110+
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
76111

77112

78113
@pytest.mark.skip("Integration test")
@@ -101,8 +136,7 @@ def test_meta_schedule_tune_relay(
101136
mod, params, (input_name, _, _) = get_network(name=model_name, input_shape=input_shape)
102137
target = Target(target)
103138
with tempfile.TemporaryDirectory() as work_dir:
104-
database = DummyDatabase()
105-
rt_mod: tvm.runtime.Module = tune_relay(
139+
rt_mod1: tvm.runtime.Module = tune_relay(
106140
mod=mod,
107141
params=params,
108142
target=target,
@@ -111,11 +145,173 @@ def test_meta_schedule_tune_relay(
111145
num_trials_total=32,
112146
),
113147
work_dir=work_dir,
114-
database=database,
148+
database=JSONDatabase(
149+
osp.join(work_dir, "workload.json"), osp.join(work_dir, "records.json")
150+
),
151+
)
152+
# Compile without meta-scheduler for correctness check
153+
with tvm.transform.PassContext(opt_level=0):
154+
rt_mod2 = relay.build(mod, target=target, params=params)
155+
156+
def get_output(data, lib):
157+
module = graph_executor.GraphModule(lib["default"](dev))
158+
module.set_input(input_name, data)
159+
module.run()
160+
return module.get_output(0).numpy()
161+
162+
# Check correctness
163+
actual_output = get_output(data, rt_mod1)
164+
expected_output = get_output(data, rt_mod2)
165+
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)
166+
167+
168+
def test_meta_schedule_te2primfunc_argument_order():
169+
@derived_object
170+
class TestDummyDatabase(PyDatabase):
171+
def __init__(self):
172+
super().__init__()
173+
self.records = []
174+
self.workload_reg = []
175+
176+
def has_workload(self, mod: IRModule) -> Workload:
177+
for workload in self.workload_reg:
178+
if tvm.ir.structural_equal(workload.mod, mod):
179+
return True
180+
# The database has already put in all correct workloads
181+
raise ValueError(
182+
"The workload searched for is not in given database!"
183+
+ " Incorrect TIR was generated from TE subgraph."
184+
)
185+
186+
def commit_tuning_record(self, record: TuningRecord) -> None:
187+
self.records.append(record)
188+
189+
def commit_workload(self, mod: IRModule) -> Workload:
190+
for workload in self.workload_reg:
191+
if tvm.ir.structural_equal(workload.mod, mod):
192+
return workload
193+
workload = Workload(mod)
194+
self.workload_reg.append(workload)
195+
return workload
196+
197+
def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
198+
return list(
199+
filter(
200+
lambda x: x.workload == workload,
201+
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
202+
)
203+
)[: int(top_k)]
204+
205+
def __len__(self) -> int:
206+
return len(self.records)
207+
208+
def print_results(self) -> None:
209+
print("\n".join([str(r) for r in self.records]))
210+
211+
data_shape = (1, 3, 16, 16)
212+
weight_shape = (8, 3, 5, 5)
213+
data = relay.var("data", relay.TensorType(data_shape, "float32"))
214+
weight = relay.var("weight", relay.TensorType(weight_shape, "float32"))
215+
y = relay.nn.conv2d(
216+
data,
217+
weight,
218+
padding=(2, 2),
219+
kernel_size=(5, 5),
220+
kernel_layout="OIHW",
221+
out_dtype="float32",
222+
)
223+
f = relay.Function([data, weight], y)
224+
mod = tvm.IRModule.from_expr(f)
225+
mod = relay.transform.InferType()(mod)
226+
227+
data_sample = np.random.rand(*data_shape).astype("float32")
228+
weight_sample = np.random.rand(*weight_shape).astype("float32")
229+
params = {mod["main"].params[1].name_hint: weight_sample}
230+
231+
input_name = "data"
232+
dev = tvm.cpu()
233+
target = Target("llvm --num-cores=16")
234+
data = tvm.nd.array(data_sample, dev)
235+
236+
database = TestDummyDatabase()
237+
database.commit_workload(tvmgen_default_fused_layout_transform)
238+
database.commit_workload(tvmgen_default_fused_layout_transform_1)
239+
database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc)
240+
241+
with ApplyHistoryBest(database):
242+
with tvm.transform.PassContext(
243+
opt_level=3,
244+
config={"relay.backend.use_meta_schedule": True},
245+
):
246+
rt_mod1 = relay.build(mod, target=target, params=params)
247+
248+
# Compile without meta-scheduler for correctness check
249+
with tvm.transform.PassContext(opt_level=0):
250+
rt_mod2 = relay.build(mod, target=target, params=params)
251+
252+
def get_output(data, lib):
253+
module = graph_executor.GraphModule(lib["default"](dev))
254+
module.set_input(input_name, data)
255+
module.run()
256+
return module.get_output(0).numpy()
257+
258+
# Check correctness
259+
actual_output = get_output(data, rt_mod1)
260+
expected_output = get_output(data, rt_mod2)
261+
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)
262+
263+
264+
def test_meta_schedule_relay_lowering():
265+
data_shape = (1, 3, 16, 16)
266+
weight_shape = (8, 3, 5, 5)
267+
data = relay.var("data", relay.TensorType(data_shape, "float32"))
268+
weight = relay.var("weight", relay.TensorType(weight_shape, "float32"))
269+
y = relay.nn.conv2d(
270+
data,
271+
weight,
272+
padding=(2, 2),
273+
kernel_size=(5, 5),
274+
kernel_layout="OIHW",
275+
out_dtype="float32",
276+
)
277+
f = relay.Function([data, weight], y)
278+
mod = tvm.IRModule.from_expr(f)
279+
mod = relay.transform.InferType()(mod)
280+
281+
data_sample = np.random.rand(*data_shape).astype("float32")
282+
weight_sample = np.random.rand(*weight_shape).astype("float32")
283+
params = {mod["main"].params[1].name_hint: weight_sample}
284+
285+
input_name = "data"
286+
dev = tvm.cpu()
287+
target = Target("llvm --num-cores=16")
288+
data = tvm.nd.array(data_sample, dev)
289+
290+
with tempfile.TemporaryDirectory() as work_dir:
291+
database = JSONDatabase(
292+
osp.join(work_dir, "workload.json"), osp.join(work_dir, "records.json")
115293
)
294+
295+
database.commit_tuning_record(
296+
TuningRecord(
297+
Trace([], {}),
298+
[0.0],
299+
database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc),
300+
target=target,
301+
args_info=[],
302+
)
303+
)
304+
305+
with ApplyHistoryBest(database):
306+
with tvm.transform.PassContext(
307+
opt_level=3,
308+
config={"relay.backend.use_meta_schedule": True},
309+
):
310+
rt_mod1 = relay.build(mod, target=target, params=params)
311+
116312
# Compile without meta-scheduler for correctness check
117313
with tvm.transform.PassContext(opt_level=0):
118-
rt_mod2 = relay.build(mod, target=Target("llvm"), params=params)
314+
rt_mod2 = relay.build(mod, target=target, params=params)
119315

120316
def get_output(data, lib):
121317
module = graph_executor.GraphModule(lib["default"](dev))
@@ -124,8 +320,8 @@ def get_output(data, lib):
124320
return module.get_output(0).numpy()
125321

126322
# Check correctness
127-
actual_output = get_output(data, rt_mod)
128-
expected_output = get_output(tvm.nd.array(data.numpy(), device=tvm.cpu()), rt_mod2)
323+
actual_output = get_output(data, rt_mod1)
324+
expected_output = get_output(data, rt_mod2)
129325
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)
130326

131327

@@ -136,3 +332,5 @@ def get_output(data, lib):
136332
test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "nvidia/geforce-rtx-3070")
137333
test_meta_schedule_tune_relay("bert_base", [1, 64], "llvm --num-cores=16")
138334
test_meta_schedule_tune_relay("bert_base", [1, 64], "nvidia/geforce-rtx-3070")
335+
test_meta_schedule_te2primfunc_argument_order()
336+
test_meta_schedule_relay_lowering()

0 commit comments

Comments
 (0)