Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 89 additions & 74 deletions paddle/phi/kernels/impl/isfinite_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ struct IsfiniteFunctor<
const DenseTensor& in,
DenseTensor* output) {
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
out_data[i] = true;
}
}
Expand All @@ -95,8 +95,8 @@ struct IsfiniteFunctor<
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
const T& a = in_a[i];
out_data[i] = std::isfinite(a);
}
Expand All @@ -113,8 +113,8 @@ struct IsfiniteFunctor<
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
const T& a = in_a[i];
out_data[i] = phi::dtype::isfinite(a);
}
Expand All @@ -131,8 +131,8 @@ struct IsfiniteFunctor<
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
const T& a = in_a[i];
out_data[i] = std::isfinite(a.real) && std::isfinite(a.imag);
}
Expand All @@ -157,8 +157,8 @@ struct IsnanFunctor<
const DenseTensor& in,
DenseTensor* output) {
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
out_data[i] = false;
}
}
Expand All @@ -174,8 +174,8 @@ struct IsnanFunctor<
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
const T& a = in_a[i];
out_data[i] = std::isnan(a);
}
Expand All @@ -191,8 +191,8 @@ struct IsnanFunctor<phi::CPUContext,
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
const T& a = in_a[i];
out_data[i] = phi::dtype::isnan(a);
}
Expand All @@ -209,8 +209,8 @@ struct IsnanFunctor<
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
const T& a = in_a[i];
out_data[i] = std::isnan(a.real) || std::isnan(a.imag);
}
Expand All @@ -236,7 +236,7 @@ struct IsinfFunctor<
DenseTensor* output) {
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
for (int64_t i = 0; i < num; i++) {
out_data[i] = false;
}
}
Expand All @@ -252,8 +252,8 @@ struct IsinfFunctor<
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
const T& a = in_a[i];
out_data[i] = std::isinf(a);
}
Expand All @@ -269,8 +269,8 @@ struct IsinfFunctor<phi::CPUContext,
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
const T& a = in_a[i];
out_data[i] = phi::dtype::isinf(a);
}
Expand All @@ -287,8 +287,8 @@ struct IsinfFunctor<
DenseTensor* output) {
auto* in_a = in.data<T>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
for (int i = 0; i < num; i++) {
int64_t num = in.numel();
for (int64_t i = 0; i < num; i++) {
const T& a = in_a[i];
out_data[i] = std::isinf(a.real) || std::isinf(a.imag);
}
Expand All @@ -297,117 +297,117 @@ struct IsinfFunctor<

#if defined(__NVCC__) || defined(__HIPCC__)
/* IsfiniteFunctor */
template <typename T>
template <typename T, typename IndexType>
__global__ void IsfiniteCUDAKernel(
const T* in_data,
int num,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
out_data[i] = isfinite(a);
}
}

template <typename T>
template <typename T, typename IndexType>
__global__ void IsfiniteCUDAKernel(
const T* in_data,
int num,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_integral<T>::value>::type* = 0) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
out_data[i] = true;
}
}

template <typename T>
template <typename T, typename IndexType>
__global__ void IsfiniteCUDAKernel(
const T* in_data,
int num,
IndexType num,
bool* out_data,
typename std::enable_if<is_complex64_or_complex128<T>::value>::type* = 0) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
out_data[i] = isfinite(a.real) && isfinite(a.imag);
}
}

/* IsnanFunctor */
template <typename T>
template <typename T, typename IndexType>
__global__ void IsnanCUDAKernel(
const T* in_data,
int num,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
out_data[i] = isnan(a);
}
}

template <typename T>
template <typename T, typename IndexType>
__global__ void IsnanCUDAKernel(
const T* in_data,
int num,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_integral<T>::value>::type* = 0) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
out_data[i] = false;
}
}

template <typename T>
template <typename T, typename IndexType>
__global__ void IsnanCUDAKernel(
const T* in_data,
int num,
IndexType num,
bool* out_data,
typename std::enable_if<is_complex64_or_complex128<T>::value>::type* = 0) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
out_data[i] = isnan(a.real) || isnan(a.imag);
}
}

/* IsinfFunctor */
template <typename T>
template <typename T, typename IndexType>
__global__ void IsinfCUDAKernel(
const T* in_data,
int num,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
out_data[i] = isinf(a);
}
}

template <typename T>
template <typename T, typename IndexType>
__global__ void IsinfCUDAKernel(
const T* in_data,
int num,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_integral<T>::value>::type* = 0) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
out_data[i] = false;
}
}

template <typename T>
template <typename T, typename IndexType>
__global__ void IsinfCUDAKernel(
const T* in_data,
int num,
IndexType num,
bool* out_data,
typename std::enable_if<is_complex64_or_complex128<T>::value>::type* = 0) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
out_data[i] = isinf(a.real) || isinf(a.imag);
}
Expand All @@ -418,14 +418,19 @@ struct IsfiniteFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& dev_ctx,
const DenseTensor& in,
DenseTensor* output) {
int num = in.numel();
int64_t num = in.numel();
const T* in_data = in.data<T>();
bool* out_data = dev_ctx.template Alloc<bool>(output);
int block = 1024;
int grid = (block - 1 + num) / block;
int64_t block = 1024;
int64_t grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
IsfiniteCUDAKernel<T>
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
if (num + block * grid + 1 > std::numeric_limits<unsigned int>::max()) {
IsfiniteCUDAKernel<T, int64_t>
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
} else {
IsfiniteCUDAKernel<T, unsigned int>
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
}
}
};

Expand All @@ -434,14 +439,19 @@ struct IsnanFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& dev_ctx,
const DenseTensor& in,
DenseTensor* output) {
int num = in.numel();
int64_t num = in.numel();
const T* in_data = in.data<T>();
bool* out_data = dev_ctx.template Alloc<bool>(output);
int block = 1024;
int grid = (block - 1 + num) / block;
int64_t block = 1024;
int64_t grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
IsnanCUDAKernel<T>
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
if (num + block * grid + 1 > std::numeric_limits<unsigned int>::max()) {
IsnanCUDAKernel<T, int64_t>
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
} else {
IsnanCUDAKernel<T, unsigned int>
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
}
}
};

Expand All @@ -450,14 +460,19 @@ struct IsinfFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& dev_ctx,
const DenseTensor& in,
DenseTensor* output) {
int num = in.numel();
int64_t num = in.numel();
const T* in_data = in.data<T>();
bool* out_data = dev_ctx.template Alloc<bool>(output);
int block = 1024;
int grid = (block - 1 + num) / block;
int64_t block = 1024;
int64_t grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
IsinfCUDAKernel<T>
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
if (num + block * grid + 1 > std::numeric_limits<unsigned int>::max()) {
IsinfCUDAKernel<T, int64_t>
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
} else {
IsinfCUDAKernel<T, unsigned int>
<<<grid, block, 0, dev_ctx.stream()>>>(in_data, num, out_data);
}
}
};
#endif
Expand Down
Loading