Skip to content

Commit 392f3a9

Browse files
committed
[API] fix paddle.cross with big tensor
1 parent f07ca44 commit 392f3a9

File tree

5 files changed

+140
-86
lines changed

5 files changed

+140
-86
lines changed

paddle/phi/kernels/cpu/cross_grad_kernel.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ void CrossGradKernel(const Context &dev_ctx,
7474
"But received: Input(X/Y).dims() == [%s].",
7575
input_x_dims));
7676
}
77-
auto outer_loops = 1;
78-
for (auto i = 0; i < dim; i++) {
79-
outer_loops *= static_cast<int>(input_x_dims[i]);
77+
int64_t outer_loops = 1;
78+
for (int i = 0; i < dim; i++) {
79+
outer_loops *= input_x_dims[i];
8080
}
81-
auto slice_size = 1;
82-
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
83-
slice_size *= static_cast<int>(input_x_dims[i]);
81+
int64_t slice_size = 1;
82+
for (int i = dim + 1; i < input_x_dims.size(); i++) {
83+
slice_size *= input_x_dims[i];
8484
}
8585

8686
int64_t numel = x.numel();
@@ -111,12 +111,12 @@ void CrossGradKernel(const Context &dev_ctx,
111111
dev_ctx.template Alloc<T>(output_x_grad);
112112
dev_ctx.template Alloc<T>(output_y_grad);
113113

114-
for (auto i = 0; i < outer_loops; i++) {
115-
for (auto j = 0; j < 3; j++) {
116-
auto dst_pos = (3 * i + j) * slice_size;
117-
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
118-
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
119-
for (auto k = 0; k < slice_size; k++) {
114+
for (int64_t i = 0; i < outer_loops; i++) {
115+
for (int64_t j = 0; j < 3; j++) {
116+
int64_t dst_pos = (3 * i + j) * slice_size;
117+
int64_t in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
118+
int64_t in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
119+
for (int64_t k = 0; k < slice_size; k++) {
120120
out_dx_vec[dst_pos + k] =
121121
input_dout_vec[in_pos2 + k] * input_y_vec[in_pos1 + k] -
122122
input_dout_vec[in_pos1 + k] * input_y_vec[in_pos2 + k];

paddle/phi/kernels/cpu/cross_kernel.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ void CrossKernel(const Context& dev_ctx,
7575
dev_ctx.template Alloc<T>(output);
7676
return;
7777
}
78-
auto outer_loops = 1;
78+
int64_t outer_loops = 1;
7979
for (auto i = 0; i < dim; i++) {
80-
outer_loops *= static_cast<int>(input_x_dims[i]);
80+
outer_loops *= input_x_dims[i];
8181
}
82-
auto slice_size = 1;
82+
int64_t slice_size = 1;
8383
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
84-
slice_size *= static_cast<int>(input_x_dims[i]);
84+
slice_size *= input_x_dims[i];
8585
}
8686

8787
std::vector<T> input_x_vec, input_y_vec;
@@ -91,13 +91,13 @@ void CrossKernel(const Context& dev_ctx,
9191

9292
dev_ctx.template Alloc<T>(output);
9393

94-
for (auto i = 0; i < outer_loops; i++) {
95-
for (auto j = 0; j < 3; j++) {
96-
auto dst_pos = (3 * i + j) * slice_size;
97-
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
98-
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
94+
for (int64_t i = 0; i < outer_loops; i++) {
95+
for (int64_t j = 0; j < 3; j++) {
96+
int64_t dst_pos = (3 * i + j) * slice_size;
97+
int64_t in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
98+
int64_t in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
9999

100-
for (auto k = 0; k < slice_size; k++) {
100+
for (int64_t k = 0; k < slice_size; k++) {
101101
out_vec[dst_pos + k] =
102102
input_x_vec[in_pos1 + k] * input_y_vec[in_pos2 + k] -
103103
input_x_vec[in_pos2 + k] * input_y_vec[in_pos1 + k];

paddle/phi/kernels/gpu/class_center_sample_kernel.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ namespace cub = hipcub;
4141
#include "paddle/phi/core/kernel_registry.h"
4242

4343
namespace phi {
44-
#define CUDA_KERNEL_LOOP(i, n) \
45-
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, \
46-
step = blockDim.x * gridDim.x; \
47-
i < (n); \
44+
#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
45+
for (index_type i = blockIdx.x * blockDim.x + threadIdx.x, \
46+
step = blockDim.x * gridDim.x; \
47+
i < (n); \
4848
i += step)
4949

50+
#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int32_t)
51+
5052
static constexpr int kNumCUDAThreads = 512;
5153
static constexpr int kNumMaximumNumBlocks = 4096;
5254

paddle/phi/kernels/gpu/cross_grad_kernel.cu

Lines changed: 77 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,22 @@
2424

2525
namespace phi {
2626

27-
template <typename T>
28-
__global__ void CrossGrad(const T* x,
29-
const T* y,
30-
const T* out,
31-
T* out_dx,
32-
T* out_dy,
33-
const int64_t stride,
34-
const int64_t N,
35-
phi::funcs::IndexCalculator<int> index_calculator) {
36-
CUDA_KERNEL_LOOP(i, N) {
37-
int64_t offset = index_calculator(i);
38-
39-
auto pos0 = offset + 0 * stride;
40-
auto pos1 = offset + 1 * stride;
41-
auto pos2 = offset + 2 * stride;
27+
template <typename T, typename IndexType>
28+
__global__ void CrossGrad(
29+
const T* x,
30+
const T* y,
31+
const T* out,
32+
T* out_dx,
33+
T* out_dy,
34+
const IndexType stride,
35+
const IndexType N,
36+
phi::funcs::IndexCalculator<IndexType> index_calculator) {
37+
CUDA_KERNEL_LOOP_TYPE(i, N, IndexType) {
38+
IndexType offset = index_calculator(i);
39+
40+
IndexType pos0 = offset + 0 * stride;
41+
IndexType pos1 = offset + 1 * stride;
42+
IndexType pos2 = offset + 2 * stride;
4243

4344
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
4445

@@ -168,11 +169,10 @@ void CrossGradKernel(const Context& dev_ctx,
168169
const auto* input_out_grad_data = input_out_grad.data<T>();
169170
auto* output_x_grad_data = dev_ctx.template Alloc<T>(x_grad);
170171
auto* output_y_grad_data = dev_ctx.template Alloc<T>(y_grad);
171-
auto index_calculator = phi::funcs::IndexCalculator<int>(
172-
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
173172

174173
backends::gpu::GpuLaunchConfig config =
175174
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);
175+
constexpr int64_t int_max = std::numeric_limits<int>::max();
176176
if (IsComplexType(x.dtype())) {
177177
DenseTensor x_conj, y_conj;
178178
DenseTensorMeta meta_xy(x.dtype(), x.dims());
@@ -189,30 +189,67 @@ void CrossGradKernel(const Context& dev_ctx,
189189
input_y_data, numel, input_y_conj_data);
190190
for_range(functor_x);
191191
for_range(functor_y);
192-
193-
CrossGrad<<<config.block_per_grid,
194-
config.thread_per_block,
195-
0,
196-
dev_ctx.stream()>>>(input_x_conj_data,
197-
input_y_conj_data,
198-
input_out_grad_data,
199-
output_x_grad_data,
200-
output_y_grad_data,
201-
full_strides[merge_axis],
202-
numel / 3,
203-
index_calculator);
192+
if (full_strides[merge_axis] * 2 > int_max || numel / 3 > int_max) {
193+
auto index_calculator = phi::funcs::IndexCalculator<int64_t>(
194+
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
195+
CrossGrad<<<config.block_per_grid,
196+
config.thread_per_block,
197+
0,
198+
dev_ctx.stream()>>>(input_x_conj_data,
199+
input_y_conj_data,
200+
input_out_grad_data,
201+
output_x_grad_data,
202+
output_y_grad_data,
203+
full_strides[merge_axis],
204+
numel / 3,
205+
index_calculator);
206+
} else {
207+
auto index_calculator = phi::funcs::IndexCalculator<int32_t>(
208+
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
209+
CrossGrad<<<config.block_per_grid,
210+
config.thread_per_block,
211+
0,
212+
dev_ctx.stream()>>>(
213+
input_x_conj_data,
214+
input_y_conj_data,
215+
input_out_grad_data,
216+
output_x_grad_data,
217+
output_y_grad_data,
218+
static_cast<int32_t>(full_strides[merge_axis]),
219+
static_cast<int32_t>(numel / 3),
220+
index_calculator);
221+
}
204222
} else {
205-
CrossGrad<<<config.block_per_grid,
206-
config.thread_per_block,
207-
0,
208-
dev_ctx.stream()>>>(input_x_data,
209-
input_y_data,
210-
input_out_grad_data,
211-
output_x_grad_data,
212-
output_y_grad_data,
213-
full_strides[merge_axis],
214-
numel / 3,
215-
index_calculator);
223+
if (full_strides[merge_axis] * 2 > int_max || numel / 3 > int_max) {
224+
auto index_calculator = phi::funcs::IndexCalculator<int64_t>(
225+
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
226+
CrossGrad<<<config.block_per_grid,
227+
config.thread_per_block,
228+
0,
229+
dev_ctx.stream()>>>(input_x_data,
230+
input_y_data,
231+
input_out_grad_data,
232+
output_x_grad_data,
233+
output_y_grad_data,
234+
full_strides[merge_axis],
235+
numel / 3,
236+
index_calculator);
237+
} else {
238+
auto index_calculator = phi::funcs::IndexCalculator<int32_t>(
239+
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
240+
CrossGrad<<<config.block_per_grid,
241+
config.thread_per_block,
242+
0,
243+
dev_ctx.stream()>>>(
244+
input_x_data,
245+
input_y_data,
246+
input_out_grad_data,
247+
output_x_grad_data,
248+
output_y_grad_data,
249+
static_cast<int32_t>(full_strides[merge_axis]),
250+
static_cast<int32_t>(numel / 3),
251+
index_calculator);
252+
}
216253
}
217254
}
218255
} // namespace phi

paddle/phi/kernels/gpu/cross_kernel.cu

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@
2323

2424
namespace phi {
2525

26-
template <typename T>
26+
template <typename T, typename IndexType>
2727
__global__ void Cross(const T* x,
2828
const T* y,
2929
T* out,
30-
const int64_t stride,
31-
const int64_t N,
32-
phi::funcs::IndexCalculator<int> index_calculator) {
33-
CUDA_KERNEL_LOOP(i, N) {
34-
int64_t offset = index_calculator(i);
30+
const IndexType stride,
31+
const IndexType N,
32+
phi::funcs::IndexCalculator<IndexType> index_calculator) {
33+
CUDA_KERNEL_LOOP_TYPE(i, N, IndexType) {
34+
IndexType offset = index_calculator(i);
3535

36-
auto pos0 = offset + 0 * stride;
37-
auto pos1 = offset + 1 * stride;
38-
auto pos2 = offset + 2 * stride;
36+
IndexType pos0 = offset + 0 * stride;
37+
IndexType pos1 = offset + 1 * stride;
38+
IndexType pos2 = offset + 2 * stride;
3939

4040
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
4141

@@ -149,22 +149,37 @@ void CrossKernel(const Context& dev_ctx,
149149
const auto* input_x_data = input_x.data<T>();
150150
const auto* input_y_data = input_y.data<T>();
151151
auto* out_data = dev_ctx.template Alloc<T>(out);
152-
auto index_calculator = phi::funcs::IndexCalculator<int>(
153-
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
154152

155153
int64_t numel = x.numel();
156154
backends::gpu::GpuLaunchConfig config =
157155
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);
158156

159-
Cross<<<config.block_per_grid,
160-
config.thread_per_block,
161-
0,
162-
dev_ctx.stream()>>>(input_x_data,
163-
input_y_data,
164-
out_data,
165-
full_strides[merge_axis],
166-
numel / 3,
167-
index_calculator);
157+
constexpr int64_t int_max = std::numeric_limits<int>::max();
158+
if (full_strides[merge_axis] * 2 > int_max || numel / 3 > int_max) {
159+
auto index_calculator = phi::funcs::IndexCalculator<int64_t>(
160+
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
161+
Cross<<<config.block_per_grid,
162+
config.thread_per_block,
163+
0,
164+
dev_ctx.stream()>>>(input_x_data,
165+
input_y_data,
166+
out_data,
167+
full_strides[merge_axis],
168+
numel / 3,
169+
index_calculator);
170+
} else {
171+
auto index_calculator = phi::funcs::IndexCalculator<int32_t>(
172+
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
173+
Cross<<<config.block_per_grid,
174+
config.thread_per_block,
175+
0,
176+
dev_ctx.stream()>>>(input_x_data,
177+
input_y_data,
178+
out_data,
179+
static_cast<int32_t>(full_strides[merge_axis]),
180+
static_cast<int32_t>(numel / 3),
181+
index_calculator);
182+
}
168183
}
169184
} // namespace phi
170185

0 commit comments

Comments
 (0)