Skip to content

Commit 5737da9

Browse files
GPT2 Text Generation APIs (#592)
* initial commit * Add keras_nlp.samplers * Change padding to left to right * more samplers * Add GPT2 text generation stuff * correct top-p and beam sampler * initial commit * Add keras_nlp.samplers * Change padding to left to right * Add serialization support, and move some args from constructor to call * Add string example * small changes * Address comments: fix docstring, remove multicase support * Address comments: move token_probability_fn to the second place * some initials * add more sampler class, and a few changes on the base sampler class * dummy * add some arg defaults * small fix * fix docstring * some changes * add classes * fix serialization * fix docstring * address comments * one more * fix docstring * minor fix
1 parent 9c5f850 commit 5737da9

File tree

8 files changed

+931
-2
lines changed

8 files changed

+931
-2
lines changed

keras_nlp/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@
3333
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
3434
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
3535
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer
36+
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
37+
from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM
38+
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
39+
GPT2CausalLMPreprocessor,
40+
)
41+
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
42+
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
3643
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
3744
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
3845
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""GPT2 Causal LM (Language Model)."""
15+
16+
import copy
17+
18+
import tensorflow as tf
19+
from tensorflow import keras
20+
21+
import keras_nlp
22+
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
23+
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
24+
GPT2CausalLMPreprocessor,
25+
)
26+
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
27+
from keras_nlp.models.task import Task
28+
from keras_nlp.utils.python_utils import classproperty
29+
30+
31+
@keras.utils.register_keras_serializable(package="keras_nlp")
32+
class GPT2CausalLM(Task):
33+
"""An end-to-end GPT2 model for causal langauge modeling.
34+
35+
A causal language model (LM) predicts the next token based on previous
36+
tokens the next token based on previous tokens, which is the way GPT2 gets
37+
pretrained. You can finetune `GPT2CausalLM` to generate text similar to
38+
the custom dataset. `GPT2CausalLM` also has a method `generate()`, which
39+
generates text based on given prompt.
40+
41+
This model can optionally be configured with a `preprocessor` layer, in
42+
which case it will automatically apply preprocessing to raw inputs during
43+
`fit()`, `predict()`, and `evaluate()`. This is done by default when
44+
creating the model with `from_preset()`.
45+
46+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
47+
warranties or conditions of any kind. The underlying model is provided by a
48+
third party and subject to a separate license, available
49+
[here](https://github.com/openai/gpt-2).
50+
51+
Args:
52+
backbone: A `keras_nlp.models.GPT2Backbone` instance.
53+
preprocessor: A `keras_nlp.models.GPT2CausalLMPreprocessor` or `None`.
54+
If `None`, this model will not apply preprocessing, and inputs
55+
should be preprocessed before calling the model.
56+
57+
Examples:
58+
59+
Use `generate()` method to do text generation.
60+
```python
61+
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
62+
gpt2_lm.generate("I want to say", max_length=30)
63+
64+
# Generate with batched prompts.
65+
gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
66+
```
67+
68+
Use a custom sampler for text generation.
69+
```python
70+
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
71+
72+
# Use string identifier to set sampler.
73+
gpt2_lm.generate("I want to say", max_length=30, sampler="top_p")
74+
75+
# Construct a sampler instance.
76+
sampler = keras_nlp.samplers.BeamSampler(num_beams=2)
77+
gpt2_lm.generate("I want to say", max_length=30, sampler=sampler)
78+
```
79+
80+
Map raw string to languages model logit predictions.
81+
```python
82+
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
83+
gpt2_lm.predict(["You know this is just a test string"])
84+
```
85+
86+
Load a pretrained GPT2 and fit on a string dataset.
87+
```python
88+
features = [
89+
"I don't listen to music while coding.",
90+
"But I watch youtube while coding!",
91+
]
92+
ds = tf.data.Dataset.from_tensor_slices(features)
93+
94+
# Create a `GPT2CausalLM` and fit your data.
95+
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
96+
"gpt2_base_en",
97+
)
98+
gpt2_lm.compile(
99+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
100+
)
101+
gpt2_lm.fit(ds, batch_size=2)
102+
```
103+
104+
Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on
105+
string inputs.
106+
```python
107+
# Use a shorter sequence length.
108+
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
109+
"gpt2_base_en",
110+
sequence_length=128,
111+
)
112+
113+
# Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor.
114+
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
115+
"gpt2_base_en",
116+
preprocessor=preprocessor,
117+
)
118+
gpt2_lm.predict(["You know this is still a test string"])
119+
```
120+
121+
Fit your preprocessed data with randomly initialized GPT2. This is useful
122+
when you want to do data preprocessing inside `tf.data` pipeline.
123+
```python
124+
# Define preprocessed input.
125+
features = {
126+
"token_ids": tf.constant(
127+
[[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6)
128+
),
129+
"padding_mask": tf.constant(
130+
[[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6)
131+
),
132+
}
133+
labels = tf.constant(
134+
[[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6)
135+
)
136+
sample_weight = tf.constant(
137+
[[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6)
138+
)
139+
140+
# Randomly initialize a GPT2 backbone.
141+
backbone = keras_nlp.models.GPT2Backbone(
142+
vocabulary_size=50257,
143+
num_layers=2,
144+
num_heads=2,
145+
hidden_dim=128,
146+
intermediate_dim=256,
147+
max_sequence_length=128,
148+
)
149+
# Create a `GPT2CausalLM` without preprocessor and fit the data.
150+
gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None)
151+
gpt2_lm.compile(
152+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
153+
)
154+
gpt2_lm.fit(
155+
x=features,
156+
y=labels,
157+
sample_weight=sample_weight,
158+
batch_size=2,
159+
)
160+
```
161+
162+
"""
163+
164+
def __init__(self, backbone, preprocessor=None, **kwargs):
165+
inputs = backbone.input
166+
x = backbone(inputs)
167+
# Use token embedding weights to project from the token representation
168+
# to vocabulary logits.
169+
outputs = tf.matmul(
170+
x,
171+
backbone.token_embedding.embeddings,
172+
transpose_b=True,
173+
)
174+
175+
# Instantiate using Functional API Model constructor.
176+
super().__init__(
177+
inputs=inputs,
178+
outputs=outputs,
179+
include_preprocessing=preprocessor is not None,
180+
**kwargs,
181+
)
182+
183+
self._backbone = backbone
184+
self._preprocessor = preprocessor
185+
186+
@classproperty
187+
def presets(cls):
188+
return copy.deepcopy(backbone_presets)
189+
190+
@classproperty
191+
def backbone_cls(cls):
192+
return GPT2Backbone
193+
194+
@classproperty
195+
def preprocessor_cls(cls):
196+
return GPT2CausalLMPreprocessor
197+
198+
def _get_token_probability(self, prompt, mask):
199+
model_inputs = {
200+
"token_ids": prompt,
201+
"padding_mask": mask,
202+
}
203+
return self(model_inputs)
204+
205+
def generate(
206+
self,
207+
prompt,
208+
max_length,
209+
sampler="top_k",
210+
):
211+
"""Generate text.
212+
213+
This method generates text based on given `prompt`. Generation will
214+
continue until `max_length` is met, and all tokens generated after
215+
`end_token` will be truncated. The sampling approach used can be
216+
controlled via the sampler argument.
217+
218+
Args:
219+
prompt: a string, string Tensor or string RaggedTensor. The prompt
220+
text for generation.
221+
max_length: int. The max length of generated sequence.
222+
sampler: a string or `keras_nlp.samplers.Sampler` instance. The
223+
sampler to be used for text generation.
224+
"""
225+
end_token_id = self.preprocessor.tokenizer.end_token_id
226+
227+
sampler = keras_nlp.samplers.get(sampler)
228+
if hasattr(self, "jit_compile"):
229+
# `jit_compile` is a public property as of tf 2.12. hasattr is for
230+
# backward compat.
231+
sampler.jit_compile = self.jit_compile
232+
sampler.run_eagerly = self.run_eagerly
233+
generated = sampler(
234+
self.preprocessor.tokenizer(prompt),
235+
self._get_token_probability,
236+
max_length=max_length,
237+
end_token_id=end_token_id,
238+
)
239+
return self.preprocessor.tokenizer.detokenize(generated)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""GPT2 Causal LM preprocessor layer."""
16+
17+
from absl import logging
18+
from tensorflow import keras
19+
20+
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
21+
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
22+
23+
24+
@keras.utils.register_keras_serializable(package="keras_nlp")
25+
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
26+
"""GPT2 Causal LM preprocessor.
27+
28+
This preprocessor is majorly used as the preprocesor for `GPT2CausalLM`.
29+
This class subclasses `keras_nlp.models.GPT2Preprocessor` and keeps most of
30+
its functionality. The only change is `GPT2CausalLMPreprocessor` sets
31+
`y` (label) and `sample_weights` field by shifting the input sequence one
32+
step towards left, and drop the last token as it does not have a successor,
33+
e.g., if the tokenized input is `[1, 2, 3, 0, 0]` with
34+
`padding_mask = [1, 1, 1, 0, 0]`, then after preprocessing, we
35+
will have `x = [1, 2, 3, 0]` and `y = [2, 3, 0, 0]`, with
36+
`padding_mask = [1, 1, 1, 0]` and `sample_weights = [1, 1, 0, 0]`.
37+
38+
Args:
39+
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
40+
sequence_length: The length of the packed inputs.
41+
42+
Examples:
43+
```python
44+
# Load the preprocessor from a preset.
45+
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
46+
"gpt2_base_en"
47+
)
48+
49+
# Tokenize and pack a single sentence.
50+
sentence = tf.constant("League of legends")
51+
preprocessor(sentence)
52+
# Same output.
53+
preprocessor("League of legends")
54+
55+
# Tokenize a batch of sentences.
56+
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
57+
preprocessor(sentences)
58+
# Same output.
59+
preprocessor(["Taco tuesday", "Fish taco please!"])
60+
61+
# Map a dataset to preprocess a single sentence.
62+
features = tf.constant(
63+
[
64+
"Avatar 2 is amazing!",
65+
"Well, I am not sure.",
66+
]
67+
)
68+
labels = tf.constant([1, 0])
69+
ds = tf.data.Dataset.from_tensor_slices((features, labels))
70+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
71+
72+
# Map a dataset to preprocess unlabled sentences.
73+
ds = tf.data.Dataset.from_tensor_slices(features)
74+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
75+
"""
76+
77+
def call(self, x, y=None, sample_weight=None):
78+
if y is not None or sample_weight is not None:
79+
logging.warning(
80+
"`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` "
81+
"based on your input data, but your data already contains `y` "
82+
"or `sample_weight`. Your `y` and `sample_weight` will be "
83+
"ignored."
84+
)
85+
86+
x = super().call(x)
87+
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
88+
# The last token does not have a next token, so we truncate it out.
89+
x = {
90+
"token_ids": token_ids[..., :-1],
91+
"padding_mask": padding_mask[..., :-1],
92+
}
93+
# Target `y` will be the next token.
94+
y = token_ids[..., 1:]
95+
sample_weight = padding_mask[..., 1:]
96+
return pack_x_y_sample_weight(x, y, sample_weight)

0 commit comments

Comments
 (0)