Skip to content

Commit 4bad153

Browse files
committed
Rebase to pass CI and reflect suggestions
1 parent e8e0b5c commit 4bad153

File tree

3 files changed

+22
-25
lines changed

3 files changed

+22
-25
lines changed

python/tvm/meta_schedule/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
"""Testing utilities in meta schedule"""
1818
from .local_rpc import LocalRPC
1919
from .relay_workload import get_network
20+
from .byoc_trt import relay_build_with_tensorrt
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import tvm
2+
from tvm.runtime import Module
3+
from tvm.target import Target
4+
from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult
5+
from typing import List
6+
7+
8+
def relay_build_with_tensorrt(
9+
mod: Module,
10+
target: Target,
11+
params: dict,
12+
) -> List[BuilderResult]:
13+
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
14+
15+
mod, config = partition_for_tensorrt(mod, params)
16+
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
17+
return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params)

tests/python/unittest/test_meta_schedule_byoc_tensorrt.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,8 @@
2525
from tvm.relay import testing
2626
from tvm.relay.op.contrib import tensorrt
2727
import numpy as np
28-
from typing import List, Tuple
29-
30-
# from tvm import script
31-
# from tvm._ffi import register_func
32-
# from tvm.runtime import Module
28+
from typing import List
3329
from tvm._ffi import register_func
34-
from tvm.relay.testing.init import Initializer
3530
from tvm.target import Target
3631
from tvm.runtime import Module
3732
from tvm.meta_schedule.arg_info import TensorInfo
@@ -94,25 +89,11 @@ def verify_meta_schedule_with_tensorrt(
9489
):
9590
if use_meta_sched:
9691
# With meta_schedule
97-
dev = "nvidia/geforce-rtx-2080"
92+
dev = "cuda"
9893

9994
# Build
10095
if use_trt:
101-
102-
def relay_build_with_tensorrt(
103-
mod: Module,
104-
target: Target,
105-
params: dict,
106-
) -> List[BuilderResult]:
107-
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
108-
109-
mod, config = partition_for_tensorrt(mod, params)
110-
with tvm.transform.PassContext(
111-
opt_level=3, config={"relay.ext.tensorrt.options": config}
112-
):
113-
return tvm.relay.build_module._build_module_no_factory(
114-
mod, "cuda", "llvm", params
115-
)
96+
from tvm.meta_schedule.testing import relay_build_with_tensorrt
11697

11798
builder = LocalBuilder(f_build=relay_build_with_tensorrt)
11899
else:
@@ -122,7 +103,6 @@ def relay_build_without_tensorrt(
122103
target: Target,
123104
params: dict,
124105
) -> List[BuilderResult]:
125-
# @Sung: Weird. Cannot pass keyword arg
126106
return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params)
127107

128108
builder = LocalBuilder(f_build=relay_build_without_tensorrt)
@@ -235,7 +215,7 @@ def test_conv2d_relu():
235215
"model_name",
236216
["resnet-50", "mobilenet"],
237217
)
238-
@pytest.mark.parametrize("batch_size", [1, 8, 16])
218+
@pytest.mark.parametrize("batch_size", [1])
239219
@pytest.mark.parametrize("use_meta_sched", [True])
240220
@pytest.mark.parametrize("use_trt", [True, False])
241221
def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool):
@@ -246,6 +226,5 @@ def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use
246226
)
247227

248228

249-
# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True)
250229
if __name__ == "__main__":
251230
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)