|
| 1 | +# Copyright 2022 The KerasNLP Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""GPT2 Causal LM (Language Model).""" |
| 15 | + |
| 16 | +import copy |
| 17 | + |
| 18 | +import tensorflow as tf |
| 19 | +from tensorflow import keras |
| 20 | + |
| 21 | +import keras_nlp |
| 22 | +from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone |
| 23 | +from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( |
| 24 | + GPT2CausalLMPreprocessor, |
| 25 | +) |
| 26 | +from keras_nlp.models.gpt2.gpt2_presets import backbone_presets |
| 27 | +from keras_nlp.models.task import Task |
| 28 | +from keras_nlp.utils.python_utils import classproperty |
| 29 | + |
| 30 | + |
| 31 | +@keras.utils.register_keras_serializable(package="keras_nlp") |
| 32 | +class GPT2CausalLM(Task): |
| 33 | + """An end-to-end GPT2 model for causal langauge modeling. |
| 34 | +
|
| 35 | + A causal language model (LM) predicts the next token based on previous |
| 36 | + tokens the next token based on previous tokens, which is the way GPT2 gets |
| 37 | + pretrained. You can finetune `GPT2CausalLM` to generate text similar to |
| 38 | + the custom dataset. `GPT2CausalLM` also has a method `generate()`, which |
| 39 | + generates text based on given prompt. |
| 40 | +
|
| 41 | + This model can optionally be configured with a `preprocessor` layer, in |
| 42 | + which case it will automatically apply preprocessing to raw inputs during |
| 43 | + `fit()`, `predict()`, and `evaluate()`. This is done by default when |
| 44 | + creating the model with `from_preset()`. |
| 45 | +
|
| 46 | + Disclaimer: Pre-trained models are provided on an "as is" basis, without |
| 47 | + warranties or conditions of any kind. The underlying model is provided by a |
| 48 | + third party and subject to a separate license, available |
| 49 | + [here](https://github.com/openai/gpt-2). |
| 50 | +
|
| 51 | + Args: |
| 52 | + backbone: A `keras_nlp.models.GPT2Backbone` instance. |
| 53 | + preprocessor: A `keras_nlp.models.GPT2CausalLMPreprocessor` or `None`. |
| 54 | + If `None`, this model will not apply preprocessing, and inputs |
| 55 | + should be preprocessed before calling the model. |
| 56 | +
|
| 57 | + Examples: |
| 58 | +
|
| 59 | + Use `generate()` method to do text generation. |
| 60 | + ```python |
| 61 | + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") |
| 62 | + gpt2_lm.generate("I want to say", max_length=30) |
| 63 | +
|
| 64 | + # Generate with batched prompts. |
| 65 | + gpt2_lm.generate(["This is a", "Where are you"], max_length=30) |
| 66 | + ``` |
| 67 | +
|
| 68 | + Use a custom sampler for text generation. |
| 69 | + ```python |
| 70 | + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") |
| 71 | +
|
| 72 | + # Use string identifier to set sampler. |
| 73 | + gpt2_lm.generate("I want to say", max_length=30, sampler="top_p") |
| 74 | +
|
| 75 | + # Construct a sampler instance. |
| 76 | + sampler = keras_nlp.samplers.BeamSampler(num_beams=2) |
| 77 | + gpt2_lm.generate("I want to say", max_length=30, sampler=sampler) |
| 78 | + ``` |
| 79 | +
|
| 80 | + Map raw string to languages model logit predictions. |
| 81 | + ```python |
| 82 | + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") |
| 83 | + gpt2_lm.predict(["You know this is just a test string"]) |
| 84 | + ``` |
| 85 | +
|
| 86 | + Load a pretrained GPT2 and fit on a string dataset. |
| 87 | + ```python |
| 88 | + features = [ |
| 89 | + "I don't listen to music while coding.", |
| 90 | + "But I watch youtube while coding!", |
| 91 | + ] |
| 92 | + ds = tf.data.Dataset.from_tensor_slices(features) |
| 93 | +
|
| 94 | + # Create a `GPT2CausalLM` and fit your data. |
| 95 | + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( |
| 96 | + "gpt2_base_en", |
| 97 | + ) |
| 98 | + gpt2_lm.compile( |
| 99 | + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
| 100 | + ) |
| 101 | + gpt2_lm.fit(ds, batch_size=2) |
| 102 | + ``` |
| 103 | +
|
| 104 | + Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on |
| 105 | + string inputs. |
| 106 | + ```python |
| 107 | + # Use a shorter sequence length. |
| 108 | + preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( |
| 109 | + "gpt2_base_en", |
| 110 | + sequence_length=128, |
| 111 | + ) |
| 112 | +
|
| 113 | + # Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor. |
| 114 | + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( |
| 115 | + "gpt2_base_en", |
| 116 | + preprocessor=preprocessor, |
| 117 | + ) |
| 118 | + gpt2_lm.predict(["You know this is still a test string"]) |
| 119 | + ``` |
| 120 | +
|
| 121 | + Fit your preprocessed data with randomly initialized GPT2. This is useful |
| 122 | + when you want to do data preprocessing inside `tf.data` pipeline. |
| 123 | + ```python |
| 124 | + # Define preprocessed input. |
| 125 | + features = { |
| 126 | + "token_ids": tf.constant( |
| 127 | + [[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6) |
| 128 | + ), |
| 129 | + "padding_mask": tf.constant( |
| 130 | + [[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6) |
| 131 | + ), |
| 132 | + } |
| 133 | + labels = tf.constant( |
| 134 | + [[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6) |
| 135 | + ) |
| 136 | + sample_weight = tf.constant( |
| 137 | + [[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6) |
| 138 | + ) |
| 139 | +
|
| 140 | + # Randomly initialize a GPT2 backbone. |
| 141 | + backbone = keras_nlp.models.GPT2Backbone( |
| 142 | + vocabulary_size=50257, |
| 143 | + num_layers=2, |
| 144 | + num_heads=2, |
| 145 | + hidden_dim=128, |
| 146 | + intermediate_dim=256, |
| 147 | + max_sequence_length=128, |
| 148 | + ) |
| 149 | + # Create a `GPT2CausalLM` without preprocessor and fit the data. |
| 150 | + gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None) |
| 151 | + gpt2_lm.compile( |
| 152 | + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
| 153 | + ) |
| 154 | + gpt2_lm.fit( |
| 155 | + x=features, |
| 156 | + y=labels, |
| 157 | + sample_weight=sample_weight, |
| 158 | + batch_size=2, |
| 159 | + ) |
| 160 | + ``` |
| 161 | +
|
| 162 | + """ |
| 163 | + |
| 164 | + def __init__(self, backbone, preprocessor=None, **kwargs): |
| 165 | + inputs = backbone.input |
| 166 | + x = backbone(inputs) |
| 167 | + # Use token embedding weights to project from the token representation |
| 168 | + # to vocabulary logits. |
| 169 | + outputs = tf.matmul( |
| 170 | + x, |
| 171 | + backbone.token_embedding.embeddings, |
| 172 | + transpose_b=True, |
| 173 | + ) |
| 174 | + |
| 175 | + # Instantiate using Functional API Model constructor. |
| 176 | + super().__init__( |
| 177 | + inputs=inputs, |
| 178 | + outputs=outputs, |
| 179 | + include_preprocessing=preprocessor is not None, |
| 180 | + **kwargs, |
| 181 | + ) |
| 182 | + |
| 183 | + self._backbone = backbone |
| 184 | + self._preprocessor = preprocessor |
| 185 | + |
| 186 | + @classproperty |
| 187 | + def presets(cls): |
| 188 | + return copy.deepcopy(backbone_presets) |
| 189 | + |
| 190 | + @classproperty |
| 191 | + def backbone_cls(cls): |
| 192 | + return GPT2Backbone |
| 193 | + |
| 194 | + @classproperty |
| 195 | + def preprocessor_cls(cls): |
| 196 | + return GPT2CausalLMPreprocessor |
| 197 | + |
| 198 | + def _get_token_probability(self, prompt, mask): |
| 199 | + model_inputs = { |
| 200 | + "token_ids": prompt, |
| 201 | + "padding_mask": mask, |
| 202 | + } |
| 203 | + return self(model_inputs) |
| 204 | + |
| 205 | + def generate( |
| 206 | + self, |
| 207 | + prompt, |
| 208 | + max_length, |
| 209 | + sampler="top_k", |
| 210 | + ): |
| 211 | + """Generate text. |
| 212 | +
|
| 213 | + This method generates text based on given `prompt`. Generation will |
| 214 | + continue until `max_length` is met, and all tokens generated after |
| 215 | + `end_token` will be truncated. The sampling approach used can be |
| 216 | + controlled via the sampler argument. |
| 217 | +
|
| 218 | + Args: |
| 219 | + prompt: a string, string Tensor or string RaggedTensor. The prompt |
| 220 | + text for generation. |
| 221 | + max_length: int. The max length of generated sequence. |
| 222 | + sampler: a string or `keras_nlp.samplers.Sampler` instance. The |
| 223 | + sampler to be used for text generation. |
| 224 | + """ |
| 225 | + end_token_id = self.preprocessor.tokenizer.end_token_id |
| 226 | + |
| 227 | + sampler = keras_nlp.samplers.get(sampler) |
| 228 | + if hasattr(self, "jit_compile"): |
| 229 | + # `jit_compile` is a public property as of tf 2.12. hasattr is for |
| 230 | + # backward compat. |
| 231 | + sampler.jit_compile = self.jit_compile |
| 232 | + sampler.run_eagerly = self.run_eagerly |
| 233 | + generated = sampler( |
| 234 | + self.preprocessor.tokenizer(prompt), |
| 235 | + self._get_token_probability, |
| 236 | + max_length=max_length, |
| 237 | + end_token_id=end_token_id, |
| 238 | + ) |
| 239 | + return self.preprocessor.tokenizer.detokenize(generated) |
0 commit comments