diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index d7c06e76c182..caa9ba1de67c 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -25,7 +25,7 @@ from ...context import current_context from . import _internal as _npi -__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power'] +__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot'] @set_module('mxnet.ndarray.numpy') @@ -293,3 +293,74 @@ def power(x1, x2, out=None): This is a scalar if both x1 and x2 are scalars. """ return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out) + + +@set_module('mxnet.ndarray.numpy') +def tensordot(a, b, axes=2): + r""" + tensordot(a, b, axes=2) + Compute tensor dot product along specified axes for arrays >= 1-D. + Given two tensors (arrays of dimension greater than or equal to one), + `a` and `b`, and an ndarray object containing two ndarray + objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s + elements (components) over the axes specified by ``a_axes`` and + ``b_axes``. The third argument can be a single non-negative + integer_like scalar, ``N``; if it is such, then the last ``N`` + dimensions of `a` and the first ``N`` dimensions of `b` are summed + over. + Parameters + ---------- + a, b : ndarray, len(shape) >= 1 + Tensors to "dot". + axes : int or (2,) ndarray + * integer_like + If an int N, sum over the last N axes of `a` and the first N axes + of `b` in order. The sizes of the corresponding axes must match. + * (2,) ndarray + Or, a list of axes to be summed over, first sequence applying to `a`, + second to `b`. Both elements ndarray must be of the same length. + See Also + -------- + dot, einsum + Notes + ----- + Three common use cases are: + * ``axes = 0`` : tensor product :math:`a\otimes b` + * ``axes = 1`` : tensor dot product :math:`a\cdot b` + * ``axes = 2`` : (default) tensor double contraction :math:`a:b` + When `axes` is integer_like, the sequence for evaluation will be: first + the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and + Nth axis in `b` last. + When there is more than one axis to sum over - and they are not the last + (first) axes of `a` (`b`) - the argument `axes` should consist of + two sequences of the same length, with the first axis to sum over given + first in both sequences, the second axis second, and so forth. + Examples + -------- + >>> a = np.arange(60.).reshape(3,4,5) + >>> b = np.arange(24.).reshape(4,3,2) + >>> c = np.tensordot(a,b, axes=([1,0],[0,1])) + >>> c.shape + (5, 2) + >>> c + array([[ 4400., 4730.], + [ 4532., 4874.], + [ 4664., 5018.], + [ 4796., 5162.], + [ 4928., 5306.]]) + """ + if _np.isscalar(axes): + return _npi.tensordot_int_axes(a, b, axes) + + if len(axes) != 2: + raise ValueError('Axes must consist of two arrays.') + a_axes_summed, b_axes_summed = axes + if _np.isscalar(a_axes_summed): + a_axes_summed = (a_axes_summed,) + if _np.isscalar(b_axes_summed): + b_axes_summed = (b_axes_summed,) + + if len(a_axes_summed) != len(b_axes_summed): + raise ValueError('Axes length mismatch') + + return _npi.tensordot(a, b, a_axes_summed, b_axes_summed) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 9e0c52dbfd68..f4b6e733c495 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -44,7 +44,7 @@ from ..ndarray.numpy import _internal as _npi __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', - 'mod', 'power'] + 'mod', 'power', 'tensordot'] # This function is copied from ndarray.py since pylint @@ -1549,3 +1549,60 @@ def power(x1, x2, out=None): This is a scalar if both x1 and x2 are scalars. """ return _mx_nd_np.power(x1, x2, out=out) + + +@set_module('mxnet.numpy') +def tensordot(a, b, axes=2): + r""" + tensordot(a, b, axes=2) + Compute tensor dot product along specified axes for arrays >= 1-D. + Given two tensors (arrays of dimension greater than or equal to one), + `a` and `b`, and an ndarray object containing two ndarray + objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s + elements (components) over the axes specified by ``a_axes`` and + ``b_axes``. The third argument can be a single non-negative + integer_like scalar, ``N``; if it is such, then the last ``N`` + dimensions of `a` and the first ``N`` dimensions of `b` are summed + over. + Parameters + ---------- + a, b : ndarray, len(shape) >= 1 + Tensors to "dot". + axes : int or (2,) ndarray + * integer_like + If an int N, sum over the last N axes of `a` and the first N axes + of `b` in order. The sizes of the corresponding axes must match. + * (2,) ndarray + Or, a list of axes to be summed over, first sequence applying to `a`, + second to `b`. Both elements ndarray must be of the same length. + See Also + -------- + dot, einsum + Notes + ----- + Three common use cases are: + * ``axes = 0`` : tensor product :math:`a\otimes b` + * ``axes = 1`` : tensor dot product :math:`a\cdot b` + * ``axes = 2`` : (default) tensor double contraction :math:`a:b` + When `axes` is integer_like, the sequence for evaluation will be: first + the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and + Nth axis in `b` last. + When there is more than one axis to sum over - and they are not the last + (first) axes of `a` (`b`) - the argument `axes` should consist of + two sequences of the same length, with the first axis to sum over given + first in both sequences, the second axis second, and so forth. + Examples + -------- + >>> a = np.arange(60.).reshape(3,4,5) + >>> b = np.arange(24.).reshape(4,3,2) + >>> c = np.tensordot(a,b, axes=([1,0],[0,1])) + >>> c.shape + (5, 2) + >>> c + array([[ 4400., 4730.], + [ 4532., 4874.], + [ 4664., 5018.], + [ 4796., 5162.], + [ 4928., 5306.]]) + """ + return _mx_nd_np.tensordot(a, b, axes) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 616f3066d98d..65db429e8ba7 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -28,7 +28,7 @@ from .._internal import _set_np_symbol_class from . import _internal as _npi -__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power'] +__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot'] def _num_outputs(sym): @@ -1010,4 +1010,59 @@ def power(x1, x2, out=None): return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out) +@set_module('mxnet.symbol.numpy') +def tensordot(a, b, axes=2): + r""" + tensordot(a, b, axes=2) + Compute tensor dot product along specified axes for arrays >= 1-D. + Given two tensors (arrays of dimension greater than or equal to one), + `a` and `b`, and an ndarray object containing two ndarray + objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s + elements (components) over the axes specified by ``a_axes`` and + ``b_axes``. The third argument can be a single non-negative + integer_like scalar, ``N``; if it is such, then the last ``N`` + dimensions of `a` and the first ``N`` dimensions of `b` are summed + over. + Parameters + ---------- + a, b : _Symbol + Tensors to "dot". + axes : int or (2,) ndarray + * integer_like + If an int N, sum over the last N axes of `a` and the first N axes + of `b` in order. The sizes of the corresponding axes must match. + * (2,) array_like + Or, a list of axes to be summed over, first sequence applying to `a`, + second to `b`. Both elements array_like must be of the same length. + Notes + ----- + Three common use cases are: + * ``axes = 0`` : tensor product :math:`a\otimes b` + * ``axes = 1`` : tensor dot product :math:`a\cdot b` + * ``axes = 2`` : (default) tensor double contraction :math:`a:b` + When `axes` is integer_like, the sequence for evaluation will be: first + the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and + Nth axis in `b` last. + When there is more than one axis to sum over - and they are not the last + (first) axes of `a` (`b`) - the argument `axes` should consist of + two sequences of the same length, with the first axis to sum over given + first in both sequences, the second axis second, and so forth. + """ + if _np.isscalar(axes): + return _npi.tensordot_int_axes(a, b, axes) + + if len(axes) != 2: + raise ValueError('Axes must consist of two arrays.') + a_axes_summed, b_axes_summed = axes + if _np.isscalar(a_axes_summed): + a_axes_summed = (a_axes_summed,) + if _np.isscalar(b_axes_summed): + b_axes_summed = (b_axes_summed,) + + if len(a_axes_summed) != len(b_axes_summed): + raise ValueError('Axes length mismatch') + + return _npi.tensordot(a, b, a_axes_summed, b_axes_summed) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_dot-inl.h b/src/operator/numpy/np_dot-inl.h new file mode 100644 index 000000000000..a854777c3109 --- /dev/null +++ b/src/operator/numpy/np_dot-inl.h @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_dot-inl.h + * \brief Function definition of matrix numpy-compatible dot operator + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_DOT_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_DOT_INL_H_ + +#include +#include +#include "../tensor/dot-inl.h" +#include "../tensor/elemwise_binary_op.h" +#include "../tensor/broadcast_reduce_op.h" +#include "np_tensordot_op-inl.h" + +namespace mxnet { +namespace op { + +template +inline void NumpyDotForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& a = inputs[0]; + const TBlob& b = inputs[1]; + const TBlob& out = outputs[0]; + const mxnet::TShape a_shape = a.shape_; + const mxnet::TShape b_shape = b.shape_; + + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + if (b_shape.ndim() < 3) { + // Case 1, 2, 3, 4, 5: a is N-D array (N >= 1) and b is vector or matrix, sum product + // over the last axis of a and the first axis of b + TensordotIntAxesImpl(1, ctx, a, b, out, req[0]); + } else { + // Case 3, 5.5: a is N-D array and b is M-D array (M > 2), sum product over the last axis + // of a and the 2nd-to-last axis of b + const Tuple a_axes_summed({a_shape.ndim() - 1}); + const Tuple b_axes_summed({b_shape.ndim() - 2}); + TensordotImpl(a_axes_summed, b_axes_summed, ctx, a, b, out, req); + } + }); +} + +template +inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + + const TBlob& ograd = inputs[0]; + const TBlob& a = inputs[1]; + const TBlob& b = inputs[2]; + const TBlob& grad_a = outputs[0]; + const TBlob& grad_b = outputs[1]; + const mxnet::TShape a_shape = a.shape_; + const mxnet::TShape b_shape = b.shape_; + + MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, { + if (b_shape.ndim() < 3) { + // Case 1, 2, 3, 4, 5: a is N-D array (N >= 1) and b is vector or matrix, sum product + // over the last axis of a and the first axis of b + TensordotIntAxesBackwardImpl(1, ctx, ograd, a, b, grad_a, grad_b, req); + } else { + // Case 3, 5.5: a is N-D array and b is M-D array (M > 2), sum product over the last axis + // of a and the 2nd-to-last axis of b + const Tuple a_axes_summed({a_shape.ndim() - 1}); + const Tuple b_axes_summed({b_shape.ndim() - 2}); + TensordotBackwardImpl(a_axes_summed, b_axes_summed, ctx, ograd, a, b, grad_a, + grad_b, req); + } + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_DOT_INL_H_ diff --git a/src/operator/numpy/np_dot.cc b/src/operator/numpy/np_dot.cc new file mode 100644 index 000000000000..627e68877998 --- /dev/null +++ b/src/operator/numpy/np_dot.cc @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_dot.cc + * \brief CPU Implementation of numpy-compatible dot + */ + +#include "./np_dot-inl.h" + +namespace mxnet { +namespace op { + +inline bool NumpyDotShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& a_shape = in_attrs->at(0); + const mxnet::TShape& b_shape = in_attrs->at(1); + + if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) { + return false; + } + + if (a_shape.ndim() == 1 && b_shape.ndim() == 1) { + // Case 1: both 1-D arrays, inner product of vectors + SHAPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1)); + SHAPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(0, 0)); + } else if (a_shape.ndim() == 2 && b_shape.ndim() == 2) { + // Case 2: both 2-D arrays, matrix multiplication + mxnet::TShape tmp_shape(2, -1); + tmp_shape[1] = b_shape[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, tmp_shape); + + tmp_shape[0] = a_shape[1]; + tmp_shape[1] = -1; + SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape); + + tmp_shape[0] = a_shape[0]; + tmp_shape[1] = b_shape[1]; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, tmp_shape); + } else if (a_shape.ndim() == 0 || b_shape.ndim() == 0) { + // Case 3 + 3.5: either of them is a scalar, just scale by one of them + mxnet::TShape oshape = (a_shape.ndim() == 0) ? b_shape : a_shape; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + } else if (b_shape.ndim() == 1) { + // Case 4: a is N-D array and b is 1-D array, sum product over the last axis + TShape tmp_shape(a_shape.ndim(), -1); + tmp_shape[a_shape.ndim() - 1] = b_shape[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, tmp_shape); + + tmp_shape = TShape(1, -1); + tmp_shape[0] = a_shape[a_shape.ndim() - 1]; + SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape); + + mxnet::TShape out_shape(a_shape.ndim() - 1, -1); + for (int i = 0; i < a_shape.ndim() - 1; ++i) { + out_shape[i] = a_shape[i]; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); + } else { + // Case 5: a is N-D array and b is M-D array, sum product over the last axis + // of a and the 2nd-to-last axis of b + TShape tmp_shape(a_shape.ndim(), -1); + tmp_shape[a_shape.ndim() - 1] = b_shape[b_shape.ndim() - 2]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, tmp_shape); + + tmp_shape = TShape(b_shape.ndim(), -1); + tmp_shape[b_shape.ndim() - 2] = a_shape[a_shape.ndim() - 1]; + SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape); + + tmp_shape = TShape(a_shape.ndim() + b_shape.ndim() - 2, -1); + for (int i = 0; i < a_shape.ndim() - 1; ++i) { + tmp_shape[i] = a_shape[i]; + } + for (int i = 0; i < b_shape.ndim() - 2; ++i) { + tmp_shape[i + a_shape.ndim() - 1] = b_shape[i]; + } + tmp_shape[tmp_shape.ndim() - 1] = b_shape[b_shape.ndim() - 1]; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, tmp_shape); + } + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +NNVM_REGISTER_OP(_np_dot) +.describe(R"doc(Dot product of two arrays. Specifically, + +- If both a and b are 1-D arrays, it is inner product of vectors. + +- If both a and b are 2-D arrays, it is matrix multiplication. + +- If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred. + +- If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b. + +- If a is an N-D array and b is an M-D array (where M>=2), it is a sum product over the last axis of a and the second-to-last axis of b: + + Example :: + + dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]) + +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) +.set_attr("FInferShape", NumpyDotShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyDotForward) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_np_dot"}) +.add_argument("a", "NDArray-or-Symbol", "First input") +.add_argument("b", "NDArray-or-Symbol", "Second input"); + +NNVM_REGISTER_OP(_backward_np_dot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyDotBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_dot.cu b/src/operator/numpy/np_dot.cu new file mode 100644 index 000000000000..9a9c69aa98e5 --- /dev/null +++ b/src/operator/numpy/np_dot.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_dot.cu + * \brief GPU Implementation of numpy-compatible dot + */ + +#include "./np_dot-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_np_dot) +.set_attr("FCompute", NumpyDotForward); + +NNVM_REGISTER_OP(_backward_np_dot) +.set_attr("FCompute", NumpyDotBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h new file mode 100644 index 000000000000..da3891665c4b --- /dev/null +++ b/src/operator/numpy/np_tensordot_op-inl.h @@ -0,0 +1,688 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_tensordot_op-inl.h + * \brief CPU Implementation of numpy-compatible tensordot + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_TENSORDOT_OP_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_TENSORDOT_OP_INL_H_ + +#include +#include "../tensor/matrix_op-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow; + +struct TensordotParam : public dmlc::Parameter { + mxnet::Tuple a_axes_summed, b_axes_summed; + DMLC_DECLARE_PARAMETER(TensordotParam) { + DMLC_DECLARE_FIELD(a_axes_summed); + DMLC_DECLARE_FIELD(b_axes_summed); + } +}; + +/** + * deals with negative axes. + */ +inline void ShiftAxes(Tuple* axes_summed, const int ndim) { + for (auto& i : *axes_summed) { + i = (i + ndim) % ndim; + } +} + +/** + * Gets matrix dimensions of a and b after transpose and reshape. + */ +inline void GetMatrixDimensions(int* ad1, + int* ad2, + int* bd1, + int* bd2, + const mxnet::Tuple& a_axes_remained, + const mxnet::Tuple& a_axes_summed, + const mxnet::Tuple& b_axes_remained, + const mxnet::Tuple& b_axes_summed, + const mxnet::TShape& a_shape, + const mxnet::TShape& b_shape) { + *ad1 = 1; + *ad2 = 1; + *bd1 = 1; + *bd2 = 1; + + for (int i = 0; i < a_axes_remained.ndim(); i++) { + *ad1 *= a_shape[a_axes_remained[i]]; + } + for (int i = 0; i < a_axes_summed.ndim(); i++) { + *ad2 *= a_shape[a_axes_summed[i]]; + } + for (int i = 0; i < b_axes_summed.ndim(); i++) { + *bd1 *= b_shape[b_axes_summed[i]]; + } + for (int i = 0; i < b_axes_remained.ndim(); i++) { + *bd2 *= b_shape[b_axes_remained[i]]; + } +} + +/** + * gets new axes of a and b after transpose and reshape. + */ +inline void GetReorderedAxes(const mxnet::Tuple& a_axes_summed, + mxnet::Tuple* a_axes_remained, + mxnet::Tuple* a_axes, + const mxnet::Tuple& b_axes_summed, + mxnet::Tuple* b_axes_remained, + mxnet::Tuple* b_axes, + const mxnet::TShape& a_shape, + const mxnet::TShape& b_shape) { + std::vector a_axes_remained_vector; + for (int i = 0; i < a_shape.ndim(); i++) { + a_axes_remained_vector.push_back(i); + } + for (auto& i : a_axes_summed) { + a_axes_remained_vector.erase(std::find(a_axes_remained_vector.begin(), + a_axes_remained_vector.end(), i)); + } + *a_axes_remained = mxnet::Tuple(a_axes_remained_vector); + + std::vector a_axes_vector(a_axes_remained_vector); + for (auto& i : a_axes_summed) { + a_axes_vector.push_back(i); + } + *a_axes = mxnet::Tuple(a_axes_vector); + + std::vector b_axes_remained_vector; + for (int i = 0; i < b_shape.ndim(); i++) { + b_axes_remained_vector.push_back(i); + } + for (auto& i : b_axes_summed) { + b_axes_remained_vector.erase(std::find(b_axes_remained_vector.begin(), + b_axes_remained_vector.end(), i)); + } + *b_axes_remained = mxnet::Tuple(b_axes_remained_vector); + + std::vector b_axes_vector; + for (auto& i : b_axes_summed) { + b_axes_vector.push_back(i); + } + for (auto& i : b_axes_remained_vector) { + b_axes_vector.push_back(i); + } + *b_axes = mxnet::Tuple(b_axes_vector); +} + +/** + * gets shapes of a and b after transpose and reshape. + */ +inline mxnet::TShape GetReorderedShape(const mxnet::TShape& shape, const mxnet::Tuple& axes) { + mxnet::TShape new_shape(shape); + for (int i = 0; i < axes.ndim(); i++) { + new_shape[i] = shape[axes[i]]; + } + return new_shape; +} + +/** + * gets matrix dot. Reshapes tensor a as ad1-by-ad2 matrix, tensor b as bd1-by-bd2 matrix, then + * calculates matrix dot a * b and stores in tensor out. + */ +template +void MatrixDot(const OpContext& ctx, + const TBlob& a, + const TBlob& b, + const TBlob& out, + const OpReqType req, + const int ad1, + const int ad2, + const int bd1, + const int bd2, + const bool aT = false, + const bool bT = false) { + using namespace mshadow; + using namespace mshadow_op; + + Stream *s = ctx.get_stream(); + + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + Tensor a_tensor = a.get_with_shape(Shape2(ad1, ad2), s); + Tensor b_tensor = b.get_with_shape(Shape2(bd1, bd2), s); + + if (aT && bT) { + CHECK_EQ(ad1, bd2); + Tensor out_tensor = out.get_with_shape(Shape2(ad2, bd1), s); + ASSIGN_DISPATCH(out_tensor, req, dot(a_tensor.T(), b_tensor.T())); + } else if (aT && !bT) { + CHECK_EQ(ad1, bd1); + Tensor out_tensor = out.get_with_shape(Shape2(ad2, bd2), s); + ASSIGN_DISPATCH(out_tensor, req, dot(a_tensor.T(), b_tensor)); + } else if (!aT && bT) { + CHECK_EQ(ad2, bd2); + Tensor out_tensor = out.get_with_shape(Shape2(ad1, bd1), s); + ASSIGN_DISPATCH(out_tensor, req, dot(a_tensor, b_tensor.T())); + } else { + CHECK_EQ(ad2, bd1); + Tensor out_tensor = out.get_with_shape(Shape2(ad1, bd2), s); + ASSIGN_DISPATCH(out_tensor, req, dot(a_tensor, b_tensor)); + } + }); +} + +/** + * Scalar multiply. + */ +template +struct scalar_mul_kernel { + template + MSHADOW_XINLINE static void Map(int i, DType *out, const DType* tensor, const DType *scalar) { + KERNEL_ASSIGN(out[i], req, tensor[i] * scalar[0]); + } +}; + +/** + * Calculates tensordot. + */ +template +void TensordotImpl(const Tuple& a_axes_summed, + const Tuple& b_axes_summed, + const OpContext& ctx, + const TBlob& a, + const TBlob& b, + const TBlob& out, + const std::vector& req) { + if (req[0] == kNullOp) { + return; + } + + if (out.shape_.Size() == 0U) { + return; // zero-size output, no need to launch kernel + } + + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& b_shape = b.shape_; + + mshadow::Stream *s = ctx.get_stream(); + CHECK_EQ(out.type_flag_, a.type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(out.type_flag_, b.type_flag_) + << "Binary function only support input/output with the same type"; + CHECK(out.type_flag_ == kFloat32 || out.type_flag_ == kFloat64 || + (out.type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask)) + << "Tensordot only supports float32/float64 for CPU, and float16/float32/float64 for GPU"; + + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + if (a_shape.Size() == 0U || b_shape.Size() == 0U) { + // 0-size input + if (req[0] != kAddTo) { + Tensor out_data = out.get_with_shape( + Shape1(out.shape_.Size()), s); + out_data = static_cast(0); + } + } else if (a_shape.ndim() == 0 && b_shape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + Tensor a_data = a.get_with_shape(Shape1(1), s); + Tensor b_data = b.get_with_shape(Shape1(1), s); + Tensor out_data = out.get_with_shape(Shape1(1), s); + ASSIGN_DISPATCH(out_data, req[0], a_data * b_data); + } else if (a_shape.ndim() == 0 || b_shape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + const DType* tensor = (a_shape.ndim() == 0) ? b.dptr() : a.dptr(); + const DType* scalar = (a_shape.ndim() == 0) ? a.dptr() : b.dptr(); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), tensor, scalar); + }); + } else { + // Two tensors of at least 1 dimensions. + Tuple a_axes_remained; + Tuple b_axes_remained; + Tuple a_axes; + Tuple b_axes; + GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained, + &b_axes, a_shape, b_shape); + + int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1; + GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed, + b_axes_remained, b_axes_summed, a_shape, b_shape); + + mxnet::TShape a_temp_shape = GetReorderedShape(a_shape, a_axes); + mxnet::TShape b_temp_shape = GetReorderedShape(b_shape, b_axes); + + Tensor workspace = ctx.requested[0].get_space_typed + (Shape1(a.Size() + b.Size()), s); + DType* a_ptr = reinterpret_cast(workspace.dptr_); + DType* b_ptr = reinterpret_cast(workspace.dptr_ + a.Size()); + TBlob a_res = TBlob(a_ptr, a_temp_shape, xpu::kDevMask); + TBlob b_res = TBlob(b_ptr, b_temp_shape, xpu::kDevMask); + + mxnet::op::TransposeImpl(ctx.run_ctx, a, a_res, + mxnet::TShape(a_axes.begin(), a_axes.end())); + mxnet::op::TransposeImpl(ctx.run_ctx, b, b_res, + mxnet::TShape(b_axes.begin(), b_axes.end())); + + MatrixDot(ctx, a_res, b_res, out, req[0], ad1, ad2, bd1, bd2); + } + }); +} + +/** + * forward function + */ +template +void TensordotOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + const TBlob& a = inputs[0]; + const TBlob& b = inputs[1]; + const TBlob& out = outputs[0]; + const mxnet::TShape a_shape = a.shape_; + const mxnet::TShape b_shape = b.shape_; + + const TensordotParam& param = nnvm::get(attrs.parsed); + Tuple a_axes_summed = param.a_axes_summed; + Tuple b_axes_summed = param.b_axes_summed; + ShiftAxes(&a_axes_summed, a_shape.ndim()); + ShiftAxes(&b_axes_summed, b_shape.ndim()); + + TensordotImpl(a_axes_summed, b_axes_summed, ctx, a, b, out, req); +} + +/** + * gets shapes for inverse transpose. + */ +inline mxnet::TShape GetReverseShape(const mxnet::Tuple& shape) { + mxnet::TShape shape2(shape.begin(), shape.end()); + for (int i = 0; i < shape.ndim(); i++) { + shape2[shape[i]] = i; + } + return shape2; +} + +/** + * calculates tensordot derivative. + */ +template +void TensordotBackwardImpl(const Tuple& a_axes_summed, + const Tuple& b_axes_summed, + const OpContext& ctx, + const TBlob& out_grad, + const TBlob& a, + const TBlob& b, + const TBlob& grad_a, + const TBlob& grad_b, + const std::vector& req) { + mshadow::Stream *s = ctx.get_stream(); + + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& b_shape = b.shape_; + + if ((a_shape.Size() == 0U) || (b_shape.Size() == 0U)) { + return; // zero-size output, no need to launch kernel + } + MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, { + if (a_shape.ndim() == 0 && b_shape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + Tensor out_grad_data = out_grad.get_with_shape(Shape1(1), s); + Tensor a_data = a.get_with_shape(Shape1(1), s); + Tensor b_data = b.get_with_shape(Shape1(1), s); + Tensor grad_a_data = grad_a.get_with_shape(Shape1(1), s); + Tensor grad_b_data = grad_b.get_with_shape(Shape1(1), s); + ASSIGN_DISPATCH(grad_a_data, req[0], b_data * out_grad_data); + ASSIGN_DISPATCH(grad_b_data, req[1], a_data * out_grad_data); + } else if (a_shape.ndim() == 0 || b_shape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + const TBlob& tensor = (a_shape.ndim() == 0) ? b : a; + const TBlob& tensor_grad = (a_shape.ndim() == 0) ? grad_b : grad_a; + const TBlob& scalar = (a_shape.ndim() == 0) ? a : b; + const TBlob& scalar_grad = (a_shape.ndim() == 0) ? grad_a : grad_b; + Tensor scalar_ = scalar.get_with_shape(Shape1(1), s); + Tensor scalar_grad_ = scalar_grad.get_with_shape(Shape1(1), s); + Tensor tensor_ = tensor.FlatTo1D(s); + Tensor tensor_grad_ = tensor_grad.FlatTo1D(s); + Tensor out_grad_ = out_grad.FlatTo1D(s); + const OpReqType& tensor_req = (a_shape.ndim() == 0) ? req[1] : req[0]; + const OpReqType& scalar_req = (a_shape.ndim() == 0) ? req[0] : req[1]; + ASSIGN_DISPATCH(tensor_grad_, tensor_req, + broadcast_scalar(scalar_, tensor_grad_.shape_) * out_grad_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(out_grad.shape_.Size()), s); + ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * out_grad_); + + ReduceAxesComputeImpl( + ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + } else { + // Two tensors of at least 1 dimensions. + Tuple a_axes_remained; + Tuple b_axes_remained; + Tuple a_axes; + Tuple b_axes; + GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained, + &b_axes, a_shape, b_shape); + + int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1; + GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed, + b_axes_remained, b_axes_summed, a_shape, b_shape); + + std::vector a_T_axes; + for (int i = 0; i < a_axes_summed.ndim(); i++) { + a_T_axes.push_back(a_axes_summed[i]); + } + for (int i = 0; i < a_axes_remained.ndim(); i++) { + a_T_axes.push_back(a_axes_remained[i]); + } + mxnet::TShape a_temp_shape(GetReorderedShape(a_shape, a_axes)); + mxnet::TShape a_T_temp_shape(GetReorderedShape(a_shape, a_T_axes)); + + std::vector b_T_axes; + for (int i = 0; i < b_axes_remained.ndim(); i++) { + b_T_axes.push_back(b_axes_remained[i]); + } + for (int i = 0; i < b_axes_summed.ndim(); i++) { + b_T_axes.push_back(b_axes_summed[i]); + } + mxnet::TShape b_temp_shape(GetReorderedShape(b_shape, b_axes)); + mxnet::TShape b_T_temp_shape(GetReorderedShape(b_shape, b_T_axes)); + + Tensor workspace = ctx.requested[0].get_space_typed + (Shape1((a.Size() + b.Size()) * 2), s); + DType* a_ptr = reinterpret_cast(workspace.dptr_); + DType* a_ptr2 = reinterpret_cast(workspace.dptr_ + a.Size()); + DType* b_ptr = reinterpret_cast(workspace.dptr_ + 2 * a.Size()); + DType* b_ptr2 = reinterpret_cast(workspace.dptr_ + 2 * a.Size() + b.Size()); + + TBlob a_res = TBlob(a_ptr, a_temp_shape, xpu::kDevMask); + TBlob b_res = TBlob(b_ptr, b_temp_shape, xpu::kDevMask); + TBlob a_res2 = TBlob(a_ptr2, a_T_temp_shape, xpu::kDevMask); + TBlob b_res2 = TBlob(b_ptr2, b_T_temp_shape, xpu::kDevMask); + + mxnet::op::TransposeImpl(ctx.run_ctx, a, a_res2, + mxnet::TShape(a_T_axes.begin(), a_T_axes.end())); + mxnet::op::TransposeImpl(ctx.run_ctx, b, b_res2, + mxnet::TShape(b_T_axes.begin(), b_T_axes.end())); + + MatrixDot(ctx, a_res2, out_grad, b_res, req[1], ad2, ad1, ad1, bd2); + MatrixDot(ctx, out_grad, b_res2, a_res, req[0], ad1, bd2, bd2, bd1); + + mxnet::op::TransposeImpl(ctx.run_ctx, a_res, grad_a, GetReverseShape(a_axes)); + mxnet::op::TransposeImpl(ctx.run_ctx, b_res, grad_b, GetReverseShape(b_axes)); + } + }); +} + +/** + * backward function. + */ +template +void TensordotOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + + const TBlob& out_grad = inputs[0]; + const TBlob& a = inputs[1]; + const TBlob& b = inputs[2]; + const TBlob& grad_a = outputs[0]; + const TBlob& grad_b = outputs[1]; + const mxnet::TShape a_shape = a.shape_; + const mxnet::TShape b_shape = b.shape_; + + const TensordotParam& param = nnvm::get(attrs.parsed); + Tuple a_axes_summed = param.a_axes_summed; + Tuple b_axes_summed = param.b_axes_summed; + ShiftAxes(&a_axes_summed, a_shape.ndim()); + ShiftAxes(&b_axes_summed, b_shape.ndim()); + + TensordotBackwardImpl(a_axes_summed, b_axes_summed, ctx, out_grad, a, b, grad_a, + grad_b, req); +} + +struct TensordotIntAxesParam : public dmlc::Parameter { + int axes; + DMLC_DECLARE_PARAMETER(TensordotIntAxesParam) { + DMLC_DECLARE_FIELD(axes); + } +}; + +/** + * gets summed axes of a and b from parameter axes. + */ +inline void GetSummedAxes(mxnet::Tuple* a_axes_summed_ptr, + mxnet::Tuple* b_axes_summed_ptr, + const int axes, + const mxnet::TShape& a_shape) { + std::vector a_axes_summed_vector; + for (int i = 0; i < axes; i++) { + a_axes_summed_vector.push_back(a_shape.ndim() - axes + i); + } + *a_axes_summed_ptr = mxnet::Tuple(a_axes_summed_vector); + + std::vector b_axes_summed_vector; + for (int i = 0; i < axes; i++) { + b_axes_summed_vector.push_back(i); + } + *b_axes_summed_ptr = mxnet::Tuple(b_axes_summed_vector); +} + +/** + * Calculates tensordot. + */ +template +void TensordotIntAxesImpl(const int axes, + const OpContext& ctx, + const TBlob& a, + const TBlob& b, + const TBlob& out, + const OpReqType req) { + if (req == kNullOp) { + return; + } + + if (out.shape_.Size() == 0U) { + return; // zero-size output, no need to launch kernel + } + + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& b_shape = b.shape_; + + mshadow::Stream *s = ctx.get_stream(); + CHECK_EQ(out.type_flag_, a.type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(out.type_flag_, b.type_flag_) + << "Binary function only support input/output with the same type"; + CHECK(out.type_flag_ == kFloat32 || out.type_flag_ == kFloat64 || + (out.type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask)) + << "Tensordot only supports float32/float64 for CPU, and float16/float32/float64 for GPU"; + + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + if (a_shape.Size() == 0U || b_shape.Size() == 0U) { + // 0-size input + if (req != kAddTo) { + Tensor out_data = out.get_with_shape( + Shape1(out.shape_.Size()), s); + out_data = static_cast(0); + } + } else if (a_shape.ndim() == 0 && b_shape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + Tensor a_data = a.get_with_shape(Shape1(1), s); + Tensor b_data = b.get_with_shape(Shape1(1), s); + Tensor out_data = out.get_with_shape(Shape1(1), s); + ASSIGN_DISPATCH(out_data, req, a_data * b_data); + } else if (a_shape.ndim() == 0 || b_shape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + const DType* tensor = (a_shape.ndim() == 0) ? b.dptr() : a.dptr(); + const DType* scalar = (a_shape.ndim() == 0) ? a.dptr() : b.dptr(); + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), tensor, scalar); + }); + } else { + // Two tensors of at least 1 dimensions. + Tuple a_axes_summed; + Tuple b_axes_summed; + GetSummedAxes(&a_axes_summed, &b_axes_summed, axes, a_shape); + + Tuple a_axes_remained; + Tuple b_axes_remained; + Tuple a_axes; + Tuple b_axes; + GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained, + &b_axes, a_shape, b_shape); + + int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1; + GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed, + b_axes_remained, b_axes_summed, a_shape, b_shape); + MatrixDot(ctx, a, b, out, req, ad1, ad2, bd1, bd2); + } + }); +} + +/** + * forward function + */ +template +void TensordotIntAxesOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + const TBlob& a = inputs[0]; + const TBlob& b = inputs[1]; + const TBlob& out = outputs[0]; + + const TensordotIntAxesParam& param = nnvm::get(attrs.parsed); + const int axes = param.axes; + + TensordotIntAxesImpl(axes, ctx, a, b, out, req[0]); +} + +template +void TensordotIntAxesBackwardImpl(const int axes, + const OpContext& ctx, + const TBlob& out_grad, + const TBlob& a, + const TBlob& b, + const TBlob& grad_a, + const TBlob& grad_b, + const std::vector& req) { + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& b_shape = b.shape_; + + if ((a_shape.Size() == 0U) || (b_shape.Size() == 0U)) { + return; // zero-size output, no need to launch kernel + } + + mshadow::Stream *s = ctx.get_stream(); + + MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, { + if (a_shape.ndim() == 0 && b_shape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + Tensor out_grad_data = out_grad.get_with_shape(Shape1(1), s); + Tensor a_data = a.get_with_shape(Shape1(1), s); + Tensor b_data = b.get_with_shape(Shape1(1), s); + Tensor grad_a_data = grad_a.get_with_shape(Shape1(1), s); + Tensor grad_b_data = grad_b.get_with_shape(Shape1(1), s); + ASSIGN_DISPATCH(grad_a_data, req[0], b_data * out_grad_data); + ASSIGN_DISPATCH(grad_b_data, req[1], a_data * out_grad_data); + } else if (a_shape.ndim() == 0 || b_shape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + const TBlob& tensor = (a_shape.ndim() == 0) ? b : a; + const TBlob& tensor_grad = (a_shape.ndim() == 0) ? grad_b : grad_a; + const TBlob& scalar = (a_shape.ndim() == 0) ? a : b; + const TBlob& scalar_grad = (a_shape.ndim() == 0) ? grad_a : grad_b; + Tensor scalar_ = scalar.get_with_shape(Shape1(1), s); + Tensor scalar_grad_ = scalar_grad.get_with_shape(Shape1(1), s); + Tensor tensor_ = tensor.FlatTo1D(s); + Tensor tensor_grad_ = tensor_grad.FlatTo1D(s); + Tensor out_grad_ = out_grad.FlatTo1D(s); + const OpReqType& tensor_req = (a_shape.ndim() == 0) ? req[1] : req[0]; + const OpReqType& scalar_req = (a_shape.ndim() == 0) ? req[0] : req[1]; + ASSIGN_DISPATCH(tensor_grad_, tensor_req, + broadcast_scalar(scalar_, tensor_grad_.shape_) * out_grad_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(out_grad.shape_.Size()), s); + ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * out_grad_); + + ReduceAxesComputeImpl( + ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + } else { + // Two tensors of at least 1 dimensions. + Tuple a_axes_summed; + Tuple b_axes_summed; + GetSummedAxes(&a_axes_summed, &b_axes_summed, axes, a_shape); + + Tuple a_axes_remained; + Tuple b_axes_remained; + Tuple a_axes; + Tuple b_axes; + GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained, + &b_axes, a_shape, b_shape); + + int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1; + GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed, + b_axes_remained, b_axes_summed, a_shape, b_shape); + + MatrixDot(ctx, a, out_grad, grad_b, req[1], ad1, ad2, ad1, bd2, true, false); + MatrixDot(ctx, out_grad, b, grad_a, req[0], ad1, bd2, bd1, bd2, false, true); + } + }); +} + +/** + * backward function. + */ +template +void TensordotIntAxesOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + + const TBlob& out_grad = inputs[0]; + const TBlob& a = inputs[1]; + const TBlob& b = inputs[2]; + const TBlob& grad_a = outputs[0]; + const TBlob& grad_b = outputs[1]; + + const TensordotIntAxesParam& param = nnvm::get(attrs.parsed); + const int axes = param.axes; + + TensordotIntAxesBackwardImpl(axes, ctx, out_grad, a, b, grad_a, grad_b, req); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_TENSORDOT_OP_INL_H_ diff --git a/src/operator/numpy/np_tensordot_op.cc b/src/operator/numpy/np_tensordot_op.cc new file mode 100644 index 000000000000..50c1647e0264 --- /dev/null +++ b/src/operator/numpy/np_tensordot_op.cc @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_tensordot_op.cc + * \brief CPU Implementation of numpy-compatible tensordot + */ + +#include +#include "np_tensordot_op-inl.h" + +namespace mxnet { +namespace op { + +bool TensordotOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& a_shape = in_attrs->at(0); + const mxnet::TShape& b_shape = in_attrs->at(1); + + if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) { + return false; + } + + if (a_shape.ndim() == 0) { + // a is scalar + SHAPE_ASSIGN_CHECK(*out_attrs, 0, b_shape); + SHAPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + } else if (b_shape.ndim() == 0) { + // b is scalar + SHAPE_ASSIGN_CHECK(*out_attrs, 0, a_shape); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } else { + // Two tensors of at least 1 dimensions. + const TensordotParam& param = nnvm::get(attrs.parsed); + Tuple a_axes_summed = param.a_axes_summed; + Tuple b_axes_summed = param.b_axes_summed; + ShiftAxes(&a_axes_summed, a_shape.ndim()); + ShiftAxes(&b_axes_summed, b_shape.ndim()); + + Tuple a_axes_remained; + Tuple b_axes_remained; + Tuple a_axes; + Tuple b_axes; + GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained, + &b_axes, a_shape, b_shape); + + CHECK_EQ(a_axes_summed.ndim(), b_axes_summed.ndim()); + + mxnet::TShape out_shape(a_axes_remained.ndim() + b_axes_remained.ndim(), -1); + for (int i = 0; i < a_axes_remained.ndim(); i++) { + out_shape[i] = a_shape[a_axes_remained[i]]; + } + for (int i = 0; i < b_axes_remained.ndim(); i++) { + out_shape[a_axes_remained.ndim() + i] = b_shape[b_axes_remained[i]]; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); + + mxnet::TShape tem_shape1(a_axes.ndim(), -1); + for (int i = 0; i < a_axes_remained.ndim(); i++) { + tem_shape1[a_axes_remained[i]] = out_shape[i]; + } + for (int i = 0; i < a_axes_summed.ndim(); i++) { + tem_shape1[a_axes_summed[i]] = b_shape[b_axes_summed[i]]; + } + SHAPE_ASSIGN_CHECK(*in_attrs, 0, tem_shape1); + + mxnet::TShape tem_shape2(b_axes.ndim(), -1); + for (int i = 0; i < b_axes_remained.ndim(); i++) { + tem_shape2[b_axes_remained[i]] = out_shape[a_axes_remained.ndim() + i]; + } + for (int i = 0; i < b_axes_summed.ndim(); i++) { + tem_shape2[b_axes_summed[i]] = a_shape[a_axes_summed[i]]; + } + SHAPE_ASSIGN_CHECK(*in_attrs, 1, tem_shape2); + } + + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +DMLC_REGISTER_PARAMETER(TensordotParam); + +NNVM_REGISTER_OP(_npi_tensordot) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) +.set_attr("FInferShape", TensordotOpShape) +.set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", TensordotOpForward) +.set_attr("FGradient", mxnet::op::ElemwiseGradUseIn{"_backward_npi_tensordot"}) +.add_argument("a", "NDArray-or-Symbol", "First input") +.add_argument("b", "NDArray-or-Symbol", "Second input") +.add_arguments(TensordotParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_tensordot) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", TensordotOpBackward); + +bool TensordotIntAxesOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& a_shape = in_attrs->at(0); + const mxnet::TShape& b_shape = in_attrs->at(1); + + if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) { + return false; + } + + if (a_shape.ndim() == 0) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, b_shape); + SHAPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + } else if (b_shape.ndim() == 0) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, a_shape); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } else { + const TensordotIntAxesParam& param = nnvm::get(attrs.parsed); + const int& axes = param.axes; + + Tuple a_axes_summed; + Tuple b_axes_summed; + GetSummedAxes(&a_axes_summed, &b_axes_summed, axes, a_shape); + + Tuple a_axes_remained; + Tuple b_axes_remained; + Tuple a_axes; + Tuple b_axes; + GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained, + &b_axes, a_shape, b_shape); + + CHECK_EQ(a_axes_summed.ndim(), b_axes_summed.ndim()); + + mxnet::TShape out_shape(a_axes_remained.ndim() + b_axes_remained.ndim(), -1); + for (int i = 0; i < a_axes_remained.ndim(); i++) { + out_shape[i] = a_shape[a_axes_remained[i]]; + } + for (int i = 0; i < b_axes_remained.ndim(); i++) { + out_shape[a_axes_remained.ndim() + i] = b_shape[b_axes_remained[i]]; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); + + mxnet::TShape tem_shape1(a_axes.ndim(), -1); + for (int i = 0; i < a_axes_remained.ndim(); i++) { + tem_shape1[a_axes_remained[i]] = out_shape[i]; + } + for (int i = 0; i < a_axes_summed.ndim(); i++) { + tem_shape1[a_axes_summed[i]] = b_shape[b_axes_summed[i]]; + } + SHAPE_ASSIGN_CHECK(*in_attrs, 0, tem_shape1); + + mxnet::TShape tem_shape2(b_axes.ndim(), -1); + for (int i = 0; i < b_axes_remained.ndim(); i++) { + tem_shape2[b_axes_remained[i]] = out_shape[a_axes_remained.ndim() + i]; + } + for (int i = 0; i < b_axes_summed.ndim(); i++) { + tem_shape2[b_axes_summed[i]] = a_shape[a_axes_summed[i]]; + } + SHAPE_ASSIGN_CHECK(*in_attrs, 1, tem_shape2); + } + + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +DMLC_REGISTER_PARAMETER(TensordotIntAxesParam); + +NNVM_REGISTER_OP(_npi_tensordot_int_axes) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) +.set_attr("FInferShape", TensordotIntAxesOpShape) +.set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", TensordotIntAxesOpForward) +.set_attr("FGradient", + mxnet::op::ElemwiseGradUseIn{"_backward_npi_tensordot_int_axes"}) +.add_argument("a", "NDArray-or-Symbol", "First input") +.add_argument("b", "NDArray-or-Symbol", "Second input") +.add_arguments(TensordotIntAxesParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_tensordot_int_axes) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", TensordotIntAxesOpBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_tensordot_op.cu b/src/operator/numpy/np_tensordot_op.cu new file mode 100644 index 000000000000..e1d8a0b85e6e --- /dev/null +++ b/src/operator/numpy/np_tensordot_op.cu @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License.ΓΈ + */ + +/*! + * \file np_tensordot_inplace.cu + * \brief GPU Implementation of numpy-compatible tensordot + */ + +#include "np_tensordot_op-inl.h" +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_tensordot) +.set_attr("FCompute", TensordotOpForward); + +NNVM_REGISTER_OP(_backward_npi_tensordot) +.set_attr("FCompute", TensordotOpBackward); + +NNVM_REGISTER_OP(_npi_tensordot_int_axes) +.set_attr("FCompute", TensordotIntAxesOpForward); + +NNVM_REGISTER_OP(_backward_npi_tensordot_int_axes) +.set_attr("FCompute", TensordotIntAxesOpBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b179f67e6128..0498566e5df3 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -30,6 +30,189 @@ import collections +@with_seed() +@use_np +def test_np_tensordot(): + class TestTensordot(HybridBlock): + def __init__(self, axes): + super(TestTensordot, self).__init__() + self._axes = axes + + def hybrid_forward(self, F, a, b): + return F.np.tensordot(a, b, self._axes) + + def tensordot_backward(a, b, axes=2): + if (a.ndim < 1) or (b.ndim < 1): + raise ValueError('An input is zero-dim') + + if _np.isscalar(axes): + a_axes_summed = [i + a.ndim - axes for i in range(axes)] + b_axes_summed = [i for i in range(axes)] + else: + if len(axes) != 2: + raise ValueError('Axes must consist of two arrays.') + a_axes_summed, b_axes_summed = axes + if _np.isscalar(a_axes_summed): + a_axes_summed = a_axes_summed, + if _np.isscalar(b_axes_summed): + b_axes_summed = b_axes_summed, + + for i in range(len(a_axes_summed)): + a_axes_summed[i] = (a_axes_summed[i] + a.ndim) % a.ndim + + for i in range(len(b_axes_summed)): + b_axes_summed[i] = (b_axes_summed[i] + b.ndim) % b.ndim + + if len(a_axes_summed) != len(b_axes_summed): + raise ValueError('Axes length mismatch') + + a_axes_remained = [] + for i in range(a.ndim): + if not (i in a_axes_summed): + a_axes_remained.append(i) + a_axes = a_axes_remained[:] + a_axes_summed[:] + + b_axes_remained = [] + for i in range(b.ndim): + if not (i in b_axes_summed): + b_axes_remained.append(i) + b_axes = b_axes_summed[:] + b_axes_remained[:] + + ad1 = _np.prod([a.shape[i] for i in a_axes_remained]) if len(a_axes_remained) > 0 else 1 + ad2 = _np.prod([a.shape[i] for i in a_axes_summed]) if len(a_axes_summed) > 0 else 1 + bd1 = _np.prod([b.shape[i] for i in b_axes_summed]) if len(b_axes_summed) > 0 else 1 + bd2 = _np.prod([b.shape[i] for i in b_axes_remained]) if len(b_axes_remained) > 0 else 1 + + out_grad = _np.ones((ad1, bd2)) + + new_a = _np.transpose(a, a_axes) + new_a_shape = new_a.shape[:] + new_a = new_a.reshape((ad1, ad2)) + new_b = _np.transpose(b, b_axes) + new_b_shape = new_b.shape[:] + new_b = new_b.reshape((bd1, bd2)) + + reverse_a_axes = [0 for i in a_axes] + for i in range(len(a_axes)): + reverse_a_axes[a_axes[i]] = i + + reverse_b_axes = [0 for i in b_axes] + for i in range(len(b_axes)): + reverse_b_axes[b_axes[i]] = i + + grad_b = _np.dot(new_a.T, out_grad).reshape(new_b_shape) + grad_b = _np.transpose(grad_b, reverse_b_axes) + grad_a = _np.dot(out_grad, new_b.T).reshape(new_a_shape) + grad_a = _np.transpose(grad_a, reverse_a_axes) + + return [grad_a, grad_b] + + # test non zero size input + tensor_shapes = [ + ((3, 5), (5, 4), 1), # (a_shape, b_shape, axes) + ((3,), (3,), 1), + ((3, 4, 5, 3, 2), (5, 3, 2, 1, 2), 3), + ((3, 5, 4, 3, 2), (2, 3, 5, 1, 2), [[1, 3, 4], [2, 1, 0]]), + ((3, 5, 4), (5, 4, 3), [[1, 0, 2], [0, 2, 1]]), + ((3, 5, 4), (5, 3, 4), [[2, 0], [-1, -2]]), + ((2, 2), (2, 2), 2), + ((3, 5, 4), (5, ), [[-2], [0]]), + ((3, 5, 4), (5, ), [[1], [0]]), + ((2,), (2, 3), 1), + ((3,), (3,), 0), + ((2,), (2, 3), 0), + ((3, 5, 4), (5, ), 0), + ((2, 3, 4), (4, 3, 2), [[], []]), + ((3, 0), (0, 5), 1), + ((3, 0), (0, 4), [[1], [0]]), + ((0, 3), (3, 5), 1), + ((0, 3), (5, 0), [[0], [1]]) + ] + + for hybridize in [True, False]: + for a_shape, b_shape, axes in tensor_shapes: + for dtype in [_np.float32, _np.float64]: + test_tensordot = TestTensordot(axes) + if hybridize: + test_tensordot.hybridize() + a = rand_ndarray(shape = a_shape, dtype = dtype).as_np_ndarray() + b = rand_ndarray(shape = b_shape, dtype = dtype).as_np_ndarray() + a.attach_grad() + b.attach_grad() + + np_out = _np.tensordot(a.asnumpy(), b.asnumpy(), axes) + with mx.autograd.record(): + mx_out = test_tensordot(a, b) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5) + mx_out.backward() + np_backward = tensordot_backward(a.asnumpy(), b.asnumpy(), axes) + assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = 1e-3, atol=1e-5) + assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = 1e-3, atol=1e-5) + + # Test imperative once again + mx_out = np.tensordot(a, b, axes) + np_out = _np.tensordot(a.asnumpy(), b.asnumpy(), axes) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + # test numeric gradient + if (_np.prod(a_shape) > 0 and _np.prod(b_shape) > 0): + a_sym = mx.sym.Variable("a").as_np_ndarray() + b_sym = mx.sym.Variable("b").as_np_ndarray() + mx_sym = mx.sym.np.tensordot(a_sym, b_sym, axes).as_nd_ndarray() + check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()], + rtol=1e-1, atol=1e-1, dtype = dtype) + + +@with_seed() +@use_np +def test_np_dot(): + shapes = [ + ((3, 0), (0, 4)), + ((3,), (3,)), # Case 1 + ((3, 4), (4, 5)), # Case 2 + ((), ()), # Case 3 + ((3, 4, 5), ()), # Case 3.5.1 + ((), (3, 4, 5)), # Case 3.5.2 + ((3, 4, 5), (5, )), # Case 4 + ((3, 4, 5), (5, 2)), # Case 5 + ((5,), (5, 2)), + ((3, 5, 4), (5, 4, 3)), + ((3, 4), (5, 4, 3)), + ((4,), (5, 4, 3)) + ] + + eps = 1e-3 + + for shape_a, shape_b in shapes: + np_a = _np.random.uniform(-1.0, 1.0, shape_a) + np_a[abs(np_a) < eps] = 2 * eps + np_b = _np.random.uniform(-1.0, 1.0, shape_b) + np_b[abs(np_b) < eps] = 2 * eps + a = mx.nd.array(np_a) + b = mx.nd.array(np_b) + np_res = _np.dot(np_a, np_b) + mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray()) + assert mx_res.shape == np_res.shape + assert_almost_equal(np_res, mx_res.asnumpy(), rtol=1e-5, atol=1e-5) + mx_a = mx.sym.Variable("a") + mx_b = mx.sym.Variable("b") + mx_sym = mx.sym.np.dot(mx_a.as_np_ndarray(), mx_b.as_np_ndarray()).as_nd_ndarray() + if (len(shape_a) > 0 and len(shape_b) > 0 and _np.prod(shape_a) > 0 and _np.prod(shape_b) > 0): + check_numeric_gradient(mx_sym, {"a": a, "b": b}, numeric_eps=eps, rtol=1e-2, atol=1e-3) + + bad_shapes = [((4, 5), (2, 3)), ((3, 4, 5), (6, ))] + + for shape_a, shape_b in bad_shapes: + a = mx.nd.array(random.random()) if len(shape_a) == 0 else rand_ndarray(shape_a) + b = mx.nd.array(random.random()) if len(shape_b) == 0 else rand_ndarray(shape_b) + try: + mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray()) + except mx.base.MXNetError: + continue + assert False + + @with_seed() @use_np def test_np_sum():