Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Numpy Tensordot Operator (#15349)
Browse files Browse the repository at this point in the history
* implements numpy tensordot

* Fixed bugs and optimized backward operator

* Rewrited tests

* Debuging

* Debuging 0-size input

* Moved axis-reordering from frontend to backend

* Added comments

* integrated forward part

* Add more tests

* Fixed GPU bugs

* Add comments to tensordot

* Add empty lines.

* Change tests

* Remove redundant code

* Change file names

* Add numerical backward test

* Change np.dot for case 5.

* Remove spaces.

* Remove more spaces.

* Add head files.

* Remove spaces in python interface

* Refactored.

* Changed intereface.

* changed GPU test.

* Clean codes.

* Change styles.

* Remove blank lines.

* Add blank lines

* Recover lines.

* Support python 2

* Test Python 2

* Add more tests

* Add error msg

* Change comments.
  • Loading branch information
ckt624 authored and reminisce committed Jul 17, 2019
1 parent a26b201 commit d8d6b3b
Show file tree
Hide file tree
Showing 8 changed files with 1,227 additions and 119 deletions.
82 changes: 81 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,87 @@
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'eye',
'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
'argsort', 'hstack']
'argsort', 'hstack', 'tensordot']


@set_module('mxnet.ndarray.numpy')
def tensordot(a, b, axes=2):
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : ndarray, len(shape) >= 1
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) ndarray
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements ndarray must be of the same length.
See Also
--------
dot, einsum
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
Examples
--------
>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[ 4400., 4730.],
[ 4532., 4874.],
[ 4664., 5018.],
[ 4796., 5162.],
[ 4928., 5306.]])
"""
if _np.isscalar(axes):
return _npi.tensordot_int_axes(a, b, axes)

if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if _np.isscalar(a_axes_summed):
a_axes_summed = (a_axes_summed,)
if _np.isscalar(b_axes_summed):
b_axes_summed = (b_axes_summed,)

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.ndarray.numpy')
Expand Down
82 changes: 81 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,87 @@
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'eye', 'sin', 'cos',
'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
'argsort', 'hstack']
'argsort', 'hstack', 'tensordot']


@set_module('mxnet.numpy')
def tensordot(a, b, axes=2):
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : ndarray, len(shape) >= 1
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) ndarray
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements ndarray must be of the same length.
See Also
--------
dot, einsum
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
Examples
--------
>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[ 4400., 4730.],
[ 4532., 4874.],
[ 4664., 5018.],
[ 4796., 5162.],
[ 4928., 5306.]])
"""
if _np.isscalar(axes):
return _npi.tensordot_int_axes(a, b, axes)

if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if _np.isscalar(a_axes_summed):
a_axes_summed = (a_axes_summed,)
if _np.isscalar(b_axes_summed):
b_axes_summed = (b_axes_summed,)

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


# This function is copied from ndarray.py since pylint
Expand Down
66 changes: 65 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
'expand_dims', 'tile', 'linspace', 'eye', 'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt',
'abs', 'exp', 'arctan', 'sign', 'log', 'degrees', 'log2', 'rint', 'radians', 'mean',
'reciprocal', 'square', 'arcsin', 'argsort', 'hstack']
'reciprocal', 'square', 'arcsin', 'argsort', 'hstack', 'tensordot']


def _num_outputs(sym):
Expand Down Expand Up @@ -2294,4 +2294,68 @@ def arcsin(x, out=None, **kwargs):
return _unary_func_helper(x, _npi.arcsin, _np.arcsin, out=out, **kwargs)


@set_module('mxnet.symbol.numpy')
def tensordot(a, b, axes=2):
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : _Symbol
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) array_like
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements array_like must be of the same length.
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
"""
if _np.isscalar(axes):
return _npi.tensordot_int_axes(a, b, axes)

if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if _np.isscalar(a_axes_summed):
a_axes_summed = (a_axes_summed,)
if _np.isscalar(b_axes_summed):
b_axes_summed = (b_axes_summed,)

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


_set_np_symbol_class(_Symbol)
Loading

0 comments on commit d8d6b3b

Please sign in to comment.