Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add gelu fuse ops #18082

Merged
merged 6 commits into from
Apr 17, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/executor/pointwise_fusion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ namespace {
op_name) !=
variable_io_ops.end())
return true;
if (op_name == "LeakyReLU") {
std::string act_type = n->attrs.dict.at("act_type");
if (LeakyReLU_ops.count(act_type))
return true;
else
return false;
}
if (op_name == "_backward_LeakyReLU") {
std::string act_type = n->attrs.dict.at("act_type");
if (LeakyReLU_bwd_ops.count(act_type))
return true;
else
return false;
}
return false;
}

Expand Down
22 changes: 22 additions & 0 deletions src/operator/fusion/fused_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,15 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
{"(% * % / op::hypot(%, %))", "_0", "_2", "_1", "_2"}}}
};

// LeakyReLU ops: based on "act_type" attribute
const std::map<std::string, std::vector<std::vector<std::string>>> LeakyReLU_ops = {
{"gelu" , {{"op::gelu(%)", "_0"}}},
};
const std::map<std::string, std::vector<std::vector<std::string>>> LeakyReLU_bwd_ops = {
{"gelu" , {{"op::backward_gelu(%, %)", "_1", "_0"}}},
};


const std::map<std::string, std::string> slice_ops = {
{"slice_axis" , ""},
{"slice" , ""},
Expand Down Expand Up @@ -543,6 +552,13 @@ __device__ inline DType relu(const DType val) {
return val > 0 ? val : 0;
}

__constant__ const float SQRT_2 = 1.4142135623730950488016887242096;
ptrendx marked this conversation as resolved.
Show resolved Hide resolved
template <typename DType>
__device__ inline DType gelu(const DType val) {
return DType(0.5f * static_cast<float>(val) *
(1.0f + erf(static_cast<float>(val) / SQRT_2)));
}

template <typename DType>
__device__ inline DType sigmoid(const DType val) {
return 1.f/(1 + expf(-val));
Expand Down Expand Up @@ -987,6 +1003,12 @@ __device__ inline DTypeGrad backward_smooth_l1(const DType val, const DType2 sca
}
}

template <typename DType, typename DTypeGrad>
__device__ inline DTypeGrad backward_gelu(const DType val, const DTypeGrad grad) {
return grad * DType(0.5f * (1.0f + erf(static_cast<float>(val) / SQRT_2) +
static_cast<float>(val) * backward_erf(static_cast<float>(val) / SQRT_2, 1.0f) / SQRT_2));
}

} // namespace op

)code";
Expand Down
36 changes: 36 additions & 0 deletions src/operator/fusion/fused_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,42 @@ std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
continue;
}

// LeakyReLU, look for act_type
if (op_name == "LeakyReLU") {
std::string act_type = node.source->attrs.dict.at("act_type");
const std::vector<std::vector<std::string>>& op_descs =
fusion::LeakyReLU_ops.at(act_type);
if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_ops.end()) {
CHECK_EQ(outputs[i], op_descs.size());
size_t count = 0;
for (const auto& op_desc : op_descs) {
var_name = "temp" + std::to_string(temp_name_counter++);
const std::string& fmt = ParseOpDescription(op_desc, variables, node);
code += "const auto " + var_name + " = " + fmt + ";\n";
variables[{i, count}] = var_name;
++count;
}
continue;
}
}
if (op_name == "_backward_LeakyReLU") {
std::string act_type = node.source->attrs.dict.at("act_type");
const std::vector<std::vector<std::string>>& op_descs =
fusion::LeakyReLU_bwd_ops.at(act_type);
if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_bwd_ops.end()) {
CHECK_EQ(outputs[i], op_descs.size());
size_t count = 0;
for (const auto& op_desc : op_descs) {
var_name = "temp" + std::to_string(temp_name_counter++);
const std::string& fmt = ParseOpDescription(op_desc, variables, node);
code += "const auto " + var_name + " = " + fmt + ";\n";
variables[{i, count}] = var_name;
++count;
}
continue;
}
}

LOG(FATAL) << "Unrecognized op " + op_name;
}
} else {
Expand Down
13 changes: 13 additions & 0 deletions tests/python/gpu/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,24 @@ def check_other_ops():
arr2 = mx.random.uniform(shape=(2,2,2,3))
check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2)

def check_leakyrelu_ops():
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
shape = rand_shape_2d()
arr1 = mx.random.uniform(shape=shape)
arr2 = mx.random.uniform(shape=shape)

# Testing gelu
print("Checking fusion of LeakyReLU:gelu")
check_fused_symbol(mx.sym.LeakyReLU(a+b, act_type='gelu'), a=arr1, b=arr2)


@with_seed()
def test_fusion():
check_unary_ops()
check_binary_ops()
check_other_ops()
check_leakyrelu_ops()

@with_seed()
def test_fusion_compiler_cache():
Expand Down