Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] add cross product op
Browse files Browse the repository at this point in the history
* implement - register op, infershape

* implement - cross product 2x2

* implement - cross product 2x3 3x2 3x3

* fix - cudaError: misaligned address

* add - test code

* fix - test axis=None

* add - Register _backward_npi_cross

* implement - get_cross_backward in test

* implement - cross backward 3x3 2x3 3x2

* implement - cross backward 2x2

* fix - sanity pylint
  • Loading branch information
Ding authored and Ubuntu committed Feb 23, 2020
1 parent d2be9a6 commit 34b4449
Show file tree
Hide file tree
Showing 9 changed files with 2,112 additions and 6 deletions.
116 changes: 115 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
Expand Down Expand Up @@ -7490,3 +7490,117 @@ def pad(x, pad_width, mode='constant', **kwargs): # pylint: disable=too-many-arg
raise ValueError("unsupported stat_length '{}'".format(values))
return _npi.pad(x, pad_width, mode='minimum')
return _npi.pad(x, pad_width, mode='constant', constant_value=0)


@set_module('mxnet.ndarray.numpy')
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None, **kwargs): # pylint: disable=too-many-arguments
"""
Return the cross product of two (arrays of) vectors.
The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular
to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors
are defined by the last axis of `a` and `b` by default, and these axis
can have dimensions 2 or 3. Where the dimension of either `a` or `b` is
2, the third component of the input vector is assumed to be zero and the
cross product calculated accordingly. In cases where both input vectors
have dimension 2, the z-component of the cross product is returned.
Parameters
----------
a : ndarray
Components of the first vector(s).
b : ndarray
Components of the second vector(s).
axisa : int, optional
Axis of `a` that defines the vector(s). By default, the last axis.
axisb : int, optional
Axis of `b` that defines the vector(s). By default, the last axis.
axisc : int, optional
Axis of `c` containing the cross product vector(s). Ignored if
both input vectors have dimension 2, as the return is scalar.
By default, the last axis.
axis : int, optional
If defined, the axis of `a`, `b` and `c` that defines the vector(s)
and cross product(s). Overrides `axisa`, `axisb` and `axisc`.
Returns
-------
c : ndarray
Vector cross product(s).
Raises
------
ValueError
When the dimension of the vector(s) in `a` and/or `b` does not
equal 2 or 3.
Notes
-----
Supports full broadcasting of the inputs.
Examples
--------
Vector cross-product.
>>> x = np.array([1., 2., 3.])
>>> y = np.array([4., 5., 6.])
>>> np.cross(x, y)
array([-3., 6., -3.])
One vector with dimension 2.
>>> x = np.array([1., 2.])
>>> y = np.array([4., 5., 6.])
>>> np.cross(x, y)
array([12., -6., -3.])
Equivalently:
>>> x = np.array([1., 2., 0.])
>>> y = np.array([4., 5., 6.])
>>> np.cross(x, y)
array([12., -6., -3.])
Both vectors with dimension 2.
>>> x = np.array([1., 2.])
>>> y = np.array([4., 5.])
>>> np.cross(x, y)
array(-3.)
Multiple vector cross-products. Note that the direction of the cross
product vector is defined by the `right-hand rule`.
>>> x = np.array([[1., 2., 3.], [4., 5., 6.]])
>>> y = np.array([[4., 5., 6.], [1., 2., 3.]])
>>> np.cross(x, y)
array([[-3., 6., -3.],
[ 3., -6., 3.]])
The orientation of `c` can be changed using the `axisc` keyword.
>>> np.cross(x, y, axisc=0)
array([[-3., 3.],
[ 6., -6.],
[-3., 3.]])
Change the vector definition of `x` and `y` using `axisa` and `axisb`.
>>> x = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
>>> y = np.array([[7., 8., 9.], [4., 5., 6.], [1., 2., 3.]])
>>> np.cross(x, y)
array([[ -6., 12., -6.],
[ 0., 0., 0.],
[ 6., -12., 6.]])
>>> np.cross(x, y, axisa=0, axisb=0)
array([[-24., 48., -24.],
[-30., 60., -30.],
[-36., 72., -36.]])
"""
if axis is not None:
axisa, axisb, axisc = (axis,) * 3

if isinstance(a, NDArray) and isinstance(b, NDArray):
return _npi.cross(a, b, axisa, axisb, axisc)
else:
raise TypeError("Input data should be NDarray")
117 changes: 113 additions & 4 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@
'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman',
'flip', 'flipud', 'fliplr', 'around', 'round', 'arctan2', 'hypot',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount', 'pad']
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'resize', 'matmul', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'polyval', 'where', 'bincount', 'pad']

__all__ += fallback.__all__

Expand Down Expand Up @@ -9550,3 +9551,111 @@ def pad(x, pad_width=None, mode="constant", **kwargs): # pylint: disable=too-man
[10, 10, 10, 10, 10, 10, 10]])
"""
return _mx_nd_np.pad(x, pad_width, mode, **kwargs)


@set_module('mxnet.numpy')
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None, **kwargs): # pylint: disable=too-many-arguments
"""
Return the cross product of two (arrays of) vectors.
The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular
to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors
are defined by the last axis of `a` and `b` by default, and these axes
can have dimensions 2 or 3. Where the dimension of either `a` or `b` is
2, the third component of the input vector is assumed to be zero and the
cross product calculated accordingly. In cases where both input vectors
have dimension 2, the z-component of the cross product is returned.
Parameters
----------
a : ndarray
Components of the first vector(s).
b : ndarray
Components of the second vector(s).
axisa : int, optional
Axis of `a` that defines the vector(s). By default, the last axis.
axisb : int, optional
Axis of `b` that defines the vector(s). By default, the last axis.
axisc : int, optional
Axis of `c` containing the cross product vector(s). Ignored if
both input vectors have dimension 2, as the return is scalar.
By default, the last axis.
axis : int, optional
If defined, the axis of `a`, `b` and `c` that defines the vector(s)
and cross product(s). Overrides `axisa`, `axisb` and `axisc`.
Returns
-------
c : ndarray
Vector cross product(s).
Raises
------
ValueError
When the dimension of the vector(s) in `a` and/or `b` does not
equal 2 or 3.
Notes
-----
Supports full broadcasting of the inputs.
Examples
--------
Vector cross-product.
>>> x = np.array([1., 2., 3.])
>>> y = np.array([4., 5., 6.])
>>> np.cross(x, y)
array([-3., 6., -3.])
One vector with dimension 2.
>>> x = np.array([1., 2.])
>>> y = np.array([4., 5., 6.])
>>> np.cross(x, y)
array([12., -6., -3.])
Equivalently:
>>> x = np.array([1., 2., 0.])
>>> y = np.array([4., 5., 6.])
>>> np.cross(x, y)
array([12., -6., -3.])
Both vectors with dimension 2.
>>> x = np.array([1., 2.])
>>> y = np.array([4., 5.])
>>> np.cross(x, y)
array(-3.)
Multiple vector cross-products. Note that the direction of the cross
product vector is defined by the `right-hand rule`.
>>> x = np.array([[1., 2., 3.], [4., 5., 6.]])
>>> y = np.array([[4., 5., 6.], [1., 2., 3.]])
>>> np.cross(x, y)
array([[-3., 6., -3.],
[ 3., -6., 3.]])
The orientation of `c` can be changed using the `axisc` keyword.
>>> np.cross(x, y, axisc=0)
array([[-3., 3.],
[ 6., -6.],
[-3., 3.]])
Change the vector definition of `x` and `y` using `axisa` and `axisb`.
>>> x = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
>>> y = np.array([[7., 8., 9.], [4., 5., 6.], [1., 2., 3.]])
>>> np.cross(x, y)
array([[ -6., 12., -6.],
[ 0., 0., 0.],
[ 6., -12., 6.]])
>>> np.cross(x, y, axisa=0, axisb=0)
array([[-24., 48., -24.],
[-30., 60., -30.],
[-36., 72., -36.]])
"""
return _mx_nd_np.cross(a, b, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis)
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'isneginf',
'isinf',
'pad',
'cross',
]


Expand Down
54 changes: 53 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
Expand Down Expand Up @@ -6619,4 +6619,56 @@ def pad(x, pad_width, mode='constant', **kwargs): # pylint: disable=too-many-arg
return _npi.pad(x, pad_width, mode='constant', constant_value=0)


@set_module('mxnet.symbol.numpy')
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None, **kwargs): # pylint: disable=too-many-arguments
"""
Return the cross product of two (arrays of) vectors.
The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular
to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors
are defined by the last axis of `a` and `b` by default, and these axes
can have dimensions 2 or 3. Where the dimension of either `a` or `b` is
2, the third component of the input vector is assumed to be zero and the
cross product calculated accordingly. In cases where both input vectors
have dimension 2, the z-component of the cross product is returned.
Parameters
----------
a : _Symbol
Components of the first vector(s).
b : _Symbol
Components of the second vector(s).
axisa : int, optional
Axis of `a` that defines the vector(s). By default, the last axis.
axisb : int, optional
Axis of `b` that defines the vector(s). By default, the last axis.
axisc : int, optional
Axis of `c` containing the cross product vector(s). Ignored if
both input vectors have dimension 2, as the return is scalar.
By default, the last axis.
axis : int, optional
If defined, the axis of `a`, `b` and `c` that defines the vector(s)
and cross product(s). Overrides `axisa`, `axisb` and `axisc`.
Returns
-------
c : _Symbol
Vector cross product(s).
Raises
------
ValueError
When the dimension of the vector(s) in `a` and/or `b` does not
equal 2 or 3.
Notes
-----
Supports full broadcasting of the inputs.
"""
if axis is not None:
axisa, axisb, axisc = (axis,) * 3

return _npi.cross(a, b, axisa, axisb, axisc)


_set_np_symbol_class(_Symbol)
Loading

0 comments on commit 34b4449

Please sign in to comment.