Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add new change
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangZhaoh committed Jan 13, 2020
1 parent 8fb4f1f commit 43be279
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 8 deletions.
2 changes: 0 additions & 2 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,6 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None):
"""
from ...numpy import ndarray as np_ndarray
input_type = (isinstance(shape, np_ndarray), isinstance(scale, np_ndarray))
if dtype is None:
dtype = 'float32'
if ctx is None:
ctx = current_context()
if out is not None:
Expand Down
2 changes: 0 additions & 2 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,6 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None):
"""
from ._symbol import _Symbol as np_symbol
input_type = (isinstance(shape, np_symbol), isinstance(scale, np_symbol))
if dtype is None:
dtype = 'float32'
if ctx is None:
ctx = current_context()
if out is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/operator/numpy/random/np_gamma_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ inline bool NumpyGammaOpType(const nnvm::NodeAttrs& attrs,
if (otype != -1) {
(*out_attrs)[0] = otype;
} else {
(*out_attrs)[0] = mshadow::kFloat32;
(*out_attrs)[0] = GetDefaultDtype(param.dtype);
}
return true;
}
Expand Down
8 changes: 5 additions & 3 deletions src/operator/numpy/random/np_gamma_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ struct NumpyGammaParam : public dmlc::Parameter<NumpyGammaParam> {
.describe("Context of output, in format [xpu|xpu|xpu_pinned](n)."
" Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype)
.add_enum("None", -1)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
.set_default(mshadow::kFloat32)
.describe("DType of the output in case this can't be inferred. "
"Defaults to float32 if not defined (dtype=None).");
.set_default(-1)
.describe("DType of the output in case this can't be inferred."
"Defaults to float64 or float32 if not defined (dtype=None),"
"which depends on your current default dtype.");
}
};

Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_numpy_default_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_workloads(name):
'hanning',
'hamming',
'blackman',
'gamma',
'random.uniform',
'random.normal',
'true_divide'
Expand Down Expand Up @@ -119,6 +120,10 @@ def _add_dtype_workload_hamming():
DtypeOpArgMngr.add_workload('hamming', 3)


def _add_dtype_workload_gamma():
DtypeOpArgMngr.add_workload('gamma', 3)


def _add_dtype_workload_blackman():
DtypeOpArgMngr.add_workload('blackman', 3)

Expand Down Expand Up @@ -151,6 +156,7 @@ def _prepare_workloads():
_add_dtype_workload_hanning()
_add_dtype_workload_hamming()
_add_dtype_workload_blackman()
_add_dtype_workload_gamma()
_add_dtype_workload_random_uniform()
_add_dtype_workload_random_normal()
_add_dtype_workload_true_divide()
Expand Down

0 comments on commit 43be279

Please sign in to comment.