Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7f7ae43
initial commit
chenmoneygithub Dec 9, 2022
c53b4a9
Add keras_nlp.samplers
chenmoneygithub Dec 10, 2022
e6483a4
Change padding to left to right
chenmoneygithub Dec 12, 2022
513121e
more samplers
chenmoneygithub Dec 20, 2022
0eb68f6
Add GPT2 text generation stuff
chenmoneygithub Dec 21, 2022
fa41d23
correct top-p and beam sampler
chenmoneygithub Jan 4, 2023
26fd509
initial commit
chenmoneygithub Dec 9, 2022
7e4c651
Add keras_nlp.samplers
chenmoneygithub Dec 10, 2022
28bcfe1
Change padding to left to right
chenmoneygithub Dec 12, 2022
9757f4d
Add serialization support, and move some args from constructor to call
chenmoneygithub Jan 5, 2023
f7508cb
Add string example
chenmoneygithub Jan 6, 2023
b658b61
small changes
chenmoneygithub Jan 6, 2023
76c430c
Address comments: fix docstring, remove multicase support
chenmoneygithub Jan 9, 2023
bb430dd
Address comments: move token_probability_fn to the second place
chenmoneygithub Jan 9, 2023
afd3082
some initials
chenmoneygithub Jan 10, 2023
273a6a5
Merge branch 'master' into text-generation-extend
chenmoneygithub Jan 10, 2023
31ad970
add more sampler class, and a few changes on the base sampler class
chenmoneygithub Jan 13, 2023
331f568
Merge branch 'text-generation-extend' into text-generation-playground
chenmoneygithub Jan 13, 2023
5300800
dummy
chenmoneygithub Jan 13, 2023
de2ac9c
add some arg defaults
chenmoneygithub Jan 13, 2023
42c164f
Merge branch 'text-generation-extend' into text-generation-playground
chenmoneygithub Jan 13, 2023
08f3c1e
small fix
chenmoneygithub Jan 13, 2023
2b93ad8
fix docstring
chenmoneygithub Jan 17, 2023
309d6d4
merge
chenmoneygithub Jan 18, 2023
8206103
some changes
chenmoneygithub Jan 19, 2023
9945c13
add classes
chenmoneygithub Jan 20, 2023
4fa8fc5
fix serialization
chenmoneygithub Jan 20, 2023
cb12604
fix docstring
chenmoneygithub Jan 23, 2023
f7685ca
address comments
chenmoneygithub Jan 24, 2023
3bac2ad
Merge branch 'master' into text-generation-playground
chenmoneygithub Jan 25, 2023
2ed9adb
one more
chenmoneygithub Jan 25, 2023
f2821b5
fix docstring
chenmoneygithub Jan 26, 2023
728a471
minor fix
chenmoneygithub Jan 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
Expand Down
239 changes: 239 additions & 0 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 Causal LM (Language Model)."""

import copy

import tensorflow as tf
from tensorflow import keras

import keras_nlp
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2CausalLM(Task):
"""An end-to-end GPT2 model for causal langauge modeling.

A causal language model (LM) predicts the next token based on previous
tokens the next token based on previous tokens, which is the way GPT2 gets
pretrained. You can finetune `GPT2CausalLM` to generate text similar to
the custom dataset. `GPT2CausalLM` also has a method `generate()`, which
generates text based on given prompt.

This model can optionally be configured with a `preprocessor` layer, in
which case it will automatically apply preprocessing to raw inputs during
`fit()`, `predict()`, and `evaluate()`. This is done by default when
creating the model with `from_preset()`.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/openai/gpt-2).

Args:
backbone: A `keras_nlp.models.GPT2Backbone` instance.
preprocessor: A `keras_nlp.models.GPT2CausalLMPreprocessor` or `None`.
If `None`, this model will not apply preprocessing, and inputs
should be preprocessed before calling the model.

Examples:

Use `generate()` method to do text generation.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.generate("I want to say", max_length=30)

# Generate with batched prompts.
gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
```

Use a custom sampler for text generation.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Use string identifier to set sampler.
gpt2_lm.generate("I want to say", max_length=30, sampler="top_p")

# Construct a sampler instance.
sampler = keras_nlp.samplers.BeamSampler(num_beams=2)
gpt2_lm.generate("I want to say", max_length=30, sampler=sampler)
```

Map raw string to languages model logit predictions.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.predict(["You know this is just a test string"])
```

Load a pretrained GPT2 and fit on a string dataset.
```python
features = [
"I don't listen to music while coding.",
"But I watch youtube while coding!",
]
ds = tf.data.Dataset.from_tensor_slices(features)

# Create a `GPT2CausalLM` and fit your data.
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(ds, batch_size=2)
```

Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on
string inputs.
```python
# Use a shorter sequence length.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en",
sequence_length=128,
)

# Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor.
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
preprocessor=preprocessor,
)
gpt2_lm.predict(["You know this is still a test string"])
```

Fit your preprocessed data with randomly initialized GPT2. This is useful
when you want to do data preprocessing inside `tf.data` pipeline.
```python
# Define preprocessed input.
features = {
"token_ids": tf.constant(
[[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6)
),
}
labels = tf.constant(
[[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6)
)
sample_weight = tf.constant(
[[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6)
)

# Randomly initialize a GPT2 backbone.
backbone = keras_nlp.models.GPT2Backbone(
vocabulary_size=50257,
num_layers=2,
num_heads=2,
hidden_dim=128,
intermediate_dim=256,
max_sequence_length=128,
)
# Create a `GPT2CausalLM` without preprocessor and fit the data.
gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(
x=features,
y=labels,
sample_weight=sample_weight,
batch_size=2,
)
```

"""

def __init__(self, backbone, preprocessor=None, **kwargs):
inputs = backbone.input
x = backbone(inputs)
# Use token embedding weights to project from the token representation
# to vocabulary logits.
outputs = tf.matmul(
x,
backbone.token_embedding.embeddings,
transpose_b=True,
)

# Instantiate using Functional API Model constructor.
super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs,
)

self._backbone = backbone
self._preprocessor = preprocessor

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classproperty
def backbone_cls(cls):
return GPT2Backbone

@classproperty
def preprocessor_cls(cls):
return GPT2CausalLMPreprocessor

def _get_token_probability(self, prompt, mask):
model_inputs = {
"token_ids": prompt,
"padding_mask": mask,
}
return self(model_inputs)

def generate(
self,
prompt,
max_length,
sampler="top_k",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why was top-k chosen as the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It;s working well with my finetuning tasks. I feel we want to later change this default to contrastive search, which is not yet available

):
"""Generate text.

This method generates text based on given `prompt`. Generation will
continue until `max_length` is met, and all tokens generated after
`end_token` will be truncated. The sampling approach used can be
controlled via the sampler argument.

Args:
prompt: a string, string Tensor or string RaggedTensor. The prompt
text for generation.
max_length: int. The max length of generated sequence.
sampler: a string or `keras_nlp.samplers.Sampler` instance. The
sampler to be used for text generation.
"""
end_token_id = self.preprocessor.tokenizer.end_token_id

sampler = keras_nlp.samplers.get(sampler)
if hasattr(self, "jit_compile"):
# `jit_compile` is a public property as of tf 2.12. hasattr is for
# backward compat.
sampler.jit_compile = self.jit_compile
sampler.run_eagerly = self.run_eagerly
generated = sampler(
self.preprocessor.tokenizer(prompt),
self._get_token_probability,
max_length=max_length,
end_token_id=end_token_id,
)
return self.preprocessor.tokenizer.detokenize(generated)
96 changes: 96 additions & 0 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GPT2 Causal LM preprocessor layer."""

from absl import logging
from tensorflow import keras

from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
"""GPT2 Causal LM preprocessor.

This preprocessor is majorly used as the preprocesor for `GPT2CausalLM`.
This class subclasses `keras_nlp.models.GPT2Preprocessor` and keeps most of
its functionality. The only change is `GPT2CausalLMPreprocessor` sets
`y` (label) and `sample_weights` field by shifting the input sequence one
step towards left, and drop the last token as it does not have a successor,
e.g., if the tokenized input is `[1, 2, 3, 0, 0]` with
`padding_mask = [1, 1, 1, 0, 0]`, then after preprocessing, we
will have `x = [1, 2, 3, 0]` and `y = [2, 3, 0, 0]`, with
`padding_mask = [1, 1, 1, 0]` and `sample_weights = [1, 1, 0, 0]`.

Args:
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
sequence_length: The length of the packed inputs.

Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en"
)

# Tokenize and pack a single sentence.
sentence = tf.constant("League of legends")
preprocessor(sentence)
# Same output.
preprocessor("League of legends")

# Tokenize a batch of sentences.
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
preprocessor(sentences)
# Same output.
preprocessor(["Taco tuesday", "Fish taco please!"])

# Map a dataset to preprocess a single sentence.
features = tf.constant(
[
"Avatar 2 is amazing!",
"Well, I am not sure.",
]
)
labels = tf.constant([1, 0])
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Map a dataset to preprocess unlabled sentences.
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
"""

def call(self, x, y=None, sample_weight=None):
if y is not None or sample_weight is not None:
logging.warning(
"`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` "
"based on your input data, but your data already contains `y` "
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)

x = super().call(x)
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y = token_ids[..., 1:]
sample_weight = padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)
Loading