Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] add cross product op
Browse files Browse the repository at this point in the history
* implement - register op, infershape

* implement - cross product 2x2

* implement - cross product 2x3 3x2 3x3

* fix - cudaError: misaligned address

* add - test code

* fix - test axis=None

* add - Register _backward_npi_cross

* implement - get_cross_backward in test

* implement - cross backward 3x3 2x3 3x2

* implement - cross backward 2x2

* fix - sanity pylint

* fix - move the python wrapper to keep consistency

* impl - FFI for cross op

* impl - FFI benchmark
  • Loading branch information
Ding authored and Ubuntu committed Apr 16, 2020
1 parent 94f235d commit b72bfb7
Show file tree
Hide file tree
Showing 11 changed files with 2,291 additions and 140 deletions.
1 change: 1 addition & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def prepare_workloads():
OpArgMngr.add_workload("kron", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("cross", pool['2'], pool['2'])
OpArgMngr.add_workload("linalg.eig", pool['3x3'])
OpArgMngr.add_workload("linalg.eigh", pool['3x3'])
OpArgMngr.add_workload("linalg.det", pool['3x3'])
Expand Down
201 changes: 155 additions & 46 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron',
'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp',
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
Expand Down Expand Up @@ -6190,6 +6190,46 @@ def ldexp(x1, x2, out=None, **kwargs):
return _api_internal.ldexp(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
def vdot(a, b):
r"""
Return the dot product of two vectors.
Note that `vdot` handles multidimensional arrays differently than `dot`:
it does *not* perform a matrix product, but flattens input arguments
to 1-D vectors first. Consequently, it should only be used for vectors.
Parameters
----------
a : ndarray
First argument to the dot product.
b : ndarray
Second argument to the dot product.
Returns
-------
output : ndarray
Dot product of `a` and `b`.
See Also
--------
dot : Return the dot product without using the complex conjugate of the
first argument.
Examples
--------
Note that higher-dimensional arrays are flattened!
>>> a = np.array([[1, 4], [5, 6]])
>>> b = np.array([[4, 1], [2, 2]])
>>> np.vdot(a, b)
30
>>> np.vdot(b, a)
30
>>> 1*4 + 4*1 + 5*2 + 6*2
30
"""
return tensordot(a.flatten(), b.flatten(), 1)


@set_module('mxnet.ndarray.numpy')
def inner(a, b):
r"""
Expand Down Expand Up @@ -6297,25 +6337,135 @@ def outer(a, b):
return tensordot(a.flatten(), b.flatten(), 0)


@set_module('mxnet.ndarray.numpy')
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=too-many-arguments
"""
Return the cross product of two (arrays of) vectors.
The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular
to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors
are defined by the last axis of `a` and `b` by default, and these axis
can have dimensions 2 or 3. Where the dimension of either `a` or `b` is
2, the third component of the input vector is assumed to be zero and the
cross product calculated accordingly. In cases where both input vectors
have dimension 2, the z-component of the cross product is returned.
Parameters
----------
a : ndarray
Components of the first vector(s).
b : ndarray
Components of the second vector(s).
axisa : int, optional
Axis of `a` that defines the vector(s). By default, the last axis.
axisb : int, optional
Axis of `b` that defines the vector(s). By default, the last axis.
axisc : int, optional
Axis of `c` containing the cross product vector(s). Ignored if
both input vectors have dimension 2, as the return is scalar.
By default, the last axis.
axis : int, optional
If defined, the axis of `a`, `b` and `c` that defines the vector(s)
and cross product(s). Overrides `axisa`, `axisb` and `axisc`.
Returns
-------
c : ndarray
Vector cross product(s).
Raises
------
ValueError
When the dimension of the vector(s) in `a` and/or `b` does not
equal 2 or 3.
Notes
-----
Supports full broadcasting of the inputs.
Examples
--------
Vector cross-product.
>>> x = np.array([1., 2., 3.])
>>> y = np.array([4., 5., 6.])
>>> np.cross(x, y)
array([-3., 6., -3.])
One vector with dimension 2.
>>> x = np.array([1., 2.])
>>> y = np.array([4., 5., 6.])
>>> np.cross(x, y)
array([12., -6., -3.])
Equivalently:
>>> x = np.array([1., 2., 0.])
>>> y = np.array([4., 5., 6.])
>>> np.cross(x, y)
array([12., -6., -3.])
Both vectors with dimension 2.
>>> x = np.array([1., 2.])
>>> y = np.array([4., 5.])
>>> np.cross(x, y)
array(-3.)
Multiple vector cross-products. Note that the direction of the cross
product vector is defined by the `right-hand rule`.
>>> x = np.array([[1., 2., 3.], [4., 5., 6.]])
>>> y = np.array([[4., 5., 6.], [1., 2., 3.]])
>>> np.cross(x, y)
array([[-3., 6., -3.],
[ 3., -6., 3.]])
The orientation of `c` can be changed using the `axisc` keyword.
>>> np.cross(x, y, axisc=0)
array([[-3., 3.],
[ 6., -6.],
[-3., 3.]])
Change the vector definition of `x` and `y` using `axisa` and `axisb`.
>>> x = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
>>> y = np.array([[7., 8., 9.], [4., 5., 6.], [1., 2., 3.]])
>>> np.cross(x, y)
array([[ -6., 12., -6.],
[ 0., 0., 0.],
[ 6., -12., 6.]])
>>> np.cross(x, y, axisa=0, axisb=0)
array([[-24., 48., -24.],
[-30., 60., -30.],
[-36., 72., -36.]])
"""
if axis is not None:
axisa, axisb, axisc = (axis,) * 3

if isinstance(a, NDArray) and isinstance(b, NDArray):
return _api_internal.cross(a, b, axisa, axisb, axisc)
else:
raise TypeError("Input data should be NDarray")


@set_module('mxnet.ndarray.numpy')
def kron(a, b):
r"""
Kronecker product of two arrays.
Computes the Kronecker product, a composite array made of blocks of the
second array scaled by the first.
Parameters
----------
a, b : ndarray
Returns
-------
out : ndarray
See Also
--------
outer : The outer product
Notes
-----
The function assumes that the number of dimensions of `a` and `b`
Expand All @@ -6331,7 +6481,6 @@ def kron(a, b):
[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
[ ... ... ],
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
Examples
--------
>>> np.kron([1,10,100], [5,6,7])
Expand All @@ -6342,46 +6491,6 @@ def kron(a, b):
return _api_internal.kron(a, b)


@set_module('mxnet.ndarray.numpy')
def vdot(a, b):
r"""
Return the dot product of two vectors.
Note that `vdot` handles multidimensional arrays differently than `dot`:
it does *not* perform a matrix product, but flattens input arguments
to 1-D vectors first. Consequently, it should only be used for vectors.
Parameters
----------
a : ndarray
First argument to the dot product.
b : ndarray
Second argument to the dot product.
Returns
-------
output : ndarray
Dot product of `a` and `b`.
See Also
--------
dot : Return the dot product without using the complex conjugate of the
first argument.
Examples
--------
Note that higher-dimensional arrays are flattened!
>>> a = np.array([[1, 4], [5, 6]])
>>> b = np.array([[4, 1], [2, 2]])
>>> np.vdot(a, b)
30
>>> np.vdot(b, a)
30
>>> 1*4 + 4*1 + 5*2 + 6*2
30
"""
return tensordot(a.flatten(), b.flatten(), 1)


@set_module('mxnet.ndarray.numpy')
def equal(x1, x2, out=None):
"""
Expand Down
Loading

0 comments on commit b72bfb7

Please sign in to comment.