Skip to content

Commit 4d9a9b7

Browse files
fix serialization
1 parent 9945c13 commit 4d9a9b7

File tree

4 files changed

+25
-20
lines changed

4 files changed

+25
-20
lines changed

keras_nlp/models/gpt2/gpt2_causal_lm.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,18 @@
2828
from keras_nlp.utils.python_utils import classproperty
2929

3030

31-
# @keras.utils.register_keras_serializable(package="keras_nlp")
31+
@keras.utils.register_keras_serializable(package="keras_nlp")
3232
class EmbeddingMapping(keras.layers.Layer):
33+
"""A layer multiplying model outputs by the token embedding.
34+
35+
This layer is used to map model outputs to logits over all vocab tokens.
36+
It's used in `GPT2CausalLM` to calculate next token's probability.
37+
38+
Args:
39+
embedding_layer: a `tf.keras.layers.Embedding` instance, the token
40+
embedding layer.
41+
"""
42+
3343
def __init__(self, embedding_layer, name="embedding_mapping", **kwargs):
3444
super().__init__(name=name, **kwargs)
3545
self.embedding_layer = embedding_layer
@@ -198,10 +208,9 @@ class GPT2CausalLM(PipelineModel):
198208
def __init__(self, backbone, preprocessor=None, **kwargs):
199209
inputs = backbone.input
200210
x = backbone(inputs)
201-
# embedding_layer = backbone.get_layer("token_embedding")
202-
# embedding_map_layer = EmbeddingMapping(embedding_layer)
203-
# outputs = embedding_map_layer(x)
204-
outputs = x
211+
embedding_layer = backbone.get_layer("token_embedding")
212+
embedding_map_layer = EmbeddingMapping(embedding_layer)
213+
outputs = embedding_map_layer(x)
205214

206215
# Instantiate using Functional API Model constructor
207216
super().__init__(
@@ -219,12 +228,12 @@ def preprocess_samples(self, x, y=None, sample_weight=None):
219228

220229
@property
221230
def backbone(self):
222-
"""The associated `keras_nlp.models.RobertaBackbone`."""
231+
"""The associated `keras_nlp.models.GPT2Backbone`."""
223232
return self._backbone
224233

225234
@property
226235
def preprocessor(self):
227-
"""A `keras_nlp.models.RobertaMaskedLMPreprocessor` for preprocessing inputs."""
236+
"""A `keras_nlp.models.GPT2CausalLMPreprocessor` for preprocessing."""
228237
return self._preprocessor
229238

230239
@classproperty

keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
"""GPT2 Causal LM preprocessor layer."""
1616

17+
from tensorflow import keras
1718

1819
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
1920
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
2021

2122

23+
@keras.utils.register_keras_serializable(package="keras_nlp")
2224
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
2325
"""GPT2 Causal LM preprocessor.
2426

keras_nlp/models/gpt2/gpt2_preprocessor.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818

1919
import tensorflow as tf
20+
from tensorflow import keras
2021

2122
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
2223
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
@@ -28,13 +29,13 @@
2829
from keras_nlp.utils.python_utils import classproperty
2930

3031

32+
@keras.utils.register_keras_serializable(package="keras_nlp")
3133
class GPT2Preprocessor(Preprocessor):
3234
"""GPT2 preprocessing layer which tokenizes and packs inputs.
3335
34-
This preprocessing layer will do three things:
36+
This preprocessing layer will do 2 things:
3537
3638
- Tokenize the input using the `tokenizer`.
37-
- Add the id of '<|endoftext|>' to the start and end of the tokenized input.
3839
- Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can
3940
be passed directly to a `keras_nlp.models.GPT2Backbone`.
4041
@@ -135,23 +136,19 @@ def call(self, x, y=None, sample_weight=None):
135136
if len(x) > 1:
136137
raise ValueError(
137138
"GPT2 requires each input feature to contain only "
138-
f"one segment, but received: {len(x)}. If you are using GPT2 "
139+
f"one segment, but received {len(x)}. If you are using GPT2 "
139140
"for a multi-segment classification task, please refer to "
140141
"classification models like BERT or RoBERTa."
141142
)
142143
token_ids = self._tokenizer(x[0])
143-
# batch_size = token_ids.nrows()
144-
# start_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id)
145-
# end_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id)
146-
# token_ids = tf.concat([start_column, token_ids, end_column], axis=1)
147-
input_is_1d = False
148-
if len(token_ids.shape) == 1:
149-
input_is_1d = True
144+
input_is_1d = len(token_ids.shape) == 1
145+
if input_is_1d:
150146
token_ids = tf.RaggedTensor.from_tensor([token_ids])
151147
mask = tf.ones_like(token_ids, dtype=tf.bool)
152148
mask = mask.to_tensor(shape=(None, self.sequence_length))
153149
token_ids = token_ids.to_tensor(shape=(None, self.sequence_length))
154150
if input_is_1d:
151+
# If the input is a single string, we let the output be a 1D tensor.
155152
token_ids = tf.squeeze(token_ids, axis=0)
156153
mask = tf.squeeze(mask, axis=0)
157154
x = {

keras_nlp/models/roberta/roberta_preprocessor_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ def test_tokenize_labeled_dataset(self):
9898
sw = tf.constant([1.0] * 4)
9999
ds = tf.data.Dataset.from_tensor_slices((x, y, sw))
100100
ds = ds.map(self.preprocessor)
101-
import pdb
102-
103-
pdb.set_trace()
104101
x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element()
105102
self.assertAllEqual(
106103
x_out["token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1, 1, 1]] * 4

0 commit comments

Comments
 (0)