From d06be8c683600bea9d695af61c41ae7639c3705a Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 2 Oct 2020 15:48:21 -0700 Subject: [PATCH] Backport PRs in v1.7.x missing from v1.x to v1.8.x (#19262) * * Fix einsum gradient (#18482) * [v1.7.x] Backport PRs of numpy features (#18653) * add zero grad for npi_unique (#18080) * fix np.clip scalar input case (#17788) * fix true_divide (#18393) Co-authored-by: Hao Jin Co-authored-by: Xi Wang * [v1.7.x] backport mixed type binary ops to v1.7.x (#18649) * Fix Windows GPU CI (#17962) Update Windows CI to use VS 2019 and enable x64 bit toolchain. Previously we are using an older 32 bit toolchain causing OOM errors during linking. Switching to x64 bit toolchain on the older VS version previously used by the CI was attempted in #17912 and did not work. Update to Cuda 10.2 as it is required by VS 2019. Switch to ninja-build on Windows to speed up build as ninja-build is now preinstalled. Remove logic to install cmake 3.16 on every PR as cmake 3.17 is now preinstalled. Add build retrials due to cuda thrust + VS2019 flakyness. Co-authored-by: vexilligera * backport mixed type Co-authored-by: Leonard Lausen Co-authored-by: vexilligera * revise activations (#18700) * [v1.6] Fix the monitor_callback invalid issue during calibration with variable input shapes (#18632) (#18703) * Fix the monitor_callback invalid issue during calibration with variable input shapes * retrigger CI * Add UT for monitor check and disable codecov Co-authored-by: Tao Lv * Fail build_windows.py if all retries failed (#18177) * Update to thrust 1.9.8 on Windows (#18218) * Update to thrust 1.9.8 on Windows * Remove debug logic * Re-enable build retries on MSVC (#18230) Updating thrust alone did not help. Similar issues (though less often) still occur with updated thrust, and also with nvidia cub. Tracked upstream at https://github.com/thrust/thrust/issues/1090 Co-authored-by: Ke Han <38852697+hanke580@users.noreply.github.com> Co-authored-by: Xingjian Shi Co-authored-by: Hao Jin Co-authored-by: Xi Wang Co-authored-by: Yijun Chen Co-authored-by: vexilligera Co-authored-by: ciyong Co-authored-by: Tao Lv --- .codecov.yml | 3 + .gitignore | 4 + CMakeLists.txt | 2 +- ci/build_windows.py | 93 ++++++++---- include/mxnet/imperative.h | 11 ++ python/mxnet/gluon/nn/activations.py | 18 ++- python/mxnet/numpy/multiarray.py | 5 + python/mxnet/symbol/numpy/_symbol.py | 8 +- src/common/utils.h | 19 +++ .../contrib/gradient_multiplier_op.cc | 6 +- src/operator/mshadow_op.h | 83 ++++++++++- src/operator/mxnet_op.h | 9 +- src/operator/numpy/np_einsum_op-inl.h | 4 +- .../numpy/np_elemwise_broadcast_logic_op.cc | 18 +-- .../numpy/np_elemwise_broadcast_op.cc | 136 +++++++++--------- .../numpy/np_elemwise_broadcast_op.cu | 41 +++--- src/operator/numpy/np_elemwise_broadcast_op.h | 112 ++++++--------- .../np_elemwise_broadcast_op_extended.cc | 68 ++++----- src/operator/numpy/np_matrix_op-inl.h | 4 +- src/operator/numpy/np_true_divide-inl.h | 127 ++-------------- src/operator/numpy/np_true_divide.cc | 40 +++--- src/operator/numpy/np_true_divide.cu | 4 + src/operator/numpy/np_unique_op.cc | 1 + src/operator/operator_tune.cc | 2 + .../tensor/elemwise_binary_broadcast_op.h | 4 +- .../tensor/elemwise_binary_scalar_op.h | 123 ++++++++++++---- .../tensor/elemwise_binary_scalar_op_basic.cc | 49 ++++--- .../elemwise_binary_scalar_op_extended.cc | 36 ++--- .../tensor/elemwise_binary_scalar_op_logic.cc | 3 +- tests/python/unittest/test_numpy_gluon.py | 50 +++++++ tests/python/unittest/test_numpy_op.py | 135 +++++++++++------ tests/python/unittest/test_symbol.py | 2 - 32 files changed, 700 insertions(+), 520 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 97624c21b8fa..70037e6483ff 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -4,6 +4,9 @@ codecov: require_ci_to_pass: yes coverage: + status: + project: off + patch: off precision: 2 round: down range: "70...100" diff --git a/.gitignore b/.gitignore index c50d1ec99b9f..9fafdb13cb7e 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,10 @@ cmake_install.cmake # Mac OS X .DS_Store +# Windows +windows_package.7z +windows_package + #Notebook Automated Test !tests/nightly/test_tutorial_config.txt !tests/nightly/TestNotebook diff --git a/CMakeLists.txt b/CMakeLists.txt index 8955551dfeb2..0fa8c9c51af5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -169,7 +169,7 @@ if(MSVC) add_definitions(-DDMLC_STRICT_CXX11) add_definitions(-DNOMINMAX) set(CMAKE_C_FLAGS "/MP") - set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} /bigobj") + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} ${CMAKE_CXX_FLAGS} /bigobj") else() include(CheckCXXCompilerFlag) if(USE_CXX14_IF_AVAILABLE) diff --git a/ci/build_windows.py b/ci/build_windows.py index 2590d211c671..c8d3af515b5a 100755 --- a/ci/build_windows.py +++ b/ci/build_windows.py @@ -31,15 +31,18 @@ import tempfile import time import zipfile +import requests from distutils.dir_util import copy_tree from enum import Enum -from subprocess import check_call +from subprocess import check_call, call from util import * KNOWN_VCVARS = { + # https://gitlab.kitware.com/cmake/cmake/issues/18920 'VS 2015': r'C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\x86_amd64\vcvarsx86_amd64.bat', - 'VS 2017': r'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsx86_amd64.bat' + 'VS 2017': r'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsx86_amd64.bat', + 'VS 2019': r'C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat', } @@ -54,6 +57,8 @@ class BuildFlavour(Enum): CMAKE_FLAGS = { 'WIN_CPU': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' '-DENABLE_CUDA_RTC=OFF ' @@ -67,6 +72,8 @@ class BuildFlavour(Enum): '-DCMAKE_BUILD_TYPE=Release') , 'WIN_CPU_MKLDNN': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' '-DENABLE_CUDA_RTC=OFF ' @@ -80,6 +87,8 @@ class BuildFlavour(Enum): '-DCMAKE_BUILD_TYPE=Release') , 'WIN_CPU_MKLDNN_MKL': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' '-DENABLE_CUDA_RTC=OFF ' @@ -93,6 +102,8 @@ class BuildFlavour(Enum): '-DCMAKE_BUILD_TYPE=Release') , 'WIN_CPU_MKL': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' '-DENABLE_CUDA_RTC=OFF ' @@ -106,6 +117,8 @@ class BuildFlavour(Enum): '-DCMAKE_BUILD_TYPE=Release') , 'WIN_GPU': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=ON ' '-DUSE_CUDNN=ON ' '-DENABLE_CUDA_RTC=ON ' @@ -115,11 +128,12 @@ class BuildFlavour(Enum): '-DUSE_LAPACK=ON ' '-DUSE_DIST_KVSTORE=OFF ' '-DMXNET_CUDA_ARCH="5.2" ' - '-DCMAKE_CXX_FLAGS="/FS /MD /O2 /Ob2" ' '-DUSE_MKL_IF_AVAILABLE=OFF ' '-DCMAKE_BUILD_TYPE=Release') , 'WIN_GPU_MKLDNN': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=ON ' '-DUSE_CUDNN=ON ' '-DENABLE_CUDA_RTC=ON ' @@ -130,7 +144,6 @@ class BuildFlavour(Enum): '-DUSE_DIST_KVSTORE=OFF ' '-DMXNET_CUDA_ARCH="5.2" ' '-DUSE_MKLDNN=ON ' - '-DCMAKE_CXX_FLAGS="/FS /MD /O2 /Ob2" ' '-DCMAKE_BUILD_TYPE=Release') } @@ -140,39 +153,65 @@ def windows_build(args): logging.info("Using vcvars environment:\n{}".format(args.vcvars)) path = args.output - os.makedirs(path, exist_ok=True) mxnet_root = get_mxnet_root() logging.info("Found MXNet root: {}".format(mxnet_root)) - url = 'https://github.com/Kitware/CMake/releases/download/v3.16.1/cmake-3.16.1-win64-x64.zip' - with tempfile.TemporaryDirectory() as tmpdir: - cmake_file_path = download_file(url, tmpdir) - with zipfile.ZipFile(cmake_file_path, 'r') as zip_ref: - # Create $tmpdir\cmake-3.16.1-win64-x64\bin\cmake.exe - zip_ref.extractall(tmpdir) + if 'GPU' in args.flavour: + # Get Thrust version to be shipped in Cuda 11, due to flakyness of + # older Thrust versions with MSVC 19 compiler + with remember_cwd(): + tmpdirname = tempfile.mkdtemp() + os.chdir(tmpdirname) + r = requests.get('https://github.com/thrust/thrust/archive/1.9.8.zip', allow_redirects=True) + with open('thrust.zip', 'wb') as f: + f.write(r.content) + with zipfile.ZipFile('thrust.zip', 'r') as zip_ref: + zip_ref.extractall('.') + thrust_path = os.path.join(tmpdirname, "thrust-1.9.8") + + + # cuda thrust / CUB + VS 2019 is flaky: try multiple times if fail + MAXIMUM_TRY = 5 + build_try = 0 + + while build_try < MAXIMUM_TRY: + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) with remember_cwd(): os.chdir(path) - cmd = "\"{}\" && {} -G \"NMake Makefiles JOM\" {} {}".format( - args.vcvars, - os.path.join(tmpdir, 'cmake-3.16.1-win64-x64', 'bin', 'cmake.exe'), - CMAKE_FLAGS[args.flavour], mxnet_root) + env = os.environ.copy() + if 'GPU' in args.flavour: + env["CXXFLAGS"] = '/FS /MD /O2 /Ob2 /I {}'.format(thrust_path) + env["CUDAFLAGS"] = '-I {}'.format(thrust_path) + cmd = "\"{}\" && cmake -GNinja {} {}".format(args.vcvars, + CMAKE_FLAGS[args.flavour], + mxnet_root) logging.info("Generating project with CMake:\n{}".format(cmd)) - check_call(cmd, shell=True) + check_call(cmd, shell=True, env=env) - cmd = "\"{}\" && jom".format(args.vcvars) - logging.info("Building with jom:\n{}".format(cmd)) + cmd = "\"{}\" && ninja".format(args.vcvars) + logging.info("Building:\n{}".format(cmd)) t0 = int(time.time()) - check_call(cmd, shell=True) + ret = call(cmd, shell=True) + - logging.info( - "Build flavour: {} complete in directory: \"{}\"".format( - args.flavour, os.path.abspath(path))) - logging.info("Build took {}".format( - datetime.timedelta(seconds=int(time.time() - t0)))) - windows_package(args) + if ret != 0: + build_try += 1 + logging.info("{} build(s) have failed".format(build_try)) + else: + logging.info("Build flavour: {} complete in directory: \"{}\"".format(args.flavour, os.path.abspath(path))) + logging.info("Build took {}".format(datetime.timedelta(seconds=int(time.time() - t0)))) + break + + if ret == 0: + windows_package(args) + else: + logging.info("Build failed") + sys.exit(1) def windows_package(args): @@ -233,7 +272,7 @@ def main(): parser.add_argument("--vcvars", help="vcvars batch file location, typically inside vs studio install dir", - default=KNOWN_VCVARS['VS 2015'], + default=KNOWN_VCVARS['VS 2019'], type=str) parser.add_argument("--arch", @@ -258,7 +297,7 @@ def main(): if 'OpenCV_DIR' not in os.environ: os.environ["OpenCV_DIR"] = "C:\\Program Files\\OpenCV-v3.4.1\\build" if 'CUDA_PATH' not in os.environ: - os.environ["CUDA_PATH"] = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v9.2" + os.environ["CUDA_PATH"] = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2" if 'MKL_ROOT' not in os.environ: os.environ["MKL_ROOT"] = "C:\\Program Files (x86)\\IntelSWTools\\compilers_and_libraries\\windows\\mkl" windows_build(args) diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 6a367b3ccef5..783ad26803ed 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -117,6 +117,16 @@ class Imperative { } return is_np_shape_thread_local_ ? 1 : 0; } + + /*! \brief return current numpy default dtype compatibility status. + * */ + bool is_np_default_dtype() const { + if (is_np_default_dtype_global_) { + return true; + } + return false; + } + /*! \brief specify numpy compatibility off, thread local on or global on. */ bool set_is_np_shape(int is_np_shape) { NumpyShape flag = static_cast(is_np_shape); @@ -215,6 +225,7 @@ class Imperative { static MX_THREAD_LOCAL bool is_np_shape_thread_local_; #endif bool is_np_shape_global_{false}; + bool is_np_default_dtype_global_{false}; /*! \brief node count used for naming */ std::atomic node_count_{0}; /*! \brief variable count used for naming */ diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py index 1b9ce91dd2aa..3cccc851e39b 100644 --- a/python/mxnet/gluon/nn/activations.py +++ b/python/mxnet/gluon/nn/activations.py @@ -139,7 +139,8 @@ def __init__(self, alpha_initializer=initializer.Constant(0.25), init=alpha_initializer) def hybrid_forward(self, F, x, alpha): - return F.LeakyReLU(x, gamma=alpha, act_type='prelu', name='fwd') + leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU + return leaky_relu(x, gamma=alpha, act_type='prelu', name='fwd') class ELU(HybridBlock): @@ -167,7 +168,8 @@ def __init__(self, alpha=1.0, **kwargs): self._alpha = alpha def hybrid_forward(self, F, x): - return F.LeakyReLU(x, act_type='elu', slope=self._alpha) + leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU + return leaky_relu(x, act_type='elu', slope=self._alpha) class SELU(HybridBlock): @@ -187,7 +189,9 @@ def __init__(self, **kwargs): super(SELU, self).__init__(**kwargs) def hybrid_forward(self, F, x): - return F.LeakyReLU(x, act_type='selu', name='fwd') + leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU + return leaky_relu(x, act_type='selu', name='fwd') + class GELU(HybridBlock): r""" @@ -206,7 +210,8 @@ def __init__(self, **kwargs): super(GELU, self).__init__(**kwargs) def hybrid_forward(self, F, x): - return F.LeakyReLU(x, act_type='gelu', name='fwd') + leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU + return leaky_relu(x, act_type='gelu', name='fwd') class Swish(HybridBlock): @@ -232,4 +237,7 @@ def __init__(self, beta=1.0, **kwargs): self._beta = beta def hybrid_forward(self, F, x): - return x * F.sigmoid(self._beta * x, name='fwd') + if is_np_array(): + return x * F.npx.sigmoid(self._beta * x) + else: + return x * F.sigmoid(self._beta * x, name='fwd') diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index fceaaf3a282f..9a803d48b5b3 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -6174,6 +6174,11 @@ def clip(a, a_min, a_max, out=None): >>> np.clip(a, 3, 6, out=a) array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32) """ + from numbers import Number + if isinstance(a, Number): + # In case input is a scalar, the computation would fall back to native numpy. + # The value returned would be a python scalar. + return _np.clip(a, a_min, a_max, out=None) return _mx_nd_np.clip(a, a_min, a_max, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index da8d5d738f46..5f81a73a4b75 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -1529,13 +1529,15 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou if isinstance(rhs, numeric_types): return fn_scalar(lhs, rhs, out=out) else: + is_int = isinstance(rhs, integer_types) if rfn_scalar is None: # commutative function - return lfn_scalar(rhs, float(lhs), out=out) + return lfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out) else: - return rfn_scalar(rhs, float(lhs), out=out) + return rfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out) elif isinstance(rhs, numeric_types): - return lfn_scalar(lhs, float(rhs), out=out) + is_int = isinstance(rhs, integer_types) + return lfn_scalar(lhs, scalar=float(rhs), is_int=is_int, out=out) elif isinstance(rhs, Symbol): return fn_array(lhs, rhs, out=out) else: diff --git a/src/common/utils.h b/src/common/utils.h index 44d4fc3e8772..227830708ff6 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -874,6 +875,11 @@ inline bool is_float(const int dtype) { return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16; } +inline bool is_int(const int dtype) { + return dtype == mshadow::kUint8 || dtype == mshadow::kInt8 || + dtype == mshadow::kInt32 || dtype == mshadow::kInt64; +} + inline int get_more_precise_type(const int type1, const int type2) { if (type1 == type2) return type1; if (is_float(type1) && is_float(type2)) { @@ -910,6 +916,19 @@ inline int np_binary_out_infer_type(const int type1, const int type2) { return get_more_precise_type(type1, type2); } +inline int GetDefaultDtype() { + return Imperative::Get()->is_np_default_dtype() ? + mshadow::kFloat64 : + mshadow::kFloat32; +} + +inline int GetDefaultDtype(int dtype) { + if (dtype != -1) return dtype; + return Imperative::Get()->is_np_default_dtype() ? + mshadow::kFloat64 : + mshadow::kFloat32; +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/operator/contrib/gradient_multiplier_op.cc b/src/operator/contrib/gradient_multiplier_op.cc index 0a49ec1c36b3..5221d89a6056 100644 --- a/src/operator/contrib/gradient_multiplier_op.cc +++ b/src/operator/contrib/gradient_multiplier_op.cc @@ -77,9 +77,7 @@ In forward pass it acts as an identity transform. During backpropagation it multiplies the gradient from the subsequent level by a scalar factor lambda and passes it to the preceding layer. )code" ADD_FILELINE) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FComputeEx", UnaryOp::IdentityComputeEx) @@ -88,7 +86,7 @@ the preceding layer. [](const NodeAttrs& attrs){ return std::vector{true}; }) -.add_argument("scalar", "float", "lambda multiplier"); +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_contrib_backward_gradientmultiplier) .set_attr("TIsBackward", true) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 2d4d49254676..61d299d39a40 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -148,7 +148,6 @@ struct true_divide : public mxnet_op::tunable { return static_cast(a) / static_cast(b); } -#ifndef _WIN32 template::value, int>::type = 0> MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { @@ -166,7 +165,6 @@ struct true_divide : public mxnet_op::tunable { MSHADOW_XINLINE static double Map(DType a, double b) { return static_cast(a) / b; } -#endif }; struct rtrue_divide : public mxnet_op::tunable { @@ -182,7 +180,6 @@ struct rtrue_divide : public mxnet_op::tunable { return static_cast(b) / static_cast(a); } -#ifndef _WIN32 template::value, int>::type = 0> MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { @@ -200,14 +197,12 @@ struct rtrue_divide : public mxnet_op::tunable { MSHADOW_XINLINE static double Map(DType a, double b) { return b / static_cast(a); } -#endif }; MXNET_BINARY_MATH_OP_NC(left, a); MXNET_BINARY_MATH_OP_NC(right, b); -#ifndef _WIN32 struct mixed_plus { template::value, int>::type = 0> @@ -345,8 +340,12 @@ struct mixed_rpower { return static_cast(math::pow(b, a)); } }; -#endif +#pragma GCC diagnostic push +#if __GNUC__ >= 7 +#pragma GCC diagnostic ignored "-Wint-in-bool-context" +#pragma GCC diagnostic ignored "-Wbool-compare" +#endif MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a * b); MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b); @@ -575,7 +574,6 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a)); MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b)); MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b)); - MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b))); MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b))); @@ -728,6 +726,10 @@ MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1)); MXNET_BINARY_MATH_OP(rminus, b - a); +MXNET_BINARY_MATH_OP_NC(posone, 1); + +MXNET_BINARY_MATH_OP_NC(negone, -1); + MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b)); template<> @@ -795,6 +797,73 @@ struct mod : public mxnet_op::tunable { } }; +struct mixed_mod { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return mod::Map(static_cast(a), b); + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return mod::Map(static_cast(a), b); + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return mod::Map(static_cast(a), b); + } +}; + +struct mixed_rmod { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return mod::Map(b, static_cast(a)); + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return mod::Map(b, static_cast(a)); + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return mod::Map(b, static_cast(a)); + } +}; + +struct fmod : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (b == DType(0)) { + return DType(0); + } else { + return DType(::fmod(static_cast(a), static_cast(b))); + } + } +}; + +struct rfmod : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (a == DType(0)) { + return DType(0); + } else { + return DType(::fmod(static_cast(b), static_cast(a))); + } + } +}; template<> MSHADOW_XINLINE mshadow::half::half2_t mod::Map diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index f7dbce270c87..0c2092f55d41 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -849,7 +849,13 @@ struct op_with_req { KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value)); } -#ifndef _WIN32 + /*! \brief input is two tensors with different type and with a boolean output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, bool *out, const LType *lhs, const RType *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + /*! \brief inputs are two tensors with a half_t output tensor */ template::value, int>::type = 0> @@ -903,7 +909,6 @@ struct op_with_req { MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); } -#endif /*! \brief inputs are two tensors with a float output tensor */ template& inputs, int j = 0; for (idim = 0; idim < ndim_iter; ++idim) { if (op_axes_arrays[i][idim] == -1 || - opshape[i][op_axes_arrays[i][idim]] == 1) { + (iop != nop && opshape[i][op_axes_arrays[i][idim]] == 1 && + op_axes_arrays[iop][idim] != -1 && + opshape[iop][op_axes_arrays[iop][idim]] != 1)) { remainstride[iop][j++] = iterstride[iop][idim]; } else { opstride[iop][op_axes_arrays[i][idim]] = iterstride[iop][idim]; diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc index 74db52d33f03..a54c7cac51a3 100644 --- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -206,7 +206,8 @@ struct TVMBinaryBroadcastScalarCompute { // scalar param type_codes[1] = kDLFloat; - values[1].v_float64 = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + values[1].v_float64 = param.scalar; // output tensor type_codes[2] = kTVMDLTensorHandle; @@ -225,9 +226,7 @@ struct TVMBinaryBroadcastScalarCompute { NNVM_REGISTER_OP(_npi_##name##_scalar) \ .set_num_inputs(1) \ .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ + .set_attr_parser(ParamParser) \ .set_attr("FListInputNames", \ [](const NodeAttrs& attrs) { \ return std::vector{"data"}; \ @@ -240,7 +239,7 @@ struct TVMBinaryBroadcastScalarCompute { }) \ .set_attr("FGradient", MakeZeroGradNodes) \ .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("scalar", "float", "scalar input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal); @@ -285,9 +284,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal); #else -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ - NNVM_REGISTER_OP(_npi_##name##_scalar) \ - .set_attr("FCompute", BinaryScalarOp::ComputeLogic) +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", BinaryScalarOp::ComputeLogic) \ + .set_attr("FResourceRequest", [](const NodeAttrs& n) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) #endif // MXNET_USE_TVM_OP diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index ae285caa9094..ae6697c0b23e 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -23,27 +23,26 @@ * \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator. */ -#include #include "./np_elemwise_broadcast_op.h" namespace mxnet { namespace op { -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +DMLC_REGISTER_PARAMETER(NumpyBinaryScalarParam); + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, @@ -61,7 +60,6 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, return true; } -#ifndef _WIN32 #define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(2) \ @@ -76,68 +74,62 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}, {1, 0}}; \ }) \ - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") -#else -#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(2) \ - .set_num_outputs(1) \ - .set_attr("FListInputNames", \ + .set_attr("FResourceRequest", \ [](const NodeAttrs& attrs) { \ - return std::vector{"lhs", "rhs"}; \ - }) \ - .set_attr("FInferShape", BinaryBroadcastShape) \ - .set_attr("FInferType", NumpyBinaryMixedPrecisionType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}, {1, 0}}; \ + return std::vector{ResourceRequest::kTempSpace}; \ }) \ - .set_attr("FResourceRequest", \ - [](const NodeAttrs& attrs) { \ - return std::vector{ResourceRequest::kTempSpace}; \ - }) \ .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") -#endif MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool) -#endif -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_add) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastCompute) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute) -#endif -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_sub) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool) -#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"}); NNVM_REGISTER_OP(_backward_npi_broadcast_mul) @@ -155,21 +147,33 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mul) .set_attr("FCompute", NumpyBinaryBackwardUseIn); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"}); +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_mod) +.set_attr( + "FCompute", + NumpyBinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mod"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_mod) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_power) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool) -#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_power"}); NNVM_REGISTER_OP(_backward_npi_broadcast_power) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 1e0130494469..a2927cda61ff 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -29,59 +29,50 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_add) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); -#endif + +NNVM_REGISTER_OP(_backward_npi_broadcast_add) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_subtract) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastCompute); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute); -#endif + +NNVM_REGISTER_OP(_backward_npi_broadcast_sub) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_multiply) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); -#endif NNVM_REGISTER_OP(_backward_npi_broadcast_mul) .set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_mod) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr( + "FCompute", + NumpyBinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_broadcast_mod) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_power) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); -#endif NNVM_REGISTER_OP(_backward_npi_broadcast_power) .set_attr("FCompute", NumpyBinaryBackwardUseIn* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return in_attrs->at(0) != -1; -} - inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) { LOG(FATAL) << "Operator " << op_name << " does not support combination of " << mshadow::dtype_string(dtype1) << " with " << mshadow::dtype_string(dtype2) << " yet..."; } -#ifndef _WIN32 template void MixedAllRealBinaryElemwiseCompute(const std::string& op_name, const OpContext& ctx, @@ -153,7 +142,6 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, const TBlob& lhs = inputs[0]; const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { if (lhs.type_flag_ == out.type_flag_) { MixedAllRealBinaryElemwiseCompute(attrs.op->name, ctx, lhs, rhs, out, req[0]); @@ -227,13 +215,9 @@ void MixedAllRealBinaryBroadcastCompute(const std::string& op_name, } }); } -#endif -#ifndef _WIN32 + template -#else -template -#endif void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -248,11 +232,9 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; -#ifndef _WIN32 mxnet::TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); - if (!ndim) { MixedBinaryElemwiseCompute(attrs, ctx, inputs, req, outputs); } else { @@ -290,47 +272,34 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, }); } }); + } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } } else { PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } } -#else - mshadow::Stream *s = ctx.get_stream(); - if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { - TBlob temp_tblob; - // one is float, the other is bool - CHECK((out.type_flag_ == lhs.type_flag_) || (out.type_flag_ == rhs.type_flag_)) - << "This case out type should be same as the float type"; - if (lhs.type_flag_ == out.type_flag_) { - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else { - PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); - } -#endif } -#ifndef _WIN32 template -#else -template -#endif void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -352,18 +321,10 @@ void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, return; } -#ifndef _WIN32 MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#else - MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#endif } -#ifndef _WIN32 template -#else -template -#endif void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -384,12 +345,31 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, BinaryBroadcastComputeWithBool(attrs, ctx, inputs, req, outputs); return; } - -#ifndef _WIN32 + if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { + Stream *s = ctx.get_stream(); + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + return; + } MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#else - MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#endif } template diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index 52d681885d82..60b721e4854d 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -30,21 +30,19 @@ namespace mxnet { namespace op { -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign) .describe(R"code()code" ADD_FILELINE) @@ -87,9 +85,7 @@ NNVM_REGISTER_OP(_npi_lcm) NNVM_REGISTER_OP(_npi_lcm_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -98,7 +94,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_npi_bitwise_and) @@ -122,9 +118,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and) NNVM_REGISTER_OP(_npi_bitwise_and_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -133,7 +127,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); NNVM_REGISTER_OP(_npi_bitwise_xor) @@ -175,9 +169,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or) NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -186,15 +178,13 @@ NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); NNVM_REGISTER_OP(_npi_bitwise_or_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -203,7 +193,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) @@ -275,14 +265,14 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); @@ -363,13 +353,13 @@ NNVM_REGISTER_OP(_backward_npi_ldexp) mshadow_op::ldexp_rgrad>); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); } // namespace op diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 593a698dc93a..a64fb4db12b6 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -917,7 +917,7 @@ void NumpyConcatenateForward(const nnvm::NodeAttrs& attrs, ConcatParam cparam; cparam.num_args = param.num_args; cparam.dim = param.axis.has_value() ? param.axis.value() : 0; - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { ConcatOp op; op.Init(cparam); op.Forward(ctx, data, req, outputs); @@ -950,7 +950,7 @@ void NumpyConcatenateBackward(const nnvm::NodeAttrs& attrs, ConcatParam cparam; cparam.num_args = param.num_args; cparam.dim = param.axis.has_value() ? param.axis.value() : 0; - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { ConcatOp op; op.Init(cparam); op.Backward(ctx, inputs[0], req, data); diff --git a/src/operator/numpy/np_true_divide-inl.h b/src/operator/numpy/np_true_divide-inl.h index e7a1c193d97f..d2dbc82b3e9b 100644 --- a/src/operator/numpy/np_true_divide-inl.h +++ b/src/operator/numpy/np_true_divide-inl.h @@ -29,6 +29,7 @@ #include #include "../../common/utils.h" #include "../tensor/elemwise_binary_broadcast_op.h" +#include "../numpy/np_elemwise_broadcast_op.h" namespace mxnet { namespace op { @@ -46,7 +47,8 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, using namespace mxnet_op; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; if (out.type_flag_ == data.type_flag_) { @@ -57,8 +59,7 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, }); }); } else { -#ifndef _WIN32 - CHECK(out.type_flag_ == mshadow::kFloat32 || out.type_flag_ == mshadow::kFloat64) + CHECK_EQ(out.type_flag_, mxnet::common::GetDefaultDtype()) << "true_divide only supports float32 and float64" " output when input's dtype is " << type_string(inputs[0].type_flag_); @@ -71,13 +72,6 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, }); }); }); -#else - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(data.Size()), s); - TBlob temp_tblob(temp_tensor); - CastCompute(attrs, ctx, {data}, {kWriteTo}, {temp_tblob}); - TrueDivideScalarCompute(attrs, ctx, {temp_tblob}, req, outputs); -#endif } } @@ -119,12 +113,10 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, }); } } else { -#ifndef _WIN32 - // Non-windows case: no usage of temporary space // Case when types of the 2 input tensors are different if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { // both lhs and rhs are float types, output type is the more precise one - LOG(ERROR) << "not implemented yet..."; + LOG(FATAL) << "not implemented yet..."; } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { // one is float type, the other is integer type, the output type should be the same as float CHECK_EQ(out.type_flag_, @@ -153,46 +145,8 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, } } else { // lhs is integer type, rhs is integer type, output type should be float - LOG(ERROR) << "not implemented yet..."; + LOG(FATAL) << "not implemented yet..."; } -#else - // Windows case: using temp space for casting the type - // Case when types of the 2 input tensors are different - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { - // both lhs and rhs are float types, output type is the more precise one - LOG(ERROR) << "not implemented yet..."; - } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { - // lhs is float type, rhs is integer type, the output type should be the same as lhs - CHECK_EQ(out.type_flag_, - common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_) - << "This case out type should be same as the float type"; - TBlob temp_tblob; - if (common::is_float(lhs.type_flag_)) { - // lhs is the float one - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - TrueDivideElemwiseCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - // rhs is the float one - MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - TrueDivideElemwiseCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else { - // lhs is integer type, rhs is integer type, output type should be float - LOG(ERROR) << "not implemented yet..."; - } -#endif } } @@ -216,7 +170,6 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, const TBlob& lhs = inputs[0]; const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; -#ifndef _WIN32 BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = calc_stride(new_lshape.get()); @@ -244,7 +197,7 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, } else { if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { // lhs and rhs have different float types, the output is the more precise one - LOG(ERROR) << "not implemented yet..."; + LOG(FATAL) << "not implemented yet..."; } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { // one of lhs and rhs is float, the output is the same type as the float one if (common::is_float(lhs.type_flag_)) { @@ -272,74 +225,10 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, } } else { // lhs and rhs have different integer types, the output is float type - LOG(ERROR) << "not implemented yet..."; + LOG(FATAL) << "not implemented yet..."; } } }); -#else - if (lhs.type_flag_ == rhs.type_flag_) { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = calc_stride(new_lshape.get()); - mshadow::Shape rstride = calc_stride(new_rshape.get()); - // When the both inputs have the same data types - if (common::is_float(lhs.type_flag_)) { - // If both inputs are the same float types, output is the same float type - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, { - Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - lhs.dptr(), rhs.dptr(), out.dptr()); - }); - } else { - CHECK_EQ(out.type_flag_, mshadow::kFloat32) - << "true_divide only supports float32 output when input's dtype is " - << type_string(lhs.type_flag_); - MXNET_INT_TYPE_SWITCH(lhs.type_flag_, DType, { - // If both inputs are the same integer types, output is float type - Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - lhs.dptr(), rhs.dptr(), out.dptr()); - }); - } - }); - } else { - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { - // lhs and rhs have different float types, the output is the more precise one - LOG(ERROR) << "not implemented yet..."; - } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { - // one of lhs and rhs is float, the output is the same type as the float one - TBlob temp_tblob; - if (common::is_float(lhs.type_flag_)) { - // lhs is float type, output will be the same float type - CHECK_EQ(lhs.type_flag_, out.type_flag_) - << "lhs should have the same type as out, infer type broken?"; - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - TrueDivideBroadcastCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - // rhs is float type, output will be the same float type - CHECK_EQ(rhs.type_flag_, out.type_flag_) - << "rhs should have the same type as out, infer type broken?"; - MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - TrueDivideBroadcastCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else { - // lhs and rhs have different integer types, the output is float type - LOG(ERROR) << "not implemented yet..."; - } - } -#endif } } diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 6edfb4dd0901..2f3bf7d5dfb6 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -74,46 +74,46 @@ NNVM_REGISTER_OP(_npi_true_divide) [](const NodeAttrs& attrs){ return std::vector >{{0, 0}, {1, 0}}; }) -#ifdef _WIN32 +.set_attr("FCompute", TrueDivideBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"}) +.add_argument("lhs", "NDArray-or-Symbol", "Dividend array") +.add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); + + +NNVM_REGISTER_OP(_backward_npi_broadcast_div) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -#endif -.set_attr("FCompute", TrueDivideBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"}) -.add_argument("lhs", "NDArray-or-Symbol", "Dividend array") -.add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); +.set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", TrueDivideType<1>) .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) -#ifdef _WIN32 -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -#endif .set_attr("FCompute", TrueDivideScalarCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "float", "scalar input"); +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()); NNVM_REGISTER_OP(_npi_rtrue_divide_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", TrueDivideType<1>) .set_attr("FInplaceOption", @@ -129,7 +129,7 @@ NNVM_REGISTER_OP(_npi_rtrue_divide_scalar) .set_attr("FCompute", TrueDivideScalarCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_rdiv_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "float", "scalar input"); +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu index 7211f4a0a006..c8eccfe140b4 100644 --- a/src/operator/numpy/np_true_divide.cu +++ b/src/operator/numpy/np_true_divide.cu @@ -31,6 +31,10 @@ namespace op { NNVM_REGISTER_OP(_npi_true_divide) .set_attr("FCompute", TrueDivideBroadcastCompute); +NNVM_REGISTER_OP(_backward_npi_broadcast_div) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); + NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_attr("FCompute", TrueDivideScalarCompute); diff --git a/src/operator/numpy/np_unique_op.cc b/src/operator/numpy/np_unique_op.cc index 2f57733a72b2..7a299cdd5221 100644 --- a/src/operator/numpy/np_unique_op.cc +++ b/src/operator/numpy/np_unique_op.cc @@ -375,6 +375,7 @@ NNVM_REGISTER_OP(_npi_unique) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "The input array") .add_arguments(NumpyUniqueParam::__FIELDS__()); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 0cc0dc92f884..de6efbfe5758 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -421,6 +421,8 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::posone); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::negone); // NOLINT() /*! * \brief Tuner objects, *not* automatically generated */ diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index ffd0f123070a..774c87afaf3c 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -227,8 +227,8 @@ struct binary_broadcast_kernel { } } -#ifndef _WIN32 /*! \brief Map function for binary_broadcast_kernel */ + /* used for mixed type binary ops */ template::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, @@ -249,6 +249,7 @@ struct binary_broadcast_kernel { } /*! \brief Map function for binary_broadcast_kernel */ + /* used for mixed type binary ops */ template::value && !std::is_pointer::value, int>::type = 0> @@ -268,7 +269,6 @@ struct binary_broadcast_kernel { KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); } } -#endif }; template diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index 4eaaff09d83d..7d65c89f7b0f 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -29,6 +29,7 @@ #include #include #include +#include #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "elemwise_unary_op.h" @@ -36,6 +37,45 @@ namespace mxnet { namespace op { +struct NumpyBinaryScalarParam : public dmlc::Parameter { + double scalar; + bool is_int; + DMLC_DECLARE_PARAMETER(NumpyBinaryScalarParam) { + DMLC_DECLARE_FIELD(scalar) + .set_default(1) + .describe("Scalar input value"); + DMLC_DECLARE_FIELD(is_int) + .set_default(true) + .describe("Indicate whether scalar input is int type"); + } + + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream scalar_s, is_int_s; + scalar_s << scalar; + is_int_s << is_int; + (*dict)["scalar"] = scalar_s.str(); + (*dict)["is_int"] = is_int_s.str(); + } +}; + +inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + if (common::is_int(in_attrs->at(0)) && !scalar_is_int) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } else if (in_attrs->at(0) == mshadow::kBool) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, scalar_is_int ? mshadow::kInt64 : mshadow::kFloat64); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } + return out_attrs->at(0) != -1; +} + class BinaryScalarOp : public UnaryOp { /*! \brief Tensor operation against a scalar with a dense result */ template @@ -45,7 +85,8 @@ class BinaryScalarOp : public UnaryOp { const NDArray &input, const OpReqType req, const NDArray &output) { - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; CHECK_EQ(output.shape(), input.shape()); const int64_t row_count = output.shape()[0]; const int64_t items_per_row = output.shape().Size() / row_count; @@ -137,7 +178,8 @@ class BinaryScalarOp : public UnaryOp { const NDArray &output) { CHECK_EQ(output.shape(), input.shape()); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; const DType dense_fill_val = OP::Map(DType(0), DType(alpha)); const TBlob column_indexes = input.aux_data(csr::kIdx); const size_t item_count = column_indexes.Size(); @@ -236,11 +278,23 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + TBlob temp_tblob; + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + const double alpha = param.scalar; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if ((common::is_int(inputs[0].type_flag_) && !scalar_is_int) || + (inputs[0].type_flag_ == kBool)) { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); + temp_tblob = TBlob(temp_tensor); + CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); + } else { + temp_tblob = inputs[0]; + } MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + s, inputs[0].Size(), outputs[0].dptr(), temp_tblob.dptr(), DType(alpha)); }); }); } @@ -256,7 +310,8 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( @@ -276,12 +331,23 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); - }); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + const double alpha = param.scalar; + TBlob temp_tblob; + if (common::is_int(inputs[0].type_flag_) && !scalar_is_int) { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); + temp_tblob = TBlob(temp_tensor); + CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); + } else { + temp_tblob = inputs[0]; + } + MSHADOW_TYPE_SWITCH_WITH_BOOL(temp_tblob.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), temp_tblob.dptr(), DType(alpha)); + }); }); } @@ -345,7 +411,8 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet::op::mxnet_op::Kernelparsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", ElemwiseType<1, 1>) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +#define MXNET_OPERATOR_REGISTER_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc index 13014b35c2fe..dc11593255e7 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc @@ -28,22 +28,24 @@ #include "./elemwise_binary_scalar_op.h" #define MXNET_OPERATOR_REGISTER_BINARY_WITH_SCALAR_SUPPORT_WITH_DENSE_RESULT(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", ElemwiseType<1, 1>) \ - .set_attr("FInferStorageType", \ - BinaryScalarStorageTypeWithDenseResultStorageType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInferStorageType", \ + BinaryScalarStorageTypeWithDenseResultStorageType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) namespace mxnet { namespace op { @@ -65,7 +67,8 @@ static bool BinaryScalarStorageTypeWithDenseResultStorageType(const NodeAttrs& a const NDArrayStorageType instype = static_cast(in_attrs->at(0)); const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx; - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; if (common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { dispatched = storage_type_assign(&out_attrs[0], kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); @@ -189,8 +192,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rdiv_scalar) .add_alias("_RDivScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rdiv_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rdiv_grad>); @@ -200,8 +203,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_mod_scalar) .add_alias("_ModScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_mod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::mod_grad>); @@ -211,8 +214,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rmod_scalar) .add_alias("_RModScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rmod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rmod_grad>); diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc index 7dd8cf41c59e..c08cdcdf0dc8 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc @@ -36,8 +36,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_maximum_scalar) .add_alias("_npi_maximum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_maximum_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) @@ -47,8 +47,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) .add_alias("_npi_minimum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_minimum_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_power_scalar) @@ -57,8 +57,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_power_scalar) .add_alias("_PowerScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_power_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::power_grad>); @@ -69,8 +69,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rpower_scalar) .add_alias("_RPowerScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rpower_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rpower_grad>); @@ -82,8 +82,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_hypot_scalar) .add_alias("_HypotScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_hypot_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::hypot_grad_left>); @@ -109,13 +109,7 @@ Example:: )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - if (attrs->dict.find("scalar") != attrs->dict.end()) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - } else { - attrs->parsed = 1.0; - } - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FInplaceOption", @@ -128,13 +122,7 @@ Example:: .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_smooth_l1" }); MXNET_OPERATOR_REGISTER_BINARY(_backward_smooth_l1) - .set_attr_parser([](NodeAttrs *attrs) { - if (attrs->dict.find("scalar") != attrs->dict.end()) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - } else { - attrs->parsed = 1.0; - } -}) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc index 17e76153ebb2..0594997877d3 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc @@ -46,7 +46,8 @@ static bool BinaryScalarLogicStorageType(const nnvm::NodeAttrs& attrs, const auto in_stype = in_attrs->at(0); auto &out_stype = out_attrs->at(0); bool dispatched = false; - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; bool is_sparse = OP::Map(static_cast(0), alpha) == 0; if (!dispatched && in_stype == kDefaultStorage) { // dns -> dns diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index 204fde6bd2bd..836bcd1fc11e 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -25,6 +25,7 @@ import mxnet as mx from mxnet import gluon, autograd, np from mxnet.test_utils import use_np, assert_almost_equal, check_gluon_hybridize_consistency +from mxnet.gluon import nn from common import with_seed import random @@ -422,6 +423,55 @@ def hybrid_forward(self, F, valid_length): assert mx.test_utils.same(out1.asnumpy(), out2.asnumpy()) +@with_seed() +@use_np +def test_activations_leakyrelu(): + # Currently, all the activation tests, we will just test for runnable. + act_layer = nn.LeakyReLU(0.1) + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_prelu(): + act_layer = nn.PReLU() + act_layer.initialize() + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_elu(): + act_layer = nn.ELU(1.0) + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_selu(): + act_layer = nn.SELU() + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_gelu(): + act_layer = nn.GELU() + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_swish(): + act_layer = nn.Swish() + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 32c8521afd51..20da12b12f48 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2548,8 +2548,9 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): assert y.shape == np_out.shape assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=rtol, atol=atol, use_broadcast=False, equal_nan=True) - if lgrad: + if (ltype in itypes) and (rtype in itypes): + continue y.backward() if ltype not in itypes: assert_almost_equal(mx_test_x1.grad.asnumpy(), @@ -2573,10 +2574,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): use_broadcast=False, equal_nan=True) funcs = { - 'add': (-1.0, 1.0, None, None), - 'subtract': (-1.0, 1.0, None, None), + 'add': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape), + lambda y, x1, x2: _np.ones(y.shape)), + 'subtract': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape), + lambda y, x1, x2: _np.ones(y.shape) * -1), 'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape), lambda y, x1, x2: _np.broadcast_to(x1, y.shape)), + 'mod': (1.0, 5.0, None, None), 'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2, lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)), } @@ -2590,8 +2594,6 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): ((2, 3), ()), ((), (2, 3))] - itypes = [np.bool, np.int8, np.int32, np.int64] - ftypes = [np.float16, np.float32, np.float64] for func, func_data in funcs.items(): low, high, lgrad, rgrad = func_data for lshape, rshape in shape_pairs: @@ -2604,6 +2606,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): continue check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + if func == 'subtract' or func == 'mod': + continue + for type1, type2 in itertools.product(itypes, itypes): + if type1 == type2: + continue + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + @with_seed() @use_np @@ -3145,52 +3154,56 @@ def get_new_shape(shape, axis): shape_lst[axis] = random.randint(0, 3) return tuple(shape_lst) - for shape in [(0, 0), (2, 3), (2, 1, 3)]: - for hybridize in [True, False]: - for axis in [0, 1, None]: - for grad_req in ['write', 'add', 'null']: - # test gluon - test_concat = TestConcat(axis=axis) - if hybridize: - test_concat.hybridize() + shapes = [(0, 0), (2, 3), (2, 1, 3)] + hybridizes = [True, False] + axes = [0, 1, None] + grad_reqs = ['write', 'add', 'null'] + dtypes = [np.float32, np.float64, np.bool] + combinations = itertools.product(shapes, hybridizes, axes, grad_reqs, dtypes) - grad_req_c = grad_req - grad_req_d = grad_req - if grad_req == 'null': - ide = random.randint(0, 2) - grad_req_c = 'write' if ide == 0 else 'add' - grad_req_c = 'write' if ide == 1 else 'add' + for shape, hybridize, axis, grad_req, dtype in combinations: + # test gluon + test_concat = TestConcat(axis=axis) + if hybridize: + test_concat.hybridize() - a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - a.attach_grad(grad_req) - b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - b.attach_grad(grad_req) - c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - c.attach_grad(grad_req_c) - d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - d.attach_grad(grad_req_d) - expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) + grad_req_c = grad_req + grad_req_d = grad_req + if grad_req == 'null': + ide = random.randint(0, 2) + grad_req_c = 'write' if ide == 0 else 'add' + grad_req_c = 'write' if ide == 1 else 'add' - with mx.autograd.record(): - y = test_concat(a, b, c, d) + a = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype) + a.attach_grad(grad_req) + b = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype) + b.attach_grad(grad_req) + c = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype) + c.attach_grad(grad_req_c) + d = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype) + d.attach_grad(grad_req_d) + expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) - assert y.shape == expected_ret.shape - assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + with mx.autograd.record(): + y = test_concat(a, b, c, d) - y.backward() - if grad_req != 'null': - assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) - if grad_req != 'null': - assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) - if grad_req_c != 'null': - assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) - if grad_req_d != 'null': - assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) - # test imperative - mx_out = np.concatenate([a, b, c, d], axis=axis) - np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + y.backward() + if grad_req != 'null': + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + if grad_req != 'null': + assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) + if grad_req_c != 'null': + assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) + if grad_req_d != 'null': + assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_out = np.concatenate([a, b, c, d], axis=axis) + np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) @with_seed() @@ -3704,6 +3717,16 @@ def __init__(self, a_min=None, a_max=None): def hybrid_forward(self, F, x): return x.clip(self._a_min, self._a_max) + + # Test scalar case + for _, a_min, a_max, throw_exception in workloads: + a = _np.random.uniform() # A scalar + if throw_exception: + # No need to test the exception case here. + continue + mx_ret = np.clip(a, a_min, a_max) + np_ret = _np.clip(a, a_min, a_max) + assert_almost_equal(mx_ret, np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) for shape, a_min, a_max, throw_exception in workloads: for dtype in dtypes: @@ -6605,7 +6628,7 @@ def hybrid_forward(self, F, a): ((5, 3, 4), True, True, True, 1), ] for dtype in ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64']: - for hybridize in [False]: + for hybridize in [False, True]: for config in configs: test_unique = TestUnique(*config[1:]) if hybridize: @@ -6998,6 +7021,26 @@ def dbg(name, data): # broadcast bug ('ij, ij -> i', [(1, 4), (2, 4)], lambda *args: (_np.sum(args[1], axis=0)[None, :], _np.tile(args[0], [2, 1]))), + # one dimensim bug + ('...ij, ...jk -> ...ik', [(1, 4), (4, 2)], lambda *args: (args[1].sum(axis=1)[None, :], + _np.tile(args[0].sum(axis=0)[: ,None], [1, 2]))), + ('...ij, ...jk -> ...ik', [(2, 4), (4, 2)], lambda *args: (_np.tile(args[1].sum(axis=1)[None, :], [2, 1]), + _np.tile(args[0].sum(axis=0)[: ,None], [1, 2]))), + ('...ij, ...jk -> ...ik', [(3, 2, 1, 4), (3, 2, 4, 2)], lambda *args: ( + args[1].sum(axis=3)[:, :, None, :], + _np.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))), + ('...ij, ...ik -> ...jk', [(1, 1, 1, 4), (1, 1, 1, 3)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, :, None], [1, 1, 1, 4]), + _np.tile(args[0].sum(axis=3)[:, :, : ,None], [1, 1, 1, 3]))), + ('...ij, ...jc -> ...ic', [(1, 1, 5, 3), (1, 1, 3, 2)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), + _np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))), + ('...ij, ...jc -> ...ic', [(1, 2, 5, 4), (1, 2, 4, 2)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), + _np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))), + ('...ij, ...jc -> ...ic', [(2, 1, 5, 4), (2, 1, 4, 2)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), + _np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))), # issue #16576 # commented due to long running time # ('abiz,abjz->abij', [(64, 8, 128, 512), (64, 8, 128, 512)], lambda *args: (_np.matmul(_np.ones((64, 8, 128, 128)), args[1]), diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index a54bcf3e16a9..2aabfbf9b24b 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -440,9 +440,7 @@ def check_cse_on_symbol(sym, expected_savings, check_data, **kwargs): arr2 = mx.random.uniform(shape=shape) arr3 = mx.random.uniform(shape=shape) - check_cse_on_symbol((a+5) + (a+5), expected_savings=1, check_data=True, a=arr1, b=arr2) check_cse_on_symbol((a+1) + (a+2), expected_savings=0, check_data=True, a=arr1, b=arr2) - check_cse_on_symbol((1+a) + (a+1), expected_savings=1, check_data=True, a=arr1, b=arr2) check_cse_on_symbol((a+b) + (a+b), expected_savings=1, check_data=True, a=arr1, b=arr2) check_cse_on_symbol(((a+b)+c) +((a+b)+c), expected_savings=2, check_data=True, a=arr1, b=arr2, c=arr3)