Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Numpy Tensordot and Dot Operator (#15820)
Browse files Browse the repository at this point in the history
* Implements tensordot and dot.

* Change tests.

* Add spaces.

* Reorganize codes.

* Remove np_matrix_op.h
  • Loading branch information
ckt624 authored and haojin2 committed Aug 12, 2019
1 parent 57927a9 commit c3f5eea
Show file tree
Hide file tree
Showing 10 changed files with 1,631 additions and 3 deletions.
73 changes: 72 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...context import current_context
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -293,3 +293,74 @@ def power(x1, x2, out=None):
This is a scalar if both x1 and x2 are scalars.
"""
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)


@set_module('mxnet.ndarray.numpy')
def tensordot(a, b, axes=2):
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : ndarray, len(shape) >= 1
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) ndarray
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements ndarray must be of the same length.
See Also
--------
dot, einsum
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
Examples
--------
>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[ 4400., 4730.],
[ 4532., 4874.],
[ 4664., 5018.],
[ 4796., 5162.],
[ 4928., 5306.]])
"""
if _np.isscalar(axes):
return _npi.tensordot_int_axes(a, b, axes)

if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if _np.isscalar(a_axes_summed):
a_axes_summed = (a_axes_summed,)
if _np.isscalar(b_axes_summed):
b_axes_summed = (b_axes_summed,)

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)
59 changes: 58 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..ndarray.numpy import _internal as _npi

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
'mod', 'power']
'mod', 'power', 'tensordot']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1549,3 +1549,60 @@ def power(x1, x2, out=None):
This is a scalar if both x1 and x2 are scalars.
"""
return _mx_nd_np.power(x1, x2, out=out)


@set_module('mxnet.numpy')
def tensordot(a, b, axes=2):
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : ndarray, len(shape) >= 1
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) ndarray
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements ndarray must be of the same length.
See Also
--------
dot, einsum
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
Examples
--------
>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[ 4400., 4730.],
[ 4532., 4874.],
[ 4664., 5018.],
[ 4796., 5162.],
[ 4928., 5306.]])
"""
return _mx_nd_np.tensordot(a, b, axes)
57 changes: 56 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .._internal import _set_np_symbol_class
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot']


def _num_outputs(sym):
Expand Down Expand Up @@ -1010,4 +1010,59 @@ def power(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)


@set_module('mxnet.symbol.numpy')
def tensordot(a, b, axes=2):
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : _Symbol
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) array_like
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements array_like must be of the same length.
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
"""
if _np.isscalar(axes):
return _npi.tensordot_int_axes(a, b, axes)

if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if _np.isscalar(a_axes_summed):
a_axes_summed = (a_axes_summed,)
if _np.isscalar(b_axes_summed):
b_axes_summed = (b_axes_summed,)

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


_set_np_symbol_class(_Symbol)
110 changes: 110 additions & 0 deletions src/operator/numpy/np_dot-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_dot-inl.h
* \brief Function definition of matrix numpy-compatible dot operator
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_DOT_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_DOT_INL_H_

#include <mxnet/operator_util.h>
#include <vector>
#include "../tensor/dot-inl.h"
#include "../tensor/elemwise_binary_op.h"
#include "../tensor/broadcast_reduce_op.h"
#include "np_tensordot_op-inl.h"

namespace mxnet {
namespace op {

template<typename xpu>
inline void NumpyDotForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;

CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

const TBlob& a = inputs[0];
const TBlob& b = inputs[1];
const TBlob& out = outputs[0];
const mxnet::TShape a_shape = a.shape_;
const mxnet::TShape b_shape = b.shape_;

MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
if (b_shape.ndim() < 3) {
// Case 1, 2, 3, 4, 5: a is N-D array (N >= 1) and b is vector or matrix, sum product
// over the last axis of a and the first axis of b
TensordotIntAxesImpl<xpu>(1, ctx, a, b, out, req[0]);
} else {
// Case 3, 5.5: a is N-D array and b is M-D array (M > 2), sum product over the last axis
// of a and the 2nd-to-last axis of b
const Tuple<int> a_axes_summed({a_shape.ndim() - 1});
const Tuple<int> b_axes_summed({b_shape.ndim() - 2});
TensordotImpl<xpu>(a_axes_summed, b_axes_summed, ctx, a, b, out, req);
}
});
}

template<typename xpu>
inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow_op;

CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);

const TBlob& ograd = inputs[0];
const TBlob& a = inputs[1];
const TBlob& b = inputs[2];
const TBlob& grad_a = outputs[0];
const TBlob& grad_b = outputs[1];
const mxnet::TShape a_shape = a.shape_;
const mxnet::TShape b_shape = b.shape_;

MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, {
if (b_shape.ndim() < 3) {
// Case 1, 2, 3, 4, 5: a is N-D array (N >= 1) and b is vector or matrix, sum product
// over the last axis of a and the first axis of b
TensordotIntAxesBackwardImpl<xpu>(1, ctx, ograd, a, b, grad_a, grad_b, req);
} else {
// Case 3, 5.5: a is N-D array and b is M-D array (M > 2), sum product over the last axis
// of a and the 2nd-to-last axis of b
const Tuple<int> a_axes_summed({a_shape.ndim() - 1});
const Tuple<int> b_axes_summed({b_shape.ndim() - 2});
TensordotBackwardImpl<xpu>(a_axes_summed, b_axes_summed, ctx, ograd, a, b, grad_a,
grad_b, req);
}
});
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_DOT_INL_H_
Loading

0 comments on commit c3f5eea

Please sign in to comment.