diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 428af92b2b3a..2f31a4dd5e48 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -34,6 +34,7 @@ from huggingface_hub import Repository, list_repo_files from requests import HTTPError +from .activations_tf import get_tf_activation from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save from .file_utils import ( @@ -1957,9 +1958,11 @@ def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, ** num_classes, kernel_initializer=get_initializer(initializer_range), name="summary" ) - self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh" - if self.has_activation: - self.activation = tf.keras.activations.tanh + self.has_activation = False + activation_string = getattr(config, "summary_activation", None) + if activation_string is not None: + self.has_activation = True + self.activation = get_tf_activation(activation_string) self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0 if self.has_first_dropout: