Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Numpy Tensordot and Dot Operator #15820

Merged
merged 5 commits into from
Aug 12, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,78 @@
from ...context import current_context
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot']


@set_module('mxnet.ndarray.numpy')
def tensordot(a, b, axes=2):
ckt624 marked this conversation as resolved.
Show resolved Hide resolved
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : ndarray, len(shape) >= 1
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) ndarray
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements ndarray must be of the same length.
See Also
--------
dot, einsum
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
Examples
--------
>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[ 4400., 4730.],
[ 4532., 4874.],
[ 4664., 5018.],
[ 4796., 5162.],
[ 4928., 5306.]])
"""
if _np.isscalar(axes):
return _npi.tensordot_int_axes(a, b, axes)

if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if _np.isscalar(a_axes_summed):
a_axes_summed = (a_axes_summed,)
if _np.isscalar(b_axes_summed):
b_axes_summed = (b_axes_summed,)

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.ndarray.numpy')
Expand Down
73 changes: 72 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,78 @@
from ..ndarray.numpy import _internal as _npi

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
'mod', 'power']
'mod', 'power', 'tensordot']


@set_module('mxnet.numpy')
def tensordot(a, b, axes=2):
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : ndarray, len(shape) >= 1
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) ndarray
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements ndarray must be of the same length.
See Also
--------
dot, einsum
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
Examples
--------
>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[ 4400., 4730.],
[ 4532., 4874.],
[ 4664., 5018.],
[ 4796., 5162.],
[ 4928., 5306.]])
"""
if _np.isscalar(axes):
return _npi.tensordot_int_axes(a, b, axes)

if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if _np.isscalar(a_axes_summed):
a_axes_summed = (a_axes_summed,)
if _np.isscalar(b_axes_summed):
b_axes_summed = (b_axes_summed,)

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


# This function is copied from ndarray.py since pylint
Expand Down
57 changes: 56 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .._internal import _set_np_symbol_class
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot']


def _num_outputs(sym):
Expand Down Expand Up @@ -1010,4 +1010,59 @@ def power(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)


@set_module('mxnet.symbol.numpy')
def tensordot(a, b, axes=2):
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : _Symbol
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) array_like
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements array_like must be of the same length.
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
"""
if _np.isscalar(axes):
return _npi.tensordot_int_axes(a, b, axes)

if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if _np.isscalar(a_axes_summed):
a_axes_summed = (a_axes_summed,)
if _np.isscalar(b_axes_summed):
b_axes_summed = (b_axes_summed,)

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


_set_np_symbol_class(_Symbol)
110 changes: 110 additions & 0 deletions src/operator/numpy/np_dot-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_dot-inl.h
* \brief Function definition of matrix numpy-compatible dot operator
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_DOT_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_DOT_INL_H_

#include <mxnet/operator_util.h>
#include <vector>
#include "../tensor/dot-inl.h"
#include "../tensor/elemwise_binary_op.h"
#include "../tensor/broadcast_reduce_op.h"
#include "np_tensordot_op-inl.h"

namespace mxnet {
namespace op {

template<typename xpu>
inline void NumpyDotForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;

CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

const TBlob& a = inputs[0];
const TBlob& b = inputs[1];
const TBlob& out = outputs[0];
const mxnet::TShape a_shape = a.shape_;
const mxnet::TShape b_shape = b.shape_;

MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
if (b_shape.ndim() < 3) {
// Case 1, 2, 3, 4, 5: a is N-D array (N >= 1) and b is vector or matrix, sum product
// over the last axis of a and the first axis of b
TensordotIntAxesImpl<xpu>(1, ctx, a, b, out, req[0]);
} else {
// Case 3, 5.5: a is N-D array and b is M-D array (M > 2), sum product over the last axis
// of a and the 2nd-to-last axis of b
const Tuple<int> a_axes_summed({a_shape.ndim() - 1});
const Tuple<int> b_axes_summed({b_shape.ndim() - 2});
TensordotImpl<xpu>(a_axes_summed, b_axes_summed, ctx, a, b, out, req);
}
});
}

template<typename xpu>
inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow_op;

CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);

const TBlob& ograd = inputs[0];
const TBlob& a = inputs[1];
const TBlob& b = inputs[2];
const TBlob& grad_a = outputs[0];
const TBlob& grad_b = outputs[1];
const mxnet::TShape a_shape = a.shape_;
const mxnet::TShape b_shape = b.shape_;

MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, {
if (b_shape.ndim() < 3) {
// Case 1, 2, 3, 4, 5: a is N-D array (N >= 1) and b is vector or matrix, sum product
// over the last axis of a and the first axis of b
TensordotIntAxesBackwardImpl<xpu>(1, ctx, ograd, a, b, grad_a, grad_b, req);
} else {
// Case 3, 5.5: a is N-D array and b is M-D array (M > 2), sum product over the last axis
// of a and the 2nd-to-last axis of b
const Tuple<int> a_axes_summed({a_shape.ndim() - 1});
const Tuple<int> b_axes_summed({b_shape.ndim() - 2});
TensordotBackwardImpl<xpu>(a_axes_summed, b_axes_summed, ctx, ograd, a, b, grad_a,
grad_b, req);
}
});
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_DOT_INL_H_
Loading