Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f812c39
Added ElectraBackbone
pranavvp16 Oct 29, 2023
879020a
Merge branch 'keras-team:master' into electra
pranavvp16 Oct 29, 2023
c2aa9bd
Added backbone tests for ELECTRA
pranavvp16 Oct 31, 2023
79df89f
Fix config
pranavvp16 Nov 3, 2023
7bc3697
Add model import to __init__
pranavvp16 Nov 8, 2023
b7bcfcf
add electra tokenizer
pranavvp16 Dec 8, 2023
8d9dd15
add tests for tokenizer
pranavvp16 Dec 8, 2023
273075a
add __init__ file
pranavvp16 Dec 8, 2023
bfbf648
add tokenizer and backbone to models __init__
pranavvp16 Dec 8, 2023
a79deb1
Merge branch 'master' into electra
pranavvp16 Dec 8, 2023
538d938
Fix Failing tokenization test
pranavvp16 Dec 9, 2023
eb8baa5
Merge remote-tracking branch 'origin/electra' into electra
pranavvp16 Dec 9, 2023
b3f81d5
Merge branch 'keras-team:master' into electra
pranavvp16 Dec 16, 2023
47c9119
Add example on usage of the tokenizer with custom vocabulary
pranavvp16 Dec 16, 2023
ec9f683
Merge branch 'keras-team:master' into electra
pranavvp16 Dec 26, 2023
e3bad73
Add conversion script to convert weights from checkpoint
pranavvp16 Jan 1, 2024
148913d
Add electra preprocessor
pranavvp16 Jan 1, 2024
06dfae9
Add presets and tests
pranavvp16 Jan 3, 2024
3b72d15
Add presets config with model weights
pranavvp16 Jan 3, 2024
fcdcbbb
Add checkpoint conversion script
pranavvp16 Jan 3, 2024
d025883
Name conversion for electra models
pranavvp16 Jan 21, 2024
97b94ee
Update naming conventions according to preset names
pranavvp16 Jan 21, 2024
316a15a
Merge branch 'master' into electra
pranavvp16 Jan 21, 2024
b52d8b5
Fix failing tokenizer tests
pranavvp16 Jan 21, 2024
2e038eb
Merge branch 'keras-team:master' into electra
pranavvp16 Feb 4, 2024
e256609
Update checkpoint conversion script according to kaggle
pranavvp16 Feb 5, 2024
33e9fb1
Add validate function
pranavvp16 Feb 5, 2024
5775fad
Merge branch 'keras-team:master' into electra
pranavvp16 Feb 20, 2024
2b70228
Kaggle preset
pranavvp16 Feb 20, 2024
b9d93e0
update preset link
pranavvp16 Feb 20, 2024
f53b9db
Add electra presets
pranavvp16 Mar 5, 2024
4be8d50
Merge branch 'keras-team:master' into electra
pranavvp16 Mar 18, 2024
b268e26
Complete run_small_preset test for electra
pranavvp16 Mar 18, 2024
0411151
Add large variations of electra in presets
pranavvp16 Mar 23, 2024
fa9a2f2
Merge remote-tracking branch 'origin/master' into electra
pranavvp16 Mar 23, 2024
0bb7b64
Fix case issues with electra presets
mattdangerw Mar 26, 2024
c49e4ac
Fix format
mattdangerw Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
DistilBertTokenizer,
)
from keras_nlp.models.electra.electra_backbone import ElectraBackbone
from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor
from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.models.f_net.f_net_classifier import FNetClassifier
Expand Down
20 changes: 18 additions & 2 deletions keras_nlp/models/electra/electra_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.electra.electra_presets import backbone_presets
from keras_nlp.utils.keras_utils import gelu_approximate
from keras_nlp.utils.python_utils import classproperty


def electra_kernel_initializer(stddev=0.02):
Expand All @@ -36,8 +40,9 @@ class ElectraBackbone(Backbone):
or classification task networks.

The default constructor gives a fully customizable, randomly initialized
Electra encoder with any number of layers, heads, and embedding
dimensions.
ELECTRA encoder with any number of layers, heads, and embedding
dimensions. To load preset architectures and weights, use the
`from_preset()` constructor.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
Expand Down Expand Up @@ -70,6 +75,13 @@ class ElectraBackbone(Backbone):
"segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}

# Pre-trained ELECTRA encoder.
model = keras_nlp.models.ElectraBackbone.from_preset(
"electra_base_discriminator_en"
)
model(input_data)

# Randomly initialized Electra encoder
backbone = keras_nlp.models.ElectraBackbone(
vocabulary_size=1000,
Expand Down Expand Up @@ -234,3 +246,7 @@ def get_config(self):
}
)
return config

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
34 changes: 34 additions & 0 deletions keras_nlp/models/electra/electra_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,37 @@ def test_saved_model(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
cls=ElectraBackbone,
preset="electra_small_discriminator_en",
input_data={
"token_ids": ops.array([[101, 1996, 4248, 102]], dtype="int32"),
"segment_ids": ops.zeros((1, 4), dtype="int32"),
"padding_mask": ops.ones((1, 4), dtype="int32"),
},
expected_output_shape={
"sequence_output": (1, 4, 256),
"pooled_output": (1, 256),
},
# The forward pass from a preset should be stable!
expected_partial_output={
"sequence_output": (
ops.array([0.32287, 0.18754, -0.22272, -0.24177, 1.18977])
),
"pooled_output": (
ops.array([-0.02974, 0.23383, 0.08430, -0.19471, 0.14822])
),
},
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in ElectraBackbone.presets:
self.run_preset_test(
cls=ElectraBackbone,
preset=preset,
input_data=self.input_data,
)
163 changes: 163 additions & 0 deletions keras_nlp/models/electra/electra_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.models.electra.electra_presets import backbone_presets
from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
from keras_nlp.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.models.ElectraPreprocessor")
class ElectraPreprocessor(Preprocessor):
"""A ELECTRA preprocessing layer which tokenizes and packs inputs.

This preprocessing layer will do three things:

1. Tokenize any number of input segments using the `tokenizer`.
2. Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`.
with the appropriate `"[CLS]"`, `"[SEP]"` and `"[PAD]"` tokens.
3. Construct a dictionary of with keys `"token_ids"` and `"padding_mask"`,
that can be passed directly to a ELECTRA model.

This layer can be used directly with `tf.data.Dataset.map` to preprocess
string data in the `(x, y, sample_weight)` format used by
`keras.Model.fit`.

Args:
tokenizer: A `keras_nlp.models.ElectraTokenizer` instance.
sequence_length: The length of the packed inputs.
truncate: string. The algorithm to truncate a list of batched segments
to fit within `sequence_length`. The value can be either
`round_robin` or `waterfall`:
- `"round_robin"`: Available space is assigned one token at a
time in a round-robin fashion to the inputs that still need
some, until the limit is reached.
- `"waterfall"`: The allocation of the budget is done using a
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It supports an arbitrary number of segments.

Call arguments:
x: A tensor of single string sequences, or a tuple of multiple
tensor sequences to be packed together. Inputs may be batched or
unbatched. For single sequences, raw python inputs will be converted
to tensors. For multiple sequences, pass tensors directly.
y: Any label data. Will be passed through unaltered.
sample_weight: Any label weight data. Will be passed through unaltered.

Examples:

Directly calling the layer on data.
```python
preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset(
"electra_base_discriminator_en"
)
preprocessor(["The quick brown fox jumped.", "Call me Ishmael."])

# Custom vocabulary.
vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
vocab += ["The", "quick", "brown", "fox", "jumped", "."]
tokenizer = keras_nlp.models.ElectraTokenizer(vocabulary=vocab)
preprocessor = keras_nlp.models.ElectraPreprocessor(tokenizer)
preprocessor("The quick brown fox jumped.")
```

Mapping with `tf.data.Dataset`.
```python
preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset(
"electra_base_discriminator_en"
)

first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."])
second = tf.constant(["The fox tripped.", "Oh look, a whale."])
label = tf.constant([1, 1])
# Map labeled single sentences.
ds = tf.data.Dataset.from_tensor_slices((first, label))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)


# Map unlabeled single sentences.
ds = tf.data.Dataset.from_tensor_slices(first)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Map labeled sentence pairs.
ds = tf.data.Dataset.from_tensor_slices(((first, second), label))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
# Map unlabeled sentence pairs.
ds = tf.data.Dataset.from_tensor_slices((first, second))

# Watch out for tf.data's default unpacking of tuples here!
# Best to invoke the `preprocessor` directly in this case.
ds = ds.map(
lambda first, second: preprocessor(x=(first, second)),
num_parallel_calls=tf.data.AUTOTUNE,
)
```
"""

def __init__(
self,
tokenizer,
sequence_length=512,
truncate="round_robin",
**kwargs,
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
self.packer = MultiSegmentPacker(
start_value=self.tokenizer.cls_token_id,
end_value=self.tokenizer.sep_token_id,
pad_value=self.tokenizer.pad_token_id,
truncate=truncate,
sequence_length=sequence_length,
)

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.packer.sequence_length,
"truncate": self.packer.truncate,
}
)
return config

def call(self, x, y=None, sample_weight=None):
x = convert_inputs_to_list_of_tensor_segments(x)
x = [self.tokenizer(segment) for segment in x]
token_ids, segment_ids = self.packer(x)
x = {
"token_ids": token_ids,
"segment_ids": segment_ids,
"padding_mask": token_ids != self.tokenizer.pad_token_id,
}
return pack_x_y_sample_weight(x, y, sample_weight)

@classproperty
def tokenizer_cls(cls):
return ElectraTokenizer

@classproperty
def presets(cls):
return copy.deepcopy({**backbone_presets})
67 changes: 67 additions & 0 deletions keras_nlp/models/electra/electra_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor
from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer
from keras_nlp.tests.test_case import TestCase


class ElectraPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
self.vocab += ["THE", "QUICK", "BROWN", "FOX"]
self.vocab += ["the", "quick", "brown", "fox"]
self.tokenizer = ElectraTokenizer(vocabulary=self.vocab)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 8,
}
self.input_data = (
["THE QUICK BROWN FOX."],
[1], # Pass through labels.
[1.0], # Pass through sample_weights.
)

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
cls=ElectraPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=(
{
"token_ids": [[2, 5, 6, 7, 8, 1, 3, 0]],
"segment_ids": [[0, 0, 0, 0, 0, 0, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
[1], # Pass through labels.
[1.0], # Pass through sample_weights.
),
)

def test_errors_for_2d_list_input(self):
preprocessor = ElectraPreprocessor(**self.init_kwargs)
ambiguous_input = [["one", "two"], ["three", "four"]]
with self.assertRaises(ValueError):
preprocessor(ambiguous_input)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in ElectraPreprocessor.presets:
self.run_preset_test(
cls=ElectraPreprocessor,
preset=preset,
input_data=self.input_data,
)
69 changes: 69 additions & 0 deletions keras_nlp/models/electra/electra_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ELECTRA model preset configurations."""

backbone_presets = {
"electra_base_discriminator_en": {
"metadata": {
"description": (
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
"This is base discriminator model with 12 layers."
),
"params": 109482240,
"official_name": "ELECTRA",
"path": "electra",
"model_card": "https://github.com/google-research/electra",
},
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_base_discriminator_en/1",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see anything at https://www.kaggle.com/models/pranavprajapati16/electra.

You should now have the ability to make models public, can you do so? Or is the actual model here? https://www.kaggle.com/models/pranavprajapati16/electra_base_discriminator_en (in which case these links are still wrong).

Let me know where to get the proper assets and I will copy to the Keras org.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry the model was private just made it public. https://www.kaggle.com/models/pranavprajapati16/electra

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Uploading now! I can just patch the new links into this PR and land. I'll ping here if I run into any issues.

},
"electra_small_discriminator_en": {
"metadata": {
"description": (
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
"This is small discriminator model with 12 layers."
),
"params": 13548800,
"official_name": "ELECTRA",
"path": "electra",
"model_card": "https://github.com/google-research/electra",
},
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_small_discriminator_en/1",
},
"electra_small_generator_en": {
"metadata": {
"description": (
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
"This is small generator model with 12 layers."
),
"params": 13548800,
"official_name": "ELECTRA",
"path": "electra",
"model_card": "https://github.com/google-research/electra",
},
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_small_generator_en/1",
},
"electra_base_generator_en": {
"metadata": {
"description": (
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
"This is base generator model with 12 layers."
),
"params": 33576960,
"official_name": "ELECTRA",
"path": "electra",
"model_card": "https://github.com/google-research/electra",
},
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_base_generator_en/1",
},
}
Loading