Skip to content

Commit f812c39

Browse files
committed
Added ElectraBackbone
1 parent f77762b commit f812c39

File tree

1 file changed

+198
-0
lines changed

1 file changed

+198
-0
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
17+
from keras_nlp.api_export import keras_nlp_export
18+
from keras_nlp.backend import keras
19+
from keras_nlp.layers.modeling.token_and_position_embedding import (
20+
PositionEmbedding, ReversibleEmbedding
21+
)
22+
from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder
23+
from keras_nlp.models.backbone import Backbone
24+
from keras_nlp.utils.python_utils import classproperty
25+
26+
27+
def electra_kernel_initializer(stddev=0.02):
28+
return keras.initializers.TruncatedNormal(stddev=stddev)
29+
30+
@keras_nlp_export("keras_nlp.models.ElectraBackbone")
31+
class ElectraBackbone(Backbone):
32+
"""A Electra encoder network.
33+
34+
This network implements a bi-directional Transformer-based encoder as
35+
described in ["Electra: Pre-training Text Encoders as Discriminators Rather
36+
Than Generators"](https://arxiv.org/abs/2003.10555). It includes the
37+
embedding lookups and transformer layers, but not the masked language model
38+
or classification task networks.
39+
40+
The default constructor gives a fully customizable, randomly initialized
41+
Electra encoder with any number of layers, heads, and embedding
42+
dimensions. To load preset architectures and weights, use the
43+
`from_preset()` constructor.
44+
45+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
46+
warranties or conditions of any kind. The underlying model is provided by a
47+
third party and subject to a separate license, available
48+
[here](https://huggingface.co/docs/transformers/model_doc/electra#overview).
49+
"""
50+
51+
def __init__(
52+
self,
53+
vocabulary_size,
54+
num_layers,
55+
num_heads,
56+
embedding_size,
57+
hidden_size,
58+
intermediate_dim,
59+
dropout=0.1,
60+
max_sequence_length=512,
61+
num_segments=2,
62+
**kwargs
63+
):
64+
# Index of classification token in the vocabulary
65+
cls_token_index = 0
66+
# Inputs
67+
token_id_input = keras.Input(
68+
shape=(None,), dtype="int32", name="token_ids"
69+
)
70+
segment_id_input = keras.Input(
71+
shape=(None,), dtype="int32", name="segment_ids"
72+
)
73+
padding_mask = keras.Input(
74+
shape=(None,), dtype="int32", name="padding_mask"
75+
)
76+
77+
# Embed tokens, positions, and segment ids.
78+
token_embedding_layer = ReversibleEmbedding(
79+
input_dim=vocabulary_size,
80+
output_dim=embedding_size,
81+
embeddings_initializer=electra_kernel_initializer(),
82+
name="token_embedding",
83+
)
84+
token_embedding = token_embedding_layer(token_id_input)
85+
position_embedding = PositionEmbedding(
86+
input_dim=max_sequence_length,
87+
output_dim=embedding_size,
88+
merge_mode="add",
89+
embeddings_initializer=electra_kernel_initializer(),
90+
name="position_embedding",
91+
)(token_embedding)
92+
segment_embedding = keras.layers.Embedding(
93+
input_dim=max_sequence_length,
94+
output_dim=embedding_size,
95+
embeddings_initializer=electra_kernel_initializer(),
96+
name="segment_embedding",
97+
)(segment_id_input)
98+
99+
# Add all embeddings together.
100+
x = keras.layers.Add()(
101+
(token_embedding, position_embedding, segment_embedding)
102+
)
103+
# Layer normalization
104+
x = keras.layers.LayerNormalization(
105+
name="embeddings_layer_norm",
106+
axis=-1,
107+
epsilon=1e-12,
108+
dtype="float32",
109+
)(x)
110+
# Dropout
111+
x = keras.layers.Dropout(
112+
dropout,
113+
name="embeddings_dropout",
114+
)(x)
115+
# Project to hidden dim
116+
if hidden_size != embedding_size:
117+
x = keras.layers.Dense(
118+
hidden_size,
119+
kernel_initializer=electra_kernel_initializer(),
120+
name="embedding_projection",
121+
)(x)
122+
123+
# Apply successive transformer encoder blocks.
124+
for i in range(num_layers):
125+
x = TransformerEncoder(
126+
num_heads=num_heads,
127+
intermediate_dim=intermediate_dim,
128+
activation="gelu",
129+
dropout=dropout,
130+
layer_norm_epsilon=1e-12,
131+
kernel_initializer=electra_kernel_initializer(),
132+
name=f"transformer_layer_{i}",
133+
)(x, padding_mask=padding_mask)
134+
135+
sequence_output = x
136+
x = keras.layers.Dense(
137+
hidden_size,
138+
kernel_initializer=electra_kernel_initializer(),
139+
activation="tanh",
140+
name="pooled_dense",
141+
)(x)
142+
pooled_output = x[:, cls_token_index, :]
143+
144+
# Instantiate using Functional API Model constructor
145+
super().__init__(
146+
inputs={
147+
"token_ids": token_id_input,
148+
"segment_ids": segment_id_input,
149+
"padding_mask": padding_mask,
150+
},
151+
outputs={
152+
"sequence_output": sequence_output,
153+
"pooled_output": pooled_output,
154+
},
155+
**kwargs,
156+
)
157+
158+
# All references to self below this line
159+
self.vocab_size = vocabulary_size
160+
self.num_layers = num_layers
161+
self.num_heads = num_heads
162+
self.hidden_size = hidden_size
163+
self.embedding_size = embedding_size
164+
self.intermediate_dim = intermediate_dim
165+
self.dropout = dropout
166+
self.max_sequence_length = max_sequence_length
167+
self.num_segments = num_segments
168+
self.cls_token_index = cls_token_index
169+
self.token_embedding = token_embedding_layer
170+
171+
def get_config(self):
172+
config = super().get_config()
173+
config.update(
174+
{
175+
"vocab_size": self.vocab_size,
176+
"num_layers": self.num_layers,
177+
"num_heads": self.num_heads,
178+
"hidden_size": self.hidden_size,
179+
"embedding_size": self.embedding_size,
180+
"intermediate_dim": self.intermediate_dim,
181+
"dropout": self.dropout,
182+
"max_sequence_length": self.max_sequence_length,
183+
"num_segments": self.num_segments,
184+
"cls_token_index": self.cls_token_index,
185+
"token_embedding": self.token_embedding,
186+
}
187+
)
188+
return config
189+
190+
191+
192+
193+
194+
195+
196+
197+
198+

0 commit comments

Comments
 (0)