Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add gelu fuse ops (#18082) #18092

Merged
merged 1 commit into from
Apr 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/executor/pointwise_fusion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ namespace {
op_name) !=
variable_io_ops.end())
return true;
if (op_name == "LeakyReLU") {
std::string act_type = n->attrs.dict.at("act_type");
if (LeakyReLU_ops.count(act_type))
return true;
else
return false;
}
if (op_name == "_backward_LeakyReLU") {
std::string act_type = n->attrs.dict.at("act_type");
if (LeakyReLU_bwd_ops.count(act_type))
return true;
else
return false;
}
return false;
}

Expand Down
23 changes: 23 additions & 0 deletions src/operator/fusion/fused_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
{"(% * % / op::hypot(%, %))", "_0", "_2", "_1", "_2"}}}
};

// LeakyReLU ops: based on "act_type" attribute
const std::map<std::string, std::vector<std::vector<std::string>>> LeakyReLU_ops = {
{"gelu" , {{"op::gelu(%)", "_0"}}},
};
const std::map<std::string, std::vector<std::vector<std::string>>> LeakyReLU_bwd_ops = {
{"gelu" , {{"op::backward_gelu(%, %)", "_1", "_0"}}},
};

const std::map<std::string, std::string> slice_ops = {
{"slice_axis" , ""},
{"slice" , ""},
Expand Down Expand Up @@ -543,6 +551,14 @@ __device__ inline DType relu(const DType val) {
return val > 0 ? val : 0;
}

const float SQRT_2 = 1.4142135623730950488016887242096;
// compatible with mshadow_op.h version
template <typename DType>
__device__ inline DType gelu(const DType val) {
return DType(0.5f * static_cast<float>(val) *
(1.0f + erf(static_cast<float>(val) / SQRT_2)));
}

template <typename DType>
__device__ inline DType sigmoid(const DType val) {
return 1.f/(1 + expf(-val));
Expand Down Expand Up @@ -984,6 +1000,13 @@ __device__ inline DTypeGrad backward_smooth_l1(const DType val, const DType2 sca
}
}

// compatible with mshadow_op.h version
template <typename DType, typename DTypeGrad>
__device__ inline DTypeGrad backward_gelu(const DType val, const DTypeGrad grad) {
return grad * DType(0.5f * (1.0f + erf(static_cast<float>(val) / SQRT_2) +
static_cast<float>(val) * backward_erf(static_cast<float>(val) / SQRT_2, 1.0f) / SQRT_2));
}

} // namespace op

)code";
Expand Down
36 changes: 36 additions & 0 deletions src/operator/fusion/fused_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,42 @@ std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
continue;
}

// LeakyReLU, look for act_type
if (op_name == "LeakyReLU") {
std::string act_type = node.source->attrs.dict.at("act_type");
const std::vector<std::vector<std::string>>& op_descs =
fusion::LeakyReLU_ops.at(act_type);
if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_ops.end()) {
CHECK_EQ(outputs[i], op_descs.size());
size_t count = 0;
for (const auto& op_desc : op_descs) {
var_name = "temp" + std::to_string(temp_name_counter++);
const std::string& fmt = ParseOpDescription(op_desc, variables, node);
code += "const auto " + var_name + " = " + fmt + ";\n";
variables[{i, count}] = var_name;
++count;
}
continue;
}
}
if (op_name == "_backward_LeakyReLU") {
std::string act_type = node.source->attrs.dict.at("act_type");
const std::vector<std::vector<std::string>>& op_descs =
fusion::LeakyReLU_bwd_ops.at(act_type);
if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_bwd_ops.end()) {
CHECK_EQ(outputs[i], op_descs.size());
size_t count = 0;
for (const auto& op_desc : op_descs) {
var_name = "temp" + std::to_string(temp_name_counter++);
const std::string& fmt = ParseOpDescription(op_desc, variables, node);
code += "const auto " + var_name + " = " + fmt + ";\n";
variables[{i, count}] = var_name;
++count;
}
continue;
}
}

LOG(FATAL) << "Unrecognized op " + op_name;
}
} else {
Expand Down
13 changes: 13 additions & 0 deletions tests/python/gpu/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,24 @@ def check_other_ops():
arr2 = mx.random.uniform(shape=(2,2,2,3))
check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2)

def check_leakyrelu_ops():
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
shape = rand_shape_2d()
arr1 = mx.random.uniform(shape=shape)
arr2 = mx.random.uniform(shape=shape)

# Testing gelu
print("Checking fusion of LeakyReLU:gelu")
check_fused_symbol(mx.sym.LeakyReLU(a+b, act_type='gelu'), a=arr1, b=arr2)


@with_seed()
def test_fusion():
check_unary_ops()
check_binary_ops()
check_other_ops()
check_leakyrelu_ops()

@with_seed()
def test_fusion_compiler_cache():
Expand Down