diff --git a/lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.cc b/lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.cc index 038e7e22678..d55b9aad45c 100644 --- a/lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.cc +++ b/lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.cc @@ -44,8 +44,8 @@ void XPUStaticKernelPickPass::Apply(const std::unique_ptr& graph) { // Collect input data precision for each node in the graph #ifdef LITE_WITH_XPU DicideUseFP16Optimizer(graph); + GetXPUDeviceType(); if (xpu_use_fp16_optimizer_) { - GetXPUDeviceType(); for (auto& node : graph->StmtTopologicalOrder()) { if (!node->IsStmt()) continue; if (xpu_special_op_.count(node->AsStmt().op_type())) { @@ -235,6 +235,12 @@ void XPUStaticKernelPickPass::Apply(const std::unique_ptr& graph) { #ifdef LITE_WITH_XPU void XPUStaticKernelPickPass::DicideUseFP16Optimizer( const std::unique_ptr& graph) { + if (GetStringFromEnv("XPUForceUseFP16", "false") == "true") { + xpu_use_fp16_optimizer_ = false; + VLOG(2) << "XPU force use data precision: FP16 "; + return; + } + if (graph->valid_places()[0].precision == PrecisionType::kFP16) { xpu_use_fp16_optimizer_ = true; VLOG(2) << "XPU auto use data precision: FP16/FP32/INT16 "; diff --git a/lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.h b/lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.h index af9fa0435ac..38f786b5216 100644 --- a/lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.h +++ b/lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.h @@ -95,7 +95,11 @@ class XPUStaticKernelPickPass : public mir::StmtPass { if (kernel_pick_factors_.IsPrecisionConsidered() && (place.precision == kernel.precision() || kernel.precision() == PRECISION(kAny) || - place.precision == PRECISION(kAny))) { + place.precision == PRECISION(kAny) || + // fp16 may also pick FP32 kernel preciison + (xpu_use_fp16_optimizer_ && + kernel.precision() == PRECISION(kFloat) && + place.precision == PRECISION(kFP16)))) { // score skipped, if kernel is int8, but op is not int8 if (!(kernel.precision() == PRECISION(kInt8) && !instruct.op_info()->HasAttr("enable_int8"))) { @@ -294,8 +298,9 @@ class XPUStaticKernelPickPass : public mir::StmtPass { private: core::KernelPickFactor kernel_pick_factors_; -#ifdef LITE_WITH_XPU + bool xpu_use_fp16_optimizer_{false}; +#ifdef LITE_WITH_XPU // TODO(quwei:) addn more op const std::set PRECISION_INT31_OP_{"__xpu__fc"}; const std::set PRECISION_INT8_OP_{"__xpu__fc"}; @@ -314,7 +319,15 @@ class XPUStaticKernelPickPass : public mir::StmtPass { "gather", "pool2d", "concat", - "calib"}; + "calib", + "relu", + "tanh", + "sigmoid", + "leaky_relu", + "conv2d_transpose", + "elementwise_mul", + "elementwise_add", + "reduce_mean"}; const std::set xpu_inplace_op_{"reshape", "reshape2", "flatten", diff --git a/lite/kernels/xpu/activation_compute.cc b/lite/kernels/xpu/activation_compute.cc index 867acb68205..bb92854f0b8 100644 --- a/lite/kernels/xpu/activation_compute.cc +++ b/lite/kernels/xpu/activation_compute.cc @@ -21,13 +21,14 @@ namespace lite { namespace kernels { namespace xpu { -void ReluCompute::Run() { +template +void ReluCompute::Run() { auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); int r = xdnn::relu(ctx.GetRawContext(), - param.X->data(), - param.Out->mutable_data(TARGET(kXPU)), + param.X->template data(), + param.Out->template mutable_data(TARGET(kXPU)), param.X->numel()); CHECK_EQ(r, 0); } @@ -54,24 +55,26 @@ void GeluCompute::Run() { CHECK_EQ(r, 0); } -void TanhCompute::Run() { +template +void TanhCompute::Run() { auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); int r = xdnn::tanh(ctx.GetRawContext(), - param.X->data(), - param.Out->mutable_data(TARGET(kXPU)), + param.X->template data(), + param.Out->template mutable_data(TARGET(kXPU)), param.X->numel()); CHECK_EQ(r, 0); } -void SigmoidCompute::Run() { +template +void SigmoidCompute::Run() { auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); int r = xdnn::sigmoid(ctx.GetRawContext(), - param.X->data(), - param.Out->mutable_data(TARGET(kXPU)), + param.X->template data(), + param.Out->template mutable_data(TARGET(kXPU)), param.X->numel()); CHECK_EQ(r, 0); } @@ -205,13 +208,13 @@ void HardSigmoidCompute::Run() { CHECK_EQ(r, 0); } -void LeakyReluCompute::Run() { +template +void LeakyReluCompute::Run() { auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); - int r = xdnn::leaky_relu(ctx.GetRawContext(), - param.X->data(), - param.Out->mutable_data(TARGET(kXPU)), + param.X->template data(), + param.Out->template mutable_data(TARGET(kXPU)), param.X->numel(), param.Leaky_relu_alpha); CHECK_EQ(r, 0); @@ -274,12 +277,20 @@ void PReluCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL( - relu, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::ReluCompute, def) +using reluFP32 = + paddle::lite::kernels::xpu::ReluCompute; +using reluFP16 = + paddle::lite::kernels::xpu::ReluCompute; +REGISTER_LITE_KERNEL(relu, kXPU, kFloat, kNCHW, reluFP32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL(relu, kXPU, kFP16, kNCHW, reluFP16, reluFP16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); + REGISTER_LITE_KERNEL( relu6, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::Relu6Compute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) @@ -292,21 +303,31 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); -REGISTER_LITE_KERNEL( - tanh, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::TanhCompute, def) +using tanhFP32 = + paddle::lite::kernels::xpu::TanhCompute; +using tanhFP16 = + paddle::lite::kernels::xpu::TanhCompute; +REGISTER_LITE_KERNEL(tanh, kXPU, kFloat, kNCHW, tanhFP32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL(tanh, kXPU, kFP16, kNCHW, tanhFP16, tanhFP16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); -REGISTER_LITE_KERNEL(sigmoid, - kXPU, - kFloat, - kNCHW, - paddle::lite::kernels::xpu::SigmoidCompute, - def) +using sigmoidFP32 = + paddle::lite::kernels::xpu::SigmoidCompute; +using sigmoidFP16 = + paddle::lite::kernels::xpu::SigmoidCompute; +REGISTER_LITE_KERNEL(sigmoid, kXPU, kFloat, kNCHW, sigmoidFP32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL(sigmoid, kXPU, kFP16, kNCHW, sigmoidFP16, sigmoidFP16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); REGISTER_LITE_KERNEL( abs, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::AbsCompute, def) @@ -386,16 +407,21 @@ REGISTER_LITE_KERNEL(hard_swish, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); -REGISTER_LITE_KERNEL(leaky_relu, - kXPU, - kFloat, - kNCHW, - paddle::lite::kernels::xpu::LeakyReluCompute, - def) +using leaky_reluFP32 = + paddle::lite::kernels::xpu::LeakyReluCompute; +using leaky_reluFP16 = + paddle::lite::kernels::xpu::LeakyReluCompute; +REGISTER_LITE_KERNEL(leaky_relu, kXPU, kFloat, kNCHW, leaky_reluFP32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL( + leaky_relu, kXPU, kFP16, kNCHW, leaky_reluFP16, leaky_reluFP16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); + REGISTER_LITE_KERNEL(softsign, kXPU, kFloat, diff --git a/lite/kernels/xpu/activation_compute.h b/lite/kernels/xpu/activation_compute.h index 057d527ef89..ab47e5ed580 100644 --- a/lite/kernels/xpu/activation_compute.h +++ b/lite/kernels/xpu/activation_compute.h @@ -20,7 +20,8 @@ namespace lite { namespace kernels { namespace xpu { -class ReluCompute : public KernelLite { +template +class ReluCompute : public KernelLite { public: using param_t = operators::ActivationParam; @@ -47,7 +48,8 @@ class GeluCompute : public KernelLite { virtual ~GeluCompute() = default; }; -class TanhCompute : public KernelLite { +template +class TanhCompute : public KernelLite { public: using param_t = operators::ActivationParam; @@ -56,7 +58,8 @@ class TanhCompute : public KernelLite { virtual ~TanhCompute() = default; }; -class SigmoidCompute : public KernelLite { +template +class SigmoidCompute : public KernelLite { public: using param_t = operators::ActivationParam; @@ -164,7 +167,8 @@ class HardSigmoidCompute : public KernelLite { virtual ~HardSigmoidCompute() = default; }; -class LeakyReluCompute : public KernelLite { +template +class LeakyReluCompute : public KernelLite { public: using param_t = operators::ActivationParam; diff --git a/lite/kernels/xpu/conv2d_transpose_compute.cc b/lite/kernels/xpu/conv2d_transpose_compute.cc index 7949b193c56..0ec8532b4bc 100644 --- a/lite/kernels/xpu/conv2d_transpose_compute.cc +++ b/lite/kernels/xpu/conv2d_transpose_compute.cc @@ -22,6 +22,23 @@ namespace lite { namespace kernels { namespace xpu { +template <> +void Conv2dTransposeCompute::PrepareForRun() { + int cur_dev_idx = 0; + + XPU_CALL(xpu_current_device(&cur_dev_idx)); + XPU_CALL(xpu_device_get_attr(&cur_dev_attr_, XPUATTR_MODEL, cur_dev_idx)); + if (cur_dev_attr_ <= 1) { + VLOG(4) << "Currents XPU device : XPU1"; + } else if (cur_dev_attr_ >= 2 && cur_dev_attr_ <= 299) { + VLOG(4) << "Currents XPU device : XPU2"; + } else if (cur_dev_attr_ >= 300 && cur_dev_attr_ <= 599) { + VLOG(4) << "Currents XPU device : XPU3"; + } else { + VLOG(4) << "invaid XPU device"; + } +} + template <> void Conv2dTransposeCompute::Run() { auto& param = this->template Param(); @@ -37,27 +54,53 @@ void Conv2dTransposeCompute::Run() { auto dilations = *param.dilations; if (param.output_padding.empty()) { - int ret = xdnn::conv2d_transpose( - ctx.GetRawContext(), - param.x->data(), - param.filter->data(), - param.output->mutable_data(TARGET(kXPU)), - in_dims[0], - in_dims[1], - in_dims[2], - in_dims[3], - out_dims[1], - std::vector{static_cast(w_dims[2]), - static_cast(w_dims[3])}, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - true); - CHECK_EQ(ret, 0); + if (cur_dev_attr_ <= 1) { + int ret = xdnn::conv2d_transpose( + ctx.GetRawContext(), + param.x->data(), + param.filter->data(), + param.output->mutable_data(TARGET(kXPU)), + in_dims[0], + in_dims[1], + in_dims[2], + in_dims[3], + out_dims[1], + std::vector{static_cast(w_dims[2]), + static_cast(w_dims[3])}, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + true); + CHECK_EQ(ret, 0); + } else { + int ret = xdnn::conv2d_transpose_fusion( + ctx.GetRawContext(), + param.x->data(), + param.filter->data(), + param.output->mutable_data(TARGET(kXPU)), + in_dims[0], + in_dims[1], + in_dims[2], + in_dims[3], + out_dims[1], + std::vector{static_cast(w_dims[2]), + static_cast(w_dims[3])}, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + xdnn::Activation_t::LINEAR, + true); + CHECK_EQ(ret, 0); + } } else { int n = in_dims[0]; int yc = in_dims[1]; diff --git a/lite/kernels/xpu/conv2d_transpose_compute.h b/lite/kernels/xpu/conv2d_transpose_compute.h index 5a3d8714fd4..6e779fc42ad 100644 --- a/lite/kernels/xpu/conv2d_transpose_compute.h +++ b/lite/kernels/xpu/conv2d_transpose_compute.h @@ -28,9 +28,11 @@ class Conv2dTransposeCompute : public KernelLite { public: using param_t = operators::ConvParam; + void PrepareForRun() override; void Run() override; virtual ~Conv2dTransposeCompute() = default; + uint64_t cur_dev_attr_ = 0; }; } // namespace xpu diff --git a/lite/kernels/xpu/elementwise_compute.cc b/lite/kernels/xpu/elementwise_compute.cc index aaf1c913209..4b8e0e158c5 100644 --- a/lite/kernels/xpu/elementwise_compute.cc +++ b/lite/kernels/xpu/elementwise_compute.cc @@ -132,10 +132,15 @@ void ElementwiseCompute::Run() { namespace xpu = paddle::lite::kernels::xpu; using AddFloat32 = xpu::ElementwiseCompute>; +using AddFloat16 = xpu::ElementwiseCompute>; using AddInt32 = xpu::ElementwiseCompute>; using AddInt64 = xpu::ElementwiseCompute>; + using SubFloat32 = xpu::ElementwiseCompute>; + using MulFloat32 = xpu::ElementwiseCompute>; +using MulFloat16 = xpu::ElementwiseCompute>; + using MulInt64 = xpu::ElementwiseCompute>; using DivFloat32 = xpu::ElementwiseCompute>; using MaxFloat32 = xpu::ElementwiseCompute>; @@ -147,6 +152,13 @@ REGISTER_LITE_KERNEL(elementwise_add, kXPU, kFloat, kNCHW, AddFloat32, def) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL( + elementwise_add, kXPU, kFloat, kNCHW, AddFloat16, DISABLE_XPU1_AddFloat16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); + REGISTER_LITE_KERNEL(elementwise_add, kXPU, kFloat, kNCHW, AddInt32, int32) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) @@ -171,6 +183,13 @@ REGISTER_LITE_KERNEL(elementwise_mul, kXPU, kFloat, kNCHW, MulFloat32, def) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL( + elementwise_mul, kXPU, kFloat, kNCHW, MulFloat16, DISABLE_XPU1_MulFloat16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kXPU, kFloat, kNCHW, MulInt64, int64) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) diff --git a/lite/kernels/xpu/reduce_compute.cc b/lite/kernels/xpu/reduce_compute.cc index da2477d48ba..8563ec4d601 100644 --- a/lite/kernels/xpu/reduce_compute.cc +++ b/lite/kernels/xpu/reduce_compute.cc @@ -154,6 +154,8 @@ using ReduceAll = xpu::ReduceCompute>; using ReduceAny = xpu::ReduceCompute>; using ReduceMeanFloat32 = xpu::ReduceCompute>; +using ReduceMeanFloat16 = + xpu::ReduceCompute>; using ReduceSumFloat32 = xpu::ReduceCompute>; using ReduceProdFloat32 = @@ -178,6 +180,16 @@ REGISTER_LITE_KERNEL(reduce_mean, kXPU, kFloat, kNCHW, ReduceMeanFloat32, def) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL(reduce_mean, + kXPU, + kFloat, + kNCHW, + ReduceMeanFloat16, + DISABLE_XPU1_ReduceMeanFloat16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); + REGISTER_LITE_KERNEL(reduce_sum, kXPU, kFloat, kNCHW, ReduceSumFloat32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) diff --git a/lite/kernels/xpu/reshape_compute.cc b/lite/kernels/xpu/reshape_compute.cc index 78359443991..c82e367e9eb 100644 --- a/lite/kernels/xpu/reshape_compute.cc +++ b/lite/kernels/xpu/reshape_compute.cc @@ -69,6 +69,21 @@ REGISTER_LITE_KERNEL(reshape2, .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); +REGISTER_LITE_KERNEL(reshape2, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::ReshapeCompute, + float16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); + REGISTER_LITE_KERNEL(reshape2, kXPU, kFloat, @@ -113,6 +128,20 @@ REGISTER_LITE_KERNEL(reshape, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL(reshape, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::ReshapeCompute, + float16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); + REGISTER_LITE_KERNEL(flatten, kXPU, kFloat, @@ -125,6 +154,18 @@ REGISTER_LITE_KERNEL(flatten, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); +REGISTER_LITE_KERNEL(flatten, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::ReshapeCompute, + float16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); + REGISTER_LITE_KERNEL(flatten2, kXPU, kFloat, @@ -137,3 +178,16 @@ REGISTER_LITE_KERNEL(flatten2, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); + +REGISTER_LITE_KERNEL(flatten2, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::ReshapeCompute, + float16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); diff --git a/lite/kernels/xpu/transpose_compute.cc b/lite/kernels/xpu/transpose_compute.cc index d1c9553ba71..19441de2849 100644 --- a/lite/kernels/xpu/transpose_compute.cc +++ b/lite/kernels/xpu/transpose_compute.cc @@ -75,6 +75,18 @@ REGISTER_LITE_KERNEL(transpose2, .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); +REGISTER_LITE_KERNEL(transpose2, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::TransposeCompute, + def_int32) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .BindOutput("XShape", + {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .Finalize(); + REGISTER_LITE_KERNEL(transpose2, kXPU, kFloat,