Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
from keras_nlp.models.distil_bert.distil_bert_classifier import (
DistilBertClassifier,
)
from keras_nlp.models.distil_bert.distil_bert_masked_lm import (
DistilBertMaskedLM,
)
from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import (
DistilBertMaskedLMPreprocessor,
)
from keras_nlp.models.distil_bert.distil_bert_preprocessor import (
DistilBertPreprocessor,
)
Expand Down
155 changes: 155 additions & 0 deletions keras_nlp/models/distil_bert/distil_bert_masked_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DistilBERT masked lm model."""

import copy

from tensorflow import keras

from keras_nlp.layers.masked_lm_head import MaskedLMHead
from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone
from keras_nlp.models.distil_bert.distil_bert_backbone import (
distilbert_kernel_initializer,
)
from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import (
DistilBertMaskedLMPreprocessor,
)
from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class DistilBertMaskedLM(Task):
"""An end-to-end DistilBERT model for the masked language modeling task.

This model will train DistilBERT on a masked language modeling task.
The model will predict labels for a number of masked tokens in the
input data. For usage of this model with pre-trained weights, see the
`from_preset()` method.

This model can optionally be configured with a `preprocessor` layer, in
which case inputs can be raw string features during `fit()`, `predict()`,
and `evaluate()`. Inputs will be tokenized and dynamically masked during
training and evaluation. This is done by default when creating the model
with `from_preset()`.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/facebookresearch/fairseq).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update this to the right disclaimer for the model, distilbert is huggingface


Args:
backbone: A `keras_nlp.models.DistilBertBackbone` instance.
preprocessor: A `keras_nlp.models.DistilBertMaskedLMPreprocessor` or
`None`. If `None`, this model will not apply preprocessing, and
inputs should be preprocessed before calling the model.

Example usage:

Raw string inputs and pretrained backbone.
```python
# Create a dataset with raw string features. Labels are inferred.
features = ["The quick brown fox jumped.", "I forgot my homework."]

# Create a DistilBertMaskedLM with a pretrained backbone and further train
# on an MLM task.
masked_lm = keras_nlp.models.DistilBertMaskedLM.from_preset(
"distil_bert_base_en",
)
masked_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
masked_lm.fit(x=features, batch_size=2)
```

Preprocessed inputs and custom backbone.
```python
# Create a preprocessed dataset where 0 is the mask token.
preprocessed_features = {
"token_ids": tf.constant(
[[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8)
),
"mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2))
}
# Labels are the original masked values.
labels = [[3, 5]] * 2

# Randomly initialize a DistilBERT encoder
backbone = keras_nlp.models.DistilBertBackbone(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a reminder, make sure that the examples actually run correctly. Thanks !

vocabulary_size=50265,
num_layers=12,
num_heads=12,
hidden_dim=768,
intermediate_dim=3072,
max_sequence_length=12
)
# Create a DistilBERT masked_lm and fit the data.
masked_lm = keras_nlp.models.DistilBertMaskedLM(
backbone,
preprocessor=None,
)
masked_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2)
```
"""

def __init__(
self,
backbone,
preprocessor=None,
**kwargs,
):
inputs = {
**backbone.input,
"mask_positions": keras.Input(
shape=(None,), dtype="int32", name="mask_positions"
),
}
backbone_outputs = backbone(backbone.input)
outputs = MaskedLMHead(
vocabulary_size=backbone.vocabulary_size,
embedding_weights=backbone.token_embedding.embeddings,
intermediate_activation="gelu",
kernel_initializer=distilbert_kernel_initializer(),
name="mlm_head",
)(backbone_outputs, inputs["mask_positions"])

# Instantiate using Functional API Model constructor
super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs,
)
# All references to `self` below this line
self.backbone = backbone
self.preprocessor = preprocessor

@classproperty
def backbone_cls(cls):
return DistilBertBackbone

@classproperty
def preprocessor_cls(cls):
return DistilBertMaskedLMPreprocessor

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
177 changes: 177 additions & 0 deletions keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DistilBERT masked language model preprocessor layer."""

from absl import logging
from tensorflow import keras

from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator
from keras_nlp.models.distil_bert.distil_bert_preprocessor import (
DistilBertPreprocessor,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras.utils.register_keras_serializable(package="keras_nlp")
class DistilBertMaskedLMPreprocessor(DistilBertPreprocessor):
"""DistilBERT preprocessing for the masked language modeling task.

This preprocessing layer will prepare inputs for a masked language modeling
task. It is primarily intended for use with the
`keras_nlp.models.DistilBertMaskedLM` task model. Preprocessing will occur in
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix line length

multiple steps.

- Tokenize any number of input segments using the `tokenizer`.
- Pack the inputs together with the appropriate `"<s>"`, `"</s>"` and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not the special tokens used by DistilBert, this will need updating.

`"<pad>"` tokens, i.e., adding a single `"<s>"` at the start of the
entire sequence, `"</s></s>"` between each segment,
and a `"</s>"` at the end of the entire sequence.
- Randomly select non-special tokens to mask, controlled by
`mask_selection_rate`.
- Construct a `(x, y, sample_weight)` tuple suitable for training with a
`keras_nlp.models.DistilBertMaskedLM` task model.

Args:
tokenizer: A `keras_nlp.models.DistilBertTokenizer` instance.
sequence_length: The length of the packed inputs.
mask_selection_rate: The probability an input token will be dynamically
masked.
mask_selection_length: The maximum number of masked tokens supported
by the layer.
mask_token_rate: float, defaults to 0.8. `mask_token_rate` must be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is prexising I know, but let's remove these defaults for this arg and below (we don't do this for other args in this list)

between 0 and 1 which indicates how often the mask_token is
substituted for tokens selected for masking.
random_token_rate: float, defaults to 0.1. `random_token_rate` must be
between 0 and 1 which indicates how often a random token is
substituted for tokens selected for masking. Default is 0.1.
Note: mask_token_rate + random_token_rate <= 1, and for
(1 - mask_token_rate - random_token_rate), the token will not be
changed.
truncate: string. The algorithm to truncate a list of batched segments
to fit within `sequence_length`. The value can be either
`round_robin` or `waterfall`:
- `"round_robin"`: Available space is assigned one token at a
time in a round-robin fashion to the inputs that still need
some, until the limit is reached.
- `"waterfall"`: The allocation of the budget is done using a
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It supports an arbitrary number of segments.

Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.DistilBertMaskedLMPreprocessor.from_preset(
"distil_bert_base_en"
)

# Tokenize and mask a single sentence.
sentence = tf.constant("The quick brown fox jumped.")
preprocessor(sentence)

# Tokenize and mask a batch of sentences.
sentences = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
preprocessor(sentences)

# Tokenize and mask a dataset of sentences.
features = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
ds = tf.data.Dataset.from_tensor_slices((features))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Alternatively, you can create a preprocessor from your own vocabulary.
# The usage is exactly the same as above.
vocab = {"<s>": 0, "<pad>": 1, "</s>": 2, "<mask>": 3}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the right format for the vocabulary for this model, will need updating.

vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6}
merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick", "Ġ f", "o x", "Ġf ox"]
tokenizer = keras_nlp.models.DistilBertTokenizer(
vocabulary=vocab,
merges=merges,
)
preprocessor = keras_nlp.models.DistilBertMaskedLMPreprocessor(
tokenizer=tokenizer,
sequence_length=8,
)
preprocessor("a quick fox")
```
"""

def __init__(
self,
tokenizer,
sequence_length=512,
truncate="round_robin",
mask_selection_rate=0.15,
mask_selection_length=96,
mask_token_rate=0.8,
random_token_rate=0.1,
**kwargs,
):
super().__init__(
tokenizer,
sequence_length=sequence_length,
truncate=truncate,
**kwargs,
)

self.masker = MaskedLMMaskGenerator(
mask_selection_rate=mask_selection_rate,
mask_selection_length=mask_selection_length,
mask_token_rate=mask_token_rate,
random_token_rate=random_token_rate,
vocabulary_size=tokenizer.vocabulary_size(),
mask_token_id=tokenizer.mask_token_id,
unselectable_token_ids=[
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.pad_token_id,
],
)

def get_config(self):
config = super().get_config()
config.update(
{
"mask_selection_rate": self.masker.mask_selection_rate,
"mask_selection_length": self.masker.mask_selection_length,
"mask_token_rate": self.masker.mask_token_rate,
"random_token_rate": self.masker.random_token_rate,
}
)
return config

def call(self, x, y=None, sample_weight=None):
if y is not None or sample_weight is not None:
logging.warning(
f"{self.__class__.__name__} generates `y` and `sample_weight` "
"based on your input data, but your data already contains `y` "
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)

x = super().call(x)
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
masker_outputs = self.masker(token_ids)
x = {
"token_ids": masker_outputs["token_ids"],
"padding_mask": padding_mask,
"mask_positions": masker_outputs["mask_positions"],
}
y = masker_outputs["mask_ids"]
sample_weight = masker_outputs["mask_weights"]
return pack_x_y_sample_weight(x, y, sample_weight)
Loading