From 26fd5099fbadbbde7f56cec17d15b5d876e9a3ea Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 9 Dec 2022 14:41:21 -0800 Subject: [PATCH 1/8] initial commit --- keras_nlp/samplers/__init__.py | 13 ++ keras_nlp/samplers/greedy_sampler.py | 112 ++++++++++++++ keras_nlp/samplers/greedy_sampler_test.py | 114 ++++++++++++++ keras_nlp/samplers/sampler.py | 178 ++++++++++++++++++++++ 4 files changed, 417 insertions(+) create mode 100644 keras_nlp/samplers/__init__.py create mode 100644 keras_nlp/samplers/greedy_sampler.py create mode 100644 keras_nlp/samplers/greedy_sampler_test.py create mode 100644 keras_nlp/samplers/sampler.py diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py new file mode 100644 index 0000000000..6e4df4e727 --- /dev/null +++ b/keras_nlp/samplers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py new file mode 100644 index 0000000000..816193ebf3 --- /dev/null +++ b/keras_nlp/samplers/greedy_sampler.py @@ -0,0 +1,112 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Greedy Sampler.""" + +import tensorflow as tf + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_keyword_args +from keras_nlp.samplers.sampler import call_keyword_docstring +from keras_nlp.samplers.sampler import sample_keyword_docstring + + +class GreedySampler(Sampler): + """Greedy Sampler class. + + This sampler is implemented on greedy search, i.e., always picking up the + token of the largest probability as the next token. + + Args: + {{base_sampler_keyword_args}} + + Call Args: + {{call_keyword_args}} + """ + + def __init__( + self, + end_token_id=None, + pad_token_id=0, + jit_compile=True, + ): + super().__init__(end_token_id, pad_token_id, jit_compile) + + def sample(self, token_probability_fn, prompt, mask, num_steps): + """Sampler's logic implementation. + + Args: + {{call_keyword_docstring}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + length = max_length - num_steps + + def one_step(length, prompt, mask): + + probs = token_probability_fn(prompt, mask) + next_token_prob = tf.gather( + probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + ) + next_token = tf.cast( + tf.argmax(next_token_prob, axis=-1), dtype=prompt.dtype + ) + next_token = tf.where( + mask[:, length], prompt[:, length], next_token + ) + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + + # Append the next token to current sequence. + prompt = tf.tensor_scatter_nd_update( + tensor=prompt, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=next_token, + ) + + length = tf.add(length, 1) + return (length, prompt, mask) + + # Run a while loop till text of length `max_length` has been generated. + length, prompt, mask = tf.while_loop( + cond=lambda length, prompt, mask: tf.less(length, max_length), + body=one_step, + loop_vars=(length, prompt, mask), + ) + return prompt + + +GreedySampler.__doc__ = GreedySampler.__doc__.replace( + "{{base_sampler_keyword_args}}", base_sampler_keyword_args +) +GreedySampler.__doc__ = GreedySampler.__doc__.replace( + "{{call_keyword_docstring}}", call_keyword_docstring +) +GreedySampler.sample.__doc__ = GreedySampler.sample.__doc__.replace( + "{{sample_keyword_docstring}}", sample_keyword_docstring +) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py new file mode 100644 index 0000000000..6e787d0afa --- /dev/null +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -0,0 +1,114 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Text Generation Utils.""" + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.samplers.greedy_sampler import GreedySampler + + +class GreedySamplerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.vocab_size = 10 + self.feature_size = 16 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=self.vocab_size, + output_dim=self.feature_size, + ), + keras.layers.Dense(self.vocab_size), + keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs, mask): + return model(inputs) + + self.token_probability_fn = token_probability_fn + + self.sampler = GreedySampler() + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + self.assertEqual(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + max_length = 5 + + def token_probability_fn(inputs, mask): + # Assert that user function is passed only dense tensors. + self.assertIsInstance(inputs, tf.Tensor) + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1) + + inputs = tf.ragged.constant([[1], [2, 1, 2]]) + outputs = self.sampler(token_probability_fn, inputs, max_length) + self.assertEqual(outputs.shape, [2, 5]) + + def test_assert_generation_is_correct(self): + batch_size = 10 + max_length = 3 + + def token_probability_fn(inputs, mask): + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat( + tf.repeat(prob, batch_size, axis=0), max_length, axis=1 + ) + + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + outputs = self.sampler( + token_probability_fn, inputs, max_length=max_length + ) + self.assertAllEqual( + outputs, 3 * tf.ones(shape=[batch_size, max_length]) + ) + + def test_end_token_id(self): + max_length = 5 + + def token_probability_fn(inputs, mask): + batch_size = inputs.shape[0] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat( + tf.repeat(prob, batch_size, axis=0), max_length, axis=1 + ) + + sampler = GreedySampler(end_token_id=2) + inputs = tf.constant([[0, 1], [1, 2]]) + outputs = sampler( + token_probability_fn, + inputs, + max_length=max_length, + ) + expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) + expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + self.assertAllEqual(outputs, expected_outputs) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py new file mode 100644 index 0000000000..90e2aae086 --- /dev/null +++ b/keras_nlp/samplers/sampler.py @@ -0,0 +1,178 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base sampler class.""" + +import tensorflow as tf + + +class Sampler: + """Base sampler class. + + This class must be implemented by child class for instantiation. + + Args: + {{base_optimizer_keyword_args}} + + Call Args: + + """ + + def __init__( + self, + end_token_id=None, + pad_token_id=0, + jit_compile=True, + ): + self.end_token_id = end_token_id + self.pad_token_id = pad_token_id + self.jit_compile = jit_compile + + def _validate_prompt(self, prompt): + """Helper function to validate input to text_generation utils.""" + if not isinstance(prompt, (tf.Tensor, tf.RaggedTensor)): + prompt = tf.convert_to_tensor(prompt) + return prompt + + def _validate_token_probability_fn( + self, token_probability_fn, prompt, mask + ): + """Helper function to validate `token_probability_fn` output.""" + test_pred = token_probability_fn(prompt, mask=mask) + if len(test_pred.shape) != 3: + raise ValueError( + "Output of `token_probability_fn` is not a 3D tensor, " + "please provide a function with the output shape " + "[batch_size, sequence_length, vocab_size]." + ) + + def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): + """Align prompt to the right side, and pad to `max_length`.""" + longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + pad_length = longest_prompt_len - prompt.row_lengths() + + prompt = tf.keras.utils.pad_sequences( + prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id + ) + + mask = tf.RaggedTensor.from_row_lengths( + tf.zeros(shape=[tf.reduce_sum(pad_length)], dtype=tf.int32), + pad_length, + ) + mask = mask.to_tensor(shape=(None, longest_prompt_len), default_value=1) + + shape = prompt.shape + extra_space = tf.math.maximum(0, max_length - shape[1]) + pad_shape = [shape[0], extra_space] + + mask = tf.concat((mask, tf.zeros(pad_shape, tf.int32)), axis=1) + prompt = tf.concat( + (prompt, tf.zeros(pad_shape, prompt.dtype) + pad_token_id), axis=1 + ) + mask = tf.cast(mask, dtype=tf.bool) + return prompt, mask + + def _mask_tokens_after_end_token( + self, prompt, max_length, end_token_id, pad_token_id + ): + """Helper function to mask the tokens after the end token.""" + # Mask out tokens after `end_token_id` is encountered. + # Find index of first end_token_id. + end_indices = tf.math.argmax(prompt == end_token_id, -1) + # Use max_length if no `end_token_id` is found. + end_indices = tf.where( + end_indices == 0, + tf.cast(max_length, dtype=end_indices.dtype), + end_indices, + ) + # Build a mask including end_token and replace tokens after end_token + # with `pad_token_id`. + valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) + return tf.where(valid_indices, prompt, pad_token_id) + + def __call__(self, token_probability_fn, prompt, max_length): + """Sampling method to be called by users.""" + + prompt = self._validate_prompt(prompt) + + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + if isinstance(prompt, tf.Tensor): + prompt = tf.RaggedTensor.from_tensor( + prompt, padding=self.pad_token_id + ) + longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + prompt, mask = self._align_and_pad_prompt( + prompt, max_length, self.pad_token_id + ) + self._validate_token_probability_fn(token_probability_fn, prompt, mask) + + sample = tf.function(self.sample, jit_compile=self.jit_compile) + prompt = sample( + token_probability_fn, prompt, mask, max_length - longest_prompt_len + ) + + if self.end_token_id is not None: + prompt = self._mask_tokens_after_end_token( + prompt, max_length, self.end_token_id, self.pad_token_id + ) + + return tf.squeeze(prompt) if input_is_1d else prompt + + def sample(self, token_probability_fn, prompt, mask, num_steps): + """Sampler's logic implementation. + + Args: + {{sample_keyword_docstring}} + + Returns: + A dense int Tensor, representing the generated text in token id + space. + """ + raise NotImplementedError + + +base_sampler_keyword_args = """ + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + pad_token_id: int, defaults to 0. The pad token after `end_token_id` + is received. + jit_compile: bool, defaults to True. If using XLA compilation.""" + +call_keyword_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a list of integers or an integer Tensor, can be 1D or 2D. The + initial tokens to append generated tokens. + max_length: int. The max length of generated sequence.""" + +sample_keyword_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a list of integers or an integer Tensor, can be 1D or 2D. The + initial tokens to append generated tokens. + num_steps: int. The number of tokens to generate.""" + +Sampler.__doc__ = Sampler.__doc__.replace( + "{{base_sampler_keyword_args}}", base_sampler_keyword_args +) +Sampler.__doc__ = Sampler.__call__.__doc__.replace( + "{{call_keyword_docstring}}", call_keyword_docstring +) +Sampler.sample.__doc__ = Sampler.sample.__doc__.replace( + "{{sample_keyword_docstring}}", sample_keyword_docstring +) From 7e4c651df9c58f1005ea0bfb85dfb514a6750596 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 9 Dec 2022 17:32:00 -0800 Subject: [PATCH 2/8] Add keras_nlp.samplers --- keras_nlp/samplers/__init__.py | 2 + keras_nlp/samplers/greedy_sampler.py | 75 ++++++++++++++++++----- keras_nlp/samplers/greedy_sampler_test.py | 16 ++++- keras_nlp/samplers/sampler.py | 70 ++++++++++++++++----- 4 files changed, 129 insertions(+), 34 deletions(-) diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 6e4df4e727..11908fb4dc 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.samplers.greedy_sampler import GreedySampler diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 816193ebf3..b57362179c 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -32,6 +32,38 @@ class GreedySampler(Sampler): Call Args: {{call_keyword_args}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + # Print the generated sequence (token ids). + print(sampler(token_probability_fn, prompt, max_length=10)) + ``` """ def __init__( @@ -43,33 +75,40 @@ def __init__( super().__init__(end_token_id, pad_token_id, jit_compile) def sample(self, token_probability_fn, prompt, mask, num_steps): - """Sampler's logic implementation. + """Sampling logic implementation. Args: - {{call_keyword_docstring}} + {{sample_keyword_docstring}} """ batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] max_length = tf.cast(max_length, num_steps.dtype) - length = max_length - num_steps + # The index of the last non-padding token in prompt. Since all sequences + # are aligned to the right side, the index is the same for all. + current_index = max_length - num_steps - def one_step(length, prompt, mask): + def one_step(current_index, prompt, mask): probs = token_probability_fn(prompt, mask) next_token_prob = tf.gather( - probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + probs, + tf.repeat(current_index - 1, batch_size), + axis=1, + batch_dims=1, ) next_token = tf.cast( tf.argmax(next_token_prob, axis=-1), dtype=prompt.dtype ) next_token = tf.where( - mask[:, length], prompt[:, length], next_token + mask[:, current_index], prompt[:, current_index], next_token ) mask = tf.tensor_scatter_nd_update( tensor=mask, indices=tf.stack( ( - tf.cast(tf.range(batch_size), dtype=length.dtype), - tf.repeat(length, batch_size), + tf.cast( + tf.range(batch_size), dtype=current_index.dtype + ), + tf.repeat(current_index, batch_size), ), axis=1, ), @@ -81,22 +120,26 @@ def one_step(length, prompt, mask): tensor=prompt, indices=tf.stack( ( - tf.cast(tf.range(batch_size), dtype=length.dtype), - tf.repeat(length, batch_size), + tf.cast( + tf.range(batch_size), dtype=current_index.dtype + ), + tf.repeat(current_index, batch_size), ), axis=1, ), updates=next_token, ) - length = tf.add(length, 1) - return (length, prompt, mask) + current_index = tf.add(current_index, 1) + return (current_index, prompt, mask) - # Run a while loop till text of length `max_length` has been generated. - length, prompt, mask = tf.while_loop( - cond=lambda length, prompt, mask: tf.less(length, max_length), + # Run a while loop till `max_length` of tokens has been generated. + current_index, prompt, mask = tf.while_loop( + cond=lambda current_index, prompt, mask: tf.less( + current_index, max_length + ), body=one_step, - loop_vars=(length, prompt, mask), + loop_vars=(current_index, prompt, mask), ) return prompt diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index 6e787d0afa..8eac81fd71 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Text Generation Utils.""" +"""Tests for GreedySampler.""" import tensorflow as tf from absl.testing import parameterized @@ -112,3 +112,17 @@ def token_probability_fn(inputs, mask): expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) expected_outputs = tf.concat([inputs, expected_outputs], axis=1) self.assertAllEqual(outputs, expected_outputs) + + def test_compare_xla_noxla_results(self): + inputs = [[1], [1]] + xla_sampler = GreedySampler(jit_compile=True) + outputs_xla = xla_sampler( + self.token_probability_fn, inputs, max_length=5 + ) + + xla_sampler = GreedySampler(jit_compile=False) + outputs_no_xla = xla_sampler( + self.token_probability_fn, inputs, max_length=5 + ) + + self.assertAllEqual(outputs_xla, outputs_no_xla) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 90e2aae086..3f997c3792 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -14,18 +14,49 @@ """Base sampler class.""" import tensorflow as tf +from tensorflow import keras class Sampler: """Base sampler class. - This class must be implemented by child class for instantiation. - Args: {{base_optimizer_keyword_args}} Call Args: - + {{call_keyword_docstring}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + # Print the generated sequence (token ids). + print(sampler(token_probability_fn, prompt, max_length=10)) + ``` """ def __init__( @@ -39,7 +70,7 @@ def __init__( self.jit_compile = jit_compile def _validate_prompt(self, prompt): - """Helper function to validate input to text_generation utils.""" + """Helper method to validate input prompt.""" if not isinstance(prompt, (tf.Tensor, tf.RaggedTensor)): prompt = tf.convert_to_tensor(prompt) return prompt @@ -47,7 +78,7 @@ def _validate_prompt(self, prompt): def _validate_token_probability_fn( self, token_probability_fn, prompt, mask ): - """Helper function to validate `token_probability_fn` output.""" + """Helper method to validate `token_probability_fn` output.""" test_pred = token_probability_fn(prompt, mask=mask) if len(test_pred.shape) != 3: raise ValueError( @@ -61,7 +92,7 @@ def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): longest_prompt_len = tf.reduce_max(prompt.row_lengths()) pad_length = longest_prompt_len - prompt.row_lengths() - prompt = tf.keras.utils.pad_sequences( + prompt = keras.utils.pad_sequences( prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id ) @@ -97,12 +128,10 @@ def _mask_tokens_after_end_token( ) # Build a mask including end_token and replace tokens after end_token # with `pad_token_id`. - valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) - return tf.where(valid_indices, prompt, pad_token_id) + mask_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) + return tf.where(mask_indices, prompt, pad_token_id) def __call__(self, token_probability_fn, prompt, max_length): - """Sampling method to be called by users.""" - prompt = self._validate_prompt(prompt) input_is_1d = prompt.shape.rank == 1 @@ -113,16 +142,23 @@ def __call__(self, token_probability_fn, prompt, max_length): prompt, padding=self.pad_token_id ) longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. + # This step is required for XLA compatibility because XLA requires a + # static shape, which means we cannot concatenate generated token to + # current prompt. prompt, mask = self._align_and_pad_prompt( prompt, max_length, self.pad_token_id ) self._validate_token_probability_fn(token_probability_fn, prompt, mask) + # Convert `sample` method to a `tf.function`, and turn on + # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( token_probability_fn, prompt, mask, max_length - longest_prompt_len ) + # Mask out tokens after `end_token_id`. if self.end_token_id is not None: prompt = self._mask_tokens_after_end_token( prompt, max_length, self.end_token_id, self.pad_token_id @@ -131,7 +167,7 @@ def __call__(self, token_probability_fn, prompt, max_length): return tf.squeeze(prompt) if input_is_1d else prompt def sample(self, token_probability_fn, prompt, mask, num_steps): - """Sampler's logic implementation. + """Sampling logic implementation. Args: {{sample_keyword_docstring}} @@ -149,9 +185,9 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): sequence. If None, every sequence is generated up to `max_length`. If set, all tokens after encountering `end_token_id` will be replaced with `pad_token_id`. - pad_token_id: int, defaults to 0. The pad token after `end_token_id` - is received. - jit_compile: bool, defaults to True. If using XLA compilation.""" + pad_token_id: int, defaults to 0. The padding token. + jit_compile: bool, defaults to True. If True, XLA compilation will be used. + """ call_keyword_docstring = """ token_probability_fn: a function that generates the probability of @@ -163,9 +199,9 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): sample_keyword_docstring = """ token_probability_fn: a function that generates the probability of the next token over the whole vocabulary for each input token. - prompt: a list of integers or an integer Tensor, can be 1D or 2D. The - initial tokens to append generated tokens. - num_steps: int. The number of tokens to generate.""" + prompt: a dense int Tensor of shape [batch_size, max_length]. The + placeholder for generated sequence. + num_steps: int. The remaining number of tokens to generate.""" Sampler.__doc__ = Sampler.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args From 28bcfe1b8d3875e30f2bf5b23fa68781be0c63b8 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 12 Dec 2022 15:50:16 -0800 Subject: [PATCH 3/8] Change padding to left to right --- keras_nlp/samplers/sampler.py | 40 +++++++++-------------------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 3f997c3792..4e6aca3bb8 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -14,7 +14,6 @@ """Base sampler class.""" import tensorflow as tf -from tensorflow import keras class Sampler: @@ -87,30 +86,13 @@ def _validate_token_probability_fn( "[batch_size, sequence_length, vocab_size]." ) - def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): - """Align prompt to the right side, and pad to `max_length`.""" - longest_prompt_len = tf.reduce_max(prompt.row_lengths()) - pad_length = longest_prompt_len - prompt.row_lengths() - - prompt = keras.utils.pad_sequences( - prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id - ) - - mask = tf.RaggedTensor.from_row_lengths( - tf.zeros(shape=[tf.reduce_sum(pad_length)], dtype=tf.int32), - pad_length, - ) - mask = mask.to_tensor(shape=(None, longest_prompt_len), default_value=1) - - shape = prompt.shape - extra_space = tf.math.maximum(0, max_length - shape[1]) - pad_shape = [shape[0], extra_space] - - mask = tf.concat((mask, tf.zeros(pad_shape, tf.int32)), axis=1) - prompt = tf.concat( - (prompt, tf.zeros(pad_shape, prompt.dtype) + pad_token_id), axis=1 + def _pad_prompt(self, prompt, max_length, pad_token_id): + """Pad prompt to `max_length`.""" + mask = tf.ones_like(prompt, dtype=tf.bool) + mask = mask.to_tensor(shape=(None, max_length)) + prompt = prompt.to_tensor( + shape=(None, max_length), default_value=pad_token_id ) - mask = tf.cast(mask, dtype=tf.bool) return prompt, mask def _mask_tokens_after_end_token( @@ -141,21 +123,19 @@ def __call__(self, token_probability_fn, prompt, max_length): prompt = tf.RaggedTensor.from_tensor( prompt, padding=self.pad_token_id ) - longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + shortest_prompt_len = tf.reduce_min(prompt.row_lengths()) # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. # This step is required for XLA compatibility because XLA requires a # static shape, which means we cannot concatenate generated token to # current prompt. - prompt, mask = self._align_and_pad_prompt( - prompt, max_length, self.pad_token_id - ) + prompt, mask = self._pad_prompt(prompt, max_length, self.pad_token_id) self._validate_token_probability_fn(token_probability_fn, prompt, mask) # Convert `sample` method to a `tf.function`, and turn on # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( - token_probability_fn, prompt, mask, max_length - longest_prompt_len + token_probability_fn, prompt, mask, max_length - shortest_prompt_len ) # Mask out tokens after `end_token_id`. @@ -206,7 +186,7 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): Sampler.__doc__ = Sampler.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args ) -Sampler.__doc__ = Sampler.__call__.__doc__.replace( +Sampler.__doc__ = Sampler.__doc__.replace( "{{call_keyword_docstring}}", call_keyword_docstring ) Sampler.sample.__doc__ = Sampler.sample.__doc__.replace( From 9757f4d78f79da4b548892c78f38beea4fd9df09 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 5 Jan 2023 12:16:14 -0800 Subject: [PATCH 4/8] Add serialization support, and move some args from constructor to call --- keras_nlp/__init__.py | 1 + keras_nlp/samplers/__init__.py | 68 +++++++++- .../samplers/{greedy_sampler.py => greedy.py} | 24 ++-- ...{greedy_sampler_test.py => greedy_test.py} | 18 +-- keras_nlp/samplers/sampler.py | 123 ++++++++++++------ keras_nlp/samplers/sampler_test.py | 57 ++++++++ 6 files changed, 229 insertions(+), 62 deletions(-) rename keras_nlp/samplers/{greedy_sampler.py => greedy.py} (88%) rename keras_nlp/samplers/{greedy_sampler_test.py => greedy_test.py} (89%) create mode 100644 keras_nlp/samplers/sampler_test.py diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 73e7e2c593..050a189d6f 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -15,6 +15,7 @@ from keras_nlp import layers from keras_nlp import metrics from keras_nlp import models +from keras_nlp import samplers from keras_nlp import tokenizers from keras_nlp import utils diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 11908fb4dc..f4841dd85a 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -12,4 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.samplers.greedy_sampler import GreedySampler +from tensorflow import keras + +from keras_nlp.samplers.greedy import Greedy + + +def serialize(sampler): + return keras.utils.serialize_keras_object(sampler) + + +def deserialize(config, custom_objects=None): + """Return a `Sampler` object from its config.""" + all_classes = { + "greedy": Greedy, + } + if config["class_name"].lower() in all_classes: + config["class_name"] = config["class_name"].lower() + return keras.utils.deserialize_keras_object( + config, + module_objects=all_classes, + custom_objects=custom_objects, + printable_module_name="samplers", + ) + + +def get(identifier): + """Retrieve a KerasNLP sampler by the identifier. + + The `identifier` may be the string name of a sampler class or class. + + >>> identifier = 'greedy' + >>> sampler = keras_nlp.samplers.get(identifier) + + You can also specify `config` of the sampler to this function by passing + dict containing `class_name` and `config` as an identifier. Also note that + the `class_name` must map to a `Sampler` class. + + >>> cfg = {'class_name': 'Greedy', 'config': {}} + >>> sampler = keras_nlp.samplers.get(cfg) + + In the case that the `identifier` is a class, this method will return a new + instance of the class by its constructor. + + Args: + identifier: String or dict that contains the sampler name or + configurations. + + Returns: + Sampler instance base on the input identifier. + + Raises: + ValueError: If the input identifier is not a supported type or in a bad + format. + """ + + if identifier is None: + return None + if isinstance(identifier, dict): + return deserialize(identifier) + elif isinstance(identifier, str): + identifier = {"class_name": str(identifier), "config": {}} + return deserialize(identifier) + elif callable(identifier): + return identifier + else: + raise ValueError( + "Could not interpret sampler identifier: " + str(identifier) + ) diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy.py similarity index 88% rename from keras_nlp/samplers/greedy_sampler.py rename to keras_nlp/samplers/greedy.py index b57362179c..f486a7cca0 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy.py @@ -21,8 +21,8 @@ from keras_nlp.samplers.sampler import sample_keyword_docstring -class GreedySampler(Sampler): - """Greedy Sampler class. +class Greedy(Sampler): + """Greedy sampler class. This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. @@ -55,26 +55,26 @@ class GreedySampler(Sampler): # Define a function that outputs the next token's probability for each token # in the input sequence. - def token_probability_fn(inputs): + def token_probability_fn(inputs, mask): return model(inputs) prompt = tf.fill((BATCH_SIZE, 1), START_ID) - sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, max_length=10)) + print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) ``` """ def __init__( self, - end_token_id=None, - pad_token_id=0, jit_compile=True, ): - super().__init__(end_token_id, pad_token_id, jit_compile) + super().__init__(jit_compile) - def sample(self, token_probability_fn, prompt, mask, num_steps): + def sample( + self, token_probability_fn, prompt, mask, num_steps, from_logits=True + ): """Sampling logic implementation. Args: @@ -144,12 +144,12 @@ def one_step(current_index, prompt, mask): return prompt -GreedySampler.__doc__ = GreedySampler.__doc__.replace( +Greedy.__doc__ = Greedy.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args ) -GreedySampler.__doc__ = GreedySampler.__doc__.replace( +Greedy.__doc__ = Greedy.__doc__.replace( "{{call_keyword_docstring}}", call_keyword_docstring ) -GreedySampler.sample.__doc__ = GreedySampler.sample.__doc__.replace( +Greedy.sample.__doc__ = Greedy.sample.__doc__.replace( "{{sample_keyword_docstring}}", sample_keyword_docstring ) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_test.py similarity index 89% rename from keras_nlp/samplers/greedy_sampler_test.py rename to keras_nlp/samplers/greedy_test.py index 8eac81fd71..0472bdcbc9 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_test.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for GreedySampler.""" +"""Tests for Greedy sampler.""" import tensorflow as tf from absl.testing import parameterized from tensorflow import keras -from keras_nlp.samplers.greedy_sampler import GreedySampler +from keras_nlp.samplers.greedy import Greedy -class GreedySamplerTest(tf.test.TestCase, parameterized.TestCase): +class GreedyTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() self.vocab_size = 10 @@ -44,7 +44,7 @@ def token_probability_fn(inputs, mask): self.token_probability_fn = token_probability_fn - self.sampler = GreedySampler() + self.sampler = Greedy() def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) @@ -102,25 +102,25 @@ def token_probability_fn(inputs, mask): tf.repeat(prob, batch_size, axis=0), max_length, axis=1 ) - sampler = GreedySampler(end_token_id=2) + sampler = Greedy() inputs = tf.constant([[0, 1], [1, 2]]) outputs = sampler( token_probability_fn, inputs, max_length=max_length, + end_token_id=2, ) - expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) - expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]]) self.assertAllEqual(outputs, expected_outputs) def test_compare_xla_noxla_results(self): inputs = [[1], [1]] - xla_sampler = GreedySampler(jit_compile=True) + xla_sampler = Greedy(jit_compile=True) outputs_xla = xla_sampler( self.token_probability_fn, inputs, max_length=5 ) - xla_sampler = GreedySampler(jit_compile=False) + xla_sampler = Greedy(jit_compile=False) outputs_no_xla = xla_sampler( self.token_probability_fn, inputs, max_length=5 ) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 4e6aca3bb8..64ff4bba0c 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -47,32 +47,45 @@ class Sampler: # Define a function that outputs the next token's probability for each token # in the input sequence. - def token_probability_fn(inputs): + def token_probability_fn(inputs, mask): return model(inputs) prompt = tf.fill((BATCH_SIZE, 1), START_ID) - sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, max_length=10)) + print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) ``` """ def __init__( self, - end_token_id=None, - pad_token_id=0, jit_compile=True, ): - self.end_token_id = end_token_id - self.pad_token_id = pad_token_id self.jit_compile = jit_compile - def _validate_prompt(self, prompt): + def _validate_prompt_and_mask(self, prompt, mask): """Helper method to validate input prompt.""" - if not isinstance(prompt, (tf.Tensor, tf.RaggedTensor)): + if not isinstance(prompt, (list, tf.RaggedTensor, tf.Tensor)): + raise ValueError( + "`prompt` must be one of `list`, `tf.RaggedTensor` or " + f"`tf.Tensor`, but received: prompt={type(prompt)}." + ) + + if isinstance(prompt, tf.RaggedTensor): + if mask: + raise ValueError( + "`mask` is only valid when `prompt` is a list or dense " + f"tensor, but received type(prompt)={type(prompt)}." + ) + return prompt, mask + + if isinstance(prompt, list): prompt = tf.convert_to_tensor(prompt) - return prompt + if not mask: + mask = tf.cast(tf.ones_like(prompt), dtype=tf.bool) + prompt = tf.ragged.boolean_mask(prompt, mask) + return prompt, mask def _validate_token_probability_fn( self, token_probability_fn, prompt, mask @@ -86,19 +99,20 @@ def _validate_token_probability_fn( "[batch_size, sequence_length, vocab_size]." ) - def _pad_prompt(self, prompt, max_length, pad_token_id): + def _pad_prompt(self, prompt, max_length): """Pad prompt to `max_length`.""" mask = tf.ones_like(prompt, dtype=tf.bool) mask = mask.to_tensor(shape=(None, max_length)) - prompt = prompt.to_tensor( - shape=(None, max_length), default_value=pad_token_id - ) + prompt = prompt.to_tensor(shape=(None, max_length)) return prompt, mask def _mask_tokens_after_end_token( - self, prompt, max_length, end_token_id, pad_token_id + self, + prompt, + max_length, + end_token_id, ): - """Helper function to mask the tokens after the end token.""" + """Helper function to truncate the tokens after the end token.""" # Mask out tokens after `end_token_id` is encountered. # Find index of first end_token_id. end_indices = tf.math.argmax(prompt == end_token_id, -1) @@ -108,45 +122,59 @@ def _mask_tokens_after_end_token( tf.cast(max_length, dtype=end_indices.dtype), end_indices, ) - # Build a mask including end_token and replace tokens after end_token - # with `pad_token_id`. - mask_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) - return tf.where(mask_indices, prompt, pad_token_id) + # Truncate out tokens after (including) the end token. + mask_indices = tf.sequence_mask(end_indices, maxlen=max_length) + return tf.ragged.boolean_mask(prompt, mask_indices) - def __call__(self, token_probability_fn, prompt, max_length): - prompt = self._validate_prompt(prompt) + def __call__( + self, + token_probability_fn, + prompt, + max_length, + padding_mask=None, + end_token_id=None, + from_logits=True, + ): + prompt, padding_mask = self._validate_prompt_and_mask( + prompt, padding_mask + ) input_is_1d = prompt.shape.rank == 1 if input_is_1d: - prompt = prompt[tf.newaxis, :] - if isinstance(prompt, tf.Tensor): - prompt = tf.RaggedTensor.from_tensor( - prompt, padding=self.pad_token_id - ) + prompt = tf.RaggedTensor.from_tensor(prompt[tf.newaxis, :]) + shortest_prompt_len = tf.reduce_min(prompt.row_lengths()) # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. # This step is required for XLA compatibility because XLA requires a # static shape, which means we cannot concatenate generated token to # current prompt. - prompt, mask = self._pad_prompt(prompt, max_length, self.pad_token_id) + prompt, mask = self._pad_prompt(prompt, max_length) self._validate_token_probability_fn(token_probability_fn, prompt, mask) # Convert `sample` method to a `tf.function`, and turn on # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( - token_probability_fn, prompt, mask, max_length - shortest_prompt_len + token_probability_fn, + prompt, + mask, + max_length - shortest_prompt_len, + from_logits, ) # Mask out tokens after `end_token_id`. - if self.end_token_id is not None: + if end_token_id is not None: prompt = self._mask_tokens_after_end_token( - prompt, max_length, self.end_token_id, self.pad_token_id + prompt, + max_length, + end_token_id, ) return tf.squeeze(prompt) if input_is_1d else prompt - def sample(self, token_probability_fn, prompt, mask, num_steps): + def sample( + self, token_probability_fn, prompt, mask, num_steps, from_logits=True + ): """Sampling logic implementation. Args: @@ -158,14 +186,13 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): """ raise NotImplementedError + def get_config(self): + return { + "jit_compile": self.jit_compile, + } + base_sampler_keyword_args = """ - end_token_id: int, defaults to None. The token marking the end of the - sequence, once encountered the generation is finished for the exact - sequence. If None, every sequence is generated up to `max_length`. - If set, all tokens after encountering `end_token_id` will be - replaced with `pad_token_id`. - pad_token_id: int, defaults to 0. The padding token. jit_compile: bool, defaults to True. If True, XLA compilation will be used. """ @@ -174,14 +201,30 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): the next token over the whole vocabulary for each input token. prompt: a list of integers or an integer Tensor, can be 1D or 2D. The initial tokens to append generated tokens. - max_length: int. The max length of generated sequence.""" + max_length: int. The max length of generated sequence. + padding_mask: a tensor, defaults to None. The padding mask of the prompt. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ sample_keyword_docstring = """ token_probability_fn: a function that generates the probability of the next token over the whole vocabulary for each input token. prompt: a dense int Tensor of shape [batch_size, max_length]. The placeholder for generated sequence. - num_steps: int. The remaining number of tokens to generate.""" + mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of + prompt. + num_steps: int. The remaining number of tokens to generate. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ Sampler.__doc__ = Sampler.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args diff --git a/keras_nlp/samplers/sampler_test.py b/keras_nlp/samplers/sampler_test.py new file mode 100644 index 0000000000..2e4ed4a1a7 --- /dev/null +++ b/keras_nlp/samplers/sampler_test.py @@ -0,0 +1,57 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Sampler classes.""" + +import tensorflow as tf + +import keras_nlp +from keras_nlp.samplers.greedy import Greedy + + +class SamplerTest(tf.test.TestCase): + def test_serialization(self): + sampler = keras_nlp.samplers.Greedy() + config = keras_nlp.samplers.serialize(sampler) + expected_config = { + "class_name": "Greedy", + "config": { + "jit_compile": True, + }, + } + self.assertDictEqual(expected_config, config) + + def test_deserialization(self): + # Test get from string. + identifier = "greedy" + sampler = keras_nlp.samplers.get(identifier) + self.assertIsInstance(sampler, Greedy) + + # Test string is not case-sensitive. + identifier = "Greedy" + sampler = keras_nlp.samplers.get(identifier) + self.assertIsInstance(sampler, Greedy) + + # Test dict identifier. + original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) + config = keras_nlp.samplers.serialize(original_sampler) + restored_sampler = keras_nlp.samplers.get(config) + self.assertDictEqual( + keras_nlp.samplers.serialize(restored_sampler), + keras_nlp.samplers.serialize(original_sampler), + ) + + # Test identifier is already a sampler instance. + original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) + restored_sampler = keras_nlp.samplers.get(original_sampler) + self.assertEqual(original_sampler, restored_sampler) From f7508cb054c2f83db8d1b864075396b980dd0eb4 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 5 Jan 2023 20:28:40 -0800 Subject: [PATCH 5/8] Add string example --- keras_nlp/samplers/sampler.py | 42 ++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 64ff4bba0c..e7216ab23d 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -25,7 +25,11 @@ class Sampler: Call Args: {{call_keyword_docstring}} + The inputs and outputs of Sampler class are both token ids. + Examples: + + Basic usage: ```python BATCH_SIZE = 8 VOCAB_SIZE = 10 @@ -56,6 +60,42 @@ def token_probability_fn(inputs, mask): # Print the generated sequence (token ids). print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) ``` + + Use with string inputs: + ```python + vocab = ["[UNK]", "[PAD]", "[END]", "the", "quick", "brown", "fox"] + tokenizer = keras_nlp.tokenizers.WordPieceTokenizer( + vocabulary=vocab, + lowercase=True, + ) + FEATURE_SIZE = 16 + VOCAB_SIZE = len(vocab) + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs, mask): + return model(inputs) + + prompt = tokenizer("the quick brown fox") + sampler = keras_nlp.samplers.Greedy() + generated = sampler( + token_probability_fn, + prompt, + 10, + end_token_id=tokenizer.token_to_id("[END]") + ) + print(tokenizer.detokenize(generated)) + ``` """ def __init__( @@ -170,7 +210,7 @@ def __call__( end_token_id, ) - return tf.squeeze(prompt) if input_is_1d else prompt + return tf.squeeze(prompt, axis=0) if input_is_1d else prompt def sample( self, token_probability_fn, prompt, mask, num_steps, from_logits=True From b658b61c0ee757d8ae707a08a36f0fcba0d094d6 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 6 Jan 2023 13:45:51 -0800 Subject: [PATCH 6/8] small changes --- keras_nlp/samplers/greedy.py | 5 +++-- keras_nlp/samplers/sampler.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/keras_nlp/samplers/greedy.py b/keras_nlp/samplers/greedy.py index f486a7cca0..a5e3626f14 100644 --- a/keras_nlp/samplers/greedy.py +++ b/keras_nlp/samplers/greedy.py @@ -14,6 +14,7 @@ """Greedy Sampler.""" import tensorflow as tf +from tensorflow import keras from keras_nlp.samplers.sampler import Sampler from keras_nlp.samplers.sampler import base_sampler_keyword_args @@ -21,6 +22,7 @@ from keras_nlp.samplers.sampler import sample_keyword_docstring +@keras.utils.register_keras_serializable(package="keras_nlp") class Greedy(Sampler): """Greedy sampler class. @@ -39,7 +41,6 @@ class Greedy(Sampler): VOCAB_SIZE = 10 FEATURE_SIZE = 16 START_ID = 1 - END_ID = 2 # Create a dummy model to predict the next token. model = keras.Sequential( @@ -62,7 +63,7 @@ def token_probability_fn(inputs, mask): sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) + print(sampler(token_probability_fn, prompt, 10)) ``` """ diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index e7216ab23d..aa7c64bdc7 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -14,8 +14,10 @@ """Base sampler class.""" import tensorflow as tf +from tensorflow import keras +@keras.utils.register_keras_serializable(package="keras_nlp") class Sampler: """Base sampler class. From 76c430c755d8f68bf9ccbedab4fdc5fb171758d0 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 9 Jan 2023 12:08:18 -0800 Subject: [PATCH 7/8] Address comments: fix docstring, remove multicase support --- keras_nlp/samplers/__init__.py | 18 +++--- keras_nlp/samplers/greedy.py | 29 ++++----- keras_nlp/samplers/sampler.py | 97 +++++++++++++++--------------- keras_nlp/samplers/sampler_test.py | 7 +-- 4 files changed, 69 insertions(+), 82 deletions(-) diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index f4841dd85a..5125a571b0 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -26,8 +26,6 @@ def deserialize(config, custom_objects=None): all_classes = { "greedy": Greedy, } - if config["class_name"].lower() in all_classes: - config["class_name"] = config["class_name"].lower() return keras.utils.deserialize_keras_object( config, module_objects=all_classes, @@ -55,15 +53,15 @@ def get(identifier): instance of the class by its constructor. Args: - identifier: String or dict that contains the sampler name or - configurations. + identifier: String or dict that contains the sampler name or + configurations. Returns: - Sampler instance base on the input identifier. + Sampler instance base on the input identifier. Raises: - ValueError: If the input identifier is not a supported type or in a bad - format. + ValueError: If the input identifier is not a supported type or in a bad + format. """ if identifier is None: @@ -71,7 +69,11 @@ def get(identifier): if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, str): - identifier = {"class_name": str(identifier), "config": {}} + if not identifier.islower(): + raise KeyError( + "`keras_nlp.samplers.get()` must take a lowercase string " + f"identifier, but received: {identifier}." + ) return deserialize(identifier) elif callable(identifier): return identifier diff --git a/keras_nlp/samplers/greedy.py b/keras_nlp/samplers/greedy.py index a5e3626f14..60fe732126 100644 --- a/keras_nlp/samplers/greedy.py +++ b/keras_nlp/samplers/greedy.py @@ -17,11 +17,15 @@ from tensorflow import keras from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import base_sampler_keyword_args -from keras_nlp.samplers.sampler import call_keyword_docstring -from keras_nlp.samplers.sampler import sample_keyword_docstring +from keras_nlp.samplers.sampler import base_sampler_args_docstring +from keras_nlp.samplers.sampler import call_args_docstring +from keras_nlp.samplers.sampler import sample_args_docstring +from keras_nlp.utils.python_utils import format_docstring +@format_docstring( + base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring +) @keras.utils.register_keras_serializable(package="keras_nlp") class Greedy(Sampler): """Greedy sampler class. @@ -30,10 +34,10 @@ class Greedy(Sampler): token of the largest probability as the next token. Args: - {{base_sampler_keyword_args}} + {{base_sampler_args}} Call Args: - {{call_keyword_args}} + {{call_args}} Examples: ```python @@ -73,13 +77,14 @@ def __init__( ): super().__init__(jit_compile) + @format_docstring(sample_args=sample_args_docstring) def sample( self, token_probability_fn, prompt, mask, num_steps, from_logits=True ): """Sampling logic implementation. Args: - {{sample_keyword_docstring}} + {{sample_args}} """ batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] max_length = tf.cast(max_length, num_steps.dtype) @@ -88,7 +93,6 @@ def sample( current_index = max_length - num_steps def one_step(current_index, prompt, mask): - probs = token_probability_fn(prompt, mask) next_token_prob = tf.gather( probs, @@ -143,14 +147,3 @@ def one_step(current_index, prompt, mask): loop_vars=(current_index, prompt, mask), ) return prompt - - -Greedy.__doc__ = Greedy.__doc__.replace( - "{{base_sampler_keyword_args}}", base_sampler_keyword_args -) -Greedy.__doc__ = Greedy.__doc__.replace( - "{{call_keyword_docstring}}", call_keyword_docstring -) -Greedy.sample.__doc__ = Greedy.sample.__doc__.replace( - "{{sample_keyword_docstring}}", sample_keyword_docstring -) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index aa7c64bdc7..8f208ff10a 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -16,16 +16,55 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.utils.python_utils import format_docstring +base_sampler_args_docstring = """ + jit_compile: bool, defaults to True. If True, XLA compilation will be used. + """ + +call_args_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a list of integers or an integer Tensor, can be 1D or 2D. The + initial tokens to append generated tokens. + max_length: int. The max length of generated sequence. + padding_mask: a tensor, defaults to None. The padding mask of the prompt. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ + +sample_args_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a dense int Tensor of shape [batch_size, max_length]. The + placeholder for generated sequence. + mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of + prompt. + num_steps: int. The remaining number of tokens to generate. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ + + +@format_docstring( + base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring +) @keras.utils.register_keras_serializable(package="keras_nlp") class Sampler: """Base sampler class. Args: - {{base_optimizer_keyword_args}} + {{base_sampler_args}} Call Args: - {{call_keyword_docstring}} + {{call_args}} The inputs and outputs of Sampler class are both token ids. @@ -39,7 +78,8 @@ class Sampler: START_ID = 1 END_ID = 2 - # Create a dummy model to predict the next token. + # Create a dummy model to predict the next token. Note that the output is + # random without training, here we jsut demo how `samplers` works. model = keras.Sequential( [ keras.Input(shape=[None]), @@ -178,7 +218,8 @@ def __call__( from_logits=True, ): prompt, padding_mask = self._validate_prompt_and_mask( - prompt, padding_mask + prompt, + padding_mask, ) input_is_1d = prompt.shape.rank == 1 @@ -214,13 +255,14 @@ def __call__( return tf.squeeze(prompt, axis=0) if input_is_1d else prompt + @format_docstring(sample_args=sample_args_docstring) def sample( self, token_probability_fn, prompt, mask, num_steps, from_logits=True ): """Sampling logic implementation. Args: - {{sample_keyword_docstring}} + {{sample_args}} Returns: A dense int Tensor, representing the generated text in token id @@ -232,48 +274,3 @@ def get_config(self): return { "jit_compile": self.jit_compile, } - - -base_sampler_keyword_args = """ - jit_compile: bool, defaults to True. If True, XLA compilation will be used. - """ - -call_keyword_docstring = """ - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. - prompt: a list of integers or an integer Tensor, can be 1D or 2D. The - initial tokens to append generated tokens. - max_length: int. The max length of generated sequence. - padding_mask: a tensor, defaults to None. The padding mask of the prompt. - end_token_id: int, defaults to None. The token marking the end of the - sequence, once encountered the generation is finished for the exact - sequence. If None, every sequence is generated up to `max_length`. - If set, all tokens after encountering `end_token_id` will be - replaced with `pad_token_id`. - from_logits: bool, defaults to True. Indicate if the `token_probability_fn` - returns logits. If False, `token_probability_fn` returns probability - distributions. - """ - -sample_keyword_docstring = """ - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. - prompt: a dense int Tensor of shape [batch_size, max_length]. The - placeholder for generated sequence. - mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of - prompt. - num_steps: int. The remaining number of tokens to generate. - from_logits: bool, defaults to True. Indicate if the `token_probability_fn` - returns logits. If False, `token_probability_fn` returns probability - distributions. - """ - -Sampler.__doc__ = Sampler.__doc__.replace( - "{{base_sampler_keyword_args}}", base_sampler_keyword_args -) -Sampler.__doc__ = Sampler.__doc__.replace( - "{{call_keyword_docstring}}", call_keyword_docstring -) -Sampler.sample.__doc__ = Sampler.sample.__doc__.replace( - "{{sample_keyword_docstring}}", sample_keyword_docstring -) diff --git a/keras_nlp/samplers/sampler_test.py b/keras_nlp/samplers/sampler_test.py index 2e4ed4a1a7..f88c6ad0ab 100644 --- a/keras_nlp/samplers/sampler_test.py +++ b/keras_nlp/samplers/sampler_test.py @@ -24,7 +24,7 @@ def test_serialization(self): sampler = keras_nlp.samplers.Greedy() config = keras_nlp.samplers.serialize(sampler) expected_config = { - "class_name": "Greedy", + "class_name": "keras_nlp>Greedy", "config": { "jit_compile": True, }, @@ -37,11 +37,6 @@ def test_deserialization(self): sampler = keras_nlp.samplers.get(identifier) self.assertIsInstance(sampler, Greedy) - # Test string is not case-sensitive. - identifier = "Greedy" - sampler = keras_nlp.samplers.get(identifier) - self.assertIsInstance(sampler, Greedy) - # Test dict identifier. original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) config = keras_nlp.samplers.serialize(original_sampler) From bb430dd905b6fcb67b80d2d0b50f55a4ad6b600d Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 9 Jan 2023 12:36:29 -0800 Subject: [PATCH 8/8] Address comments: move token_probability_fn to the second place --- keras_nlp/samplers/__init__.py | 2 +- keras_nlp/samplers/greedy.py | 4 ++-- keras_nlp/samplers/greedy_test.py | 16 ++++++++-------- keras_nlp/samplers/sampler.py | 27 ++++++++++++--------------- 4 files changed, 23 insertions(+), 26 deletions(-) diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 5125a571b0..3f9e6cd8c2 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -46,7 +46,7 @@ def get(identifier): dict containing `class_name` and `config` as an identifier. Also note that the `class_name` must map to a `Sampler` class. - >>> cfg = {'class_name': 'Greedy', 'config': {}} + >>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}} >>> sampler = keras_nlp.samplers.get(cfg) In the case that the `identifier` is a class, this method will return a new diff --git a/keras_nlp/samplers/greedy.py b/keras_nlp/samplers/greedy.py index 60fe732126..398f719b6b 100644 --- a/keras_nlp/samplers/greedy.py +++ b/keras_nlp/samplers/greedy.py @@ -67,7 +67,7 @@ def token_probability_fn(inputs, mask): sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, 10)) + print(sampler(prompt, token_probability_fn, 10)) ``` """ @@ -79,7 +79,7 @@ def __init__( @format_docstring(sample_args=sample_args_docstring) def sample( - self, token_probability_fn, prompt, mask, num_steps, from_logits=True + self, prompt, token_probability_fn, mask, num_steps, from_logits=True ): """Sampling logic implementation. diff --git a/keras_nlp/samplers/greedy_test.py b/keras_nlp/samplers/greedy_test.py index 0472bdcbc9..bd77b2490f 100644 --- a/keras_nlp/samplers/greedy_test.py +++ b/keras_nlp/samplers/greedy_test.py @@ -48,17 +48,17 @@ def token_probability_fn(inputs, mask): def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) - outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) self.assertEqual(outputs.shape, [5]) def test_generate_with_2d_prompt(self): inputs = tf.constant([[1], [1]]) - outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) self.assertEqual(outputs.shape, [2, 5]) def test_generate_with_list_prompt(self): inputs = [[1], [1]] - outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) self.assertEqual(outputs.shape, [2, 5]) def test_generate_with_ragged_prompt(self): @@ -71,7 +71,7 @@ def token_probability_fn(inputs, mask): return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1) inputs = tf.ragged.constant([[1], [2, 1, 2]]) - outputs = self.sampler(token_probability_fn, inputs, max_length) + outputs = self.sampler(inputs, token_probability_fn, max_length) self.assertEqual(outputs.shape, [2, 5]) def test_assert_generation_is_correct(self): @@ -86,7 +86,7 @@ def token_probability_fn(inputs, mask): inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) outputs = self.sampler( - token_probability_fn, inputs, max_length=max_length + inputs, token_probability_fn, max_length=max_length ) self.assertAllEqual( outputs, 3 * tf.ones(shape=[batch_size, max_length]) @@ -105,8 +105,8 @@ def token_probability_fn(inputs, mask): sampler = Greedy() inputs = tf.constant([[0, 1], [1, 2]]) outputs = sampler( - token_probability_fn, inputs, + token_probability_fn, max_length=max_length, end_token_id=2, ) @@ -117,12 +117,12 @@ def test_compare_xla_noxla_results(self): inputs = [[1], [1]] xla_sampler = Greedy(jit_compile=True) outputs_xla = xla_sampler( - self.token_probability_fn, inputs, max_length=5 + inputs, self.token_probability_fn, max_length=5 ) xla_sampler = Greedy(jit_compile=False) outputs_no_xla = xla_sampler( - self.token_probability_fn, inputs, max_length=5 + inputs, self.token_probability_fn, max_length=5 ) self.assertAllEqual(outputs_xla, outputs_no_xla) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 8f208ff10a..c265bf9172 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -23,12 +23,12 @@ """ call_args_docstring = """ - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. prompt: a list of integers or an integer Tensor, can be 1D or 2D. The initial tokens to append generated tokens. + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. max_length: int. The max length of generated sequence. - padding_mask: a tensor, defaults to None. The padding mask of the prompt. + mask: a tensor, defaults to None. The padding mask of the prompt. end_token_id: int, defaults to None. The token marking the end of the sequence, once encountered the generation is finished for the exact sequence. If None, every sequence is generated up to `max_length`. @@ -40,10 +40,10 @@ """ sample_args_docstring = """ - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. prompt: a dense int Tensor of shape [batch_size, max_length]. The placeholder for generated sequence. + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of prompt. num_steps: int. The remaining number of tokens to generate. @@ -100,7 +100,7 @@ def token_probability_fn(inputs, mask): sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) + print(sampler(prompt, token_probability_fn, 10, end_token_id=END_ID)) ``` Use with string inputs: @@ -131,8 +131,8 @@ def token_probability_fn(inputs, mask): prompt = tokenizer("the quick brown fox") sampler = keras_nlp.samplers.Greedy() generated = sampler( - token_probability_fn, prompt, + token_probability_fn, 10, end_token_id=tokenizer.token_to_id("[END]") ) @@ -210,17 +210,14 @@ def _mask_tokens_after_end_token( def __call__( self, - token_probability_fn, prompt, + token_probability_fn, max_length, - padding_mask=None, + mask=None, end_token_id=None, from_logits=True, ): - prompt, padding_mask = self._validate_prompt_and_mask( - prompt, - padding_mask, - ) + prompt, mask = self._validate_prompt_and_mask(prompt, mask) input_is_1d = prompt.shape.rank == 1 if input_is_1d: @@ -238,8 +235,8 @@ def __call__( # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( - token_probability_fn, prompt, + token_probability_fn, mask, max_length - shortest_prompt_len, from_logits, @@ -257,7 +254,7 @@ def __call__( @format_docstring(sample_args=sample_args_docstring) def sample( - self, token_probability_fn, prompt, mask, num_steps, from_logits=True + self, prompt, token_probability_fn, mask, num_steps, from_logits=True ): """Sampling logic implementation.