Skip to content

Commit

Permalink
【Cherry-pick PR 36511】fix out_of_range bug of multinomial op's cuda k…
Browse files Browse the repository at this point in the history
…ernel (#36511) (#36808)

Cherry-pick PR #36511
  • Loading branch information
pangyoki authored Oct 28, 2021
1 parent e3db65d commit d8ffb26
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
67 changes: 33 additions & 34 deletions paddle/fluid/operators/multinomial_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,33 @@ namespace operators {

template <typename T>
__global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) {
T* sum_rows, int64_t num_distributions,
int64_t num_categories) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
PADDLE_ENFORCE(
in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0, but got %f.",
in_data[id]);
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f.",
sum_rows[blockIdx.y]);
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
if (id < num_distributions * num_categories) {
PADDLE_ENFORCE(
in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0, but got %f.",
in_data[id]);
int64_t row_id = id / num_categories;
PADDLE_ENFORCE(sum_rows[row_id] > 0.0,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f.",
sum_rows[row_id]);
norm_probs[id] = in_data[id] / sum_rows[row_id];
}
}

template <typename T>
__global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t num_distributions,
int64_t num_categories,
T* cumulative_probs) {
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
thrust::inclusive_scan(thrust::device,
norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs + id * num_categories);
}
int id = blockIdx.x;
thrust::inclusive_scan(thrust::device, norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs + id * num_categories);
}

template <typename T>
Expand Down Expand Up @@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement(
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].

int idx = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;

// for every distribution
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
// for every sample
for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
sample < num_samples; sample += blockDim.x * gridDim.x) {
T rng_number = rng_data[sample + dist * num_samples];

// Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number);

out_data[sample + dist * num_samples] = selected_category;
}
int dist = blockIdx.y;
// for every sample
int sample = blockIdx.x * blockDim.x + threadIdx.x;
if (sample < num_samples) {
T rng_number = rng_data[sample + dist * num_samples];

// Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number);

out_data[sample + dist * num_samples] = selected_category;
}
}

Expand Down Expand Up @@ -215,10 +213,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>

// number of threads in a block is min(num_categories, 512)
dim3 block_norm(num_categories < 512 ? num_categories : 512);
dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions);
dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
NormalizeProbability<
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
norm_probs_data, in_data, sum_rows_data);
norm_probs_data, in_data, sum_rows_data, num_distributions,
num_categories);

// Get cumulative probability of each distribution. It's the same function
// of
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/fluid/tests/unittests/test_multinomial_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def test_dygraph3(self):
"replacement is False. categories can't be sampled repeatedly")
paddle.enable_static()

def test_dygraph4(self):
paddle.disable_static()
logits = -1 * paddle.ones([2800])
# Categorical.sample API will call multinomial op with replacement=True
cat = paddle.distribution.Categorical(logits.exp())
cat.sample([1])
paddle.enable_static()

def test_static(self):
paddle.enable_static()
startup_program = fluid.Program()
Expand Down

0 comments on commit d8ffb26

Please sign in to comment.