Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/phi/kernels/cpu/bincount_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template <typename Context, typename T, typename InputT>
void BincountInner(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& weights,
int minlength,
int64_t minlength,
DenseTensor* out) {
const DenseTensor* input = &x;
DenseTensor* output = out;
Expand All @@ -48,7 +48,7 @@ void BincountInner(const Context& dev_ctx,
int64_t output_size = static_cast<int64_t>(*std::max_element(
input_data, input_data + input_numel)) +
1L;
output_size = std::max(output_size, static_cast<int64_t>(minlength));
output_size = std::max(output_size, minlength);

phi::DDim out_dim{output_size};
output->Resize(out_dim);
Expand Down Expand Up @@ -89,7 +89,7 @@ void BincountKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& weights,
const Scalar& minlength,
DenseTensor* out) {
int int_minlength = minlength.to<int>();
int64_t int_minlength = minlength.to<int64_t>();
PADDLE_ENFORCE_GE(int_minlength,
0,
common::errors::InvalidArgument(
Expand Down
106 changes: 77 additions & 29 deletions paddle/phi/kernels/gpu/bincount_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,64 @@ namespace phi {

using phi::PADDLE_CUDA_NUM_THREADS;

inline int GET_BLOCKS(const int N) {
inline int64_t GET_BLOCKS(const int64_t N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}

template <typename T>
__global__ void KernelReduceMinMax(const T* input,
int64_t numel,
T* min_out,
T* max_out) {
__shared__ T smin[PADDLE_CUDA_NUM_THREADS];
__shared__ T smax[PADDLE_CUDA_NUM_THREADS];
int tid = threadIdx.x;
int64_t global_thread_id =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
int64_t stride = static_cast<int64_t>(gridDim.x) * blockDim.x;

T local_min = std::numeric_limits<T>::max();
T local_max = std::numeric_limits<T>::lowest();

for (int64_t i = global_thread_id; i < numel; i += stride) {
T val = input[i];
local_min = min(local_min, val);
local_max = max(local_max, val);
}

smin[tid] = local_min;
smax[tid] = local_max;
__syncthreads();

for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
if (tid < offset) {
smin[tid] = min(smin[tid], smin[tid + offset]);
smax[tid] = max(smax[tid], smax[tid + offset]);
}
__syncthreads();
}

if (tid == 0) {
phi::CudaAtomicMin(min_out, smin[0]);
phi::CudaAtomicMax(max_out, smax[0]);
}
}

template <typename T, typename InputT, typename OutT>
__global__ void KernelBincount(const InputT* input,
const int total_elements,
const int64_t total_elements,
const bool has_weights,
const T* weights,
OutT* output) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < total_elements) {
int64_t global_tid =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
int64_t stride = static_cast<int64_t>(gridDim.x) * blockDim.x;
for (int64_t i = global_tid; i < total_elements; i += stride) {
InputT index = input[i];
if (!has_weights) {
phi::CudaAtomicAdd(&output[input[tid]], 1L);
phi::CudaAtomicAdd(&output[index], 1L);
} else {
phi::CudaAtomicAdd(&output[input[tid]], static_cast<OutT>(weights[tid]));
phi::CudaAtomicAdd(&output[index], static_cast<OutT>(weights[i]));
}
}
}
Expand All @@ -48,39 +90,45 @@ template <typename Context, typename T, typename InputT>
void BincountCUDAInner(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& weights,
int minlength,
int64_t minlength,
DenseTensor* out) {
const DenseTensor* input = &x;
DenseTensor* output = out;
const InputT* input_data = input->data<InputT>();

const int input_numel = input->numel();
int64_t input_numel = static_cast<int64_t>(input->numel());

if (input_data == nullptr) {
phi::DDim out_dim{0};
output->Resize(out_dim);
dev_ctx.template Alloc<T>(output);
return;
}
auto input_x = EigenVector<InputT>::Flatten(*input);
DenseTensor input_min_t, input_max_t;
input_max_t.Resize({1});
auto* input_max_data = dev_ctx.template Alloc<InputT>(&input_max_t);
input_min_t.Resize({1});
auto* input_min_data = dev_ctx.template Alloc<InputT>(&input_min_t);

auto input_max_scala = EigenScalar<InputT>::From(input_max_t);
auto input_min_scala = EigenScalar<InputT>::From(input_min_t);
DenseTensor input_min_max_cpu;
input_min_max_cpu.Resize({2});
auto* input_min_max_cpu_data =
dev_ctx.template HostAlloc<InputT>(&input_min_max_cpu);
input_min_max_cpu.data<InputT>()[0] = std::numeric_limits<InputT>::max();
input_min_max_cpu.data<InputT>()[1] = std::numeric_limits<InputT>::lowest();

DenseTensor input_min_max_t;
input_min_max_t.Resize({2});
auto* input_min_max_data = dev_ctx.template Alloc<InputT>(&input_min_max_t);

phi::Copy(
dev_ctx, input_min_max_cpu, dev_ctx.GetPlace(), true, &input_min_max_t);

auto* place = dev_ctx.eigen_device();
input_max_scala.device(*place) = input_x.maximum();
input_min_scala.device(*place) = input_x.minimum();
int64_t max_grid_x = dev_ctx.GetCUDAMaxGridDimSize()[0];
int64_t num_blocks = std::min(GET_BLOCKS(input_numel), max_grid_x);
KernelReduceMinMax<InputT>
<<<num_blocks, PADDLE_CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
input_data, input_numel, input_min_max_data, input_min_max_data + 1);

DenseTensor input_min_cpu, input_max_cpu;
phi::Copy(dev_ctx, input_min_t, phi::CPUPlace(), true, &input_min_cpu);
phi::Copy(dev_ctx, input_max_t, phi::CPUPlace(), true, &input_max_cpu);
phi::Copy(
dev_ctx, input_min_max_t, phi::CPUPlace(), true, &input_min_max_cpu);

InputT input_min = input_min_cpu.data<InputT>()[0];
InputT input_min = input_min_max_cpu.data<InputT>()[0];

PADDLE_ENFORCE_GE(
input_min,
Expand All @@ -89,9 +137,9 @@ void BincountCUDAInner(const Context& dev_ctx,
"The elements in input tensor must be non-negative ints"));

int64_t output_size =
static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L;
static_cast<int64_t>(input_min_max_cpu.data<InputT>()[1]) + 1L;

output_size = std::max(output_size, static_cast<int64_t>(minlength));
output_size = std::max(output_size, minlength);
phi::DDim out_dim{output_size};
output->Resize(out_dim);

Expand All @@ -106,7 +154,7 @@ void BincountCUDAInner(const Context& dev_ctx,
dev_ctx, output, static_cast<int64_t>(0));

KernelBincount<T, InputT, int64_t>
<<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
<<<num_blocks, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
} else {
if (weights->dtype() == DataType::FLOAT32) {
Expand All @@ -115,14 +163,14 @@ void BincountCUDAInner(const Context& dev_ctx,
dev_ctx, output, static_cast<float>(0));

KernelBincount<T, InputT, float>
<<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
<<<num_blocks, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
} else {
double* output_data = dev_ctx.template Alloc<double>(output);
phi::funcs::SetConstant<Context, double>()(
dev_ctx, output, static_cast<double>(0));
KernelBincount<T, InputT, double>
<<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
<<<num_blocks, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
}
}
Expand All @@ -134,7 +182,7 @@ void BincountKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& weights,
const Scalar& minlength,
DenseTensor* out) {
int int_minlength = minlength.to<int>();
int64_t int_minlength = minlength.to<int64_t>();
PADDLE_ENFORCE_GE(int_minlength,
0,
common::errors::InvalidArgument(
Expand Down
12 changes: 12 additions & 0 deletions test/legacy_test/test_bincount_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def test_dygraph(self):
msg='bincount output is wrong, out =' + str(actual.numpy()),
)

def test_dygraph_cpu(self):
with base.dygraph.guard():
paddle.device.set_device('cpu')
inputs_np = np.array([0, 1, 1, 3, 2, 1, 7]).astype(np.int64)
inputs = paddle.to_tensor(inputs_np)
actual = paddle.bincount(inputs)
expected = np.bincount(inputs)
self.assertTrue(
(actual.numpy() == expected).all(),
msg='bincount output is wrong, out =' + str(actual.numpy()),
)


class TestBincountOpError(unittest.TestCase):
"""Test bincount op error."""
Expand Down
Loading