diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index e5f2c6c02089..6ee37982a124 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -31,6 +31,7 @@ import zipfile import json from contextlib import contextmanager +from collections import OrderedDict import numpy as np import numpy.testing as npt import numpy.random as rnd @@ -109,6 +110,16 @@ def random_sample(population, k): return population_copy[0:k] +def _sorted_items(d): + """Return (key, value) pairs of dict 'd' in a deterministic order (sorted by key).""" + return sorted(d.items(), key=lambda t: t[0]) + + +def _sorted_dict(d): + """Return ordered dictionary containing items ordered by their keys.""" + return OrderedDict(_sorted_items(d)) + + def _validate_csr_generation_inputs(num_rows, num_cols, density, distribution="uniform"): """Validates inputs for csr generation helper functions @@ -482,9 +493,10 @@ def find_max_violation(a, b, rtol=None, atol=None): """Finds and returns the location of maximum violation.""" rtol = get_rtol(rtol) atol = get_atol(atol) - diff = np.abs(a-b) + # 'smart' absdiff that considers inf's as equals (to match np.allclose) + absdiff = np.where(np.equal(a, b), 0, np.abs(a-b)) tol = atol + rtol*np.abs(b) - violation = diff/(tol+1e-20) + violation = absdiff/(tol+1e-20) loc = np.argmax(violation) idx = np.unravel_index(loc, violation.shape) return idx, np.max(violation) @@ -500,40 +512,122 @@ def same(a, b): """ return np.array_equal(a, b) -def almost_equal(a, b, rtol=None, atol=None, equal_nan=False, use_broadcast=True): - """Test if two numpy arrays are almost equal.""" - # pylint: disable=unexpected-keyword-arg - if (not use_broadcast) and a.shape != b.shape: + +def checkShapes(a, b): + if a.shape != b.shape: msg = npt.build_err_msg([a, b], err_msg="a.shape = {} and b.shape = {} are not equal" .format(str(a.shape), str(b.shape))) raise AssertionError(msg) + + +def almost_equal(a, b, rtol=None, atol=None, equal_nan=False, use_broadcast=True): + """Test if two numpy arrays are almost equal.""" + # pylint: disable=unexpected-keyword-arg + if not use_broadcast: + checkShapes(a, b) + return np.allclose(a, b, rtol=get_rtol(rtol), atol=get_atol(atol), equal_nan=equal_nan) # pylint: enable=unexpected-keyword-arg -def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=False, use_broadcast=True): + +def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=False, + use_broadcast=True, mismatches=(10, 10)): """Test that two numpy arrays are almost equal. Raise exception message if not. Parameters ---------- - a : np.ndarray - b : np.ndarray - threshold : None or float - The checking threshold. Default threshold will be used if set to ``None``. + a : np.ndarray or mx.nd.array + b : np.ndarray or mx.nd.array + rtol : None or float + The relative threshold. Default threshold will be used if set to ``None``. + atol : None or float + The absolute threshold. Default threshold will be used if set to ``None``. + names : tuple of names, optional + The names used in error message when an exception occurs + equal_nan : boolean, optional + The flag determining how to treat NAN values in comparison + mismatches : tuple of mismatches + Maximum number of mismatches to be printed (mismatches[0]) and determine (mismatches[1]) """ + if not use_broadcast: + checkShapes(a, b) + rtol = get_rtol(rtol) atol = get_atol(atol) - if almost_equal(a, b, rtol, atol, equal_nan=equal_nan, use_broadcast=use_broadcast): - return + use_np_allclose = isinstance(a, np.ndarray) and isinstance(b, np.ndarray) + if not use_np_allclose: + if not (hasattr(a, 'context') and hasattr(b, 'context') and a.context == b.context and a.dtype == b.dtype): + use_np_allclose = True + if isinstance(a, mx.nd.NDArray): + a = a.asnumpy() + if isinstance(b, mx.nd.NDArray): + b = b.asnumpy() + + if use_np_allclose: + if almost_equal(a, b, rtol, atol, equal_nan=equal_nan): + return + else: + output = mx.nd.contrib.allclose(a, b, rtol, atol, equal_nan) + if output.asnumpy() == 1: + return + + a = a.asnumpy() + b = b.asnumpy() + + def locationError(a, b, index, names, maxError=False): + """Create element mismatch comment + + Parameters + ---------- + a, b : compared np.ndarray's + index : tuple of coordinate arrays + Location of violation + names : tuple of names + The names of compared arrays. + maxError: boolean, optional + Flag indicating that maximum error is reporting. + """ + maximum = "maximum " if maxError else "" + return "Location of %serror: %s, %s=%.8f, %s=%.8f" \ + % (maximum, str(index), names[0], a[index], names[1], b[index]) + index, rel = find_max_violation(a, b, rtol, atol) + indexErr = index + relErr = rel + + print('\n*** Maximum errors for vector of size {}: rtol={}, atol={}\n'.format(a.size, rtol, atol)) + aTmp = a.copy() + bTmp = b.copy() + i = 1 + while i <= a.size: + if i <= mismatches[0]: + print("%3d: Error %f %s" %(i, rel, locationError(a, b, index, names))) + + aTmp[index] = bTmp[index] = 0 + if almost_equal(aTmp, bTmp, rtol, atol, equal_nan=equal_nan): + break + + i += 1 + if i <= mismatches[1] or mismatches[1] <= 0: + index, rel = find_max_violation(aTmp, bTmp, rtol, atol) + else: + break + + mismatchDegree = "at least " if mismatches[1] > 0 and i > mismatches[1] else "" + errMsg = "Error %f exceeds tolerance rtol=%e, atol=%e (mismatch %s%f%%).\n%s" % \ + (relErr, rtol, atol, mismatchDegree, 100*i/a.size, \ + locationError(a, b, indexErr, names, maxError=True)) np.set_printoptions(threshold=4, suppress=True) - msg = npt.build_err_msg([a, b], - err_msg="Error %f exceeds tolerance rtol=%f, atol=%f. " - " Location of maximum error:%s, a=%f, b=%f" - % (rel, rtol, atol, str(index), a[index], b[index]), - names=names) + msg = npt.build_err_msg([a, b], err_msg=errMsg) + raise AssertionError(msg) + +def assert_allclose(a, b, rtol=1e-07, atol=0, equal_nan=True): + assert_almost_equal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None, names=('a', 'b'), equal_nan=False): """Test that two numpy arrays are almost equal within given error rate. Raise exception message if not. @@ -554,7 +648,6 @@ def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None, names=(' equals = np.isclose(a, b, rtol=rtol, atol=atol) err = 1 - np.count_nonzero(equals) / equals.size if err > etol: - #if True: index, rel = find_max_violation(a, b, rtol, atol) np.set_printoptions(threshold=4, suppress=True) msg = npt.build_err_msg([a, b], @@ -684,7 +777,7 @@ def simple_forward(sym, ctx=None, is_train=False, **inputs): def _parse_location(sym, location, ctx, dtype=default_dtype()): - """Parses the given location to a dictionary. + """Parses the given location to a ordered dictionary. Arguments of the provided op `sym` are used as dictionary keys and elements of `location` are used as values. @@ -740,7 +833,7 @@ def _parse_location(sym, location, ctx, dtype=default_dtype()): location = {k: v for k, v in zip(sym.list_arguments(), location)} location = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \ if isinstance(v, np.ndarray) else v for k, v in location.items()} - return location + return _sorted_dict(location) def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()): @@ -1177,7 +1270,8 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} - args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in expected.items()} + # Dirty the output buffer deterministically, for reproducibility. + args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in _sorted_items(expected)} args_grad_data = {} for k, v in args_grad_npy.items(): nd = mx.nd.array(v, ctx=ctx, dtype=expected[k].dtype if dtype == "asnumpy" else dtype) @@ -1313,6 +1407,15 @@ def check_speed(sym, location=None, ctx=None, N=20, grad_req=None, typ="whole", else: raise ValueError('typ can only be "whole" or "forward".') + +def get_tolerance(rtol, ctx): + if 'atol' in ctx: + return ctx['atol'] + if 'atol_mult' in ctx: + return ctx['atol_mult'] * rtol + return rtol + + def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', arg_params=None, aux_params=None, tol=None, raise_on_err=True, ground_truth=None, equal_nan=False, @@ -1431,12 +1534,15 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', for i, exe in enumerate(exe_list): if i == max_idx: continue + + rtol = tol[dtypes[i]] + atol = get_tolerance(rtol, ctx_list[i]) for name, arr in zip(output_names, exe.outputs): - gtarr = gt[name].astype(dtypes[i]).asnumpy() - arr = arr.asnumpy() + # Previously, the cast was to dtypes[i], but symbol may be mixed-precision, + # so casting the ground truth to the actual output type seems more correct. + gtarr = gt[name].astype(arr.dtype) try: - assert_almost_equal(arr, gtarr, rtol=tol[dtypes[i]], atol=tol[dtypes[i]], - equal_nan=equal_nan) + assert_almost_equal(arr, gtarr, rtol=rtol, atol=atol, equal_nan=equal_nan) except AssertionError as e: print('Predict Err: ctx %d vs ctx %d at %s'%(i, max_idx, name)) traceback.print_exc() @@ -1454,16 +1560,20 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write', for i, exe in enumerate(exe_list): if i == max_idx: continue + + rtol = tol[dtypes[i]] + atol = get_tolerance(rtol, ctx_list[i]) curr = zip(output_names + arg_names, exe.outputs + exe.grad_arrays) for name, arr in curr: if gt[name] is None: assert arr is None continue - gtarr = gt[name].astype(dtypes[i]).asnumpy() - arr = arr.asnumpy() + + # Previous cast was to dtypes[i], but symbol may be mixed-precision, + # so casting the ground truth to the actual output type seems more correct. + gtarr = gt[name].astype(arr.dtype) try: - assert_almost_equal(arr, gtarr, rtol=tol[dtypes[i]], atol=tol[dtypes[i]], - equal_nan=equal_nan) + assert_almost_equal(arr, gtarr, rtol=rtol, atol=atol, equal_nan=equal_nan) except AssertionError as e: print('Train Err: ctx %d vs ctx %d at %s'%(i, max_idx, name)) traceback.print_exc() @@ -1694,7 +1804,7 @@ def get_mnist_iterator(batch_size, input_shape, num_parts=1, part_index=0): """ get_mnist_ubyte() - flat = not bool(len(input_shape) == 3) + flat = len(input_shape) != 3 train_dataiter = mx.io.MNISTIter( image="data/train-images-idx3-ubyte", @@ -2134,12 +2244,14 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc def compare_ndarray_tuple(t1, t2, rtol=None, atol=None): """Compare ndarray tuple.""" - if t1 is not None and t2 is not None: - if isinstance(t1, tuple): - for s1, s2 in zip(t1, t2): - compare_ndarray_tuple(s1, s2, rtol, atol) - else: - assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol) + if t1 is None or t2 is None: + return + + if isinstance(t1, tuple): + for s1, s2 in zip(t1, t2): + compare_ndarray_tuple(s1, s2, rtol, atol) + else: + assert_almost_equal(t1, t2, rtol=rtol, atol=atol) def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default', @@ -2171,7 +2283,7 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa opt2.update_multi_precision(0, w2, g2, state2) if compare_states: compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol) - assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol) + assert_almost_equal(w1, w2, rtol=rtol, atol=atol) def same_symbol_structure(sym1, sym2): diff --git a/src/operator/contrib/allclose_op-inl.h b/src/operator/contrib/allclose_op-inl.h new file mode 100644 index 000000000000..a858450f0007 --- /dev/null +++ b/src/operator/contrib/allclose_op-inl.h @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file allclose-inl.h + * \brief Operator implementing numpy.allclose function. + * \author Andrei Ivanov + */ +#ifndef MXNET_OPERATOR_CONTRIB_ALLCLOSE_OP_INL_H_ +#define MXNET_OPERATOR_CONTRIB_ALLCLOSE_OP_INL_H_ + +#include +#include +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" +#include "../tensor/init_op.h" + +namespace mxnet { +namespace op { + +// Intermediate and Output data types could be integers OR unsigned characters +#define USE_INTEGER 0 +#if USE_INTEGER + #define INTERM_DATA_TYPE int32_t + #define OUT_DATA_TYPE mshadow::kInt32 +#else + #define INTERM_DATA_TYPE uint8_t + #define OUT_DATA_TYPE mshadow::kUint8 +#endif + +struct AllCloseParam : public dmlc::Parameter { + float rtol, atol; + bool equal_nan; + DMLC_DECLARE_PARAMETER(AllCloseParam) { + DMLC_DECLARE_FIELD(rtol) + .set_default(1e-05) + .describe("Relative tolerance."); + DMLC_DECLARE_FIELD(atol) + .set_default(1e-08) + .describe("Absolute tolerance."); + DMLC_DECLARE_FIELD(equal_nan) + .set_default(true) + .describe("Whether to compare NaN’s as equal. If True, NaN’s in A will be considered equal " + "to NaN’s in B in the output array."); + } +}; + +inline bool AllCloseShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U) << "Input:[array1, array2]"; + CHECK_EQ(out_attrs->size(), 1U); + + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(0, -1)); + return in_attrs->at(0) == in_attrs->at(1); +} + +inline bool AllCloseType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + // The output will be boolean stored as an OUT_DATA_TYPE format + TYPE_ASSIGN_CHECK(*out_attrs, 0, OUT_DATA_TYPE); + return (*out_attrs)[0] != -1; +} + +using namespace mshadow_op::isnan_typed; + +template +struct allclose_forward { + template + MSHADOW_XINLINE static void Map(int i, INTERM_DATA_TYPE *out_data, + const DType *in_a, const DType *in_b, + const float rtol, const float atol, bool equal_nan) { + const DType a = in_a[i], b = in_b[i]; + bool val; + if (IsNan(a) || IsNan(b)) + val = equal_nan && IsNan(a) == IsNan(b); + else + val = a == b || (a > b? a - b : b - a) <= atol + (b > 0? rtol * b : (-rtol) * b); + + KERNEL_ASSIGN(out_data[i], req, val? 1 : 0); + } +}; + +template +size_t GetAdditionalMemoryLogical(mshadow::Stream *s, const int num_items); + +template +INTERM_DATA_TYPE *GetAdditionalMemoryLogical(const OpContext& ctx, + int num_items, size_t *pExtraStorageBytes) { +// Get length of the additional memory (which is used only by DeviceReduce::Min(...) on gpu) + *pExtraStorageBytes = GetAdditionalMemoryLogical(ctx.get_stream(), num_items); + const size_t workspace_total_bytes_ = num_items * sizeof(INTERM_DATA_TYPE) + *pExtraStorageBytes; + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed( + mshadow::Shape1(workspace_total_bytes_), ctx.get_stream()); + + return reinterpret_cast(workspace.dptr_); +} + +template +void GetResultLogical(mshadow::Stream *s, INTERM_DATA_TYPE *workMem, size_t extraStorageBytes, + int num_items, INTERM_DATA_TYPE *outPntr); + +template +void AllClose(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + const TBlob& in0 = inputs[0]; + const TBlob& in1 = inputs[1]; + const int num_items = in0.Size(); + + size_t extraStorageBytes; + auto workspaceMem = GetAdditionalMemoryLogical(ctx, num_items, &extraStorageBytes); + auto s = ctx.get_stream(); + const AllCloseParam& param = nnvm::get(attrs.parsed); + using namespace mxnet_op; + MSHADOW_TYPE_SWITCH(in0.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, num_items, workspaceMem, in0.dptr(), in1.dptr(), + param.rtol, param.atol, param.equal_nan); + }); + }); + + auto *pOut = outputs[0].dptr(); + GetResultLogical(s, workspaceMem, extraStorageBytes, num_items, pOut); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_ALLCLOSE_OP_INL_H_ diff --git a/src/operator/contrib/allclose_op.cc b/src/operator/contrib/allclose_op.cc new file mode 100644 index 000000000000..6b301ad6519f --- /dev/null +++ b/src/operator/contrib/allclose_op.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file allclose_op.cc + * \brief CPU Implementation of allclose op + * \author Andrei Ivanov + */ +#include "./allclose_op-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(AllCloseParam); + +NNVM_REGISTER_OP(_contrib_allclose) +.describe(R"code(This operators implements the numpy.allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False) + +.. math:: + + f(x) = |a−b|≤atol+rtol|b| + +where +:math:`a, b` are the input tensors of equal types an shapes +:math:`atol, rtol` the values of absolute and relative tolerance (by default, rtol=1e-05, atol=1e-08) + +Examples:: + + a = [1e10, 1e-7], + b = [1.00001e10, 1e-8] + y = allclose(a, b) + y = False + + a = [1e10, 1e-8], + b = [1.00001e10, 1e-9] + y = allclose(a, b) + y = True + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) +.set_attr("FInferShape", AllCloseShape) +.set_attr("FInferType", AllCloseType) +.set_attr("FCompute", AllClose) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.add_argument("a", "NDArray-or-Symbol", "Input array a") +.add_argument("b", "NDArray-or-Symbol", "Input array b") +.add_arguments(AllCloseParam::__FIELDS__()); + +template<> +size_t GetAdditionalMemoryLogical(mshadow::Stream *s, const int num_items) { + return 0; +} + +template<> +void GetResultLogical(mshadow::Stream *s, INTERM_DATA_TYPE *workMem, + size_t extraStorageBytes, int num_items, INTERM_DATA_TYPE *outPntr) { + while (num_items-- > 0 && workMem[num_items]) {} + outPntr[0] = num_items >= 0? 0 : 1; +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/allclose_op.cu b/src/operator/contrib/allclose_op.cu new file mode 100644 index 000000000000..f923ab060813 --- /dev/null +++ b/src/operator/contrib/allclose_op.cu @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file allclose_op.cu + * \brief GPU Implementation of allclose op + * \author Andrei Ivanov + */ +#include "./allclose_op-inl.h" +#include + +namespace mxnet { +namespace op { + +template +size_t GetAdditionalMemory(mshadow::Stream *s, const int num_items) { + T *d_in = nullptr; + T *d_out = nullptr; + size_t temp_storage_bytes = 0; + cudaStream_t stream = mshadow::Stream::GetStream(s); + cub::DeviceReduce::Min(nullptr, temp_storage_bytes, d_in, d_out, num_items, stream); + return temp_storage_bytes; +} + +template<> +size_t GetAdditionalMemoryLogical(mshadow::Stream *s, const int num_items) { + return GetAdditionalMemory(s, num_items); +} + +template<> +void GetResultLogical(mshadow::Stream *s, INTERM_DATA_TYPE *workMem, + size_t extraStorageBytes, int num_items, INTERM_DATA_TYPE *outPntr) { + cudaStream_t stream = mshadow::Stream::GetStream(s); + cub::DeviceReduce::Min(workMem + num_items, extraStorageBytes, + workMem, outPntr, num_items, stream); +} + +NNVM_REGISTER_OP(_contrib_allclose) +.set_attr("FCompute", AllClose); + +} // namespace op +} // namespace mxnet diff --git a/tests/python-pytest/onnx/mxnet_export_test.py b/tests/python-pytest/onnx/mxnet_export_test.py index 6c81198a8bca..90e92cccee06 100644 --- a/tests/python-pytest/onnx/mxnet_export_test.py +++ b/tests/python-pytest/onnx/mxnet_export_test.py @@ -74,7 +74,7 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params= # Confirm network outputs are the same imported_net_output = _force_list(imported_net(data)) for out, imp_out in zip(output, imported_net_output): - mx.test_utils.assert_almost_equal(out.asnumpy(), imp_out.asnumpy()) + mx.test_utils.assert_almost_equal(out, imp_out) class TestExport(unittest.TestCase): diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index fc650294a538..48fff7810c2e 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -20,26 +20,17 @@ import os import tempfile import time -import multiprocessing as mp -import unittest -import random import mxnet as mx +import multiprocessing as mp +from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, rand_ndarray import mxnet.ndarray as nd import numpy as np -import unittest import math -from nose.tools import assert_raises -from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal -from mxnet.base import MXNetError from mxnet import autograd -from numpy.testing import assert_allclose -from mxnet.test_utils import rand_ndarray - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) -from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied -from common import run_in_spawned_process +from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, run_in_spawned_process from test_gluon import * from test_loss import * from test_gluon_rnn import * @@ -60,9 +51,9 @@ def check_rnn_layer(layer): co, cs = layer(x, states) # atol of 1e-6 required, as exposed by seed 2124685726 - assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6) + assert_almost_equal(go, co, rtol=1e-2, atol=1e-6) for g, c in zip(gs, cs): - assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) + assert_almost_equal(g, c, rtol=1e-2, atol=1e-6) @with_seed() @@ -79,9 +70,9 @@ def check_rnn_layer_w_rand_inputs(layer): states = layer.begin_state(16) co, cs = layer(x, states) - assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6) + assert_almost_equal(go, co, rtol=1e-2, atol=1e-6) for g, c in zip(gs, cs): - assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) + assert_almost_equal(g, c, rtol=1e-2, atol=1e-6) @with_seed() @@ -117,21 +108,19 @@ def test_lstmp(): layer_output = lstm_layer(lstm_input.copy()) cell_output = lstm_cell.unroll(seq_len, lstm_input.copy(), layout='TNC', merge_outputs=True)[0] - assert_almost_equal(layer_output.asnumpy(), - cell_output.asnumpy(), rtol=rtol, atol=atol) + + assert_almost_equal(layer_output, cell_output, rtol=rtol, atol=atol) layer_output.backward() cell_output.backward() for k, v in weights.items(): layer_grad = layer_params['lstm0_l0_' + k].grad() cell_grad = cell_params['lstm0_l0_' + k].grad() print('checking gradient for {}'.format('lstm0_l0_' + k)) - assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(), - rtol=rtol, atol=atol) + assert_almost_equal(layer_grad, cell_grad, rtol=rtol, atol=atol) check_rnn_layer_forward(gluon.rnn.LSTM( 10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx=ctx) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones( (8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], ctx=ctx) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)), run_only=True, ctx=ctx) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5), @@ -223,7 +212,8 @@ def forward(self, inpt): 'r0', 'r0l0')].set_data(weights[k]) data = mx.random.uniform(shape=(11, 10, in_size)) - assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy()) + mx.test_utils.assert_allclose(net(data), ref_net(data), rtol=1e-6) + def check_layer_bidirectional_varseqlen(size, in_size): @@ -336,8 +326,7 @@ def test_gluon_ctc_consistency(): l_gpu = loss(gpu_data, gpu_label) l_gpu.backward() - assert_almost_equal(cpu_data.grad.asnumpy(), - gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(cpu_data.grad, gpu_data.grad, atol=1e-3, rtol=1e-3) @with_seed() @@ -351,9 +340,88 @@ def test_global_norm_clip_multi_device(): assert norm == 5.0 else: assert norm.asscalar() == 5.0 - assert_almost_equal(x1.asnumpy(), np.ones((3, 3)) / 5) - assert_almost_equal(x2.asnumpy(), np.ones((4, 4)) / 5) + assert_almost_equal(x1, np.ones((3, 3)) / 5) + assert_almost_equal(x2, np.ones((4, 4)) / 5) + + +def _check_batchnorm_result(input, num_devices=1, cuda=False): + from mxnet.gluon.utils import split_and_load + def _find_bn(module): + if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2, ctx): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(ctx)) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + nch = input.shape[1] + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # assert forwarding + assert_almost_equal(input1, input2, atol=1e-3, rtol=1e-3) + assert_almost_equal(output1, output2, atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]), + _find_bn(bn2).running_mean.data(ctx_list[0]), + atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]), + _find_bn(bn2).running_var.data(ctx_list[0]), + atol=1e-3, rtol=1e-3) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + assert_almost_equal(input1.grad, input2grad, atol=1e-3, rtol=1e-3) + +@with_seed() +def test_sync_batchnorm(): + def get_num_devices(): + for i in range(100): + try: + mx.nd.zeros((1,), ctx=mx.gpu(i)) + except: + return i + # no need to use SyncBN with 1 gpu + if get_num_devices() < 2: + return + ndev = 2 + # check with unsync version + for i in range(10): + _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=True) @with_seed() def test_symbol_block_fp16(): @@ -468,10 +536,7 @@ def get_net(num_ops): time_per_iteration.value = (time.time() - start) / num_iterations - -@with_seed() -@unittest.skip('skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/14970') -def test_bulking(): +def _test_bulking(test_bulking_func): # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) test_cases = [(0, 0, True), (1, 1, True), (15, 15, False), (15, 0, True), (0, 15, True), (15, 15, True)] @@ -480,7 +545,8 @@ def test_bulking(): for seg_sizes in test_cases: # Create shared variable to return measured time from test process time_per_iteration = mp.Manager().Value('d', 0.0) - if not run_in_spawned_process(_test_bulking_in_process, + + if not run_in_spawned_process(test_bulking_func, {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': seg_sizes[0], 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': seg_sizes[1], 'MXNET_EXEC_BULK_EXEC_TRAIN': seg_sizes[2]}, @@ -492,8 +558,7 @@ def test_bulking(): '\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format( seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes]) - fastest_non_bulked_time = min( - times[(0, 0, True)], times[(1, 1, True)], times[(15, 15, False)]) + fastest_non_bulked_time = min(times[(0, 0, True)], times[(1, 1, True)], times[(15, 15, False)]) slowest_half_bulked_time = max(times[(0, 15, True)], times[(15, 0, True)]) fastest_half_bulked_time = min(times[(0, 15, True)], times[(15, 0, True)]) fully_bulked_time = times[(15, 15, True)] @@ -509,6 +574,11 @@ def test_bulking(): 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \ .format(fully_bulked_time - fastest_half_bulked_time, times_str) +@with_seed() +@unittest.skip('skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/14970') +def test_bulking_gluon_gpu(): + _test_bulking(_test_bulking_in_process) + if __name__ == '__main__': import nose diff --git a/tests/python/gpu/test_gluon_model_zoo_gpu.py b/tests/python/gpu/test_gluon_model_zoo_gpu.py index d4f6f31a30e7..6f559db62808 100644 --- a/tests/python/gpu/test_gluon_model_zoo_gpu.py +++ b/tests/python/gpu/test_gluon_model_zoo_gpu.py @@ -81,16 +81,17 @@ def test_inference(): gpu_param = gpu_params.get(k) gpu_param.set_data(cpu_param.data().as_in_context(mx.gpu())) + cpu_data = mx.nd.array(data, ctx=mx.cpu()) for i in range(5): # Run inference. with autograd.record(train_mode=False): - cpu_out = cpu_model(mx.nd.array(data, ctx=mx.cpu())) + cpu_out = cpu_model(cpu_data) gpu_out = gpu_model(gpu_data) - out = cpu_out.asnumpy() - max_val = np.max(np.abs(out)) + + max_val = np.max(np.abs(cpu_out.asnumpy())) gpu_max_val = np.max(np.abs(gpu_out.asnumpy())) eprint(model_name + ": CPU " + str(max_val) + ", GPU " + str(gpu_max_val)) - assert_almost_equal(out / max_val, gpu_out.asnumpy() / max_val, rtol=1e-3, atol=1e-3) + assert_almost_equal(cpu_out / max_val, gpu_out / gpu_max_val, rtol=1e-3, atol=1e-3) def get_nn_model(name): if "densenet" in name: @@ -161,7 +162,7 @@ def test_training(): max_val = np.max(np.abs(cpu_out.asnumpy())) gpu_max_val = np.max(np.abs(gpu_out.asnumpy())) eprint(model_name + ": CPU " + str(max_val) + ", GPU " + str(gpu_max_val)) - assert_almost_equal(cpu_out.asnumpy() / max_val, gpu_out.asnumpy() / max_val, rtol=1e-3, atol=1e-3) + assert_almost_equal(cpu_out / max_val, gpu_out / max_val, rtol=1e-3, atol=1e-3) cpu_loss.backward() gpu_loss.backward() cpu_trainer.step(batch_size) @@ -177,8 +178,7 @@ def test_training(): k = k.replace(cpu_params.prefix, '') cpu_param = cpu_params.get(k) gpu_param = gpu_params.get(k) - assert_almost_equal(cpu_param.data().asnumpy(), gpu_param.data().asnumpy(), - rtol=1e-3, atol=1e-3) + assert_almost_equal(cpu_param.data(), gpu_param.data(), rtol=1e-3, atol=1e-3) if __name__ == '__main__': import nose diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 0f1cd93755c3..b79b08219221 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -20,15 +20,13 @@ import os import time import multiprocessing as mp -import unittest import mxnet as mx import numpy as np import unittest from nose.tools import assert_raises -from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal +from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, assert_allclose from mxnet.base import MXNetError from mxnet import autograd -from numpy.testing import assert_allclose curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) @@ -46,6 +44,7 @@ from test_sparse_operator import * from test_ndarray import * from test_subgraph_op import * +from test_gluon_gpu import _test_bulking from test_contrib_operator import test_multibox_target_op from test_tvm_op import * from test_library_loading import * @@ -230,7 +229,7 @@ def check_fft(shape): a[i,j,:,p+1] = out2[i,j+out1[0].shape[1],:,k] p = p+2 - assert_almost_equal(a, out1[0],rtol=1e-3, atol=1e-5) + assert_almost_equal(a, out1[0], rtol=1e-3, atol=1e-5) # backward if len(shape) == 2: @@ -244,7 +243,7 @@ def check_fft(shape): for exe in exe_list: exe.backward([out_grad]) a = np.fft.ifft(out_grad_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[1],rtol=1e-3, atol=1e-5) + assert_almost_equal(a.real, exe.grad_arrays[0]/shape[1],rtol=1e-3, atol=1e-5) if len(shape) == 4: out_grad = mx.nd.empty(out1[0].shape) @@ -257,7 +256,7 @@ def check_fft(shape): for exe in exe_list: exe.backward([out_grad]) a = np.fft.ifft(out_grad_complex, n=None, axis=-1, norm=None) - assert_almost_equal(a.real, exe.grad_arrays[0].asnumpy()/shape[3],rtol=1e-3, atol=1e-5) + assert_almost_equal(a.real, exe.grad_arrays[0]/shape[3],rtol=1e-3, atol=1e-5) @with_seed() def test_fft(): @@ -1682,7 +1681,7 @@ def check_rnn_consistency(cell1, cell2): mod1.forward(batch, is_train=False) mod2.forward(batch, is_train=False) - assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4) + mx.test_utils.assert_allclose(mod1.get_outputs()[0], mod2.get_outputs()[0], rtol=1e-2, atol=1e-4) @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') @@ -1716,7 +1715,7 @@ def test_lstm_forget_bias(): bias_name = next(x for x in args if x.endswith('f_bias')) expected_bias = forget_bias * np.ones(10, ) - assert_allclose(args[bias_name].asnumpy(), expected_bias) + mx.test_utils.assert_allclose(args[bias_name], expected_bias) @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') @@ -2008,9 +2007,9 @@ def check_rnn_layer(layer): co, cs = layer(x, states) # atol of 1e-6 required, as exposed by seed 2124685726 - assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6) + assert_almost_equal(go, co, rtol=1e-2, atol=1e-6) for g, c in zip(gs, cs): - assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) + assert_almost_equal(g, c, rtol=1e-2, atol=1e-6) def check_rnn_layer_w_rand_inputs(layer): layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) @@ -2025,9 +2024,9 @@ def check_rnn_layer_w_rand_inputs(layer): states = layer.begin_state(16) co, cs = layer(x, states) - assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6) + assert_almost_equal(go, co, rtol=1e-2, atol=1e-6) for g, c in zip(gs, cs): - assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) + assert_almost_equal(g, c, rtol=1e-2, atol=1e-6) @with_seed() def test_sequence_reverse(): @@ -2092,7 +2091,7 @@ def test_cross_device_autograd(): y.backward() - dx = x.grad.asnumpy() + dx = x.grad.copy() x.grad[:] = 0 with mx.autograd.record(): @@ -2101,7 +2100,7 @@ def test_cross_device_autograd(): y = mx.nd.tanh(y) y.backward() - assert_almost_equal(dx, x.grad.asnumpy()) + assert_almost_equal(dx, x.grad) @with_seed() def test_multi_proposal_op(): @@ -2285,12 +2284,11 @@ def test_softmax_activation(): with mx.autograd.record(): gpu_y = mx.nd.SoftmaxActivation(data = gpu_a) cpu_y = mx.nd.SoftmaxActivation(data = cpu_a) - assert_almost_equal(cpu_y.asnumpy(), gpu_y.asnumpy(), atol = 1e-3, rtol = 1e-3) + assert_almost_equal(cpu_y, gpu_y, atol = 1e-3, rtol = 1e-3) gpu_y.backward() cpu_y.backward() - assert_almost_equal(cpu_a.grad.asnumpy(), gpu_a.grad.asnumpy(), - atol = 1e-3, rtol = 1e-3) + assert_almost_equal(cpu_a.grad, gpu_a.grad, atol = 1e-3, rtol = 1e-3) @with_seed() @@ -2320,13 +2318,13 @@ def test_bilinear_sampler_versions(): exe.arg_dict['data'][:] = test_data exe.arg_dict['grid'][:] = test_grid exe.forward(is_train=True) - assert_almost_equal(exe_list[ref_idx].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5) + mx.test_utils.assert_almost_equal(exe_list[ref_idx].outputs[0], exe.outputs[0], rtol=1e-3, atol=1e-5) out_grad = np.random.uniform(low=-0.01, high=0.01,size=data_shape[:2] + grid_shape[2:]).astype(np.float32) for exe in exe_list: exe.backward(mx.nd.array(out_grad)) - assert_almost_equal(exe.grad_dict['data'].asnumpy(), exe_list[ref_idx].grad_dict['data'].asnumpy(), rtol=1e-3, atol=1e-5) - assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_dict['data'], exe_list[ref_idx].grad_dict['data'], rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_dict['grid'], exe_list[ref_idx].grad_dict['grid'], rtol=1e-3, atol=1e-5) data_grad = exe_list[ref_idx].grad_dict['data'].asnumpy() grid_grad = exe_list[ref_idx].grad_dict['grid'].asnumpy() @@ -2345,10 +2343,10 @@ def test_bilinear_sampler_versions(): exe.grad_dict['grid'][:] = grid_initial_grad exe.forward(is_train=True) exe.backward(mx.nd.array(out_grad)) - assert_almost_equal(exe.grad_dict['data'].asnumpy(), exe_list[ref_idx].grad_dict['data'].asnumpy(), rtol=1e-3, atol=1e-5) - assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) - assert_almost_equal(exe_list[ref_idx].grad_dict['data'].asnumpy(), data_grad + data_initial_grad, rtol=1e-3, atol=1e-5) - assert_almost_equal(exe_list[ref_idx].grad_dict['grid'].asnumpy(), grid_grad + grid_initial_grad, rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_dict['data'], exe_list[ref_idx].grad_dict['data'], rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_dict['grid'], exe_list[ref_idx].grad_dict['grid'], rtol=1e-3, atol=1e-5) + assert_almost_equal(exe_list[ref_idx].grad_dict['data'], data_grad + data_initial_grad, rtol=1e-3, atol=1e-5) + assert_almost_equal(exe_list[ref_idx].grad_dict['grid'], grid_grad + grid_initial_grad, rtol=1e-3, atol=1e-5) for req_dict in [{'data' : 'null', 'grid' : 'write'}, {'data' : 'write', 'grid' : 'null'}]: # Mixture of kWriteTo and kNullOp @@ -2362,9 +2360,9 @@ def test_bilinear_sampler_versions(): exe.forward(is_train=True) exe.backward(mx.nd.array(out_grad)) if req_dict['data'] is 'write': - assert_almost_equal(exe.grad_dict['data'].asnumpy(), exe_list[ref_idx].grad_dict['data'].asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_dict['data'], exe_list[ref_idx].grad_dict['data'], rtol=1e-3, atol=1e-5) if req_dict['grid'] is 'write': - assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_dict['grid'], exe_list[ref_idx].grad_dict['grid'], rtol=1e-3, atol=1e-5) # isolated execution bulking test function to be invoked with different env var settings @@ -2394,7 +2392,12 @@ def _test_bulking_in_process(seed, time_per_iteration): dx.wait_to_read() time_per_iteration.value = (time.time() - start) / num_iterations + @with_seed() +def test_bulking_operator_gpu(): + _test_bulking(_test_bulking_in_process) + + @unittest.skip('skippping temporarily, tracked by https://github.com/apache/incubator-mxnet/issues/14970') def test_bulking(): # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) @@ -2433,6 +2436,10 @@ def test_bulking(): .format(fully_bulked_time - fastest_half_bulked_time, times_str) +@with_seed() +def test_allclose_function_gpu(): + allclose_function([mx.cpu(), mx.gpu(0)]) + def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 2ffe3eaa233d..f88c0a888320 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -161,8 +161,8 @@ def hybrid_forward(self, F, x, *args, **kwargs): with mx.autograd.record(): out2 = net(x) out2.backward() - mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6) - mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) + assert_almost_equal(dx1, x.grad, rtol=1e-5, atol=1e-6) + assert_almost_equal(out1, out2, rtol=1e-5, atol=1e-6) @with_seed() @@ -195,8 +195,8 @@ def hybrid_forward(self, F, x, *args, **kwargs): with mx.autograd.record(): out2 = net(x) out2.backward() - mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6) - mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) + assert_almost_equal(dx1, x.grad, rtol=1e-5, atol=1e-6) + assert_almost_equal(out1, out2, rtol=1e-5, atol=1e-6) @with_seed() @@ -229,8 +229,8 @@ def hybrid_forward(self, F, x, *args, **kwargs): with mx.autograd.record(): out2 = net(x) out2.backward() - mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6) - mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) + assert_almost_equal(dx1, x.grad, rtol=1e-5, atol=1e-6) + assert_almost_equal(out1, out2, rtol=1e-5, atol=1e-6) @with_seed() diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index 1e0555900f17..123db085e817 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -24,10 +24,9 @@ from mxnet.gluon.contrib.nn import ( Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D, PixelShuffle2D, PixelShuffle3D) -from mxnet.test_utils import almost_equal, default_context, assert_almost_equal +from mxnet.test_utils import almost_equal, default_context, assert_almost_equal, assert_allclose from common import setup_module, with_seed, teardown import numpy as np -from numpy.testing import assert_allclose def check_rnn_cell(cell, prefix, in_shape=(10, 50), out_shape=(10, 100), begin_state=None): @@ -112,6 +111,7 @@ def test_conv_fill_shape(): check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7))) assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1] + @with_seed() def test_lstmp(): nhid = 100 @@ -193,8 +193,7 @@ def test_concurrent(): def test_identity(): model = Identity() x = mx.nd.random.uniform(shape=(128, 33, 64)) - mx.test_utils.assert_almost_equal(model(x).asnumpy(), - x.asnumpy()) + assert_almost_equal(model(x), x) @with_seed() def test_sparse_embedding(): @@ -219,7 +218,7 @@ def test_pixelshuffle1d(): y = layer(x) assert y.shape == shape_after assert_allclose( - y.asnumpy(), + y, [[[0, 3, 1, 4, 2, 5], [6, 9, 7, 10, 8, 11]]] ) @@ -242,7 +241,7 @@ def test_pixelshuffle2d(): # - Increasing the block index adds an offset of 1 # - Increasing the channel index adds an offset of `nx * up_x * ny * up_y` assert_allclose( - y.asnumpy(), + y, [[[[ 0, 6, 12, 1, 7, 13, 2, 8, 14], [18, 24, 30, 19, 25, 31, 20, 26, 32], [ 3, 9, 15, 4, 10, 16, 5, 11, 17], @@ -273,7 +272,7 @@ def test_pixelshuffle3d(): # column index by 1, e.g. the block [[[ 0, 24]], [[48, 72]]] # - Increasing the block index adds an offset of 1 assert_allclose( - y.asnumpy(), + y, [[[[[ 0, 24, 1, 25, 2, 26, 3, 27], [ 4, 28, 5, 29, 6, 30, 7, 31], [ 8, 32, 9, 33, 10, 34, 11, 35]], @@ -382,19 +381,17 @@ def check_unroll(cell_type, num_states, layout): trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) with mx.autograd.record(): res2, states2 = layer(rnn_data, states, valid_length) - assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + assert_almost_equal(res1, res2, rtol=0.001, atol=0.0001) assert len(states1) == len(states2) for i in range(len(states1)): - assert_almost_equal(states1[i].asnumpy(), states2[i].asnumpy(), - rtol=0.001, atol=0.0001) + assert_almost_equal(states1[i], states2[i], rtol=0.001, atol=0.0001) res2.backward() trainer.step(batch_size) for key, val in params1.items(): weight1 = val.data() weight2 = params2[key].data() - assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), - rtol=0.001, atol=0.0001) + assert_almost_equal(weight1, weight2, rtol=0.001, atol=0.0001) @with_seed() diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index 3b9b46b16f93..af7c4838d5a9 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -49,10 +49,10 @@ def test_loss_ndarray(): loss = gluon.loss.SoftmaxCrossEntropyLoss() L = loss(output, label).asnumpy() - mx.test_utils.assert_almost_equal(L, np.array([ 2.12692809, 0.04858733])) + assert_almost_equal(L, np.array([ 2.12692809, 0.04858733])) L = loss(output, label, weighting).asnumpy() - mx.test_utils.assert_almost_equal(L, np.array([ 1.06346405, 0.04858733])) + assert_almost_equal(L, np.array([ 1.06346405, 0.04858733])) def get_net(num_hidden, flatten=True): @@ -126,8 +126,8 @@ def test_logistic_loss_equal_bce(): loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False) data = mx.random.uniform(-10, 10, shape=(N, 1)) label = mx.nd.round(mx.random.uniform(0, 1, shape=(N, 1))) - assert_almost_equal(loss_binary(data, label).asnumpy(), loss_bce(data, label).asnumpy(), atol=1e-6) - assert_almost_equal(loss_signed(data, 2 * label - 1).asnumpy(), loss_bce(data, label).asnumpy(), atol=1e-6) + assert_almost_equal(loss_binary(data, label), loss_bce(data, label), atol=1e-6) + assert_almost_equal(loss_signed(data, 2 * label - 1), loss_bce(data, label), atol=1e-6) @with_seed() def test_kl_loss(): @@ -186,27 +186,27 @@ def test_l1_loss(): def test_ctc_loss(): loss = gluon.loss.CTCLoss() l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[1,0,-1,-1],[2,1,1,-1]])) - mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss(layout='TNC') l = loss(mx.nd.ones((20,2,4)), mx.nd.array([[1,0,-1,-1],[2,1,1,-1]])) - mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss(layout='TNC', label_layout='TN') l = loss(mx.nd.ones((20,2,4)), mx.nd.array([[1,0,-1,-1],[2,1,1,-1]]).T) - mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss() l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[2,1,2,2],[3,2,2,2]]), None, mx.nd.array([2,3])) - mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss() l = loss(mx.nd.ones((2,25,4)), mx.nd.array([[2,1,-1,-1],[3,2,2,-1]]), mx.nd.array([20,20])) - mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss() l = loss(mx.nd.ones((2,25,4)), mx.nd.array([[2,1,3,3],[3,2,2,3]]), mx.nd.array([20,20]), mx.nd.array([2,3])) - mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) @with_seed() @@ -367,12 +367,14 @@ def test_cosine_loss(): assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5) def test_poisson_nllloss(): - pred = mx.nd.random.normal(shape=(3, 4)) + shape=(3, 4) + not_axis0 = tuple(range(1, len(shape))) + pred = mx.nd.random.normal(shape=shape) min_pred = mx.nd.min(pred) #This is necessary to ensure only positive random values are generated for prediction, # to avoid ivalid log calculation pred[:] = pred + mx.nd.abs(min_pred) - target = mx.nd.random.normal(shape=(3, 4)) + target = mx.nd.random.normal(shape=shape) min_target = mx.nd.min(target) #This is necessary to ensure only positive random values are generated for prediction, # to avoid ivalid log calculation @@ -396,8 +398,9 @@ def test_poisson_nllloss(): assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar()) #3) Testing for Sterling approximation - np_pred = np.random.uniform(1, 5, (2, 3)) - np_target = np.random.uniform(1, 5, (2, 3)) + shape=(2, 3) + np_pred = np.random.uniform(1, 5, shape) + np_target = np.random.uniform(1, 5, shape) np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\ np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1))) Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 7be59df6efda..bee4bff0f7c0 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -41,6 +41,10 @@ def check_with_uniform(uf, arg_shapes, dim=None, npuf=None, rmin=-10, type_list= assert dim shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) arg_shapes = [shape] * arg_shapes + + if npuf is None: + npuf = uf + for dtype in type_list: ndarray_arg = [] numpy_arg = [] @@ -50,10 +54,7 @@ def check_with_uniform(uf, arg_shapes, dim=None, npuf=None, rmin=-10, type_list= ndarray_arg.append(narr) numpy_arg.append(npy) out1 = uf(*ndarray_arg) - if npuf is None: - out2 = uf(*numpy_arg).astype(dtype) - else: - out2 = npuf(*numpy_arg).astype(dtype) + out2 = npuf(*numpy_arg).astype(dtype) assert out1.shape == out2.shape if isinstance(out1, mx.nd.NDArray): @@ -1712,13 +1713,13 @@ def l2norm(input_data, axis=0, keepdims=False): np_arr, i, keep_dims) mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims) assert npy_out.shape == mx_out.shape - mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) + assert_almost_equal(npy_out, mx_out) if (i < 3): npy_out = l1norm(np_arr, (i, i + 1), keep_dims) if ord == 1 else l2norm( np_arr, (i, i + 1), keep_dims) mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i + 1), keepdims=keep_dims) assert npy_out.shape == mx_out.shape - mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) + assert_almost_equal(npy_out, mx_out) @with_seed() @@ -1732,7 +1733,7 @@ def test_dlpack(): for dtype in [np.float32, np.int32]: for shape in [(3, 4, 5, 6), (2, 10), (15,)]: a = mx.nd.random.uniform(shape = shape) - a_np = a.asnumpy() + a_np = a.copy() pack = a.to_dlpack_for_read() b = mx.nd.from_dlpack(pack) @@ -1750,14 +1751,10 @@ def test_dlpack(): del a, pack, pack2, pack3, pack4 - b_np = b.asnumpy() - c_np = c.asnumpy() - d_np = d.asnumpy() - e_np = e.asnumpy() - mx.test_utils.assert_almost_equal(a_np, b_np) - mx.test_utils.assert_almost_equal(a_np, c_np) - mx.test_utils.assert_almost_equal(a_np, d_np) - mx.test_utils.assert_almost_equal(a_np, e_np) + assert_almost_equal(a_np, b) + assert_almost_equal(a_np, c) + assert_almost_equal(a_np, d) + assert_almost_equal(a_np, e) @with_seed() def test_ndarray_is_inf(): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 74cc4203df1d..e5f8909dcc58 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -363,9 +363,9 @@ def check_elementwise_sum_with_shape(shape, n): exec1 = out.bind(default_context(), args=arr, args_grad=arr_grad) - out1 = exec1.outputs[0].asnumpy() + exec1.forward(is_train=True) - out1 = exec1.outputs[0].asnumpy() + out1 = exec1.outputs[0] out = sum(a.asnumpy() for a in arr) assert_almost_equal(out, out1, rtol=1e-5, atol=1e-5) @@ -374,7 +374,7 @@ def check_elementwise_sum_with_shape(shape, n): # backward exec1.backward([out_grad]) for a in arr_grad: - assert_almost_equal(a.asnumpy(), out_grad.asnumpy(), rtol=1e-5, atol=1e-5) + assert_almost_equal(a, out_grad, rtol=1e-5, atol=1e-5) @with_seed() @@ -419,7 +419,7 @@ def check_concat_with_shape(shapes, dimension, skip_second): exec1.forward(is_train=True) out1 = exec1.outputs[0] ret = np.concatenate([narray.asnumpy() for narray in arr], axis=dimension) - assert_almost_equal(out1.asnumpy(), ret) + assert_almost_equal(out1, ret) # backward out1.copyto(out_grad) out_grad[:] += 1 @@ -429,7 +429,7 @@ def check_concat_with_shape(shapes, dimension, skip_second): if not skip_second or name != 'arg1': grad = dict_grad[name] np_grad = arr_np[i] - assert_almost_equal(grad.asnumpy(), np_grad + 1) + assert_almost_equal(grad, np_grad + 1) @with_seed() @@ -514,18 +514,17 @@ def check_slice_channel(data_ndim, axis, num_outputs, squeeze_axis): gt = data_npy.take(np.arange(i * shape[axis]/num_outputs, (i+1) * shape[axis]/num_outputs).astype(np.int), axis=axis) if squeeze_axis: - - assert_almost_equal(outputs[i].asnumpy(), gt.reshape(outputs[i].shape)) + assert_almost_equal(outputs[i], gt.reshape(outputs[i].shape)) else: - assert_almost_equal(outputs[i].asnumpy(), gt) + assert_almost_equal(outputs[i], gt) # test backward exe.backward(out_grads=[mx.nd.array(ele, ctx=default_context()) for ele in out_grads_npy]) if squeeze_axis: - assert_almost_equal(exe.grad_arrays[0].asnumpy(), + assert_almost_equal(exe.grad_arrays[0], np.concatenate([np.expand_dims(ele, axis=axis) for ele in out_grads_npy], axis=axis)) else: - assert_almost_equal(exe.grad_arrays[0].asnumpy(), + assert_almost_equal(exe.grad_arrays[0], np.concatenate(out_grads_npy, axis=axis)) check_slice_channel(data_ndim=2, axis=1, num_outputs=3, squeeze_axis=True) check_slice_channel(data_ndim=4, axis=2, num_outputs=3, squeeze_axis=False) @@ -557,8 +556,8 @@ def check_regression(symbol, forward, backward, shape, stype='default', densitie out_exec.backward() np_out = forward(arr_data.asnumpy()) out_grad = backward(np_out, arr_label.asnumpy().reshape(np_out.shape)) / shape[1] - assert_almost_equal(out_exec.outputs[0].asnumpy(), np_out, atol=atol) - assert_almost_equal(grad_map["data"].asnumpy(), out_grad, atol=atol) + assert_almost_equal(out_exec.outputs[0], np_out, atol=atol) + assert_almost_equal(grad_map["data"], out_grad, atol=atol) shape = (50, 30) @@ -681,7 +680,7 @@ def check_softmax_with_shape(shape, xpu, preserve_shape=False): atol = 1e-6 assert_almost_equal(out, np_softmax(x.asnumpy()), rtol=rtol, atol=atol) exec1.backward() - assert_almost_equal(grad.asnumpy(), np_softmax(x.asnumpy()) - l.asnumpy(), rtol=rtol, atol=atol) + assert_almost_equal(grad, np_softmax(x.asnumpy()) - l.asnumpy(), rtol=rtol, atol=atol) def test_python_op(): @@ -694,9 +693,9 @@ def test_python_op(): dy = mx.ndarray.ones((10)) exec1 = s.bind(default_context(), args=[x], args_grad = {'X': dx}) exec1.forward(is_train=True) - assert_almost_equal(x.asnumpy(), exec1.outputs[0].asnumpy()) + assert_almost_equal(x, exec1.outputs[0]) exec1.backward(dy) - assert_almost_equal(dy.asnumpy(), dx.asnumpy()) + assert_almost_equal(dy, dx) def test_swapaxes(): @@ -710,7 +709,7 @@ def test_swapaxes(): swap = mx.symbol.SwapAxis(data=swap0, dim1=1, dim2=2) exe_c = swap.bind(default_context(), args=[arr_data]) exe_c.forward(is_train=True) - out = exe_c.outputs[0].asnumpy() + out = exe_c.outputs[0] swap0_ = np.swapaxes(data_tmp, 0, 2) swap_ = np.swapaxes(swap0_, 1, 2) @@ -1021,7 +1020,7 @@ def test_shape_array(): exe.backward([yg]) yo = exe.outputs[0].asnumpy() same(yo, ya) - assert_almost_equal(xg.asnumpy(), np.zeros_like(xg.asnumpy())) + assert_almost_equal(xg, np.zeros_like(xg.asnumpy())) @with_seed() def test_size_array(): @@ -1039,7 +1038,7 @@ def test_size_array(): exe.backward([yg]) yo = exe.outputs[0].asnumpy() same(yo, ya) - assert_almost_equal(xg.asnumpy(), np.zeros_like(xg.asnumpy())) + assert_almost_equal(xg, np.zeros_like(xg.asnumpy())) @with_seed() def test_hard_sigmoid(): @@ -1098,7 +1097,7 @@ def _inner_test(forward_gt, logic_sym, x_shape, y_shape, test_scalar=True): x_npy = np.random.randint(0, 4, size=x_shape).astype(np.float32) y_npy = np.random.randint(0, 4, size=y_shape).astype(np.float32) exe = z.simple_bind(ctx=default_context(), x=x_shape, y=y_shape) - mx_out = exe.forward(is_train=True, x=x_npy, y=y_npy)[0].asnumpy() + mx_out = exe.forward(is_train=True, x=x_npy, y=y_npy)[0] assert_almost_equal(mx_out, forward_gt(x_npy, y_npy)) exe.backward() if test_scalar: @@ -1106,8 +1105,8 @@ def _inner_test(forward_gt, logic_sym, x_shape, y_shape, test_scalar=True): z_rscalar = logic_sym(x, 1) exe_lscalar = z_lscalar.simple_bind(ctx=default_context(), y=y_shape) exe_rscalar = z_rscalar.simple_bind(ctx=default_context(), x=x_shape) - mx_lscalar_out = exe_lscalar.forward(is_train=True, y=y_npy)[0].asnumpy() - mx_rscalar_out = exe_rscalar.forward(is_train=True, x=x_npy)[0].asnumpy() + mx_lscalar_out = exe_lscalar.forward(is_train=True, y=y_npy)[0] + mx_rscalar_out = exe_rscalar.forward(is_train=True, x=x_npy)[0] assert_almost_equal(mx_lscalar_out, forward_gt(1, y_npy)) assert_almost_equal(mx_rscalar_out, forward_gt(x_npy, 1)) exe_lscalar.backward() @@ -1154,12 +1153,12 @@ def reference(a, dtype): xa = np.random.randint(-2, 2, size=shape).astype(np.float32) mx_xa = mx.nd.array(xa) mx_out = mx.nd.logical_not(mx_xa) - assert_almost_equal(mx_out.asnumpy(), reference(xa, dtype=xa.dtype)) + assert_almost_equal(mx_out, reference(xa, dtype=xa.dtype)) x = mx.sym.Variable('x') y = mx.sym.logical_not(data=x) exe = y.simple_bind(ctx=default_context(), x=shape) sym_out = exe.forward(is_train=True, x=mx_xa)[0] - assert_almost_equal(sym_out.asnumpy(), reference(xa, dtype=xa.dtype)) + assert_almost_equal(sym_out, reference(xa, dtype=xa.dtype)) @with_seed() @@ -1184,13 +1183,13 @@ def test_embedding(): # Non-zero atol required, as exposed by seed 781663739 rtol = 1e-5 atol = 1e-5 - assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, np_weight), rtol=rtol, atol=atol) + assert_almost_equal(exe_test.outputs[0], np.dot(np_onehot, np_weight), rtol=rtol, atol=atol) # backward np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) grad = mx.nd.zeros(np_grad.shape) grad[:] = np_grad exe_test.backward([grad]) - assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad), rtol=rtol, atol=atol) + assert_almost_equal(grad_map["embed_weight"], np.dot(np_onehot.T, np_grad), rtol=rtol, atol=atol) # check ops handle duplicate input correctly. @@ -1208,9 +1207,9 @@ def test_binary_op_duplicate_input(): square = data * data exe_square = square.bind(default_context(), args=[arr_data], args_grad=[arr_grad]) exe_square.forward(is_train=True) - assert_almost_equal(exe_square.outputs[0].asnumpy(), data_tmp * data_tmp) + assert_almost_equal(exe_square.outputs[0], data_tmp * data_tmp) exe_square.backward(out_grad) - assert_almost_equal(arr_grad.asnumpy(), 2.0 * data_tmp) + assert_almost_equal(arr_grad, 2.0 * data_tmp) @with_seed() @@ -1226,7 +1225,7 @@ def test_sign(): test = mx.sym.sign(data) exe_test = test.bind(default_context(), args=[arr_data], args_grad=[arr_grad]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] npout = np.sign(data_tmp) assert_almost_equal(out, npout) @@ -1235,7 +1234,7 @@ def test_sign(): npout_grad = out_grad.asnumpy() npout_grad = 0; exe_test.backward(out_grad) - assert_almost_equal(arr_grad.asnumpy(), npout_grad) + assert_almost_equal(arr_grad, npout_grad) @with_seed() @@ -1251,7 +1250,7 @@ def test_round_ceil_floor(): test = mx.sym.round(data) + mx.sym.ceil(data) + mx.sym.floor(data) exe_test = test.bind(default_context(), args=[arr_data]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] npout = np.round(data_tmp) + np.ceil(data_tmp) + np.floor(data_tmp) assert_almost_equal(out, npout) @@ -1265,7 +1264,7 @@ def test_trunc(): exe_test = test.bind(default_context(), args=[arr_data]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] # 'trunc' is sensitive to the precision of the calculation. Force numpy to match mxnet's float32. # Repro issue with seed 1660190454 npout = np.trunc(np.float32(data_tmp)) @@ -1286,16 +1285,16 @@ def test_rsqrt_cos_sin(): test = mx.sym.rsqrt(data) + mx.sym.cos(data) + mx.sym.sin(data) exe_test = test.bind(default_context(), args=[arr_data], args_grad=[arr_grad]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] npout = 1/ np.sqrt(data_tmp) + np.cos(data_tmp) + np.sin(data_tmp) assert_almost_equal(out, npout) out_grad = mx.nd.empty(shape) - out_grad[:] = 2; + out_grad[:] = 2 npout_grad = out_grad.asnumpy() npout_grad = npout_grad * -(1.0 / (2.0 * data_tmp * np.sqrt(data_tmp))) + npout_grad * -1 * np.sin(data_tmp) + npout_grad * np.cos(data_tmp) exe_test.backward(out_grad) - assert_almost_equal(arr_grad.asnumpy(), npout_grad) + assert_almost_equal(arr_grad, npout_grad) @with_seed() @@ -1317,7 +1316,7 @@ def test_maximum_minimum(): test = mx.sym.maximum(data1,data2) + mx.sym.minimum(data1,data2) exe_test = test.bind(default_context(), args=[arr_data1,arr_data2], args_grad=[arr_grad1,arr_grad2]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] npout = np.maximum(data_tmp1,data_tmp2) + np.minimum(data_tmp1,data_tmp2) assert_almost_equal(out, npout) @@ -1332,8 +1331,8 @@ def test_maximum_minimum(): npout_grad1 = npout_grad * mask1 + npout_grad * mask2 npout_grad2 = (npout_grad - npout_grad * mask1) + (npout_grad - npout_grad * mask2) - assert_almost_equal(arr_grad1.asnumpy(), npout_grad1) - assert_almost_equal(arr_grad2.asnumpy(), npout_grad2) + assert_almost_equal(arr_grad1, npout_grad1) + assert_almost_equal(arr_grad2, npout_grad2) @with_seed() @@ -1349,7 +1348,7 @@ def test_maximum_minimum_scalar(): test = mx.sym.maximum(data1,3) + mx.sym.maximum(9,data1) + mx.sym.minimum(5,data1) + mx.sym.minimum(data1,4) exe_test = test.bind(default_context(), args=[arr_data1], args_grad=[arr_grad1]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] npout = np.maximum(data_tmp1,3) + np.maximum(9,data_tmp1) + np.minimum(5,data_tmp1) + np.minimum(data_tmp1,4) assert_almost_equal(out, npout) @@ -1365,7 +1364,7 @@ def test_maximum_minimum_scalar(): mask4 = (data_tmp1 < 4).astype('float') npout_grad1 = npout_grad * mask1 + (npout_grad - npout_grad * mask2) + (npout_grad - npout_grad * mask3) + npout_grad * mask4 - assert_almost_equal(arr_grad1.asnumpy(), npout_grad1) + assert_almost_equal(arr_grad1, npout_grad1) @with_seed() @@ -1381,7 +1380,7 @@ def test_abs(): test = mx.sym.abs(data) exe_test = test.bind(default_context(), args=[arr_data], args_grad=[arr_grad]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] npout = abs(data_tmp) assert_almost_equal(out, npout) @@ -1390,7 +1389,7 @@ def test_abs(): npout_grad = out_grad.asnumpy() npout_grad = npout_grad * np.sign(data_tmp) exe_test.backward(out_grad) - assert_almost_equal(arr_grad.asnumpy(), npout_grad) + assert_almost_equal(arr_grad, npout_grad) def check_deconvolution_forward_backward(input_shape, num_filter, kernel, stride, pad): @@ -1422,9 +1421,9 @@ def check_deconvolution_forward_backward(input_shape, num_filter, kernel, stride exe = deconv.bind(default_context(), args=args, args_grad=args_grad) exe.forward(is_train=True) - out = exe.outputs[0].asnumpy() + out = exe.outputs[0] exe.backward(out_grad) - assert_almost_equal(out, args_grad[0].asnumpy(), rtol=1E-3, atol=1e-3) + assert_almost_equal(out, args_grad[0], rtol=1E-3, atol=1e-3) args_grad_addto_npy = [np.random.normal(size=s) for s in arg_shapes] args_grad_addto = [mx.nd.array(ele) for ele in args_grad_addto_npy] @@ -1481,7 +1480,7 @@ def check_deconvolution_gradient(input_shape, num_filter, pad): exe_deconv.forward(is_train=True) deconv_out_grad = conv_data[:] exe_deconv.backward(deconv_out_grad) - assert_almost_equal(conv_args_grad[1].asnumpy(), deconv_args_grad[1].asnumpy(), rtol=1e-3, atol=1e-2) + assert_almost_equal(conv_args_grad[1], deconv_args_grad[1], rtol=1e-3, atol=1e-2) # Test AddTo exe_deconv_addto = deconv.bind(default_context(), args=deconv_args, args_grad=deconv_addto_args_grad, @@ -2069,7 +2068,7 @@ def test_depthwise_convolution(): exe2.backward(exe2.outputs[0]) for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays): - np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-3) + assert_allclose(arr1, arr2, rtol=1e-3, atol=1e-3) @with_seed() @@ -2850,9 +2849,9 @@ def test_broadcasting_ele(sym_bcast): args_grad={'a': grad_nd}) net.forward(is_train=True) assert (net.outputs[0].shape == target_shape).all() - assert_almost_equal(net.outputs[0].asnumpy(), groundtruth, rtol=1e-4) + assert_almost_equal(net.outputs[0], groundtruth, rtol=1e-4) net.backward(out_grads=mx.nd.array(outgrad_npy)) - assert_almost_equal(grad_nd.asnumpy(), grad_groundtruth, rtol=1e-4) + assert_almost_equal(grad_nd, grad_groundtruth, rtol=1e-4) test_broadcasting_ele(sym_bcast_axis) test_broadcasting_ele(sym_bcast_to) test_broadcasting_ele(sym_bcast_to_with_zero) @@ -3120,13 +3119,13 @@ def test_stn(): grad_grad = [mx.nd.zeros(shape, ctx=dev) for shape in arg_shapes] exe = stn.bind(dev, args=args, args_grad=grad_grad) exe.forward(is_train=True) - out = exe.outputs[0].asnumpy() + out = exe.outputs[0] # check forward assert_almost_equal(out, args['data'].asnumpy()[:, :, h//4:h-h//4, w//4:w-w//4], rtol=1e-2, atol=1e-4) out_grad = mx.nd.ones(out.shape, ctx=dev) exe.backward([out_grad]) # check backward - assert_almost_equal(out_grad.asnumpy(), grad_grad[0].asnumpy()[:, :, h//4:h-h//4, w//4:w-w//4], rtol=1e-2, atol=1e-4) + assert_almost_equal(out_grad, grad_grad[0].asnumpy()[:, :, h//4:h-h//4, w//4:w-w//4], rtol=1e-2, atol=1e-4) def test_stn_valid_sampling(): @@ -3188,6 +3187,7 @@ def test_dot(): # Test normal dot. for ndim in ndims: for data_type in dtypes: + tol = 1e-2 if data_type == 'float16' else 1e-3 for m in range(1, 5): for k in range(1, 5): if ndim == 1 and k != 1: @@ -3212,16 +3212,10 @@ def test_dot(): c = mx.sym.dot(a, b) exe = c.simple_bind(ctx=ctx, a=a_npy.shape, b=b_npy.shape) outputs = exe.forward(is_train=True, a=a_npy, b=b_npy) - assert_almost_equal(outputs[0].asnumpy(), c_npy, - rtol=1e-2 if data_type == 'float16' else 1e-3, - atol=1e-2 if data_type == 'float16' else 1e-3) + assert_almost_equal(outputs[0], c_npy, rtol=tol, atol=tol) exe.backward(out_grads=[mx.nd.array(ograd_npy, mx.cpu()).astype(data_type)]) - assert_almost_equal(exe.grad_dict['a'].asnumpy(), agrad_npy, - rtol=1e-2 if data_type == 'float16' else 1e-3, - atol=1e-2 if data_type == 'float16' else 1e-3) - assert_almost_equal(exe.grad_dict['b'].asnumpy(), bgrad_npy, - rtol=1e-2 if data_type == 'float16' else 1e-3, - atol=1e-2 if data_type == 'float16' else 1e-3) + assert_almost_equal(exe.grad_dict['a'], agrad_npy, rtol=tol, atol=tol) + assert_almost_equal(exe.grad_dict['b'], bgrad_npy, rtol=tol, atol=tol) # Test dot with transpose flag using gradient checker. def dot_sym(data_type): @@ -3258,8 +3252,9 @@ def dot_sym_xT_yT(data_type): @with_seed() def test_batch_dot(): + ctx = default_context() dtypes = ['float32', 'float64'] - if default_context().device_type == 'gpu': + if ctx.device_type == 'gpu': dtypes += ['float16'] for data_type in dtypes: @@ -3297,30 +3292,30 @@ def test_batch_dot(): b_npy = np.transpose(b_npy, axes=(0, 2, 1)) bgrad_npy = np.transpose(bgrad_npy, axes=(0, 2, 1)) b_init_grad_npy = np.transpose(b_init_grad_npy, axes=(0, 2, 1)) - exe = c.simple_bind(ctx=default_context(), + exe = c.simple_bind(ctx=ctx, a=a_npy.shape, b=b_npy.shape, grad_req='write') - exe_add = c.simple_bind(ctx=default_context(), + exe_add = c.simple_bind(ctx=ctx, a=a_npy.shape, b=b_npy.shape, grad_req='add') exe_add.grad_dict['a'][:] = a_init_grad_npy exe_add.grad_dict['b'][:] = b_init_grad_npy outputs = exe.forward(is_train=True, a=a_npy, b=b_npy) - assert_almost_equal(outputs[0].asnumpy(), c_npy, + assert_almost_equal(outputs[0], c_npy, rtol=1e-2 if data_type == 'float16' else 1e-3, atol=1e-2 if data_type == 'float16' else 1e-4) exe.backward(out_grads=[mx.nd.array(ograd_npy, ctx=exe._ctx)]) - assert_almost_equal(exe.grad_dict['a'].asnumpy(), agrad_npy, + assert_almost_equal(exe.grad_dict['a'], agrad_npy, rtol=1e-2 if data_type == 'float16' else 1e-3, atol=1e-2 if data_type == 'float16' else 1e-4) - assert_almost_equal(exe.grad_dict['b'].asnumpy(), bgrad_npy, + assert_almost_equal(exe.grad_dict['b'], bgrad_npy, rtol=1e-2 if data_type == 'float16' else 1e-3, atol=1e-2 if data_type == 'float16' else 1e-4) exe_add.forward(is_train=True, a=a_npy, b=b_npy) exe_add.backward(out_grads=[mx.nd.array(ograd_npy, ctx=exe._ctx)]) - assert_almost_equal(exe_add.grad_dict['a'].asnumpy(), + assert_almost_equal(exe_add.grad_dict['a'], agrad_npy + a_init_grad_npy, rtol=1e-2 if data_type == 'float16' else 1e-3, atol=1e-2 if data_type == 'float16' else 1e-4) - assert_almost_equal(exe_add.grad_dict['b'].asnumpy(), + assert_almost_equal(exe_add.grad_dict['b'], bgrad_npy + b_init_grad_npy, rtol=1e-2 if data_type == 'float16' else 1e-3, atol=1e-2 if data_type == 'float16' else 1e-4) @@ -3452,7 +3447,7 @@ def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2 forward_result,tmp1,tmp2 = correlation_forward(img1,img2,pad_size,kernel_size,stride1,stride2,max_displacement,is_multiply) # forward error - assert_almost_equal(exe1.outputs[0].asnumpy(), forward_result, rtol=1e-4, atol=1e-4) + assert_almost_equal(exe1.outputs[0], forward_result, rtol=1e-4, atol=1e-4) # out_grad a = np.ones(forward_result.shape) @@ -3463,8 +3458,8 @@ def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2 grad1,grad2 = correlation_backward(a,tmp1,tmp2,img1,img2,pad_size,kernel_size,stride1,stride2,max_displacement,is_multiply) # backward error - assert_almost_equal(exe1.grad_dict['img1'].asnumpy(), grad1, rtol=1e-3, atol=1e-4) - assert_almost_equal(exe1.grad_dict['img2'].asnumpy(), grad2, rtol=1e-3, atol=1e-4) + assert_almost_equal(exe1.grad_dict['img1'], grad1, rtol=1e-3, atol=1e-4) + assert_almost_equal(exe1.grad_dict['img2'], grad2, rtol=1e-3, atol=1e-4) @with_seed() @@ -3522,7 +3517,7 @@ def test_support_vector_machine_l1_svm(): exec1 = Y.bind(xpu, args = [x, l], args_grad = {'X': grad}) exec1.forward(is_train=True) - assert_almost_equal(x_np, exec1.outputs[0].asnumpy()) + assert_almost_equal(x_np, exec1.outputs[0]) exec1.backward() @@ -3530,7 +3525,7 @@ def test_support_vector_machine_l1_svm(): l_mask = np.array(l_mask, dtype=np.float32)*2 -1 grad_np = (-1) * l_mask * np.greater(1 - l_mask * x_np, 0) - assert_almost_equal(grad_np, grad.asnumpy()) + assert_almost_equal(grad_np, grad) @with_seed() @@ -3553,7 +3548,7 @@ def test_support_vector_machine_l2_svm(): exec1 = Y.bind(xpu, args = [x, l], args_grad = {'X': grad}) exec1.forward(is_train=True) - assert_almost_equal(x_np, exec1.outputs[0].asnumpy()) + assert_almost_equal(x_np, exec1.outputs[0]) exec1.backward() @@ -3561,7 +3556,7 @@ def test_support_vector_machine_l2_svm(): l_mask = np.array(l_mask, dtype=np.float32)*2 -1 grad_np = (-2)*l_mask*np.maximum(1-l_mask*x_np,0) grad_np = grad_np.astype(np.float32) - assert_almost_equal(grad_np, grad.asnumpy()) + assert_almost_equal(grad_np, grad) # Seed set because the test is not robust enough to operate on random data @@ -3595,7 +3590,7 @@ def check_pad_with_shape(shape, xpu, pad_width, mode, dtype="float64"): grad = mx.nd.empty(shape, ctx = xpu, dtype=dtype) exec1 = Y.bind(xpu, args = [x], args_grad = {'X': grad}) exec1.forward(is_train=True) - out = exec1.outputs[0].asnumpy() + out = exec1.outputs[0] # compare numpy + mxnet assert_almost_equal(out, np_out) # grad check @@ -3652,7 +3647,7 @@ def check_instance_norm_with_shape(shape, xpu): np_out = np_instance_norm(x.asnumpy(), gamma.asnumpy(), beta.asnumpy(), eps) exec1 = Y.bind(xpu, args = {'X':x, 'G':gamma, 'B':beta}) exec1.forward(is_train=False) - out = exec1.outputs[0].asnumpy() + out = exec1.outputs[0] assert_almost_equal(out, np_out, rtol=1e-4, atol=1e-4) check_numeric_gradient(Y, {'X':x.asnumpy(), 'G':gamma.asnumpy(), 'B':beta.asnumpy()}, numeric_eps=1e-2, rtol=1e-2, atol=1e-2) @@ -3694,7 +3689,7 @@ def check_l2_normalization(in_shape, mode, dtype, norm_eps=1e-10): exe = out.simple_bind(ctx=ctx, data=in_data.shape) output = exe.forward(is_train=True, data=in_data) # compare numpy + mxnet - assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-2 if dtype is 'float16' else 1e-5, atol=1e-5) + assert_almost_equal(exe.outputs[0], np_out, rtol=1e-2 if dtype is 'float16' else 1e-5, atol=1e-5) # check gradient check_numeric_gradient(out, [in_data], numeric_eps=1e-3, rtol=1e-2, atol=5e-3) @@ -3758,7 +3753,7 @@ def npy_layer_norm_grad(data, gamma, out_grad, axis, eps): exe.arg_dict['beta'][:] = beta out_nd = exe.forward()[0] out = npy_layer_norm(data, gamma, beta, axis, eps) - assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps, forward_check_eps) + assert_almost_equal(out, out_nd, forward_check_eps, forward_check_eps) if finite_grad_check: for req in ['write', 'add']: @@ -4098,7 +4093,7 @@ def mathematical_core_binary(name, test = forward_mxnet_call(data1, data2) exe_test = test.bind(default_context(), args=[arr_data1, arr_data2], args_grad=[arr_grad1, arr_grad2]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] npout = forward_numpy_call(data_tmp1, data_tmp2) assert_almost_equal(out, npout) @@ -4111,8 +4106,6 @@ def mathematical_core_binary(name, npout_grad1 = npout_grad * backward_numpy_call1(data_tmp1, data_tmp2) npout_grad2 = npout_grad * backward_numpy_call2(data_tmp1, data_tmp2) - arr_grad1 = arr_grad1.asnumpy() - arr_grad2 = arr_grad2.asnumpy() assert_almost_equal(arr_grad1, npout_grad1) assert_almost_equal(arr_grad2, npout_grad2) @@ -4130,7 +4123,7 @@ def mathematical_core(name, forward_mxnet_call, forward_numpy_call, backward_num test = forward_mxnet_call(data) exe_test = test.bind(default_context(), args=[arr_data], args_grad=[arr_grad]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] npout = forward_numpy_call(data_tmp) assert_almost_equal(out, npout) @@ -4140,10 +4133,6 @@ def mathematical_core(name, forward_mxnet_call, forward_numpy_call, backward_num temp = backward_numpy_call(data_tmp) npout_grad = npout_grad * temp exe_test.backward(out_grad) - arr_grad = arr_grad.asnumpy() - # print(name) - # print(arr_grad) - # print(npout_grad) assert_almost_equal(arr_grad, npout_grad) @@ -4182,7 +4171,7 @@ def rounding(name, forward_mxnet_call, forward_numpy_call, data_init=5., grad_in test = forward_mxnet_call(data) exe_test = test.bind(default_context(), args=[arr_data]) exe_test.forward(is_train=True) - out = exe_test.outputs[0].asnumpy() + out = exe_test.outputs[0] npout = forward_numpy_call(data_tmp) assert_almost_equal(out, npout) @@ -4320,7 +4309,7 @@ def test_basic_val_init(sym_func, np_func, shape, dtype): x = sym_func(shape=shape, dtype=dtype) exe = x.bind(default_context(), args=[], args_grad=[]) exe.forward(is_train=True) - assert_almost_equal(exe.outputs[0].asnumpy(), np_func(shape=shape, dtype=dtype)) + assert_almost_equal(exe.outputs[0], np_func(shape=shape, dtype=dtype)) assert exe.outputs[0].asnumpy().dtype == dtype def test_arange(): @@ -4337,14 +4326,14 @@ def test_arange(): repeats = random.choice([1, 3]) np_out = np.repeat(np.arange(*config, dtype=dtype), repeats) nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype) - assert_almost_equal(np_out, nd_out.asnumpy()) + assert_almost_equal(np_out, nd_out) def test_arange_inferstop(): s = mx.sym.arange(start=0, stop=None, infer_range=True) s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5])) exe = s.bind(ctx=mx.cpu(), args={}) exe.forward() - assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4])) + assert_almost_equal(exe.outputs[0], np.array([0,1,2,3,4])) def test_arange_like(): shape_list = [(10,), (10, 20), (10, 20, 30), (10, 20, 30, 40)] @@ -4514,7 +4503,7 @@ def test_blockgrad(): exe = b.simple_bind(ctx=default_context(), a=(10, 10)) a_npy = np.random.rand(10, 10) exe.forward(is_train=True, a=a_npy) - assert_almost_equal(exe.outputs[0].asnumpy(), a_npy) + assert_almost_equal(exe.outputs[0], a_npy) exe.backward() # No error if BlockGrad works @@ -4579,7 +4568,7 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode, out_of_range=True): # Did not raise exception assert False, "did not raise %s" % MXNetError.__name__ - assert_almost_equal(exe.outputs[0].asnumpy(), np.take(data_real, idx_real, axis=axis, mode=mode)) + assert_almost_equal(exe.outputs[0], np.take(data_real, idx_real, axis=axis, mode=mode)) for i in np.nditer(idx_real): if mode == 'clip': @@ -4587,7 +4576,7 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode, out_of_range=True): grad_helper(grad_in, axis, i) exe.backward([mx.nd.array(grad_out)]) - assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in) + assert_almost_equal(exe.grad_dict['a'], grad_in) def check_autograd_req(): row_len = 2 @@ -4612,7 +4601,7 @@ def check_autograd_req(): x = x.sum() x.backward() - assert_almost_equal(np.ones(sc.grad.shape), sc.grad.asnumpy()) + assert_almost_equal(np.ones(sc.grad.shape), sc.grad) for mode in ['clip', 'wrap', 'raise']: for data_ndim in range(1, 5): @@ -4658,7 +4647,7 @@ def test_grid_generator(): tmp[1] = -1.0 + (np.arange(target_shape[0]*target_shape[1]) // target_shape[1]) * (2.0 / (target_shape[0]-1)) tmp[2] = 1 grad_est = np.dot(out_grad[0].reshape(2,target_shape[0]*target_shape[1]),tmp.T).reshape(1,6) - assert_almost_equal(exe.grad_dict['affine'].asnumpy(), grad_est, rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_dict['affine'], grad_est, rtol=1e-3, atol=1e-5) # check addto exe = grid.simple_bind(ctx=default_context(), affine=(1,6), grad_req='add') grid_grad_npy = np.random.normal(size=exe.grad_dict['affine'].shape) @@ -4666,7 +4655,7 @@ def test_grid_generator(): exe.arg_dict['affine'][:] = np.array([[1.0, 0, 0, 0, 1.0, 0]]) exe.forward(is_train=True) exe.backward(mx.nd.array(out_grad)) - assert_almost_equal(exe.grad_dict['affine'].asnumpy(), grad_est + grid_grad_npy, rtol=1e-2, atol=1e-5) + assert_almost_equal(exe.grad_dict['affine'], grad_est + grid_grad_npy, rtol=1e-2, atol=1e-5) # transform_type = warp test_case = [(12,21),(4,3),(6,12)] @@ -4689,7 +4678,7 @@ def test_grid_generator(): grad_est = np.zeros((1,2)+target_shape) grad_est[0,0] = out_grad[0,0] / ((target_shape[1]-1.0) / 2.0) grad_est[0,1] = out_grad[0,1] / ((target_shape[0]-1.0) / 2.0) - assert_almost_equal(exe.grad_dict['flow'].asnumpy(), grad_est, rtol=1e-3) + assert_almost_equal(exe.grad_dict['flow'], grad_est, rtol=1e-3) # check addto exe_add = grid.simple_bind(ctx=default_context(), flow=(1, 2) + target_shape, grad_req='add') flow_grad_npy = np.random.normal(size=exe_add.grad_dict['flow'].shape) @@ -4697,7 +4686,7 @@ def test_grid_generator(): exe_add.grad_dict['flow'][:] = flow_grad_npy exe_add.forward(is_train=True) exe_add.backward(mx.nd.array(out_grad)) - assert_almost_equal(exe_add.grad_dict['flow'].asnumpy(), grad_est + flow_grad_npy, rtol=1e-3, atol=1e-5) + assert_almost_equal(exe_add.grad_dict['flow'], grad_est + flow_grad_npy, rtol=1e-3, atol=1e-5) @with_seed() @@ -4708,7 +4697,7 @@ def test_index2d(): data = mx.random.uniform(-1, 1, shape=(n, m), ctx=default_context()) x = mx.nd.array(np.random.randint(0, m, size=n), ctx=default_context(), dtype='int32') r = mx.nd.batch_take(data, x) - assert_almost_equal(r.asnumpy(), data.asnumpy()[np.arange(n), x.asnumpy()]) + assert_almost_equal(r, data.asnumpy()[np.arange(n), x.asnumpy()]) @with_seed() @@ -4724,8 +4713,8 @@ def test_cast(): exe.arg_arrays[0][:] = X exe.forward(is_train=True) exe.backward(mx.nd.array(X, dtype=dsttype, ctx=default_context())) - assert_almost_equal(exe.outputs[0].asnumpy(), X.astype(srctype).astype(dsttype), rtol=1e-3, atol=1e-5) - assert_almost_equal(exe.grad_arrays[0].asnumpy(), X.astype(dsttype).astype(srctype), rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.outputs[0], X.astype(srctype).astype(dsttype), rtol=1e-3, atol=1e-5) + assert_almost_equal(exe.grad_arrays[0], X.astype(dsttype).astype(srctype), rtol=1e-3, atol=1e-5) def get_cast_op_data(): FP16_FRACTION_BITS = 10 @@ -4853,12 +4842,12 @@ def test_repeat_forward(): a = np.random.random_sample(size=shape) aa = np.repeat(a, repeats) b = mx.nd.array(a, ctx=default_context()) - bb = mx.nd.repeat(b, repeats).asnumpy() + bb = mx.nd.repeat(b, repeats) assert_almost_equal(aa, bb) for axis in range(0, ndim): aa = np.repeat(a, repeats, axis) - bb = mx.nd.repeat(b, repeats, axis).asnumpy() + bb = mx.nd.repeat(b, repeats, axis) assert_almost_equal(aa, bb) def test_repeat_backward(axis): @@ -4896,7 +4885,7 @@ def test_repeat_backward(axis): else: raise RuntimeError("Invalid axis value") - assert_almost_equal(expected_grad, arr_grad.asnumpy(), rtol=1e-3) + assert_almost_equal(expected_grad, arr_grad, rtol=1e-3) def test_repeat_numeric_gradient(): data = mx.sym.Variable('data') @@ -4993,7 +4982,7 @@ def test_tile_backward(): for j in range(shape[1]): expected_grad[i][j] += sum(sum(npout_grad[i:(n1 * reps1):reps1, j:(n2 * reps2):reps2])) - assert_almost_equal(expected_grad, arr_grad.asnumpy(), rtol=1e-3) + assert_almost_equal(expected_grad, arr_grad, rtol=1e-3) def test_tile_numeric_gradient(): data = mx.sym.Variable('data') @@ -5309,8 +5298,7 @@ def softmax_forward(input_data, true_output): exec1 = out1.bind(default_context(), args={'data': input_data}) exec1.forward()[0].wait_to_read() ndarr = exec1.outputs[0][0][0][0] - nparr = ndarr.asnumpy() - assert_almost_equal(nparr, true_output, rtol=1e-5, atol=1e-5) + assert_almost_equal(ndarr, true_output, rtol=1e-5, atol=1e-5) softmax_forward(mx.nd.array([[[[-1e30,-1e30]]]]), np.array([1.0,1.0])) softmax_forward(mx.nd.array([[[[1e30,1e30]]]]), np.array([1.0,1.0])) @@ -5332,14 +5320,10 @@ def check_dtypes_almost_equal(op_name, with mx.autograd.record(): dtype_softmax = op(dtype_input, axis=-1, dtype=odtype) ref_softmax = op(ref_input, axis=-1, dtype=odtype) - dtype_softmax_np = dtype_softmax.asnumpy() - ref_softmax_np = ref_softmax.asnumpy() - assert_almost_equal(dtype_softmax_np, ref_softmax_np, rtol=rtol, atol=atol) + assert_almost_equal(dtype_softmax, ref_softmax, rtol=rtol, atol=atol) dtype_softmax.backward() ref_softmax.backward() - dtype_grad_np = dtype_input.grad.asnumpy() - ref_grad_np = ref_input.grad.asnumpy() - assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol) + assert_almost_equal(dtype_input.grad, ref_input.grad, rtol=grad_rtol, atol=grad_atol) import sys is_windows = sys.platform.startswith('win') @@ -5439,48 +5423,28 @@ def test_pick_helper(index_type=np.int32): test_pick_helper(np.float32) -def check_ctc_loss(acts, labels, loss_truth): +def check_ctc_loss(acts, labels, loss_truth, contrib=False): in_var = mx.sym.Variable('input') labels_var = mx.sym.Variable('labels') - ctc = mx.sym.ctc_loss(in_var, labels_var) - acts_nd = mx.nd.array(acts, ctx=default_context()) - labels_nd = mx.nd.array(labels, ctx=default_context()) - exe = ctc.bind(ctx=default_context(), args=[acts_nd, labels_nd]) - # test forward with grad calc - exe.forward(is_train=True) - outTest = exe.outputs[0] - # test forward without grad calc - exe.forward(is_train=False) - outTrain = exe.outputs[0] - # make sure losses calculated with both modes are the same - assert_almost_equal(outTest.asnumpy(), outTrain.asnumpy()) - - # test against ground truth, if available - if loss_truth is not None: - assert_almost_equal(outTest.asnumpy(), loss_truth) - # test grad - check_numeric_gradient(ctc, [acts, labels], grad_nodes=['input'], rtol=0.05, atol=1e-3) - -# check contrib operator for backward compatibility -def check_contrib_ctc_loss(acts, labels, loss_truth): - in_var = mx.sym.Variable('input') - labels_var = mx.sym.Variable('labels') - ctc = mx.sym.contrib.ctc_loss(in_var, labels_var) + if contrib: + ctc = mx.sym.contrib.ctc_loss(in_var, labels_var) + else: + ctc = mx.sym.ctc_loss(in_var, labels_var) acts_nd = mx.nd.array(acts, ctx=default_context()) labels_nd = mx.nd.array(labels, ctx=default_context()) exe = ctc.bind(ctx=default_context(), args=[acts_nd, labels_nd]) # test forward with grad calc exe.forward(is_train=True) - outTest = exe.outputs[0] + outTest = exe.outputs[0].copy() # test forward without grad calc exe.forward(is_train=False) outTrain = exe.outputs[0] # make sure losses calculated with both modes are the same - assert_almost_equal(outTest.asnumpy(), outTrain.asnumpy()) + assert_almost_equal(outTest, outTrain) # test against ground truth, if available if loss_truth is not None: - assert_almost_equal(outTest.asnumpy(), loss_truth) + assert_almost_equal(outTest, loss_truth) # test grad check_numeric_gradient(ctc, [acts, labels], grad_nodes=['input'], rtol=0.05, atol=1e-3) @@ -5494,8 +5458,9 @@ def test_ctc_loss(): dtype=np.float32) labels = np.array([[2, 3, 0], [2, 3, 0]]) true_loss = np.array([4.04789, 4.04789], dtype=np.float32) # from Torch - check_ctc_loss(acts, labels, true_loss) - check_contrib_ctc_loss(acts, labels, true_loss) + for contrib in [False, True]: + check_ctc_loss(acts, labels, true_loss, contrib=contrib) + # Test 2: acts2 = np.array([ @@ -5504,14 +5469,14 @@ def test_ctc_loss(): [[-15, -14, -13, -12, -11], [-15, -14.2, -13.5, -12.2, -11.22]]], dtype=np.float32) labels2 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.float32) true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch - check_ctc_loss(acts2, labels2, true_loss) - check_contrib_ctc_loss(acts2, labels2, true_loss) + for contrib in [False, True]: + check_ctc_loss(acts2, labels2, true_loss, contrib=contrib) # Test 3: check use integer type as label labels3 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.int32) true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch - check_ctc_loss(acts2, labels3, true_loss) - check_contrib_ctc_loss(acts2, labels3, true_loss) + for contrib in [False, True]: + check_ctc_loss(acts2, labels3, true_loss, contrib=contrib) @with_seed() def test_ctc_loss_with_large_classes(): @@ -5531,90 +5496,11 @@ def test_ctc_loss_with_large_classes(): nd_label = mx.nd.array(label) loss = mx.nd.ctc_loss(data=nd_data, label=nd_label) expected_loss = np.array([688.02826, 145.34462]) - assert_almost_equal(loss.asnumpy(), expected_loss) + assert_almost_equal(loss, expected_loss) @with_seed() def test_ctc_loss_grad(): - def check_ctc_loss_grad(blank_label): # from tf - vocab_size = 5 - max_label_len = 5 - padding_mask = -1+ (blank_label=='first') - - targets_0 = [0, 1, 2, 1, 0] - loss_log_prob_0 = -3.34211 - input_prob_matrix_0 = np.asarray( - [[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], - [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436], - [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688], - [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533], - [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]], - dtype=np.float32) - gradient_log_prob_0 = np.asarray( - [[-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], - [0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436], - [0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688], - [0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533], - [-0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]], - dtype=np.float32) - - targets_1 = [0, 1, 1, 0] - loss_log_prob_1 = -5.42262 - input_prob_matrix_1 = np.asarray( - [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508], - [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549], - [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456], - [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345], - [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]], - dtype=np.float32) - gradient_log_prob_1 = np.asarray( - [[-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508], - [0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549], - [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544], - [0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345], - [-0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]], - dtype=np.float32) - - inputs = [ - np.vstack( - [input_prob_matrix_0[t, :], input_prob_matrix_1[t, :]]) - for t in range(5) - ] + 2 * [np.nan * np.ones((2, vocab_size+1), np.float32)] - inputs = np.log(np.asarray(inputs, dtype=np.float32)) - - grad_truth = np.array([ - np.vstack( - [gradient_log_prob_0[t, :], gradient_log_prob_1[t, :]]) - for t in range(5) - ] + 2 * [np.zeros((2, vocab_size+1), np.float32)]) - - if blank_label == 'first': - inputs = np.roll(inputs, 1, axis=2) - grad_truth = np.roll(grad_truth, 1, axis=2) - - labels = (np.asarray([x + [padding_mask]*(max_label_len-len(x)) - for x in [targets_0, targets_1]])+(blank_label == 'first')) - - seq_lens = np.array([5, 5], dtype=np.int32) - label_lens = np.array([5, 4], dtype=np.int32) - loss_truth = np.array([-loss_log_prob_0, -loss_log_prob_1], np.float32) - - with default_context(): - data = mx.nd.array(inputs) - label = mx.nd.array(labels) - data.attach_grad() - with mx.autograd.record(): - l = mx.ndarray.CTCLoss(data, label, - use_data_lengths=True, - use_label_lengths=True, - data_lengths=mx.nd.array(seq_lens), - label_lengths=mx.nd.array(label_lens), - blank_label=blank_label) - l.backward() - assert_almost_equal(l.asnumpy(), loss_truth, atol=1e-5, rtol=1e-5) - assert_almost_equal(data.grad.asnumpy(), grad_truth, atol=1e-5, rtol=1e-5) - - # check contrib operator for backward compatibility - def check_contrib_ctc_loss_grad(blank_label): # from tf + def check_ctc_loss_grad(blank_label, contrib=False): # from tf vocab_size = 5 max_label_len = 5 padding_mask = -1+ (blank_label=='first') @@ -5682,22 +5568,28 @@ def check_contrib_ctc_loss_grad(blank_label): # from tf label = mx.nd.array(labels) data.attach_grad() with mx.autograd.record(): - l = mx.contrib.ndarray.CTCLoss(data, label, - use_data_lengths=True, - use_label_lengths=True, - data_lengths=mx.nd.array(seq_lens), - label_lengths=mx.nd.array(label_lens), - blank_label=blank_label) + if contrib: + l = mx.contrib.ndarray.CTCLoss(data, label, + use_data_lengths=True, + use_label_lengths=True, + data_lengths=mx.nd.array(seq_lens), + label_lengths=mx.nd.array(label_lens), + blank_label=blank_label) + else: + l = mx.ndarray.CTCLoss(data, label, + use_data_lengths=True, + use_label_lengths=True, + data_lengths=mx.nd.array(seq_lens), + label_lengths=mx.nd.array(label_lens), + blank_label=blank_label) l.backward() - assert_almost_equal(l.asnumpy(), loss_truth, atol=1e-5, rtol=1e-5) - assert_almost_equal(data.grad.asnumpy(), grad_truth, atol=1e-5, rtol=1e-5) + assert_almost_equal(l, loss_truth, atol=1e-5, rtol=1e-5) + assert_almost_equal(data.grad, grad_truth, atol=1e-5, rtol=1e-5) - check_ctc_loss_grad('first') - check_ctc_loss_grad('last') - check_contrib_ctc_loss_grad('first') - check_contrib_ctc_loss_grad('last') - + for contrib in [False, True]: + for label in ['first', 'last']: + check_ctc_loss_grad(label, contrib=contrib) @with_seed() def test_quantization_op(): @@ -5936,8 +5828,8 @@ def create_operator(self, ctx, shapes, dtypes): expected_grad = 2 * x2 rtol = 1e-4 atol = 1e-6 - assert_almost_equal(output.asnumpy(), expected_output.asnumpy(), rtol=rtol, atol=atol) - assert_almost_equal(x2.grad.asnumpy(), expected_grad.asnumpy(), rtol=rtol, atol=atol) + assert_almost_equal(output, expected_output, rtol=rtol, atol=atol) + assert_almost_equal(x2.grad, expected_grad, rtol=rtol, atol=atol) # test for backward compatibility, i.e. the correctness of default implementation of @@ -5974,8 +5866,8 @@ def create_operator(self, ctx, shapes, dtypes): with mx.autograd.record(): y = mx.nd.Custom(lhs, rhs, name='mult', op_type='mult') y.backward() - assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy(), rtol=rtol, atol=atol) - assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy(), rtol=rtol, atol=atol) + assert_almost_equal(rhs, lhs.grad, rtol=rtol, atol=atol) + assert_almost_equal(lhs, rhs.grad, rtol=rtol, atol=atol) class MultNoGrad(mx.operator.CustomOp): def forward(self, is_train, req, in_data, out_data, aux): @@ -6008,8 +5900,8 @@ def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, igrad_st with mx.autograd.record(): y2 = mx.nd.Custom(lhs, rhs, name="mult_no_grad", op_type="mult_no_grad") y2.backward() - assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy(), rtol=rtol, atol=atol) - assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy(), rtol=rtol, atol=atol) + assert_almost_equal(rhs, lhs.grad, rtol=rtol, atol=atol) + assert_almost_equal(lhs, rhs.grad, rtol=rtol, atol=atol) class NoInputOp(mx.operator.CustomOp): def __init__(self, length, depth): @@ -6046,7 +5938,7 @@ def create_operator(self, ctx, shapes, dtypes): with mx.autograd.record(): x = mx.nd.Custom(length=10, depth=10, op_type="no_input_op") - assert_almost_equal(x.asnumpy(), np.ones(shape=(10, 10), dtype=np.float32)) + assert_almost_equal(x, np.ones(shape=(10, 10), dtype=np.float32)) @with_seed() @@ -6523,95 +6415,83 @@ def test_laop(): dtype = np.float64 rtol_fw = 1e-7 atol_fw = 1e-9 - num_eps = 1e-6 + num_eps = 2e-6 rtol_bw = 1e-5 - atol_bw = 1e-6 + atol_bw = 1e-5 # enable numerical checking of gradients grad_check = 1 data1 = mx.symbol.Variable('data1') data2 = mx.symbol.Variable('data2') - data3 = mx.symbol.Variable('data3') - check_fw = lambda sym, location, expected :\ - check_symbolic_forward(sym, location, expected, rtol=rtol_fw, - atol=atol_fw, dtype=dtype) - check_grad = lambda sym, location:\ - check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw, - atol=atol_bw, dtype=dtype) rep_3x = lambda a, m, n :\ np.reshape(np.tile(np.array(a).flatten(), 3), (3, 1, m, n)) + def check_fw_grad(sym, location, expected): + check_symbolic_forward(sym, location, expected, rtol=rtol_fw, + atol=atol_fw, dtype=dtype) + if grad_check == 1: + check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw, + atol=atol_bw, dtype=dtype) + + matrix = np.array([[9., 3., -6., 12.], + [3., 26., -7., -11.], + [-6., -7., 9., 7.], + [12., -11., 7., 65.]]) + trian = np.array([[3., 0., 0., 0.], + [1., 5., 0., 0.], + [-2., -1., 2., 0.], + [4., -3., 6., 2.]]) + pow = np.array([[2., 1., 1., 1.], + [1., 4., 1., 1.], + [1., 1., 8., 1.], + [1., 1., 1., 16.]]) + inv = np.array([[8.95/3., 0.05/3., 2.65, -2.5/3.], + [0.05/3., 0.05, 0.05, 0.], + [2.65, 0.05, 2.5, -0.75], + [-2.5/3., 0., -0.75, 0.25]]) + ident = np.eye(4) + shape = (4, 4, 1, 1) + ones = mx.nd.ones(shape).asnumpy() + for lower in [True, False]: upper = not lower # Tests with trivial 1x1 matrices. - shape = (4, 4, 1, 1) data_in = np.random.uniform(1, 10, shape) # test potrf # Note: Have to symmetrize input, for gradient test to work res_potrf = np.sqrt(data_in) test_potrf = mx.sym.linalg.potrf(data1, lower=lower) - check_fw(test_potrf, [data_in], [res_potrf]) - if grad_check == 1: - check_grad(test_potrf, [data_in]) + check_fw_grad(test_potrf, [data_in], [res_potrf]) # test potri - ones = mx.nd.ones(shape).asnumpy() res_potri = np.divide(ones, data_in * data_in) test_potri = mx.sym.linalg.potri(data1, lower=lower) - check_fw(test_potri, [data_in], [res_potri]) - if grad_check == 1: - check_grad(test_potri, [data_in]) + check_fw_grad(test_potri, [data_in], [res_potri]) # test trsm trian_in = data_in * 7. test_trsm = mx.sym.linalg.trsm(data1, data2, alpha=7., lower=lower) - check_fw(test_trsm, [trian_in, data_in], [ones]) - if grad_check == 1: - check_grad(test_trsm, [trian_in,data_in]) + check_fw_grad(test_trsm, [trian_in, data_in], [ones]) # test trmm trian_in = np.divide(ones, trian_in) test_trmm = mx.sym.linalg.trmm(data1, data2, alpha=7., transpose=True, rightside=True, lower=lower) - check_fw(test_trmm, [trian_in, data_in], [ones]) - if grad_check == 1: - check_grad(test_trmm, [trian_in, data_in]) + check_fw_grad(test_trmm, [trian_in, data_in], [ones]) # test sumlogdiag res_sumlogdiag = np.reshape(np.log(data_in), (4, 4)) test_sumlogdiag = mx.sym.linalg.sumlogdiag(data1) - check_fw(test_sumlogdiag, [data_in], [res_sumlogdiag]) - if grad_check == 1: - check_grad(test_sumlogdiag, [data_in]) + check_fw_grad(test_sumlogdiag, [data_in], [res_sumlogdiag]) # more elaborate example of Cholesky factorization - matrix = np.array([[9., 3., -6., 12.], - [3., 26., -7., -11.], - [-6., -7., 9., 7.], - [12., -11., 7., 65.]]) - trian = np.array([[3., 0., 0., 0.], - [1., 5., 0., 0.], - [-2., -1., 2., 0.], - [4., -3., 6., 2.]]) - pow = np.array([[2., 1., 1., 1.], - [1., 4., 1., 1.], - [1., 1., 8., 1.], - [1., 1., 1., 16.]]) - inv = np.array([[8.95/3., 0.05/3., 2.65, -2.5/3.], - [0.05/3., 0.05, 0.05, 0.], - [2.65, 0.05, 2.5, -0.75], - [-2.5/3., 0., -0.75, 0.25]]) - ident = np.eye(4) - low_trian = trian - if not lower: + if upper: trian = np.transpose(trian) # test potrf test_potrf = mx.sym.linalg.potrf(_make_symm_symbol(data1, ndims=4), lower=lower) a = rep_3x(matrix, 4, 4) r = rep_3x(trian, 4, 4) - check_fw(test_potrf, [a], [r]) - if grad_check == 1: - check_grad(test_potrf, [a]) + check_fw_grad(test_potrf, [a], [r]) #test potri data1_ltri = _make_triangle_symm( @@ -6619,77 +6499,54 @@ def test_laop(): test_potri = mx.sym.linalg.potri(data1_ltri, lower=lower) a = rep_3x(trian, 4, 4) r = rep_3x(inv, 4, 4) - check_fw(test_potri, [a], [r]) - if grad_check == 1: - check_grad(test_potri, [a]) + check_fw_grad(test_potri, [a], [r]) # test trsm test_trsm = mx.sym.linalg.trsm(data1_ltri, data2, alpha=7., transpose=upper, lower=lower) - a = rep_3x(trian, 4, 4) b = rep_3x(matrix, 4, 4) r = rep_3x(7. * np.transpose(low_trian), 4, 4) - check_fw(test_trsm, [a, b], [r]) - if grad_check == 1: - check_grad(test_trsm, [a, b]) + check_fw_grad(test_trsm, [a, b], [r]) test_trsm2 = mx.sym.linalg.trsm( data1_ltri, data2, alpha=-2., rightside=True, transpose=lower, lower=lower) r = rep_3x(-2. * low_trian, 4, 4) - check_fw(test_trsm2, [a, b], [r]) - if grad_check == 1: - check_grad(test_trsm2, [a, b]) + check_fw_grad(test_trsm2, [a, b], [r]) test_trsm3 = mx.sym.linalg.trsm( data1_ltri, data2, alpha=0.5, transpose=lower, lower=lower) b = rep_3x(np.transpose(low_trian), 4, 4) r = rep_3x(0.5 * ident, 4, 4) - check_fw(test_trsm3, [a, b], [r]) - if grad_check == 1: - check_grad(test_trsm3, [a, b]) + check_fw_grad(test_trsm3, [a, b], [r]) test_trsm4 = mx.sym.linalg.trsm( data1_ltri, data2, alpha=-0.5, rightside=True, transpose=upper, lower=lower) b = rep_3x(low_trian, 4, 4) r = rep_3x(-0.5 * ident, 4, 4) - check_fw(test_trsm4, [a, b], [r]) - if grad_check == 1: - check_grad(test_trsm4, [a, b]) + check_fw_grad(test_trsm4, [a, b], [r]) # test trmm test_trmm = mx.sym.linalg.trmm( data1_ltri, data2, alpha=7., transpose=True, rightside=True, lower=lower) - a = rep_3x(trian, 4, 4) - b = rep_3x(matrix, 4, 4) + a = [a, rep_3x(matrix, 4, 4)] r = rep_3x(7. * np.dot(matrix, trian.T), 4, 4) - check_fw(test_trmm, [a, b], [r]) - if grad_check == 1: - check_grad(test_trmm, [a, b]) + check_fw_grad(test_trmm, a, [r]) test_trmm2 = mx.sym.linalg.trmm(data1_ltri, data2, alpha=-2., lower=lower) r = rep_3x(-2. * np.dot(trian, matrix), 4, 4) - check_fw(test_trmm2, [a, b], [r]) - if grad_check == 1: - check_grad(test_trmm2, [a, b]) + check_fw_grad(test_trmm2, a, [r]) test_trmm3 = mx.sym.linalg.trmm(data1_ltri, data2, rightside=True, lower=lower) r = rep_3x(np.dot(matrix, trian), 4, 4) - check_fw(test_trmm3, [a, b], [r]) - if grad_check == 1: - check_grad(test_trmm3, [a, b]) + check_fw_grad(test_trmm3, a, [r]) test_trmm4 = mx.sym.linalg.trmm( data1_ltri, data2, alpha=1.2, transpose=True, lower=lower) r = rep_3x(1.2 * np.dot(trian.T, matrix), 4, 4) - check_fw(test_trmm4, [a, b], [r]) - if grad_check == 1: - check_grad(test_trmm4, [a, b]) + check_fw_grad(test_trmm4, a, [r]) - # test sumlogdiag - a = rep_3x(pow, 4, 4) - r = np.reshape(np.tile(10. * np.log(np.array([2.])), 3), (3,)) - check_fw(test_sumlogdiag, [a], [r]) - if grad_check == 1: - check_grad(test_sumlogdiag, [a]) + # test sumlogdiag + r = np.reshape(np.tile(10. * np.log(np.array([2.])), 3), (3,)) + check_fw_grad(test_sumlogdiag, [rep_3x(pow, 4, 4)], [r]) # Tests for operators linalg.syrk, linalg.gelqf @@ -7894,12 +7751,12 @@ def py_bilinear_resize_backward(x, incoming_grads, mode='size'): def check_bilinear_resize_op(shape, height, width): x = mx.nd.random.uniform(shape=shape) y = mx.nd.contrib.BilinearResize2D(x, height=height, width=width) - assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width)) + assert_almost_equal(y, py_bilinear_resize(x.asnumpy(), height, width)) x_scale = width / shape[-1] y_scale = height / shape[-2] y = mx.nd.contrib.BilinearResize2D(x, scale_height=y_scale, scale_width=x_scale) - assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width)) + assert_almost_equal(y, py_bilinear_resize(x.asnumpy(), height, width)) def check_bilinear_resize_modes_op(shape, scale_height=None, scale_width=None, shape_1=None, mode=None): x = mx.nd.random.uniform(shape=shape) original_h = shape[2] @@ -8115,6 +7972,7 @@ def f(x, a, b, c): data = mx.symbol.Variable('data') quad_sym = mx.sym.contrib.quadratic(data=data, a=a, b=b, c=c) for dtype in [np.float16, np.float32, np.float64]: + tol = 1e-2 if dtype is np.float16 else 1e-5 for ndim in range(1, 6): shape = rand_shape_nd(ndim, 5) data_np = np.random.randn(*shape).astype(dtype) @@ -8123,21 +7981,92 @@ def f(x, a, b, c): # check imperative forward output = mx.nd.contrib.quadratic(mx.nd.array(data_np), a=a, b=b, c=c) - assert_almost_equal(output.asnumpy(),expected, - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5) + assert_almost_equal(output, expected, rtol=tol, atol=tol) # check forward - check_symbolic_forward(quad_sym, [data_np], [expected], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5) + check_symbolic_forward(quad_sym, [data_np], [expected], rtol=tol, atol=tol) # check backward check_symbolic_backward(quad_sym, [data_np], [np.ones(expected.shape)], - [backward_expected], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5) + [backward_expected], rtol=tol, atol=tol) # check backward using finite difference check_numeric_gradient(quad_sym, [data_np], atol=0.001) +def allclose_function(contexts): + def getRandom(base, percent = 1.): + return base * (1 + percent * (2 * np.random.random_sample() - 1.) / 100) + + title = 'exp' + for ctx in contexts: + title += ' cpu' if ctx == mx.cpu() else ' gpu' + + title += ' nElem shape' + num_ctx = len(contexts) + result = [False, False] + for dtype in [np.float16, np.float32, np.float64]: + rtol = getRandom(1e-2 if dtype is np.float16 else 1e-5) + atol = getRandom(1e-4 if dtype is np.float16 else 1e-7) + print('\nnumpy.{}: atol = {} rtol = {}'.format(dtype.__name__, atol, rtol)) + print(title) + for ndim in range(1, 10): + shape = rand_shape_nd(ndim, 8) + a_np = np.random.randn(*shape).astype(dtype) + b_np = (a_np + np.random.randn(*shape).astype(dtype) / 10000000).astype(dtype) + expected = np.allclose(a_np, b_np, rtol, atol) + + for n, ctx in enumerate(contexts): + a_ctx = mx.nd.array(a_np, dtype = dtype, ctx=ctx) + b_ctx = mx.nd.array(b_np, dtype = dtype, ctx=ctx) + output = mx.nd.contrib.allclose(a_ctx, b_ctx, rtol=rtol, atol=atol) + result[n] = output.asnumpy() == 1 + if expected != result[n]: + # Preparing the output of elements of the array, which are considered as "not close" AND + # corresponding elements of comparison CPU/GPU/Python vectors, which are considered as "close" + v_ctx = 'CPU' if ctx == mx.cpu() else 'GPU' + if expected: + v_cmp = 'Python' + a_b = a_ctx.asnumpy() + b_b = b_ctx.asnumpy() + a_g = np.asarray(a_np) + b_g = np.asarray(b_np) + + else: + v_cmp = v_ctx + v_ctx = 'Python' + a_b = np.asarray(a_np) + b_b = np.asarray(b_np) + a_g = a_ctx.asnumpy() + b_g = b_ctx.asnumpy() + + print('\n *** Violations found on %s, but not on %s side ***' % (v_ctx, v_cmp)) + frmt = " a[{0:d}]: b[{0:d}]:" \ + " abs(a[{0:d}]-b[{0:d}]) - atol + rtol*abs(b[{0:d}]):" + + # Define the indices of all violations and corresponding values of coordinates + bad_indexes = np.abs(a_b - b_b) >= atol + rtol * abs(b_b) + a_values = [a_b[bad_indexes], a_g[bad_indexes]] + b_values = [b_b[bad_indexes], b_g[bad_indexes]] + idx = np.asarray(np.where(bad_indexes == True)) + idx = idx.reshape(1, idx.size) + idx_flat = np.asarray(np.where(bad_indexes.flatten() == True)).flatten() + for i in range(len(a_values[0])): + flat_idx = idx_flat[i] + print('{}: index = {} flat_index = {}'.format('%4d'%i, idx[i], flat_idx)) + print(frmt.format(flat_idx)) + for j in range(2): + diff = np.abs(a_values[j][i]-b_values[j][i]) - atol + rtol*abs(b_values[j][i]) + print('{}: {} {} {}'.format('%6s'%v_ctx, a_values[j][i], b_values[j][i], diff)) + + + if num_ctx == 1: + print(' {0:d} {1:d} {2:10d} {3:}'.format(expected, result[0], np.prod(shape), shape)) + else: + print(' {0:d} {1:d} {2:d} {3:10d} {4:}'.format(expected, result[0], result[1], np.prod(shape), shape)) + + if expected != result[0] or num_ctx > 1 and expected != result[1]: + assert False + +@with_seed() +def test_allclose_function(): + allclose_function([default_context()]) @with_seed() def test_histogram(): @@ -8153,12 +8082,12 @@ def f(x, bins=10, range=None): bin_range = (-2.5, 2.5) mx_histo1, mx_bins1 = mx.nd.histogram(x, bins=bin_cnt, range=bin_range) np_histo1, np_bins1 = f(x.asnumpy(), bins=bin_cnt, range=bin_range) - assert_almost_equal(mx_bins1.asnumpy(), np_bins1) - assert_almost_equal(mx_histo1.asnumpy(), np_histo1, rtol=1e-3, atol=1e-5) + assert_almost_equal(mx_bins1, np_bins1) + assert_almost_equal(mx_histo1, np_histo1, rtol=1e-3, atol=1e-5) mx_histo2, mx_bins2 = mx.nd.histogram(x, bins=mx_bins) np_histo2, np_bins2 = f(x.asnumpy(), bins=np_bins) - assert_almost_equal(mx_histo2.asnumpy(), np_histo2, rtol=1e-3, atol=1e-5) - assert_almost_equal(mx_bins2.asnumpy(), np_bins2, rtol=1e-3, atol=1e-5) + assert_almost_equal(mx_histo2, np_histo2, rtol=1e-3, atol=1e-5) + assert_almost_equal(mx_bins2, np_bins2, rtol=1e-3, atol=1e-5) data = mx.sym.Variable("data") @@ -8520,9 +8449,9 @@ def test_roi_align_value(sampling_ratio=0, position_sensitive=False): spatial_scale, sampling_ratio, position_sensitive, dy.asnumpy()) - assert_almost_equal(output.asnumpy(), real_output, atol=1e-3) - assert_almost_equal(data.grad.asnumpy(), dx, atol=1e-3) - assert_almost_equal(rois.grad.asnumpy(), drois, atol=1e-3) + assert_almost_equal(output, real_output, atol=1e-3) + assert_almost_equal(data.grad, dx, atol=1e-3) + assert_almost_equal(rois.grad, drois, atol=1e-3) # modified from test_roipooling() def test_roi_align_autograd(sampling_ratio=0): @@ -8752,24 +8681,8 @@ def test_diag(): a_np = np.random.random((h, w)).astype(np.float32) a = mx.nd.array(a_np).astype('float32') - # k == 0 - r = mx.nd.diag(a) - assert_almost_equal(r.asnumpy(), np.diag(a_np)) - - # k == 1 - k = 1 - r = mx.nd.diag(a, k=k) - assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) - - # k == -1 - k = -1 - r = mx.nd.diag(a, k=k) - assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) - - # random k - k = np.random.randint(-min(h,w) + 1, min(h,w)) - r = mx.nd.diag(a, k=k) - assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + for k in [0, 1, -1, np.random.randint(-min(h,w) + 1, min(h,w))]: + assert_almost_equal(mx.nd.diag(a, k=k), np.diag(a_np, k=k)) # invalid k k = max(h,w) + 1 @@ -8797,9 +8710,7 @@ def test_diag(): # k is random k = np.random.randint(-d,d) - r = mx.nd.diag(a, k=k) - - assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + assert_almost_equal(mx.nd.diag(a, k=k), np.diag(a_np, k=k)) # Test 2d backward, k=0 data = mx.sym.Variable('data') @@ -8826,19 +8737,19 @@ def test_diag(): # k = 0, axis1=0, axis2=1 r = mx.nd.diag(data=a, k=0, axis1=0, axis2=1) - assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=0, axis1=0, axis2=1)) + assert_almost_equal(r, np.diagonal(a_np, offset=0, axis1=0, axis2=1)) # k = 1, axis1=1, axis2=0 r = mx.nd.diag(data=a, k=1, axis1=1, axis2=0) - assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=1, axis1=1, axis2=0)) + assert_almost_equal(r, np.diagonal(a_np, offset=1, axis1=1, axis2=0)) # k = -1 axis1=1, axis3=3 r = mx.nd.diag(data=a, k=-1, axis1=1, axis2=3) - assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=-1, axis1=1, axis2=3)) + assert_almost_equal(r, np.diagonal(a_np, offset=-1, axis1=1, axis2=3)) # k = 2, axis1=-2, axis2=0 r = mx.nd.diag(data=a, k=2, axis1=-2, axis2=0) - assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=2, axis1=-2, axis2=0)) + assert_almost_equal(r, np.diagonal(a_np, offset=2, axis1=-2, axis2=0)) # Test 4d backward, k=0, axis1=3, axis2=0 data = mx.sym.Variable('data') @@ -8880,7 +8791,7 @@ def f(x, blocksize): data_np = data.asnumpy() expected = f(data_np, block) output = mx.nd.depth_to_space(data, block) - assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) + assert_almost_equal(output, expected, atol=1e-3, rtol=1e-3) shape_out = (n, c // (block ** 2), h * block, w * block) data = mx.sym.Variable('data') @@ -8931,7 +8842,7 @@ def f(x, blocksize): data_np = data.asnumpy() expected = f(data_np, block) output = mx.nd.space_to_depth(data, block) - assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) + assert_almost_equal(output, expected, atol=1e-3, rtol=1e-3) shape_out = (n, c * (block ** 2), h // block, w // block) data = mx.sym.Variable('data') diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index fe276685bfe3..efcf16dc78da 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -19,7 +19,7 @@ import math import itertools import mxnet as mx -from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry, assert_almost_equal import numpy as np import random as rnd from common import setup_module, with_seed, random_seed, teardown @@ -583,14 +583,14 @@ def test_sample_multinomial(): prob = prob.reshape((1, prob.shape[0])) for i in range(x.shape[0]): freq = np.bincount(y[i,:].astype('int32'), minlength=5)/np.float32(samples)*x[i,:].sum() - mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20, atol=1e-1) + assert_almost_equal(freq, x[i], rtol=0.20, atol=1e-1) rprob = x[i][y[i].astype('int32')]/x[i].sum() - mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5) + assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5) real_dx = np.zeros((5,)) for j in range(samples): real_dx[int(y[i][j])] += 5.0 / rprob[j] - mx.test_utils.assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5) + assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5) for dtype in ['uint8', 'float16', 'float32']: # Bound check for the output data types. 'int32' and 'float64' require large memory so are skipped. x = mx.nd.zeros(2 ** 25) # Larger than the max integer in float32 without precision loss. @@ -883,8 +883,8 @@ def compute_expected_prob(): # test ndarray true_classes = mx.nd.random.uniform(0, range_max, shape=(num_true,)).astype('int32') sampled_classes, exp_cnt_true, exp_cnt_sampled = mx.nd.contrib.rand_zipfian(true_classes, num_sampled, range_max) - mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2) - mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2) + assert_almost_equal(exp_cnt_sampled, exp_cnt[sampled_classes], rtol=1e-1, atol=1e-2) + assert_almost_equal(exp_cnt_true, exp_cnt[true_classes], rtol=1e-1, atol=1e-2) # test symbol true_classes_var = mx.sym.var('true_classes') @@ -893,8 +893,8 @@ def compute_expected_prob(): executor = outputs.bind(mx.context.current_context(), {'true_classes' : true_classes}) executor.forward() sampled_classes, exp_cnt_true, exp_cnt_sampled = executor.outputs - mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2) - mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2) + assert_almost_equal(exp_cnt_sampled, exp_cnt[sampled_classes], rtol=1e-1, atol=1e-2) + assert_almost_equal(exp_cnt_true, exp_cnt[true_classes], rtol=1e-1, atol=1e-2) # Issue #10277 (https://github.com/apache/incubator-mxnet/issues/10277) discusses this test. @with_seed() diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 91194c562a57..4c4e3dbdfc51 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -732,11 +732,9 @@ def check_sparse_mathematical_core(name, stype, assert arr_grad.stype == expected_grad_result_type - arr_grad = arr_grad.asnumpy() - if verbose is True: print(name) - print("arr_grad", arr_grad) + print("arr_grad", arr_grad.asnumpy()) print("input_grad", input_grad) assert_almost_equal(arr_grad, input_grad, equal_nan=True) diff --git a/tests/python/unittest/test_subgraph.py b/tests/python/unittest/test_subgraph.py index 4c13f9c70dfc..3da125a946bc 100644 --- a/tests/python/unittest/test_subgraph.py +++ b/tests/python/unittest/test_subgraph.py @@ -20,15 +20,8 @@ import numpy as np import mxnet as mx import copy -import math -import ctypes -import random -import itertools -from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * -from mxnet.base import py_str, MXNetError, _as_list, SymbolHandle, check_call, _LIB, c_handle_array, mx_uint from common import setup_module, with_seed, teardown -import unittest from mxnet.gluon.model_zoo.vision import get_model def make_subgraph(subg, *args):