|
17 | 17 | import copy |
18 | 18 |
|
19 | 19 | import tensorflow as tf |
| 20 | +from tensorflow import keras |
20 | 21 |
|
21 | 22 | from keras_nlp.models.gpt2.gpt2_presets import backbone_presets |
22 | 23 | from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer |
|
28 | 29 | from keras_nlp.utils.python_utils import classproperty |
29 | 30 |
|
30 | 31 |
|
| 32 | +@keras.utils.register_keras_serializable(package="keras_nlp") |
31 | 33 | class GPT2Preprocessor(Preprocessor): |
32 | 34 | """GPT2 preprocessing layer which tokenizes and packs inputs. |
33 | 35 |
|
34 | | - This preprocessing layer will do three things: |
| 36 | + This preprocessing layer will do 2 things: |
35 | 37 |
|
36 | 38 | - Tokenize the input using the `tokenizer`. |
37 | | - - Add the id of '<|endoftext|>' to the start and end of the tokenized input. |
38 | 39 | - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can |
39 | 40 | be passed directly to a `keras_nlp.models.GPT2Backbone`. |
40 | 41 |
|
@@ -135,23 +136,19 @@ def call(self, x, y=None, sample_weight=None): |
135 | 136 | if len(x) > 1: |
136 | 137 | raise ValueError( |
137 | 138 | "GPT2 requires each input feature to contain only " |
138 | | - f"one segment, but received: {len(x)}. If you are using GPT2 " |
| 139 | + f"one segment, but received {len(x)}. If you are using GPT2 " |
139 | 140 | "for a multi-segment classification task, please refer to " |
140 | 141 | "classification models like BERT or RoBERTa." |
141 | 142 | ) |
142 | 143 | token_ids = self._tokenizer(x[0]) |
143 | | - # batch_size = token_ids.nrows() |
144 | | - # start_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id) |
145 | | - # end_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id) |
146 | | - # token_ids = tf.concat([start_column, token_ids, end_column], axis=1) |
147 | | - input_is_1d = False |
148 | | - if len(token_ids.shape) == 1: |
149 | | - input_is_1d = True |
| 144 | + input_is_1d = len(token_ids.shape) == 1 |
| 145 | + if input_is_1d: |
150 | 146 | token_ids = tf.RaggedTensor.from_tensor([token_ids]) |
151 | 147 | mask = tf.ones_like(token_ids, dtype=tf.bool) |
152 | 148 | mask = mask.to_tensor(shape=(None, self.sequence_length)) |
153 | 149 | token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) |
154 | 150 | if input_is_1d: |
| 151 | + # If the input is a single string, we let the output be a 1D tensor. |
155 | 152 | token_ids = tf.squeeze(token_ids, axis=0) |
156 | 153 | mask = tf.squeeze(mask, axis=0) |
157 | 154 | x = { |
|
0 commit comments