diff --git a/benchmark/python/ffi/benchmark_ffi.py b/benchmark/python/ffi/benchmark_ffi.py index 818e5a621aeb..42b9fcc16d7f 100644 --- a/benchmark/python/ffi/benchmark_ffi.py +++ b/benchmark/python/ffi/benchmark_ffi.py @@ -66,6 +66,7 @@ def prepare_workloads(): OpArgMngr.add_workload("average", pool['2x2'], weights=pool['2'], axis=1, returned=True) OpArgMngr.add_workload("histogram", pool['2x2'], bins=10, range=(0.0, 10.0)) OpArgMngr.add_workload("add", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("cross", pool['2'], pool['2']) OpArgMngr.add_workload("linalg.eig", pool['3x3']) OpArgMngr.add_workload("linalg.eigh", pool['3x3']) OpArgMngr.add_workload("linalg.det", pool['3x3']) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 06b8d5f844e9..a311dc1b80fc 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -43,7 +43,7 @@ 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', - 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron', + 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp', 'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', @@ -6282,6 +6282,46 @@ def ldexp(x1, x2, out=None, **kwargs): return _api_internal.ldexp(x1, x2, out) +@set_module('mxnet.ndarray.numpy') +def vdot(a, b): + r""" + Return the dot product of two vectors. + Note that `vdot` handles multidimensional arrays differently than `dot`: + it does *not* perform a matrix product, but flattens input arguments + to 1-D vectors first. Consequently, it should only be used for vectors. + + Parameters + ---------- + a : ndarray + First argument to the dot product. + b : ndarray + Second argument to the dot product. + + Returns + ------- + output : ndarray + Dot product of `a` and `b`. + + See Also + -------- + dot : Return the dot product without using the complex conjugate of the + first argument. + + Examples + -------- + Note that higher-dimensional arrays are flattened! + >>> a = np.array([[1, 4], [5, 6]]) + >>> b = np.array([[4, 1], [2, 2]]) + >>> np.vdot(a, b) + 30 + >>> np.vdot(b, a) + 30 + >>> 1*4 + 4*1 + 5*2 + 6*2 + 30 + """ + return tensordot(a.flatten(), b.flatten(), 1) + + @set_module('mxnet.ndarray.numpy') def inner(a, b): r""" @@ -6389,25 +6429,135 @@ def outer(a, b): return tensordot(a.reshape_view((-1, )), b.reshape_view((-1, )), 0) +@set_module('mxnet.ndarray.numpy') +def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=too-many-arguments + """ + Return the cross product of two (arrays of) vectors. + + The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular + to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors + are defined by the last axis of `a` and `b` by default, and these axis + can have dimensions 2 or 3. Where the dimension of either `a` or `b` is + 2, the third component of the input vector is assumed to be zero and the + cross product calculated accordingly. In cases where both input vectors + have dimension 2, the z-component of the cross product is returned. + + Parameters + ---------- + a : ndarray + Components of the first vector(s). + b : ndarray + Components of the second vector(s). + axisa : int, optional + Axis of `a` that defines the vector(s). By default, the last axis. + axisb : int, optional + Axis of `b` that defines the vector(s). By default, the last axis. + axisc : int, optional + Axis of `c` containing the cross product vector(s). Ignored if + both input vectors have dimension 2, as the return is scalar. + By default, the last axis. + axis : int, optional + If defined, the axis of `a`, `b` and `c` that defines the vector(s) + and cross product(s). Overrides `axisa`, `axisb` and `axisc`. + + Returns + ------- + c : ndarray + Vector cross product(s). + + Raises + ------ + ValueError + When the dimension of the vector(s) in `a` and/or `b` does not + equal 2 or 3. + + Notes + ----- + Supports full broadcasting of the inputs. + + Examples + -------- + Vector cross-product. + + >>> x = np.array([1., 2., 3.]) + >>> y = np.array([4., 5., 6.]) + >>> np.cross(x, y) + array([-3., 6., -3.]) + + One vector with dimension 2. + + >>> x = np.array([1., 2.]) + >>> y = np.array([4., 5., 6.]) + >>> np.cross(x, y) + array([12., -6., -3.]) + + Equivalently: + + >>> x = np.array([1., 2., 0.]) + >>> y = np.array([4., 5., 6.]) + >>> np.cross(x, y) + array([12., -6., -3.]) + + Both vectors with dimension 2. + + >>> x = np.array([1., 2.]) + >>> y = np.array([4., 5.]) + >>> np.cross(x, y) + array(-3.) + + Multiple vector cross-products. Note that the direction of the cross + product vector is defined by the `right-hand rule`. + + >>> x = np.array([[1., 2., 3.], [4., 5., 6.]]) + >>> y = np.array([[4., 5., 6.], [1., 2., 3.]]) + >>> np.cross(x, y) + array([[-3., 6., -3.], + [ 3., -6., 3.]]) + + The orientation of `c` can be changed using the `axisc` keyword. + + >>> np.cross(x, y, axisc=0) + array([[-3., 3.], + [ 6., -6.], + [-3., 3.]]) + + Change the vector definition of `x` and `y` using `axisa` and `axisb`. + + >>> x = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) + >>> y = np.array([[7., 8., 9.], [4., 5., 6.], [1., 2., 3.]]) + >>> np.cross(x, y) + array([[ -6., 12., -6.], + [ 0., 0., 0.], + [ 6., -12., 6.]]) + >>> np.cross(x, y, axisa=0, axisb=0) + array([[-24., 48., -24.], + [-30., 60., -30.], + [-36., 72., -36.]]) + """ + if axis is not None: + axisa, axisb, axisc = (axis,) * 3 + + if isinstance(a, NDArray) and isinstance(b, NDArray): + return _api_internal.cross(a, b, axisa, axisb, axisc) + else: + raise TypeError("Input data should be NDarray") + + @set_module('mxnet.ndarray.numpy') def kron(a, b): r""" Kronecker product of two arrays. Computes the Kronecker product, a composite array made of blocks of the second array scaled by the first. - Parameters ---------- a, b : ndarray - Returns ------- out : ndarray - See Also -------- outer : The outer product - Notes ----- The function assumes that the number of dimensions of `a` and `b` @@ -6423,7 +6573,6 @@ def kron(a, b): [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], [ ... ... ], [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] - Examples -------- >>> np.kron([1,10,100], [5,6,7]) @@ -6434,46 +6583,6 @@ def kron(a, b): return _api_internal.kron(a, b) -@set_module('mxnet.ndarray.numpy') -def vdot(a, b): - r""" - Return the dot product of two vectors. - Note that `vdot` handles multidimensional arrays differently than `dot`: - it does *not* perform a matrix product, but flattens input arguments - to 1-D vectors first. Consequently, it should only be used for vectors. - - Parameters - ---------- - a : ndarray - First argument to the dot product. - b : ndarray - Second argument to the dot product. - - Returns - ------- - output : ndarray - Dot product of `a` and `b`. - - See Also - -------- - dot : Return the dot product without using the complex conjugate of the - first argument. - - Examples - -------- - Note that higher-dimensional arrays are flattened! - >>> a = np.array([[1, 4], [5, 6]]) - >>> b = np.array([[4, 1], [2, 2]]) - >>> np.vdot(a, b) - 30 - >>> np.vdot(b, a) - 30 - >>> 1*4 + 4*1 + 5*2 + 6*2 - 30 - """ - return tensordot(a.flatten(), b.flatten(), 1) - - @set_module('mxnet.ndarray.numpy') def equal(x1, x2, out=None): """ diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 99a67444f6a8..4b89848fe98d 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -68,8 +68,8 @@ 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron', - 'equal', 'not_equal', 'interp', + 'unique', 'lcm', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', + 'cross', 'kron', 'equal', 'not_equal', 'interp', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount', @@ -8118,6 +8118,47 @@ def ldexp(x1, x2, out=None, **kwargs): return _mx_nd_np.ldexp(x1, x2, out) +@set_module('mxnet.numpy') +def vdot(a, b): + r""" + Return the dot product of two vectors. + Note that `vdot` handles multidimensional arrays differently than `dot`: + it does *not* perform a matrix product, but flattens input arguments + to 1-D vectors first. Consequently, it should only be used for vectors. + + Parameters + ---------- + a : ndarray + First argument to the dot product. + b : ndarray + Second argument to the dot product. + + Returns + ------- + output : ndarray + Dot product of `a` and `b`. + + See Also + -------- + dot : Return the dot product without using the complex conjugate of the + first argument. + + Examples + -------- + Note that higher-dimensional arrays are flattened! + + >>> a = np.array([[1, 4], [5, 6]]) + >>> b = np.array([[4, 1], [2, 2]]) + >>> np.vdot(a, b) + array(30.) + >>> np.vdot(b, a) + array(30.) + >>> 1*4 + 4*1 + 5*2 + 6*2 + 30 + """ + return tensordot(a.flatten(), b.flatten(), 1) + + @set_module('mxnet.numpy') def inner(a, b): r"""Inner product of two arrays. @@ -8230,25 +8271,129 @@ def outer(a, b): return tensordot(a.flatten(), b.flatten(), 0) +@set_module('mxnet.numpy') +def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=too-many-arguments + """ + Return the cross product of two (arrays of) vectors. + + The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular + to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors + are defined by the last axis of `a` and `b` by default, and these axes + can have dimensions 2 or 3. Where the dimension of either `a` or `b` is + 2, the third component of the input vector is assumed to be zero and the + cross product calculated accordingly. In cases where both input vectors + have dimension 2, the z-component of the cross product is returned. + + Parameters + ---------- + a : ndarray + Components of the first vector(s). + b : ndarray + Components of the second vector(s). + axisa : int, optional + Axis of `a` that defines the vector(s). By default, the last axis. + axisb : int, optional + Axis of `b` that defines the vector(s). By default, the last axis. + axisc : int, optional + Axis of `c` containing the cross product vector(s). Ignored if + both input vectors have dimension 2, as the return is scalar. + By default, the last axis. + axis : int, optional + If defined, the axis of `a`, `b` and `c` that defines the vector(s) + and cross product(s). Overrides `axisa`, `axisb` and `axisc`. + + Returns + ------- + c : ndarray + Vector cross product(s). + + Raises + ------ + ValueError + When the dimension of the vector(s) in `a` and/or `b` does not + equal 2 or 3. + + Notes + ----- + Supports full broadcasting of the inputs. + + Examples + -------- + Vector cross-product. + + >>> x = np.array([1., 2., 3.]) + >>> y = np.array([4., 5., 6.]) + >>> np.cross(x, y) + array([-3., 6., -3.]) + + One vector with dimension 2. + + >>> x = np.array([1., 2.]) + >>> y = np.array([4., 5., 6.]) + >>> np.cross(x, y) + array([12., -6., -3.]) + + Equivalently: + + >>> x = np.array([1., 2., 0.]) + >>> y = np.array([4., 5., 6.]) + >>> np.cross(x, y) + array([12., -6., -3.]) + + Both vectors with dimension 2. + + >>> x = np.array([1., 2.]) + >>> y = np.array([4., 5.]) + >>> np.cross(x, y) + array(-3.) + + Multiple vector cross-products. Note that the direction of the cross + product vector is defined by the `right-hand rule`. + + >>> x = np.array([[1., 2., 3.], [4., 5., 6.]]) + >>> y = np.array([[4., 5., 6.], [1., 2., 3.]]) + >>> np.cross(x, y) + array([[-3., 6., -3.], + [ 3., -6., 3.]]) + + The orientation of `c` can be changed using the `axisc` keyword. + + >>> np.cross(x, y, axisc=0) + array([[-3., 3.], + [ 6., -6.], + [-3., 3.]]) + + Change the vector definition of `x` and `y` using `axisa` and `axisb`. + + >>> x = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) + >>> y = np.array([[7., 8., 9.], [4., 5., 6.], [1., 2., 3.]]) + >>> np.cross(x, y) + array([[ -6., 12., -6.], + [ 0., 0., 0.], + [ 6., -12., 6.]]) + >>> np.cross(x, y, axisa=0, axisb=0) + array([[-24., 48., -24.], + [-30., 60., -30.], + [-36., 72., -36.]]) + """ + return _mx_nd_np.cross(a, b, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis) + + @set_module('mxnet.numpy') def kron(a, b): r""" Kronecker product of two arrays. Computes the Kronecker product, a composite array made of blocks of the second array scaled by the first. - Parameters ---------- a, b : ndarray - Returns ------- out : ndarray - See Also -------- outer : The outer product - Notes ----- The function assumes that the number of dimensions of `a` and `b` @@ -8264,7 +8409,6 @@ def kron(a, b): [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], [ ... ... ], [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] - Examples -------- >>> np.kron([1,10,100], [5,6,7]) @@ -8275,47 +8419,6 @@ def kron(a, b): return _mx_nd_np.kron(a, b) -@set_module('mxnet.numpy') -def vdot(a, b): - r""" - Return the dot product of two vectors. - Note that `vdot` handles multidimensional arrays differently than `dot`: - it does *not* perform a matrix product, but flattens input arguments - to 1-D vectors first. Consequently, it should only be used for vectors. - - Parameters - ---------- - a : ndarray - First argument to the dot product. - b : ndarray - Second argument to the dot product. - - Returns - ------- - output : ndarray - Dot product of `a` and `b`. - - See Also - -------- - dot : Return the dot product without using the complex conjugate of the - first argument. - - Examples - -------- - Note that higher-dimensional arrays are flattened! - - >>> a = np.array([[1, 4], [5, 6]]) - >>> b = np.array([[4, 1], [2, 2]]) - >>> np.vdot(a, b) - array(30.) - >>> np.vdot(b, a) - array(30.) - >>> 1*4 + 4*1 + 5*2 + 6*2 - 30 - """ - return tensordot(a.flatten(), b.flatten(), 1) - - @set_module('mxnet.numpy') def equal(x1, x2, out=None): """ diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index e693a00ea1a5..26257cb79f2f 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -194,6 +194,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'isneginf', 'isinf', 'pad', + 'cross', ] diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 595186518fc2..76beac98fe05 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -49,7 +49,7 @@ 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'interp', - 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'kron', + 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', @@ -5690,6 +5690,46 @@ def ldexp(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out) +@set_module('mxnet.symbol.numpy') +def vdot(a, b): + r""" + Return the dot product of two vectors. + Note that `vdot` handles multidimensional arrays differently than `dot`: + it does *not* perform a matrix product, but flattens input arguments + to 1-D vectors first. Consequently, it should only be used for vectors. + + Parameters + ---------- + a : _Symbol + First argument to the dot product. + b : _Symbol + Second argument to the dot product. + + Returns + ------- + output : _Symbol + Dot product of `a` and `b`. + + See Also + -------- + dot : Return the dot product without using the complex conjugate of the + first argument. + + Examples + -------- + Note that higher-dimensional arrays are flattened! + >>> a = np.array([[1, 4], [5, 6]]) + >>> b = np.array([[4, 1], [2, 2]]) + >>> np.vdot(a, b) + 30 + >>> np.vdot(b, a) + 30 + >>> 1*4 + 4*1 + 5*2 + 6*2 + 30 + """ + return tensordot(a.flatten(), b.flatten(), 1) + + @set_module('mxnet.symbol.numpy') def inner(a, b): r"""Inner product of two arrays. @@ -5798,6 +5838,58 @@ def outer(a, b): return tensordot(a.flatten(), b.flatten(), 0) +@set_module('mxnet.symbol.numpy') +def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=too-many-arguments + """ + Return the cross product of two (arrays of) vectors. + + The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular + to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors + are defined by the last axis of `a` and `b` by default, and these axes + can have dimensions 2 or 3. Where the dimension of either `a` or `b` is + 2, the third component of the input vector is assumed to be zero and the + cross product calculated accordingly. In cases where both input vectors + have dimension 2, the z-component of the cross product is returned. + + Parameters + ---------- + a : _Symbol + Components of the first vector(s). + b : _Symbol + Components of the second vector(s). + axisa : int, optional + Axis of `a` that defines the vector(s). By default, the last axis. + axisb : int, optional + Axis of `b` that defines the vector(s). By default, the last axis. + axisc : int, optional + Axis of `c` containing the cross product vector(s). Ignored if + both input vectors have dimension 2, as the return is scalar. + By default, the last axis. + axis : int, optional + If defined, the axis of `a`, `b` and `c` that defines the vector(s) + and cross product(s). Overrides `axisa`, `axisb` and `axisc`. + + Returns + ------- + c : _Symbol + Vector cross product(s). + + Raises + ------ + ValueError + When the dimension of the vector(s) in `a` and/or `b` does not + equal 2 or 3. + + Notes + ----- + Supports full broadcasting of the inputs. + """ + if axis is not None: + axisa, axisb, axisc = (axis,) * 3 + + return _npi.cross(a, b, axisa, axisb, axisc) + + @set_module('mxnet.symbol.numpy') def kron(a, b): r""" @@ -5805,19 +5897,15 @@ def kron(a, b): Kronecker product of two arrays. Computes the Kronecker product, a composite array made of blocks of the second array scaled by the first. - Parameters ---------- a, b : ndarray - Returns ------- out : ndarray - See Also -------- outer : The outer product - Notes ----- The function assumes that the number of dimensions of `a` and `b` @@ -5833,7 +5921,6 @@ def kron(a, b): [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], [ ... ... ], [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]] - Examples -------- >>> np.kron([1,10,100], [5,6,7]) @@ -5844,46 +5931,6 @@ def kron(a, b): return _npi.kron(a, b) -@set_module('mxnet.symbol.numpy') -def vdot(a, b): - r""" - Return the dot product of two vectors. - Note that `vdot` handles multidimensional arrays differently than `dot`: - it does *not* perform a matrix product, but flattens input arguments - to 1-D vectors first. Consequently, it should only be used for vectors. - - Parameters - ---------- - a : _Symbol - First argument to the dot product. - b : _Symbol - Second argument to the dot product. - - Returns - ------- - output : _Symbol - Dot product of `a` and `b`. - - See Also - -------- - dot : Return the dot product without using the complex conjugate of the - first argument. - - Examples - -------- - Note that higher-dimensional arrays are flattened! - >>> a = np.array([[1, 4], [5, 6]]) - >>> b = np.array([[4, 1], [2, 2]]) - >>> np.vdot(a, b) - 30 - >>> np.vdot(b, a) - 30 - >>> 1*4 + 4*1 + 5*2 + 6*2 - 30 - """ - return tensordot(a.flatten(), b.flatten(), 1) - - @set_module('mxnet.symbol.numpy') def equal(x1, x2, out=None): """ diff --git a/src/api/operator/numpy/np_cross.cc b/src/api/operator/numpy/np_cross.cc new file mode 100644 index 000000000000..0dd4644cad59 --- /dev/null +++ b/src/api/operator/numpy/np_cross.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_cross.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_cross.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/numpy/np_cross-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.cross") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npi_cross"); + op::NumpyCrossParam param; + param.axisa = args[2].operator int(); + param.axisb = args[3].operator int(); + param.axisc = args[4].operator int(); + attrs.op = op; + attrs.parsed = param; + SetAttrDict(&attrs); + int num_inputs = 2; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +} // namespace mxnet diff --git a/src/operator/numpy/np_cross-inl.h b/src/operator/numpy/np_cross-inl.h new file mode 100644 index 000000000000..c2092bbfec23 --- /dev/null +++ b/src/operator/numpy/np_cross-inl.h @@ -0,0 +1,1387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file np_cross-inl.h + * \brief Function definition of cross product of two (arrays of) vectors + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_CROSS_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_CROSS_INL_H_ + +#include +#include +#include +#include +#include "../mshadow_op.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" +#include "../tensor/broadcast_reduce_op.h" +#include "../tensor/elemwise_binary_broadcast_op.h" +#include "../tensor/matrix_op-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow; + +struct NumpyCrossParam : public dmlc::Parameter { + int axisa, axisb, axisc; + DMLC_DECLARE_PARAMETER(NumpyCrossParam) { + DMLC_DECLARE_FIELD(axisa) + .set_default(-1) + .describe("Axis of `a` that defines the vector(s). By default, the last axis."); + DMLC_DECLARE_FIELD(axisb) + .set_default(-1) + .describe("Axis of `b` that defines the vector(s). By default, the last axis."); + DMLC_DECLARE_FIELD(axisc) + .set_default(-1) + .describe("Axis of `c` containing the cross product vector(s)." + "Ignored if both input vectors have dimension 2, as the return is scalar." + "By default, the last axis."); + } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axisa_s, axisb_s, axisc_s; + axisa_s << axisa; + axisb_s << axisb; + axisc_s << axisc; + (*dict)["axisa"] = axisa_s.str(); + (*dict)["axisb"] = axisb_s.str(); + (*dict)["axisc"] = axisc_s.str(); + } +}; + +#define SUM_NDIM_SWITCH(ndim, NDim, ...) \ + if (ndim == 1) { \ + const int NDim = 1; \ + {__VA_ARGS__} \ + } else if (ndim == 2) { \ + const int NDim = 2; \ + {__VA_ARGS__} \ + } else if (ndim == 3) { \ + const int NDim = 3; \ + {__VA_ARGS__} \ + } else if (ndim == 4) { \ + const int NDim = 4; \ + {__VA_ARGS__} \ + } else if (ndim <= broadcast::MAX_DIM) { \ + const int NDim = broadcast::MAX_DIM; \ + {__VA_ARGS__} \ + } else { \ + LOG(FATAL) << "NDim too large "; \ + } + +struct CrossInAssign { + template + MSHADOW_XINLINE static void Map(int i, const DType *in_ptr, DType *out_ptr, + const int stride, const int index, const int msize) { + if (index < stride && i * stride + index < msize) { + out_ptr[i] = in_ptr[i * stride + index]; + } + } +}; + +template +struct CrossOutAssign { + template + MSHADOW_XINLINE static void Map(int i, const DType *in_ptr, DType *out_ptr, + const int positive, const int stride, + const int index, const int msize) { + if (index < stride && i * stride + index < msize) { + KERNEL_ASSIGN(out_ptr[i * stride + index], req, positive == 1 ? in_ptr[i] : -in_ptr[i]); + } + } +}; + +template +struct ResAssign { + template + MSHADOW_XINLINE static void Map(int i, const DType *in_data, DType *out_data) { + KERNEL_ASSIGN(out_data[i], req, in_data[i]); + } +}; + +struct DeleteAssign { + template + MSHADOW_XINLINE static void Map(int i, const DType *in_data, DType *out_data, + const int in_stride, const int out_stride) { + const DType *in_ptr = in_data + i * in_stride; + DType *out_ptr = out_data + i * out_stride; + if (in_stride == out_stride + 1) { + for (int idx = 0; idx < out_stride; ++idx) { + out_ptr[idx] = in_ptr[idx]; + } + } + } +}; + +// Get moveaxis index. +inline mxnet::Tuple GetMoveaxisIndex(const int& source, + const int& destination, + const mxnet::TShape& shape) { + const int ndim = shape.ndim(); + const int src_axis = CheckAxis(source, ndim); + const int dest_axis = CheckAxis(destination, ndim); + std::vector moveaxis_index_vec; + for (int i = 0; i < ndim; ++i) { + if (i != src_axis) { moveaxis_index_vec.push_back(i); } + } + moveaxis_index_vec.insert(moveaxis_index_vec.begin() + dest_axis, src_axis); + return mxnet::Tuple(moveaxis_index_vec); +} + +// Get moveaxis shape. +inline mxnet::TShape GetMoveaxisShape(const Tuple& moveaxis_index, + const mxnet::TShape& org_shape) { + const int ndim = org_shape.ndim(); + if (ndim == 0) { return mxnet::TShape(0, 0); } + CHECK_EQ(moveaxis_index.ndim(), org_shape.ndim()) << "moveaxis index dismatch original shape."; + std::vector moveaxis_shape_vec(ndim, -1); + for (int i = 0; i < ndim; ++i) { + moveaxis_shape_vec[i] = org_shape[moveaxis_index[i]]; + } + return mxnet::TShape(moveaxis_shape_vec.begin(), moveaxis_shape_vec.end()); +} + +// Get or check broadcast shape for cross product. +inline void GetOrCheckLRShape(const nnvm::NodeAttrs& attrs, + const mxnet::TShape& a_moveaxis_shape, + const mxnet::TShape& b_moveaxis_shape, + mxnet::TShape *c_shape_ptr = nullptr) { + const int a_ndim = a_moveaxis_shape.ndim(); + const int b_ndim = b_moveaxis_shape.ndim(); + mxnet::TShape a_cutoff_shape(a_ndim - 1, -1); + mxnet::TShape b_cutoff_shape(b_ndim - 1, -1); + for (int i = 0; i < a_ndim - 1; ++i) { + a_cutoff_shape[i] = a_moveaxis_shape[i]; + } + for (int i = 0; i < b_ndim - 1; ++i) { + b_cutoff_shape[i] = b_moveaxis_shape[i]; + } + mxnet::ShapeVector in_shape_vec({ a_cutoff_shape, b_cutoff_shape}); + mxnet::ShapeVector out_shape_vec({ mxnet::TShape() }); + mxnet::op::BinaryBroadcastShape(attrs, &in_shape_vec, &out_shape_vec); + if (c_shape_ptr && (a_moveaxis_shape[a_ndim - 1] == 3 || b_moveaxis_shape[b_ndim - 1] == 3)) { + mxnet::TShape c_shape(out_shape_vec[0].ndim() + 1, -1); + for (int i = 0; i < c_shape.ndim() - 1; ++i) { + c_shape[i] = out_shape_vec[0][i]; + } + c_shape[c_shape.ndim() - 1] = 3; + *c_shape_ptr = c_shape; + } else { + *c_shape_ptr = out_shape_vec[0]; + } +} + +// Get data[..., 0] shape. +inline mxnet::TShape GetCutoffShape(const mxnet::TShape& shape) { + if (shape.ndim() == 0 || !ndim_is_known(shape)) { return mxnet::TShape(0, 0); } + mxnet::TShape cutoff_shape(shape.ndim() - 1, -1); + for (int i = 0; i < shape.ndim() - 1; ++i) { cutoff_shape[i] = shape[i]; } + return cutoff_shape; +} + +template +inline size_t AligndWorkspaceSize(const size_t& offset, + const size_t& add_size) { + size_t m32 = 32; + size_t wks = offset + add_size; + return ((wks * sizeof(DType) + m32 - 1) / m32 * m32 + sizeof(DType) - 1) / sizeof(DType); +} + +// Calculate workspace size for numpy cross forward op. +template +inline size_t NumpyCrossWorkspaceSize(const mxnet::TShape& a_moveaxis_shape, + const mxnet::TShape& b_moveaxis_shape, + const mxnet::TShape& c_moveaxis_shape, + const int& a_axis, + const int& b_axis, + const OpContext& ctx, + const std::vector& req) { + if (kNullOp == req[0]) { return 0U; } + // Zero-size input, no need to launch kernel + if (0U == a_moveaxis_shape.Size() || 0U == b_moveaxis_shape.Size()) { return 0U; } + + size_t workspace_size = 0; + const int a_ndim = a_moveaxis_shape.ndim(); + const int b_ndim = b_moveaxis_shape.ndim(); + const int c_ndim = c_moveaxis_shape.ndim(); + + if (ctx.run_ctx.get_ctx().dev_mask() == cpu::kDevMask) { + if (a_moveaxis_shape[a_ndim - 1] == 2 && b_moveaxis_shape[b_ndim - 1] == 2) { + // Case 1: a.shape[-1] == 2 and b.shape[-1] == 2, param.axisc is ignored. + workspace_size += a_moveaxis_shape.ProdShape(0, a_ndim - 1); + workspace_size += b_moveaxis_shape.ProdShape(0, b_ndim - 1); + workspace_size += c_moveaxis_shape.Size(); + } else { + // Case 2, 3, 4: a.shape[-1] == 3 or b.shape[-1] == 3, param.axisc is not ignored. + workspace_size += a_moveaxis_shape.ProdShape(0, a_ndim - 1); + workspace_size += b_moveaxis_shape.ProdShape(0, b_ndim - 1); + workspace_size += c_moveaxis_shape.Size(); + workspace_size += 3 * c_moveaxis_shape.ProdShape(0, c_ndim - 1); + } + if (a_axis != a_ndim -1 || b_axis != b_ndim - 1) { + workspace_size += a_moveaxis_shape.Size(); + workspace_size += b_moveaxis_shape.Size(); + } + } else { + if (a_moveaxis_shape[a_ndim - 1] == 2 && b_moveaxis_shape[b_ndim - 1] == 2) { + // Case 1: a.shape[-1] == 2 and b.shape[-1] == 2, param.axisc is ignored. + workspace_size = AligndWorkspaceSize(workspace_size, + a_moveaxis_shape.ProdShape(0, a_ndim - 1)); + workspace_size = AligndWorkspaceSize(workspace_size, + b_moveaxis_shape.ProdShape(0, b_ndim - 1)); + workspace_size = AligndWorkspaceSize(workspace_size, + c_moveaxis_shape.Size()); + } else { + // Case 2, 3, 4: a.shape[-1] == 3 or b.shape[-1] == 3, param.axisc is not ignored. + workspace_size = AligndWorkspaceSize(workspace_size, + a_moveaxis_shape.ProdShape(0, a_ndim - 1)); + workspace_size = AligndWorkspaceSize(workspace_size, + b_moveaxis_shape.ProdShape(0, b_ndim - 1)); + for (int i = 0; i < 3; ++i) { + workspace_size = AligndWorkspaceSize(workspace_size, + c_moveaxis_shape.ProdShape(0, c_ndim - 1)); + } + workspace_size = AligndWorkspaceSize(workspace_size, + c_moveaxis_shape.Size()); + } + if (a_axis != a_ndim -1 || b_axis != b_ndim - 1) { + workspace_size = AligndWorkspaceSize(workspace_size, a_moveaxis_shape.Size()); + workspace_size = AligndWorkspaceSize(workspace_size, b_moveaxis_shape.Size()); + } + } + return workspace_size; +} + +template +struct NumpyCrossForwardImpl { + static void op(const TBlob& a, const TBlob& b, const TBlob& c, + const std::vector >& moveaxis_index_vec, + const std::vector& moveaxis_shape_vec, + const int a_axis, const int b_axis, const int c_axis, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& req, + const Tensor& workspace) { + CHECK(a_dim == 3 || b_dim == 3) + << "no specialized NumpyCrossOp defined for template parameters."; + Stream *s = ctx.get_stream(); + const Tuple& a_moveaxis_index = moveaxis_index_vec[0]; + const Tuple& b_moveaxis_index = moveaxis_index_vec[1]; + const mxnet::TShape& a_moveaxis_shape = moveaxis_shape_vec[0]; + const mxnet::TShape& b_moveaxis_shape = moveaxis_shape_vec[1]; + const mxnet::TShape& c_moveaxis_shape = moveaxis_shape_vec[2]; + const int a_ndim = a_moveaxis_shape.ndim(); + const int b_ndim = b_moveaxis_shape.ndim(); + const int c_ndim = b_moveaxis_shape.ndim(); + CHECK_EQ(c_moveaxis_shape[c_ndim - 1], 3) + << "no specialized NumpyCrossOp defined for template parameters."; + + TBlob aw_data, bw_data, c_data, cw_data, a_data, b_data; + std::vector cw_data_vec; + if (ctx.run_ctx.get_ctx().dev_mask() == cpu::kDevMask) { + // Allocate workspace in cpu, no need to align address. + DType *aw_ptr = workspace.dptr_; + DType *bw_ptr = aw_ptr + a_moveaxis_shape.ProdShape(0, a_ndim - 1); + DType *cw_ptr = bw_ptr + b_moveaxis_shape.ProdShape(0, b_ndim - 1); + DType *c_ptr = cw_ptr + 3 * c_moveaxis_shape.ProdShape(0, c_ndim - 1); + a_data = a; + b_data = b; + if (a_axis != a_ndim -1 || b_axis != b_ndim - 1) { + DType *a_ptr = c_ptr + c_moveaxis_shape.Size(); + DType *b_ptr = a_ptr + a_moveaxis_shape.Size(); + a_data = TBlob(a_ptr, a_moveaxis_shape, a.dev_mask(), a.dev_id()); + b_data = TBlob(b_ptr, b_moveaxis_shape, b.dev_mask(), b.dev_id()); + TransposeImpl(ctx.run_ctx, a, a_data, + mxnet::TShape(a_moveaxis_index.begin(), a_moveaxis_index.end())); + TransposeImpl(ctx.run_ctx, b, b_data, + mxnet::TShape(b_moveaxis_index.begin(), b_moveaxis_index.end())); + } + aw_data = TBlob(aw_ptr, GetCutoffShape(a_moveaxis_shape), a.dev_mask(), a.dev_id()); + bw_data = TBlob(bw_ptr, GetCutoffShape(b_moveaxis_shape), b.dev_mask(), b.dev_id()); + cw_data = TBlob(cw_ptr, c_moveaxis_shape, c.dev_mask(), c.dev_id()); + c_data = TBlob(c_ptr, c_moveaxis_shape, c.dev_mask(), c.dev_id()); + for (int i = 0; i < 3; ++i) { + cw_data_vec.push_back(TBlob(cw_ptr + i * c_moveaxis_shape.ProdShape(0, c_ndim - 1), + GetCutoffShape(c_moveaxis_shape), c.dev_mask(), c.dev_id())); + } + } else { + // Allocate workspace in gpu, need to align address. + size_t offset = 0; + aw_data = TBlob(workspace.dptr_ + offset, GetCutoffShape(a_moveaxis_shape), + a.dev_mask(), a.dev_id()); + offset = AligndWorkspaceSize(offset, aw_data.shape_.Size()); + + bw_data = TBlob(workspace.dptr_ + offset, GetCutoffShape(b_moveaxis_shape), + b.dev_mask(), b.dev_id()); + offset = AligndWorkspaceSize(offset, bw_data.shape_.Size()); + + cw_data = TBlob(workspace.dptr_ + offset, c_moveaxis_shape, c.dev_mask(), c.dev_id()); + for (int i = 0; i < 3; ++i) { + cw_data_vec.push_back(TBlob(workspace.dptr_ + offset, + GetCutoffShape(c_moveaxis_shape), c.dev_mask(), c.dev_id())); + offset = AligndWorkspaceSize(offset, cw_data_vec[i].shape_.Size()); + } + c_data = TBlob(workspace.dptr_ + offset, c_moveaxis_shape, c.dev_mask(), c.dev_id()); + offset = AligndWorkspaceSize(offset, c_data.shape_.Size()); + + a_data = a; + b_data = b; + if (a_axis != a_ndim -1 || b_axis != b_ndim - 1) { + a_data = TBlob(workspace.dptr_ + offset, a_moveaxis_shape, + a.dev_mask(), a.dev_id()); + offset = AligndWorkspaceSize(offset, a_data.shape_.Size()); + + b_data = TBlob(workspace.dptr_ + offset, b_moveaxis_shape, + b.dev_mask(), b.dev_id()); + offset = AligndWorkspaceSize(offset, b_data.shape_.Size()); + + TransposeImpl(ctx.run_ctx, a, a_data, + mxnet::TShape(a_moveaxis_index.begin(), a_moveaxis_index.end())); + TransposeImpl(ctx.run_ctx, b, b_data, + mxnet::TShape(b_moveaxis_index.begin(), b_moveaxis_index.end())); + } + } + std::vector positive_vec; + std::vector a_index_vec, b_index_vec, c_index_vec; + std::vector req_vec; + if (a_dim == 2 && b_dim == 3) { + a_index_vec = {1, 0, 0, 1}; + b_index_vec = {2, 2, 1, 0}; + c_index_vec = {0, 1, 2, 2}; + positive_vec = {1, 0, 1, 0}; + req_vec = { kWriteTo, kWriteTo, kWriteTo, kAddTo }; + } else if (a_dim == 3 && b_dim == 2) { + a_index_vec = {2, 2, 0, 1}; + b_index_vec = {1, 0, 1, 0}; + c_index_vec = {0, 1, 2, 2}; + positive_vec = {0, 1, 1, 0}; + req_vec = { kWriteTo, kWriteTo, kWriteTo, kAddTo }; + } else { + a_index_vec = {1, 2, 2, 0, 0, 1}; + b_index_vec = {2, 1, 0, 2, 1, 0}; + c_index_vec = {0, 0, 1, 1, 2, 2}; + positive_vec = {1, 0, 1, 0, 1, 0}; + req_vec = { kWriteTo, kAddTo, kWriteTo, kAddTo, kWriteTo, kAddTo}; + } + for (size_t i = 0; i < a_index_vec.size(); ++i) { + int idx = c_index_vec[i]; + mxnet_op::Kernel::Launch(s, aw_data.Size(), a_data.dptr(), + aw_data.dptr(), a_data.size(a_ndim - 1), + a_index_vec[i], a_data.Size()); + mxnet_op::Kernel::Launch(s, bw_data.Size(), b_data.dptr(), + bw_data.dptr(), b_data.size(b_ndim - 1), + b_index_vec[i], b_data.Size()); + BinaryBroadcastCompute(attrs, ctx, { aw_data, bw_data }, + { kWriteTo }, { cw_data_vec[idx] }); + MXNET_ASSIGN_REQ_SWITCH(req_vec[i], req_type, { + mxnet_op::Kernel, xpu>::Launch(s, cw_data_vec[idx].Size(), + cw_data_vec[idx].dptr(), + c_data.dptr(), + positive_vec[i], + c_data.size(c_ndim - 1), + idx, c_data.Size()); + }); + } + cw_data = cw_data.reshape(c.shape_); + const DType *res_ptr = c_data.dptr(); + if (c_axis != c_ndim -1) { + const Tuple c_axis_index = GetMoveaxisIndex(-1, c_axis, c_moveaxis_shape); + TransposeImpl(ctx.run_ctx, c_data, cw_data, + mxnet::TShape(c_axis_index.begin(), c_axis_index.end())); + res_ptr = cw_data.dptr(); + } + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, c.Size(), res_ptr, c.dptr()); + }); + } +}; + +template +struct NumpyCrossForwardImpl { + static void op(const TBlob& a, const TBlob& b, const TBlob& c, + const std::vector >& moveaxis_index_vec, + const std::vector& moveaxis_shape_vec, + const int a_axis, const int b_axis, const int c_axis, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& req, + const Tensor& workspace) { + Stream *s = ctx.get_stream(); + const Tuple& a_moveaxis_index = moveaxis_index_vec[0]; + const Tuple& b_moveaxis_index = moveaxis_index_vec[1]; + const mxnet::TShape& a_moveaxis_shape = moveaxis_shape_vec[0]; + const mxnet::TShape& b_moveaxis_shape = moveaxis_shape_vec[1]; + const mxnet::TShape& c_shape = c.shape_; + const int a_ndim = a_moveaxis_shape.ndim(); + const int b_ndim = b_moveaxis_shape.ndim(); + + TBlob aw_data, bw_data, cw_data, a_data, b_data; + if (ctx.run_ctx.get_ctx().dev_mask() == cpu::kDevMask) { + // Allocate workspace in cpu, no need to align address. + DType *aw_ptr = workspace.dptr_; + DType *bw_ptr = aw_ptr + a_moveaxis_shape.ProdShape(0, a_ndim - 1); + DType *cw_ptr = bw_ptr + b_moveaxis_shape.ProdShape(0, b_ndim - 1); + aw_data = TBlob(aw_ptr, GetCutoffShape(a_moveaxis_shape), a.dev_mask(), a.dev_id()); + bw_data = TBlob(bw_ptr, GetCutoffShape(b_moveaxis_shape), b.dev_mask(), b.dev_id()); + cw_data = TBlob(cw_ptr, c_shape, c.dev_mask(), c.dev_id()); + a_data = a; + b_data = b; + if (a_axis != a_ndim -1 || b_axis != b_ndim - 1) { + DType *a_ptr = cw_ptr + c_shape.Size(); + DType *b_ptr = a_ptr + a_moveaxis_shape.Size(); + a_data = TBlob(a_ptr, a_moveaxis_shape, a.dev_mask(), a.dev_id()); + b_data = TBlob(b_ptr, b_moveaxis_shape, b.dev_mask(), b.dev_id()); + TransposeImpl(ctx.run_ctx, a, a_data, + mxnet::TShape(a_moveaxis_index.begin(), a_moveaxis_index.end())); + TransposeImpl(ctx.run_ctx, b, b_data, + mxnet::TShape(b_moveaxis_index.begin(), b_moveaxis_index.end())); + } + } else { + // Allocate workspace in cpu, need to align address. + size_t offset = 0; + aw_data = TBlob(workspace.dptr_ + offset, GetCutoffShape(a_moveaxis_shape), + a.dev_mask(), a.dev_id()); + offset = AligndWorkspaceSize(offset, aw_data.shape_.Size()); + + bw_data = TBlob(workspace.dptr_ + offset, GetCutoffShape(b_moveaxis_shape), + b.dev_mask(), b.dev_id()); + offset = AligndWorkspaceSize(offset, bw_data.shape_.Size()); + + cw_data = TBlob(workspace.dptr_ + offset, c_shape, + c.dev_mask(), c.dev_id()); + offset = AligndWorkspaceSize(offset, cw_data.shape_.Size()); + a_data = a; + b_data = b; + if (a_axis != a_ndim -1 || b_axis != b_ndim - 1) { + a_data = TBlob(workspace.dptr_ + offset, a_moveaxis_shape, + a.dev_mask(), a.dev_id()); + offset = AligndWorkspaceSize(offset, a_data.shape_.Size()); + + b_data = TBlob(workspace.dptr_ + offset, b_moveaxis_shape, + b.dev_mask(), b.dev_id()); + offset = AligndWorkspaceSize(offset, b_data.shape_.Size()); + + TransposeImpl(ctx.run_ctx, a, a_data, + mxnet::TShape(a_moveaxis_index.begin(), a_moveaxis_index.end())); + TransposeImpl(ctx.run_ctx, b, b_data, + mxnet::TShape(b_moveaxis_index.begin(), b_moveaxis_index.end())); + } + } + mxnet_op::Kernel::Launch(s, aw_data.Size(), a_data.dptr(), + aw_data.dptr(), a_data.size(a_ndim - 1), + 0, a_data.Size()); + mxnet_op::Kernel::Launch(s, bw_data.Size(), b_data.dptr(), + bw_data.dptr(), b_data.size(b_ndim - 1), + 1, b_data.Size()); + BinaryBroadcastCompute(attrs, ctx, { aw_data, bw_data }, + { req[0] }, { c }); + mxnet_op::Kernel::Launch(s, aw_data.Size(), a_data.dptr(), + aw_data.dptr(), a_data.size(a_ndim - 1), + 1, a_data.Size()); + mxnet_op::Kernel::Launch(s, bw_data.Size(), b_data.dptr(), + bw_data.dptr(), b_data.size(b_ndim - 1), + 0, b_data.Size()); + BinaryBroadcastCompute(attrs, ctx, { aw_data, bw_data }, + { kWriteTo }, { cw_data }); + BinaryBroadcastCompute(attrs, ctx, { c, cw_data }, + { kWriteTo }, { c }); + } +}; + +template +void NumpyCrossForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + Stream *s = ctx.get_stream(); + const TBlob& a = inputs[0]; + const TBlob& b = inputs[1]; + const TBlob& c = outputs[0]; + + if (kNullOp == req[0]) { return; } + // Zero-size output, no need to launch kernel + if (0U == a.Size() || 0U == b.Size()) { return; } + + const NumpyCrossParam& param = nnvm::get(attrs.parsed); + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& b_shape = b.shape_; + const mxnet::TShape& c_shape = c.shape_; + const int a_ndim = a_shape.ndim(); + const int b_ndim = b_shape.ndim(); + const int c_ndim = c_shape.ndim(); + Tuple a_moveaxis_index = GetMoveaxisIndex(param.axisa, -1, a_shape); + Tuple b_moveaxis_index = GetMoveaxisIndex(param.axisb, -1, b_shape); + Tuple c_moveaxis_index = GetMoveaxisIndex(param.axisc, -1, c_shape); + mxnet::TShape a_moveaxis_shape = GetMoveaxisShape(a_moveaxis_index, a_shape); + mxnet::TShape b_moveaxis_shape = GetMoveaxisShape(b_moveaxis_index, b_shape); + mxnet::TShape c_moveaxis_shape = GetMoveaxisShape(c_moveaxis_index, c_shape); + const int a_axis = CheckAxis(param.axisa, a_ndim); + const int b_axis = CheckAxis(param.axisb, b_ndim); + const int c_axis = CheckAxis(param.axisc, c_ndim); + const std::vector shape_vec({ a_moveaxis_shape, b_moveaxis_shape, + c_moveaxis_shape }); + const std::vector > index_vec({ a_moveaxis_index, b_moveaxis_index, + c_moveaxis_index }); + + MSHADOW_SGL_DBL_TYPE_SWITCH(c.type_flag_, DType, { + // Calculate workspace. + size_t workspace_size = NumpyCrossWorkspaceSize(a_moveaxis_shape, + b_moveaxis_shape, + c_moveaxis_shape, + a_axis, b_axis, ctx, req); + Tensor workspace = ctx.requested[0].get_space_typed( + Shape1(workspace_size), s); + + if (a_moveaxis_shape[a_ndim - 1] == 2) { + if (b_moveaxis_shape[b_ndim - 1] == 2) { + // Case 1: a.shape[-1] == 2 and b.shape[-1] == 2, param.axisc is ignored. + NumpyCrossForwardImpl::op(a, b, c, index_vec, shape_vec, + a_axis, b_axis, c_axis, + attrs, ctx, req, workspace); + } else { + // Case 2: a.shape[-1] == 2 and b.shape[-1] == 3, param.axisc is not ignored. + NumpyCrossForwardImpl::op(a, b, c, index_vec, shape_vec, + a_axis, b_axis, c_axis, + attrs, ctx, req, workspace); + } + } else { + if (b_moveaxis_shape[b_ndim - 1] == 2) { + // Case 3: a.shape[-1] == 3 and b.shape[-1] == 2, param.axisc is not ignored. + NumpyCrossForwardImpl::op(a, b, c, index_vec, shape_vec, + a_axis, b_axis, c_axis, + attrs, ctx, req, workspace); + } else { + // Case 4: a.shape[-1] == 3 and b.shape[-1] == 3, param.axisc is not ignored. + NumpyCrossForwardImpl::op(a, b, c, index_vec, shape_vec, + a_axis, b_axis, c_axis, + attrs, ctx, req, workspace); + } + } + }); +} + +inline bool CheckUseBroadcast(const mxnet::TShape& a_move_shape, + const mxnet::TShape& b_move_shape) { + return !(GetCutoffShape(a_move_shape) == GetCutoffShape(b_move_shape)); +} + +inline mxnet::TShape GetOriShape(const mxnet::TShape& move_shape, + const int axis) { + Tuple origin_index = GetMoveaxisIndex(-1, axis, move_shape); + return GetMoveaxisShape(origin_index, move_shape); +} + +inline std::vector GetReduceAxis(const mxnet::TShape& move_shape, + const mxnet::TShape& broad_move_shape) { + std::vector axis_idx; + if (move_shape.ndim() == broad_move_shape.ndim() || + move_shape.ndim() == broad_move_shape.ndim() + 1) { + for (int i = 0; i < move_shape.ndim() - 1; ++i) { + if (move_shape[i] != broad_move_shape[i]) { axis_idx.push_back(i); } + } + } + return axis_idx; +} + +template +inline void CrossImplWrap(const std::vector& inputs, + const std::vector& outputs, + const std::vector& axises, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const OpReqType& req, + const Tensor& workspace) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(axises.size(), 3U); + const TBlob& a = inputs[0]; + const TBlob& b = inputs[1]; + const TBlob& c = outputs[0]; + if (kNullOp == req) { return; } + if (0U == a.Size() || 0U == b.Size()) { return; } + + const int a_axis = CheckAxis(axises[0], a.ndim()); + const int b_axis = CheckAxis(axises[1], b.ndim()); + const int c_axis = CheckAxis(axises[2], c.ndim()); + Tuple a_move_index = GetMoveaxisIndex(a_axis, -1, a.shape_); + Tuple b_move_index = GetMoveaxisIndex(b_axis, -1, b.shape_); + Tuple c_move_index = GetMoveaxisIndex(c_axis, -1, c.shape_); + mxnet::TShape a_move_shape = GetMoveaxisShape(a_move_index, a.shape_); + mxnet::TShape b_move_shape = GetMoveaxisShape(b_move_index, b.shape_); + mxnet::TShape c_move_shape = GetMoveaxisShape(c_move_index, c.shape_); + // Check workspace size. + size_t workspace_size = NumpyCrossWorkspaceSize(a_move_shape, b_move_shape, + c_move_shape, a_axis, b_axis, + ctx, { req }); + CHECK_GE(workspace.MSize(), workspace_size) + << "Not enough working space size for cross product(should >= " << workspace_size << ")"; + NumpyCrossForwardImpl::op(a, b, c, + { a_move_index, b_move_index}, + { a_move_shape, b_move_shape, c_move_shape}, + a_axis, b_axis, c_axis, + attrs, ctx, { req }, workspace); +} + +template +struct ReduceImplWrap { + static size_t wks(const mxnet::TShape& out_shape, + const mxnet::TShape& out_move_shape, + const mxnet::TShape& in_shape, + const mxnet::TShape& in_move_shape, + const OpContext& ctx, const OpReqType& req) { + size_t ws_reduce = 0U; + std::vector reduce_axis = GetReduceAxis(out_move_shape, in_move_shape); + if (reduce_axis.empty() || req == kNullOp) { return 0U; } + SUM_NDIM_SWITCH(out_shape.ndim(), NDim, { + ws_reduce = broadcast::ReduceWorkspaceSize(ctx.get_stream(), + out_shape, req, in_shape); + }); + return ws_reduce; + } + + static void op(const TBlob& work_in, + const TBlob& work_out, + const TBlob& out_data, + const OpContext& ctx, + const OpReqType& out_req, + const Tensor workspace_tensor) { + Stream *s = ctx.get_stream(); + // Reduce work_in to work_out. + SUM_NDIM_SWITCH(work_out.ndim(), NDim, { + op::broadcast::Reduce( + s, work_out, kWriteTo, workspace_tensor, work_in); + }); + // Copy work_out to out_data. + MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, out_data.Size(), work_out.dptr(), out_data.dptr()); + }); + } +}; + +template +struct NumpyCrossBackwardImpl { + static std::vector + CrossBwkWorkspaceSize(const bool use_broadcast, + const mxnet::TShape& a_shape, + const mxnet::TShape& b_shape, + const mxnet::TShape& c_shape, + const mxnet::TShape& a_move_shape, + const mxnet::TShape& b_move_shape, + const mxnet::TShape& c_move_shape, + const int& a_axis, const int& b_axis, const int& c_axis, + const OpContext& ctx, + const std::vector& req) { + CHECK((a_dim == 2 && b_dim == 3) || (a_dim == 3 && b_dim == 2)) + << "no specialized NumpyCrossOp defined for template parameters."; + std::vector workspace_size(3, 0U); + if (use_broadcast) { + size_t ws_a = 0U, ws_b = 0U, rws_a = 0U, rws_b = 0U; + size_t ws1 = NumpyCrossWorkspaceSize(b_move_shape, c_move_shape, c_move_shape, + b_axis, c_axis, ctx, { kWriteTo }); + size_t ws2 = NumpyCrossWorkspaceSize(b_move_shape, c_move_shape, a_move_shape, + b_axis, c_axis, ctx, { kWriteTo }); + ws_a = std::max(ws1, ws2); + size_t ws3 = NumpyCrossWorkspaceSize(c_move_shape, a_move_shape, c_move_shape, + c_axis, a_axis, ctx, { kWriteTo }); + size_t ws4 = NumpyCrossWorkspaceSize(c_move_shape, a_move_shape, b_move_shape, + c_axis, a_axis, ctx, { kWriteTo }); + ws_b = std::max(ws3, ws4); + // Get delete result shape. + mxnet::TShape c_move_dshape = c_move_shape; + c_move_dshape[c_move_shape.ndim() - 1] = 2; + if (a_dim == 2) { + mxnet::TShape c_dshape = GetOriShape(c_move_dshape, a_axis); + // Calculate workspace size used in sum(grad_a). + size_t rws1 = ReduceImplWrap::wks(c_dshape, c_move_dshape, + c_shape, c_move_shape, ctx, req[0]); + // Calculate workspace size used in sum(grad_b). + size_t rws2 = ReduceImplWrap::wks(b_shape, b_move_shape, + c_shape, c_move_shape, ctx, req[1]); + rws_a = std::max(rws1, rws2); + } + if (b_dim == 2) { + mxnet::TShape c_dshape = GetOriShape(c_move_dshape, b_axis); + // Calculate workspace size used in sum(grad_a). + size_t rws1 = ReduceImplWrap::wks(a_shape, a_move_shape, + c_shape, c_move_shape, ctx, req[0]); + // Calculate workspace size used in sum(grad_b). + size_t rws2 = ReduceImplWrap::wks(c_dshape, c_move_dshape, + c_shape, c_move_shape, ctx, req[1]); + rws_b = std::max(rws1, rws2); + } + size_t rws = (std::max(rws_a, rws_b) + sizeof(DType) - 1) / sizeof(DType); + workspace_size[0] += std::max(ws_a, ws_b); // For cross product workspace. + workspace_size[1] += c_move_shape.Size(); // For cross product result. + workspace_size[1] += c_move_shape.Size(); // For delete result shape. + workspace_size[2] += rws; // For reduce workspace. + } else { + mxnet::TShape a_moveaxis_shape = (a_dim == 2 ? c_move_shape : a_move_shape); + mxnet::TShape b_moveaxis_shape = (b_dim == 2 ? c_move_shape : b_move_shape); + size_t ws1 = NumpyCrossWorkspaceSize(b_moveaxis_shape, c_move_shape, + a_moveaxis_shape, b_axis, c_axis, ctx, + { kWriteTo }); + size_t ws2 = NumpyCrossWorkspaceSize(c_move_shape, a_moveaxis_shape, + b_moveaxis_shape, c_axis, a_axis, ctx, + { kWriteTo }); + workspace_size[0] += std::max(ws1, ws2); // For cross product workspace. + if (a_dim == 2 && b_dim == 3) { + workspace_size[1] += a_moveaxis_shape.Size(); // For cross product result. + workspace_size[1] += a_move_shape.Size(); // For delete kernel result. + } + if (a_dim == 3 && b_dim == 2) { + workspace_size[1] += b_moveaxis_shape.Size(); // For cross product result. + workspace_size[1] += b_move_shape.Size(); // For delete kernel result. + } + } + return workspace_size; + } + + static void op(const bool use_broadcast, + const TBlob& grad_c, const TBlob& a, const TBlob& b, + const TBlob& grad_a, const TBlob& grad_b, + const std::vector& moveaxis_shape_vec, + const int a_axis, const int b_axis, const int c_axis, + const nnvm::NodeAttrs &attrs, + const OpContext& ctx, + const std::vector& req) { + CHECK((a_dim == 2 && b_dim == 3) || (a_dim == 3 && b_dim == 2)) + << "no specialized NumpyCrossOp defined for template parameters."; + Stream *s = ctx.get_stream(); + const mxnet::TShape& a_move_shp = moveaxis_shape_vec[0]; + const mxnet::TShape& b_move_shp = moveaxis_shape_vec[1]; + const mxnet::TShape& c_move_shp = moveaxis_shape_vec[2]; + std::vector a_reduce_axis = GetReduceAxis(a_move_shp, c_move_shp); + std::vector b_reduce_axis = GetReduceAxis(b_move_shp, c_move_shp); + const int c_ndim = c_move_shp.ndim(); + // Get delete result shape. + mxnet::TShape c_move_dshp = c_move_shp; + c_move_dshp[c_move_shp.ndim() - 1] = 2; + std::vector wk_size = CrossBwkWorkspaceSize(use_broadcast, + a.shape_, b.shape_, grad_c.shape_, + a_move_shp, b_move_shp, c_move_shp, + a_axis, b_axis, c_axis, ctx, req); + Tensor workspace = ctx.requested[0].get_space_typed( + Shape1(wk_size[0] + wk_size[1] + wk_size[2]), s); + if (use_broadcast) { + // Use broadcast in forward, need reduce in backward. + DType *w0_ptr = workspace.dptr_; + DType *w1_ptr = w0_ptr + wk_size[0]; + DType *w2_ptr = w1_ptr + c_move_shp.Size(); + char *w3_ptr = reinterpret_cast(w2_ptr + c_move_shp.Size()); + TBlob w0_data(w0_ptr, Shape1(wk_size[0]), grad_c.dev_mask(), grad_c.dev_id()); + Tensor w3_tensor(w3_ptr, Shape1(wk_size[2] * sizeof(DType)), s); + if (a_dim == 2) { // a_dim == 2, b_dim == 3 + TBlob w1_data(w1_ptr, c_move_shp, grad_c.dev_mask(), grad_c.dev_id()); + TBlob w2_data(w2_ptr, c_move_dshp, grad_c.dev_mask(), grad_c.dev_id()); + // Calculate grad_a = cross(b, grad_c). + CrossImplWrap({ b, grad_c }, { w1_data }, { b_axis, c_axis, -1 }, + attrs, ctx, kWriteTo, w0_data.get(s)); + // Copy w1_data to w2_data with delete. + mxnet_op::Kernel::Launch(s, c_move_dshp.ProdShape(0, c_ndim - 1), + w1_data.dptr(), + w2_data.dptr(), 3, 2); + // Transpose w2_data to w1_data. + if (a_axis != grad_a.ndim() - 1) { + const Tuple axis_idx = GetMoveaxisIndex(-1, a_axis, c_move_dshp); + mxnet::TShape c_dshp = GetMoveaxisShape(axis_idx, c_move_dshp); + w1_data = TBlob(w1_ptr, c_dshp, grad_c.dev_mask(), grad_c.dev_id()); + TransposeImpl(ctx.run_ctx, w2_data, w1_data, + mxnet::TShape(axis_idx.begin(), axis_idx.end())); + w2_data = TBlob(w2_ptr, grad_a.shape_, grad_c.dev_mask(), grad_c.dev_id()); + } else { + // If no transpose, exchange the pointer. + w1_data = TBlob(w2_ptr, c_move_dshp, grad_c.dev_mask(), grad_c.dev_id()); + w2_data = TBlob(w1_ptr, grad_a.shape_, grad_c.dev_mask(), grad_c.dev_id()); + } + // Reduce w1_data to w2_data. + if (a_reduce_axis.empty()) { + // No need Reduce w1_data, Copy w1_data to grad_a. + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, grad_a.Size(), w1_data.dptr(), grad_a.dptr()); + }); + } else { + // Need Reduce w1_data to w2_data and Copy w2_data to grad_a. + ReduceImplWrap::op(w1_data, w2_data, grad_a, ctx, req[0], w3_tensor); + } + // Calculate grad_b = cross(grad_c, a). + if (b_reduce_axis.empty()) { + CrossImplWrap({ grad_c, a }, { grad_b }, { c_axis, a_axis, b_axis }, + attrs, ctx, req[1], w0_data.get(s)); + } else { + mxnet::TShape c_shp = GetOriShape(c_move_shp, b_axis); + w1_data = TBlob(w1_ptr, c_shp, grad_c.dev_mask(), grad_c.dev_id()); + w2_data = TBlob(w2_ptr, grad_b.shape_, grad_c.dev_mask(), grad_c.dev_id()); + CrossImplWrap({ grad_c, a }, { w1_data }, { c_axis, a_axis, b_axis }, + attrs, ctx, req[1], w0_data.get(s)); + // Need Reduce w1_data to w2_data and Copy w2_data to grad_b. + ReduceImplWrap::op(w1_data, w2_data, grad_b, ctx, req[1], w3_tensor); + } + } // End of a_dim == 2 + if (b_dim == 2) { // a_dim == 3, b_dim == 2 + TBlob w1_data(w1_ptr, c_move_shp, grad_c.dev_mask(), grad_c.dev_id()); + TBlob w2_data(w2_ptr, c_move_dshp, grad_c.dev_mask(), grad_c.dev_id()); + // Calculate grad_b = cross(grad_c, a). + CrossImplWrap({ grad_c, a }, { w1_data }, { c_axis, a_axis, -1 }, + attrs, ctx, kWriteTo, w0_data.get(s)); + // Copy w1_data to w2_data with delete. + mxnet_op::Kernel::Launch(s, c_move_dshp.ProdShape(0, c_ndim - 1), + w1_data.dptr(), + w2_data.dptr(), 3, 2); + // Transpose w2_data to w1_data. + if (b_axis != grad_b.ndim() - 1) { + const Tuple axis_idx = GetMoveaxisIndex(-1, b_axis, c_move_dshp); + mxnet::TShape c_dshp = GetMoveaxisShape(axis_idx, c_move_dshp); + w1_data = TBlob(w1_ptr, c_dshp, grad_c.dev_mask(), grad_c.dev_id()); + TransposeImpl(ctx.run_ctx, w2_data, w1_data, + mxnet::TShape(axis_idx.begin(), axis_idx.end())); + w2_data = TBlob(w2_ptr, grad_b.shape_, grad_c.dev_mask(), grad_c.dev_id()); + } else { + // If no transpose, exchange the pointer. + w1_data = TBlob(w2_ptr, c_move_dshp, grad_c.dev_mask(), grad_c.dev_id()); + w2_data = TBlob(w1_ptr, grad_b.shape_, grad_c.dev_mask(), grad_c.dev_id()); + } + // Reduce w1_data to w2_data. + if (b_reduce_axis.empty()) { + // No need Reduce w1_data, Copy w1_data to grad_b. + MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, grad_b.Size(), w1_data.dptr(), grad_b.dptr()); + }); + } else { + // Need Reduce w1_data to w2_data and Copy w2_data to grad_a. + ReduceImplWrap::op(w1_data, w2_data, grad_b, ctx, req[1], w3_tensor); + } + // Calculate grad_a = cross(b, grad_c). + if (a_reduce_axis.empty()) { + CrossImplWrap({ b, grad_c }, { grad_a }, { b_axis, c_axis, a_axis }, + attrs, ctx, req[0], workspace); + } else { + mxnet::TShape c_shp = GetOriShape(c_move_shp, a_axis); + w1_data = TBlob(w1_ptr, c_shp, grad_c.dev_mask(), grad_c.dev_id()); + w2_data = TBlob(w2_ptr, grad_a.shape_, grad_c.dev_mask(), grad_c.dev_id()); + CrossImplWrap({ b, grad_c }, { w1_data }, { b_axis, c_axis, a_axis }, + attrs, ctx, req[0], w0_data.get(s)); + // Need Reduce w1_data to w2_data and Copy w2_data to grad_b. + ReduceImplWrap::op(w1_data, w2_data, grad_a, ctx, req[0], w3_tensor); + } + } // End of b_dim == 3 + } else { + // No use broadcast in forward, not need reduce in backward. + DType *w0_ptr = workspace.dptr_; + DType *w1_ptr = w0_ptr + wk_size[0]; + DType *w2_ptr = w1_ptr + c_move_shp.Size(); + TBlob w0_data(w0_ptr, Shape1(wk_size[0]), grad_c.dev_mask(), grad_c.dev_id()); + if (a_dim == 2) { // a_dim == 2, b_dim == 3 + TBlob w1_data(w1_ptr, c_move_shp, grad_c.dev_mask(), grad_c.dev_id()); + TBlob w2_data(w2_ptr, a_move_shp, grad_c.dev_mask(), grad_c.dev_id()); + // Calculate w1_data = cross(b, grad_c). + CrossImplWrap({ b, grad_c }, { w1_data }, { b_axis, c_axis, -1 }, + attrs, ctx, kWriteTo, w0_data.get(s)); + // Calculate grad_b = cross(grad_c, a). + CrossImplWrap({ grad_c, a }, { grad_b }, { c_axis, a_axis, b_axis }, + attrs, ctx, req[1], w0_data.get(s)); + // Copy w1_data to w2_data with delete. + mxnet_op::Kernel::Launch(s, a_move_shp.ProdShape(0, a.ndim() - 1), + w1_data.dptr(), + w2_data.dptr(), 3, 2); + DType *res_ptr = w2_data.dptr(); + if (a_axis != grad_a.ndim() - 1) { + // Transpose w2_data to w1_data. + const Tuple grad_a_axis_idx = GetMoveaxisIndex(-1, a_axis, a_move_shp); + w1_data = TBlob(w1_ptr, grad_a.shape_, grad_c.dev_mask(), grad_c.dev_id()); + TransposeImpl(ctx.run_ctx, w2_data, w1_data, + mxnet::TShape(grad_a_axis_idx.begin(), grad_a_axis_idx.end())); + res_ptr = w1_data.dptr(); + } + // Copy w1_data to grad_a. + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, grad_a.Size(), res_ptr, grad_a.dptr()); + }); + } // End of a_dim == 2 + if (b_dim == 2) { // a_dim == 3, b_dim == 2 + TBlob w1_data(w1_ptr, c_move_shp, grad_c.dev_mask(), grad_c.dev_id()); + TBlob w2_data(w2_ptr, b_move_shp, grad_c.dev_mask(), grad_c.dev_id()); + // Calculate grad_a = cross(b, grad_c). + CrossImplWrap({ b, grad_c }, { grad_a }, { b_axis, c_axis, a_axis }, + attrs, ctx, req[0], w0_data.get(s)); + // Calculate w1_data = cross(grad_c, a). + CrossImplWrap({ grad_c, a }, { w1_data }, { c_axis, a_axis, -1 }, + attrs, ctx, kWriteTo, w0_data.get(s)); + // Copy w1_data to w2_data with delete. + mxnet_op::Kernel::Launch(s, b_move_shp.ProdShape(0, b.ndim() - 1), + w1_data.dptr(), + w2_data.dptr(), 3, 2); + DType *res_ptr = w2_data.dptr(); + if (b_axis != grad_b.ndim() - 1) { + // Transpose w2_data to w1_data. + const Tuple grad_b_axis_idx = GetMoveaxisIndex(-1, b_axis, b_move_shp); + w1_data = TBlob(w1_ptr, grad_b.shape_, grad_c.dev_mask(), grad_c.dev_id()); + TransposeImpl(ctx.run_ctx, w2_data, w1_data, + mxnet::TShape(grad_b_axis_idx.begin(), grad_b_axis_idx.end())); + res_ptr = w1_data.dptr(); + } + // Copy w1_data to grad_b. + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, grad_b.Size(), res_ptr, grad_b.dptr()); + }); + } // End of b_dim == 2 + } // End of use_broadcast + } +}; + +template +struct NumpyCrossBackwardImpl { + static std::vector + CrossBwkWorkspaceSize(const bool use_broadcast, + const mxnet::TShape& a_shape, + const mxnet::TShape& b_shape, + const mxnet::TShape& c_shape, + const mxnet::TShape& a_move_shape, + const mxnet::TShape& b_move_shape, + const mxnet::TShape& c_move_shape, + const int& a_axis, const int& b_axis, const int& c_axis, + const OpContext& ctx, + const std::vector& req) { + std::vector workspace_size(3, 0U); + if (use_broadcast) { + // Calculate workspace size used in cross(b_move, grad_c_move). + size_t ws1 = NumpyCrossWorkspaceSize(b_move_shape, c_move_shape, c_move_shape, + b_axis, c_axis, ctx, { kWriteTo }); + // Calculate workspace size used in cross(grad_c_move, a_move). + size_t ws2 = NumpyCrossWorkspaceSize(c_move_shape, a_move_shape, c_move_shape, + c_axis, a_axis, ctx, { kWriteTo }); + // Calculate workspace size used in sum(grad_a). + size_t rws1 = ReduceImplWrap::wks(a_shape, a_move_shape, + c_shape, c_move_shape, ctx, req[0]); + // Calculate workspace size used in sum(grad_b). + size_t rws2 = ReduceImplWrap::wks(b_shape, b_move_shape, + c_shape, c_move_shape, ctx, req[1]); + // For cross product workspace. + workspace_size[0] += std::max(ws1, ws2); + // For reduce workspace. + workspace_size[1] += (std::max(rws1, rws2) + sizeof(DType) - 1) / sizeof(DType); + // For cross result and reduce result. + workspace_size[2] += c_move_shape.Size(); + workspace_size[2] += std::max(a_move_shape.Size(), b_move_shape.Size()); + } else { + size_t ws1 = NumpyCrossWorkspaceSize(b_move_shape, c_move_shape, a_move_shape, + b_axis, c_axis, ctx, { req[0] }); + size_t ws2 = NumpyCrossWorkspaceSize(c_move_shape, a_move_shape, b_move_shape, + c_axis, a_axis, ctx, { req[1] }); + workspace_size[0] += std::max(ws1, ws2); // For cross product workspace. + } + return workspace_size; + } + + static void op(const bool use_broadcast, + const TBlob& grad_c, const TBlob& a, const TBlob& b, + const TBlob& grad_a, const TBlob& grad_b, + const std::vector& moveaxis_shape_vec, + const int a_axis, const int b_axis, const int c_axis, + const nnvm::NodeAttrs &attrs, + const OpContext& ctx, + const std::vector& req) { + Stream *s = ctx.get_stream(); + const mxnet::TShape& a_move_shp = moveaxis_shape_vec[0]; + const mxnet::TShape& b_move_shp = moveaxis_shape_vec[1]; + const mxnet::TShape& c_move_shp = moveaxis_shape_vec[2]; + std::vector wk_size = CrossBwkWorkspaceSize(use_broadcast, + a.shape_, b.shape_, grad_c.shape_, + a_move_shp, b_move_shp, c_move_shp, + a_axis, b_axis, c_axis, ctx, req); + Tensor workspace = ctx.requested[0].get_space_typed( + Shape1(wk_size[0] + wk_size[1] + wk_size[2]), s); + if (use_broadcast) { + // Use broadcast in forward, need reduce in backward. + std::vector a_reduce_axis = GetReduceAxis(a_move_shp, c_move_shp); + std::vector b_reduce_axis = GetReduceAxis(b_move_shp, c_move_shp); + // Allocate workspace. + DType *w0_ptr = workspace.dptr_; + char *w1_ptr = reinterpret_cast(w0_ptr + wk_size[0]); + TBlob w0_data(w0_ptr, Shape1(wk_size[0]), grad_c.dev_mask(), grad_c.dev_id()); + Tensor w1_tensor(w1_ptr, Shape1(wk_size[1] * sizeof(DType)), s); + if (a_reduce_axis.empty()) { + // Calculate grad_a = cross(b, grad_c). + CrossImplWrap({ b, grad_c }, { grad_a }, { b_axis, c_axis, a_axis }, + attrs, ctx, req[0], w0_data.get(s)); + } else { + mxnet::TShape c_shp = GetOriShape(c_move_shp, a_axis); + DType *w2_ptr = w0_ptr + wk_size[0] + wk_size[1]; + DType *w3_ptr = w2_ptr + c_move_shp.Size(); + TBlob w2_data(w2_ptr, c_shp, grad_c.dev_mask(), grad_c.dev_id()); + TBlob w3_data(w3_ptr, grad_a.shape_, grad_c.dev_mask(), grad_c.dev_id()); + // Calculate w2_data = cross(b, grad_c). + CrossImplWrap({ b, grad_c }, { w2_data }, { b_axis, c_axis, a_axis }, + attrs, ctx, kWriteTo, w0_data.get(s)); + // Reduce w2_data to w3_data and Copy w3_data to grad_a. + ReduceImplWrap::op(w2_data, w3_data, grad_a, ctx, req[0], w1_tensor); + } + if (b_reduce_axis.empty()) { + // Calculate grad_b = cross(grad_c, a). + CrossImplWrap({ grad_c, a }, { grad_b }, { c_axis, a_axis, b_axis }, + attrs, ctx, req[1], w0_data.get(s)); + } else { + mxnet::TShape c_shp = GetOriShape(c_move_shp, b_axis); + DType *w2_ptr = w0_ptr + wk_size[0] + wk_size[1]; + DType *w3_ptr = w2_ptr + c_move_shp.Size(); + TBlob w2_data(w2_ptr, c_shp, grad_c.dev_mask(), grad_c.dev_id()); + TBlob w3_data(w3_ptr, grad_b.shape_, grad_c.dev_mask(), grad_c.dev_id()); + // Calculate w2_data = cross(grad_c, a). + CrossImplWrap({ grad_c, a }, { w2_data }, { c_axis, a_axis, b_axis }, + attrs, ctx, kWriteTo, w0_data.get(s)); + // Reduce w2_data to w3_data and Copy w3_data to grad_b. + ReduceImplWrap::op(w2_data, w3_data, grad_b, ctx, req[1], w1_tensor); + } + } else { + CrossImplWrap({ b, grad_c }, { grad_a }, { b_axis, c_axis, a_axis }, + attrs, ctx, req[0], workspace); + CrossImplWrap({ grad_c, a }, { grad_b }, { c_axis, a_axis, b_axis }, + attrs, ctx, req[1], workspace); + } + } +}; + +template +struct NumpyCrossBackwardImpl { + static std::vector + CrossBwkWorkspaceSize(const bool use_broadcast, + const mxnet::TShape& a_shape, + const mxnet::TShape& b_shape, + const mxnet::TShape& c_shape, + const mxnet::TShape& a_move_shape, + const mxnet::TShape& b_move_shape, + const mxnet::TShape& c_move_shape, + const int& a_axis, const int& b_axis, const int& c_axis, + const OpContext& ctx, + const std::vector& req) { + std::vector workspace_size(3, 0U); + const int a_ndim = a_move_shape.ndim(); + const int b_ndim = b_move_shape.ndim(); + const int c_ndim = c_move_shape.ndim(); + mxnet::TShape grad_move_shape(c_ndim + 1, 2); + for (int i = 0; i < c_ndim; ++i) { grad_move_shape[i] = c_move_shape[i]; } + + workspace_size[0] += grad_move_shape.Size(); // For grad_a_move or grad_b_move. + workspace_size[0] += a_move_shape.ProdShape(0, a_ndim - 1); // For a_move work data. + workspace_size[0] += b_move_shape.ProdShape(0, b_ndim - 1); // For b_move work data. + workspace_size[0] = // For c_move work data. + AligndWorkspaceSize(workspace_size[0], c_move_shape.Size()); + + if (a_axis != a_ndim -1 || b_axis != b_ndim - 1) { + if (ctx.run_ctx.get_ctx().dev_mask() == cpu::kDevMask) { + workspace_size[1] += a_move_shape.Size(); // For a_move size. + workspace_size[1] += b_move_shape.Size(); // For b_move size. + workspace_size[1] += grad_move_shape.Size(); // For grad_a_move or grad_b_move trans. + } else { + workspace_size[1] = AligndWorkspaceSize(workspace_size[1], a_move_shape.Size()); + workspace_size[1] = AligndWorkspaceSize(workspace_size[1], b_move_shape.Size()); + workspace_size[1] = AligndWorkspaceSize(workspace_size[1], grad_move_shape.Size()); + } + } + if (use_broadcast) { + mxnet::TShape grad_a_dshape = GetOriShape(grad_move_shape, a_axis); + mxnet::TShape grad_b_dshape = GetOriShape(grad_move_shape, b_axis); + size_t rws1 = ReduceImplWrap::wks(a_shape, a_move_shape, + grad_a_dshape, grad_move_shape, ctx, req[0]); + size_t rws2 = ReduceImplWrap::wks(b_shape, b_move_shape, + grad_b_dshape, grad_move_shape, ctx, req[1]); + size_t rws = (std::max(rws1, rws2) + sizeof(DType) - 1) / sizeof(DType); + workspace_size[2] += std::max(a_shape.Size(), b_shape.Size()); // For reduce result. + workspace_size[2] += rws; // For reduce workspace. + } + return workspace_size; + } + + static void op(const bool use_broadcast, + const TBlob& grad_c, const TBlob& a, const TBlob& b, + const TBlob& grad_a, const TBlob& grad_b, + const std::vector& moveaxis_shape_vec, + const int a_axis, const int b_axis, const int c_axis, + const nnvm::NodeAttrs &attrs, + const OpContext& ctx, + const std::vector& req) { + Stream *s = ctx.get_stream(); + Tuple a_move_idx = GetMoveaxisIndex(a_axis, -1, a.shape_); + Tuple b_move_idx = GetMoveaxisIndex(b_axis, -1, b.shape_); + const mxnet::TShape& a_move_shp = moveaxis_shape_vec[0]; + const mxnet::TShape& b_move_shp = moveaxis_shape_vec[1]; + const mxnet::TShape& c_move_shp = grad_c.shape_; + std::vector a_reduce_axis = GetReduceAxis(a_move_shp, c_move_shp); + std::vector b_reduce_axis = GetReduceAxis(b_move_shp, c_move_shp); + const int a_ndim = a_move_shp.ndim(); + const int b_ndim = b_move_shp.ndim(); + const int c_ndim = c_move_shp.ndim(); + mxnet::TShape grad_move_shp(c_ndim + 1, 2); + for (int i = 0; i < c_ndim; ++i) { grad_move_shp[i] = c_move_shp[i]; } + // Calculate workspace size. + std::vector wk_size = CrossBwkWorkspaceSize(use_broadcast, + a.shape_, b.shape_, grad_c.shape_, + a_move_shp, b_move_shp, c_move_shp, + a_axis, b_axis, c_axis, ctx, req); + Tensor workspace = ctx.requested[0].get_space_typed( + Shape1(wk_size[0] + wk_size[1] + wk_size[2]), s); + + // Allocate workspace in cpu, no need to align address. + DType *grad_ptr = workspace.dptr_; + DType *aw_ptr = grad_ptr + grad_move_shp.Size(); + DType *bw_ptr = aw_ptr + a_move_shp.ProdShape(0, a_ndim - 1); + DType *cw_ptr = bw_ptr + b_move_shp.ProdShape(0, b_ndim - 1); + TBlob grad_move_data(grad_ptr, grad_move_shp, grad_c.dev_mask(), grad_c.dev_id()); + TBlob aw_data(aw_ptr, GetCutoffShape(a_move_shp), grad_c.dev_mask(), grad_c.dev_id()); + TBlob bw_data(bw_ptr, GetCutoffShape(b_move_shp), grad_c.dev_mask(), grad_c.dev_id()); + TBlob cw_data(cw_ptr, c_move_shp, grad_c.dev_mask(), grad_c.dev_id()); + TBlob a_move_data = a; + TBlob b_move_data = b; + TBlob grad_data = grad_move_data; + size_t offset = 0; + if (a_axis != a_ndim -1 || b_axis != b_ndim - 1) { + if (ctx.run_ctx.get_ctx().dev_mask() == cpu::kDevMask) { + DType *a_ptr = workspace.dptr_ + wk_size[0]; + DType *b_ptr = a_ptr + a_move_shp.Size(); + a_move_data = TBlob(a_ptr, a_move_shp, a.dev_mask(), a.dev_id()); + b_move_data = TBlob(b_ptr, b_move_shp, b.dev_mask(), b.dev_id()); + TransposeImpl(ctx.run_ctx, a, a_move_data, + mxnet::TShape(a_move_idx.begin(), a_move_idx.end())); + TransposeImpl(ctx.run_ctx, b, b_move_data, + mxnet::TShape(b_move_idx.begin(), b_move_idx.end())); + } else { + DType *w1_ptr = workspace.dptr_ + wk_size[0]; + a_move_data = TBlob(w1_ptr + offset, a_move_shp, a.dev_mask(), a.dev_id()); + offset = AligndWorkspaceSize(offset, a_move_shp.Size()); + + b_move_data = TBlob(w1_ptr + offset, b_move_shp, a.dev_mask(), a.dev_id()); + offset = AligndWorkspaceSize(offset, a_move_shp.Size()); + + TransposeImpl(ctx.run_ctx, a, a_move_data, + mxnet::TShape(a_move_idx.begin(), a_move_idx.end())); + TransposeImpl(ctx.run_ctx, b, b_move_data, + mxnet::TShape(b_move_idx.begin(), b_move_idx.end())); + } + } + // Copy b_move_data[..., 1] to bw_data. + mxnet_op::Kernel::Launch(s, bw_data.Size(), + b_move_data.dptr(), + bw_data.dptr(), + b_move_data.size(b_ndim - 1), + 1, b_move_data.Size()); + // cw_data = grad_c_move * b_move_data[..., 1]. + BinaryBroadcastCompute(attrs, ctx, { grad_c, bw_data }, + { kWriteTo }, { cw_data }); + // Copy cw_data to grad_move_data[..., 0]. + mxnet_op::Kernel, xpu>::Launch(s, cw_data.Size(), + cw_data.dptr(), + grad_move_data.dptr(), + true, grad_move_data.size(c_ndim), + 0, grad_move_data.Size()); + // Copy b_move_data[..., 0] to bw_data. + mxnet_op::Kernel::Launch(s, bw_data.Size(), + b_move_data.dptr(), + bw_data.dptr(), + b_move_data.size(b_ndim - 1), + 0, b_move_data.Size()); + // cw_data = grad_c_move * b_move_data[..., 0]. + BinaryBroadcastCompute(attrs, ctx, { grad_c, bw_data }, + { kWriteTo }, { cw_data }); + // Copy -cw_data to grad_move_data[..., 1]. + mxnet_op::Kernel, xpu>::Launch(s, cw_data.Size(), + cw_data.dptr(), + grad_move_data.dptr(), + false, grad_move_data.size(c_ndim), + 1, grad_move_data.Size()); + // Transpose grad_move_data according to a_axis. + grad_data = grad_move_data; + if (a_axis != a_ndim - 1) { + mxnet::TShape grad_shp = GetOriShape(grad_move_shp, a_axis); + if (ctx.run_ctx.get_ctx().dev_mask() == cpu::kDevMask) { + DType *grad_ptr = workspace.dptr_ + wk_size[0] + a_move_shp.Size() + b_move_shp.Size(); + grad_data = TBlob(grad_ptr, grad_shp, grad_c.dev_mask(), grad_c.dev_id()); + } else { + DType *w1_ptr = workspace.dptr_ + wk_size[0]; + grad_data = TBlob(w1_ptr + offset, grad_shp, grad_c.dev_mask(), grad_c.dev_id()); + offset = AligndWorkspaceSize(offset, grad_shp.Size()); + } + const Tuple axis_idx = GetMoveaxisIndex(-1, a_axis, grad_move_shp); + TransposeImpl(ctx.run_ctx, grad_move_data, grad_data, + mxnet::TShape(axis_idx.begin(), axis_idx.end())); + } + if (!a_reduce_axis.empty()) { + size_t interval = std::max(grad_a.Size(), grad_b.Size()); + DType *grad_delete_ptr = workspace.dptr_ + wk_size[0] + wk_size[1]; + char *dw_ptr = reinterpret_cast(grad_delete_ptr + interval); + TBlob grad_delete_data(grad_delete_ptr, grad_a.shape_, grad_c.dev_mask(), grad_c.dev_id()); + Tensor dw_tensor(dw_ptr, Shape1((wk_size[2] - interval) * sizeof(DType)), s); + // Reduce grad_data to grad_delete_data and copy to grad_a. + ReduceImplWrap::op(grad_data, grad_delete_data, grad_a, ctx, req[0], dw_tensor); + } else { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, grad_a.Size(), grad_data.dptr(), grad_a.dptr()); + }); + } + + // Copy a_move_data[..., 1] to aw_data. + mxnet_op::Kernel::Launch(s, aw_data.Size(), + a_move_data.dptr(), + aw_data.dptr(), + a_move_data.size(a_ndim - 1), + 1, a_move_data.Size()); + // cw_data = grad_c_move * a_move_data[..., 1]. + BinaryBroadcastCompute(attrs, ctx, { grad_c, aw_data }, + { kWriteTo }, { cw_data }); + // Copy -cw_data to grad_move_data[..., 0]. + mxnet_op::Kernel, xpu>::Launch(s, cw_data.Size(), + cw_data.dptr(), + grad_move_data.dptr(), + false, grad_move_data.size(c_ndim), + 0, grad_move_data.Size()); + // Copy a_move_data[..., 0] to aw_data. + mxnet_op::Kernel::Launch(s, aw_data.Size(), + a_move_data.dptr(), + aw_data.dptr(), + a_move_data.size(a_ndim - 1), + 0, a_move_data.Size()); + // cw_data = grad_c_move * a_move_data[..., 0]. + BinaryBroadcastCompute(attrs, ctx, { grad_c, aw_data }, + { kWriteTo }, { cw_data }); + // Copy cw_data to grad_move_data[..., 1]. + mxnet_op::Kernel, xpu>::Launch(s, cw_data.Size(), + cw_data.dptr(), + grad_move_data.dptr(), + true, grad_move_data.size(c_ndim), + 1, grad_move_data.Size()); + // Transpose grad_move_data according to b_axis. + grad_data = grad_move_data; + if (b_axis != b_ndim - 1) { + mxnet::TShape grad_shp = GetOriShape(grad_move_shp, b_axis); + if (ctx.run_ctx.get_ctx().dev_mask() == cpu::kDevMask) { + DType *grad_ptr = workspace.dptr_ + wk_size[0] + a_move_shp.Size() + b_move_shp.Size(); + grad_data = TBlob(grad_ptr, grad_shp, grad_c.dev_mask(), grad_c.dev_id()); + } else { + DType *w1_ptr = workspace.dptr_ + wk_size[0]; + grad_data = TBlob(w1_ptr + offset, grad_shp, grad_c.dev_mask(), grad_c.dev_id()); + offset = AligndWorkspaceSize(offset, grad_shp.Size()); + } + const Tuple axis_idx = GetMoveaxisIndex(-1, b_axis, grad_move_shp); + TransposeImpl(ctx.run_ctx, grad_move_data, grad_data, + mxnet::TShape(axis_idx.begin(), axis_idx.end())); + } + if (!b_reduce_axis.empty()) { + size_t interval = std::max(grad_a.Size(), grad_b.Size()); + DType *grad_delete_ptr = workspace.dptr_ + wk_size[0] + wk_size[1]; + char *dw_ptr = reinterpret_cast(grad_delete_ptr + interval); + TBlob grad_delete_data(grad_delete_ptr, grad_b.shape_, grad_c.dev_mask(), grad_c.dev_id()); + Tensor dw_tensor(dw_ptr, Shape1((wk_size[2] - interval) * sizeof(DType)), s); + // Reduce grad_data to grad_delete_data and copy to grad_a. + ReduceImplWrap::op(grad_data, grad_delete_data, grad_b, ctx, req[1], dw_tensor); + } else { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, grad_b.Size(), grad_data.dptr(), grad_b.dptr()); + }); + } + } +}; + +template +void NumpyCrossBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + const TBlob& grad_c = inputs[0]; + const TBlob& a = inputs[1]; + const TBlob& b = inputs[2]; + const TBlob& grad_a = outputs[0]; + const TBlob& grad_b = outputs[1]; + + if (kNullOp == req[0] && kNullOp == req[1]) { return; } + // Zero-size output, no need to launch kernel + if (0U == grad_c.Size()) { return; } + + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& b_shape = b.shape_; + const mxnet::TShape& c_shape = grad_c.shape_; + const int a_ndim = a_shape.ndim(); + const int b_ndim = b_shape.ndim(); + const int c_ndim = c_shape.ndim(); + const NumpyCrossParam& param = nnvm::get(attrs.parsed); + Tuple a_moveaxis_index = GetMoveaxisIndex(param.axisa, -1, a_shape); + Tuple b_moveaxis_index = GetMoveaxisIndex(param.axisb, -1, b_shape); + Tuple c_moveaxis_index = GetMoveaxisIndex(param.axisc, -1, c_shape); + mxnet::TShape a_moveaxis_shape = GetMoveaxisShape(a_moveaxis_index, a_shape); + mxnet::TShape b_moveaxis_shape = GetMoveaxisShape(b_moveaxis_index, b_shape); + mxnet::TShape c_moveaxis_shape = GetMoveaxisShape(c_moveaxis_index, c_shape); + const int a_axis = CheckAxis(param.axisa, a_ndim); + const int b_axis = CheckAxis(param.axisb, b_ndim); + const int c_axis = CheckAxis(param.axisc, c_ndim); + std::vector move_shp_vec({ a_moveaxis_shape, b_moveaxis_shape, c_moveaxis_shape }); + + MSHADOW_SGL_DBL_TYPE_SWITCH(grad_c.type_flag_, DType, { + bool use_broadcast = CheckUseBroadcast(a_moveaxis_shape, b_moveaxis_shape); + if (a_moveaxis_shape[a_ndim - 1] == 2) { + if (b_moveaxis_shape[b_ndim - 1] == 2) { + // Case 1: a.shape[-1] == 2 and b.shape[-1] == 2, param.axisc is ignored. + NumpyCrossBackwardImpl::op(use_broadcast, grad_c, a, b, grad_a, grad_b, + move_shp_vec, a_axis, b_axis, c_axis, + attrs, ctx, req); + } else { + // Case 2: a.shape[-1] == 2 and b.shape[-1] == 3, param.axisc is not ignored. + NumpyCrossBackwardImpl::op(use_broadcast, grad_c, a, b, grad_a, grad_b, + move_shp_vec, a_axis, b_axis, c_axis, + attrs, ctx, req); + } + } else { + if (b_moveaxis_shape[b_ndim - 1] == 2) { + // Case 3: a.shape[-1] == 3 and b.shape[-1] == 2, param.axisc is not ignored. + NumpyCrossBackwardImpl::op(use_broadcast, grad_c, a, b, grad_a, grad_b, + move_shp_vec, a_axis, b_axis, c_axis, + attrs, ctx, req); + } else { + // Case 4: a.shape[-1] == 3 and b.shape[-1] == 3, param.axisc is not ignored. + NumpyCrossBackwardImpl::op(use_broadcast, grad_c, a, b, grad_a, grad_b, + move_shp_vec, a_axis, b_axis, c_axis, + attrs, ctx, req); + } + } + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_CROSS_INL_H_ diff --git a/src/operator/numpy/np_cross.cc b/src/operator/numpy/np_cross.cc new file mode 100644 index 000000000000..cdf81cc3edfd --- /dev/null +++ b/src/operator/numpy/np_cross.cc @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file np_cross.cc + * \brief CPU Implementation of numpy-compatible cross + */ + +#include "./np_cross-inl.h" + +namespace mxnet { +namespace op { + +inline bool NumpyCrossShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const mxnet::TShape& a_shape = in_attrs->at(0); + const mxnet::TShape& b_shape = in_attrs->at(1); + if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) { + return false; + } + + if (shape_is_known(a_shape) && shape_is_known(b_shape)) { + const NumpyCrossParam& param = nnvm::get(attrs.parsed); + const int a_ndim = a_shape.ndim(); + const int b_ndim = b_shape.ndim(); + CHECK_GE(a_ndim, 1) << "Array must be at least one-dimensional"; + CHECK_GE(b_ndim, 1) << "Array must be at least one-dimensional"; + CHECK_LE(a_ndim, broadcast::MAX_DIM) + << "cross product support at most " << broadcast::MAX_DIM << " dimensions"; + CHECK_LE(b_ndim, broadcast::MAX_DIM) + << "cross product support at most " << broadcast::MAX_DIM << " dimensions"; + + const Tuple a_moveaxis_index = GetMoveaxisIndex(param.axisa, -1, a_shape); + const Tuple b_moveaxis_index = GetMoveaxisIndex(param.axisb, -1, b_shape); + const mxnet::TShape a_moveaxis_shape = GetMoveaxisShape(a_moveaxis_index, a_shape); + const mxnet::TShape b_moveaxis_shape = GetMoveaxisShape(b_moveaxis_index, b_shape); + + CHECK(a_moveaxis_shape[a_ndim - 1] == 2 || a_moveaxis_shape[a_ndim - 1] == 3) + << "incompatible dimensions for cross product and axis should have dimensions 2 or 3."; + CHECK(b_moveaxis_shape[b_ndim - 1] == 2 || b_moveaxis_shape[b_ndim - 1] == 3) + << "incompatible dimensions for cross product and axis should have dimensions 2 or 3."; + + if (a_ndim == 1 && b_ndim == 1) { + if (a_moveaxis_shape[a_ndim - 1] == 2 && b_moveaxis_shape[b_ndim - 1] == 2) { + // Both 1-D arrays with dim = 2, cross product of vectors. + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(0, 0)); + } else { + // Both 1-D arrays with at least one dim = 3, cross product of vectors. + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, 3)); + } + } else { + mxnet::TShape c_shape; + GetOrCheckLRShape(attrs, a_moveaxis_shape, b_moveaxis_shape, &c_shape); + if (a_moveaxis_shape[a_ndim - 1] == 2 && b_moveaxis_shape[b_ndim - 1] == 2) { + // At least one N-D arrays and both dim = 2, param.axisc is ignored. + SHAPE_ASSIGN_CHECK(*out_attrs, 0, c_shape); + } else { + // At least one N-D arrays and at least one dim = 3, param.axisc not ignored. + // Check axisc is within bounds. + const Tuple c_moveaxis_index = GetMoveaxisIndex(-1, param.axisc, c_shape); + const mxnet::TShape c_moveaxis_shape = GetMoveaxisShape(c_moveaxis_index, c_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, c_moveaxis_shape); + } + } + } + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +DMLC_REGISTER_PARAMETER(NumpyCrossParam); + +NNVM_REGISTER_OP(_npi_cross) +.set_attr_parser(ParamParser) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; +}) +.set_attr("FInferShape", NumpyCrossShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs){ + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("THasDeterministicOutput", true) +.set_attr("FCompute", NumpyCrossForward) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_cross"}) +.add_argument("a", "NDArray-or-Symbol", "First vector") +.add_argument("b", "NDArray-or-Symbol", "Second vector") +.add_arguments(NumpyCrossParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_cross) +.set_attr_parser(ParamParser) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) { + return std::vector(1, ResourceRequest::kTempSpace); +}) +.set_attr("FCompute", NumpyCrossBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_cross.cu b/src/operator/numpy/np_cross.cu new file mode 100644 index 000000000000..50b2a2e7dd2c --- /dev/null +++ b/src/operator/numpy/np_cross.cu @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file np_bicount_op.cu + * \brief GPU Implementation of numpy-compatible cross + */ + +#include "./np_cross-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_cross) +.set_attr("FCompute", NumpyCrossForward); + +NNVM_REGISTER_OP(_backward_npi_cross) +.set_attr("FCompute", NumpyCrossBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 080fb03a7158..5e2dd40371c4 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -132,6 +132,39 @@ def _add_workload_bincount(): OpArgMngr.add_workload('bincount', y2, minlength=5) +def _add_workload_cross(): + shapes = [ + # (a_shape, b_shape, (a_axis, b_axis, c_axis)) + ((2,), (2,), (-1, -1, -1)), + ((1, 2), (1, 2), (-1, -1, -1)), + ((2, 5, 4, 3), (5, 2, 4, 3), (0, 1, 2)), + ((2, 5, 1, 3), (1, 2, 4, 3), (0, 1, 2)), + + ((2,), (3,), (-1, -1, -1)), + ((1, 2,), (1, 3,), (-1, -1, -1)), + ((6, 2, 5, 4), (6, 5, 3, 4), (1, 2, 0)), + ((6, 2, 1, 4), (1, 5, 3, 4), (1, 2, 0)), + + ((3,), (2,), (-1, -1, -1)), + ((1, 3,), (1, 2,), (-1, -1, -1)), + ((6, 3, 5, 4), (6, 5, 2, 4), (1, 2, 0)), + ((6, 3, 1, 4), (1, 5, 2, 4), (1, 2, 0)), + + ((3,), (3,), (-1, -1, -1)), + ((1, 3,), (1, 3,), (-1, -1, -1)), + ((6, 3, 5, 4), (6, 5, 3, 4), (1, 2, 0)), + ((6, 3, 1, 4), (1, 5, 3, 4), (1, 2, 0)), + ] + dtypes = [np.float32, np.float64] + for shape, dtype in itertools.product(shapes, dtypes): + a_shape, b_shape, (a_axis, b_axis, c_axis) = shape + a_np = _np.random.uniform(-10., 10., size=a_shape) + b_np = _np.random.uniform(-10., 10., size=b_shape) + a = np.array(a_np, dtype=dtype) + b = np.array(b_np, dtype=dtype) + OpArgMngr.add_workload('cross', a, b, axisa=a_axis, axisb=b_axis, axisc=c_axis) + + def _add_workload_diag(): def get_mat(n): data = _np.arange(n) @@ -2867,6 +2900,7 @@ def _prepare_workloads(): _add_workload_clip() _add_workload_concatenate(array_pool) _add_workload_copy() + _add_workload_cross() _add_workload_cumsum() _add_workload_ravel() _add_workload_unravel_index() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 3b88dae57d49..4f1a8b5572ad 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -8996,6 +8996,266 @@ def hybrid_forward(self, F, x, *args, **kwargs): assert ret.asnumpy().shape == expected_ret.shape +@with_seed() +@use_np +def test_np_cross(): + class TestNumpyCross(HybridBlock): + def __init__(self, axisa=-1, axisb=-1, axisc=-1, axis=None): + super(TestNumpyCross, self).__init__() + self._axisa = axisa + self._axisb = axisb + self._axisc = axisc + self._axis = axis + + def hybrid_forward(self, F, a, b): + return F.np.cross(a, b, self._axisa, self._axisb, self._axisc, self._axis) + + def check_np_cross(x, a_np, b_np, axises): + try: + if axises is None: + x_expected = _np.cross(a_np, b_np) + elif len(axises) == 4: + (a_axis, b_axis, c_axis, axis,) = axises + x_expected = _np.cross(a_np, b_np, axisa=a_axis, axisb=b_axis, axisc=c_axis, axis=axis) + else: + (a_axis, b_axis, c_axis,) = axises + x_expected = _np.cross(a_np, b_np, axisa=a_axis, axisb=b_axis, axisc=c_axis) + except Exception as e: + print("a:", a_np) + print("a shape:", a_np.shape) + print("b:", b_np) + print("b shape:", b_np.shape) + print(e) + else: + assert x.shape == x_expected.shape + assert_almost_equal(x.asnumpy(), x_expected, rtol=rtol, atol=atol) + + def check_not_use_broadcast(a_np, b_np, axises): + a_shape = a_np.shape + b_shape = b_np.shape + if axises is None: + return a_shape[:-1] == b_shape[:-1] + elif len(axises) == 4: + axis = axises[3] + a_moveaxis_shape = _np.moveaxis(a_np, axis, -1).shape + b_moveaxis_shape = _np.moveaxis(b_np, axis, -1).shape + return a_moveaxis_shape[:-1] == b_moveaxis_shape[:-1] + else: + a_axis = axises[0] + b_axis = axises[1] + a_moveaxis_shape = _np.moveaxis(a_np, a_axis, -1).shape + b_moveaxis_shape = _np.moveaxis(b_np, b_axis, -1).shape + return a_moveaxis_shape[:-1] == b_moveaxis_shape[:-1] + + # calculate dL = gradC * dC + def cal_dL(grad_c_move, dc_move): + num = int(_np.prod(dc_move.shape)) + grad_c_move_1d = grad_c_move.reshape((num,)) + dc_move_1d = dc_move.reshape((num,)) + dL = _np.inner(grad_c_move_1d, dc_move_1d) + return dL + + # get reduced axis index + def get_reduce_axis(shape, broad_shape): + axis = list() + length = len(broad_shape) if len(shape) == len(broad_shape) + 1 else len(broad_shape) - 1 + for i in range(length): + if shape[i] != broad_shape[i]: + axis.append(i) + return tuple(axis) if len(axis) > 0 else None + + # get grad_a and grad_b + def get_cross_backward(a, b, axises): + if axises == None: + a_axis, b_axis, c_axis = (-1,) * 3 + elif len(axises) == 4: + a_axis, b_axis, c_axis = (axises[-1],) * 3 + else: + (a_axis, b_axis, c_axis) = axises + c = _np.cross(a, b, axisa=a_axis, axisb=b_axis, axisc=c_axis) + c_move = _np.moveaxis(c, c_axis, -1) if a.shape[a_axis] == 3 or b.shape[b_axis] == 3 else c + grad_c_move = _np.ones(shape=c_move.shape, dtype=c_move.dtype) + a_move = _np.moveaxis(a, a_axis, -1) + b_move = _np.moveaxis(b, b_axis, -1) + da_move = _np.random.uniform(-1., 1., size=a_move.shape) + db_move = _np.random.uniform(-1., 1., size=b_move.shape) + # dC = dA x B + A x dB + dc_move = _np.cross(da_move, b_move) + _np.cross(a_move, db_move) + # dL1 = Tr(grad_C.T * dC) = dL/dCi * dCi + dL1 = cal_dL(grad_c_move, dc_move) + # check cross backward. + if a.shape[a_axis] == 2 and b.shape[b_axis] == 2: + # Case 1: a.shape[-1] == 2 and b.shape[-1] == 2, param.axisc is ignored. + shape = grad_c_move.shape if grad_c_move.ndim != 0 else (1,) + grad_a_move = _np.empty(shape, dtype=a_move.dtype) + grad_b_move = _np.empty(shape, dtype=b_move.dtype) + grad_a_move = _np.expand_dims(grad_a_move, -1).repeat(2, axis=-1) + grad_b_move = _np.expand_dims(grad_b_move, -1).repeat(2, axis=-1) + a_move_0 = a_move[..., 0] + a_move_1 = a_move[..., 1] + b_move_0 = b_move[..., 0] + b_move_1 = b_move[..., 1] + grad_a_move_0 = grad_c_move * b_move_1 + grad_a_move_1 = grad_c_move * b_move_0 + if grad_a_move_1.ndim == 0: + grad_a_move_1 = -grad_a_move_1 + else: + _np.negative(grad_a_move_1, out=grad_a_move_1) + grad_b_move_0 = grad_c_move * a_move_1 + grad_b_move_1 = grad_c_move * a_move_0 + if grad_b_move_0.ndim == 0: + grad_b_move_0 = -grad_b_move_0 + else: + _np.negative(grad_b_move_0, out=grad_b_move_0) + grad_a_move[..., 0] = grad_a_move_0 + grad_a_move[..., 1] = grad_a_move_1 + grad_b_move[..., 0] = grad_b_move_0 + grad_b_move[..., 1] = grad_b_move_1 + else: + # Case 4: a.shape[-1] == 3 and b.shape[-1] == 3, param.axisc is not ignored. + grad_a_move = _np.cross(b_move, grad_c_move) + grad_b_move = _np.cross(grad_c_move, a_move) + if a.shape[a_axis] == 2: + # Case 2: a.shape[-1] == 2 and b.shape[-1] == 3, param.axisc is not ignored. + grad_a_move = _np.delete(grad_a_move, obj=-1, axis=-1) + if b.shape[b_axis] == 2: + # Case 3: a.shape[-1] == 3 and b.shape[-1] == 2, param.axisc is not ignored. + grad_b_move = _np.delete(grad_b_move, obj=-1, axis=-1) + + if not check_not_use_broadcast(a, b, axises): + a_broad_axis = get_reduce_axis(a_move.shape, c_move.shape) + b_broad_axis = get_reduce_axis(b_move.shape, c_move.shape) + if a_broad_axis is not None: + grad_a_move_reduce = _np.ones_like(a_move) + grad_a_move_reduce = _np.sum(grad_a_move, axis=a_broad_axis, out=grad_a_move_reduce, keepdims=True) + grad_a_move = grad_a_move_reduce + if b_broad_axis is not None: + grad_b_move_reduce = _np.ones_like(b_move) + grad_b_move_reduce = _np.sum(grad_b_move, axis=b_broad_axis, out=grad_b_move_reduce, keepdims=True) + grad_b_move = grad_b_move_reduce + # dL2 = dL/dAi * dAi + dL/dBi * dBi + dL2 = cal_dL(grad_a_move, da_move) + cal_dL(grad_b_move, db_move) + assert_almost_equal(dL1, dL2, rtol=rtol, atol=atol) + # move working axis + return _np.moveaxis(grad_a_move, -1, a_axis), _np.moveaxis(grad_b_move, -1, b_axis) + + shapes = [ + # - (a_shape, b_shape, (a_axis, b_axis, c_axis)) + # - 2 x 2 + ((2,), (2,), (-1, -1, -1)), + ((1, 2), (1, 2), (-1, -1, -1)), + ((1, 2), (2, 2), (-1, -1, -1)), + ((2, 2), (1, 2), (-1, -1, -1)), + ((2, 2), (2, 2), (-1, -1, -1)), + ((1, 2), (2, 2), (-1, 0, -1)), + ((2, 2), (1, 2), (0, -1, -1)), + ((2, 2), (2, 2), (0, 0, -1)), + ((2, 2), (2, 2), (0, 0, 0)), + ((5, 4, 3, 2), (5, 4, 3, 2), (-1, -1, -1)), + ((1, 4, 3, 2), (5, 1, 3, 2), (-1, -1, -1)), + ((5, 4, 3, 2), (5, 4, 3, 2), (-1, -1, 0)), + ((2, 5, 4, 3), (5, 2, 4, 3), (0, 1, 2)), + ((2, 5, 1, 3), (1, 2, 4, 3), (0, 1, 2)), + # - 2 x 3 + ((2,), (3,), (-1, -1, -1)), + ((1, 2,), (1, 3,), (-1, -1, -1)), + ((2, 2,), (2, 3,), (0, -1, 0)), + ((1, 2,), (2, 3,), (-1, -1, -1)), + ((2, 2,), (1, 3,), (-1, -1, -1)), + ((2, 1,), (3, 4,), (0, 0, 0)), + ((2, 1, 3), (4, 3, 1), (0, 1, 2)), + ((6, 5, 4, 2), (6, 5, 4, 3), (-1, -1, -1)), + ((2, 6, 5, 4), (6, 5, 4, 3), (0, -1, 2)), + ((2, 6, 5, 4), (6, 3, 5, 4), (0, 1, 2)), + ((6, 2, 5, 4), (6, 5, 3, 4), (1, 2, 0)), + ((6, 2, 1, 4), (1, 5, 3, 4), (1, 2, 0)), + # - 3 x 2 + ((3,), (2,), (-1, -1, -1)), + ((1, 3,), (1, 2,), (-1, -1, -1)), + ((2, 3,), (2, 2,), (-1, 0, 0)), + ((2, 3,), (1, 2,), (-1, -1, -1)), + ((2, 3,), (1, 2,), (-1, -1, -1)), + ((3, 4, 4), (1, 1, 2,), (0, -1, 0)), + ((3, 4, 4), (1, 2, 1,), (0, 1, 2)), + ((6, 5, 4, 3), (6, 5, 4, 2), (-1, -1, -1)), + ((3, 6, 5, 4), (6, 5, 4, 2), (0, -1, 2)), + ((3, 6, 5, 4), (6, 2, 5, 4), (0, 1, 2)), + ((6, 3, 5, 4), (6, 5, 2, 4), (1, 2, 0)), + ((6, 3, 1, 4), (1, 5, 2, 4), (1, 2, 0)), + # - 3 x 3 + ((3,), (3,), (-1, -1, -1)), + ((1, 3,), (1, 3,), (-1, -1, -1)), + ((2, 3,), (3, 2,), (-1, 0, 0)), + ((1, 3,), (3, 2,), (-1, 0, 0)), + ((1, 3,), (3, 4,), (-1, 0, 0)), + ((1, 1, 3,), (3, 2, 2), (-1, 0, 0)), + ((1, 1, 2, 3,), (3, 2, 2, 2), (-1, 0, 0)), + ((6, 5, 4, 3), (6, 5, 4, 3), (-1, -1, -1)), + ((3, 6, 5, 4), (6, 5, 4, 3), (0, -1, 2)), + ((3, 6, 5, 4), (6, 3, 5, 4), (0, 1, 2)), + ((6, 3, 5, 4), (6, 5, 3, 4), (1, 2, 0)), + ((6, 3, 1, 4), (1, 5, 3, 4), (1, 2, -1)), + + # - (a_shape, b_shape, None) + ((2,), (2,), None), + ((2,), (3,), None), + ((3,), (2,), None), + ((3,), (3,), None), + ((5, 4, 3, 2), (5, 4, 3, 2), None), + ((6, 5, 4, 2), (6, 5, 4, 3), None), + ((6, 5, 4, 3), (6, 5, 4, 2), None), + ((6, 5, 4, 3), (6, 5, 4, 3), None), + ((1, 4, 3, 2), (5, 1, 3, 2), None), + ((6, 1, 4, 2), (6, 5, 1, 3), None), + ((6, 5, 1, 3), (1, 5, 4, 2), None), + ((1, 5, 4, 3), (6, 5, 1, 3), None), + + # - (a_shape, b_shape, (a_axis, b_axis, c_axis, axis)) + ((2, 5, 4, 3), (2, 5, 4, 3), (-1, -1, -1, 0,)), + ((6, 2, 5, 4), (6, 3, 5, 4), (-1, -1, -1, 1,)), + ((6, 5, 3, 4), (6, 5, 2, 4), (-1, -1, -1, 2,)), + ((6, 5, 4, 3), (6, 5, 4, 3), (-1, -1, -1, 3,)), + ] + dtypes = [np.float32, np.float64] + for hybridize in [True, False]: + for shape, dtype in itertools.product(shapes, dtypes): + rtol = 1e-3 + atol = 1e-5 + a_shape, b_shape, axises = shape + if axises is None: + a_axis, b_axis, c_axis = (-1,) * 3 + test_numpy_cross = TestNumpyCross() + elif len(axises) == 4: + (a_axis, b_axis, c_axis, axis,) = axises + test_numpy_cross = TestNumpyCross(axisa=a_axis, axisb=b_axis, axisc=c_axis, axis=axis) + else: + (a_axis, b_axis, c_axis,) = axises + test_numpy_cross = TestNumpyCross(axisa=a_axis, axisb=b_axis, axisc=c_axis) + if hybridize: + test_numpy_cross.hybridize() + a_np = _np.random.uniform(-10., 10., size=a_shape) + b_np = _np.random.uniform(-10., 10., size=b_shape) + a = np.array(a_np, dtype=dtype) + b = np.array(b_np, dtype=dtype) + a.attach_grad() + b.attach_grad() + + # check cross validity + with mx.autograd.record(): + mx_out = test_numpy_cross(a, b) + check_np_cross(mx_out, a.asnumpy(), b.asnumpy(), axises) + + # check cross backward + mx.autograd.backward(mx_out) + grad_a_expected, grad_b_expected = get_cross_backward(a.asnumpy(), b.asnumpy(), axises) + assert_almost_equal(a.grad.asnumpy(), grad_a_expected, rtol=rtol, atol=atol) + assert_almost_equal(b.grad.asnumpy(), grad_b_expected, rtol=rtol, atol=atol) + + # check imperative once again + mx_out = test_numpy_cross(a, b) + check_np_cross(mx_out, a.asnumpy(), b.asnumpy(), axises) + + @with_seed() @use_np def test_np_rollaxis(): diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 2a15e3407862..d3545ce86638 100755 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -712,9 +712,10 @@ def test_sparse_ftrl(): if (dtype == np.float16 and ('multi_precision' not in kwarg or not kwarg['multi_precision'])): continue + rtol, atol = (1e-3, 1e-3) if dtype is np.float16 else (1e-4, 1e-4) compare_optimizer(opt1(**kwarg), opt2(**kwarg), shapes, dtype, w_stype='row_sparse', g_stype='row_sparse', - rtol=1e-4, atol=1e-4) + rtol=rtol, atol=atol) @with_seed()