Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Decouple dtype from shape for Random multinomial #15980

Merged
merged 4 commits into from
Aug 25, 2019
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/operator/random/sample_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!ndim_is_known(ishape)) return false;

MSHADOW_TYPE_SWITCH(param.dtype, DType, {
CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue<DType>())
<< "'dtype' does not have a sufficient precision to represent the indices of the input array.";
});

if (ishape.ndim() == 1) {
if (param.shape.ndim() > 0) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
Expand Down Expand Up @@ -121,7 +116,7 @@ inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,

struct SampleMultinomialKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, index_t K, index_t M,
MSHADOW_XINLINE static void Map(index_t i, index_t K, index_t M,
DType* dist, float* uniform, float* cum_table,
IType* out, DType* prob) {
double acc = 0.0;
Expand Down