Skip to content

Commit dada198

Browse files
committed
Complete run_small_preset test for electra
1 parent 4be8d50 commit dada198

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

keras_nlp/models/electra/electra_backbone_test.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_saved_model(self):
5959
def test_smallest_preset(self):
6060
self.run_preset_test(
6161
cls=ElectraBackbone,
62-
preset="electra-small-generator",
62+
preset="electra_small_discriminator_en",
6363
input_data={
6464
"token_ids": ops.array([[101, 1996, 4248, 102]], dtype="int32"),
6565
"segment_ids": ops.zeros((1, 4), dtype="int32"),
@@ -70,10 +70,13 @@ def test_smallest_preset(self):
7070
"pooled_output": (1, 256),
7171
},
7272
# The forward pass from a preset should be stable!
73-
# TODO: Add sequence and pooled output trimmed to 5 tokens.
7473
expected_partial_output={
75-
"sequence_output": (ops.array()),
76-
"pooled_output": (ops.array()),
74+
"sequence_output": (
75+
ops.array([0.32287, 0.18754, -0.22272, -0.24177, 1.18977])
76+
),
77+
"pooled_output": (
78+
ops.array([-0.02974, 0.23383, 0.08430, -0.19471, 0.14822])
79+
),
7780
},
7881
)
7982

keras_nlp/models/electra/electra_preprocessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ class ElectraPreprocessor(Preprocessor):
3838
2. Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`.
3939
with the appropriate `"[CLS]"`, `"[SEP]"` and `"[PAD]"` tokens.
4040
3. Construct a dictionary of with keys `"token_ids"` and `"padding_mask"`,
41-
that can be passed directly to a DistilBERT model.
41+
that can be passed directly to a ELECTRA model.
4242
4343
This layer can be used directly with `tf.data.Dataset.map` to preprocess
4444
string data in the `(x, y, sample_weight)` format used by
4545
`keras.Model.fit`.
4646
4747
Args:
48-
tokenizer: A `keras_nlp.models.DistilBertTokenizer` instance.
48+
tokenizer: A `keras_nlp.models.ElectraTokenizer` instance.
4949
sequence_length: The length of the packed inputs.
5050
truncate: string. The algorithm to truncate a list of batched segments
5151
to fit within `sequence_length`. The value can be either

keras_nlp/models/electra/electra_presets.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
2121
"This is base discriminator model with 12 layers."
2222
),
23-
"params": "109482240",
23+
"params": 109482240,
2424
"official_name": "ELECTRA",
2525
"path": "electra",
26-
"model_card": "https://huggingface.co/google/electra-base-discriminator",
26+
"model_card": "https://github.com/google-research/electra",
2727
},
2828
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_base_discriminator_en/1",
2929
},
@@ -33,10 +33,10 @@
3333
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
3434
"This is small discriminator model with 12 layers."
3535
),
36-
"params": "13,548,800",
36+
"params": 13548800,
3737
"official_name": "ELECTRA",
3838
"path": "electra",
39-
"model_card": "https://huggingface.co/google/electra-small-discriminator",
39+
"model_card": "https://github.com/google-research/electra",
4040
},
4141
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_small_discriminator_en/1",
4242
},
@@ -46,10 +46,10 @@
4646
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
4747
"This is small generator model with 12 layers."
4848
),
49-
"params": "13548800",
49+
"params": 13548800,
5050
"official_name": "ELECTRA",
5151
"path": "electra",
52-
"model_card": "https://huggingface.co/google/electra-small-generator",
52+
"model_card": "https://github.com/google-research/electra",
5353
},
5454
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_small_generator_en/1",
5555
},
@@ -59,10 +59,10 @@
5959
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
6060
"This is base generator model with 12 layers."
6161
),
62-
"params": "33576960",
62+
"params": 33576960,
6363
"official_name": "ELECTRA",
6464
"path": "electra",
65-
"model_card": "https://huggingface.co/google/electra-base-generator",
65+
"model_card": "https://github.com/google-research/electra",
6666
},
6767
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_base_generator_en/1",
6868
},

keras_nlp/models/electra/electra_tokenizer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_errors_missing_special_tokens(self):
4747
def test_smallest_preset(self):
4848
self.run_preset_test(
4949
cls=ElectraTokenizer,
50-
preset="distil_bert_base_en_uncased",
50+
preset="electra_base_discriminator_en",
5151
input_data=["The quick brown fox."],
5252
expected_output=[[1996, 4248, 2829, 4419, 1012]],
5353
)

0 commit comments

Comments
 (0)