From fe3439afca1fa5977dbb242b64fe99306f8f4a9a Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Tue, 10 Dec 2019 12:54:33 -0800 Subject: [PATCH] Workaround problem with fusion in CUDA 9 (#17028) (#17035) --- src/operator/fusion/fused_op-inl.h | 232 +++++++++++++++-------------- src/operator/fusion/fused_op.cu | 5 +- 2 files changed, 123 insertions(+), 114 deletions(-) diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h index e86ce7682ad8..7373cd07400a 100644 --- a/src/operator/fusion/fused_op-inl.h +++ b/src/operator/fusion/fused_op-inl.h @@ -256,22 +256,22 @@ struct LoadType { }; template -inline typename LoadType::Type load(const DType input) { +__device__ inline typename LoadType::Type load(const DType input) { return input; } template <> -inline float load(const half input) { +__device__ inline float load(const half input) { return __half2float(input); } template -inline DType1 store(const DType2 input, DType1* ref) { +__device__ inline DType1 store(const DType2 input, DType1* ref) { return input; } template -inline half store(const DType input, half* ref) { +__device__ inline half store(const DType input, half* ref) { return __float2half(input); } @@ -297,12 +297,12 @@ struct VectorConfig<32> { }; template -inline DType add_elem(const DType& x, const DType& y) { +__device__ inline DType add_elem(const DType& x, const DType& y) { return x + y; } template <> -inline half add_elem(const half& x, const half& y) { +__device__ inline half add_elem(const half& x, const half& y) { return __float2half(__half2float(x) + __half2float(y)); } @@ -310,14 +310,14 @@ template union VectorType { typename VectorConfig::IndexType y; DType x[nvec]; - VectorType () {}; - VectorType (const VectorType& y2) { + __device__ VectorType () {}; + __device__ VectorType (const VectorType& y2) { y = y2.y; } - VectorType (const decltype(y) &y2) { + __device__ VectorType (const decltype(y) &y2) { y = y2; } - inline VectorType& operator+=(const VectorType& rhs) { + __device__ inline VectorType& operator+=(const VectorType& rhs) { #pragma unroll for (int i = 0; i < nvec; ++i) { x[i] = add_elem(x[i], rhs.x[i]); @@ -330,13 +330,13 @@ template struct Shape { int x[ndim]; size_t size; - inline const int& operator [](const int i) const { + __device__ inline const int& operator [](const int i) const { return x[i]; } - inline int& operator [](const int i) { + __device__ inline int& operator [](const int i) { return x[i]; } - inline void set(const int def) { + __device__ inline void set(const int def) { #pragma unroll for (int i = 0; i < ndim; i++) { x[i] = def; @@ -350,7 +350,8 @@ struct Shape<0> { }; template -inline VectorType load_index(const DType * input, int i, const Shape &shape) { +__device__ inline VectorType load_index(const DType * input, int i, + const Shape &shape) { if (i < shape.size) { const auto* vector_input = reinterpret_cast< const typename VectorConfig::IndexType *>( @@ -364,7 +365,8 @@ inline VectorType load_index(const DType * input, int i, const Shap } template -inline VectorType global_load_index(const DType * input, int i, const Shape &shape) { +__device__ inline VectorType global_load_index(const DType * input, int i, + const Shape &shape) { if (i < shape.size) { const auto* vector_input = reinterpret_cast< const typename VectorConfig::IndexType *>( @@ -378,7 +380,9 @@ inline VectorType global_load_index(const DType * input, int i, con } template -inline VectorType load_slice(const DType * input, const Shape& shape, Shape begin, Shape end, int offset) { +__device__ inline VectorType load_slice(const DType * input, const Shape& shape, + Shape begin, Shape end, + int offset) { int idx[nvec]; Shape ref_strides; @@ -417,7 +421,11 @@ inline VectorType load_slice(const DType * input, const Shape } template -inline VectorType fast_load_slice(const DType * input, const Shape& shape, Shape begin, Shape end, int offset) { +__device__ inline VectorType fast_load_slice(const DType * input, + const Shape& shape, + Shape begin, + Shape end, + int offset) { int idx = 0; Shape ref_strides; @@ -447,7 +455,7 @@ inline VectorType fast_load_slice(const DType * input, const Shape< } template -inline void store_index(const VectorType value, int i, +__device__ inline void store_index(const VectorType value, int i, DType * output, const Shape& shape) { if (i < (shape.size + nvec - 1) / nvec) { auto vector_output = reinterpret_cast< @@ -457,7 +465,7 @@ inline void store_index(const VectorType value, int i, } template -inline void store_add_index(const VectorType value, int i, +__device__ inline void store_add_index(const VectorType value, int i, DType * output, const Shape& shape) { if (i < (shape.size + nvec - 1) / nvec) { auto vector_output = reinterpret_cast< @@ -469,116 +477,116 @@ inline void store_add_index(const VectorType value, int i, } template -inline DType identity(const DType val) { +__device__ inline DType identity(const DType val) { return val; } template -inline DType add(const DType a, const DType2 b) { +__device__ inline DType add(const DType a, const DType2 b) { return a + b; } template -inline DType sub(const DType a, const DType2 b) { +__device__ inline DType sub(const DType a, const DType2 b) { return a - b; } template -inline DType mul(const DType a, const DType2 b) { +__device__ inline DType mul(const DType a, const DType2 b) { return a * b; } template -inline DType div(const DType a, const DType2 b) { +__device__ inline DType div(const DType a, const DType2 b) { return a / b; } template -inline DType rdiv(const DType a, const DType2 b) { +__device__ inline DType rdiv(const DType a, const DType2 b) { return b / a; } template -inline DType power(const DType a, const DType2 b) { +__device__ inline DType power(const DType a, const DType2 b) { return powf(a, b); } template -inline DType rpow(const DType a, const DType2 b) { +__device__ inline DType rpow(const DType a, const DType2 b) { return powf(b, a); } template -inline DType max(const DType a, const DType2 b) { +__device__ inline DType max(const DType a, const DType2 b) { return a > b ? a : b; } template -inline DType min(const DType a, const DType2 b) { +__device__ inline DType min(const DType a, const DType2 b) { return a < b ? a : b; } template -inline DType hypot(const DType a, const DType2 b) { +__device__ inline DType hypot(const DType a, const DType2 b) { return hypotf(a, b); } template -inline typename LoadType::Type cast(const DType val) { +__device__ inline typename LoadType::Type cast(const DType val) { return static_cast::Type>(val); } // activations template -inline DType relu(const DType val) { +__device__ inline DType relu(const DType val) { return val > 0 ? val : 0; } template -inline DType sigmoid(const DType val) { +__device__ inline DType sigmoid(const DType val) { return 1.f/(1 + expf(-val)); } template -inline DType softrelu(const DType val) { +__device__ inline DType softrelu(const DType val) { return logf(1 + expf(val)); } template -inline DType softsign(const DType val) { +__device__ inline DType softsign(const DType val) { return val / (1 + fabsf(val)); } // exp and log template -inline DType exp(const DType val) { +__device__ inline DType exp(const DType val) { return expf(val); } template -inline DType expm1(const DType val) { +__device__ inline DType expm1(const DType val) { return expm1f(val); } template -inline DType log(const DType val) { +__device__ inline DType log(const DType val) { return logf(val); } template -inline DType log10(const DType val) { +__device__ inline DType log10(const DType val) { return log10f(val); } template -inline DType log2(const DType val) { +__device__ inline DType log2(const DType val) { return log2f(val); } template -inline DType log1p(const DType val) { +__device__ inline DType log1p(const DType val) { return log1pf(val); } @@ -587,197 +595,197 @@ inline DType log1p(const DType val) { constexpr double pi = 3.14159265358979323846; template -inline DType degrees(const DType val) { +__device__ inline DType degrees(const DType val) { return (val / pi) * 180; } template -inline DType radians(const DType val) { +__device__ inline DType radians(const DType val) { return (val / 180.0) * pi; } template -inline DType sin(const DType val) { +__device__ inline DType sin(const DType val) { return sinf(val); } template -inline DType cos(const DType val) { +__device__ inline DType cos(const DType val) { return cosf(val); } template -inline DType tan(const DType val) { +__device__ inline DType tan(const DType val) { return tanf(val); } template -inline DType arcsin(const DType val) { +__device__ inline DType arcsin(const DType val) { return asinf(val); } template -inline DType arccos(const DType val) { +__device__ inline DType arccos(const DType val) { return acosf(val); } template -inline DType arctan(const DType val) { +__device__ inline DType arctan(const DType val) { return atanf(val); } template -inline DType sinh(const DType val) { +__device__ inline DType sinh(const DType val) { return sinhf(val); } template -inline DType cosh(const DType val) { +__device__ inline DType cosh(const DType val) { return coshf(val); } template -inline DType tanh(const DType val) { +__device__ inline DType tanh(const DType val) { return tanhf(val); } template -inline DType arcsinh(const DType val) { +__device__ inline DType arcsinh(const DType val) { return asinhf(val); } template -inline DType arccosh(const DType val) { +__device__ inline DType arccosh(const DType val) { return acoshf(val); } template -inline DType arctanh(const DType val) { +__device__ inline DType arctanh(const DType val) { return atanhf(val); } // sqrt template -inline DType sqrt(const DType val) { +__device__ inline DType sqrt(const DType val) { return sqrtf(val); } template -inline DType rsqrt(const DType val) { +__device__ inline DType rsqrt(const DType val) { return rsqrtf(val); } template -inline DType cbrt(const DType val) { +__device__ inline DType cbrt(const DType val) { return cbrtf(val); } template -inline DType rcbrt(const DType val) { +__device__ inline DType rcbrt(const DType val) { return rcbrtf(val); } template -inline DType square(const DType val) { +__device__ inline DType square(const DType val) { return val * val; } template -inline typename LoadType::Type zero(const DType val) { +__device__ inline typename LoadType::Type zero(const DType val) { return 0; } template -inline typename LoadType::Type zero() { +__device__ inline typename LoadType::Type zero() { return 0; } template -inline typename LoadType::Type one(const DType val) { +__device__ inline typename LoadType::Type one(const DType val) { return 1; } template -inline typename LoadType::Type one() { +__device__ inline typename LoadType::Type one() { return 1; } template -inline DType round(const DType val) { +__device__ inline DType round(const DType val) { return roundf(val); } template -inline DType rint(const DType val) { +__device__ inline DType rint(const DType val) { return rintf(val); } template -inline DType fix(const DType val) { +__device__ inline DType fix(const DType val) { const auto floor = floorf(val); const auto ceil = ceilf(val); return (floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil; } template -inline DType floor(const DType val) { +__device__ inline DType floor(const DType val) { return floorf(val); } template -inline DType ceil(const DType val) { +__device__ inline DType ceil(const DType val) { return ceilf(val); } template -inline DType trunc(const DType val) { +__device__ inline DType trunc(const DType val) { return truncf(val); } template -inline DType clip(const DType val, const float a_min, const float a_max) { +__device__ inline DType clip(const DType val, const float a_min, const float a_max) { return max(min(val, a_max), a_min); } template -inline DType sign(const DType val) { +__device__ inline DType sign(const DType val) { if (val < 0) return -1; return val > 0 ? 1 : 0; } template -inline DType reciprocal(const DType val) { +__device__ inline DType reciprocal(const DType val) { return 1.0f / val; } template -inline DType abs(const DType val) { +__device__ inline DType abs(const DType val) { return fabsf(val); } template -inline DType gamma(const DType val) { +__device__ inline DType gamma(const DType val) { return tgammaf(val); } template -inline DType gammaln(const DType val) { +__device__ inline DType gammaln(const DType val) { return lgammaf(val); } template -inline DType erf(const DType val) { +__device__ inline DType erf(const DType val) { return erff(val); } template -inline DType erfinv(const DType val) { +__device__ inline DType erfinv(const DType val) { return erfinvf(val); } template -inline DType1 smooth_l1(const DType1 val, const DType2 scalar) { +__device__ inline DType1 smooth_l1(const DType1 val, const DType2 scalar) { const auto bsq = scalar * scalar; const auto ibsq = 1.0f / bsq; if (val > ibsq) { @@ -798,147 +806,148 @@ const char backward_function_definitions[] = R"code( namespace op { template -inline DTypeGrad backward_relu(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_relu(const DType val, const DTypeGrad grad) { return val > 0 ? grad : 0; } template -inline DTypeGrad backward_sigmoid(const DType out, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_sigmoid(const DType out, const DTypeGrad grad) { return grad * out * (1 - out); } template -inline DTypeGrad backward_softrelu(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_softrelu(const DType val, const DTypeGrad grad) { return grad * sigmoid(val); } template -inline DTypeGrad backward_softsign(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_softsign(const DType val, const DTypeGrad grad) { const DType ap1 = 1 + fabsf(val); return grad / (ap1 * ap1); } template -inline DTypeGrad backward_exp(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_exp(const DType val, const DTypeGrad grad) { return grad * expf(val); } template -inline DTypeGrad backward_expm1(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_expm1(const DType val, const DTypeGrad grad) { return grad * expf(val); } template -inline DTypeGrad backward_log(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_log(const DType val, const DTypeGrad grad) { return grad / val; } template -inline DTypeGrad backward_log10(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_log10(const DType val, const DTypeGrad grad) { return grad / (val * logf(10)); } template -inline DTypeGrad backward_log2(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_log2(const DType val, const DTypeGrad grad) { return grad / (val * logf(2)); } template -inline DTypeGrad backward_log1p(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_log1p(const DType val, const DTypeGrad grad) { return grad / (1 + val); } template -inline DTypeGrad backward_sin(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_sin(const DType val, const DTypeGrad grad) { return grad * cosf(val); } template -inline DTypeGrad backward_cos(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_cos(const DType val, const DTypeGrad grad) { return -grad * sinf(val); } // Uses output from tan template -inline DTypeGrad backward_tan(const DType out, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_tan(const DType out, const DTypeGrad grad) { return grad * (out * out + 1); } template -inline DTypeGrad backward_arcsin(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_arcsin(const DType val, const DTypeGrad grad) { return grad / sqrtf(1 - val*val); } template -inline DTypeGrad backward_arccos(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_arccos(const DType val, const DTypeGrad grad) { return -grad / sqrtf(1 - val*val); } template -inline DTypeGrad backward_arctan(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_arctan(const DType val, const DTypeGrad grad) { return grad / (1 + val*val); } template -inline DTypeGrad backward_sinh(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_sinh(const DType val, const DTypeGrad grad) { return grad * coshf(val); } template -inline DTypeGrad backward_cosh(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_cosh(const DType val, const DTypeGrad grad) { return grad * sinhf(val); } // Uses tanh output template -inline DTypeGrad backward_tanh(const DType out, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_tanh(const DType out, const DTypeGrad grad) { return grad * (1 - out * out); } template -inline DTypeGrad backward_arcsinh(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_arcsinh(const DType val, const DTypeGrad grad) { return grad / sqrtf(val * val + 1); } template -inline DTypeGrad backward_arccosh(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_arccosh(const DType val, const DTypeGrad grad) { return grad / sqrtf(val * val - 1); } template -inline DTypeGrad backward_arctanh(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_arctanh(const DType val, const DTypeGrad grad) { return grad / (1 - val * val); } template -inline DTypeGrad backward_sqrt(const DType out, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_sqrt(const DType out, const DTypeGrad grad) { return 0.5 * grad / out; } template -inline DTypeGrad backward_rsqrt(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_rsqrt(const DType val, const DTypeGrad grad) { const DType inv = 1 / val; return -0.5 * grad * sqrtf(inv) * inv; } template -inline DTypeGrad backward_cbrt(const DType out, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_cbrt(const DType out, const DTypeGrad grad) { return grad / (3.0f * out * out); } template -inline DTypeGrad backward_rcbrt(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_rcbrt(const DType val, const DTypeGrad grad) { const DType inv = 1 / val; return -1.f/3.f * grad * cbrtf(inv) * inv; } template -inline DTypeGrad backward_square(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_square(const DType val, const DTypeGrad grad) { return 2 * val * grad; } template -inline DTypeGrad backward_clip(const DType val, const DTypeGrad grad, const float a_min, const float a_max) { +__device__ inline DTypeGrad backward_clip(const DType val, const DTypeGrad grad, + const float a_min, const float a_max) { if (val > a_max || val < a_min) { return 0; } else { @@ -947,22 +956,23 @@ inline DTypeGrad backward_clip(const DType val, const DTypeGrad grad, const floa } template -inline DTypeGrad backward_reciprocal(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_reciprocal(const DType val, const DTypeGrad grad) { return -grad / (val * val); } template -inline DTypeGrad backward_erf(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_erf(const DType val, const DTypeGrad grad) { return 2.0f / sqrt(pi) * exp(-(val*val)) * grad; } template -inline DTypeGrad backward_erfinv(const DType val, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_erfinv(const DType val, const DTypeGrad grad) { return 0.5f * sqrt(pi) * exp(val * val) * grad; } template -inline DTypeGrad backward_smooth_l1(const DType val, const DType2 scalar, const DTypeGrad grad) { +__device__ inline DTypeGrad backward_smooth_l1(const DType val, const DType2 scalar, + const DTypeGrad grad) { auto bsq = scalar * scalar; auto ibsq = 1.0f / bsq; if (val > ibsq) { diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu index 34360544cd11..c8a888301abd 100644 --- a/src/operator/fusion/fused_op.cu +++ b/src/operator/fusion/fused_op.cu @@ -594,13 +594,12 @@ CUfunction FusedOp::CompileCode(const std::string &code, std::string gpu_arch_arg = "--gpu-architecture=compute_" + std::to_string(sm_arch); const char *opts[] = {gpu_arch_arg.c_str(), - "--std=c++11", - "-default-device"}; + "--std=c++11"}; const std::string kernel_name_demangled = "FusedKernel_" + kernel_name; NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str())); nvrtcResult compileResult = nvrtcCompileProgram(program, // prog - 3, // num options + 2, // num options opts); // options CHECK_EQ(compileResult, NVRTC_SUCCESS) << "NVRTC Compilation failed. Please set environment variable MXNET_USE_FUSION to 0.\n"