Skip to content

Commit

Permalink
cpu: rnn: make activation function non-templated
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Mar 22, 2024
1 parent 6f5621a commit e92c404
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 57 deletions.
23 changes: 3 additions & 20 deletions src/cpu/rnn/postgemm_dispatcher.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2023 Intel Corporation
* Copyright 2019-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,8 +45,8 @@ namespace dnnl {
namespace impl {
namespace cpu {

template <alg_kind_t alg_kind, prop_kind_t prop_kind>
float activation(float s, float alpha, float cliping);
float activation(alg_kind_t alg_kind, prop_kind_t prop_kind, float s,
float alpha, float cliping);

template <prop_kind_t aprop, impl::data_type_t src_type,
impl::data_type_t scratch_type, impl::data_type_t acc_type>
Expand Down Expand Up @@ -88,22 +88,6 @@ struct rnn_postgemm_dispatcher {
break;
case alg_kind::vanilla_rnn:
postgemm_func = &class_name::rnn_postgemm;
switch (pd->activation_kind()) {
case alg_kind::eltwise_relu:
activation_func
= &activation<alg_kind::eltwise_relu, aprop>;
break;
case alg_kind::eltwise_tanh:
activation_func
= &activation<alg_kind::eltwise_tanh, aprop>;
break;
case alg_kind::eltwise_logistic:
activation_func
= &activation<alg_kind::eltwise_logistic,
aprop>;
break;
default: assert(!"Unsupported activation function"); break;
}
break;
case alg_kind::vanilla_gru:
case alg_kind::vanilla_augru:
Expand Down Expand Up @@ -233,7 +217,6 @@ struct rnn_postgemm_dispatcher {
}

protected:
float (*activation_func)(float s, float alpha, float cliping);
virtual rnn_postgemm_sig(rnn_postgemm) = 0;
virtual rnn_postgemm_sig(lstm_postgemm) = 0;
virtual rnn_postgemm_sig(lstm_projection_postgemm) = 0;
Expand Down
65 changes: 28 additions & 37 deletions src/cpu/rnn/ref_postgemm_rnn.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2023 Intel Corporation
* Copyright 2018-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,40 +31,29 @@ using namespace dnnl::impl::utils;
using namespace dnnl::impl::math;
using namespace rnn_utils;

template <>
float activation<alg_kind::eltwise_relu, prop_kind::forward>(
float s, float alpha, float cliping) {
return relu_fwd<float>(s, alpha);
}

template <>
float activation<alg_kind::eltwise_relu, prop_kind::backward>(
float s, float alpha, float cliping) {
return relu_bwd<float>(s, alpha);
}

template <>
float activation<alg_kind::eltwise_tanh, prop_kind::forward>(
float s, float alpha, float cliping) {
return tanh_fwd<float>(s);
}

template <>
float activation<alg_kind::eltwise_tanh, prop_kind::backward>(
float s, float alpha, float cliping) {
return one_m_square<float>(s);
}

template <>
float activation<alg_kind::eltwise_logistic, prop_kind::forward>(
float s, float alpha, float cliping) {
return logistic_fwd<float>(s);
}

template <>
float activation<alg_kind::eltwise_logistic, prop_kind::backward>(
float s, float alpha, float cliping) {
return x_m_square<float>(s);
float activation(alg_kind_t alg_kind, prop_kind_t prop_kind, float s,
float alpha, float cliping) {
using namespace dnnl::impl::alg_kind;

if (prop_kind == prop_kind::forward
|| prop_kind == prop_kind::forward_inference) {
switch (alg_kind) {
case eltwise_relu: return relu_fwd<float>(s, alpha);
case eltwise_tanh: return tanh_fwd<float>(s);
case eltwise_logistic: return logistic_fwd<float>(s);
default: assert(!"unsupported algorithm");
}
} else if (prop_kind == prop_kind::backward) {
switch (alg_kind) {
case eltwise_relu: return relu_bwd<float>(s, alpha);
case eltwise_tanh: return one_m_square<float>(s);
case eltwise_logistic: return x_m_square<float>(s);
default: assert(!"unsupported algorithm");
}
} else {
assert(!"unsupported propagation kind");
}
return NAN;
}

constexpr float linear(float s, float alpha, float clipping) {
Expand Down Expand Up @@ -118,7 +107,8 @@ rnn_postgemm_sig(
(rnn_postgemm_fwd_t<src_type, scratch_type, acc_type>::rnn_postgemm)) {
const float *scales = this->pd_->attr()->rnn_tparams_.scales_;
const auto act_f = [this](float a, float alpha, float clipping) {
return gates_t(this->activation_func(a, alpha, clipping));
return gates_t(activation(this->pd_->activation_kind(),
this->pd_->get_prop_kind(), a, alpha, clipping));
};
const auto linear_f = [](float a, float alpha, float clipping) {
return gates_t(linear(a, alpha, clipping));
Expand Down Expand Up @@ -178,7 +168,8 @@ rnn_postgemm_sig(
(rnn_postgemm_bwd_t<src_type, scratch_type, acc_type>::rnn_postgemm)) {
const float *scales = this->pd_->attr()->rnn_tparams_.scales_;
const auto act_f = [this](float a, float alpha, float clipping) {
return this->activation_func(a, alpha, 0);
return activation(this->pd_->activation_kind(),
this->pd_->get_prop_kind(), a, alpha, 0);
};
const auto linear_f = [](float a, float alpha, float clipping) {
return linear(a, alpha, 0);
Expand Down

0 comments on commit e92c404

Please sign in to comment.