Skip to content

Commit f53b9db

Browse files
committed
Add electra presets
1 parent b9d93e0 commit f53b9db

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

keras_nlp/models/electra/electra_presets.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,54 @@
1616
backbone_presets = {
1717
"electra_base_discriminator_en": {
1818
"metadata": {
19-
"description": ("ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
20-
"This is base discriminator model with 12 layers."
21-
),
19+
"description": (
20+
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
21+
"This is base discriminator model with 12 layers."
22+
),
2223
"params": "109482240",
2324
"official_name": "ELECTRA",
2425
"path": "electra",
25-
"model_card": "https://github.com/google-research/electra/blob/master/README.md"
26+
"model_card": "https://huggingface.co/google/electra-base-discriminator",
2627
},
27-
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_base_discriminator_en/2"
28-
}
29-
}
28+
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_base_discriminator_en/1",
29+
},
30+
"electra_small_discriminator_en": {
31+
"metadata": {
32+
"description": (
33+
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
34+
"This is small discriminator model with 12 layers."
35+
),
36+
"params": "13,548,800",
37+
"official_name": "ELECTRA",
38+
"path": "electra",
39+
"model_card": "https://huggingface.co/google/electra-small-discriminator",
40+
},
41+
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_small_discriminator_en/1",
42+
},
43+
"electra_small_generator_en": {
44+
"metadata": {
45+
"description": (
46+
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
47+
"This is small generator model with 12 layers."
48+
),
49+
"params": "13548800",
50+
"official_name": "ELECTRA",
51+
"path": "electra",
52+
"model_card": "https://huggingface.co/google/electra-small-generator",
53+
},
54+
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_small_generator_en/1",
55+
},
56+
"electra_base_generator_en": {
57+
"metadata": {
58+
"description": (
59+
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
60+
"This is base generator model with 12 layers."
61+
),
62+
"params": "33576960",
63+
"official_name": "ELECTRA",
64+
"path": "electra",
65+
"model_card": "https://huggingface.co/google/electra-base-generator",
66+
},
67+
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_base_generator_en/1",
68+
},
69+
}

tools/checkpoint_conversion/convert_electra_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
FLAGS = flags.FLAGS
4545
flags.DEFINE_string(
4646
"preset",
47-
"electra_small_generator_en",
47+
"electra_base_discriminator_en",
4848
f'Must be one of {",".join(PRESET_MAP)}',
4949
)
5050
flags.mark_flag_as_required("preset")

0 commit comments

Comments
 (0)