diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index 845ed35cf24f..1a4d5cec52e8 100644 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -272,7 +272,6 @@ extern "C" { } #include "./half.h" -#include "./half2.h" #include "./bfloat.h" #define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \ MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \ @@ -387,11 +386,6 @@ struct DataType { #endif }; template<> -struct DataType { - static const int kFlag = kFloat16; - static const int kLanes = 2; -}; -template<> struct DataType { static const int kFlag = kBfloat16; static const int kLanes = 1; @@ -1144,48 +1138,6 @@ struct minimum { } #endif -#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \ - switch (type) { \ - case mshadow::kFloat32: \ - { \ - typedef float DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat64: \ - { \ - typedef double DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat16: \ - { \ - typedef mshadow::half::half2_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kUint8: \ - { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kInt32: \ - { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kInt64: \ - { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - LOG(FATAL) << "Unknown type enum " << type; \ - } - #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \ switch (type) { \ case mshadow::kFloat32: \ diff --git a/3rdparty/mshadow/mshadow/half2.h b/3rdparty/mshadow/mshadow/half2.h deleted file mode 100644 index cecc5449383c..000000000000 --- a/3rdparty/mshadow/mshadow/half2.h +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2017 by Contributors - * \file half2.h - * \brief definition of vector float16, half2 type. - * - * \author Antti-Pekka Hynninen - */ -#ifndef MSHADOW_HALF2_H_ -#define MSHADOW_HALF2_H_ - -#if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) - #define MSHADOW_CUDA_HALF2 1 - #include -#else - #define MSHADOW_CUDA_HALF2 0 -#endif - -#include - -/*! \brief namespace for mshadow */ -namespace mshadow { -/* \brief name space for host/device portable half-precision floats */ -namespace half { - -#define MSHADOW_HALF2_ASSIGNOP(AOP, OP) \ - template \ - MSHADOW_XINLINE half2_t operator AOP (const T& a) { \ - return *this = half2_t(*this OP a); /* NOLINT(*)*/ \ - } \ - -class MSHADOW_ALIGNED(4) half2_t { - public: -#if MSHADOW_CUDA_HALF2 - half2 half2_; -#else - half_t half_t2[2]; -#endif - - MSHADOW_XINLINE half2_t() {} - -#if MSHADOW_CUDA_HALF2 - MSHADOW_XINLINE explicit half2_t(half2 a) : half2_(a) {} -#else - MSHADOW_XINLINE explicit half2_t(half_t a, half_t b) { - half_t2[0] = a; - half_t2[1] = b; - } -#endif - - MSHADOW_XINLINE explicit half2_t(int a) { -#if MSHADOW_CUDA_HALF2 - half2_ = __half2half2(__int2half_rz(a)); -#else - half_t2[0] = (half_t)a; - half_t2[1] = (half_t)a; -#endif - } - - MSHADOW_XINLINE half2_t operator+() { - return *this; - } - - MSHADOW_XINLINE half2_t operator-() { -#if MSHADOW_CUDA_HALF2 - return half2_t(__hneg2(half2_)); -#else - return half2_t(-half_t2[0], -half_t2[1]); -#endif - } - - MSHADOW_XINLINE half2_t operator=(const half2_t& a) { -#if MSHADOW_CUDA_HALF2 - half2_ = a.half2_; -#else - half_t2[0] = a.half_t2[0]; - half_t2[1] = a.half_t2[1]; -#endif - return a; - } - - MSHADOW_HALF2_ASSIGNOP(+=, +) - MSHADOW_HALF2_ASSIGNOP(-=, -) - MSHADOW_HALF2_ASSIGNOP(*=, *) - MSHADOW_HALF2_ASSIGNOP(/=, /) -}; - -/*! \brief overloaded + operator for half2_t */ -MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(__low2float(a.half2_) + __low2float(b.half2_), - __high2float(a.half2_) + __high2float(b.half2_))); -#else - return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]); -#endif -} -/*! \brief overloaded - operator for half2_t */ -MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(__low2float(a.half2_) - __low2float(b.half2_), - __high2float(a.half2_) - __high2float(b.half2_))); -#else - return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]); -#endif -} -/*! \brief overloaded * operator for half2_t */ -MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(__low2float(a.half2_) * __low2float(b.half2_), - __high2float(a.half2_) * __high2float(b.half2_))); -#else - return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]); -#endif -} -/*! \brief overloaded / operator for half2_t */ -MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(__low2float(a.half2_) / __low2float(b.half2_), - __high2float(a.half2_) / __high2float(b.half2_))); -#else - return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]); -#endif -} -/*! \brief overloaded % operator for half2_t */ -MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), __low2float(b.half2_)), - ::fmod(__high2float(a.half2_), __high2float(b.half2_)))); -#else - return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], b.half_t2[1])); -#endif -} -/*! \brief overloaded == operator for half2_t */ -MSHADOW_XINLINE bool operator==(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return __hbeq2(a.half2_, b.half2_); -#else - return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]); -#endif -} - -} // namespace half -} // namespace mshadow -#endif // MSHADOW_HALF2_H_ diff --git a/CMakeLists.txt b/CMakeLists.txt index b494fdd75ae7..814c8c99f65e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,7 +79,6 @@ option(USE_MXNET_LIB_NAMING "Use MXNet library naming conventions." ON) option(USE_GPROF "Compile with gprof (profiling) flag" OFF) option(USE_VTUNE "Enable use of Intel Amplifier XE (VTune)" OFF) # one could set VTUNE_ROOT for search path option(USE_TVM_OP "Enable use of TVM operator build system." OFF) -option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation support" ON) option(BUILD_CPP_EXAMPLES "Build cpp examples" ON) option(INSTALL_EXAMPLES "Install the example source files." OFF) option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." ON) @@ -547,18 +546,11 @@ if(USE_CUDA) string(REPLACE ";" " " CUDA_ARCH_FLAGS_SPACES "${CUDA_ARCH_FLAGS}") - find_package(CUDAToolkit REQUIRED cublas cufft cusolver curand - OPTIONAL_COMPONENTS nvToolsExt nvrtc) + find_package(CUDAToolkit REQUIRED cublas cufft cusolver curand nvrtc cuda_driver + OPTIONAL_COMPONENTS nvToolsExt) - list(APPEND mxnet_LINKER_LIBS CUDA::cudart CUDA::cublas CUDA::cufft CUDA::cusolver CUDA::curand) - if(ENABLE_CUDA_RTC) - if(CUDA_nvrtc_LIBRARY) - list(APPEND mxnet_LINKER_LIBS CUDA::nvrtc cuda) - add_definitions(-DMXNET_ENABLE_CUDA_RTC=1) - else() - message(FATAL_ERROR "ENABLE_CUDA_RTC=ON, but failed to find NVRTC. CMake will exit." ) - endif() - endif() + list(APPEND mxnet_LINKER_LIBS CUDA::cudart CUDA::cublas CUDA::cufft CUDA::cusolver CUDA::curand + CUDA::nvrtc CUDA::cuda_driver) list(APPEND SOURCE ${CUDA}) add_definitions(-DMXNET_USE_CUDA=1) diff --git a/ci/build_windows.py b/ci/build_windows.py index c8d3af515b5a..0a195b50f77a 100755 --- a/ci/build_windows.py +++ b/ci/build_windows.py @@ -61,7 +61,6 @@ class BuildFlavour(Enum): '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' - '-DENABLE_CUDA_RTC=OFF ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=open ' @@ -76,7 +75,6 @@ class BuildFlavour(Enum): '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' - '-DENABLE_CUDA_RTC=OFF ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=open ' @@ -91,7 +89,6 @@ class BuildFlavour(Enum): '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' - '-DENABLE_CUDA_RTC=OFF ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=mkl ' @@ -106,7 +103,6 @@ class BuildFlavour(Enum): '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' - '-DENABLE_CUDA_RTC=OFF ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=mkl ' @@ -121,7 +117,6 @@ class BuildFlavour(Enum): '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=ON ' '-DUSE_CUDNN=ON ' - '-DENABLE_CUDA_RTC=ON ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=open ' @@ -136,7 +131,6 @@ class BuildFlavour(Enum): '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=ON ' '-DUSE_CUDNN=ON ' - '-DENABLE_CUDA_RTC=ON ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=open ' diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index c0c1cd28b472..5bdd01e1b62f 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -142,7 +142,6 @@ build_jetson() { -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} \ -DUSE_CUDA=ON \ -DMXNET_CUDA_ARCH="5.2" \ - -DENABLE_CUDA_RTC=OFF \ -DUSE_OPENCV=OFF \ -DUSE_OPENMP=ON \ -DUSE_LAPACK=OFF \ @@ -670,27 +669,6 @@ build_ubuntu_gpu_cmake() { ninja } -build_ubuntu_gpu_cmake_no_rtc() { - set -ex - cd /work/build - CC=gcc-7 CXX=g++-7 cmake \ - -DUSE_SIGNAL_HANDLER=ON \ - -DUSE_CUDA=ON \ - -DUSE_CUDNN=ON \ - -DUSE_MKL_IF_AVAILABLE=OFF \ - -DUSE_MKLML_MKL=OFF \ - -DUSE_MKLDNN=ON \ - -DUSE_DIST_KVSTORE=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -DMXNET_CUDA_ARCH="$CI_CMAKE_CUDA_ARCH" \ - -DBUILD_CYTHON_MODULES=1 \ - -DENABLE_CUDA_RTC=OFF \ - -G Ninja \ - /work/mxnet - - ninja -} - build_ubuntu_cpu_large_tensor() { set -ex cd /work/build diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index ab2adbfb5346..85257b90c54a 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -258,20 +258,6 @@ def compile_unix_cmake_gpu(lib_name) { }] } -def compile_unix_cmake_gpu_no_rtc(lib_name) { - return ['GPU: CMake CUDA RTC OFF': { - node(NODE_LINUX_CPU) { - ws('workspace/build-cmake-gpu-no-rtc') { - timeout(time: max_time, unit: 'MINUTES') { - utils.init_git() - utils.docker_run('ubuntu_gpu_cu101', 'build_ubuntu_gpu_cmake_no_rtc', false) - utils.pack_lib(lib_name, mx_cmake_lib) - } - } - } - }] -} - def compile_unix_tensorrt_gpu(lib_name) { return ['TensorRT': { node(NODE_LINUX_CPU) { diff --git a/ci/jenkins/Jenkinsfile_unix_gpu b/ci/jenkins/Jenkinsfile_unix_gpu index 1fcdc96f46ac..6fbdc751ea90 100644 --- a/ci/jenkins/Jenkinsfile_unix_gpu +++ b/ci/jenkins/Jenkinsfile_unix_gpu @@ -41,7 +41,6 @@ core_logic: { custom_steps.compile_unix_cmake_gpu('cmake_gpu'), custom_steps.compile_unix_tensorrt_gpu('tensorrt'), custom_steps.compile_unix_int64_gpu('gpu_int64'), - custom_steps.compile_unix_cmake_gpu_no_rtc('gpu_no_rtc'), ]) utils.parallel_stage('Tests', [ diff --git a/config/darwin.cmake b/config/darwin.cmake index a65509f0ba1c..59f031e49f01 100644 --- a/config/darwin.cmake +++ b/config/darwin.cmake @@ -126,5 +126,4 @@ set(USE_INT64_TENSOR_SIZE OFF CACHE BOOL "Use int64_t to represent the total num # Other GPU features set(USE_NCCL "Use NVidia NCCL with CUDA" OFF) set(NCCL_ROOT "" CACHE BOOL "NCCL install path. Supports autodetection.") -set(ENABLE_CUDA_RTC ON CACHE BOOL "Build with CUDA runtime compilation support") set(USE_NVTX ON CACHE BOOL "Build with NVTX support") diff --git a/config/linux.cmake b/config/linux.cmake index 84eecc2e9701..ff338231e277 100644 --- a/config/linux.cmake +++ b/config/linux.cmake @@ -125,5 +125,4 @@ set(USE_INT64_TENSOR_SIZE OFF CACHE BOOL "Use int64_t to represent the total num # Other GPU features set(USE_NCCL "Use NVidia NCCL with CUDA" OFF) set(NCCL_ROOT "" CACHE BOOL "NCCL install path. Supports autodetection.") -set(ENABLE_CUDA_RTC ON CACHE BOOL "Build with CUDA runtime compilation support") set(USE_NVTX ON CACHE BOOL "Build with NVTX support") diff --git a/config/linux_gpu.cmake b/config/linux_gpu.cmake index 0dad43332978..442ac6cb3578 100644 --- a/config/linux_gpu.cmake +++ b/config/linux_gpu.cmake @@ -125,5 +125,4 @@ set(USE_INT64_TENSOR_SIZE OFF CACHE BOOL "Use int64_t to represent the total num # Other GPU features set(USE_NCCL "Use NVidia NCCL with CUDA" OFF) set(NCCL_ROOT "" CACHE BOOL "NCCL install path. Supports autodetection.") -set(ENABLE_CUDA_RTC ON CACHE BOOL "Build with CUDA runtime compilation support") set(USE_NVTX ON CACHE BOOL "Build with NVTX support") diff --git a/docs/python_docs/python/tutorials/extend/index.rst b/docs/python_docs/python/tutorials/extend/index.rst index aaa243e4a692..d516b52d4dd6 100644 --- a/docs/python_docs/python/tutorials/extend/index.rst +++ b/docs/python_docs/python/tutorials/extend/index.rst @@ -53,6 +53,12 @@ The following tutorials will help you learn how to customize MXNet. How to create new MXNet operators in MXNet's backend using C++. An example custom quadratic function op. + .. card:: + :title: Using runtime compilation (RTC) to write CUDA kernels in MXNet + :link: /api/faq/using_rtc + + How to write CUDA kernels in MXNet using runtime compilation. + .. toctree:: :hidden: @@ -61,3 +67,4 @@ The following tutorials will help you learn how to customize MXNet. * New Operator Creation New Operator in MXNet Backend + Using RTC for CUDA kernels diff --git a/docs/static_site/src/pages/api/faq/add_op_in_backend.md b/docs/static_site/src/pages/api/faq/add_op_in_backend.md index 19e55ec432cf..f8b8a0d8f8b1 100644 --- a/docs/static_site/src/pages/api/faq/add_op_in_backend.md +++ b/docs/static_site/src/pages/api/faq/add_op_in_backend.md @@ -721,3 +721,4 @@ and ## Additional Resources - [Use TensorInspector to Help Debug Operators](./tensor_inspector_tutorial) +- [Use RTC to write CUDA kernels](./using_rtc) diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index 65ce0b84202b..b28e27b02700 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -390,10 +390,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. - It works in Symbolic execution as well as in Gluon models hybridized with ```static_alloc=True``` option. - Only applies to MXNet that has been compiled with CUDA (```pip install mxnet-cuXX``` or built from source with ```USE_CUDA=1```) and running on GPU. -* MXNET_FUSION_VERBOSE +* MXNET_RTC_VERBOSE - Values: 0(false) or 1(true) ```(default=0)``` - - Only applies to MXNet that has been compiled with CUDA and when ```MXNET_USE_FUSION``` option is enabled. - - If this variable is set, MXNet will print the code for fused operators that it generated. + - Only applies to MXNet that has been compiled with CUDA. + - If this variable is set, MXNet will print the code for operators compiled at runtime. * MXNET_ELIMINATE_COMMON_EXPR - Values: 0(false) or 1(true) ```(default=1)``` diff --git a/docs/static_site/src/pages/api/faq/using_rtc.md b/docs/static_site/src/pages/api/faq/using_rtc.md new file mode 100644 index 000000000000..6a772ee3c7f9 --- /dev/null +++ b/docs/static_site/src/pages/api/faq/using_rtc.md @@ -0,0 +1,465 @@ +--- +layout: page_category +title: Using runtime compilation (RTC) to write CUDA kernels in MXNet +category: faq +faq_c: Extend and Contribute to MXNet +question: How do I implement GPU functions in MXNet using RTC? +permalink: /api/faq/using_rtc +--- + + + + + + + + + + + + + + + + + +# Using runtime compilation (RTC) to write CUDA kernels in MXNet + +## Introduction + +CUDA kernel is a function running on the GPU to perform computation. This tutorial assumes the +reader has a basic knowledge about how to write such kernels. + +There are currently 2 typical ways of writing and launching CUDA kernels in MXNet. The first one is +to use the `Kernel<...>::Launch()` API, which is suitable for simple elementwise operations and +enables writing only portion of the kernel, leaving the launch mechanism to MXNet. The +other one is to write a kernel from scratch and launch it using the `<<<...>>>` method from CUDA. +Starting from MXNet 2.0, there is a third option - runtime compilation (RTC). This differs from the +previous methods (which use kernels compiled ahead of time), as it compiles the needed kernels +during runtime of the user script. + +In this tutorial we will cover the reasons for using RTC instead of the other methods, show how to +do it, as well as tips on what to keep in mind when doing it. + +## Why RTC? + +### Problems with kernels compiled ahead of time + +The use of kernels compiled ahead of time in MXNet leads to a few problems, which unfortunately +are mostly invisible in any single PR, but grow over the course of many contributions and result in +serious issues. + +In order to understand them, let us look at the typical way kernels are launched in MXNet. This +example shows a launch of the simple kernel, taking a single input of type `DType` and producing +single output of type `OType`: + +```cpp +MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + Kernel<...>::Launch(s, inputs[0].dptr(), outputs[0].dptr()); + }); +}); +``` + +This launch mechanism uses the `MSHADOW_TYPE_SWITCH` macro, which produces a version of the kernel +for every possible type. In the case of nested usage (as is the case in the example shown) it +produces a version of the kernel for every combination of types. This results in a large number of +kernels being generated. + +Another factor that multiplies the number of kernels is that different GPU architectures require +different compiled binaries. Therefore for MXNet to support all of them with a single binary, that +binary needs to contain copies of those kernels for each architecure. + +This proliferation of CUDA kernels in the binary leads to multiple issues. The first problem is the +size of the MXNet library - each compiled version of the kernel takes some space in the binary, +which is small but multiplied by the number of all versions (which could reach thousands per +GPU architecture) and GPU architectures. This increase in size led to multiple issues reported with +distribution of the MXNet package, +[building the library](https://github.com/apache/incubator-mxnet/issues/17045) as well as +[limiting the number of architectures natively +supported](https://github.com/apache/incubator-mxnet/pull/18205). + +The second issue is the "idle" memory consumption of the MXNet library. In order to efficiently +launch kernels when they are called, CUDA driver needs to transfer them to the GPU memory ahead of +time. Since it cannot anticipate which kernels will actually be used, all of the kernels are +transferred when the CUDA context is created on a GPU. This means that, even if a user never uses +e.g. kernel which adds `int8` and `float16` tensors, that kernel still occupies memory on their GPU, +reducing the amount of memory available for useful work. + +The third issue, mostly affecting MXNet developers, is the compilation time of the MXNet library. +The more kernels versions need to be compiled, the more time and hardware resources is needed. + +### RTC to the rescue! + +All of the issues mentioned in the previous paragraph are solved when using runtime compilation. +Using this paradigm, only the kernels actually invoked in the user script are compiled. They do not +occupy space in the MXNet binary and there is no unused kernels stored in users' GPU memory. + +RTC also enables more features: + + - using more information about specific usage of the kernel when compiling it (e.g. using shape + information of the inputs) to optimize it better + - writing kernels accepting any combinations of input and output types + - (in the future) fusing more operations into the generated kernels. + +## RTC for kernel developers + +### Example: unary operators + +Let us start with an example of the simple kernel written using RTC: a kernel which performs unary +operation (with a concrete example of sigmoid) on its input. It is not a toy example though: it is +a fully generic kernel, capable of operating on any combination of input and output types, as well +as applying any unary operator: + +```cpp +struct UnaryRTCCompute { + std::string OP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +const char unary_kernel_fwd[] = R"code( + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void unary_kernel(const InputType* input, + const OutputType* output, + const index_t N) { + using IType = AccType; + using OType = AccType; + + for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < N; + tid += gridDim.x * blockDim.x) { + const auto input = IType::from(input[i]); + const auto temp = OP(input); // enables returning different type + + if (req == OpReqType::kAddTo) { + // temp2 may have a wider type than either temp + // or OType + const auto temp2 = op::add(temp, OType::from(output[i])); + output[i] = OType::to(temp2); + } else { + output[i] = OType::to(temp); + } + } +} + +)code"; + +void UnaryRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet::common::cuda::rtc; + if (req[0] == kNullOp) return; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + + const std::string code = std::string("const OpReqType req = ") + + util::to_string(req[0]) + + ";\n" + "#define OP op::" + + OP + + "\n" + "using InputType = " + + common::mshadow_type_info(inputs[0].type_flag_).name + + ";\n" + "using OutputType = " + + common::mshadow_type_info(outputs[0].type_flag_).name + + ";\n"; + + std::vector args; + const index_t size = outputs[0].Size(); + args.emplace_back(&(inputs[0].dptr_)); + args.emplace_back(&(outputs[0].dptr_)); + args.emplace_back(&size); + + auto kernel = get_function(code, "unary_kernel", unary_kernel_fwd, + ctx.run_ctx.get_ctx().dev_id); + + const int n_threads = 512; + const index_t n_blocks = (size + n_threads - 1) / n_threads; + const int shared_memory_size = 0; + launch(kernel, {n_blocks, 1, 1}, {512, 1, 1}, + shared_memory_size, s, &args); +} + +NNVM_REGISTER_OP(sigmoid) +.set_attr("FCompute", UnaryRTCCompute{"sigmoid"}); +``` + +### Kernels are text... + +The main difference when writing kernels using RTC is that the kernel code becomes the text string. +This means that it is possible to change or compose the code at runtime, as is done here: + +```cpp + const std::string code = std::string("const OpReqType req = ") + + util::to_string(req[0]) + + ";\n" + "#define OP op::" + + OP + + "\n" + "using InputType = " + + common::mshadow_type_info(inputs[0].type_flag_).name + + ";\n" + "using OutputType = " + + common::mshadow_type_info(outputs[0].type_flag_).name + + ";\n"; +``` + +where the operation `OP` is also provided as a string in the operator declaration: + +```cpp +NNVM_REGISTER_OP(sigmoid) +.set_attr("FCompute", UnaryRTCCompute{"sigmoid"}); +``` + +### and do not know MXNet source code + +How does the kernel know what operation it should perform? The kernel's source code uses `OP`, +which shows up in the `code` variable and is equal to `op::sigmoid`. Let us compare this to how the +same operator is defined for CPU: + +```cpp +MXNET_OPERATOR_REGISTER_UNARY(sigmoid) +.set_attr("FCompute", UnaryOp::Compute) +``` + +Since the kernel is compiled at runtime, it does not have access to the rest of the MXNet source +code, including `mshadow_op.h`, which defined `mshadow_op::sigmoid`. This means that we need to +provide the kernel with definitions of those functions (again, in text string form). Every +RTC-compiled kernel is prepended with a common header, containing string found in +`src/common/cuda/rtc/` directory. The `src/common/cuda/rtc/forward_functions-inl.h` file contains +the definition of `op::sigmoid`: + +```cpp +template +__device__ inline DType sigmoid(const DType val) { + if (type_util::has_double_or_integral::value) { + return 1./(1 + ::exp(-val)); + } else { + return 1.f/(1 + expf(-val)); + } +} +``` + +### Handling of data types + +MXNet has support for many datatypes. Some of those datatypes, like `float16`, `int8` or `bool` are +useful when storing the results, but in many computations they are too limiting as they can easily +overflow in the intermediate stages. That is why in the example we use `AccType` class - it +provides an accumulation type, that is potentially larger than the storage type - for example, +`AccType::type` is `float32`. It also provides special loading and storing functions: +`AccType::from()` and `AccType::to()`. + +One of the features of RTC-enabled kernels is to be able to accommodate any combination of the +input and output datatypes. Using `auto` as the output type of the intermediate steps helps with, +especially since many binary operators return a mixed type: + +```cpp +template +__device__ inline typename type_util::mixed_type::type +add(const DType a, const DType2 b) { + return a + b; +} +``` + +`mixed_type::type` is a type capable of storing value of the operation between 2 types `T` and +`U` - e.g. `mixed_type::type = float64` and `mixed_type::type = +float32`. + +### Compiling and launching RTC kernels + +The kernel code stored in `unary_kernel_fwd` is generic and relies on multiple names to be defined, +like `req`, `OP` or `InputType`. This is handled in the specific operator using the kernel by +defining a set of parameters that will be concatenated to the code during compilation: + +```cpp + const std::string code = std::string("const OpReqType req = ") + + util::to_string(req[0]) + + ";\n" + "#define OP op::" + + OP + + "\n" + "using InputType = " + + common::mshadow_type_info(inputs[0].type_flag_).name + + ";\n" + "using OutputType = " + + common::mshadow_type_info(outputs[0].type_flag_).name + + ";\n"; +``` + +In order to compile the kernel, the `mxnet::common::cuda::rtc::get_function` method is used: + +```cpp + auto kernel = get_function(code, "unary_kernel", unary_kernel_fwd, + ctx.run_ctx.get_ctx().dev_id); +``` + +In order to eliminate overheads coming from the compilation, it uses cache of kernels, with a key +being the name of the kernel (`"unary_kernel"` in our case) and the set of parameters (`code` in our +case). If the kernel is already in cache, it is returned, otherwise compilation takes place. If it +fails, the full source code is saved to disk and the MXNet error with the compilation log is +generated. + +To launch the kernel, the `mxnet::common::cuda::rtc::launch` method is used: + +```cpp + launch(kernel, {n_blocks, 1, 1}, {512, 1, 1}, + shared_memory_size, s, &args); +``` + +It takes the kernel object, grid and block dimensions, size of dynamic shared memory, stream and +kernel parameters. + +## Other features enabled by RTC + +### Vectorization + +The actual kernel used for application of unary operator in MXNet looks slightly different compared +to the simple example shown in the previous paragraph. Differences come from using vectorization. +This means, that instead of reading (or writing) 1 element at a time, kernel instead accesses +multiple array elements at once. This is beneficial, especially when dealing with smaller +types like `float16` or `int8`. Accessing those small types one by one is inefficient and does not +saturate the memory bandwidth of the GPU, so using vector accesses improves achieved memory +bandwidth. + +```cpp + +// excerpt from src/operator/tensor/elemwise_unary_op.h +struct UnaryRTCCompute { + std::string OP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +// excerpt from src/operator/tensor/elemwise_unary_op.cc +struct unary_kernel_params { + const void *inputs[1]; + void *outputs[1]; +}; + +const char unary_kernel_fwd[] = R"code( + +struct unary_kernel_params { + const void *inputs[1]; + void *outputs[1]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void unary_kernel(const unary_kernel_params params, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + VectorizedLoader loader( + reinterpret_cast(params.inputs[0]), N); + VectorizedStorer storer( + reinterpret_cast(params.outputs[0]), N); + + using IType = AccType; + using OType = AccType; + + const index_t M = num_aligned_elements; + + for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + loader.load(tid, N); + if (req == OpReqType::kAddTo) { + storer.load(tid, N); + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const auto input = IType::from(loader.separate()[i]); + const auto temp = OP(input); // enables returning different type + + if (req == OpReqType::kAddTo) { + // temp2 may have a wider type than either temp + // or OType + const auto temp2 = op::add(temp, OType::from(storer.separate()[i])); + storer.separate()[i] = OType::to(temp2); + } else { + storer.separate()[i] = OType::to(temp); + } + } + storer.store(tid, N); + } +} + +)code"; + +void UnaryRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet::common::cuda::rtc; + if (req[0] == kNullOp) return; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + + const std::string code = std::string("const OpReqType req = ") + + util::to_string(req[0]) + + ";\n" + "#define OP op::" + + OP + + "\n"; + const int nvec = outputs[0].type_flag_ == mshadow::kFloat64 ? 2 : 4; + + const index_t size = outputs[0].Size(); + unary_kernel_params params = { {inputs[0].dptr_}, + {outputs[0].dptr_} }; + + VectorizedKernelRTCLauncher(code, "unary_kernel", + unary_kernel_fwd, nvec, + size, 1, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id); +} + +// excerpt from src/operator/tensor/elemwise_unary_op_basic.cu +NNVM_REGISTER_OP(sigmoid) +.set_attr("FCompute", UnaryRTCCompute{"sigmoid"}); +``` + +RTC implementation in MXNet provides a few useful helper functions and classes, which simplify the +process of writing and launching kernels using vectorization. For accessing the memory using +vectorization, 2 classes are provided, used in this kernel to access input and output array: + +```cpp + VectorizedLoader loader( + reinterpret_cast(params.inputs[0]), N); + VectorizedStorer storer( + reinterpret_cast(params.outputs[0]), N); +``` + +The `loader` object accesses `params.inputs[0]` pointer to array of N elements having type +`InputType0` (which is the name assigned to the type of the first input by the +`VectorizedKernelRTCLauncher`, which is the helper launcher function). It loads `nvec` elements at +a time and has additional `aligned` option, which is also set by the `VectorizedKernelRTCLauncher`. +Similarly `storer` object is used to write data of type `OutputType0` to `params.outputs[0]`. + +The kernel using `VectorizedKernelRTCLauncher` needs to have specific parameters: + +```cpp +__global__ void unary_kernel(const unary_kernel_params params, // kernel-specific parameters + const index_t lead_dim, // lead dimension of the tensor + const index_t other_dim, // size of the other dimensions + const index_t N, // total number of elements + const index_t num_aligned_elements) { // number of vector elements in + // lead dimension +``` diff --git a/include/mxnet/libinfo.h b/include/mxnet/libinfo.h index dd7790059de1..1aa6b9ef6b21 100644 --- a/include/mxnet/libinfo.h +++ b/include/mxnet/libinfo.h @@ -70,10 +70,6 @@ #define MXNET_USE_CUSOLVER MSHADOW_USE_CUSOLVER #endif -#ifndef MXNET_ENABLE_CUDA_RTC -#define MXNET_ENABLE_CUDA_RTC 0 -#endif - /*! \brief Error message for using gpu when MXNET_USE_CUDA==0 */ #define MXNET_GPU_NOT_ENABLED_ERROR "GPU is not enabled" @@ -142,7 +138,6 @@ enum : unsigned { CUDA = 0, CUDNN, NCCL, - CUDA_RTC, TENSORRT, // CPU Features / optimizations diff --git a/include/mxnet/rtc.h b/include/mxnet/rtc.h index 76c3064db71a..747c0b5c94ab 100644 --- a/include/mxnet/rtc.h +++ b/include/mxnet/rtc.h @@ -20,7 +20,7 @@ #ifndef MXNET_RTC_H_ #define MXNET_RTC_H_ #include "./base.h" -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA #include #include @@ -132,5 +132,5 @@ class CudaModule { } // namespace rtc } // namespace mxnet -#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#endif // MXNET_USE_CUDA #endif // MXNET_RTC_H_ diff --git a/python/mxnet/contrib/amp/lists/symbol_fp16.py b/python/mxnet/contrib/amp/lists/symbol_fp16.py index 506757307352..5858eb9ff092 100644 --- a/python/mxnet/contrib/amp/lists/symbol_fp16.py +++ b/python/mxnet/contrib/amp/lists/symbol_fp16.py @@ -184,8 +184,6 @@ '_sample_poisson', '_sample_uniform', '_sample_unique_zipfian', - '_scatter_minus_scalar', - '_scatter_plus_scalar', '_scatter_set_nd', '_set_value', '_shuffle', @@ -508,7 +506,6 @@ '_Mul', '_Div', '_div', - '_scatter_elemwise_div', '_Mod', '_Not_Equal', '_Equal', diff --git a/python/mxnet/runtime.py b/python/mxnet/runtime.py index 28525ae65edf..a8cac42fb7c7 100644 --- a/python/mxnet/runtime.py +++ b/python/mxnet/runtime.py @@ -37,7 +37,7 @@ True print(features) - [✖ CUDA, ✖ CUDNN, ✖ NCCL, ✖ CUDA_RTC, ✖ TENSORRT, ✔ CPU_SSE, ✔ CPU_SSE2, ✔ CPU_SSE3, + [✖ CUDA, ✖ CUDNN, ✖ NCCL, ✖ TENSORRT, ✔ CPU_SSE, ✔ CPU_SSE2, ✔ CPU_SSE3, ✔ CPU_SSE4_1, ✔ CPU_SSE4_2, ✖ CPU_SSE4A, ✔ CPU_AVX, ✖ CPU_AVX2, ✔ OPENMP, ✖ SSE, ✔ F16C, ✔ JEMALLOC, ✔ BLAS_OPEN, ✖ BLAS_ATLAS, ✖ BLAS_MKL, ✖ BLAS_APPLE, ✔ LAPACK, ✖ MKLDNN, ✔ OPENCV, ✖ DIST_KVSTORE, ✖ INT64_TENSOR_SIZE, ✔ SIGNAL_HANDLER, ✔ DEBUG, ✖ TVM_OP] diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 23049f1b8867..faa030dcb459 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -3248,24 +3248,24 @@ int MXRtcCudaModuleCreate(const char* source, int num_options, const char** options, int num_exports, const char** exports, CudaModuleHandle *out) { API_BEGIN(); -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA std::vector str_opts; for (int i = 0; i < num_options; ++i) str_opts.emplace_back(options[i]); std::vector str_exports; for (int i = 0; i < num_exports; ++i) str_exports.emplace_back(exports[i]); *out = new rtc::CudaModule(source, str_opts, str_exports); #else - LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation."; + LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation."; #endif API_END(); } int MXRtcCudaModuleFree(CudaModuleHandle handle) { API_BEGIN(); -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA delete reinterpret_cast(handle); #else - LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation."; + LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation."; #endif API_END(); } @@ -3274,7 +3274,7 @@ int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, int num_arg int* is_ndarray, int* is_const, int* arg_types, CudaKernelHandle *out) { API_BEGIN(); -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA auto module = reinterpret_cast(handle); std::vector signature; for (int i = 0; i < num_args; ++i) { @@ -3285,17 +3285,17 @@ int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, int num_arg auto kernel = module->GetKernel(name, signature); *out = new std::shared_ptr(kernel); #else - LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation."; + LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation."; #endif API_END(); } int MXRtcCudaKernelFree(CudaKernelHandle handle) { API_BEGIN(); -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA delete reinterpret_cast*>(handle); #else - LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation."; + LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation."; #endif API_END(); } @@ -3306,7 +3306,7 @@ int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args, uint32_t block_dim_y, uint32_t block_dim_z, uint32_t shared_mem) { API_BEGIN(); -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA auto kernel = reinterpret_cast*>(handle); const auto& signature = (*kernel)->signature(); std::vector any_args; @@ -3322,7 +3322,7 @@ int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args, (*kernel)->Launch(Context::GPU(dev_id), any_args, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem); #else - LOG(FATAL) << "Compile with USE_CUDA=1 and ENABLE_CUDA_RTC=1 to have CUDA runtime compilation."; + LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation."; #endif API_END(); } diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc new file mode 100644 index 000000000000..8f3b3391f5e4 --- /dev/null +++ b/src/common/cuda/rtc.cc @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "mxnet/base.h" + +#if MXNET_USE_CUDA + +#include + +#include +#include +#include +#include +#include + +#include "rtc.h" +#include "rtc/half-inl.h" +#include "rtc/util-inl.h" +#include "rtc/forward_functions-inl.h" +#include "rtc/backward_functions-inl.h" +#include "rtc/vectorization-inl.h" +#include "rtc/special_functions-inl.h" +#include "rtc/reducer-inl.h" +#include "utils.h" + + +namespace mxnet { +namespace common { +namespace cuda { +namespace rtc { + +std::mutex lock; + +namespace util { + +std::string to_string(OpReqType req) { + switch (req) { + case kNullOp: + return "OpReqType::kNullOp"; + case kWriteTo: + case kWriteInplace: + return "OpReqType::kWriteTo"; + case kAddTo: + return "OpReqType::kAddTo"; + } + LOG(FATAL) << "Unrecognized req."; + return ""; +} + +} // namespace util + +namespace { + +// Obtain compilation log from the program. +std::string GetCompileLog(nvrtcProgram program) { + size_t log_size_including_null; + NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size_including_null)); + std::string log(log_size_including_null - 1, '\0'); + // Room for terminating null character ensured since C++11 + NVRTC_CALL(nvrtcGetProgramLog(program, &log[0])); + return log; +} + +// Obtain compilation result (ptx assembly) from the program. +std::string GetPtx(nvrtcProgram program) { + size_t ptx_size_including_null; + NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size_including_null)); + std::string ptx(ptx_size_including_null - 1, '\0'); + // Room for terminating null character ensured since C++11 + NVRTC_CALL(nvrtcGetPTX(program, &ptx[0])); + return ptx; +} + +} // namespace + +CUfunction get_function(const std::string ¶meters, + const std::string &kernel_name, + const std::string &code, + int dev_id) { + constexpr int CACHESIZE_WARN_THRESHOLD = 10000; + std::lock_guard l(lock); + // Local class for value type of compile cache + struct KernelInfo { + std::string mangled_name; + std::string ptx; + std::vector functions; + }; + // Maps from the kernel name and parameters to the ptx and jit-compiled CUfunctions. + using KernelCache = std::unordered_map; + // Per-gpu-architecture compiled kernel cache with jit-compiled function for each device context + static std::unordered_map compiled_kernels; + int sm_arch = SMArch(dev_id); + // make null map as needed + KernelCache& compiled_kernels_this_arch = compiled_kernels[sm_arch]; + // make KernelInfo as needed + KernelInfo& kinfo = compiled_kernels_this_arch[parameters + kernel_name]; + if (kinfo.ptx.size() == 0) { + // It's the first time we've seen this kernel, so we need to generate the ptx and mangled_name. + static std::string common_header = + std::string(fp16_support_string) + "\n" + + type_support_string + "\n" + + util_string + "\n" + + special_functions_definitions + '\n' + + vectorization_support_string + "\n" + + function_definitions_util + "\n" + + function_definitions_binary + "\n" + + function_definitions_unary + "\n" + + backward_function_definitions + "\n" + + reducer + "\n"; + std::string code_with_header = common_header + parameters + code; + // If verbose mode, output kernel source, though not including the common header + if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) { + LOG(INFO) << "\n" << std::string(80, '-') << "\n" << (parameters + code); + } + if (compiled_kernels_this_arch.size() == CACHESIZE_WARN_THRESHOLD + 1 && + dmlc::GetEnv("MXNET_RTC_SIZE_WARNING", true)) { + LOG(WARNING) << "The number of different compiled kernels exceeds " + << CACHESIZE_WARN_THRESHOLD + << ". Set MXNET_RTC_SIZE_WARNING=0 to quiet this warning."; + } + nvrtcProgram program; + NVRTC_CALL(nvrtcCreateProgram(&program, // prog + &code_with_header[0], // buffer + (kernel_name + "_kernel.cu").c_str(), // name + 0, // num headers + nullptr, // headers + nullptr)); // include names + + std::string gpu_arch_arg = "--gpu-architecture=compute_" + std::to_string(sm_arch); + const char *opts[] = {gpu_arch_arg.c_str(), +#if NDEBUG == 0 + "-G", +#endif + "--std=c++14"}; + const std::string kernel_name_demangled = kernel_name; + NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str())); + + nvrtcResult compileResult = nvrtcCompileProgram(program, // prog + sizeof(opts) / sizeof(opts[0]), // num options + opts); // options + static const std::string dump_file = "mxnet_rtc_debug_code.log"; + if (compileResult != NVRTC_SUCCESS) { + std::ofstream f(dump_file); + f << code_with_header; + f.close(); + } + CHECK_EQ(compileResult, NVRTC_SUCCESS) + << "NVRTC Compilation failed.\n" + << "The generated code was stored in " << dump_file << "\n" + << GetCompileLog(program); + + kinfo.ptx = GetPtx(program); + const char *mangled_name; + NVRTC_CALL(nvrtcGetLoweredName(program, + kernel_name_demangled.c_str(), + &mangled_name)); + kinfo.mangled_name = mangled_name; + // Destroy the program. + NVRTC_CALL(nvrtcDestroyProgram(&program)); + } + // Ensure function array is deep enough to index by dev_id + while (kinfo.functions.size() <= static_cast(dev_id)) + kinfo.functions.push_back(static_cast(nullptr)); + // Jit-compile ptx for the device as needed + if (kinfo.functions[dev_id] == static_cast(nullptr)) { + // Make sure driver context is set to the proper device + CUdevice cu_device; + CUcontext context; + CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id)); + CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device)); + // Jit-compile ptx for the driver's current context + CUmodule module; + +#if NDEBUG == 0 + intptr_t debug_info = 1; + intptr_t line_info = 1; +#else + intptr_t debug_info = 0; + intptr_t line_info = 0; +#endif + + CUjit_option jit_opts[] = {CU_JIT_GENERATE_DEBUG_INFO, CU_JIT_GENERATE_LINE_INFO}; + void* jit_opt_values[] = {reinterpret_cast(debug_info), + reinterpret_cast(line_info)}; + + CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, kinfo.ptx.c_str(), 2, jit_opts, jit_opt_values)); + CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id], + module, + kinfo.mangled_name.c_str())); + } + return kinfo.functions[dev_id]; +} + +void launch(CUfunction function, + const dim3 grid_dim, + const dim3 block_dim, + unsigned int shared_mem_bytes, + mshadow::Stream *stream, + std::vector *args) { + CHECK(args->size() != 0) << + "Empty argument list passed to a kernel."; + // CUDA_DRIVER_CALL( + CUresult err = cuLaunchKernel(function, // function to launch + grid_dim.x, grid_dim.y, grid_dim.z, // grid dim + block_dim.x, block_dim.y, block_dim.z, // block dim + shared_mem_bytes, // shared memory + mshadow::Stream::GetStream(stream), // stream + const_cast(args->data()), // arguments + nullptr); // ); + if (err != CUDA_SUCCESS) { + const char* error_string; + cuGetErrorString(err, &error_string); + LOG(FATAL) << "cuLaunchKernel failed: " + << err << " " << error_string << ": " + << reinterpret_cast(function) << " " + << "(" << grid_dim.x << ", " << grid_dim.y << ", " << grid_dim.z << ") " + << "(" << block_dim.x << ", " << block_dim.y << ", " << block_dim.z << ") " + << shared_mem_bytes << " " + << args->size(); + } +} + +} // namespace rtc +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA diff --git a/src/common/cuda/rtc.h b/src/common/cuda/rtc.h new file mode 100644 index 000000000000..126c967a0cb3 --- /dev/null +++ b/src/common/cuda/rtc.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file cuda_rtc.h + * \brief Common CUDA utilities for + * runtime compilation. + */ + +#ifndef MXNET_COMMON_CUDA_RTC_H_ +#define MXNET_COMMON_CUDA_RTC_H_ + +#include "mxnet/base.h" +#include "mxnet/op_attr_types.h" + +#if MXNET_USE_CUDA + +#include +#include + +#include +#include +#include + +namespace mxnet { +namespace common { +namespace cuda { +namespace rtc { + +namespace util { + +/*! \brief Convert OpReqType to string. + * \param req to convert + */ +std::string to_string(OpReqType req); + +} // namespace util + +extern std::mutex lock; + +/*! \brief Compile and get the GPU kernel. Uses cache in order to + * eliminate the overhead of compilation. + * \param parameters of the kernel (e.g. values of the template arguments, types used) + * \param kernel_name name of the kernel + * \param code used for compilation of the kernel if not found in cache + * \param dev_id id of the device which the kernel will be launched on + */ +CUfunction get_function(const std::string ¶meters, + const std::string &kernel_name, + const std::string &code, + int dev_id); + +/*! \brief Launch a GPU kernel. + * \param function to launch + * \param grid_dim grid dimensions + * \param block_dim block dimensions + * \param shared_mem_bytes amount of dynamic shared memory needed by the kernel + * \param stream used for launching the kernel + * \param args arguments of the kernel + */ +void launch(CUfunction function, + const dim3 grid_dim, + const dim3 block_dim, + unsigned int shared_mem_bytes, + mshadow::Stream *stream, + std::vector *args); + +} // namespace rtc +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA + +#endif // MXNET_COMMON_CUDA_RTC_H_ diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h new file mode 100644 index 000000000000..168dc686e7ad --- /dev/null +++ b/src/common/cuda/rtc/backward_functions-inl.h @@ -0,0 +1,480 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_COMMON_CUDA_RTC_BACKWARD_FUNCTIONS_INL_H_ +#define MXNET_COMMON_CUDA_RTC_BACKWARD_FUNCTIONS_INL_H_ + +#if MXNET_USE_CUDA + +namespace mxnet { +namespace common { +namespace cuda { +namespace rtc { + +const char backward_function_definitions[] = R"code( + +namespace op { + +template +__device__ inline typename type_util::mixed_type::type +backward_relu(const DTypeGrad grad, const DType val) { + if (isnan(val)) return val; + return val > 0 ? grad : 0; +} + +template +__device__ inline typename type_util::mixed_type::type +backward_sigmoid(const DTypeGrad grad, const DType out) { + return grad * out * (1 - out); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_softrelu(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return grad * sigmoid(v); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_softsign(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + const auto ap1 = 1 + op::abs(v); + return grad / (ap1 * ap1); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_abs(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return grad * op::sign(v); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_exp(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return grad * op::exp(v); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_expm1(const DTypeGrad grad, const DType val) { + return backward_exp(grad, val); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_log(const DTypeGrad grad, const DType val) { + return grad / val; +} + +template +__device__ inline typename type_util::mixed_type::type +backward_log10(const DTypeGrad grad, const DType val) { + return grad / (val * op::log(static_cast(10))); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_log2(const DTypeGrad grad, const DType val) { + return grad / (val * op::log(static_cast(2))); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_log1p(const DTypeGrad grad, const DType val) { + return grad / (1 + val); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_sin(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return grad * op::cos(v); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_cos(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return -grad * op::sin(v); +} + +// Uses output from tan +template +__device__ inline typename type_util::mixed_type::type +backward_tan(const DTypeGrad grad, const DType out) { + return grad * (out * out + 1); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_arcsin(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return grad / op::sqrt(1 - v*v); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_arccos(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return -grad / op::sqrt(1 - v*v); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_arctan(const DTypeGrad grad, const DType val) { + return grad / (1 + val*val); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_degrees(const DTypeGrad grad, const DType /* val */) { + return op::degrees(grad); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_radians(const DTypeGrad grad, const DType /* val */) { + return op::radians(grad); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_sinh(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return grad * op::cosh(v); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_cosh(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return grad * op::sinh(v); +} + +// Uses tanh output +template +__device__ inline typename type_util::mixed_type::type +backward_tanh(const DTypeGrad grad, const DType out) { + return grad * (1 - out * out); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_arcsinh(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return grad / op::sqrt(v * v + 1); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_arccosh(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + return grad / op::sqrt(v * v - 1); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_arctanh(const DTypeGrad grad, const DType val) { + return grad / (1 - val * val); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_sqrt(const DTypeGrad grad, const DType out) { + return 0.5 * grad / out; +} + +template +__device__ inline typename type_util::mixed_type::type +backward_rsqrt(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + const auto inv = 1 / v; + return -0.5 * grad * op::sqrt(inv) * inv; +} + +template +__device__ inline typename type_util::mixed_type::type +backward_cbrt(const DTypeGrad grad, const DType out) { + return grad / (3.0f * out * out); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_rcbrt(const DTypeGrad grad, const DType val) { + const typename type_util::mixed_type::type v = val; + const auto inv = 1 / v; + return -1.f/3.f * grad * op::cbrt(inv) * inv; +} + +template +__device__ inline typename type_util::mixed_type::type +backward_square(const DTypeGrad grad, const DType val) { + return 2 * val * grad; +} + +template +__device__ inline typename type_util::mixed_type::type +rdiv_grad(const DType val, + const DType2 val2) { + return -val2 / (val * val); +} + +template +__device__ inline typename type_util::mixed_type::type +div_grad(const DType val, + const DType2 val2) { + const typename type_util::mixed_type::type temp = val2; + return op::reciprocal(temp); +} + +template +__device__ inline DType div_rgrad(const DType val, + const DType2 val2) { + return -val / (val2 * val2); +} + +template +__device__ inline DType mod_grad(const DType val, + const DType2 val2) { + if (type_util::is_integral::value) { + return 0; + } else { + return 1; + } +} + +template +__device__ inline DType mod_rgrad(const DType val, + const DType2 val2) { + if (type_util::is_integral::value) { + return 0; + } else { + return -op::floor(val / val2); + } +} + +template +__device__ inline DType rmod_grad(const DType val, + const DType2 val2) { + if (type_util::is_integral::value) { + return 0; + } else { + return -op::floor(val2 / val); + } +} + +template +__device__ inline typename type_util::mixed_type::type +power_grad(const DType val, + const DType2 val2) { + return op::power(val, val2 - 1.f) * val2; +} + +template +__device__ inline typename type_util::mixed_type::type +power_rgrad(const DType val, + const DType2 val2) { + const typename type_util::mixed_type::type temp = val; + return op::power(val, val2) * op::log(temp); +} + +template +__device__ inline typename type_util::mixed_type::type +rpower_grad(const DType val, + const DType2 val2) { + const typename type_util::mixed_type::type temp = val2; + return val * op::log(temp); +} + +template +__device__ inline typename type_util::mixed_type::type +hypot_grad_left(const DType val, + const DType2 val2) { + return val / op::hypot(val, val2); +} + +template +__device__ inline typename type_util::mixed_type::type +hypot_grad_right(const DType val, + const DType2 val2) { + return val2 / op::hypot(val, val2); +} + +template +__device__ inline typename type_util::mixed_type::type +copysign_grad(const DType val, + const DType2 val2) { + return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1; +} + +template +__device__ inline typename type_util::mixed_type::type +arctan2_grad(const DType val, + const DType2 val2) { + return val2 / (val * val + val2 * val2); +} + +template +__device__ inline typename type_util::mixed_type::type +rarctan2_grad(const DType val, + const DType2 val2) { + return val / (val * val + val2 * val2); +} + +template +__device__ inline typename type_util::mixed_type::type +arctan2_rgrad(const DType val, + const DType2 val2) { + return -rarctan2_grad(val, val2); +} + +template +__device__ inline typename type_util::mixed_type::type +ldexp_grad(const DType val, + const DType2 val2) { + return op::power(static_cast(2), val2); +} + +template +__device__ inline typename type_util::mixed_type::type +rldexp_grad(const DType val, + const DType2 val2) { + using mixed_type = typename type_util::mixed_type::type; + return val2 * op::power(static_cast(2), val) * op::log(static_cast(2)); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_clip(const DTypeGrad grad, const DType val, + const float a_min, const float a_max) { + if (val > a_max || val < a_min) { + return 0; + } else { + return grad; + } +} + +template +__device__ inline typename type_util::mixed_type::type +backward_reciprocal(const DTypeGrad grad, const DType val) { + return -grad / (val * val); +} + +template +__device__ inline typename type_util::mixed_type::type +backward_erf(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type v = val; + constexpr mixed_type my_pi = pi; + return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad; +} + +template +__device__ inline typename type_util::mixed_type::type +backward_erfinv(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + constexpr mixed_type my_pi = pi; + const mixed_type g = grad; + const mixed_type v = val; + return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g; +} + +template +__device__ inline typename type_util::mixed_type::type +backward_gamma(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type v = val; + if (type_util::is_same::value) { + return grad * op::gamma(v) * op::special_functions::cephes::psi(v); + } else { + return grad * op::gamma(v) * op::special_functions::cephes::psi(v); + } +} + +template +__device__ inline typename type_util::mixed_type::type +backward_gammaln(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type v = val; + if (type_util::is_same::value) { + return grad * op::special_functions::cephes::psi(v); + } else { + return grad * op::special_functions::cephes::psi(v); + } +} + +template +__device__ inline typename type_util::mixed_type::type +backward_digamma(const DTypeGrad grad, const DType val) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type v = val; + if (type_util::is_same::value) { + return grad * op::special_functions::trigamma(v); + } else { + return grad * op::special_functions::trigamma(v); + } +} + +template +__device__ inline typename type_util::mixed_type::type +backward_gelu(const DTypeGrad grad, const DType val) { + return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) + + val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f)); +} + +template +__device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) { + auto bsq = scalar * scalar; + auto ibsq = 1.0f / bsq; + if (val > ibsq) { + return 1; + } else if (val < -ibsq) { + return -1; + } else { + return bsq * val; + } +} + +template +__device__ inline DType2 xelu_grad(const DType val, + const DType2 val2) { + return (val > 0) ? 1 : val2; +} + +template +__device__ inline DType prelu_grad(const DType val, + const DType2 val2) { + return (val > 0) ? 0 : val; +} + +} // namespace op + +)code"; + +} // namespace rtc +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA + +#endif // MXNET_COMMON_CUDA_RTC_BACKWARD_FUNCTIONS_INL_H_ diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h new file mode 100644 index 000000000000..14ee83cd0759 --- /dev/null +++ b/src/common/cuda/rtc/forward_functions-inl.h @@ -0,0 +1,917 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_COMMON_CUDA_RTC_FORWARD_FUNCTIONS_INL_H_ +#define MXNET_COMMON_CUDA_RTC_FORWARD_FUNCTIONS_INL_H_ + +#if MXNET_USE_CUDA + +namespace mxnet { +namespace common { +namespace cuda { +namespace rtc { + +const char function_definitions_util[] = R"code( + +#define INT_MAX (2147483647) + +namespace op { + +template +struct LoadType { + using Type = DType; +}; + +template <> +struct LoadType { + using Type = float; +}; + +template +__device__ inline typename LoadType::Type load(const DType input) { + return input; +} + +template <> +__device__ inline float load(const half input) { + return __half2float(input); +} + +template +__device__ inline DType1 store(const DType2 input, DType1* ref) { + return input; +} + +template +__device__ inline half store(const DType input, half* ref) { + return __float2half(input); +} + +template +struct Shape { + int x[ndim]; + size_t size; + __device__ inline const int& operator [](const int i) const { + return x[i]; + } + __device__ inline int& operator [](const int i) { + return x[i]; + } + __device__ inline void set(const int def) { + #pragma unroll + for (int i = 0; i < ndim; i++) { + x[i] = def; + } + } +}; + +template <> +struct Shape<0> { + size_t size; +}; + +template +__device__ inline vector::VectorizedStorage load_index(const DType * input, int i, + const Shape &shape) { + using V = vector::VectorizedStorage; + if (i < shape.size) { + const auto* vector_input = reinterpret_cast(input + i); + return V(*vector_input); + } else { + return V({0}); + } +} + +template +__device__ inline vector::VectorizedStorage global_load_index(const DType * input, + int i, const Shape &shape) { + using V = vector::VectorizedStorage; + if (i < shape.size) { + const auto* vector_input = reinterpret_cast(input + i); + return V(__ldg(vector_input)); + } else { + return V({0}); + } +} + +template +__device__ inline vector::VectorizedStorage load_slice(const DType * input, + const Shape& shape, + Shape begin, + Shape end, + int offset) { + int idx[nvec]; + + Shape ref_strides; + Shape strides; + ref_strides[ndim-1] = 1; + strides[ndim-1] = 1; + #pragma unroll + for (int dim = ndim-1; dim >=0; dim--) { + if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim]; + if (end[dim] < 0) end[dim] = shape[dim] + end[dim]; + if (end[dim] == INT_MAX) end[dim] = shape[dim]; + if (dim > 0) { + ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]); + strides[dim-1] = strides[dim] * shape[dim]; + } + } + #pragma unroll + for (int j = 0; j < nvec; j++) { + idx[j] = 0; + int ref_idx = offset + j; + #pragma unroll + for (int dim = 0; dim < ndim; dim++) { + int stride = ref_strides[dim]; + if (shape[dim] > 1) { + idx[j] += (ref_idx / stride + begin[dim]) * strides[dim]; + } + ref_idx = ref_idx % stride; + } + } + vector::VectorizedStorage ret; + #pragma unroll + for (int j = 0; j < nvec; j++) { + ret.scratch_.separate[j] = *(input + idx[j]); + } + return ret; +} + +template +__device__ inline vector::VectorizedStorage fast_load_slice(const DType * input, + const Shape& shape, + Shape begin, + Shape end, + int offset) { + int idx = 0; + + Shape ref_strides; + Shape strides; + ref_strides[ndim-1] = 1; + strides[ndim-1] = 1; + #pragma unroll + for (int dim = ndim-1; dim >=0; dim--) { + if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim]; + if (end[dim] < 0) end[dim] = shape[dim] + end[dim]; + if (end[dim] == INT_MAX) end[dim] = shape[dim]; + if (dim > 0) { + ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]); + strides[dim-1] = strides[dim] * shape[dim]; + } + } + int ref_idx = offset; + #pragma unroll + for (int dim = 0; dim < ndim; dim++) { + int stride = ref_strides[dim]; + if (shape[dim] > 1) { + idx += (ref_idx / stride + begin[dim]) * strides[dim]; + } + ref_idx = ref_idx % stride; + } + return global_load_index(input, idx, shape); +} + +template +__device__ inline void store_index(const vector::VectorizedStorage value, int i, + DType * output, const Shape& shape) { + if (i < (shape.size + nvec - 1) / nvec) { + auto vector_output = reinterpret_cast< + typename vector::VectorizedStorage::LType *>(output); + vector_output[i] = value.scratch_.aligned; + } +} + +template +__device__ inline void store_add_index(const vector::VectorizedStorage value, int i, + DType * output, const Shape& shape) { + if (i < (shape.size + nvec - 1) / nvec) { + auto vector_output = reinterpret_cast< + typename vector::VectorizedStorage::LType *>(output); + vector::VectorizedStorage ret(vector_output[i]); + ret += value; + vector_output[i] = ret.scratch_.aligned; + } +} + +} // namespace op +)code"; + +const char function_definitions_binary[] = R"code( +namespace op { + +template +__device__ inline typename type_util::mixed_type::type +add(const DType a, const DType2 b) { + return a + b; +} + +template +__device__ inline typename type_util::mixed_type::type +sub(const DType a, const DType2 b) { + return a - b; +} + +template +__device__ inline typename type_util::mixed_type::type +rsub(const DType a, const DType2 b) { + return b - a; +} + +template +__device__ inline typename type_util::mixed_type::type +mul(const DType a, const DType2 b) { + return a * b; +} + +template +__device__ inline typename type_util::mixed_type::type +div(const DType a, const DType2 b) { + return a / b; +} + +template +__device__ inline typename type_util::mixed_type::type +rdiv(const DType a, const DType2 b) { + return b / a; +} + +#define DEFINE_BINARY_MATH_FUNC(name, double_version, float_version) \ +template \ +__device__ inline typename type_util::mixed_type::type \ +name (const DType a, const DType2 b) { \ + if (type_util::has_double_or_integral::value) { \ + return double_version ((double)a, (double)b); \ + } else { \ + return float_version ((float)a, (float)b); \ + } \ +} + +template +__device__ inline typename type_util::mixed_type::type +power (const DType a, const DType2 b) { + if (type_util::has_double::value) { + return ::pow ((double)a, (double)b); \ + } else { + return ::powf ((float)a, (float)b); + } +} + +template +__device__ inline typename type_util::mixed_type::type +rpow(const DType a, const DType2 b) { + return power(b, a); +} + +template +__device__ inline typename type_util::mixed_type::type +max(const DType a, const DType2 b) { + if (isnan(a)) return a; + return a > b ? a : b; +} + +template +__device__ inline typename type_util::mixed_type::type +fmax(const DType a, const DType2 b) { + if (isnan(b)) return a; + return a > b ? a : b; +} + +template +__device__ inline typename type_util::mixed_type::type +min(const DType a, const DType2 b) { + if (isnan(a)) return a; + return a < b ? a : b; +} + +template +__device__ inline typename type_util::mixed_type::type +fmin(const DType a, const DType2 b) { + if (isnan(b)) return a; + return a < b ? a : b; +} + +DEFINE_BINARY_MATH_FUNC(hypot, ::hypot, ::hypotf) + +template +__device__ inline typename type_util::mixed_type::type +mod(const DType a, const DType2 b) { + if (b == 0) { + return 0; + } + const double ad = static_cast(a); + const double bd = static_cast(b); + if (bd < 0) { + if (ad < 0) { + return -::fmod(-ad, -bd); + } else { + return ::fmod(ad, -bd) + + (::fmod(ad, -bd) != 0 ? bd : 0); + } + } else { + if (ad < 0) { + return -::fmod(-ad, bd) + + (::fmod(-ad, bd) != 0 ? bd : 0); + } else { + return ::fmod(ad, bd); + } + } +} + +template +__device__ inline typename type_util::mixed_type::type +fmod(const DType a, const DType2 b) { + if (b == 0) { + return 0; + } + return ::fmod(static_cast(a), static_cast(b)); +} + +template +__device__ inline typename type_util::mixed_type::type +rmod(const DType a, const DType2 b) { + return op::mod(b, a); +} + +template +__device__ inline typename type_util::mixed_type::type +rfmod(const DType a, const DType2 b) { + return op::fmod(b, a); +} + +template +__device__ inline DType equal(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a == real_b ? 1 : 0; +} + +template +__device__ inline DType not_equal(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a != real_b ? 1 : 0; +} + +template +__device__ inline DType greater(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a > real_b ? 1 : 0; +} + +template +__device__ inline DType greater_equal(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a >= real_b ? 1 : 0; +} + +template +__device__ inline DType less(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a < real_b ? 1 : 0; +} + +template +__device__ inline DType less_equal(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a <= real_b ? 1 : 0; +} + +template +__device__ inline bool_t np_equal(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a == real_b ? true : false; +} + +template +__device__ inline bool_t np_not_equal(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a != real_b ? true : false; +} + +template +__device__ inline bool_t np_greater(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a > real_b ? true : false; +} + +template +__device__ inline bool_t np_greater_equal(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a >= real_b ? true : false; +} + +template +__device__ inline bool_t np_less(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a < real_b ? true : false; +} + +template +__device__ inline bool_t np_less_equal(const DType a, const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a <= real_b ? true : false; +} + +template +__device__ inline DType logical_and(const DType a, const DType2 b) { + return a && b ? 1 : 0; +} + +template +__device__ inline DType logical_or(const DType a, const DType2 b) { + return a || b ? 1 : 0; +} + +template +__device__ inline DType logical_xor(const DType a, const DType2 b) { + return ((a || b) && !(a && b)) ? 1 : 0; +} + +template +__device__ inline DType copysign(const DType a, const DType2 b) { + return (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a; +} + +template +__device__ inline DType2 rcopysign(const DType a, const DType2 b) { + return copysign(b, a); +} + +template +__device__ inline typename type_util::mixed_type::type +lcm(const DType a, const DType2 b) { + if (type_util::is_integral::value && + type_util::is_integral::value) { + DType A = a; + DType2 B = b; + // minus cases. + if (a < 0) { + A = -a; + } + if (b < 0) { + B = -b; + } + // handle zero-valued cases. + DType c; + if (a == 0 || b == 0) { + c = 0; + } else { + DType tmp; + DType tmp_a = A; + DType tmp_b = B; + if (A < B) { + tmp = A; + A = B; + B = tmp; + } + while (A % B != 0) { + A = A % B; + tmp = A; + A = B; + B = tmp; + } + c = tmp_a / B * tmp_b; + } + return c; + } else { + return 0; + } +} + +template +__device__ inline typename type_util::mixed_type::type bitwise_xor(const DType a, + const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a ^ real_b; +} + +template +__device__ inline typename type_util::mixed_type::type bitwise_or(const DType a, + const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a | real_b; +} + +template +__device__ inline typename type_util::mixed_type::type bitwise_and(const DType a, + const DType2 b) { + using mixed_type = typename type_util::mixed_type::type; + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a & real_b; +} + +DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f) + +template +__device__ inline typename type_util::mixed_type::type +rarctan2(const DType a, const DType2 b) { + return arctan2(b, a); +} + +template +__device__ inline typename type_util::mixed_type::type +ldexp(const DType a, const DType2 b) { + if (type_util::has_double_or_integral::value) { + return a * ::pow(2.0, static_cast(b)); + } else { + return a * ::powf(2.0f, static_cast(b)); + } +} + +template +__device__ inline typename type_util::mixed_type::type +rldexp(const DType a, const DType2 b) { + return ldexp(b, a); +} + +#undef DEFINE_BINARY_MATH_FUNC + +template +__device__ inline bool np_logical_and(const DType val, const DType2 val2) { + return (val && val2) ? true : false; +} + +template +__device__ inline bool np_logical_or(const DType val, const DType2 val2) { + return (val || val2) ? true : false; +} + +template +__device__ inline bool np_logical_xor(const DType val, const DType2 val2) { + return ((val || val2) && !(val && val2)) ? true : false; +} + +template +__device__ inline DType left(const DType left_val, const DType2 right_val) { + return left_val; +} + +template +__device__ inline DType2 right(const DType left_val, const DType2 right_val) { + return right_val; +} + +} // namespace op +)code"; + +const char function_definitions_unary[] = R"code( +namespace op { + +template +__device__ inline DType identity(const DType val) { + return val; +} + +template +__device__ inline DType negation(const DType val) { + return -val; +} + +template +__device__ inline typename LoadType::Type cast(const DType val) { + return static_cast::Type>(val); +} + +// activations + +template +__device__ inline DType relu(const DType val) { + return (isnan(val) || val > 0) ? val : 0; +} + +template +__device__ inline DType sigmoid(const DType val) { + if (type_util::has_double_or_integral::value) { + return 1./(1 + ::exp(-val)); + } else { + return 1.f/(1 + expf(-val)); + } +} + +template +__device__ inline DType softrelu(const DType val) { + if (type_util::has_double_or_integral::value) { + return ::log(1 + ::exp(val)); + } else { + return logf(1 + expf(val)); + } +} + +template +__device__ inline DType softsign(const DType val) { + if (type_util::has_double_or_integral::value) { + return val / (1 + fabs(val)); + } else { + return val / (1 + fabsf(val)); + } +} + +// exp and log + +#define DEFINE_UNARY_MATH_FUNC(name, double_version, float_version) \ +template \ +__device__ inline DType name (const DType a) { \ + if (type_util::has_double_or_integral::value) { \ + return double_version ((double)a); \ + } else { \ + return float_version (a); \ + } \ +} + +DEFINE_UNARY_MATH_FUNC(exp, ::exp, ::expf) +DEFINE_UNARY_MATH_FUNC(expm1, ::expm1, ::expm1f) +DEFINE_UNARY_MATH_FUNC(log, ::log, ::logf) +DEFINE_UNARY_MATH_FUNC(log10, ::log10, ::log10f) +DEFINE_UNARY_MATH_FUNC(log2, ::log2, ::log2f) +DEFINE_UNARY_MATH_FUNC(log1p, ::log1p, ::log1pf) + +// trigonometric + +constexpr double pi = 3.14159265358979323846; + +template +__device__ inline DType degrees(const DType val) { + if (type_util::has_double_or_integral::value) { + return (val / pi) * 180; + } else { + return (val / static_cast(pi)) * 180.f; + } +} + +template +__device__ inline DType radians(const DType val) { + if (type_util::has_double_or_integral::value) { + return (val / 180.0) * pi; + } else { + return (val / 180.0f) * static_cast(pi); + } +} + +DEFINE_UNARY_MATH_FUNC(sin, ::sin, ::sinf) +DEFINE_UNARY_MATH_FUNC(cos, ::cos, ::cosf) +DEFINE_UNARY_MATH_FUNC(tan, ::tan, ::tanf) +DEFINE_UNARY_MATH_FUNC(arcsin, ::asin, ::asinf) +DEFINE_UNARY_MATH_FUNC(arccos, ::acos, ::acosf) +DEFINE_UNARY_MATH_FUNC(arctan, ::atan, ::atanf) + +DEFINE_UNARY_MATH_FUNC(sinh, ::sinh, ::sinhf) +DEFINE_UNARY_MATH_FUNC(cosh, ::cosh, ::coshf) +DEFINE_UNARY_MATH_FUNC(tanh, ::tanh, ::tanhf) +DEFINE_UNARY_MATH_FUNC(arcsinh, ::asinh, ::asinhf) +DEFINE_UNARY_MATH_FUNC(arccosh, ::acosh, ::acoshf) +DEFINE_UNARY_MATH_FUNC(arctanh, ::atanh, ::atanhf) + +// sqrt + +DEFINE_UNARY_MATH_FUNC(sqrt, ::sqrt, ::sqrtf) +DEFINE_UNARY_MATH_FUNC(rsqrt, ::rsqrt, ::rsqrtf) +DEFINE_UNARY_MATH_FUNC(cbrt, ::cbrt, ::cbrtf) +DEFINE_UNARY_MATH_FUNC(rcbrt, ::rcbrt, ::rcbrtf) + +template +__device__ inline DType square(const DType val) { + return val * val; +} + +template +__device__ inline typename LoadType::Type zero(const DType val, const DTypes... args) { + return 0; +} + +template +__device__ inline typename LoadType::Type zero() { + return 0; +} + +template +__device__ inline typename LoadType::Type one(const DType val, const DTypes... args) { + return 1; +} + +template +__device__ inline typename LoadType::Type one() { + return 1; +} + +template +__device__ inline typename LoadType::Type negone(const DType val, const DTypes... args) { + return -1; +} + +template +__device__ inline typename LoadType::Type negone() { + return -1; +} + +template +__device__ inline DType round(const DType val) { + if (type_util::has_double::value) { + return ::round((double)val); + } else if (type_util::is_integral::value) { + return val; + } else { + return ::roundf(val); + } +} + +template +__device__ inline DType floor(const DType val) { + if (type_util::has_double::value) { + return ::floor((double)val); + } else if (type_util::is_integral::value) { + return val; + } else { + return ::floorf(val); + } +} + +template +__device__ inline DType ceil(const DType val) { + if (type_util::has_double::value) { + return ::ceil((double)val); + } else if (type_util::is_integral::value) { + return val; + } else { + return ::ceilf(val); + } +} + +template +__device__ inline DType rint(const DType val) { + if (type_util::has_double::value) { + return ::rint((double)val); + } else if (type_util::is_integral::value) { + return val; + } else { + return ::rintf(val); + } +} + +template +__device__ inline DType fix(const DType val) { + const auto f = floor(val); + const auto c = ceil(val); + return (f > 0 ? f : -f) < (c > 0 ? c : -c) ? f : c; +} + +template +__device__ inline DType trunc(const DType val) { + if (type_util::has_double::value) { + return ::trunc((double)val); + } else if (type_util::is_integral::value) { + return val; + } else { + return ::truncf(val); + } +} + +template +__device__ inline DType clip(const DType val, const float a_min, const float a_max) { + return max(min(val, a_max), a_min); +} + +template +__device__ inline DType sign(const DType val) { + if (val < 0) return -1; + return val > 0 ? 1 : 0; +} + +template +__device__ inline DType reciprocal(const DType val) { + return 1.0f / val; +} + +DEFINE_UNARY_MATH_FUNC(abs, ::fabs, ::fabsf) +DEFINE_UNARY_MATH_FUNC(gamma, ::tgamma, ::tgammaf) +DEFINE_UNARY_MATH_FUNC(gammaln, ::lgamma, ::lgammaf) +DEFINE_UNARY_MATH_FUNC(erf, ::erf, ::erff) +DEFINE_UNARY_MATH_FUNC(erfinv, ::erfinv, ::erfinvf) + +template +__device__ inline DType gelu(const DType val) { + return 0.5f * val * (1.0f + op::erf(val / op::sqrt(2.0f))); +} + +template +__device__ inline DType1 smooth_l1(const DType1 val, const DType2 scalar) { + const auto bsq = scalar * scalar; + const auto ibsq = 1.0f / bsq; + if (val > ibsq) { + return val - 0.5f * ibsq; + } else if (val < -ibsq) { + return -val - 0.5f * ibsq; + } else { + return 0.5f * val * val * bsq; + } +} + +template +__device__ inline DType digamma(const DType val) { + if (type_util::has_double_or_integral::value) { + return special_functions::cephes::psi(val); + } else { + return special_functions::cephes::psi(val); + } +} + +template +__device__ inline DType logical_not(const DType val) { + return val != DType(0) ? DType(0) : DType(1); +} + +template +__device__ inline bool_t np_logical_not(const DType val) { + return !static_cast(val); +} + +template +__device__ inline bool isnan(const DType val) { + return util::isnan(val); +} + +template +__device__ inline bool_t isinf(const DType val) { + return util::isinf(val); +} + +template +__device__ inline bool_t isposinf(const DType val) { + return util::isinf(val) && (val > 0); +} + +template +__device__ inline bool_t isneginf(const DType val) { + return util::isinf(val) && (val < 0); +} + +template +__device__ inline bool_t isfinite(const DType val) { + return !op::isnan(val) && !op::isinf(val); +} + +#undef DEFINE_UNARY_MATH_FUNC + +template +__device__ inline DType bitwise_not(const DType a) { + if (type_util::is_same::value) { + return !a; + } else { + return ~static_cast(a); + } +} + +} // namespace op + +)code"; + +} // namespace rtc +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA + +#endif // MXNET_COMMON_CUDA_RTC_FORWARD_FUNCTIONS_INL_H_ diff --git a/src/common/cuda/rtc/half-inl.h b/src/common/cuda/rtc/half-inl.h new file mode 100644 index 000000000000..922bc2f25e45 --- /dev/null +++ b/src/common/cuda/rtc/half-inl.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_COMMON_CUDA_RTC_HALF_INL_H_ +#define MXNET_COMMON_CUDA_RTC_HALF_INL_H_ + +#if MXNET_USE_CUDA + +namespace mxnet { +namespace common { +namespace cuda { +namespace rtc { + +const char fp16_support_string[] = R"code( +struct __align__(2) __half { + __host__ __device__ __half() { } + unsigned short __x; +}; +/* Definitions of intrinsics */ +__device__ inline __half __float2half(const float f) { + __half val; + asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(val.__x) : "f"(f)); + return val; +} +__device__ inline float __half2float(const __half h) { + float val; + asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(h.__x)); + return val; +} + +typedef __half half; + +template +struct AccType { + using type = DType; + + __device__ static inline type from(const DType& val) { + return val; + } + + __device__ static inline DType to(type val) { + return val; + } + +}; + +template<> +struct AccType { + using type = float; + + __device__ static inline type from(const half& val) { + return __half2float(val); + } + + __device__ static inline half to(type val) { + return __float2half(val); + } +}; +)code"; + +} // namespace rtc +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA + +#endif // MXNET_COMMON_CUDA_RTC_HALF_INL_H_ diff --git a/src/common/cuda/rtc/reducer-inl.h b/src/common/cuda/rtc/reducer-inl.h new file mode 100644 index 000000000000..93b702788c46 --- /dev/null +++ b/src/common/cuda/rtc/reducer-inl.h @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_ +#define MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_ + +#if MXNET_USE_CUDA + +namespace mxnet { +namespace common { +namespace cuda { +namespace rtc { + +const char reducer[] = R"code( + +namespace red { + +/*! \brief sum reducer */ +struct sum { + /*! \brief do reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { + dst = op::add(dst, src); + } + + /*! \brief do stable reduction into dst */ + template + __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src, + volatile DType& residual) { + DType y = op::sub(src, residual); + DType t = dst + y; + if (util::isinf(t)) { + residual = 0; + } else { + residual = (t - dst) - y; + } + dst = t; + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) { + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual, + volatile DType& src_val, volatile DType& src_residual) { + DType t1 = dst_val + src_val; + if (util::isinf(t1)) { + dst_val = t1; + dst_residual = 0; + } else { + DType e = t1 - dst_val; + DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; + dst_val = t1 + t2; + dst_residual = t2 - (dst_val - t1); + } + } + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst) {} + /*! \brief finalize reduction result */ + template + __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {} + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv) { + initv = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + __device__ inline static void SetInitValue(DType &initv, DType &residual) { + SetInitValue(initv); + residual = 0; + } +}; +} // namespace red + +)code"; + +} // namespace rtc +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA + +#endif // MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_ + diff --git a/src/common/cuda/rtc/special_functions-inl.h b/src/common/cuda/rtc/special_functions-inl.h new file mode 100644 index 000000000000..50f860405ef2 --- /dev/null +++ b/src/common/cuda/rtc/special_functions-inl.h @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_COMMON_CUDA_RTC_SPECIAL_FUNCTIONS_INL_H_ +#define MXNET_COMMON_CUDA_RTC_SPECIAL_FUNCTIONS_INL_H_ + +#include +#include + +namespace mxnet { +namespace common { +namespace cuda { +namespace rtc { + +// This code is based on the Cephes Library availible at http://www.netlib.org/cephes +// The original author, Stephen Moshier, has kindly given permission to use this code +// in mxnet. (See email below). +// +// Date: Tue, 13 Sep 2016 09:28:20 -0400 +// From: Stephen Moshier +// To: Flunkert, Valentin +// Subject: Re: cephes code in mxnet +// +// Hello Valentin, +// +// Thank you for writing. You are welcome to use and modify the Cephes code +// and distribute it under the Apache license. +// +// Good luck with your project, +// Steve Moshier +// +// Cephes Math Library Release 2.2: June, 1992 +// Copyright 1984, 1987, 1992 by Stephen L. Moshier +// Direct inquiries to 30 Frost Street, Cambridge, MA 02140 +// +const char special_functions_definitions[] = R"code( +constexpr double DBL_MAX = 1.7976931348623157081e+308; + +namespace op { + +namespace special_functions { + +template +__device__ inline static DType trigamma(DType x); + +template<> +__device__ inline double trigamma(double x) { + double PI(3.14159265358979323846); + double sign = +1; + double result = 0; + if (x < 0.5) { + sign = -1; + const double sin_pi_x = sin(PI * x); + result -= (PI * PI) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const double ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (1./6 - ixx * (1./30 - ixx * (1./42)))) / x; + return sign * result; +} + +template<> +__device__ inline float trigamma(float x) { + float PI(3.14159265358979323846); + float sign = +1; + float result = 0; + if (x < 0.5f) { + sign = -1; + const float sin_pi_x = sinf(PI * x); + result -= (PI * PI) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const float ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x; + return sign * result; +} + +struct cephes { + /* + * Helper to evaluate a polynomial given an array of coefficients. + */ + template + __device__ inline static DType polevl(DType x, const DType coef[], int N) { + DType ans; + DType const *p; + int i; + + p = coef; + ans = *p++; + + i = N; + do { + ans = ans * x + *p++; + } while ( --i ); + + return( ans ); + } + + + /* + * Helper function for psi that handles double/float specific differences + * in the algorithm. + */ + template + __device__ inline static DType psi_helper(DType s); + + /* + * + * Psi (digamma) function + * + * + * SYNOPSIS: + * + * float x, y, psif(); + * + * y = psif( x ); + * + * + * DESCRIPTION: + * + * d - + * psi(x) = -- ln | (x) + * dx + * + * is the logarithmic derivative of the gamma function. + * For integer x, + * n-1 + * - + * psi(n) = -EUL + > 1/k. + * - + * k=1 + * + * This formula is used for 0 < n <= 10. If x is negative, it + * is transformed to a positive argument by the reflection + * formula psi(1-x) = psi(x) + pi cot(pi x). + * For general positive x, the argument is made greater than 10 + * using the recurrence psi(x+1) = psi(x) + 1/x. + * Then the following asymptotic expansion is applied: + * + * inf. B + * - 2k + * psi(x) = log(x) - 1/2x - > ------- + * - 2k + * k=1 2k x + * + * where the B2k are Bernoulli numbers. + * + * ACCURACY: + * Absolute error, relative when |psi| > 1 : + * arithmetic domain # trials peak rms + * IEEE -33,0 30000 8.2e-7 1.2e-7 + * IEEE 0,33 100000 7.3e-7 7.7e-8 + * + * ERROR MESSAGES: + * message condition value returned + * psi singularity x integer <=0 MAXNUMF + */ + template + __device__ inline static DType psi(DType x) { + DType p, q, nz, s, w, y; + int i, n, negative; + + DType EUL(0.57721566490153286061); + DType PI(3.14159265358979323846); + + negative = 0; + nz = 0.0; + + if ( x <= 0.0 ) { + negative = 1; + q = x; + p = ::floor(q); + if ( p == q ) { + return DBL_MAX; + } + /* Remove the zeros of tan(PI x) + * by subtracting the nearest integer from x + */ + nz = q - p; + if ( nz != 0.5 ) { + if ( nz > 0.5 ) { + p += 1.0; + nz = q - p; + } + nz = PI/::tan(PI*nz); + } else { + nz = 0.0; + } + x = 1.0 - x; + } + + /* check for positive integer up to 10 */ + if ( (x <= 10.0) && (x == ::floor(x)) ) { + y = 0.0; + n = x; + for ( i = 1; i < n; i++ ) { + w = i; + y += 1.0/w; + } + y -= EUL; + goto done; + } + + s = x; + w = 0.0; + while ( s < 10.0 ) { + w += 1.0/s; + s += 1.0; + } + + y = psi_helper(s); + + y = logf(s) - (0.5/s) - y - w; + +done: + + if ( negative ) { + y -= nz; + } + + return(y); + } +}; + + +template<> +__device__ inline double cephes::psi_helper(double s) { + double z; + const double A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2 + }; + + if ( s < 1.0e17 ) { + z = 1.0/(s * s); + return z * cephes::polevl(z, A, 6); + } else { + return 0.0; + } +} + +template<> +__device__ inline float cephes::psi_helper(float s) { + float z; + const float A[] = { + -4.16666666666666666667E-3f, + 3.96825396825396825397E-3f, + -8.33333333333333333333E-3f, + 8.33333333333333333333E-2f + }; + + if ( s < 1.0e8 ) { + z = 1.0/(s * s); + return z * cephes::polevl(z, A, 3); + } else { + return 0.0; + } +} +} // namespace special_functions +} // namespace op +)code"; + +} // namespace rtc +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_COMMON_CUDA_RTC_SPECIAL_FUNCTIONS_INL_H_ diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h new file mode 100644 index 000000000000..372390fdc117 --- /dev/null +++ b/src/common/cuda/rtc/util-inl.h @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_COMMON_CUDA_RTC_UTIL_INL_H_ +#define MXNET_COMMON_CUDA_RTC_UTIL_INL_H_ + +#include + +#if MXNET_USE_CUDA + +namespace mxnet { +namespace common { +namespace cuda { +namespace rtc { + +const char type_support_string[] = R"code( +using float32 = float; +using float64 = double; +using float16 = half; +using uint8 = unsigned char; +using int8 = char; +using int32 = int; +using int64 = long long; + +static_assert(sizeof(float32) == 4, "Size of float32 is expected to be 4B"); +static_assert(sizeof(float64) == 8, "Size of float64 is expected to be 8B"); +static_assert(sizeof(float16) == 2, "Size of float16 is expected to be 2B"); +static_assert(sizeof(uint8) == 1, "Size of uint8 is expected to be 1B"); +static_assert(sizeof(int8) == 1, "Size of int8 is expected to be 1B"); +static_assert(sizeof(int32) == 4, "Size of int32 is expected to be 4B"); +static_assert(sizeof(int64) == 8, "Size of int64 is expected to be 8B"); + +)code" +#if MSHADOW_INT64_TENSOR_SIZE == 1 +"typedef int64 index_t;\n" +#else +"typedef int32 index_t;\n" +#endif +R"code( +// bool and int8 need to be accumulated in index_t +// but bool needs to be treated in the special way +// for ops like bitwise_not +struct bool_t { + index_t value; + + __device__ inline bool_t(const index_t& v) : value(v) {} + __device__ inline bool_t(const volatile index_t& v) : value(v) {} + __device__ inline bool_t() : value(0) {} + + __device__ inline operator index_t() const volatile { return value; } + __device__ inline bool_t& operator= (const index_t& v) { + value = v; + return *this; + } + __device__ inline volatile bool_t& operator= (const index_t& v) volatile { + value = v; + return *this; + } + __device__ inline bool_t& operator= (const volatile index_t& v) { + value = v; + return *this; + } +}; +template<> +struct AccType { + using type = bool_t; + + __device__ static inline type from(const bool& val) { + return val; + } + + __device__ static inline bool to(type val) { + return val; + } +}; + +template<> +struct AccType { + using type = index_t; + + __device__ static inline type from(const int8& val) { + return val; + } + + __device__ static inline int8 to(type val) { + return val; + } +}; + +template<> +struct AccType { + using type = index_t; + + __device__ static inline type from(const uint8& val) { + return val; + } + + __device__ static inline uint8 to(type val) { + return val; + } +}; + +namespace type_util { + +struct false_type { + static constexpr bool value = false; +}; + +struct true_type { + static constexpr bool value = true; +}; + +// is_integral +template struct is_integral : false_type {}; +template <> struct is_integral : true_type {}; +template <> struct is_integral : true_type {}; +template <> struct is_integral : true_type {}; +template <> struct is_integral : true_type {}; +template <> struct is_integral : true_type {}; +template <> struct is_integral : true_type {}; + +// is_unsigned +template struct is_unsigned : false_type {}; +template <> struct is_unsigned : true_type {}; +template <> struct is_unsigned : true_type {}; +template <> struct is_unsigned : true_type {}; + +// is_same +template +struct is_same : false_type {}; +template struct is_same : true_type {}; + +// has_double +template struct has_double : false_type {}; + +template +struct has_double { + static constexpr bool value = is_same::value || + has_double::value; +}; + +// has_double_or_integral +template struct has_double_or_integral : false_type {}; + +template +struct has_double_or_integral { + static constexpr bool value = is_same::value || + is_integral::value || + has_double_or_integral::value; +}; + +template +struct enable_if {}; + +template <> +struct enable_if { + using type = void; +}; + +template +struct mixed_type; + +template +struct mixed_type::value>::type> { + using type = float64; +}; + +template +struct mixed_type { + using type = float64; +}; + +template +struct mixed_type::value && + !is_same::value>::type> { + using type = float32; +}; + +template +struct mixed_type::value>::type> { + using type = float32; +}; + +template +struct mixed_type::value || + is_integral::value>::type> { + using type = float16; +}; + +template +struct mixed_type::value>::type> { + using type = float16; +}; + +template +struct mixed_type::value && + is_integral::value && + !is_same::value && + sizeof(T) <= sizeof(U)>::type> { + using type = U; +}; + +template +struct mixed_type::value && + is_integral::value && + !is_same::value && + sizeof(T) < sizeof(U)>::type> { + using type = U; +}; + +template +struct mixed_type::value && + sizeof(T) < sizeof(bool_t)>::type> { + using type = index_t; +}; + +template +struct mixed_type::value && + sizeof(T) < sizeof(bool_t)>::type> { + using type = index_t; +}; + +template +struct mixed_type::value && + sizeof(T) == sizeof(bool_t)>::type> { + using type = T; +}; + +} // namespace type_util +)code"; + +const char util_string[] = R"code( +enum class OpReqType { + kNullOp, + kWriteTo, + kWriteInplace, + kAddTo +}; + +constexpr int kRTCMaxThreadsPerBlock = 512; + +namespace util { + +constexpr int MAX_DIM = 5; + +template +__device__ inline void unravel_dot(const index_t idx, const index_t (&shape)[MAX_DIM], + const index_t (&stridej)[MAX_DIM], const index_t (&stridek)[MAX_DIM], index_t* j, index_t* k) { + *j = 0; + *k = 0; + #pragma unroll + for (index_t i = ndim-1, idx_t = idx; i >=0; --i) { + const auto tmp = idx_t / shape[i]; + const auto coord = idx_t - tmp*shape[i]; + *j += coord*stridej[i]; + *k += coord*stridek[i]; + idx_t = tmp; + } +} + +template +__device__ inline index_t unravel_dot(const index_t idx, const index_t (&shape)[MAX_DIM], + const index_t (&stride)[MAX_DIM]) { + index_t ret = 0; + #pragma unroll + for (index_t i = ndim-1, j = idx; i >=0; --i) { + auto tmp = j / shape[i]; + ret += (j - tmp*shape[i])*stride[i]; + j = tmp; + } + return ret; +} + +template +__device__ inline index_t unravel_ravel(const index_t idx, const index_t (&shape1)[MAX_DIM], + const index_t (&shape2)[MAX_DIM]) { + index_t ret = 0; + index_t total_shape = 1; +#pragma unroll + for (index_t i = ndim-1, j = idx; i >=0; --i) { + if (i != ndim - 1) { + total_shape *= shape2[i + 1]; + } + auto tmp = j / shape1[i]; + const index_t coord = j - tmp*shape1[i]; + ret += total_shape * (shape2[i] > coord) * coord; + j = tmp; + } + return ret; +} + +template +__device__ inline index_t ravel(const index_t (&coord)[ndim], const index_t (&shape)[ndim2]) { + index_t ret = 0; +#pragma unroll + for (int i = 0; i < ndim; ++i) { + ret = ret * shape[i] + (shape[i] > coord[i]) * coord[i]; + } + return ret; +} + +template +__device__ inline void unravel(const index_t idx, + const index_t (&shape)[ndim2], + index_t (&coord)[ndim]) { +#pragma unroll + for (index_t i = ndim-1, j = idx; i >=0; --i) { + auto tmp = j / shape[i]; + coord[i] = j - tmp*shape[i]; + j = tmp; + } +} + +template +__device__ inline bool isinf(volatile const DType &val) { + return false; +} + +template <> +__device__ inline bool isinf(volatile const float &val) { + return ::isinf(val); +} + +template <> +__device__ inline bool isinf(volatile const double &val) { + return ::isinf(val); +} + +template <> +__device__ inline bool isinf(volatile const long double &val) { + return ::isinf(val); +} + +template <> +__device__ inline bool isinf(volatile const float16 &val) { + return ::isinf(__half2float(const_cast(val))); +} + +template +__device__ inline bool isnan(volatile const DType &val) { + return false; +} + +template <> +__device__ inline bool isnan(volatile const float &val) { + return ::isnan(val); +} + +template <> +__device__ inline bool isnan(volatile const double &val) { + return ::isnan(val); +} + +template <> +__device__ inline bool isnan(volatile const long double &val) { + return ::isnan(val); +} + +template <> +__device__ inline bool isnan(volatile const float16 &val) { + return ::isnan(__half2float(const_cast(val))); +} + +} // namespace util +)code"; +} // namespace rtc +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA + +#endif // MXNET_COMMON_CUDA_RTC_UTIL_INL_H_ diff --git a/src/common/cuda/rtc/vectorization-inl.h b/src/common/cuda/rtc/vectorization-inl.h new file mode 100644 index 000000000000..9868069daf73 --- /dev/null +++ b/src/common/cuda/rtc/vectorization-inl.h @@ -0,0 +1,463 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_COMMON_CUDA_RTC_VECTORIZATION_INL_H_ +#define MXNET_COMMON_CUDA_RTC_VECTORIZATION_INL_H_ + +#include + +#if MXNET_USE_CUDA + +#include +#include +#include +#include + +#include "../rtc.h" +#include "../../utils.h" + +namespace mxnet { +namespace common { +namespace cuda { +namespace rtc { + +const char vectorization_support_string[] = R"code( + +namespace vector { + +template +struct VectorType { + static_assert(size <= 32, "VectorType needs to have size of at most 32B"); +}; + +template <> +struct VectorType<1> { + using type = char; +}; + +template <> +struct VectorType<2> { + using type = short; +}; + + +template <> +struct VectorType<4> { + using type = int; +}; + +template <> +struct VectorType<8> { + using type = long long; +}; + +template <> +struct VectorType<16> { + using type = ulonglong2; +}; + +template <> +struct VectorType<32> { + using type = ulonglong4; +}; + +template +__device__ inline DType add_elem(const DType& x, const DType& y) { + return x + y; +} + +template <> +__device__ inline half add_elem(const half& x, const half& y) { + return __float2half(__half2float(x) + __half2float(y)); +} + +/* \brief Helper class that enables storing multiple values of type DType + as 1 value of type LType. +*/ +template +class VectorizedStorage { + public: + using LType = typename VectorType::type; + constexpr static int nvec = n; + union vectorized_storage { + LType aligned; + DType separate[nvec]; // NOLINT(*) + + inline __device__ vectorized_storage() {} + inline __device__ ~vectorized_storage() {} + } scratch_; + + inline __device__ VectorizedStorage() {} + inline __device__ VectorizedStorage (const VectorizedStorage& y2) { + scratch_.aligned = y2.scratch_.aligned; + } + inline __device__ VectorizedStorage (const LType &y2) { + scratch_.aligned = y2; + } + inline __device__ VectorizedStorage& operator+=( + const VectorizedStorage& rhs) { + #pragma unroll + for (int i = 0; i < nvec; ++i) { + scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]); + } + return *this; + } + inline __device__ ~VectorizedStorage() {} +}; + +// Returns const LType is DType is const +template +struct select_const { + using type = LType; +}; + +template +struct select_const { + using type = const LType; +}; + +template +struct remove_const { + using type = DType; +}; + +template +struct remove_const { + using type = DType; +}; + + +/* \brief Helper class that enables accessing multiple values of type DType + as 1 value of type LType. Additional aligned template argument + allows performance optimizations if the pointer and the size of + the allocation is aligned to sizeof(LType) / sizeof(DType) elements. +*/ +template +class VectorizedAccessor { + public: + using StorageType = VectorizedStorage::type, + nvec>; + using LType = typename select_const::type; + StorageType storage_; + + LType* aligned_ptr_; + DType* unaligned_ptr_; + int alignment_; + index_t n_elems_; + + inline __device__ VectorizedAccessor(DType* const ptr, const index_t size) { + unaligned_ptr_ = ptr; + if (aligned) { + alignment_ = 0; + aligned_ptr_ = reinterpret_cast(ptr); + n_elems_ = (size + nvec- 1) / nvec; + } else { + size_t ptr_as_number = reinterpret_cast(ptr); + alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType); + aligned_ptr_ = reinterpret_cast(ptr - alignment_); + n_elems_ = (size + alignment_ + nvec - 1) / nvec; + } + } + + /* \brief Alignment of the input pointer in elements. */ + inline __device__ int alignment() const { + return alignment_; + } + + /* \brief Access to separate elements. */ + inline __device__ DType* separate() { + return storage_.scratch_.separate; + } + + /* \brief Number of aligned elements that span the entire input tensor. */ + inline __device__ index_t num_aligned_elements() const { + return n_elems_; + } + + /* \brief Load values from the input. + \param id Aligned index of the element. + \param N size of the tensor. + */ + inline __device__ void load(const index_t id, const index_t N) { + if (aligned) { + storage_.scratch_.aligned = aligned_ptr_[id]; + } else { + if (id > 0 && id < n_elems_ - 1) { + storage_.scratch_.aligned = aligned_ptr_[id]; + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + DType* ptr = reinterpret_cast(&(aligned_ptr_[id])) + j; + if (reinterpret_cast(ptr) >= reinterpret_cast(unaligned_ptr_) && + reinterpret_cast(ptr) < reinterpret_cast(unaligned_ptr_ + N)) { + storage_.scratch_.separate[j] = *ptr; + } + } + } + } + } +}; + +/* \brief Class used for vectorized read-only access. */ +template +class VectorizedLoader : public VectorizedAccessor { + public: + inline __device__ VectorizedLoader(const DType* ptr, const index_t N) : + VectorizedAccessor(ptr, N) { + } +}; + +/* \brief Class used for vectorized writable access. */ +template +class VectorizedStorer : public VectorizedAccessor { + public: + inline __device__ VectorizedStorer(DType* ptr, const index_t N) : + VectorizedAccessor(ptr, N) { + } + + /* \brief Store values to the output. + \param id Aligned index of the element. + \param N size of the tensor. + */ + inline __device__ void store(const index_t id, const index_t N) { + if (aligned) { + this->aligned_ptr_[id] = this->storage_.scratch_.aligned; + } else { + if (id > 0 && id < this->n_elems_ - 1) { + this->aligned_ptr_[id] = this->storage_.scratch_.aligned; + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + DType* ptr = reinterpret_cast(&(this->aligned_ptr_[id])) + j; + if (reinterpret_cast(ptr) >= reinterpret_cast(this->unaligned_ptr_) && + reinterpret_cast(ptr) < reinterpret_cast(this->unaligned_ptr_ + N)) { + *ptr = this->storage_.scratch_.separate[j]; + } + } + } + } + } +}; + +} // namespace vector + +)code"; + +namespace { + +inline index_t get_num_aligned_elements(const void *ptr, const index_t lead_dim, + const int nvec, const int size) { + size_t ptr_as_number = reinterpret_cast(ptr); + int alignment = (ptr_as_number % (nvec * size)) / size; + return (lead_dim + alignment + nvec - 1) / nvec; +} + +enum class Alignment { + SAME_ALIGNED, // All tensors aligned + SAME_UNALIGNED, // All tensors have the same misalignment + DIFFERENT // Tensors have different alignment +}; + +inline int CalcAlignment(const void *ptr, const int size) { + size_t ptr_as_number = reinterpret_cast(ptr); + return ptr_as_number % size; +} + +/* \brief Check alignment of the inputs and outputs when using vectorized accesses. + \param params Structure containing arrays with inputs' and outputs' pointers + \param lead_dim Leading dimension of the tensors. + \param other_dim The size of the other dimensions of the tensors. + \param nvec Length of the vector. + \param inputs Inputs to the operator. + \param outputs Outputs of the operator. +*/ +template +Alignment CheckAlignment(const Params& params, const index_t lead_dim, + const index_t other_dim, const int nvec, + const std::vector &inputs, + const std::vector &outputs) { + using namespace common; + int align = -1; + + size_t i = 0; + for (const void *ptr : params.inputs) { + if (ptr != nullptr) { + int new_align = CalcAlignment(ptr, + mshadow_type_info(inputs[i].type_flag_).size * nvec); + if (align == -1) { + align = new_align; + } else { + if (align != new_align) { + return Alignment::DIFFERENT; + } + } + } + ++i; + } + + i = 0; + for (const void *ptr : params.outputs) { + if (ptr != nullptr) { + int new_align = CalcAlignment(ptr, + mshadow_type_info(outputs[i].type_flag_).size * nvec); + if (align == -1) { + align = new_align; + } else { + if (align != new_align) { + return Alignment::DIFFERENT; + } + } + } + ++i; + } + + if ((other_dim != 1) && + (lead_dim % nvec != 0)) { + return Alignment::DIFFERENT; + } + + if ((align == 0) && + (lead_dim % nvec == 0)) { + return Alignment::SAME_ALIGNED; + } else { + return Alignment::SAME_UNALIGNED; + } +} + +constexpr int vectorized_kernel_thread_num = 512; + +} // namespace + +/*! \brief Launcher helper for the kernels using vectorization. + * \param parameters of the kernel (e.g. values of the template arguments) + * \param kernel_name name of the kernel + * \param code used for compilation of the kernel if not found in cache + * \param nvec length of the vector used for loading/storing data + * \param lead_dim size of leading dimension of the tensors + * \param other_dim maximum of the total size of all the other dimensions of the tensors + * \param s stream used to launch the kernel + * \param inputs to the kernel + * \param outputs of the kernel + * \param dev_id id of the devide which the kernel will be launched on + * \param lead_input_num number of input to use for checking alignment + * (in case only a subset of inputs is used vectorized). + * Default is 0. + */ +template +void VectorizedKernelRTCLauncher(const std::string ¶meters, + const std::string &kernel_name, + const std::string &code, + int nvec, + const index_t lead_dim, + const index_t other_dim, + mshadow::Stream *s, + const Params params, + const std::vector &inputs, + const std::vector &outputs, + const int dev_id, + const int lead_input_num = 0) { + const index_t N = lead_dim * other_dim; + nvec = std::min(nvec, 4); // Use at most 4-wide vectors + if (N != 0) { + auto align = CheckAlignment(params, lead_dim, other_dim, + nvec, inputs, outputs); + std::string kernel_builder; + kernel_builder.reserve(2560); + + // Fill input types + int counter = 0; + for (const auto& input : inputs) { + const auto& type_info = common::mshadow_type_info(input.type_flag_); + kernel_builder += "using InputType"; + kernel_builder += std::to_string(counter); + kernel_builder += " = "; + kernel_builder += type_info.name; + kernel_builder += ";\n"; + ++counter; + } + + // Fill output types + counter = 0; + for (const auto& output : outputs) { + const auto& type_info = common::mshadow_type_info(output.type_flag_); + kernel_builder += "using OutputType"; + kernel_builder += std::to_string(counter); + kernel_builder += " = "; + kernel_builder += type_info.name; + kernel_builder += ";\n"; + ++counter; + } + + switch (align) { + case Alignment::SAME_ALIGNED: + kernel_builder += "const bool aligned = true;\n" + "const int nvec = "; + kernel_builder += std::to_string(nvec); + kernel_builder += ";\n"; + break; + case Alignment::SAME_UNALIGNED: + kernel_builder += "const bool aligned = false;\n" + "const int nvec = "; + kernel_builder += std::to_string(nvec); + kernel_builder += ";\n"; + break; + case Alignment::DIFFERENT: { + // If the pointers are aligned differently we cannot vectorize + kernel_builder += "const bool aligned = true;\n" + "const int nvec = 1;\n"; + nvec = 1; + break; + } + } + + kernel_builder += parameters; + + index_t num_aligned_elements = get_num_aligned_elements( + params.inputs[lead_input_num], + lead_dim, nvec, + common::mshadow_type_info( + inputs[lead_input_num].type_flag_).size); + size_t num_elements = other_dim * num_aligned_elements; + constexpr int threads = vectorized_kernel_thread_num; + constexpr int max_blocks = 65535; + index_t blocks = std::min(static_cast((num_elements + threads - 1) / threads), + max_blocks); + std::vector args = {¶ms, &lead_dim, &other_dim, + &N, &num_aligned_elements}; + auto function = common::cuda::rtc::get_function(kernel_builder, + kernel_name, + code, + dev_id); + + common::cuda::rtc::launch(function, + {static_cast(blocks), 1, 1}, + {threads, 1, 1}, + 0, s, &args); + } +} + + +} // namespace rtc +} // namespace cuda +} // namespace common +} // namespace mxnet + +#endif // MXNET_USE_CUDA + +#endif // MXNET_COMMON_CUDA_RTC_VECTORIZATION_INL_H_ diff --git a/src/common/cuda_utils.cc b/src/common/cuda/utils.cc similarity index 99% rename from src/common/cuda_utils.cc rename to src/common/cuda/utils.cc index 893b34e6ff29..b87c39386604 100644 --- a/src/common/cuda_utils.cc +++ b/src/common/cuda/utils.cc @@ -28,7 +28,7 @@ #include -#include "cuda_utils.h" +#include "utils.h" #if MXNET_USE_CUDA diff --git a/src/common/cuda_utils.h b/src/common/cuda/utils.h similarity index 99% rename from src/common/cuda_utils.h rename to src/common/cuda/utils.h index 22ac42c6c67b..e0f0f152f63c 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda/utils.h @@ -19,7 +19,7 @@ /*! * Copyright (c) 2015 by Contributors - * \file cuda_utils.h + * \file utils.h * \brief Common CUDA utilities. */ #ifndef MXNET_COMMON_CUDA_UTILS_H_ @@ -168,11 +168,11 @@ inline __device__ bool __is_supported_cuda_architecture() { { \ CUresult e = (func); \ if (e != CUDA_SUCCESS) { \ - char const * err_msg = nullptr; \ + char const * err_msg = nullptr; \ if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \ LOG(FATAL) << "CUDA Driver: Unknown error " << e; \ } else { \ - LOG(FATAL) << "CUDA Driver: " << err_msg; \ + LOG(FATAL) << "CUDA Driver: " << e << " " << err_msg; \ } \ } \ } diff --git a/src/common/rtc.cc b/src/common/rtc.cc index df79ff69ebb7..21d3061e5209 100644 --- a/src/common/rtc.cc +++ b/src/common/rtc.cc @@ -20,10 +20,10 @@ #include #include -#include "../common/cuda_utils.h" +#include "cuda/utils.h" #include "../operator/operator_common.h" -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA namespace mxnet { namespace rtc { @@ -186,4 +186,4 @@ void CudaModule::Kernel::Launch( } // namespace rtc } // namespace mxnet -#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#endif // MXNET_USE_CUDA diff --git a/src/common/utils.cc b/src/common/utils.cc index 032a324c96b0..67f1f3137c9f 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -108,5 +108,30 @@ void ExecuteMonOutputCallback( } } +MShadowTypeInfo mshadow_type_info(const int type_flag) { + using namespace mshadow; + switch (type_flag) { + case kFloat32: + return MShadowTypeInfo("float32", sizeof(float)); + case kFloat64: + return MShadowTypeInfo("float64", sizeof(double)); + case kFloat16: + return MShadowTypeInfo("float16", 2, sizeof(float)); + case kUint8: + return MShadowTypeInfo("uint8", sizeof(uint8_t), sizeof(index_t)); + case kInt32: + return MShadowTypeInfo("int32", sizeof(int32_t)); + case kInt8: + return MShadowTypeInfo("int8", sizeof(int8_t), sizeof(index_t)); + case kInt64: + return MShadowTypeInfo("int64", sizeof(int64_t)); + case kBool: + return MShadowTypeInfo("bool", sizeof(bool), sizeof(index_t)); + default: + LOG(FATAL) << "Unknown type flag " << type_flag; + return MShadowTypeInfo("INVALID", 1); + } +} + } // namespace common } // namespace mxnet diff --git a/src/common/utils.h b/src/common/utils.h index aa0cb6b1b454..5582f711ae1f 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -950,6 +950,20 @@ inline int GetDefaultDtype(int dtype) { mshadow::kFloat32; } +struct MShadowTypeInfo { + std::string name; + int size; + int acc_size; + + MShadowTypeInfo(const std::string name, const int size, const int acc_size) : + name(std::move(name)), size(size), acc_size(acc_size) {} + + MShadowTypeInfo(const std::string name, const int size) : + MShadowTypeInfo(name, size, size) {} +}; + +MShadowTypeInfo mshadow_type_info(const int type_flag); + inline bool AlignedMemAlloc(void** ptr, size_t size, size_t alignment) { #if _MSC_VER *ptr = _aligned_malloc(size, alignment); diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 42d03e55a275..da1e4bc436ab 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -29,7 +29,7 @@ #include #include #include -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" namespace mxnet { namespace engine { diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index e62351687083..3eda2c8712f7 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -30,7 +30,7 @@ #include #include #include "./threaded_engine.h" -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" namespace mxnet { namespace engine { diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index 1304594e24a8..dde16bc8fe5d 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -33,7 +33,7 @@ #include "./thread_pool.h" #include "./stream_manager.h" #if MXNET_USE_CUDA -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" #endif namespace mxnet { diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 0c0c7db98174..eeab47b7e178 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -330,7 +330,7 @@ void OptimizeGraph(nnvm::Graph* full_graph, nnvm::Graph* fwd_graph, nnvm::Graph* size_t num_forward_outputs, const bool inlining) { input_map->resize(full_graph->indexed_graph().input_nodes().size()); std::iota(input_map->begin(), input_map->end(), 0); -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32) +#if MXNET_USE_CUDA && !defined(_WIN32) if (context.dev_mask() == kGPU && !inlining && dmlc::GetEnv("MXNET_USE_FUSION", true)) { @@ -375,7 +375,7 @@ void OptimizeGraph(nnvm::Graph* full_graph, nnvm::Graph* fwd_graph, nnvm::Graph* dmlc::GetEnv("MXNET_USE_FUSION", false)) { exec::WarnFusionNotSupported(); } -#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32) +#endif // MXNET_USE_CUDA && !defined(_WIN32) *fwd_graph = nnvm::Graph(); fwd_graph->outputs = std::vector(full_graph->outputs.begin(), diff --git a/src/imperative/pointwise_fusion_pass.cc b/src/imperative/pointwise_fusion_pass.cc index 3203f67e8b68..656a420eb654 100644 --- a/src/imperative/pointwise_fusion_pass.cc +++ b/src/imperative/pointwise_fusion_pass.cc @@ -48,13 +48,13 @@ void WarnFusionNotSupported() { << "Unset env var MXNET_USE_FUSION=1 to quiet this message."; #else LOG(WARNING) << "Omitting dynamic fused op creation- needs MXNet lib built with " - << "USE_CUDA=1 and ENABLE_CUDA_RTC=1. Unset env var MXNET_USE_FUSION=1 " + << "USE_CUDA=1. Unset env var MXNET_USE_FUSION=1 " << "to quiet this message."; #endif // defined(_WIN32) } } -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA namespace { bool IsFusionCompatible(nnvm::Node* n) { @@ -334,7 +334,7 @@ Graph FusePointwiseBackward(Graph &&g) { ret.outputs = g.outputs; return ret; } -#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#endif // MXNET_USE_CUDA } // namespace exec } // namespace mxnet diff --git a/src/kvstore/kvstore_nccl.h b/src/kvstore/kvstore_nccl.h index e35f3a3da3fb..09bd880bfd68 100644 --- a/src/kvstore/kvstore_nccl.h +++ b/src/kvstore/kvstore_nccl.h @@ -38,7 +38,7 @@ #include #include "./comm.h" #include "./kvstore_local.h" -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" // NCCL v2 introduces NCCL_MAJOR macro for versioning, // so if there is no such macro defined in nccl.h diff --git a/src/libinfo.cc b/src/libinfo.cc index d14aaf5769b2..28ddab200c49 100644 --- a/src/libinfo.cc +++ b/src/libinfo.cc @@ -39,7 +39,6 @@ class FeatureSet { feature_bits.set(CUDA, MXNET_USE_CUDA); feature_bits.set(CUDNN, MXNET_USE_CUDNN); feature_bits.set(NCCL, MXNET_USE_NCCL); - feature_bits.set(CUDA_RTC, MXNET_ENABLE_CUDA_RTC); feature_bits.set(TENSORRT, MXNET_USE_TENSORRT); // Check flags for example with gcc -msse3 -mavx2 -dM -E - < /dev/null | egrep "SSE|AVX" @@ -133,7 +132,6 @@ const std::vector EnumNames::names = { "CUDA", "CUDNN", "NCCL", - "CUDA_RTC", "TENSORRT", "CPU_SSE", "CPU_SSE2", diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index a3f99c1afdf0..e00b4c3f948e 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -31,7 +31,7 @@ #include "../operator/tensor/init_op.h" #include "../operator/tensor/util/tensor_util-inl.h" #include "../operator/tensor/util/tensor_util-inl.cuh" -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" #include "./ndarray_function.h" #include "./ndarray_function-inl.h" #include "./ndarray_function-inl.cuh" diff --git a/src/operator/bilinear_sampler.cu b/src/operator/bilinear_sampler.cu index e8b1ce68847f..dae14a645fd8 100644 --- a/src/operator/bilinear_sampler.cu +++ b/src/operator/bilinear_sampler.cu @@ -26,7 +26,7 @@ #include "./bilinear_sampler-inl.h" #include -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" #if MXNET_USE_CUDNN == 1 #include "./cudnn_bilinear_sampler-inl.h" #endif // MXNET_USE_CUDNN diff --git a/src/operator/contrib/deformable_psroi_pooling.cu b/src/operator/contrib/deformable_psroi_pooling.cu index 62680d1fb8d1..2206b5aa67b3 100644 --- a/src/operator/contrib/deformable_psroi_pooling.cu +++ b/src/operator/contrib/deformable_psroi_pooling.cu @@ -29,7 +29,7 @@ #include #include #include -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" #include "../mxnet_op.h" #define DeformablePSROIPOOLING_CUDA_CHECK(condition) \ diff --git a/src/operator/contrib/gradient_multiplier_op.cu b/src/operator/contrib/gradient_multiplier_op.cu index 7159cea9805d..f519f0db5f49 100644 --- a/src/operator/contrib/gradient_multiplier_op.cu +++ b/src/operator/contrib/gradient_multiplier_op.cu @@ -34,8 +34,8 @@ NNVM_REGISTER_OP(_contrib_gradientmultiplier) .set_attr("FCompute", UnaryOp::IdentityCompute); NNVM_REGISTER_OP(_contrib_backward_gradientmultiplier) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"mul"}) +.set_attr("(FComputeEx", BinaryScalarRTCCompute{"mul"}); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/nn/deformable_im2col.cuh b/src/operator/contrib/nn/deformable_im2col.cuh index 9494fb379faf..8efee9979046 100644 --- a/src/operator/contrib/nn/deformable_im2col.cuh +++ b/src/operator/contrib/nn/deformable_im2col.cuh @@ -67,7 +67,7 @@ #include #include #include "../../mxnet_op.h" -#include "../../../common/cuda_utils.h" +#include "../../../common/cuda/utils.h" diff --git a/src/operator/contrib/nn/modulated_deformable_im2col.cuh b/src/operator/contrib/nn/modulated_deformable_im2col.cuh index 16d9cef46d4e..9673edf813a4 100644 --- a/src/operator/contrib/nn/modulated_deformable_im2col.cuh +++ b/src/operator/contrib/nn/modulated_deformable_im2col.cuh @@ -86,7 +86,7 @@ #include #include #include "../../mxnet_op.h" -#include "../../../common/cuda_utils.h" +#include "../../../common/cuda/utils.h" diff --git a/src/operator/contrib/psroi_pooling.cu b/src/operator/contrib/psroi_pooling.cu index 8765eb95b72e..62ecd4ce8baa 100644 --- a/src/operator/contrib/psroi_pooling.cu +++ b/src/operator/contrib/psroi_pooling.cu @@ -30,7 +30,7 @@ #include #include #include -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" #include "../mxnet_op.h" #define PSROIPOOLING_CUDA_CHECK(condition) \ diff --git a/src/operator/contrib/stes_op.cu b/src/operator/contrib/stes_op.cu index 85e3ddaf206f..5ce947900899 100644 --- a/src/operator/contrib/stes_op.cu +++ b/src/operator/contrib/stes_op.cu @@ -31,13 +31,13 @@ namespace op { // Round STE NNVM_REGISTER_OP(_contrib_round_ste) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"round"}) +.set_attr("FComputeEx", UnaryRTCCompute{"round"}); // Sign STE NNVM_REGISTER_OP(_contrib_sign_ste) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"sign"}) +.set_attr("FComputeEx", UnaryRTCCompute{"sign"}); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu index bcbc18525c09..bfa4993e0d4b 100644 --- a/src/operator/contrib/transformer.cu +++ b/src/operator/contrib/transformer.cu @@ -30,7 +30,7 @@ #include #include "./transformer-inl.h" -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" namespace mxnet { namespace op { diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h index 0b10f821d8e1..0add7eaa99da 100644 --- a/src/operator/fusion/fused_op-inl.h +++ b/src/operator/fusion/fused_op-inl.h @@ -24,42 +24,12 @@ #include #include -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA namespace mxnet { namespace fusion { -const char fp16_support_string[] = R"code( -struct __align__(2) __half { - __host__ __device__ __half() { } - unsigned short __x; -}; -/* Definitions of intrinsics */ -__device__ inline __half __float2half(const float f) { - __half val; - asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(val.__x) : "f"(f)); - return val; -} -__device__ inline float __half2float(const __half h) { - float val; - asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(h.__x)); - return val; -} - -typedef __half half; -)code"; - -const char type_support_string[] = R"code( -using float32 = float; -using float64 = double; -using float16 = half; -using uint8 = unsigned char; -using int8 = char; -using int32 = int; -using int64 = long long; -)code"; - const std::map>> ops_desc = { {"elemwise_add" , {{"op::add(%, %)", "_0", "_1"}}}, {"_plus" , {{"op::add(%, %)", "_0", "_1"}}}, @@ -81,6 +51,7 @@ const std::map>> ops_desc = { {"_maximum" , {{"op::max(%, %)", "_0", "_1"}}}, {"_Minimum" , {{"op::min(%, %)", "_0", "_1"}}}, {"_minimum" , {{"op::min(%, %)", "_0", "_1"}}}, + {"_mod" , {{"op::mod(%, %)", "_0", "_1"}}}, {"amp_cast" , {{"op::identity(%)", "_0"}}}, {"_backward_amp_cast" , {{"op::identity(%)", "_0"}}}, {"relu" , {{"op::relu(%)", "_0"}}}, @@ -150,6 +121,8 @@ const std::map>> ops_desc = { {"_rpower_scalar" , {{"op::rpow(%, float(%))", "_0", "scalar"}}}, {"_RPowerScalar" , {{"op::rpow(%, float(%))", "_0", "scalar"}}}, {"_RDivScalar" , {{"op::rdiv(%, float(%))", "_0", "scalar"}}}, + {"_mod_scalar" , {{"op::mod(%, float(%))", "_0", "scalar"}}}, + {"_rmod_scalar" , {{"op::rmod(%, float(%))", "_0", "scalar"}}}, {"Cast" , {{"op::cast<%>(%)", "dtype", "_0"}}}, {"cast" , {{"op::cast<%>(%)", "dtype", "_0"}}}, {"Activation" , {{"op::%(%)", "act_type", "_0"}}}, @@ -159,51 +132,50 @@ const std::map>> ops_desc = { {"negative" , {{"(-%)", "_0"}}}, {"_hypot" , {{"op::hypot(%, %)", "_0", "_1"}}}, {"_hypot_scalar" , {{"op::hypot(%, float(%))", "_0", "scalar"}}}, - {"_backward_relu" , {{"op::backward_relu(%, %)", "_1", "_0"}}}, - {"_backward_sigmoid" , {{"op::backward_sigmoid(%, %)", "_1", "_0"}}}, - {"_backward_expm1" , {{"op::backward_expm1(%, %)", "_1", "_0"}}}, - {"_backward_log" , {{"op::backward_log(%, %)", "_1", "_0"}}}, - {"_backward_log10" , {{"op::backward_log10(%, %)", "_1", "_0"}}}, - {"_backward_log2" , {{"op::backward_log2(%, %)", "_1", "_0"}}}, - {"_backward_log1p" , {{"op::backward_log1p(%, %)", "_1", "_0"}}}, - {"_backward_sin" , {{"op::backward_sin(%, %)", "_1", "_0"}}}, - {"_backward_cos" , {{"op::backward_cos(%, %)", "_1", "_0"}}}, - {"_backward_tan" , {{"op::backward_tan(%, %)", "_1", "_0"}}}, - {"_backward_arcsin" , {{"op::backward_arcsin(%, %)", "_1", "_0"}}}, - {"_backward_arccos" , {{"op::backward_arccos(%, %)", "_1", "_0"}}}, - {"_backward_arctan" , {{"op::backward_arctan(%, %)", "_1", "_0"}}}, - {"_backward_sinh" , {{"op::backward_sinh(%, %)", "_1", "_0"}}}, - {"_backward_cosh" , {{"op::backward_cosh(%, %)", "_1", "_0"}}}, - {"_backward_tanh" , {{"op::backward_tanh(%, %)", "_1", "_0"}}}, - {"_backward_arcsinh" , {{"op::backward_arcsinh(%, %)", "_1", "_0"}}}, - {"_backward_arccosh" , {{"op::backward_arccosh(%, %)", "_1", "_0"}}}, - {"_backward_arctanh" , {{"op::backward_arctanh(%, %)", "_1", "_0"}}}, - {"_backward_sqrt" , {{"op::backward_sqrt(%, %)", "_1", "_0"}}}, - {"_backward_rsqrt" , {{"op::backward_rsqrt(%, %)", "_1", "_0"}}}, - {"_backward_cbrt" , {{"op::backward_cbrt(%, %)", "_1", "_0"}}}, - {"_backward_rcbrt" , {{"op::backward_rcbrt(%, %)", "_1", "_0"}}}, - {"_backward_square" , {{"op::backward_square(%, %)", "_1", "_0"}}}, + {"logical_not" , {{"op::logical_not(%)", "_0"}}}, + {"_backward_relu" , {{"op::backward_relu(%, %)", "_0", "_1"}}}, + {"_backward_sigmoid" , {{"op::backward_sigmoid(%, %)", "_0", "_1"}}}, + {"_backward_expm1" , {{"op::backward_expm1(%, %)", "_0", "_1"}}}, + {"_backward_log" , {{"op::backward_log(%, %)", "_0", "_1"}}}, + {"_backward_log10" , {{"op::backward_log10(%, %)", "_0", "_1"}}}, + {"_backward_log2" , {{"op::backward_log2(%, %)", "_0", "_1"}}}, + {"_backward_log1p" , {{"op::backward_log1p(%, %)", "_0", "_1"}}}, + {"_backward_sin" , {{"op::backward_sin(%, %)", "_0", "_1"}}}, + {"_backward_cos" , {{"op::backward_cos(%, %)", "_0", "_1"}}}, + {"_backward_tan" , {{"op::backward_tan(%, %)", "_0", "_1"}}}, + {"_backward_arcsin" , {{"op::backward_arcsin(%, %)", "_0", "_1"}}}, + {"_backward_arccos" , {{"op::backward_arccos(%, %)", "_0", "_1"}}}, + {"_backward_arctan" , {{"op::backward_arctan(%, %)", "_0", "_1"}}}, + {"_backward_sinh" , {{"op::backward_sinh(%, %)", "_0", "_1"}}}, + {"_backward_cosh" , {{"op::backward_cosh(%, %)", "_0", "_1"}}}, + {"_backward_tanh" , {{"op::backward_tanh(%, %)", "_0", "_1"}}}, + {"_backward_arcsinh" , {{"op::backward_arcsinh(%, %)", "_0", "_1"}}}, + {"_backward_arccosh" , {{"op::backward_arccosh(%, %)", "_0", "_1"}}}, + {"_backward_arctanh" , {{"op::backward_arctanh(%, %)", "_0", "_1"}}}, + {"_backward_sqrt" , {{"op::backward_sqrt(%, %)", "_0", "_1"}}}, + {"_backward_rsqrt" , {{"op::backward_rsqrt(%, %)", "_0", "_1"}}}, + {"_backward_cbrt" , {{"op::backward_cbrt(%, %)", "_0", "_1"}}}, + {"_backward_rcbrt" , {{"op::backward_rcbrt(%, %)", "_0", "_1"}}}, + {"_backward_square" , {{"op::backward_square(%, %)", "_0", "_1"}}}, {"_backward_div_scalar" , {{"(% * 1.0f/float(%))", "_0", "scalar"}}}, {"_backward_div_scalar" , {{"(% * 1.0f/float(%))", "_0", "scalar"}}}, - {"_backward_rdiv_scalar" , {{"(-% * float(%) / (% * %))", "_0", - "scalar", "_1", "_1"}}}, + {"_backward_rdiv_scalar" , {{"op::rdiv_grad(%, %) * %", "_1", + "scalar", "_0"}}}, {"_backward_hypot_scalar" , {{"(% * % / op::hypot(%, float(%)))", "_0", "_1", "_1", "scalar"}}}, {"_backward_radians" , {{"op::radians(%)", "_0"}}}, - {"_backward_erf" , {{"op::backward_erf(%, %)", "_1", "_0"}}}, - {"_backward_erfinv" , {{"op::backward_erfinv(%, %)", "_1", "_0"}}}, - {"_backward_reciprocal" , {{"op::backward_reciprocal(%, %)", "_1", "_0"}}}, - {"_backward_abs" , {{"(% * op::sign(%))", "_0", "_1"}}}, + {"_backward_erf" , {{"op::backward_erf(%, %)", "_0", "_1"}}}, + {"_backward_erfinv" , {{"op::backward_erfinv(%, %)", "_0", "_1"}}}, + {"_backward_reciprocal" , {{"op::backward_reciprocal(%, %)", "_0", "_1"}}}, + {"_backward_abs" , {{"(op::backward_abs(%, %))", "_0", "_1"}}}, {"_backward_degrees" , {{"op::degrees(%)", "_0"}}}, - {"_backward_sign" , {{"op::zero(%)", "_0"}}}, - {"_backward_clip" , {{"op::backward_clip(%, %, %, %)", "_1", "_0", + {"_backward_clip" , {{"op::backward_clip(%, %, %, %)", "_0", "_1", "a_min", "a_max"}}}, {"smooth_l1" , {{"op::smooth_l1(%, float(%))", "_0", "scalar"}}}, - {"_backward_smooth_l1" , {{"op::backward_smooth_l1(%, float(%), %)", + {"_backward_smooth_l1" , {{"op::smooth_l1_grad(%, float(%)) * %", "_1", "scalar", "_0"}}}, // TODO(ptredak): arange // TODO(ptredak): LeakyRelu - // TODO(ptredak): mod and rmod {"_backward_sub" , {{"(%)", "_0"}, {"(-(%))", "_0"}}}, {"_backward_mul" , {{"(% * %)", "_0", "_2"}, @@ -229,7 +201,7 @@ const std::map>> LeakyReLU_ops {"gelu" , {{"op::gelu(%)", "_0"}}}, }; const std::map>> LeakyReLU_bwd_ops = { - {"gelu" , {{"op::backward_gelu(%, %)", "_1", "_0"}}}, + {"gelu" , {{"op::backward_gelu(%, %)", "_0", "_1"}}}, }; const std::map slice_ops = { @@ -247,772 +219,6 @@ const std::vector variable_io_ops = { "_backward_cast" }; -const char function_definitions[] = R"code( - -#define INT_MAX (2147483647) - -namespace op { - -template -struct LoadType { - using Type = DType; -}; - -template <> -struct LoadType { - using Type = float; -}; - -template -__device__ inline typename LoadType::Type load(const DType input) { - return input; -} - -template <> -__device__ inline float load(const half input) { - return __half2float(input); -} - -template -__device__ inline DType1 store(const DType2 input, DType1* ref) { - return input; -} - -template -__device__ inline half store(const DType input, half* ref) { - return __float2half(input); -} - -template -struct VectorConfig { - static_assert(size >= 4, "VectorConfig needs to have size of at least 4B"); - using IndexType = float; -}; - -template <> -struct VectorConfig<8> { - using IndexType = double; -}; - -template <> -struct VectorConfig<16> { - using IndexType = double2; -}; - -template <> -struct VectorConfig<32> { - using IndexType = double4; -}; - -template -__device__ inline DType add_elem(const DType& x, const DType& y) { - return x + y; -} - -template <> -__device__ inline half add_elem(const half& x, const half& y) { - return __float2half(__half2float(x) + __half2float(y)); -} - -template -union VectorType { - typename VectorConfig::IndexType y; - DType x[nvec]; - __device__ VectorType () {}; - __device__ VectorType (const VectorType& y2) { - y = y2.y; - } - __device__ VectorType (const decltype(y) &y2) { - y = y2; - } - __device__ inline VectorType& operator+=(const VectorType& rhs) { - #pragma unroll - for (int i = 0; i < nvec; ++i) { - x[i] = add_elem(x[i], rhs.x[i]); - } - return *this; - } -}; - -template -struct Shape { - int x[ndim]; - size_t size; - __device__ inline const int& operator [](const int i) const { - return x[i]; - } - __device__ inline int& operator [](const int i) { - return x[i]; - } - __device__ inline void set(const int def) { - #pragma unroll - for (int i = 0; i < ndim; i++) { - x[i] = def; - } - } -}; - -template <> -struct Shape<0> { - size_t size; -}; - -template -__device__ inline VectorType load_index(const DType * input, int i, - const Shape &shape) { - if (i < shape.size) { - const auto* vector_input = reinterpret_cast< - const typename VectorConfig::IndexType *>( - input + i); - VectorType ret = {*vector_input}; - return ret; - } else { - VectorType ret({0}); - return ret; - } -} - -template -__device__ inline VectorType global_load_index(const DType * input, int i, - const Shape &shape) { - if (i < shape.size) { - const auto* vector_input = reinterpret_cast< - const typename VectorConfig::IndexType *>( - input + i); - VectorType ret = {__ldg(vector_input)}; - return ret; - } else { - VectorType ret({0}); - return ret; - } -} - -template -__device__ inline VectorType load_slice(const DType * input, const Shape& shape, - Shape begin, Shape end, - int offset) { - int idx[nvec]; - - Shape ref_strides; - Shape strides; - ref_strides[ndim-1] = 1; - strides[ndim-1] = 1; - #pragma unroll - for (int dim = ndim-1; dim >=0; dim--) { - if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim]; - if (end[dim] < 0) end[dim] = shape[dim] + end[dim]; - if (end[dim] == INT_MAX) end[dim] = shape[dim]; - if (dim > 0) { - ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]); - strides[dim-1] = strides[dim] * shape[dim]; - } - } - #pragma unroll - for (int j = 0; j < nvec; j++) { - idx[j] = 0; - int ref_idx = offset + j; - #pragma unroll - for (int dim = 0; dim < ndim; dim++) { - int stride = ref_strides[dim]; - if (shape[dim] > 1) { - idx[j] += (ref_idx / stride + begin[dim]) * strides[dim]; - } - ref_idx = ref_idx % stride; - } - } - VectorType ret; - #pragma unroll - for (int j = 0; j < nvec; j++) { - ret.x[j] = *(input + idx[j]); - } - return ret; -} - -template -__device__ inline VectorType fast_load_slice(const DType * input, - const Shape& shape, - Shape begin, - Shape end, - int offset) { - int idx = 0; - - Shape ref_strides; - Shape strides; - ref_strides[ndim-1] = 1; - strides[ndim-1] = 1; - #pragma unroll - for (int dim = ndim-1; dim >=0; dim--) { - if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim]; - if (end[dim] < 0) end[dim] = shape[dim] + end[dim]; - if (end[dim] == INT_MAX) end[dim] = shape[dim]; - if (dim > 0) { - ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]); - strides[dim-1] = strides[dim] * shape[dim]; - } - } - int ref_idx = offset; - #pragma unroll - for (int dim = 0; dim < ndim; dim++) { - int stride = ref_strides[dim]; - if (shape[dim] > 1) { - idx += (ref_idx / stride + begin[dim]) * strides[dim]; - } - ref_idx = ref_idx % stride; - } - return global_load_index(input, idx, shape); -} - -template -__device__ inline void store_index(const VectorType value, int i, - DType * output, const Shape& shape) { - if (i < (shape.size + nvec - 1) / nvec) { - auto vector_output = reinterpret_cast< - typename VectorConfig::IndexType *>(output); - vector_output[i] = value.y; - } -} - -template -__device__ inline void store_add_index(const VectorType value, int i, - DType * output, const Shape& shape) { - if (i < (shape.size + nvec - 1) / nvec) { - auto vector_output = reinterpret_cast< - typename VectorConfig::IndexType *>(output); - VectorType ret(vector_output[i]); - ret += value; - vector_output[i] = ret.y; - } -} - -template -__device__ inline DType identity(const DType val) { - return val; -} - -template -__device__ inline DType add(const DType a, const DType2 b) { - return a + b; -} - -template -__device__ inline DType sub(const DType a, const DType2 b) { - return a - b; -} - -template -__device__ inline DType mul(const DType a, const DType2 b) { - return a * b; -} - -template -__device__ inline DType div(const DType a, const DType2 b) { - return a / b; -} - -template -__device__ inline DType rdiv(const DType a, const DType2 b) { - return b / a; -} - -template -__device__ inline DType power(const DType a, const DType2 b) { - return powf(a, b); -} - -template -__device__ inline DType rpow(const DType a, const DType2 b) { - return powf(b, a); -} - -template -__device__ inline DType max(const DType a, const DType2 b) { - return a > b ? a : b; -} - -template -__device__ inline DType min(const DType a, const DType2 b) { - return a < b ? a : b; -} - -template -__device__ inline DType hypot(const DType a, const DType2 b) { - return hypotf(a, b); -} - -template -__device__ inline typename LoadType::Type cast(const DType val) { - return static_cast::Type>(val); -} - -// activations - -template -__device__ inline DType relu(const DType val) { - return val > 0 ? val : 0; -} - -const float SQRT_2 = 1.4142135623730950488016887242096; -// compatible with mshadow_op.h version -template -__device__ inline DType gelu(const DType val) { - return DType(0.5f * static_cast(val) * - (1.0f + erf(static_cast(val) / SQRT_2))); -} - -template -__device__ inline DType sigmoid(const DType val) { - return 1.f/(1 + expf(-val)); -} - -template -__device__ inline DType softrelu(const DType val) { - // Avoid overflow of exp for large inputs. - // The threshold 20 is chosen such that softrelu(a) = a - // for a > 20 using floating precision. - return val > 20 ? val : logf(1 + expf(val)); -} - -template -__device__ inline DType softsign(const DType val) { - return val / (1 + fabsf(val)); -} - -// exp and log - -template -__device__ inline DType exp(const DType val) { - return expf(val); -} - -template -__device__ inline DType expm1(const DType val) { - return expm1f(val); -} - -template -__device__ inline DType log(const DType val) { - return logf(val); -} - -template -__device__ inline DType log10(const DType val) { - return log10f(val); -} - -template -__device__ inline DType log2(const DType val) { - return log2f(val); -} - -template -__device__ inline DType log1p(const DType val) { - return log1pf(val); -} - -// trigonometric - -constexpr double pi = 3.14159265358979323846; - -template -__device__ inline DType degrees(const DType val) { - return (val / pi) * 180; -} - -template -__device__ inline DType radians(const DType val) { - return (val / 180.0) * pi; -} - -template -__device__ inline DType sin(const DType val) { - return sinf(val); -} - -template -__device__ inline DType cos(const DType val) { - return cosf(val); -} - -template -__device__ inline DType tan(const DType val) { - return tanf(val); -} - -template -__device__ inline DType arcsin(const DType val) { - return asinf(val); -} - -template -__device__ inline DType arccos(const DType val) { - return acosf(val); -} - -template -__device__ inline DType arctan(const DType val) { - return atanf(val); -} - -template -__device__ inline DType sinh(const DType val) { - return sinhf(val); -} - -template -__device__ inline DType cosh(const DType val) { - return coshf(val); -} - -template -__device__ inline DType tanh(const DType val) { - return tanhf(val); -} - -template -__device__ inline DType arcsinh(const DType val) { - return asinhf(val); -} - -template -__device__ inline DType arccosh(const DType val) { - return acoshf(val); -} - -template -__device__ inline DType arctanh(const DType val) { - return atanhf(val); -} - -// sqrt - -template -__device__ inline DType sqrt(const DType val) { - return sqrtf(val); -} - -template -__device__ inline DType rsqrt(const DType val) { - return rsqrtf(val); -} - -template -__device__ inline DType cbrt(const DType val) { - return cbrtf(val); -} - -template -__device__ inline DType rcbrt(const DType val) { - return rcbrtf(val); -} - -template -__device__ inline DType square(const DType val) { - return val * val; -} - -template -__device__ inline typename LoadType::Type zero(const DType val) { - return 0; -} - -template -__device__ inline typename LoadType::Type zero() { - return 0; -} - -template -__device__ inline typename LoadType::Type one(const DType val) { - return 1; -} - -template -__device__ inline typename LoadType::Type one() { - return 1; -} - -template -__device__ inline DType round(const DType val) { - return roundf(val); -} - -template -__device__ inline DType rint(const DType val) { - return rintf(val); -} - -template -__device__ inline DType fix(const DType val) { - const auto floor = floorf(val); - const auto ceil = ceilf(val); - return (floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil; -} - -template -__device__ inline DType floor(const DType val) { - return floorf(val); -} - -template -__device__ inline DType ceil(const DType val) { - return ceilf(val); -} - -template -__device__ inline DType trunc(const DType val) { - return truncf(val); -} - -template -__device__ inline DType clip(const DType val, const float a_min, const float a_max) { - return max(min(val, a_max), a_min); -} - -template -__device__ inline DType sign(const DType val) { - if (val < 0) return -1; - return val > 0 ? 1 : 0; -} - -template -__device__ inline DType reciprocal(const DType val) { - return 1.0f / val; -} - -template -__device__ inline DType abs(const DType val) { - return fabsf(val); -} - -template -__device__ inline DType gamma(const DType val) { - return tgammaf(val); -} - -template -__device__ inline DType gammaln(const DType val) { - return lgammaf(val); -} - -template -__device__ inline DType erf(const DType val) { - return erff(val); -} - -template -__device__ inline DType erfinv(const DType val) { - return erfinvf(val); -} - -template -__device__ inline DType1 smooth_l1(const DType1 val, const DType2 scalar) { - const auto bsq = scalar * scalar; - const auto ibsq = 1.0f / bsq; - if (val > ibsq) { - return val - 0.5f * ibsq; - } else if (val < -ibsq) { - return -val - 0.5f * ibsq; - } else { - return 0.5f * val * val * bsq; - } -} - -} // namespace op - -)code"; - -const char backward_function_definitions[] = R"code( - -namespace op { - -template -__device__ inline DTypeGrad backward_relu(const DType val, const DTypeGrad grad) { - return val > 0 ? grad : 0; -} - -template -__device__ inline DTypeGrad backward_sigmoid(const DType out, const DTypeGrad grad) { - return grad * out * (1 - out); -} - -template -__device__ inline DTypeGrad backward_softrelu(const DType val, const DTypeGrad grad) { - return grad * sigmoid(val); -} - -template -__device__ inline DTypeGrad backward_softsign(const DType val, const DTypeGrad grad) { - const DType ap1 = 1 + fabsf(val); - return grad / (ap1 * ap1); -} - -template -__device__ inline DTypeGrad backward_exp(const DType val, const DTypeGrad grad) { - return grad * expf(val); -} - -template -__device__ inline DTypeGrad backward_expm1(const DType val, const DTypeGrad grad) { - return grad * expf(val); -} - -template -__device__ inline DTypeGrad backward_log(const DType val, const DTypeGrad grad) { - return grad / val; -} - -template -__device__ inline DTypeGrad backward_log10(const DType val, const DTypeGrad grad) { - return grad / (val * logf(10)); -} - -template -__device__ inline DTypeGrad backward_log2(const DType val, const DTypeGrad grad) { - return grad / (val * logf(2)); -} - -template -__device__ inline DTypeGrad backward_log1p(const DType val, const DTypeGrad grad) { - return grad / (1 + val); -} - -template -__device__ inline DTypeGrad backward_sin(const DType val, const DTypeGrad grad) { - return grad * cosf(val); -} - -template -__device__ inline DTypeGrad backward_cos(const DType val, const DTypeGrad grad) { - return -grad * sinf(val); -} - -// Uses output from tan -template -__device__ inline DTypeGrad backward_tan(const DType out, const DTypeGrad grad) { - return grad * (out * out + 1); -} - -template -__device__ inline DTypeGrad backward_arcsin(const DType val, const DTypeGrad grad) { - return grad / sqrtf(1 - val*val); -} - -template -__device__ inline DTypeGrad backward_arccos(const DType val, const DTypeGrad grad) { - return -grad / sqrtf(1 - val*val); -} - -template -__device__ inline DTypeGrad backward_arctan(const DType val, const DTypeGrad grad) { - return grad / (1 + val*val); -} - -template -__device__ inline DTypeGrad backward_sinh(const DType val, const DTypeGrad grad) { - return grad * coshf(val); -} - -template -__device__ inline DTypeGrad backward_cosh(const DType val, const DTypeGrad grad) { - return grad * sinhf(val); -} - -// Uses tanh output -template -__device__ inline DTypeGrad backward_tanh(const DType out, const DTypeGrad grad) { - return grad * (1 - out * out); -} - -template -__device__ inline DTypeGrad backward_arcsinh(const DType val, const DTypeGrad grad) { - return grad / sqrtf(val * val + 1); -} - -template -__device__ inline DTypeGrad backward_arccosh(const DType val, const DTypeGrad grad) { - return grad / sqrtf(val * val - 1); -} - -template -__device__ inline DTypeGrad backward_arctanh(const DType val, const DTypeGrad grad) { - return grad / (1 - val * val); -} - -template -__device__ inline DTypeGrad backward_sqrt(const DType out, const DTypeGrad grad) { - return 0.5 * grad / out; -} - -template -__device__ inline DTypeGrad backward_rsqrt(const DType val, const DTypeGrad grad) { - const DType inv = 1 / val; - return -0.5 * grad * sqrtf(inv) * inv; -} - -template -__device__ inline DTypeGrad backward_cbrt(const DType out, const DTypeGrad grad) { - return grad / (3.0f * out * out); -} - -template -__device__ inline DTypeGrad backward_rcbrt(const DType val, const DTypeGrad grad) { - const DType inv = 1 / val; - return -1.f/3.f * grad * cbrtf(inv) * inv; -} - -template -__device__ inline DTypeGrad backward_square(const DType val, const DTypeGrad grad) { - return 2 * val * grad; -} - -template -__device__ inline DTypeGrad backward_clip(const DType val, const DTypeGrad grad, - const float a_min, const float a_max) { - if (val > a_max || val < a_min) { - return 0; - } else { - return grad; - } -} - -template -__device__ inline DTypeGrad backward_reciprocal(const DType val, const DTypeGrad grad) { - return -grad / (val * val); -} - -template -__device__ inline DTypeGrad backward_erf(const DType val, const DTypeGrad grad) { - return 2.0f / sqrt(pi) * exp(-(val*val)) * grad; -} - -template -__device__ inline DTypeGrad backward_erfinv(const DType val, const DTypeGrad grad) { - return 0.5f * sqrt(pi) * exp(val * val) * grad; -} - -template -__device__ inline DTypeGrad backward_smooth_l1(const DType val, const DType2 scalar, - const DTypeGrad grad) { - auto bsq = scalar * scalar; - auto ibsq = 1.0f / bsq; - if (val > ibsq) { - return grad; - } else if (val < -ibsq) { - return -grad; - } else { - return bsq * val * grad; - } -} - -// compatible with mshadow_op.h version -template -__device__ inline DTypeGrad backward_gelu(const DType val, const DTypeGrad grad) { - return grad * DType(0.5f * (1.0f + erf(static_cast(val) / SQRT_2) + - static_cast(val) * backward_erf(static_cast(val) / SQRT_2, 1.0f) / SQRT_2)); -} - -} // namespace op - -)code"; const char kernel_begin[] = R"code( const int tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -1028,6 +234,6 @@ const char kernel_end[] = R"code(} } // namespace mxnet -#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#endif // MXNET_USE_CUDA #endif // MXNET_OPERATOR_FUSION_FUSED_OP_INL_H_ diff --git a/src/operator/fusion/fused_op.cc b/src/operator/fusion/fused_op.cc index 596f4e7146e0..fafc75d9aa93 100644 --- a/src/operator/fusion/fused_op.cc +++ b/src/operator/fusion/fused_op.cc @@ -23,7 +23,7 @@ #include "../operator_common.h" #include "../../imperative/exec_pass.h" -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA namespace mxnet { @@ -302,4 +302,4 @@ NNVM_REGISTER_OP(_FusedOpOutHelper) } // namespace mxnet -#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#endif // MXNET_USE_CUDA diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu index fe667946a0c4..3c27e87af2a7 100644 --- a/src/operator/fusion/fused_op.cu +++ b/src/operator/fusion/fused_op.cu @@ -17,12 +17,7 @@ * under the License. */ -// Additional use of MXNET_USE_CUDA is not needed to guard a '.cu' file. -#if MXNET_ENABLE_CUDA_RTC - #include -#include -#include #include #include #include @@ -31,7 +26,8 @@ #include "../operator_common.h" #include "../elemwise_op_common.h" #include "../../imperative/exec_pass.h" -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" +#include "../../common/cuda/rtc.h" namespace mxnet { @@ -170,30 +166,6 @@ void AddPointerAndShape(const TBlob& data, }); } -// Obtain compilation log from the program. -std::string GetCompileLog(nvrtcProgram program) { - size_t log_size_including_null; - NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size_including_null)); - // For most std::string implementations, this is probably 1 char bigger than needed. OK though. - std::string log(log_size_including_null, '\0'); - NVRTC_CALL(nvrtcGetProgramLog(program, &log[0])); - // Make sure the string reflects the true size (so minus the null terminator). - log.resize(log_size_including_null - 1); - return log; -} - -// Obtain compilation result (ptx assembly) from the program. -std::string GetPtx(nvrtcProgram program) { - size_t ptx_size_including_null; - NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size_including_null)); - // For most std::string implementations, this is probably 1 char bigger than needed. OK though. - std::string ptx(ptx_size_including_null, '\0'); - NVRTC_CALL(nvrtcGetPTX(program, &ptx[0])); - // Make sure the string reflects the true size (so minus the null terminator). - ptx.resize(ptx_size_including_null - 1); - return ptx; -} - } // namespace std::string FusedOp::GenerateCode(const std::vector &req, @@ -360,7 +332,7 @@ std::string FusedOp::GenerateCode(const std::vector &req, size_t counter = 0; for (const auto& entry : g.outputs()) { std::string var_name = "output" + std::to_string(counter); - code += "op::VectorType vec_" + var_name + ";\n"; ++counter; } @@ -376,7 +348,7 @@ std::string FusedOp::GenerateCode(const std::vector &req, if (source->is_variable()) { if (load_index[i]) { code += "const auto " + var_name + " = op::load(vec_" + - variables[{i, 0}] + ".x[j]);\n"; + variables[{i, 0}] + ".scratch_.separate[j]);\n"; CHECK_EQ(outputs[i], 1); variables[{i, 0}] = var_name; } @@ -398,7 +370,9 @@ std::string FusedOp::GenerateCode(const std::vector &req, } if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) { - code += "const auto " + var_name + " = op::load(" + variables[{i, 0}] + ".x[j]);\n"; + code += "const auto " + var_name + + " = op::load(" + variables[{i, 0}] + + ".scratch_.separate[j]);\n"; variables[{i, 0}] = var_name; continue; } @@ -422,17 +396,17 @@ std::string FusedOp::GenerateCode(const std::vector &req, if (op_name == "_backward_Activation") { CHECK_EQ(outputs[i], 1); std::string act_type = node.source->attrs.dict.at("act_type"); - std::string rhs, lhs; - rhs = variables[{node.inputs[0].node_id, node.inputs[0].index}]; + std::string ograd, input; + ograd = variables[{node.inputs[0].node_id, node.inputs[0].index}]; if (act_type == "relu" || act_type == "sigmoid" || act_type == "tanh") { - lhs = variables[{node.inputs[1].node_id, node.inputs[1].index}]; + input = variables[{node.inputs[1].node_id, node.inputs[1].index}]; } else { - lhs = variables[{node.inputs[2].node_id, node.inputs[2].index}]; + input = variables[{node.inputs[2].node_id, node.inputs[2].index}]; } code += "const auto " + var_name + " = op::backward_" + act_type + - "(" + lhs + ", " + rhs + ");\n"; + "(" + ograd + ", " + input + ");\n"; variables[{i, 0}] = var_name; continue; @@ -507,7 +481,7 @@ std::string FusedOp::GenerateCode(const std::vector &req, for (const auto& entry : g.outputs()) { const std::string& var = variables[{entry.node_id, entry.index}]; const auto var_name = "output" + std::to_string(counter); - code += "vec_" + var_name + ".x[j] = op::store("+ var +", " + var_name + ");\n"; + code += "vec_" + var_name + ".scratch_.separate[j] = op::store("+ var +", " + var_name + ");\n"; ++counter; } @@ -595,86 +569,7 @@ std::string FusedOp::GenerateCode(const std::vector &req, CUfunction FusedOp::CompileCode(const std::string &code, const std::string &kernel_name, int dev_id) { - // Guard NVRTC calls - std::lock_guard lock_nvrtc(mutex_); - // Local class for value type of compile cache - struct KernelInfo { - std::string mangled_name; - std::string ptx; - std::vector functions; - }; - // Maps from the cuda source code (minus header) to the ptx and jit-compiled CUfunctions. - using KernelCache = std::map; - // Per-gpu-architecture compiled kernel cache with jit-compiled function for each device context - static std::map compiled_kernels; - int sm_arch = SMArch(dev_id); - KernelCache& compiled_kernels_this_arch = compiled_kernels[sm_arch]; // make null map as needed - KernelInfo& kinfo = compiled_kernels_this_arch[code]; // make KernelInfo as needed - if (kinfo.ptx.size() == 0) { - // It's the first time we've seen this kernel, so we need to generate the ptx and mangled_name. - static std::string common_header = - std::string(fusion::fp16_support_string) + "\n" + - fusion::type_support_string + "\n" + - fusion::function_definitions + "\n" + - fusion::backward_function_definitions + "\n"; - std::string code_with_header = common_header + code; - // If verbose mode, output kernel source, though not including the common header - if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) { - LOG(INFO) << "\n" << std::string(80, '-') << "\n" << code; - } - if (compiled_kernels_this_arch.size() == CACHESIZE_WARN_THRESHOLD + 1 && - dmlc::GetEnv("MXNET_FUSION_SIZE_WARNING", true)) { - LOG(WARNING) << "The number of different fused ops exceeds " << CACHESIZE_WARN_THRESHOLD - << ". Set MXNET_FUSION_SIZE_WARNING=0 to quiet this warning."; - } - nvrtcProgram program; - NVRTC_CALL(nvrtcCreateProgram(&program, // prog - &code_with_header[0], // buffer - (kernel_name + "_kernel.cu").c_str(), // name - 0, // num headers - nullptr, // headers - nullptr)); // include names - - std::string gpu_arch_arg = "--gpu-architecture=compute_" + std::to_string(sm_arch); - const char *opts[] = {gpu_arch_arg.c_str(), - "--std=c++14"}; - const std::string kernel_name_demangled = "FusedKernel_" + kernel_name; - NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str())); - - nvrtcResult compileResult = nvrtcCompileProgram(program, // prog - 2, // num options - opts); // options - CHECK_EQ(compileResult, NVRTC_SUCCESS) - << "NVRTC Compilation failed. Please set environment variable MXNET_USE_FUSION to 0.\n" - << GetCompileLog(program); - - kinfo.ptx = GetPtx(program); - const char *mangled_name; - NVRTC_CALL(nvrtcGetLoweredName(program, - kernel_name_demangled.c_str(), - &mangled_name)); - kinfo.mangled_name = mangled_name; - // Destroy the program. - NVRTC_CALL(nvrtcDestroyProgram(&program)); - } - // Ensure function array is deep enough to index by dev_id - while (kinfo.functions.size() <= static_cast(dev_id)) - kinfo.functions.push_back(static_cast(nullptr)); - // Jit-compile ptx for the device as needed - if (kinfo.functions[dev_id] == static_cast(nullptr)) { - // Make sure driver context is set to the proper device - CUdevice cu_device; - CUcontext context; - CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id)); - CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device)); - // Jit-compile ptx for the driver's current context - CUmodule module; - CUDA_DRIVER_CALL(cuModuleLoadData(&module, kinfo.ptx.c_str())); - CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id], - module, - kinfo.mangled_name.c_str())); - } - return kinfo.functions[dev_id]; + return common::cuda::rtc::get_function(code, "FusedKernel_" + kernel_name, "", dev_id); } @@ -779,8 +674,7 @@ void FusedOp::Forward(const nnvm::NodeAttrs& attrs, << ", not expecting switch to device " << dev_id; Stream* s = ctx.get_stream(); - auto stream = Stream::GetStream(s); - std::vector args; + std::vector args; size_t N = 0; for (const auto& output : outputs) { N = std::max(N, output.shape_.Size()); @@ -819,12 +713,10 @@ void FusedOp::Forward(const nnvm::NodeAttrs& attrs, } } } - CUDA_DRIVER_CALL( - cuLaunchKernel(kernel_functions_[kernel_variant], - num_blocks, 1, 1, // grid dim - FusedOp::NTHREADS, 1, 1, // block dim - 0, stream, // shared mem and stream - &(args[0]), 0)); // arguments + common::cuda::rtc::launch(kernel_functions_[kernel_variant], + {num_blocks, 1, 1}, + {FusedOp::NTHREADS, 1, 1}, + 0, s, &args); } void FusedOpForwardGPU(const nnvm::NodeAttrs& attrs, @@ -840,5 +732,3 @@ NNVM_REGISTER_OP(_FusedOp) .set_attr("FCompute", FusedOpForwardGPU); } // namespace mxnet - -#endif // MXNET_ENABLE_CUDA_RTC diff --git a/src/operator/fusion/fused_op.h b/src/operator/fusion/fused_op.h index 3a1db4e2a369..bb13309bc9cb 100644 --- a/src/operator/fusion/fused_op.h +++ b/src/operator/fusion/fused_op.h @@ -28,7 +28,7 @@ #include #include -#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#if MXNET_USE_CUDA namespace mxnet { @@ -58,7 +58,6 @@ struct FusedOpEntry { class FusedOp { public: static const int NTHREADS = 512; - static const int CACHESIZE_WARN_THRESHOLD = 10000; explicit FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config); ~FusedOp() {} @@ -201,6 +200,6 @@ using FusedOpHelperParamPtr = std::shared_ptr; } // namespace mxnet -#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC +#endif // MXNET_USE_CUDA #endif // MXNET_OPERATOR_FUSION_FUSED_OP_H_ diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index 3d81cfc0d967..945bd00f74c1 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -254,13 +254,21 @@ class LeakyReLUOp : public Operator { &new_rshape, &new_oshape) != 0; if (!need_bc) { +#if !defined(__CUDACC__) ElemwiseBinaryOp::BackwardUseIn( nnvm::NodeAttrs(), ctx, {out_grad[leakyrelu::kOut], in_data[leakyrelu::kData], in_data[leakyrelu::kGamma]}, req, in_grad); +#else + ElemwiseBinaryRTCBwdUseIn {"xelu_grad", "prelu_grad"}( + nnvm::NodeAttrs(), ctx, {out_grad[leakyrelu::kOut], + in_data[leakyrelu::kData], + in_data[leakyrelu::kGamma]}, req, in_grad); +#endif // !defined(__CUDACC__) } else { +#if !defined(__CUDACC__) BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, { BinaryBroadcastBackwardUseInImpl( @@ -269,6 +277,16 @@ class LeakyReLUOp : public Operator { in_data[leakyrelu::kGamma]}, req, in_grad, new_lshape, new_rshape, new_oshape); }); +#else + std::vector new_in_grad(2); + new_in_grad[leakyrelu::kData] = in_grad[leakyrelu::kData]; + new_in_grad[leakyrelu::kGamma] = in_grad[leakyrelu::kGamma].reshape(gshape); + BinaryBroadcastRTCBackwardUseIn {"xelu_grad", "prelu_grad"}( + nnvm::NodeAttrs(), ctx, {out_grad[leakyrelu::kOut], + in_data[leakyrelu::kData], + in_data[leakyrelu::kGamma]}, + req, new_in_grad); +#endif // !defined(__CUDACC__) } break; } diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 6d94f33bc700..83a75ccf19cd 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -29,7 +29,7 @@ #include -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" #include "mxnet_op.h" // Convenience functions. diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 5d20a9ffe9bf..ccc39ab1d8bf 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -746,22 +746,8 @@ MXNET_BINARY_MATH_OP_NC(negone, -1); MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b)); -template<> -MSHADOW_XINLINE mshadow::half::half2_t div_grad::Map - (mshadow::half::half2_t a, - mshadow::half::half2_t b) { - return mshadow::half::half2_t(1) / b; -} - MXNET_BINARY_MATH_OP(div_rgrad, -math::id(a) / math::sqr(b)); -template<> -MSHADOW_XINLINE mshadow::half::half2_t div_rgrad::Map - (mshadow::half::half2_t a, - mshadow::half::half2_t b) { - return -a / (b * b); -} - MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a)); MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a)); @@ -774,8 +760,6 @@ MXNET_BINARY_MATH_OP(copysign_rgrad, 0); MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b); -MXNET_BINARY_MATH_OP(rcopysign_grad, 0); - struct mod : public mxnet_op::tunable { template MSHADOW_XINLINE static typename enable_if::value, DType>::type @@ -879,13 +863,6 @@ struct rfmod : public mxnet_op::tunable { } }; -template<> -MSHADOW_XINLINE mshadow::half::half2_t mod::Map - (mshadow::half::half2_t a, - mshadow::half::half2_t b) { - return a%b; -} - struct mod_grad : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { @@ -907,19 +884,6 @@ MSHADOW_XINLINE mshadow::half::half_t mod_grad::Map mshadow::half::half_t b) { return mshadow::half::half_t(1.0f); } -template<> -MSHADOW_XINLINE mshadow::half::half2_t mod_grad::Map - (mshadow::half::half2_t a, - mshadow::half::half2_t b) { - mshadow::half::half2_t result = mshadow::half::half2_t(); -#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2) - result.half2_ = ::__float2half2_rn(1.0f); -#else - result.half_t2[0] = mshadow::half::half_t(0.0f); - result.half_t2[1] = mshadow::half::half_t(1.0f); -#endif - return result; -} struct mod_rgrad : public mxnet_op::tunable { template @@ -942,19 +906,6 @@ MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map mshadow::half::half_t b) { return mshadow::half::half_t(-::floorf(static_cast(a/b))); } -template<> -MSHADOW_XINLINE mshadow::half::half2_t mod_rgrad::Map - (mshadow::half::half2_t a, - mshadow::half::half2_t b) { -#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2) - return mshadow::half::half2_t(__hneg2(::h2floor((a/b).half2_))); -#else - return mshadow::half::half2_t(mshadow::half::half_t(-::floorf( - static_cast(a.half_t2[0]/b.half_t2[0]))), - mshadow::half::half_t(-::floorf( - static_cast(a.half_t2[1]/b.half_t2[1])))); -#endif -} struct rmod : public mxnet_op::tunable { template @@ -991,13 +942,6 @@ struct rmod : public mxnet_op::tunable { } }; -template<> -MSHADOW_XINLINE mshadow::half::half2_t rmod::Map - (mshadow::half::half2_t a, - mshadow::half::half2_t b) { - return b%a; -} - struct rmod_grad { template MSHADOW_XINLINE static DType Map(DType a, DType b) { @@ -1019,19 +963,6 @@ MSHADOW_XINLINE mshadow::half::half_t rmod_grad::Map mshadow::half::half_t b) { return mshadow::half::half_t(-::floorf(static_cast(b/a))); } -template<> -MSHADOW_XINLINE mshadow::half::half2_t rmod_grad::Map - (mshadow::half::half2_t a, - mshadow::half::half2_t b) { -#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2) - return mshadow::half::half2_t(::__hneg2(::h2floor((b/a).half2_))); -#else - return mshadow::half::half2_t(mshadow::half::half_t(-::floorf( - static_cast(b.half_t2[0]/a.half_t2[0]))), - mshadow::half::half_t(-::floorf( - static_cast(b.half_t2[1]/a.half_t2[1])))); -#endif -} struct clip : public mxnet_op::tunable { template diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 8b7a38be3986..81cb4493798d 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -35,7 +35,7 @@ #include "../engine/openmp.h" #ifdef __CUDACC__ -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" #endif // __CUDACC__ namespace mxnet { diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 72e4a76a26d4..894bfcd489f3 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -27,7 +27,7 @@ #include #include #include "batch_norm-inl.h" -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" #define WRITE_DATA_FLAG 1 @@ -44,7 +44,6 @@ #include "./cudnn/cudnn_batch_norm-inl.h" #endif -#include "../../common/cuda_utils.h" #include "../../../include/mxnet/tensor_blob.h" using namespace mxnet; diff --git a/src/operator/nn/cudnn/cudnn_activation-inl.h b/src/operator/nn/cudnn/cudnn_activation-inl.h index 186274b2f1e1..5ad0da3d5dea 100644 --- a/src/operator/nn/cudnn/cudnn_activation-inl.h +++ b/src/operator/nn/cudnn/cudnn_activation-inl.h @@ -29,7 +29,7 @@ #include #include #include "../activation-inl.h" -#include "../../../common/cuda_utils.h" +#include "../../../common/cuda/utils.h" namespace mxnet { namespace op { diff --git a/src/operator/nn/cudnn/cudnn_algoreg-inl.h b/src/operator/nn/cudnn/cudnn_algoreg-inl.h index f7e01e214719..00939cfd8679 100644 --- a/src/operator/nn/cudnn/cudnn_algoreg-inl.h +++ b/src/operator/nn/cudnn/cudnn_algoreg-inl.h @@ -32,7 +32,7 @@ #include #include #include -#include "../../../common/cuda_utils.h" +#include "../../../common/cuda/utils.h" #include "../convolution-inl.h" #include "../deconvolution-inl.h" namespace mxnet { diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index c5beb8a9c575..4ab3dfa6ed9f 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -33,7 +33,7 @@ #include #include "../convolution-inl.h" #include "./cudnn_algoreg-inl.h" -#include "../../../common/cuda_utils.h" +#include "../../../common/cuda/utils.h" namespace mxnet { namespace op { diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 4f025113a45e..5c4b11f0148a 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -33,7 +33,7 @@ #include #include "../deconvolution-inl.h" #include "./cudnn_algoreg-inl.h" -#include "../../../common/cuda_utils.h" +#include "../../../common/cuda/utils.h" namespace mxnet { namespace op { diff --git a/src/operator/nn/depthwise_convolution-inl.h b/src/operator/nn/depthwise_convolution-inl.h index 9db2650491a8..cd2cbd1c0788 100644 --- a/src/operator/nn/depthwise_convolution-inl.h +++ b/src/operator/nn/depthwise_convolution-inl.h @@ -27,7 +27,7 @@ #include #include #include "./convolution-inl.h" -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" #if MXNET_USE_CUDA #include diff --git a/src/operator/nn/depthwise_convolution_tf.cuh b/src/operator/nn/depthwise_convolution_tf.cuh index e59d8986b895..bb91ea9fb050 100644 --- a/src/operator/nn/depthwise_convolution_tf.cuh +++ b/src/operator/nn/depthwise_convolution_tf.cuh @@ -26,7 +26,7 @@ */ #ifndef MXNET_OPERATOR_NN_DEPTHWISE_CONVOLUTION_TF_CUH_ #define MXNET_OPERATOR_NN_DEPTHWISE_CONVOLUTION_TF_CUH_ -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" #include "../mxnet_op.h" namespace tf { diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h index 143e2168d113..da30192231c7 100644 --- a/src/operator/nn/group_norm-inl.h +++ b/src/operator/nn/group_norm-inl.h @@ -115,10 +115,9 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, size_t workspace_size = 0; MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - workspace_size = - broadcast::ReduceWorkspaceSize(s, red_dst_shape, req[0], red_src_shape); - }); + workspace_size = + broadcast::ReduceWorkspaceSize(s, red_dst_shape, req[0], + red_src_shape, sizeof(DType)); }); workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); @@ -139,9 +138,15 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, const TBlob& output_grp = outputs[groupnorm::kOut].reshape(temp_data_shape); // Calculate data = data - mean +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {data_grp, mean_grp}, {kWriteTo}, {output_grp}); +#else + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {data_grp, mean_grp}, + {kWriteTo}, {output_grp}); +#endif // !defined(__CUDACC__) // Calculate std const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape); @@ -156,9 +161,15 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, }); // Calculate data = data / std +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {output_grp, std_grp}, {kWriteTo}, {output_grp}); +#else + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {output_grp, std_grp}, + {kWriteTo}, {output_grp}); +#endif // !defined(__CUDACC__) const TBlob& output = outputs[groupnorm::kOut]; mxnet::TShape new_param_shape(data_shape.ndim(), 1); @@ -167,6 +178,7 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape); const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape); +#if !defined(__CUDACC__) // Calculate data = data * gamma BinaryBroadcastCompute(attrs, ctx, {output, gamma}, @@ -175,6 +187,16 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs, BinaryBroadcastCompute(attrs, ctx, {output, beta}, {kWriteTo}, {output}); +#else + // Calculate data = data * gamma + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {output, gamma}, + {kWriteTo}, {output}); + // Calculate data = data + beta + BinaryBroadcastRTCCompute {"add"}(attrs, ctx, + {output, beta}, + {kWriteTo}, {output}); +#endif // !defined(__CUDACC__) } /* @@ -250,18 +272,16 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, // There are two types of reduction workloads: reduce over axis and reduce exclude axis // We take the maximum of the workspace sizes required by these workloads. // Also, we explicitly set the req_type=kAddto in case we want to use it. - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - reduce_workspace_size = - std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_dst_shape, - kAddTo, red_src_shape)); - }); - BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - reduce_workspace_size = - std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, - red_exclude_src_shape)); - }); + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_dst_shape, + kAddTo, red_src_shape, + sizeof(DType))); + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, + red_exclude_src_shape, + sizeof(DType))); }); workspace = ctx.requested[0].get_space_typed( Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s); @@ -273,12 +293,21 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, mean_.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); // Compute normalized_data = (data - mean) / std +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {data_, mean_}, {kWriteTo}, {normalized_data}); BinaryBroadcastCompute(attrs, ctx, {normalized_data, std_}, {kWriteTo}, {normalized_data}); +#else + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {data_, mean_}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {normalized_data, std_}, + {kWriteTo}, {normalized_data}); +#endif // !defined(__CUDACC__) // Calculate grad_beta if (req[2] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { @@ -290,8 +319,13 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, }); } // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) +#if !defined(__CUDACC__) ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, {kWriteTo}, {ograd_mult}); +#else + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); +#endif // !defined(__CUDACC__) if (req[1] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { @@ -308,12 +342,23 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) if (req[0] != kNullOp) { const TBlob output_ = outputs[0].reshape(data_.shape_); +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {inputs[0], gamma}, - {kWriteTo}, {ograd_mult.reshape(data.shape_)}); + {kWriteTo}, + {ograd_mult.reshape(data.shape_)}); BinaryBroadcastCompute(attrs, ctx, {ograd_mult, std_}, {kWriteTo}, {ograd_mult}); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {inputs[0], gamma}, + {kWriteTo}, + {ograd_mult.reshape(data.shape_)}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {ograd_mult, std_}, + {kWriteTo}, {ograd_mult}); +#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( @@ -323,11 +368,19 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(N); }); +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {ograd_mult, red_out}, {req[0]}, {output_}); ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, {kWriteTo}, {ograd_mult}); +#else + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {output_}); + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); +#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { broadcast::Reduce( @@ -337,9 +390,15 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(-N); }); +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {normalized_data, red_out}, {kAddTo}, {output_}); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {output_}); +#endif // !defined(__CUDACC__) } } diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index 8dcaeb3d6510..b440e3d96952 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -108,10 +108,9 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, Tensor workspace; size_t workspace_size = 0; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - workspace_size = - broadcast::ReduceWorkspaceSize(s, mean_data.shape_, req[0], in_data.shape_); - }); + workspace_size = + broadcast::ReduceWorkspaceSize(s, mean_data.shape_, req[0], + in_data.shape_, sizeof(DType)); }); workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); @@ -137,9 +136,15 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, }); }); // Calculate data = data - mean +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {inputs[0], outputs[layernorm::kMean]}, {kWriteTo}, {outputs[0]}); +#else + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {inputs[0], outputs[layernorm::kMean]}, + {kWriteTo}, {outputs[0]}); +#endif // !defined(__CUDACC__) // Calculate std const TBlob centered_out = outputs[0].reshape(red_src_shape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { @@ -156,6 +161,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, + scalar(param.eps)); }); }); +#if !defined(__CUDACC__) // Calculate data = data / std BinaryBroadcastCompute(attrs, ctx, {outputs[0], outputs[layernorm::kStd]}, @@ -168,6 +174,20 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, BinaryBroadcastCompute(attrs, ctx, {outputs[0], beta}, {kWriteTo}, {outputs[0]}); +#else + // Calculate data = data / std + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {outputs[0], outputs[layernorm::kStd]}, + {kWriteTo}, {outputs[0]}); + // Calculate data = data * gamma + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {outputs[0], gamma}, + {kWriteTo}, {outputs[0]}); + // Calculate data = data + beta + BinaryBroadcastRTCCompute {"add"}(attrs, ctx, + {outputs[0], beta}, + {kWriteTo}, {outputs[0]}); +#endif // !defined(__CUDACC__) } template @@ -230,18 +250,16 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, // There are two types of reduction workloads: reduce over axis and reduce exclude axis // We take the maximum of the workspace sizes required by these workloads. // Also, we explicitly set the req_type=kAddto in case we want to use it. - BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - reduce_workspace_size = - std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_dst_shape, - kAddTo, red_src_shape)); - }); - BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - reduce_workspace_size = - std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, - red_exclude_src_shape)); - }); + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_dst_shape, + kAddTo, red_src_shape, + sizeof(DType))); + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, + red_exclude_src_shape, + sizeof(DType))); }); workspace = ctx.requested[0].get_space_typed( Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s); @@ -252,12 +270,21 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); // Compute normalized_data = (data - mean) / std +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {data, mean}, {kWriteTo}, {normalized_data}); BinaryBroadcastCompute(attrs, ctx, {normalized_data, std}, {kWriteTo}, {normalized_data}); +#else + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {data, mean}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {normalized_data, std}, + {kWriteTo}, {normalized_data}); +#endif // !defined(__CUDACC__) // Calculate grad_beta bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); if (req[2] != kNullOp) { @@ -276,8 +303,13 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, }); } // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) +#if !defined(__CUDACC__) ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, {kWriteTo}, {ograd_mult}); +#else + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); +#endif // !defined(__CUDACC__) if (req[1] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { @@ -298,12 +330,21 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, // grad_data = ograd_mult - mean(ograd_mult, axis) // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) if (req[0] != kNullOp) { +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {ograd, gamma}, {kWriteTo}, {ograd_mult}); BinaryBroadcastCompute(attrs, ctx, {ograd_mult, std}, {kWriteTo}, {ograd_mult}); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {ograd, gamma}, + {kWriteTo}, {ograd_mult}); + BinaryBroadcastRTCCompute {"div"}(attrs, ctx, + {ograd_mult, std}, + {kWriteTo}, {ograd_mult}); +#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { if (safe_acc) { @@ -319,11 +360,19 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(channel_size); }); +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {ograd_mult, red_out}, {req[0]}, {outputs[0]}); ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, {kWriteTo}, {ograd_mult}); +#else + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {outputs[0]}); + ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); +#endif // !defined(__CUDACC__) MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { if (safe_acc) { @@ -339,9 +388,15 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(- channel_size); }); +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, {normalized_data, red_out}, {kAddTo}, {outputs[0]}); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {outputs[0]}); +#endif // !defined(__CUDACC__) } } diff --git a/src/operator/nn/pool.cuh b/src/operator/nn/pool.cuh index e771b3681573..92d4e43d51ea 100644 --- a/src/operator/nn/pool.cuh +++ b/src/operator/nn/pool.cuh @@ -83,7 +83,7 @@ #include "./pool_utils.h" #include "../mxnet_op.h" #include "../mshadow_op.h" -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" namespace mxnet { namespace op { diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index e34cc263183e..7806eaf90811 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -34,7 +34,7 @@ #include "../mxnet_op.h" #include "../operator_common.h" #include "../tensor/broadcast_reduce_op.h" -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" namespace mxnet { namespace op { diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh index 357ce6cd31d5..d4374edc9828 100644 --- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh +++ b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh @@ -285,7 +285,7 @@ __global__ void reduce_kernel_M1_wr(const int N, const bool addto, template void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const OpReqType req, const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config, + const ReduceImplConfig& config, Reducer* reducer = nullptr) { bool need_clean = !reducer; reducer = reducer ? reducer : new Reducer(); @@ -310,13 +310,13 @@ void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const OpReqT const int by = (config.kernel_1.do_transpose) ? config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig::unroll_reduce, UNROLL, { + const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); + KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { reduce_kernel_wr <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( config.N, config.M, addto, big.dptr(), small_dptr, big.shape_.get(), - small.shape_.get(), config.rshape, config.rstride, config.Mnext, - config.kernel_1.do_transpose, reducer); + small.shape_.get(), config.rshape.get(), config.rstride.get(), + config.Mnext, config.kernel_1.do_transpose, reducer); }); MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr); @@ -335,7 +335,7 @@ void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const OpReqT template void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs, const OpReqType req, const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config, Reducer* reducer = nullptr) { + const ReduceImplConfig& config, Reducer* reducer = nullptr) { bool need_clean = !reducer; reducer = reducer ? reducer : new Reducer(); if (config.M == 1) { @@ -360,8 +360,8 @@ void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const TBlob& const int by = (config.kernel_1.do_transpose) ? config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig::unroll_reduce, UNROLL, { + const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); + KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { reduce_kernel_wr <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( config.N, config.M, addto, big.dptr(), lhs.dptr(), rhs.dptr(), @@ -393,14 +393,13 @@ void ReduceWithReducer(Stream *s, const TBlob& small, const OpReqType req, cudaStream_t stream = Stream::GetStream(s); bool need_clean = !reducer; reducer = reducer ? reducer : new Reducer(); - ReduceImplConfig config = - ConfigureReduceImpl(small.shape_, big.shape_, nullptr, nullptr); + ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType)); if (safe_acc) { MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { typedef typename std::conditional::type AccType; MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { typedef typename std::conditional::type OutType; - config = ConfigureReduceImpl(small.shape_, big.shape_, nullptr, nullptr); + config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr, sizeof(AccType)); ReduceImplWithReducer( stream, small, req, big, workspace, config, reducer); }); diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h index 2b5970d4f4ae..0226df45f960 100644 --- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h +++ b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h @@ -31,6 +31,10 @@ namespace mxnet { namespace op { namespace broadcast { using namespace mshadow; +using mxnet_op::unravel; +using mxnet_op::ravel; +using mxnet_op::dot; +using mxnet_op::unravel_dot; template MSHADOW_XINLINE void seq_reduce_assign_wr(const index_t idx, const size_t M, const bool addto, diff --git a/src/operator/numpy/linalg/broadcast_reduce_op_customized.h b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h index 25f66d04f663..8e1c0b3db18d 100644 --- a/src/operator/numpy/linalg/broadcast_reduce_op_customized.h +++ b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h @@ -51,8 +51,8 @@ void ReduceAxesComputeImplWithReducer(const OpContext& ctx, const TBlob in_data = inputs[0].reshape(src_shape); const TBlob out_data = outputs[0].reshape(dst_shape); BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data.shape_, req[0], in_data.shape_); + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_, sizeof(OType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); broadcast::ReduceWithReducer( diff --git a/src/operator/numpy/linalg/np_matrix_rank-inl.h b/src/operator/numpy/linalg/np_matrix_rank-inl.h index 8ccecb57db11..9dba245aff0b 100644 --- a/src/operator/numpy/linalg/np_matrix_rank-inl.h +++ b/src/operator/numpy/linalg/np_matrix_rank-inl.h @@ -359,6 +359,14 @@ void MatrixRankNoneTolForward(const nnvm::NodeAttrs& attrs, MatrixRankNoneTolForwardImpl(a, rank, attrs, ctx, req); } +// Windows has issues with #ifdefs inside MSHADOW_TYPE_SWITCH +#ifndef __CUDACC__ +#define NP_LINALG_MATRIX_RANK_BROADCAST(OP, RTCOP) \ + mxnet::op::BinaryBroadcastCompute +#else +#define NP_LINALG_MATRIX_RANK_BROADCAST(OP, RTCOP) mxnet::op::BinaryBroadcastRTCCompute {#RTCOP} +#endif + template void MatrixRankForwardImpl(const TBlob& a, const TBlob& tol, @@ -410,9 +418,9 @@ void MatrixRankForwardImpl(const TBlob& a, if (new_tol_data.dptr() != tol.dptr()) { Copy(new_tol_data.FlatTo1D(s), tol.FlatTo1D(s), s); } - mxnet::op::BinaryBroadcastCompute(attrs, ctx, - {s_data, new_tol_data}, - {kWriteTo}, {broadcast_data}); + NP_LINALG_MATRIX_RANK_BROADCAST(gt, greater)(attrs, ctx, + {s_data, new_tol_data}, + {kWriteTo}, {broadcast_data}); // Step5: Calculate rank. const int b_ndim = broadcast_shape.ndim(); const int data_size = broadcast_data.size(b_ndim - 1); @@ -425,6 +433,8 @@ void MatrixRankForwardImpl(const TBlob& a, }); } +#undef NP_LINALG_MATRIX_RANK_BROADCAST + template void MatrixRankForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/linalg/np_pinv-inl.h b/src/operator/numpy/linalg/np_pinv-inl.h index b3b8e0c76c64..c6163617de68 100644 --- a/src/operator/numpy/linalg/np_pinv-inl.h +++ b/src/operator/numpy/linalg/np_pinv-inl.h @@ -464,6 +464,14 @@ inline mxnet::TShape GetTransAxis(const mxnet::TShape& in_shape) { return mxnet::TShape(trans_axis.begin(), trans_axis.end()); } +// Windows has issues with #ifdefs inside MSHADOW_TYPE_SWITCH +#ifndef __CUDACC__ +#define NP_LINALG_PINV_BROADCAST(OP, RTCOP) \ + mxnet::op::BinaryBroadcastCompute +#else +#define NP_LINALG_PINV_BROADCAST(OP, RTCOP) mxnet::op::BinaryBroadcastRTCCompute {#RTCOP} +#endif + template void PinvOpForwardImpl(const TBlob& a, const TBlob& rcond, @@ -553,13 +561,13 @@ void PinvOpForwardImpl(const TBlob& a, s, S.size(0), Smax.dptr_, S.dptr_, S.size(1), S.stride_); // Step3: Calculate Cutoff. std::vector temp_req({kWriteTo}); - mxnet::op::BinaryBroadcastCompute(attrs, ctx, - {rcond_data, smax_data}, - temp_req, {cutoff_data}); + NP_LINALG_PINV_BROADCAST(mul, mul)(attrs, ctx, + {rcond_data, smax_data}, + temp_req, {cutoff_data}); // Step4: Calculte Large. - mxnet::op::BinaryBroadcastCompute(attrs, ctx, - {s_data, cutoff_data}, - temp_req, {large_data}); + NP_LINALG_PINV_BROADCAST(gt, greater)(attrs, ctx, + {s_data, cutoff_data}, + temp_req, {large_data}); // Step5: Discard small singular values. mxnet_op::Kernel::Launch( s, s_data.Size(), s_data.dptr(), large_data.dptr()); @@ -573,8 +581,8 @@ void PinvOpForwardImpl(const TBlob& a, } s_data = s_data.reshape(s_shape_newaxis); u_data = ut_data.reshape(ut_shape); - mxnet::op::BinaryBroadcastCompute(attrs, ctx, {s_data, ut_data}, - temp_req, {u_data}); + NP_LINALG_PINV_BROADCAST(mul, mul)(attrs, ctx, {s_data, ut_data}, + temp_req, {u_data}); gemm2::op(vt_data.FlatToKD(s), u_data.FlatToKD(s), pinv_a.FlatToKD(s), @@ -712,8 +720,8 @@ void PinvScalarRcondOpForwardImpl(const TBlob& a, } s_data = s_data.reshape(s_shape_newaxis); u_data = ut_data.reshape(ut_shape); - mxnet::op::BinaryBroadcastCompute(attrs, ctx, {s_data, ut_data}, - {kWriteTo}, {u_data}); + NP_LINALG_PINV_BROADCAST(mul, mul)(attrs, ctx, {s_data, ut_data}, + {kWriteTo}, {u_data}); gemm2::op(vt_data.FlatToKD(s), u_data.FlatToKD(s), pinv_a.FlatToKD(s), @@ -722,6 +730,8 @@ void PinvScalarRcondOpForwardImpl(const TBlob& a, }); } +#undef NP_LINALG_PINV_BROADCAST + template void PinvScalarRcondOpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 6b59ac0d8621..3b505b788ae9 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -777,6 +777,13 @@ struct avg_grad_w_1D_kernel { } }; +// Windows has issues with #ifdefs inside MSHADOW_TYPE_SWITCH +#ifndef __CUDACC__ +#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastCompute +#else +#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastRTCCompute {#OP} +#endif + template void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -820,10 +827,8 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, TShape src_shape, dst_shape; BroadcastReduceShapeCompact(data.shape_, small1, &src_shape, &dst_shape); size_t workspace_size = 0; - MXNET_NDIM_SWITCH(dst_shape.ndim(), NDim, { - workspace_size = broadcast::ReduceWorkspaceSize( - s, dst_shape, {kWriteTo}, src_shape); - }); + workspace_size = broadcast::ReduceWorkspaceSize( + s, dst_shape, {kWriteTo}, src_shape, sizeof(DType)); size_t temp_mem_size = temp_data_size + temp_sum_size + workspace_size; Tensor temp_mem = ctx.requested[0].get_space_typed(Shape1(temp_mem_size), s); @@ -834,7 +839,7 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, // Compute weighted data TBlob wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask); - BinaryBroadcastCompute( + NP_BROADCAST_REDUCE_OP_BROADCAST(mul)( attrs, ctx, {data, weights}, {kWriteTo}, {wa}); // Compute sum of weighted data @@ -852,7 +857,7 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, ctx, {weights}, {kWriteTo}, {scl}, workspace, w_src_shape, w_dst_shape); // Compute avg and assign output - BinaryBroadcastCompute( + NP_BROADCAST_REDUCE_OP_BROADCAST(div)( attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)}); } else { // Compute and assign the derivatives of a and weights @@ -897,6 +902,8 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs, }); } +#undef NP_BROADCAST_REDUCE_OP_BROADCAST + template void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -993,10 +1000,8 @@ void NumpyMomentsForward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { // Get workspace and temp space for data - mean size_t workspace_size = 0; - BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - workspace_size = broadcast::ReduceWorkspaceSize( - s, dst_shape, req[0], src_shape);; - }); + workspace_size = broadcast::ReduceWorkspaceSize( + s, dst_shape, req[0], src_shape, sizeof(DType)); size_t temp_data_size = data.shape_.Size() * sizeof(DType); size_t temp_mem_size = temp_data_size + workspace_size; Tensor temp_mem = diff --git a/src/operator/numpy/np_cross-inl.h b/src/operator/numpy/np_cross-inl.h index c2092bbfec23..23a3a3326f5f 100644 --- a/src/operator/numpy/np_cross-inl.h +++ b/src/operator/numpy/np_cross-inl.h @@ -390,8 +390,13 @@ struct NumpyCrossForwardImpl { mxnet_op::Kernel::Launch(s, bw_data.Size(), b_data.dptr(), bw_data.dptr(), b_data.size(b_ndim - 1), b_index_vec[i], b_data.Size()); +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, { aw_data, bw_data }, { kWriteTo }, { cw_data_vec[idx] }); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, { aw_data, bw_data }, + { kWriteTo }, { cw_data_vec[idx] }); +#endif // !defined(__CUDACC__) MXNET_ASSIGN_REQ_SWITCH(req_vec[i], req_type, { mxnet_op::Kernel, xpu>::Launch(s, cw_data_vec[idx].Size(), cw_data_vec[idx].dptr(), @@ -493,18 +498,30 @@ struct NumpyCrossForwardImpl { mxnet_op::Kernel::Launch(s, bw_data.Size(), b_data.dptr(), bw_data.dptr(), b_data.size(b_ndim - 1), 1, b_data.Size()); +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, { aw_data, bw_data }, { req[0] }, { c }); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, { aw_data, bw_data }, + { req[0] }, { c }); +#endif // !defined(__CUDACC__) mxnet_op::Kernel::Launch(s, aw_data.Size(), a_data.dptr(), aw_data.dptr(), a_data.size(a_ndim - 1), 1, a_data.Size()); mxnet_op::Kernel::Launch(s, bw_data.Size(), b_data.dptr(), bw_data.dptr(), b_data.size(b_ndim - 1), 0, b_data.Size()); +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, { aw_data, bw_data }, { kWriteTo }, { cw_data }); BinaryBroadcastCompute(attrs, ctx, { c, cw_data }, { kWriteTo }, { c }); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, { aw_data, bw_data }, + { kWriteTo }, { cw_data }); + BinaryBroadcastRTCCompute {"sub"}(attrs, ctx, { c, cw_data }, + { kWriteTo }, { c }); +#endif // !defined(__CUDACC__) } }; @@ -659,10 +676,9 @@ struct ReduceImplWrap { size_t ws_reduce = 0U; std::vector reduce_axis = GetReduceAxis(out_move_shape, in_move_shape); if (reduce_axis.empty() || req == kNullOp) { return 0U; } - SUM_NDIM_SWITCH(out_shape.ndim(), NDim, { - ws_reduce = broadcast::ReduceWorkspaceSize(ctx.get_stream(), - out_shape, req, in_shape); - }); + ws_reduce = broadcast::ReduceWorkspaceSize(ctx.get_stream(), + out_shape, req, in_shape, + sizeof(DType)); return ws_reduce; } @@ -1196,8 +1212,13 @@ struct NumpyCrossBackwardImpl { b_move_data.size(b_ndim - 1), 1, b_move_data.Size()); // cw_data = grad_c_move * b_move_data[..., 1]. +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, { grad_c, bw_data }, { kWriteTo }, { cw_data }); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, { grad_c, bw_data }, + { kWriteTo }, { cw_data }); +#endif // !defined(__CUDACC__) // Copy cw_data to grad_move_data[..., 0]. mxnet_op::Kernel, xpu>::Launch(s, cw_data.Size(), cw_data.dptr(), @@ -1211,8 +1232,13 @@ struct NumpyCrossBackwardImpl { b_move_data.size(b_ndim - 1), 0, b_move_data.Size()); // cw_data = grad_c_move * b_move_data[..., 0]. +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, { grad_c, bw_data }, { kWriteTo }, { cw_data }); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, { grad_c, bw_data }, + { kWriteTo }, { cw_data }); +#endif // !defined(__CUDACC__) // Copy -cw_data to grad_move_data[..., 1]. mxnet_op::Kernel, xpu>::Launch(s, cw_data.Size(), cw_data.dptr(), @@ -1257,8 +1283,13 @@ struct NumpyCrossBackwardImpl { a_move_data.size(a_ndim - 1), 1, a_move_data.Size()); // cw_data = grad_c_move * a_move_data[..., 1]. +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, { grad_c, aw_data }, { kWriteTo }, { cw_data }); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, { grad_c, aw_data }, + { kWriteTo }, { cw_data }); +#endif // !defined(__CUDACC__) // Copy -cw_data to grad_move_data[..., 0]. mxnet_op::Kernel, xpu>::Launch(s, cw_data.Size(), cw_data.dptr(), @@ -1272,8 +1303,13 @@ struct NumpyCrossBackwardImpl { a_move_data.size(a_ndim - 1), 0, a_move_data.Size()); // cw_data = grad_c_move * a_move_data[..., 0]. +#if !defined(__CUDACC__) BinaryBroadcastCompute(attrs, ctx, { grad_c, aw_data }, { kWriteTo }, { cw_data }); +#else + BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, { grad_c, aw_data }, + { kWriteTo }, { cw_data }); +#endif // !defined(__CUDACC__) // Copy cw_data to grad_move_data[..., 1]. mxnet_op::Kernel, xpu>::Launch(s, cw_data.Size(), cw_data.dptr(), diff --git a/src/operator/numpy/np_diff-inl.h b/src/operator/numpy/np_diff-inl.h index 8a8bc558962a..3d80e2d941c8 100644 --- a/src/operator/numpy/np_diff-inl.h +++ b/src/operator/numpy/np_diff-inl.h @@ -73,7 +73,7 @@ struct diff_forward { const int stride, const mshadow::Shape oshape, const mshadow::Shape ishape) { - using namespace broadcast; + using namespace mxnet_op; // j represent the memory index of the corresponding input entry int j = ravel(unravel(i, oshape), ishape); @@ -145,7 +145,7 @@ struct diff_backward { const int stride, const int axis, const mshadow::Shape oshape, const mshadow::Shape ishape) { - using namespace broadcast; + using namespace mxnet_op; if (n == 0) { igrad[i] = ograd[i]; return; diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cu b/src/operator/numpy/np_elemwise_broadcast_logic_op.cu index e27535d9f4f5..90c5bb465a73 100644 --- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cu @@ -34,11 +34,11 @@ namespace op { #define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ NNVM_REGISTER_OP(_npi_##name) \ - .set_attr("FCompute", BinaryBroadcastComputeLogic) + .set_attr("FCompute", BinaryBroadcastRTCCompute{"np_" #name}) #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(name) \ NNVM_REGISTER_OP(_npi_##name##_scalar) \ - .set_attr("FCompute", BinaryScalarOp::ComputeLogic) + .set_attr("FCompute", BinaryScalarRTCCompute{"np_" #name}) MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index a2927cda61ff..a6f85a8bc219 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -29,78 +29,58 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_add) -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"add"}); NNVM_REGISTER_OP(_backward_npi_broadcast_add) -.set_attr("FCompute", NumpyBinaryBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"one", "one"}); NNVM_REGISTER_OP(_npi_subtract) -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"sub"}); NNVM_REGISTER_OP(_backward_npi_broadcast_sub) -.set_attr("FCompute", NumpyBinaryBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"one", "negone"}); NNVM_REGISTER_OP(_npi_multiply) -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"mul"}); NNVM_REGISTER_OP(_backward_npi_broadcast_mul) -.set_attr("FCompute", NumpyBinaryBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"right", "left"}); NNVM_REGISTER_OP(_npi_mod) -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"mod"}); NNVM_REGISTER_OP(_backward_npi_broadcast_mod) -.set_attr("FCompute", NumpyBinaryBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"mod_grad", "mod_rgrad"}); NNVM_REGISTER_OP(_npi_power) -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"power"}); NNVM_REGISTER_OP(_backward_npi_broadcast_power) -.set_attr("FCompute", NumpyBinaryBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"power_grad", "power_rgrad"}); NNVM_REGISTER_OP(_npi_add_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"add"}); NNVM_REGISTER_OP(_npi_subtract_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"sub"}); NNVM_REGISTER_OP(_npi_rsubtract_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rsub"}); NNVM_REGISTER_OP(_npi_multiply_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"mul"}); NNVM_REGISTER_OP(_npi_mod_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"mod"}); NNVM_REGISTER_OP(_npi_rmod_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rmod"}); NNVM_REGISTER_OP(_npi_power_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"power"}); NNVM_REGISTER_OP(_npi_rpower_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rpow"}); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index ce53eb6a3872..1fa58908a113 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -412,12 +412,10 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, { if (need_bc) { - BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, { - workspace_size_l = ReduceWorkspaceSize( - s, new_lshape, req[0], new_oshape, new_lshape, new_rshape); - workspace_size_r = ReduceWorkspaceSize( - s, new_rshape, req[1], new_oshape, new_lshape, new_rshape); - }); + workspace_size_l = ReduceWorkspaceSize( + s, new_lshape, req[0], new_oshape, new_lshape, new_rshape, sizeof(OType)); + workspace_size_r = ReduceWorkspaceSize( + s, new_rshape, req[1], new_oshape, new_lshape, new_rshape, sizeof(OType)); } size_t workspace_size = std::max(workspace_size_l, workspace_size_r); size_t cast_tensor_size = tensor_size * sizeof(OType); diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index ce7f59a5520f..90a48d4aee9f 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -201,16 +201,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"}); +.set_attr("FGradient", MakeZeroGradNodes); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); - inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu index 33c77e08e408..b1d7e71bf17d 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu @@ -29,92 +29,89 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_copysign) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"copysign"}); NNVM_REGISTER_OP(_npi_lcm) -.set_attr("FCompute", BinaryBroadcastIntCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"lcm"}); NNVM_REGISTER_OP(_npi_bitwise_and) -.set_attr("FCompute", BinaryBroadcastIntCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"bitwise_and"}); NNVM_REGISTER_OP(_npi_bitwise_xor) -.set_attr("FCompute", BinaryBroadcastIntCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"bitwise_xor"}); NNVM_REGISTER_OP(_npi_bitwise_or) -.set_attr("FCompute", BinaryBroadcastIntCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"bitwise_or"}); NNVM_REGISTER_OP(_backward_npi_copysign) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"copysign_grad", + "zero"}); NNVM_REGISTER_OP(_npi_arctan2) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"arctan2"}); NNVM_REGISTER_OP(_backward_npi_arctan2) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"arctan2_grad", + "arctan2_rgrad"}); + NNVM_REGISTER_OP(_npi_hypot) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"hypot"}); NNVM_REGISTER_OP(_backward_npi_hypot) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"hypot_grad_left", + "hypot_grad_right"}); NNVM_REGISTER_OP(_npi_copysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"copysign"}); NNVM_REGISTER_OP(_npi_rcopysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rcopysign"}); NNVM_REGISTER_OP(_backward_npi_copysign_scalar) .set_attr("FCompute", - BinaryScalarOp::Backward); - -NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); + BinaryScalarRTCBackward{"copysign_grad"}); NNVM_REGISTER_OP(_npi_arctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"arctan2"}); NNVM_REGISTER_OP(_backward_npi_arctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"arctan2_grad"}); NNVM_REGISTER_OP(_npi_rarctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rarctan2"}); NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"rarctan2_grad"}); NNVM_REGISTER_OP(_npi_lcm_scalar) -.set_attr("FCompute", BinaryScalarOp::ComputeInt); +.set_attr("FCompute", BinaryScalarRTCCompute{"lcm"}); NNVM_REGISTER_OP(_npi_bitwise_and_scalar) -.set_attr("FCompute", BinaryScalarOp::ComputeInt); +.set_attr("FCompute", BinaryScalarRTCCompute{"bitwise_and"}); NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) -.set_attr("FCompute", BinaryScalarOp::ComputeInt); +.set_attr("FCompute", BinaryScalarRTCCompute{"bitwise_xor"}); NNVM_REGISTER_OP(_npi_bitwise_or_scalar) -.set_attr("FCompute", BinaryScalarOp::ComputeInt); +.set_attr("FCompute", BinaryScalarRTCCompute{"bitwise_or"}); NNVM_REGISTER_OP(_npi_ldexp) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"ldexp"}); NNVM_REGISTER_OP(_npi_ldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"ldexp"}); NNVM_REGISTER_OP(_npi_rldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rldexp"}); NNVM_REGISTER_OP(_backward_npi_ldexp) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"ldexp_grad", + "ldexp_rgrad"}); NNVM_REGISTER_OP(_backward_npi_ldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"ldexp_grad"}); NNVM_REGISTER_OP(_backward_npi_rldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"rldexp_grad"}); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cu index fa2f3bf080c7..93d2cf18350e 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cu @@ -29,49 +29,48 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_fmax) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"fmax"}); NNVM_REGISTER_OP(_backward_npi_fmax) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"greater_equal", "less"}); NNVM_REGISTER_OP(_npi_fmax_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"fmax"}); NNVM_REGISTER_OP(_backward_npi_fmax_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"greater_equal"}); NNVM_REGISTER_OP(_npi_fmin) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"fmin"}); NNVM_REGISTER_OP(_backward_npi_fmin) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"less_equal", + "greater"}); NNVM_REGISTER_OP(_npi_fmin_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"fmin"}); NNVM_REGISTER_OP(_backward_npi_fmin_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"less_equal"}); NNVM_REGISTER_OP(_npi_fmod) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"fmod"}); NNVM_REGISTER_OP(_backward_npi_fmod) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"mod_grad", + "mod_rgrad"}); NNVM_REGISTER_OP(_npi_fmod_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"fmod"}); NNVM_REGISTER_OP(_backward_npi_fmod_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"mod_grad"}); NNVM_REGISTER_OP(_npi_rfmod_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rfmod"}); NNVM_REGISTER_OP(_backward_npi_rfmod_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"rmod_grad"}); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc index c375b3a2036d..548f61874614 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cc +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc @@ -228,7 +228,7 @@ The sign function returns -1 if x < 0, 0 if x==0, 1 if x > 0. Example:: sign([-2, 0, 3]) = [-1, 0, 1] )code" ADD_FILELINE) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_sign"}); +.set_attr("FGradient", MakeZeroGradNodes); // rint MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY(_npi_rint, "x", mshadow_op::rint) diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu index a872006336e4..d11a4addf97c 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cu +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu @@ -27,108 +27,94 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_npx_relu) -.set_attr("FCompute", UnaryOp::Compute); +#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(__name$, __kernel$) \ + NNVM_REGISTER_OP(__name$) \ + .set_attr("FCompute", UnaryRTCCompute{#__kernel$}) -NNVM_REGISTER_OP(_npx_sigmoid) -.set_attr("FCompute", UnaryOp::Compute); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npx_relu, relu); + +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npx_sigmoid, sigmoid); NNVM_REGISTER_OP(_npi_copy) .set_attr("FCompute", UnaryOp::IdentityCompute); -#define MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(__name$, __kernel$) \ - NNVM_REGISTER_OP(__name$) \ - .set_attr("FCompute", UnaryOp::Compute) - -#define MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(__name$, __kernel$) \ - NNVM_REGISTER_OP(__name$) \ - .set_attr("FCompute", UnaryOp::ComputeMixedType) - -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_negative, mshadow_op::negation); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_negative, negation); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_reciprocal, mshadow_op::reciprocal); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_reciprocal, reciprocal); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_absolute, mshadow_op::abs); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_absolute, abs); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sign, mshadow_op::sign); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sign, sign); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_rint, mshadow_op::rint); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_rint, rint); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_ceil, mshadow_op::ceil); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_ceil, ceil); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_floor, mshadow_op::floor); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_floor, floor); -NNVM_REGISTER_OP(_npi_bitwise_not) -.set_attr("FCompute", UnaryOp::ComputeInt); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_bitwise_not, bitwise_not); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_trunc, mshadow_op::trunc); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_trunc, trunc); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_fix, mshadow_op::fix); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_fix, fix); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_square, mshadow_op::square); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_square, square); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_sqrt, mshadow_op::square_root); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sqrt, sqrt); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_cbrt, mshadow_op::cube_root); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_cbrt, cbrt); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_exp, mshadow_op::exp); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_exp, exp); -NNVM_REGISTER_OP(_npi_log) -.set_attr("FCompute", UnaryOp::ComputeMixedType); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log, log); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_log10, mshadow_op::log10); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log10, log10); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_log2, mshadow_op::log2); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log2, log2); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_log1p, mshadow_op::log1p); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log1p, log1p); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_expm1, mshadow_op::expm1); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_expm1, expm1); -NNVM_REGISTER_OP(_npi_logical_not) -.set_attr("FCompute", UnaryOp::ComputeLogic); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_logical_not, np_logical_not); -NNVM_REGISTER_OP(_npi_isnan) -.set_attr("FCompute", UnaryOp::ComputeLogic); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_isnan, isnan); -NNVM_REGISTER_OP(_npi_isinf) -.set_attr("FCompute", UnaryOp::ComputeLogic); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_isinf, isinf); -NNVM_REGISTER_OP(_npi_isposinf) -.set_attr("FCompute", UnaryOp::ComputeLogic); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_isposinf, isposinf); -NNVM_REGISTER_OP(_npi_isneginf) -.set_attr("FCompute", UnaryOp::ComputeLogic); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_isneginf, isneginf); -NNVM_REGISTER_OP(_npi_isfinite) -.set_attr("FCompute", UnaryOp::ComputeLogic); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_isfinite, isfinite); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_sin, mshadow_op::sin); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sin, sin); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_cos, mshadow_op::cos); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_cos, cos); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_tan, mshadow_op::tan); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_tan, tan); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_arcsin, mshadow_op::arcsin); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arcsin, arcsin); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_arccos, mshadow_op::arccos); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arccos, arccos); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_arctan, mshadow_op::arctan); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arctan, arctan); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_degrees, mshadow_op::degrees); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_degrees, degrees); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_radians, mshadow_op::radians); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_radians, radians); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_sinh, mshadow_op::sinh); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sinh, sinh); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_cosh, mshadow_op::cosh); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_cosh, cosh); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_tanh, mshadow_op::tanh); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_tanh, tanh); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_arcsinh, mshadow_op::arcsinh); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arcsinh, arcsinh); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_arccosh, mshadow_op::arccosh); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arccosh, arccosh); -MXNET_OPERATOR_REGISTER_NUMPY_MIXED_TYPE_UNARY_GPU(_npi_arctanh, mshadow_op::arctanh); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arctanh, arctanh); NNVM_REGISTER_OP(_npi_around) .set_attr("FCompute", AroundOpForward); @@ -140,93 +126,70 @@ NNVM_REGISTER_OP(_npi_backward_nan_to_num) .set_attr("FCompute", NumpyNanToNumOpBackward); NNVM_REGISTER_OP(_backward_npi_exp) -.set_attr("FCompute", - ElemwiseBinaryOp::MixedUnaryBackwardUseInOutCompute); +.set_attr("FCompute", UnaryBwdInOutRTCCompute{"mul"}); NNVM_REGISTER_OP(_backward_npi_log) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_log"}); NNVM_REGISTER_OP(_backward_npi_log10) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< -gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_log10"}); NNVM_REGISTER_OP(_backward_npi_log2) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_log2"}); NNVM_REGISTER_OP(_backward_npi_log1p) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_log1p"}); NNVM_REGISTER_OP(_backward_npi_expm1) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_expm1"}); NNVM_REGISTER_OP(_backward_npi_sqrt) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInOutCompute< - gpu, unary_bwd >); +.set_attr("FCompute", UnaryBwdInOutRTCCompute{"backward_sqrt"}); NNVM_REGISTER_OP(_backward_npi_cbrt) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInOutCompute< - gpu, unary_bwd >); +.set_attr("FCompute", UnaryBwdInOutRTCCompute{"backward_cbrt"}); NNVM_REGISTER_OP(_backward_npi_sin) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_sin"}); NNVM_REGISTER_OP(_backward_npi_cos) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_cos"}); NNVM_REGISTER_OP(_backward_npi_tan) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInOutCompute< - gpu, unary_bwd >); +.set_attr("FCompute", UnaryBwdInOutRTCCompute{"backward_tan"}); NNVM_REGISTER_OP(_backward_npi_arcsin) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arcsin"}); NNVM_REGISTER_OP(_backward_npi_arccos) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arccos"}); NNVM_REGISTER_OP(_backward_npi_arctan) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arctan"}); NNVM_REGISTER_OP(_backward_npi_degrees) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_degrees"}); NNVM_REGISTER_OP(_backward_npi_radians) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_radians"}); NNVM_REGISTER_OP(_backward_npi_cosh) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_cosh"}); NNVM_REGISTER_OP(_backward_npi_sinh) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_sinh"}); NNVM_REGISTER_OP(_backward_npi_tanh) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInOutCompute< - gpu, unary_bwd >); +.set_attr("FCompute", UnaryBwdInOutRTCCompute{"backward_tanh"}); NNVM_REGISTER_OP(_backward_npi_arcsinh) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arcsinh"}); NNVM_REGISTER_OP(_backward_npi_arccosh) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arccosh"}); NNVM_REGISTER_OP(_backward_npi_arctanh) -.set_attr("FCompute", ElemwiseBinaryOp::MixedUnaryBackwardUseInCompute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arctanh"}); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_polynomial_op.cu b/src/operator/numpy/np_polynomial_op.cu index 31f284b7a2a8..3c4655b2da22 100644 --- a/src/operator/numpy/np_polynomial_op.cu +++ b/src/operator/numpy/np_polynomial_op.cu @@ -23,7 +23,7 @@ */ #include "np_polynomial_op-inl.h" -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" namespace mxnet { namespace op { diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu index c8eccfe140b4..757fa0d9e8a2 100644 --- a/src/operator/numpy/np_true_divide.cu +++ b/src/operator/numpy/np_true_divide.cu @@ -32,8 +32,7 @@ NNVM_REGISTER_OP(_npi_true_divide) .set_attr("FCompute", TrueDivideBroadcastCompute); NNVM_REGISTER_OP(_backward_npi_broadcast_div) -.set_attr("FCompute", NumpyBinaryBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"div_grad", "div_rgrad"}); NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_attr("FCompute", TrueDivideScalarCompute); diff --git a/src/operator/numpy/np_where_op-inl.h b/src/operator/numpy/np_where_op-inl.h index 3233f785d246..10ec081b2a8f 100644 --- a/src/operator/numpy/np_where_op-inl.h +++ b/src/operator/numpy/np_where_op-inl.h @@ -225,10 +225,10 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, Tensor workspace; size_t ws_size = 0; if (ograd.shape_ != dx.shape_ || ograd.shape_ != dy.shape_) { - size_t ws_size1 = broadcast::ReduceWorkspaceSize( - s, expanded_lshape, req[0], expanded_oshape); - size_t ws_size2 = broadcast::ReduceWorkspaceSize( - s, expanded_rshape, req[1], expanded_oshape); + size_t ws_size1 = broadcast::ReduceWorkspaceSize( + s, expanded_lshape, req[0], expanded_oshape, sizeof(DType)); + size_t ws_size2 = broadcast::ReduceWorkspaceSize( + s, expanded_rshape, req[1], expanded_oshape, sizeof(DType)); ws_size = std::max(ws_size1, ws_size2); } // process left output @@ -366,8 +366,8 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs, Tensor workspace; size_t ws_size = 0; if (ograd.shape_ != dx.shape_) { - ws_size = broadcast::ReduceWorkspaceSize( - s, expanded_lshape, req[0], expanded_oshape); + ws_size = broadcast::ReduceWorkspaceSize(s, expanded_lshape, req[0], + expanded_oshape, sizeof(DType)); } // If lscalar, then process right output, `is_left` should be false if (ograd.shape_ == dx.shape_) { diff --git a/src/operator/numpy/random/np_exponential_op.h b/src/operator/numpy/random/np_exponential_op.h index 36d29ff842e3..203430dd5879 100644 --- a/src/operator/numpy/random/np_exponential_op.h +++ b/src/operator/numpy/random/np_exponential_op.h @@ -171,7 +171,7 @@ inline void ExponentialReparamBackwardImpl(const OpContext& ctx, const TBlob samples = inputs[3].reshape(new_oshape); const TBlob noise = inputs[4].reshape(new_oshape); size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); Reduce( diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index e1d031fb9a0b..57d46ff5cf51 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -420,7 +420,7 @@ inline void GammaReparamBackwardImpl(const OpContext& ctx, const TBlob alpha = inputs[1].reshape(new_ishape); TBlob samples = inputs[2].reshape(new_oshape); size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); // Convert samples to standard gamma Kernel, xpu>::Launch( s, samples.Size(), samples.dptr(), samples.dptr(), DType(scale)); diff --git a/src/operator/numpy/random/np_location_scale_op.h b/src/operator/numpy/random/np_location_scale_op.h index 00c89c149c5c..49bcbc3d0413 100644 --- a/src/operator/numpy/random/np_location_scale_op.h +++ b/src/operator/numpy/random/np_location_scale_op.h @@ -296,10 +296,10 @@ inline void LocationScaleReparamBackwardImpl(const OpContext& ctx, const TBlob rhs = inputs[3].reshape(new_rshape); const TBlob samples = inputs[4].reshape(new_oshape); const TBlob noise = inputs[5].reshape(new_oshape); - size_t workspace_size_l = ReduceWorkspaceSize( - s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_); - size_t workspace_size_r = ReduceWorkspaceSize( - s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_); + size_t workspace_size_l = ReduceWorkspaceSize( + s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); + size_t workspace_size_r = ReduceWorkspaceSize( + s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); size_t workspace_size = std::max(workspace_size_l, workspace_size_r); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); @@ -329,7 +329,7 @@ inline void ScalarLocationScaleReparamBackwardImpl(const OpContext& ctx, const TBlob samples = inputs[3].reshape(new_oshape); const TBlob noise = inputs[4].reshape(new_oshape); size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); if (loc_is_tensor) { diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h index d81f3d38f3a3..332200dc6cd2 100644 --- a/src/operator/numpy/random/np_normal_op.h +++ b/src/operator/numpy/random/np_normal_op.h @@ -261,10 +261,10 @@ inline void NormalReparamBackwardImpl(const OpContext& ctx, const TBlob rhs = inputs[3].reshape(new_rshape); const TBlob samples = inputs[4].reshape(new_oshape); const TBlob noise = inputs[5].reshape(new_oshape); - size_t workspace_size_l = ReduceWorkspaceSize( - s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_); - size_t workspace_size_r = ReduceWorkspaceSize( - s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_); + size_t workspace_size_l = ReduceWorkspaceSize( + s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); + size_t workspace_size_r = ReduceWorkspaceSize( + s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); size_t workspace_size = std::max(workspace_size_l, workspace_size_r); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); @@ -294,7 +294,7 @@ inline void ScalarNormalReparamBackwardImpl(const OpContext& ctx, const TBlob samples = inputs[3].reshape(new_oshape); const TBlob noise = inputs[4].reshape(new_oshape); size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); if (loc_is_tensor) { diff --git a/src/operator/numpy/random/np_pareto_op.h b/src/operator/numpy/random/np_pareto_op.h index a8a5d7f411c0..af0e6c568187 100644 --- a/src/operator/numpy/random/np_pareto_op.h +++ b/src/operator/numpy/random/np_pareto_op.h @@ -174,7 +174,7 @@ inline void ScalarParetoReparamBackwardImpl(const OpContext& ctx, const TBlob samples = inputs[3].reshape(new_oshape); const TBlob noise = inputs[4].reshape(new_oshape); size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); Reduce( diff --git a/src/operator/numpy/random/np_rayleigh_op.h b/src/operator/numpy/random/np_rayleigh_op.h index 3444f3b74af5..0bbaf5d7158b 100644 --- a/src/operator/numpy/random/np_rayleigh_op.h +++ b/src/operator/numpy/random/np_rayleigh_op.h @@ -172,7 +172,7 @@ inline void ScalarRayleighReparamBackwardImpl(const OpContext& ctx, const TBlob samples = inputs[3].reshape(new_oshape); const TBlob noise = inputs[4].reshape(new_oshape); size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); Reduce( diff --git a/src/operator/numpy/random/np_weibull_op.h b/src/operator/numpy/random/np_weibull_op.h index ff4c40ae8db5..74aeeff9f1fc 100644 --- a/src/operator/numpy/random/np_weibull_op.h +++ b/src/operator/numpy/random/np_weibull_op.h @@ -174,7 +174,7 @@ inline void ScalarWeibullReparamBackwardImpl(const OpContext& ctx, const TBlob samples = inputs[3].reshape(new_oshape); const TBlob noise = inputs[4].reshape(new_oshape); size_t workspace_size = - ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); Reduce( diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index ccfebf597f67..31c666307768 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -39,7 +39,7 @@ #include #include #include -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" #include "../common/utils.h" namespace mxnet { diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 4c66f00b14d6..9af336499d5c 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -297,7 +297,6 @@ IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_cube_root_grad); IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::abs); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sign); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign); // NOLINT() -IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc); // NOLINT() @@ -368,7 +367,6 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::copysign); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rcopysign); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_rgrad); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rcopysign_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan2); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rarctan2); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan2_grad); // NOLINT() diff --git a/src/operator/pad.cu b/src/operator/pad.cu index 643e62db722a..8d82ba337fdd 100644 --- a/src/operator/pad.cu +++ b/src/operator/pad.cu @@ -25,7 +25,7 @@ */ #include #include "./pad-inl.h" -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" namespace mshadow { namespace cuda { diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h index 5230576ce594..2c5c1ebe1fd3 100644 --- a/src/operator/quantization/quantization_utils.h +++ b/src/operator/quantization/quantization_utils.h @@ -184,7 +184,7 @@ inline size_t ConfigReduce(mshadow::Stream* s, CHECK_EQ(src_shape->ndim(), NDim); CHECK_EQ(dst_shape->ndim(), NDim); - return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape); + return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape, sizeof(DType)); } enum QuantizeOutType { kAuto = 0, kInt8, kUint8 }; diff --git a/src/operator/random/pdf_op.h b/src/operator/random/pdf_op.h index ee15e993c430..57bddfc2b1fe 100644 --- a/src/operator/random/pdf_op.h +++ b/src/operator/random/pdf_op.h @@ -588,8 +588,8 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs, const TShape src_shape(Shape2(N, outputs[0].Size() / N)), dst_shape(Shape2(N, 1)); // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf. MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - const size_t red_work_size(broadcast::ReduceWorkspaceSize<2, DType>( - s, dst_shape, kAddTo, src_shape)); + const size_t red_work_size(broadcast::ReduceWorkspaceSize( + s, dst_shape, kAddTo, src_shape, sizeof(DType))); const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size); Tensor tmp_space = ctx.requested[0].get_space_typed(Shape1(tmp_size), s); diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh index 379443dc1688..c7a7c478cbb3 100644 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ b/src/operator/tensor/broadcast_reduce-inl.cuh @@ -18,60 +18,16 @@ */ /*! - * Copyright (c) 2015-2017 by Contributors + * Copyright (c) 2015-2020 by Contributors * \file broadcast_reduce-inl.cuh * \brief CUDA implementations for binary broadcast and reduce - * \author Antti-Pekka Hynninen + * \author Antti-Pekka Hynninen, Przemyslaw Tredak */ #ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_ #define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_ using namespace mshadow::cuda; -template -__launch_bounds__(kMaxThreadsPerBlock) -__global__ void binary_broadcast_kernel(const int N, const bool addto, - const DType* __restrict lhs, - const DType* __restrict rhs, DType *out, - const Shape lstride, const Shape rstride, - const Shape oshape) { - for (int idx = blockIdx.x * blockDim.x * unroll + threadIdx.x; idx < N; - idx += blockDim.x * gridDim.x * unroll) - { - int j[unroll]; - int k[unroll]; - DType val[unroll]; - #pragma unroll - for (int i=0;i < unroll;i++) { - unravel_dot(idx + i*blockDim.x, oshape, lstride, rstride, &j[i], &k[i]); - val[i] = OP::Map(lhs[j[i]], rhs[k[i]]); - } - #pragma unroll - for (int i=0;i < unroll;i++) { - if (idx + i*blockDim.x < N) assign(&out[idx + i*blockDim.x], addto, val[i]); - } - - } -} - -template -void BinaryBroadcastComputeImpl(Stream *s, const OpReqType req, - const TBlob& lhs, const TBlob& rhs, const TBlob& out) { - if (req == kNullOp) return; - cudaStream_t stream = Stream::GetStream(s); - int N = out.shape_.Size(); - const int warpSize = 32; - const int unroll = 2; - int nthread = std::min(kMaxThreadsPerBlock, ((N + warpSize - 1)/warpSize)*warpSize ); - int ngrid = std::min(kBaseGridNum, (N + nthread*unroll - 1) / (nthread*unroll)); - Shape lstride = calc_stride(lhs.shape_.get()); - Shape rstride = calc_stride(rhs.shape_.get()); - binary_broadcast_kernel<<>>( - N, req == kAddTo, lhs.dptr(), rhs.dptr(), out.dptr(), lstride, rstride, - out.shape_.get()); -} - -const int nthread_reduce = kMaxThreadsPerBlock; template __launch_bounds__(nthread_reduce) __global__ void reduce_kernel(const int N, const int M, const bool addto, @@ -92,8 +48,8 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { int idx = idx0 + tidx; - Shape coord = unravel(idx, small_shape); - int idx_big0 = ravel(coord, big_shape0); + Shape coord = mxnet_op::unravel(idx, small_shape); + int idx_big0 = mxnet_op::ravel(coord, big_shape0); AType val, residual; Reducer::SetInitValue(val, residual); @@ -102,7 +58,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, int idx_big[unroll]; #pragma unroll for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride); + idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride); } DType tmp[unroll]; #pragma unroll @@ -175,10 +131,10 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext); for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { int idx = idx0 + tidx; - Shape coord = unravel(idx, small_shape); - int idx_big0 = ravel(coord, big_shape0); - int idx_lhs0 = ravel(coord, lhs_shape0); - int idx_rhs0 = ravel(coord, rhs_shape0); + Shape coord = mxnet_op::unravel(idx, small_shape); + int idx_big0 = mxnet_op::ravel(coord, big_shape0); + int idx_lhs0 = mxnet_op::ravel(coord, lhs_shape0); + int idx_rhs0 = mxnet_op::ravel(coord, rhs_shape0); DType val, residual; Reducer::SetInitValue(val, residual); @@ -189,9 +145,9 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, int idx_rhs[unroll]; #pragma unroll for (int u=0;u < unroll;u++) { - idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride); - idx_lhs[u] = idx_lhs0 + unravel_dot(k + u*by, lhs_shape, lhs_stride); - idx_rhs[u] = idx_rhs0 + unravel_dot(k + u*by, rhs_shape, rhs_stride); + idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride); + idx_lhs[u] = idx_lhs0 + mxnet_op::unravel_dot(k + u*by, lhs_shape, lhs_stride); + idx_rhs[u] = idx_rhs0 + mxnet_op::unravel_dot(k + u*by, rhs_shape, rhs_stride); } DType tmp[unroll]; #pragma unroll @@ -267,8 +223,8 @@ __global__ void reduce_kernel_M1(const int N, const bool addto, const DType* __restrict big, OType *small, const Shape bshape, const Shape sshape) { for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = unravel(idx, sshape); - int j = ravel(coord, bshape); + Shape coord = mxnet_op::unravel(idx, sshape); + int j = mxnet_op::ravel(coord, bshape); AType val, residual; Reducer::SetInitValue(val, residual); Reducer::Reduce(val, AType(OP::Map(big[j])), residual); @@ -289,10 +245,10 @@ __global__ void reduce_kernel_M1(const int N, const bool addto, const Shape rhs_shape, const Shape small_shape) { for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { - Shape coord = unravel(idx, small_shape); - int idx_big = ravel(coord, big_shape); - int idx_lhs = ravel(coord, lhs_shape); - int idx_rhs = ravel(coord, rhs_shape); + Shape coord = mxnet_op::unravel(idx, small_shape); + int idx_big = mxnet_op::ravel(coord, big_shape); + int idx_lhs = mxnet_op::ravel(coord, lhs_shape); + int idx_rhs = mxnet_op::ravel(coord, rhs_shape); DType val, residual; Reducer::SetInitValue(val, residual); Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); @@ -301,215 +257,6 @@ __global__ void reduce_kernel_M1(const int N, const bool addto, } } -// Returns the stride with which the fastest dimension is moving. -// Used to detect memory access scatter. -template -MSHADOW_XINLINE int fastest_stride(const Shape& small, const Shape& big, - const Shape& big_stride) { - for (int i = ndim-1; i >= 0; --i) { - if (big[i] != 1) { - return (small[i] == big[i]) ? 1 : big_stride[i]; - } - } - return 1; -} - -// Returns a/b integer division rounded up -template -Type ceil_idiv(const Type a, const Type b) { - return (a + b - 1)/b; -} - -// Configuration for ReduceImpl() -template -struct ReduceImplConfig { - static const int warpSize = 32; - static const int unroll_reduce = 2; - static const int maxLoopPerTB = 64; - int N; - int M; - int Mnext; - struct { - dim3 blockDim; - dim3 gridDim; - int shMemSize; - bool do_transpose; - } kernel_1; - struct { - int blockSize; - int gridSize; - } kernel_2; - size_t workspace_size; - - Shape rshape, rstride; - Shape lhs_shape, lhs_stride; - Shape rhs_shape, rhs_stride; -}; - -static inline uint64_t calc_num_load(const int X, const int Y, const int* strides) { - const int warpSize = ReduceImplConfig<1>::warpSize; - // Number of full warps - uint64_t num_full_warp = X / warpSize; - // Length of the partial warp i.e. number of threads that are performing loads - uint64_t len_part_warp = X % warpSize; - - uint64_t num_load_full = (std::min(warpSize, strides[0]) + - std::min(warpSize, strides[1]) + - std::min(warpSize, strides[2]))*num_full_warp; - - uint64_t num_load_part = - (std::min(len_part_warp, ceil_idiv(len_part_warp*strides[0], warpSize)) + - std::min(len_part_warp, ceil_idiv(len_part_warp*strides[1], warpSize)) + - std::min(len_part_warp, ceil_idiv(len_part_warp*strides[2], warpSize)))* - (len_part_warp != 0); - - uint64_t num_load = (num_load_full + num_load_part)*(uint64_t)Y; - return num_load; -} - -template -ReduceImplConfig ConfigureReduceImpl(const mxnet::TShape& small, - const mxnet::TShape& big, - const mxnet::TShape* lhs, - const mxnet::TShape* rhs) { - ReduceImplConfig config; - - diff(small.get(), big.get(), &config.rshape, &config.rstride); - config.N = small.Size(); - config.M = config.rshape.Size(); - - bool multiOp = false; - if (lhs != nullptr) { - CHECK_NOTNULL(rhs); - diff(small.get(), lhs->get(), &config.lhs_shape, - &config.lhs_stride); - diff(small.get(), rhs->get(), &config.rhs_shape, - &config.rhs_stride); - multiOp = true; - } - - config.workspace_size = 0; - - if (config.M == 1) { - config.kernel_1.blockDim.x = kMaxThreadsPerBlock; - config.kernel_1.gridDim.x = std::min((unsigned int)kBaseGridNum, - (config.N + config.kernel_1.blockDim.x - 1)/config.kernel_1.blockDim.x); - } else { - - int reduce_strides[3]; - reduce_strides[0] = fastest_stride(small.get(), big.get(), - big.get()); - reduce_strides[1] = (multiOp) ? fastest_stride(small.get(), - lhs->get(), lhs->get()) : 1; - reduce_strides[2] = (multiOp) ? fastest_stride(small.get(), - rhs->get(), rhs->get()) : 1; - - int reduce_strides_transp[3]; - reduce_strides_transp[0] = fastest_stride(small.get(), config.rshape, - config.rstride); - reduce_strides_transp[1] = (multiOp) ? - fastest_stride(small.get(), config.lhs_shape, config.lhs_stride) : 1; - reduce_strides_transp[2] = (multiOp) ? - fastest_stride(small.get(), config.rhs_shape, config.rhs_stride) : 1; - - uint64_t num_load = calc_num_load(config.N, config.M, reduce_strides); - uint64_t num_load_transp = calc_num_load(config.M, config.N, reduce_strides_transp); - - config.Mnext = 1; - config.kernel_1.do_transpose = (num_load > num_load_transp); - - config.kernel_1.blockDim.x = 0; - config.kernel_1.blockDim.y = 0; - - if (config.kernel_1.do_transpose) { - // Fastest thread ID goes through M - // Loop over N has step size config.kernel_1.blockDim.y - if (config.N < 8) { - config.kernel_1.blockDim.y = 1; - } else if (config.N < 256) { - config.kernel_1.blockDim.y = 4; - } else { - if (config.M < 8) { - config.kernel_1.blockDim.x = 1; - } else if (config.M < 256) { - config.kernel_1.blockDim.x = 4; - } else { - config.kernel_1.blockDim.x = config.warpSize; - } - } - } else { - // Fastest thread ID goes through N - // Loop over M has step size config.kernel_1.blockDim.y - if (config.M < 8) { - config.kernel_1.blockDim.y = 1; - } else if (config.M < 256) { - config.kernel_1.blockDim.y = 4; - } else { - if (config.N < 8) { - config.kernel_1.blockDim.x = 1; - } else if (config.N < 256) { - config.kernel_1.blockDim.x = 4; - } else { - config.kernel_1.blockDim.x = config.warpSize; - } - } - } - - if (config.kernel_1.blockDim.x == 0 && config.kernel_1.blockDim.y == 0) { - LOG(FATAL) << "Unable to set blockDim"; - } else if (config.kernel_1.blockDim.x == 0) { - config.kernel_1.blockDim.x = nthread_reduce / config.kernel_1.blockDim.y; - } else if (config.kernel_1.blockDim.y == 0) { - config.kernel_1.blockDim.y = nthread_reduce / config.kernel_1.blockDim.x; - } - - if (config.kernel_1.do_transpose) { - // Fastest thread ID goes through M - config.kernel_1.gridDim.x = std::min((unsigned int)kBaseGridNum, - ceil_idiv(config.N, config.kernel_1.blockDim.y)); - config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext); - int by = config.kernel_1.blockDim.y; - if (config.kernel_1.blockDim.y % config.warpSize == 0) { - // Fix shared memory bank conflict - by++; - } - config.kernel_1.shMemSize = (config.kernel_1.blockDim.x > 1) ? - config.kernel_1.blockDim.x*by*sizeof(DType) * 2 : 0; - // Maximum number of times we want TB to loop in M - // Max size of M-block each TB can handle - int maxMblock = config.kernel_1.blockDim.x*config.maxLoopPerTB; - config.Mnext = (config.M + maxMblock - 1) / maxMblock; - } else { - // Fastest thread ID goes through N - config.kernel_1.gridDim.x = std::min((unsigned int)kBaseGridNum, - ceil_idiv(config.N, config.kernel_1.blockDim.x)); - config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext); - config.kernel_1.shMemSize = (config.kernel_1.blockDim.y > 1) ? - config.kernel_1.blockDim.x*config.kernel_1.blockDim.y*sizeof(DType) * 2 : 0; - // Maximum number of times we want TB to loop in M - // Max size of M-block each TB can handle - int maxMblock = config.kernel_1.blockDim.y*config.maxLoopPerTB; - config.Mnext = (config.M + maxMblock - 1) / maxMblock; - } - - if (config.Mnext > 1) { - // small_dptr[] is N*Mnext*sizeof(DType) bytes - config.workspace_size += config.N*config.Mnext*sizeof(double); - // Set gridDim.y to Mnext - config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext); - } - - if (config.Mnext > 1) { - config.kernel_2.blockSize = kMaxThreadsPerBlock; - config.kernel_2.gridSize = std::min((int)kBaseGridNum, - (config.N + config.kernel_2.blockSize - 1)/config.kernel_2.blockSize ); - } - - } - - return config; -} - #define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \ if (do_unroll) { \ const int unrollVar = unrollAmount; \ @@ -522,7 +269,7 @@ ReduceImplConfig ConfigureReduceImpl(const mxnet::TShape& small, template void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config) { + const ReduceImplConfig& config) { if (config.M == 1) { reduce_kernel_M1 <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( @@ -544,13 +291,13 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, const int by = (config.kernel_1.do_transpose) ? config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig::unroll_reduce, UNROLL, { + const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); + KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { reduce_kernel <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( config.N, config.M, addto, big.dptr(), small_dptr, big.shape_.get(), - small.shape_.get(), config.rshape, config.rstride, config.Mnext, - config.kernel_1.do_transpose); + small.shape_.get(), config.rshape.get(), config.rstride.get(), + config.Mnext, config.kernel_1.do_transpose); }); MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel); @@ -566,7 +313,7 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, template void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs, const OpReqType req, const TBlob& big, const Tensor& workspace, - const ReduceImplConfig& config) { + const ReduceImplConfig& config) { if (config.M == 1) { reduce_kernel_M1 <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( @@ -589,14 +336,15 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const const int by = (config.kernel_1.do_transpose) ? config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; - const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce ); - KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig::unroll_reduce, UNROLL, { + const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce ); + KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, { reduce_kernel <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( config.N, config.M, addto, big.dptr(), lhs.dptr(), rhs.dptr(), small_dptr, big.shape_.get(), lhs.shape_.get(), - rhs.shape_.get(), small.shape_.get(), config.rshape, config.lhs_shape, - config.rhs_shape, config.rstride, config.lhs_stride, config.rhs_stride, config.Mnext, + rhs.shape_.get(), small.shape_.get(), config.rshape.get(), + config.lhs_shape.get(), config.rhs_shape.get(), config.rstride.get(), + config.lhs_stride.get(), config.rhs_stride.get(), config.Mnext, config.kernel_1.do_transpose); MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel); }); @@ -617,14 +365,14 @@ void Reduce(Stream *s, const TBlob& small, const OpReqType req, const Tensor& workspace, const TBlob& big) { if (req == kNullOp) return; cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config = - ConfigureReduceImpl(small.shape_, big.shape_, nullptr, nullptr); + ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType)); if (safe_acc) { MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { typedef typename std::conditional::type AccType; MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { typedef typename std::conditional::type OutType; - config = ConfigureReduceImpl(small.shape_, big.shape_, nullptr, nullptr); + config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr, + sizeof(AccType)); ReduceImpl( stream, small, req, big, workspace, config); }); @@ -639,8 +387,7 @@ void ReduceBool(Stream *s, const TBlob& small, const OpReqType req, const Tensor& workspace, const TBlob& big) { if (req == kNullOp) return; cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config = - ConfigureReduceImpl(small.shape_, big.shape_, nullptr, nullptr); + ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType)); ReduceImpl(stream, small, req, big, workspace, config); } @@ -654,25 +401,8 @@ void Reduce(Stream *s, const TBlob& small, const OpReqType req, const TBlob& lhs, const TBlob& rhs) { if (req == kNullOp) return; cudaStream_t stream = Stream::GetStream(s); - ReduceImplConfig config = - ConfigureReduceImpl(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_); + ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_, sizeof(DType)); ReduceImpl(stream, small, lhs, rhs, req, big, workspace, config); } -template -size_t ReduceWorkspaceSize(Stream *s, const mxnet::TShape& small, const OpReqType req, - const mxnet::TShape& big) { - if (req == kNullOp) return 0; - ReduceImplConfig config = ConfigureReduceImpl(small, big, nullptr, nullptr); - return config.workspace_size; -} - -template -size_t ReduceWorkspaceSize(Stream *s, const mxnet::TShape& small, const OpReqType req, - const mxnet::TShape& big, const mxnet::TShape& lhs, const mxnet::TShape& rhs) { - if (req == kNullOp) return 0; - ReduceImplConfig config = ConfigureReduceImpl(small, big, &lhs, &rhs); - return config.workspace_size; -} - #endif //MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_ diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 841fbcd28a68..ad3bd2a2bec9 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -31,27 +31,181 @@ #include #include #include "../mshadow_op.h" +#include "../mxnet_op.h" #include "../operator_common.h" namespace mxnet { namespace op { +namespace mxnet_op { +template +struct binary_broadcast_kernel { + /*! \brief Map function for binary_broadcast_kernel */ + template + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, IType *lhs, IType *rhs, + DType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); + } + } + + /*! \brief Map function for binary_broadcast_kernel */ + template + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, LType *lhs, RType *rhs, + OType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); + } + } + + /*! \brief Map function for binary_broadcast_kernel */ + template + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, IType lhs, IType *rhs, + DType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); + } + } + + /*! \brief Map function for binary_broadcast_kernel */ + /* used for mixed type binary ops */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, IType *lhs, DType *rhs, + DType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); + } + } + + /*! \brief Map function for binary_broadcast_kernel */ + /* used for mixed type binary ops */ + template::value && + !std::is_pointer::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape &lstride, const Shape &rstride, + const Shape &oshape, IType lhs, DType *rhs, + DType *out) { + Shape coord = unravel(base, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto ridx = static_cast(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); + } + } +}; + +template +struct csr_dns_csr_broadcast_kernel { + /*! + * \brief Map function for broadcast between csr and 1D vector + * \param row global thread id/assigned row id + * \param csr_data ptr to data buffer of csr matrix + * \param csr_indices ptr to indices buffer of csr matrix + * \param csr_indptr ptr to indptr buffer of csr matrix + * \param dns ptr to data buffer of the dense vector + * \param out ptr to the data buffer of the result csr matrix + */ + template + MSHADOW_XINLINE static void Map(index_t row, const DType *csr_data, const CType *csr_indices, + const RType *csr_indptr, const DType *dns, DType *out) { + const nnvm::dim_t curr_row_i = csr_indptr[row]; + const nnvm::dim_t next_row_i = csr_indptr[row + 1]; + for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) { + KERNEL_ASSIGN(out[iter], req, OP::Map(csr_data[iter], + (col_vec)? dns[row] : dns[csr_indices[iter]])); + } + } + + /*! + * \brief Map function for broadcast between csr and a scalar + * \param i global thread id + * \param csr_data ptr to data buffer of csr matrix + * \param scalar_ptr ptr to data buffer of the scalar tensor, only the 0-th element is used + * \param out ptr to the data buffer of output csr matrix + * \param nnz number of non-zero elements in input csr matrix + */ + template + MSHADOW_XINLINE static void Map(index_t i, const DType *csr_data, const DType* scalar_ptr, + DType *out, const nnvm::dim_t nnz) { + const DType scale = scalar_ptr[0]; + if (i < nnz) { + KERNEL_ASSIGN(out[i], req, OP::Map(csr_data[i], scale)); + } + } +}; + +template +struct csr_dns_map_kernel { + template + MSHADOW_XINLINE static void Map(index_t row, const DType *csr_data, const CType *csr_indices, + const RType *csr_indptr, DType *out, const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + if (row < num_rows) { + const nnvm::dim_t curr_row_i = csr_indptr[row]; + const nnvm::dim_t next_row_i = csr_indptr[row + 1]; + for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) { + const nnvm::dim_t target = row * num_cols + csr_indices[iter]; + KERNEL_ASSIGN(out[target], req, + reverse ? OP::Map(out[target], csr_data[iter]) : + OP::Map(csr_data[iter], out[target])); + } + } + } +}; + +} // namespace mxnet_op + namespace broadcast { using namespace mshadow; const int MAX_DIM = 5; -template -MSHADOW_XINLINE Shape calc_stride(const Shape& shape) { - Shape stride; - index_t cumprod = 1; - #pragma unroll - for (int i = ndim - 1; i >= 0; --i) { - stride[i] = (shape[i] > 1) ? cumprod : 0; - cumprod *= shape[i]; - } - return stride; -} - template MSHADOW_XINLINE void unravel_dot(const index_t idx, const Shape& shape, const Shape& stridej, const Shape& stridek, index_t* j, index_t* k) { @@ -67,28 +221,6 @@ MSHADOW_XINLINE void unravel_dot(const index_t idx, const Shape& shape, } } -template -MSHADOW_XINLINE Shape unravel(const index_t idx, const Shape& shape) { - Shape ret; - #pragma unroll - for (index_t i = ndim-1, j = idx; i >=0; --i) { - auto tmp = j / shape[i]; - ret[i] = j - tmp*shape[i]; - j = tmp; - } - return ret; -} - -template -MSHADOW_XINLINE index_t ravel(const Shape& coord, const Shape& shape) { - index_t ret = 0; - #pragma unroll - for (index_t i = 0; i < ndim; ++i) { - ret = ret * shape[i] + (shape[i] > 1) * coord[i]; - } - return ret; -} - template MSHADOW_XINLINE int diff(const Shape& small, const Shape& big, @@ -114,28 +246,6 @@ MSHADOW_XINLINE int diff(const Shape& small, return mdim; } -template -MSHADOW_XINLINE index_t unravel_dot(const index_t idx, const Shape& shape, - const Shape& stride) { - index_t ret = 0; - #pragma unroll - for (index_t i = ndim-1, j = idx; i >=0; --i) { - auto tmp = j / shape[i]; - ret += (j - tmp*shape[i])*stride[i]; - j = tmp; - } - return ret; -} - -template -MSHADOW_XINLINE index_t dot(const Shape& coord, const Shape& stride) { - index_t ret = 0; - #pragma unroll - for (int i = 0; i < ndim; ++i) - ret += coord[i] * stride[i]; - return ret; -} - template MSHADOW_XINLINE void assign(DType* dst, const bool addto, const DType src) { if (addto) { @@ -151,9 +261,9 @@ MSHADOW_XINLINE void binary_broadcast_assign(const index_t idx, const bool addto const DType* __restrict rhs, DType* out, const Shape& lshape, const Shape& rshape, const Shape& oshape) { - const Shape coord = unravel(idx, oshape); - const index_t j = ravel(coord, lshape); - const index_t k = ravel(coord, rshape); + const Shape coord = mxnet_op::unravel(idx, oshape); + const index_t j = mxnet_op::ravel(coord, lshape); + const index_t k = mxnet_op::ravel(coord, rshape); assign(&out[idx], addto, OP::Map(lhs[j], rhs[k])); } @@ -162,40 +272,44 @@ MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const const DType* __restrict big, OType *small, const Shape& bshape, const Shape& sshape, const Shape& rshape, const Shape& rstride) { - Shape coord = unravel(idx, sshape); - index_t j = ravel(coord, bshape); + Shape coord = mxnet_op::unravel(idx, sshape); + index_t j = mxnet_op::ravel(coord, bshape); AType val, residual; Reducer::SetInitValue(val, residual); for (size_t k = 0; k < M; ++k) { - coord = unravel(k, rshape); - Reducer::Reduce(val, AType(OP::Map(big[j + dot(coord, rstride)])), residual); + coord = mxnet_op::unravel(k, rshape); + Reducer::Reduce(val, AType(OP::Map(big[j + mxnet_op::dot(coord, rstride)])), residual); } Reducer::Finalize(val, residual); assign(&small[idx], addto, OType(val)); } -#ifdef __CUDACC__ -#include "broadcast_reduce-inl.cuh" - -#else +namespace { -template -void binary_broadcast_compute(const size_t N, const bool addto, const DType *lhs, - const DType *rhs, DType *out, const Shape lshape, - const Shape rshape, const Shape oshape) { - for (size_t idx = 0; idx < N; ++idx) { - binary_broadcast_assign(idx, addto, lhs, rhs, out, lshape, rshape, oshape); +// Returns the stride with which the fastest dimension is moving. +// Used to detect memory access scatter. +inline int fastest_stride(const TShape &small, const TShape &big, + const TShape &big_stride) { + const int ndim = small.ndim(); + for (int i = ndim-1; i >= 0; --i) { + if (big[i] != 1) { + return (small[i] == big[i]) ? 1 : big_stride[i]; + } } + return 1; } +} // namespace + template void BinaryBroadcastComputeImpl(Stream *s, const OpReqType req, const TBlob& lhs, const TBlob& rhs, const TBlob& out) { - if (req == kNullOp) return; - size_t N = out.shape_.Size(); - binary_broadcast_compute(N, req == kAddTo, lhs.dptr(), rhs.dptr(), - out.dptr(), lhs.shape_.get(), rhs.shape_.get(), - out.shape_.get()); + mshadow::Shape oshape = out.shape_.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(lhs.shape_.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(rhs.shape_.get()); + mxnet_op::Kernel, cpu>:: + template LaunchEx(s, out.shape_.Size(), req, lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), out.dptr()); } template @@ -220,8 +334,8 @@ void seq_reduce_compute_extra_mem(const size_t N, const size_t M, const bool add const index_t* ws_dptr) { #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (index_t idx = 0; idx < static_cast(N); ++idx) { - Shape coord = unravel(idx, sshape); - index_t j = ravel(coord, bshape); + Shape coord = mxnet_op::unravel(idx, sshape); + index_t j = mxnet_op::ravel(coord, bshape); DType val, residual; Reducer::SetInitValue(val, residual); for (size_t k = 0; k < M; ++k) { @@ -278,8 +392,8 @@ void ReduceWithExtraMem(Stream* s, const TBlob& small, const OpReqType req, size_t N = small.shape_.Size(), M = rshape.Size(); #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (index_t k = 0; k < static_cast(M); k++) { - Shape coord = unravel(k, rshape); - ws_dptr[k] = dot(coord, rstride); + Shape coord = mxnet_op::unravel(k, rshape); + ws_dptr[k] = mxnet_op::dot(coord, rstride); } seq_reduce_compute_extra_mem( @@ -287,19 +401,263 @@ void ReduceWithExtraMem(Stream* s, const TBlob& small, const OpReqType req, small.shape_.get(), rshape, rstride, ws_dptr); } -template -size_t ReduceWorkspaceSize(Stream *s, const mxnet::TShape& small, const OpReqType req, - const mxnet::TShape& big) { +inline size_t ReduceWorkspaceSize(Stream *s, const mxnet::TShape& small, const OpReqType req, + const mxnet::TShape& big, const int type_size) { return 0; } -template -size_t ReduceWorkspaceSize(Stream *s, const mxnet::TShape& small, const OpReqType req, - const mxnet::TShape& big, const mxnet::TShape& lhs, - const mxnet::TShape& rhs) { +inline size_t ReduceWorkspaceSize(Stream *s, const mxnet::TShape& small, const OpReqType req, + const mxnet::TShape& big, const mxnet::TShape& lhs, + const mxnet::TShape& rhs, const int type_size) { return 0; } +#if MXNET_USE_CUDA + +namespace { + +constexpr int warpSize = 32; +constexpr int unroll_reduce = 2; + +// Returns a/b integer division rounded up +template +Type ceil_idiv(const Type a, const Type b) { + return (a + b - 1)/b; +} + +uint64_t calc_num_load(const int X, const int Y, const int* strides) { + // Number of full warps + uint64_t num_full_warp = X / warpSize; + // Length of the partial warp i.e. number of threads that are performing loads + uint64_t len_part_warp = X % warpSize; + + uint64_t num_load_full = (std::min(warpSize, strides[0]) + + std::min(warpSize, strides[1]) + + std::min(warpSize, strides[2]))*num_full_warp; + + uint64_t num_load_part = + (std::min(len_part_warp, ceil_idiv(len_part_warp*strides[0], warpSize)) + + std::min(len_part_warp, ceil_idiv(len_part_warp*strides[1], warpSize)) + + std::min(len_part_warp, ceil_idiv(len_part_warp*strides[2], warpSize)))* + (len_part_warp != 0); + + uint64_t num_load = (num_load_full + num_load_part)*(uint64_t)Y; + return num_load; +} + +inline int diff(const TShape& small, const TShape& big, + TShape* dims, TShape* stride) { + int ndim = small.ndim(); + int mdim = 0; + #pragma unroll + for (int i = 0; i < ndim; ++i) { + mdim += small[i] != big[i]; + (*dims)[i] = (*stride)[i] = 1; + } + + index_t s = 1; + #pragma unroll + for (int i = ndim - 1, j = mdim; i >= 0; --i) { + if (small[i] != big[i]) { + --j; + (*stride)[j] = s; + (*dims)[j] = big[i]; + } + s *= big[i]; + } + return mdim; +} + +constexpr int nthread_reduce = 512; +constexpr index_t kBaseGridNum = 1024; + +} // namespace + +// Configuration for ReduceImpl() +struct ReduceImplConfig { + index_t N; + index_t M; + index_t Mnext; + struct { + dim3 blockDim; + dim3 gridDim; + int shMemSize; + bool do_transpose; + } kernel_1; + struct { + int blockSize; + int gridSize; + } kernel_2; + size_t workspace_size; + + TShape rshape, rstride; + TShape lhs_shape, lhs_stride; + TShape rhs_shape, rhs_stride; + + inline ReduceImplConfig(const ::mxnet::TShape& small, const ::mxnet::TShape& big, + const ::mxnet::TShape* lhs, + const ::mxnet::TShape* rhs, + const size_t type_size) : + rshape(small.ndim(), 1), rstride(small.ndim(), 1), + lhs_shape(small.ndim(), 1), lhs_stride(small.ndim(), 1), + rhs_shape(small.ndim(), 1), rhs_stride(small.ndim(), 1) { + constexpr int maxLoopPerTB = 64; + int ndim = small.ndim(); + + diff(small, big, &rshape, &rstride); + N = small.Size(); + + M = rshape[0]; + for (int i = 1; i < ndim; ++i) { + M *= rshape[i]; + } + + bool multiOp = false; + if (lhs != nullptr) { + CHECK_NOTNULL(rhs); + diff(small, *lhs, &lhs_shape, &lhs_stride); + diff(small, *rhs, &rhs_shape, &rhs_stride); + multiOp = true; + } + + workspace_size = 0; + kernel_1.shMemSize = 0; + kernel_1.do_transpose = false; + + if (M == 1) { + kernel_1.blockDim.x = nthread_reduce; + kernel_1.gridDim.x = std::min(kBaseGridNum, + static_cast((N + kernel_1.blockDim.x - 1)/kernel_1.blockDim.x)); + } else { + int reduce_strides[3]; + reduce_strides[0] = fastest_stride(small, big, big); + reduce_strides[1] = (multiOp) ? fastest_stride(small, *lhs, *lhs) : 1; + reduce_strides[2] = (multiOp) ? fastest_stride(small, *rhs, *rhs) : 1; + + int reduce_strides_transp[3]; + reduce_strides_transp[0] = fastest_stride(small, rshape, rstride); + reduce_strides_transp[1] = (multiOp) ? + fastest_stride(small, lhs_shape, lhs_stride) : 1; + reduce_strides_transp[2] = (multiOp) ? + fastest_stride(small, rhs_shape, rhs_stride) : 1; + + uint64_t num_load = calc_num_load(N, M, reduce_strides); + uint64_t num_load_transp = calc_num_load(M, N, reduce_strides_transp); + + Mnext = 1; + kernel_1.do_transpose = (num_load > num_load_transp); + + kernel_1.blockDim.x = 0; + kernel_1.blockDim.y = 0; + + if (kernel_1.do_transpose) { + // Fastest thread ID goes through M + // Loop over N has step size kernel_1.blockDim.y + if (N < 8) { + kernel_1.blockDim.y = 1; + } else if (N < 256) { + kernel_1.blockDim.y = 4; + } else { + if (M < 8) { + kernel_1.blockDim.x = 1; + } else if (M < 256) { + kernel_1.blockDim.x = 4; + } else { + kernel_1.blockDim.x = warpSize; + } + } + } else { + // Fastest thread ID goes through N + // Loop over M has step size kernel_1.blockDim.y + if (M < 8) { + kernel_1.blockDim.y = 1; + } else if (M < 256) { + kernel_1.blockDim.y = 4; + } else { + if (N < 8) { + kernel_1.blockDim.x = 1; + } else if (N < 256) { + kernel_1.blockDim.x = 4; + } else { + kernel_1.blockDim.x = warpSize; + } + } + } + + if (kernel_1.blockDim.x == 0 && kernel_1.blockDim.y == 0) { + LOG(FATAL) << "Unable to set blockDim"; + } else if (kernel_1.blockDim.x == 0) { + kernel_1.blockDim.x = nthread_reduce / kernel_1.blockDim.y; + } else if (kernel_1.blockDim.y == 0) { + kernel_1.blockDim.y = nthread_reduce / kernel_1.blockDim.x; + } + + if (kernel_1.do_transpose) { + // Fastest thread ID goes through M + kernel_1.gridDim.x = std::min((unsigned int)kBaseGridNum, + ceil_idiv(N, kernel_1.blockDim.y)); + kernel_1.gridDim.y = std::min(kBaseGridNum, Mnext); + int by = kernel_1.blockDim.y; + if (kernel_1.blockDim.y % warpSize == 0) { + // Fix shared memory bank conflict + by++; + } + kernel_1.shMemSize = (kernel_1.blockDim.x > 1) ? + kernel_1.blockDim.x*by*type_size * 2 : 0; + // Maximum number of times we want TB to loop in M + // Max size of M-block each TB can handle + int maxMblock = kernel_1.blockDim.x*maxLoopPerTB; + Mnext = (M + maxMblock - 1) / maxMblock; + } else { + // Fastest thread ID goes through N + kernel_1.gridDim.x = std::min((unsigned int)kBaseGridNum, + ceil_idiv(N, kernel_1.blockDim.x)); + kernel_1.gridDim.y = std::min(kBaseGridNum, Mnext); + kernel_1.shMemSize = (kernel_1.blockDim.y > 1) ? + kernel_1.blockDim.x*kernel_1.blockDim.y*type_size * 2 : 0; + // Maximum number of times we want TB to loop in M + // Max size of M-block each TB can handle + int maxMblock = kernel_1.blockDim.y*maxLoopPerTB; + Mnext = (M + maxMblock - 1) / maxMblock; + } + + if (Mnext > 1) { + // small_dptr[] is N*Mnext*type_size bytes + workspace_size += N*Mnext*sizeof(double); + // Set gridDim.y to Mnext + kernel_1.gridDim.y = std::min(kBaseGridNum, Mnext); + } + + if (Mnext > 1) { + kernel_2.blockSize = nthread_reduce; + kernel_2.gridSize = std::min(kBaseGridNum, + static_cast((N + kernel_2.blockSize - 1)/kernel_2.blockSize)); + } + } + } +}; + +inline size_t ReduceWorkspaceSize(Stream *s, const ::mxnet::TShape& small, const OpReqType req, + const ::mxnet::TShape& big, const int type_size) { + if (req == kNullOp) return 0; + ReduceImplConfig config(small, big, nullptr, nullptr, type_size); + return config.workspace_size; +} + +inline size_t ReduceWorkspaceSize(Stream *s, const ::mxnet::TShape& small, const OpReqType req, + const ::mxnet::TShape& big, const ::mxnet::TShape& lhs, + const ::mxnet::TShape& rhs, const int type_size) { + if (req == kNullOp) return 0; + ReduceImplConfig config(small, big, &lhs, &rhs, type_size); + return config.workspace_size; +} + +#ifdef __CUDACC__ +#include "broadcast_reduce-inl.cuh" +#endif + +#endif // MXNET_USE_CUDA + template MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const bool addto, const DType* __restrict big, const DType* __restrict lhs, @@ -310,21 +668,21 @@ MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const const Shape& lhs_shape, const Shape& rhs_shape, const Shape& rstride, const Shape& lhs_stride, const Shape& rhs_stride) { - Shape coord = unravel(idx, small_shape); - const index_t idx_big0 = ravel(coord, big_shape); - const index_t idx_lhs0 = ravel(coord, lhs_shape0); - const index_t idx_rhs0 = ravel(coord, rhs_shape0); + Shape coord = mxnet_op::unravel(idx, small_shape); + const index_t idx_big0 = mxnet_op::ravel(coord, big_shape); + const index_t idx_lhs0 = mxnet_op::ravel(coord, lhs_shape0); + const index_t idx_rhs0 = mxnet_op::ravel(coord, rhs_shape0); DType val, residual; Reducer::SetInitValue(val, residual); for (size_t k = 0; k < M; ++k) { - Shape coord_big = unravel(k, rshape); - index_t idx_big = idx_big0 + dot(coord_big, rstride); + Shape coord_big = mxnet_op::unravel(k, rshape); + index_t idx_big = idx_big0 + mxnet_op::dot(coord_big, rstride); - Shape coord_lhs = unravel(k, lhs_shape); - index_t idx_lhs = idx_lhs0 + dot(coord_lhs, lhs_stride); + Shape coord_lhs = mxnet_op::unravel(k, lhs_shape); + index_t idx_lhs = idx_lhs0 + mxnet_op::dot(coord_lhs, lhs_stride); - Shape coord_rhs = unravel(k, rhs_shape); - index_t idx_rhs = idx_rhs0 + dot(coord_rhs, rhs_stride); + Shape coord_rhs = mxnet_op::unravel(k, rhs_shape); + index_t idx_rhs = idx_rhs0 + mxnet_op::dot(coord_rhs, rhs_stride); Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); } @@ -374,7 +732,31 @@ void Reduce(Stream *s, const TBlob& small, const OpReqType req, lhs.shape_.get(), rhs.shape_.get()); } +#if MXNET_USE_CUDA + +void RTCReduce(const OpContext& ctx, + const TBlob& small, + const OpReqType req, + const Tensor& workspace, + const TBlob& big, + const std::string& reducer, + int ndim, + const std::string& OP); + +void RTCReduce(const OpContext& ctx, + const TBlob& small, + const OpReqType req, + const Tensor& workspace, + const TBlob& big, + const TBlob &lhs, + const TBlob &rhs, + const std::string& reducer, + int ndim, + const std::string& OP1, + const std::string& OP2); + #endif + } // namespace broadcast } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 371fcee9a47f..ea9aa6a7cf65 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -634,8 +634,8 @@ void ReduceAxesComputeImpl(const OpContext& ctx, const TBlob in_data = inputs[0].reshape(src_shape); const TBlob out_data = outputs[0].reshape(dst_shape); BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data.shape_, req[0], in_data.shape_); + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_, sizeof(OType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); broadcast::Reduce( @@ -667,8 +667,8 @@ void ReduceAxesComputeBoolImpl(const OpContext& ctx, const TBlob in_data = inputs[0].reshape(src_shape); const TBlob out_data = outputs[0].reshape(dst_shape); BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data.shape_, req[0], in_data.shape_); + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_, sizeof(OType)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); broadcast::ReduceBool( @@ -1633,7 +1633,7 @@ struct pick { const IType *idx, index_t M, int stride, mshadow::Shape bshape, mshadow::Shape sshape) { - using namespace broadcast; + using namespace mxnet_op; index_t j = static_cast(idx[i]); if (clip) { if (j <= 0) j = 0; @@ -1655,7 +1655,7 @@ struct pick_grad { const IType *idx, index_t M, int stride, mshadow::Shape bshape, mshadow::Shape sshape) { - using namespace broadcast; + using namespace mxnet_op; index_t j = static_cast(idx[i]); if (clip) { if (j <= 0) j = 0; diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index c42590fd0848..e8fec3081c0a 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -30,7 +30,7 @@ #include #include "../mxnet_op.h" #include "../operator_common.h" -#include "../../src/operator/tensor/init_op.h" +#include "./init_op.h" #ifdef __CUDACC__ #include "./cast_storage-inl.cuh" #endif // __CUDACC__ diff --git a/src/operator/tensor/elemwise_binary_broadcast_op-inl.cuh b/src/operator/tensor/elemwise_binary_broadcast_op-inl.cuh deleted file mode 100644 index d65e12aef86e..000000000000 --- a/src/operator/tensor/elemwise_binary_broadcast_op-inl.cuh +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file elemwise_binary_broadcast_op-inl.cuh - * \brief CUDA specific Function definition of elementwise binary broadcast operators - */ -#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_CUH_ -#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_CUH_ -#include -#include -#include -#include -#include -#include -#include "broadcast_reduce-inl.h" -namespace mxnet { -namespace op { -template -inline typename std::enable_if::value, void>::type -BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace broadcast; - mxnet::TShape new_lshape, new_rshape, new_oshape; - int ndim = BinaryBroadcastShapeCompact(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_, - &new_lshape, &new_rshape, &new_oshape); - if (!ndim) { - ElemwiseBinaryOp::BackwardUseNone(attrs, ctx, inputs, req, outputs); - } else { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Stream *s = ctx.get_stream(); - const TBlob lhs = outputs[0].reshape(new_lshape); - const TBlob rhs = outputs[1].reshape(new_rshape); - const TBlob out = inputs[0].reshape(new_oshape); - BROADCAST_NDIM_SWITCH(ndim, NDim, { - // Request temporary storage - size_t workspace_size = new_oshape.Size(); - Tensor workspace = - ctx.requested[0].get_space_typed( - Shape1(workspace_size * sizeof(index_t)), s); - if (out.shape_.Size() != 0) { - Reduce(s, lhs, req[0], workspace, out); - Reduce(s, rhs, req[1], workspace, out); - } else { - using namespace mxnet_op; - if (lhs.shape_.Size() != 0) { - MSHADOW_TYPE_SWITCH(lhs.type_flag_, LType, { - Kernel::Launch(s, lhs.shape_.Size(), lhs.dptr()); - }); - } - if (rhs.shape_.Size() != 0) { - MSHADOW_TYPE_SWITCH(rhs.type_flag_, RType, { - Kernel::Launch(s, rhs.shape_.Size(), rhs.dptr()); - }); - } - } - }); - }); - } -} -} // namespace op -} // namespace mxnet -#endif diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.cc b/src/operator/tensor/elemwise_binary_broadcast_op.cc new file mode 100644 index 000000000000..2f9832a173f6 --- /dev/null +++ b/src/operator/tensor/elemwise_binary_broadcast_op.cc @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#if MXNET_USE_CUDA +#include +#endif // MXNET_USE_CUDA + +#include "broadcast_reduce-inl.h" +#include "elemwise_binary_broadcast_op.h" + +#if MXNET_USE_CUDA +#include "../../common/cuda/rtc/vectorization-inl.h" +#include "../../common/cuda/rtc.h" +#endif // MXNET_USE_CUDA + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUDA + +struct binary_broadcast_params { + const void* inputs[2]; + void* outputs[1]; + index_t stride[2][broadcast::MAX_DIM]; + index_t oshape[broadcast::MAX_DIM]; + index_t size[2]; +}; + +const char broadcast_kernel_fwd[] = R"code( +struct binary_broadcast_params { + const void* inputs[2]; + void* outputs[1]; + index_t stride[2][util::MAX_DIM]; + index_t oshape[util::MAX_DIM]; + index_t size[2]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void binary_broadcast_kernel( + const binary_broadcast_params param, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + const index_t M = num_aligned_elements * other_dim; + + VectorizedLoader lloader( + reinterpret_cast(param.inputs[0]), param.size[0]); + VectorizedLoader rloader( + reinterpret_cast(param.inputs[1]), param.size[1]); + + using IType0 = AccType; + using IType1 = AccType; + using OType = AccType; + + + for (index_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < M; + idx += gridDim.x * blockDim.x) { + OutputType0 * current_output_pointer; + index_t output_size; + index_t output_idx; + if (aligned) { + // Simplified case + index_t lindex, rindex; + util::unravel_dot(idx * nvec, param.oshape, + param.stride[0], param.stride[1], + &lindex, &rindex); + lloader.load(lindex / nvec, param.size[0]); + rloader.load(rindex / nvec, param.size[1]); + current_output_pointer = reinterpret_cast(param.outputs[0]); + output_size = N; + output_idx = idx; + } else { + const index_t row = idx / num_aligned_elements; + const index_t lead_dim_idx = idx - row * num_aligned_elements; + + index_t lindex, rindex; + const index_t original_idx = max(lead_dim_idx * nvec - lloader.alignment(), + static_cast(0)) + + row * lead_dim; + util::unravel_dot(original_idx, param.oshape, + param.stride[0], param.stride[1], + &lindex, &rindex); + lloader.load((lindex + lloader.alignment()) / nvec, param.size[0]); + rloader.load((rindex + lloader.alignment()) / nvec, param.size[1]); + current_output_pointer = reinterpret_cast(param.outputs[0]) + row * lead_dim; + output_size = lead_dim; + output_idx = lead_dim_idx; + } + VectorizedStorer storer(current_output_pointer, output_size); + + if (req == OpReqType::kAddTo) { + storer.load(output_idx, output_size); + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const auto temp = OP(IType0::from(lloader.separate()[i]), + IType1::from(rloader.separate()[i])); + + if (req == OpReqType::kAddTo) { + const auto temp2 = op::add(temp, OType::from(storer.separate()[i])); + storer.separate()[i] = OType::to(temp2); + } else { + storer.separate()[i] = OType::to(temp); + } + } + storer.store(output_idx, output_size); + } +} +)code"; + +const char single_side_broadcast_kernel_fwd[] = R"code( +struct binary_broadcast_params { + const void* inputs[2]; + void* outputs[1]; + index_t stride[2][util::MAX_DIM]; + index_t oshape[util::MAX_DIM]; + index_t size[2]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void single_side_binary_broadcast_kernel( + const binary_broadcast_params param, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + const index_t M = num_aligned_elements * other_dim; + constexpr int other_side = 1 - side; + + VectorizedLoader lloader( + reinterpret_cast(param.inputs[side]), param.size[side]); + + using IType = AccType; + using IType2 = AccType; + using OType = AccType; + + + for (index_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < M; + idx += gridDim.x * blockDim.x) { + index_t original_idx; + OutputType0 * current_output_pointer; + index_t output_size; + index_t output_idx; + if (aligned) { + // Simplified case + original_idx = idx * nvec; + const index_t lindex = util::unravel_dot(original_idx, param.oshape, + param.stride[side]); + lloader.load(lindex / nvec, param.size[side]); + current_output_pointer = reinterpret_cast(param.outputs[0]); + output_size = N; + output_idx = idx; + } else { + const index_t row = idx / num_aligned_elements; + const index_t lead_dim_idx = idx - row * num_aligned_elements; + original_idx = lead_dim_idx * nvec - + lloader.alignment() + row * lead_dim; + const index_t original_idx_clamped = max(lead_dim_idx * nvec - lloader.alignment(), + static_cast(0)) + + row * lead_dim; + const index_t lindex = util::unravel_dot(original_idx_clamped, param.oshape, + param.stride[side]); + lloader.load((lindex + lloader.alignment()) / nvec, param.size[side]); + current_output_pointer = reinterpret_cast(param.outputs[0]) + row * lead_dim; + output_size = lead_dim; + output_idx = lead_dim_idx; + } + VectorizedStorer storer(current_output_pointer, output_size); + + if (req == OpReqType::kAddTo) { + storer.load(output_idx, output_size); + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const index_t rindex = min(max(util::unravel_dot(original_idx + i, + param.oshape, + param.stride[other_side]), + static_cast(0)), + param.size[other_side] - 1); + const auto rinput = IType2::from( + reinterpret_cast(param.inputs[other_side]) + [rindex]); + + typename OType::type temp; + if (side == 0) { + // Left side is vectorized + temp = OP(IType::from(lloader.separate()[i]), + rinput); + } else { + // Right side is vectorized + temp = OP(rinput, + IType::from(lloader.separate()[i])); + } + + if (req == OpReqType::kAddTo) { + const auto temp2 = op::add(temp, OType::from(storer.separate()[i])); + storer.separate()[i] = OType::to(temp2); + } else { + storer.separate()[i] = OType::to(temp); + } + } + storer.store(output_idx, output_size); + } +} +)code"; +namespace { + +std::vector calc_stride(const mxnet::TShape& shape, int ndim) { + CHECK_EQ(ndim, shape.ndim()); + std::vector stride(ndim); + index_t cumprod = 1; + for (int i = shape.ndim() - 1; i >= 0; --i) { + stride[i] = (shape[i] > 1) ? cumprod : 0; + cumprod *= shape[i]; + } + return stride; +} + +} // namespace + +void BinaryBroadcastRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet::common::cuda::rtc; + if (outputs[0].shape_.Size() == 0U) return; + if (req[0] == kNullOp) return; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_rshape, &new_oshape); + // Pad the ndim + BROADCAST_NDIM_SWITCH(ndim, NDim, { + if (ndim != 0) { + ndim = NDim; + } + }); + + if (!ndim) { + ElemwiseBinaryRTCCompute {OP}(attrs, ctx, inputs, req, outputs); + } else { + mshadow::Stream *s = ctx.get_stream(); + const TBlob& lhs = inputs[0].reshape(new_lshape); + const TBlob& rhs = inputs[1].reshape(new_rshape); + const TBlob& output = outputs[0].reshape(new_oshape); + + const auto& lstride = calc_stride(lhs.shape_, ndim); + const auto& rstride = calc_stride(rhs.shape_, ndim); + + size_t output_type_size = common::mshadow_type_info(outputs[0].type_flag_).size; + const int nvec = output_type_size <= sizeof(uint64_t) + ? (sizeof(uint64_t) / output_type_size) + : 1; + binary_broadcast_params params{}; + params.inputs[0] = lhs.dptr_; + params.inputs[1] = rhs.dptr_; + params.outputs[0] = output.dptr_; + for (int i = 0; i < ndim; ++i) { + params.stride[0][i] = lstride[i]; + params.stride[1][i] = rstride[i]; + params.oshape[i] = new_oshape[i]; + } + params.size[0] = lhs.shape_.Size(); + params.size[1] = rhs.shape_.Size(); + + index_t lead_dim = 1; + for (int i = ndim - 1; i >= 0; --i) { + /* Find the first non-1 dimension + to check the alignment + */ + if (params.oshape[i] != 1) { + lead_dim = params.oshape[i]; + break; + } + } + const index_t other_dim = output.shape_.Size() / lead_dim; + + int first_different = -1; + int common_shape = 1; + for (int i = ndim - 1; i >= 0; --i) { + if (params.stride[0][i] == params.stride[1][i]) { + common_shape *= params.oshape[i]; + } else { + first_different = i; + break; + } + } + + int lead_input_num = 0; + std::string code = std::string("const OpReqType req = ") + + util::to_string(req[0]) + + ";\n" + "#define OP op::" + + OP + + "\n" + "const int ndim = " + + std::to_string(ndim) + + ";\n"; + if (common_shape != 1) { + VectorizedKernelRTCLauncher(code, "binary_broadcast_kernel", + broadcast_kernel_fwd, nvec, + lead_dim, other_dim, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id, + lead_input_num); + } else { + if (params.stride[0][first_different] == 0) { + lead_input_num = 1; + code += "const int side = 1;\n" + "using DType = InputType1;\n" + "using DType2 = InputType0;\n"; + } else { + code += "const int side = 0;\n" + "using DType = InputType0;\n" + "using DType2 = InputType1;\n"; + } + VectorizedKernelRTCLauncher(code, "single_side_binary_broadcast_kernel", + single_side_broadcast_kernel_fwd, nvec, + lead_dim, other_dim, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id, + lead_input_num); + } + } +} + +void BinaryBroadcastRTCBackwardUseNone::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(outputs[0].shape_, outputs[1].shape_, inputs[0].shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryRTCBwdUseNone {LOP, ROP}(attrs, ctx, inputs, req, outputs); + } else { + Stream *s = ctx.get_stream(); + const TBlob lhs = outputs[0].reshape(new_lshape); + const TBlob rhs = outputs[1].reshape(new_rshape); + const TBlob out = inputs[0].reshape(new_oshape); + BROADCAST_NDIM_SWITCH(ndim, NDim, { + // Request temporary storage + size_t workspace_size = new_oshape.Size(); + Tensor workspace = + ctx.requested[0].get_space_typed( + Shape1(workspace_size * sizeof(index_t)), s); + if (out.shape_.Size() != 0) { + broadcast::RTCReduce(ctx, lhs, req[0], + workspace, out, + "red::sum", NDim, LOP); + broadcast::RTCReduce(ctx, rhs, req[1], + workspace, out, + "red::sum", NDim, ROP); + } else { + using namespace common::cuda::rtc::util; + if (lhs.shape_.Size() != 0) { + cudaMemsetAsync(lhs.dptr_, 0, + lhs.shape_.Size() * common::mshadow_type_info(lhs.type_flag_).size, + Stream::GetStream(s)); + } + if (rhs.shape_.Size() != 0) { + cudaMemsetAsync(rhs.dptr_, 0, + rhs.shape_.Size() * common::mshadow_type_info(rhs.type_flag_).size, + Stream::GetStream(s)); + } + } + }); + } +} + +void BinaryBroadcastRTCBackwardUseIn::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + // skip kernel launch for zero-size tensors + if (inputs[0].shape_.Size() == 0U) { + return; + } + mxnet::TShape new_lshape, new_rshape, new_oshape; + const bool need_bc = BinaryBroadcastShapeCompact(outputs[0].shape_, + outputs[1].shape_, inputs[0].shape_, + &new_lshape, &new_rshape, &new_oshape) != 0; + if (!need_bc) { + ElemwiseBinaryRTCBwdUseIn {LOP, ROP}(attrs, ctx, inputs, req, outputs); + } else { + BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, { + using namespace mshadow; + Stream *s = ctx.get_stream(); + const TBlob lgrad = outputs[0].reshape(new_lshape); + const TBlob rgrad = outputs[1].reshape(new_rshape); + const TBlob ograd = inputs[0].reshape(new_oshape); + const TBlob lhs = inputs[1].reshape(new_lshape); + const TBlob rhs = inputs[2].reshape(new_rshape); + size_t workspace_size_l = broadcast::ReduceWorkspaceSize( + s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, + rhs.shape_, common::mshadow_type_info(outputs[0].type_flag_).size); + size_t workspace_size_r = broadcast::ReduceWorkspaceSize( + s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, + rhs.shape_, common::mshadow_type_info(outputs[1].type_flag_).size); + size_t workspace_size = std::max(workspace_size_l, workspace_size_r); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + if (req[0] != kNullOp) { + broadcast::RTCReduce(ctx, lgrad, req[0], workspace, + ograd, lhs, rhs, "red::sum", NDim, + "mul", LOP); + } + if (req[1] != kNullOp) { + broadcast::RTCReduce(ctx, rgrad, req[1], workspace, + ograd, lhs, rhs, "red::sum", NDim, + "mul", ROP); + } + }); + } +} + +#endif // MXNET_USE_CUDA + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index ca83bdb01e37..e3ba92ddd0ff 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -183,174 +183,10 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet: } else { LOG(FATAL) << "Too many broadcast dimensions with operands " << lshape << " " << rshape; } + return j; } -namespace mxnet_op { -template -struct binary_broadcast_kernel { - /*! \brief Map function for binary_broadcast_kernel */ - template - MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, - const Shape &lstride, const Shape &rstride, - const Shape &oshape, IType *lhs, IType *rhs, - DType *out) { - Shape coord = unravel(base, oshape); - auto lidx = static_cast(dot(coord, lstride)); - auto ridx = static_cast(dot(coord, rstride)); - KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); - // starts from 1 to avoid extra inc at end of loop - for (index_t i = 1; i < length; ++i) { - inc(&coord, oshape, &lidx, lstride, &ridx, rstride); - // When tuning, don't actually run the op, since it's not going to be tuned against - // the actual op we'll eventually be using - KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); - } - } - - /*! \brief Map function for binary_broadcast_kernel */ - template - MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, - const Shape &lstride, const Shape &rstride, - const Shape &oshape, LType *lhs, RType *rhs, - OType *out) { - Shape coord = unravel(base, oshape); - auto lidx = static_cast(dot(coord, lstride)); - auto ridx = static_cast(dot(coord, rstride)); - KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); - // starts from 1 to avoid extra inc at end of loop - for (index_t i = 1; i < length; ++i) { - inc(&coord, oshape, &lidx, lstride, &ridx, rstride); - // When tuning, don't actually run the op, since it's not going to be tuned against - // the actual op we'll eventually be using - KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); - } - } - - /*! \brief Map function for binary_broadcast_kernel */ - template - MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, - const Shape &lstride, const Shape &rstride, - const Shape &oshape, IType lhs, IType *rhs, - DType *out) { - Shape coord = unravel(base, oshape); - auto lidx = static_cast(dot(coord, lstride)); - auto ridx = static_cast(dot(coord, rstride)); - KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx])); - // starts from 1 to avoid extra inc at end of loop - for (index_t i = 1; i < length; ++i) { - inc(&coord, oshape, &lidx, lstride, &ridx, rstride); - // When tuning, don't actually run the op, since it's not going to be tuned against - // the actual op we'll eventually be using - KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); - } - } - - /*! \brief Map function for binary_broadcast_kernel */ - /* used for mixed type binary ops */ - template::value, int>::type = 0> - MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, - const Shape &lstride, const Shape &rstride, - const Shape &oshape, IType *lhs, DType *rhs, - DType *out) { - Shape coord = unravel(base, oshape); - auto lidx = static_cast(dot(coord, lstride)); - auto ridx = static_cast(dot(coord, rstride)); - KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); - // starts from 1 to avoid extra inc at end of loop - for (index_t i = 1; i < length; ++i) { - inc(&coord, oshape, &lidx, lstride, &ridx, rstride); - // When tuning, don't actually run the op, since it's not going to be tuned against - // the actual op we'll eventually be using - KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); - } - } - - /*! \brief Map function for binary_broadcast_kernel */ - /* used for mixed type binary ops */ - template::value && - !std::is_pointer::value, int>::type = 0> - MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, - const Shape &lstride, const Shape &rstride, - const Shape &oshape, IType lhs, DType *rhs, - DType *out) { - Shape coord = unravel(base, oshape); - auto lidx = static_cast(dot(coord, lstride)); - auto ridx = static_cast(dot(coord, rstride)); - KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx])); - // starts from 1 to avoid extra inc at end of loop - for (index_t i = 1; i < length; ++i) { - inc(&coord, oshape, &lidx, lstride, &ridx, rstride); - // When tuning, don't actually run the op, since it's not going to be tuned against - // the actual op we'll eventually be using - KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); - } - } -}; - -template -struct csr_dns_csr_broadcast_kernel { - /*! - * \brief Map function for broadcast between csr and 1D vector - * \param row global thread id/assigned row id - * \param csr_data ptr to data buffer of csr matrix - * \param csr_indices ptr to indices buffer of csr matrix - * \param csr_indptr ptr to indptr buffer of csr matrix - * \param dns ptr to data buffer of the dense vector - * \param out ptr to the data buffer of the result csr matrix - */ - template - MSHADOW_XINLINE static void Map(index_t row, const DType *csr_data, const CType *csr_indices, - const RType *csr_indptr, const DType *dns, DType *out) { - const nnvm::dim_t curr_row_i = csr_indptr[row]; - const nnvm::dim_t next_row_i = csr_indptr[row + 1]; - for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) { - KERNEL_ASSIGN(out[iter], req, OP::Map(csr_data[iter], - (col_vec)? dns[row] : dns[csr_indices[iter]])); - } - } - - /*! - * \brief Map function for broadcast between csr and a scalar - * \param i global thread id - * \param csr_data ptr to data buffer of csr matrix - * \param scalar_ptr ptr to data buffer of the scalar tensor, only the 0-th element is used - * \param out ptr to the data buffer of output csr matrix - * \param nnz number of non-zero elements in input csr matrix - */ - template - MSHADOW_XINLINE static void Map(index_t i, const DType *csr_data, const DType* scalar_ptr, - DType *out, const nnvm::dim_t nnz) { - const DType scale = scalar_ptr[0]; - if (i < nnz) { - KERNEL_ASSIGN(out[i], req, OP::Map(csr_data[i], scale)); - } - } -}; - -template -struct csr_dns_map_kernel { - template - MSHADOW_XINLINE static void Map(index_t row, const DType *csr_data, const CType *csr_indices, - const RType *csr_indptr, DType *out, const nnvm::dim_t num_rows, - const nnvm::dim_t num_cols) { - if (row < num_rows) { - const nnvm::dim_t curr_row_i = csr_indptr[row]; - const nnvm::dim_t next_row_i = csr_indptr[row + 1]; - for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) { - const nnvm::dim_t target = row * num_cols + csr_indices[iter]; - KERNEL_ASSIGN(out[target], req, - reverse ? OP::Map(out[target], csr_data[iter]) : - OP::Map(csr_data[iter], out[target])); - } - } - } -}; - -} // namespace mxnet_op - template void BinaryBroadcastIntCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -402,17 +238,51 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); + broadcast::BinaryBroadcastComputeImpl(s, req[0], + inputs[0].reshape(new_lshape), + inputs[1].reshape(new_rshape), + outputs[0].reshape(new_oshape)); }); }); } } +#if MXNET_USE_CUDA + +struct BinaryBroadcastRTCCompute { + std::string OP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +struct BinaryBroadcastRTCBackwardUseNone { + std::string LOP; + std::string ROP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +struct BinaryBroadcastRTCBackwardUseIn { + std::string LOP; + std::string ROP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +#endif // MXNET_USE_CUDA + template void BinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -715,14 +585,6 @@ BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs, } } -template -inline typename std::enable_if::value, void>::type -BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs); - template void BinaryBroadcastBackwardUseInImplWithWorkspace(const OpContext& ctx, const std::vector& inputs, @@ -766,10 +628,10 @@ inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx, const TBlob ograd = inputs[0].reshape(new_oshape); const TBlob lhs = inputs[1].reshape(new_lshape); const TBlob rhs = inputs[2].reshape(new_rshape); - size_t workspace_size_l = ReduceWorkspaceSize( - s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_); - size_t workspace_size_r = ReduceWorkspaceSize( - s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_); + size_t workspace_size_l = ReduceWorkspaceSize( + s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); + size_t workspace_size_r = ReduceWorkspaceSize( + s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType)); size_t workspace_size = std::max(workspace_size_l, workspace_size_r); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); @@ -824,7 +686,4 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -#ifdef __CUDACC__ -#include "./elemwise_binary_broadcast_op-inl.cuh" -#endif #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_ diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu index aa0850ac5bbf..adc1dbb12cb9 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu @@ -29,43 +29,38 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(broadcast_add) -.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FCompute", BinaryBroadcastRTCCompute{"add"}) .set_attr("FComputeEx", BinaryBroadcastComputeDenseEx); NNVM_REGISTER_OP(_backward_broadcast_add) -.set_attr("FCompute", BinaryBroadcastBackwardUseNone); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseNone{"identity", "identity"}); NNVM_REGISTER_OP(broadcast_sub) -.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FCompute", BinaryBroadcastRTCCompute{"sub"}) .set_attr("FComputeEx", BinaryBroadcastComputeDenseEx); NNVM_REGISTER_OP(_backward_broadcast_sub) -.set_attr("FCompute", BinaryBroadcastBackwardUseNone); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseNone{"identity", "negation"}); NNVM_REGISTER_OP(broadcast_mul) -.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FCompute", BinaryBroadcastRTCCompute{"mul"}) .set_attr("FComputeEx", BinaryBroadcastComputeSparseEx); NNVM_REGISTER_OP(_backward_broadcast_mul) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"right", "left"}); NNVM_REGISTER_OP(broadcast_div) -.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FCompute", BinaryBroadcastRTCCompute{"div"}) .set_attr("FComputeEx", BinaryBroadcastComputeSparseEx); NNVM_REGISTER_OP(_backward_broadcast_div) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"div_grad", "div_rgrad"}); NNVM_REGISTER_OP(broadcast_mod) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"mod"}); NNVM_REGISTER_OP(_backward_broadcast_mod) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"mod_grad", "mod_rgrad"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_extended.cu b/src/operator/tensor/elemwise_binary_broadcast_op_extended.cu index e8e79f726b65..042a4da2b688 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_extended.cu +++ b/src/operator/tensor/elemwise_binary_broadcast_op_extended.cu @@ -29,32 +29,29 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(broadcast_power) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"power"}); NNVM_REGISTER_OP(_backward_broadcast_power) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"power_grad", "power_rgrad"}); NNVM_REGISTER_OP(broadcast_maximum) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"max"}); NNVM_REGISTER_OP(_backward_broadcast_maximum) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"greater_equal", "less"}); NNVM_REGISTER_OP(broadcast_minimum) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"min"}); NNVM_REGISTER_OP(_backward_broadcast_minimum) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"less_equal", "greater"}); NNVM_REGISTER_OP(broadcast_hypot) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"hypot"}); NNVM_REGISTER_OP(_backward_broadcast_hypot) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", BinaryBroadcastRTCBackwardUseIn{"hypot_grad_left", + "hypot_grad_right"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu index 4bec07b7096a..bd2f50a23566 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu +++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu @@ -30,31 +30,31 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(broadcast_equal) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"equal"}); NNVM_REGISTER_OP(broadcast_not_equal) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"not_equal"}); NNVM_REGISTER_OP(broadcast_greater) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"greater"}); NNVM_REGISTER_OP(broadcast_greater_equal) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"greater_equal"}); NNVM_REGISTER_OP(broadcast_lesser) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"less"}); NNVM_REGISTER_OP(broadcast_lesser_equal) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"less_equal"}); NNVM_REGISTER_OP(broadcast_logical_and) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"logical_and"}); NNVM_REGISTER_OP(broadcast_logical_or) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"logical_or"}); NNVM_REGISTER_OP(broadcast_logical_xor) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr("FCompute", BinaryBroadcastRTCCompute{"logical_xor"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_op.cc b/src/operator/tensor/elemwise_binary_op.cc index ea2466259494..c4fb5e1f5a39 100644 --- a/src/operator/tensor/elemwise_binary_op.cc +++ b/src/operator/tensor/elemwise_binary_op.cc @@ -25,6 +25,10 @@ #include "./elemwise_binary_op.h" +#if MXNET_USE_CUDA +#include "../../common/cuda/rtc/vectorization-inl.h" +#include "../../common/cuda/rtc.h" +#endif // MXNET_USE_CUDA namespace mxnet { namespace op { @@ -70,11 +74,6 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const nnvm::NodeAttrs& attrs, const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx; - const int ograd_stype = in_attrs->at(0); - const int lhs_stype = in_attrs->at(1); - const int rhs_stype = in_attrs->at(2); - int& lhs_grad_stype = out_attrs->at(0); - int& rhs_grad_stype = out_attrs->at(1); if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { dispatched = storage_type_assign(out_attrs, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); @@ -92,5 +91,347 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } +#if MXNET_USE_CUDA + +struct binary_kernel_params { + const void *inputs[3]; + void *outputs[2]; +}; + +const char binary_kernel_fwd[] = R"code( + +struct binary_kernel_params { + const void *inputs[3]; + void *outputs[2]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void binary_kernel(const binary_kernel_params params, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + VectorizedLoader loader0( + reinterpret_cast(params.inputs[0]), N); + VectorizedLoader loader1( + reinterpret_cast(params.inputs[1]), N); + VectorizedStorer storer( + reinterpret_cast(params.outputs[0]), N); + + using IType0 = AccType; + using IType1 = AccType; + using OType = AccType; + + const index_t M = num_aligned_elements; + + for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + loader0.load(tid, N); + loader1.load(tid, N); + if (req == OpReqType::kAddTo) { + storer.load(tid, N); + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const auto input0 = IType0::from(loader0.separate()[i]); + const auto input1 = IType1::from(loader1.separate()[i]); + const auto temp = OP(input0, input1); // enables returning different type + + if (req == OpReqType::kAddTo) { + // temp2 may have a wider type than either temp + // or OType + const auto temp2 = op::add(temp, OType::from(storer.separate()[i])); + storer.separate()[i] = OType::to(temp2); + } else { + storer.separate()[i] = OType::to(temp); + } + } + storer.store(tid, N); + } +} + +)code"; + +void ElemwiseBinaryRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet::common::cuda::rtc; + if (req[0] == kNullOp) return; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + std::string code = "const OpReqType req = "; + code += util::to_string(req[0]); + code += ";\n" + "#define OP op::"; + code += OP; + code += "\n"; + const int nvec = outputs[0].type_flag_ == mshadow::kFloat64 ? 2 : 4; + + const index_t size = outputs[0].Size(); + binary_kernel_params params = { {inputs[0].dptr_, inputs[1].dptr_, nullptr}, + {outputs[0].dptr_, nullptr} }; + + VectorizedKernelRTCLauncher(code, "binary_kernel", + binary_kernel_fwd, nvec, + size, 1, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id); +} + +const char binary_kernel_bwd_use_none[] = R"code( + +struct binary_kernel_params { + const void *inputs[3]; + void *outputs[2]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void binary_kernel_bwd(const binary_kernel_params params, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + VectorizedLoader loader( + reinterpret_cast(params.inputs[0]), N); + VectorizedStorer lstorer( + reinterpret_cast(params.outputs[0]), N); + VectorizedStorer rstorer( + reinterpret_cast(params.outputs[1]), N); + + using IType = AccType; + using OType0 = AccType; + using OType1 = AccType; + + const index_t M = num_aligned_elements; + + for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + loader.load(tid, N); + if (lreq == OpReqType::kAddTo) { + lstorer.load(tid, N); + } + if (rreq == OpReqType::kAddTo) { + rstorer.load(tid, N); + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const auto input = IType::from(loader.separate()[i]); + if (write_left_output) { + const auto temp = LOP(input); + if (lreq == OpReqType::kAddTo) { + // temp2 may have a wider type than either temp + // or OType + const auto temp2 = op::add(temp, OType0::from(lstorer.separate()[i])); + lstorer.separate()[i] = OType0::to(temp2); + } else { + lstorer.separate()[i] = OType0::to(temp); + } + } + if (write_right_output) { + const auto temp = ROP(input); + if (rreq == OpReqType::kAddTo) { + // temp2 may have a wider type than either temp + // or OType + const auto temp2 = op::add(temp, OType1::from(rstorer.separate()[i])); + rstorer.separate()[i] = OType1::to(temp2); + } else { + rstorer.separate()[i] = OType1::to(temp); + } + } + } + if (write_left_output) { + lstorer.store(tid, N); + } + if (write_right_output) { + rstorer.store(tid, N); + } + } +} +)code"; + +void ElemwiseBinaryRTCBwdUseNone::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet::common::cuda::rtc; + if (req[0] == kNullOp && req[1] == kNullOp) return; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); + + bool write_left_output = req[0] != kNullOp && + (req[0] != kWriteInplace || + (req[0] == kWriteInplace && LOP != "identity")); + + bool write_right_output = req[1] != kNullOp && + (req[1] != kWriteInplace || + (req[1] == kWriteInplace && LOP != "identity")); + + const std::string code = std::string("const OpReqType lreq = ") + + util::to_string(req[0]) + + ";\n" + "const OpReqType rreq = " + + util::to_string(req[1]) + + ";\n" + "#define ROP op::" + + ROP + + "\n" + "#define LOP op::" + + LOP + + "\n" + "const bool write_left_output = " + + std::to_string(write_left_output) + + ";\n" + "const bool write_right_output = " + + std::to_string(write_right_output) + + ";\n"; + const int nvec = outputs[0].type_flag_ == mshadow::kFloat64 ? 2 : 4; + + const index_t size = outputs[0].Size(); + binary_kernel_params params = { {inputs[0].dptr_, nullptr, nullptr}, + {outputs[0].dptr_, outputs[1].dptr_} }; + + VectorizedKernelRTCLauncher(code, "binary_kernel_bwd", + binary_kernel_bwd_use_none, nvec, + size, 1, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id); +} + +const char binary_kernel_bwd_use_in[] = R"code( + +struct binary_kernel_params { + const void *inputs[3]; + void *outputs[2]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void binary_kernel_bwd(const binary_kernel_params params, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + VectorizedLoader ograd_loader( + reinterpret_cast(params.inputs[0]), N); + VectorizedLoader linput_loader( + reinterpret_cast(params.inputs[1]), N); + VectorizedLoader rinput_loader( + reinterpret_cast(params.inputs[2]), N); + + VectorizedStorer lstorer( + reinterpret_cast(params.outputs[0]), N); + VectorizedStorer rstorer( + reinterpret_cast(params.outputs[1]), N); + + using IType0 = AccType; + using IType1 = AccType; + using IType2 = AccType; + using OType0 = AccType; + using OType1 = AccType; + + + const index_t M = num_aligned_elements; + + for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + ograd_loader.load(tid, N); + linput_loader.load(tid, N); + rinput_loader.load(tid, N); + if (lreq == OpReqType::kAddTo) { + lstorer.load(tid, N); + } + if (rreq == OpReqType::kAddTo) { + rstorer.load(tid, N); + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const auto ograd = IType0::from(ograd_loader.separate()[i]); + const auto linput = IType1::from(linput_loader.separate()[i]); + const auto rinput = IType2::from(rinput_loader.separate()[i]); + + if (lreq != OpReqType::kNullOp) { + const auto temp = op::mul(ograd, LOP(linput, rinput)); + if (lreq == OpReqType::kAddTo) { + const auto temp2 = op::add(temp, OType0::from(lstorer.separate()[i])); + lstorer.separate()[i] = OType0::to(temp2); + } else { + lstorer.separate()[i] = OType0::to(temp); + } + } + + if (rreq != OpReqType::kNullOp) { + const auto temp = op::mul(ograd, ROP(linput, rinput)); + if (rreq == OpReqType::kAddTo) { + const auto temp2 = op::add(temp, OType1::from(rstorer.separate()[i])); + rstorer.separate()[i] = OType1::to(temp2); + } else { + rstorer.separate()[i] = OType1::to(temp); + } + } + } + if (lreq != OpReqType::kNullOp) { + lstorer.store(tid, N); + } + if (rreq != OpReqType::kNullOp) { + rstorer.store(tid, N); + } + } +} +)code"; + +void ElemwiseBinaryRTCBwdUseIn::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet::common::cuda::rtc; + if (req[0] == kNullOp && req[1] == kNullOp) return; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + + const std::string code = std::string("const OpReqType lreq = ") + + util::to_string(req[0]) + + ";\n" + "const OpReqType rreq = " + + util::to_string(req[1]) + + ";\n" + "#define ROP op::" + + ROP + + "\n" + "#define LOP op::" + + LOP + + "\n"; + // Using 64 bit loads to reduce register pressure + size_t output_type_size = common::mshadow_type_info(outputs[0].type_flag_).size; + const int nvec = output_type_size <= sizeof(uint64_t) + ? (sizeof(uint64_t) / output_type_size) + : 1; + + const index_t size = outputs[0].Size(); + binary_kernel_params params = { {inputs[0].dptr_, inputs[1].dptr_, inputs[2].dptr_}, + {outputs[0].dptr_, outputs[1].dptr_} }; + + VectorizedKernelRTCLauncher(code, "binary_kernel_bwd", + binary_kernel_bwd_use_in, nvec, + size, 1, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id); +} + + +#endif // MXNET_USE_CUDA + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 7094e1e7367c..dc44dda73822 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -38,6 +38,7 @@ #include "elemwise_unary_op.h" #include "../../common/utils.h" #include "./init_op.h" +#include "../operator_common.h" namespace mxnet { namespace op { @@ -106,61 +107,67 @@ class ElemwiseBinaryOp : public OpBase { } private: - template + template static void BackwardUseNone_(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, + mshadow::Stream* s, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - using namespace mxnet_op; - Stream *s = ctx.get_stream(); - const int size = static_cast((outputs[0].Size() + DataType::kLanes - 1) - / DataType::kLanes); - const DType *ograd_dptr = inputs[0].dptr(); - if (std::is_same::value && req[0] == kWriteInplace) { - CHECK_EQ(ograd_dptr, outputs[0].dptr()); - } else if (req[0] != kNullOp) { - DType *lgrad_dptr = outputs[0].dptr(); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - Kernel, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr); - }); - } - if (std::is_same::value && req[1] == kWriteInplace) { - CHECK_EQ(ograd_dptr, outputs[1].dptr()); - } else if (req[1] != kNullOp) { - DType *rgrad_dptr = outputs[1].dptr(); - MXNET_ASSIGN_REQ_SWITCH(req[1], Req, { - Kernel, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr); - }); - } + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + using namespace mxnet_op; + const int size = static_cast((outputs[0].Size() + DataType::kLanes - 1) + / DataType::kLanes); + const DType *ograd_dptr = inputs[0].dptr(); + if (std::is_same::value && req[0] == kWriteInplace) { + CHECK_EQ(ograd_dptr, outputs[0].dptr()); + } else if (req[0] != kNullOp) { + DType *lgrad_dptr = outputs[0].dptr(); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + Kernel, cpu>::Launch(s, size, lgrad_dptr, ograd_dptr); + }); + } + if (std::is_same::value && req[1] == kWriteInplace) { + CHECK_EQ(ograd_dptr, outputs[1].dptr()); + } else if (req[1] != kNullOp) { + DType *rgrad_dptr = outputs[1].dptr(); + MXNET_ASSIGN_REQ_SWITCH(req[1], Req, { + Kernel, cpu>::Launch(s, size, rgrad_dptr, ograd_dptr); + }); + } + }); } - template + template static void BackwardUseIn_(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, + mshadow::Stream* s, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - DCHECK_EQ(outputs.size(), 2U); - DCHECK_EQ(inputs.size(), 3U); - mxnet_op::Stream *s = ctx.get_stream(); - const DType *ograd_dptr = inputs[0].dptr(); - const DType *lhs_dptr = inputs[1].dptr(); - const DType *rhs_dptr = inputs[2].dptr(); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - const int size = static_cast( - (outputs[0].Size() + mxnet_op::DataType::kLanes - 1) - / mxnet_op::DataType::kLanes); - DType * lgrad_dptr = outputs[0].dptr(); - mxnet_op::Kernel, Req>, xpu>::Launch( - s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);}); - MXNET_ASSIGN_REQ_SWITCH(req[1], Req, { - const int size = static_cast( - (outputs[1].Size() + mxnet_op::DataType::kLanes - 1) - / mxnet_op::DataType::kLanes); - DType * rgrad_dptr = outputs[1].dptr(); - mxnet_op::Kernel, Req>, xpu>::Launch( - s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);}); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + DCHECK_EQ(outputs.size(), 2U); + DCHECK_EQ(inputs.size(), 3U); + const DType *ograd_dptr = inputs[0].dptr(); + const DType *lhs_dptr = inputs[1].dptr(); + const DType *rhs_dptr = inputs[2].dptr(); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + const int size = static_cast( + (outputs[0].Size() + mxnet_op::DataType::kLanes - 1) + / mxnet_op::DataType::kLanes); + DType * lgrad_dptr = outputs[0].dptr(); + mxnet_op::Kernel< + mxnet_op::op_with_req, Req>, cpu>::Launch( + s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr); + }); + MXNET_ASSIGN_REQ_SWITCH(req[1], Req, { + const int size = static_cast( + (outputs[1].Size() + mxnet_op::DataType::kLanes - 1) + / mxnet_op::DataType::kLanes); + DType * rgrad_dptr = outputs[1].dptr(); + mxnet_op::Kernel< + mxnet_op::op_with_req, Req>, cpu>::Launch( + s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr); + }); + }); } template< @@ -479,7 +486,7 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &outputs) { using namespace mxnet_op; if (req[0] == kNullOp) return; - Stream *s = ctx.get_stream(); + mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); if (outputs[0].type_flag_ == mshadow::kBool) { @@ -607,30 +614,6 @@ template }); } - template - static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mxnet_op; - if (req[0] == kNullOp) return; - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { - const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) - + DataType::kLanes - 1) / DataType::kLanes; - if (size != 0) { - Kernel, xpu>::Launch(s, size, - outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr()); - } - }); - }); - } - template static void ComputeEx(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -727,20 +710,8 @@ template const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - BackwardUseNone_(attrs, ctx, inputs, req, outputs); - }); - } - - template - static inline void BackwardUseNoneWithHalf2(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { - BackwardUseNone_(attrs, ctx, inputs, req, outputs); - }); + mshadow::Stream *s = ctx.get_stream(); + BackwardUseNone_(attrs, s, inputs, req, outputs); } template @@ -784,21 +755,10 @@ template const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - BackwardUseIn_(attrs, ctx, inputs, req, outputs); - }); + mshadow::Stream *s = ctx.get_stream(); + BackwardUseIn_(attrs, s, inputs, req, outputs); } - template - static inline void BackwardUseInWithHalf2(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { - BackwardUseIn_(attrs, ctx, inputs, req, outputs); - }); - } template static inline void BackwardUseInEx(const nnvm::NodeAttrs &attrs, @@ -809,7 +769,6 @@ template using namespace common; CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 2U); // lhs input grad, rhs input grad - const auto out_grad_stype = inputs[0].storage_type(); const auto lhs_grad_stype = outputs[0].storage_type(); const auto rhs_grad_stype = outputs[1].storage_type(); if (ContainsOnlyStorage(inputs, kRowSparseStorage) && @@ -888,6 +847,43 @@ template [](const NodeAttrs& attrs) { \ return std::vector{ResourceRequest::kTempSpace};}) +#if MXNET_USE_CUDA + +struct ElemwiseBinaryRTCCompute { + std::string OP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +struct ElemwiseBinaryRTCBwdUseNone { + std::string LOP; + std::string ROP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +struct ElemwiseBinaryRTCBwdUseIn { + std::string LOP; + std::string ROP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +#endif + } // namespace op } // namespace mxnet + #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_ diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index 16d7fc1ad72b..cb0da7554bc3 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -26,6 +26,7 @@ #include "./elemwise_binary_op.h" #include "./elemwise_binary_op-inl.h" #include "./indexing_op.h" +#include "../../common/cuda/rtc.h" namespace mxnet { namespace op { @@ -218,52 +219,47 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream *s, } NNVM_REGISTER_OP(elemwise_add) -.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2) +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"add"}) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeEx); NNVM_REGISTER_OP(_grad_add) -.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"add"}); NNVM_REGISTER_OP(_backward_add) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseNoneWithHalf2); + ElemwiseBinaryRTCBwdUseNone{"identity", "identity"}); NNVM_REGISTER_OP(elemwise_sub) -.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2< - gpu, op::mshadow_op::minus>) +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"sub"}) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeEx); NNVM_REGISTER_OP(_backward_sub) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseNoneWithHalf2); + ElemwiseBinaryRTCBwdUseNone{"identity", "negation"}); NNVM_REGISTER_OP(elemwise_mul) -.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2) +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"mul"}) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeDnsLRValueEx); NNVM_REGISTER_OP(_backward_mul) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseInWithHalf2); + ElemwiseBinaryRTCBwdUseIn{"right", "left"}); NNVM_REGISTER_OP(elemwise_div) .set_attr("FCompute", - ElemwiseBinaryOp::ElemwiseBinaryOp::ComputeWithHalf2); + ElemwiseBinaryRTCCompute{"div"}); NNVM_REGISTER_OP(_backward_div) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseInWithHalf2); + ElemwiseBinaryRTCBwdUseIn{"div_grad", "div_rgrad"}); NNVM_REGISTER_OP(_mod) -.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"mod"}); NNVM_REGISTER_OP(_backward_mod) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseInWithHalf2); + ElemwiseBinaryRTCBwdUseIn{"mod_grad", "mod_rgrad"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_op_extended.cu b/src/operator/tensor/elemwise_binary_op_extended.cu index 0ae6ac966a2b..9d568a404f3d 100644 --- a/src/operator/tensor/elemwise_binary_op_extended.cu +++ b/src/operator/tensor/elemwise_binary_op_extended.cu @@ -22,38 +22,34 @@ * \file elemwise_binary_op_extended.cu * \brief GPU Implementation of binary function. */ -#include "./elemwise_unary_op.h" #include "./elemwise_binary_op.h" namespace mxnet { namespace op { NNVM_REGISTER_OP(_power) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"power"}); NNVM_REGISTER_OP(_backward_power) -.set_attr("FCompute", ElemwiseBinaryOp::BackwardUseIn); +.set_attr("FCompute", ElemwiseBinaryRTCBwdUseIn{"power_grad", "power_rgrad"}); NNVM_REGISTER_OP(_maximum) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"max"}); NNVM_REGISTER_OP(_backward_maximum) -.set_attr("FCompute", ElemwiseBinaryOp::BackwardUseIn); +.set_attr("FCompute", ElemwiseBinaryRTCBwdUseIn{"greater_equal", "less"}); NNVM_REGISTER_OP(_minimum) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"min"}); NNVM_REGISTER_OP(_backward_minimum) -.set_attr("FCompute", ElemwiseBinaryOp::BackwardUseIn); +.set_attr("FCompute", ElemwiseBinaryRTCBwdUseIn{"less_equal", "greater"}); NNVM_REGISTER_OP(_hypot) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"hypot"}); NNVM_REGISTER_OP(_backward_hypot) -.set_attr("FCompute", ElemwiseBinaryOp::BackwardUseIn); +.set_attr("FCompute", ElemwiseBinaryRTCBwdUseIn{"hypot_grad_left", + "hypot_grad_right"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_op_logic.cu b/src/operator/tensor/elemwise_binary_op_logic.cu index e36e6971148f..8ef84130c5e5 100644 --- a/src/operator/tensor/elemwise_binary_op_logic.cu +++ b/src/operator/tensor/elemwise_binary_op_logic.cu @@ -22,37 +22,36 @@ * \file elemwise_binary_op_logic.cu * \brief GPU Implementation of unary function. */ -#include "./elemwise_unary_op.h" #include "./elemwise_binary_op.h" namespace mxnet { namespace op { NNVM_REGISTER_OP(_equal) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"equal"}); NNVM_REGISTER_OP(_not_equal) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"not_equal"}); NNVM_REGISTER_OP(_greater) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"greater"}); NNVM_REGISTER_OP(_greater_equal) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"greater_equal"}); NNVM_REGISTER_OP(_lesser) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"less"}); NNVM_REGISTER_OP(_lesser_equal) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"less_equal"}); NNVM_REGISTER_OP(_logical_and) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"logical_and"}); NNVM_REGISTER_OP(_logical_or) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"logical_or"}); NNVM_REGISTER_OP(_logical_xor) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"logical_xor"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op.cc b/src/operator/tensor/elemwise_binary_scalar_op.cc new file mode 100644 index 000000000000..f09bf21cceb4 --- /dev/null +++ b/src/operator/tensor/elemwise_binary_scalar_op.cc @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "elemwise_binary_scalar_op.h" + +#if MXNET_USE_CUDA +#include "../../common/cuda/rtc/vectorization-inl.h" +#include "../../common/cuda/rtc.h" +#endif // MXNET_USE_CUDA + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUDA + +struct binary_scalar_kernel_params { + const void *inputs[2]; + void *outputs[1]; + double scalar; +}; + +const char binary_scalar_kernel_fwd[] = R"code( + +struct binary_scalar_kernel_params { + const void *inputs[2]; + void *outputs[1]; + double scalar; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void binary_scalar_kernel(const binary_scalar_kernel_params params, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + VectorizedLoader loader( + reinterpret_cast(params.inputs[0]), N); + VectorizedStorer storer( + reinterpret_cast(params.outputs[0]), N); + + using IType = AccType; + using OType = AccType; + + const index_t M = num_aligned_elements; + + for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + loader.load(tid, N); + if (req == OpReqType::kAddTo) { + storer.load(tid, N); + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const auto input = IType::from(loader.separate()[i]); + // enables returning different type + const auto temp = OP(input, + static_cast::type> + (params.scalar)); + + if (req == OpReqType::kAddTo) { + // temp2 may have a wider type than either temp + // or OType + const auto temp2 = op::add(temp, OType::from(storer.separate()[i])); + storer.separate()[i] = OType::to(temp2); + } else { + storer.separate()[i] = OType::to(temp); + } + } + storer.store(tid, N); + } +} + +)code"; + +void BinaryScalarRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet::common::cuda::rtc; + if (req[0] == kNullOp) return; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; + + const std::string code = std::string("const OpReqType req = ") + + util::to_string(req[0]) + + ";\n" + + "#define OP op::" + + OP + + "\n"; + const int nvec = common::mshadow_type_info(outputs[0].type_flag_).size == 8 ? 2 : 4; + + const index_t size = outputs[0].Size(); + binary_scalar_kernel_params params = { {inputs[0].dptr_, nullptr}, + {outputs[0].dptr_}, + alpha }; + + VectorizedKernelRTCLauncher(code, "binary_scalar_kernel", + binary_scalar_kernel_fwd, nvec, + size, 1, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id); +} + +void BinaryScalarRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) { + return; + } + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + InitStorageGeometry<1, 1>(attrs, inputs, outputs); + CHECK_NE(outputs[0].storage_type(), kDefaultStorage) + << "This function works only for sparse types."; + CHECK_EQ(inputs[0].storage_type(), outputs[0].storage_type()) + << "The storage type of both inputs and outputs needs to be the same."; + AllocateGeometry(&outputs[0], req[0], &inputs[0]); + CopyGeometryBlobs(ctx.get_stream(), &outputs[0], req[0], inputs[0]); + outputs[0].CheckAndAllocData(inputs[0].storage_shape()); + if (inputs[0].storage_shape().Size()) { + std::vector in_blobs, out_blobs; + in_blobs.reserve(inputs.size()); + out_blobs.reserve(outputs.size()); + for (auto &input : inputs) { + in_blobs.emplace_back(input.data()); + } + for (auto &output : outputs) { + out_blobs.emplace_back(output.data()); + } + this->operator()(attrs, ctx, in_blobs, req, out_blobs); + } +} + +const char binary_scalar_kernel_bwd[] = R"code( + +struct binary_scalar_kernel_params { + const void *inputs[2]; + void *outputs[1]; + double scalar; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void binary_scalar_kernel_bwd(const binary_scalar_kernel_params params, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + VectorizedLoader ograd_loader( + reinterpret_cast(params.inputs[0]), N); + VectorizedLoader input_loader( + reinterpret_cast(params.inputs[1]), N); + VectorizedStorer storer( + reinterpret_cast(params.outputs[0]), N); + + using GType = AccType; + using IType = AccType; + using OType = AccType; + + const index_t M = num_aligned_elements; + + for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + ograd_loader.load(tid, N); + input_loader.load(tid, N); + if (req == OpReqType::kAddTo) { + storer.load(tid, N); + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const auto ograd = GType::from(ograd_loader.separate()[i]); + const auto input = IType::from(input_loader.separate()[i]); + // enables returning different type + const auto temp = op::mul(ograd, + OP(input, + static_cast + ::type>(params.scalar))); + + if (req == OpReqType::kAddTo) { + // temp2 may have a wider type than either temp + // or OType + const auto temp2 = op::add(temp, OType::from(storer.separate()[i])); + storer.separate()[i] = OType::to(temp2); + } else { + storer.separate()[i] = OType::to(temp); + } + } + storer.store(tid, N); + } +} + +)code"; + +void BinaryScalarRTCBackward::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet::common::cuda::rtc; + if (req[0] == kNullOp) return; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; + + const std::string code = std::string("const OpReqType req = ") + + util::to_string(req[0]) + + ";\n" + "#define OP op::" + + OP + + "\n"; + const int nvec = outputs[0].type_flag_ == mshadow::kFloat64 ? 2 : 4; + + const index_t size = outputs[0].Size(); + binary_scalar_kernel_params params = { {inputs[0].dptr_, inputs[1].dptr_}, + {outputs[0].dptr_}, + alpha }; + + VectorizedKernelRTCLauncher(code, "binary_scalar_kernel_bwd", + binary_scalar_kernel_bwd, nvec, + size, 1, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id); +} + +#endif // MXNET_USE_CUDA + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index c09e41867f46..a6fdf1e7572d 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -266,17 +266,17 @@ class BinaryScalarOp : public UnaryOp { } public: - template - static void Compute(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + template + static void Compute_(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + mshadow::Stream* s, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { DCHECK_EQ(inputs.size(), 1); DCHECK_EQ(outputs.size(), 1); using namespace mshadow; using namespace mshadow::expr; - Stream *s = ctx.get_stream(); TBlob temp_tblob; const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); bool scalar_is_int = param.is_int; @@ -284,20 +284,30 @@ class BinaryScalarOp : public UnaryOp { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { if ((common::is_int(inputs[0].type_flag_) && !scalar_is_int) || (inputs[0].type_flag_ == kBool)) { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); temp_tblob = TBlob(temp_tensor); - CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); + CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); } else { temp_tblob = inputs[0]; } MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( + mxnet_op::Kernel, cpu>::Launch( s, inputs[0].Size(), outputs[0].dptr(), temp_tblob.dptr(), DType(alpha)); }); }); } + template + static void Compute(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + mshadow::Stream *s = ctx.get_stream(); + Compute_(attrs, ctx, s, inputs, req, outputs); + } + template static void ComputeInt(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -401,27 +411,38 @@ class BinaryScalarOp : public UnaryOp { } } - template - static void Backward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + template + static void Backward_(const nnvm::NodeAttrs &attrs, + mshadow::Stream* s, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mshadow; using namespace mshadow::expr; - Stream *s = ctx.get_stream(); const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); const double alpha = param.scalar; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet::op::mxnet_op::Kernel, Req>, xpu>:: + mxnet::op::mxnet_op::backward_grad_tuned, Req>, cpu>:: Launch(s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), inputs[1].dptr(), DType(alpha)); }); }); } + + template + static void Backward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + Backward_(attrs, s, inputs, req, outputs); + } }; #define MXNET_OPERATOR_REGISTER_BINARY_SCALAR(name) \ @@ -442,6 +463,38 @@ class BinaryScalarOp : public UnaryOp { .add_argument("data", "NDArray-or-Symbol", "source input") \ .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +#if MXNET_USE_CUDA + +struct BinaryScalarRTCCompute { + std::string OP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +struct BinaryScalarRTCBackward { + std::string OP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +#endif + } // namespace op } // namespace mxnet + + #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_H_ diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu index 3c839205683a..9635e83b4453 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu @@ -29,50 +29,47 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_plus_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"add"}); NNVM_REGISTER_OP(_minus_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"sub"}); NNVM_REGISTER_OP(_rminus_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rsub"}); NNVM_REGISTER_OP(_mul_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"mul"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"mul"}); NNVM_REGISTER_OP(_backward_mul_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"mul"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"mul"}); NNVM_REGISTER_OP(_div_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"div"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"div"}); NNVM_REGISTER_OP(_backward_div_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"div"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"div"}); NNVM_REGISTER_OP(_rdiv_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rdiv"}); NNVM_REGISTER_OP(_backward_rdiv_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"rdiv_grad"}); NNVM_REGISTER_OP(_mod_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"mod"}); NNVM_REGISTER_OP(_backward_mod_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward< - gpu, mshadow_op::mod_grad>); +.set_attr("FCompute", BinaryScalarRTCBackward{"mod_grad"}); NNVM_REGISTER_OP(_rmod_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rmod"}); NNVM_REGISTER_OP(_backward_rmod_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward< - gpu, mshadow_op::rmod_grad>); +.set_attr("FCompute", BinaryScalarRTCBackward{"rmod_grad"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cu b/src/operator/tensor/elemwise_binary_scalar_op_extended.cu index 2bd52d7b9d7c..c662dc2d2923 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cu +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cu @@ -29,45 +29,40 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_maximum_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"max"}); NNVM_REGISTER_OP(_backward_maximum_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"greater_equal"}); NNVM_REGISTER_OP(_minimum_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"min"}); NNVM_REGISTER_OP(_backward_minimum_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarRTCBackward{"less_equal"}); NNVM_REGISTER_OP(_power_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"power"}); NNVM_REGISTER_OP(_backward_power_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward< - gpu, mshadow_op::power_grad>); +.set_attr("FCompute", BinaryScalarRTCBackward{"power_grad"}); NNVM_REGISTER_OP(_rpower_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"rpow"}); NNVM_REGISTER_OP(_backward_rpower_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward< - gpu, mshadow_op::rpower_grad>); +.set_attr("FCompute", BinaryScalarRTCBackward{"rpower_grad"}); NNVM_REGISTER_OP(_hypot_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"hypot"}); NNVM_REGISTER_OP(_backward_hypot_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward< - gpu, mshadow_op::hypot_grad_left>); +.set_attr("FCompute", BinaryScalarRTCBackward{"hypot_grad_left"}); NNVM_REGISTER_OP(smooth_l1) -.set_attr("FCompute", BinaryScalarOp::Compute< - gpu, mshadow_op::smooth_l1_loss>); +.set_attr("FCompute", BinaryScalarRTCCompute{"smooth_l1"}); NNVM_REGISTER_OP(_backward_smooth_l1) -.set_attr("FCompute", BinaryScalarOp::Backward< - gpu, mshadow_op::smooth_l1_gradient>); +.set_attr("FCompute", BinaryScalarRTCBackward{"smooth_l1_grad"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cu b/src/operator/tensor/elemwise_binary_scalar_op_logic.cu index 6c393e0719a5..70ef26b0a5e5 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cu +++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cu @@ -28,37 +28,37 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_equal_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::LogicComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"equal"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"equal"}); NNVM_REGISTER_OP(_not_equal_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::LogicComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"not_equal"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"not_equal"}); NNVM_REGISTER_OP(_greater_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::LogicComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"greater"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"greater"}); NNVM_REGISTER_OP(_greater_equal_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::LogicComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"greater_equal"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"greater_equal"}); NNVM_REGISTER_OP(_lesser_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::LogicComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"less"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"less"}); NNVM_REGISTER_OP(_lesser_equal_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::LogicComputeEx); +.set_attr("FCompute", BinaryScalarRTCCompute{"less_equal"}) +.set_attr("FComputeEx", BinaryScalarRTCCompute{"less_equal"}); NNVM_REGISTER_OP(_logical_and_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"logical_and"}); NNVM_REGISTER_OP(_logical_or_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"logical_or"}); NNVM_REGISTER_OP(_logical_xor_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarRTCCompute{"logical_xor"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_scatter_op.cc b/src/operator/tensor/elemwise_scatter_op.cc deleted file mode 100644 index 41f22b057a53..000000000000 --- a/src/operator/tensor/elemwise_scatter_op.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file elemwise_scatter_op.cc - * \brief CPU implementation of elementwise scatter operators - */ -#include "./elemwise_binary_op-inl.h" -#include "./elemwise_binary_scalar_op.h" -#include "./elemwise_scatter_op.h" - -namespace mxnet { -namespace op { - -static bool StorageTypeRspOrDenseOutput(const NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - const auto lhs_stype = static_cast((*in_attrs)[0]); - if (common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) - && common::ContainsOnlyStorage(*out_attrs, kDefaultStorage)) { - if (storage_type_assign(&out_attrs[0], kDefaultStorage, - dispatch_mode, DispatchMode::kFCompute)) { - return true; - } - } - if (lhs_stype == kRowSparseStorage) { - if (storage_type_assign(&out_attrs[0], kRowSparseStorage, - dispatch_mode, - DispatchMode::kFComputeEx)) { - return true; - } - } - return dispatch_fallback(out_attrs, dispatch_mode); -} - -static bool StorageTypeScatteredScalarOp(const NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_attrs, - std::vector* out_attrs) { - // Supports kDefaultStorage, kRowSparseStorage and kCSRStorage - const auto stype = static_cast((*in_attrs)[0]); - if (storage_type_assign(out_attrs, - stype, - dispatch_mode, - stype == kDefaultStorage ? DispatchMode::kFCompute - : DispatchMode::kFComputeEx)) { - return true; - } - return dispatch_fallback(out_attrs, dispatch_mode); -} - -/*! \brief _scatter_elemwise_div */ -MXNET_OPERATOR_REGISTER_BINARY(_scatter_elemwise_div) -.set_attr("FCompute", ElemwiseScatterBinaryOp::Compute) -.set_attr("FComputeEx", ElemwiseScatterBinaryOp::ComputeEx< - cpu, op::mshadow_op::div>) -.describe(R"code(Divides arguments element-wise. If the left-hand-side input is 'row_sparse', then -only the values which exist in the left-hand sparse array are computed. The 'missing' values -are ignored. - -The storage type of ``_scatter_elemwise_div`` output depends on storage types of inputs - -- _scatter_elemwise_div(row_sparse, row_sparse) = row_sparse -- _scatter_elemwise_div(row_sparse, dense) = row_sparse -- _scatter_elemwise_div(row_sparse, csr) = row_sparse -- otherwise, ``_scatter_elemwise_div`` behaves exactly like elemwise_div and generates output -with default storage - -)code") -.set_attr("FInferStorageType", StorageTypeRspOrDenseOutput) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("THasDeterministicOutput", true) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_div"}); - -/*! \brief _scatter_plus_scalar */ -MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_scatter_plus_scalar) -.describe(R"code(Adds a scalar to a tensor element-wise. If the left-hand-side input is -'row_sparse' or 'csr', then only the values which exist in the left-hand sparse array are computed. -The 'missing' values are ignored. - -The storage type of ``_scatter_plus_scalar`` output depends on storage types of inputs - -- _scatter_plus_scalar(row_sparse, scalar) = row_sparse -- _scatter_plus_scalar(csr, scalar) = csr -- otherwise, ``_scatter_plus_scalar`` behaves exactly like _plus_scalar and generates output -with default storage - -)code") -.set_attr("FInferStorageType", StorageTypeScatteredScalarOp) -.set_attr("FCompute", - ElemwiseScatterBinaryScalarOp::Compute) -.set_attr("FComputeEx", - ElemwiseScatterBinaryScalarOp::ComputeEx) -.set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); - -/*! \brief _scatter_minus_scalar */ -MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_scatter_minus_scalar) - .describe(R"code(Subtracts a scalar to a tensor element-wise. If the left-hand-side input is -'row_sparse' or 'csr', then only the values which exist in the left-hand sparse array are computed. -The 'missing' values are ignored. - -The storage type of ``_scatter_minus_scalar`` output depends on storage types of inputs - -- _scatter_minus_scalar(row_sparse, scalar) = row_sparse -- _scatter_minus_scalar(csr, scalar) = csr -- otherwise, ``_scatter_minus_scalar`` behaves exactly like _minus_scalar and generates output -with default storage - -)code") -.set_attr("FInferStorageType", StorageTypeScatteredScalarOp) -.set_attr("FCompute", - ElemwiseScatterBinaryScalarOp::Compute) -.set_attr("FComputeEx", - ElemwiseScatterBinaryScalarOp::ComputeEx) -.set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); - -} // namespace op -} // namespace mxnet diff --git a/src/operator/tensor/elemwise_scatter_op.cu b/src/operator/tensor/elemwise_scatter_op.cu deleted file mode 100644 index 913aa9512193..000000000000 --- a/src/operator/tensor/elemwise_scatter_op.cu +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - - /*! - * \file elemwise_scatter_op.cu - * \brief GPU implementation of elementwise scatter operators - */ -#include "./elemwise_binary_scalar_op.h" -#include "./elemwise_scatter_op.h" - -namespace mxnet { -namespace op { - -NNVM_REGISTER_OP(_scatter_elemwise_div) -.set_attr("FCompute", ElemwiseScatterBinaryOp::Compute) -.set_attr("FComputeEx", ElemwiseScatterBinaryOp::ComputeEx); - -NNVM_REGISTER_OP(_scatter_plus_scalar) -.set_attr("FCompute", - ElemwiseScatterBinaryScalarOp::Compute) -.set_attr("FComputeEx", - ElemwiseScatterBinaryScalarOp::ComputeEx); - -NNVM_REGISTER_OP(_scatter_minus_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); - -} // namespace op -} // namespace mxnet - diff --git a/src/operator/tensor/elemwise_scatter_op.h b/src/operator/tensor/elemwise_scatter_op.h deleted file mode 100644 index 0e52a86da8bf..000000000000 --- a/src/operator/tensor/elemwise_scatter_op.h +++ /dev/null @@ -1,318 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file elemwise_scatter_op.h - * \brief Function definition of elementwise scatter operators - */ -#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_SCATTER_OP_H_ -#define MXNET_OPERATOR_TENSOR_ELEMWISE_SCATTER_OP_H_ - -#include -#include "./elemwise_binary_op.h" -#include "./elemwise_binary_scalar_op.h" -#include "sparse_retain-inl.h" -#include "cast_storage-inl.h" - -namespace mxnet { -namespace op { - -/*! - * \brief Shared helper functions for scatter ops - */ -class ScatterOpBase { - /*! \brief Protected in order to prevent widespread use. Scatter ops is a special case */ - protected: - /*! - * \brief For some situations, we need to do the computation as dense and then use - * sparse-retain to strip out the portions we aren't interested in. - * \note If your operastor uses this function, it must request kTempStorage - * \tparam xpu gpu or cpu - * \tparam Function Function to call with dense inputs and outputs - * \param attrs Operator attributes - * \param ctx Operator context - * \param inputs Input NDArrays - * \param req Operation request - * \param outputs Output NDArrays - * \param function - */ - template - static void ComputeAsDense(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs, - Function function) { - std::vector output_converted; - std::vector input_data, output_data; - std::vector other_inputs, other_outputs; - other_inputs.reserve(inputs.size()); - input_data.reserve(inputs.size()); - output_data.reserve(outputs.size()); - other_outputs.reserve(outputs.size()); - output_converted.reserve(outputs.size()); - // Inputs... - for (const NDArray& nd : inputs) { - if (nd.storage_type() != kDefaultStorage) { - NDArray in(nd.shape(), ctx.run_ctx.get_ctx()); - CastStorageComputeEx(attrs, ctx, { nd }, req, { in }); - other_inputs.push_back(in); - input_data.push_back(in.data()); - } else { - input_data.push_back(nd.data()); - } - } - - // Outputs... - for (const NDArray& nd : outputs) { - if (nd.storage_type() != kDefaultStorage) { - NDArray out(nd.shape(), ctx.run_ctx.get_ctx()); - CastStorageComputeEx(attrs, ctx, { nd }, req, { out }); - other_outputs.push_back(out); - output_data.push_back(out.data()); - output_converted.push_back(true); - } else { - other_outputs.push_back(nd); - output_data.push_back(nd.data()); - output_converted.push_back(false); - } - } - - // Call the function - function(attrs, ctx, input_data, req, output_data); - - // Convert output(s) back if necessary - for (size_t i = 0, n = outputs.size(); i < n; ++i) { - if (output_converted[i]) { - CastStorageComputeEx(attrs, - ctx, - { other_outputs[i] }, - req, - { outputs[i] }); - } - } - } - - /*! - * \brief Execute the supplied function/operation, followed by a sparse retain operation - * of the lhs argument's rows only (row indices) - * \tparam xpu gpu or cpu - * \tparam Function Function type call to wrap and return sparse-retained output - * \param attrs Operator attributes - * \param ctx Operator context - * \param inputs Input NDArrays - * \param req Operation request - * \param outputs Output NDArrays - * \param pre_retain Whether to call SparseRetain before calling the given function - * \param function Function call to wrap and return sparse-retained output - */ - template - static void ScatterWrap(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs, - bool pre_retain, - Function function) { - CHECK_EQ(outputs.size(), 1U); - if (inputs[0].storage_type() == kRowSparseStorage - && outputs[0].storage_type() == kRowSparseStorage) { - if (pre_retain && inputs[1].storage_type() == kRowSparseStorage) { - // Retain only rhs rows which have same row as lhs input - NDArray retained_input(outputs[0].storage_type(), outputs[0].shape(), outputs[0].ctx()); - SparseRetainOpForwardEx(attrs, ctx, - { inputs[1], inputs[0].aux_ndarray(rowsparse::kIdx) }, - req, - {retained_input}); - CHECK(retained_input.storage_initialized()); - // Perform the operation - function(attrs, ctx, {inputs[0], retained_input}, req, outputs); - // Sanity check - DCHECK_LE(outputs[0].aux_shape(rowsparse::kIdx).Size(), - inputs[0].aux_shape(rowsparse::kIdx).Size()); - } else { - // Perform the operation as usual - NDArray temp_out(outputs[0].storage_type(), outputs[0].shape(), outputs[0].ctx()); - function(attrs, ctx, inputs, req, { temp_out }); - CHECK(temp_out.storage_initialized()); - CHECK_EQ(temp_out.storage_type(), kRowSparseStorage); - // Sparse-retain the output based upon lhs-input sparsity - const NDArray indices(inputs[0].aux_data(rowsparse::kIdx), inputs[0].ctx().dev_id); - SparseRetainOpForwardEx(attrs, ctx, { temp_out, indices }, - req, outputs); - DCHECK_LE(outputs[0].aux_shape(rowsparse::kIdx).Size(), - inputs[0].aux_shape(rowsparse::kIdx).Size()); - } - } else { - function(attrs, ctx, inputs, req, outputs); - } - } -}; - -/*! \brief Scatter elemwise binary op handlers */ -class ElemwiseScatterBinaryOp : public ElemwiseBinaryOp, - public ScatterOpBase { - /*! \brief CPU version, RspRsp knows how to do an efficient scatter, - * otherwise retain rhs + normal op */ - template - static void ComputeEx_(mshadow::Stream *stream, - const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - // row_sparse-op-row_sparse or row_sparse-op-default can call RspRsp - const NDArrayStorageType input0_stype = inputs[0].storage_type(); - const NDArrayStorageType input1_stype = inputs[1].storage_type(); - if (input0_stype == kRowSparseStorage - && (input1_stype == kRowSparseStorage || input1_stype == kDefaultStorage) - && outputs[0].storage_type() == kRowSparseStorage) { - mshadow::Stream *s = ctx.get_stream(); - RspRspOp(s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0], - false, true, false, true); - CHECK_EQ(inputs[0].aux_shape(rowsparse::kIdx).Size(), - outputs[0].aux_shape(rowsparse::kIdx).Size()); - } else { - ScatterWrap(attrs, ctx, inputs, req, - outputs, true, [input0_stype, input1_stype](const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - if ((input0_stype == kCSRStorage || input1_stype == kCSRStorage) - && input0_stype != input1_stype) { - // Fallback to dense + retain - ComputeAsDense(attrs, ctx, inputs, req, - outputs, ElemwiseBinaryOp::Compute); - } else { - // Normal operation + retain - ElemwiseBinaryOp::ComputeEx(attrs, ctx, inputs, req, outputs); - } - }); - } - } - -#ifdef __CUDACC__ - /*! \brief GPU version, fallback op + retain */ - template - static void ComputeEx_(mshadow::Stream *stream, - const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - ScatterWrap(attrs, ctx, inputs, req, - outputs, false, [](const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - ComputeAsDense(attrs, ctx, inputs, req, outputs, ElemwiseBinaryOp::Compute); - }); - } -#endif // #ifdef __CUDACC__ - - public: - /*! \brief General compute for operations which include sparse tensors */ - template - static void ComputeEx(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - DCHECK_EQ(inputs.size(), 2U); - DCHECK_EQ(outputs.size(), 1U); - ComputeEx_(ctx.get_stream(), attrs, ctx, inputs, req, outputs); - } -}; - -/*! \brief Scatter elemwise binary scalar op handlers */ -class ElemwiseScatterBinaryScalarOp : public BinaryScalarOp, - public ScatterOpBase { - /*! \brief CPU version, retain rhs + normal op */ - template - static void ComputeEx_(mshadow::Stream *stream, - const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - ScatterWrap(attrs, ctx, inputs, req, - outputs, true, [](const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - // Normal operation + retain - BinaryScalarOp::ComputeEx(attrs, ctx, inputs, req, outputs); - }); - } - -#ifdef __CUDACC__ - /*! \brief GPU version, fallback op + retain */ - template - static void ComputeEx_(mshadow::Stream *stream, - const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - CHECK_NE(inputs[0].storage_type(), kDefaultStorage); - if (outputs[0].storage_type() == inputs[0].storage_type()) { - BinaryScalarOp::ComputeEx(attrs, ctx, inputs, req, outputs); - } else { - ScatterWrap(attrs, ctx, inputs, req, - outputs, false, [](const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - // Fallback to dense + retain - ComputeAsDense(attrs, ctx, inputs, req, outputs, BinaryScalarOp::Compute); - }); - } - } -#endif // __CUDACC__ - - public: - using BinaryScalarOp::Compute; - - /*! \brief General compute for operations which include sparse tensors */ - template - static void ComputeEx(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - DCHECK_EQ(inputs.size(), 1U); - DCHECK_EQ(outputs.size(), 1U); - CHECK_NE(inputs[0].storage_type(), kDefaultStorage); - if (inputs[0].storage_type() == kRowSparseStorage - && outputs[0].storage_type() == kRowSparseStorage) { - UnaryOp::MapToFCompute(attrs, ctx, inputs, req, outputs, Compute); - } else { - ComputeEx_(ctx.get_stream(), attrs, ctx, inputs, req, outputs); - } - } -}; - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_SCATTER_OP_H_ diff --git a/src/operator/tensor/elemwise_sum.cu b/src/operator/tensor/elemwise_sum.cu index f9a248214e85..acee34fb35bc 100644 --- a/src/operator/tensor/elemwise_sum.cu +++ b/src/operator/tensor/elemwise_sum.cu @@ -24,10 +24,136 @@ */ #include "./elemwise_sum.h" #include "../../ndarray/ndarray_function.h" +#include "../../common/cuda/rtc.h" +#include "../../common/cuda/rtc/vectorization-inl.h" namespace mxnet { namespace op { +namespace { + +constexpr size_t num_inputs_per_kernel = 4; + +struct elementwise_sum_params { + int num_inputs; + const void* inputs[num_inputs_per_kernel]; + void* outputs[1]; +}; + +const char elementwise_sum_kernel[] = R"code( +constexpr size_t num_inputs_per_kernel = 4; + +struct elementwise_sum_params { + int num_inputs; + const void* inputs[num_inputs_per_kernel]; + void* outputs[1]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void elementwise_sum_kernel( + const elementwise_sum_params params, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + VectorizedStorer storer( + reinterpret_cast(params.outputs[0]), N); + + using IType = AccType; + using OType = AccType; + + const index_t M = num_aligned_elements; + + for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + typename OType::type temp[nvec]; + if (req == OpReqType::kAddTo) { + storer.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + temp[i] = OType::from(storer.separate()[i]); + } + } else { +#pragma unroll + for (int i = 0; i < nvec; ++i) { + temp[i] = 0; + } + } +#pragma unroll + for (int i = 0; i < num_inputs_per_kernel; ++i) { + if (i < params.num_inputs) { + VectorizedLoader loader( + reinterpret_cast(params.inputs[i]), N); + loader.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + temp[i] += IType::from(loader.separate()[i]); + } + } + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + storer.separate()[i] = OType::to(temp[i]); + } + + storer.store(tid, N); + } +} +)code"; + +void VectorizedElementwiseSum(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet::common::cuda::rtc; + mshadow::Stream *s = ctx.get_stream(); + if (req[0] == kNullOp) return; + CHECK_EQ(outputs.size(), 1U); + size_t output_type_size = common::mshadow_type_info(outputs[0].type_flag_).size; + const int nvec = output_type_size <= sizeof(uint2) + ? (sizeof(uint2) / output_type_size) + : 1; + const index_t size = inputs[0].Size(); + for (size_t i = 0; i < inputs.size(); i += num_inputs_per_kernel) { + if (i == 0) { + const std::string code = std::string("const OpReqType req = ") + + util::to_string(req[0]) + + ";\n"; + elementwise_sum_params params{}; + params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - i); + for (int j = 0; j < params.num_inputs; ++j) { + params.inputs[j] = inputs[i + j].dptr_; + } + params.outputs[0] = outputs[0].dptr_; + VectorizedKernelRTCLauncher(code, "elementwise_sum_kernel", + elementwise_sum_kernel, nvec, + size, 1, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id); + } else { + /* During subsequent launches we need to + accumulate into the previous outputs + */ + const std::string code = "const OpReqType req = OpReqType::kAddTo;\n"; + elementwise_sum_params params{}; + params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - i); + for (int j = 0; j < params.num_inputs; ++j) { + params.inputs[j] = inputs[i + j].dptr_; + } + params.outputs[0] = outputs[0].dptr_; + const std::vector new_inputs(inputs.begin() + i, inputs.end()); + VectorizedKernelRTCLauncher(code, "elementwise_sum_kernel", + elementwise_sum_kernel, nvec, + size, 1, s, params, + new_inputs, outputs, + ctx.run_ctx.get_ctx().dev_id); + } + } +} + void ElementWiseSumComputeExGPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -51,8 +177,10 @@ void ElementWiseSumComputeExGPU(const nnvm::NodeAttrs& attrs, } } +} // namespace + NNVM_REGISTER_OP(add_n) -.set_attr("FCompute", ElementWiseSumComputeWithHalf2) +.set_attr("FCompute", VectorizedElementwiseSum) .set_attr("FComputeEx", ElementWiseSumComputeExGPU); } // namespace op diff --git a/src/operator/tensor/elemwise_sum.h b/src/operator/tensor/elemwise_sum.h index 259c80ddddac..d40ab4de0f0f 100644 --- a/src/operator/tensor/elemwise_sum.h +++ b/src/operator/tensor/elemwise_sum.h @@ -113,18 +113,6 @@ void ElementWiseSumCompute(const nnvm::NodeAttrs& attrs, }); } -template -void ElementWiseSumComputeWithHalf2(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CHECK_EQ(outputs.size(), 1U); - MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { - ElementWiseSumCompute_(attrs, ctx, inputs, req, outputs); - }); -} - } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_SUM_H_ diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc new file mode 100644 index 000000000000..df51c7ba2d12 --- /dev/null +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include "elemwise_unary_op.h" +#include "elemwise_binary_op.h" + +#if MXNET_USE_CUDA +#include "../../common/cuda/rtc/vectorization-inl.h" +#include "../../common/cuda/rtc.h" +#endif // MXNET_USE_CUDA + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUDA + +struct unary_kernel_params { + const void *inputs[1]; + void *outputs[1]; +}; + +const char unary_kernel_fwd[] = R"code( + +struct unary_kernel_params { + const void *inputs[1]; + void *outputs[1]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void unary_kernel(const unary_kernel_params params, + const index_t lead_dim, + const index_t other_dim, + const index_t N, + const index_t num_aligned_elements) { + using namespace vector; + VectorizedLoader loader( + reinterpret_cast(params.inputs[0]), N); + VectorizedStorer storer( + reinterpret_cast(params.outputs[0]), N); + + using IType = AccType; + using OType = AccType; + + const index_t M = num_aligned_elements; + + for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + loader.load(tid, N); + if (req == OpReqType::kAddTo) { + storer.load(tid, N); + } +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const auto input = IType::from(loader.separate()[i]); + const auto temp = OP(input); // enables returning different type + + if (req == OpReqType::kAddTo) { + // temp2 may have a wider type than either temp + // or OType + const auto temp2 = op::add(temp, OType::from(storer.separate()[i])); + storer.separate()[i] = OType::to(temp2); + } else { + storer.separate()[i] = OType::to(temp); + } + } + storer.store(tid, N); + } +} + +)code"; + +void UnaryRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet::common::cuda::rtc; + if (req[0] == kNullOp) return; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + + const std::string code = std::string("const OpReqType req = ") + + util::to_string(req[0]) + + ";\n" + "#define OP op::" + + OP + + "\n"; + const int nvec = outputs[0].type_flag_ == mshadow::kFloat64 ? 2 : 4; + + const index_t size = outputs[0].Size(); + unary_kernel_params params = { {inputs[0].dptr_}, + {outputs[0].dptr_} }; + + VectorizedKernelRTCLauncher(code, "unary_kernel", + unary_kernel_fwd, nvec, + size, 1, s, params, + inputs, outputs, + ctx.run_ctx.get_ctx().dev_id); +} + +void UnaryRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) { + return; + } + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + InitStorageGeometry<1, 1>(attrs, inputs, outputs); + CHECK_NE(outputs[0].storage_type(), kDefaultStorage) + << "This function works only for sparse types."; + CHECK_EQ(inputs[0].storage_type(), outputs[0].storage_type()) + << "The storage type of both inputs and outputs needs to be the same."; + AllocateGeometry(&outputs[0], req[0], &inputs[0]); + CopyGeometryBlobs(ctx.get_stream(), &outputs[0], req[0], inputs[0]); + outputs[0].CheckAndAllocData(inputs[0].storage_shape()); + if (inputs[0].storage_shape().Size()) { + std::vector in_blobs, out_blobs; + in_blobs.reserve(inputs.size()); + out_blobs.reserve(outputs.size()); + for (auto &input : inputs) { + in_blobs.emplace_back(input.data()); + } + for (auto &output : outputs) { + out_blobs.emplace_back(output.data()); + } + this->operator()(attrs, ctx, in_blobs, req, out_blobs); + } +} + +void UnaryBwdInOutRTCCompute::operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + ElemwiseBinaryRTCCompute {OP} (attrs, ctx, {inputs[0], inputs[2]}, req, outputs); +} + +#endif // MXNET_USE_CUDA + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index f2148b559c0c..1f0610d63b62 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -31,6 +31,7 @@ #include #include #include +#include #include "./cast_storage-inl.h" #include "../mshadow_op.h" #include "../mxnet_op.h" @@ -45,69 +46,117 @@ namespace mxnet { namespace op { -class OpBase { - protected: - /*! \brief simple kernel to set to a scalar value of arbitrary type */ - template - using set_to_scalar = mxnet_op::op_with_req; +namespace { - /*! \brief Copy blob data */ - template - static void inline CopyBlob(mshadow::Stream *s, - const TBlob *dest_blob, - const OpReqType reqi, - const TBlob& src_blob) { - CHECK_EQ(src_blob.type_flag_, dest_blob->type_flag_); - CHECK_EQ(src_blob.shape_, dest_blob->shape_); - MSHADOW_TYPE_SWITCH(src_blob.type_flag_, DType, { - // Check if the pointers are the same (in-place operation needs no copy) - if (reqi != kNullOp && src_blob.dptr() != dest_blob->dptr()) { - mshadow::Copy(dest_blob->FlatTo1D(s), src_blob.FlatTo1D(s), s); +/*! \brief Infer the output storage geometry + * \return boolean signifying whether the proper storage geometry was initialized + */ +template +bool InitStorageGeometry(const nnvm::NodeAttrs& attrs, + const std::vector& inputs, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), static_cast(n_in)) + << " in operator " << attrs.name; + CHECK_EQ(outputs.size(), static_cast(n_out)) + << " in operator " << attrs.name; + static_assert(n_in > 0 && n_out > 0, "Invalid input and/or output count values"); + const mxnet::TShape& isshape = inputs[0].storage_shape(); + if (!shape_is_none(isshape)) { + NDArray *output = nullptr; + for (size_t i = 0, n = inputs.size(); i < n; ++i) { + const NDArray &input = inputs[i]; + if (i < n_out) { + output = const_cast(&outputs[i]); } - }); + CHECK_EQ(output->shape(), inputs[i].shape()); + CHECK_EQ(output->storage_type(), input.storage_type()); + CHECK_EQ(output->aux_shapes().size(), input.aux_shapes().size()); + mxnet::ShapeVector aux_shapes; + const size_t aux_shape_count = input.aux_shapes().size(); + aux_shapes.reserve(aux_shape_count); + for (size_t j = 0; j < aux_shape_count; ++j) { + aux_shapes.emplace_back(input.aux_shape(j)); + } + output->CheckAndAlloc(aux_shapes); + DCHECK_EQ(output->storage_shape(), input.storage_shape()); + } + return true; + } + if (isshape.ndim() > 0 && !isshape.Size() + && inputs[0].storage_type() != kDefaultStorage) { + return true; // 0% density + } else { + CHECK(false); // implement when necessary } + return false; +} - /*! \brief Allocate geometry-related blob data for sparse tensors - * \param dest Destination sparse NDArray - * \param clone_from sparse NDArray from which to clone storage attributes - */ - static void AllocateGeometry(const NDArray *dest, - const OpReqType req, - const NDArray* clone_from = nullptr) { - if (req != kNullOp) { - if (clone_from) { - const mxnet::TShape& ishape = clone_from->storage_shape(); - dest->CheckAndAllocData(ishape); - CHECK_EQ(dest->storage_type(), clone_from->storage_type()); - for (size_t i = 0, n = clone_from->aux_shapes().size(); i < n; ++i) { - dest->CheckAndAllocAuxData(i, clone_from->aux_shape(i)); - } - DCHECK_EQ(dest->aux_shapes().size(), clone_from->aux_shapes().size()); - } else { - for (size_t i = 0, n = dest->aux_shapes().size(); i < n; ++i) { - dest->CheckAndAllocAuxData(i, dest->aux_shape(i)); - } - dest->CheckAndAllocData(dest->storage_shape()); +/*! \brief Copy blob data */ +template +void inline CopyBlob(mshadow::Stream *s, + const TBlob *dest_blob, + const OpReqType reqi, + const TBlob& src_blob) { + CHECK_EQ(src_blob.type_flag_, dest_blob->type_flag_); + CHECK_EQ(src_blob.shape_, dest_blob->shape_); + MSHADOW_TYPE_SWITCH(src_blob.type_flag_, DType, { + // Check if the pointers are the same (in-place operation needs no copy) + if (reqi != kNullOp && src_blob.dptr() != dest_blob->dptr()) { + mshadow::Copy(dest_blob->FlatTo1D(s), src_blob.FlatTo1D(s), s); + } + }); +} + +/*! \brief Allocate geometry-related blob data for sparse tensors + * \param dest Destination sparse NDArray + * \param clone_from sparse NDArray from which to clone storage attributes + */ +void inline AllocateGeometry(const NDArray *dest, + const OpReqType req, + const NDArray* clone_from = nullptr) { + if (req != kNullOp) { + if (clone_from) { + const mxnet::TShape& ishape = clone_from->storage_shape(); + dest->CheckAndAllocData(ishape); + CHECK_EQ(dest->storage_type(), clone_from->storage_type()); + for (size_t i = 0, n = clone_from->aux_shapes().size(); i < n; ++i) { + dest->CheckAndAllocAuxData(i, clone_from->aux_shape(i)); } + DCHECK_EQ(dest->aux_shapes().size(), clone_from->aux_shapes().size()); + } else { + for (size_t i = 0, n = dest->aux_shapes().size(); i < n; ++i) { + dest->CheckAndAllocAuxData(i, dest->aux_shape(i)); + } + dest->CheckAndAllocData(dest->storage_shape()); } } +} - /*! \brief Copy the geometry-related blobs (row sparse indexes, etc.) */ - template - static inline void CopyGeometryBlobs(mshadow::Stream *s, - const NDArray *dest, - const OpReqType reqi, - const NDArray &src) { - CHECK_EQ(src.aux_shapes().size(), dest->aux_shapes().size()); - // My assumption is that the geometry blobs are not large enough to justify an omp loop here, - // since the thread synchronization calls for each fork will take longer - // than copying a few floats - for (size_t i = 0, n = src.aux_shapes().size(); i < n; ++i) { - const TBlob src_blob = src.aux_data(i); - const TBlob dest_blob = dest->aux_data(i); - CopyBlob(s, &dest_blob, reqi, src_blob); - } +/*! \brief Copy the geometry-related blobs (row sparse indexes, etc.) */ +template +inline void CopyGeometryBlobs(mshadow::Stream *s, + const NDArray *dest, + const OpReqType reqi, + const NDArray &src) { + CHECK_EQ(src.aux_shapes().size(), dest->aux_shapes().size()); + // My assumption is that the geometry blobs are not large enough to justify an omp loop here, + // since the thread synchronization calls for each fork will take longer + // than copying a few floats + for (size_t i = 0, n = src.aux_shapes().size(); i < n; ++i) { + const TBlob src_blob = src.aux_data(i); + const TBlob dest_blob = dest->aux_data(i); + CopyBlob(s, &dest_blob, reqi, src_blob); } +} + +} // namespace + +class OpBase { + protected: + /*! \brief simple kernel to set to a scalar value of arbitrary type */ + template + using set_to_scalar = mxnet_op::op_with_req; + /*! \brief Generic copy NDArray */ template @@ -172,49 +221,6 @@ class OpBase { /*! \brief Unary operator class */ class UnaryOp : public OpBase { - /*! \brief Infer the output storage geometry - * \return boolean signifying whether the proper storage geometry was initialized - */ - template - static bool InitStorageGeometry(const nnvm::NodeAttrs& attrs, - const std::vector& inputs, - const std::vector& outputs) { - CHECK_EQ(inputs.size(), static_cast(n_in)) - << " in operator " << attrs.name; - CHECK_EQ(outputs.size(), static_cast(n_out)) - << " in operator " << attrs.name; - static_assert(n_in > 0 && n_out > 0, "Invalid input and/or output count values"); - const mxnet::TShape& isshape = inputs[0].storage_shape(); - if (!shape_is_none(isshape)) { - NDArray *output = nullptr; - for (size_t i = 0, n = inputs.size(); i < n; ++i) { - const NDArray &input = inputs[i]; - if (i < n_out) { - output = const_cast(&outputs[i]); - } - CHECK_EQ(output->shape(), inputs[i].shape()); - CHECK_EQ(output->storage_type(), input.storage_type()); - CHECK_EQ(output->aux_shapes().size(), input.aux_shapes().size()); - mxnet::ShapeVector aux_shapes; - const size_t aux_shape_count = input.aux_shapes().size(); - aux_shapes.reserve(aux_shape_count); - for (size_t j = 0; j < aux_shape_count; ++j) { - aux_shapes.emplace_back(input.aux_shape(j)); - } - output->CheckAndAlloc(aux_shapes); - DCHECK_EQ(output->storage_shape(), input.storage_shape()); - } - return true; - } - if (isshape.ndim() > 0 && !isshape.Size() - && inputs[0].storage_type() != kDefaultStorage) { - return true; // 0% density - } else { - CHECK(false); // implement when necessary - } - return false; - } - public: /*! \brief Map NDArray vectors to TBlob vectors and pass to compute function */ template @@ -224,7 +230,7 @@ class UnaryOp : public OpBase { const std::vector &req, const std::vector &outputs, FComputer computer) { - UnaryOp::template InitStorageGeometry<1, 1>(attrs, inputs, outputs); + InitStorageGeometry<1, 1>(attrs, inputs, outputs); CHECK_EQ(inputs.size(), outputs.size()); // need to figure out what to do for binary type CHECK_NE(outputs[0].storage_type(), kDefaultStorage); CHECK_EQ(inputs[0].storage_type(), outputs[0].storage_type()); @@ -236,23 +242,32 @@ class UnaryOp : public OpBase { } } - template - static void Compute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - mshadow::Stream *s = ctx.get_stream(); + template + static void Compute_(const nnvm::NodeAttrs& attrs, + mshadow::Stream* s, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { if (inputs[0].Size() != 0) { - mxnet_op::Kernel, xpu>::Launch( + mxnet_op::Kernel, cpu>::Launch( s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr()); } }); }); } + template + static void Compute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + mshadow::Stream *s = ctx.get_stream(); + Compute_(attrs, s, inputs, req, outputs); + } + template static void ComputeMixedType(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -319,7 +334,8 @@ class UnaryOp : public OpBase { const std::vector& outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); - CHECK_NE(inputs[0].storage_type(), kDefaultStorage); + CHECK_NE(inputs[0].storage_type(), kDefaultStorage) + << "Operation requires a sparse input storage type"; CHECK_NE(outputs[0].storage_type(), kDefaultStorage) << "Operation requires a sparse output storage type"; if (inputs[0].storage_shape().Size()) { @@ -360,7 +376,7 @@ class UnaryOp : public OpBase { CHECK_EQ(outputs.size(), 1U) << "Invalid output, only one output is allowed"; CHECK_NE(inputs[0].storage_type(), kDefaultStorage) - << "Operation requires a sparse output storage type"; + << "Operation requires a sparse input storage type"; CHECK_NE(outputs[0].storage_type(), kDefaultStorage) << "Operation requires a sparse output storage type"; if (inputs[0].storage_shape().Size()) { @@ -369,23 +385,6 @@ class UnaryOp : public OpBase { } #endif - template - static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mshadow; - using namespace mxnet_op; - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { - Kernel::Launch(s, outputs[0].Size(), - outputs[0].dptr(), inputs[0].dptr()); - }); - } - template static void IdentityCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -915,6 +914,36 @@ void NumpyNanToNumOpBackward(const nnvm::NodeAttrs& attrs, MXNET_OPERATOR_REGISTER_UNARY(__name$) \ .set_attr("FCompute<" #__xpu$ ">", UnaryOp::Compute<__xpu$, __kernel$>) +#if MXNET_USE_CUDA + +struct UnaryRTCCompute { + std::string OP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +struct UnaryBwdInOutRTCCompute { + std::string OP; + + void operator()(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +#endif // MXNET_USE_CUDA + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index c00aea7e8af4..de8044368157 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -757,9 +757,7 @@ The storage type of ``sign`` output depends upon the input storage type: - sign(csr) = csr )code" ADD_FILELINE) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_sign"}); - -MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sign, unary_bwd); +.set_attr("FGradient", MakeZeroGradNodes); // round MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(round, cpu, mshadow_op::round) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu index 8fef6f5f7b38..074f7ac69a26 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cu +++ b/src/operator/tensor/elemwise_unary_op_basic.cu @@ -22,23 +22,22 @@ * \brief GPU Implementation of unary functions. */ #include "./elemwise_binary_op.h" +#include "./elemwise_unary_op.h" namespace mxnet { namespace op { NNVM_REGISTER_OP(relu) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"relu"}) +.set_attr("FComputeEx", UnaryRTCCompute{"relu"}); NNVM_REGISTER_OP(_backward_relu) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd>); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_relu"}); NNVM_REGISTER_OP(sigmoid) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"sigmoid"}); NNVM_REGISTER_OP(_backward_sigmoid) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd>); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_sigmoid"}); NNVM_REGISTER_OP(hard_sigmoid) .set_attr("FCompute", HardSigmoidForward); @@ -48,27 +47,26 @@ NNVM_REGISTER_OP(_backward_hard_sigmoid) // softsign NNVM_REGISTER_OP(softsign) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"softsign"}); NNVM_REGISTER_OP(_backward_softsign) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd>); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_softsign"}); // erf NNVM_REGISTER_OP(erf) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"erf"}); NNVM_REGISTER_OP(_backward_erf) .set_attr("FCompute", - ElemwiseBinaryOp::Compute>); + ElemwiseBinaryRTCCompute{"backward_erf"}); // erfinv NNVM_REGISTER_OP(erfinv) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"erfinv"}); NNVM_REGISTER_OP(_backward_erfinv) .set_attr("FCompute", - ElemwiseBinaryOp::Compute>); + ElemwiseBinaryRTCCompute{"backward_erfinv"}); // copy NNVM_REGISTER_OP(_copy) @@ -151,83 +149,76 @@ NNVM_REGISTER_OP(_backward_cast) // negative NNVM_REGISTER_OP(negative) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"negation"}) +.set_attr("FComputeEx", UnaryRTCCompute{"negation"}); // abs NNVM_REGISTER_OP(abs) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"abs"}) +.set_attr("FComputeEx", UnaryRTCCompute{"abs"}); NNVM_REGISTER_OP(_backward_abs) -.set_attr("FCompute", ElemwiseBinaryOp::Compute >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_abs"}); // sign NNVM_REGISTER_OP(sign) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); - -NNVM_REGISTER_OP(_backward_sign) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", UnaryRTCCompute{"sign"}) +.set_attr("FComputeEx", UnaryRTCCompute{"sign"}); // round NNVM_REGISTER_OP(round) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"round"}) +.set_attr("FComputeEx", UnaryRTCCompute{"round"}); // ceil NNVM_REGISTER_OP(ceil) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"ceil"}) +.set_attr("FComputeEx", UnaryRTCCompute{"ceil"}); // floor NNVM_REGISTER_OP(floor) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"floor"}) +.set_attr("FComputeEx", UnaryRTCCompute{"floor"}); // trunc NNVM_REGISTER_OP(trunc) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"trunc"}) +.set_attr("FComputeEx", UnaryRTCCompute{"trunc"}); // rint NNVM_REGISTER_OP(rint) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"rint"}) +.set_attr("FComputeEx", UnaryRTCCompute{"rint"}); // fix NNVM_REGISTER_OP(fix) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"fix"}) +.set_attr("FComputeEx", UnaryRTCCompute{"fix"}); // gamma NNVM_REGISTER_OP(gamma) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"gamma"}); NNVM_REGISTER_OP(_backward_gamma) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_gamma"}); // gammaln NNVM_REGISTER_OP(gammaln) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"gammaln"}); NNVM_REGISTER_OP(_backward_gammaln) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_gammaln"}); // digamma NNVM_REGISTER_OP(digamma) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"digamma"}); NNVM_REGISTER_OP(_backward_digamma) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_digamma"}); // logical not NNVM_REGISTER_OP(logical_not) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"logical_not"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_unary_op_logexp.cu b/src/operator/tensor/elemwise_unary_op_logexp.cu index febc1914feb7..e0f0d69cac11 100644 --- a/src/operator/tensor/elemwise_unary_op_logexp.cu +++ b/src/operator/tensor/elemwise_unary_op_logexp.cu @@ -28,49 +28,44 @@ namespace op { // exp NNVM_REGISTER_OP(exp) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"exp"}); // log NNVM_REGISTER_OP(log) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"log"}); // log10 NNVM_REGISTER_OP(log10) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"log10"}); // log2 NNVM_REGISTER_OP(log2) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"log2"}); NNVM_REGISTER_OP(_backward_log) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_log"}); NNVM_REGISTER_OP(_backward_log10) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_log10"}); NNVM_REGISTER_OP(_backward_log2) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_log2"}); // log1p NNVM_REGISTER_OP(log1p) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"log1p"}) +.set_attr("FComputeEx", UnaryRTCCompute{"log1p"}); NNVM_REGISTER_OP(_backward_log1p) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_log1p"}); // expm1 NNVM_REGISTER_OP(expm1) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"expm1"}) +.set_attr("FComputeEx", UnaryRTCCompute{"expm1"}); NNVM_REGISTER_OP(_backward_expm1) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_expm1"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_unary_op_pow.cu b/src/operator/tensor/elemwise_unary_op_pow.cu index 4dbdf349cdb0..c05627724738 100644 --- a/src/operator/tensor/elemwise_unary_op_pow.cu +++ b/src/operator/tensor/elemwise_unary_op_pow.cu @@ -22,61 +22,58 @@ * \brief GPU Implementation of power (x^k for fixed k) functions. */ #include "./elemwise_binary_op.h" +#include "./elemwise_unary_op.h" namespace mxnet { namespace op { // reciprocal NNVM_REGISTER_OP(reciprocal) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"reciprocal"}); NNVM_REGISTER_OP(_backward_reciprocal) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_reciprocal"}); // square NNVM_REGISTER_OP(square) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"square"}) +.set_attr("FComputeEx", UnaryRTCCompute{"square"}); NNVM_REGISTER_OP(_backward_square) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_square"}); // sqrt NNVM_REGISTER_OP(sqrt) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"sqrt"}) +.set_attr("FComputeEx", UnaryRTCCompute{"sqrt"}); NNVM_REGISTER_OP(_backward_sqrt) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_sqrt"}); // rsqrt NNVM_REGISTER_OP(rsqrt) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"rsqrt"}); NNVM_REGISTER_OP(_backward_rsqrt) .set_attr("FCompute", - ElemwiseBinaryOp::Compute >); + ElemwiseBinaryRTCCompute{"backward_rsqrt"}); // cbrt NNVM_REGISTER_OP(cbrt) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"cbrt"}) +.set_attr("FComputeEx", UnaryRTCCompute{"cbrt"}); NNVM_REGISTER_OP(_backward_cbrt) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_cbrt"}); // rcbrt NNVM_REGISTER_OP(rcbrt) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"rcbrt"}); NNVM_REGISTER_OP(_backward_rcbrt) .set_attr("FCompute", - ElemwiseBinaryOp::Compute >); + ElemwiseBinaryRTCCompute{"backward_rcbrt"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_unary_op_trig.cu b/src/operator/tensor/elemwise_unary_op_trig.cu index 8e28b9c609fa..8adf6b6ea260 100644 --- a/src/operator/tensor/elemwise_unary_op_trig.cu +++ b/src/operator/tensor/elemwise_unary_op_trig.cu @@ -22,131 +22,118 @@ * \brief GPU Implementation of unary trigonometric function. */ #include "./elemwise_binary_op.h" +#include "./elemwise_unary_op.h" namespace mxnet { namespace op { // sin NNVM_REGISTER_OP(sin) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"sin"}) +.set_attr("FComputeEx", UnaryRTCCompute{"sin"}); NNVM_REGISTER_OP(_backward_sin) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_sin"}); // cos NNVM_REGISTER_OP(cos) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"cos"}); NNVM_REGISTER_OP(_backward_cos) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_cos"}); // tan NNVM_REGISTER_OP(tan) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"tan"}) +.set_attr("FComputeEx", UnaryRTCCompute{"tan"}); NNVM_REGISTER_OP(_backward_tan) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_tan"}); // arcsin NNVM_REGISTER_OP(arcsin) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"arcsin"}) +.set_attr("FComputeEx", UnaryRTCCompute{"arcsin"}); NNVM_REGISTER_OP(_backward_arcsin) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arcsin"}); // arccos NNVM_REGISTER_OP(arccos) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"arccos"}); NNVM_REGISTER_OP(_backward_arccos) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arccos"}); // arctan NNVM_REGISTER_OP(arctan) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"arctan"}) +.set_attr("FComputeEx", UnaryRTCCompute{"arctan"}); NNVM_REGISTER_OP(_backward_arctan) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arctan"}); // degrees NNVM_REGISTER_OP(degrees) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"degrees"}) +.set_attr("FComputeEx", UnaryRTCCompute{"degrees"}); NNVM_REGISTER_OP(_backward_degrees) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_degrees"}); // radians NNVM_REGISTER_OP(radians) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"radians"}) +.set_attr("FComputeEx", UnaryRTCCompute{"radians"}); NNVM_REGISTER_OP(_backward_radians) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_radians"}); // cosh NNVM_REGISTER_OP(cosh) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"cosh"}); NNVM_REGISTER_OP(_backward_cosh) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_cosh"}); // sinh NNVM_REGISTER_OP(sinh) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"sinh"}) +.set_attr("FComputeEx", UnaryRTCCompute{"sinh"}); NNVM_REGISTER_OP(_backward_sinh) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_sinh"}); // tanh NNVM_REGISTER_OP(tanh) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"tanh"}) +.set_attr("FComputeEx", UnaryRTCCompute{"tanh"}); NNVM_REGISTER_OP(_backward_tanh) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_tanh"}); // arcsinh NNVM_REGISTER_OP(arcsinh) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"arcsinh"}) +.set_attr("FComputeEx", UnaryRTCCompute{"arcsinh"}); NNVM_REGISTER_OP(_backward_arcsinh) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arcsinh"}); // arccosh NNVM_REGISTER_OP(arccosh) -.set_attr("FCompute", UnaryOp::Compute); +.set_attr("FCompute", UnaryRTCCompute{"arccosh"}); NNVM_REGISTER_OP(_backward_arccosh) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arccosh"}); // arctanh NNVM_REGISTER_OP(arctanh) -.set_attr("FCompute", UnaryOp::Compute) -.set_attr("FComputeEx", UnaryOp::ComputeEx); +.set_attr("FCompute", UnaryRTCCompute{"arctanh"}) +.set_attr("FComputeEx", UnaryRTCCompute{"arctanh"}); NNVM_REGISTER_OP(_backward_arctanh) -.set_attr("FCompute", ElemwiseBinaryOp::Compute< - gpu, unary_bwd >); +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_arctanh"}); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh index 5898c0bcf07c..c89fe2e2b959 100644 --- a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh +++ b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh @@ -32,7 +32,7 @@ #include #include #include -#include "../../common/cuda_utils.h" +#include "../../common/cuda/utils.h" namespace mxnet { diff --git a/src/operator/tensor/reduce_rtc.cc b/src/operator/tensor/reduce_rtc.cc new file mode 100644 index 000000000000..9e2d6d3f2a53 --- /dev/null +++ b/src/operator/tensor/reduce_rtc.cc @@ -0,0 +1,524 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "broadcast_reduce-inl.h" +#include "elemwise_unary_op.h" + +#if MXNET_USE_CUDA +#include "../../common/cuda/rtc.h" +#endif // MXNET_USE_CUDA + +using namespace mshadow; + +namespace mxnet { +namespace op { +namespace broadcast { + +#if MXNET_USE_CUDA + +namespace { + +struct reduce_kernel_params { + index_t big_shape[MAX_DIM]; + index_t small_shape[MAX_DIM]; + index_t lhs_shape0[MAX_DIM]; + index_t rhs_shape0[MAX_DIM]; + index_t rshape[MAX_DIM]; + index_t rstride[MAX_DIM]; + index_t lhs_stride[MAX_DIM]; + index_t rhs_stride[MAX_DIM]; + index_t lhs_shape[MAX_DIM]; + index_t rhs_shape[MAX_DIM]; +}; + +const char reduce_function_code[] = R"code( +#define FUNC OP(IType0::from(big[idx_big[u]])) +)code"; + +const char reduce_function_use_input_code[] = R"code( +#define FUNC OP1(IType0::from(big[idx_big[u]]), \ + OP2(IType1::from(lhs[idx_lhs[u]]), \ + IType2::from(rhs[idx_rhs[u]]))) +)code"; + +const char reduce_kernel_code[] = R"code( +struct reduce_kernel_params { + index_t big_shape[util::MAX_DIM]; + index_t small_shape[util::MAX_DIM]; + index_t lhs_shape0[util::MAX_DIM]; + index_t rhs_shape0[util::MAX_DIM]; + index_t rshape[util::MAX_DIM]; + index_t rstride[util::MAX_DIM]; + index_t lhs_stride[util::MAX_DIM]; + index_t rhs_stride[util::MAX_DIM]; + index_t lhs_shape[util::MAX_DIM]; + index_t rhs_shape[util::MAX_DIM]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void reduce_kernel(const int N, const int M, const bool addto, + const InputType0* __restrict big, + const InputType1* __restrict lhs, + const InputType2* __restrict rhs, + OutputType0 *small, + const reduce_kernel_params params, + const int Mnext) { + extern __shared__ char shTileChar[]; + using IType0 = AccType; + using IType1 = AccType; + using IType2 = AccType; + using OType = AccType; + using AType = typename IType0::type; + AType* shTile = (AType*)(shTileChar); + const int tid = threadIdx.x + threadIdx.y*blockDim.x; + const int bx = (do_transpose) ? blockDim.y : blockDim.x; + const int by = (do_transpose) ? blockDim.x : blockDim.y; + const int tidx = (do_transpose) ? tid / by : threadIdx.x; + const int tidy = (do_transpose) ? tid % by : threadIdx.y; + for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) { + // This TB handles M range [Mstart, ...., Mend - 1] + const index_t Mstart = (index_t)((int64)M*(int64)m0/(int64)Mnext); + const index_t Mend = (index_t)((int64)M*(int64)(m0 + 1)/(int64)Mnext); + for (index_t idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) { + int idx = idx0 + tidx; + index_t coord[ndim]; + util::unravel(idx, params.small_shape, coord); + index_t idx_big0, idx_lhs0, idx_rhs0; + idx_big0 = util::ravel(coord, params.big_shape); + if (use_input) { + idx_lhs0 = util::ravel(coord, params.lhs_shape0); + idx_rhs0 = util::ravel(coord, params.rhs_shape0); + } + + AType val, residual; + REDUCER::SetInitValue(val, residual); + if (idx < N) { + for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) { + index_t idx_big[UNROLL]; + index_t idx_lhs[UNROLL]; + index_t idx_rhs[UNROLL]; + #pragma unroll + for (int u=0;u < UNROLL;u++) { + idx_big[u] = idx_big0 + util::unravel_dot(k + u*by, params.rshape, + params.rstride); + if (use_input) { + idx_lhs[u] = idx_lhs0 + util::unravel_dot(k + u*by, params.lhs_shape, + params.lhs_stride); + idx_rhs[u] = idx_rhs0 + util::unravel_dot(k + u*by, params.rhs_shape, + params.rhs_stride); + } + } + typename OType::type tmp[UNROLL]; + #pragma unroll + for (int u=0;u < UNROLL;u++) { + if (k + u*by < Mend) { + tmp[u] = FUNC; + } + } + #pragma unroll + for (int u=0;u < UNROLL;u++) { + if (k + u*by < Mend) REDUCER::Reduce(val, tmp[u], residual); + } + } + } + + // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0 + if (by > 1) { + // Fix bx to avoid bank conflicts. Assumes warpSize number of banks + const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; + const int it0 = tidx + tidy*fbx; + shTile[it0 * 2] = val; + shTile[it0 * 2 + 1] = residual; + __syncthreads(); + for (int t=1;t < by;t <<= 1) { + AType tmp, tmp_residual; + REDUCER::SetInitValue(tmp, tmp_residual); + if (tidy + t < by) { + tmp = shTile[(it0 + t*fbx) * 2]; + tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; + } + __syncthreads(); + REDUCER::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); + __syncthreads(); + } + if (idx < N && tidy == 0) { + REDUCER::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); + if (addto) { + small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), + shTile[tidx * 2])); + } else { + small[idx + m0 * N] = OType::to(shTile[tidx * 2]); + } + } + } else { + if (idx < N) { + REDUCER::Finalize(val, residual); + if (addto) { + small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]), + val)); + } else { + small[idx + m0 * N] = OType::to(val); + } + } + } + } + } +} +)code"; + +const char reduce_lines_kernel_code[] = R"code( +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void reduce_lines_kernel(const index_t N, const index_t M, + const index_t small_in_stride, + const OutputType0* __restrict small_in, + OutputType0 *small_out) { + using OType = AccType; + for (index_t idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { + typename OType::type val, residual; + REDUCER::SetInitValue(val, residual); + for (int k = 0; k < M; k++) { + REDUCER::Reduce(val, + OType::from(reinterpret_cast(small_in)[idx + k*small_in_stride]), + residual); + } + + if (idx < N) { + REDUCER::Finalize(val, residual); + if (req == OpReqType::kAddTo) { + small_out[idx] = OType::to(op::add(OType::from(small_out[idx]), val)); + } else { + small_out[idx] = OType::to(val); + } + } + + } +} +)code"; + +void RTCReduceImpl(Stream *s, const TBlob& small, const bool addto, + const TBlob& big, const Tensor& workspace, + const ReduceImplConfig& config, const int ndim, + const std::string &common_code, int dev_id, + const TBlob *lhs = nullptr, const TBlob *rhs = nullptr) { + using namespace common::cuda::rtc; + void* small_dptr = small.dptr_; + bool first_kernel_addto = addto; + if (config.Mnext > 1) { + // small_dptr[] is N*Mnext*sizeof(DType) bytes + small_dptr = workspace.dptr_; + first_kernel_addto = false; + // Check that the workspace is contigiuous + CHECK_EQ(workspace.CheckContiguous(), true); + // Check that we have enough storage + CHECK_GE(workspace.size(0), config.workspace_size); + } + + const int by = (config.kernel_1.do_transpose) ? + config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; + const bool do_unroll = (config.M / (by*config.Mnext) >= unroll_reduce); + std::string code = common_code + + "#define UNROLL " + + (do_unroll ? std::to_string(unroll_reduce) : "1") + + "\n" + "const bool do_transpose = " + + (config.kernel_1.do_transpose ? "true" : "false") + + ";\n" + "using InputType0 = " + + common::mshadow_type_info(big.type_flag_).name + + ";\n" + "using OutputType0 = " + + common::mshadow_type_info(small.type_flag_).name + + ";\n" + "using InputType1 = " + + ((lhs != nullptr) + ? common::mshadow_type_info(lhs->type_flag_).name + : "float32") + + ";\n" + "using InputType2 = " + + ((rhs != nullptr) + ? common::mshadow_type_info(rhs->type_flag_).name + : "float32") + + ";\n"; + if (lhs != nullptr) { + code += "const bool use_input = true;"; + } else { + code += "const bool use_input = false;"; + } + + reduce_kernel_params param {}; + for (int i = 0; i < ndim; ++i) { + param.big_shape[i] = big.shape_[i]; + param.small_shape[i] = small.shape_[i]; + param.rshape[i] = config.rshape[i]; + param.rstride[i] = config.rstride[i]; + if (lhs != nullptr) { + param.lhs_shape0[i] = lhs->shape_[i]; + param.rhs_shape0[i] = rhs->shape_[i]; + param.lhs_shape[i] = config.lhs_shape[i]; + param.rhs_shape[i] = config.rhs_shape[i]; + param.lhs_stride[i] = config.lhs_stride[i]; + param.rhs_stride[i] = config.rhs_stride[i]; + } + } + + void *null_ptr = nullptr; + std::vector args; + args.emplace_back(&config.N); + args.emplace_back(&config.M); + args.emplace_back(&first_kernel_addto); + args.emplace_back(&big.dptr_); + if (lhs != nullptr) { + args.emplace_back(&(lhs->dptr_)); + args.emplace_back(&(rhs->dptr_)); + } else { + args.emplace_back(&(null_ptr)); + args.emplace_back(&(null_ptr)); + } + args.emplace_back(&small_dptr); + args.emplace_back(¶m); + args.emplace_back(&config.Mnext); + + const auto &function_code = (lhs == nullptr) + ? reduce_function_code + : reduce_function_use_input_code; + auto reduce_kernel_func = get_function(code + function_code, + "reduce_kernel", + reduce_kernel_code, + dev_id); + launch(reduce_kernel_func, config.kernel_1.gridDim, + config.kernel_1.blockDim, + config.kernel_1.shMemSize, s, &args); + + if (config.Mnext > 1) { + args.resize(0); + args.emplace_back(&config.N); + args.emplace_back(&config.Mnext); + args.emplace_back(&config.N); + args.emplace_back(&small_dptr); + args.emplace_back(&small.dptr_); + + auto reduce_lines_kernel_func = get_function(code, + "reduce_lines_kernel", + reduce_lines_kernel_code, + dev_id); + launch(reduce_lines_kernel_func, config.kernel_2.gridSize, + config.kernel_2.blockSize, 0, s, &args); + } +} + +struct reduce_kernel_M1_params { + index_t big_shape[MAX_DIM]; + index_t lhs_shape[MAX_DIM]; + index_t rhs_shape[MAX_DIM]; + index_t small_shape[MAX_DIM]; +}; + +const char reduce_kernel_M1_code[] = R"code( +struct reduce_kernel_M1_params { + index_t big_shape[util::MAX_DIM]; + index_t lhs_shape[util::MAX_DIM]; + index_t rhs_shape[util::MAX_DIM]; + index_t small_shape[util::MAX_DIM]; +}; + +__launch_bounds__(kRTCMaxThreadsPerBlock) +__global__ void reduce_kernel_M1(const int N, + const InputType0* __restrict big, + const InputType1* __restrict lhs, + const InputType2* __restrict rhs, + OutputType0 *small, + const reduce_kernel_M1_params params) { + using IType0 = AccType; + using IType1 = AccType; + using IType2 = AccType; + using OType = AccType; + for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { + index_t coord[ndim]; + util::unravel(idx, params.small_shape, coord); + index_t idx_big[1]; + idx_big[0] = util::ravel(coord, params.big_shape); + index_t idx_lhs[1], idx_rhs[1]; + if (use_input) { + idx_lhs[0] = util::ravel(coord, params.lhs_shape); + idx_rhs[0] = util::ravel(coord, params.rhs_shape); + } + typename OType::type val, residual; + REDUCER::SetInitValue(val, residual); + const int u = 0; + REDUCER::Reduce(val, FUNC, residual); + REDUCER::Finalize(val, residual); + if (req == OpReqType::kAddTo) { + const auto temp = op::add(val, OType::from(small[idx])); + small[idx] = OType::to(temp); + } else { + small[idx] = OType::to(val); + } + } +} +)code"; + +void RTCReduceM1Impl(Stream *s, const TBlob &small, const TBlob &big, + const TBlob *lhs, const TBlob *rhs, + const ReduceImplConfig &config, const int ndim, + const std::string &common_code, int dev_id) { + using namespace common::cuda::rtc; + + std::string code = common_code + + "using InputType0 = " + + common::mshadow_type_info(big.type_flag_).name + + ";\n" + "using InputType1 = " + + ((lhs != nullptr) + ? common::mshadow_type_info(lhs->type_flag_).name + : "float32") + + ";\n" + "using InputType2 = " + + ((rhs != nullptr) + ? common::mshadow_type_info(rhs->type_flag_).name + : "float32") + + ";\n" + "using OutputType0 = " + + common::mshadow_type_info(small.type_flag_).name + + ";\n"; + if (lhs != nullptr) { + code += "const bool use_input = true;"; + } else { + code += "const bool use_input = false;"; + } + + reduce_kernel_M1_params param {}; + for (int i = 0; i < ndim; ++i) { + param.big_shape[i] = big.shape_[i]; + param.small_shape[i] = small.shape_[i]; + if (lhs != nullptr) { + param.lhs_shape[i] = lhs->shape_[i]; + param.rhs_shape[i] = rhs->shape_[i]; + } + } + + void *null_ptr = nullptr; + std::vector args; + args.emplace_back(&config.N); + args.emplace_back(&big.dptr_); + if (lhs != nullptr) { + args.emplace_back(&(lhs->dptr_)); + args.emplace_back(&(rhs->dptr_)); + } else { + args.emplace_back(&(null_ptr)); + args.emplace_back(&(null_ptr)); + } + args.emplace_back(&small.dptr_); + args.emplace_back(¶m); + + const auto &function_code = (lhs == nullptr) + ? reduce_function_code + : reduce_function_use_input_code; + auto reduce_kernel_M1_func = get_function(code + function_code, + "reduce_kernel_M1", + reduce_kernel_M1_code, + dev_id); + launch(reduce_kernel_M1_func, config.kernel_1.gridDim, + config.kernel_1.blockDim, + 0, s, &args); +} + +} // namespace + +void RTCReduce(const OpContext& ctx, + const TBlob& small, + const OpReqType req, + const Tensor& workspace, + const TBlob& big, + const std::string& reducer, + int ndim, + const std::string& OP) { + using namespace mxnet::common::cuda::rtc; + if (req == kNullOp) return; + Stream *s = ctx.get_stream(); + size_t big_type_size = common::mshadow_type_info(big.type_flag_).acc_size; + size_t small_type_size = common::mshadow_type_info(small.type_flag_).acc_size; + size_t type_size = std::max(big_type_size, small_type_size); + ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, type_size); + std::string common_code = std::string("const OpReqType req = ") + + util::to_string(req) + + ";\n" + "#define OP op::" + + OP + + "\n" + "#define REDUCER " + + reducer + + "\n" + "const int ndim = " + + std::to_string(ndim) + + ";\n"; + if (config.M == 1) { + RTCReduceM1Impl(s, small, big, nullptr, nullptr, config, + ndim, common_code, ctx.run_ctx.ctx.dev_id); + } else { + RTCReduceImpl(s, small, req == kAddTo, big, workspace, config, + ndim, common_code, ctx.run_ctx.ctx.dev_id); + } +} + +void RTCReduce(const OpContext& ctx, + const TBlob& small, + const OpReqType req, + const Tensor& workspace, + const TBlob& big, + const TBlob &lhs, + const TBlob &rhs, + const std::string& reducer, + int ndim, + const std::string& OP1, + const std::string& OP2) { + using namespace mxnet::common::cuda::rtc; + if (req == kNullOp) return; + Stream *s = ctx.get_stream(); + size_t big_type_size = common::mshadow_type_info(big.type_flag_).acc_size; + size_t small_type_size = common::mshadow_type_info(small.type_flag_).acc_size; + size_t type_size = std::max(big_type_size, small_type_size); + ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_, type_size); + std::string common_code = std::string("const OpReqType req = ") + + util::to_string(req) + + ";\n" + "#define OP1 op::" + + OP1 + + "\n" + "#define OP2 op::" + + OP2 + + "\n" + "#define REDUCER " + + reducer + + "\n" + "const int ndim = " + + std::to_string(ndim) + + ";\n"; + if (config.M == 1) { + RTCReduceM1Impl(s, small, big, &lhs, &rhs, config, ndim, common_code, ctx.run_ctx.ctx.dev_id); + } else { + RTCReduceImpl(s, small, req == kAddTo, big, workspace, config, + ndim, common_code, ctx.run_ctx.ctx.dev_id, &lhs, &rhs); + } +} + +#endif // MXNET_USE_CUDA + +} // namespace broadcast +} // namespace op +} // namespace mxnet diff --git a/src/profiler/profiler.cc b/src/profiler/profiler.cc index 080d0454faff..107c171107d6 100644 --- a/src/profiler/profiler.cc +++ b/src/profiler/profiler.cc @@ -31,7 +31,7 @@ #include "./profiler.h" #if MXNET_USE_CUDA -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" #endif #if defined(_MSC_VER) && _MSC_VER <= 1800 diff --git a/src/profiler/storage_profiler.cc b/src/profiler/storage_profiler.cc index edb16cf32337..5bbfa5917ea9 100644 --- a/src/profiler/storage_profiler.cc +++ b/src/profiler/storage_profiler.cc @@ -27,7 +27,7 @@ #include #include "./profiler.h" #include "../common/utils.h" -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" namespace mxnet { namespace profiler { diff --git a/src/resource.cc b/src/resource.cc index 28e24e5c6984..f4f9da2f041a 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -34,7 +34,7 @@ #include #include "./common/lazy_alloc_array.h" #include "./common/utils.h" -#include "./common/cuda_utils.h" +#include "./common/cuda/utils.h" #include "./profiler/storage_profiler.h" namespace mxnet { diff --git a/src/storage/storage_manager_helpers.h b/src/storage/storage_manager_helpers.h index 1fccb5a08f45..14f9ea7727fc 100644 --- a/src/storage/storage_manager_helpers.h +++ b/src/storage/storage_manager_helpers.h @@ -22,7 +22,7 @@ #if MXNET_USE_CUDA #include -#include "../common/cuda_utils.h" +#include "../common/cuda/utils.h" #include "../profiler/storage_profiler.h" typedef mxnet::common::cuda::DeviceStore CudaDeviceStore; #endif // MXNET_USE_CUDA diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 1f261adcebac..6b98130503d5 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -110,6 +110,7 @@ def check_unary_ops(): 'gammaln', 'erf', 'negative', + 'logical_not', ] def announce_check(op_name): diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 93d3e7085148..a45f973a92fa 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -37,7 +37,7 @@ import random from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf from mxnet.numpy_op_signature import _get_builtin_op -from mxnet.test_utils import is_op_runnable, has_tvm_ops +from mxnet.test_utils import is_op_runnable, has_tvm_ops, rand_shape_2d from mxnet.operator import get_all_registered_operators @@ -10209,6 +10209,7 @@ def hybrid_forward(self, F, a, *args, **kwargs): assert same(mx_out.asnumpy(), np_out) +@with_seed() @use_np def test_npx_stop_gradient(): class TestStopGradient(HybridBlock): @@ -10236,3 +10237,81 @@ def hybrid_forward(self, F, a): elif grad_req == 'add': assert_almost_equal(new_grad, old_grad + 1) + +@with_seed() +@use_np +def test_np_elementwise_ops_on_misaligned_input(): + a = np.array([1,2,3,4], dtype='float16') + b = np.array([1,2,3,4], dtype='float16') + + c = a[1:3] + d = b[1:3] + # Note: testing just elemwise_add since all elemwise_ops + # share the implementation + c[:] = c + d + mx.nd.waitall() + + a = np.array([1,2,3,4], dtype='float16') + b = np.array([1,2,3,4], dtype='float16') + + c = a[0:3] + d = b[0:3] + c[:] = c + d + mx.nd.waitall() + assert a[3] == 4.0 + + +@with_seed() +@use_np +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64']) +@pytest.mark.parametrize('lead_dim', [2, 3, 4, 6, 10]) +@pytest.mark.parametrize('both_ways', [False, True]) +def test_np_broadcast_ops_on_misaligned_input(dtype, lead_dim, both_ways): + shape = list(rand_shape_2d()) + [lead_dim] + small_shape = [shape[0], 1, lead_dim] + if both_ways: + # Broadcast in both ways [1, K, L] x [M, 1, L] + big_shape = [1, shape[1], lead_dim] + else: + big_shape = shape + size = _np.product(shape) + small_size = _np.product(small_shape) + big_size = _np.product(big_shape) + a = np.arange(5000) + b = np.arange(5000) + e = np.arange(5000) + c = a[1:big_size + 1].reshape(tuple(big_shape)) + d = b[1:small_size + 1].reshape(tuple(small_shape)) + f = e[1:size + 1].reshape(tuple(shape)) + f[:] = c + d + expected = c.asnumpy() + d.asnumpy() + mx.nd.waitall() + assert_almost_equal(f, expected) + + +@with_seed() +@use_np +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64']) +@pytest.mark.parametrize('lead_dim', [2, 3, 4, 6, 10]) +@pytest.mark.parametrize('both_ways', [False, True]) +def test_np_broadcast_ops_on_misaligned_input_oneside(dtype, lead_dim, both_ways): + shape = list(rand_shape_2d()) + [lead_dim] + small_shape = [shape[0], shape[1], 1] + if both_ways: + # Broadcast in both ways [1, K, L] x [M, 1, 1] + big_shape = [1, shape[1], lead_dim] + else: + big_shape = shape + size = _np.product(shape) + small_size = _np.product(small_shape) + big_size = _np.product(big_shape) + a = np.arange(5000) + b = np.arange(5000) + e = np.arange(5000) + c = a[1:big_size + 1].reshape(tuple(big_shape)) + d = b[1:small_size + 1].reshape(tuple(small_shape)) + f = e[1:size + 1].reshape(tuple(shape)) + f[:] = c + d + expected = c.asnumpy() + d.asnumpy() + mx.nd.waitall() + assert_almost_equal(f, expected) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index a44ba327b3a1..6732336ff60c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9337,4 +9337,76 @@ def test_elemwise_sum_for_gradient_accumulation(): assert stored_grad['write'] == stored_grad['add'] assert stored_grad['write'] == 2 * nrepeat +@with_seed() +def test_elementwise_ops_on_misaligned_input(): + a = mx.nd.array([1,2,3,4], dtype='float16') + b = mx.nd.array([1,2,3,4], dtype='float16') + + c = a[1:3] + d = b[1:3] + # Note: testing just elemwise_add since all elemwise_ops + # share the implementation + mx.nd.elemwise_add(c, d, out=c) + mx.nd.waitall() + + a = mx.nd.array([1,2,3,4], dtype='float16') + b = mx.nd.array([1,2,3,4], dtype='float16') + + c = a[0:3] + d = b[0:3] + mx.nd.elemwise_add(c, d, out=c) + mx.nd.waitall() + assert a[3].asscalar() == 4.0 + +@with_seed() +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64']) +@pytest.mark.parametrize('lead_dim', [2, 3, 4, 6, 10]) +@pytest.mark.parametrize('both_ways', [False, True]) +def test_broadcast_ops_on_misaligned_input(dtype, lead_dim, both_ways): + shape = list(rand_shape_2d()) + [lead_dim] + small_shape = [shape[0], 1, lead_dim] + if both_ways: + # Broadcast in both ways [1, K, L] x [M, 1, L] + big_shape = [1, shape[1], lead_dim] + else: + big_shape = shape + size = np.product(shape) + small_size = np.product(small_shape) + big_size = np.product(big_shape) + a = mx.nd.arange(5000) + b = mx.nd.arange(5000) + e = mx.nd.arange(5000) + c = a[1:big_size + 1].reshape(big_shape) + d = b[1:small_size + 1].reshape(small_shape) + f = e[1:size + 1].reshape(shape) + mx.nd.broadcast_add(c, d, out=f) + expected = c.asnumpy() + d.asnumpy() + mx.nd.waitall() + assert_almost_equal(f, expected) + +@with_seed() +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64']) +@pytest.mark.parametrize('lead_dim', [2, 3, 4, 6, 10]) +@pytest.mark.parametrize('both_ways', [False, True]) +def test_broadcast_ops_on_misaligned_input_oneside(dtype, lead_dim, both_ways): + shape = list(rand_shape_2d()) + [lead_dim] + small_shape = [shape[0], shape[1], 1] + if both_ways: + # Broadcast in both ways [1, K, L] x [M, 1, 1] + big_shape = [1, shape[1], lead_dim] + else: + big_shape = shape + size = np.product(shape) + small_size = np.product(small_shape) + big_size = np.product(big_shape) + a = mx.nd.arange(5000) + b = mx.nd.arange(5000) + e = mx.nd.arange(5000) + c = a[1:big_size + 1].reshape(big_shape) + d = b[1:small_size + 1].reshape(small_shape) + f = e[1:size + 1].reshape(shape) + mx.nd.broadcast_add(c, d, out=f) + expected = c.asnumpy() + d.asnumpy() + mx.nd.waitall() + assert_almost_equal(f, expected) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 8bc086e14e52..970dd640f0c6 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1853,152 +1853,6 @@ def check_broadcast_div(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype): check_broadcast_mul(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32) check_broadcast_div(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32) -@with_seed() -def test_scatter_ops(): - def csr_get_seen_points(name, csr_array, verbose=False): - """Get a unique list of points int he CSR array as well as a - corresponding parallel list of points and values""" - seen_points = set() - seen_point_list = list() - values = list() - row_count = csr_array.shape[0] - row_pointers = csr_array.indptr.asnumpy() - col_indexes = csr_array.indices.asnumpy() - data = csr_array.data.asnumpy() - for row in range(row_count): - start_pos = row_pointers[row] - end_pos = row_pointers[row + 1] - for col_index in range(start_pos, end_pos): - col = col_indexes[col_index] - val = data[col_index] - if verbose is True: - print("{}: (row, col = ({}, {}) = {}".format(name, row, col, val)) - seen_points.add((row, col)) - seen_point_list.append((row, col)) - values.append(val) - return seen_points, values, seen_point_list - - def check_scatter_ops(name, shape, lhs_stype, rhs_stype, forward_mxnet_call, forward_numpy_call, - density=0.25, rhs_is_scalar=False, verbose=False): - lhs = mx.symbol.Variable('lhs', stype=lhs_stype) - if rhs_is_scalar is False: - rhs = mx.symbol.Variable('rhs', stype=rhs_stype) - - if verbose is True: - print(name) - - if lhs_stype != 'default': - lhs_nd = create_sparse_array_zd( - shape, lhs_stype, density=density, - rsp_indices=gen_rsp_random_indices( - shape, - density=density, - force_indices=[(shape[0]/2)] # force at least one overlap - )) - else: - lhs_nd = rand_ndarray(shape, 'default') - - if rhs_is_scalar is False: - if rhs_stype != 'default': - rhs_nd = create_sparse_array_zd( - shape, rhs_stype, density=density, - rsp_indices=gen_rsp_random_indices( - shape, - density=density, - force_indices=[(shape[0]/2)] # force at least one overlap - )) - else: - rhs_nd = rand_ndarray(shape, 'default') - else: - rhs_nd = 9 - rhs = rhs_nd - - lhs_np = lhs_nd.asnumpy() - rhs_np = rhs_nd if rhs_is_scalar is True else rhs_nd.asnumpy() - - if verbose is True: - print("lhs = {}".format(lhs_np)) - print("rhs = {}".format(rhs_np)) - - out_np = forward_numpy_call(lhs_np, rhs_np) - - if verbose is True: - print("Numpy: out_np = {}".format(out_np)) - - location = {'lhs': lhs_nd, 'rhs': rhs_nd} - - out = forward_mxnet_call(lhs, rhs) - exe_test = out._bind(default_context(), args=location) - exe_test.forward(is_train=False) - out_nd = exe_test.outputs[0] - - if verbose is True: - print("Sym: out_nd = {}".format(out_nd.asnumpy())) - - # For row_sparse, check that rows only exist for rows that are - # either int lhs or rhs, and if they exist, they should equal - # the numpy values - if lhs_stype == 'default': - almost_equal(out_nd.asnumpy(), out_np, equal_nan=True) - elif lhs_stype == 'row_sparse': - seen_rows = set() - indices = lhs_nd.indices.asnumpy() - for i in range(len(indices)): - seen_rows.add(indices[i]) - assert len(out_nd.indices.asnumpy()) == len(seen_rows) - out_nd_np = out_nd.asnumpy() - for row in seen_rows: - row_nd = out_nd_np[row] - row_np = out_np[row] - almost_equal(row_nd, row_np, equal_nan=True) - elif lhs_stype == 'csr' and rhs_is_scalar is False: - almost_equal(out_nd.asnumpy(), out_np, equal_nan=True) - else: - assert rhs_is_scalar - lhs_seen_points, _, _ = csr_get_seen_points("lhs", lhs_nd, verbose) - if rhs_is_scalar is False: - rhs_seen_points, _, _ = csr_get_seen_points("rhs", rhs_nd, verbose) - else: - rhs_seen_points = set() - input_seen_points = lhs_seen_points.union(rhs_seen_points) - out_seen_pounts, out_values, seen_point_list = csr_get_seen_points("out_nd", out_nd, verbose) - # Some may have been zero - assert len(out_seen_pounts) <= len(input_seen_points) - out_nd_np = out_nd.asnumpy() - val_index = 0 - for row_col in seen_point_list: - row = row_col[0] - col = row_col[1] - val = out_values[val_index] - val_np = out_nd_np[row, col] - almost_equal(val, val_np, equal_nan=True) - val_index += 1 - - shape = (10, 5) - - for lhs_stype in ['row_sparse', 'default', 'csr']: - for rhs_stype in ['row_sparse', 'default', 'csr']: - print("op: {}, lhs_stype: {}, rhs_stype: {}".format('_scatter_elemwise_div', - lhs_stype, rhs_stype)) - check_scatter_ops('_scatter_elemwise_div', shape, lhs_stype, rhs_stype, - lambda l, r: mx.sym._internal._scatter_elemwise_div(l, r), - lambda l, r: l / r, - verbose=False) - - for lhs_stype in ['row_sparse', 'default', 'csr']: - print("op: {}, lhs_stype: {}".format('_scatter_plus', lhs_stype)) - check_scatter_ops('_scatter_plus', shape, lhs_stype, 'scalar', - lambda l, r: mx.sym._internal._scatter_plus_scalar(l, r), - lambda l, r: l + r, - rhs_is_scalar=True, verbose=False) - - print("op: {}, lhs_stype: {}".format('_scatter_minus', lhs_stype)) - check_scatter_ops('_scatter_minus', shape, lhs_stype, 'scalar', - lambda l, r: mx.sym._internal._scatter_minus_scalar(l, r), - lambda l, r: l + r, - rhs_is_scalar=True, verbose=False, density=0.5) - - @with_seed() def test_batchnorm_fallback(): # same test as test_operator.test_batchnorm_training, but tests fallback logic of batchnorm