diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 992112139842..760383d9d209 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -28,6 +28,20 @@ from tvm.relay.expr_functor import ExprMutator, ExprVisitor logger = logging.getLogger("TensorRT") +supported_types = ["float32", "float16"] + + +def is_supported_trt_dtype(args): + """Check if the TensorRT BYOC support input tensor dtype. + Returns + ------- + ret: bool + True if supported, False if not. + """ + if any([x.checked_type.dtype in supported_types for x in args]): + logger.info("Only float32 and float16 inputs are supported for TensorRT BYOC.") + return True + return False def is_tensorrt_runtime_enabled(): @@ -87,6 +101,8 @@ def partition_for_tensorrt( use_implicit_batch=True, remove_no_mac_subgraphs=False, max_workspace_size=1 << 30, + use_fp16=False, + use_uint8=False, ): """Partition the graph greedily offloading supported operators to TensorRT. @@ -110,6 +126,13 @@ def partition_for_tensorrt( max_workspace_size : Optional[int] How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation. See TensorRT documentation for more info. + use_fp16: Optional[bool] + Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled + if FP16 inputs tensors and weights are used. + Note that TensorRT will still choose a higher-precision kernel if it results in overall + lower runtime, or if no low-precision implementation exists. + use_uint8: Optional[bool] + Allows, TRT to automatically convert FP32 inputs to UINT8. Returns ------- mod_and_config : Tuple[Module, Dict[str, Any]] @@ -120,6 +143,8 @@ def partition_for_tensorrt( "use_implicit_batch": use_implicit_batch, "max_workspace_size": max_workspace_size, "remove_no_mac_subgraphs": remove_no_mac_subgraphs, + "use_fp16": use_fp16, + "use_uint8": use_uint8, } if version: assert isinstance(version, tuple) and len(version) == 3 @@ -186,11 +211,7 @@ def check_dynamism(args, op_name): elif isinstance(arg, Tuple): return check_dynamism(arg.fields, op_name) else: - logger.info( - "Arg not supported in TensorRT for %s with type %s", - op_name, - type(arg), - ) + logger.info("Arg not supported in TensorRT for %s with type %s", op_name, type(arg)) return True return False @@ -200,10 +221,9 @@ def _register_external_op_helper_with_checker(op_name, checker): def _func_wrapper(expr): attrs, args = expr.attrs, expr.args # ops with dynamic shapes are offloaded to VM - if check_dynamism(args, op_name): + if not is_supported_trt_dtype(args): return False - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if check_dynamism(args, op_name): return False if op_name == "multiply": shapes = [ @@ -315,7 +335,8 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable """Check if add is supported by TensorRT.""" args = expr.args - + if not is_supported_trt_dtype(args): + return False shapes = [ [int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 for x in arg.checked_type.shape] for arg in args @@ -325,9 +346,6 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable if get_tensorrt_use_implicit_batch_mode() and any([len(shape) < 1 for shape in shapes]): return False - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") - return False if ( not get_tensorrt_use_implicit_batch_mode() and (isinstance(args[0], Constant) or isinstance(args[1], Constant)) @@ -347,8 +365,7 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.batch_norm is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1): logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.") @@ -367,8 +384,7 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.softmax is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0: logger.info("nn.softmax: can't modify batch dimension.") @@ -381,8 +397,7 @@ def conv1d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv1d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if attrs.data_layout != "NCW": logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout) @@ -398,8 +413,7 @@ def conv2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv2d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if attrs.data_layout != "NCHW": logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout) @@ -418,8 +432,7 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable """Check if dense is supported by TensorRT.""" args = expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False input_rank = len(args[0].checked_type.shape) weight_rank = len(args[1].checked_type.shape) @@ -436,8 +449,8 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable def batch_matmul_annotate_fn(expr): """Check if dense is supported by TensorRT.""" - if any([x.checked_type.dtype != "float32" for x in expr.args]): - logger.info("Only float32 inputs are supported for TensorRT.") + args = expr.args + if not is_supported_trt_dtype(args): return False if get_tensorrt_use_implicit_batch_mode() and len(expr.args[0].checked_type.shape) != len( expr.args[1].checked_type.shape @@ -451,8 +464,8 @@ def batch_matmul_annotate_fn(expr): def layer_norm_annotate_fn(expr): """Check if dense is supported by TensorRT.""" - if any([x.checked_type.dtype != "float32" for x in expr.args]): - logger.info("Only float32 inputs are supported for TensorRT.") + args = expr.args + if not is_supported_trt_dtype(args): return False if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0: logger.info("nn.layer_norm: requires use_implict_batch=False.") @@ -465,8 +478,7 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.bias_add is supported by TensorRT.""" args = expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False input_rank = len(args[0].checked_type.shape) if input_rank not in (2, 3, 4): @@ -480,8 +492,7 @@ def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.max_pool2d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if attrs.layout != "NCHW": logger.info("nn.max_pool2d: layout is %s but must be NCHW.", attrs.layout) @@ -497,8 +508,7 @@ def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.avg_pool2d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if attrs.layout != "NCHW": logger.info("nn.avg_pool2d: layout is %d but must be NCHW.", attrs.layout) @@ -527,8 +537,7 @@ def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.global_max_pool2d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if attrs.layout != "NCHW": logger.info("nn.global_max_pool2d: layout is %s but must be NCHW.", attrs.layout) @@ -541,8 +550,7 @@ def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.global_avg_pool2d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if attrs.layout != "NCHW": logger.info("nn.global_avg_pool2d: layout is %s but must be NCHW.", attrs.layout) @@ -555,8 +563,7 @@ def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable """Check if expand_dims is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0: logger.info("expand_dims: can't modify batch dimension.") @@ -569,8 +576,7 @@ def squeeze_annotate_fn(expr): # pylint: disable=unused-variable """Check if squeeze is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if not attrs.axis: logger.info("squeeze: must explicitly set axis.") @@ -586,9 +592,8 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable """Check if concatenate is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.dtype != "float32" for x in args[0].checked_type.fields]): - logger.info("Only float32 inputs are supported for TensorRT.") - return False + if any([x.dtype not in supported_types for x in args[0].checked_type.fields]): + logger.info("Only float16 and float32 inputs are supported for TensorRT.") if not get_tensorrt_use_implicit_batch_mode(): return True if int(attrs.axis) == 0: @@ -606,8 +611,8 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable def split_annotate_fn(expr): """Check if split is supported by TensorRT.""" - if any([x.checked_type.dtype != "float32" for x in expr.args]): - logger.info("Only float32 inputs are supported for TensorRT.") + args = expr.args + if not is_supported_trt_dtype(args): return False if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0: logger.info("split: can't modify batch dimension.") @@ -620,8 +625,7 @@ def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv2d_transpose is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if attrs.data_layout != "NCHW": logger.info("nn.conv2d_transpose: data_layout is %s but must be NCHW.", attrs.data_layout) @@ -645,8 +649,7 @@ def transpose_annotate_fn(expr): # pylint: disable=unused-variable """Check if transpose is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0: logger.info("transpose: can't modify batch dimension.") @@ -659,8 +662,7 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable """Check if layout_transform is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if (attrs.src_layout, attrs.dst_layout) not in [ ("NCHW", "NHWC"), @@ -679,8 +681,7 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable def reshape_annotate_fn(expr): # pylint: disable=unused-variable """Check if reshape is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if args[0].checked_type.dtype != "float32": - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if any([x < -1 for x in map(int, attrs.newshape)]): logger.info("reshape: new shape dims must be explicit.") @@ -737,12 +738,11 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.pad is supported by TensorRT.""" attrs, args = expr.attrs, expr.args + if not is_supported_trt_dtype(args): + return False pad_value = args[1] assert isinstance(pad_value, relay.Constant) pad_value = pad_value.data.numpy().item() - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") - return False if attrs.pad_mode != "constant": logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode) return False @@ -766,8 +766,7 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable """Check if strided_slice is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if args[0].checked_type.dtype != "float32": - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if not trt_version_annotate_fn((5, 1, 5))(attrs, args, "strided_slice"): return False @@ -814,8 +813,7 @@ def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.adaptive_max_pool2d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]): logger.info("nn.adaptive_max_pool2d: output size must be (1, 1).") @@ -828,8 +826,7 @@ def adaptive_avg_pool2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.adaptive_avg_pool2d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]): logger.info("nn.adaptive_avg_pool2d: output size must be (1, 1).") @@ -842,8 +839,7 @@ def conv3d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv3d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d"): return False @@ -864,8 +860,7 @@ def max_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.max_pool3d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.max_pool3d"): return False @@ -880,8 +875,7 @@ def avg_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.avg_pool3d is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.avg_pool3d"): return False @@ -896,8 +890,7 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv3d_transpose is supported by TensorRT.""" attrs, args = expr.attrs, expr.args - if any([x.checked_type.dtype != "float32" for x in args]): - logger.info("Only float32 inputs are supported for TensorRT.") + if not is_supported_trt_dtype(args): return False if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d_transpose"): return False @@ -990,11 +983,8 @@ def is_valid_subgraph(params, body): if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1: logger.info("tensorrt: inputs have different batch sizes") return False - if ( - get_tensorrt_remove_no_mac_subgraphs() - and not IsComputeIntensiveGraph().is_graph_compute_intensive(body) - ): - return False + if get_tensorrt_remove_no_mac_subgraphs(): + return IsComputeIntensiveGraph().is_graph_compute_intensive(body) return True diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index d83a9003229c..431be8ed3dc3 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -46,6 +46,8 @@ struct TensorRTCompilerConfigNode : public tvm::AttrsNodetensorrt_version[2])}; std::vector use_implicit_batch = {std::to_string(cfg.value()->use_implicit_batch)}; std::vector max_workspace_size = {std::to_string(cfg.value()->max_workspace_size)}; - std::vector tensorrt_version_attr, use_implicit_batch_attr, max_workspace_size_attr; + std::vector use_fp16 = {std::to_string(cfg.value()->use_fp16)}; + std::vector use_uint8 = {std::to_string(cfg.value()->use_uint8)}; + std::vector tensorrt_version_attr, use_implicit_batch_attr, max_workspace_size_attr, + use_fp16_attr, use_uint8_attr; tensorrt_version_attr.emplace_back(tensorrt_version); use_implicit_batch_attr.emplace_back(use_implicit_batch); max_workspace_size_attr.emplace_back(max_workspace_size); + use_fp16_attr.emplace_back(use_fp16); + use_uint8_attr.emplace_back(use_uint8); node->SetAttr("tensorrt_version", tensorrt_version_attr); node->SetAttr("use_implicit_batch", use_implicit_batch_attr); node->SetAttr("max_workspace_size", max_workspace_size_attr); + node->SetAttr("use_fp16", use_fp16_attr); + node->SetAttr("use_uint8", use_uint8_attr); } }; diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index c60928e95db4..4f196265b51b 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -85,8 +85,13 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& shape.erase(shape.begin()); } nvinfer1::Dims dims = VectorToTrtDims(shape); - ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported."; - auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims); + ICHECK((dtypes[i].bits != 16 || dtypes[i].bits != 32)) + << "Invalid input Tensor type. Float16 and Float32 are supported"; + + auto tensor_dtype = + (dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; + + auto input_tensor = network_->addInput(name.c_str(), tensor_dtype, dims); node_output_map_[nid].push_back(TensorRTOpInput(input_tensor)); network_input_names_.push_back(name); entry_id_map_[name] = entry_id + i; @@ -141,8 +146,6 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { } params.inputs.push_back(input); } - ICHECK(converter->variable_input_count || converter->input_types.size() == params.inputs.size()) - << "Op expected a different number of inputs."; // Convert op to TRT. converter->Convert(¶ms); @@ -150,6 +153,11 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { // Get outputs. node_output_map_[nid] = {}; for (auto out : params.outputs) { + auto out_type = params.inputs.at(1).weight.type == params.inputs.at(0).tensor->getType() + ? params.inputs.at(0).tensor->getType() + : params.inputs.at(1).weight.type; + out->setType(out_type); + node_output_map_[nid].push_back(TensorRTOpInput(out)); } } @@ -205,18 +213,17 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, DLDeviceType src_device) { ICHECK_EQ(dptr->device.device_type, src_device); - ICHECK(static_cast(dptr->dtype.code) == kDLFloat || - static_cast(dptr->dtype.code) == kDLInt); - const auto trt_dtype = static_cast(dptr->dtype.code) == kDLFloat - ? nvinfer1::DataType::kFLOAT - : nvinfer1::DataType::kINT32; + ICHECK((dptr->dtype.bits != 16 || dptr->dtype.bits != 32)) + << "Invalid input Tensor type. Float16 and Float32 are supported"; + const auto trt_dtype = (static_cast(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT; + const size_t weight_bytes = GetDataSize(*dptr); nvinfer1::Weights weight{trt_dtype, nullptr, 0}; size_t count = 1; for (tvm_index_t i = 0; i < dptr->ndim; ++i) { count *= dptr->shape[i]; } - ICHECK_EQ(count * 4, weight_bytes); weight.count = count; weight.values = new float[count]; ICHECK_EQ(TVMArrayCopyToBytes(const_cast(dptr), const_cast(weight.values), @@ -250,7 +257,7 @@ void TensorRTBuilder::CleanUp() { #endif builder_->destroy(); for (auto weight : trt_weights_) { - if (weight.type == nvinfer1::DataType::kFLOAT) { + if (weight.type == nvinfer1::DataType::kFLOAT || weight.type == nvinfer1::DataType::kHALF) { delete[] static_cast(weight.values); } else { delete[] static_cast(weight.values); diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h index bf74630bce7f..13a118340e11 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -68,7 +68,7 @@ class TensorRTBuilder { * \param logger TensorRT logger to use for errors and warnings. * \param max_workspace_size Workspace size parameter for TensorRT engine build phase. * \param use_implicit_batch Whether to use implicit batch mode (default) - * \param use_fp16 Whether to use implicit batch mode (default) + * \param use_fp16 Whether to automatically convert a model to fp16 * \param batch_size If use_implicit_batch, */ TensorRTBuilder(TensorRTLogger* logger, const std::vector& data_entry, diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index a27fe1114af9..2c5f293bc431 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -49,6 +49,7 @@ nvinfer1::ITensor* TensorRTOpConverter::Reshape(TensorRTOpConverterParams* param auto layer = params->network->addShuffle(*input); ICHECK(layer != nullptr); layer->setReshapeDimensions(VectorToTrtDims(new_shape)); + layer->setOutputType(0, input->getType()); return layer->getOutput(0); } @@ -99,7 +100,8 @@ nvinfer1::ITensor* TensorRTOpConverter::CreateScalar( std::fill_n(dims.d, dims.nbDims, 1); float* values = new float[1]; values[0] = value; - nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT, static_cast(values), 1}; + const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; + nvinfer1::Weights weights{weight_type, static_cast(values), 1}; params->trt_weights->push_back(weights); return params->network->addConstant(dims, weights)->getOutput(0); } @@ -252,7 +254,9 @@ class Conv1DOpConverter : public TensorRTOpConverter { input_tensor = shuffle_layer->getOutput(0); const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], 1); - nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; + + nvinfer1::Weights bias{weight_type, nullptr, 0}; auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, params->inputs.at(1).weight, bias); @@ -313,7 +317,8 @@ class Conv2DOpConverter : public TensorRTOpConverter { #endif const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); - nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; + nvinfer1::Weights bias{weight_type, nullptr, 0}; auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, params->inputs.at(1).weight, bias); ICHECK(conv_layer != nullptr); @@ -361,7 +366,8 @@ class Conv3DOpConverter : public TensorRTOpConverter { const int num_outputs = std::stoi(params->node.GetAttr>("channels")[0]); const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); - nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; + nvinfer1::Weights bias{weight_type, nullptr, 0}; auto conv_layer = params->network->addConvolutionNd(*input_tensor, num_outputs, kernel_size, params->inputs.at(1).weight, bias); ICHECK(conv_layer != nullptr); @@ -404,7 +410,8 @@ class DenseOpConverter : public TensorRTOpConverter { // Weights are in KC format. ICHECK_EQ(params->inputs.at(1).weight_shape.size(), 2); const int num_units = params->inputs.at(1).weight_shape[0]; - nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; + nvinfer1::Weights bias{weight_type, nullptr, 0}; nvinfer1::IFullyConnectedLayer* fc_layer = params->network->addFullyConnected( *input_tensor, num_units, params->inputs.at(1).weight, bias); ICHECK(fc_layer != nullptr); @@ -466,12 +473,15 @@ class BatchNormOpConverter : public TensorRTOpConverter { } void* weight_scale_ptr = new float[gamma.count]; - nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, gamma.count}; + const nvinfer1::DataType weight_type_scale = params->inputs.at(1).weight.type; + nvinfer1::Weights weight_scale{weight_type_scale, weight_scale_ptr, gamma.count}; params->trt_weights->push_back(weight_scale); void* weight_shift_ptr = new float[gamma.count]; - nvinfer1::Weights weight_shift{nvinfer1::DataType::kFLOAT, weight_shift_ptr, gamma.count}; + const nvinfer1::DataType weight_type_shift = params->inputs.at(2).weight.type; + nvinfer1::Weights weight_shift{weight_type_shift, weight_shift_ptr, gamma.count}; params->trt_weights->push_back(weight_shift); - nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; + const nvinfer1::DataType weight_type_power = params->inputs.at(3).weight.type; + nvinfer1::Weights power{weight_type_power, nullptr, 0}; // fill in the content of weights for the Scale layer const float* gamma_ptr = reinterpret_cast(gamma.values); @@ -911,8 +921,10 @@ class BiasAddOpConverter : public TensorRTOpConverter { input_tensor = Reshape(params, input_tensor, new_shape); } - nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, 0}; - nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; + const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; + + nvinfer1::Weights shift{weight_type, nullptr, 0}; + nvinfer1::Weights power{weight_type, nullptr, 0}; nvinfer1::IScaleLayer* scale_layer = params->network->addScale( *input_tensor, nvinfer1::ScaleMode::kCHANNEL, params->inputs.at(1).weight, shift, power); ICHECK(scale_layer != nullptr); @@ -962,7 +974,8 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter { const int num_outputs = std::stoi(params->node.GetAttr>("channels")[0]); const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); - nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; + nvinfer1::Weights bias{weight_type, nullptr, 0}; auto deconv_layer = params->network->addDeconvolution(*input_tensor, num_outputs, kernel_size, params->inputs.at(1).weight, bias); ICHECK(deconv_layer != nullptr); @@ -1020,7 +1033,8 @@ class Conv3DTransposeOpConverter : public TensorRTOpConverter { const int num_outputs = std::stoi(params->node.GetAttr>("channels")[0]); const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); - nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; + nvinfer1::Weights bias{weight_type, nullptr, 0}; auto deconv_layer = params->network->addDeconvolutionNd(*input_tensor, num_outputs, kernel_size, params->inputs.at(1).weight, bias); ICHECK(deconv_layer != nullptr); diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h index e9871d42146c..b71dec00c9be 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.h +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h @@ -76,7 +76,7 @@ struct TensorRTOpInput { std::vector weight_shape; explicit TensorRTOpInput(nvinfer1::ITensor* tensor) - : tensor(tensor), weight({nvinfer1::DataType::kFLOAT, nullptr, 0}), type(kTensor) {} + : tensor(tensor), weight({tensor->getType(), nullptr, 0}), type(kTensor) {} TensorRTOpInput(nvinfer1::Weights weight, const std::vector& shape) : tensor(nullptr), weight(weight), type(kWeight), weight_shape(shape) {} }; diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index a5779f739dac..3f4fa9da9820 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -72,7 +72,8 @@ class TensorRTRuntime : public JSONRuntimeBase { use_implicit_batch_(true), max_workspace_size_(size_t(1) << 30), max_batch_size_(-1), - multi_engine_mode_(false) { + multi_engine_mode_(false), + use_fp16_(false) { const bool use_int8 = dmlc::GetEnv("TVM_TENSORRT_USE_INT8", false); multi_engine_mode_ = dmlc::GetEnv("TVM_TENSORRT_MULTI_ENGINE", false); num_calibration_batches_remaining_ = dmlc::GetEnv("TENSORRT_NUM_CALI_INT8", 0); @@ -304,7 +305,7 @@ class TensorRTRuntime : public JSONRuntimeBase { } void BuildEngineFromJson(int batch_size) { - const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false); + const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false) || use_fp16_; TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_, use_fp16, batch_size, calibrator_.get()); for (size_t i = 0; i < input_nodes_.size(); ++i) { @@ -492,6 +493,9 @@ class TensorRTRuntime : public JSONRuntimeBase { * encountered. Multi-engine mode should give better performance, at a cost of higher memory usage * and more time spent building engines. */ bool multi_engine_mode_; + + /*! \brief Use auto-conversion to fp16 */ + bool use_fp16_; }; runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 81e3cc068905..607b222bc91d 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -14,26 +14,36 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import tvm.testing +from curses import tparm +from unittest import result import numpy as np import time import pytest import itertools +import pdb + import tvm +from tvm.relay.op.contrib.bnns import dtype_is_supported import tvm.relay.testing from tvm import relay, runtime from tvm.relay.op.contrib import tensorrt from tvm.contrib import graph_executor, utils from tvm.runtime.vm import VirtualMachine -from tvm.relay import Any, GlobalVar, transform + +from tvm.relay import Any, GlobalVar +from tvm.relay.transform import FirstOrderGradient, InferType +from tvm.relay.transform.transform import ToMixedPrecision + from tvm.relay.expr_functor import ExprVisitor from typing import Dict, Tuple, Union from tvm.contrib.download import download from tvm.relay.op.contrib import tensorrt -import tvm.testing +SUPPORTED_DTYPES = ["float16", "float32"] has_tensorrt_codegen = pytest.mark.skipif( not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" @@ -60,12 +70,15 @@ def vmobj_to_list(o): raise RuntimeError("Unknown object type: %s" % type(o)) -def assert_result_dict_holds(result_dict): +def assert_result_dict_holds(result_dict, dtype="float16"): for k1, k2 in itertools.combinations(result_dict, 2): res1 = vmobj_to_list(result_dict[k1]) res2 = vmobj_to_list(result_dict[k2]) for r1, r2 in zip(res1, res2): - tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3) + if dtype == "float16": + tvm.testing.assert_allclose(r1, r2, rtol=1e-1, atol=1e-1) + else: + tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3) def set_func_attr(func, compile_name, symbol_name): @@ -76,7 +89,7 @@ def set_func_attr(func, compile_name, symbol_name): return func -def run_and_verify_func(config, target="cuda", run_module=True): +def run_and_verify_func(config, target="cuda", run_module=True, data_type="float32"): """Test a Relay func by compiling, running, and comparing TVM and TRT outputs. Parameters @@ -88,40 +101,49 @@ def run_and_verify_func(config, target="cuda", run_module=True): run_module: bool If True, the built module will be run after being compiled. + + data_type: str + Check between single and double floating precision """ f, input_shapes, is_param = config - params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(np.float32) for x in is_param} + params = { + x: np.random.uniform(-1, 1, input_shapes[x]).astype(dtype=data_type) for x in is_param + } input_dict = { - k: np.random.uniform(-1, 1, v).astype(np.float32) + k: np.random.uniform(-1, 1, v).astype(dtype=data_type) for k, v in input_shapes.items() if k not in is_param } dev = tvm.device(target) result_dict = dict() - for mode in ["graph", "vm"]: - for use_trt in [False, True]: - mod = tvm.IRModule() - mod["main"] = f - result_key = mode + ("_trt" if use_trt else "") - if use_trt: - mod, config = tensorrt.partition_for_tensorrt(mod, params) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - func = relay.create_executor( - mode, mod=mod, device=dev, target=target - ).evaluate() - else: - with tvm.transform.PassContext(opt_level=3): - func = relay.create_executor( - mode, mod=mod, device=dev, target=target - ).evaluate() - if run_module: - result_dict[result_key] = func(**input_dict, **params) + for mode in ["vm", "graph"]: + for mode in ["graph"]: + for use_trt in [True, False]: + mod = tvm.IRModule() + mod["main"] = f + result_key = mode + ("_trt" if use_trt else "") + if use_trt: + mod = relay.transform.InferType()(mod) + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + func = relay.create_executor( + mode, mod=mod, device=dev, target=target + ).evaluate() + else: + mod = relay.transform.InferType()(mod) + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor( + mode, mod=mod, device=dev, target=target + ).evaluate() - if run_module: - assert_result_dict_holds(result_dict) + if run_module: + result_dict[result_key] = func(**input_dict, **params) + + if run_module: + assert_result_dict_holds(result_dict, data_type) def run_and_verify_model(model, run_module): @@ -174,45 +196,47 @@ def compile_and_run(mod, params, i_data, mode="vm", use_trt=True): def test_tensorrt_simple(run_module): - dtype = "float32" - xshape = (1, 3, 2, 2) - yshape = (1, 3, 1, 1) - zshape = (1, 1, 1, 1) - x = relay.var("x", shape=(xshape), dtype=dtype) - y = relay.var("y", shape=(yshape), dtype=dtype) - z = relay.var("z", shape=(zshape), dtype=dtype) - w = z * (x + y) - out = relay.nn.relu(w) - f = relay.Function([x, y, z], out) - - x_data = np.random.uniform(-1, 1, xshape).astype(dtype) - y_data = np.random.uniform(-1, 1, yshape).astype(dtype) - z_data = np.random.uniform(-1, 1, zshape).astype(dtype) + for dtype in SUPPORTED_DTYPES: + xshape = (1, 3, 2, 2) + yshape = (1, 3, 1, 1) + zshape = (1, 1, 1, 1) + x = relay.var("x", shape=(xshape), dtype=dtype) + y = relay.var("y", shape=(yshape), dtype=dtype) + z = relay.var("z", shape=(zshape), dtype=dtype) + w = z * (x + y) + out = relay.nn.relu(w) + f = relay.Function([x, y, z], out) + x_data = np.random.uniform(-1, 1, xshape).astype(dtype) + y_data = np.random.uniform(-1, 1, yshape).astype(dtype) + z_data = np.random.uniform(-1, 1, zshape).astype(dtype) - result_dict = dict() - for mode in ["vm", "graph"]: - for use_trt in [True, False]: - mod = tvm.IRModule() - mod["main"] = f - result_key = mode + ("_trt" if use_trt else "") - if use_trt: - mod, config = tensorrt.partition_for_tensorrt(mod) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda" - ).evaluate() - else: - with tvm.transform.PassContext(opt_level=3): - func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda" - ).evaluate() - if run_module: - result_dict[result_key] = func(x_data, y_data, z_data) + result_dict = dict() + for mode in ["vm", "graph"]: + for use_trt in [False, True]: + mod = tvm.IRModule() + mod["main"] = f + result_key = mode + ("_trt" if use_trt else "") + if use_trt: + mod = relay.transform.InferType()(mod) + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() + else: + mod = relay.transform.InferType()(mod) + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() + if run_module: + result_dict[result_key] = func(x_data, y_data, z_data) - if run_module: - assert_result_dict_holds(result_dict) + print(result_dict) + if run_module: + assert_result_dict_holds(result_dict) def test_tensorrt_simple_cpu_io(run_module): @@ -254,6 +278,9 @@ def test_tensorrt_not_compatible(run_module): results = func(x_data) +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) def test_tensorrt_serialize_graph_executor(run_module): import mxnet as mx from mxnet.gluon.model_zoo.vision import get_model @@ -308,6 +335,9 @@ def load_graph(): assert_result_dict_holds(result_dict) +@pytest.mark.xfail( + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") +) def test_tensorrt_serialize_vm(run_module): import mxnet as mx from mxnet.gluon.model_zoo.vision import get_model @@ -364,9 +394,10 @@ def get_graph( strides=(1), dilation=(1), channels=None, + d_type="float16", ): - x = relay.var("x", shape=(x_shape), dtype="float32") - kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + x = relay.var("x", shape=(x_shape), dtype=d_type) + kernel = relay.var("kernel", shape=(k_shape), dtype=d_type) out = relay.nn.conv1d( x, kernel, @@ -376,11 +407,15 @@ def get_graph( strides=strides, dilation=dilation, channels=channels, + out_dtype="float16", ) f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] - run_and_verify_func(get_graph(channels=10), run_module=run_module) + for d_type in ["float16"]: + run_and_verify_func( + get_graph(channels=10, d_type=d_type), run_module=run_module, data_type=d_type + ) def test_conv2d(run_module): @@ -392,9 +427,10 @@ def get_graph( strides=(1, 1), dilation=(1, 1), channels=None, + data_type="float16", ): - x = relay.var("x", shape=(x_shape), dtype="float32") - kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + x = relay.var("x", shape=(x_shape), dtype=data_type) + kernel = relay.var("kernel", shape=(k_shape), dtype=data_type) out = relay.nn.conv2d( x, kernel, @@ -404,6 +440,7 @@ def get_graph( strides=strides, dilation=dilation, channels=channels, + out_dtype=data_type, ) f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] @@ -421,12 +458,21 @@ def get_graph( dilation=dilation, ), run_module=run_module, + data_type="float16", ) run_and_verify_func( - get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24), + get_graph( + (1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24, data_type="float16" + ), + run_module=run_module, + data_type="float16", + ) + + run_and_verify_func( + get_graph((1, 3, 16, 16), (1, 3, 1, 1), channels=1, data_type="float32"), run_module=run_module, + data_type="float32", ) - run_and_verify_func(get_graph((1, 3, 16, 16), (1, 3, 1, 1), channels=1), run_module=run_module) def test_conv2d_nhwc(run_module): @@ -434,12 +480,7 @@ def get_graph(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)): x = relay.var("x", shape=(x_shape), dtype="float32") kernel = relay.var("kernel", shape=(k_shape), dtype="float32") out = relay.nn.conv2d( - x, - kernel, - channels=16, - kernel_size=(3, 3), - data_layout="NHWC", - kernel_layout="HWIO", + x, kernel, channels=16, kernel_size=(3, 3), data_layout="NHWC", kernel_layout="HWIO" ) f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] @@ -455,9 +496,10 @@ def get_graph( padding=(0, 0), strides=(1, 1), dilation=(1, 1), + data_type="float16", ): - x = relay.var("x", shape=(x_shape), dtype="float32") - kernel = relay.const(np.ones(k_shape).astype("float32")) + x = relay.var("x", shape=(x_shape), dtype=data_type) + kernel = relay.const(np.ones(k_shape).astype(dtype=data_type)) out = relay.nn.conv2d( x, kernel, @@ -471,7 +513,8 @@ def get_graph( f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph(), run_module=run_module) + for tp in ["float16"]: + run_and_verify_func(get_graph(data_type=tp), run_module=run_module, data_type=tp) def test_conv2d_weights_transposed(run_module): @@ -489,16 +532,17 @@ def get_graph(x_shape=(1, 32, 9, 9), k_shape=(3, 3, 32, 16), order=(3, 2, 0, 1)) def test_dense(run_module): - def get_graph(x_shape=(1, 16), k_shape=(32, 16)): - x = relay.var("x", shape=(x_shape), dtype="float32") - kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + def get_graph(x_shape=(1, 16), k_shape=(32, 16), dtp="float16"): + x = relay.var("x", shape=(x_shape), dtype=dtp) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtp) # Dense requires constant weights in TensorRT, so the weights are transposed by us. out = relay.nn.dense(x, kernel, units=k_shape[0]) f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] - run_and_verify_func(get_graph(), run_module=run_module) - run_and_verify_func(get_graph(k_shape=(1, 16)), run_module=run_module) + for tp in ["float32"]: + run_and_verify_func(get_graph(dtp=tp), run_module=run_module, data_type=tp) + run_and_verify_func(get_graph(k_shape=(1, 16), dtp=tp), run_module=run_module, data_type=tp) def test_batch_matmul(run_module): @@ -560,13 +604,7 @@ def get_graph( count_include_pad=count_include_pad, ) else: - out = op( - x, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - ) + out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode) f = relay.Function([x], out) return f, {"x": x_shape}, [] @@ -616,13 +654,14 @@ def get_graph(op, x_shape=(1, 3, 32, 32)): def test_batch_flatten(run_module): - def get_graph(x_shape=(1, 3, 4, 6)): - x = relay.var("x", shape=(x_shape), dtype="float32") + def get_graph(x_shape=(1, 3, 4, 6), data_type="float16"): + x = relay.var("x", shape=(x_shape), dtype=data_type) out = relay.nn.batch_flatten(x) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph(), run_module=run_module) + for dtp in ["float16", "float32"]: + run_and_verify_func(get_graph(data_type=dtp), run_module=run_module, data_type=dtp) def test_expand_dims(run_module): @@ -636,14 +675,19 @@ def get_graph(x_shape=(1, 3), axis=1, num_newaxis=1): def test_squeeze(run_module): - def get_graph(x_shape, axis): - x = relay.var("x", shape=(x_shape), dtype="float32") + def get_graph(x_shape, axis, dtype): + x = relay.var("x", shape=(x_shape), dtype=dtype) out = relay.squeeze(x, axis=axis) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph((1, 5, 1, 1), (2, 3)), run_module=run_module) - run_and_verify_func(get_graph((1, 3, 1), (-1,)), run_module=run_module) + for dtype in SUPPORTED_DTYPES: + run_and_verify_func( + get_graph((1, 5, 1, 1), (2, 3), dtype=dtype), run_module=run_module, data_type=dtype + ) + run_and_verify_func( + get_graph((1, 3, 1), (-1,), dtype=dtype), run_module=run_module, data_type=dtype + ) def test_concatenate(run_module): @@ -678,11 +722,7 @@ def get_graph(x_shape, indices_or_sections, axis): def test_conv2d_transpose(run_module): def get_graph( - x_shape=(1, 32, 8, 8), - k_shape=(32, 16, 3, 3), - groups=1, - padding=(0, 0), - strides=(1, 1), + x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), groups=1, padding=(0, 0), strides=(1, 1) ): x = relay.var("x", shape=(x_shape), dtype="float32") kernel = relay.var("kernel", shape=(k_shape), dtype="float32") @@ -705,7 +745,7 @@ def get_graph( def test_reshape(run_module): def get_graph(x_shape, new_shape): - x = relay.var("x", shape=(x_shape), dtype="float32") + x = relay.var("x", shape=(x_shape), dtype="float16") out = relay.reshape(x, new_shape) f = relay.Function([x], out) return f, {"x": x_shape}, [] @@ -836,6 +876,17 @@ def get_graph(x_shape=(1, 16)): f = relay.Function([x], out) return f, {"x": x_shape}, [] + run_and_verify_func(get_graph(), run_module=run_module, data_type="float32") + + +def test_float_const16(run_module): + def get_graph(x_shape=(1, 16)): + x = relay.var("x", shape=(x_shape), dtype="float16") + beta = relay.const(1, dtype="float16") + out = relay.multiply(x, beta) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + run_and_verify_func(get_graph(), run_module=run_module) @@ -861,17 +912,44 @@ def get_graph(x_shape, pad_width): ) +def test_add(run_module): + def get_graph(x_shape): + x = relay.var("x", shape=(x_shape), dtype="float16") + y = relay.var("y", shape=(x_shape), dtype="float16") + out = relay.add(x, y) + f = relay.Function([x, y], out) + return f, {"x": x_shape, "y": x_shape}, [] + + run_and_verify_func(get_graph((1, 1000)), run_module=run_module, data_type="float16") + + def test_softmax(run_module): - def get_graph(x_shape, axis): - x = relay.var("x", shape=(x_shape), dtype="float32") + def get_graph(x_shape, axis, data_type="float32"): + x = relay.var("x", shape=(x_shape), dtype=data_type) out = relay.nn.softmax(x, axis=axis) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph((1, 1000), axis=1), run_module=run_module) - run_and_verify_func(get_graph((1, 1000), axis=-1), run_module=run_module) - run_and_verify_func(get_graph((1, 3, 4), axis=-2), run_module=run_module) - run_and_verify_func(get_graph((1, 3, 4), axis=1), run_module=run_module) + run_and_verify_func( + get_graph((1, 1000), axis=1, data_type="float32"), + run_module=run_module, + data_type="float32", + ) + run_and_verify_func( + get_graph((1, 1000), axis=-1, data_type="float32"), + run_module=run_module, + data_type="float32", + ) + run_and_verify_func( + get_graph((1, 3, 4), axis=-2, data_type="float16"), + run_module=run_module, + data_type="float16", + ) + run_and_verify_func( + get_graph((1, 3, 4), axis=1, data_type="float16"), + run_module=run_module, + data_type="float16", + ) def test_batch_norm(run_module): @@ -923,24 +1001,10 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): gamma = relay.var("gamma", shape=(param_shape), dtype="float32") beta = relay.var("beta", shape=(param_shape), dtype="float32") out = relay.nn.layer_norm( - x, - gamma=gamma, - beta=beta, - axis=axis, - epsilon=epsilon, - center=True, - scale=True, + x, gamma=gamma, beta=beta, axis=axis, epsilon=epsilon, center=True, scale=True ) f = relay.Function([x, gamma, beta], out) - return ( - f, - { - "x": x_shape, - "beta": param_shape, - "gamma": param_shape, - }, - ["beta", "gamma"], - ) + return (f, {"x": x_shape, "beta": param_shape, "gamma": param_shape}, ["beta", "gamma"]) run_and_verify_func(get_graph((1, 32, 8, 8), (32,)), run_module=run_module) run_and_verify_func( @@ -977,91 +1041,116 @@ def get_graph(op, x_shape=(1, 8, 3, 3)): def test_clip(run_module): def get_graph(x_shape=(1, 8, 3, 3)): - x = relay.var("x", shape=(x_shape), dtype="float32") + x = relay.var("x", shape=(x_shape), dtype="float16") out = relay.clip(x, a_min=-0.2, a_max=0.4) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph(), run_module=run_module) + run_and_verify_func(get_graph(), run_module=run_module, data_type="float16") + + +def test_relu(run_module): + def get_graph(x_shape=(1, 8, 3, 4)): + x = relay.var("x", shape=(x_shape), dtype="float16") + out = relay.nn.relu(x) + f = relay.Function([x], out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(), run_module=run_module, data_type="float16") def test_leaky_relu(run_module): - def get_graph(x_shape=(1, 8, 3, 3)): - x = relay.var("x", shape=(x_shape), dtype="float32") + def get_graph(x_shape=(1, 8, 3, 4)): + x = relay.var("x", shape=(x_shape), dtype="float16") out = relay.nn.leaky_relu(x, alpha=0.1) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph(), run_module=run_module) + run_and_verify_func(get_graph(), run_module=run_module, data_type="float16") def test_binary(run_module): - def get_graph(op, x_shape, y_shape, y_is_const=False): - x = relay.var("x", shape=(x_shape), dtype="float32") + def get_graph(op, x_shape, y_shape, y_is_const=False, d_type="float16"): + x = relay.var("x", shape=(x_shape), dtype=d_type) if y_is_const: - y = relay.const(np.ones(y_shape).astype("float32")) + y = relay.const(np.ones(y_shape).astype(d_type)) out = op(x, y) f = relay.Function([x], out) return f, {"x": x_shape}, [] - y = relay.var("y", shape=(y_shape), dtype="float32") + y = relay.var("y", shape=(y_shape), dtype=d_type) out = op(x, y) f = relay.Function([x, y], out) return f, {"x": x_shape, "y": y_shape}, [] for op in [relay.add, relay.subtract, relay.multiply, relay.divide, relay.power]: - for y_is_const in [True, False]: - run_and_verify_func( - get_graph(op, (1, 8, 3, 3), (1, 8, 3, 3), y_is_const), run_module=run_module - ) - run_and_verify_func( - get_graph(op, (1, 8, 1, 3), (1, 8, 3, 1), y_is_const), run_module=run_module - ) - run_and_verify_func(get_graph(op, (1, 10), (10,), y_is_const), run_module=run_module) - run_and_verify_func( - get_graph(op, (1, 1, 1, 10), (10,), y_is_const), run_module=run_module - ) - run_and_verify_func(get_graph(op, (1, 1, 1), (3,), y_is_const), run_module=run_module) + for d_type in SUPPORTED_DTYPES: + for y_is_const in [True, False]: + run_and_verify_func( + get_graph(op, (1, 8, 3, 3), (1, 8, 3, 3), y_is_const, d_type), + run_module=run_module, + data_type=d_type, + ) + run_and_verify_func( + get_graph(op, (1, 8, 1, 3), (1, 8, 3, 1), y_is_const, d_type), + run_module=run_module, + data_type=d_type, + ) + run_and_verify_func( + get_graph(op, (1, 10), (10,), y_is_const, d_type), + run_module=run_module, + data_type=d_type, + ) + run_and_verify_func( + get_graph(op, (1, 1, 1, 10), (10,), y_is_const, d_type), + run_module=run_module, + data_type=d_type, + ) + run_and_verify_func( + get_graph(op, (1, 1, 1), (3,), y_is_const, d_type), + run_module=run_module, + data_type=d_type, + ) def test_reduce(run_module): - def get_graph(op, x_shape=(1, 2, 3, 4), axis=(2, 3), keepdims=False): - x = relay.var("x", shape=(x_shape), dtype="float32") + def get_graph(op, x_shape=(1, 2, 3, 4), axis=(2, 3), keepdims=False, d_type="float32"): + x = relay.var("x", shape=(x_shape), dtype=d_type) out = op(x, axis=axis, keepdims=keepdims) f = relay.Function([x], out) return f, {"x": x_shape}, [] - for op in [relay.sum, relay.prod, relay.max, relay.min, relay.mean]: - for keepdims in [True, False]: - run_and_verify_func(get_graph(op, axis=(1), keepdims=keepdims), run_module=run_module) - run_and_verify_func( - get_graph(op, axis=(2, 3), keepdims=keepdims), run_module=run_module - ) - run_and_verify_func( - get_graph(op, axis=(1, 2), keepdims=keepdims), run_module=run_module - ) - run_and_verify_func( - get_graph(op, axis=(1, 2, 3), keepdims=keepdims), run_module=run_module - ) + for type in SUPPORTED_DTYPES: + for op in [relay.sum, relay.prod, relay.max, relay.min, relay.mean]: + for keepdims in [True, False]: + run_and_verify_func( + get_graph(op, axis=(1), keepdims=keepdims, d_type=type), + run_module=run_module, + data_type=type, + ) + run_and_verify_func( + get_graph(op, axis=(2, 3), keepdims=keepdims, d_type=type), + run_module=run_module, + data_type=type, + ) + run_and_verify_func( + get_graph(op, axis=(1, 2), keepdims=keepdims, d_type=type), + run_module=run_module, + data_type=type, + ) + run_and_verify_func( + get_graph(op, axis=(1, 2, 3), keepdims=keepdims, d_type=type), + run_module=run_module, + data_type=type, + ) def test_strided_slice(run_module): def get_graph(x_shape, begin, end, strides=None, slice_mode="size"): x = relay.var("x", shape=(x_shape), dtype="float32") if strides: - out = relay.strided_slice( - x, - begin, - end, - strides, - slice_mode=slice_mode, - ) + out = relay.strided_slice(x, begin, end, strides, slice_mode=slice_mode) else: - out = relay.strided_slice( - x, - begin, - end, - slice_mode=slice_mode, - ) + out = relay.strided_slice(x, begin, end, slice_mode=slice_mode) f = relay.Function([x], out) return f, {"x": x_shape}, [] @@ -1088,27 +1177,37 @@ def get_graph(x_shape, begin, end, strides=None, slice_mode="size"): def test_adaptive_pool2d(run_module): - def get_graph(op, x_shape=(1, 3, 32, 32), out_size=(1, 1)): - x = relay.var("x", shape=(x_shape), dtype="float32") + def get_graph(op, x_shape=(1, 3, 32, 32), out_size=(1, 1), data_type="float16"): + x = relay.var("x", shape=(x_shape), dtype=data_type) out = op(x, out_size) f = relay.Function([x], out) return f, {"x": x_shape}, [] - run_and_verify_func(get_graph(relay.nn.adaptive_max_pool2d), run_module=run_module) - run_and_verify_func(get_graph(relay.nn.adaptive_avg_pool2d), run_module=run_module) + for type in SUPPORTED_DTYPES: + run_and_verify_func( + get_graph(relay.nn.adaptive_max_pool2d, data_type=type), + run_module=run_module, + data_type=type, + ) + run_and_verify_func( + get_graph(relay.nn.adaptive_avg_pool2d, data_type=type), + run_module=run_module, + data_type=type, + ) def test_multiple_outputs(run_module): - def get_graph(): - x = relay.var("x", shape=(1, 3), dtype="float32") - y = relay.var("y", shape=(1, 3), dtype="float32") + def get_graph(d_type="float16"): + x = relay.var("x", shape=(1, 3), dtype=d_type) + y = relay.var("y", shape=(1, 3), dtype=d_type) z = relay.add(x, y) w = relay.add(z, y) out = relay.Tuple((z, w)) f = relay.Function([x, y], out) return f, {"x": (1, 3), "y": (1, 3)}, [] - run_and_verify_func(get_graph(), run_module=run_module) + for type in SUPPORTED_DTYPES: + run_and_verify_func(get_graph(d_type=type), run_module=run_module, data_type=type) def test_conv3d(run_module): @@ -1160,13 +1259,7 @@ def get_graph( count_include_pad=count_include_pad, ) else: - out = op( - x, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - ) + out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode) f = relay.Function([x], out) return f, {"x": x_shape}, [] @@ -1482,7 +1575,8 @@ def get_maskrcnn_input(in_size: int) -> np.ndarray: # Descending sort by scores and get the high confidence indices pt_indices = np.argsort(-1 * out[1].numpy())[:num_high_confidence_boxes] - tol = [1e-1, 5e-3, 1e-5, 4e-1] # [Box Tol, Score Tol, Label Tol, Mask Tol] + # [Box Tol, Score Tol, Label Tol, Mask Tol] + tol = [1e-1, 5e-3, 1e-5, 4e-1] # Because of certain ops, there are certain minor differences in TVM outputs and PT outputs, # This means that the tolerance can't be 1e-4 or 1e-5 throughout. The ideal way to get around # this is to test it on an entire dataset and compare mAP with the original model.