diff --git a/benchmark/python/einsum/benchmark_einsum.py b/benchmark/python/einsum/benchmark_einsum.py new file mode 100644 index 000000000000..3593de2db9e1 --- /dev/null +++ b/benchmark/python/einsum/benchmark_einsum.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +import mxnet as mx +from mxnet import np, npx + +def measure_cost(repeat, func_name, *args, **kwargs): + """Measure time cost of running a function + """ + mx.nd.waitall() + start = time.time() + for _ in range(repeat): + func_name(*args, **kwargs) + mx.nd.waitall() + end = time.time() + diff = end - start + return diff / repeat + + +def test_np_einsum(): + print("Path optimization test:") + # Basic einsum + a = np.ones(64).reshape(2,4,8) + args = ['ijk,ilm,njm,nlk,abc->', a, a, a, a, a] + cost = measure_cost(500, np.einsum, *args) + print("Basic einsum: {} ms".format(cost * 1000)) + + # Sub-optimal einsum + # cost = measure_cost(500, np.einsum, *args, optimize='optimal') + # print("Optimal einsum: {} ms".format(cost * 1000)) + + # Greedy einsum + cost = measure_cost(500, np.einsum, *args, optimize=True) + print("Greedy einsum: {} ms".format(cost * 1000)) + + print('Inner Product:') + a = np.ones(6000000) + b = np.ones(6000000) + args = [a, b] + cost = measure_cost(50, np.tensordot, *args, axes=([0],[0])) + print('Tensordot: {} ms'.format(cost * 1000)) + args = ['i, i', a, b] + cost = measure_cost(50, np.einsum, *args, optimize=True) + print('Greedy einsum: {} ms'.format(cost * 1000)) + cost = measure_cost(50, np.einsum, *args) + print('Basic einsum: {} ms'.format(cost * 1000)) + + print('Matrix Product:') + a = np.ones(600000).reshape(200, 3000) + b = np.ones(600000).reshape(3000, 200) + args = [a, b] + cost = measure_cost(50, np.tensordot, *args, axes=([1],[0])) + print('Tensordot: {} ms'.format(cost * 1000)) + args = ['ij, jk', a, b] + cost = measure_cost(50, np.einsum, *args, optimize=True) + print('Greedy einsum: {} ms'.format(cost * 1000)) + cost = measure_cost(50, np.einsum, *args) + print('Basic einsum: {} ms'.format(cost * 1000)) + + +if __name__ == "__main__": + npx.set_np() + test_np_einsum() diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 08bcdb7833fb..ad76e43b2a90 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -38,7 +38,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', - 'hsplit', 'rot90'] + 'hsplit', 'rot90', 'einsum'] @set_module('mxnet.ndarray.numpy') @@ -4488,3 +4488,242 @@ def rot90(m, k=1, axes=(0, 1)): [4., 6.]]]) """ return _npi.rot90(m, k=k, axes=axes) + + +@set_module('mxnet.ndarray.numpy') +def einsum(*operands, **kwargs): + r""" + einsum(subscripts, *operands, out=None, optimize=False) + + Evaluates the Einstein summation convention on the operands. + + Using the Einstein summation convention, many common multi-dimensional, + linear algebraic array operations can be represented in a simple fashion. + In *implicit* mode `einsum` computes these values. + + In *explicit* mode, `einsum` provides further flexibility to compute + other array operations that might not be considered classical Einstein + summation operations, by disabling, or forcing summation over specified + subscript labels. + + See the notes and examples for clarification. + + Parameters + ---------- + subscripts : str + Specifies the subscripts for summation as comma separated list of + subscript labels. An implicit (classical Einstein summation) + calculation is performed unless the explicit indicator '->' is + included as well as subscript labels of the precise output form. + operands : list of ndarray + These are the arrays for the operation. + out : ndarray, optional + If provided, the calculation is done into this array. + optimize : {False, True}, optional + Controls if intermediate optimization should occur. No optimization + will occur if False. Defaults to False. + + Returns + ------- + output : ndarray + The calculation based on the Einstein summation convention. + + Notes + ----- + The Einstein summation convention can be used to compute + many multi-dimensional, linear algebraic array operations. `einsum` + provides a succinct way of representing these. + + A non-exhaustive list of these operations, + which can be computed by `einsum`, is shown below along with examples: + + * Trace of an array, :py:func:`np.trace`. + * Return a diagonal, :py:func:`np.diag`. + * Array axis summations, :py:func:`np.sum`. + * Transpositions and permutations, :py:func:`np.transpose`. + * Matrix multiplication and dot product, :py:func:`np.matmul` :py:func:`np.dot`. + * Vector inner and outer products, :py:func:`np.inner` :py:func:`np.outer`. + * Broadcasting, element-wise and scalar multiplication, :py:func:`np.multiply`. + * Tensor contractions, :py:func:`np.tensordot`. + + The subscripts string is a comma-separated list of subscript labels, + where each label refers to a dimension of the corresponding operand. + Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` + is equivalent to :py:func:`np.inner(a,b) `. If a label + appears only once, it is not summed, so ``np.einsum('i', a)`` produces a + view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)`` + describes traditional matrix multiplication and is equivalent to + :py:func:`np.matmul(a,b) `. Repeated subscript labels in one + operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent + to :py:func:`np.trace(a) `. + + In *implicit mode*, the chosen subscripts are important + since the axes of the output are reordered alphabetically. This + means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while + ``np.einsum('ji', a)`` takes its transpose. Additionally, + ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, + ``np.einsum('ij,jh', a, b)`` returns the transpose of the + multiplication since subscript 'h' precedes subscript 'i'. + + In *explicit mode* the output can be directly controlled by + specifying output subscript labels. This requires the + identifier '->' as well as the list of output subscript labels. + This feature increases the flexibility of the function since + summing can be disabled or forced when required. The call + ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) `, + and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) `. + The difference is that `einsum` does not allow broadcasting by default. + Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the + order of the output subscript labels and therefore returns matrix + multiplication, unlike the example above in implicit mode. + + To enable and control broadcasting, use an ellipsis. Default + NumPy-style broadcasting is done by adding an ellipsis + to the left of each term, like ``np.einsum('...ii->...i', a)``. + To take the trace along the first and last axes, + you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix + product with the left-most indices instead of rightmost, one can do + ``np.einsum('ij...,jk...->ik...', a, b)``. + + When there is only one operand, no axes are summed, and no output + parameter is provided, a view into the operand is returned instead + of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` + produces a view. + + The ``optimize`` argument which will optimize the contraction order + of an einsum expression. For a contraction with three or more operands this + can greatly increase the computational efficiency at the cost of a larger + memory footprint during computation. + + Typically a 'greedy' algorithm is applied which empirical tests have shown + returns the optimal path in the majority of cases. 'optimal' is not supported + for now. + + This function differs from the original `numpy.einsum + `_ in + the following way(s): + + - Does not support 'optimal' strategy + - Does not support the alternative subscript like + `einsum(op0, sublist0, op1, sublist1, ..., [sublistout])` + - Does not produce view in any cases + + Examples + -------- + >>> a = np.arange(25).reshape(5,5) + >>> b = np.arange(5) + >>> c = np.arange(6).reshape(2,3) + + Trace of a matrix: + + >>> np.einsum('ii', a) + array(60.) + + Extract the diagonal (requires explicit form): + + >>> np.einsum('ii->i', a) + array([ 0., 6., 12., 18., 24.]) + + Sum over an axis (requires explicit form): + + >>> np.einsum('ij->i', a) + array([ 10., 35., 60., 85., 110.]) + >>> np.sum(a, axis=1) + array([ 10., 35., 60., 85., 110.]) + + For higher dimensional arrays summing a single axis can be done with ellipsis: + + >>> np.einsum('...j->...', a) + array([ 10., 35., 60., 85., 110.]) + + Compute a matrix transpose, or reorder any number of axes: + + >>> np.einsum('ji', c) + array([[0., 3.], + [1., 4.], + [2., 5.]]) + >>> np.einsum('ij->ji', c) + array([[0., 3.], + [1., 4.], + [2., 5.]]) + >>> np.transpose(c) + array([[0., 3.], + [1., 4.], + [2., 5.]]) + + Vector inner products: + + >>> np.einsum('i,i', b, b) + array(30.) + + Matrix vector multiplication: + + >>> np.einsum('ij,j', a, b) + array([ 30., 80., 130., 180., 230.]) + >>> np.dot(a, b) + array([ 30., 80., 130., 180., 230.]) + >>> np.einsum('...j,j', a, b) + array([ 30., 80., 130., 180., 230.]) + + Broadcasting and scalar multiplication: + + >>> np.einsum('..., ...', np.array(3), c) + array([[ 0., 3., 6.], + [ 9., 12., 15.]]) + >>> np.einsum(',ij', np.array(3), c) + array([[ 0., 3., 6.], + [ 9., 12., 15.]]) + >>> np.multiply(3, c) + array([[ 0., 3., 6.], + [ 9., 12., 15.]]) + + Vector outer product: + + >>> np.einsum('i,j', np.arange(2)+1, b) + array([[0., 1., 2., 3., 4.], + [0., 2., 4., 6., 8.]]) + + Tensor contraction: + + >>> a = np.arange(60.).reshape(3,4,5) + >>> b = np.arange(24.).reshape(4,3,2) + >>> np.einsum('ijk,jil->kl', a, b) + array([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) + + Example of ellipsis use: + + >>> a = np.arange(6).reshape((3,2)) + >>> b = np.arange(12).reshape((4,3)) + >>> np.einsum('ki,jk->ij', a, b) + array([[10., 28., 46., 64.], + [13., 40., 67., 94.]]) + >>> np.einsum('ki,...k->i...', a, b) + array([[10., 28., 46., 64.], + [13., 40., 67., 94.]]) + >>> np.einsum('k...,jk', a, b) + array([[10., 28., 46., 64.], + [13., 40., 67., 94.]]) + + Chained array operations. For more complicated contractions, speed ups + might be achieved by repeatedly computing a 'greedy' path. Performance + improvements can be particularly significant with larger arrays: + + >>> a = np.ones(64).reshape(2,4,8) + # Basic `einsum`: ~42.22ms (benchmarked on 3.4GHz Intel Xeon.) + >>> for iteration in range(500): + ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) + # Greedy `einsum` (faster optimal path approximation): ~0.117ms + >>> for iteration in range(500): + ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=True) + """ + # Grab non-einsum kwargs; do not optimize by default. + optimize_arg = kwargs.pop('optimize', False) + out = kwargs.pop('out', None) + + subscripts = operands[0] + operands = operands[1:] + return _npi.einsum(*operands, subscripts=subscripts, out=out, optimize=int(optimize_arg)) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index aec4f9ef4617..e507b17d68d8 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -55,7 +55,7 @@ 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', - 'greater_equal', 'less_equal', 'hsplit', 'rot90'] + 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -5959,3 +5959,236 @@ def hsplit(ary, indices_or_sections): [array([0., 1.]), array([], dtype=float32), array([2., 3.])] """ return _mx_nd_np.hsplit(ary, indices_or_sections) + + +@set_module('mxnet.numpy') +def einsum(*operands, **kwargs): + r""" + einsum(subscripts, *operands, out=None, optimize=False) + + Evaluates the Einstein summation convention on the operands. + + Using the Einstein summation convention, many common multi-dimensional, + linear algebraic array operations can be represented in a simple fashion. + In *implicit* mode `einsum` computes these values. + + In *explicit* mode, `einsum` provides further flexibility to compute + other array operations that might not be considered classical Einstein + summation operations, by disabling, or forcing summation over specified + subscript labels. + + See the notes and examples for clarification. + + Parameters + ---------- + subscripts : str + Specifies the subscripts for summation as comma separated list of + subscript labels. An implicit (classical Einstein summation) + calculation is performed unless the explicit indicator '->' is + included as well as subscript labels of the precise output form. + operands : list of ndarray + These are the arrays for the operation. + out : ndarray, optional + If provided, the calculation is done into this array. + optimize : {False, True}, optional + Controls if intermediate optimization should occur. No optimization + will occur if False. Defaults to False. + + Returns + ------- + output : ndarray + The calculation based on the Einstein summation convention. + + Notes + ----- + The Einstein summation convention can be used to compute + many multi-dimensional, linear algebraic array operations. `einsum` + provides a succinct way of representing these. + + A non-exhaustive list of these operations, + which can be computed by `einsum`, is shown below along with examples: + + * Trace of an array, :py:func:`np.trace`. + * Return a diagonal, :py:func:`np.diag`. + * Array axis summations, :py:func:`np.sum`. + * Transpositions and permutations, :py:func:`np.transpose`. + * Matrix multiplication and dot product, :py:func:`np.matmul` :py:func:`np.dot`. + * Vector inner and outer products, :py:func:`np.inner` :py:func:`np.outer`. + * Broadcasting, element-wise and scalar multiplication, :py:func:`np.multiply`. + * Tensor contractions, :py:func:`np.tensordot`. + + The subscripts string is a comma-separated list of subscript labels, + where each label refers to a dimension of the corresponding operand. + Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` + is equivalent to :py:func:`np.inner(a,b) `. If a label + appears only once, it is not summed, so ``np.einsum('i', a)`` produces a + view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)`` + describes traditional matrix multiplication and is equivalent to + :py:func:`np.matmul(a,b) `. Repeated subscript labels in one + operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent + to :py:func:`np.trace(a) `. + + In *implicit mode*, the chosen subscripts are important + since the axes of the output are reordered alphabetically. This + means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while + ``np.einsum('ji', a)`` takes its transpose. Additionally, + ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, + ``np.einsum('ij,jh', a, b)`` returns the transpose of the + multiplication since subscript 'h' precedes subscript 'i'. + + In *explicit mode* the output can be directly controlled by + specifying output subscript labels. This requires the + identifier '->' as well as the list of output subscript labels. + This feature increases the flexibility of the function since + summing can be disabled or forced when required. The call + ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) `, + and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) `. + The difference is that `einsum` does not allow broadcasting by default. + Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the + order of the output subscript labels and therefore returns matrix + multiplication, unlike the example above in implicit mode. + + To enable and control broadcasting, use an ellipsis. Default + NumPy-style broadcasting is done by adding an ellipsis + to the left of each term, like ``np.einsum('...ii->...i', a)``. + To take the trace along the first and last axes, + you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix + product with the left-most indices instead of rightmost, one can do + ``np.einsum('ij...,jk...->ik...', a, b)``. + + When there is only one operand, no axes are summed, and no output + parameter is provided, a view into the operand is returned instead + of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` + produces a view. + + The ``optimize`` argument which will optimize the contraction order + of an einsum expression. For a contraction with three or more operands this + can greatly increase the computational efficiency at the cost of a larger + memory footprint during computation. + + Typically a 'greedy' algorithm is applied which empirical tests have shown + returns the optimal path in the majority of cases. 'optimal' is not supported + for now. + + This function differs from the original `numpy.einsum + `_ in + the following way(s): + + - Does not support 'optimal' strategy + - Does not support the alternative subscript like + `einsum(op0, sublist0, op1, sublist1, ..., [sublistout])` + - Does not produce view in any cases + + Examples + -------- + >>> a = np.arange(25).reshape(5,5) + >>> b = np.arange(5) + >>> c = np.arange(6).reshape(2,3) + + Trace of a matrix: + + >>> np.einsum('ii', a) + array(60.) + + Extract the diagonal (requires explicit form): + + >>> np.einsum('ii->i', a) + array([ 0., 6., 12., 18., 24.]) + + Sum over an axis (requires explicit form): + + >>> np.einsum('ij->i', a) + array([ 10., 35., 60., 85., 110.]) + >>> np.sum(a, axis=1) + array([ 10., 35., 60., 85., 110.]) + + For higher dimensional arrays summing a single axis can be done with ellipsis: + + >>> np.einsum('...j->...', a) + array([ 10., 35., 60., 85., 110.]) + + Compute a matrix transpose, or reorder any number of axes: + + >>> np.einsum('ji', c) + array([[0., 3.], + [1., 4.], + [2., 5.]]) + >>> np.einsum('ij->ji', c) + array([[0., 3.], + [1., 4.], + [2., 5.]]) + >>> np.transpose(c) + array([[0., 3.], + [1., 4.], + [2., 5.]]) + + Vector inner products: + + >>> np.einsum('i,i', b, b) + array(30.) + + Matrix vector multiplication: + + >>> np.einsum('ij,j', a, b) + array([ 30., 80., 130., 180., 230.]) + >>> np.dot(a, b) + array([ 30., 80., 130., 180., 230.]) + >>> np.einsum('...j,j', a, b) + array([ 30., 80., 130., 180., 230.]) + + Broadcasting and scalar multiplication: + + >>> np.einsum('..., ...', np.array(3), c) + array([[ 0., 3., 6.], + [ 9., 12., 15.]]) + >>> np.einsum(',ij', np.array(3), c) + array([[ 0., 3., 6.], + [ 9., 12., 15.]]) + >>> np.multiply(3, c) + array([[ 0., 3., 6.], + [ 9., 12., 15.]]) + + Vector outer product: + + >>> np.einsum('i,j', np.arange(2)+1, b) + array([[0., 1., 2., 3., 4.], + [0., 2., 4., 6., 8.]]) + + Tensor contraction: + + >>> a = np.arange(60.).reshape(3,4,5) + >>> b = np.arange(24.).reshape(4,3,2) + >>> np.einsum('ijk,jil->kl', a, b) + array([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) + + Example of ellipsis use: + + >>> a = np.arange(6).reshape((3,2)) + >>> b = np.arange(12).reshape((4,3)) + >>> np.einsum('ki,jk->ij', a, b) + array([[10., 28., 46., 64.], + [13., 40., 67., 94.]]) + >>> np.einsum('ki,...k->i...', a, b) + array([[10., 28., 46., 64.], + [13., 40., 67., 94.]]) + >>> np.einsum('k...,jk', a, b) + array([[10., 28., 46., 64.], + [13., 40., 67., 94.]]) + + Chained array operations. For more complicated contractions, speed ups + might be achieved by repeatedly computing a 'greedy' path. Performance + improvements can be particularly significant with larger arrays: + + >>> a = np.ones(64).reshape(2,4,8) + # Basic `einsum`: ~42.22ms (benchmarked on 3.4GHz Intel Xeon.) + >>> for iteration in range(500): + ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) + # Greedy `einsum` (faster optimal path approximation): ~0.117ms + >>> for iteration in range(500): + ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=True) + """ + return _mx_nd_np.einsum(*operands, **kwargs) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 1e1824fe43b4..a241d2687ee4 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -114,7 +114,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'var', 'zeros_like', 'meshgrid', - 'outer' + 'outer', + 'einsum' ] diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 5e535ef3ef84..1945c5b0e695 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -40,7 +40,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', - 'less_equal', 'hsplit', 'rot90'] + 'less_equal', 'hsplit', 'rot90', 'einsum'] def _num_outputs(sym): @@ -4428,4 +4428,131 @@ def rot90(m, k=1, axes=(0, 1)): return _npi.rot90(m, k=k, axes=axes) +@set_module('mxnet.symbol.numpy') +def einsum(*operands, **kwargs): + r""" + einsum(subscripts, *operands, out=None, optimize=False) + + Evaluates the Einstein summation convention on the operands. + + Using the Einstein summation convention, many common multi-dimensional, + linear algebraic array operations can be represented in a simple fashion. + In *implicit* mode `einsum` computes these values. + + In *explicit* mode, `einsum` provides further flexibility to compute + other array operations that might not be considered classical Einstein + summation operations, by disabling, or forcing summation over specified + subscript labels. + + See the notes and examples for clarification. + + Parameters + ---------- + subscripts : str + Specifies the subscripts for summation as comma separated list of + subscript labels. An implicit (classical Einstein summation) + calculation is performed unless the explicit indicator '->' is + included as well as subscript labels of the precise output form. + operands : list of _Symbol + These are the arrays for the operation. + out : _Symbol, optional + If provided, the calculation is done into this array. + optimize : {False, True}, optional + Controls if intermediate optimization should occur. No optimization + will occur if False. Defaults to False. + + Returns + ------- + output : _Symbol + The calculation based on the Einstein summation convention. + + Notes + ----- + The Einstein summation convention can be used to compute + many multi-dimensional, linear algebraic array operations. `einsum` + provides a succinct way of representing these. + + A non-exhaustive list of these operations, + which can be computed by `einsum`, is shown below along with examples: + + * Trace of an array, :py:func:`np.trace`. + * Return a diagonal, :py:func:`np.diag`. + * Array axis summations, :py:func:`np.sum`. + * Transpositions and permutations, :py:func:`np.transpose`. + * Matrix multiplication and dot product, :py:func:`np.matmul` :py:func:`np.dot`. + * Vector inner and outer products, :py:func:`np.inner` :py:func:`np.outer`. + * Broadcasting, element-wise and scalar multiplication, :py:func:`np.multiply`. + * Tensor contractions, :py:func:`np.tensordot`. + + The subscripts string is a comma-separated list of subscript labels, + where each label refers to a dimension of the corresponding operand. + Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` + is equivalent to :py:func:`np.inner(a,b) `. If a label + appears only once, it is not summed, so ``np.einsum('i', a)`` produces a + view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)`` + describes traditional matrix multiplication and is equivalent to + :py:func:`np.matmul(a,b) `. Repeated subscript labels in one + operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent + to :py:func:`np.trace(a) `. + + In *implicit mode*, the chosen subscripts are important + since the axes of the output are reordered alphabetically. This + means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while + ``np.einsum('ji', a)`` takes its transpose. Additionally, + ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, + ``np.einsum('ij,jh', a, b)`` returns the transpose of the + multiplication since subscript 'h' precedes subscript 'i'. + + In *explicit mode* the output can be directly controlled by + specifying output subscript labels. This requires the + identifier '->' as well as the list of output subscript labels. + This feature increases the flexibility of the function since + summing can be disabled or forced when required. The call + ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) `, + and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) `. + The difference is that `einsum` does not allow broadcasting by default. + Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the + order of the output subscript labels and therefore returns matrix + multiplication, unlike the example above in implicit mode. + + To enable and control broadcasting, use an ellipsis. Default + NumPy-style broadcasting is done by adding an ellipsis + to the left of each term, like ``np.einsum('...ii->...i', a)``. + To take the trace along the first and last axes, + you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix + product with the left-most indices instead of rightmost, one can do + ``np.einsum('ij...,jk...->ik...', a, b)``. + + When there is only one operand, no axes are summed, and no output + parameter is provided, a view into the operand is returned instead + of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` + produces a view. + + The ``optimize`` argument which will optimize the contraction order + of an einsum expression. For a contraction with three or more operands this + can greatly increase the computational efficiency at the cost of a larger + memory footprint during computation. + + Typically a 'greedy' algorithm is applied which empirical tests have shown + returns the optimal path in the majority of cases. 'optimal' is not supported + for now. + + This function differs from the original `numpy.einsum + `_ in + the following way(s): + + - Does not support 'optimal' strategy + - Does not support the alternative subscript like + `einsum(op0, sublist0, op1, sublist1, ..., [sublistout])` + - Does not produce view in any cases + """ + # Grab non-einsum kwargs; do not optimize by default. + optimize_arg = kwargs.pop('optimize', False) + out = kwargs.pop('out', None) + + subscripts = operands[0] + operands = operands[1:] + return _npi.einsum(*operands, subscripts=subscripts, out=out, optimize=int(optimize_arg)) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 950db174595e..18b39b532388 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -159,6 +159,42 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "ndim=" << NDim << "too large "; \ } +#define MXNET_NDIM_SWITCH_EX(NDim, ndim, ...) \ + if (NDim == 0) { \ + } else if (NDim == 1) { \ + const int ndim = 1; \ + {__VA_ARGS__} \ + } else if (NDim == 2) { \ + const int ndim = 2; \ + {__VA_ARGS__} \ + } else if (NDim == 3) { \ + const int ndim = 3; \ + {__VA_ARGS__} \ + } else if (NDim == 4) { \ + const int ndim = 4; \ + {__VA_ARGS__} \ + } else if (NDim == 5) { \ + const int ndim = 5; \ + {__VA_ARGS__} \ + } else if (NDim == 6) { \ + const int ndim = 6; \ + {__VA_ARGS__} \ + } else if (NDim == 7) { \ + const int ndim = 7; \ + {__VA_ARGS__} \ + } else if (NDim == 8) { \ + const int ndim = 8; \ + {__VA_ARGS__} \ + } else if (NDim == 9) { \ + const int ndim = 9; \ + {__VA_ARGS__} \ + } else if (NDim == 10) { \ + const int ndim = 10; \ + {__VA_ARGS__} \ + } else { \ + LOG(FATAL) << "ndim=" << NDim << "too large "; \ + } + #define MXNET_NO_INT8_TYPE_SWITCH(type, DType, ...) \ switch (type) { \ case mshadow::kFloat32: \ diff --git a/src/operator/numpy/np_dot-inl.h b/src/operator/numpy/np_dot-inl.h index a854777c3109..fe63c0b0ec51 100644 --- a/src/operator/numpy/np_dot-inl.h +++ b/src/operator/numpy/np_dot-inl.h @@ -63,7 +63,13 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs, // of a and the 2nd-to-last axis of b const Tuple a_axes_summed({a_shape.ndim() - 1}); const Tuple b_axes_summed({b_shape.ndim() - 2}); - TensordotImpl(a_axes_summed, b_axes_summed, ctx, a, b, out, req); + size_t workspace_size = TensordotWorkspaceSize(a_axes_summed, + b_axes_summed, + a, b, out, + req); + Tensor workspace = ctx.requested[0].get_space_typed( + Shape1(workspace_size), ctx.get_stream()); + TensordotImpl(a_axes_summed, b_axes_summed, ctx, a, b, out, req, workspace); } }); } @@ -98,8 +104,14 @@ inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs, // of a and the 2nd-to-last axis of b const Tuple a_axes_summed({a_shape.ndim() - 1}); const Tuple b_axes_summed({b_shape.ndim() - 2}); + size_t workspace_size = TensordotBackwardWorkspaceSize(a_axes_summed, b_axes_summed, + ograd, a, b, grad_a, + grad_b, req); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), + ctx.get_stream()); TensordotBackwardImpl(a_axes_summed, b_axes_summed, ctx, ograd, a, b, grad_a, - grad_b, req); + grad_b, req, workspace); } }); } diff --git a/src/operator/numpy/np_dot.cc b/src/operator/numpy/np_dot.cc index 627e68877998..6afc896a7720 100644 --- a/src/operator/numpy/np_dot.cc +++ b/src/operator/numpy/np_dot.cc @@ -129,7 +129,7 @@ NNVM_REGISTER_OP(_np_dot) .set_attr("FInferType", ElemwiseType<2, 1>) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; + return std::vector(1, ResourceRequest::kTempSpace); }) .set_attr("FCompute", NumpyDotForward) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_np_dot"}) @@ -142,7 +142,7 @@ NNVM_REGISTER_OP(_backward_np_dot) .set_attr("TIsBackward", true) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; + return std::vector(1, ResourceRequest::kTempSpace); }) .set_attr("FCompute", NumpyDotBackward); diff --git a/src/operator/numpy/np_einsum_op-inl.h b/src/operator/numpy/np_einsum_op-inl.h new file mode 100644 index 000000000000..2145abec682b --- /dev/null +++ b/src/operator/numpy/np_einsum_op-inl.h @@ -0,0 +1,1092 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright (c) 2005-2019, NumPy Developers. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * * Neither the name of the NumPy Developers nor the names of any + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/*! + * \file np_einsum_op-inl.h + * \brief Function definition of numpy-compatible einsum operator + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_ + +#include +#include +#include +#include +#include "./np_tensordot_op-inl.h" +#include "./np_einsum_path_op-inl.h" +#include "../../common/static_array.h" +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +#define NPY_MAXDIMS 32 +#define NPY_MAXARGS 32 + +inline TShape get_stride(const TShape& shape) { + int ndim = shape.ndim(), prod = 1; + TShape stride = TShape(ndim, -1); + for (int i = ndim - 1; i >= 0; i--) { + stride[i] = shape[i] > 1 ? prod : 0; + prod = prod * shape[i]; + } + return stride; +} + +inline TShape pad(const TShape& shape, int odim) { + int ndim = shape.ndim(); + CHECK_GE(odim, ndim); + TShape ret(odim, 1); + for (int idim = 0; idim < ndim; ++idim) { + ret[idim] = shape[idim]; + } + return ret; +} + +/* + * Parses the subscripts for one operand into an output of 'ndim' + * labels. The resulting 'op_labels' array will have: + * - the ASCII code of the label for the first occurrence of a label; + * - the (negative) offset to the first occurrence of the label for + * repeated labels; + * - zero for broadcast dimensions, if subscripts has an ellipsis. + * For example: + * - subscripts="abbcbc", ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2] + * - subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99] + */ +inline int parse_operand_subscripts(const char *subscripts, int length, + int ndim, int iop, char *op_labels, + char *label_counts, int *min_label, int *max_label) { + using namespace mxnet_op; + int i; + int idim = 0; + int ellipsis = -1; + + /* Process all labels for this operand */ + for (i = 0; i < length; ++i) { + int label = subscripts[i]; + + /* A proper label for an axis. */ + if (label > 0 && isalpha(label)) { + /* Check we don't exceed the operator dimensions. */ + CHECK(idim < ndim) + << "einstein sum subscripts string contains " + << "too many subscripts for operand " + << iop; + + op_labels[idim++] = label; + if (label < *min_label) { + *min_label = label; + } + if (label > *max_label) { + *max_label = label; + } + label_counts[label]++; + } else if (label == '.') { + /* The beginning of the ellipsis. */ + /* Check it's a proper ellipsis. */ + CHECK(!(ellipsis != -1 || i + 2 >= length + || subscripts[++i] != '.' || subscripts[++i] != '.')) + << "einstein sum subscripts string contains a " + << "'.' that is not part of an ellipsis ('...') " + << "in operand " + << iop; + + ellipsis = idim; + } else { + CHECK(label == ' ') + << "invalid subscript '" << static_cast(label) + << "' in einstein sum " + << "subscripts string, subscripts must " + << "be letters"; + } + } + + /* No ellipsis found, labels must match dimensions exactly. */ + if (ellipsis == -1) { + CHECK(idim == ndim) + << "operand has more dimensions than subscripts " + << "given in einstein sum, but no '...' ellipsis " + << "provided to broadcast the extra dimensions."; + } else if (idim < ndim) { + /* Ellipsis found, may have to add broadcast dimensions. */ + /* Move labels after ellipsis to the end. */ + for (i = 0; i < idim - ellipsis; ++i) { + op_labels[ndim - i - 1] = op_labels[idim - i - 1]; + } + /* Set all broadcast dimensions to zero. */ + for (i = 0; i < ndim - idim; ++i) { + op_labels[ellipsis + i] = 0; + } + } + + /* + * Find any labels duplicated for this operand, and turn them + * into negative offsets to the axis to merge with. + * + * In C, the char type may be signed or unsigned, but with + * twos complement arithmetic the char is ok either way here, and + * later where it matters the char is cast to a signed char. + */ + for (idim = 0; idim < ndim - 1; ++idim) { + int label = op_labels[idim]; + /* If it is a proper label, find any duplicates of it. */ + if (label > 0) { + /* Search for the next matching label. */ + char *next = reinterpret_cast(memchr(op_labels + idim + 1, label, ndim - idim - 1)); + + while (next != NULL) { + /* The offset from next to op_labels[idim] (negative). */ + *next = static_cast((op_labels + idim) - next); + /* Search for the next matching label. */ + next = reinterpret_cast(memchr(next + 1, label, op_labels + ndim - 1 - next)); + } + } + } + return 0; +} + +/* + * Parses the subscripts for the output operand into an output that + * includes 'ndim_broadcast' unlabeled dimensions, and returns the total + * number of output dimensions, or -1 if there is an error. Similarly + * to parse_operand_subscripts, the 'out_labels' array will have, for + * each dimension: + * - the ASCII code of the corresponding label; + * - zero for broadcast dimensions, if subscripts has an ellipsis. + */ +inline int parse_output_subscripts(const char *subscripts, int length, + int ndim_broadcast, + const char *label_counts, char *out_labels) { + using namespace mxnet_op; + int i, bdim; + int ndim = 0; + int ellipsis = 0; + + /* Process all the output labels. */ + for (i = 0; i < length; ++i) { + int label = subscripts[i]; + + /* A proper label for an axis. */ + if (label > 0 && isalpha(label)) { + /* Check that it doesn't occur again. */ + CHECK(memchr(subscripts + i + 1, label, length - i - 1) == NULL) + << "einstein sum subscripts string includes " + << "output subscript '" << static_cast(label) + << "' multiple times"; + + /* Check that it was used in the inputs. */ + CHECK(label_counts[label] != 0) + << "einstein sum subscripts string included " + << "output subscript '" << static_cast(label) + << "' which never appeared " + << "in an input"; + + /* Check that there is room in out_labels for this label. */ + CHECK(ndim < NPY_MAXDIMS) + << "einstein sum subscripts string contains " + << "too many subscripts in the output"; + + out_labels[ndim++] = label; + } else if (label == '.') { + /* The beginning of the ellipsis. */ + /* Check it is a proper ellipsis. */ + CHECK(!(ellipsis || i + 2 >= length + || subscripts[++i] != '.' || subscripts[++i] != '.')) + << "einstein sum subscripts string " + << "contains a '.' that is not part of " + << "an ellipsis ('...') in the output"; + + /* Check there is room in out_labels for broadcast dims. */ + CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS) + << "einstein sum subscripts string contains " + << "too many subscripts in the output"; + + ellipsis = 1; + for (bdim = 0; bdim < ndim_broadcast; ++bdim) { + out_labels[ndim++] = 0; + } + } else { + CHECK(label == ' ') + << "invalid subscript '" << static_cast(label) + << "' in einstein sum " + << "subscripts string, subscripts must " + << "be letters"; + } + } + + /* If no ellipsis was found there should be no broadcast dimensions. */ + CHECK(!(!ellipsis && ndim_broadcast > 0)) + << "output has more dimensions than subscripts " + << "given in einstein sum, but no '...' ellipsis " + << "provided to broadcast the extra dimensions."; + + return ndim; +} + +inline void get_combined_dims_view(const TBlob& op, int iop, + char *labels, + TShape* newshape, + TShape* newstride) { + using namespace mxnet_op; + int idim, ndim, icombine, combineoffset; + int icombinemap[NPY_MAXDIMS]; + int newdim; + + const TShape& shape = op.shape_; + TShape stride = get_stride(shape); + ndim = op.shape_.ndim(); + newdim = newshape->ndim(); + + /* Initialize the dimensions and strides to zero */ + for (idim = 0; idim < newdim; ++idim) { + (*newshape)[idim] = 0; + (*newstride)[idim] = 0; + } + + /* Copy the dimensions and strides, except when collapsing */ + icombine = 0; + for (idim = 0; idim < ndim; ++idim) { + /* + * The char type may be either signed or unsigned, we + * need it to be signed here. + */ + int label = (signed char)labels[idim]; + /* If this label says to merge axes, get the actual label */ + if (label < 0) { + combineoffset = label; + label = labels[idim+label]; + } else { + combineoffset = 0; + if (icombine != idim) { + labels[icombine] = labels[idim]; + } + icombinemap[idim] = icombine; + } + /* If the label is 0, it's an unlabeled broadcast dimension */ + if (label == 0) { + (*newshape)[icombine] = shape[idim]; + (*newstride)[icombine] = stride[idim]; + } else { + /* Update the combined axis dimensions and strides */ + int i = icombinemap[idim + combineoffset]; + CHECK(!(combineoffset < 0 && (*newshape)[i] != 0 && + (*newshape)[i] != shape[idim])) + << "dimensions in operand " << iop + << " for collapsing index '" << label + << "' don't match (" << static_cast((*newshape)[i]) + << " != " << shape[idim] << ")"; + (*newshape)[i] = shape[idim]; + (*newstride)[i] += stride[idim]; + } + + /* If the label didn't say to combine axes, increment dest i */ + if (combineoffset == 0) { + icombine++; + } + } +} + +inline static int prepare_op_axes(int ndim, int iop, char *labels, + int *axes, int ndim_iter, char *iter_labels) { + using namespace mxnet_op; + int i, label, ibroadcast; + + ibroadcast = ndim-1; + for (i = ndim_iter-1; i >= 0; --i) { + label = iter_labels[i]; + /* + * If it's an unlabeled broadcast dimension, choose + * the next broadcast dimension from the operand. + */ + if (label == 0) { + while (ibroadcast >= 0 && labels[ibroadcast] != 0) { + --ibroadcast; + } + /* + * If we used up all the operand broadcast dimensions, + * extend it with a "newaxis" + */ + if (ibroadcast < 0) { + axes[i] = -1; + } else { + /* Otherwise map to the broadcast axis */ + axes[i] = ibroadcast; + --ibroadcast; + } + } else { + /* It's a labeled dimension, find the matching one */ + char *match = reinterpret_cast(memchr(labels, label, ndim)); + /* If the op doesn't have the label, broadcast it */ + if (match == NULL) { + axes[i] = -1; + } else { + /* Otherwise use it */ + axes[i] = match - labels; + } + } + } + return 0; +} + +struct NumpyEinsumParam: public dmlc::Parameter { + int num_args; + int optimize; + std::string subscripts; + DMLC_DECLARE_PARAMETER(NumpyEinsumParam) { + DMLC_DECLARE_FIELD(num_args) + .set_lower_bound(1) + .describe("Number of input arrays."); + DMLC_DECLARE_FIELD(subscripts) + .set_default("") + .describe("Specifies the subscripts for summation as comma separated list" + " of subscript labels. An implicit (classical Einstein summation) calculation" + " is performed unless the explicit indicator ‘->’ is included as well as" + " subscript labels of the precise output form."); + DMLC_DECLARE_FIELD(optimize) + .set_default(0); + } +}; + +class EinsumOp { + public: + int num_args; + int optimize; + std::string subscripts; + std::shared_ptr tempspace; + std::vector paths; + explicit EinsumOp(int num_args, int optimize, std::string subscripts) { + this->num_args = num_args; + this->optimize = optimize; + this->subscripts = subscripts; + } +}; // class EinsumOp + +template +struct numpy_einsum { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out, + common::StaticArray op, + mshadow::Shape oshape, + mshadow::Shape ostride, + mshadow::Shape reduceshape, + mshadow::Shape reducestride, + mshadow::Shape itershape, + common::StaticArray, + NPY_MAXARGS> iterstride, + int nop, + int iop0, + const DType* out_grad) { + using namespace mxnet_op; + index_t oidx = back ? dot(unravel(dot(unravel(i, oshape), ostride), itershape), + iterstride[iop0]) : i; + if (req == kWriteTo) { + out[oidx] = (DType)0; + } + for (int j = 0; j < reduceshape.Size(); j++) { + mshadow::Shape idx = unravel(dot(unravel(j, reduceshape), reducestride) + + dot(unravel(i, oshape), ostride), + itershape); + DType tmp = back ? out_grad[dot(idx, iterstride[nop])] : (DType)1; + for (int iop = 0; iop < nop; ++iop) { + if (iop != iop0) { + index_t k = dot(idx, iterstride[iop]); + tmp = tmp * op[iop][k]; + } + } + out[oidx] = out[oidx] + tmp; + } + } +}; + +template +inline void NumpyEinsumProcess(const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const char *subscripts, int nop, + const OpContext& ctx) { + using namespace mxnet_op; + + /* nop+1 (+1 is for the output) must fit in NPY_MAXARGS */ + CHECK(nop < NPY_MAXARGS) + << "too many operands provided to einstein sum function"; + CHECK(nop >= 1) + << "not enough operands provided to einstein sum function"; + + /* Step 1: Parse the subscripts string into label_counts and op_labels */ + int iop, idim, min_label = 127, max_label = 0; + char label_counts[128], op_labels[NPY_MAXARGS][NPY_MAXDIMS]; + memset(label_counts, 0, sizeof(label_counts)); + for (iop = 0; iop < nop; ++iop) { + int length = static_cast(strcspn(subscripts, ",-")); + + CHECK(!(iop == nop - 1 && subscripts[length] == ',')) + << "more operands provided to einstein sum function " + << "than specified in the subscripts string"; + CHECK(!(iop < nop-1 && subscripts[length] != ',')) + << "fewer operands provided to einstein sum function " + << "than specified in the subscripts string"; + CHECK_GE(parse_operand_subscripts(subscripts, length, + inputs[iop + back].shape_.ndim(), + iop, op_labels[iop], label_counts, + &min_label, &max_label), 0); + + /* Move subscripts to the start of the labels for the next op */ + subscripts += length; + if (iop < nop - 1) { + subscripts++; + } + } + + /* + * Find the number of broadcast dimensions, which is the maximum + * number of labels == 0 in an op_labels array. + */ + int ndim_broadcast = 0; + for (iop = 0; iop < nop; ++iop) { + int count_zeros = 0; + int ndim; + char *labels = op_labels[iop]; + + ndim = inputs[iop + back].shape_.ndim(); + for (idim = 0; idim < ndim; ++idim) { + if (labels[idim] == 0) { + ++count_zeros; + } + } + + if (count_zeros > ndim_broadcast) { + ndim_broadcast = count_zeros; + } + } + + /* + * If there is no output signature, fill output_labels and ndim_output + * using each label that appeared once, in alphabetical order. + */ + int label, ndim_output; + char output_labels[NPY_MAXDIMS]; + if (subscripts[0] == '\0') { + /* If no output was specified, always broadcast left, as usual. */ + for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) { + output_labels[ndim_output] = 0; + } + for (label = min_label; label <= max_label; ++label) { + if (label_counts[label] == 1) { + CHECK(ndim_output < NPY_MAXDIMS) + << "einstein sum subscript string has too many " + << "distinct labels"; + output_labels[ndim_output++] = label; + } + } + } else { + CHECK(subscripts[0] == '-' && subscripts[1] == '>') + << "einstein sum subscript string does not " + << "contain proper '->' output specified"; + subscripts += 2; + + /* Parse the output subscript string. */ + ndim_output = parse_output_subscripts(subscripts, strlen(subscripts), + ndim_broadcast, label_counts, + output_labels); + CHECK_GE(ndim_output, 0); + } + + /* + * Step 2: + * Process all the input ops, combining dimensions into their + * diagonal where specified. + */ + std::vector opshape(nop), opstride_true(nop); + for (iop = 0; iop < nop; ++iop) { + char *labels = op_labels[iop]; + int combine, ndim; + + ndim = inputs[iop + back].shape_.ndim(); + + /* + * Check whether any dimensions need to be combined + * + * The char type may be either signed or unsigned, we + * need it to be signed here. + */ + combine = 0; + for (idim = 0; idim < ndim; ++idim) { + if ((signed char)labels[idim] < 0) { + combine++; + } + } + + /* If any dimensions are combined, create a view which combines them */ + if (combine) { + TShape tshape(ndim - combine, -1); + TShape tstride(ndim - combine, -1); + get_combined_dims_view(inputs[iop + back], iop, labels, + &tshape, &tstride); + opshape[iop] = tshape; + opstride_true[iop] = tstride; + } else { + /* No combining needed */ + opshape[iop] = inputs[iop + back].shape_; + opstride_true[iop] = get_stride(opshape[iop]); + } + } + + /* + * Step 3: + * Set up the labels for the iterator (output + combined labels). + * Can just share the output_labels memory, because iter_labels + * is output_labels with some more labels appended. + */ + char *iter_labels = output_labels; + int ndim_iter = ndim_output; + for (label = min_label; label <= max_label; ++label) { + if (label_counts[label] > 0 && + memchr(output_labels, label, ndim_output) == NULL) { + CHECK(ndim_iter < NPY_MAXDIMS) + << "too many subscripts in einsum"; + iter_labels[ndim_iter++] = label; + } + } + + /* Step 4: Set up the op_axes for the iterator */ + TShape itershape(ndim_iter, -1), iterstride_true(ndim_iter, -1); + TShape oshape = back ? inputs[0].shape_ : outputs[0].shape_; + TShape ostride_true = get_stride(oshape); + TShape reduceshape, ostride, reducestride; + std::vector iterstride(nop + 1, TShape(ndim_iter, 0)); + std::vector remainshape(nop), opstride(nop), remainstride(nop); + int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS]; + int *op_axes[NPY_MAXARGS]; + + for (iop = 0; iop < nop; ++iop) { + op_axes[iop] = op_axes_arrays[iop]; + CHECK_GE(prepare_op_axes(opshape[iop].ndim(), iop, op_labels[iop], + op_axes[iop], ndim_iter, iter_labels), 0); + for (idim = 0; idim < ndim_iter; idim++) { + if (op_axes[iop][idim] != -1) { + iterstride[iop][idim] = opstride_true[iop][op_axes[iop][idim]]; + if (itershape[idim] != -1) { + if (itershape[idim] == 1) { + itershape[idim] = opshape[iop][op_axes[iop][idim]]; + } + } else { + itershape[idim] = opshape[iop][op_axes[iop][idim]]; + } + } + } + } + for (idim = 0; idim < ndim_output; ++idim) { + iterstride[nop][idim] = ostride_true[idim]; + } + iterstride_true = get_stride(itershape); + reduceshape = TShape(ndim_iter - ndim_output, 0); + for (idim = ndim_output; idim < ndim_iter; ++idim) { + reduceshape[idim - ndim_output] = itershape[idim]; + } + for (iop = 0; iop < nop; iop++) { + std::vector rsh; + for (idim = 0; idim < ndim_iter; idim++) { + if (op_axes_arrays[iop][idim] == -1 || + itershape[idim] != opshape[iop][op_axes_arrays[iop][idim]]) { + rsh.push_back(itershape[idim]); + } + } + remainshape[iop] = TShape(rsh.begin(), rsh.end()); + } + + // calculate stride + ostride = TShape(ndim_output, 0); + for (idim = 0; idim < ndim_output; ++idim) { + ostride[idim] = iterstride_true[idim]; + } + reducestride = TShape(ndim_iter - ndim_output, 0); + for (idim = ndim_output; idim < ndim_iter; ++idim) { + reducestride[idim - ndim_output] = iterstride_true[idim]; + } + for (iop = 0; iop < nop; ++iop) { + opstride[iop] = TShape(opshape[iop].ndim(), 0); + remainstride[iop] = TShape(remainshape[iop].ndim(), 0); + int j = 0; + for (idim = 0; idim < ndim_iter; ++idim) { + if (op_axes_arrays[iop][idim] != -1 && + itershape[idim] == opshape[iop][op_axes_arrays[iop][idim]]) { + opstride[iop][op_axes_arrays[iop][idim]] = iterstride_true[idim]; + } else { + remainstride[iop][j++] = iterstride_true[idim]; + } + } + CHECK_EQ(j, remainstride[iop].ndim()); + } + + // exclude the 0-dim case + if (ndim_iter == 0) { + ndim_iter = 1; + } + itershape = pad(itershape, ndim_iter); + for (iop = 0; iop <= nop; ++iop) { + iterstride[iop] = pad(iterstride[iop], ndim_iter); + } + oshape = pad(oshape, ndim_iter); + ostride = pad(ostride, ndim_iter); + reduceshape = pad(reduceshape, ndim_iter); + reducestride = pad(reducestride, ndim_iter); + for (iop = 0; iop < nop; ++iop) { + opshape[iop] = pad(opshape[iop], ndim_iter); + opstride[iop] = pad(opstride[iop], ndim_iter); + remainshape[iop] = pad(remainshape[iop], ndim_iter); + remainstride[iop] = pad(remainstride[iop], ndim_iter); + } + + if (!back) { + if (oshape.Size() == 0) { + return; + } + const TBlob &out_data = outputs[0]; + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + mxnet::common::StaticArray op; + for (iop = 0; iop < nop; ++iop) { + op[iop] = inputs[iop].dptr(); + } + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, { + mxnet::common::StaticArray, NPY_MAXARGS> iterstride_arr; + for (iop = 0; iop <= nop; ++iop) { + iterstride_arr[iop] = iterstride[iop].get(); + } + Kernel, + xpu>::Launch(ctx.get_stream(), + oshape.Size(), + out_data.dptr(), + op, + oshape.get(), + ostride.get(), + reduceshape.get(), + reducestride.get(), + itershape.get(), + iterstride_arr, + nop, + -1, + reinterpret_cast(NULL)); + }) + }) + }) + } else { + if (oshape.Size() == 0) { + for (iop = 0; iop < nop; ++iop) { + const TBlob& out_data = outputs[iop]; + if (opshape[iop].Size() > 0) { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[iop], req_type, { + if (req_type == kWriteTo) { + out_data.FlatTo1D(ctx.get_stream()) = 0; + } + }) + }) + } + } + return; + } + for (int i = 0; i < nop; ++i) { + const TBlob &out_data = outputs[i]; + const TBlob &out_grad = inputs[0]; + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + mxnet::common::StaticArray op; + for (iop = 0; iop < nop; ++iop) { + op[iop] = inputs[iop + back].dptr(); + } + MXNET_ASSIGN_REQ_SWITCH(req[i], req_type, { + MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, { + mxnet::common::StaticArray, NPY_MAXARGS> iterstride_arr; + for (iop = 0; iop <= nop; ++iop) { + iterstride_arr[iop] = iterstride[iop].get(); + } + Kernel, + xpu>::Launch(ctx.get_stream(), + opshape[i].Size(), + out_data.dptr(), + op, + opshape[i].get(), + opstride[i].get(), + remainshape[i].get(), + remainstride[i].get(), + itershape.get(), + iterstride_arr, + nop, + i, + out_grad.dptr()); + }) + }) + }) + } + } +} + +template +inline void NumpyEinsumForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + EinsumOp& state = state_ptr.get_state(); + int num_args = state.num_args; + int optimize = state.optimize; + const char* subscripts = state.subscripts.c_str(); + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), num_args); + CHECK_EQ(outputs.size(), 1U); + if (optimize == 0) { + NumpyEinsumProcess(inputs, req, outputs, subscripts, num_args, ctx); + return; + } + std::vector& paths = state.paths; + std::vector > pos; + std::string string_repr; + paths = einsum_path(state.subscripts, inputs, true, ctx.run_ctx, &pos, &string_repr); + int paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0; + std::vector operands(inputs), tmp_operands, temp_space_vec(paths_len - 1); + for (int i = 0; i + 1 < paths_len; ++i) { + temp_space_size += paths[i].oshape.Size(); + } + for (int i = 0; i < paths_len; ++i) { + max_temp_space_size = std::max(max_temp_space_size, static_cast(paths[i].oshape.Size())); + } + temp_space_size += max_temp_space_size; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + state.tempspace.reset(new NDArray(TShape(Shape1(temp_space_size)), + ctx.run_ctx.ctx, + false, + outputs[0].type_flag_)); + Tensor temp_space = state.tempspace->data().FlatTo1D(); + int begin = max_temp_space_size; + for (int i = 0; i < paths_len - 1; ++i) { + TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size())); + temp_space_vec[i] = tblob.reshape(paths[i].oshape); + begin = begin + paths[i].oshape.Size(); + } + for (int i = 0; i < paths_len; ++i) { + tmp_operands.clear(); + + // We remove inds from right to left + for (const int& p : paths[i].contract_inds) { + tmp_operands.push_back(operands[p]); + operands.erase(operands.begin() + p); + } + bool handle_out = (i == paths_len - 1); + // Call tensordot if still possible + if (paths[i].do_blas) { + // Contract! + if (paths[i].do_einsum || handle_out) { + TBlob max_temp_space = TBlob(temp_space.Slice(0, paths[i].tshape.Size())); + max_temp_space.FlatTo1D(s) = 0; + max_temp_space = max_temp_space.reshape(paths[i].tshape); + size_t tensordot_tempspace_size = + TensordotWorkspaceSize(paths[i].left_pos, + paths[i].right_pos, + tmp_operands[0], + tmp_operands[1], + max_temp_space, + std::vector{OpReqType::kWriteTo}); + Tensor tensordot_tempspace = + ctx.requested[0].get_space_typed(Shape1(tensordot_tempspace_size), s); + TensordotImpl(paths[i].left_pos, + paths[i].right_pos, + ctx, + tmp_operands[0], + tmp_operands[1], + max_temp_space, + std::vector{OpReqType::kWriteTo}, + tensordot_tempspace); + NumpyEinsumProcess(std::vector{max_temp_space}, + handle_out ? req : std::vector{OpReqType::kWriteTo}, + handle_out ? outputs : std::vector{temp_space_vec[i]}, + paths[i].blas2einsum_str.c_str(), + 1, ctx); + } else { + size_t tensordot_tempspace_size = + TensordotWorkspaceSize(paths[i].left_pos, + paths[i].right_pos, + tmp_operands[0], + tmp_operands[1], + temp_space_vec[i], + std::vector{OpReqType::kWriteTo}); + Tensor tensordot_tempspace = ctx.requested[0].get_space_typed( + Shape1(tensordot_tempspace_size), s); + TensordotImpl(paths[i].left_pos, + paths[i].right_pos, + ctx, + tmp_operands[0], + tmp_operands[1], + temp_space_vec[i], + std::vector{OpReqType::kWriteTo}, + tensordot_tempspace); + } + } else { + NumpyEinsumProcess(tmp_operands, + handle_out ? req : std::vector{OpReqType::kWriteTo}, + handle_out ? outputs : std::vector{temp_space_vec[i]}, + paths[i].einsum_str.c_str(), tmp_operands.size(), ctx); + } + if (!handle_out) { + operands.push_back(temp_space_vec[i]); + } + } + }); +} + +template +inline void NumpyEinsumBackward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + const EinsumOp& state = state_ptr.get_state(); + int num_args = state.num_args; + int optimize = state.optimize; + const char* subscripts = state.subscripts.c_str(); + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1 + num_args); + CHECK_EQ(outputs.size(), num_args); + if (optimize == 0) { + NumpyEinsumProcess(inputs, req, outputs, subscripts, num_args, ctx); + return; + } + // calculate temporary space size for temp_grad + const std::vector& paths = state.paths; + int paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0; + for (int i = 0; i < paths_len - 1; ++i) { + temp_space_size += paths[i].oshape.Size(); + } + for (int i = 0; i < paths_len; ++i) { + max_temp_space_size = std::max(max_temp_space_size, static_cast(paths[i].oshape.Size())); + } + temp_space_size += max_temp_space_size; + // replay the forward process + std::vector > op_idx(paths_len + 1); + for (int i = 0; i <= paths_len; ++i) { + if (i == 0) { + op_idx[i].reserve(num_args); + for (int j = 0; j < num_args; ++j) { + op_idx[i].push_back(j + 1); + } + } else { + op_idx[i] = op_idx[i - 1]; + // We remove inds from right to left + for (const int& p : paths[i - 1].contract_inds) { + op_idx[i].erase(op_idx[i].begin() + p); + } + op_idx[i].push_back(-static_cast(i - 1)); + } + } + // calculate temporary space size for tensordot + int tensordot_max_tempspace_size = 0; + int begin_tensordot_tempspace = 0; + std::vector temp_inputs, temp_outputs; + std::vector temp_req; + std::vector tensordot_tempspace_size; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + for (int i = 0; i < paths_len; ++i) { + temp_inputs.clear(); + temp_outputs.clear(); + temp_req.clear(); + bool handle_out = (i == paths_len - 1); + + if (handle_out) { + temp_inputs.push_back(inputs[0]); + } else { + temp_inputs.push_back(TBlob(reinterpret_cast(NULL), + paths[i].oshape, + xpu::kDevMask)); + } + for (auto p : paths[i].contract_inds) { + int idx = op_idx[i][p]; + if (idx >= 1) { + temp_inputs.push_back(inputs[idx]); + temp_outputs.push_back(outputs[idx - 1]); + temp_req.push_back(req[idx - 1]); + } else { + temp_inputs.push_back(TBlob(reinterpret_cast(NULL), + paths[-idx].oshape, + xpu::kDevMask)); + temp_outputs.push_back(TBlob(reinterpret_cast(NULL), + paths[-idx].oshape, + xpu::kDevMask)); + temp_req.push_back(OpReqType::kWriteTo); + } + } + size_t cur_tensordot_tempspace_size = 0; + if (paths[i].do_blas) { + if (paths[i].do_einsum) { + cur_tensordot_tempspace_size = + TensordotBackwardWorkspaceSize(paths[i].left_pos, + paths[i].right_pos, + TBlob(reinterpret_cast(NULL), + paths[i].tshape, + xpu::kDevMask), + temp_inputs[1], + temp_inputs[2], + temp_outputs[0], + temp_outputs[1], + temp_req); + } else { + cur_tensordot_tempspace_size = + TensordotBackwardWorkspaceSize(paths[i].left_pos, + paths[i].right_pos, + temp_inputs[0], + temp_inputs[1], + temp_inputs[2], + temp_outputs[0], + temp_outputs[1], + temp_req); + } + } + tensordot_tempspace_size.push_back(cur_tensordot_tempspace_size); + tensordot_max_tempspace_size = std::max(tensordot_max_tempspace_size, + static_cast(cur_tensordot_tempspace_size)); + } + begin_tensordot_tempspace = temp_space_size; + temp_space_size += (tensordot_max_tempspace_size + sizeof(DType) - 1) / sizeof(DType); + }); + // allocate temporary space and propagate + std::vector temp_grad(paths_len - 1), temp_data(paths_len - 1); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + // allocate temporary space for gradients of intermediate results + Tensor temp_space = ctx.requested[0].get_space_typed + (Shape1(temp_space_size), s); + int begin = max_temp_space_size; + for (int i = 0; i + 1 < paths_len; ++i) { + TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size())); + temp_grad[i] = tblob.reshape(paths[i].oshape); + begin = begin + paths[i].oshape.Size(); + } + + // reinterprete ndarray for intermediate results + Tensor ndarray_space = state.tempspace->data().FlatTo1D(); + begin = max_temp_space_size; + for (int i = 0; i + 1 < paths_len; ++i) { + TBlob tblob = TBlob(ndarray_space.Slice(begin, begin + paths[i].oshape.Size())); + temp_data[i] = tblob.reshape(paths[i].oshape); + begin = begin + paths[i].oshape.Size(); + } + + // go through the paths in the reversed order + for (int i = paths_len - 1; i >= 0; i--) { + temp_inputs.clear(); + temp_outputs.clear(); + temp_req.clear(); + bool handle_out = (i == paths_len - 1); + + if (handle_out) { + temp_inputs.push_back(inputs[0]); + } else { + temp_inputs.push_back(temp_grad[i]); + } + for (auto p : paths[i].contract_inds) { + int idx = op_idx[i][p]; + if (idx >= 1) { + temp_inputs.push_back(inputs[idx]); + temp_outputs.push_back(outputs[idx - 1]); + temp_req.push_back(req[idx - 1]); + } else { + temp_inputs.push_back(temp_data[-idx]); + temp_outputs.push_back(temp_grad[-idx]); + temp_req.push_back(OpReqType::kWriteTo); + } + } + if (paths[i].do_blas) { + CHECK_EQ(temp_inputs.size(), 3U); + CHECK_EQ(temp_outputs.size(), 2U); + CHECK_EQ(temp_req.size(), 2U); + Tensor tensordot_tempspace = temp_space.Slice(begin_tensordot_tempspace, + temp_space_size); + Tensor char_tempspace = + Tensor(reinterpret_cast(tensordot_tempspace.dptr_), + Shape1(tensordot_tempspace_size[i]), + tensordot_tempspace.stream_); + if (paths[i].do_einsum) { + TBlob max_temp_space = TBlob(temp_space.Slice(0, paths[i].tshape.Size())); + max_temp_space = max_temp_space.reshape(paths[i].tshape); + NumpyEinsumProcess(std::vector{temp_inputs[0]}, + std::vector{kWriteTo}, + std::vector{max_temp_space}, + paths[i].einsum2blas_str.c_str(), + 1, ctx); + TensordotBackwardImpl(paths[i].left_pos, paths[i].right_pos, ctx, + max_temp_space, temp_inputs[1], temp_inputs[2], + temp_outputs[0], temp_outputs[1], temp_req, char_tempspace); + } else { + TensordotBackwardImpl(paths[i].left_pos, paths[i].right_pos, ctx, + temp_inputs[0], temp_inputs[1], temp_inputs[2], + temp_outputs[0], temp_outputs[1], temp_req, char_tempspace); + } + } else { + NumpyEinsumProcess(temp_inputs, temp_req, temp_outputs, + paths[i].einsum_str.c_str(), + temp_outputs.size(), + ctx); + } + } + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_ diff --git a/src/operator/numpy/np_einsum_op.cc b/src/operator/numpy/np_einsum_op.cc new file mode 100644 index 000000000000..4d232b9b7c04 --- /dev/null +++ b/src/operator/numpy/np_einsum_op.cc @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright (c) 2005-2019, NumPy Developers. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * * Neither the name of the NumPy Developers nor the names of any + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/*! + * \file np_einsum_op.cc + * \brief CPU Implementation of numpy-compatible einsum + */ + +#include "./np_einsum_op-inl.h" +#include +#include + +namespace mxnet { +namespace op { + +inline std::vector _parse_einsum_input(std::string subscripts, + const mxnet::ShapeVector& shapes) { + const std::string einsum_symbols = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + std::bitset einsum_symbols_set; + for (const char& c : einsum_symbols) { + einsum_symbols_set.set(c); + } + + CHECK_NE(shapes.size(), 0U) + << "No input operands"; + + auto end_pos = std::remove(subscripts.begin(), subscripts.end(), ' '); + subscripts.erase(end_pos, subscripts.end()); + + // Ensure all characters are valid + for (const char& c : subscripts) { + if (c == '.' || c == ',' || c == '-' || c == '>') { + continue; + } + CHECK(einsum_symbols_set.test(c)) + << "Character " << c + << " is not a valid symbol."; + } + + // Check for proper "->" + if (subscripts.find('-') != std::string::npos || + subscripts.find('>') != std::string::npos) { + bool invalid = (std::count(subscripts.begin(), subscripts.end(), '-') > 1 || + std::count(subscripts.begin(), subscripts.end(), '>') > 1); + CHECK(!invalid && _count_substring(subscripts, "->") == 1) + << "Subscripts can only contain one '->'."; + } + + // Parse ellipses + if (subscripts.find('.') != std::string::npos) { + std::string used = subscripts; + used.erase(std::remove_if(used.begin(), + used.end(), + [](const char& c){return c == '.' || + c == ',' || + c == '-' || + c == '>';}), + used.end()); + + std::bitset used_set = str2set(used); + std::string ellipse_inds = ""; + for (const char& c : einsum_symbols) { + if (!used_set.test(static_cast(c))) { + ellipse_inds.append(1, c); + } + } + int longest = 0; + std::string input_tmp, output_sub; + std::vector split_subscripts; + bool out_sub; + + if (subscripts.find("->") != std::string::npos) { + std::vector tmp = split(subscripts, "->"); + input_tmp = tmp[0]; + output_sub = tmp[1]; + split_subscripts = split(input_tmp, ","); + out_sub = true; + } else { + split_subscripts = split(subscripts, ","); + out_sub = false; + } + + size_t size_split_subscripts = split_subscripts.size(); + subscripts = ""; + for (size_t i = 0; i < size_split_subscripts; ++i) { + const std::string& sub = split_subscripts[i]; + if (sub.find('.') != std::string::npos) { + CHECK_EQ(std::count(sub.begin(), sub.end(), '.'), 3) + << "Invalid Ellipses"; + CHECK_EQ(_count_substring(sub, "..."), 1) + << "Invalid Ellipses"; + + // Take into account numerical values + int ellipse_count = 0; + if (shapes[i].ndim() == 0) { + ellipse_count = 0; + } else { + ellipse_count = std::max(shapes[i].ndim(), 1); + ellipse_count -= sub.length() - 3; + } + + if (ellipse_count > longest) { + longest = ellipse_count; + } + + CHECK_GE(ellipse_count, 0) + << "Ellipses lengths do not match."; + if (ellipse_count == 0) { + split_subscripts[i].erase(sub.find("..."), 3); + } else { + std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() - ellipse_count); + split_subscripts[i].replace(sub.find("..."), 3, rep_inds); + } + } + subscripts += split_subscripts[i]; + if (i + 1 < size_split_subscripts) { + subscripts += ","; + } + } + std::string out_ellipse; + if (longest == 0) { + out_ellipse = ""; + } else { + out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest); + } + + if (out_sub) { + output_sub.replace(output_sub.find("..."), 3, out_ellipse); + subscripts += "->" + output_sub; + } else { + // Special care for outputless ellipses + std::bitset out_ellipse_set = str2set(out_ellipse); + std::string tmp_subscripts = subscripts, output_subscript = ""; + size_t len_tmp_subscripts = tmp_subscripts.length(); + std::sort(tmp_subscripts.begin(), tmp_subscripts.end()); + for (size_t i = 0; i < len_tmp_subscripts; ++i) { + const char& c = tmp_subscripts[i]; + if (c == ',') { + continue; + } + CHECK(einsum_symbols_set.test(c)) + << "Character " << c + << " is not a valid symbol."; + if ((i == 0 || tmp_subscripts[i - 1] != c) && + (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) && + !out_ellipse_set.test(c)) { + output_subscript.append(1, c); + } + } + subscripts += "->" + out_ellipse + output_subscript; + } + } + + // Build output string if does not exist + std::vector ret(2); + if (subscripts.find("->") != std::string::npos) { + ret = split(subscripts, "->"); + } else { + ret[0] = subscripts; + ret[1] = ""; + // Build output subscripts + std::string tmp_subscripts = subscripts; + size_t len_tmp_subscripts = tmp_subscripts.length(); + std::sort(tmp_subscripts.begin(), tmp_subscripts.end()); + for (size_t i = 0; i < len_tmp_subscripts; ++i) { + const char& c = tmp_subscripts[i]; + if (c == ',') { + continue; + } + CHECK(einsum_symbols_set.test(c)) + << "Character " << c + << " is not a valid symbol."; + if ((i == 0 || tmp_subscripts[i - 1] != c) && + (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) { + ret[1].append(1, c); + } + } + } + + // Make sure output subscripts are in the input + std::bitset input_subscripts_set = str2set(ret[0]); + for (const char& c : ret[1]) { + CHECK(input_subscripts_set.test(c)) + << "Output character " << c + << " did not appear in the input"; + } + + // Make sure number operands is equivalent to the number of terms + CHECK_EQ(std::count(ret[0].begin(), ret[0].end(), ',') + 1, shapes.size()) + << "Number of einsum subscripts must be equal to the " + << "number of operands."; + + return ret; +} + + +bool NumpyEinsumShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const NumpyEinsumParam ¶m = nnvm::get(attrs.parsed); + const std::string& subscripts = param.subscripts; + int num_args = param.num_args; + CHECK_EQ(in_attrs->size(), num_args); + CHECK_EQ(out_attrs->size(), 1U); + for (int i = 0; i < num_args; i++) { + if (!shape_is_known(in_attrs->at(i))) { + return false; + } + } + + // Parsing + std::vector parsed_subscripts = _parse_einsum_input(subscripts, *in_attrs); + + // Build a few useful list and sets + std::vector input_list = split(parsed_subscripts[0], ","); + size_t isize = input_list.size(); + + // Get length of each unique dimension and ensure all dimensions are correct + dim_t dimension_dict[MAXAXIS]; + memset(dimension_dict, -1, sizeof(dimension_dict)); + for (size_t i = 0; i < isize; ++i) { + const std::string& term = input_list[i]; + const TShape& sh = in_attrs->at(i); + CHECK_EQ(sh.ndim(), term.length()) + << "Einstein sum subscript " << input_list[i] + << " does not contain the " + << "correct number of indices for operand " << i << "."; + size_t len_term = term.length(); + for (size_t j = 0; j < len_term; ++j) { + dim_t dim = sh[j]; + const char& c = term[j]; + + if (dimension_dict[static_cast(c)] != -1) { + // For broadcasting cases we always want the largest dim size + if (dimension_dict[static_cast(c)] == 1) { + dimension_dict[static_cast(c)] = dim; + } + CHECK(dim == 1 || dim == dimension_dict[static_cast(c)]) + << "Size of label '" << c + << "' for operand " << i + << " (" << dimension_dict[static_cast(c)] + << ") does not match previous terms (" + << dim << ")."; + } else { + dimension_dict[static_cast(c)] = dim; + } + } + } + + // Get oshape + const std::string& output_str = parsed_subscripts[1]; + size_t odim = output_str.size(); + TShape oshape(odim, -1); + for (size_t i = 0; i < odim; ++i) { + oshape[i] = dimension_dict[static_cast(output_str[i])]; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + return shape_is_known(oshape); +} + +OpStatePtr CreateEinsumState(const NodeAttrs& attrs, + Context ctx, + const mxnet::ShapeVector& in_shapes, + const std::vector& in_types) { + const NumpyEinsumParam& param = dmlc::get(attrs.parsed); + return OpStatePtr::Create(param.num_args, param.optimize, param.subscripts); +} + +DMLC_REGISTER_PARAMETER(NumpyEinsumParam); + +NNVM_REGISTER_OP(_npi_einsum) +.describe(R"doc()doc" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const NumpyEinsumParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args); +}) +.set_num_outputs(1) +.set_attr("key_var_num_args", "num_args") +.set_attr("FListInputNames", + [](const nnvm::NodeAttrs& attrs) { + int num_args = dmlc::get(attrs.parsed).num_args; + std::vector ret; + for (int i = 0; i < num_args; i++) { + ret.push_back(std::string("arg") + std::to_string(i)); + } + return ret; +}) +.set_attr("FInferShape", NumpyEinsumShape) +.set_attr("FInferType", ElemwiseType<-1, 1>) +.set_attr("FCreateOpState", CreateEinsumState) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector(1, ResourceRequest::kTempSpace); + }) +.set_attr("FStatefulCompute", NumpyEinsumForward) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_einsum"}) +.add_argument("data", "NDArray-or-Symbol[]", "List of eimsum operands") +.add_arguments(NumpyEinsumParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_einsum) +.set_attr_parser(ParamParser) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const NumpyEinsumParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args + 1); +}) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const NumpyEinsumParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_args); +}) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector(1, ResourceRequest::kTempSpace); + }) +.set_attr("FStatefulCompute", NumpyEinsumBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_einsum_op.cu b/src/operator/numpy/np_einsum_op.cu new file mode 100644 index 000000000000..1f76f2436436 --- /dev/null +++ b/src/operator/numpy/np_einsum_op.cu @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_einsum_op.cu + * \brief GPU Implementation of numpy-compatible einsum + */ + +#include "./np_einsum_op-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_einsum) +.set_attr("FStatefulCompute", NumpyEinsumForward); +NNVM_REGISTER_OP(_backward_npi_einsum) +.set_attr("FStatefulCompute", NumpyEinsumBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_einsum_path_op-inl.h b/src/operator/numpy/np_einsum_path_op-inl.h new file mode 100644 index 000000000000..cebd4e8ce9af --- /dev/null +++ b/src/operator/numpy/np_einsum_path_op-inl.h @@ -0,0 +1,964 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright (c) 2005-2019, NumPy Developers. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * * Neither the name of the NumPy Developers nor the names of any + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/*! + * \file np_einsum_path_op-inl.h + * \brief Function definition of numpy-compatible einsum_path operator + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_EINSUM_PATH_OP_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_EINSUM_PATH_OP_INL_H_ + +#include +#include +#include +#include +#include +#include + +namespace mxnet { +namespace op { + +const int MAXAXIS = 128; + +typedef std::vector > SetVector; + +struct Contraction { + std::bitset new_result; + std::vector > remaining; + std::bitset idx_removed; + std::bitset idx_contract; +}; + +struct Alternative { + int cost[2]; + std::vector positions; + SetVector new_input_sets; +}; + +struct Step { + std::vector contract_inds; + std::bitset idx_removed; + std::string einsum_str, blas2einsum_str, einsum2blas_str; + std::vector input_list; + bool do_blas, do_einsum; + TShape oshape, tshape; + Tuple left_pos, right_pos; +}; + +inline size_t _compute_size_by_dict(const std::string& indices, + const dim_t idx_dict[]) { + size_t ret = 1; + for (const char& c : indices) { + ret *= idx_dict[static_cast(c)]; + } + return ret; +} + +inline size_t _compute_size_by_dict(const std::bitset& indices, + const dim_t idx_dict[]) { + size_t ret = 1; + for (int i = 0; i < MAXAXIS; ++i) { + if (indices[i]) { + ret *= idx_dict[i]; + } + } + return ret; +} + +inline int _flop_count(const std::string& idx_contraction, + bool inner, + int num_terms, + const dim_t size_dictionary[]) { + size_t overall_size = _compute_size_by_dict(idx_contraction, size_dictionary); + int op_factor = std::max(1, num_terms - 1); + if (inner) { + ++op_factor; + } + return overall_size * op_factor; +} + +inline int _flop_count(const std::bitset& idx_contraction, + bool inner, + int num_terms, + const dim_t size_dictionary[]) { + size_t overall_size = _compute_size_by_dict(idx_contraction, size_dictionary); + int op_factor = std::max(1, num_terms - 1); + if (inner) { + ++op_factor; + } + return overall_size * op_factor; +} + +inline Contraction _find_contraction(const std::vector& positions, + const SetVector& input_sets, + const std::bitset& output_set) { + Contraction ret; + std::bitset idx_remain(output_set); + size_t size = input_sets.size(); + for (size_t i = 0; i < size; ++i) { + if (std::find(positions.begin(), positions.end(), i) != positions.end()) { + ret.idx_contract |= input_sets[i]; + } else { + ret.remaining.push_back(input_sets[i]); + idx_remain |= input_sets[i]; + } + } + ret.new_result = idx_remain & ret.idx_contract; + ret.idx_removed = (ret.idx_contract & ~ret.new_result); + ret.remaining.push_back(ret.new_result); + + return ret; +} + +inline int _parse_possible_contraction(const std::vector& positions, + const SetVector& input_sets, + const std::bitset& output_set, + const dim_t idx_dict[], + int memory_limit, + int path_cost, + int naive_cost, + Alternative* ret) { + // Find the contraction + Contraction contract = _find_contraction(positions, input_sets, output_set); + + // Sieve the results based on memory_limit + size_t new_size = _compute_size_by_dict(contract.new_result, idx_dict); + if (new_size > static_cast(memory_limit)) { + return -1; + } + + // Build sort tuple + size_t old_sizes = 0; + for (auto p : positions) { + old_sizes += _compute_size_by_dict(input_sets[p], idx_dict); + } + int remove_size = old_sizes - new_size; + + int cost = _flop_count(contract.idx_contract, contract.idx_removed.any(), + positions.size(), idx_dict); + ret->cost[0] = -remove_size; + ret->cost[1] = cost; + + // Sieve based on total cost as well + if (path_cost + cost > naive_cost) { + return -1; + } + + // Add contraction to possible choices + ret->positions = positions; + ret->new_input_sets = contract.remaining; + return 0; +} + +inline void _update_other_results(std::vector* results, + const Alternative& best) { + const std::vector& best_con = best.positions; + int bx = best_con[0], by = best_con[1]; + size_t size = results->size(); + + for (int i = size - 1; i >= 0; --i) { + int x = results->at(i).positions[0], y = results->at(i).positions[1]; + + // Ignore results involving tensors just contracted + if (x == bx || x == by || y == bx || y == by) { + results->erase(results->begin() + i); + continue; + } + + // Update the input_sets + CHECK_GT(by, bx) + << "by must be greater than bx"; + results->at(i).new_input_sets.erase(results->at(i).new_input_sets.begin() + + by - static_cast(by > x) - static_cast(by > y)); + results->at(i).new_input_sets.erase(results->at(i).new_input_sets.begin() + + bx - static_cast(bx > x) - static_cast(bx > y)); + results->at(i).new_input_sets.push_back(best.new_input_sets.back()); + + // Update the position indices + results->at(i).positions[0] = x - static_cast(x > bx) - static_cast(x > by); + results->at(i).positions[1] = y - static_cast(y > bx) - static_cast(y > by); + } +} + +inline std::vector > _greedy_path(const SetVector* input_sets, + const std::bitset& output_set, + const dim_t idx_dict[], + int memory_limit) { + size_t isize = input_sets->size(); + size_t iteration_num = isize; + // Handle trivial cases that leaked through + if (isize == 1) { + return std::vector >{std::vector{0}}; + } else if (isize == 2) { + return std::vector >{std::vector{0, 1}}; + } + + // Build up a naive cost + std::vector range(isize); + for (size_t i = 0; i < isize; ++i) { + range[i] = i; + } + Contraction contract = _find_contraction(range, *input_sets, output_set); + int naive_cost = _flop_count(contract.idx_contract, contract.idx_removed.any(), + isize, idx_dict); + + // Initially iterate over all pairs + std::vector known_contractions; + Alternative best; + int path_cost = 0; + std::vector > ret; + + for (size_t iteration = 0; iteration + 1 < iteration_num; ++iteration) { + if (iteration == 0) { + for (int x = 0; x < static_cast(isize); ++x) { + for (int y = x + 1; y < static_cast(isize); ++y) { + if (!((input_sets->at(x) & input_sets->at(y)).any())) { + continue; + } + Alternative alternative; + int result = _parse_possible_contraction(std::vector{x, y}, + *input_sets, + output_set, + idx_dict, + memory_limit, + path_cost, + naive_cost, + &alternative); + if (result != -1) { + known_contractions.push_back(alternative); + } + } + } + } else { + for (int x = 0; x < static_cast(isize) - 1; ++x) { + int y = isize - 1; + if (!((input_sets->at(x) & input_sets->at(y)).any())) { + continue; + } + Alternative alternative; + int result = _parse_possible_contraction(std::vector{x, y}, + *input_sets, + output_set, + idx_dict, + memory_limit, + path_cost, + naive_cost, + &alternative); + if (result != -1) { + known_contractions.push_back(alternative); + } + } + } + + // If we do not have a inner contraction, rescan pairs including outer products + if (known_contractions.size() == 0) { + // Then check the outer productsj + for (int x = 0; x < static_cast(isize); ++x) { + for (int y = x + 1; y < static_cast(isize); ++y) { + Alternative alternative; + int result = _parse_possible_contraction(std::vector{x, y}, + *input_sets, + output_set, + idx_dict, + memory_limit, + path_cost, + naive_cost, + &alternative); + if (result != -1) { + known_contractions.push_back(alternative); + } + } + } + + // If we still did not find any remaining contractions, default back to einsum like behavior + if (known_contractions.size() == 0) { + std::vector range(isize); + for (size_t i = 0; i < isize; ++i) { + range[i] = i; + } + ret.push_back(range); + break; + } + } + + // Sort based on first index + int best_cost[2], idx = -1; + size_t size = known_contractions.size(); + for (size_t i = 0; i < size; ++i) { + auto x = known_contractions[i]; + if (idx == -1) { + best_cost[0] = x.cost[0]; + best_cost[1] = x.cost[1]; + idx = i; + } else if (x.cost[0] < best_cost[0] || + (x.cost[0] == best_cost[0] && + x.cost[1] < best_cost[1])) { + best_cost[0] = x.cost[0]; + best_cost[1] = x.cost[1]; + idx = i; + } + } + best = known_contractions[idx]; + + // Now propagate as many unused contractions as possible to next iteration + _update_other_results(&known_contractions, best); + + // Next iteration only compute contractions with the new tensor + // All other contractions have been accounted for + input_sets = &best.new_input_sets; + isize = input_sets->size(); + + // Update path and total cost + ret.push_back(best.positions); + path_cost += best.cost[1]; + } + return ret; +} + +inline bool _can_dot(const std::vector& inputs, + const std::bitset& result, + const std::bitset& idx_removed) { + // All `dot` calls remove indices + if (!idx_removed.any()) { + return false; + } + + // BLAS can only handle two operands + if (inputs.size() != 2) { + return false; + } + + const std::string& input_left = inputs[0]; + const std::string& input_right = inputs[1]; + + if (input_left.size() == 0 || input_right.size() == 0) { + return false; + } + + for (int i = 0; i < 2; ++i) { + for (const char& c : inputs[i]) { + // can't deal with repeated indices on same input or more than 2 total + size_t nl = std::count(input_left.begin(), input_left.end(), c); + size_t nr = std::count(input_right.begin(), input_right.end(), c); + if (nl > 1 || nr > 1 || nl + nr > 2) { + return false; + } + + // can't do implicit summation or dimension collapse e.g. + // "ab,bc->c" (implicitly sum over 'a') + // "ab,ca->ca" (take diagonal of 'a') + if (nl + nr == static_cast(result.test(c)) + 1) { + return false; + } + } + } + + // Build a few temporaries + std::bitset set_left; + std::bitset set_right; + for (const char& c : input_left) { + set_left.set(c); + } + for (const char& c : input_right) { + set_right.set(c); + } + std::bitset keep_left = set_left & ~idx_removed; + std::bitset keep_right = set_right & ~idx_removed; + size_t rs = idx_removed.count(); + + // At this point we are a DOT, GEMV, or GEMM operation + + // Handle inner products + + // DDOT with aligned data + if (input_left == input_right) + return true; + + // DDOT without aligned data (better to use einsum) + if (set_left == set_right) + return false; + + // Handle the 4 possible (aligned) GEMV or GEMM cases + + // GEMM or GEMV no transpose + if (std::equal(input_left.end() - rs, + input_left.end(), + input_right.begin())) { + return true; + } + + // GEMM or GEMV transpose both + if (std::equal(input_left.begin(), + input_left.begin() + rs, + input_right.end() - rs)) { + return true; + } + + // GEMM or GEMV transpose right + if (std::equal(input_left.end() - rs, + input_left.end(), + input_right.end() - rs)) { + return true; + } + + // GEMM or GEMV transpose left + if (std::equal(input_left.begin(), + input_left.begin() + rs, + input_right.begin())) { + return true; + } + + // Einsum is faster than GEMV if we have to copy data + if (!keep_left.any() || !keep_right.any()) { + return false; + } + + // We are a matrix-matrix product, but we need to copy data + return true; +} + + +inline int _count_substring(const std::string& str, + const std::string& sub) { + int count = 0; + std::string::size_type pos = 0; + while ((pos = str.find(sub, pos)) != std::string::npos) { + ++count; + pos += sub.length(); + } + return count; +} + +inline std::bitset str2set(const std::string& str) { + std::bitset ret; + for (const char& c : str) { + ret.set(static_cast(c)); + } + return ret; +} + +inline std::string set2str(const std::bitset& set) { + std::string ret; + for (int i = 0; i < MAXAXIS; ++i) { + if (set.test(i)) { + ret.append(1, static_cast(i)); + } + } + return ret; +} + +inline std::vector split(const std::string& str, + const std::string& sub) { + std::string::size_type pos = 0; + std::string::size_type start = 0; + std::vector ret; + while ((pos = str.find(sub, start)) != std::string::npos) { + ret.push_back(str.substr(start, pos - start)); + start = pos + sub.length(); + } + ret.push_back(str.substr(start)); + return ret; +} + +inline std::vector _parse_einsum_input( + std::string subscripts, + const std::vector& operands) { + const std::string einsum_symbols = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + std::bitset einsum_symbols_set; + for (const char& c : einsum_symbols) { + einsum_symbols_set.set(c); + } + + CHECK_NE(operands.size(), 0U) + << "No input operands"; + + auto end_pos = std::remove(subscripts.begin(), subscripts.end(), ' '); + subscripts.erase(end_pos, subscripts.end()); + + // Ensure all characters are valid + for (const char& c : subscripts) { + if (c == '.' || c == ',' || c == '-' || c == '>') { + continue; + } + CHECK(einsum_symbols_set.test(c)) + << "Character " << c + << " is not a valid symbol."; + } + + // Check for proper "->" + if (subscripts.find('-') != std::string::npos || + subscripts.find('>') != std::string::npos) { + bool invalid = (std::count(subscripts.begin(), subscripts.end(), '-') > 1 || + std::count(subscripts.begin(), subscripts.end(), '>') > 1); + CHECK(!invalid && _count_substring(subscripts, "->") == 1) + << "Subscripts can only contain one '->'."; + } + + // Parse ellipses + if (subscripts.find('.') != std::string::npos) { + std::string used = subscripts; + used.erase(std::remove_if(used.begin(), + used.end(), + [](const char& c){return c == '.' || + c == ',' || + c == '-' || + c == '>';}), + used.end()); + + std::bitset used_set = str2set(used); + std::string ellipse_inds = ""; + for (const char& c : einsum_symbols) { + if (!used_set.test(static_cast(c))) { + ellipse_inds.append(1, c); + } + } + int longest = 0; + std::string input_tmp, output_sub; + std::vector split_subscripts; + bool out_sub; + + if (subscripts.find("->") != std::string::npos) { + std::vector tmp = split(subscripts, "->"); + input_tmp = tmp[0]; + output_sub = tmp[1]; + split_subscripts = split(input_tmp, ","); + out_sub = true; + } else { + split_subscripts = split(subscripts, ","); + out_sub = false; + } + + size_t size_split_subscripts = split_subscripts.size(); + subscripts = ""; + for (size_t i = 0; i < size_split_subscripts; ++i) { + const std::string& sub = split_subscripts[i]; + if (sub.find('.') != std::string::npos) { + CHECK_EQ(std::count(sub.begin(), sub.end(), '.'), 3) + << "Invalid Ellipses"; + CHECK_EQ(_count_substring(sub, "..."), 1) + << "Invalid Ellipses"; + + // Take into account numerical values + int ellipse_count = 0; + if (operands[i].shape_.ndim() == 0) { + ellipse_count = 0; + } else { + ellipse_count = std::max(operands[i].shape_.ndim(), 1); + ellipse_count -= sub.length() - 3; + } + + if (ellipse_count > longest) { + longest = ellipse_count; + } + + CHECK_GE(ellipse_count, 0) + << "Ellipses lengths do not match."; + if (ellipse_count == 0) { + split_subscripts[i].erase(sub.find("..."), 3); + } else { + std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() - ellipse_count); + split_subscripts[i].replace(sub.find("..."), 3, rep_inds); + } + } + subscripts += split_subscripts[i]; + if (i + 1 < size_split_subscripts) { + subscripts += ","; + } + } + std::string out_ellipse; + if (longest == 0) { + out_ellipse = ""; + } else { + out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest); + } + + if (out_sub) { + output_sub.replace(output_sub.find("..."), 3, out_ellipse); + subscripts += "->" + output_sub; + } else { + // Special care for outputless ellipses + std::bitset out_ellipse_set = str2set(out_ellipse); + std::string tmp_subscripts = subscripts, output_subscript = ""; + size_t len_tmp_subscripts = tmp_subscripts.length(); + std::sort(tmp_subscripts.begin(), tmp_subscripts.end()); + for (size_t i = 0; i < len_tmp_subscripts; ++i) { + const char& c = tmp_subscripts[i]; + if (c == ',') { + continue; + } + CHECK(einsum_symbols_set.test(c)) + << "Character " << c + << " is not a valid symbol."; + if ((i == 0 || tmp_subscripts[i - 1] != c) && + (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) && + !out_ellipse_set.test(c)) { + output_subscript.append(1, c); + } + } + subscripts += "->" + out_ellipse + output_subscript; + } + } + + // Build output string if does not exist + std::vector ret(2); + if (subscripts.find("->") != std::string::npos) { + ret = split(subscripts, "->"); + } else { + ret[0] = subscripts; + ret[1] = ""; + // Build output subscripts + std::string tmp_subscripts = subscripts; + size_t len_tmp_subscripts = tmp_subscripts.length(); + std::sort(tmp_subscripts.begin(), tmp_subscripts.end()); + for (size_t i = 0; i < len_tmp_subscripts; ++i) { + const char& c = tmp_subscripts[i]; + if (c == ',') { + continue; + } + CHECK(einsum_symbols_set.test(c)) + << "Character " << c + << " is not a valid symbol."; + if ((i == 0 || tmp_subscripts[i - 1] != c) && + (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) { + ret[1].append(1, c); + } + } + } + + // Make sure output subscripts are in the input + std::bitset input_subscripts_set = str2set(ret[0]); + for (const char& c : ret[1]) { + CHECK(input_subscripts_set.test(c)) + << "Output character " << c + << " did not appear in the input"; + } + + // Make sure number operands is equivalent to the number of terms + CHECK_EQ(std::count(ret[0].begin(), ret[0].end(), ',') + 1, operands.size()) + << "Number of einsum subscripts must be equal to the " + << "number of operands."; + + return ret; +} + +inline bool _tensordot_type_check(int type_flag_, const RunContext& run_ctx) { + return type_flag_ == kFloat32 || type_flag_ == kFloat64 || + (type_flag_ == kFloat16 && run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask); +} + +inline std::vector einsum_path(const std::string& subscripts, + const std::vector& operands, + bool optimize, + const RunContext& run_ctx, + std::vector >* ret_path, + std::string* ret_string_repr) { + // Parsing + std::vector parsed_subscripts = _parse_einsum_input(subscripts, operands); + + // Build a few useful list and sets + std::vector input_list = split(parsed_subscripts[0], ","); + size_t isize = input_list.size(); + SetVector input_sets; + for (int i = 0; i < static_cast(isize); ++i) { + input_sets.push_back(str2set(input_list[i])); + } + std::bitset output_set = str2set(parsed_subscripts[1]); + std::bitset indices = str2set(parsed_subscripts[0]); + indices.set(',', false); + + // Get length of each unique dimension and ensure all dimensions are correct + dim_t dimension_dict[MAXAXIS]; + SetVector broadcast_indices(isize); + memset(dimension_dict, -1, sizeof(dimension_dict)); + for (size_t i = 0; i < isize; ++i) { + const std::string& term = input_list[i]; + const TShape& sh = operands[i].shape_; + CHECK_EQ(sh.ndim(), term.length()) + << "Einstein sum subscript " << input_list[i] + << " does not contain the " + << "correct number of indices for operand " << i << "."; + size_t len_term = term.length(); + for (size_t j = 0; j < len_term; ++j) { + dim_t dim = sh[j]; + const char& c = term[j]; + // Build out broadcast indices + if (dim == 1) { + broadcast_indices[i].set(c); + } + + if (dimension_dict[static_cast(c)] != -1) { + // For broadcasting cases we always want the largest dim size + if (dimension_dict[static_cast(c)] == 1) { + dimension_dict[static_cast(c)] = dim; + } + CHECK(dim == 1 || dim == dimension_dict[static_cast(c)]) + << "Size of label '" << c + << "' for operand " << i + << " (" << dimension_dict[static_cast(c)] + << ") does not match previous terms (" + << dim << ")."; + } else { + dimension_dict[static_cast(c)] = dim; + } + } + } + + // Compute size of each input array plus the output array + std::vector size_list(isize + 1); + size_t max_size = -1, memory_arg; + for (size_t i = 0; i < isize; ++i) { + size_list[i] = _compute_size_by_dict(input_list[i], dimension_dict); + max_size = std::max(max_size, size_list[i]); + } + size_list[isize] = _compute_size_by_dict(parsed_subscripts[1], dimension_dict); + max_size = std::max(max_size, size_list[isize]); + memory_arg = max_size; + + // Compute naive cost + // This isn't quite right, need to look into exactly how einsum does this + size_t sum_len_input_sets = 0; + for (auto x : input_sets) { + sum_len_input_sets += x.count(); + } + bool inner_product = (sum_len_input_sets > indices.count()); + int naive_cost = _flop_count(indices, inner_product, isize, dimension_dict); + + // Compute the path + std::vector > path; + if (optimize == false) { + path.push_back(std::vector()); + for (size_t i = 0; i < isize; ++i) { + path[0].push_back(i); + } + } else { + path = _greedy_path(&input_sets, output_set, dimension_dict, memory_arg); + } + + std::vector cost_list; + std::vector scale_list; + int opt_cost = 1; + size_t max_i = 0, max_scale = 0, size_path = path.size(); + std::vector ret(size_path); + size_list.clear(); + + // Build contraction tuple (positions, gemm, einsum_str, remaining) + for (size_t i = 0; i < size_path; ++i) { + // Make sure we remove inds from right to left + std::vector contract_inds = path[i]; + std::sort(contract_inds.begin(), contract_inds.end(), std::greater()); + + Contraction contract = _find_contraction(contract_inds, input_sets, output_set); + input_sets = contract.remaining; + + int cost = _flop_count(contract.idx_contract, + contract.idx_removed.any(), + contract_inds.size(), + dimension_dict); + opt_cost += cost; + cost_list.push_back(cost); + scale_list.push_back(contract.idx_contract.count()); + size_list.push_back(_compute_size_by_dict(contract.new_result, dimension_dict)); + max_i = std::max(max_i, size_list.back()); + max_scale = std::max(max_scale, scale_list.back()); + + std::bitset bcast; + std::vector tmp_inputs; + for (const int& x : contract_inds) { + tmp_inputs.push_back(input_list[x]); + input_list.erase(input_list.begin() + x); + bcast |= broadcast_indices[x]; + broadcast_indices.erase(broadcast_indices.begin() + x); + } + + std::bitset new_bcast_inds = bcast & ~contract.idx_removed; + + // If we're broadcasting, nix blas + bool do_blas; + if ((contract.idx_removed & bcast).any() || + !_tensordot_type_check(operands[0].type_flag_, run_ctx)) { + do_blas = false; + } else { + do_blas = _can_dot(tmp_inputs, contract.new_result, contract.idx_removed); + } + + // Last contraction + std::string idx_result; + if (i + 1 == size_path) { + idx_result = parsed_subscripts[1]; + } else { + idx_result = set2str(contract.new_result); + std::sort(idx_result.begin(), idx_result.end(), + [&dimension_dict](const char& a, const char& b) -> bool { + return dimension_dict[static_cast(a)] < + dimension_dict[static_cast(b)] || + (dimension_dict[static_cast(a)] == + dimension_dict[static_cast(b)] && + a < b); + }); + } + size_t len_idx_result = idx_result.length(); + ret[i].oshape = TShape(len_idx_result, -1); + for (size_t j = 0; j < len_idx_result; ++j) { + ret[i].oshape[j] = dimension_dict[static_cast(idx_result[j])]; + } + + if (do_blas) { + CHECK_EQ(tmp_inputs.size(), 2U) + << "BLAS accepts exactly 2 inputs"; + std::string tensor_result = tmp_inputs[0] + tmp_inputs[1]; + tensor_result.erase(std::remove_if(tensor_result.begin(), + tensor_result.end(), + [&](const char& c) { + return contract.idx_removed.test(static_cast(c));}), + tensor_result.end()); + + // Find indices to contract over + std::vector left_pos, right_pos; + left_pos.reserve(MAXAXIS); + right_pos.reserve(MAXAXIS); + size_t tmp[MAXAXIS] = {0}; + size_t length_left_input = tmp_inputs[0].length(); + size_t length_right_input = tmp_inputs[1].length(); + for (size_t j = 0; j < length_right_input; ++j) { + if (contract.idx_removed.test(static_cast(tmp_inputs[1][j]))) { + tmp[static_cast(tmp_inputs[1][j])] = j; + } + } + for (size_t j = 0; j < length_left_input; ++j) { + if (contract.idx_removed.test(static_cast(tmp_inputs[0][j]))) { + left_pos.push_back(static_cast(j)); + right_pos.push_back(static_cast(tmp[static_cast(tmp_inputs[0][j])])); + } + } + // Calculate left_pos and right_pos + ret[i].left_pos = Tuple(left_pos); + ret[i].right_pos = Tuple(right_pos); + // Calculate do_einsum + ret[i].do_einsum = (tensor_result != idx_result); + // Calculate tshape + CHECK_EQ(tensor_result.length(), len_idx_result) + << "tensordot produces dim " << tensor_result.length() + << ", while einsum produces dim " << len_idx_result << "."; + ret[i].tshape = TShape(len_idx_result, -1); + for (size_t j = 0; j < len_idx_result; ++j) { + ret[i].tshape[j] = dimension_dict[static_cast(tensor_result[j])]; + } + // Calculate blas2einsum_str + ret[i].blas2einsum_str = tensor_result + "->" + idx_result; + ret[i].einsum2blas_str = idx_result + "->" + tensor_result; + } + input_list.push_back(idx_result); + broadcast_indices.push_back(new_bcast_inds); + size_t len_tmp_inputs = tmp_inputs.size(); + for (size_t j = 0; j < len_tmp_inputs; ++j) { + ret[i].einsum_str += tmp_inputs[j]; + if (j + 1 < len_tmp_inputs) { + ret[i].einsum_str += ","; + } + } + ret[i].einsum_str += "->" + idx_result; + ret[i].contract_inds = contract_inds; + ret[i].idx_removed = contract.idx_removed; + ret[i].input_list = input_list; + ret[i].do_blas = do_blas; + } + + if (ret_path == NULL || ret_string_repr == NULL) { + return ret; + } + + // Return the path along with a nice string representation + std::string overall_contraction = parsed_subscripts[0] + "->" + parsed_subscripts[1]; + std::string header[3] = {"scaling", "current", "remaining"}; + + double speedup = 1.0 * naive_cost / (1.0 * opt_cost); + std::ostringstream ss; + ss << " Complete contraction: " << overall_contraction << std::endl; + ss << " Naive scaling: " << indices.count() << std::endl; + ss << " Optimized scaling: " << max_scale << std::endl; + ss.precision(3); + ss << " Naive FLOP count: " << std::scientific << naive_cost << std::endl; + ss << " Optimized FLOP count: " << std::scientific << opt_cost << std::endl; + ss << " Theoretical speedup: " << std::scientific << speedup << std::endl; + ss << " Largest intermediate: " << std::scientific << max_i << "elements" << std::endl; + ss << std::string(74, '-') << std::endl; + ss << std::setw(6) << header[0] << " "; + ss << std::setw(24) << header[1] << " "; + ss << std::setw(40) << header[2] << std::endl; + ss << std::string(74, '-'); + + for (size_t i = 0; i < size_path; ++i) { + ss << std::endl; + ss << std::setw(4) << scale_list[i] << " "; + ss << std::setw(24) << ret[i].einsum_str << " "; + std::string remaining_str; + size_t len_input_list = ret[i].input_list.size(); + for (size_t j = 0; j < len_input_list; ++j) { + remaining_str += ret[i].input_list[j]; + if (j + 1 < len_input_list) { + remaining_str += ","; + } + } + remaining_str += "->" + parsed_subscripts[1]; + ss << std::setw(40) << remaining_str; + } + *ret_string_repr = ss.str(); + *ret_path = path; + return ret; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_EINSUM_PATH_OP_INL_H_ diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h index da3891665c4b..1c62527db43f 100644 --- a/src/operator/numpy/np_tensordot_op-inl.h +++ b/src/operator/numpy/np_tensordot_op-inl.h @@ -140,7 +140,7 @@ inline mxnet::TShape GetReorderedShape(const mxnet::TShape& shape, const mxnet:: } /** - * gets matrix dot. Reshapes tensor a as ad1-by-ad2 matrix, tensor b as bd1-by-bd2 matrix, then + * gets matrix dot. Reshapes tensor a as ad1-by-ad2 matrix, tensor b as bd1-by-bd2 matrix, then * calculates matrix dot a * b and stores in tensor out. */ template @@ -205,7 +205,8 @@ void TensordotImpl(const Tuple& a_axes_summed, const TBlob& a, const TBlob& b, const TBlob& out, - const std::vector& req) { + const std::vector& req, + const Tensor& workspace) { if (req[0] == kNullOp) { return; } @@ -264,10 +265,8 @@ void TensordotImpl(const Tuple& a_axes_summed, mxnet::TShape a_temp_shape = GetReorderedShape(a_shape, a_axes); mxnet::TShape b_temp_shape = GetReorderedShape(b_shape, b_axes); - Tensor workspace = ctx.requested[0].get_space_typed - (Shape1(a.Size() + b.Size()), s); DType* a_ptr = reinterpret_cast(workspace.dptr_); - DType* b_ptr = reinterpret_cast(workspace.dptr_ + a.Size()); + DType* b_ptr = reinterpret_cast(workspace.dptr_ + a.Size() * sizeof(DType)); TBlob a_res = TBlob(a_ptr, a_temp_shape, xpu::kDevMask); TBlob b_res = TBlob(b_ptr, b_temp_shape, xpu::kDevMask); @@ -281,6 +280,46 @@ void TensordotImpl(const Tuple& a_axes_summed, }); } +/** + * Calculates workspace size of tensordot. + */ +template +size_t TensordotWorkspaceSize(const Tuple& a_axes_summed, + const Tuple& b_axes_summed, + const TBlob& a, + const TBlob& b, + const TBlob& out, + const std::vector& req) { + if (req[0] == kNullOp) { + return 0U; + } + + if (out.shape_.Size() == 0U) { + return 0U; // zero-size output, no need to launch kernel + } + + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& b_shape = b.shape_; + + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + if (a_shape.Size() == 0U || b_shape.Size() == 0U) { + // 0-size input + return 0U; + } else if (a_shape.ndim() == 0 && b_shape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + return 0U; + } else if (a_shape.ndim() == 0 || b_shape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + return 0U; + } else { + // Two tensors of at least 1 dimensions. + return (a.Size() + b.Size()) * sizeof(DType); + } + }); + LOG(FATAL) << "InternalError: cannot reach here"; + return 0U; +} + /** * forward function */ @@ -306,7 +345,10 @@ void TensordotOpForward(const nnvm::NodeAttrs& attrs, ShiftAxes(&a_axes_summed, a_shape.ndim()); ShiftAxes(&b_axes_summed, b_shape.ndim()); - TensordotImpl(a_axes_summed, b_axes_summed, ctx, a, b, out, req); + size_t workspace_size = TensordotWorkspaceSize(a_axes_summed, b_axes_summed, a, b, out, req); + Tensor workspace = ctx.requested[0].get_space_typed( + Shape1(workspace_size), ctx.get_stream()); + TensordotImpl(a_axes_summed, b_axes_summed, ctx, a, b, out, req, workspace); } /** @@ -332,7 +374,8 @@ void TensordotBackwardImpl(const Tuple& a_axes_summed, const TBlob& b, const TBlob& grad_a, const TBlob& grad_b, - const std::vector& req) { + const std::vector& req, + const Tensor& workspace) { mshadow::Stream *s = ctx.get_stream(); const mxnet::TShape& a_shape = a.shape_; @@ -366,12 +409,15 @@ void TensordotBackwardImpl(const Tuple& a_axes_summed, const OpReqType& scalar_req = (a_shape.ndim() == 0) ? req[0] : req[1]; ASSIGN_DISPATCH(tensor_grad_, tensor_req, broadcast_scalar(scalar_, tensor_grad_.shape_) * out_grad_); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(out_grad.shape_.Size()), s); - ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * out_grad_); + Tensor dtypespace = + Tensor(reinterpret_cast(workspace.dptr_), + workspace.shape_, + workspace.stride_, + workspace.stream_); + ASSIGN_DISPATCH(dtypespace, kWriteTo, tensor_ * out_grad_); ReduceAxesComputeImpl( - ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); + ctx, {TBlob(dtypespace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_); } else { // Two tensors of at least 1 dimensions. Tuple a_axes_remained; @@ -405,12 +451,13 @@ void TensordotBackwardImpl(const Tuple& a_axes_summed, mxnet::TShape b_temp_shape(GetReorderedShape(b_shape, b_axes)); mxnet::TShape b_T_temp_shape(GetReorderedShape(b_shape, b_T_axes)); - Tensor workspace = ctx.requested[0].get_space_typed - (Shape1((a.Size() + b.Size()) * 2), s); DType* a_ptr = reinterpret_cast(workspace.dptr_); - DType* a_ptr2 = reinterpret_cast(workspace.dptr_ + a.Size()); - DType* b_ptr = reinterpret_cast(workspace.dptr_ + 2 * a.Size()); - DType* b_ptr2 = reinterpret_cast(workspace.dptr_ + 2 * a.Size() + b.Size()); + DType* a_ptr2 = reinterpret_cast(workspace.dptr_ + a.Size() * sizeof(DType)); + DType* b_ptr = reinterpret_cast(workspace.dptr_ + 2 * a.Size() * sizeof(DType)); + DType* b_ptr2 = reinterpret_cast(workspace.dptr_ + + (2 * a.Size() + + b.Size()) * + sizeof(DType)); TBlob a_res = TBlob(a_ptr, a_temp_shape, xpu::kDevMask); TBlob b_res = TBlob(b_ptr, b_temp_shape, xpu::kDevMask); @@ -431,6 +478,39 @@ void TensordotBackwardImpl(const Tuple& a_axes_summed, }); } +/** + * Calculates workspace size of tensordot backward. + */ +template +size_t TensordotBackwardWorkspaceSize(const Tuple& a_axes_summed, + const Tuple& b_axes_summed, + const TBlob& out_grad, + const TBlob& a, + const TBlob& b, + const TBlob& grad_a, + const TBlob& grad_b, + const std::vector& req) { + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& b_shape = b.shape_; + + if ((a_shape.Size() == 0U) || (b_shape.Size() == 0U)) { + return 0U; // zero-size output, no need to launch kernel + } + MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, { + if (a_shape.ndim() == 0 && b_shape.ndim() == 0) { + // Both 0-D scalars, equivalent to multiply + return 0U; + } else if (a_shape.ndim() == 0 || b_shape.ndim() == 0) { + // Either of them is a scalar, just scale by one of them + return out_grad.shape_.Size() * sizeof(DType); + } else { + return (a.Size() + b.Size()) * 2 * sizeof(DType); + } + }); + LOG(FATAL) << "InternalError: cannot reach here"; + return 0U; +} + /** * backward function. */ @@ -458,8 +538,14 @@ void TensordotOpBackward(const nnvm::NodeAttrs& attrs, ShiftAxes(&a_axes_summed, a_shape.ndim()); ShiftAxes(&b_axes_summed, b_shape.ndim()); + size_t workspace_size = TensordotBackwardWorkspaceSize(a_axes_summed, b_axes_summed, + out_grad, a, b, grad_a, + grad_b, req); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), + ctx.get_stream()); TensordotBackwardImpl(a_axes_summed, b_axes_summed, ctx, out_grad, a, b, grad_a, - grad_b, req); + grad_b, req, workspace); } struct TensordotIntAxesParam : public dmlc::Parameter { diff --git a/src/operator/numpy/np_tensordot_op.cc b/src/operator/numpy/np_tensordot_op.cc index 50c1647e0264..aca45c1652ee 100644 --- a/src/operator/numpy/np_tensordot_op.cc +++ b/src/operator/numpy/np_tensordot_op.cc @@ -111,7 +111,7 @@ NNVM_REGISTER_OP(_npi_tensordot) .set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; + return std::vector(1, ResourceRequest::kTempSpace); }) .set_attr("FCompute", TensordotOpForward) .set_attr("FGradient", mxnet::op::ElemwiseGradUseIn{"_backward_npi_tensordot"}) @@ -126,7 +126,7 @@ NNVM_REGISTER_OP(_backward_npi_tensordot) .set_attr("TIsBackward", true) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; + return std::vector(1, ResourceRequest::kTempSpace); }) .set_attr("FCompute", TensordotOpBackward); @@ -211,7 +211,7 @@ NNVM_REGISTER_OP(_npi_tensordot_int_axes) .set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; + return std::vector(1, ResourceRequest::kTempSpace); }) .set_attr("FCompute", TensordotIntAxesOpForward) .set_attr("FGradient", @@ -227,7 +227,7 @@ NNVM_REGISTER_OP(_backward_npi_tensordot_int_axes) .set_attr("TIsBackward", true) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; + return std::vector(1, ResourceRequest::kTempSpace); }) .set_attr("FCompute", TensordotIntAxesOpBackward); diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index ac30f996ba18..f22d42bb678b 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -42,6 +42,78 @@ def get_workloads(name): return OpArgMngr._args.get(name, None) +def _add_workload_einsum(): + chars = 'abcdefghij' + sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4] + size_dict = dict(zip(chars, sizes)) + + configs = [ + ('ij...,j...->ij...', [(2, 3, 4), (3,)]), + ('ij...,...j->ij...', [(2, 3, 4), (3,)]), + ('ij...,j->ij...', [(2, 3, 4), (3,)]), + ('cl, cpx->lpx', [(2, 3), (2, 3, 2731)]), + ('aabb->ab', [(5, 5, 5, 5)]), + ('mi,mi,mi->m', [(5, 5), (5, 5), (5, 5)]), + ('a,ab,abc->abc', None), + ('a,b,ab->ab', None), + ('ea,fb,gc,hd,abcd->efgh', None), + ('ea,fb,abcd,gc,hd->efgh', None), + ('abcd,ea,fb,gc,hd->efgh', None), + # test_complex + ('acdf,jbje,gihb,hfac,gfac,gifabc,hfac', None), + ('acdf,jbje,gihb,hfac,gfac,gifabc,hfac', None), + ('cd,bdhe,aidb,hgca,gc,hgibcd,hgac', None), + ('abhe,hidj,jgba,hiab,gab', None), + ('bde,cdh,agdb,hica,ibd,hgicd,hiac', None), + ('chd,bde,agbc,hiad,hgc,hgi,hiad', None), + ('chd,bde,agbc,hiad,bdi,cgh,agdb', None), + ('bdhe,acad,hiab,agac,hibd', None), + # test_collapse + ('ab,ab,c->', None), + ('ab,ab,c->c', None), + ('ab,ab,cd,cd->', None), + ('ab,ab,cd,cd->ac', None), + ('ab,ab,cd,cd->cd', None), + ('ab,ab,cd,cd,ef,ef->', None), + # test_inner_product + ('ab,ab', None), + ('ab,ba', None), + ('abc,abc', None), + ('abc,bac', None), + ('abc,cba', None), + # test_random_cases + ('aab,fa,df,ecc->bde', None), + ('ecb,fef,bad,ed->ac', None), + ('bcf,bbb,fbf,fc->', None), + ('bb,ff,be->e', None), + ('bcb,bb,fc,fff->', None), + ('fbb,dfd,fc,fc->', None), + ('afd,ba,cc,dc->bf', None), + ('adb,bc,fa,cfc->d', None), + ('bbd,bda,fc,db->acf', None), + ('dba,ead,cad->bce', None), + ('aef,fbc,dca->bde', None), + # test_broadcasting_dot_cases + ('ijk,kl,jl', [(1, 5, 4), (4, 6), (5, 6)]), + ('ijk,kl,jl,i->i', [(1, 5, 4), (4, 6), (5, 6), (10)]), + ('abjk,kl,jl', [(1, 1, 5, 4), (4, 6), (5, 6)]), + ('abjk,kl,jl,ab->ab', [(1, 1, 5, 4), (4, 6), (5, 6), (7, 7)]), + ('obk,ijk->ioj', [(2, 4, 8), (2, 4, 8)]), + ] + for optimize in [False, True]: + for config in configs: + subscripts, args = config + if args is None: + args = [] + terms = subscripts.split('->')[0].split(',') + for term in terms: + dims = [size_dict[x] for x in term] + args.append(np.random.uniform(size=dims)) + else: + args = [np.random.uniform(size=arg) for arg in args] + OpArgMngr.add_workload('einsum', subscripts, *args, optimize=optimize) + + @use_np def _prepare_workloads(): array_pool = { @@ -164,6 +236,7 @@ def _prepare_workloads(): OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3])) OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]), np.array([4, 5, 6, 7])) OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]), np.array([4, 5, 6, 7]), indexing='ij') + _add_workload_einsum() # workloads for array ufunc protocol OpArgMngr.add_workload('add', array_pool['4x1'], array_pool['1x2']) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2303c9cee29c..cc98bffe59ba 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3305,6 +3305,137 @@ def hybrid_forward(self, F, a, *args, **kwargs): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_einsum(): + class TestEinsum(HybridBlock): + def __init__(self, subscripts, optimize): + super(TestEinsum, self).__init__() + self.subscripts = subscripts + self.optimize = optimize + + def hybrid_forward(self, F, *operands): + return F.np.einsum(self.subscripts, *operands, optimize=self.optimize) + + def dbg(name, data): + print('type of {} = {}'.format(name, type(data))) + print('shape of {} = {}'.format(name, data.shape)) + print('{} = {}'.format(name, data)) + + configs = [ + ('ii', [(5, 5)], lambda *args: (_np.eye(5),)), + ('ii->i', [(5, 5)], lambda *args: (_np.eye(5),)), + ('ij->i', [(5, 5)], lambda *args: (_np.ones((5, 5)),)), + ('...j->...', [(5, 5)], lambda *args: (_np.ones((5, 5)),)), + ('ji', [(2, 3)], lambda *args: (_np.ones((2, 3)),)), + ('ij->ji', [(2, 3)], lambda *args: (_np.ones((2, 3)),)), + ('i, i', [(5,), (5,)], lambda *args: (args[1], args[0])), + ('ij, j', [(5, 5), (5,)], lambda *args: (_np.tile(args[1][None, :], [5, 1]), + args[0].sum(axis=0))), + ('...j, j', [(5, 5), (5,)], lambda *args: (_np.tile(args[1][None, :], [5, 1]), + _np.sum(args[0], axis=0))), + ('..., ...', [(), (2, 3)], lambda *args: (_np.sum(args[1], axis=None), + args[0] * _np.ones((2, 3)))), + (', ij', [(), (2, 3)], lambda *args: (_np.sum(args[1], axis=None), + args[0] * _np.ones((2, 3)))), + ('i, j', [(2,), (5, )], lambda *args: (_np.sum(args[1], axis=None) * _np.ones(2), + _np.sum(args[0], axis=None) * _np.ones(5))), + ('ijk, jil->kl', [(3, 4, 5), (4, 3, 2)], lambda *args: (_np.tile(_np.transpose(_np.sum(args[1], + axis=-1))[:, :, None], + [1, 1, 5]), + _np.tile(_np.transpose(_np.sum(args[0], + axis=-1))[:, :, None], + [1, 1, 2]))), + ('ii->i', [(3, 3)], lambda *args: (_np.eye(3),)), + ('ki, jk->ij', [(3, 2), (4, 3)], lambda *args: (_np.tile(args[1].sum(axis=0)[:, None], [1, 2]), + _np.tile(args[0].sum(axis=1)[None, :], [4, 1]))), + ('ki, ...k->i...', [(3, 2), (4, 3)], lambda *args: (_np.tile(args[1].sum(axis=0)[:, None], [1, 2]), + _np.tile(args[0].sum(axis=1)[None, :], [4, 1]))), + ('k..., jk', [(3, 2), (4, 3)], lambda *args: (_np.tile(args[1].sum(axis=0)[:, None], [1, 2]), + _np.tile(args[0].sum(axis=1)[None, :], [4, 1]))), + ('ij, jk', [(5, 0), (0, 4)], lambda *args: (_np.empty((5, 0)), _np.empty((0, 4)))), + (('ij,jk,kl->il'), [(2, 2), (2, 5), (5, 2)], lambda *args: (_np.dot(_np.ones((2, 2)), _np.dot(args[1], args[2]).T), + _np.dot(args[0].T, _np.dot(_np.ones((2, 2)), args[2].T)), + _np.dot(_np.dot(args[0], args[1]).T, _np.ones((2, 2))))), + # broadcast bug + (('ij, ij -> i'), [(1, 4), (2, 4)], lambda *args: (_np.sum(args[1], axis=0)[None, :], + _np.tile(args[0], [2, 1]))), + ] + dtypes = ['int32', 'float16', 'float32', 'float64'] + for hybridize in [False, True]: + for dtype in dtypes: + for config in configs: + for optimize in [False, True]: + rtol = 1e-0 if dtype == 'float16' else 1e-3 + atol = 1e-1 if dtype == 'float16' else 1e-5 + (subscripts, operands, get_grad) = config + test_einsum = TestEinsum(subscripts, optimize) + if hybridize: + test_einsum.hybridize() + x = [] + x_np = [] + for shape in operands: + x_np.append(_np.array(_np.random.uniform(-10.0, 10.0, shape), + dtype=dtype)) + x.append(np.array(x_np[-1], dtype=dtype)) + x[-1].attach_grad() + expected_np = _np.einsum(subscripts, *x_np, optimize=optimize) + with mx.autograd.record(): + out_mx = test_einsum(*x) + assert out_mx.shape == expected_np.shape + assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol) + out_mx.backward() + for (iop, op) in enumerate(x): + assert_almost_equal(op.grad.asnumpy(), get_grad(*x_np)[iop], rtol=rtol, atol=atol) + + # Test imperative once again + for op in x: + op.attach_grad() + with mx.autograd.record(): + out_mx = np.einsum(subscripts, *x, optimize=optimize) + out_mx.backward() + expected_np = _np.einsum(subscripts, *x_np, optimize=optimize) + assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol) + for (iop, op) in enumerate(x): + assert_almost_equal(op.grad.asnumpy(), get_grad(*x_np)[iop], rtol=rtol, atol=atol) + configs = [ + (('ij,jk,kl->il'), [(2, 2), (2, 5), (5, 2)]), + (('ea,fb,abcd,gc,hd->efgh'), [(5, 5), (5, 5), (5, 5, 5, 5), (5, 5), (5, 5)]), + ] + dtypes = ['int32', 'float32', 'float64'] + for hybridize in [False, True]: + for dtype in dtypes: + for config in configs: + (subscripts, operands) = config + rtol = 1e-0 if dtype == 'float16' else 1e-2 + atol = 1e-1 if dtype == 'float16' else 1e-2 + grad = [] + x_np = [] + for shape in operands: + x_np.append(_np.array(_np.random.uniform(-2.0, 2.0, shape), + dtype=dtype)) + for optimize in [False, True]: + x = [] + for (iop, op) in enumerate(operands): + x.append(np.array(x_np[iop], dtype=dtype)) + x[-1].attach_grad() + test_einsum = TestEinsum(subscripts, optimize) + if hybridize: + test_einsum.hybridize() + expected_np = _np.einsum(subscripts, *x_np, optimize=optimize) + with mx.autograd.record(): + out_mx = test_einsum(*x) + assert out_mx.shape == expected_np.shape + assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol) + out_mx.backward() + cur_grad = [] + for (iop, op) in enumerate(x): + cur_grad.append(op.grad.asnumpy()) + grad.append(cur_grad) + for (iop, op) in enumerate(grad[0]): + assert_almost_equal(grad[0][iop], grad[1][iop], rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule()