Skip to content

Commit ba65197

Browse files
sungggSiyuan FengspectrometerHBHjinhongyiiMasterJH5574
authored
[MetaSchedule][M4b] Testcases for TensorRT builder/runner (#10055)
Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]>
1 parent 0fb5ae2 commit ba65197

File tree

3 files changed

+282
-0
lines changed

3 files changed

+282
-0
lines changed

python/tvm/meta_schedule/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
"""Testing utilities in meta schedule"""
1818
from .local_rpc import LocalRPC
1919
from .relay_workload import get_network
20+
from .byoc_trt import relay_build_with_tensorrt
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""TensorRT-MetaSchedule integration"""
18+
# pylint: disable=import-outside-toplevel
19+
20+
from typing import List
21+
import tvm
22+
from tvm.runtime import Module
23+
from tvm.meta_schedule.builder import BuilderResult
24+
from tvm.target import Target
25+
26+
27+
def relay_build_with_tensorrt(
28+
mod: Module,
29+
target: Target,
30+
params: dict,
31+
) -> List[BuilderResult]:
32+
"""Build a Relay IRModule with TensorRT BYOC
33+
Parameters
34+
----------
35+
mod : IRModule
36+
The Relay IRModule to build.
37+
target : Target
38+
The target to build the module for.
39+
params : Dict[str, NDArray]
40+
The parameter dict to build the module with.
41+
Returns
42+
-------
43+
mod : runtime.Module
44+
The built module.
45+
"""
46+
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
47+
48+
assert isinstance(target, Target)
49+
mod, config = partition_for_tensorrt(mod, params)
50+
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
51+
result = tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params)
52+
assert isinstance(result, Module)
53+
return result
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
""" Test Meta Schedule Builder """
18+
import sys
19+
import pytest
20+
import itertools
21+
import tvm
22+
from tvm import relay
23+
from tvm.relay import testing
24+
from tvm.relay.op.contrib import tensorrt
25+
import numpy as np
26+
from typing import List
27+
from tvm._ffi import register_func
28+
from tvm.target import Target
29+
from tvm.runtime import Module
30+
from tvm.meta_schedule.arg_info import TensorInfo
31+
from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult
32+
from tvm.meta_schedule.runner import (
33+
EvaluatorConfig,
34+
LocalRunner,
35+
RunnerInput,
36+
)
37+
38+
from tvm.tir import FloatImm
39+
from tvm.meta_schedule.testing import get_network
40+
41+
has_tensorrt_codegen = pytest.mark.skipif(
42+
not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available"
43+
)
44+
has_tensorrt_runtime = pytest.mark.skipif(
45+
not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available"
46+
)
47+
48+
49+
# conv2d+relu network
50+
def get_conv2d_relu(
51+
data_shape,
52+
out_channels,
53+
kernel_size,
54+
strides,
55+
padding,
56+
dilation,
57+
groups,
58+
data_layout,
59+
kernel_layout,
60+
dtype,
61+
):
62+
63+
data = relay.var("data", relay.TensorType(data_shape, dtype))
64+
weight = relay.var("weight")
65+
66+
net = relay.nn.conv2d(
67+
data=data,
68+
weight=weight, # conv kernel
69+
strides=strides,
70+
padding=padding,
71+
dilation=dilation,
72+
groups=groups,
73+
channels=out_channels,
74+
kernel_size=kernel_size,
75+
data_layout=data_layout,
76+
kernel_layout=kernel_layout,
77+
)
78+
net = relay.add(net, net)
79+
net = relay.nn.relu(net)
80+
81+
inputs = relay.analysis.free_vars(net)
82+
return relay.Function(inputs, net)
83+
84+
85+
def verify_meta_schedule_with_tensorrt(
86+
mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool = True, mode: str = "vm"
87+
):
88+
if use_meta_sched:
89+
# With meta_schedule
90+
dev = "cuda"
91+
92+
# Build
93+
if use_trt:
94+
from tvm.meta_schedule.testing import relay_build_with_tensorrt
95+
96+
builder = LocalBuilder(f_build=relay_build_with_tensorrt)
97+
else:
98+
99+
def relay_build_without_tensorrt(
100+
mod: Module,
101+
target: Target,
102+
params: dict,
103+
) -> List[BuilderResult]:
104+
return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params)
105+
106+
builder = LocalBuilder(f_build=relay_build_without_tensorrt)
107+
108+
builder_input = BuilderInput(mod, Target(dev, host="llvm"), params)
109+
110+
(builder_result,) = builder.build([builder_input])
111+
assert builder_result.error_msg is None
112+
assert builder_result.artifact_path is not None
113+
114+
# Run
115+
evaluator_config = EvaluatorConfig(
116+
number=5,
117+
repeat=2,
118+
min_repeat_ms=0,
119+
enable_cpu_cache_flush=False,
120+
)
121+
122+
runner_input = RunnerInput(
123+
builder_result.artifact_path, "cuda", [TensorInfo("float32", data_shape)]
124+
)
125+
126+
def eval_func(rt_mod, device, evaluator_config, repeated_args):
127+
rt_mod = tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device))
128+
129+
eval = rt_mod.module.time_evaluator(
130+
func_name="run",
131+
dev=device,
132+
number=evaluator_config.number,
133+
repeat=evaluator_config.repeat,
134+
min_repeat_ms=evaluator_config.min_repeat_ms,
135+
f_preproc="cache_flush_cpu_non_first_arg"
136+
if evaluator_config.enable_cpu_cache_flush
137+
else "",
138+
)
139+
repeated_costs: List[List[float]] = []
140+
for args in repeated_args:
141+
profile_result = eval(*args)
142+
repeated_costs.append(profile_result.results)
143+
144+
costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
145+
return costs
146+
147+
runner = LocalRunner(
148+
evaluator_config=evaluator_config,
149+
f_run_evaluator=eval_func,
150+
)
151+
152+
# Run the module
153+
(runner_future,) = runner.run([runner_input])
154+
runner_result = runner_future.result()
155+
assert runner_result is not None
156+
assert runner_result.run_secs is not None
157+
assert runner_result.error_msg is None
158+
159+
for result in runner_result.run_secs:
160+
if isinstance(result, FloatImm):
161+
result = result.value
162+
assert isinstance(result, float)
163+
assert result >= 0.0
164+
165+
else:
166+
# Without meta_schedule
167+
if use_trt:
168+
mod, config = tensorrt.partition_for_tensorrt(mod)
169+
with tvm.transform.PassContext(
170+
opt_level=3, config={"relay.ext.tensorrt.options": config}
171+
):
172+
func = relay.create_executor(
173+
mode, mod=mod, device=tvm.cuda(0), target="cuda"
174+
).evaluate()
175+
else:
176+
with tvm.transform.PassContext(opt_level=3):
177+
func = relay.create_executor(
178+
mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params
179+
).evaluate()
180+
181+
182+
@tvm.testing.requires_cuda
183+
@has_tensorrt_codegen
184+
@has_tensorrt_runtime
185+
def test_conv2d_relu():
186+
data_shape = (1, 1280, 14, 14)
187+
out_channels = 256
188+
kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1
189+
data_layout, kernel_layout = "NCHW", "OIHW"
190+
dtype = "float32"
191+
192+
f = get_conv2d_relu(
193+
data_shape,
194+
out_channels,
195+
kernel_size,
196+
strides,
197+
padding,
198+
dilation,
199+
groups,
200+
data_layout,
201+
kernel_layout,
202+
dtype,
203+
)
204+
205+
mod, params = testing.create_workload(f)
206+
verify_meta_schedule_with_tensorrt(mod, params, data_shape)
207+
208+
209+
@tvm.testing.requires_cuda
210+
@has_tensorrt_codegen
211+
@has_tensorrt_runtime
212+
@pytest.mark.parametrize(
213+
"model_name",
214+
["resnet-50", "mobilenet"],
215+
)
216+
@pytest.mark.parametrize("batch_size", [1])
217+
@pytest.mark.parametrize("use_meta_sched", [True])
218+
@pytest.mark.parametrize("use_trt", [True, False])
219+
def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool):
220+
221+
mod, params, input_shape, output_shape = get_network(name=model_name, batch_size=batch_size)
222+
verify_meta_schedule_with_tensorrt(
223+
mod, params, input_shape, use_meta_sched=use_meta_sched, use_trt=use_trt, mode="vm"
224+
)
225+
226+
227+
if __name__ == "__main__":
228+
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)