Skip to content

Commit e36ceb0

Browse files
author
Michalis Papapdimitriou
committed
FP16 support for TRT
1 parent d101c50 commit e36ceb0

File tree

6 files changed

+442
-324
lines changed

6 files changed

+442
-324
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 14 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,6 @@ def _func_wrapper(expr):
202202
# ops with dynamic shapes are offloaded to VM
203203
if check_dynamism(args, op_name):
204204
return False
205-
if any([x.checked_type.dtype != "float32" for x in args]):
206-
logger.info("Only float32 inputs are supported for TensorRT.")
207-
return False
208205
if op_name == "multiply":
209206
shapes = [
210207
[
@@ -325,9 +322,6 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
325322
if get_tensorrt_use_implicit_batch_mode() and any([len(shape) < 1 for shape in shapes]):
326323
return False
327324

328-
if any([x.checked_type.dtype != "float32" for x in args]):
329-
logger.info("Only float32 inputs are supported for TensorRT.")
330-
return False
331325
if (
332326
not get_tensorrt_use_implicit_batch_mode()
333327
and (isinstance(args[0], Constant) or isinstance(args[1], Constant))
@@ -347,9 +341,6 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
347341
"""Check if nn.batch_norm is supported by TensorRT."""
348342

349343
attrs, args = expr.attrs, expr.args
350-
if any([x.checked_type.dtype != "float32" for x in args]):
351-
logger.info("Only float32 inputs are supported for TensorRT.")
352-
return False
353344
if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1):
354345
logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.")
355346
return False
@@ -366,10 +357,7 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
366357
def softmax_annotate_fn(expr): # pylint: disable=unused-variable
367358
"""Check if nn.softmax is supported by TensorRT."""
368359

369-
attrs, args = expr.attrs, expr.args
370-
if any([x.checked_type.dtype != "float32" for x in args]):
371-
logger.info("Only float32 inputs are supported for TensorRT.")
372-
return False
360+
attrs = expr.attrs
373361
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
374362
logger.info("nn.softmax: can't modify batch dimension.")
375363
return False
@@ -380,10 +368,7 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable
380368
def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
381369
"""Check if nn.conv1d is supported by TensorRT."""
382370

383-
attrs, args = expr.attrs, expr.args
384-
if any([x.checked_type.dtype != "float32" for x in args]):
385-
logger.info("Only float32 inputs are supported for TensorRT.")
386-
return False
371+
attrs = expr.attrs
387372
if attrs.data_layout != "NCW":
388373
logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout)
389374
return False
@@ -397,10 +382,7 @@ def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
397382
def conv2d_annotate_fn(expr): # pylint: disable=unused-variable
398383
"""Check if nn.conv2d is supported by TensorRT."""
399384

400-
attrs, args = expr.attrs, expr.args
401-
if any([x.checked_type.dtype != "float32" for x in args]):
402-
logger.info("Only float32 inputs are supported for TensorRT.")
403-
return False
385+
attrs = expr.attrs
404386
if attrs.data_layout != "NCHW":
405387
logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout)
406388
return False
@@ -418,9 +400,6 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
418400
"""Check if dense is supported by TensorRT."""
419401

420402
args = expr.args
421-
if any([x.checked_type.dtype != "float32" for x in args]):
422-
logger.info("Only float32 inputs are supported for TensorRT.")
423-
return False
424403
input_rank = len(args[0].checked_type.shape)
425404
weight_rank = len(args[1].checked_type.shape)
426405
if input_rank not in (2, 3, 4):
@@ -436,9 +415,6 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
436415
def batch_matmul_annotate_fn(expr):
437416
"""Check if dense is supported by TensorRT."""
438417

439-
if any([x.checked_type.dtype != "float32" for x in expr.args]):
440-
logger.info("Only float32 inputs are supported for TensorRT.")
441-
return False
442418
if get_tensorrt_use_implicit_batch_mode() and len(expr.args[0].checked_type.shape) != len(
443419
expr.args[1].checked_type.shape
444420
):
@@ -451,9 +427,6 @@ def batch_matmul_annotate_fn(expr):
451427
def layer_norm_annotate_fn(expr):
452428
"""Check if dense is supported by TensorRT."""
453429

454-
if any([x.checked_type.dtype != "float32" for x in expr.args]):
455-
logger.info("Only float32 inputs are supported for TensorRT.")
456-
return False
457430
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
458431
logger.info("nn.layer_norm: requires use_implict_batch=False.")
459432
return False
@@ -465,9 +438,6 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
465438
"""Check if nn.bias_add is supported by TensorRT."""
466439

467440
args = expr.args
468-
if any([x.checked_type.dtype != "float32" for x in args]):
469-
logger.info("Only float32 inputs are supported for TensorRT.")
470-
return False
471441
input_rank = len(args[0].checked_type.shape)
472442
if input_rank not in (2, 3, 4):
473443
logger.info("nn.bias_add: input rank is %d but must be 2, 3 or 4.", input_rank)
@@ -479,10 +449,7 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
479449
def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
480450
"""Check if nn.max_pool2d is supported by TensorRT."""
481451

482-
attrs, args = expr.attrs, expr.args
483-
if any([x.checked_type.dtype != "float32" for x in args]):
484-
logger.info("Only float32 inputs are supported for TensorRT.")
485-
return False
452+
attrs = expr.attrs
486453
if attrs.layout != "NCHW":
487454
logger.info("nn.max_pool2d: layout is %s but must be NCHW.", attrs.layout)
488455
return False
@@ -496,10 +463,7 @@ def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
496463
def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
497464
"""Check if nn.avg_pool2d is supported by TensorRT."""
498465

499-
attrs, args = expr.attrs, expr.args
500-
if any([x.checked_type.dtype != "float32" for x in args]):
501-
logger.info("Only float32 inputs are supported for TensorRT.")
502-
return False
466+
attrs = expr.attrs
503467
if attrs.layout != "NCHW":
504468
logger.info("nn.avg_pool2d: layout is %d but must be NCHW.", attrs.layout)
505469
return False
@@ -526,10 +490,7 @@ def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
526490
def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
527491
"""Check if nn.global_max_pool2d is supported by TensorRT."""
528492

529-
attrs, args = expr.attrs, expr.args
530-
if any([x.checked_type.dtype != "float32" for x in args]):
531-
logger.info("Only float32 inputs are supported for TensorRT.")
532-
return False
493+
attrs = expr.attrs
533494
if attrs.layout != "NCHW":
534495
logger.info("nn.global_max_pool2d: layout is %s but must be NCHW.", attrs.layout)
535496
return False
@@ -540,10 +501,7 @@ def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
540501
def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
541502
"""Check if nn.global_avg_pool2d is supported by TensorRT."""
542503

543-
attrs, args = expr.attrs, expr.args
544-
if any([x.checked_type.dtype != "float32" for x in args]):
545-
logger.info("Only float32 inputs are supported for TensorRT.")
546-
return False
504+
attrs = expr.attrs
547505
if attrs.layout != "NCHW":
548506
logger.info("nn.global_avg_pool2d: layout is %s but must be NCHW.", attrs.layout)
549507
return False
@@ -554,10 +512,7 @@ def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
554512
def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
555513
"""Check if expand_dims is supported by TensorRT."""
556514

557-
attrs, args = expr.attrs, expr.args
558-
if any([x.checked_type.dtype != "float32" for x in args]):
559-
logger.info("Only float32 inputs are supported for TensorRT.")
560-
return False
515+
attrs = expr.attrs
561516
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
562517
logger.info("expand_dims: can't modify batch dimension.")
563518
return False
@@ -568,10 +523,7 @@ def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
568523
def squeeze_annotate_fn(expr): # pylint: disable=unused-variable
569524
"""Check if squeeze is supported by TensorRT."""
570525

571-
attrs, args = expr.attrs, expr.args
572-
if any([x.checked_type.dtype != "float32" for x in args]):
573-
logger.info("Only float32 inputs are supported for TensorRT.")
574-
return False
526+
attrs = expr.attrs
575527
if not attrs.axis:
576528
logger.info("squeeze: must explicitly set axis.")
577529
return False
@@ -586,9 +538,6 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
586538
"""Check if concatenate is supported by TensorRT."""
587539

588540
attrs, args = expr.attrs, expr.args
589-
if any([x.dtype != "float32" for x in args[0].checked_type.fields]):
590-
logger.info("Only float32 inputs are supported for TensorRT.")
591-
return False
592541
if not get_tensorrt_use_implicit_batch_mode():
593542
return True
594543
if int(attrs.axis) == 0:
@@ -606,9 +555,6 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
606555
def split_annotate_fn(expr):
607556
"""Check if split is supported by TensorRT."""
608557

609-
if any([x.checked_type.dtype != "float32" for x in expr.args]):
610-
logger.info("Only float32 inputs are supported for TensorRT.")
611-
return False
612558
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
613559
logger.info("split: can't modify batch dimension.")
614560
return False
@@ -619,10 +565,7 @@ def split_annotate_fn(expr):
619565
def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
620566
"""Check if nn.conv2d_transpose is supported by TensorRT."""
621567

622-
attrs, args = expr.attrs, expr.args
623-
if any([x.checked_type.dtype != "float32" for x in args]):
624-
logger.info("Only float32 inputs are supported for TensorRT.")
625-
return False
568+
attrs = expr.attrs
626569
if attrs.data_layout != "NCHW":
627570
logger.info("nn.conv2d_transpose: data_layout is %s but must be NCHW.", attrs.data_layout)
628571
return False
@@ -644,10 +587,7 @@ def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
644587
def transpose_annotate_fn(expr): # pylint: disable=unused-variable
645588
"""Check if transpose is supported by TensorRT."""
646589

647-
attrs, args = expr.attrs, expr.args
648-
if any([x.checked_type.dtype != "float32" for x in args]):
649-
logger.info("Only float32 inputs are supported for TensorRT.")
650-
return False
590+
attrs = expr.attrs
651591
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
652592
logger.info("transpose: can't modify batch dimension.")
653593
return False
@@ -658,10 +598,7 @@ def transpose_annotate_fn(expr): # pylint: disable=unused-variable
658598
def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
659599
"""Check if layout_transform is supported by TensorRT."""
660600

661-
attrs, args = expr.attrs, expr.args
662-
if any([x.checked_type.dtype != "float32" for x in args]):
663-
logger.info("Only float32 inputs are supported for TensorRT.")
664-
return False
601+
attrs = expr.attrs
665602
if (attrs.src_layout, attrs.dst_layout) not in [
666603
("NCHW", "NHWC"),
667604
("NHWC", "NCHW"),
@@ -679,9 +616,6 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
679616
def reshape_annotate_fn(expr): # pylint: disable=unused-variable
680617
"""Check if reshape is supported by TensorRT."""
681618
attrs, args = expr.attrs, expr.args
682-
if args[0].checked_type.dtype != "float32":
683-
logger.info("Only float32 inputs are supported for TensorRT.")
684-
return False
685619
if any([x < -1 for x in map(int, attrs.newshape)]):
686620
logger.info("reshape: new shape dims must be explicit.")
687621
return False
@@ -740,9 +674,6 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
740674
pad_value = args[1]
741675
assert isinstance(pad_value, relay.Constant)
742676
pad_value = pad_value.data.numpy().item()
743-
if any([x.checked_type.dtype != "float32" for x in args]):
744-
logger.info("Only float32 inputs are supported for TensorRT.")
745-
return False
746677
if attrs.pad_mode != "constant":
747678
logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode)
748679
return False
@@ -766,9 +697,6 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
766697
"""Check if strided_slice is supported by TensorRT."""
767698

768699
attrs, args = expr.attrs, expr.args
769-
if args[0].checked_type.dtype != "float32":
770-
logger.info("Only float32 inputs are supported for TensorRT.")
771-
return False
772700
if not trt_version_annotate_fn((5, 1, 5))(attrs, args, "strided_slice"):
773701
return False
774702
if get_tensorrt_use_implicit_batch_mode():
@@ -813,10 +741,7 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
813741
def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
814742
"""Check if nn.adaptive_max_pool2d is supported by TensorRT."""
815743

816-
attrs, args = expr.attrs, expr.args
817-
if any([x.checked_type.dtype != "float32" for x in args]):
818-
logger.info("Only float32 inputs are supported for TensorRT.")
819-
return False
744+
attrs = expr.attrs
820745
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
821746
logger.info("nn.adaptive_max_pool2d: output size must be (1, 1).")
822747
return False
@@ -827,10 +752,7 @@ def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
827752
def adaptive_avg_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
828753
"""Check if nn.adaptive_avg_pool2d is supported by TensorRT."""
829754

830-
attrs, args = expr.attrs, expr.args
831-
if any([x.checked_type.dtype != "float32" for x in args]):
832-
logger.info("Only float32 inputs are supported for TensorRT.")
833-
return False
755+
attrs = expr.attrs
834756
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
835757
logger.info("nn.adaptive_avg_pool2d: output size must be (1, 1).")
836758
return False
@@ -842,9 +764,6 @@ def conv3d_annotate_fn(expr): # pylint: disable=unused-variable
842764
"""Check if nn.conv3d is supported by TensorRT."""
843765

844766
attrs, args = expr.attrs, expr.args
845-
if any([x.checked_type.dtype != "float32" for x in args]):
846-
logger.info("Only float32 inputs are supported for TensorRT.")
847-
return False
848767
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d"):
849768
return False
850769
if attrs.data_layout != "NCDHW":
@@ -864,9 +783,6 @@ def max_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
864783
"""Check if nn.max_pool3d is supported by TensorRT."""
865784

866785
attrs, args = expr.attrs, expr.args
867-
if any([x.checked_type.dtype != "float32" for x in args]):
868-
logger.info("Only float32 inputs are supported for TensorRT.")
869-
return False
870786
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.max_pool3d"):
871787
return False
872788
if attrs.layout != "NCDHW":
@@ -880,9 +796,6 @@ def avg_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
880796
"""Check if nn.avg_pool3d is supported by TensorRT."""
881797

882798
attrs, args = expr.attrs, expr.args
883-
if any([x.checked_type.dtype != "float32" for x in args]):
884-
logger.info("Only float32 inputs are supported for TensorRT.")
885-
return False
886799
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.avg_pool3d"):
887800
return False
888801
if attrs.layout != "NCDHW":
@@ -896,9 +809,6 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
896809
"""Check if nn.conv3d_transpose is supported by TensorRT."""
897810

898811
attrs, args = expr.attrs, expr.args
899-
if any([x.checked_type.dtype != "float32" for x in args]):
900-
logger.info("Only float32 inputs are supported for TensorRT.")
901-
return False
902812
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d_transpose"):
903813
return False
904814
if attrs.data_layout != "NCDHW":

src/runtime/contrib/tensorrt/tensorrt_builder.cc

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,10 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode&
8585
shape.erase(shape.begin());
8686
}
8787
nvinfer1::Dims dims = VectorToTrtDims(shape);
88-
ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported.";
89-
auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims);
88+
auto tensor_dtype =
89+
(dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
90+
91+
auto input_tensor = network_->addInput(name.c_str(), tensor_dtype, dims);
9092
node_output_map_[nid].push_back(TensorRTOpInput(input_tensor));
9193
network_input_names_.push_back(name);
9294
entry_id_map_[name] = entry_id + i;
@@ -139,6 +141,7 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) {
139141
<< " requires weights but got a tensor.";
140142
}
141143
}
144+
VLOG(1) << "INT " << input.type;
142145
params.inputs.push_back(input);
143146
}
144147
ICHECK(converter->variable_input_count || converter->input_types.size() == params.inputs.size())
@@ -150,6 +153,10 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) {
150153
// Get outputs.
151154
node_output_map_[nid] = {};
152155
for (auto out : params.outputs) {
156+
// out->setType(params.inputs.at(1).weight.type);
157+
// out->setType(nvinfer1::DataType::kFLOAT);
158+
out->setType(nvinfer1::DataType::kHALF);
159+
153160
node_output_map_[nid].push_back(TensorRTOpInput(out));
154161
}
155162
}
@@ -205,18 +212,16 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
205212
nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,
206213
DLDeviceType src_device) {
207214
ICHECK_EQ(dptr->device.device_type, src_device);
208-
ICHECK(static_cast<int>(dptr->dtype.code) == kDLFloat ||
209-
static_cast<int>(dptr->dtype.code) == kDLInt);
210-
const auto trt_dtype = static_cast<int>(dptr->dtype.code) == kDLFloat
211-
? nvinfer1::DataType::kFLOAT
212-
: nvinfer1::DataType::kINT32;
215+
216+
const auto trt_dtype = (static_cast<int>(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF
217+
: nvinfer1::DataType::kFLOAT;
218+
213219
const size_t weight_bytes = GetDataSize(*dptr);
214220
nvinfer1::Weights weight{trt_dtype, nullptr, 0};
215221
size_t count = 1;
216222
for (tvm_index_t i = 0; i < dptr->ndim; ++i) {
217223
count *= dptr->shape[i];
218224
}
219-
ICHECK_EQ(count * 4, weight_bytes);
220225
weight.count = count;
221226
weight.values = new float[count];
222227
ICHECK_EQ(TVMArrayCopyToBytes(const_cast<DLTensor*>(dptr), const_cast<void*>(weight.values),
@@ -250,7 +255,7 @@ void TensorRTBuilder::CleanUp() {
250255
#endif
251256
builder_->destroy();
252257
for (auto weight : trt_weights_) {
253-
if (weight.type == nvinfer1::DataType::kFLOAT) {
258+
if (static_cast<int>(weight.type) <= 1) {
254259
delete[] static_cast<const float*>(weight.values);
255260
} else {
256261
delete[] static_cast<const uint16_t*>(weight.values);

0 commit comments

Comments
 (0)