1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import copy
16-
1715from keras_nlp .api_export import keras_nlp_export
1816from keras_nlp .backend import keras
19- from keras_nlp .layers .modeling .token_and_position_embedding import (
20- PositionEmbedding , ReversibleEmbedding
21- )
17+ from keras_nlp .layers .modeling .position_embedding import PositionEmbedding
18+ from keras_nlp .layers .modeling .reversible_embedding import ReversibleEmbedding
2219from keras_nlp .layers .modeling .transformer_encoder import TransformerEncoder
2320from keras_nlp .models .backbone import Backbone
24- from keras_nlp .utils .python_utils import classproperty
2521
2622
2723def electra_kernel_initializer (stddev = 0.02 ):
2824 return keras .initializers .TruncatedNormal (stddev = stddev )
2925
26+
3027@keras_nlp_export ("keras_nlp.models.ElectraBackbone" )
3128class ElectraBackbone (Backbone ):
3229 """A Electra encoder network.
@@ -46,20 +43,56 @@ class ElectraBackbone(Backbone):
4643 warranties or conditions of any kind. The underlying model is provided by a
4744 third party and subject to a separate license, available
4845 [here](https://huggingface.co/docs/transformers/model_doc/electra#overview).
46+
47+ Args:
48+ vocabulary_size: int. The size of the token vocabulary.
49+ num_layers: int. The number of transformer layers.
50+ num_heads: int. The number of attention heads for each transformer.
51+ The hidden size must be divisible by the number of attention heads.
52+ hidden_dim: int. The size of the transformer encoding and pooler layers.
53+ embedding_size: int. The size of the token embeddings.
54+ intermediate_dim: int. The output dimension of the first Dense layer in
55+ a two-layer feedforward network for each transformer.
56+ dropout: float. Dropout probability for the Transformer encoder.
57+ max_sequence_length: int. The maximum sequence length that this encoder
58+ can consume. If None, `max_sequence_length` uses the value from
59+ sequence length. This determines the variable shape for positional
60+ embeddings.
61+
62+ Examples:
63+ ```python
64+ input_data = {
65+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
66+ "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]),
67+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
68+ }
69+ # Randomly initialized Electra encoder
70+ backbone = keras_nlp.models.ElectraBackbone(
71+ vocabulary_size=1000,
72+ num_layers=2,
73+ num_heads=2,
74+ hidden_size=32,
75+ intermediate_dim=64,
76+ dropout=0.1,
77+ max_sequence_length=512,
78+ )
79+ # Returns sequence and pooled outputs.
80+ sequence_output, pooled_output = backbone(input_data)
81+ ```
4982 """
5083
5184 def __init__ (
52- self ,
53- vocabulary_size ,
54- num_layers ,
55- num_heads ,
56- embedding_size ,
57- hidden_size ,
58- intermediate_dim ,
59- dropout = 0.1 ,
60- max_sequence_length = 512 ,
61- num_segments = 2 ,
62- ** kwargs
85+ self ,
86+ vocabulary_size ,
87+ num_layers ,
88+ num_heads ,
89+ embedding_size ,
90+ hidden_size ,
91+ intermediate_size ,
92+ dropout = 0.1 ,
93+ max_sequence_length = 512 ,
94+ num_segments = 2 ,
95+ ** kwargs ,
6396 ):
6497 # Index of classification token in the vocabulary
6598 cls_token_index = 0
@@ -83,14 +116,12 @@ def __init__(
83116 )
84117 token_embedding = token_embedding_layer (token_id_input )
85118 position_embedding = PositionEmbedding (
86- input_dim = max_sequence_length ,
87- output_dim = embedding_size ,
88- merge_mode = "add" ,
89- embeddings_initializer = electra_kernel_initializer (),
119+ initializer = electra_kernel_initializer (),
120+ sequence_length = max_sequence_length ,
90121 name = "position_embedding" ,
91122 )(token_embedding )
92123 segment_embedding = keras .layers .Embedding (
93- input_dim = max_sequence_length ,
124+ input_dim = num_segments ,
94125 output_dim = embedding_size ,
95126 embeddings_initializer = electra_kernel_initializer (),
96127 name = "segment_embedding" ,
@@ -124,7 +155,7 @@ def __init__(
124155 for i in range (num_layers ):
125156 x = TransformerEncoder (
126157 num_heads = num_heads ,
127- intermediate_dim = intermediate_dim ,
158+ intermediate_dim = intermediate_size ,
128159 activation = "gelu" ,
129160 dropout = dropout ,
130161 layer_norm_epsilon = 1e-12 ,
@@ -161,7 +192,7 @@ def __init__(
161192 self .num_heads = num_heads
162193 self .hidden_size = hidden_size
163194 self .embedding_size = embedding_size
164- self .intermediate_dim = intermediate_dim
195+ self .intermediate_dim = intermediate_size
165196 self .dropout = dropout
166197 self .max_sequence_length = max_sequence_length
167198 self .num_segments = num_segments
@@ -186,13 +217,3 @@ def get_config(self):
186217 }
187218 )
188219 return config
189-
190-
191-
192-
193-
194-
195-
196-
197-
198-
0 commit comments