-
Notifications
You must be signed in to change notification settings - Fork 310
GPT2 Text Generation APIs #592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 32 commits
7f7ae43
c53b4a9
e6483a4
513121e
0eb68f6
fa41d23
26fd509
7e4c651
28bcfe1
9757f4d
f7508cb
b658b61
76c430c
bb430dd
afd3082
273a6a5
31ad970
331f568
5300800
de2ac9c
42c164f
08f3c1e
2b93ad8
309d6d4
8206103
9945c13
4fa8fc5
cb12604
f7685ca
3bac2ad
2ed9adb
f2821b5
728a471
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,239 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """GPT2 Causal LM (Language Model).""" | ||
|
|
||
| import copy | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow import keras | ||
|
|
||
| import keras_nlp | ||
| from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone | ||
| from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( | ||
| GPT2CausalLMPreprocessor, | ||
| ) | ||
| from keras_nlp.models.gpt2.gpt2_presets import backbone_presets | ||
| from keras_nlp.models.task import Task | ||
| from keras_nlp.utils.python_utils import classproperty | ||
|
|
||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_nlp") | ||
| class GPT2CausalLM(Task): | ||
| """An end-to-end GPT2 model for causal langauge modeling. | ||
|
|
||
| A causal language model (LM) predicts the next token based on previous | ||
| tokens the next token based on previous tokens, which is the way GPT2 gets | ||
| pretrained. You can finetune `GPT2CausalLM` to generate text similar to | ||
| the custom dataset. `GPT2CausalLM` also has a method `generate()`, which | ||
| generates text based on given prompt. | ||
|
|
||
| This model can optionally be configured with a `preprocessor` layer, in | ||
| which case it will automatically apply preprocessing to raw inputs during | ||
| `fit()`, `predict()`, and `evaluate()`. This is done by default when | ||
| creating the model with `from_preset()`. | ||
|
|
||
| Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
| warranties or conditions of any kind. The underlying model is provided by a | ||
| third party and subject to a separate license, available | ||
| [here](https://github.com/openai/gpt-2). | ||
|
|
||
| Args: | ||
| backbone: A `keras_nlp.models.GPT2Backbone` instance. | ||
| preprocessor: A `keras_nlp.models.GPT2CausalLMPreprocessor` or `None`. | ||
| If `None`, this model will not apply preprocessing, and inputs | ||
| should be preprocessed before calling the model. | ||
|
|
||
| Examples: | ||
|
|
||
| Use `generate()` method to do text generation. | ||
| ```python | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") | ||
| gpt2_lm.generate("I want to say", max_length=30) | ||
|
|
||
| # Generate with batched prompts. | ||
| gpt2_lm.generate(["This is a", "Where are you"], max_length=30) | ||
| ``` | ||
|
|
||
| Use a custom sampler for text generation. | ||
| ```python | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") | ||
|
|
||
| # Use string identifier to set sampler. | ||
| gpt2_lm.generate("I want to say", max_length=30, sampler="top_p") | ||
|
|
||
| # Construct a sampler instance. | ||
| sampler = keras_nlp.samplers.BeamSampler(num_beams=2) | ||
| gpt2_lm.generate("I want to say", max_length=30, sampler=sampler) | ||
| ``` | ||
|
|
||
| Map raw string to languages model logit predictions. | ||
| ```python | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") | ||
| gpt2_lm.predict(["You know this is just a test string"]) | ||
| ``` | ||
|
|
||
| Load a pretrained GPT2 and fit on a string dataset. | ||
| ```python | ||
| features = [ | ||
| "I don't listen to music while coding.", | ||
| "But I watch youtube while coding!", | ||
| ] | ||
| ds = tf.data.Dataset.from_tensor_slices(features) | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Create a `GPT2CausalLM` and fit your data. | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( | ||
| "gpt2_base_en", | ||
| ) | ||
| gpt2_lm.compile( | ||
| loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
| ) | ||
| gpt2_lm.fit(ds, batch_size=2) | ||
| ``` | ||
|
|
||
| Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on | ||
| string inputs. | ||
| ```python | ||
| # Use a shorter sequence length. | ||
| preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( | ||
| "gpt2_base_en", | ||
| sequence_length=128, | ||
| ) | ||
|
|
||
| # Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor. | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( | ||
| "gpt2_base_en", | ||
| preprocessor=preprocessor, | ||
| ) | ||
| gpt2_lm.predict(["You know this is still a test string"]) | ||
| ``` | ||
|
|
||
| Fit your preprocessed data with randomly initialized GPT2. This is useful | ||
jbischof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| when you want to do data preprocessing inside `tf.data` pipeline. | ||
| ```python | ||
| # Define preprocessed input. | ||
| features = { | ||
| "token_ids": tf.constant( | ||
| [[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6) | ||
| ), | ||
| "padding_mask": tf.constant( | ||
| [[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6) | ||
| ), | ||
| } | ||
| labels = tf.constant( | ||
| [[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6) | ||
| ) | ||
| sample_weight = tf.constant( | ||
| [[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6) | ||
| ) | ||
|
|
||
| # Randomly initialize a GPT2 backbone. | ||
| backbone = keras_nlp.models.GPT2Backbone( | ||
| vocabulary_size=50257, | ||
| num_layers=2, | ||
| num_heads=2, | ||
| hidden_dim=128, | ||
| intermediate_dim=256, | ||
| max_sequence_length=128, | ||
| ) | ||
| # Create a `GPT2CausalLM` without preprocessor and fit the data. | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None) | ||
| gpt2_lm.compile( | ||
| loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
| ) | ||
| gpt2_lm.fit( | ||
| x=features, | ||
| y=labels, | ||
| sample_weight=sample_weight, | ||
| batch_size=2, | ||
| ) | ||
| ``` | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, backbone, preprocessor=None, **kwargs): | ||
| inputs = backbone.input | ||
| x = backbone(inputs) | ||
| # Use token embedding weights to project from the token representation | ||
| # to vocabulary logits. | ||
| outputs = tf.matmul( | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| x, | ||
| backbone.token_embedding.embeddings, | ||
| transpose_b=True, | ||
| ) | ||
|
|
||
| # Instantiate using Functional API Model constructor. | ||
| super().__init__( | ||
| inputs=inputs, | ||
| outputs=outputs, | ||
| include_preprocessing=preprocessor is not None, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| self._backbone = backbone | ||
| self._preprocessor = preprocessor | ||
|
|
||
| @classproperty | ||
| def presets(cls): | ||
| return copy.deepcopy(backbone_presets) | ||
|
|
||
| @classproperty | ||
| def backbone_cls(cls): | ||
| return GPT2Backbone | ||
|
|
||
| @classproperty | ||
| def preprocessor_cls(cls): | ||
| return GPT2CausalLMPreprocessor | ||
|
|
||
| def _get_token_probability(self, prompt, mask): | ||
| model_inputs = { | ||
| "token_ids": prompt, | ||
| "padding_mask": mask, | ||
| } | ||
| return self(model_inputs) | ||
|
|
||
| def generate( | ||
| self, | ||
| prompt, | ||
| max_length, | ||
| sampler="top_k", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just curious, why was top-k chosen as the default?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It;s working well with my finetuning tasks. I feel we want to later change this default to contrastive search, which is not yet available |
||
| ): | ||
| """Generate text. | ||
|
|
||
| This method generates text based on given `prompt`. Generation will | ||
| continue until `max_length` is met, and all tokens generated after | ||
| `end_token` will be truncated. The sampling approach used can be | ||
| controlled via the sampler argument. | ||
|
|
||
| Args: | ||
| prompt: a string, string Tensor or string RaggedTensor. The prompt | ||
| text for generation. | ||
| max_length: int. The max length of generated sequence. | ||
| sampler: a string or `keras_nlp.samplers.Sampler` instance. The | ||
| sampler to be used for text generation. | ||
| """ | ||
| end_token_id = self.preprocessor.tokenizer.end_token_id | ||
|
|
||
| sampler = keras_nlp.samplers.get(sampler) | ||
| if hasattr(self, "jit_compile"): | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # `jit_compile` is a public property as of tf 2.12. hasattr is for | ||
| # backward compat. | ||
| sampler.jit_compile = self.jit_compile | ||
| sampler.run_eagerly = self.run_eagerly | ||
| generated = sampler( | ||
| self.preprocessor.tokenizer(prompt), | ||
| self._get_token_probability, | ||
| max_length=max_length, | ||
| end_token_id=end_token_id, | ||
| ) | ||
| return self.preprocessor.tokenizer.detokenize(generated) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """GPT2 Causal LM preprocessor layer.""" | ||
|
|
||
| from absl import logging | ||
| from tensorflow import keras | ||
|
|
||
| from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor | ||
| from keras_nlp.utils.keras_utils import pack_x_y_sample_weight | ||
|
|
||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_nlp") | ||
| class GPT2CausalLMPreprocessor(GPT2Preprocessor): | ||
| """GPT2 Causal LM preprocessor. | ||
|
|
||
| This preprocessor is majorly used as the preprocesor for `GPT2CausalLM`. | ||
| This class subclasses `keras_nlp.models.GPT2Preprocessor` and keeps most of | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| its functionality. The only change is `GPT2CausalLMPreprocessor` sets | ||
| `y` (label) and `sample_weights` field by shifting the input sequence one | ||
| step towards left, and drop the last token as it does not have a successor, | ||
| e.g., if the tokenized input is `[1, 2, 3, 0, 0]` with | ||
| `padding_mask = [1, 1, 1, 0, 0]`, then after preprocessing, we | ||
| will have `x = [1, 2, 3, 0]` and `y = [2, 3, 0, 0]`, with | ||
| `padding_mask = [1, 1, 1, 0]` and `sample_weights = [1, 1, 0, 0]`. | ||
|
|
||
| Args: | ||
| tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance. | ||
| sequence_length: The length of the packed inputs. | ||
|
|
||
| Examples: | ||
| ```python | ||
| # Load the preprocessor from a preset. | ||
| preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( | ||
| "gpt2_base_en" | ||
| ) | ||
|
|
||
| # Tokenize and pack a single sentence. | ||
| sentence = tf.constant("League of legends") | ||
| preprocessor(sentence) | ||
| # Same output. | ||
| preprocessor("League of legends") | ||
|
|
||
| # Tokenize a batch of sentences. | ||
| sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) | ||
| preprocessor(sentences) | ||
| # Same output. | ||
| preprocessor(["Taco tuesday", "Fish taco please!"]) | ||
|
|
||
| # Map a dataset to preprocess a single sentence. | ||
| features = tf.constant( | ||
| [ | ||
| "Avatar 2 is amazing!", | ||
| "Well, I am not sure.", | ||
| ] | ||
| ) | ||
| labels = tf.constant([1, 0]) | ||
| ds = tf.data.Dataset.from_tensor_slices((features, labels)) | ||
| ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
|
|
||
| # Map a dataset to preprocess unlabled sentences. | ||
| ds = tf.data.Dataset.from_tensor_slices(features) | ||
| ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
| """ | ||
|
|
||
| def call(self, x, y=None, sample_weight=None): | ||
| if y is not None or sample_weight is not None: | ||
| logging.warning( | ||
| "`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` " | ||
| "based on your input data, but your data already contains `y` " | ||
| "or `sample_weight`. Your `y` and `sample_weight` will be " | ||
| "ignored." | ||
| ) | ||
|
|
||
| x = super().call(x) | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| token_ids, padding_mask = x["token_ids"], x["padding_mask"] | ||
| # The last token does not have a next token, so we truncate it out. | ||
| x = { | ||
jbischof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "token_ids": token_ids[..., :-1], | ||
| "padding_mask": padding_mask[..., :-1], | ||
| } | ||
| # Target `y` will be the next token. | ||
| y = token_ids[..., 1:] | ||
| sample_weight = padding_mask[..., 1:] | ||
jbischof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return pack_x_y_sample_weight(x, y, sample_weight) | ||
Uh oh!
There was an error while loading. Please reload this page.