diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index d634e7922fae..d2e2344e253f 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -29,7 +29,6 @@ from .. import ndarray from ..base import numeric_types from .block import HybridBlock -from .utils import _adapt_np_array from ..util import is_np_array @@ -188,7 +187,6 @@ class L1Loss(Loss): def __init__(self, weight=None, batch_axis=0, **kwargs): super(L1Loss, self).__init__(weight, batch_axis, **kwargs) - @_adapt_np_array def hybrid_forward(self, F, pred, label, sample_weight=None): label = _reshape_like(F, label, pred) loss = F.abs(label - pred) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 87d6e89a4d99..85967424331c 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -25,7 +25,7 @@ from .activations import Activation from ..block import Block, HybridBlock -from ..utils import _indent, _adapt_np_array +from ..utils import _indent from ... import nd, sym from ...util import is_np_array @@ -521,7 +521,6 @@ def __init__(self, axis=1, epsilon=1e-5, center=True, scale=False, shape=(in_channels,), init=beta_initializer, allow_deferred_init=True) - @_adapt_np_array def hybrid_forward(self, F, x, gamma, beta): if self._axis == 1: return F.InstanceNorm(x, gamma, beta, @@ -706,7 +705,6 @@ def __init__(self, function, prefix=None): "Unrecognized function in lambda: {} of type {}" .format(function, type(function))) - @_adapt_np_array def hybrid_forward(self, F, x, *args): return self._func(F, x, *args) diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 2822c7019a28..b8e5b2688429 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -40,7 +40,7 @@ class requests_failed_to_import(object): import numpy as np from .. import ndarray -from ..util import is_np_shape, is_np_array, wraps_safely +from ..util import is_np_shape, is_np_array from .. import numpy as _mx_np # pylint: disable=reimported @@ -484,53 +484,3 @@ def _check_all_np_ndarrays(out): for i in out: _check_all_np_ndarrays(i) # pylint: enable=no-else-raise - - -def _to_classic_arrays(*args, **kwargs): - """Convert arrays to classic arrays. This is used in a Gluon layer for converting - inputs of np arrays to classic arrays so that the layer built with legacy ops can still - be used in np_array semantics.""" - from ..numpy import ndarray as np_ndarray - from ..symbol.numpy import _Symbol as np_symbol - num_inputs = len(args) - assert num_inputs != 0 - if not is_np_array(): - return args, kwargs - in_arrs = [arr if arr is None else arr.as_nd_ndarray() for arr in args] - new_kwargs = {} - for k, v in kwargs.items(): - if isinstance(v, (np_ndarray, np_symbol)): - new_kwargs[k] = v.as_nd_ndarray() - else: - new_kwargs[k] = v - return in_arrs, new_kwargs - - -def _to_np_arrays(*args): - """Convert arrays to np arrays. This is used in a Gluon layer for converting - outputs of classic arrays to np arrays so that the layer built with legacy ops can still - be used in np_array semantics.""" - num_outputs = len(args) - assert num_outputs != 0 - if not is_np_array(): - return args[0] if num_outputs == 1 else args - out = [arr.as_np_ndarray() for arr in args] - return out[0] if num_outputs == 1 else out - - -# TODO(junwu): This is a temp solution for allowing basic layers -# implemented using legacy ops to accept np.ndarrays as inputs and return -# np.ndarrays as outputs. We should remove it after changing all the layers -# to use np ops in np_array semantics in the future. -def _adapt_np_array(func): - @wraps_safely(func) - def _with_np_array(*args, **kwargs): - assert len(args) > 2, "expect at least three arguments in args" - if is_np_array(): - input_args, kwargs = _to_classic_arrays(*args[2:], **kwargs) - input_args = list(args[0:2]) + list(input_args) - out = func(*input_args, **kwargs) - return _to_np_arrays(out) - else: - return func(*args, **kwargs) - return _with_np_array diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index 6e89c004f6a4..4c26f59b980b 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -25,10 +25,7 @@ from . import _register from ._op import * # pylint: disable=wildcard-import from ..context import * # pylint: disable=wildcard-import -# TODO(junwu): revisit what functions should be exposed to users -from ..util import use_np_shape, np_shape, is_np_shape -from ..util import use_np_array, np_array, is_np_array -from ..util import set_np, use_np, reset_np +from ..util import is_np_shape, is_np_array, set_np, reset_np from ..ndarray import waitall from .utils import * # pylint: disable=wildcard-import from .random import * # pylint: disable=wildcard-import diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 0dcb54b625cd..7ecfd58da7b8 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -49,6 +49,7 @@ from .ndarray import array from .symbol import Symbol from .symbol.numpy import _Symbol as np_symbol +from .util import use_np # pylint: disable=unused-import def default_context(): diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc index dc262feda8ac..fc1abe7896e0 100644 --- a/src/operator/numpy/np_init_op.cc +++ b/src/operator/numpy/np_init_op.cc @@ -115,7 +115,7 @@ NNVM_REGISTER_OP(_npi_arange) .set_attr_parser(RangeParamParser) .set_attr("FInferShape", NumpyRangeShape) .set_attr("FInferType", InitType) -.set_attr("FCompute", RangeCompute) +.set_attr("FCompute", RangeCompute) .add_arguments(RangeParam::__FIELDS__()); NNVM_REGISTER_OP(_npi_eye) diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu index 68d168182e5e..7f0d587a55de 100644 --- a/src/operator/numpy/np_init_op.cu +++ b/src/operator/numpy/np_init_op.cu @@ -41,7 +41,7 @@ NNVM_REGISTER_OP(_np_ones_like) .set_attr("FCompute", FillCompute); NNVM_REGISTER_OP(_npi_arange) -.set_attr("FCompute", RangeCompute); +.set_attr("FCompute", RangeCompute); NNVM_REGISTER_OP(_npi_eye) .set_attr("FCompute", NumpyEyeFill); diff --git a/tests/python/unittest/test_contrib_amp.py b/tests/python/unittest/test_contrib_amp.py deleted file mode 100644 index ef3a6d81fb48..000000000000 --- a/tests/python/unittest/test_contrib_amp.py +++ /dev/null @@ -1,86 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import mxnet as mx -import warnings -import collections -import ctypes -import mxnet.contrib.amp as amp - - -def test_amp_coverage(): - conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS] - - # Check for duplicates - for a in [amp.lists.symbol.FP16_FUNCS, - amp.lists.symbol.FP16_FP32_FUNCS, - amp.lists.symbol.FP32_FUNCS, - amp.lists.symbol.WIDEST_TYPE_CASTS, - conditional]: - ret = [item for item, count in collections.Counter(a).items() if count > 1] - assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists." - - t = [] - for a in [amp.lists.symbol.FP16_FUNCS, - amp.lists.symbol.FP16_FP32_FUNCS, - amp.lists.symbol.FP32_FUNCS, - amp.lists.symbol.WIDEST_TYPE_CASTS, - conditional]: - t += a - ret = [item for item, count in collections.Counter(t).items() if count > 1] - assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list." - - # Check the coverage - py_str = lambda x: x.decode('utf-8') - - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - mx.base._LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist)) - op_names = [] - for i in range(size.value): - s = py_str(plist[i]) - if not s.startswith("_backward") \ - and not s.startswith("_contrib_backward_"): - op_names.append(s) - - ret1 = set(op_names) - set(t) - - if ret1 != set(): - warnings.warn("Operators " + str(ret1) + " do not exist in AMP lists (in " - "python/mxnet/contrib/amp/lists/symbol.py) - please add them. " - """Please follow these guidelines for choosing a proper list: - - if your operator is not to be used in a computational graph - (e.g. image manipulation operators, optimizers) or does not have - inputs, put it in FP16_FP32_FUNCS list, - - if your operator requires FP32 inputs or is not safe to use with lower - precision, put it in FP32_FUNCS list, - - if your operator supports both FP32 and lower precision, has - multiple inputs and expects all inputs to be of the same - type, put it in WIDEST_TYPE_CASTS list, - - if your operator supports both FP32 and lower precision and has - either a single input or supports inputs of different type, - put it in FP16_FP32_FUNCS list, - - if your operator is both safe to use in lower precision and - it is highly beneficial to use it in lower precision, then - put it in FP16_FUNCS (this is unlikely for new operators) - - If you are not sure which list to choose, FP32_FUNCS is the - safest option""") - -if __name__ == '__main__': - test_amp_coverage() diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index b4db7bfc4ab0..1821f8d68427 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -20,7 +20,8 @@ from __future__ import division import mxnet as mx -from mxnet import gluon, autograd, np, npx +from mxnet import gluon, autograd, np +from mxnet.test_utils import use_np def test_create_np_param(): @@ -45,7 +46,7 @@ def __init__(self): def hybrid_forward(self, F, x, w): return F.dot(x, w) - @npx.use_np + @use_np class TestBlock2(gluon.HybridBlock): def __init__(self): super(TestBlock2, self).__init__() @@ -62,7 +63,7 @@ def hybrid_forward(self, F, x, w): check_block_params(x.as_np_ndarray(), TestBlock2, True, np.ndarray) -@npx.use_np +@use_np def test_optimizer_with_np_ndarrays(): class LinearRegression(gluon.HybridBlock): def __init__(self, num_input_dim=0, num_hidden_dim=100, num_output_dim=10): diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 887bb9a7916f..080a662980cf 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -23,12 +23,12 @@ import mxnet as mx from mxnet import np, npx, autograd from mxnet.gluon import HybridBlock -from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception +from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception, use_np from common import with_seed, TemporaryDirectory @with_seed() -@npx.use_np_shape +@use_np def test_array_creation(): dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None] objects = [ @@ -53,7 +53,7 @@ def test_array_creation(): @with_seed() -@npx.use_np_shape +@use_np def test_zeros(): # test np.zeros in Gluon class TestZeros(HybridBlock): @@ -101,7 +101,7 @@ def check_zero_array_creation(shape, dtype): @with_seed() -@npx.use_np_shape +@use_np def test_ones(): # test np.ones in Gluon class TestOnes(HybridBlock): @@ -167,7 +167,7 @@ def test_ndarray_binary_element_wise_ops(): def get_np_ret(x1, x2, op): return np_op_map[op](x1, x2) - @npx.use_np_shape + @use_np class TestBinaryElementWiseOp(HybridBlock): def __init__(self, op, scalar=None, reverse=False): super(TestBinaryElementWiseOp, self).__init__() @@ -235,7 +235,7 @@ def hybrid_forward(self, F, x, *args): print(self._op) assert False - @npx.use_np_shape + @use_np def check_binary_op_result(shape1, shape2, op, dtype=None): if shape1 is None: mx_input1 = abs(_np.random.uniform()) + 1 @@ -305,7 +305,7 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): @with_seed() def test_hybrid_block_multiple_outputs(): - @npx.use_np_shape + @use_np class TestAllNumpyOutputs(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): return F.npx.relu(x), F.np.sum(x) @@ -325,7 +325,7 @@ def hybrid_forward(self, F, x, *args, **kwargs): assert type(out1) is expected_out_type assert type(out2) is expected_out_type - @npx.use_np_array + @use_np class TestMixedTypeOutputsFailure(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): return F.relu(x.as_nd_ndarray()), F.np.sum(x) @@ -337,7 +337,7 @@ def hybrid_forward(self, F, x, *args, **kwargs): @with_seed() -@npx.use_np_shape +@use_np def test_grad_ndarray_type(): data = np.array(2, dtype=_np.float32) data.attach_grad() @@ -375,7 +375,7 @@ def test_np_ndarray_copy(): @with_seed() -@npx.use_np_shape +@use_np def test_np_ndarray_indexing(): def test_getitem(np_array, index): """`is_scalar` indicates whether we should expect a scalar for the result. @@ -627,7 +627,7 @@ def convert(num): @with_seed() -@npx.use_np +@use_np def test_np_save_load_ndarrays(): shapes = [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 5, 0), (4, 5, 6), (0, 0, 0)] array_list = [_np.random.randint(0, 10, size=shape) for shape in shapes] @@ -671,7 +671,7 @@ def test_np_save_load_ndarrays(): @retry(5) @with_seed() -@npx.use_np_shape +@use_np def test_np_multinomial(): pvals_list = [[0.0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0.0]] sizes = [None, (), (3,), (2, 5, 7), (4, 9)] diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index cd323e202ed4..8a89b91fff73 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -24,20 +24,20 @@ from mxnet.gluon import HybridBlock from mxnet.base import MXNetError from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray -from mxnet.test_utils import check_numeric_gradient +from mxnet.test_utils import check_numeric_gradient, use_np from common import assertRaises, with_seed import random import collections @with_seed() -@npx.use_np_shape +@use_np def test_np_tensordot(): class TestTensordot(HybridBlock): def __init__(self, axes): super(TestTensordot, self).__init__() self._axes = axes - + def hybrid_forward(self, F, a, b): return F.np.tensordot(a, b, self._axes) @@ -180,7 +180,7 @@ def tensordot_backward(a, b, axes=2): @with_seed() -@npx.use_np_shape +@use_np def test_np_sum(): class TestSum(HybridBlock): def __init__(self, axis=None, dtype=None, keepdims=False): @@ -242,7 +242,7 @@ def is_int(dtype): @with_seed() -@npx.use_np_shape +@use_np def test_np_dot(): shapes = [ ((3, 0), (0, 4)), @@ -290,9 +290,8 @@ def test_np_dot(): @with_seed() -@npx.use_np_shape +@use_np def test_np_mean(): - @npx.use_np_shape class TestMean(HybridBlock): def __init__(self, axis=None, dtype=None, keepdims=False): super(TestMean, self).__init__() @@ -355,9 +354,8 @@ def is_int(dtype): @with_seed() -@npx.use_np_shape +@use_np def test_np_max(): - @npx.use_np_shape class TestMax(HybridBlock): def __init__(self, axis=None, keepdims=False): super(TestMax, self).__init__() @@ -444,7 +442,7 @@ def _test_np_max_exception(shape, dim): @with_seed() -@npx.use_np_shape +@use_np def test_np_transpose(): # TODO(junwu): Add more test cases data = mx.sym.var('a').as_np_ndarray() @@ -474,7 +472,7 @@ def test_np_transpose(): @with_seed() -@npx.use_np_shape +@use_np def test_npx_relu(): # TODO(junwu): Add more test cases data = mx.sym.var('data').as_np_ndarray() @@ -490,7 +488,7 @@ def test_npx_relu(): @with_seed() -@npx.use_np_shape +@use_np def test_npx_sigmoid(): # TODO(junwu): Add more test cases data = mx.sym.var('data').as_np_ndarray() @@ -506,7 +504,7 @@ def test_npx_sigmoid(): @with_seed() -@npx.use_np_shape +@use_np def test_np_reshape(): # TODO(junwu): Add more test cases data = mx.sym.var('a').as_np_ndarray() @@ -524,7 +522,7 @@ def test_np_reshape(): @with_seed() -@npx.use_np_shape +@use_np def test_np_maximum(): # TODO(junwu): Add more test cases x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray() @@ -545,7 +543,7 @@ def check_maximum(x1, x2): @with_seed() -@npx.use_np_shape +@use_np def test_np_minimum(): # TODO(junwu): Add more test cases x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray() @@ -566,10 +564,9 @@ def check_minimum(x1, x2): @with_seed() -@npx.use_np_shape +@use_np def test_np_unary_funcs(): def check_unary_func(func, ref_grad, shape, low, high): - @npx.use_np_shape class TestUnary(HybridBlock): def __init__(self, func): super(TestUnary, self).__init__() @@ -641,9 +638,8 @@ def hybrid_forward(self, F, a, *args, **kwargs): @with_seed() -@npx.use_np_shape +@use_np def test_np_stack(): - @npx.use_np_shape class TestStack(HybridBlock): def __init__(self, axis=None): super(TestStack, self).__init__() @@ -694,7 +690,7 @@ def hybrid_forward(self, F, a, *args): @with_seed() -@npx.use_np_shape +@use_np def test_np_random(): shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] dtypes = ['float16', 'float32', 'float64'] @@ -710,7 +706,6 @@ def test_np_random(): expected_shape = () if shape is None else (shape,) assert out.shape == expected_shape - @npx.use_np class TestRandom(HybridBlock): def __init__(self, shape, op_name): super(TestRandom, self).__init__() @@ -737,7 +732,7 @@ def hybrid_forward(self, F, x): @with_seed() -@npx.use_np_shape +@use_np def test_np_arange(): configs = [ (1, 10, 2), @@ -772,7 +767,6 @@ def test_np_arange(): np_ret = _np.arange(config, dtype=dtype) assert same(mx_ret.asnumpy(), np_ret) - @npx.use_np class TestRange(HybridBlock): def __init__(self, start, stop=None, step=None, dtype=None): super(TestRange, self).__init__() @@ -801,7 +795,7 @@ def hybrid_forward(self, F, x): @with_seed() -@npx.use_np_shape +@use_np def test_np_linspace(): configs = [ (0.0, 1.0, 10), @@ -835,7 +829,7 @@ def test_np_linspace(): # check linspace equivalent to arange for test_index in range(1000): assert_almost_equal(mx.np.linspace(0, test_index, test_index + 1).asnumpy(), mx.np.arange(test_index + 1).asnumpy()) - @npx.use_np + @use_np class TestLinspace(HybridBlock): def __init__(self, start, stop, num=50, endpoint=None, retstep=False, dtype=None, axis=0): super(TestLinspace, self).__init__() @@ -871,7 +865,7 @@ def hybrid_forward(self, F, x): @with_seed() -@npx.use_np_shape +@use_np def test_np_eye(): configs = [ 4, @@ -910,7 +904,7 @@ def test_np_eye(): assertRaises(MXNetError, np.eye, *config) else: assertRaises(MXNetError, np.eye, config) - @npx.use_np + @use_np class TestEye(HybridBlock): def __init__(self, N, M=None, k=0, dtype=None): super(TestEye, self).__init__() @@ -939,7 +933,7 @@ def hybrid_forward(self, F, x): @with_seed() -@npx.use_np_shape +@use_np def test_np_argmax(): workloads = [ ((), 0, False), @@ -956,7 +950,7 @@ def test_np_argmax(): ] dtypes = ['float16', 'float32', 'float64'] - @npx.use_np + @use_np class TestArgMax(HybridBlock): def __init__(self, axis=None): super(TestArgMax, self).__init__() @@ -1001,9 +995,9 @@ def hybrid_forward(self, F, x): @with_seed() -@npx.use_np_shape +@use_np def test_np_argsort(): - @npx.use_np_shape + @use_np class TestArgsort(HybridBlock): def __init__(self, axis=-1): super(TestArgsort, self).__init__() @@ -1042,9 +1036,9 @@ def hybrid_forward(self, F, a): @with_seed() -@npx.use_np_shape +@use_np def test_np_linalg_norm(): - @npx.use_np + @use_np class TestLinalgNorm(HybridBlock): def __init__(self, ord=None, axis=None, keepdims=False): super(TestLinalgNorm, self).__init__() @@ -1073,7 +1067,7 @@ def hybrid_forward(self, F, x): @with_seed() -@npx.use_np_shape +@use_np def test_np_concat(): class TestConcat(HybridBlock): def __init__(self, axis=None): @@ -1124,12 +1118,12 @@ def get_new_shape(shape, axis): @with_seed() -@npx.use_np_shape +@use_np def test_np_hstack(): class TestHStack(HybridBlock): def __init__(self): super(TestHStack, self).__init__() - + def hybrid_forward(self, F, a, *args): return F.np.hstack([a] + list(args)) @@ -1189,7 +1183,7 @@ def get_new_shape(shape): @with_seed() -@npx.use_np_shape +@use_np def test_np_swapaxes(): config = [((0, 1, 2), 0, 1), ((0, 1, 2), -1, -2), @@ -1221,7 +1215,7 @@ def hybrid_forward(self, F, x): @with_seed() -@npx.use_np_shape +@use_np def test_np_squeeze(): config = [((), None), ((), -1), @@ -1255,7 +1249,7 @@ def hybrid_forward(self, F, x): @with_seed() -@npx.use_np_shape +@use_np def test_np_split(): class TestSplit(HybridBlock): def __init__(self, indices_or_sections, axis=None): @@ -1308,12 +1302,12 @@ def get_indices(axis_size): @with_seed() -@npx.use_np_shape +@use_np def test_np_cumsum(): def np_cumsum_backward(ograd, axis=None, dtype=None): return _np.flip(_np.cumsum(_np.flip(ograd, axis=axis), axis=axis, dtype=dtype), axis=axis) - @npx.use_np_shape + @use_np class TestCumsum(HybridBlock): def __init__(self, axis=None, dtype=None): super(TestCumsum, self).__init__() @@ -1350,7 +1344,7 @@ def hybrid_forward(self, F, a): @with_seed() -@npx.use_np_shape +@use_np def test_np_tile(): config = [ ((), ()), @@ -1391,7 +1385,7 @@ def hybrid_forward(self, F, x): @with_seed() -@npx.use_np_shape +@use_np def test_np_prod(): class TestProd(HybridBlock): def __init__(self, axis=None, dtype=None, keepdims=False): @@ -1443,7 +1437,7 @@ def hybrid_forward(self, F, a, *args, **kwargs): @with_seed() -@npx.use_np +@use_np def test_np_flatten(): # TODO(junwu): Add more test cases shapes = [(), (2, 0, 1), (3, 4, 5), 6] @@ -1456,7 +1450,7 @@ def test_np_flatten(): @with_seed() -@npx.use_np +@use_np def test_np_broadcast_to(): # TODO(junwu): Add more test cases and backward test shapes = [(1, 2, 3, 4, 5), (1, 0, 3, 4, 5)] @@ -1469,7 +1463,7 @@ def test_np_broadcast_to(): @with_seed() -@npx.use_np +@use_np def test_np_meshgrid(): nx, ny = (4, 5) x = np.linspace(0, 1, nx) @@ -1484,14 +1478,14 @@ def test_np_meshgrid(): @with_seed() -@npx.use_np +@use_np def test_np_broadcast_arrays(): # TODO(junwu): Add test pass @with_seed() -@npx.use_np +@use_np def test_np_trace(): class TestTrace(HybridBlock): def __init__(self, axis1, axis2, offset): @@ -1499,10 +1493,10 @@ def __init__(self, axis1, axis2, offset): self._axis1 = axis1 self._axis2 = axis2 self._offset = offset - + def hybrid_forward(self, F, data): return F.np.trace(data, axis1=self._axis1, axis2=self._axis2, offset=self._offset) - + def g(data, axis1, axis2, offset): idx = _np.indices(data.shape) ret = _np.zeros_like(data)