From cc91bde879c1dacbeafd9617700a9f239e4c0b78 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 13 Nov 2019 07:04:28 +0000 Subject: [PATCH 01/11] det test --- python/mxnet/ndarray/numpy/linalg.py | 112 ++++++++++++++++++++++++- python/mxnet/numpy/linalg.py | 112 ++++++++++++++++++++++++- python/mxnet/symbol/numpy/linalg.py | 112 ++++++++++++++++++++++++- src/operator/tensor/la_op.cc | 2 + tests/python/unittest/test_numpy_op.py | 88 +++++++++++++++++++ 5 files changed, 423 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index d443d68cda0d..87e0e470380a 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -21,7 +21,7 @@ from . import _op as _mx_nd_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet'] def norm(x, ord=None, axis=None, keepdims=False): @@ -242,3 +242,113 @@ def inv(a): [ 0.75000006, -0.25000003]]]) """ return _npi.inv(a) + + +def det(a): + r""" + Compute the determinant of an array. + + Parameters + ---------- + a : (..., M, M) ndarray + Input array to compute determinants for. + + Returns + ------- + det : (...) ndarray + Determinant of `a`. + + See Also + -------- + slogdet : Another way to represent the determinant, more suitable + for large matrices where underflow/overflow may occur. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + The determinant is computed via LU factorization using the LAPACK + routine z/dgetrf. + + Examples + -------- + The determinant of a 2-D array [[a, b], [c, d]] is ad - bc: + >>> a = np.array([[1, 2], [3, 4]]) + >>> np.linalg.det(a) + -2.0 + + Computing determinants for a stack of matrices: + >>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ]) + >>> a.shape + (3, 2, 2) + + >>> np.linalg.det(a) + array([-2., -3., -8.]) + """ + return _npi.det(a) + + +def slogdet(a): + r""" + Compute the sign and (natural) logarithm of the determinant of an array. + If an array has a very small or very large determinant, then a call to + `det` may overflow or underflow. This routine is more robust against such + issues, because it computes the logarithm of the determinant rather than + the determinant itself. + + Parameters + ---------- + a : (..., M, M) ndarray + Input array, has to be a square 2-D array. + + Returns + ------- + sign : (...) ndarray + A number representing the sign of the determinant. For a real matrix, + this is 1, 0, or -1. + logdet : (...) array_like + The natural log of the absolute value of the determinant. + If the determinant is zero, then `sign` will be 0 and `logdet` will be + -Inf. In all cases, the determinant is equal to ``sign * np.exp(logdet)``. + + See Also + -------- + det + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + The determinant is computed via LU factorization using the LAPACK + routine z/dgetrf. + + Examples + -------- + The determinant of a 2-D array ``[[a, b], [c, d]]`` is ``ad - bc``: + >>> a = np.array([[1, 2], [3, 4]]) + >>> (sign, logdet) = np.linalg.slogdet(a) + >>> (sign, logdet) + (-1., 0.69314718055994529) + + >>> sign * np.exp(logdet) + -2.0 + + Computing log-determinants for a stack of matrices: + >>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ]) + >>> a.shape + (3, 2, 2) + + >>> sign, logdet = np.linalg.slogdet(a) + >>> (sign, logdet) + (array([-1., -1., -1.]), array([ 0.69314718, 1.09861229, 2.07944154])) + + >>> sign * np.exp(logdet) + array([-2., -3., -8.]) + + This routine succeeds where ordinary `det` does not: + >>> np.linalg.det(np.eye(500) * 0.1) + 0.0 + >>> np.linalg.slogdet(np.eye(500) * 0.1) + (1., -1151.2925464970228) + """ + return _npi.slogdet(a) diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index b0552ee9f319..0f910a86ffbc 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -20,7 +20,7 @@ from __future__ import absolute_import from ..ndarray import numpy as _mx_nd_np -__all__ = ['norm', 'svd', 'cholesky', 'inv'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet'] def norm(x, ord=None, axis=None, keepdims=False): @@ -260,3 +260,113 @@ def inv(a): [ 0.75000006, -0.25000003]]]) """ return _mx_nd_np.linalg.inv(a) + + +def det(a): + r""" + Compute the determinant of an array. + + Parameters + ---------- + a : (..., M, M) ndarray + Input array to compute determinants for. + + Returns + ------- + det : (...) ndarray + Determinant of `a`. + + See Also + -------- + slogdet : Another way to represent the determinant, more suitable + for large matrices where underflow/overflow may occur. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + The determinant is computed via LU factorization using the LAPACK + routine z/dgetrf. + + Examples + -------- + The determinant of a 2-D array [[a, b], [c, d]] is ad - bc: + >>> a = np.array([[1, 2], [3, 4]]) + >>> np.linalg.det(a) + -2.0 + + Computing determinants for a stack of matrices: + >>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ]) + >>> a.shape + (3, 2, 2) + + >>> np.linalg.det(a) + array([-2., -3., -8.]) + """ + return _mx_nd_np.linalg.det(a) + + +def slogdet(a): + r""" + Compute the sign and (natural) logarithm of the determinant of an array. + If an array has a very small or very large determinant, then a call to + `det` may overflow or underflow. This routine is more robust against such + issues, because it computes the logarithm of the determinant rather than + the determinant itself. + + Parameters + ---------- + a : (..., M, M) ndarray + Input array, has to be a square 2-D array. + + Returns + ------- + sign : (...) ndarray + A number representing the sign of the determinant. For a real matrix, + this is 1, 0, or -1. + logdet : (...) array_like + The natural log of the absolute value of the determinant. + If the determinant is zero, then `sign` will be 0 and `logdet` will be + -Inf. In all cases, the determinant is equal to ``sign * np.exp(logdet)``. + + See Also + -------- + det + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + The determinant is computed via LU factorization using the LAPACK + routine z/dgetrf. + + Examples + -------- + The determinant of a 2-D array ``[[a, b], [c, d]]`` is ``ad - bc``: + >>> a = np.array([[1, 2], [3, 4]]) + >>> (sign, logdet) = np.linalg.slogdet(a) + >>> (sign, logdet) + (-1., 0.69314718055994529) + + >>> sign * np.exp(logdet) + -2.0 + + Computing log-determinants for a stack of matrices: + >>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ]) + >>> a.shape + (3, 2, 2) + + >>> sign, logdet = np.linalg.slogdet(a) + >>> (sign, logdet) + (array([-1., -1., -1.]), array([ 0.69314718, 1.09861229, 2.07944154])) + + >>> sign * np.exp(logdet) + array([-2., -3., -8.]) + + This routine succeeds where ordinary `det` does not: + >>> np.linalg.det(np.eye(500) * 0.1) + 0.0 + >>> np.linalg.slogdet(np.eye(500) * 0.1) + (1., -1151.2925464970228) + """ + return _mx_nd_np.linalg.slogdet(a) diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index 3c8118550559..06af0c90d3b2 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -22,7 +22,7 @@ from . import _op as _mx_sym_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet'] def norm(x, ord=None, axis=None, keepdims=False): @@ -229,3 +229,113 @@ def inv(a): [ 0.75000006, -0.25000003]]]) """ return _npi.inv(a) + + +def det(a): + r""" + Compute the determinant of an array. + + Parameters + ---------- + a : (..., M, M) ndarray + Input array to compute determinants for. + + Returns + ------- + det : (...) ndarray + Determinant of `a`. + + See Also + -------- + slogdet : Another way to represent the determinant, more suitable + for large matrices where underflow/overflow may occur. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + The determinant is computed via LU factorization using the LAPACK + routine z/dgetrf. + + Examples + -------- + The determinant of a 2-D array [[a, b], [c, d]] is ad - bc: + >>> a = np.array([[1, 2], [3, 4]]) + >>> np.linalg.det(a) + -2.0 + + Computing determinants for a stack of matrices: + >>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ]) + >>> a.shape + (3, 2, 2) + + >>> np.linalg.det(a) + array([-2., -3., -8.]) + """ + return _npi.det(a) + + +def slogdet(a): + r""" + Compute the sign and (natural) logarithm of the determinant of an array. + If an array has a very small or very large determinant, then a call to + `det` may overflow or underflow. This routine is more robust against such + issues, because it computes the logarithm of the determinant rather than + the determinant itself. + + Parameters + ---------- + a : (..., M, M) ndarray + Input array, has to be a square 2-D array. + + Returns + ------- + sign : (...) ndarray + A number representing the sign of the determinant. For a real matrix, + this is 1, 0, or -1. + logdet : (...) array_like + The natural log of the absolute value of the determinant. + If the determinant is zero, then `sign` will be 0 and `logdet` will be + -Inf. In all cases, the determinant is equal to ``sign * np.exp(logdet)``. + + See Also + -------- + det + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + The determinant is computed via LU factorization using the LAPACK + routine z/dgetrf. + + Examples + -------- + The determinant of a 2-D array ``[[a, b], [c, d]]`` is ``ad - bc``: + >>> a = np.array([[1, 2], [3, 4]]) + >>> (sign, logdet) = np.linalg.slogdet(a) + >>> (sign, logdet) + (-1., 0.69314718055994529) + + >>> sign * np.exp(logdet) + -2.0 + + Computing log-determinants for a stack of matrices: + >>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ]) + >>> a.shape + (3, 2, 2) + + >>> sign, logdet = np.linalg.slogdet(a) + >>> (sign, logdet) + (array([-1., -1., -1.]), array([ 0.69314718, 1.09861229, 2.07944154])) + + >>> sign * np.exp(logdet) + array([-2., -3., -8.]) + + This routine succeeds where ordinary `det` does not: + >>> np.linalg.det(np.eye(500) * 0.1) + 0.0 + >>> np.linalg.slogdet(np.eye(500) * 0.1) + (1., -1151.2925464970228) + """ + return _npi.slogdet(a) diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 8f9da23711a0..8407307bdd6d 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -945,6 +945,7 @@ NNVM_REGISTER_OP(_backward_linalg_inverse) NNVM_REGISTER_OP(_linalg_det) .add_alias("linalg_det") +.add_alias("_npi_det") .describe(R"code(Compute the determinant of a matrix. Input is a tensor *A* of dimension *n >= 2*. @@ -997,6 +998,7 @@ NNVM_REGISTER_OP(_backward_linalg_det) NNVM_REGISTER_OP(_linalg_slogdet) .add_alias("linalg_slogdet") +.add_alias("_npi_slogdet") .describe(R"code(Compute the sign and log of the determinant of a matrix. Input is a tensor *A* of dimension *n >= 2*. diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9aabdfd4cabc..f36d259601cb 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3222,6 +3222,94 @@ def check_inv(A_inv, data_np): check_inv(A_inv, data_np) +def test_np_linalg_det(): + class TestDet(HybridBlock): + def __init__(self): + super(TestDet, self).__init__() + + def hybrid_forward(self, F, a): + return F.np.linalg.det(a) + + # test non zero size input + tensor_shapes = [ + (5, 5), + (3, 3, 3), + (2, 2, 2, 2, 2), + (1, 1) + ] + + for hybridize in [True, False]: + for shape in tensor_shapes: + for dtype in [_np.float32, _np.float64]: + a_shape = (1,) + shape + test_det = TestDet() + if hybridize: + test_det.hybridize() + a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() + a.attach_grad() + + np_out = _np.linalg.det(a.asnumpy()) + with mx.autograd.record(): + mx_out = test_det(a) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1) + mx_out.backward() + + # Test imperative once again + mx_out = np.linalg.det(a) + np_out = _np.linalg.det(a.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1) + + # test numeric gradient + a_sym = mx.sym.Variable("a").as_np_ndarray() + mx_sym = mx.sym.np.linalg.det(a_sym).as_nd_ndarray() + check_numeric_gradient(mx_sym, [a.as_nd_ndarray()], + rtol=1e-1, atol=1e-1, dtype=dtype) + +""" +@with_seed() +@use_np +def test_np_linalg_slogdet(): + class TestSlogdet(HybridBlock): + def __init__(self): + super(TestSlogdet, self).__init__() + + def hybrid_forward(self, F, a): + return F.np.linalg.slogdet(a) + + # test non zero size input + tensor_shapes = [ + (5, 5), + (3, 3, 3), + (2, 2, 2, 2, 2), + (1, 1) + ] + + for hybridize in [True, False]: + for a_shape in tensor_shapes: + for dtype in [_np.float32, _np.float64]: + test_slogdet = TestSlogdet() + if hybridize: + test_slogdet.hybridize() + a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() + a.attach_grad() + + np_out = _np.linalg.slogdet(a.asnumpy()) + with mx.autograd.record(): + mx_out = test_slogdet(a) + assert mx_out[0].shape == np_out[0].shape + assert mx_out[1].shape == np_out[1].shape + assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol = 1e-1, atol = 1e-1) + assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol = 1e-1, atol = 1e-1) + mx_out[1].backward() + + # Test imperative once again + mx_out = np.linalg.slogdet(a) + np_out = _np.linalg.slogdet(a.asnumpy()) + assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol=1e-1, atol=1e-1) + assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol=1e-1, atol=1e-1) +""" + @with_seed() @use_np def test_np_vstack(): From c10a4cb9f2e32085f2d1ab009efe5edcf72c37b8 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 13 Nov 2019 11:37:36 +0000 Subject: [PATCH 02/11] slogdet test --- python/mxnet/ndarray/numpy/linalg.py | 10 +++++----- python/mxnet/numpy/linalg.py | 10 +++++----- python/mxnet/symbol/numpy/linalg.py | 10 +++++----- src/operator/tensor/la_op.h | 7 ++++++- tests/python/unittest/test_numpy_op.py | 4 ++-- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 87e0e470380a..74ba41f22979 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -295,12 +295,12 @@ def slogdet(a): `det` may overflow or underflow. This routine is more robust against such issues, because it computes the logarithm of the determinant rather than the determinant itself. - + Parameters ---------- a : (..., M, M) ndarray Input array, has to be a square 2-D array. - + Returns ------- sign : (...) ndarray @@ -310,18 +310,18 @@ def slogdet(a): The natural log of the absolute value of the determinant. If the determinant is zero, then `sign` will be 0 and `logdet` will be -Inf. In all cases, the determinant is equal to ``sign * np.exp(logdet)``. - + See Also -------- det - + Notes ----- Broadcasting rules apply, see the `numpy.linalg` documentation for details. The determinant is computed via LU factorization using the LAPACK routine z/dgetrf. - + Examples -------- The determinant of a 2-D array ``[[a, b], [c, d]]`` is ``ad - bc``: diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index 0f910a86ffbc..fbe3631eb6e6 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -313,12 +313,12 @@ def slogdet(a): `det` may overflow or underflow. This routine is more robust against such issues, because it computes the logarithm of the determinant rather than the determinant itself. - + Parameters ---------- a : (..., M, M) ndarray Input array, has to be a square 2-D array. - + Returns ------- sign : (...) ndarray @@ -328,18 +328,18 @@ def slogdet(a): The natural log of the absolute value of the determinant. If the determinant is zero, then `sign` will be 0 and `logdet` will be -Inf. In all cases, the determinant is equal to ``sign * np.exp(logdet)``. - + See Also -------- det - + Notes ----- Broadcasting rules apply, see the `numpy.linalg` documentation for details. The determinant is computed via LU factorization using the LAPACK routine z/dgetrf. - + Examples -------- The determinant of a 2-D array ``[[a, b], [c, d]]`` is ``ad - bc``: diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index 06af0c90d3b2..cf33777b2637 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -282,12 +282,12 @@ def slogdet(a): `det` may overflow or underflow. This routine is more robust against such issues, because it computes the logarithm of the determinant rather than the determinant itself. - + Parameters ---------- a : (..., M, M) ndarray Input array, has to be a square 2-D array. - + Returns ------- sign : (...) ndarray @@ -297,18 +297,18 @@ def slogdet(a): The natural log of the absolute value of the determinant. If the determinant is zero, then `sign` will be 0 and `logdet` will be -Inf. In all cases, the determinant is equal to ``sign * np.exp(logdet)``. - + See Also -------- det - + Notes ----- Broadcasting rules apply, see the `numpy.linalg` documentation for details. The determinant is computed via LU factorization using the LAPACK routine z/dgetrf. - + Examples -------- The determinant of a 2-D array ``[[a, b], [c, d]]`` is ``ad - bc``: diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index e024693e3819..bb56dc52b594 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -26,6 +26,7 @@ #define MXNET_OPERATOR_TENSOR_LA_OP_H_ #include +#include #include #include #include "../mshadow_op.h" @@ -428,7 +429,11 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal"; mxnet::TShape out; if (ndim == 2) { - out = mxnet::TShape(1, 1); + if (Imperative::Get()->is_np_shape()) { + out = mxnet::TShape(0, 1); + } else { + out = mxnet::TShape(1, 1); + } } else { out = mxnet::TShape(in.begin(), in.end() - 2); } diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index f36d259601cb..d112bb107d2c 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3266,7 +3266,7 @@ def hybrid_forward(self, F, a): check_numeric_gradient(mx_sym, [a.as_nd_ndarray()], rtol=1e-1, atol=1e-1, dtype=dtype) -""" + @with_seed() @use_np def test_np_linalg_slogdet(): @@ -3308,7 +3308,7 @@ def hybrid_forward(self, F, a): np_out = _np.linalg.slogdet(a.asnumpy()) assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol=1e-1, atol=1e-1) assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol=1e-1, atol=1e-1) -""" + @with_seed() @use_np From 2ab582a0402a56b01e79630c8d2ae900e3ba7d7b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 14 Nov 2019 08:19:30 +0000 Subject: [PATCH 03/11] add more tests --- src/operator/tensor/la_op-inl.h | 12 +++ src/operator/tensor/la_op.h | 3 + .../unittest/test_numpy_interoperability.py | 12 +++ tests/python/unittest/test_numpy_op.py | 95 ++++++++++--------- 4 files changed, 77 insertions(+), 45 deletions(-) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index 35edbbd42565..d580cced4ec5 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -502,6 +502,9 @@ struct det { static void op(const Tensor& A, const Tensor& det, const Tensor& LU, const Tensor& pivot, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + if (A.shape_.Size() == 0U) { + return; + } Stream *s = ctx.get_stream(); Tensor sign = ctx.requested[0] .get_space_typed(det.shape_, s); @@ -524,6 +527,9 @@ struct slogdet { const Tensor& logabsdet, const Tensor& LU, const Tensor& pivot, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + if (A.shape_.Size() == 0U) { + return; + } Stream *s = ctx.get_stream(); Copy(LU, A, s); linalg_batch_getrf(LU, pivot, false, s); @@ -921,6 +927,9 @@ struct det_backward { using namespace mshadow; using namespace mshadow::expr; using namespace mxnet_op; + if (dA.shape_.Size() == 0U) { + return; + } // compute inverse(A) and stores it to LU linalg_batch_det_backward_helper(LU, pivot, det, dA, DType(0), ctx); const_cast&>(dA) = broadcast_to(reshape(det * ddet, \ @@ -949,6 +958,9 @@ struct slogdet_backward { using namespace mshadow; using namespace mshadow::expr; using namespace mxnet_op; + if (dA.shape_.Size() == 0U) { + return; + } // compute inverse(A) and stores it to LU linalg_batch_det_backward_helper(LU, pivot, logabsdet, dA, DType(-INFINITY), ctx); const_cast&>(dA) = broadcast_to(reshape(dlogabsdet, \ diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index bb56dc52b594..09fcca1bb26c 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -434,6 +434,9 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs, } else { out = mxnet::TShape(1, 1); } + if (in.Size() == 0U) { + out = mxnet::TShape(0, -1); + } } else { out = mxnet::TShape(in.begin(), in.end() - 2); } diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 6d9c63f9f857..600c62b2d6f5 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -278,6 +278,16 @@ def _add_workload_linalg_inv(): OpArgMngr.add_workload('linalg.inv', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) +def _add_workload_linalg_det(): + OpArgMngr.add_workload('linalg.det', np.array(_np.ones((2, 2)), dtype=np.float32)) + OpArgMngr.add_workload('linalg.det', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) + + +def _add_workload_linalg_slogdet(): + OpArgMngr.add_workload('linalg.slogdet', np.array(_np.ones((2, 2)), dtype=np.float32)) + OpArgMngr.add_workload('linalg.slogdet', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) + + def _add_workload_trace(): OpArgMngr.add_workload('trace', np.random.uniform(size=(4, 1))) OpArgMngr.add_workload('trace', np.random.uniform(size=(3, 2))) @@ -1222,6 +1232,8 @@ def _prepare_workloads(): _add_workload_linalg_norm() _add_workload_linalg_cholesky() _add_workload_linalg_inv() + _add_workload_linalg_det() + _add_workload_linalg_slogdet() _add_workload_trace() _add_workload_tril() _add_workload_outer() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index d112bb107d2c..0c3fa062a39f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3222,6 +3222,8 @@ def check_inv(A_inv, data_np): check_inv(A_inv, data_np) +@with_seed() +@use_np def test_np_linalg_det(): class TestDet(HybridBlock): def __init__(self): @@ -3233,38 +3235,39 @@ def hybrid_forward(self, F, a): # test non zero size input tensor_shapes = [ (5, 5), + (0, 2, 2), + (2, 0, 2, 2), (3, 3, 3), (2, 2, 2, 2, 2), (1, 1) ] + types = [_np.float32, _np.float64] - for hybridize in [True, False]: - for shape in tensor_shapes: - for dtype in [_np.float32, _np.float64]: - a_shape = (1,) + shape - test_det = TestDet() - if hybridize: - test_det.hybridize() - a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() - a.attach_grad() + for hybridize, dtype, shape in itertools.product([True, False], types, tensor_shapes): + a_shape = (1,) + shape + test_det = TestDet() + if hybridize: + test_det.hybridize() + a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() + a.attach_grad() - np_out = _np.linalg.det(a.asnumpy()) - with mx.autograd.record(): - mx_out = test_det(a) - assert mx_out.shape == np_out.shape - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1) - mx_out.backward() + np_out = _np.linalg.det(a.asnumpy()) + with mx.autograd.record(): + mx_out = test_det(a) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1) + #mx_out.backward() - # Test imperative once again - mx_out = np.linalg.det(a) - np_out = _np.linalg.det(a.asnumpy()) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1) + # Test imperative once again + mx_out = np.linalg.det(a) + np_out = _np.linalg.det(a.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1) - # test numeric gradient - a_sym = mx.sym.Variable("a").as_np_ndarray() - mx_sym = mx.sym.np.linalg.det(a_sym).as_nd_ndarray() - check_numeric_gradient(mx_sym, [a.as_nd_ndarray()], - rtol=1e-1, atol=1e-1, dtype=dtype) + # test numeric gradient + a_sym = mx.sym.Variable("a").as_np_ndarray() + mx_sym = mx.sym.np.linalg.det(a_sym).as_nd_ndarray() + check_numeric_gradient(mx_sym, [a.as_nd_ndarray()], + rtol=1e-1, atol=1e-1, dtype=dtype) @with_seed() @@ -3281,33 +3284,35 @@ def hybrid_forward(self, F, a): tensor_shapes = [ (5, 5), (3, 3, 3), + (0, 2, 2), + (2, 0, 2, 2), (2, 2, 2, 2, 2), (1, 1) ] - for hybridize in [True, False]: - for a_shape in tensor_shapes: - for dtype in [_np.float32, _np.float64]: - test_slogdet = TestSlogdet() - if hybridize: - test_slogdet.hybridize() - a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() - a.attach_grad() + types = [_np.float32, _np.float64] - np_out = _np.linalg.slogdet(a.asnumpy()) - with mx.autograd.record(): - mx_out = test_slogdet(a) - assert mx_out[0].shape == np_out[0].shape - assert mx_out[1].shape == np_out[1].shape - assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol = 1e-1, atol = 1e-1) - assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol = 1e-1, atol = 1e-1) - mx_out[1].backward() + for hybridize, a_shape, dtype in itertools.product([True, False], tensor_shapes, types): + test_slogdet = TestSlogdet() + if hybridize: + test_slogdet.hybridize() + a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() + a.attach_grad() - # Test imperative once again - mx_out = np.linalg.slogdet(a) - np_out = _np.linalg.slogdet(a.asnumpy()) - assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol=1e-1, atol=1e-1) - assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol=1e-1, atol=1e-1) + np_out = _np.linalg.slogdet(a.asnumpy()) + with mx.autograd.record(): + mx_out = test_slogdet(a) + assert mx_out[0].shape == np_out[0].shape + assert mx_out[1].shape == np_out[1].shape + assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol = 1e-1, atol = 1e-1) + assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol = 1e-1, atol = 1e-1) + mx_out[1].backward() + + # Test imperative once again + mx_out = np.linalg.slogdet(a) + np_out = _np.linalg.slogdet(a.asnumpy()) + assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol=1e-1, atol=1e-1) + assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol=1e-1, atol=1e-1) @with_seed() From 6d02b534c352bb05f4d417f4d527be093751e23f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 15 Nov 2019 02:58:20 +0000 Subject: [PATCH 04/11] interoperability test --- src/operator/tensor/la_op-inl.h | 3 +-- tests/python/unittest/test_numpy_op.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index d580cced4ec5..b57c6bd08299 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -502,9 +502,8 @@ struct det { static void op(const Tensor& A, const Tensor& det, const Tensor& LU, const Tensor& pivot, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { - if (A.shape_.Size() == 0U) { + if (A.shape_.Size() == 0U) return; - } Stream *s = ctx.get_stream(); Tensor sign = ctx.requested[0] .get_space_typed(det.shape_, s); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index a60f36484b54..7802a2b7c379 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3265,29 +3265,32 @@ def hybrid_forward(self, F, a): # test non zero size input tensor_shapes = [ - (5, 5), - (0, 2, 2), (2, 0, 2, 2), + (5, 5), + (0, 2, 2, 2), (3, 3, 3), + (0, 2, 2), (2, 2, 2, 2, 2), - (1, 1) + (1, 1), ] types = [_np.float32, _np.float64] - for hybridize, dtype, shape in itertools.product([True, False], types, tensor_shapes): + #for hybridize, dtype, shape in itertools.product([True, False], types, tensor_shapes): + hybridize = False + dtype = _np.float32 + for shape in tensor_shapes: a_shape = (1,) + shape test_det = TestDet() if hybridize: test_det.hybridize() a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() a.attach_grad() - np_out = _np.linalg.det(a.asnumpy()) with mx.autograd.record(): mx_out = test_det(a) assert mx_out.shape == np_out.shape assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1) - #mx_out.backward() + mx_out.backward() # Test imperative once again mx_out = np.linalg.det(a) @@ -3297,8 +3300,9 @@ def hybrid_forward(self, F, a): # test numeric gradient a_sym = mx.sym.Variable("a").as_np_ndarray() mx_sym = mx.sym.np.linalg.det(a_sym).as_nd_ndarray() - check_numeric_gradient(mx_sym, [a.as_nd_ndarray()], - rtol=1e-1, atol=1e-1, dtype=dtype) + if 0 not in shape: + check_numeric_gradient(mx_sym, [a.as_nd_ndarray()], + rtol=1e-1, atol=1e-1, dtype=dtype) @with_seed() From b0ed6c9b5df48ab94f404224f46f3e6506ea0341 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 15 Nov 2019 03:03:12 +0000 Subject: [PATCH 05/11] fix tests --- tests/python/unittest/test_numpy_op.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7802a2b7c379..6954809788e0 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3275,10 +3275,7 @@ def hybrid_forward(self, F, a): ] types = [_np.float32, _np.float64] - #for hybridize, dtype, shape in itertools.product([True, False], types, tensor_shapes): - hybridize = False - dtype = _np.float32 - for shape in tensor_shapes: + for hybridize, dtype, shape in itertools.product([True, False], types, tensor_shapes): a_shape = (1,) + shape test_det = TestDet() if hybridize: From d23a1e339f11d1601dfe83b749219bfd5959dc79 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 16 Nov 2019 08:30:27 +0000 Subject: [PATCH 06/11] beautify --- src/operator/tensor/la_op-inl.h | 3 ++- src/operator/tensor/la_op.h | 5 +---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index b57c6bd08299..d580cced4ec5 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -502,8 +502,9 @@ struct det { static void op(const Tensor& A, const Tensor& det, const Tensor& LU, const Tensor& pivot, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { - if (A.shape_.Size() == 0U) + if (A.shape_.Size() == 0U) { return; + } Stream *s = ctx.get_stream(); Tensor sign = ctx.requested[0] .get_space_typed(det.shape_, s); diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 09fcca1bb26c..b5716744fcdd 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -429,14 +429,11 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal"; mxnet::TShape out; if (ndim == 2) { - if (Imperative::Get()->is_np_shape()) { + if (Imperative::Get()->is_np_shape() || in.Size() == 0U) { out = mxnet::TShape(0, 1); } else { out = mxnet::TShape(1, 1); } - if (in.Size() == 0U) { - out = mxnet::TShape(0, -1); - } } else { out = mxnet::TShape(in.begin(), in.end() - 2); } From 0c3443a6b0ab549f9c306bfc62ba5bd30e6ce82b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Nov 2019 09:04:18 +0000 Subject: [PATCH 07/11] fix slogdet --- src/operator/tensor/la_op.h | 7 ++-- tests/python/unittest/test_numpy_op.py | 47 ++++++++++++++++++++------ 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index b5716744fcdd..320338f42da0 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -901,19 +901,22 @@ void LaOpDetBackward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { using namespace mshadow; + if (outputs[0].shape_.Size() == 0U) { + return; + } Stream *s = ctx.get_stream(); CHECK_EQ(inputs.size(), onum + 3); CHECK_EQ(outputs.size(), 1); MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { std::vector tspace(outputs); - for ( int i = 0; i < onum; ++i ) { + for ( int i = 0; i < outputs.size(); ++i ) { if ( req[i] == kAddTo ) { tspace[i].dptr_ = ctx.requested[0] .get_space_typed(Shape1(outputs[i].Size()), s).dptr_; } } LaOpDetBackwardCaller::op(inputs, tspace, attrs, ctx); - for ( int i = 0; i < onum; ++i ) { + for ( int i = 0; i < outputs.size(); ++i ) { if ( req[i] == kAddTo ) { Tensor out = outputs[i].FlatTo1D(s); out += tspace[i].FlatTo1D(s); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 0aef8d9e7a68..5d57110f1dcf 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3329,20 +3329,23 @@ def hybrid_forward(self, F, a): (1, 1), ] types = [_np.float32, _np.float64] + grad_reqs = ['write', 'add', 'null'] - for hybridize, dtype, shape in itertools.product([True, False], types, tensor_shapes): + for hybridize, dtype, shape, grad_req in itertools.product([True, False], types, tensor_shapes, grad_reqs): a_shape = (1,) + shape test_det = TestDet() if hybridize: test_det.hybridize() a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() - a.attach_grad() + a.attach_grad(grad_req) np_out = _np.linalg.det(a.asnumpy()) with mx.autograd.record(): mx_out = test_det(a) assert mx_out.shape == np_out.shape assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1) - mx_out.backward() + if grad_req != 'null': + print(shape, grad_req) + mx_out.backward() # Test imperative once again mx_out = np.linalg.det(a) @@ -3352,7 +3355,7 @@ def hybrid_forward(self, F, a): # test numeric gradient a_sym = mx.sym.Variable("a").as_np_ndarray() mx_sym = mx.sym.np.linalg.det(a_sym).as_nd_ndarray() - if 0 not in shape: + if 0 not in shape and grad_req != 'null': check_numeric_gradient(mx_sym, [a.as_nd_ndarray()], rtol=1e-1, atol=1e-1, dtype=dtype) @@ -3369,31 +3372,34 @@ def hybrid_forward(self, F, a): # test non zero size input tensor_shapes = [ + (2, 0, 2, 2), (5, 5), + (0, 2, 2, 2), (3, 3, 3), (0, 2, 2), - (2, 0, 2, 2), (2, 2, 2, 2, 2), - (1, 1) + (1, 1), ] types = [_np.float32, _np.float64] + grad_reqs = ['write', 'add', 'null'] - for hybridize, a_shape, dtype in itertools.product([True, False], tensor_shapes, types): + for hybridize, a_shape, dtype, grad_req in itertools.product([True, False], tensor_shapes, types, grad_reqs): test_slogdet = TestSlogdet() if hybridize: test_slogdet.hybridize() a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray() - a.attach_grad() + a.attach_grad(grad_req) np_out = _np.linalg.slogdet(a.asnumpy()) with mx.autograd.record(): mx_out = test_slogdet(a) assert mx_out[0].shape == np_out[0].shape assert mx_out[1].shape == np_out[1].shape - assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol = 1e-1, atol = 1e-1) - assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol = 1e-1, atol = 1e-1) - mx_out[1].backward() + assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol=1e-1, atol=1e-1) + assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol=1e-1, atol=1e-1) + if grad_req != 'null': + mx_out[1].backward() # Test imperative once again mx_out = np.linalg.slogdet(a) @@ -3401,6 +3407,13 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol=1e-1, atol=1e-1) assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol=1e-1, atol=1e-1) + # test numeric gradient + a_sym = mx.sym.Variable("a").as_np_ndarray() + mx_sym = mx.sym.np.linalg.det(a_sym).as_nd_ndarray() + if 0 not in a_shape and grad_req != 'null': + check_numeric_gradient(mx_sym, [a.as_nd_ndarray()], + rtol=1e-1, atol=1e-1, dtype=dtype) + @with_seed() @use_np @@ -4686,6 +4699,18 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + +@with_seed() +@use_np +def test_np_linalg_debug(): + #test_np_linalg_cholesky() + test_np_linalg_det() + test_np_linalg_inv() + test_np_linalg_norm() + test_np_linalg_slogdet() + test_np_linalg_svd() + + if __name__ == '__main__': import nose nose.runmodule() From 3fdd588753671bb66eaa1cd96728cdfaf860fc62 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Nov 2019 09:16:21 +0000 Subject: [PATCH 08/11] remove slogdet numeric gradient test because it doesn't have that in the first place --- tests/python/unittest/test_numpy_op.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 309e8aa1cd60..810035833bb7 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3467,13 +3467,6 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out[0].asnumpy(), np_out[0], rtol=1e-1, atol=1e-1) assert_almost_equal(mx_out[1].asnumpy(), np_out[1], rtol=1e-1, atol=1e-1) - # test numeric gradient - a_sym = mx.sym.Variable("a").as_np_ndarray() - mx_sym = mx.sym.np.linalg.det(a_sym).as_nd_ndarray() - if 0 not in a_shape and grad_req != 'null': - check_numeric_gradient(mx_sym, [a.as_nd_ndarray()], - rtol=1e-1, atol=1e-1, dtype=dtype) - @with_seed() @use_np From f58c7b5b7a04dc9e6f307d8d6807b17089900108 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Nov 2019 12:45:03 +0000 Subject: [PATCH 09/11] why CI's dead --- tests/python/unittest/test_numpy_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 810035833bb7..a65a913d8f57 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3381,7 +3381,7 @@ def hybrid_forward(self, F, a): # test non zero size input tensor_shapes = [ (2, 0, 2, 2), - (5, 5), + (4, 4), (0, 2, 2, 2), (3, 3, 3), (0, 2, 2), From 8d268fff66184a8b8b51bb2c46ce327e8f879ea1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 21 Nov 2019 05:03:05 +0000 Subject: [PATCH 10/11] CI please --- tests/python/unittest/test_numpy_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index c953f601a652..9f9da587b6ba 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3440,7 +3440,7 @@ def hybrid_forward(self, F, a): (5, 5), (0, 2, 2, 2), (3, 3, 3), - (0, 2, 2), + (0, 3, 3), (2, 2, 2, 2, 2), (1, 1), ] From f8202003c482d058f85b5729752bf938c57e3f40 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 22 Nov 2019 04:12:20 +0000 Subject: [PATCH 11/11] size_t --- src/operator/tensor/la_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 320338f42da0..5fe7a92e2a12 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -909,14 +909,14 @@ void LaOpDetBackward(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1); MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { std::vector tspace(outputs); - for ( int i = 0; i < outputs.size(); ++i ) { + for ( size_t i = 0; i < outputs.size(); ++i ) { if ( req[i] == kAddTo ) { tspace[i].dptr_ = ctx.requested[0] .get_space_typed(Shape1(outputs[i].Size()), s).dptr_; } } LaOpDetBackwardCaller::op(inputs, tspace, attrs, ctx); - for ( int i = 0; i < outputs.size(); ++i ) { + for ( size_t i = 0; i < outputs.size(); ++i ) { if ( req[i] == kAddTo ) { Tensor out = outputs[i].FlatTo1D(s); out += tspace[i].FlatTo1D(s);