Skip to content

Commit e3a82a6

Browse files
Add GPT2 text generation stuff
1 parent 513121e commit e3a82a6

File tree

6 files changed

+237
-4
lines changed

6 files changed

+237
-4
lines changed

keras_nlp/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
from keras_nlp.models.distil_bert.distil_bert_tokenizer import (
2727
DistilBertTokenizer,
2828
)
29+
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
30+
from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM
31+
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2CausalLMPreprocessor
32+
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
33+
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
2934
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
3035
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
3136
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""BERT task specific models and heads."""
15+
16+
import copy
17+
18+
import tensorflow as tf
19+
from tensorflow import keras
20+
21+
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
22+
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2CausalLMPreprocessor
23+
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
24+
from keras_nlp.samplers.beam_sampler import BeamSampler
25+
from keras_nlp.samplers.greedy_sampler import GreedySampler
26+
from keras_nlp.samplers.top_k_sampler import TopKSampler
27+
from keras_nlp.samplers.top_p_sampler import TopPSampler
28+
from keras_nlp.utils.pipeline_model import PipelineModel
29+
from keras_nlp.utils.python_utils import classproperty
30+
31+
32+
@keras.utils.register_keras_serializable(package="keras_nlp")
33+
class GPT2CausalLM(PipelineModel):
34+
def __init__(self, backbone, preprocessor=None, **kwargs):
35+
36+
inputs = backbone.input
37+
x = backbone(inputs)
38+
x = tf.matmul(
39+
x,
40+
backbone.get_layer("token_embedding").embeddings,
41+
transpose_b=True,
42+
)
43+
outputs = tf.keras.layers.Softmax()(x)
44+
# Instantiate using Functional API Model constructor
45+
super().__init__(
46+
inputs=inputs,
47+
outputs=outputs,
48+
include_preprocessing=preprocessor is not None,
49+
**kwargs,
50+
)
51+
52+
self.preprocessor = preprocessor
53+
self.backbone = backbone
54+
55+
def preprocess_samples(self, x, y=None, sample_weight=None):
56+
return self.preprocessor(x, y=y, sample_weight=sample_weight)
57+
58+
@classproperty
59+
def presets(cls):
60+
return copy.deepcopy(backbone_presets)
61+
62+
@classmethod
63+
def from_preset(
64+
cls,
65+
preset,
66+
load_weights=True,
67+
**kwargs,
68+
):
69+
if "preprocessor" not in kwargs:
70+
kwargs["preprocessor"] = GPT2CausalLMPreprocessor.from_preset(
71+
preset
72+
)
73+
74+
# Check if preset is backbone-only model.
75+
if preset in GPT2Backbone.presets:
76+
backbone = GPT2Backbone.from_preset(preset, load_weights)
77+
return cls(backbone, **kwargs)
78+
79+
# Otherwise must be one of class presets.
80+
# Currently no classifier-level presets, so we raise ValueError.
81+
if preset not in cls.presets:
82+
raise ValueError(
83+
"`preset` must be one of "
84+
f"""{", ".join(cls.presets)}. Received: {preset}."""
85+
)
86+
87+
def _get_generator(self, identifier):
88+
maps = {
89+
"greedy": GreedySampler(),
90+
"top_k": TopKSampler(k=5, from_logits=False),
91+
"top_p": TopPSampler(p=0.1, from_logits=False),
92+
"beam": BeamSampler(num_beams=5),
93+
}
94+
return maps[identifier]
95+
96+
def _get_token_probability(self, prompt, mask):
97+
model_inputs = {
98+
"token_ids": prompt,
99+
"padding_mask": mask,
100+
}
101+
probs = self(model_inputs)
102+
return probs
103+
104+
def generate(self, prompt, max_length, generator="top_k"):
105+
"""Pick one method as the default generation algo."""
106+
if isinstance(generator, str):
107+
generator = self._get_generator(generator)
108+
prompt = self.preprocessor.tokenizer(prompt)
109+
generated = generator(self._get_token_probability, prompt, max_length)
110+
return self.preprocessor.tokenizer.detokenize(generated)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""GPT2 preprocessor layer."""
16+
17+
import copy
18+
19+
import tensorflow as tf
20+
from tensorflow import keras
21+
22+
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
23+
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
24+
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
25+
from keras_nlp.utils.python_utils import classproperty
26+
27+
28+
class GPT2Preprocessor(keras.layers.Layer):
29+
def __init__(self, tokenizer, sequence_length, **kwargs):
30+
31+
super().__init__(**kwargs)
32+
33+
self.tokenizer = tokenizer
34+
self.sequence_length = sequence_length
35+
36+
def call(self, x, y=None, sample_weight=None):
37+
token_ids = self.tokenizer(x)
38+
mask = tf.ones_like(token_ids, dtype=tf.bool)
39+
mask = mask.to_tensor(shape=(None, self.sequence_length))
40+
token_ids = token_ids.to_tensor(shape=(None, self.sequence_length))
41+
x = {
42+
"token_ids": token_ids,
43+
"padding_mask": mask,
44+
}
45+
46+
return pack_x_y_sample_weight(x, y, sample_weight)
47+
48+
@classproperty
49+
def presets(cls):
50+
return copy.deepcopy(backbone_presets)
51+
52+
@classmethod
53+
def from_preset(
54+
cls,
55+
preset,
56+
sequence_length=None,
57+
**kwargs,
58+
):
59+
if preset not in cls.presets:
60+
raise ValueError(
61+
"`preset` must be one of "
62+
f"""{", ".join(cls.presets)}. Received: {preset}."""
63+
)
64+
65+
tokenizer = GPT2Tokenizer.from_preset(preset)
66+
67+
# Use model's `max_sequence_length` if `sequence_length` unspecified;
68+
# otherwise check that `sequence_length` not too long.
69+
metadata = cls.presets[preset]
70+
max_sequence_length = metadata["config"]["max_sequence_length"]
71+
if sequence_length is not None:
72+
if sequence_length > max_sequence_length:
73+
raise ValueError(
74+
f"`sequence_length` cannot be longer than `{preset}` "
75+
f"preset's `max_sequence_length` of {max_sequence_length}. "
76+
f"Received: {sequence_length}."
77+
)
78+
else:
79+
sequence_length = max_sequence_length
80+
81+
return cls(
82+
tokenizer=tokenizer,
83+
sequence_length=sequence_length,
84+
**kwargs,
85+
)
86+
87+
88+
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
89+
def call(self, x, y=None, sample_weight=None):
90+
token_ids = self.tokenizer(x)
91+
mask = tf.ones_like(token_ids, dtype=tf.bool)
92+
mask = mask.to_tensor(shape=(None, self.sequence_length))
93+
token_ids = token_ids.to_tensor(shape=(None, self.sequence_length))
94+
x = {
95+
"token_ids": token_ids[:, :-1],
96+
"padding_mask": mask[:, 1:],
97+
}
98+
99+
y = token_ids[:, 1:]
100+
sample_weight = mask[:, 1:]
101+
102+
return pack_x_y_sample_weight(x, y, sample_weight)

keras_nlp/samplers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from keras_nlp.samplers.beam_sampler import BeamSampler
1516
from keras_nlp.samplers.greedy_sampler import GreedySampler
17+
from keras_nlp.samplers.top_k_sampler import TopKSampler
18+
from keras_nlp.samplers.top_p_sampler import TopPSampler

keras_nlp/samplers/top_k_sampler renamed to keras_nlp/samplers/top_k_sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Greedy Sampler."""
1515

1616
import tensorflow as tf
17+
from tensorflow import keras
1718

1819
from keras_nlp.samplers.sampler import Sampler
1920
from keras_nlp.samplers.sampler import base_sampler_keyword_args

keras_nlp/samplers/top_p_sampler.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def sample(self, token_probability_fn, prompt, mask, num_steps):
5858
max_length = tf.cast(max_length, num_steps.dtype)
5959
length = max_length - num_steps
6060

61-
def one_step(length, prompt):
62-
pred = token_probability_fn(prompt[:, :length])
61+
def one_step(length, prompt, mask):
62+
pred = token_probability_fn(prompt[:, :length], mask)
6363
if self.from_logits:
6464
pred = keras.activations.softmax(pred, axis=-1)
6565
# Sort preds in descending order.
@@ -91,6 +91,18 @@ def one_step(length, prompt):
9191
mask[:, length], prompt[:, length], next_token
9292
)
9393

94+
mask = tf.tensor_scatter_nd_update(
95+
tensor=mask,
96+
indices=tf.stack(
97+
(
98+
tf.cast(tf.range(batch_size), dtype=length.dtype),
99+
tf.repeat(length, batch_size),
100+
),
101+
axis=1,
102+
),
103+
updates=tf.repeat(True, batch_size),
104+
)
105+
94106
# Append the next token to current sequence.
95107
prompt = tf.tensor_scatter_nd_update(
96108
tensor=prompt,
@@ -105,13 +117,13 @@ def one_step(length, prompt):
105117
)
106118

107119
length = tf.add(length, 1)
108-
return (length, prompt)
120+
return (length, prompt, mask)
109121

110122
# Run a while loop till text of length `max_length` has been generated.
111123
length, prompt = tf.while_loop(
112124
cond=lambda length, _: tf.less(length, max_length),
113125
body=one_step,
114-
loop_vars=(length, prompt),
126+
loop_vars=(length, prompt, mask),
115127
)
116128

117129
return prompt

0 commit comments

Comments
 (0)