-
Notifications
You must be signed in to change notification settings - Fork 310
Add keras_nlp.samplers #563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
26fd509
7e4c651
28bcfe1
9757f4d
f7508cb
b658b61
76c430c
bb430dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from tensorflow import keras | ||
|
|
||
| from keras_nlp.samplers.greedy import Greedy | ||
|
|
||
|
|
||
| def serialize(sampler): | ||
| return keras.utils.serialize_keras_object(sampler) | ||
|
|
||
|
|
||
| def deserialize(config, custom_objects=None): | ||
| """Return a `Sampler` object from its config.""" | ||
| all_classes = { | ||
| "greedy": Greedy, | ||
| } | ||
| return keras.utils.deserialize_keras_object( | ||
| config, | ||
| module_objects=all_classes, | ||
| custom_objects=custom_objects, | ||
| printable_module_name="samplers", | ||
| ) | ||
|
|
||
|
|
||
| def get(identifier): | ||
| """Retrieve a KerasNLP sampler by the identifier. | ||
|
|
||
| The `identifier` may be the string name of a sampler class or class. | ||
|
|
||
| >>> identifier = 'greedy' | ||
| >>> sampler = keras_nlp.samplers.get(identifier) | ||
|
|
||
| You can also specify `config` of the sampler to this function by passing | ||
| dict containing `class_name` and `config` as an identifier. Also note that | ||
| the `class_name` must map to a `Sampler` class. | ||
|
|
||
| >>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}} | ||
| >>> sampler = keras_nlp.samplers.get(cfg) | ||
|
|
||
| In the case that the `identifier` is a class, this method will return a new | ||
| instance of the class by its constructor. | ||
|
|
||
| Args: | ||
| identifier: String or dict that contains the sampler name or | ||
| configurations. | ||
|
|
||
| Returns: | ||
| Sampler instance base on the input identifier. | ||
|
|
||
| Raises: | ||
| ValueError: If the input identifier is not a supported type or in a bad | ||
| format. | ||
| """ | ||
|
|
||
| if identifier is None: | ||
| return None | ||
| if isinstance(identifier, dict): | ||
| return deserialize(identifier) | ||
| elif isinstance(identifier, str): | ||
| if not identifier.islower(): | ||
| raise KeyError( | ||
| "`keras_nlp.samplers.get()` must take a lowercase string " | ||
| f"identifier, but received: {identifier}." | ||
| ) | ||
| return deserialize(identifier) | ||
| elif callable(identifier): | ||
| return identifier | ||
| else: | ||
| raise ValueError( | ||
| "Could not interpret sampler identifier: " + str(identifier) | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Greedy Sampler.""" | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow import keras | ||
|
|
||
| from keras_nlp.samplers.sampler import Sampler | ||
| from keras_nlp.samplers.sampler import base_sampler_args_docstring | ||
| from keras_nlp.samplers.sampler import call_args_docstring | ||
| from keras_nlp.samplers.sampler import sample_args_docstring | ||
| from keras_nlp.utils.python_utils import format_docstring | ||
|
|
||
|
|
||
| @format_docstring( | ||
| base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring | ||
| ) | ||
| @keras.utils.register_keras_serializable(package="keras_nlp") | ||
| class Greedy(Sampler): | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Greedy sampler class. | ||
|
|
||
| This sampler is implemented on greedy search, i.e., always picking up the | ||
| token of the largest probability as the next token. | ||
|
|
||
| Args: | ||
| {{base_sampler_args}} | ||
|
|
||
| Call Args: | ||
| {{call_args}} | ||
|
|
||
| Examples: | ||
| ```python | ||
| BATCH_SIZE = 8 | ||
| VOCAB_SIZE = 10 | ||
| FEATURE_SIZE = 16 | ||
| START_ID = 1 | ||
|
|
||
| # Create a dummy model to predict the next token. | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| model = keras.Sequential( | ||
| [ | ||
| keras.Input(shape=[None]), | ||
| keras.layers.Embedding( | ||
| input_dim=VOCAB_SIZE, | ||
| output_dim=FEATURE_SIZE, | ||
| ), | ||
| keras.layers.Dense(VOCAB_SIZE, activation="softmax"), | ||
| ] | ||
| ) | ||
|
|
||
| # Define a function that outputs the next token's probability for each token | ||
| # in the input sequence. | ||
| def token_probability_fn(inputs, mask): | ||
| return model(inputs) | ||
|
|
||
| prompt = tf.fill((BATCH_SIZE, 1), START_ID) | ||
|
|
||
| sampler = keras_nlp.samplers.Greedy() | ||
| # Print the generated sequence (token ids). | ||
| print(sampler(prompt, token_probability_fn, 10)) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| jit_compile=True, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. give that we talked about moving compilation to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we cannot unfortunately, not all of the sampler's |
||
| ): | ||
| super().__init__(jit_compile) | ||
|
|
||
| @format_docstring(sample_args=sample_args_docstring) | ||
| def sample( | ||
| self, prompt, token_probability_fn, mask, num_steps, from_logits=True | ||
| ): | ||
| """Sampling logic implementation. | ||
|
|
||
| Args: | ||
| {{sample_args}} | ||
| """ | ||
| batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] | ||
| max_length = tf.cast(max_length, num_steps.dtype) | ||
| # The index of the last non-padding token in prompt. Since all sequences | ||
| # are aligned to the right side, the index is the same for all. | ||
| current_index = max_length - num_steps | ||
|
|
||
| def one_step(current_index, prompt, mask): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we factor out
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's actually a good idea, the issue is the customization of child We can expose an abstract method |
||
| probs = token_probability_fn(prompt, mask) | ||
| next_token_prob = tf.gather( | ||
| probs, | ||
| tf.repeat(current_index - 1, batch_size), | ||
| axis=1, | ||
| batch_dims=1, | ||
| ) | ||
| next_token = tf.cast( | ||
| tf.argmax(next_token_prob, axis=-1), dtype=prompt.dtype | ||
| ) | ||
| next_token = tf.where( | ||
| mask[:, current_index], prompt[:, current_index], next_token | ||
| ) | ||
| mask = tf.tensor_scatter_nd_update( | ||
| tensor=mask, | ||
| indices=tf.stack( | ||
| ( | ||
| tf.cast( | ||
| tf.range(batch_size), dtype=current_index.dtype | ||
| ), | ||
| tf.repeat(current_index, batch_size), | ||
| ), | ||
| axis=1, | ||
| ), | ||
| updates=tf.repeat(True, batch_size), | ||
| ) | ||
|
|
||
| # Append the next token to current sequence. | ||
| prompt = tf.tensor_scatter_nd_update( | ||
| tensor=prompt, | ||
| indices=tf.stack( | ||
| ( | ||
| tf.cast( | ||
| tf.range(batch_size), dtype=current_index.dtype | ||
| ), | ||
| tf.repeat(current_index, batch_size), | ||
| ), | ||
| axis=1, | ||
| ), | ||
| updates=next_token, | ||
| ) | ||
|
|
||
| current_index = tf.add(current_index, 1) | ||
| return (current_index, prompt, mask) | ||
|
|
||
| # Run a while loop till `max_length` of tokens has been generated. | ||
| current_index, prompt, mask = tf.while_loop( | ||
| cond=lambda current_index, prompt, mask: tf.less( | ||
| current_index, max_length | ||
| ), | ||
| body=one_step, | ||
| loop_vars=(current_index, prompt, mask), | ||
| ) | ||
| return prompt | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Tests for Greedy sampler.""" | ||
|
|
||
| import tensorflow as tf | ||
| from absl.testing import parameterized | ||
| from tensorflow import keras | ||
|
|
||
| from keras_nlp.samplers.greedy import Greedy | ||
|
|
||
|
|
||
| class GreedyTest(tf.test.TestCase, parameterized.TestCase): | ||
| def setUp(self): | ||
| super().setUp() | ||
| self.vocab_size = 10 | ||
| self.feature_size = 16 | ||
|
|
||
| # Create a dummy model to predict the next token. | ||
| model = keras.Sequential( | ||
| [ | ||
| keras.Input(shape=[None]), | ||
| keras.layers.Embedding( | ||
| input_dim=self.vocab_size, | ||
| output_dim=self.feature_size, | ||
| ), | ||
| keras.layers.Dense(self.vocab_size), | ||
| keras.layers.Softmax(), | ||
| ] | ||
| ) | ||
|
|
||
| def token_probability_fn(inputs, mask): | ||
| return model(inputs) | ||
|
|
||
| self.token_probability_fn = token_probability_fn | ||
|
|
||
| self.sampler = Greedy() | ||
|
|
||
| def test_generate_with_1d_prompt(self): | ||
| inputs = tf.constant([1]) | ||
| outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) | ||
| self.assertEqual(outputs.shape, [5]) | ||
|
|
||
| def test_generate_with_2d_prompt(self): | ||
| inputs = tf.constant([[1], [1]]) | ||
| outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) | ||
| self.assertEqual(outputs.shape, [2, 5]) | ||
|
|
||
| def test_generate_with_list_prompt(self): | ||
| inputs = [[1], [1]] | ||
| outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) | ||
| self.assertEqual(outputs.shape, [2, 5]) | ||
|
|
||
| def test_generate_with_ragged_prompt(self): | ||
| max_length = 5 | ||
|
|
||
| def token_probability_fn(inputs, mask): | ||
| # Assert that user function is passed only dense tensors. | ||
| self.assertIsInstance(inputs, tf.Tensor) | ||
| prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) | ||
| return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1) | ||
|
|
||
| inputs = tf.ragged.constant([[1], [2, 1, 2]]) | ||
| outputs = self.sampler(inputs, token_probability_fn, max_length) | ||
| self.assertEqual(outputs.shape, [2, 5]) | ||
|
|
||
| def test_assert_generation_is_correct(self): | ||
| batch_size = 10 | ||
| max_length = 3 | ||
|
|
||
| def token_probability_fn(inputs, mask): | ||
| prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) | ||
| return tf.repeat( | ||
| tf.repeat(prob, batch_size, axis=0), max_length, axis=1 | ||
| ) | ||
|
|
||
| inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) | ||
| outputs = self.sampler( | ||
| inputs, token_probability_fn, max_length=max_length | ||
| ) | ||
| self.assertAllEqual( | ||
| outputs, 3 * tf.ones(shape=[batch_size, max_length]) | ||
| ) | ||
|
|
||
| def test_end_token_id(self): | ||
| max_length = 5 | ||
|
|
||
| def token_probability_fn(inputs, mask): | ||
| batch_size = inputs.shape[0] | ||
| prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) | ||
| return tf.repeat( | ||
| tf.repeat(prob, batch_size, axis=0), max_length, axis=1 | ||
| ) | ||
|
|
||
| sampler = Greedy() | ||
| inputs = tf.constant([[0, 1], [1, 2]]) | ||
| outputs = sampler( | ||
| inputs, | ||
| token_probability_fn, | ||
| max_length=max_length, | ||
| end_token_id=2, | ||
| ) | ||
| expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]]) | ||
| self.assertAllEqual(outputs, expected_outputs) | ||
|
|
||
| def test_compare_xla_noxla_results(self): | ||
| inputs = [[1], [1]] | ||
| xla_sampler = Greedy(jit_compile=True) | ||
| outputs_xla = xla_sampler( | ||
| inputs, self.token_probability_fn, max_length=5 | ||
| ) | ||
|
|
||
| xla_sampler = Greedy(jit_compile=False) | ||
| outputs_no_xla = xla_sampler( | ||
| inputs, self.token_probability_fn, max_length=5 | ||
| ) | ||
|
|
||
| self.assertAllEqual(outputs_xla, outputs_no_xla) |
Uh oh!
There was an error while loading. Please reload this page.