Skip to content

Commit d8c1ad1

Browse files
committed
[API] isnan isinf isfinite support bigtensor (PaddlePaddle#72517)
* isnan isinf isfinite support bigtensor * refine
1 parent f179a6a commit d8c1ad1

File tree

1 file changed

+89
-74
lines changed

1 file changed

+89
-74
lines changed

paddle/phi/kernels/impl/isfinite_kernel_impl.h

Lines changed: 89 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ struct IsfiniteFunctor<
7878
const DenseTensor& in,
7979
DenseTensor* output) {
8080
auto* out_data = ctx.template Alloc<bool>(output);
81-
auto num = in.numel();
82-
for (int i = 0; i < num; i++) {
81+
int64_t num = in.numel();
82+
for (int64_t i = 0; i < num; i++) {
8383
out_data[i] = true;
8484
}
8585
}
@@ -95,8 +95,8 @@ struct IsfiniteFunctor<
9595
DenseTensor* output) {
9696
auto* in_a = in.data<T>();
9797
auto* out_data = ctx.template Alloc<bool>(output);
98-
auto num = in.numel();
99-
for (int i = 0; i < num; i++) {
98+
int64_t num = in.numel();
99+
for (int64_t i = 0; i < num; i++) {
100100
const T& a = in_a[i];
101101
out_data[i] = std::isfinite(a);
102102
}
@@ -113,8 +113,8 @@ struct IsfiniteFunctor<
113113
DenseTensor* output) {
114114
auto* in_a = in.data<T>();
115115
auto* out_data = ctx.template Alloc<bool>(output);
116-
auto num = in.numel();
117-
for (int i = 0; i < num; i++) {
116+
int64_t num = in.numel();
117+
for (int64_t i = 0; i < num; i++) {
118118
const T& a = in_a[i];
119119
out_data[i] = phi::dtype::isfinite(a);
120120
}
@@ -131,8 +131,8 @@ struct IsfiniteFunctor<
131131
DenseTensor* output) {
132132
auto* in_a = in.data<T>();
133133
auto* out_data = ctx.template Alloc<bool>(output);
134-
auto num = in.numel();
135-
for (int i = 0; i < num; i++) {
134+
int64_t num = in.numel();
135+
for (int64_t i = 0; i < num; i++) {
136136
const T& a = in_a[i];
137137
out_data[i] = std::isfinite(a.real) && std::isfinite(a.imag);
138138
}
@@ -157,8 +157,8 @@ struct IsnanFunctor<
157157
const DenseTensor& in,
158158
DenseTensor* output) {
159159
auto* out_data = ctx.template Alloc<bool>(output);
160-
auto num = in.numel();
161-
for (int i = 0; i < num; i++) {
160+
int64_t num = in.numel();
161+
for (int64_t i = 0; i < num; i++) {
162162
out_data[i] = false;
163163
}
164164
}
@@ -174,8 +174,8 @@ struct IsnanFunctor<
174174
DenseTensor* output) {
175175
auto* in_a = in.data<T>();
176176
auto* out_data = ctx.template Alloc<bool>(output);
177-
auto num = in.numel();
178-
for (int i = 0; i < num; i++) {
177+
int64_t num = in.numel();
178+
for (int64_t i = 0; i < num; i++) {
179179
const T& a = in_a[i];
180180
out_data[i] = std::isnan(a);
181181
}
@@ -191,8 +191,8 @@ struct IsnanFunctor<phi::CPUContext,
191191
DenseTensor* output) {
192192
auto* in_a = in.data<T>();
193193
auto* out_data = ctx.template Alloc<bool>(output);
194-
auto num = in.numel();
195-
for (int i = 0; i < num; i++) {
194+
int64_t num = in.numel();
195+
for (int64_t i = 0; i < num; i++) {
196196
const T& a = in_a[i];
197197
out_data[i] = phi::dtype::isnan(a);
198198
}
@@ -209,8 +209,8 @@ struct IsnanFunctor<
209209
DenseTensor* output) {
210210
auto* in_a = in.data<T>();
211211
auto* out_data = ctx.template Alloc<bool>(output);
212-
auto num = in.numel();
213-
for (int i = 0; i < num; i++) {
212+
int64_t num = in.numel();
213+
for (int64_t i = 0; i < num; i++) {
214214
const T& a = in_a[i];
215215
out_data[i] = std::isnan(a.real) || std::isnan(a.imag);
216216
}
@@ -236,7 +236,7 @@ struct IsinfFunctor<
236236
DenseTensor* output) {
237237
auto* out_data = ctx.template Alloc<bool>(output);
238238
auto num = in.numel();
239-
for (int i = 0; i < num; i++) {
239+
for (int64_t i = 0; i < num; i++) {
240240
out_data[i] = false;
241241
}
242242
}
@@ -252,8 +252,8 @@ struct IsinfFunctor<
252252
DenseTensor* output) {
253253
auto* in_a = in.data<T>();
254254
auto* out_data = ctx.template Alloc<bool>(output);
255-
auto num = in.numel();
256-
for (int i = 0; i < num; i++) {
255+
int64_t num = in.numel();
256+
for (int64_t i = 0; i < num; i++) {
257257
const T& a = in_a[i];
258258
out_data[i] = std::isinf(a);
259259
}
@@ -269,8 +269,8 @@ struct IsinfFunctor<phi::CPUContext,
269269
DenseTensor* output) {
270270
auto* in_a = in.data<T>();
271271
auto* out_data = ctx.template Alloc<bool>(output);
272-
auto num = in.numel();
273-
for (int i = 0; i < num; i++) {
272+
int64_t num = in.numel();
273+
for (int64_t i = 0; i < num; i++) {
274274
const T& a = in_a[i];
275275
out_data[i] = phi::dtype::isinf(a);
276276
}
@@ -287,8 +287,8 @@ struct IsinfFunctor<
287287
DenseTensor* output) {
288288
auto* in_a = in.data<T>();
289289
auto* out_data = ctx.template Alloc<bool>(output);
290-
auto num = in.numel();
291-
for (int i = 0; i < num; i++) {
290+
int64_t num = in.numel();
291+
for (int64_t i = 0; i < num; i++) {
292292
const T& a = in_a[i];
293293
out_data[i] = std::isinf(a.real) || std::isinf(a.imag);
294294
}
@@ -297,117 +297,117 @@ struct IsinfFunctor<
297297

298298
#if defined(__NVCC__) || defined(__HIPCC__)
299299
/* IsfiniteFunctor */
300-
template <typename T>
300+
template <typename T, typename IndexType>
301301
__global__ void IsfiniteCUDAKernel(
302302
const T* in_data,
303-
int num,
303+
IndexType num,
304304
bool* out_data,
305305
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
306-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
307-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
306+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
307+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
308308
const T& a = in_data[i];
309309
out_data[i] = isfinite(a);
310310
}
311311
}
312312

313-
template <typename T>
313+
template <typename T, typename IndexType>
314314
__global__ void IsfiniteCUDAKernel(
315315
const T* in_data,
316-
int num,
316+
IndexType num,
317317
bool* out_data,
318318
typename std::enable_if<std::is_integral<T>::value>::type* = 0) {
319-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
320-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
319+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
320+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
321321
out_data[i] = true;
322322
}
323323
}
324324

325-
template <typename T>
325+
template <typename T, typename IndexType>
326326
__global__ void IsfiniteCUDAKernel(
327327
const T* in_data,
328-
int num,
328+
IndexType num,
329329
bool* out_data,
330330
typename std::enable_if<is_complex64_or_complex128<T>::value>::type* = 0) {
331-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
332-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
331+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
332+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
333333
const T& a = in_data[i];
334334
out_data[i] = isfinite(a.real) && isfinite(a.imag);
335335
}
336336
}
337337

338338
/* IsnanFunctor */
339-
template <typename T>
339+
template <typename T, typename IndexType>
340340
__global__ void IsnanCUDAKernel(
341341
const T* in_data,
342-
int num,
342+
IndexType num,
343343
bool* out_data,
344344
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
345-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
346-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
345+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
346+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
347347
const T& a = in_data[i];
348348
out_data[i] = isnan(a);
349349
}
350350
}
351351

352-
template <typename T>
352+
template <typename T, typename IndexType>
353353
__global__ void IsnanCUDAKernel(
354354
const T* in_data,
355-
int num,
355+
IndexType num,
356356
bool* out_data,
357357
typename std::enable_if<std::is_integral<T>::value>::type* = 0) {
358-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
359-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
358+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
359+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
360360
out_data[i] = false;
361361
}
362362
}
363363

364-
template <typename T>
364+
template <typename T, typename IndexType>
365365
__global__ void IsnanCUDAKernel(
366366
const T* in_data,
367-
int num,
367+
IndexType num,
368368
bool* out_data,
369369
typename std::enable_if<is_complex64_or_complex128<T>::value>::type* = 0) {
370-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
371-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
370+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
371+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
372372
const T& a = in_data[i];
373373
out_data[i] = isnan(a.real) || isnan(a.imag);
374374
}
375375
}
376376

377377
/* IsinfFunctor */
378-
template <typename T>
378+
template <typename T, typename IndexType>
379379
__global__ void IsinfCUDAKernel(
380380
const T* in_data,
381-
int num,
381+
IndexType num,
382382
bool* out_data,
383383
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
384-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
385-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
384+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
385+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
386386
const T& a = in_data[i];
387387
out_data[i] = isinf(a);
388388
}
389389
}
390390

391-
template <typename T>
391+
template <typename T, typename IndexType>
392392
__global__ void IsinfCUDAKernel(
393393
const T* in_data,
394-
int num,
394+
IndexType num,
395395
bool* out_data,
396396
typename std::enable_if<std::is_integral<T>::value>::type* = 0) {
397-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
398-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
397+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
398+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
399399
out_data[i] = false;
400400
}
401401
}
402402

403-
template <typename T>
403+
template <typename T, typename IndexType>
404404
__global__ void IsinfCUDAKernel(
405405
const T* in_data,
406-
int num,
406+
IndexType num,
407407
bool* out_data,
408408
typename std::enable_if<is_complex64_or_complex128<T>::value>::type* = 0) {
409-
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
410-
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
409+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
410+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
411411
const T& a = in_data[i];
412412
out_data[i] = isinf(a.real) || isinf(a.imag);
413413
}
@@ -418,14 +418,19 @@ struct IsfiniteFunctor<phi::GPUContext, T> {
418418
void operator()(const phi::GPUContext& dev_ctx,
419419
const DenseTensor& in,
420420
DenseTensor* output) {
421-
int num = in.numel();
421+
int64_t num = in.numel();
422422
const T* in_data = in.data<T>();
423423
bool* out_data = dev_ctx.template Alloc<bool>(output);
424-
int block = 1024;
425-
int grid = (block - 1 + num) / block;
424+
int64_t block = 1024;
425+
int64_t grid = (block - 1 + num) / block;
426426
grid = (grid > block) ? block : grid;
427-
IsfiniteCUDAKernel<T>
428-
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
427+
if (num + block * grid + 1 > std::numeric_limits<unsigned int>::max()) {
428+
IsfiniteCUDAKernel<T, int64_t>
429+
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
430+
} else {
431+
IsfiniteCUDAKernel<T, unsigned int>
432+
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
433+
}
429434
}
430435
};
431436

@@ -434,14 +439,19 @@ struct IsnanFunctor<phi::GPUContext, T> {
434439
void operator()(const phi::GPUContext& dev_ctx,
435440
const DenseTensor& in,
436441
DenseTensor* output) {
437-
int num = in.numel();
442+
int64_t num = in.numel();
438443
const T* in_data = in.data<T>();
439444
bool* out_data = dev_ctx.template Alloc<bool>(output);
440-
int block = 1024;
441-
int grid = (block - 1 + num) / block;
445+
int64_t block = 1024;
446+
int64_t grid = (block - 1 + num) / block;
442447
grid = (grid > block) ? block : grid;
443-
IsnanCUDAKernel<T>
444-
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
448+
if (num + block * grid + 1 > std::numeric_limits<unsigned int>::max()) {
449+
IsnanCUDAKernel<T, int64_t>
450+
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
451+
} else {
452+
IsnanCUDAKernel<T, unsigned int>
453+
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
454+
}
445455
}
446456
};
447457

@@ -450,14 +460,19 @@ struct IsinfFunctor<phi::GPUContext, T> {
450460
void operator()(const phi::GPUContext& dev_ctx,
451461
const DenseTensor& in,
452462
DenseTensor* output) {
453-
int num = in.numel();
463+
int64_t num = in.numel();
454464
const T* in_data = in.data<T>();
455465
bool* out_data = dev_ctx.template Alloc<bool>(output);
456-
int block = 1024;
457-
int grid = (block - 1 + num) / block;
466+
int64_t block = 1024;
467+
int64_t grid = (block - 1 + num) / block;
458468
grid = (grid > block) ? block : grid;
459-
IsinfCUDAKernel<T>
460-
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
469+
if (num + block * grid + 1 > std::numeric_limits<unsigned int>::max()) {
470+
IsinfCUDAKernel<T, int64_t>
471+
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
472+
} else {
473+
IsinfCUDAKernel<T, unsigned int>
474+
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
475+
}
461476
}
462477
};
463478
#endif

0 commit comments

Comments
 (0)