Skip to content

Commit ccd7a5b

Browse files
wbn03newway
authored andcommitted
[XPU] fixed the bug of tile op in large input and add XPU implementation. (PaddlePaddle#9102)
1 parent 8f09eb2 commit ccd7a5b

File tree

6 files changed

+123
-3
lines changed

6 files changed

+123
-3
lines changed

lite/kernels/host/tile_compute.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ void TileCompute<T, PType>::Run() {
8585
int dst_stride = in_stride[i + 1] * right;
8686
for (int m = 0; m < num; m++) {
8787
for (int j = 0; j < bcast_dims[i]; j++) {
88-
std::memcpy(tmp_dst + j * dst_stride / bcast_dims[i] + m * dst_stride,
89-
tmp_src + m * dst_stride / bcast_dims[i],
90-
dst_stride / bcast_dims[i] * sizeof(T));
88+
std::memcpy(
89+
tmp_dst + j * (dst_stride / bcast_dims[i]) + m * dst_stride,
90+
tmp_src + m * (dst_stride / bcast_dims[i]),
91+
dst_stride / bcast_dims[i] * sizeof(T));
9192
}
9293
}
9394
tmp_src_tensor.CopyDataFrom(tmp_dst_tensor);

lite/kernels/xpu/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ add_kernel(gru_compute_xpu XPU basic SRCS gru_compute.cc)
3030
add_kernel(gru_unit_compute_xpu XPU basic SRCS gru_unit_compute.cc)
3131
add_kernel(stack_compute_xpu XPU basic SRCS stack_compute.cc)
3232
add_kernel(slice_compute_xpu XPU basic SRCS slice_compute.cc)
33+
add_kernel(tile_compute_xpu XPU basic SRCS tile_compute.cc)
3334
add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc)
3435
add_kernel(sequence_topk_avg_pooling_compute_xpu XPU basic SRCS sequence_topk_avg_pooling_compute.cc)
3536
add_kernel(concat_compute_xpu XPU basic SRCS concat_compute.cc)

lite/kernels/xpu/tile_compute.cc

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/tile_compute.h"
16+
#include <vector>
17+
#include "lite/backends/xpu/xpu_header_sitter.h"
18+
#include "lite/core/op_registry.h"
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace kernels {
23+
namespace xpu {
24+
25+
template <typename T, PrecisionType PType>
26+
void TileCompute<T, PType>::Run() {
27+
auto& param = this->template Param<param_t>();
28+
auto& ctx = this->ctx_->template As<XPUContext>();
29+
auto repeat_times = param.repeat_times;
30+
if (param.RepeatTimes) {
31+
auto repeat_times_size = param.RepeatTimes->data_size();
32+
for (int64_t i = 0; i < repeat_times_size; i++) {
33+
repeat_times.push_back(param.RepeatTimes->template data<int>()[i]);
34+
}
35+
} else if (param.repeat_times_tensor.size() != 0) {
36+
for (int i = 0; i < param.repeat_times_tensor.size(); i++) {
37+
auto temp = param.repeat_times_tensor[i];
38+
repeat_times.push_back(*(temp->template data<int>()));
39+
}
40+
}
41+
auto in_dims = param.X->dims();
42+
auto vec_in_dims = in_dims.Vectorize();
43+
// broadcast for vec_in_dims.size() equal to repeat_times.size()
44+
if (repeat_times.size() < vec_in_dims.size()) {
45+
int diff = vec_in_dims.size() - repeat_times.size();
46+
repeat_times.insert(repeat_times.begin(), diff, 1);
47+
} else {
48+
int diff = repeat_times.size() - vec_in_dims.size();
49+
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
50+
}
51+
52+
std::vector<int> new_in_dims(vec_in_dims.begin(), vec_in_dims.end());
53+
std::vector<int> out_dims(param.Out->dims().data().begin(),
54+
param.Out->dims().data().end());
55+
int r = xdnn::broadcast<T>(ctx.GetRawContext(),
56+
param.X->template data<T>(),
57+
param.Out->template mutable_data<T>(TARGET(kXPU)),
58+
new_in_dims,
59+
out_dims);
60+
61+
CHECK_EQ(r, 0);
62+
}
63+
64+
} // namespace xpu
65+
} // namespace kernels
66+
} // namespace lite
67+
} // namespace paddle
68+
69+
using tile_float =
70+
paddle::lite::kernels::xpu::TileCompute<float, PRECISION(kFloat)>;
71+
REGISTER_LITE_KERNEL(tile, kXPU, kFloat, kNCHW, tile_float, def)
72+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
73+
.BindInput("RepeatTimes",
74+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
75+
.BindInput("repeat_times_tensor",
76+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
77+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
78+
.Finalize();

lite/kernels/xpu/tile_compute.h

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace kernels {
21+
namespace xpu {
22+
23+
template <typename T, PrecisionType PType>
24+
class TileCompute : public KernelLite<TARGET(kXPU), PType> {
25+
public:
26+
using param_t = operators::TileParam;
27+
28+
virtual void Run();
29+
30+
virtual ~TileCompute() = default;
31+
};
32+
33+
} // namespace xpu
34+
} // namespace kernels
35+
} // namespace lite
36+
} // namespace paddle

lite/operators/tile_op.cc

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ bool TileOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
118118
} else if (opdesc.HasInput("repeat_times_tensor") &&
119119
(opdesc.Input("repeat_times_tensor").size() != 0)) {
120120
auto temp = opdesc.Input("repeat_times_tensor");
121+
param_.repeat_times_tensor.clear();
121122
for (auto var : temp) {
122123
param_.repeat_times_tensor.push_back(
123124
scope->FindVar(var)->GetMutable<lite::Tensor>());

lite/tests/kernels/tile_compute_test.cc

+3
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ TEST(tile, precision) {
199199
#else
200200
return;
201201
#endif
202+
#elif defined(LITE_WITH_XPU)
203+
place = TARGET(kXPU);
204+
alias = "def";
202205
#elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86)
203206
place = TARGET(kHost);
204207
#else

0 commit comments

Comments
 (0)