Skip to content

Commit

Permalink
add unittest (#36511)
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki committed Oct 27, 2021
1 parent 7cb7535 commit e0874eb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
67 changes: 33 additions & 34 deletions paddle/fluid/operators/multinomial_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,33 @@ namespace operators {

template <typename T>
__global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) {
T* sum_rows, int64_t num_distributions,
int64_t num_categories) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
PADDLE_ENFORCE(
in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0, but got %f.",
in_data[id]);
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f.",
sum_rows[blockIdx.y]);
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
if (id < num_distributions * num_categories) {
PADDLE_ENFORCE(
in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0, but got %f.",
in_data[id]);
int64_t row_id = id / num_categories;
PADDLE_ENFORCE(sum_rows[row_id] > 0.0,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f.",
sum_rows[row_id]);
norm_probs[id] = in_data[id] / sum_rows[row_id];
}
}

template <typename T>
__global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t num_distributions,
int64_t num_categories,
T* cumulative_probs) {
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
thrust::inclusive_scan(thrust::device,
norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs + id * num_categories);
}
int id = blockIdx.x;
thrust::inclusive_scan(thrust::device, norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs + id * num_categories);
}

template <typename T>
Expand Down Expand Up @@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement(
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].

int idx = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;

// for every distribution
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
// for every sample
for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
sample < num_samples; sample += blockDim.x * gridDim.x) {
T rng_number = rng_data[sample + dist * num_samples];

// Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number);

out_data[sample + dist * num_samples] = selected_category;
}
int dist = blockIdx.y;
// for every sample
int sample = blockIdx.x * blockDim.x + threadIdx.x;
if (sample < num_samples) {
T rng_number = rng_data[sample + dist * num_samples];

// Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number);

out_data[sample + dist * num_samples] = selected_category;
}
}

Expand Down Expand Up @@ -215,10 +213,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>

// number of threads in a block is min(num_categories, 512)
dim3 block_norm(num_categories < 512 ? num_categories : 512);
dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions);
dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
NormalizeProbability<
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
norm_probs_data, in_data, sum_rows_data);
norm_probs_data, in_data, sum_rows_data, num_distributions,
num_categories);

// Get cumulative probability of each distribution. It's the same function
// of
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/fluid/tests/unittests/test_multinomial_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def test_dygraph3(self):
"replacement is False. categories can't be sampled repeatedly")
paddle.enable_static()

def test_dygraph4(self):
paddle.disable_static()
logits = -1 * paddle.ones([2800])
# Categorical.sample API will call multinomial op with replacement=True
cat = paddle.distribution.Categorical(logits.exp())
cat.sample([1])
paddle.enable_static()

def test_static(self):
paddle.enable_static()
startup_program = fluid.Program()
Expand Down

1 comment on commit e0874eb

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.