-
Notifications
You must be signed in to change notification settings - Fork 316
GPT2 Text Generation APIs #592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 29 commits
7f7ae43
c53b4a9
e6483a4
513121e
0eb68f6
fa41d23
26fd509
7e4c651
28bcfe1
9757f4d
f7508cb
b658b61
76c430c
bb430dd
afd3082
273a6a5
31ad970
331f568
5300800
de2ac9c
42c164f
08f3c1e
2b93ad8
309d6d4
8206103
9945c13
4fa8fc5
cb12604
f7685ca
3bac2ad
2ed9adb
f2821b5
728a471
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,275 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """GPT2 Causal LM (Language Model).""" | ||
|
|
||
| import copy | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow import keras | ||
|
|
||
| import keras_nlp | ||
| from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone | ||
| from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( | ||
| GPT2CausalLMPreprocessor, | ||
| ) | ||
| from keras_nlp.models.gpt2.gpt2_presets import backbone_presets | ||
| from keras_nlp.models.task import Task | ||
| from keras_nlp.utils.python_utils import classproperty | ||
|
|
||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_nlp") | ||
| class ReverseEmbedding(keras.layers.Layer): | ||
| """A layer multiplying model outputs by the token embedding. | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| This layer is used to map model outputs to logits over all vocab tokens. | ||
| It's used in `GPT2CausalLM` to calculate next token's probability. | ||
|
|
||
| Args: | ||
| embedding_layer: a `tf.keras.layers.Embedding` instance, the token | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| embedding layer. | ||
| """ | ||
|
|
||
| def __init__(self, embedding_layer, name="embedding_mapping", **kwargs): | ||
| super().__init__(name=name, **kwargs) | ||
| self.embedding_layer = embedding_layer | ||
|
|
||
| def call(self, inputs): | ||
| return tf.matmul( | ||
| inputs, | ||
| self.embedding_layer.embeddings, | ||
| transpose_b=True, | ||
| ) | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "embedding_layer": keras.layers.serialize(self.embedding_layer), | ||
| } | ||
| ) | ||
| return config | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config): | ||
| config["embedding_layer"] = keras.layers.deserialize( | ||
| config["embedding_layer"], | ||
| ) | ||
| return cls(**config) | ||
|
|
||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_nlp") | ||
| class GPT2CausalLM(Task): | ||
| """GPT2 Causal LM task model. | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Causal LM is predicting the next token based on previous tokens, which is | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| the way GPT2 gets pretrained. Users can finetune `GPT2CausalLM` to generate | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| text similar to the custom dataset. `GPT2CausalLM` also has a public method | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| `generate()`, which generates text based on given prompt. | ||
|
|
||
| This model can optionally be configured with a `preprocessor` layer, in | ||
| which case it will automatically apply preprocessing to raw inputs during | ||
| `fit()`, `predict()`, and `evaluate()`. This is done by default when | ||
| creating the model with `from_preset()`. | ||
|
|
||
| Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
| warranties or conditions of any kind. The underlying model is provided by a | ||
| third party and subject to a separate license, available | ||
| [here](https://github.com/openai/gpt-2). | ||
|
|
||
| Args: | ||
| backbone: A `keras_nlp.models.GPT2Backbone` instance. | ||
| preprocessor: A `keras_nlp.models.GPT2CausalLMPreprocessor` or `None`. | ||
| If `None`, this model will not apply preprocessing, and inputs | ||
| should be preprocessed before calling the model. | ||
|
|
||
| Examples: | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Use `generate()` method to do text generation. | ||
| ```python | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") | ||
| gpt2_lm.generate("I want to say", max_length=30) | ||
|
|
||
| # Generate with batched prompts. | ||
| gpt2_lm.generate(["This is a", "Where are you"], max_length=30) | ||
| ``` | ||
|
|
||
| Use a custom sampler for text generation. | ||
| ```python | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") | ||
|
|
||
| # Use string identifier to set sampler. | ||
| gpt2_lm.generate("I want to say", max_length=30, sampler="top_p") | ||
|
|
||
| # Construct a sampler instance. | ||
| sampler = keras_nlp.samplers.BeamSampler(num_beams=2) | ||
| gpt2_lm.generate("I want to say", max_length=30, sampler=sampler) | ||
| ``` | ||
|
|
||
| Load a pretrained `GPT2CausalLM` and get outputs on raw string inputs. | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ```python | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") | ||
| gpt2_lm.predict(["You know this is just a test string"]) | ||
| ``` | ||
|
|
||
| Load a pretrained GPT2 and fit on a string dataset. | ||
| ```python | ||
| features = [ | ||
| "I don't listen to music while coding.", | ||
| "But I watch youtube while coding!", | ||
| ] | ||
| ds = tf.data.Dataset.from_tensor_slices(features) | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Create a `GPT2CausalLM` and fit your data. | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( | ||
| "gpt2_base_en", | ||
| ) | ||
| gpt2_lm.compile( | ||
| loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
| ) | ||
| gpt2_lm.fit(x=features, batch_size=2) | ||
| ``` | ||
|
|
||
| Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on | ||
| string inputs. | ||
| ```python | ||
| # Use a shorter sequence length. | ||
| preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( | ||
| "gpt2_base_en", | ||
| sequence_length=128, | ||
| ) | ||
|
|
||
| # Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor. | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( | ||
| "gpt2_base_en", | ||
| preprocessor=preprocessor, | ||
| ) | ||
| gpt2_lm.predict(["You know this is still a test string"]) | ||
| ``` | ||
|
|
||
| Fit your preprocessed data with randomly initialized GPT2. This is useful | ||
jbischof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| when you want to do data preprocessing inside `tf.data` pipeline. | ||
| ```python | ||
| # Define preprocessed input. | ||
| features = { | ||
| "token_ids": tf.constant( | ||
| [[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6) | ||
| ), | ||
| "padding_mask": tf.constant( | ||
| [[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6) | ||
| ), | ||
| } | ||
| labels = tf.constant( | ||
| [[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6) | ||
| ) | ||
| sample_weight = tf.constant( | ||
| [[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6) | ||
| ) | ||
|
|
||
| # Randomly initialize a GPT2 backbone. | ||
| backbone = keras_nlp.models.GPT2Backbone( | ||
| vocabulary_size=50257, | ||
| num_layers=2, | ||
| num_heads=2, | ||
| hidden_dim=128, | ||
| intermediate_dim=256, | ||
| max_sequence_length=128, | ||
| ) | ||
| # Create a `GPT2CausalLM` without preprocessor and fit the data. | ||
| gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None) | ||
| gpt2_lm.compile( | ||
| loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
| ) | ||
| gpt2_lm.fit( | ||
| x=features, | ||
| y=labels, | ||
| sample_weight=sample_weight, | ||
| batch_size=2, | ||
| ) | ||
| ``` | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, backbone, preprocessor=None, **kwargs): | ||
| inputs = backbone.input | ||
| x = backbone(inputs) | ||
| embedding_layer = backbone.get_layer("token_embedding") | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| embedding_map_layer = ReverseEmbedding(embedding_layer) | ||
| outputs = embedding_map_layer(x) | ||
|
|
||
| # Instantiate using Functional API Model constructor | ||
| super().__init__( | ||
| inputs=inputs, | ||
| outputs=outputs, | ||
| include_preprocessing=preprocessor is not None, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| self._backbone = backbone | ||
| self._preprocessor = preprocessor | ||
|
|
||
| @classproperty | ||
| def presets(cls): | ||
| return copy.deepcopy(backbone_presets) | ||
|
|
||
| @classproperty | ||
| def backbone_cls(cls): | ||
| return GPT2Backbone | ||
|
|
||
| @classproperty | ||
| def preprocessor_cls(cls): | ||
| return GPT2CausalLMPreprocessor | ||
|
|
||
| def _get_token_probability(self, prompt, mask): | ||
| model_inputs = { | ||
| "token_ids": prompt, | ||
| "padding_mask": mask, | ||
| } | ||
| return self(model_inputs) | ||
|
|
||
| def generate( | ||
| self, | ||
| prompt, | ||
| max_length, | ||
| sampler="top_k", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just curious, why was top-k chosen as the default?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It;s working well with my finetuning tasks. I feel we want to later change this default to contrastive search, which is not yet available |
||
| ): | ||
| """Generate text. | ||
|
|
||
| This method generates text based on given `prompt`. Generation will | ||
| continue until `max_length` is met, and all tokens generated after | ||
| `end_token` will be truncated. | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Args: | ||
| prompt: a string, string Tensor or string RaggedTensor. The prompt | ||
| text for generation. | ||
| max_length: int. The max length of generated sequence. | ||
| end_token: string, defaults to "<|endoftext|>", which is the default | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| end token of GPT2. The token marking the end of the sequence, | ||
| tokens generated after the end token will be truncated. | ||
| sampler: a string or `keras_nlp.samplers.Sampler` instance. The | ||
| sampler to be used for text generation. | ||
| """ | ||
| end_token_id = self.preprocessor.tokenizer.end_token_id | ||
|
|
||
| if isinstance(sampler, str): | ||
| sampler = keras_nlp.samplers.get(sampler) | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if hasattr(self, "jit_compile"): | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sampler.jit_compile = self.jit_compile | ||
| sampler.run_eagerly = self.run_eagerly | ||
| generated = sampler( | ||
| self.preprocessor.tokenizer(prompt), | ||
| self._get_token_probability, | ||
| max_length=max_length, | ||
| end_token_id=end_token_id, | ||
| ) | ||
| return self.preprocessor.tokenizer.detokenize(generated) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """GPT2 Causal LM preprocessor layer.""" | ||
|
|
||
| from tensorflow import keras | ||
|
|
||
| from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor | ||
| from keras_nlp.utils.keras_utils import pack_x_y_sample_weight | ||
|
|
||
|
|
||
| @keras.utils.register_keras_serializable(package="keras_nlp") | ||
| class GPT2CausalLMPreprocessor(GPT2Preprocessor): | ||
| """GPT2 Causal LM preprocessor. | ||
|
|
||
| This class subclasses `keras_nlp.models.GPT2Preprocessor` and keeps most of | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| its functionality. The only change is `GPT2CausalLMPreprocessor` sets | ||
| `y` (label) and `sample_weights` field by shifting the input sequence one | ||
| step towards left, and drop the last token as it does not have a successor, | ||
| e.g., if the tokenized input is `[1, 2, 3, 0, 0]` with | ||
| `padding_mask=[1, 1, 1, 0, 0]`, then after preprocessing, we | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| will have `x=[1, 2, 3, 0]` and `y=[2, 3, 0, 0]`, with | ||
| `padding_mask=[1, 1, 1, 0]` and `sample_weights=[1, 1, 0, 0]`. | ||
|
|
||
| Args: | ||
| tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance. | ||
| sequence_length: The length of the packed inputs. | ||
|
|
||
| Examples: | ||
| ```python | ||
| # Load the preprocessor from a preset. | ||
| preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( | ||
| "gpt2_base_en" | ||
| ) | ||
|
|
||
| # Tokenize and pack a single sentence. | ||
| sentence = tf.constant("league of legends") | ||
chenmoneygithub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| preprocessor(sentence) | ||
| # Same output. | ||
| preprocessor("league of legends") | ||
|
|
||
| # Tokenize a batch of sentences. | ||
| sentences = tf.constant(["taco tuesday", "fish taco please!"]) | ||
| preprocessor(sentences) | ||
| # Same output. | ||
| preprocessor(["taco tuesday", "fish taco please!"]) | ||
|
|
||
| # Map a dataset to preprocess a single sentence. | ||
| features = tf.constant( | ||
| [ | ||
| "Avatar 2 is amazing!", | ||
| "Well, I am not sure.", | ||
| ] | ||
| ) | ||
| labels = tf.constant([1, 0]) | ||
| ds = tf.data.Dataset.from_tensor_slices((features, labels)) | ||
| ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
|
|
||
| # Map a dataset to preprocess unlabled sentences. | ||
| ds = tf.data.Dataset.from_tensor_slices(features) | ||
| ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
| """ | ||
|
|
||
| def call(self, x, y=None, sample_weight=None): | ||
| x = super().call(x) | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| token_ids, padding_mask = x["token_ids"], x["padding_mask"] | ||
| # The last token does not have a next token, so we truncate it out. | ||
| x = { | ||
jbischof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "token_ids": token_ids[..., :-1], | ||
| "padding_mask": padding_mask[..., :-1], | ||
| } | ||
| # Target `y` will be the next token. | ||
| y = token_ids[..., 1:] | ||
| sample_weight = padding_mask[..., 1:] | ||
jbischof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return pack_x_y_sample_weight(x, y, sample_weight) | ||
Uh oh!
There was an error while loading. Please reload this page.