diff --git a/paddle/fluid/operators/multinomial_op.cu b/paddle/fluid/operators/multinomial_op.cu index 2d97111709a0f..1e52cf36f69c8 100644 --- a/paddle/fluid/operators/multinomial_op.cu +++ b/paddle/fluid/operators/multinomial_op.cu @@ -33,18 +33,22 @@ namespace operators { template __global__ void NormalizeProbability(T* norm_probs, const T* in_data, - T* sum_rows) { + T* sum_rows, int64_t num_distributions, + int64_t num_categories) { int id = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; - PADDLE_ENFORCE( - in_data[id] >= 0.0, - "The input of multinomial distribution should be >= 0, but got %f.", - in_data[id]); - PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0, - "The sum of one multinomial distribution probability should " - "be > 0, but got %f.", - sum_rows[blockIdx.y]); - norm_probs[id] = in_data[id] / sum_rows[blockIdx.y]; + if (id < num_distributions * num_categories) { + PADDLE_ENFORCE( + in_data[id] >= 0.0, + "The input of multinomial distribution should be >= 0, but got %f.", + in_data[id]); + int64_t row_id = id / num_categories; + PADDLE_ENFORCE(sum_rows[row_id] > 0.0, + "The sum of one multinomial distribution probability should " + "be > 0, but got %f.", + sum_rows[row_id]); + norm_probs[id] = in_data[id] / sum_rows[row_id]; + } } template @@ -52,12 +56,10 @@ __global__ void GetCumulativeProbs(T* norm_probs_data, int64_t num_distributions, int64_t num_categories, T* cumulative_probs) { - for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) { - thrust::inclusive_scan(thrust::device, - norm_probs_data + id * num_categories, - norm_probs_data + (id + 1) * num_categories, - cumulative_probs + id * num_categories); - } + int id = blockIdx.x; + thrust::inclusive_scan(thrust::device, norm_probs_data + id * num_categories, + norm_probs_data + (id + 1) * num_categories, + cumulative_probs + id * num_categories); } template @@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement( // use binary search to get the selected category sample id. // let cumulative_probs[id-1] < rng_data < cumulative_probs[id]. - int idx = threadIdx.x + blockIdx.x * blockDim.x + - blockIdx.y * gridDim.x * blockDim.x; - // for every distribution - for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { - // for every sample - for (int sample = blockIdx.x * blockDim.x + threadIdx.x; - sample < num_samples; sample += blockDim.x * gridDim.x) { - T rng_number = rng_data[sample + dist * num_samples]; - - // Find the bucket that a uniform random number lies in - int selected_category = binarySearchFunctor( - cumulative_probs + dist * num_categories, - norm_probs_data + dist * num_categories, num_categories, rng_number); - - out_data[sample + dist * num_samples] = selected_category; - } + int dist = blockIdx.y; + // for every sample + int sample = blockIdx.x * blockDim.x + threadIdx.x; + if (sample < num_samples) { + T rng_number = rng_data[sample + dist * num_samples]; + + // Find the bucket that a uniform random number lies in + int selected_category = binarySearchFunctor( + cumulative_probs + dist * num_categories, + norm_probs_data + dist * num_categories, num_categories, rng_number); + + out_data[sample + dist * num_samples] = selected_category; } } @@ -215,10 +213,11 @@ class MultinomialOpKernel // number of threads in a block is min(num_categories, 512) dim3 block_norm(num_categories < 512 ? num_categories : 512); - dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions); + dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1); NormalizeProbability< T><<>>( - norm_probs_data, in_data, sum_rows_data); + norm_probs_data, in_data, sum_rows_data, num_distributions, + num_categories); // Get cumulative probability of each distribution. It's the same function // of diff --git a/python/paddle/fluid/tests/unittests/test_multinomial_op.py b/python/paddle/fluid/tests/unittests/test_multinomial_op.py index 957c06eca89c3..cdb89bb964055 100644 --- a/python/paddle/fluid/tests/unittests/test_multinomial_op.py +++ b/python/paddle/fluid/tests/unittests/test_multinomial_op.py @@ -141,6 +141,14 @@ def test_dygraph3(self): "replacement is False. categories can't be sampled repeatedly") paddle.enable_static() + def test_dygraph4(self): + paddle.disable_static() + logits = -1 * paddle.ones([2800]) + # Categorical.sample API will call multinomial op with replacement=True + cat = paddle.distribution.Categorical(logits.exp()) + cat.sample([1]) + paddle.enable_static() + def test_static(self): paddle.enable_static() startup_program = fluid.Program()