|
| 1 | +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "lite/kernels/xpu/tile_compute.h" |
| 16 | +#include <vector> |
| 17 | +#include "lite/backends/xpu/xpu_header_sitter.h" |
| 18 | +#include "lite/core/op_registry.h" |
| 19 | + |
| 20 | +namespace paddle { |
| 21 | +namespace lite { |
| 22 | +namespace kernels { |
| 23 | +namespace xpu { |
| 24 | + |
| 25 | +template <typename T, PrecisionType PType> |
| 26 | +void TileCompute<T, PType>::Run() { |
| 27 | + auto& param = this->template Param<param_t>(); |
| 28 | + auto& ctx = this->ctx_->template As<XPUContext>(); |
| 29 | + auto repeat_times = param.repeat_times; |
| 30 | + if (param.RepeatTimes) { |
| 31 | + auto repeat_times_size = param.RepeatTimes->data_size(); |
| 32 | + for (int64_t i = 0; i < repeat_times_size; i++) { |
| 33 | + repeat_times.push_back(param.RepeatTimes->template data<int>()[i]); |
| 34 | + } |
| 35 | + } else if (param.repeat_times_tensor.size() != 0) { |
| 36 | + for (int i = 0; i < param.repeat_times_tensor.size(); i++) { |
| 37 | + auto temp = param.repeat_times_tensor[i]; |
| 38 | + repeat_times.push_back(*(temp->template data<int>())); |
| 39 | + } |
| 40 | + } |
| 41 | + auto in_dims = param.X->dims(); |
| 42 | + auto vec_in_dims = in_dims.Vectorize(); |
| 43 | + // broadcast for vec_in_dims.size() equal to repeat_times.size() |
| 44 | + if (repeat_times.size() < vec_in_dims.size()) { |
| 45 | + int diff = vec_in_dims.size() - repeat_times.size(); |
| 46 | + repeat_times.insert(repeat_times.begin(), diff, 1); |
| 47 | + } else { |
| 48 | + int diff = repeat_times.size() - vec_in_dims.size(); |
| 49 | + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); |
| 50 | + } |
| 51 | + |
| 52 | + std::vector<int> new_in_dims(vec_in_dims.begin(), vec_in_dims.end()); |
| 53 | + std::vector<int> out_dims(param.Out->dims().data().begin(), |
| 54 | + param.Out->dims().data().end()); |
| 55 | + int r = xdnn::broadcast<T>(ctx.GetRawContext(), |
| 56 | + param.X->template data<T>(), |
| 57 | + param.Out->template mutable_data<T>(TARGET(kXPU)), |
| 58 | + new_in_dims, |
| 59 | + out_dims); |
| 60 | + |
| 61 | + CHECK_EQ(r, 0); |
| 62 | +} |
| 63 | + |
| 64 | +} // namespace xpu |
| 65 | +} // namespace kernels |
| 66 | +} // namespace lite |
| 67 | +} // namespace paddle |
| 68 | + |
| 69 | +using tile_float = |
| 70 | + paddle::lite::kernels::xpu::TileCompute<float, PRECISION(kFloat)>; |
| 71 | +REGISTER_LITE_KERNEL(tile, kXPU, kFloat, kNCHW, tile_float, def) |
| 72 | + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 73 | + .BindInput("RepeatTimes", |
| 74 | + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) |
| 75 | + .BindInput("repeat_times_tensor", |
| 76 | + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) |
| 77 | + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 78 | + .Finalize(); |
0 commit comments