Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.

from keras_nlp.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.models.albert.albert_masked_lm import AlbertMaskedLM
from keras_nlp.models.albert.albert_masked_lm_preprocessor import (
AlbertMaskedLMPreprocessor,
)
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.models.bart.bart_backbone import BartBackbone
Expand Down
151 changes: 151 additions & 0 deletions keras_nlp/models/albert/albert_masked_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Albert masked lm model."""

import copy

from tensorflow import keras

from keras_nlp.layers.masked_lm_head import MaskedLMHead
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.models.albert.albert_backbone import albert_kernel_initializer
from keras_nlp.models.albert.albert_masked_lm_preprocessor import (
AlbertMaskedLMPreprocessor,
)
from keras_nlp.models.albert.albert_presets import backbone_presets
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class AlbertMaskedLM(Task):
"""An end-to-end Albert model for the masked language modeling task.

This model will train Albert on a masked language modeling task.
The model will predict labels for a number of masked tokens in the
input data. For usage of this model with pre-trained weights, see the
`from_preset()` method.

This model can optionally be configured with a `preprocessor` layer, in
which case inputs can be raw string features during `fit()`, `predict()`,
and `evaluate()`. Inputs will be tokenized and dynamically masked during
training and evaluation. This is done by default when creating the model
with `from_preset()`.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/facebookresearch/fairseq).

Args:
backbone: A `keras_nlp.models.AlbertBackbone` instance.
preprocessor: A `keras_nlp.models.AlbertMaskedLMPreprocessor` or
`None`. If `None`, this model will not apply preprocessing, and
inputs should be preprocessed before calling the model.

Example usage:

Raw string inputs and pretrained backbone.
```python
# Create a dataset with raw string features. Labels are inferred.
features = ["The quick brown fox jumped.", "I forgot my homework."]

# Create a AlbertaskedLM with a pretrained backbone and further train
# on an MLM task.
masked_lm = keras_nlp.models.AlbertMaskedLM.from_preset(
"albert_base_en_uncased",
)
masked_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
masked_lm.fit(x=features, batch_size=2)
```

Preprocessed inputs and custom backbone.
```python
# Create a preprocessed dataset where 0 is the mask token.
preprocessed_features = {
"token_ids": tf.constant(
[[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8)
),
"mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2))
}
# Labels are the original masked values.
labels = [[3, 5]] * 2

# Randomly initialize a Albert encoder
backbone = keras_nlp.models.AlbertBackbone(
vocabulary_size=50265,
num_layers=12,
num_heads=12,
hidden_dim=768,
intermediate_dim=3072,
max_sequence_length=12
)
# Create a Albert masked_lm and fit the data.
masked_lm = keras_nlp.models.AlbertMaskedLM(
backbone,
preprocessor=None,
)
masked_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2)
```
"""

def __init__(self, backbone, preprocessor=None, **kwargs):
inputs = {
**backbone.input,
"mask_positions": keras.Input(
shape=(None,), dtype="int32", name="mask_positions"
),
}

backbone_outputs = backbone(backbone.input)
outputs = MaskedLMHead(
vocabulary_size=backbone.vocabulary_size,
embedding_weights=backbone.token_embedding.embeddings,
intermediate_activation=lambda x: keras.activations.gelu(
x, approximate=True
),
kernel_initializer=albert_kernel_initializer(),
name="mlm_head",
)(backbone_outputs["sequence_output"], inputs["mask_positions"])

super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs
)

self.backbone = backbone
self.preprocessor = preprocessor

@classproperty
def backbone_cls(cls):
return AlbertBackbone

@classproperty
def preprocessor_cls(cls):
return AlbertMaskedLMPreprocessor

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
201 changes: 201 additions & 0 deletions keras_nlp/models/albert/albert_masked_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Albert masked language model preprocessor layer"""

from absl import logging
from tensorflow import keras

from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras.utils.register_keras_serializable(package="keras_nlp")
class AlbertMaskedLMPreprocessor(AlbertPreprocessor):
"""Albert preprocessing for the masked language modeling task.

This preprocessing layer will prepare inputs for a masked language modeling
task. It is primarily intended for use with the
`keras_nlp.models.AlbertMaskedLM` task model. Preprocessing will occur in
multiple steps.

- Tokenize any number of input segments using the `tokenizer`.
- Pack the inputs together with the appropriate `"<s>"`, `"</s>"` and
`"<pad>"` tokens, i.e., adding a single `"<s>"` at the start of the
entire sequence, `"</s></s>"` between each segment,
and a `"</s>"` at the end of the entire sequence.
- Randomly select non-special tokens to mask, controlled by
`mask_selection_rate`.
- Construct a `(x, y, sample_weight)` tuple suitable for training with a
`keras_nlp.models.AlbertMaskedLM` task model.

Args:
tokenizer: A `keras_nlp.models.AlbertTokenizer` instance.
sequence_length: The length of the packed inputs.
mask_selection_rate: The probability an input token will be dynamically
masked.
mask_selection_length: The maximum number of masked tokens supported
by the layer.
mask_token_rate: float, defaults to 0.8. `mask_token_rate` must be
between 0 and 1 which indicates how often the mask_token is
substituted for tokens selected for masking.
random_token_rate: float, defaults to 0.1. `random_token_rate` must be
between 0 and 1 which indicates how often a random token is
substituted for tokens selected for masking. Default is 0.1.
Note: mask_token_rate + random_token_rate <= 1, and for
(1 - mask_token_rate - random_token_rate), the token will not be
changed.
truncate: string. The algorithm to truncate a list of batched segments
to fit within `sequence_length`. The value can be either
`round_robin` or `waterfall`:
- `"round_robin"`: Available space is assigned one token at a
time in a round-robin fashion to the inputs that still need
some, until the limit is reached.
- `"waterfall"`: The allocation of the budget is done using a
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It supports an arbitrary number of segments.

Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.AlbertMaskedLMPreprocessor.from_preset(
"albert_base_en_uncased"
)

# Tokenize and mask a single sentence.
sentence = tf.constant("The quick brown fox jumped.")
preprocessor(sentence)

# Tokenize and mask a batch of sentences.
sentences = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
preprocessor(sentences)

# Tokenize and mask a dataset of sentences.
features = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
ds = tf.data.Dataset.from_tensor_slices((features))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Alternatively, you can create a preprocessor from your own vocabulary.
tf.data.Dataset.from_tensor_slices(
["the quick brown fox", "the earth is round"]
)

bytes_io = io.BytesIO()
sentencepiece.SentencePieceTrainer.train(
sentence_iterator=vocab_data.as_numpy_iterator(),
model_writer=bytes_io,
vocab_size=10,
model_type="WORD",
pad_id=0,
unk_id=1,
bos_id=2,
eos_id=3,
pad_piece="<pad>",
unk_piece="<unk>",
bos_piece="[CLS]",
eos_piece="[SEP]",
)

proto = bytes_io.getvalue()

tokenizer = AlbertTokenizer(proto=proto)

preprocessor = AlbertMaskedLMPreprocessor(
tokenizer=tokenizer,
# Simplify out testing by masking every available token.
mask_selection_rate=1.0,
mask_token_rate=1.0,
random_token_rate=0.0,
mask_selection_length=5,
sequence_length=12,
)

```
"""

def __init__(
self,
tokenizer,
sequence_length=512,
truncate="round_robin",
mask_selection_rate=0.15,
mask_selection_length=96,
mask_token_rate=0.8,
random_token_rate=0.1,
**kwargs,
):
super().__init__(
tokenizer,
sequence_length=sequence_length,
truncate=truncate,
**kwargs,
)

self.masker = MaskedLMMaskGenerator(
mask_selection_rate=mask_selection_rate,
mask_selection_length=mask_selection_length,
mask_token_rate=mask_token_rate,
random_token_rate=random_token_rate,
vocabulary_size=tokenizer.vocabulary_size(),
mask_token_id=tokenizer.mask_token_id,
unselectable_token_ids=[
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.pad_token_id,
],
)

def get_config(self):
config = super().get_config()
config.update(
{
"mask_selection_rate": self.masker.mask_selection_rate,
"mask_selection_length": self.masker.mask_selection_length,
"mask_token_rate": self.masker.mask_token_rate,
"random_token_rate": self.masker.random_token_rate,
}
)
return config

def call(self, x, y=None, sample_weight=None):
if y is not None or sample_weight is not None:
logging.warning(
f"{self.__class__.__name__} generates `y` and `sample_weight` "
"based on your input data, but your data already contains `y` "
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)

x = super().call(x)
token_ids, segment_ids, padding_mask = (
x["token_ids"],
x["segment_ids"],
x["padding_mask"],
)
masker_outputs = self.masker(token_ids)
x = {
"token_ids": masker_outputs["token_ids"],
"segment_ids": segment_ids,
"padding_mask": padding_mask,
"mask_positions": masker_outputs["mask_positions"],
}
y = masker_outputs["mask_ids"]
sample_weight = masker_outputs["mask_weights"]
return pack_x_y_sample_weight(x, y, sample_weight)
Loading