From b1e7563a50c7eae4a74435145042c6716863d860 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 11 Oct 2018 13:59:07 -0700 Subject: [PATCH 01/30] randint operator add along with add optional tag to params --- include/mxnet/random_generator.h | 13 +- python/mxnet/ndarray/random.py | 155 ++++++++++++++++-------- python/mxnet/symbol/random.py | 109 ++++++++++++----- src/operator/contrib/quadratic_op-inl.h | 2 +- src/operator/random/sample_op.cc | 18 +++ src/operator/random/sample_op.h | 61 +++++++++- src/operator/random/sampler.h | 27 +++++ tests/python/unittest/test_random.py | 18 +++ 8 files changed, 317 insertions(+), 86 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 6e37efd40598..8b289129e591 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -71,6 +71,11 @@ class RandGenerator { return dist_uniform(*engine_); } + MSHADOW_XINLINE FType discrete_uniform(const int lower, const int upper) { + std::uniform_int_distribution dist_discrete_uniform(lower, upper); + return dist_discrete_uniform(*engine_); + } + MSHADOW_XINLINE FType normal() { std::normal_distribution dist_normal; return dist_normal(*engine_); @@ -145,7 +150,13 @@ class RandGenerator { return curand_normal(&state_); } - private: + template + MSHADOW_XINLINE FType discrete_uniform(const IType *lower, const IType *upper) { + std::uniform_int_distribution dist_discrete_uniform(*lower, *upper); + return dist_discrete_uniform(*engine_); + } + + private: RandGenerator *global_gen_; int global_state_idx_; curandStatePhilox4_32_10_t state_; diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py index 1e941f79aa1c..12d0c222c202 100644 --- a/python/mxnet/ndarray/random.py +++ b/python/mxnet/ndarray/random.py @@ -59,23 +59,23 @@ def uniform(low=0, high=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg Parameters ---------- - low : float or NDArray + low : float or NDArray, optional Lower boundary of the output interval. All values generated will be greater than or equal to low. The default value is 0. - high : float or NDArray + high : float or NDArray, optional Upper boundary of the output interval. All values generated will be less than high. The default value is 1.0. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `low` and `high` are scalars, output shape will be `(m, n)`. If `low` and `high` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' - ctx : Context + ctx : Context, optional Device context of output. Default is current context. Overridden by `low.context` when `low` is an NDArray. - out : NDArray + out : NDArray, optional Store output to an existing NDArray. @@ -111,21 +111,21 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg Parameters ---------- - loc : float or NDArray + loc : float or NDArray, optional Mean (centre) of the distribution. - scale : float or NDArray + scale : float or NDArray, optional Standard deviation (spread or width) of the distribution. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' - ctx : Context + ctx : Context, optional Device context of output. Default is current context. Overridden by `loc.context` when `loc` is an NDArray. - out : NDArray + out : NDArray, optional Store output to an existing NDArray. @@ -152,7 +152,7 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg [loc, scale], shape, dtype, ctx, out, kwargs) -def randn(*shape, **kwargs): +def randn(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -161,21 +161,21 @@ def randn(*shape, **kwargs): Parameters ---------- - loc : float or NDArray + loc : float or NDArray, optional Mean (centre) of the distribution. - scale : float or NDArray + scale : float or NDArray, optional Standard deviation (spread or width) of the distribution. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' - ctx : Context + ctx : Context, optional Device context of output. Default is current context. Overridden by `loc.context` when `loc` is an NDArray. - out : NDArray + out : NDArray, optional Store output to an existing NDArray. @@ -212,19 +212,19 @@ def poisson(lam=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): Parameters ---------- - lam : float or NDArray + lam : float or NDArray, optional Expectation of interval, should be >= 0. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `lam` is a scalar, output shape will be `(m, n)`. If `lam` is an NDArray with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `lam`. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' - ctx : Context + ctx : Context, optional Device context of output. Default is current context. Overridden by `lam.context` when `lam` is an NDArray. - out : NDArray + out : NDArray, optional Store output to an existing NDArray. @@ -259,19 +259,19 @@ def exponential(scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs) Parameters ---------- - scale : float or NDArray + scale : float or NDArray, optional The scale parameter, \beta = 1/\lambda. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `scale` is a scalar, output shape will be `(m, n)`. If `scale` is an NDArray with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `scale`. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' - ctx : Context + ctx : Context, optional Device context of output. Default is current context. Overridden by `scale.context` when `scale` is an NDArray. - out : NDArray + out : NDArray, optional Store output to an existing NDArray. @@ -302,22 +302,22 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg Parameters ---------- - alpha : float or NDArray + alpha : float or NDArray, optional The shape of the gamma distribution. Should be greater than zero. - beta : float or NDArray + beta : float or NDArray, optional The scale of the gamma distribution. Should be greater than zero. Default is equal to 1. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `alpha` and `beta` are scalars, output shape will be `(m, n)`. If `alpha` and `beta` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[alpha, beta)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' - ctx : Context + ctx : Context, optional Device context of output. Default is current context. Overridden by `alpha.context` when `alpha` is an NDArray. - out : NDArray + out : NDArray, optional Store output to an existing NDArray. @@ -352,21 +352,21 @@ def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, ctx=None, Parameters ---------- - k : float or NDArray + k : float or NDArray, optional Limit of unsuccessful experiments, > 0. - p : float or NDArray + p : float or NDArray, optional Failure probability in each experiment, >= 0 and <=1. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `k` and `p` are scalars, output shape will be `(m, n)`. If `k` and `p` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[k, p)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' - ctx : Context + ctx : Context, optional Device context of output. Default is current context. Overridden by `k.context` when `k` is an NDArray. - out : NDArray + out : NDArray, optional Store output to an existing NDArray. @@ -403,21 +403,21 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, ctx=N Parameters ---------- - mu : float or NDArray + mu : float or NDArray, optional Mean of the negative binomial distribution. - alpha : float or NDArray + alpha : float or NDArray, optional Alpha (dispersion) parameter of the negative binomial distribution. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `mu` and `alpha` are scalars, output shape will be `(m, n)`. If `mu` and `alpha` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[mu, alpha)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' - ctx : Context + ctx : Context, optional Device context of output. Default is current context. Overridden by `mu.context` when `mu` is an NDArray. - out : NDArray + out : NDArray, optional Store output to an existing NDArray. @@ -455,17 +455,17 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw `k` is the number of possible outcomes of each multinomial distribution. For example, data with shape `(m, n, k)` specifies `m*n` multinomial distributions each with `k` possible outcomes. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw from each distribution. If shape is empty one sample will be drawn from each distribution. - get_prob : bool + get_prob : bool, optional If true, a second array containing log likelihood of the drawn samples will also be returned. This is usually used for reinforcement learning, where you can provide reward as head gradient w.r.t. this array to estimate gradient. - out : NDArray + out : NDArray, optional Store output to an existing NDArray. - dtype : str or numpy.dtype + dtype : str or numpy.dtype, optional Data type of the sample output array. The default is int32. Note that the data type of the log likelihood array is the same with that of `data`. @@ -500,7 +500,7 @@ def shuffle(data, **kwargs): ---------- data : NDArray Input data array. - out : NDArray + out : NDArray, optional Array to store the result. Examples @@ -518,3 +518,54 @@ def shuffle(data, **kwargs): """ return _internal._shuffle(data, **kwargs) + + +def randint(low=0, high=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): + """Draw random samples from a discrete uniform distribution. + + Samples are uniformly distributed over the half-open interval *[low, high)* + (includes *low*, but excludes *high*). + + Parameters + ---------- + low : float or NDArray, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float or NDArray, optional + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1. + shape : int or tuple of ints, optional + The number of samples to draw. If shape is, e.g., `(m, n)` and `low` and + `high` are scalars, output shape will be `(m, n)`. If `low` and `high` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. + dtype : {'uint8','int32','int8', 'int64'}, optional + Data type of output samples. Default is 'int32' + ctx : Context, optional + Device context of output. Default is current context. Overridden by + `low.context` when `low` is an NDArray. + out : NDArray, optional + Store output to an existing NDArray. + + + Examples + -------- + >>> mx.nd.random.randint() + [ 190] + >> mx.nd.random.randint(-10, 2, ctx=mx.gpu(0)) + [ -8] + + >>> mx.nd.random.randint(-10, 10, shape=(2,)) + [ -5 4] + + >>> low = mx.nd.array([1,2,3]) + >>> high = mx.nd.array([2,3,4]) + >>> mx.nd.random.randint(low, high, shape=2) + [[ 1 1] + [ 2 2] + [ 3 3]] + + """ + return _random_helper(_internal._random_randint, _internal._sample_uniform, + [low, high], shape, dtype, ctx, out, kwargs) diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py index e9abe9c4a187..1be8d6aac1bd 100644 --- a/python/mxnet/symbol/random.py +++ b/python/mxnet/symbol/random.py @@ -53,18 +53,18 @@ def uniform(low=0, high=1, shape=_Null, dtype=_Null, **kwargs): Parameters ---------- - low : float or Symbol + low : float or Symbol, optional Lower boundary of the output interval. All values generated will be greater than or equal to low. The default value is 0. - high : float or Symbol + high : float or Symbol, optional Upper boundary of the output interval. All values generated will be less than high. The default value is 1.0. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `low` and `high` are scalars, output shape will be `(m, n)`. If `low` and `high` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_uniform, _internal._sample_uniform, @@ -80,18 +80,45 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs): Parameters ---------- - loc : float or Symbol + loc : float or Symbol, optional Mean (centre) of the distribution. - scale : float or Symbol + scale : float or Symbol, optional Standard deviation (spread or width) of the distribution. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + """ + return _random_helper(_internal._random_normal, _internal._sample_normal, + [loc, scale], shape, dtype, kwargs) + + +def randn(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float or NDArray, optional + Mean (centre) of the distribution. + scale : float or NDArray, optional + Standard deviation (spread or width) of the distribution. + shape : int or tuple of ints, optional + The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and + `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ + assert isinstance(loc, (int, float)) + assert isinstance(scale, (int, float)) return _random_helper(_internal._random_normal, _internal._sample_normal, [loc, scale], shape, dtype, kwargs) @@ -104,14 +131,14 @@ def poisson(lam=1, shape=_Null, dtype=_Null, **kwargs): Parameters ---------- - lam : float or Symbol + lam : float or Symbol, optional Expectation of interval, should be >= 0. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `lam` is a scalar, output shape will be `(m, n)`. If `lam` is an Symbol with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `lam`. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_poisson, _internal._sample_poisson, @@ -130,14 +157,14 @@ def exponential(scale=1, shape=_Null, dtype=_Null, **kwargs): Parameters ---------- - scale : float or Symbol + scale : float or Symbol, optional The scale parameter, \beta = 1/\lambda. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `scale` is a scalar, output shape will be `(m, n)`. If `scale` is an Symbol with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `scale`. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_exponential, _internal._sample_exponential, @@ -152,17 +179,17 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, **kwargs): Parameters ---------- - alpha : float or Symbol + alpha : float or Symbol, optional The shape of the gamma distribution. Should be greater than zero. - beta : float or Symbol + beta : float or Symbol, optional The scale of the gamma distribution. Should be greater than zero. Default is equal to 1. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `alpha` and `beta` are scalars, output shape will be `(m, n)`. If `alpha` and `beta` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[alpha, beta)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_gamma, _internal._sample_gamma, @@ -179,16 +206,16 @@ def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, **kwargs): Parameters ---------- - k : float or Symbol + k : float or Symbol, optional Limit of unsuccessful experiments, > 0. - p : float or Symbol + p : float or Symbol, optional Failure probability in each experiment, >= 0 and <=1. shape : int or tuple of ints The number of samples to draw. If shape is, e.g., `(m, n)` and `k` and `p` are scalars, output shape will be `(m, n)`. If `k` and `p` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[k, p)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_negative_binomial, @@ -207,16 +234,16 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, **kwa Parameters ---------- - mu : float or Symbol + mu : float or Symbol, optional Mean of the negative binomial distribution. - alpha : float or Symbol + alpha : float or Symbol, optional Alpha (dispersion) parameter of the negative binomial distribution. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `mu` and `alpha` are scalars, output shape will be `(m, n)`. If `mu` and `alpha` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[mu, alpha)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16','float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_generalized_negative_binomial, @@ -237,15 +264,15 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs): `k` is the number of possible outcomes of each multinomial distribution. For example, data with shape `(m, n, k)` specifies `m*n` multinomial distributions each with `k` possible outcomes. - shape : int or tuple of ints + shape : int or tuple of ints, optional The number of samples to draw from each distribution. If shape is empty one sample will be drawn from each distribution. - get_prob : bool + get_prob : bool, optional If true, a second array containing log likelihood of the drawn samples will also be returned. This is usually used for reinforcement learning, where you can provide reward as head gradient w.r.t. this array to estimate gradient. - dtype : str or numpy.dtype + dtype : str or numpy.dtype, optional Data type of the sample output array. The default is int32. Note that the data type of the log likelihood array is the same with that of `data`. """ @@ -281,3 +308,29 @@ def shuffle(data, **kwargs): """ return _internal._shuffle(data, **kwargs) + + +def randint(low=0, high=1, shape=_Null, dtype=_Null, **kwargs): + """Draw random samples from a discrete uniform distribution. + + Samples are uniformly distributed over the half-open interval *[low, high)* + (includes *low*, but excludes *high*). + + Parameters + ---------- + low : float or NDArray, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float or NDArray, optional + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1. + shape : int or tuple of ints, optional + The number of samples to draw. If shape is, e.g., `(m, n)` and `low` and + `high` are scalars, output shape will be `(m, n)`. If `low` and `high` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. + dtype : {'uint8','int32','int8', 'int64'}, optional + Data type of output samples. Default is 'int32' + """ + return _random_helper(_internal._random_randint, _internal._sample_uniform, + [low, high], shape, dtype, kwargs) diff --git a/src/operator/contrib/quadratic_op-inl.h b/src/operator/contrib/quadratic_op-inl.h index 71cb76a7b565..a6fa260f10e8 100644 --- a/src/operator/contrib/quadratic_op-inl.h +++ b/src/operator/contrib/quadratic_op-inl.h @@ -20,7 +20,7 @@ /*! * \file quad_function-inl.h * \brief Operator implementing quadratic function. - * For using as an exmaple in the tutorial of adding operators + * For using as an example in the tutorial of adding operators * in MXNet backend. */ #ifndef MXNET_OPERATOR_CONTRIB_QUADRATIC_OP_INL_H_ diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index a2b332456fb5..e11c8404ab52 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -171,5 +171,23 @@ Example:: .set_attr("FCompute", Sample_>) .set_attr("FComputeEx", SampleEx_>); +MXNET_OPERATOR_REGISTER_SAMPLE(_random_randint, SampleRandIntParam) +.add_alias("random_randint") +.describe(R"code(Draw random samples from a discrete uniform distribution. + +Samples are uniformly distributed over the half-open interval *[low, high)* +(includes *low*, but excludes *high*). + +Example:: + + randint(low=0, high=5, shape=(2,2)) = [[ 0, 2], + [ 3, 1]] + +)code" ADD_FILELINE) +.set_attr("FInferStorageType", + InitStorageType) +.set_attr("FCompute", Sample_>) +.set_attr("FComputeEx", SampleEx_>); + } // namespace op } // namespace mxnet diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h index a81b41a09af7..d5d4b879028a 100644 --- a/src/operator/random/sample_op.h +++ b/src/operator/random/sample_op.h @@ -235,6 +235,36 @@ struct SampleGenNegBinomialParam : public dmlc::Parameter { + int low; + int high; + TShape shape; + std::string ctx; + int dtype; + DMLC_DECLARE_PARAMETER(SampleRandIntParam) { + DMLC_DECLARE_FIELD(low).set_default(0) + .describe("Lower bound of the distribution."); + DMLC_DECLARE_FIELD(high).set_default(1) + .describe("Upper bound of the distribution."); + DMLC_DECLARE_FIELD(shape) + .set_default(TShape()) + .describe("Shape of the output."); + DMLC_DECLARE_FIELD(ctx) + .set_default("") + .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." + " Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("None", -1) + .add_enum("uint8", kUint8) + .add_enum("int32", kInt32) + .add_enum("int8", kInt8) + .add_enum("int64", kInt64) + .set_default(-1) + .describe("DType of the output in case this can't be inferred. " + "Defaults to int32 if not defined (dtype=None)."); + } +}; + using FSampleCompute = std::function> { } }; +template +struct SampleMaster> { + static void op(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const OpReqType& req, + TBlob* outputs) { + Stream *s = ctx.get_stream(); + const SampleRandIntParam& param = nnvm::get(attrs.parsed); + CHECK_GE(param.high, param.low) << "low must be less or equal to high in uniform distribution"; + Tensor low, high; + GetSamplingTempData(param.low, param.high, ctx, + &low, &high); + RandIntSampler sampler; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + RandGenerator *pgen = ctx.requested[0].get_parallel_random(); + Tensor out = outputs->FlatTo1D(s); + sampler.Sample(low, high, out, pgen, s); + }); + } +}; + template void SampleComputeEx_(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -468,7 +519,7 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs, int dtype = -1; int dtype_out = (*out_type)[0]; if (dtype_out != -1) { - // Output type can be inferred, use it and make sure it + // Output type can be inferred, use it and make sure it matches dtype = dtype_out; if (param.dtype != -1) { // dtype given in args, check that it matches the output type @@ -486,10 +537,12 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs, } } bool dtype_ok = (dtype == kFloat16) || (dtype == kFloat32) || - (dtype == kFloat64); - CHECK_EQ(dtype_ok, true) << "Output type must be float16, float32, or float64: dtype is " + (dtype == kFloat64) || (dtype == kUint8) || (dtype == kInt32) || + (dtype == kInt8) || (dtype == kInt64); + CHECK_EQ(dtype_ok, true) << "Output type must be float16, float32, float64, uint8, " + "int32, int8, or int64: dtype is " << dtype_out << " vs " << kFloat16 << " or " << kFloat32 << " or " - << kFloat64; + << kFloat64 << " or " << kUint8 << " or " << kInt32 << " or " << kInt8 << " or " << kInt64; TYPE_ASSIGN_CHECK(*out_type, 0, dtype); return true; } diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h index 44f80ab56254..dec67da09d88 100644 --- a/src/operator/random/sampler.h +++ b/src/operator/random/sampler.h @@ -92,6 +92,33 @@ struct UniformSampler { } }; +template +struct SampleRandIntKernel { + template + MSHADOW_XINLINE static void Map(int id, RandGenerator gen, + const int N, const int step, + index_t nParm, index_t nSample, + const IType *lower, const IType *upper, OType *out) { + RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, { + index_t nBatch(1 + (nSample - 1) / nParm); + out[i] = OType(genImpl.discrete_uniform(lower[i / nBatch], upper[i / nBatch])); + }); + } +}; + +template +struct RandIntSampler { + template + MSHADOW_FORCE_INLINE void Sample(const Tensor& lower, + const Tensor& upper, + const Tensor& out, + RandGenerator *pgen, + Stream *s) { + LaunchRNG, xpu>(s, pgen, out.size(0), lower.size(0), out.size(0), + lower.dptr_, upper.dptr_, out.dptr_); + } +}; + template struct SampleNormalKernel { template diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 4310658ae0bf..02df39549846 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -767,6 +767,24 @@ def testLarge(data, repeat): testLarge(mx.nd.arange(0, 100000).reshape((10, 10000)), 10) testLarge(mx.nd.arange(0, 100000).reshape((10000, 10)), 10) +@with_seed() +def test_randint(): + dtypes = ['uint8','int32','int8', 'int64'] + for dtype in dtypes: + params = { + 'low': -1, + 'high': 3, + 'shape' : (500, 500), + 'dtype' : dtype, + 'ctx' : mx.context.current_context() + } + mx.random.seed(128) + ret1 = mx.nd.random.randint(**params).asnumpy() + mx.random.seed(128) + ret2 = mx.nd.random.randint(**params).asnumpy() + assert same(ret1, ret2), \ + "ndarray test: `%s` should give the same result with the same seed" + if __name__ == '__main__': import nose nose.runmodule() From 89ad2ee34ae30ef82cb331474d191823ea0f0252 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 11 Oct 2018 14:30:56 -0700 Subject: [PATCH 02/30] register param --- src/operator/random/sample_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index e11c8404ab52..dd7b3e4240f0 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -36,6 +36,7 @@ DMLC_REGISTER_PARAMETER(SampleExponentialParam); DMLC_REGISTER_PARAMETER(SamplePoissonParam); DMLC_REGISTER_PARAMETER(SampleNegBinomialParam); DMLC_REGISTER_PARAMETER(SampleGenNegBinomialParam); +DMLC_REGISTER_PARAMETER(SampleRandIntParam); #define MXNET_OPERATOR_REGISTER_SAMPLE(name, ParamType) \ NNVM_REGISTER_OP(name) \ From 3dabb6dfe7d017594b9ecd78f8793c3f136b23e5 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 11 Oct 2018 14:35:25 -0700 Subject: [PATCH 03/30] lint space issue --- include/mxnet/random_generator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 8b289129e591..300fac271f8d 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -156,7 +156,7 @@ class RandGenerator { return dist_discrete_uniform(*engine_); } - private: + private: RandGenerator *global_gen_; int global_state_idx_; curandStatePhilox4_32_10_t state_; From 40adfdec76eacbf05c4ebafe389adc80f7f9d5b1 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 11 Oct 2018 16:35:44 -0700 Subject: [PATCH 04/30] randn issue fix --- python/mxnet/ndarray/random.py | 5 ----- tests/python/unittest/test_random.py | 6 +++--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py index 12d0c222c202..a0489861d893 100644 --- a/python/mxnet/ndarray/random.py +++ b/python/mxnet/ndarray/random.py @@ -193,11 +193,6 @@ def randn(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs [5.357444 5.7793283 3.9896927]] """ - loc = kwargs.pop('loc', 0) - scale = kwargs.pop('scale', 1) - dtype = kwargs.pop('dtype', _Null) - ctx = kwargs.pop('ctx', None) - out = kwargs.pop('out', None) assert isinstance(loc, (int, float)) assert isinstance(scale, (int, float)) return _random_helper(_internal._random_normal, _internal._sample_normal, diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 02df39549846..c8fa68950793 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -133,9 +133,9 @@ def check_with_device(device, dtype): params = symbdic['params'].copy() params.update(shape=shape, dtype=dtype, ctx=device) args = () - if name == 'randn': - params.pop('shape') # randn does not accept shape param - args = shape + # if name == 'randn': + # params.pop('shape') # randn does not accept shape param + # args = shape mx.random.seed(128) ret1 = ndop(*args, **params).asnumpy() mx.random.seed(128) From 9827e1e5321850d0cdc6a8f2cb44bc4dd39b1354 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Fri, 12 Oct 2018 18:03:34 -0700 Subject: [PATCH 05/30] uniform_int_distribution doesn't support int8, uint8 fix --- src/operator/random/sample_op.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h index d5d4b879028a..0e4a1532adf0 100644 --- a/src/operator/random/sample_op.h +++ b/src/operator/random/sample_op.h @@ -255,9 +255,7 @@ struct SampleRandIntParam : public dmlc::Parameter { " Only used for imperative calls."); DMLC_DECLARE_FIELD(dtype) .add_enum("None", -1) - .add_enum("uint8", kUint8) .add_enum("int32", kInt32) - .add_enum("int8", kInt8) .add_enum("int64", kInt64) .set_default(-1) .describe("DType of the output in case this can't be inferred. " From 3363e56b612ae6a001b756b7961b6da9adf34297 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Fri, 12 Oct 2018 18:22:26 -0700 Subject: [PATCH 06/30] dtype ftype --- include/mxnet/random_generator.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 300fac271f8d..aacf3de2c8c0 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -72,7 +72,7 @@ class RandGenerator { } MSHADOW_XINLINE FType discrete_uniform(const int lower, const int upper) { - std::uniform_int_distribution dist_discrete_uniform(lower, upper); + std::uniform_int_distribution dist_discrete_uniform(lower, upper); return dist_discrete_uniform(*engine_); } @@ -152,7 +152,7 @@ class RandGenerator { template MSHADOW_XINLINE FType discrete_uniform(const IType *lower, const IType *upper) { - std::uniform_int_distribution dist_discrete_uniform(*lower, *upper); + std::uniform_int_distribution dist_discrete_uniform(*lower, *upper); return dist_discrete_uniform(*engine_); } From df14294fe1126a56dd871e97d76351ce05afaca4 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Fri, 12 Oct 2018 18:52:01 -0700 Subject: [PATCH 07/30] ftype to dtype - invalid template arg --- include/mxnet/random_generator.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index aacf3de2c8c0..c47c9a4a9fce 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -151,8 +151,8 @@ class RandGenerator { } template - MSHADOW_XINLINE FType discrete_uniform(const IType *lower, const IType *upper) { - std::uniform_int_distribution dist_discrete_uniform(*lower, *upper); + MSHADOW_XINLINE DType discrete_uniform(const IType *lower, const IType *upper) { + std::uniform_int_distribution dist_discrete_uniform(*lower, *upper); return dist_discrete_uniform(*engine_); } From 2e24870c6b9579e636536e5e425fb01872474231 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 15 Oct 2018 12:00:21 -0700 Subject: [PATCH 08/30] fix template arg issue --- include/mxnet/random_generator.h | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index c47c9a4a9fce..653be3a3f950 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -71,8 +71,11 @@ class RandGenerator { return dist_uniform(*engine_); } - MSHADOW_XINLINE FType discrete_uniform(const int lower, const int upper) { - std::uniform_int_distribution dist_discrete_uniform(lower, upper); + MSHADOW_XINLINE DType discrete_uniform(const int lower, const int upper) { + typedef typename std::conditional::value, + std::uniform_int_distribution, + std::uniform_real_distribution>::type GType; + GType dist_discrete_uniform(lower, upper); return dist_discrete_uniform(*engine_); } @@ -150,12 +153,6 @@ class RandGenerator { return curand_normal(&state_); } - template - MSHADOW_XINLINE DType discrete_uniform(const IType *lower, const IType *upper) { - std::uniform_int_distribution dist_discrete_uniform(*lower, *upper); - return dist_discrete_uniform(*engine_); - } - private: RandGenerator *global_gen_; int global_state_idx_; From 693622c8880b8f5504d438d1d61e3c7f11920da7 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 15 Oct 2018 14:55:33 -0700 Subject: [PATCH 09/30] test with int dtype for windows --- include/mxnet/random_generator.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 653be3a3f950..0ef437296133 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -71,11 +71,12 @@ class RandGenerator { return dist_uniform(*engine_); } - MSHADOW_XINLINE DType discrete_uniform(const int lower, const int upper) { - typedef typename std::conditional::value, - std::uniform_int_distribution, - std::uniform_real_distribution>::type GType; - GType dist_discrete_uniform(lower, upper); + MSHADOW_XINLINE int discrete_uniform(const int lower, const int upper) { + // typedef typename std::conditional::value, + // std::uniform_int_distribution, + // std::uniform_real_distribution>::type GType; + // GType dist_discrete_uniform(lower, upper); + std::uniform_int_distribution dist_discrete_uniform(lower, upper); return dist_discrete_uniform(*engine_); } From c135e3af72b6d652327112b32c443091a2bab9ac Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 16 Oct 2018 11:32:44 -0700 Subject: [PATCH 10/30] removed int8,uint8 from test --- tests/python/unittest/test_random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index c8fa68950793..b19e8830c346 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -769,7 +769,7 @@ def testLarge(data, repeat): @with_seed() def test_randint(): - dtypes = ['uint8','int32','int8', 'int64'] + dtypes = ['int32', 'int64'] for dtype in dtypes: params = { 'low': -1, From d106ff3d84dafe02dbe8369fdc6d98ce7b66e1ec Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 16 Oct 2018 14:23:02 -0700 Subject: [PATCH 11/30] gpu implementation --- include/mxnet/random_generator.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 0ef437296133..89cb9f82fb16 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -154,6 +154,15 @@ class RandGenerator { return curand_normal(&state_); } + MSHADOW_XINLINE int discrete_uniform(const int lower, const int upper) { + // typedef typename std::conditional::value, + // std::uniform_int_distribution, + // std::uniform_real_distribution>::type GType; + // GType dist_discrete_uniform(lower, upper); + std::uniform_int_distribution dist_discrete_uniform(lower, upper); + return dist_discrete_uniform(*engine_); + } + private: RandGenerator *global_gen_; int global_state_idx_; From ef9e1f1687a862ca1440dfd7584170576c9a8e65 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 16 Oct 2018 15:17:13 -0700 Subject: [PATCH 12/30] gpu engine state diff --- include/mxnet/random_generator.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 89cb9f82fb16..cb7d9c52422a 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -87,7 +87,7 @@ class RandGenerator { private: std::mt19937 *engine_; - }; + }; // class RandGenerator::Impl static void AllocState(RandGenerator *inst) { inst->states_ = new std::mt19937[kNumRandomStates]; @@ -160,7 +160,7 @@ class RandGenerator { // std::uniform_real_distribution>::type GType; // GType dist_discrete_uniform(lower, upper); std::uniform_int_distribution dist_discrete_uniform(lower, upper); - return dist_discrete_uniform(*engine_); + return dist_discrete_uniform(&state_); } private: From ad574dcd59a9ae0bca5c2a4b7d09ad806b7419f8 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 16 Oct 2018 16:23:13 -0700 Subject: [PATCH 13/30] removed gpu support --- include/mxnet/random_generator.h | 9 --------- tests/python/unittest/test_random.py | 2 +- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index cb7d9c52422a..20f8fa462c40 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -154,15 +154,6 @@ class RandGenerator { return curand_normal(&state_); } - MSHADOW_XINLINE int discrete_uniform(const int lower, const int upper) { - // typedef typename std::conditional::value, - // std::uniform_int_distribution, - // std::uniform_real_distribution>::type GType; - // GType dist_discrete_uniform(lower, upper); - std::uniform_int_distribution dist_discrete_uniform(lower, upper); - return dist_discrete_uniform(&state_); - } - private: RandGenerator *global_gen_; int global_state_idx_; diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index b19e8830c346..392e223d44aa 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -776,7 +776,7 @@ def test_randint(): 'high': 3, 'shape' : (500, 500), 'dtype' : dtype, - 'ctx' : mx.context.current_context() + 'ctx' : mx.context.cpu() } mx.random.seed(128) ret1 = mx.nd.random.randint(**params).asnumpy() From e3f6afc2446c50630c6b36942f7111ad57b790a2 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 16 Oct 2018 19:02:22 -0700 Subject: [PATCH 14/30] empty commit From 70cb9af4501ead3e2a7350a113f114e1266d6ae1 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 17 Oct 2018 10:59:23 -0700 Subject: [PATCH 15/30] temporary fix : batchnorm flaky test skip --- tests/python/unittest/test_gluon.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index a6932d252dd2..e8ef704c4ec2 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -2007,6 +2007,7 @@ def hybrid_forward(self, F, x): @with_seed() +@unittest.skip('Flaky test: https://github.com/apache/incubator-mxnet/issues/12767') def test_slice_batchnorm_reshape_batchnorm(): class Net(gluon.HybridBlock): def __init__(self, shape, slice, **kwargs): From bf47cde69a55bf777f5108ef4dc4a2d486965496 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 18 Oct 2018 12:54:32 -0700 Subject: [PATCH 16/30] removed randn symbol specific code since other PR is on it --- python/mxnet/symbol/random.py | 27 --------------------------- tests/python/unittest/test_random.py | 6 +++--- 2 files changed, 3 insertions(+), 30 deletions(-) diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py index 1be8d6aac1bd..3e2667b1db46 100644 --- a/python/mxnet/symbol/random.py +++ b/python/mxnet/symbol/random.py @@ -96,33 +96,6 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs): [loc, scale], shape, dtype, kwargs) -def randn(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs): - """Draw random samples from a normal (Gaussian) distribution. - - Samples are distributed according to a normal distribution parametrized - by *loc* (mean) and *scale* (standard deviation). - - - Parameters - ---------- - loc : float or NDArray, optional - Mean (centre) of the distribution. - scale : float or NDArray, optional - Standard deviation (spread or width) of the distribution. - shape : int or tuple of ints, optional - The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and - `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` - are NDArrays with shape, e.g., `(x, y)`, then output will have shape - `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. - dtype : {'float16','float32', 'float64'}, optional - Data type of output samples. Default is 'float32' - """ - assert isinstance(loc, (int, float)) - assert isinstance(scale, (int, float)) - return _random_helper(_internal._random_normal, _internal._sample_normal, - [loc, scale], shape, dtype, kwargs) - - def poisson(lam=1, shape=_Null, dtype=_Null, **kwargs): """Draw random samples from a Poisson distribution. diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 392e223d44aa..5a29b8fe5e03 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -133,9 +133,9 @@ def check_with_device(device, dtype): params = symbdic['params'].copy() params.update(shape=shape, dtype=dtype, ctx=device) args = () - # if name == 'randn': - # params.pop('shape') # randn does not accept shape param - # args = shape + if name == 'randn': + params.pop('shape') # randn does not accept shape param + args = shape mx.random.seed(128) ret1 = ndop(*args, **params).asnumpy() mx.random.seed(128) From c3abb3a626986da03481862c0a93162727eb09aa Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 18 Oct 2018 14:35:03 -0700 Subject: [PATCH 17/30] revert ndarray/randn for compatibility --- python/mxnet/ndarray/random.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py index a0489861d893..15f47ff05233 100644 --- a/python/mxnet/ndarray/random.py +++ b/python/mxnet/ndarray/random.py @@ -152,7 +152,7 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg [loc, scale], shape, dtype, ctx, out, kwargs) -def randn(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): +def randn(*shape, **kwargs): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -161,21 +161,21 @@ def randn(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs Parameters ---------- - loc : float or NDArray, optional + loc : float or NDArray Mean (centre) of the distribution. - scale : float or NDArray, optional + scale : float or NDArray Standard deviation (spread or width) of the distribution. - shape : int or tuple of ints, optional + shape : int or tuple of ints The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16','float32', 'float64'} Data type of output samples. Default is 'float32' - ctx : Context, optional + ctx : Context Device context of output. Default is current context. Overridden by `loc.context` when `loc` is an NDArray. - out : NDArray, optional + out : NDArray Store output to an existing NDArray. @@ -193,6 +193,11 @@ def randn(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs [5.357444 5.7793283 3.9896927]] """ + loc = kwargs.pop('loc', 0) + scale = kwargs.pop('scale', 1) + dtype = kwargs.pop('dtype', _Null) + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) assert isinstance(loc, (int, float)) assert isinstance(scale, (int, float)) return _random_helper(_internal._random_normal, _internal._sample_normal, From 4bef3af3e61b994fcd3dc795678d5d38ebd54f00 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Fri, 19 Oct 2018 15:51:24 -0700 Subject: [PATCH 18/30] added unit test for checking extremes and uniform distribution for sufficient samples --- tests/python/unittest/test_random.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 5a29b8fe5e03..7979b4fd7476 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -785,6 +785,32 @@ def test_randint(): assert same(ret1, ret2), \ "ndarray test: `%s` should give the same result with the same seed" +@with_seed() +def test_randint_extremes(): + a = mx.nd.random.randint(dtype='int64', low=50000000, high=50000010) + assert a>50000000 and a<50000010 + +@with_seed() +def test_randint_generator(): + ctx = mx.context.cpu() + low = 50000000 + high = 50001000 + for dtype in ['int64']: + print("ctx=%s, dtype=%s, Low=%g, High=%g:" % (ctx, dtype, low, high)) + scale = high - low + buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5) + # Quantize bucket boundaries to reflect the actual dtype and adjust probs accordingly + buckets = np.array(buckets, dtype=dtype).tolist() + probs = [(buckets[i][1] - buckets[i][0])/scale for i in range(5)] + generator_mx = lambda x: mx.nd.random.randint(low, high, shape=x, ctx=ctx, dtype=dtype).asnumpy() + verify_generator(generator=generator_mx, buckets=buckets, probs=probs) + generator_mx_same_seed = \ + lambda x: np.concatenate( + [mx.nd.random.randint(low, high, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy() + for _ in range(10)]) + verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs) + + if __name__ == '__main__': import nose nose.runmodule() From c201fb3f973d38965a29eddc6a257d4984a64297 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 29 Oct 2018 15:28:03 -0700 Subject: [PATCH 19/30] increased the high val --- tests/python/unittest/test_random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 7979b4fd7476..ca7913355c08 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -794,7 +794,7 @@ def test_randint_extremes(): def test_randint_generator(): ctx = mx.context.cpu() low = 50000000 - high = 50001000 + high = 50010000 for dtype in ['int64']: print("ctx=%s, dtype=%s, Low=%g, High=%g:" % (ctx, dtype, low, high)) scale = high - low From 851130396b46c60a0ac0649a457abdea772ccb77 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Sat, 3 Nov 2018 11:47:20 -0700 Subject: [PATCH 20/30] int32 to int64 support, indentation fix, check for optype correctly based on type of random function --- include/mxnet/random_generator.h | 4 +- python/mxnet/ndarray/random.py | 26 ++++++------- python/mxnet/symbol/random.py | 24 ++++++------ src/operator/random/sample_op.cc | 50 ++++++++++++++----------- src/operator/random/sample_op.h | 56 ++++++++++++++++++++++------ tests/python/unittest/test_random.py | 2 +- 6 files changed, 98 insertions(+), 64 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 20f8fa462c40..cc1e7591b239 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -71,12 +71,12 @@ class RandGenerator { return dist_uniform(*engine_); } - MSHADOW_XINLINE int discrete_uniform(const int lower, const int upper) { + MSHADOW_XINLINE int64_t discrete_uniform(const int64_t lower, const int64_t upper) { // typedef typename std::conditional::value, // std::uniform_int_distribution, // std::uniform_real_distribution>::type GType; // GType dist_discrete_uniform(lower, upper); - std::uniform_int_distribution dist_discrete_uniform(lower, upper); + std::uniform_int_distribution dist_discrete_uniform(lower, upper); return dist_discrete_uniform(*engine_); } diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py index 15f47ff05233..a05ac75c5242 100644 --- a/python/mxnet/ndarray/random.py +++ b/python/mxnet/ndarray/random.py @@ -70,7 +70,7 @@ def uniform(low=0, high=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg `high` are scalars, output shape will be `(m, n)`. If `low` and `high` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. Overridden by @@ -120,7 +120,7 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. Overridden by @@ -170,7 +170,7 @@ def randn(*shape, **kwargs): `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. - dtype : {'float16','float32', 'float64'} + dtype : {'float16', 'float32', 'float64'} Data type of output samples. Default is 'float32' ctx : Context Device context of output. Default is current context. Overridden by @@ -219,7 +219,7 @@ def poisson(lam=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): a scalar, output shape will be `(m, n)`. If `lam` is an NDArray with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `lam`. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. Overridden by @@ -266,7 +266,7 @@ def exponential(scale=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs) a scalar, output shape will be `(m, n)`. If `scale` is an NDArray with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `scale`. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. Overridden by @@ -312,7 +312,7 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg `beta` are scalars, output shape will be `(m, n)`. If `alpha` and `beta` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[alpha, beta)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. Overridden by @@ -361,7 +361,7 @@ def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, ctx=None, `p` are scalars, output shape will be `(m, n)`. If `k` and `p` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[k, p)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. Overridden by @@ -412,7 +412,7 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, ctx=N `alpha` are scalars, output shape will be `(m, n)`. If `mu` and `alpha` are NDArrays with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[mu, alpha)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. Overridden by @@ -528,18 +528,16 @@ def randint(low=0, high=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg Parameters ---------- - low : float or NDArray, optional + low : int, optional Lower boundary of the output interval. All values generated will be greater than or equal to low. The default value is 0. - high : float or NDArray, optional + high : int, optional Upper boundary of the output interval. All values generated will be less than high. The default value is 1. shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `low` and - `high` are scalars, output shape will be `(m, n)`. If `low` and `high` - are NDArrays with shape, e.g., `(x, y)`, then output will have shape - `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. - dtype : {'uint8','int32','int8', 'int64'}, optional + `high` are scalars, output shape will be `(m, n)`. + dtype : {'int32', 'int64'}, optional Data type of output samples. Default is 'int32' ctx : Context, optional Device context of output. Default is current context. Overridden by diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py index 3e2667b1db46..0a3b0ff39508 100644 --- a/python/mxnet/symbol/random.py +++ b/python/mxnet/symbol/random.py @@ -64,7 +64,7 @@ def uniform(low=0, high=1, shape=_Null, dtype=_Null, **kwargs): `high` are scalars, output shape will be `(m, n)`. If `low` and `high` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_uniform, _internal._sample_uniform, @@ -89,7 +89,7 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs): `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_normal, _internal._sample_normal, @@ -111,7 +111,7 @@ def poisson(lam=1, shape=_Null, dtype=_Null, **kwargs): a scalar, output shape will be `(m, n)`. If `lam` is an Symbol with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `lam`. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_poisson, _internal._sample_poisson, @@ -137,7 +137,7 @@ def exponential(scale=1, shape=_Null, dtype=_Null, **kwargs): a scalar, output shape will be `(m, n)`. If `scale` is an Symbol with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each entry in `scale`. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_exponential, _internal._sample_exponential, @@ -162,7 +162,7 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, **kwargs): `beta` are scalars, output shape will be `(m, n)`. If `alpha` and `beta` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[alpha, beta)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_gamma, _internal._sample_gamma, @@ -188,7 +188,7 @@ def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, **kwargs): `p` are scalars, output shape will be `(m, n)`. If `k` and `p` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[k, p)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_negative_binomial, @@ -216,7 +216,7 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, **kwa `alpha` are scalars, output shape will be `(m, n)`. If `mu` and `alpha` are Symbols with shape, e.g., `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are drawn for each `[mu, alpha)` pair. - dtype : {'float16','float32', 'float64'}, optional + dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' """ return _random_helper(_internal._random_generalized_negative_binomial, @@ -291,18 +291,16 @@ def randint(low=0, high=1, shape=_Null, dtype=_Null, **kwargs): Parameters ---------- - low : float or NDArray, optional + low : int, optional Lower boundary of the output interval. All values generated will be greater than or equal to low. The default value is 0. - high : float or NDArray, optional + high : int, optional Upper boundary of the output interval. All values generated will be less than high. The default value is 1. shape : int or tuple of ints, optional The number of samples to draw. If shape is, e.g., `(m, n)` and `low` and - `high` are scalars, output shape will be `(m, n)`. If `low` and `high` - are NDArrays with shape, e.g., `(x, y)`, then output will have shape - `(x, y, m, n)`, where `m*n` samples are drawn for each `[low, high)` pair. - dtype : {'uint8','int32','int8', 'int64'}, optional + `high` are scalars, output shape will be `(m, n)`. + dtype : {'int32', 'int64'}, optional Data type of output samples. Default is 'int32' """ return _random_helper(_internal._random_randint, _internal._sample_uniform, diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index 2158a7ad1f1f..66bee7185ce9 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -52,7 +52,6 @@ DMLC_REGISTER_PARAMETER(SampleGenNegBinomialLikeParam); .set_num_outputs(1) \ .set_attr_parser(ParamParser) \ .set_attr("FInferShape", InitShape) \ - .set_attr("FInferType", SampleOpType) \ .set_attr("FResourceRequest", SampleResource) \ .add_arguments(ParamType::__FIELDS__()) \ .set_attr("FInferStorageType", InitStorageType) \ @@ -93,7 +92,8 @@ Example:: uniform(low=0, high=1, shape=(2,2)) = [[ 0.60276335, 0.85794562], [ 0.54488319, 0.84725171]] -)code" ADD_FILELINE); +)code" ADD_FILELINE) +.set_attr("FInferType", SampleOpType); // Add "normal" alias for backward compatibility MXNET_OPERATOR_REGISTER_SAMPLE(_random_normal, SampleNormalParam) @@ -110,7 +110,8 @@ Example:: normal(loc=0, scale=1, shape=(2,2)) = [[ 1.89171135, -1.16881478], [-1.23474145, 1.55807114]] -)code" ADD_FILELINE); +)code" ADD_FILELINE) +.set_attr("FInferType", SampleOpType); MXNET_OPERATOR_REGISTER_SAMPLE(_random_gamma, SampleGammaParam) .add_alias("random_gamma") @@ -122,7 +123,8 @@ Example:: gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289], [ 3.91697288, 3.65933681]] -)code" ADD_FILELINE); +)code" ADD_FILELINE) +.set_attr("FInferType", SampleOpType); MXNET_OPERATOR_REGISTER_SAMPLE(_random_exponential, SampleExponentialParam) .add_alias("random_exponential") @@ -134,7 +136,8 @@ Example:: exponential(lam=4, shape=(2,2)) = [[ 0.0097189 , 0.08999364], [ 0.04146638, 0.31715935]] -)code" ADD_FILELINE); +)code" ADD_FILELINE) +.set_attr("FInferType", SampleOpType); MXNET_OPERATOR_REGISTER_SAMPLE(_random_poisson, SamplePoissonParam) .add_alias("random_poisson") @@ -147,7 +150,8 @@ Example:: poisson(lam=4, shape=(2,2)) = [[ 5., 2.], [ 4., 6.]] -)code" ADD_FILELINE); +)code" ADD_FILELINE +).set_attr("FInferType", SampleOpType); MXNET_OPERATOR_REGISTER_SAMPLE(_random_negative_binomial, SampleNegBinomialParam) .add_alias("random_negative_binomial") @@ -161,7 +165,8 @@ Example:: negative_binomial(k=3, p=0.4, shape=(2,2)) = [[ 4., 7.], [ 2., 5.]] -)code" ADD_FILELINE); +)code" ADD_FILELINE) +.set_attr("FInferType", SampleOpType); MXNET_OPERATOR_REGISTER_SAMPLE(_random_generalized_negative_binomial, SampleGenNegBinomialParam) .add_alias("random_generalized_negative_binomial") @@ -176,8 +181,23 @@ Example:: generalized_negative_binomial(mu=2.0, alpha=0.3, shape=(2,2)) = [[ 2., 1.], [ 6., 4.]] -)code" ADD_FILELINE); +)code" ADD_FILELINE) +.set_attr("FInferType", SampleOpType); +MXNET_OPERATOR_REGISTER_SAMPLE(_random_randint, SampleRandIntParam) +.add_alias("random_randint") +.describe(R"code(Draw random samples from a discrete uniform distribution. + +Samples are uniformly distributed over the half-open interval *[low, high)* +(includes *low*, but excludes *high*). + +Example:: + + randint(low=0, high=5, shape=(2,2)) = [[ 0, 2], + [ 3, 1]] + +)code" ADD_FILELINE) +.set_attr("FInferType", RandIntOpType); // *_like operators @@ -269,19 +289,5 @@ Example:: [ 6., 4.]] )code" ADD_FILELINE); -MXNET_OPERATOR_REGISTER_SAMPLE(_random_randint, SampleRandIntParam) -.add_alias("random_randint") -.describe(R"code(Draw random samples from a discrete uniform distribution. - -Samples are uniformly distributed over the half-open interval *[low, high)* -(includes *low*, but excludes *high*). - -Example:: - - randint(low=0, high=5, shape=(2,2)) = [[ 0, 2], - [ 3, 1]] - -)code" ADD_FILELINE); - } // namespace op } // namespace mxnet diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h index 510452ef8cf1..c0d5e1c0348f 100644 --- a/src/operator/random/sample_op.h +++ b/src/operator/random/sample_op.h @@ -79,8 +79,8 @@ struct GenNegBinomialParam { }; struct RandIntParam { - int low; - int high; + int64_t low; + int64_t high; }; struct SampleUniformParam : public dmlc::Parameter, @@ -505,14 +505,14 @@ static inline void gen_neg_binomial_op(const nnvm::NodeAttrs& attrs, template static inline void rand_int_op(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const OpReqType& req, - TBlob* outputs) { + const OpContext& ctx, + const OpReqType& req, + TBlob* outputs) { Stream *s = ctx.get_stream(); const SampleRandIntParam& param = nnvm::get(attrs.parsed); CHECK_GE(param.high, param.low) << "low must be less or equal to high in uniform distribution"; - Tensor low, high; - GetSamplingTempData(param.low, param.high, ctx, + Tensor low, high; + GetSamplingTempData(param.low, param.high, ctx, &low, &high); RandIntSampler sampler; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { @@ -750,12 +750,44 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs, } } bool dtype_ok = (dtype == kFloat16) || (dtype == kFloat32) || - (dtype == kFloat64) || (dtype == kUint8) || (dtype == kInt32) || - (dtype == kInt8) || (dtype == kInt64); - CHECK_EQ(dtype_ok, true) << "Output type must be float16, float32, float64, uint8, " - "int32, int8, or int64: dtype is " + (dtype == kFloat64); + CHECK_EQ(dtype_ok, true) << "Output type must be float16, float32, float64: dtype is " << dtype_out << " vs " << kFloat16 << " or " << kFloat32 << " or " - << kFloat64 << " or " << kUint8 << " or " << kInt32 << " or " << kInt8 << " or " << kInt64; + << kFloat64; + TYPE_ASSIGN_CHECK(*out_type, 0, dtype); + return true; +} + +template +inline bool RandIntOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const ParamType& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_type->size(), 0); + CHECK_EQ(out_type->size(), 1); + int dtype = -1; + int dtype_out = (*out_type)[0]; + if (dtype_out != -1) { + // Output type can be inferred, use it and make sure it matches + dtype = dtype_out; + if (param.dtype != -1) { + // dtype given in args, check that it matches the output type + CHECK_EQ(dtype_out, param.dtype) << "Output type does not match requested type: " + << dtype_out << " vs " << param.dtype; + } + } else { + // Output type can't be inferred + if (param.dtype != -1) { + // Use dtype given in args + dtype = param.dtype; + } else { + // Use default + dtype = kFloat32; + } + } + bool dtype_ok = (dtype == kInt32) || (dtype == kInt64); + CHECK_EQ(dtype_ok, true) << "Output type must be int32, int64: dtype is " + << dtype_out << " vs " << kInt32 << " or " << kInt64; TYPE_ASSIGN_CHECK(*out_type, 0, dtype); return true; } diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 4a0c981460e5..66153d1b8bc8 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -872,7 +872,7 @@ def test_randint_generator(): ctx = mx.context.cpu() low = 50000000 high = 50010000 - for dtype in ['int64']: + for dtype in ['int32', 'int64']: print("ctx=%s, dtype=%s, Low=%g, High=%g:" % (ctx, dtype, low, high)) scale = high - low buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5) From d7aa9bd6e2d9f1400c052f62c66085e9c0b31ad2 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 6 Nov 2018 11:08:19 -0800 Subject: [PATCH 21/30] gpu support, revert finfertype using template specialization, remove defaults, prints, test other low high val --- include/mxnet/random_generator.h | 8 ++----- python/mxnet/ndarray/random.py | 10 ++++---- python/mxnet/symbol/random.py | 10 ++++---- src/operator/random/sample_op.cc | 25 +++++++------------ src/operator/random/sample_op.cu | 1 + src/operator/random/sample_op.h | 16 ++++++++----- tests/python/unittest/test_random.py | 36 ++++++++++------------------ 7 files changed, 47 insertions(+), 59 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index cc1e7591b239..8ed55f0c70e4 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -71,12 +71,8 @@ class RandGenerator { return dist_uniform(*engine_); } - MSHADOW_XINLINE int64_t discrete_uniform(const int64_t lower, const int64_t upper) { - // typedef typename std::conditional::value, - // std::uniform_int_distribution, - // std::uniform_real_distribution>::type GType; - // GType dist_discrete_uniform(lower, upper); - std::uniform_int_distribution dist_discrete_uniform(lower, upper); + MSHADOW_XINLINE DType discrete_uniform(const int64_t lower, const int64_t upper) { + std::uniform_int_distribution dist_discrete_uniform(lower, upper); return dist_discrete_uniform(*engine_); } diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py index a05ac75c5242..67555118daf1 100644 --- a/python/mxnet/ndarray/random.py +++ b/python/mxnet/ndarray/random.py @@ -520,17 +520,19 @@ def shuffle(data, **kwargs): return _internal._shuffle(data, **kwargs) -def randint(low=0, high=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): +def randint(low, high=None, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): """Draw random samples from a discrete uniform distribution. Samples are uniformly distributed over the half-open interval *[low, high)* - (includes *low*, but excludes *high*). + (includes *low*, but excludes *high*). If high is None (the default), then results + are from *[0, low)*. Parameters ---------- - low : int, optional + low : int, required Lower boundary of the output interval. All values generated will be - greater than or equal to low. The default value is 0. + greater than or equal to low. Unless high=None, in which case this parameter + becomes high and low is set to 0 high : int, optional Upper boundary of the output interval. All values generated will be less than high. The default value is 1. diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py index 0a3b0ff39508..887ef851d32f 100644 --- a/python/mxnet/symbol/random.py +++ b/python/mxnet/symbol/random.py @@ -283,17 +283,19 @@ def shuffle(data, **kwargs): return _internal._shuffle(data, **kwargs) -def randint(low=0, high=1, shape=_Null, dtype=_Null, **kwargs): +def randint(low, high=None, shape=_Null, dtype=_Null, **kwargs): """Draw random samples from a discrete uniform distribution. Samples are uniformly distributed over the half-open interval *[low, high)* - (includes *low*, but excludes *high*). + (includes *low*, but excludes *high*). If high is None (the default), then results + are from *[0, low)*. Parameters ---------- - low : int, optional + low : int, required Lower boundary of the output interval. All values generated will be - greater than or equal to low. The default value is 0. + greater than or equal to low. Unless high=None, in which case this parameter + becomes high and low is set to 0 high : int, optional Upper boundary of the output interval. All values generated will be less than high. The default value is 1. diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index 66bee7185ce9..fdbf64e92006 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -52,6 +52,7 @@ DMLC_REGISTER_PARAMETER(SampleGenNegBinomialLikeParam); .set_num_outputs(1) \ .set_attr_parser(ParamParser) \ .set_attr("FInferShape", InitShape) \ + .set_attr("FInferType", SampleOpType) \ .set_attr("FResourceRequest", SampleResource) \ .add_arguments(ParamType::__FIELDS__()) \ .set_attr("FInferStorageType", InitStorageType) \ @@ -92,8 +93,7 @@ Example:: uniform(low=0, high=1, shape=(2,2)) = [[ 0.60276335, 0.85794562], [ 0.54488319, 0.84725171]] -)code" ADD_FILELINE) -.set_attr("FInferType", SampleOpType); +)code" ADD_FILELINE); // Add "normal" alias for backward compatibility MXNET_OPERATOR_REGISTER_SAMPLE(_random_normal, SampleNormalParam) @@ -110,8 +110,7 @@ Example:: normal(loc=0, scale=1, shape=(2,2)) = [[ 1.89171135, -1.16881478], [-1.23474145, 1.55807114]] -)code" ADD_FILELINE) -.set_attr("FInferType", SampleOpType); +)code" ADD_FILELINE); MXNET_OPERATOR_REGISTER_SAMPLE(_random_gamma, SampleGammaParam) .add_alias("random_gamma") @@ -123,8 +122,7 @@ Example:: gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289], [ 3.91697288, 3.65933681]] -)code" ADD_FILELINE) -.set_attr("FInferType", SampleOpType); +)code" ADD_FILELINE); MXNET_OPERATOR_REGISTER_SAMPLE(_random_exponential, SampleExponentialParam) .add_alias("random_exponential") @@ -136,8 +134,7 @@ Example:: exponential(lam=4, shape=(2,2)) = [[ 0.0097189 , 0.08999364], [ 0.04146638, 0.31715935]] -)code" ADD_FILELINE) -.set_attr("FInferType", SampleOpType); +)code" ADD_FILELINE); MXNET_OPERATOR_REGISTER_SAMPLE(_random_poisson, SamplePoissonParam) .add_alias("random_poisson") @@ -150,8 +147,7 @@ Example:: poisson(lam=4, shape=(2,2)) = [[ 5., 2.], [ 4., 6.]] -)code" ADD_FILELINE -).set_attr("FInferType", SampleOpType); +)code" ADD_FILELINE); MXNET_OPERATOR_REGISTER_SAMPLE(_random_negative_binomial, SampleNegBinomialParam) .add_alias("random_negative_binomial") @@ -165,8 +161,7 @@ Example:: negative_binomial(k=3, p=0.4, shape=(2,2)) = [[ 4., 7.], [ 2., 5.]] -)code" ADD_FILELINE) -.set_attr("FInferType", SampleOpType); +)code" ADD_FILELINE); MXNET_OPERATOR_REGISTER_SAMPLE(_random_generalized_negative_binomial, SampleGenNegBinomialParam) .add_alias("random_generalized_negative_binomial") @@ -181,8 +176,7 @@ Example:: generalized_negative_binomial(mu=2.0, alpha=0.3, shape=(2,2)) = [[ 2., 1.], [ 6., 4.]] -)code" ADD_FILELINE) -.set_attr("FInferType", SampleOpType); +)code" ADD_FILELINE); MXNET_OPERATOR_REGISTER_SAMPLE(_random_randint, SampleRandIntParam) .add_alias("random_randint") @@ -196,8 +190,7 @@ Example:: randint(low=0, high=5, shape=(2,2)) = [[ 0, 2], [ 3, 1]] -)code" ADD_FILELINE) -.set_attr("FInferType", RandIntOpType); +)code" ADD_FILELINE); // *_like operators diff --git a/src/operator/random/sample_op.cu b/src/operator/random/sample_op.cu index 55c04a989a01..016ce6d487a7 100644 --- a/src/operator/random/sample_op.cu +++ b/src/operator/random/sample_op.cu @@ -39,6 +39,7 @@ MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_exponential, SampleExponentialParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_poisson, SamplePoissonParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_negative_binomial, SampleNegBinomialParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_generalized_negative_binomial, SampleGenNegBinomialParam) +MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_int, SampleRandIntParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_uniform_like, SampleUniformLikeParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_normal_like, SampleNormalLikeParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_gamma_like, SampleGammaLikeParam) diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h index c0d5e1c0348f..cca0d6b6eacb 100644 --- a/src/operator/random/sample_op.h +++ b/src/operator/random/sample_op.h @@ -257,9 +257,9 @@ struct SampleGenNegBinomialParam : public dmlc::Parameter, RandIntParam, SampleOpParam { DMLC_DECLARE_PARAMETER(SampleRandIntParam) { - DMLC_DECLARE_FIELD(low).set_default(0) + DMLC_DECLARE_FIELD(low) .describe("Lower bound of the distribution."); - DMLC_DECLARE_FIELD(high).set_default(1) + DMLC_DECLARE_FIELD(high) .describe("Upper bound of the distribution."); DMLC_DECLARE_FIELD(shape) .set_default(TShape()) @@ -508,8 +508,12 @@ static inline void rand_int_op(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const OpReqType& req, TBlob* outputs) { - Stream *s = ctx.get_stream(); + Stream *s = ctx.get_stream(); const SampleRandIntParam& param = nnvm::get(attrs.parsed); + if(param.high==None){ + param.high = param.low; + param.low = 0; + } CHECK_GE(param.high, param.low) << "low must be less or equal to high in uniform distribution"; Tensor low, high; GetSamplingTempData(param.low, param.high, ctx, @@ -758,11 +762,11 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs, return true; } -template -inline bool RandIntOpType(const nnvm::NodeAttrs& attrs, +template<> +inline bool SampleOpType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { - const ParamType& param = nnvm::get(attrs.parsed); + const SampleRandIntParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_type->size(), 0); CHECK_EQ(out_type->size(), 1); int dtype = -1; diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 66153d1b8bc8..36159458d817 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -505,7 +505,6 @@ def test_normal_generator(): num_buckets = 5 for dtype in ['float16', 'float32', 'float64']: for mu, sigma in [(0.0, 1.0), (1.0, 5.0)]: - print("ctx=%s, dtype=%s, Mu=%g, Sigma=%g:" % (ctx, dtype, mu, sigma)) buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.ppf(x, mu, sigma), num_buckets) # Quantize bucket boundaries to reflect the actual dtype and adjust probs accordingly buckets = np.array(buckets, dtype=dtype).tolist() @@ -526,7 +525,6 @@ def test_uniform_generator(): ctx = mx.context.current_context() for dtype in ['float16', 'float32', 'float64']: for low, high in [(-1.0, 1.0), (1.0, 3.0)]: - print("ctx=%s, dtype=%s, Low=%g, High=%g:" % (ctx, dtype, low, high)) scale = high - low buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5) # Quantize bucket boundaries to reflect the actual dtype and adjust probs accordingly @@ -546,7 +544,6 @@ def test_gamma_generator(): ctx = mx.context.current_context() for dtype in ['float16', 'float32', 'float64']: for kappa, theta in [(0.5, 1.0), (1.0, 5.0)]: - print("ctx=%s, dtype=%s, Shape=%g, Scale=%g:" % (ctx, dtype, kappa, theta)) buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.gamma.ppf(x, a=kappa, loc=0, scale=theta), 5) generator_mx = lambda x: mx.nd.random.gamma(kappa, theta, shape=x, ctx=ctx, dtype=dtype).asnumpy() verify_generator(generator=generator_mx, buckets=buckets, probs=probs, success_rate=success_rate) @@ -561,7 +558,6 @@ def test_exponential_generator(): ctx = mx.context.current_context() for dtype in ['float16', 'float32', 'float64']: for scale in [0.1, 1.0]: - print("ctx=%s, dtype=%s, Scale=%g:" % (ctx, dtype, scale)) buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.expon.ppf(x, loc=0, scale=scale), 5) generator_mx = lambda x: mx.nd.random.exponential(scale, shape=x, ctx=ctx, dtype=dtype).asnumpy() verify_generator(generator=generator_mx, buckets=buckets, probs=probs) @@ -576,7 +572,6 @@ def test_poisson_generator(): ctx = mx.context.current_context() for dtype in ['float16', 'float32', 'float64']: for lam in [1, 10]: - print("ctx=%s, dtype=%s, Lambda=%d:" % (ctx, dtype, lam)) buckets = [(-1.0, lam - 0.5), (lam - 0.5, 2 * lam + 0.5), (2 * lam + 0.5, np.inf)] probs = [ss.poisson.cdf(bucket[1], lam) - ss.poisson.cdf(bucket[0], lam) for bucket in buckets] generator_mx = lambda x: mx.nd.random.poisson(lam, shape=x, ctx=ctx, dtype=dtype).asnumpy() @@ -593,7 +588,6 @@ def test_negative_binomial_generator(): for dtype in ['float16', 'float32', 'float64']: success_num = 2 success_prob = 0.2 - print("ctx=%s, dtype=%s, Success Num=%d:, Success Prob=%g" % (ctx, dtype, success_num, success_prob)) buckets = [(-1.0, 2.5), (2.5, 5.5), (5.5, 8.5), (8.5, np.inf)] probs = [ss.nbinom.cdf(bucket[1], success_num, success_prob) - ss.nbinom.cdf(bucket[0], success_num, success_prob) for bucket in buckets] @@ -606,7 +600,6 @@ def test_negative_binomial_generator(): for _ in range(10)]) verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs) # Also test the Gamm-Poisson Mixture - print('Gamm-Poisson Mixture Test:') alpha = 1.0 / success_num mu = (1.0 - success_prob) / success_prob / alpha generator_mx = lambda x: mx.nd.random.generalized_negative_binomial(mu, alpha, @@ -643,7 +636,6 @@ def quantize_probs(probs, dtype): trials = 5 buckets = list(range(6)) for dtype in ['float16', 'float32', 'float64']: - print("ctx=%s, dtype=%s" %(ctx, dtype)) quantized_probs = quantize_probs(probs, dtype) generator_mx = lambda x: mx.nd.random.multinomial(data=mx.nd.array(quantized_probs, ctx=ctx, dtype=dtype), shape=x).asnumpy() @@ -870,22 +862,20 @@ def test_randint_extremes(): @with_seed() def test_randint_generator(): ctx = mx.context.cpu() - low = 50000000 - high = 50010000 for dtype in ['int32', 'int64']: - print("ctx=%s, dtype=%s, Low=%g, High=%g:" % (ctx, dtype, low, high)) - scale = high - low - buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5) - # Quantize bucket boundaries to reflect the actual dtype and adjust probs accordingly - buckets = np.array(buckets, dtype=dtype).tolist() - probs = [(buckets[i][1] - buckets[i][0]) / float(scale) for i in range(5)] - generator_mx = lambda x: mx.nd.random.randint(low, high, shape=x, ctx=ctx, dtype=dtype).asnumpy() - verify_generator(generator=generator_mx, buckets=buckets, probs=probs) - generator_mx_same_seed = \ - lambda x: np.concatenate( - [mx.nd.random.randint(low, high, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy() - for _ in range(10)]) - verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs) + for low, high in [(50000000, 50001000),(-50000000,-990000),(-500,199),(-2147483647,2147483647)]: + scale = high - low + buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5) + # Quantize bucket boundaries to reflect the actual dtype and adjust probs accordingly + buckets = np.array(buckets, dtype=dtype).tolist() + probs = [(buckets[i][1] - buckets[i][0]) / float(scale) for i in range(5)] + generator_mx = lambda x: mx.nd.random.randint(low, high, shape=x, ctx=ctx, dtype=dtype).asnumpy() + verify_generator(generator=generator_mx, buckets=buckets, probs=probs) + generator_mx_same_seed = \ + lambda x: np.concatenate( + [mx.nd.random.randint(low, high, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy() + for _ in range(10)]) + verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs) if __name__ == '__main__': From 21d53568d7ec16e03d64b2edd963d57e46229d79 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 7 Nov 2018 17:54:43 -0800 Subject: [PATCH 22/30] fix for invalid template arg by checking for int32,int64 --- include/mxnet/random_generator.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 231f47289889..0c4f392f6e7d 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -73,7 +73,11 @@ class RandGenerator { } MSHADOW_XINLINE IType discrete_uniform(const int64_t lower, const int64_t upper) { - std::uniform_int_distribution dist_discrete_uniform(lower, upper); + typedef typename std::conditional, + std::uniform_int_distribution>::type GType; + GType dist_discrete_uniform(lower, upper); return dist_discrete_uniform(*engine_); } From 1254fa71172ede45392e53c68d080927bd9c96c0 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 8 Nov 2018 12:16:11 -0800 Subject: [PATCH 23/30] gpu randint in random_generator --- include/mxnet/random_generator.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 0c4f392f6e7d..7a5fadb5fe26 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -153,6 +153,10 @@ class RandGenerator { MSHADOW_FORCE_INLINE __device__ float normal() { return curand_normal(&state_); + + MSHADOW_FORCE_INLINE __device__ int discrete_uniform() { + return curand(&state_); + } } private: From 274366b61ba56ee3125712c7e1460aa20010c57f Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 12 Nov 2018 12:22:45 -0800 Subject: [PATCH 24/30] sample_uniform issue and param, removed old flaky test skip line --- src/operator/random/sample_op.cc | 2 +- tests/python/unittest/test_gluon.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index fdbf64e92006..b065615e1fb1 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -52,7 +52,7 @@ DMLC_REGISTER_PARAMETER(SampleGenNegBinomialLikeParam); .set_num_outputs(1) \ .set_attr_parser(ParamParser) \ .set_attr("FInferShape", InitShape) \ - .set_attr("FInferType", SampleOpType) \ + .set_attr("FInferType", SampleOpType) \ .set_attr("FResourceRequest", SampleResource) \ .add_arguments(ParamType::__FIELDS__()) \ .set_attr("FInferStorageType", InitStorageType) \ diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index e3c1f8ba9145..3049674821c9 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -2007,7 +2007,6 @@ def hybrid_forward(self, F, x): @with_seed() -@unittest.skip('Flaky test: https://github.com/apache/incubator-mxnet/issues/12767') def test_slice_batchnorm_reshape_batchnorm(): class Net(gluon.HybridBlock): def __init__(self, shape, slice, **kwargs): From e7e622c25f50ad0588ee9107e903279aea102ed3 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 13 Nov 2018 15:57:26 -0800 Subject: [PATCH 25/30] replaced discrete_uniform function by rand_int64 for consistency --- include/mxnet/random_generator.h | 35 +++++++++++--------------------- src/operator/random/sampler.h | 7 ++++++- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index db7990c7a000..7296305b96df 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -65,6 +65,10 @@ class RandGenerator { MSHADOW_XINLINE int rand() { return engine_->operator()(); } + MSHADOW_XINLINE int64_t rand_int64() { + return static_cast(engine_->operator()() << 31) + engine_->operator()(); + } + MSHADOW_XINLINE FType uniform() { typedef typename std::conditional::value, std::uniform_int_distribution, @@ -73,15 +77,6 @@ class RandGenerator { return dist_uniform(*engine_); } - MSHADOW_XINLINE IType discrete_uniform(const int64_t lower, const int64_t upper) { - typedef typename std::conditional, - std::uniform_int_distribution>::type GType; - GType dist_discrete_uniform(lower, upper); - return dist_discrete_uniform(*engine_); - } - MSHADOW_XINLINE FType normal() { std::normal_distribution dist_normal; return dist_normal(*engine_); @@ -148,6 +143,10 @@ class RandGenerator { return curand(&state_); } + MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() { + return static_cast(curand(&state_) << 31) + curand(&state_); + } + MSHADOW_FORCE_INLINE __device__ float uniform() { return static_cast(1.0) - curand_uniform(&state_); } @@ -156,13 +155,6 @@ class RandGenerator { return curand_normal(&state_); } - MSHADOW_FORCE_INLINE __device__ int discrete_uniform(const int64_t lower, const int64_t upper) { - float randu_f = curand_uniform(&state_); - randu_f *= (upper-lower+0.999999); - randu_f += lower; - return static_cast(trunc(randu_f)); - } - private: RandGenerator *global_gen_; int global_state_idx_; @@ -207,6 +199,10 @@ class RandGenerator { return curand(&state_); } + MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() { + return static_cast(curand(&state_) << 31) + curand(&state_); + } + MSHADOW_FORCE_INLINE __device__ double uniform() { return static_cast(1.0) - curand_uniform_double(&state_); } @@ -215,13 +211,6 @@ class RandGenerator { return curand_normal_double(&state_); } - MSHADOW_FORCE_INLINE __device__ int discrete_uniform(const int64_t lower, const int64_t upper) { - float randu_f = curand_uniform(&state_); - randu_f *= (upper-lower+0.999999); - randu_f += lower; - return static_cast(trunc(randu_f)); - } - private: RandGenerator *global_gen_; int global_state_idx_; diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h index dec67da09d88..3e9147635a83 100644 --- a/src/operator/random/sampler.h +++ b/src/operator/random/sampler.h @@ -101,7 +101,12 @@ struct SampleRandIntKernel { const IType *lower, const IType *upper, OType *out) { RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, { index_t nBatch(1 + (nSample - 1) / nParm); - out[i] = OType(genImpl.discrete_uniform(lower[i / nBatch], upper[i / nBatch])); + if (sizeof(IType) == sizeof(int64_t)) + out[i] = OType(lower[i / nBatch] + + (upper[i / nBatch] - lower[i / nBatch]) * genImpl.rand_int64()); + else + out[i] = OType(lower[i / nBatch] + + (upper[i / nBatch] - lower[i / nBatch]) * genImpl.rand()); }); } }; From ce3849d6f630a6cd6bd3c4b256c3e37c10b321f3 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 14 Nov 2018 11:38:31 -0800 Subject: [PATCH 26/30] formula update and removed itype --- include/mxnet/random_generator.h | 2 -- python/mxnet/symbol/random.py | 2 +- src/operator/random/sampler.h | 4 ++-- tests/python/unittest/test_random.py | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 7296305b96df..e7b419309cb7 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -55,8 +55,6 @@ class RandGenerator { public: typedef typename std::conditional::value, DType, double>::type FType; - typedef typename std::conditional::value, - DType, int>::type IType; explicit Impl(RandGenerator *gen, int state_idx) : engine_(gen->states_ + state_idx) {} diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py index 53c2a48f4311..c5940ac96a50 100644 --- a/python/mxnet/symbol/random.py +++ b/python/mxnet/symbol/random.py @@ -303,5 +303,5 @@ def randint(low, high, shape=_Null, dtype=_Null, **kwargs): dtype : {'int32', 'int64'}, optional Data type of output samples. Default is 'int32' """ - return _random_helper(_internal._random_randint, _internal._sample_uniform, + return _random_helper(_internal._random_randint, None, [low, high], shape, dtype, kwargs) diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h index 3e9147635a83..ca764e706c64 100644 --- a/src/operator/random/sampler.h +++ b/src/operator/random/sampler.h @@ -103,10 +103,10 @@ struct SampleRandIntKernel { index_t nBatch(1 + (nSample - 1) / nParm); if (sizeof(IType) == sizeof(int64_t)) out[i] = OType(lower[i / nBatch] + - (upper[i / nBatch] - lower[i / nBatch]) * genImpl.rand_int64()); + genImpl.rand_int64() % (upper[i / nBatch] - lower[i / nBatch])); else out[i] = OType(lower[i / nBatch] + - (upper[i / nBatch] - lower[i / nBatch]) * genImpl.rand()); + genImpl.rand() % (upper[i / nBatch] - lower[i / nBatch])); }); } }; diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 36159458d817..7f1281b544b3 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -863,7 +863,7 @@ def test_randint_extremes(): def test_randint_generator(): ctx = mx.context.cpu() for dtype in ['int32', 'int64']: - for low, high in [(50000000, 50001000),(-50000000,-990000),(-500,199),(-2147483647,2147483647)]: + for low, high in [(50000000, 50001000),(-50000000,-9900),(-500,199),(-2147483647,2147483647)]: scale = high - low buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5) # Quantize bucket boundaries to reflect the actual dtype and adjust probs accordingly From 494e41661ba6f6682d3d54b61eec5ba1243205fb Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 19 Nov 2018 15:03:37 -0800 Subject: [PATCH 27/30] change ctx to include gpu, randint samepl_op.cu typo --- src/operator/random/sample_op.cu | 2 +- tests/python/unittest/test_random.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/random/sample_op.cu b/src/operator/random/sample_op.cu index 016ce6d487a7..39ab55afc081 100644 --- a/src/operator/random/sample_op.cu +++ b/src/operator/random/sample_op.cu @@ -39,7 +39,7 @@ MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_exponential, SampleExponentialParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_poisson, SamplePoissonParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_negative_binomial, SampleNegBinomialParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_generalized_negative_binomial, SampleGenNegBinomialParam) -MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_int, SampleRandIntParam) +MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_randint, SampleRandIntParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_uniform_like, SampleUniformLikeParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_normal_like, SampleNormalLikeParam) MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_gamma_like, SampleGammaLikeParam) diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 7f1281b544b3..e7c853219936 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -845,7 +845,7 @@ def test_randint(): 'high': 3, 'shape' : (500, 500), 'dtype' : dtype, - 'ctx' : mx.context.cpu() + 'ctx' : mx.context.current_context() } mx.random.seed(128) ret1 = mx.nd.random.randint(**params).asnumpy() @@ -856,12 +856,12 @@ def test_randint(): @with_seed() def test_randint_extremes(): - a = mx.nd.random.randint(dtype='int64', low=50000000, high=50000010, ctx=mx.context.cpu()) + a = mx.nd.random.randint(dtype='int64', low=50000000, high=50000010, ctx=mx.context.current_context()) assert a>=50000000 and a<=50000010 @with_seed() def test_randint_generator(): - ctx = mx.context.cpu() + ctx = mx.context.current_context() for dtype in ['int32', 'int64']: for low, high in [(50000000, 50001000),(-50000000,-9900),(-500,199),(-2147483647,2147483647)]: scale = high - low From ccebcecd15f13de172437e3d4c42c978481cc1a2 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 20 Nov 2018 11:10:20 -0800 Subject: [PATCH 28/30] trigger ci From f583ac0b2ba193c813eca8912985488ffc652c65 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 21 Nov 2018 13:18:55 -0800 Subject: [PATCH 29/30] doc fix, check fix, whitespace remove --- python/mxnet/ndarray/random.py | 7 ------- src/operator/random/sample_op.h | 6 +++--- tests/python/unittest/test_random.py | 2 +- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py index 24d666b2974b..fc8be571e2e3 100644 --- a/python/mxnet/ndarray/random.py +++ b/python/mxnet/ndarray/random.py @@ -557,13 +557,6 @@ def randint(low, high, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs): >>> mx.nd.random.randint(-10, 10, shape=(2,)) [ -5 4] - >>> low = mx.nd.array([1,2,3]) - >>> high = mx.nd.array([2,3,4]) - >>> mx.nd.random.randint(low, high, shape=2) - [[ 1 1] - [ 2 2] - [ 3 3]] - """ return _random_helper(_internal._random_randint, None, [low, high], shape, dtype, ctx, out, kwargs) diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h index 5108ed0704b6..b12dfafbcfc8 100644 --- a/src/operator/random/sample_op.h +++ b/src/operator/random/sample_op.h @@ -751,7 +751,7 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs, } bool dtype_ok = (dtype == kFloat16) || (dtype == kFloat32) || (dtype == kFloat64); - CHECK_EQ(dtype_ok, true) << "Output type must be float16, float32, float64: dtype is " + CHECK(dtype_ok) << "Output type must be float16, float32, float64: dtype is " << dtype_out << " vs " << kFloat16 << " or " << kFloat32 << " or " << kFloat64; TYPE_ASSIGN_CHECK(*out_type, 0, dtype); @@ -782,11 +782,11 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs, dtype = param.dtype; } else { // Use default - dtype = kFloat32; + dtype = kInt32; } } bool dtype_ok = (dtype == kInt32) || (dtype == kInt64); - CHECK_EQ(dtype_ok, true) << "Output type must be int32, int64: dtype is " + CHECK(dtype_ok) << "Output type must be int32, int64: dtype is " << dtype_out << " vs " << kInt32 << " or " << kInt64; TYPE_ASSIGN_CHECK(*out_type, 0, dtype); return true; diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index e7c853219936..6c6b0b1ad6dc 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -865,7 +865,7 @@ def test_randint_generator(): for dtype in ['int32', 'int64']: for low, high in [(50000000, 50001000),(-50000000,-9900),(-500,199),(-2147483647,2147483647)]: scale = high - low - buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5) + buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5) # Quantize bucket boundaries to reflect the actual dtype and adjust probs accordingly buckets = np.array(buckets, dtype=dtype).tolist() probs = [(buckets[i][1] - buckets[i][0]) / float(scale) for i in range(5)] From e3be157723e6ea334315002ed1092cbfa6662161 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 27 Nov 2018 10:39:10 -0800 Subject: [PATCH 30/30] added the without dtype testcase --- tests/python/unittest/test_random.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 6c6b0b1ad6dc..3436e9a9e80e 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -877,6 +877,10 @@ def test_randint_generator(): for _ in range(10)]) verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs) +with_seed() +def test_randint_without_dtype(): + a = mx.nd.random.randint(low=50000000, high=50000010, ctx=mx.context.current_context()) + assert(a.dtype, 'int32') if __name__ == '__main__': import nose