diff --git a/src/executor/pointwise_fusion_pass.cc b/src/executor/pointwise_fusion_pass.cc index 5db9706b4f99..3203f67e8b68 100644 --- a/src/executor/pointwise_fusion_pass.cc +++ b/src/executor/pointwise_fusion_pass.cc @@ -71,6 +71,20 @@ namespace { op_name) != variable_io_ops.end()) return true; + if (op_name == "LeakyReLU") { + std::string act_type = n->attrs.dict.at("act_type"); + if (LeakyReLU_ops.count(act_type)) + return true; + else + return false; + } + if (op_name == "_backward_LeakyReLU") { + std::string act_type = n->attrs.dict.at("act_type"); + if (LeakyReLU_bwd_ops.count(act_type)) + return true; + else + return false; + } return false; } diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h index 7373cd07400a..005ea4d48390 100644 --- a/src/operator/fusion/fused_op-inl.h +++ b/src/operator/fusion/fused_op-inl.h @@ -224,6 +224,14 @@ const std::map>> ops_desc = { {"(% * % / op::hypot(%, %))", "_0", "_2", "_1", "_2"}}} }; +// LeakyReLU ops: based on "act_type" attribute +const std::map>> LeakyReLU_ops = { + {"gelu" , {{"op::gelu(%)", "_0"}}}, +}; +const std::map>> LeakyReLU_bwd_ops = { + {"gelu" , {{"op::backward_gelu(%, %)", "_1", "_0"}}}, +}; + const std::map slice_ops = { {"slice_axis" , ""}, {"slice" , ""}, @@ -543,6 +551,14 @@ __device__ inline DType relu(const DType val) { return val > 0 ? val : 0; } +const float SQRT_2 = 1.4142135623730950488016887242096; +// compatible with mshadow_op.h version +template +__device__ inline DType gelu(const DType val) { + return DType(0.5f * static_cast(val) * + (1.0f + erf(static_cast(val) / SQRT_2))); +} + template __device__ inline DType sigmoid(const DType val) { return 1.f/(1 + expf(-val)); @@ -984,6 +1000,13 @@ __device__ inline DTypeGrad backward_smooth_l1(const DType val, const DType2 sca } } +// compatible with mshadow_op.h version +template +__device__ inline DTypeGrad backward_gelu(const DType val, const DTypeGrad grad) { + return grad * DType(0.5f * (1.0f + erf(static_cast(val) / SQRT_2) + + static_cast(val) * backward_erf(static_cast(val) / SQRT_2, 1.0f) / SQRT_2)); +} + } // namespace op )code"; diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu index 544dd0221c17..f883c5c4b726 100644 --- a/src/operator/fusion/fused_op.cu +++ b/src/operator/fusion/fused_op.cu @@ -453,6 +453,42 @@ std::string FusedOp::GenerateCode(const std::vector &req, continue; } + // LeakyReLU, look for act_type + if (op_name == "LeakyReLU") { + std::string act_type = node.source->attrs.dict.at("act_type"); + const std::vector>& op_descs = + fusion::LeakyReLU_ops.at(act_type); + if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_ops.end()) { + CHECK_EQ(outputs[i], op_descs.size()); + size_t count = 0; + for (const auto& op_desc : op_descs) { + var_name = "temp" + std::to_string(temp_name_counter++); + const std::string& fmt = ParseOpDescription(op_desc, variables, node); + code += "const auto " + var_name + " = " + fmt + ";\n"; + variables[{i, count}] = var_name; + ++count; + } + continue; + } + } + if (op_name == "_backward_LeakyReLU") { + std::string act_type = node.source->attrs.dict.at("act_type"); + const std::vector>& op_descs = + fusion::LeakyReLU_bwd_ops.at(act_type); + if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_bwd_ops.end()) { + CHECK_EQ(outputs[i], op_descs.size()); + size_t count = 0; + for (const auto& op_desc : op_descs) { + var_name = "temp" + std::to_string(temp_name_counter++); + const std::string& fmt = ParseOpDescription(op_desc, variables, node); + code += "const auto " + var_name + " = " + fmt + ";\n"; + variables[{i, count}] = var_name; + ++count; + } + continue; + } + } + LOG(FATAL) << "Unrecognized op " + op_name; } } else { diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 9a37c0e844a0..f69d50c0d53f 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -213,11 +213,24 @@ def check_other_ops(): arr2 = mx.random.uniform(shape=(2,2,2,3)) check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2) +def check_leakyrelu_ops(): + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + shape = rand_shape_2d() + arr1 = mx.random.uniform(shape=shape) + arr2 = mx.random.uniform(shape=shape) + + # Testing gelu + print("Checking fusion of LeakyReLU:gelu") + check_fused_symbol(mx.sym.LeakyReLU(a+b, act_type='gelu'), a=arr1, b=arr2) + + @with_seed() def test_fusion(): check_unary_ops() check_binary_ops() check_other_ops() + check_leakyrelu_ops() @with_seed() def test_fusion_compiler_cache():