diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 1c27c008d8ca7..68590fff37b61 100755 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2268,6 +2268,8 @@ USE_TRT_CONVERTER(conv2d_transpose); USE_TRT_CONVERTER(leaky_relu); USE_TRT_CONVERTER(shuffle_channel); USE_TRT_CONVERTER(where); +USE_TRT_CONVERTER(one_hot); +USE_TRT_CONVERTER(one_hot_v2); USE_TRT_CONVERTER(swish); USE_TRT_CONVERTER(silu); USE_TRT_CONVERTER(group_norm); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index b796cf1c2a230..47f4e152503f8 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -27,6 +27,7 @@ list( shuffle_channel_op.cc fill_any_like_op.cc where_op.cc + one_hot_op.cc swish_op.cc silu_op.cc instance_norm_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc new file mode 100644 index 0000000000000..f1ea2fcc482a1 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc @@ -0,0 +1,92 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * OneHot Op + */ +class OneHotOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { +#if IS_TRT_VERSION_GE(8510) + VLOG(3) << "convert a fluid one_hot op to tensorrt one_hot layer"; + framework::OpDesc op_desc(op, nullptr); + + const auto indices_tensor = engine_->GetITensor(op_desc.Input("X").front()); + nvinfer1::ITensor* values_tensor; + nvinfer1::ITensor* depth_tensor; + const int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + if (dtype == 2 || dtype == 3) { // int, int64 + const std::vector values_data = {0, 1}; + values_tensor = Add1DConstantLayer(values_data, "values_tensor"); + if (dtype == 3) { // int64 + VLOG(3) << "trt not support int64, so it is converted to int32."; + } + } else if (dtype == 5 || dtype == 6) { // float + const std::vector values_data = {0.0f, 1.0f}; + values_tensor = Add1DConstantLayer(values_data, "values_tensor"); + if (dtype == 6) { // int64 + VLOG(3) << "trt not support float64, so it is converted to float32."; + } + } + + auto depth_name = op_desc.Input("depth_tensor"); + if (depth_name.size() == 0) { + const int depth = PADDLE_GET_CONST(int, op_desc.GetAttr("depth")); + depth_tensor = Add1DConstantLayer(depth, "depth_tensor", true); + } else { + nvinfer1::Dims depth_dims; + depth_dims.nbDims = 0; + nvinfer1::ITensor* depth_tensor_paddle = + engine_->GetITensor(depth_name.front()); + auto shuffle_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *depth_tensor_paddle); + shuffle_layer->setReshapeDimensions(depth_dims); + shuffle_layer->getOutput(0)->setName(depth_tensor_paddle->getName()); + depth_tensor = shuffle_layer->getOutput(0); + } + auto layer = TRT_ENGINE_ADD_LAYER( + engine_, OneHot, *indices_tensor, *values_tensor, *depth_tensor, -1); + + auto output_name = op_desc.Output("Out").front(); + RreplenishLayerAndOutput(layer, "one_hot", {output_name}, test_mode); +#else + VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; +#endif + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(one_hot, OneHotOpConverter); +REGISTER_TRT_OP_CONVERTER(one_hot_v2, OneHotOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 17fb2f0aa6d09..aba57215939ca 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1742,6 +1742,45 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "one_hot" || op_type == "one_hot_v2") { +#if IS_TRT_VERSION_LT(8510) + VLOG(3) << "one_hot/one_hot_v2 is not supported when TensorRT < 8.5.1"; + return false; +#endif + if (!with_dynamic_shape) { + VLOG(3) + << "the one_hot/one_hot_v2 op does not support static shape yet"; + return false; + } + if (desc.HasAttr("allow_out_of_range")) { + VLOG(3) + << "allow_out_of_range one_hot/one_hot_v2 op is not supported now."; + if (PADDLE_GET_CONST(bool, desc.GetAttr("allow_out_of_range"))) + return false; + } + if (desc.HasAttr("dtype")) { + const int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); + if (dtype != 2 && dtype != 3 && dtype != 5) { + VLOG(3) << "one_hot/one_hot_v2 op only support int32, int64, float."; + return false; + } + } + auto one_hot_inputs = desc.Inputs(); + if (one_hot_inputs.find("depth_tensor") != one_hot_inputs.end()) { + if (desc.Input("depth_tensor").size() != 0) { + return true; + } + } + + if (desc.HasAttr("depth")) { + const int depth = PADDLE_GET_CONST(int, desc.GetAttr("depth")); + if (depth <= 0) { + VLOG(3) << "depth only support positive in one_hot/one_hot_v2 op."; + return false; + } + } + } + if (op_type == "skip_layernorm") { if (!with_dynamic_shape) { VLOG(3) << "the skip_layernorm does not support static shape yet"; @@ -2391,6 +2430,8 @@ struct SimpleOpTypeSetTeller : public Teller { "fc", "shuffle_channel", "where", + "one_hot", + "one_hot_v2", "swish", "silu", "celu", @@ -2523,6 +2564,8 @@ struct SimpleOpTypeSetTeller : public Teller { "fc", "shuffle_channel", "where", + "one_hot", + "one_hot_v2", "swish", "silu", "celu", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py new file mode 100644 index 0000000000000..60e654bb95e5e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py @@ -0,0 +1,168 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertOneHotTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8510: + return False + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_indices(dims, batch): + if dims == 2: + return np.random.randint(0, 10, (batch, 4), dtype=np.int32) + elif dims == 3: + return np.random.randint(0, 10, (batch, 4, 6), dtype=np.int32) + else: + return np.random.randint( + 0, 10, (batch, 4, 6, 8), dtype=np.int32 + ) + + def generate_depth(dims, batch): + return np.ones((1,), dtype=np.int32) * 10 + + for dims in [2, 3, 4]: + for batch in [1, 2]: + self.dims = dims + dics = [{"dtype": 5, "depth": 10}, {}] + ops_config = [ + { + "op_type": "one_hot", + "op_inputs": { + "X": ["input_x_data"], + "depth_tensor": ["input_depth_data"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": dics[0], + "outputs_dtype": {"output_data": np.int}, + }, + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "depth_tensor": TensorConfig( + data_gen=partial(generate_depth, dims, batch) + ), + }, + inputs={ + "indices_tensor": TensorConfig( + data_gen=partial(generate_indices, dims, batch) + ), + }, + outputs=["output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = { + "input_x_data": [1], + } + self.dynamic_shape.max_input_shape = { + "input_x_data": [2], + } + self.dynamic_shape.opt_input_shape = { + "input_x_data": [1], + } + elif self.dims == 2: + self.dynamic_shape.min_input_shape = { + "input_x_data": [1, 4], + } + self.dynamic_shape.max_input_shape = { + "input_x_data": [2, 4], + } + self.dynamic_shape.opt_input_shape = { + "input_x_data": [1, 4], + } + elif self.dims == 3: + self.dynamic_shape.min_input_shape = { + "input_x_data": [1, 4, 6], + } + self.dynamic_shape.max_input_shape = { + "input_x_data": [2, 4, 6], + } + self.dynamic_shape.opt_input_shape = { + "input_x_data": [1, 4, 6], + } + elif self.dims == 4: + self.dynamic_shape.min_input_shape = { + "input_x_data": [1, 4, 6, 8], + } + self.dynamic_shape.max_input_shape = { + "input_x_data": [2, 4, 6, 8], + } + self.dynamic_shape.opt_input_shape = { + "input_x_data": [1, 4, 6, 8], + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if not dynamic_shape: + return 0, 3 + return 1, 2 + + attrs = [op.attrs for op in program_config.ops] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main()