Skip to content

Commit 2a14346

Browse files
committed
[API] fix paddle.cross with big tensor
1 parent 3283385 commit 2a14346

File tree

4 files changed

+40
-39
lines changed

4 files changed

+40
-39
lines changed

paddle/phi/kernels/cpu/cross_grad_kernel.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ void CrossGradKernel(const Context &dev_ctx,
7474
"But received: Input(X/Y).dims() == [%s].",
7575
input_x_dims));
7676
}
77-
auto outer_loops = 1;
78-
for (auto i = 0; i < dim; i++) {
79-
outer_loops *= static_cast<int>(input_x_dims[i]);
77+
int64_t outer_loops = 1;
78+
for (int64_t i = 0; i < dim; i++) {
79+
outer_loops *= input_x_dims[i];
8080
}
81-
auto slice_size = 1;
82-
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
83-
slice_size *= static_cast<int>(input_x_dims[i]);
81+
int64_t slice_size = 1;
82+
for (int64_t i = dim + 1; i < input_x_dims.size(); i++) {
83+
slice_size *= input_x_dims[i];
8484
}
8585

8686
int64_t numel = x.numel();
@@ -111,12 +111,12 @@ void CrossGradKernel(const Context &dev_ctx,
111111
dev_ctx.template Alloc<T>(output_x_grad);
112112
dev_ctx.template Alloc<T>(output_y_grad);
113113

114-
for (auto i = 0; i < outer_loops; i++) {
115-
for (auto j = 0; j < 3; j++) {
116-
auto dst_pos = (3 * i + j) * slice_size;
117-
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
118-
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
119-
for (auto k = 0; k < slice_size; k++) {
114+
for (int64_t i = 0; i < outer_loops; i++) {
115+
for (int64_t j = 0; j < 3; j++) {
116+
int64_t dst_pos = (3 * i + j) * slice_size;
117+
int64_t in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
118+
int64_t in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
119+
for (int64_t k = 0; k < slice_size; k++) {
120120
out_dx_vec[dst_pos + k] =
121121
input_dout_vec[in_pos2 + k] * input_y_vec[in_pos1 + k] -
122122
input_dout_vec[in_pos1 + k] * input_y_vec[in_pos2 + k];

paddle/phi/kernels/cpu/cross_kernel.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ void CrossKernel(const Context& dev_ctx,
7575
dev_ctx.template Alloc<T>(output);
7676
return;
7777
}
78-
auto outer_loops = 1;
78+
int64_t outer_loops = 1;
7979
for (auto i = 0; i < dim; i++) {
80-
outer_loops *= static_cast<int>(input_x_dims[i]);
80+
outer_loops *= input_x_dims[i];
8181
}
82-
auto slice_size = 1;
82+
int64_t slice_size = 1;
8383
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
84-
slice_size *= static_cast<int>(input_x_dims[i]);
84+
slice_size *= input_x_dims[i];
8585
}
8686

8787
std::vector<T> input_x_vec, input_y_vec;
@@ -91,13 +91,13 @@ void CrossKernel(const Context& dev_ctx,
9191

9292
dev_ctx.template Alloc<T>(output);
9393

94-
for (auto i = 0; i < outer_loops; i++) {
95-
for (auto j = 0; j < 3; j++) {
96-
auto dst_pos = (3 * i + j) * slice_size;
97-
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
98-
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
94+
for (int64_t i = 0; i < outer_loops; i++) {
95+
for (int64_t j = 0; j < 3; j++) {
96+
int64_t dst_pos = (3 * i + j) * slice_size;
97+
int64_t in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
98+
int64_t in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
9999

100-
for (auto k = 0; k < slice_size; k++) {
100+
for (int64_t k = 0; k < slice_size; k++) {
101101
out_vec[dst_pos + k] =
102102
input_x_vec[in_pos1 + k] * input_y_vec[in_pos2 + k] -
103103
input_x_vec[in_pos2 + k] * input_y_vec[in_pos1 + k];

paddle/phi/kernels/gpu/cross_grad_kernel.cu

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,21 @@
2525
namespace phi {
2626

2727
template <typename T>
28-
__global__ void CrossGrad(const T* x,
29-
const T* y,
30-
const T* out,
31-
T* out_dx,
32-
T* out_dy,
33-
const int64_t stride,
34-
const int64_t N,
35-
phi::funcs::IndexCalculator<int> index_calculator) {
28+
__global__ void CrossGrad(
29+
const T* x,
30+
const T* y,
31+
const T* out,
32+
T* out_dx,
33+
T* out_dy,
34+
const int64_t stride,
35+
const int64_t N,
36+
phi::funcs::IndexCalculator<int64_t> index_calculator) {
3637
CUDA_KERNEL_LOOP(i, N) {
3738
int64_t offset = index_calculator(i);
3839

39-
auto pos0 = offset + 0 * stride;
40-
auto pos1 = offset + 1 * stride;
41-
auto pos2 = offset + 2 * stride;
40+
int64_t pos0 = offset + 0 * stride;
41+
int64_t pos1 = offset + 1 * stride;
42+
int64_t pos2 = offset + 2 * stride;
4243

4344
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
4445

@@ -168,7 +169,7 @@ void CrossGradKernel(const Context& dev_ctx,
168169
const auto* input_out_grad_data = input_out_grad.data<T>();
169170
auto* output_x_grad_data = dev_ctx.template Alloc<T>(x_grad);
170171
auto* output_y_grad_data = dev_ctx.template Alloc<T>(y_grad);
171-
auto index_calculator = phi::funcs::IndexCalculator<int>(
172+
auto index_calculator = phi::funcs::IndexCalculator<int64_t>(
172173
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
173174

174175
backends::gpu::GpuLaunchConfig config =

paddle/phi/kernels/gpu/cross_kernel.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ __global__ void Cross(const T* x,
2929
T* out,
3030
const int64_t stride,
3131
const int64_t N,
32-
phi::funcs::IndexCalculator<int> index_calculator) {
32+
phi::funcs::IndexCalculator<int64_t> index_calculator) {
3333
CUDA_KERNEL_LOOP(i, N) {
3434
int64_t offset = index_calculator(i);
3535

36-
auto pos0 = offset + 0 * stride;
37-
auto pos1 = offset + 1 * stride;
38-
auto pos2 = offset + 2 * stride;
36+
int64_t pos0 = offset + 0 * stride;
37+
int64_t pos1 = offset + 1 * stride;
38+
int64_t pos2 = offset + 2 * stride;
3939

4040
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
4141

@@ -149,7 +149,7 @@ void CrossKernel(const Context& dev_ctx,
149149
const auto* input_x_data = input_x.data<T>();
150150
const auto* input_y_data = input_y.data<T>();
151151
auto* out_data = dev_ctx.template Alloc<T>(out);
152-
auto index_calculator = phi::funcs::IndexCalculator<int>(
152+
auto index_calculator = phi::funcs::IndexCalculator<int64_t>(
153153
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
154154

155155
int64_t numel = x.numel();

0 commit comments

Comments
 (0)