Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
set atol=rtol=1e-1 in test_np_linalg_cholesky
Browse files Browse the repository at this point in the history
fix test_np_linalg_tensorinv
fix bug in test_np_linalg_tensorinv
commit tensorinv src
  • Loading branch information
Ubuntu committed Dec 6, 2019
1 parent 8dd7051 commit 868b224
Show file tree
Hide file tree
Showing 9 changed files with 636 additions and 5 deletions.
57 changes: 56 additions & 1 deletion python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import _op as _mx_nd_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'tensorinv']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -352,3 +352,58 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _npi.slogdet(a)


def tensorinv(a, ind=2):
r"""
Compute the 'inverse' of an N-dimensional array.
The result is an inverse for `a` relative to the tensordot operation
``tensordot(a, b, ind)``, i. e., up to floating-point accuracy,
``tensordot(tensorinv(a), a, ind)`` is the "identity" tensor for the
tensordot operation.
Parameters
----------
a : array_like
Tensor to 'invert'. Its shape must be 'square', i. e.,
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
ind : int, optional
Number of first indices that are involved in the inverse sum.
Must be a positive integer, default is 2.
Returns
-------
b : ndarray
`a`'s tensordot inverse, shape ``a.shape[ind:] + a.shape[:ind]``.
Raises
------
MXNetError
If `a` is singular or not 'square' (in the above sense).
See Also
--------
tensordot, tensorsolve
Examples
--------
>>> a = np.eye(4*6)
>>> a.shape = (4, 6, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=2)
>>> ainv.shape
(8, 3, 4, 6)
>>> b = np.random.randn(4, 6)
>>> np.allclose(np.tensordot(ainv, b), np.linalg.tensorsolve(a, b))
True
>>> a = np.eye(4*6)
>>> a.shape = (24, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=1)
>>> ainv.shape
(8, 3, 24)
>>> b = np.random.randn(24)
>>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
True
"""
return _npi.tensorinv(a, ind)
57 changes: 56 additions & 1 deletion python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'tensorinv']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -370,3 +370,58 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _mx_nd_np.linalg.slogdet(a)


def tensorinv(a, ind=2):
r"""
Compute the 'inverse' of an N-dimensional array.
The result is an inverse for `a` relative to the tensordot operation
``tensordot(a, b, ind)``, i. e., up to floating-point accuracy,
``tensordot(tensorinv(a), a, ind)`` is the "identity" tensor for the
tensordot operation.
Parameters
----------
a : array_like
Tensor to 'invert'. Its shape must be 'square', i. e.,
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
ind : int, optional
Number of first indices that are involved in the inverse sum.
Must be a positive integer, default is 2.
Returns
-------
b : ndarray
`a`'s tensordot inverse, shape ``a.shape[ind:] + a.shape[:ind]``.
Raises
------
MXNetError
If `a` is singular or not 'square' (in the above sense).
See Also
--------
tensordot, tensorsolve
Examples
--------
>>> a = np.eye(4*6)
>>> a.shape = (4, 6, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=2)
>>> ainv.shape
(8, 3, 4, 6)
>>> b = np.random.randn(4, 6)
>>> np.allclose(np.tensordot(ainv, b), np.linalg.tensorsolve(a, b))
True
>>> a = np.eye(4*6)
>>> a.shape = (24, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=1)
>>> ainv.shape
(8, 3, 24)
>>> b = np.random.randn(24)
>>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
True
"""
return _mx_nd_np.linalg.tensorinv(a, ind)
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'linalg.norm',
'linalg.cholesky',
'linalg.inv',
'linalg.tensorinv',
'shape',
'trace',
'tril',
Expand Down
57 changes: 56 additions & 1 deletion python/mxnet/symbol/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from . import _op as _mx_sym_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'tensorinv']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -339,3 +339,58 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _npi.slogdet(a)


def tensorinv(a, ind=2):
r"""
Compute the 'inverse' of an N-dimensional array.
The result is an inverse for `a` relative to the tensordot operation
``tensordot(a, b, ind)``, i. e., up to floating-point accuracy,
``tensordot(tensorinv(a), a, ind)`` is the "identity" tensor for the
tensordot operation.
Parameters
----------
a : array_like
Tensor to 'invert'. Its shape must be 'square', i. e.,
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
ind : int, optional
Number of first indices that are involved in the inverse sum.
Must be a positive integer, default is 2.
Returns
-------
b : ndarray
`a`'s tensordot inverse, shape ``a.shape[ind:] + a.shape[:ind]``.
Raises
------
MXNetError
If `a` is singular or not 'square' (in the above sense).
See Also
--------
tensordot, tensorsolve
Examples
--------
>>> a = np.eye(4*6)
>>> a.shape = (4, 6, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=2)
>>> ainv.shape
(8, 3, 4, 6)
>>> b = np.random.randn(4, 6)
>>> np.allclose(np.tensordot(ainv, b), np.linalg.tensorsolve(a, b))
True
>>> a = np.eye(4*6)
>>> a.shape = (24, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=1)
>>> ainv.shape
(8, 3, 24)
>>> b = np.random.randn(24)
>>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
True
"""
return _npi.tensorinv(a, ind)
171 changes: 171 additions & 0 deletions src/operator/numpy/linalg/np_tensorinv-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_tensorinv-inl.h
* \brief Placeholder for tensor inverse
*/
#ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_
#define MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_

#include <mxnet/operator_util.h>
#include <vector>
#include "../../operator_common.h"
#include "../../mshadow_op.h"
#include "../../tensor/la_op.h"
#include "../../tensor/la_op-inl.h"

namespace mxnet {
namespace op {

using namespace mshadow;

struct TensorinvParam : public dmlc::Parameter<TensorinvParam> {
int ind;
DMLC_DECLARE_PARAMETER(TensorinvParam) {
DMLC_DECLARE_FIELD(ind)
.set_default(2)
.describe("Number of first indices that are involved in the inverse sum.");
}
};

template<typename xpu>
void TensorinvOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const mxnet::TBlob& a_tblob = inputs[0];
const mxnet::TBlob& inv_a_tblob = outputs[0];
const mxnet::TShape& a_shape = a_tblob.shape_;
CHECK_EQ(inv_a_tblob.type_flag_, a_tblob.type_flag_)
<< "Binary function only support input/output with the same type";
MSHADOW_SGL_DBL_TYPE_SWITCH(
outputs[0].type_flag_,
OType, {
const int ind = nnvm::get<TensorinvParam>(attrs.parsed).ind;
dim_t prod_front = 1, prod_back = 1;
if (ind < a_shape.ndim()) {
for (int i = 0; i < ind; ++i) {
prod_front *= a_shape[i];
}
for (int i = ind; i < a_shape.ndim(); ++i) {
prod_back *= a_shape[i];
}
} else {
for (int i = 0; i < a_shape.ndim(); ++i) {
prod_front *= a_shape[i];
}
}
Tensor<xpu, 3, OType> A =
a_tblob.get_with_shape<xpu, 3, OType>(Shape3(1, prod_back, prod_front), s);
Tensor<xpu, 3, OType> inv_A =
inv_a_tblob.get_with_shape<xpu, 3, OType>(Shape3(1, prod_back, prod_front), s);
inverse::op(A, inv_A, ctx, attrs);
});
}

template<typename xpu>
void TensorinvOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
// const int axes = nnvm::get<TensorinvParam>(attrs.parsed).ind;
const TBlob& out_grad = inputs[0];
const TBlob& inv_a = inputs[1];
const TBlob& grad_a = outputs[0];
const TShape& inv_a_shape = inv_a.shape_;
MSHADOW_SGL_DBL_TYPE_SWITCH(
outputs[0].type_flag_,
OType, {
const int axes = nnvm::get<TensorinvParam>(attrs.parsed).ind;
CHECK_LE(inv_a_shape.ndim(), 6U)
<< "tensorinv backward only support tensor's dimension <= 6";
if (axes < inv_a_shape.ndim()) {
const int axes1 = inv_a_shape.ndim() - axes, axes2 = axes;
TShape inv_a_transpose_shape(inv_a_shape.ndim(), -1);
for (int i = 0; i < axes; ++i) {
inv_a_transpose_shape[i] = inv_a_shape[i + inv_a_shape.ndim() - axes];
}
for (int i = axes; i < inv_a_shape.ndim(); ++i) {
inv_a_transpose_shape[i] = inv_a_shape[i - axes];
}
TShape temp_shape(2 * axes, -1);
for (int i = 0; i < axes; ++i) {
temp_shape[i] = inv_a_transpose_shape[i];
temp_shape[i + axes] = inv_a_transpose_shape[i];
}
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_shape.Size() * sizeof(OType)),
ctx.get_stream<xpu>());
TBlob temp_tblob =
TBlob(reinterpret_cast<OType*>(workspace.dptr_), temp_shape, xpu::kDevMask);
dim_t a1 = 1, a2 = 1;
for (int i = 0; i < axes2; ++i) {
a1 *= inv_a_transpose_shape[i];
}
for (int i = 0; i < axes1; ++i) {
a2 *= inv_a_shape[i];
}
Tensor<xpu, 3, OType> inv_a_tensor =
inv_a.get_with_shape<xpu, 3, OType>(Shape3(1, a2, a1), s);
Tensor<xpu, 3, OType> out_grad_tensor =
out_grad.get_with_shape<xpu, 3, OType>(Shape3(1, a2, a1), s);
Tensor<xpu, 3, OType> temp_tensor =
temp_tblob.get_with_shape<xpu, 3, OType>(Shape3(1, a1, a1), s);
Tensor<xpu, 3, OType> grad_a_tensor =
grad_a.get_with_shape<xpu, 3, OType>(Shape3(1, a1, a2), s);
gemm2::op(inv_a_tensor, out_grad_tensor, temp_tensor, OType(1), true, false, s);
gemm2::op(temp_tensor, inv_a_tensor, grad_a_tensor, OType(-1), false, true, s);
} else { // axes >= inv_a_shape.ndim()
dim_t a = 1;
for (int i = 0; i < inv_a_shape.ndim(); ++i) {
a *= inv_a_shape[i];
}
// check again
CHECK_EQ(a, 1U)
<< "a shape must be square, i. e., prod(a.shape[:ind]) == prod(a.shape[ind:]).";
Tensor<xpu, 1, OType> inv_a_tensor =
inv_a.get_with_shape<xpu, 1, OType>(Shape1(1), s);
Tensor<xpu, 1, OType> out_grad_tensor =
out_grad.get_with_shape<xpu, 1, OType>(Shape1(1), s);
Tensor<xpu, 1, OType> grad_a_tensor =
grad_a.get_with_shape<xpu, 1, OType>(Shape1(1), s);
ASSIGN_DISPATCH(grad_a_tensor, kWriteTo,
OType(-1) * inv_a_tensor * out_grad_tensor * inv_a_tensor);
}
});
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_
Loading

0 comments on commit 868b224

Please sign in to comment.