From 681a37021e85dfabd34f9f6fa8e795d220ebd21f Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Fri, 2 Dec 2022 15:45:40 +0800 Subject: [PATCH 01/13] add onehot trt converter --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../inference/tensorrt/convert/one_hot_op.cc | 115 ++++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 9 ++ .../unittests/ir/inference/auto_scan_test.py | 4 +- 5 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/one_hot_op.cc diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 1c27c008d8ca7..d1d2951ad84ee 100755 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2268,6 +2268,7 @@ USE_TRT_CONVERTER(conv2d_transpose); USE_TRT_CONVERTER(leaky_relu); USE_TRT_CONVERTER(shuffle_channel); USE_TRT_CONVERTER(where); +USE_TRT_CONVERTER(one_hot); USE_TRT_CONVERTER(swish); USE_TRT_CONVERTER(silu); USE_TRT_CONVERTER(group_norm); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index b796cf1c2a230..47f4e152503f8 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -27,6 +27,7 @@ list( shuffle_channel_op.cc fill_any_like_op.cc where_op.cc + one_hot_op.cc swish_op.cc silu_op.cc instance_norm_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc new file mode 100644 index 0000000000000..106f71222f5cc --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * OneHot Op + */ +class OneHotOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a fluid one_hot op to tensorrt one_hot layer"; + + framework::OpDesc op_desc(op, nullptr); + + const auto indices_tensor = engine_->GetITensor(op_desc.Input("X").front()); + const nvinfer1::ITensor* values_tensor; + const nvinfer1::ITensor* depth_tensor; + const int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + const bool allow_out_of_range = + PADDLE_GET_CONST(int, op_desc.GetAttr("allow_out_of_range")); + PADDLE_ENFORCE_EQ(allow_out_of_range, + false, + platform::errors::InvalidArgument( + "Errors occurs in Paddle-TRT one_hot op, " + "allow_out_of_range is not supported")); + // const int axis = ; + + nvinfer1::Dims trt_values_tensor_shape; + trt_values_tensor_shape.nbDims = 1; + trt_values_tensor_shape.d[0] = 2; + + if (dtype == 2) { // int + const int values_data[2] = {0, 1}; + values_tensor = AddConstantLayer( + values_data, trt_values_tensor_shape, "values_tensor"); + } else if (dtype == 3) { // int64 + const int64_t values_data[2] = {0, 1}; + values_tensor = AddConstantLayer( + values_data, trt_values_tensor_shape, "values_tensor"); + } else if (dtype == 5) { // float + const float values_data[2] = {0.0f, 1.0f}; + values_tensor = AddConstantLayer( + values_data, trt_values_tensor_shape, "values_tensor"); + } + + nvinfer1::Dims indices_dims = indices_tensor->getDimensions(); + auto depth_name = op_desc.Input("depth_tensor"); + if (depth_name.size() == 0) { + const int depth = PADDLE_GET_CONST(int, op_desc.GetAttr("depth")); + PADDLE_ENFORCE_GT(depth, + 0, + platform::errors::InvalidArgument( + "Errors occurs in Paddle-TRT one_hot op, " + "axis must bigger than zero")); + + int32_t last_dim = 1; + int32_t length = 1; + for (int32_t i = 0; i < indices_dims.nbDims; i++) { + last_dim = indices_dims.d[i]; + length *= last_dim; + } + if (last_dim == 1) { + indices_dims.nbDims--; + } + const int* depth_data = new int[length](); + depth_tensor = + AddConstantLayer(depth_data, indices_dims, "values_tensor"); + } else { + depth_tensor = engine_->GetITensor(depth_name.front()); + } + auto layer = TRT_ENGINE_ADD_LAYER(engine_, + OneHot, + *indices_tensor, + *values_tensor, + *depth_tensor, + indices_dims.nbDims); + + auto output_name = op_desc.Output("Out").front(); + RreplenishLayerAndOutput(layer, "one_hot", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(one_hot, OneHotOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 17fb2f0aa6d09..f87cc70f0c03a 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1742,6 +1742,13 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "one_hot") { +#if !IS_TRT_VERSION_GE(8510) + VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; + return false; +#endif + } + if (op_type == "skip_layernorm") { if (!with_dynamic_shape) { VLOG(3) << "the skip_layernorm does not support static shape yet"; @@ -2391,6 +2398,7 @@ struct SimpleOpTypeSetTeller : public Teller { "fc", "shuffle_channel", "where", + "one_hot", "swish", "silu", "celu", @@ -2523,6 +2531,7 @@ struct SimpleOpTypeSetTeller : public Teller { "fc", "shuffle_channel", "where", + "one_hot", "swish", "silu", "celu", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py index b561822f1af92..d3089d60b6fca 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py @@ -105,7 +105,9 @@ def sample_program_configs(self): raise NotImplementedError @abc.abstractmethod - def sample_predictor_configs(self): + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): raise NotImplementedError @abc.abstractmethod From f6c918df5aafdacae25d5359fe6140ad32323fbf Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Fri, 2 Dec 2022 16:16:28 +0800 Subject: [PATCH 02/13] add unitest --- .../ir/inference/test_trt_convert_one_hot.py | 186 ++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py new file mode 100644 index 0000000000000..c37f9a4040dcc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py @@ -0,0 +1,186 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertOneHotTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8510: + return False + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_indices(dims, batch): + if dims == 1: + return np.random.randint(0, 10, (batch,), dtype=np.int) + elif dims == 2: + return np.random.randint(0, 10, (batch, 4), dtype=np.int) + elif dims == 3: + return np.random.randint(0, 10, (batch, 4, 6), dtype=np.int) + else: + return np.random.randint(0, 10, (batch, 4, 6, 8), dtype=np.int) + + def generate_depth(dims, batch): + if dims == 1: + return np.ones((batch,), dtype=np.int) * 10 + elif dims == 2: + return np.ones((batch, 4), dtype=np.int) * 10 + elif dims == 3: + return np.ones((batch, 4, 6), dtype=np.int) * 10 + else: + return np.ones((batch, 4, 6, 8), dtype=np.int) * 10 + + for dims in [1, 2, 3, 4]: + for batch in [1, 2]: + self.dims = dims + dics = [{"dtype": 2, "depth": 10}, {}] + ops_config = [ + { + "op_type": "one_hot", + "op_inputs": { + "X": ["input_x_data"], + "depth_tensor": ["input_depth_data"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": dics[0], + "outputs_dtype": {"output_data": np.int}, + }, + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "indices_tensor": TensorConfig( + data_gen=partial(generate_indices, dims, batch) + ), + "depth_tensor": TensorConfig( + data_gen=partial(generate_depth, dims, batch) + ), + }, + outputs=["output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = { + "input_x_data": [1], + "input_depth_data": [1], + } + self.dynamic_shape.max_input_shape = { + "input_x_data": [2], + "input_depth_data": [2], + } + self.dynamic_shape.opt_input_shape = { + "input_x_data": [1], + "input_depth_data": [1], + } + elif self.dims == 2: + self.dynamic_shape.min_input_shape = { + "input_x_data": [1, 4], + "input_depth_data": [1, 4], + } + self.dynamic_shape.max_input_shape = { + "input_x_data": [2, 4], + "input_depth_data": [2, 4], + } + self.dynamic_shape.opt_input_shape = { + "input_x_data": [1, 4], + "input_depth_data": [1, 4], + } + elif self.dims == 3: + self.dynamic_shape.min_input_shape = { + "input_x_data": [1, 4, 6], + "input_depth_data": [1, 4, 6], + } + self.dynamic_shape.max_input_shape = { + "input_x_data": [2, 4, 6], + "input_depth_data": [2, 4, 6], + } + self.dynamic_shape.opt_input_shape = { + "input_x_data": [1, 4, 6], + "input_depth_data": [1, 4, 6], + } + elif self.dims == 4: + self.dynamic_shape.min_input_shape = { + "input_x_data": [1, 4, 6, 8], + "input_depth_data": [1, 4, 6, 8], + } + self.dynamic_shape.max_input_shape = { + "input_x_data": [2, 4, 6, 8], + "input_depth_data": [2, 4, 6, 8], + } + self.dynamic_shape.opt_input_shape = { + "input_x_data": [1, 4, 6, 8], + "input_depth_data": [1, 4, 6, 8], + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if not dynamic_shape: + return 0, 6 + return 1, 4 + + attrs = [op.attrs for op in program_config.ops] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() From 12eb774c88db00add26e8828be80b8e036b03c02 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Fri, 2 Dec 2022 19:39:09 +0800 Subject: [PATCH 03/13] fix bug --- paddle/fluid/inference/tensorrt/convert/one_hot_op.cc | 6 +++++- paddle/fluid/inference/tensorrt/op_teller.cc | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc index 106f71222f5cc..7cb88a0804571 100644 --- a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc @@ -36,6 +36,7 @@ class OneHotOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { +#if IS_TRT_VERSION_GE(8150) VLOG(3) << "convert a fluid one_hot op to tensorrt one_hot layer"; framework::OpDesc op_desc(op, nullptr); @@ -92,7 +93,7 @@ class OneHotOpConverter : public OpConverter { } const int* depth_data = new int[length](); depth_tensor = - AddConstantLayer(depth_data, indices_dims, "values_tensor"); + AddConstantLayer(depth_data, indices_dims, "values_tensor"); } else { depth_tensor = engine_->GetITensor(depth_name.front()); } @@ -105,6 +106,9 @@ class OneHotOpConverter : public OpConverter { auto output_name = op_desc.Output("Out").front(); RreplenishLayerAndOutput(layer, "one_hot", {output_name}, test_mode); +#else + VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; +#endif } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index f87cc70f0c03a..0a4584859353a 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1743,7 +1743,7 @@ struct SimpleOpTypeSetTeller : public Teller { } if (op_type == "one_hot") { -#if !IS_TRT_VERSION_GE(8510) +#if IS_TRT_VERSION_LT(8510) VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; return false; #endif From db6104c4c50e0caabc73bc439a05e227aff457a9 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sat, 3 Dec 2022 14:15:22 +0800 Subject: [PATCH 04/13] opt code --- .../inference/tensorrt/convert/one_hot_op.cc | 58 +++++-------------- paddle/fluid/inference/tensorrt/op_teller.cc | 19 ++++++ 2 files changed, 32 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc index 7cb88a0804571..17cd4757a1516 100644 --- a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc @@ -36,79 +36,47 @@ class OneHotOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { -#if IS_TRT_VERSION_GE(8150) VLOG(3) << "convert a fluid one_hot op to tensorrt one_hot layer"; - framework::OpDesc op_desc(op, nullptr); const auto indices_tensor = engine_->GetITensor(op_desc.Input("X").front()); const nvinfer1::ITensor* values_tensor; const nvinfer1::ITensor* depth_tensor; - const int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); - const bool allow_out_of_range = - PADDLE_GET_CONST(int, op_desc.GetAttr("allow_out_of_range")); - PADDLE_ENFORCE_EQ(allow_out_of_range, - false, - platform::errors::InvalidArgument( - "Errors occurs in Paddle-TRT one_hot op, " - "allow_out_of_range is not supported")); - // const int axis = ; nvinfer1::Dims trt_values_tensor_shape; trt_values_tensor_shape.nbDims = 1; trt_values_tensor_shape.d[0] = 2; - if (dtype == 2) { // int - const int values_data[2] = {0, 1}; - values_tensor = AddConstantLayer( - values_data, trt_values_tensor_shape, "values_tensor"); - } else if (dtype == 3) { // int64 - const int64_t values_data[2] = {0, 1}; - values_tensor = AddConstantLayer( - values_data, trt_values_tensor_shape, "values_tensor"); + if (dtype == 2 || dtype == 3) { // int, int64 + const std::vector values_data = {0, 1}; + values_tensor = Add1DConstantLayer(values_data, "values_tensor"); + if (dtype == 3) { // int64 + VLOG(3) << "trt not support int64, so it is converted to int32."; + } } else if (dtype == 5) { // float - const float values_data[2] = {0.0f, 1.0f}; - values_tensor = AddConstantLayer( - values_data, trt_values_tensor_shape, "values_tensor"); + const std::vector values_data = {0.0f, 1.0f}; + values_tensor = Add1DConstantLayer(values_data, "values_tensor"); } nvinfer1::Dims indices_dims = indices_tensor->getDimensions(); auto depth_name = op_desc.Input("depth_tensor"); if (depth_name.size() == 0) { const int depth = PADDLE_GET_CONST(int, op_desc.GetAttr("depth")); - PADDLE_ENFORCE_GT(depth, - 0, - platform::errors::InvalidArgument( - "Errors occurs in Paddle-TRT one_hot op, " - "axis must bigger than zero")); - - int32_t last_dim = 1; int32_t length = 1; for (int32_t i = 0; i < indices_dims.nbDims; i++) { - last_dim = indices_dims.d[i]; - length *= last_dim; - } - if (last_dim == 1) { - indices_dims.nbDims--; + length *= indices_dims.d[i]; } - const int* depth_data = new int[length](); + const std::vector depth_data(length, depth); depth_tensor = - AddConstantLayer(depth_data, indices_dims, "values_tensor"); + Add1DConstantLayer(depth_data, indices_dims, "values_tensor"); } else { depth_tensor = engine_->GetITensor(depth_name.front()); } - auto layer = TRT_ENGINE_ADD_LAYER(engine_, - OneHot, - *indices_tensor, - *values_tensor, - *depth_tensor, - indices_dims.nbDims); + auto layer = TRT_ENGINE_ADD_LAYER( + engine_, OneHot, *indices_tensor, *values_tensor, *depth_tensor, -1); auto output_name = op_desc.Output("Out").front(); RreplenishLayerAndOutput(layer, "one_hot", {output_name}, test_mode); -#else - VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; -#endif } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 0a4584859353a..74ea9e2782e7b 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1747,6 +1747,25 @@ struct SimpleOpTypeSetTeller : public Teller { VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; return false; #endif + if (desc.HasAttr("allow_out_of_range")) { + VLOG(3) << "allow_out_of_range one_hot op is not supported now."; + if (PADDLE_GET_CONST(bool, desc.GetAttr("allow_out_of_range"))) + return false; + } + if (desc.HasAttr("dtype")) { + const int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); + if (dtype != 2 && dtype != 3 && dtype != 5) { + VLOG(3) << "one_hot op only support int32, int64, float."; + return false; + } + } + if (desc.HasAttr("depth")) { + const int depth = PADDLE_GET_CONST(int, desc.GetAttr("depth")); + if (depth <= 0) { + VLOG(3) << "depth only support positive in one_hot op."; + return false; + } + } } if (op_type == "skip_layernorm") { From dfbb9d7149fdfe6927c8e2b7ca71bb129670af71 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sat, 3 Dec 2022 16:21:51 +0800 Subject: [PATCH 05/13] fix bug --- paddle/fluid/inference/tensorrt/convert/one_hot_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc index 17cd4757a1516..1119070f14647 100644 --- a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc @@ -42,6 +42,7 @@ class OneHotOpConverter : public OpConverter { const auto indices_tensor = engine_->GetITensor(op_desc.Input("X").front()); const nvinfer1::ITensor* values_tensor; const nvinfer1::ITensor* depth_tensor; + const int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); nvinfer1::Dims trt_values_tensor_shape; trt_values_tensor_shape.nbDims = 1; @@ -67,8 +68,7 @@ class OneHotOpConverter : public OpConverter { length *= indices_dims.d[i]; } const std::vector depth_data(length, depth); - depth_tensor = - Add1DConstantLayer(depth_data, indices_dims, "values_tensor"); + depth_tensor = Add1DConstantLayer(depth_data, "values_tensor"); } else { depth_tensor = engine_->GetITensor(depth_name.front()); } From ca2c8cb68d816e371cedb4d40e936e228de2282c Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sun, 4 Dec 2022 17:37:49 +0800 Subject: [PATCH 06/13] fix depth_tensor --- .../inference/tensorrt/convert/one_hot_op.cc | 28 +++++++++---------- .../unittests/ir/inference/auto_scan_test.py | 4 +-- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc index 1119070f14647..a2172f7be0c7f 100644 --- a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc @@ -43,34 +43,34 @@ class OneHotOpConverter : public OpConverter { const nvinfer1::ITensor* values_tensor; const nvinfer1::ITensor* depth_tensor; const int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); - - nvinfer1::Dims trt_values_tensor_shape; - trt_values_tensor_shape.nbDims = 1; - trt_values_tensor_shape.d[0] = 2; - if (dtype == 2 || dtype == 3) { // int, int64 const std::vector values_data = {0, 1}; values_tensor = Add1DConstantLayer(values_data, "values_tensor"); if (dtype == 3) { // int64 VLOG(3) << "trt not support int64, so it is converted to int32."; } - } else if (dtype == 5) { // float + } else if (dtype == 5 || dtype == 6) { // float const std::vector values_data = {0.0f, 1.0f}; values_tensor = Add1DConstantLayer(values_data, "values_tensor"); + if (dtype == 6) { // int64 + VLOG(3) << "trt not support float64, so it is converted to float32."; + } } - nvinfer1::Dims indices_dims = indices_tensor->getDimensions(); auto depth_name = op_desc.Input("depth_tensor"); if (depth_name.size() == 0) { const int depth = PADDLE_GET_CONST(int, op_desc.GetAttr("depth")); - int32_t length = 1; - for (int32_t i = 0; i < indices_dims.nbDims; i++) { - length *= indices_dims.d[i]; - } - const std::vector depth_data(length, depth); - depth_tensor = Add1DConstantLayer(depth_data, "values_tensor"); + depth_tensor = Add1DConstantLayer(depth, "depth_tensor", true); } else { - depth_tensor = engine_->GetITensor(depth_name.front()); + nvinfer1::Dims depth_dims; + depth_dims.nbDims = 0; + const nvinfer1::ITensor* depth_tensor_paddle = + engine_->GetITensor(depth_name.front()); + auto shuffle_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *depth_tensor_paddle); + shuffle_layer->setReshapeDimensions(depth_dims); + depth_tensor = shuffle_layer->getOutput(0); + depth_tensor->setName(depth_tensor_paddle->getName()); } auto layer = TRT_ENGINE_ADD_LAYER( engine_, OneHot, *indices_tensor, *values_tensor, *depth_tensor, -1); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py index d3089d60b6fca..b561822f1af92 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py @@ -105,9 +105,7 @@ def sample_program_configs(self): raise NotImplementedError @abc.abstractmethod - def sample_predictor_configs( - self, program_config - ) -> (paddle_infer.Config, List[int], float): + def sample_predictor_configs(self): raise NotImplementedError @abc.abstractmethod From 48c0ae66ccee5150ac9f5f283a34ce094e572b43 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sun, 4 Dec 2022 18:31:44 +0800 Subject: [PATCH 07/13] fix unitest --- paddle/fluid/inference/tensorrt/op_teller.cc | 4 +++ .../ir/inference/test_trt_convert_one_hot.py | 33 +++++-------------- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 74ea9e2782e7b..1b273e938479c 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1747,6 +1747,10 @@ struct SimpleOpTypeSetTeller : public Teller { VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; return false; #endif + if (!with_dynamic_shape) { + VLOG(3) << "the one_hot op does not support static shape yet"; + return false; + } if (desc.HasAttr("allow_out_of_range")) { VLOG(3) << "allow_out_of_range one_hot op is not supported now."; if (PADDLE_GET_CONST(bool, desc.GetAttr("allow_out_of_range"))) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py index c37f9a4040dcc..46fa4353adb3d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py @@ -45,13 +45,7 @@ def generate_indices(dims, batch): def generate_depth(dims, batch): if dims == 1: - return np.ones((batch,), dtype=np.int) * 10 - elif dims == 2: - return np.ones((batch, 4), dtype=np.int) * 10 - elif dims == 3: - return np.ones((batch, 4, 6), dtype=np.int) * 10 - else: - return np.ones((batch, 4, 6, 8), dtype=np.int) * 10 + return np.ones((1,), dtype=np.int) * 10 for dims in [1, 2, 3, 4]: for batch in [1, 2]: @@ -73,14 +67,15 @@ def generate_depth(dims, batch): program_config = ProgramConfig( ops=ops, - weights={}, + weights={ + "depth_tensor": TensorConfig( + data_gen=partial(generate_depth, dims, batch) + ), + }, inputs={ "indices_tensor": TensorConfig( data_gen=partial(generate_indices, dims, batch) ), - "depth_tensor": TensorConfig( - data_gen=partial(generate_depth, dims, batch) - ), }, outputs=["output_data"], ) @@ -94,54 +89,42 @@ def generate_dynamic_shape(attrs): if self.dims == 1: self.dynamic_shape.min_input_shape = { "input_x_data": [1], - "input_depth_data": [1], } self.dynamic_shape.max_input_shape = { "input_x_data": [2], - "input_depth_data": [2], } self.dynamic_shape.opt_input_shape = { "input_x_data": [1], - "input_depth_data": [1], } elif self.dims == 2: self.dynamic_shape.min_input_shape = { "input_x_data": [1, 4], - "input_depth_data": [1, 4], } self.dynamic_shape.max_input_shape = { "input_x_data": [2, 4], - "input_depth_data": [2, 4], } self.dynamic_shape.opt_input_shape = { "input_x_data": [1, 4], - "input_depth_data": [1, 4], } elif self.dims == 3: self.dynamic_shape.min_input_shape = { "input_x_data": [1, 4, 6], - "input_depth_data": [1, 4, 6], } self.dynamic_shape.max_input_shape = { "input_x_data": [2, 4, 6], - "input_depth_data": [2, 4, 6], } self.dynamic_shape.opt_input_shape = { "input_x_data": [1, 4, 6], - "input_depth_data": [1, 4, 6], } elif self.dims == 4: self.dynamic_shape.min_input_shape = { "input_x_data": [1, 4, 6, 8], - "input_depth_data": [1, 4, 6, 8], } self.dynamic_shape.max_input_shape = { "input_x_data": [2, 4, 6, 8], - "input_depth_data": [2, 4, 6, 8], } self.dynamic_shape.opt_input_shape = { "input_x_data": [1, 4, 6, 8], - "input_depth_data": [1, 4, 6, 8], } def clear_dynamic_shape(): @@ -151,8 +134,8 @@ def clear_dynamic_shape(): def generate_trt_nodes_num(attrs, dynamic_shape): if not dynamic_shape: - return 0, 6 - return 1, 4 + return 0, 3 + return 1, 2 attrs = [op.attrs for op in program_config.ops] From 9059fec5c7032acec86d7f5570cbf6f022f084d8 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sun, 4 Dec 2022 19:36:32 +0800 Subject: [PATCH 08/13] fix bug --- paddle/fluid/inference/tensorrt/convert/one_hot_op.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc index a2172f7be0c7f..cb5afb491fee5 100644 --- a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc @@ -36,6 +36,7 @@ class OneHotOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { +#if IS_TRT_VERSION_GE(8510) VLOG(3) << "convert a fluid one_hot op to tensorrt one_hot layer"; framework::OpDesc op_desc(op, nullptr); @@ -64,19 +65,22 @@ class OneHotOpConverter : public OpConverter { } else { nvinfer1::Dims depth_dims; depth_dims.nbDims = 0; - const nvinfer1::ITensor* depth_tensor_paddle = + nvinfer1::ITensor* depth_tensor_paddle = engine_->GetITensor(depth_name.front()); auto shuffle_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *depth_tensor_paddle); shuffle_layer->setReshapeDimensions(depth_dims); + shuffle_layer->getOutput(0)->setName(depth_tensor_paddle->getName()); depth_tensor = shuffle_layer->getOutput(0); - depth_tensor->setName(depth_tensor_paddle->getName()); } auto layer = TRT_ENGINE_ADD_LAYER( engine_, OneHot, *indices_tensor, *values_tensor, *depth_tensor, -1); auto output_name = op_desc.Output("Out").front(); RreplenishLayerAndOutput(layer, "one_hot", {output_name}, test_mode); +#else + VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; +#endif } }; From baa89a7ca2378cfd8e2e6890588527caabcf3c76 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Mon, 5 Dec 2022 11:59:57 +0800 Subject: [PATCH 09/13] fix unitest --- .../ir/inference/test_trt_convert_one_hot.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py index 46fa4353adb3d..60e654bb95e5e 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_one_hot.py @@ -34,23 +34,22 @@ def sample_program_configs(self): self.trt_param.workspace_size = 1073741824 def generate_indices(dims, batch): - if dims == 1: - return np.random.randint(0, 10, (batch,), dtype=np.int) - elif dims == 2: - return np.random.randint(0, 10, (batch, 4), dtype=np.int) + if dims == 2: + return np.random.randint(0, 10, (batch, 4), dtype=np.int32) elif dims == 3: - return np.random.randint(0, 10, (batch, 4, 6), dtype=np.int) + return np.random.randint(0, 10, (batch, 4, 6), dtype=np.int32) else: - return np.random.randint(0, 10, (batch, 4, 6, 8), dtype=np.int) + return np.random.randint( + 0, 10, (batch, 4, 6, 8), dtype=np.int32 + ) def generate_depth(dims, batch): - if dims == 1: - return np.ones((1,), dtype=np.int) * 10 + return np.ones((1,), dtype=np.int32) * 10 - for dims in [1, 2, 3, 4]: + for dims in [2, 3, 4]: for batch in [1, 2]: self.dims = dims - dics = [{"dtype": 2, "depth": 10}, {}] + dics = [{"dtype": 5, "depth": 10}, {}] ops_config = [ { "op_type": "one_hot", From be73de8fc2bc80c0df711cae4e7588bdcad8db0c Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Tue, 6 Dec 2022 00:09:15 +0800 Subject: [PATCH 10/13] fix bug --- paddle/fluid/inference/api/analysis_predictor.cc | 1 + paddle/fluid/inference/tensorrt/convert/one_hot_op.cc | 5 +++-- paddle/fluid/inference/tensorrt/op_teller.cc | 7 ------- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index d1d2951ad84ee..68590fff37b61 100755 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2269,6 +2269,7 @@ USE_TRT_CONVERTER(leaky_relu); USE_TRT_CONVERTER(shuffle_channel); USE_TRT_CONVERTER(where); USE_TRT_CONVERTER(one_hot); +USE_TRT_CONVERTER(one_hot_v2); USE_TRT_CONVERTER(swish); USE_TRT_CONVERTER(silu); USE_TRT_CONVERTER(group_norm); diff --git a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc index cb5afb491fee5..f1ea2fcc482a1 100644 --- a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc @@ -41,8 +41,8 @@ class OneHotOpConverter : public OpConverter { framework::OpDesc op_desc(op, nullptr); const auto indices_tensor = engine_->GetITensor(op_desc.Input("X").front()); - const nvinfer1::ITensor* values_tensor; - const nvinfer1::ITensor* depth_tensor; + nvinfer1::ITensor* values_tensor; + nvinfer1::ITensor* depth_tensor; const int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); if (dtype == 2 || dtype == 3) { // int, int64 const std::vector values_data = {0, 1}; @@ -89,3 +89,4 @@ class OneHotOpConverter : public OpConverter { } // namespace paddle REGISTER_TRT_OP_CONVERTER(one_hot, OneHotOpConverter); +REGISTER_TRT_OP_CONVERTER(one_hot_v2, OneHotOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 1b273e938479c..a5884985b4f89 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1763,13 +1763,6 @@ struct SimpleOpTypeSetTeller : public Teller { return false; } } - if (desc.HasAttr("depth")) { - const int depth = PADDLE_GET_CONST(int, desc.GetAttr("depth")); - if (depth <= 0) { - VLOG(3) << "depth only support positive in one_hot op."; - return false; - } - } } if (op_type == "skip_layernorm") { From 9eaa050b69f0c144b24695627fb18bacee169198 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Tue, 6 Dec 2022 21:34:12 +0800 Subject: [PATCH 11/13] fix bug --- paddle/fluid/inference/tensorrt/op_teller.cc | 22 +++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index a5884985b4f89..0099a3a0e5fd3 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1742,24 +1742,36 @@ struct SimpleOpTypeSetTeller : public Teller { } } - if (op_type == "one_hot") { + if (op_type == "one_hot" || op_type == "one_hot_v2") { #if IS_TRT_VERSION_LT(8510) - VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; + VLOG(3) << "one_hot/one_hot_v2 is not supported when TensorRT < 8.5.1"; return false; #endif if (!with_dynamic_shape) { - VLOG(3) << "the one_hot op does not support static shape yet"; + VLOG(3) + << "the one_hot/one_hot_v2 op does not support static shape yet"; return false; } if (desc.HasAttr("allow_out_of_range")) { - VLOG(3) << "allow_out_of_range one_hot op is not supported now."; + VLOG(3) + << "allow_out_of_range one_hot/one_hot_v2 op is not supported now."; if (PADDLE_GET_CONST(bool, desc.GetAttr("allow_out_of_range"))) return false; } if (desc.HasAttr("dtype")) { const int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); if (dtype != 2 && dtype != 3 && dtype != 5) { - VLOG(3) << "one_hot op only support int32, int64, float."; + VLOG(3) << "one_hot/one_hot_v2 op only support int32, int64, float."; + return false; + } + } + if (op_desc.Input("depth_tensor").size() != 0) { + return true; + } + if (desc.HasAttr("depth")) { + const int depth = PADDLE_GET_CONST(int, desc.GetAttr("depth")); + if (depth <= 0) { + VLOG(3) << "depth only support positive in one_hot/one_hot_v2 op."; return false; } } From d96f615e6e2e75269d59e4b45e49c6403e312a25 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Wed, 7 Dec 2022 15:04:54 +0800 Subject: [PATCH 12/13] fix bug --- paddle/fluid/inference/tensorrt/op_teller.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 0099a3a0e5fd3..45b6abbd77e01 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1765,9 +1765,13 @@ struct SimpleOpTypeSetTeller : public Teller { return false; } } - if (op_desc.Input("depth_tensor").size() != 0) { - return true; + auto one_hot_inputs = desc.Inputs(); + if (one_hot_inputs.find("depth_tensor") != one_hot_inputs.end()) { + if (op_desc.Input("depth_tensor").size() != 0) { + return true; + } } + if (desc.HasAttr("depth")) { const int depth = PADDLE_GET_CONST(int, desc.GetAttr("depth")); if (depth <= 0) { @@ -2427,6 +2431,7 @@ struct SimpleOpTypeSetTeller : public Teller { "shuffle_channel", "where", "one_hot", + "one_hot_v2", "swish", "silu", "celu", @@ -2560,6 +2565,7 @@ struct SimpleOpTypeSetTeller : public Teller { "shuffle_channel", "where", "one_hot", + "one_hot_v2", "swish", "silu", "celu", From 7e41019fcc5daeceda3c6e9227b96bb7c0d8a8bb Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Wed, 7 Dec 2022 16:33:45 +0800 Subject: [PATCH 13/13] fix bug --- paddle/fluid/inference/tensorrt/op_teller.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 45b6abbd77e01..aba57215939ca 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1767,7 +1767,7 @@ struct SimpleOpTypeSetTeller : public Teller { } auto one_hot_inputs = desc.Inputs(); if (one_hot_inputs.find("depth_tensor") != one_hot_inputs.end()) { - if (op_desc.Input("depth_tensor").size() != 0) { + if (desc.Input("depth_tensor").size() != 0) { return true; } }