diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 693cca05d9ce..414c28df4364 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -348,8 +348,6 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): """ from ...numpy import ndarray as np_ndarray input_type = (isinstance(shape, np_ndarray), isinstance(scale, np_ndarray)) - if dtype is None: - dtype = 'float32' if ctx is None: ctx = current_context() if out is not None: diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 73866ae90bf0..84b8bf4fbc14 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -320,8 +320,6 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): """ from ._symbol import _Symbol as np_symbol input_type = (isinstance(shape, np_symbol), isinstance(scale, np_symbol)) - if dtype is None: - dtype = 'float32' if ctx is None: ctx = current_context() if out is not None: diff --git a/src/operator/numpy/random/np_gamma_op.cc b/src/operator/numpy/random/np_gamma_op.cc index 72e337b1642b..210843fc2087 100644 --- a/src/operator/numpy/random/np_gamma_op.cc +++ b/src/operator/numpy/random/np_gamma_op.cc @@ -38,7 +38,7 @@ inline bool NumpyGammaOpType(const nnvm::NodeAttrs& attrs, if (otype != -1) { (*out_attrs)[0] = otype; } else { - (*out_attrs)[0] = mshadow::kFloat32; + (*out_attrs)[0] = GetDefaultDtype(param.dtype); } return true; } diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index 83e8f1f5242b..b4a2c98db6aa 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -62,12 +62,14 @@ struct NumpyGammaParam : public dmlc::Parameter { .describe("Context of output, in format [xpu|xpu|xpu_pinned](n)." " Only used for imperative calls."); DMLC_DECLARE_FIELD(dtype) + .add_enum("None", -1) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) .add_enum("float16", mshadow::kFloat16) - .set_default(mshadow::kFloat32) - .describe("DType of the output in case this can't be inferred. " - "Defaults to float32 if not defined (dtype=None)."); + .set_default(-1) + .describe("DType of the output in case this can't be inferred." + "Defaults to float64 or float32 if not defined (dtype=None)," + "which depends on your current default dtype."); } }; diff --git a/tests/python/unittest/test_numpy_default_dtype.py b/tests/python/unittest/test_numpy_default_dtype.py index 68a07e932bfd..760127800bbd 100644 --- a/tests/python/unittest/test_numpy_default_dtype.py +++ b/tests/python/unittest/test_numpy_default_dtype.py @@ -56,6 +56,7 @@ def get_workloads(name): 'hanning', 'hamming', 'blackman', + 'gamma', 'random.uniform', 'random.normal', 'true_divide' @@ -119,6 +120,10 @@ def _add_dtype_workload_hamming(): DtypeOpArgMngr.add_workload('hamming', 3) +def _add_dtype_workload_gamma(): + DtypeOpArgMngr.add_workload('gamma', 3) + + def _add_dtype_workload_blackman(): DtypeOpArgMngr.add_workload('blackman', 3) @@ -151,6 +156,7 @@ def _prepare_workloads(): _add_dtype_workload_hanning() _add_dtype_workload_hamming() _add_dtype_workload_blackman() + _add_dtype_workload_gamma() _add_dtype_workload_random_uniform() _add_dtype_workload_random_normal() _add_dtype_workload_true_divide()