Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into add_op_append
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangZhaoh committed Oct 31, 2019
2 parents f7d42b6 + f9baec9 commit 127c498
Show file tree
Hide file tree
Showing 23 changed files with 1,110 additions and 56 deletions.
2 changes: 1 addition & 1 deletion julia/deps/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ if Sys.isunix()
nvcc_path = Sys.which("nvcc")
if nvcc_path nothing
@info "Found nvcc: $nvcc_path"
push!(CUDAPATHS, replace(nvcc_path, "bin/nvcc", "lib64"))
push!(CUDAPATHS, replace(nvcc_path, "bin/nvcc" => "lib64"))
end
end

Expand Down
2 changes: 1 addition & 1 deletion julia/src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ end

function get(metric::MSE)
# Delay copy until last possible moment
mse_sum = mapreduce(nda->copy(nda)[1], +, 0.0, metric.mse_sum)
mse_sum = mapreduce(nda->copy(nda)[1], +, metric.mse_sum; init = zero(MX_float))
[(:MSE, mse_sum / metric.n_sample)]
end

Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/contrib/amp/lists/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@
'arctanh',
'_sparse_arcsin',
'_sparse_arctanh',
'_contrib_MultiBoxDetection',
'_contrib_MultiBoxPrior',
'_contrib_MultiBoxTarget',

# Exponents
'exp',
Expand Down Expand Up @@ -575,9 +578,6 @@
'stack',
'_Maximum',
'_Minimum',
'_contrib_MultiBoxDetection',
'_contrib_MultiBoxPrior',
'_contrib_MultiBoxTarget',
'_contrib_MultiProposal',
'_contrib_PSROIPooling',
'_contrib_Proposal',
Expand Down
85 changes: 82 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate',
'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin',
'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'append']
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff',
'append']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -3058,6 +3059,36 @@ def get_list(arrays):
return _npi.vstack(*arrays)


@set_module('mxnet.ndarray.numpy')
def column_stack(tup):
"""
Stack 1-D arrays as columns into a 2-D array.
Take a sequence of 1-D arrays and stack them as columns
to make a single 2-D array. 2-D arrays are stacked as-is,
just like with `hstack`. 1-D arrays are turned into 2-D columns
first.
Returns
--------
stacked : 2-D array
The array formed by stacking the given arrays.
See Also
--------
stack, hstack, vstack, concatenate
Examples
--------
>>> a = np.array((1,2,3))
>>> b = np.array((2,3,4))
>>> np.column_stack((a,b))
array([[1., 2.],
[2., 3.],
[3., 4.]])
"""
return _npi.column_stack(*tup)


@set_module('mxnet.ndarray.numpy')
def dstack(arrays):
"""
Expand Down Expand Up @@ -5037,3 +5068,51 @@ def may_share_memory(a, b, max_work=None):
- Actually it is same as `shares_memory` in MXNet DeepNumPy
"""
return _npi.share_memory(a, b).item()


def diff(a, n=1, axis=-1, prepend=None, append=None):
r"""
numpy.diff(a, n=1, axis=-1, prepend=<no value>, append=<no value>)
Calculate the n-th discrete difference along the given axis.
Parameters
----------
a : ndarray
Input array
n : int, optional
The number of times values are differenced. If zero, the input is returned as-is.
axis : int, optional
The axis along which the difference is taken, default is the last axis.
prepend, append : ndarray, optional
Not supported yet
Returns
-------
diff : ndarray
The n-th differences.
The shape of the output is the same as a except along axis where the dimension is smaller by n.
The type of the output is the same as the type of the difference between any two elements of a.
Examples
--------
>>> x = np.array([1, 2, 4, 7, 0])
>>> np.diff(x)
array([ 1, 2, 3, -7])
>>> np.diff(x, n=2)
array([ 1, 1, -10])
>>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]])
>>> np.diff(x)
array([[2, 3, 4],
[5, 1, 2]])
>>> np.diff(x, axis=0)
array([[-1, 2, 0, -2]])
Notes
-----
Optional inputs `prepend` and `append` are not supported yet
"""
if (prepend or append):
raise NotImplementedError('prepend and append options are not supported yet')
return _npi.diff(a, n=n, axis=axis)
110 changes: 100 additions & 10 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'append']
'may_share_memory', 'diff', 'append']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -577,7 +577,7 @@ def __setitem__(self, key, value):
if not isinstance(key, tuple) or len(key) != 0:
raise IndexError('scalar tensor can only accept `()` as index')
if isinstance(value, numeric_types):
self.full(value)
self._full(value)
elif isinstance(value, ndarray) and value.size == 1:
if value.shape != self.shape:
value = value.reshape(self.shape)
Expand Down Expand Up @@ -1993,15 +1993,20 @@ def array(object, dtype=None, ctx=None):
"""
if ctx is None:
ctx = current_context()
if isinstance(object, ndarray):
if isinstance(object, (ndarray, _np.ndarray)):
dtype = object.dtype if dtype is None else dtype
elif isinstance(object, NDArray):
raise ValueError("If you're trying to create a mxnet.numpy.ndarray "
"from mx.nd.NDArray, please use the zero-copy as_np_ndarray function.")
else:
dtype = _np.float32 if dtype is None else dtype
if not isinstance(object, (ndarray, _np.ndarray)):
try:
object = _np.array(object, dtype=dtype)
except Exception as e:
raise TypeError('{}'.format(str(e)))
if dtype is None:
dtype = object.dtype if hasattr(object, "dtype") else _np.float32
try:
object = _np.array(object, dtype=dtype)
except Exception as e:
# printing out the error raised by official NumPy's array function
# for transparency on users' side
raise TypeError('{}'.format(str(e)))
ret = empty(object.shape, dtype=dtype, ctx=ctx)
if len(object.shape) == 0:
ret[()] = object
Expand Down Expand Up @@ -4940,6 +4945,42 @@ def vstack(arrays, out=None):
return _mx_nd_np.vstack(arrays)


@set_module('mxnet.numpy')
def column_stack(tup):
"""
Stack 1-D arrays as columns into a 2-D array.
Take a sequence of 1-D arrays and stack them as columns
to make a single 2-D array. 2-D arrays are stacked as-is,
just like with `hstack`. 1-D arrays are turned into 2-D columns
first.
Parameters
----------
tup : sequence of 1-D or 2-D arrays.
Arrays to stack. All of them must have the same first dimension.
Returns
--------
stacked : 2-D array
The array formed by stacking the given arrays.
See Also
--------
stack, hstack, vstack, concatenate
Examples
--------
>>> a = np.array((1,2,3))
>>> b = np.array((2,3,4))
>>> np.column_stack((a,b))
array([[1., 2.],
[2., 3.],
[3., 4.]])
"""
return _mx_nd_np.column_stack(tup)


@set_module('mxnet.numpy')
def dstack(arrays):
"""
Expand Down Expand Up @@ -7016,3 +7057,52 @@ def may_share_memory(a, b, max_work=None):
- Actually it is same as `shares_memory` in MXNet DeepNumPy
"""
return _mx_nd_np.may_share_memory(a, b, max_work)


def diff(a, n=1, axis=-1, prepend=None, append=None):
r"""
numpy.diff(a, n=1, axis=-1, prepend=<no value>, append=<no value>)
Calculate the n-th discrete difference along the given axis.
Parameters
----------
a : ndarray
Input array
n : int, optional
The number of times values are differenced. If zero, the input is returned as-is.
axis : int, optional
The axis along which the difference is taken, default is the last axis.
prepend, append : ndarray, optional
Not supported yet
Returns
-------
diff : ndarray
The n-th differences.
The shape of the output is the same as a except along axis where the dimension is smaller by n.
The type of the output is the same as the type of the difference between any two elements of a.
This is the same as the type of a in most cases.
Examples
--------
>>> x = np.array([1, 2, 4, 7, 0])
>>> np.diff(x)
array([ 1, 2, 3, -7])
>>> np.diff(x, n=2)
array([ 1, 1, -10])
>>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]])
>>> np.diff(x)
array([[2, 3, 4],
[5, 1, 2]])
>>> np.diff(x, axis=0)
array([[-1, 2, 0, -2]])
Notes
-----
Optional inputs `prepend` and `append` are not supported yet
"""
if (prepend or append):
raise NotImplementedError('prepend and append options are not supported yet')
return _mx_nd_np.diff(a, n=n, axis=axis)
2 changes: 2 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'var',
'vdot',
'vstack',
'column_stack',
'zeros_like',
'linalg.norm',
'trace',
Expand All @@ -131,6 +132,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'einsum',
'shares_memory',
'may_share_memory',
'diff',
]


Expand Down
Loading

0 comments on commit 127c498

Please sign in to comment.