diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 73e7e2c593..050a189d6f 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -15,6 +15,7 @@ from keras_nlp import layers from keras_nlp import metrics from keras_nlp import models +from keras_nlp import samplers from keras_nlp import tokenizers from keras_nlp import utils diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py new file mode 100644 index 0000000000..3f9e6cd8c2 --- /dev/null +++ b/keras_nlp/samplers/__init__.py @@ -0,0 +1,83 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensorflow import keras + +from keras_nlp.samplers.greedy import Greedy + + +def serialize(sampler): + return keras.utils.serialize_keras_object(sampler) + + +def deserialize(config, custom_objects=None): + """Return a `Sampler` object from its config.""" + all_classes = { + "greedy": Greedy, + } + return keras.utils.deserialize_keras_object( + config, + module_objects=all_classes, + custom_objects=custom_objects, + printable_module_name="samplers", + ) + + +def get(identifier): + """Retrieve a KerasNLP sampler by the identifier. + + The `identifier` may be the string name of a sampler class or class. + + >>> identifier = 'greedy' + >>> sampler = keras_nlp.samplers.get(identifier) + + You can also specify `config` of the sampler to this function by passing + dict containing `class_name` and `config` as an identifier. Also note that + the `class_name` must map to a `Sampler` class. + + >>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}} + >>> sampler = keras_nlp.samplers.get(cfg) + + In the case that the `identifier` is a class, this method will return a new + instance of the class by its constructor. + + Args: + identifier: String or dict that contains the sampler name or + configurations. + + Returns: + Sampler instance base on the input identifier. + + Raises: + ValueError: If the input identifier is not a supported type or in a bad + format. + """ + + if identifier is None: + return None + if isinstance(identifier, dict): + return deserialize(identifier) + elif isinstance(identifier, str): + if not identifier.islower(): + raise KeyError( + "`keras_nlp.samplers.get()` must take a lowercase string " + f"identifier, but received: {identifier}." + ) + return deserialize(identifier) + elif callable(identifier): + return identifier + else: + raise ValueError( + "Could not interpret sampler identifier: " + str(identifier) + ) diff --git a/keras_nlp/samplers/greedy.py b/keras_nlp/samplers/greedy.py new file mode 100644 index 0000000000..398f719b6b --- /dev/null +++ b/keras_nlp/samplers/greedy.py @@ -0,0 +1,149 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Greedy Sampler.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_args_docstring +from keras_nlp.samplers.sampler import call_args_docstring +from keras_nlp.samplers.sampler import sample_args_docstring +from keras_nlp.utils.python_utils import format_docstring + + +@format_docstring( + base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring +) +@keras.utils.register_keras_serializable(package="keras_nlp") +class Greedy(Sampler): + """Greedy sampler class. + + This sampler is implemented on greedy search, i.e., always picking up the + token of the largest probability as the next token. + + Args: + {{base_sampler_args}} + + Call Args: + {{call_args}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs, mask): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.Greedy() + # Print the generated sequence (token ids). + print(sampler(prompt, token_probability_fn, 10)) + ``` + """ + + def __init__( + self, + jit_compile=True, + ): + super().__init__(jit_compile) + + @format_docstring(sample_args=sample_args_docstring) + def sample( + self, prompt, token_probability_fn, mask, num_steps, from_logits=True + ): + """Sampling logic implementation. + + Args: + {{sample_args}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + # The index of the last non-padding token in prompt. Since all sequences + # are aligned to the right side, the index is the same for all. + current_index = max_length - num_steps + + def one_step(current_index, prompt, mask): + probs = token_probability_fn(prompt, mask) + next_token_prob = tf.gather( + probs, + tf.repeat(current_index - 1, batch_size), + axis=1, + batch_dims=1, + ) + next_token = tf.cast( + tf.argmax(next_token_prob, axis=-1), dtype=prompt.dtype + ) + next_token = tf.where( + mask[:, current_index], prompt[:, current_index], next_token + ) + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast( + tf.range(batch_size), dtype=current_index.dtype + ), + tf.repeat(current_index, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + + # Append the next token to current sequence. + prompt = tf.tensor_scatter_nd_update( + tensor=prompt, + indices=tf.stack( + ( + tf.cast( + tf.range(batch_size), dtype=current_index.dtype + ), + tf.repeat(current_index, batch_size), + ), + axis=1, + ), + updates=next_token, + ) + + current_index = tf.add(current_index, 1) + return (current_index, prompt, mask) + + # Run a while loop till `max_length` of tokens has been generated. + current_index, prompt, mask = tf.while_loop( + cond=lambda current_index, prompt, mask: tf.less( + current_index, max_length + ), + body=one_step, + loop_vars=(current_index, prompt, mask), + ) + return prompt diff --git a/keras_nlp/samplers/greedy_test.py b/keras_nlp/samplers/greedy_test.py new file mode 100644 index 0000000000..bd77b2490f --- /dev/null +++ b/keras_nlp/samplers/greedy_test.py @@ -0,0 +1,128 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Greedy sampler.""" + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.samplers.greedy import Greedy + + +class GreedyTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.vocab_size = 10 + self.feature_size = 16 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=self.vocab_size, + output_dim=self.feature_size, + ), + keras.layers.Dense(self.vocab_size), + keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs, mask): + return model(inputs) + + self.token_probability_fn = token_probability_fn + + self.sampler = Greedy() + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + max_length = 5 + + def token_probability_fn(inputs, mask): + # Assert that user function is passed only dense tensors. + self.assertIsInstance(inputs, tf.Tensor) + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1) + + inputs = tf.ragged.constant([[1], [2, 1, 2]]) + outputs = self.sampler(inputs, token_probability_fn, max_length) + self.assertEqual(outputs.shape, [2, 5]) + + def test_assert_generation_is_correct(self): + batch_size = 10 + max_length = 3 + + def token_probability_fn(inputs, mask): + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat( + tf.repeat(prob, batch_size, axis=0), max_length, axis=1 + ) + + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + outputs = self.sampler( + inputs, token_probability_fn, max_length=max_length + ) + self.assertAllEqual( + outputs, 3 * tf.ones(shape=[batch_size, max_length]) + ) + + def test_end_token_id(self): + max_length = 5 + + def token_probability_fn(inputs, mask): + batch_size = inputs.shape[0] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat( + tf.repeat(prob, batch_size, axis=0), max_length, axis=1 + ) + + sampler = Greedy() + inputs = tf.constant([[0, 1], [1, 2]]) + outputs = sampler( + inputs, + token_probability_fn, + max_length=max_length, + end_token_id=2, + ) + expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]]) + self.assertAllEqual(outputs, expected_outputs) + + def test_compare_xla_noxla_results(self): + inputs = [[1], [1]] + xla_sampler = Greedy(jit_compile=True) + outputs_xla = xla_sampler( + inputs, self.token_probability_fn, max_length=5 + ) + + xla_sampler = Greedy(jit_compile=False) + outputs_no_xla = xla_sampler( + inputs, self.token_probability_fn, max_length=5 + ) + + self.assertAllEqual(outputs_xla, outputs_no_xla) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py new file mode 100644 index 0000000000..c265bf9172 --- /dev/null +++ b/keras_nlp/samplers/sampler.py @@ -0,0 +1,273 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base sampler class.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.utils.python_utils import format_docstring + +base_sampler_args_docstring = """ + jit_compile: bool, defaults to True. If True, XLA compilation will be used. + """ + +call_args_docstring = """ + prompt: a list of integers or an integer Tensor, can be 1D or 2D. The + initial tokens to append generated tokens. + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + max_length: int. The max length of generated sequence. + mask: a tensor, defaults to None. The padding mask of the prompt. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ + +sample_args_docstring = """ + prompt: a dense int Tensor of shape [batch_size, max_length]. The + placeholder for generated sequence. + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of + prompt. + num_steps: int. The remaining number of tokens to generate. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ + + +@format_docstring( + base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring +) +@keras.utils.register_keras_serializable(package="keras_nlp") +class Sampler: + """Base sampler class. + + Args: + {{base_sampler_args}} + + Call Args: + {{call_args}} + + The inputs and outputs of Sampler class are both token ids. + + Examples: + + Basic usage: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. Note that the output is + # random without training, here we jsut demo how `samplers` works. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs, mask): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.Greedy() + # Print the generated sequence (token ids). + print(sampler(prompt, token_probability_fn, 10, end_token_id=END_ID)) + ``` + + Use with string inputs: + ```python + vocab = ["[UNK]", "[PAD]", "[END]", "the", "quick", "brown", "fox"] + tokenizer = keras_nlp.tokenizers.WordPieceTokenizer( + vocabulary=vocab, + lowercase=True, + ) + FEATURE_SIZE = 16 + VOCAB_SIZE = len(vocab) + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs, mask): + return model(inputs) + + prompt = tokenizer("the quick brown fox") + sampler = keras_nlp.samplers.Greedy() + generated = sampler( + prompt, + token_probability_fn, + 10, + end_token_id=tokenizer.token_to_id("[END]") + ) + print(tokenizer.detokenize(generated)) + ``` + """ + + def __init__( + self, + jit_compile=True, + ): + self.jit_compile = jit_compile + + def _validate_prompt_and_mask(self, prompt, mask): + """Helper method to validate input prompt.""" + if not isinstance(prompt, (list, tf.RaggedTensor, tf.Tensor)): + raise ValueError( + "`prompt` must be one of `list`, `tf.RaggedTensor` or " + f"`tf.Tensor`, but received: prompt={type(prompt)}." + ) + + if isinstance(prompt, tf.RaggedTensor): + if mask: + raise ValueError( + "`mask` is only valid when `prompt` is a list or dense " + f"tensor, but received type(prompt)={type(prompt)}." + ) + return prompt, mask + + if isinstance(prompt, list): + prompt = tf.convert_to_tensor(prompt) + if not mask: + mask = tf.cast(tf.ones_like(prompt), dtype=tf.bool) + prompt = tf.ragged.boolean_mask(prompt, mask) + return prompt, mask + + def _validate_token_probability_fn( + self, token_probability_fn, prompt, mask + ): + """Helper method to validate `token_probability_fn` output.""" + test_pred = token_probability_fn(prompt, mask=mask) + if len(test_pred.shape) != 3: + raise ValueError( + "Output of `token_probability_fn` is not a 3D tensor, " + "please provide a function with the output shape " + "[batch_size, sequence_length, vocab_size]." + ) + + def _pad_prompt(self, prompt, max_length): + """Pad prompt to `max_length`.""" + mask = tf.ones_like(prompt, dtype=tf.bool) + mask = mask.to_tensor(shape=(None, max_length)) + prompt = prompt.to_tensor(shape=(None, max_length)) + return prompt, mask + + def _mask_tokens_after_end_token( + self, + prompt, + max_length, + end_token_id, + ): + """Helper function to truncate the tokens after the end token.""" + # Mask out tokens after `end_token_id` is encountered. + # Find index of first end_token_id. + end_indices = tf.math.argmax(prompt == end_token_id, -1) + # Use max_length if no `end_token_id` is found. + end_indices = tf.where( + end_indices == 0, + tf.cast(max_length, dtype=end_indices.dtype), + end_indices, + ) + # Truncate out tokens after (including) the end token. + mask_indices = tf.sequence_mask(end_indices, maxlen=max_length) + return tf.ragged.boolean_mask(prompt, mask_indices) + + def __call__( + self, + prompt, + token_probability_fn, + max_length, + mask=None, + end_token_id=None, + from_logits=True, + ): + prompt, mask = self._validate_prompt_and_mask(prompt, mask) + + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = tf.RaggedTensor.from_tensor(prompt[tf.newaxis, :]) + + shortest_prompt_len = tf.reduce_min(prompt.row_lengths()) + # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. + # This step is required for XLA compatibility because XLA requires a + # static shape, which means we cannot concatenate generated token to + # current prompt. + prompt, mask = self._pad_prompt(prompt, max_length) + self._validate_token_probability_fn(token_probability_fn, prompt, mask) + + # Convert `sample` method to a `tf.function`, and turn on + # `jit_compile` accordingly. + sample = tf.function(self.sample, jit_compile=self.jit_compile) + prompt = sample( + prompt, + token_probability_fn, + mask, + max_length - shortest_prompt_len, + from_logits, + ) + + # Mask out tokens after `end_token_id`. + if end_token_id is not None: + prompt = self._mask_tokens_after_end_token( + prompt, + max_length, + end_token_id, + ) + + return tf.squeeze(prompt, axis=0) if input_is_1d else prompt + + @format_docstring(sample_args=sample_args_docstring) + def sample( + self, prompt, token_probability_fn, mask, num_steps, from_logits=True + ): + """Sampling logic implementation. + + Args: + {{sample_args}} + + Returns: + A dense int Tensor, representing the generated text in token id + space. + """ + raise NotImplementedError + + def get_config(self): + return { + "jit_compile": self.jit_compile, + } diff --git a/keras_nlp/samplers/sampler_test.py b/keras_nlp/samplers/sampler_test.py new file mode 100644 index 0000000000..f88c6ad0ab --- /dev/null +++ b/keras_nlp/samplers/sampler_test.py @@ -0,0 +1,52 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Sampler classes.""" + +import tensorflow as tf + +import keras_nlp +from keras_nlp.samplers.greedy import Greedy + + +class SamplerTest(tf.test.TestCase): + def test_serialization(self): + sampler = keras_nlp.samplers.Greedy() + config = keras_nlp.samplers.serialize(sampler) + expected_config = { + "class_name": "keras_nlp>Greedy", + "config": { + "jit_compile": True, + }, + } + self.assertDictEqual(expected_config, config) + + def test_deserialization(self): + # Test get from string. + identifier = "greedy" + sampler = keras_nlp.samplers.get(identifier) + self.assertIsInstance(sampler, Greedy) + + # Test dict identifier. + original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) + config = keras_nlp.samplers.serialize(original_sampler) + restored_sampler = keras_nlp.samplers.get(config) + self.assertDictEqual( + keras_nlp.samplers.serialize(restored_sampler), + keras_nlp.samplers.serialize(original_sampler), + ) + + # Test identifier is already a sampler instance. + original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) + restored_sampler = keras_nlp.samplers.get(original_sampler) + self.assertEqual(original_sampler, restored_sampler)