Skip to content

Commit 3a3e1e4

Browse files
committed
Fix bug on passing the new config attrs to codegen for tensorrt partition
1 parent 422ae09 commit 3a3e1e4

File tree

3 files changed

+23
-67
lines changed

3 files changed

+23
-67
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,7 @@ def check_dynamism(args, op_name):
211211
elif isinstance(arg, Tuple):
212212
return check_dynamism(arg.fields, op_name)
213213
else:
214-
logger.info(
215-
"Arg not supported in TensorRT for %s with type %s",
216-
op_name,
217-
type(arg),
218-
)
214+
logger.info("Arg not supported in TensorRT for %s with type %s", op_name, type(arg))
219215
return True
220216
return False
221217

@@ -596,8 +592,8 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
596592
"""Check if concatenate is supported by TensorRT."""
597593

598594
attrs, args = expr.attrs, expr.args
599-
if not is_supported_trt_dtype(args):
600-
return False
595+
if any([x.dtype not in supported_types for x in args[0].checked_type.fields]):
596+
logger.info("Only float16 and float32 inputs are supported for TensorRT.")
601597
if not get_tensorrt_use_implicit_batch_mode():
602598
return True
603599
if int(attrs.axis) == 0:
@@ -987,11 +983,8 @@ def is_valid_subgraph(params, body):
987983
if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1:
988984
logger.info("tensorrt: inputs have different batch sizes")
989985
return False
990-
if (
991-
get_tensorrt_remove_no_mac_subgraphs()
992-
and not IsComputeIntensiveGraph().is_graph_compute_intensive(body)
993-
):
994-
return False
986+
if get_tensorrt_remove_no_mac_subgraphs():
987+
return IsComputeIntensiveGraph().is_graph_compute_intensive(body)
995988
return True
996989

997990

src/relay/backend/contrib/tensorrt/codegen.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
226226
tensorrt_version_attr.emplace_back(tensorrt_version);
227227
use_implicit_batch_attr.emplace_back(use_implicit_batch);
228228
max_workspace_size_attr.emplace_back(max_workspace_size);
229+
use_fp16_attr.emplace_back(use_fp16);
230+
use_uint8_attr.emplace_back(use_uint8);
229231
node->SetAttr("tensorrt_version", tensorrt_version_attr);
230232
node->SetAttr("use_implicit_batch", use_implicit_batch_attr);
231233
node->SetAttr("max_workspace_size", max_workspace_size_attr);

tests/python/contrib/test_tensorrt.py

Lines changed: 16 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tvm.contrib import graph_executor, utils
3434
from tvm.runtime.vm import VirtualMachine
3535

36+
from tvm.relay import Any, GlobalVar
3637
from tvm.relay.transform import FirstOrderGradient, InferType
3738
from tvm.relay.transform.transform import ToMixedPrecision
3839

@@ -88,7 +89,7 @@ def set_func_attr(func, compile_name, symbol_name):
8889
return func
8990

9091

91-
def run_and_verify_func(config, target="cuda", run_module=True, data_type="float16"):
92+
def run_and_verify_func(config, target="cuda", run_module=True, data_type="float32"):
9293
"""Test a Relay func by compiling, running, and comparing TVM and TRT outputs.
9394
9495
Parameters
@@ -277,6 +278,9 @@ def test_tensorrt_not_compatible(run_module):
277278
results = func(x_data)
278279

279280

281+
@pytest.mark.xfail(
282+
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
283+
)
280284
def test_tensorrt_serialize_graph_executor(run_module):
281285
import mxnet as mx
282286
from mxnet.gluon.model_zoo.vision import get_model
@@ -331,6 +335,9 @@ def load_graph():
331335
assert_result_dict_holds(result_dict)
332336

333337

338+
@pytest.mark.xfail(
339+
reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901")
340+
)
334341
def test_tensorrt_serialize_vm(run_module):
335342
import mxnet as mx
336343
from mxnet.gluon.model_zoo.vision import get_model
@@ -473,12 +480,7 @@ def get_graph(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)):
473480
x = relay.var("x", shape=(x_shape), dtype="float32")
474481
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
475482
out = relay.nn.conv2d(
476-
x,
477-
kernel,
478-
channels=16,
479-
kernel_size=(3, 3),
480-
data_layout="NHWC",
481-
kernel_layout="HWIO",
483+
x, kernel, channels=16, kernel_size=(3, 3), data_layout="NHWC", kernel_layout="HWIO"
482484
)
483485
f = relay.Function([x, kernel], out)
484486
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
@@ -602,13 +604,7 @@ def get_graph(
602604
count_include_pad=count_include_pad,
603605
)
604606
else:
605-
out = op(
606-
x,
607-
pool_size=pool_size,
608-
strides=strides,
609-
padding=padding,
610-
ceil_mode=ceil_mode,
611-
)
607+
out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode)
612608
f = relay.Function([x], out)
613609
return f, {"x": x_shape}, []
614610

@@ -726,11 +722,7 @@ def get_graph(x_shape, indices_or_sections, axis):
726722

727723
def test_conv2d_transpose(run_module):
728724
def get_graph(
729-
x_shape=(1, 32, 8, 8),
730-
k_shape=(32, 16, 3, 3),
731-
groups=1,
732-
padding=(0, 0),
733-
strides=(1, 1),
725+
x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), groups=1, padding=(0, 0), strides=(1, 1)
734726
):
735727
x = relay.var("x", shape=(x_shape), dtype="float32")
736728
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
@@ -1009,24 +1001,10 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5):
10091001
gamma = relay.var("gamma", shape=(param_shape), dtype="float32")
10101002
beta = relay.var("beta", shape=(param_shape), dtype="float32")
10111003
out = relay.nn.layer_norm(
1012-
x,
1013-
gamma=gamma,
1014-
beta=beta,
1015-
axis=axis,
1016-
epsilon=epsilon,
1017-
center=True,
1018-
scale=True,
1004+
x, gamma=gamma, beta=beta, axis=axis, epsilon=epsilon, center=True, scale=True
10191005
)
10201006
f = relay.Function([x, gamma, beta], out)
1021-
return (
1022-
f,
1023-
{
1024-
"x": x_shape,
1025-
"beta": param_shape,
1026-
"gamma": param_shape,
1027-
},
1028-
["beta", "gamma"],
1029-
)
1007+
return (f, {"x": x_shape, "beta": param_shape, "gamma": param_shape}, ["beta", "gamma"])
10301008

10311009
run_and_verify_func(get_graph((1, 32, 8, 8), (32,)), run_module=run_module)
10321010
run_and_verify_func(
@@ -1170,20 +1148,9 @@ def test_strided_slice(run_module):
11701148
def get_graph(x_shape, begin, end, strides=None, slice_mode="size"):
11711149
x = relay.var("x", shape=(x_shape), dtype="float32")
11721150
if strides:
1173-
out = relay.strided_slice(
1174-
x,
1175-
begin,
1176-
end,
1177-
strides,
1178-
slice_mode=slice_mode,
1179-
)
1151+
out = relay.strided_slice(x, begin, end, strides, slice_mode=slice_mode)
11801152
else:
1181-
out = relay.strided_slice(
1182-
x,
1183-
begin,
1184-
end,
1185-
slice_mode=slice_mode,
1186-
)
1153+
out = relay.strided_slice(x, begin, end, slice_mode=slice_mode)
11871154
f = relay.Function([x], out)
11881155
return f, {"x": x_shape}, []
11891156

@@ -1292,13 +1259,7 @@ def get_graph(
12921259
count_include_pad=count_include_pad,
12931260
)
12941261
else:
1295-
out = op(
1296-
x,
1297-
pool_size=pool_size,
1298-
strides=strides,
1299-
padding=padding,
1300-
ceil_mode=ceil_mode,
1301-
)
1262+
out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode)
13021263
f = relay.Function([x], out)
13031264
return f, {"x": x_shape}, []
13041265

0 commit comments

Comments
 (0)