From 7f7ae434ee4b0052f40459c0d503061addf15f5f Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 9 Dec 2022 14:41:21 -0800 Subject: [PATCH 01/28] initial commit --- keras_nlp/samplers/__init__.py | 13 ++ keras_nlp/samplers/greedy_sampler.py | 112 ++++++++++++++ keras_nlp/samplers/greedy_sampler_test.py | 114 ++++++++++++++ keras_nlp/samplers/sampler.py | 178 ++++++++++++++++++++++ 4 files changed, 417 insertions(+) create mode 100644 keras_nlp/samplers/__init__.py create mode 100644 keras_nlp/samplers/greedy_sampler.py create mode 100644 keras_nlp/samplers/greedy_sampler_test.py create mode 100644 keras_nlp/samplers/sampler.py diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py new file mode 100644 index 0000000000..6e4df4e727 --- /dev/null +++ b/keras_nlp/samplers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py new file mode 100644 index 0000000000..816193ebf3 --- /dev/null +++ b/keras_nlp/samplers/greedy_sampler.py @@ -0,0 +1,112 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Greedy Sampler.""" + +import tensorflow as tf + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_keyword_args +from keras_nlp.samplers.sampler import call_keyword_docstring +from keras_nlp.samplers.sampler import sample_keyword_docstring + + +class GreedySampler(Sampler): + """Greedy Sampler class. + + This sampler is implemented on greedy search, i.e., always picking up the + token of the largest probability as the next token. + + Args: + {{base_sampler_keyword_args}} + + Call Args: + {{call_keyword_args}} + """ + + def __init__( + self, + end_token_id=None, + pad_token_id=0, + jit_compile=True, + ): + super().__init__(end_token_id, pad_token_id, jit_compile) + + def sample(self, token_probability_fn, prompt, mask, num_steps): + """Sampler's logic implementation. + + Args: + {{call_keyword_docstring}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + length = max_length - num_steps + + def one_step(length, prompt, mask): + + probs = token_probability_fn(prompt, mask) + next_token_prob = tf.gather( + probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + ) + next_token = tf.cast( + tf.argmax(next_token_prob, axis=-1), dtype=prompt.dtype + ) + next_token = tf.where( + mask[:, length], prompt[:, length], next_token + ) + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + + # Append the next token to current sequence. + prompt = tf.tensor_scatter_nd_update( + tensor=prompt, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=next_token, + ) + + length = tf.add(length, 1) + return (length, prompt, mask) + + # Run a while loop till text of length `max_length` has been generated. + length, prompt, mask = tf.while_loop( + cond=lambda length, prompt, mask: tf.less(length, max_length), + body=one_step, + loop_vars=(length, prompt, mask), + ) + return prompt + + +GreedySampler.__doc__ = GreedySampler.__doc__.replace( + "{{base_sampler_keyword_args}}", base_sampler_keyword_args +) +GreedySampler.__doc__ = GreedySampler.__doc__.replace( + "{{call_keyword_docstring}}", call_keyword_docstring +) +GreedySampler.sample.__doc__ = GreedySampler.sample.__doc__.replace( + "{{sample_keyword_docstring}}", sample_keyword_docstring +) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py new file mode 100644 index 0000000000..6e787d0afa --- /dev/null +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -0,0 +1,114 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Text Generation Utils.""" + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.samplers.greedy_sampler import GreedySampler + + +class GreedySamplerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.vocab_size = 10 + self.feature_size = 16 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=self.vocab_size, + output_dim=self.feature_size, + ), + keras.layers.Dense(self.vocab_size), + keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs, mask): + return model(inputs) + + self.token_probability_fn = token_probability_fn + + self.sampler = GreedySampler() + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + self.assertEqual(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + max_length = 5 + + def token_probability_fn(inputs, mask): + # Assert that user function is passed only dense tensors. + self.assertIsInstance(inputs, tf.Tensor) + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1) + + inputs = tf.ragged.constant([[1], [2, 1, 2]]) + outputs = self.sampler(token_probability_fn, inputs, max_length) + self.assertEqual(outputs.shape, [2, 5]) + + def test_assert_generation_is_correct(self): + batch_size = 10 + max_length = 3 + + def token_probability_fn(inputs, mask): + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat( + tf.repeat(prob, batch_size, axis=0), max_length, axis=1 + ) + + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + outputs = self.sampler( + token_probability_fn, inputs, max_length=max_length + ) + self.assertAllEqual( + outputs, 3 * tf.ones(shape=[batch_size, max_length]) + ) + + def test_end_token_id(self): + max_length = 5 + + def token_probability_fn(inputs, mask): + batch_size = inputs.shape[0] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat( + tf.repeat(prob, batch_size, axis=0), max_length, axis=1 + ) + + sampler = GreedySampler(end_token_id=2) + inputs = tf.constant([[0, 1], [1, 2]]) + outputs = sampler( + token_probability_fn, + inputs, + max_length=max_length, + ) + expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) + expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + self.assertAllEqual(outputs, expected_outputs) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py new file mode 100644 index 0000000000..90e2aae086 --- /dev/null +++ b/keras_nlp/samplers/sampler.py @@ -0,0 +1,178 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base sampler class.""" + +import tensorflow as tf + + +class Sampler: + """Base sampler class. + + This class must be implemented by child class for instantiation. + + Args: + {{base_optimizer_keyword_args}} + + Call Args: + + """ + + def __init__( + self, + end_token_id=None, + pad_token_id=0, + jit_compile=True, + ): + self.end_token_id = end_token_id + self.pad_token_id = pad_token_id + self.jit_compile = jit_compile + + def _validate_prompt(self, prompt): + """Helper function to validate input to text_generation utils.""" + if not isinstance(prompt, (tf.Tensor, tf.RaggedTensor)): + prompt = tf.convert_to_tensor(prompt) + return prompt + + def _validate_token_probability_fn( + self, token_probability_fn, prompt, mask + ): + """Helper function to validate `token_probability_fn` output.""" + test_pred = token_probability_fn(prompt, mask=mask) + if len(test_pred.shape) != 3: + raise ValueError( + "Output of `token_probability_fn` is not a 3D tensor, " + "please provide a function with the output shape " + "[batch_size, sequence_length, vocab_size]." + ) + + def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): + """Align prompt to the right side, and pad to `max_length`.""" + longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + pad_length = longest_prompt_len - prompt.row_lengths() + + prompt = tf.keras.utils.pad_sequences( + prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id + ) + + mask = tf.RaggedTensor.from_row_lengths( + tf.zeros(shape=[tf.reduce_sum(pad_length)], dtype=tf.int32), + pad_length, + ) + mask = mask.to_tensor(shape=(None, longest_prompt_len), default_value=1) + + shape = prompt.shape + extra_space = tf.math.maximum(0, max_length - shape[1]) + pad_shape = [shape[0], extra_space] + + mask = tf.concat((mask, tf.zeros(pad_shape, tf.int32)), axis=1) + prompt = tf.concat( + (prompt, tf.zeros(pad_shape, prompt.dtype) + pad_token_id), axis=1 + ) + mask = tf.cast(mask, dtype=tf.bool) + return prompt, mask + + def _mask_tokens_after_end_token( + self, prompt, max_length, end_token_id, pad_token_id + ): + """Helper function to mask the tokens after the end token.""" + # Mask out tokens after `end_token_id` is encountered. + # Find index of first end_token_id. + end_indices = tf.math.argmax(prompt == end_token_id, -1) + # Use max_length if no `end_token_id` is found. + end_indices = tf.where( + end_indices == 0, + tf.cast(max_length, dtype=end_indices.dtype), + end_indices, + ) + # Build a mask including end_token and replace tokens after end_token + # with `pad_token_id`. + valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) + return tf.where(valid_indices, prompt, pad_token_id) + + def __call__(self, token_probability_fn, prompt, max_length): + """Sampling method to be called by users.""" + + prompt = self._validate_prompt(prompt) + + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + if isinstance(prompt, tf.Tensor): + prompt = tf.RaggedTensor.from_tensor( + prompt, padding=self.pad_token_id + ) + longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + prompt, mask = self._align_and_pad_prompt( + prompt, max_length, self.pad_token_id + ) + self._validate_token_probability_fn(token_probability_fn, prompt, mask) + + sample = tf.function(self.sample, jit_compile=self.jit_compile) + prompt = sample( + token_probability_fn, prompt, mask, max_length - longest_prompt_len + ) + + if self.end_token_id is not None: + prompt = self._mask_tokens_after_end_token( + prompt, max_length, self.end_token_id, self.pad_token_id + ) + + return tf.squeeze(prompt) if input_is_1d else prompt + + def sample(self, token_probability_fn, prompt, mask, num_steps): + """Sampler's logic implementation. + + Args: + {{sample_keyword_docstring}} + + Returns: + A dense int Tensor, representing the generated text in token id + space. + """ + raise NotImplementedError + + +base_sampler_keyword_args = """ + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + pad_token_id: int, defaults to 0. The pad token after `end_token_id` + is received. + jit_compile: bool, defaults to True. If using XLA compilation.""" + +call_keyword_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a list of integers or an integer Tensor, can be 1D or 2D. The + initial tokens to append generated tokens. + max_length: int. The max length of generated sequence.""" + +sample_keyword_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a list of integers or an integer Tensor, can be 1D or 2D. The + initial tokens to append generated tokens. + num_steps: int. The number of tokens to generate.""" + +Sampler.__doc__ = Sampler.__doc__.replace( + "{{base_sampler_keyword_args}}", base_sampler_keyword_args +) +Sampler.__doc__ = Sampler.__call__.__doc__.replace( + "{{call_keyword_docstring}}", call_keyword_docstring +) +Sampler.sample.__doc__ = Sampler.sample.__doc__.replace( + "{{sample_keyword_docstring}}", sample_keyword_docstring +) From c53b4a912435a8cf8b0f63a54a6451c938de79a7 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 9 Dec 2022 17:32:00 -0800 Subject: [PATCH 02/28] Add keras_nlp.samplers --- keras_nlp/samplers/__init__.py | 2 + keras_nlp/samplers/greedy_sampler.py | 75 ++++++++++++++++++----- keras_nlp/samplers/greedy_sampler_test.py | 16 ++++- keras_nlp/samplers/sampler.py | 70 ++++++++++++++++----- 4 files changed, 129 insertions(+), 34 deletions(-) diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 6e4df4e727..11908fb4dc 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.samplers.greedy_sampler import GreedySampler diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 816193ebf3..b57362179c 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -32,6 +32,38 @@ class GreedySampler(Sampler): Call Args: {{call_keyword_args}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + # Print the generated sequence (token ids). + print(sampler(token_probability_fn, prompt, max_length=10)) + ``` """ def __init__( @@ -43,33 +75,40 @@ def __init__( super().__init__(end_token_id, pad_token_id, jit_compile) def sample(self, token_probability_fn, prompt, mask, num_steps): - """Sampler's logic implementation. + """Sampling logic implementation. Args: - {{call_keyword_docstring}} + {{sample_keyword_docstring}} """ batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] max_length = tf.cast(max_length, num_steps.dtype) - length = max_length - num_steps + # The index of the last non-padding token in prompt. Since all sequences + # are aligned to the right side, the index is the same for all. + current_index = max_length - num_steps - def one_step(length, prompt, mask): + def one_step(current_index, prompt, mask): probs = token_probability_fn(prompt, mask) next_token_prob = tf.gather( - probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + probs, + tf.repeat(current_index - 1, batch_size), + axis=1, + batch_dims=1, ) next_token = tf.cast( tf.argmax(next_token_prob, axis=-1), dtype=prompt.dtype ) next_token = tf.where( - mask[:, length], prompt[:, length], next_token + mask[:, current_index], prompt[:, current_index], next_token ) mask = tf.tensor_scatter_nd_update( tensor=mask, indices=tf.stack( ( - tf.cast(tf.range(batch_size), dtype=length.dtype), - tf.repeat(length, batch_size), + tf.cast( + tf.range(batch_size), dtype=current_index.dtype + ), + tf.repeat(current_index, batch_size), ), axis=1, ), @@ -81,22 +120,26 @@ def one_step(length, prompt, mask): tensor=prompt, indices=tf.stack( ( - tf.cast(tf.range(batch_size), dtype=length.dtype), - tf.repeat(length, batch_size), + tf.cast( + tf.range(batch_size), dtype=current_index.dtype + ), + tf.repeat(current_index, batch_size), ), axis=1, ), updates=next_token, ) - length = tf.add(length, 1) - return (length, prompt, mask) + current_index = tf.add(current_index, 1) + return (current_index, prompt, mask) - # Run a while loop till text of length `max_length` has been generated. - length, prompt, mask = tf.while_loop( - cond=lambda length, prompt, mask: tf.less(length, max_length), + # Run a while loop till `max_length` of tokens has been generated. + current_index, prompt, mask = tf.while_loop( + cond=lambda current_index, prompt, mask: tf.less( + current_index, max_length + ), body=one_step, - loop_vars=(length, prompt, mask), + loop_vars=(current_index, prompt, mask), ) return prompt diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index 6e787d0afa..8eac81fd71 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Text Generation Utils.""" +"""Tests for GreedySampler.""" import tensorflow as tf from absl.testing import parameterized @@ -112,3 +112,17 @@ def token_probability_fn(inputs, mask): expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) expected_outputs = tf.concat([inputs, expected_outputs], axis=1) self.assertAllEqual(outputs, expected_outputs) + + def test_compare_xla_noxla_results(self): + inputs = [[1], [1]] + xla_sampler = GreedySampler(jit_compile=True) + outputs_xla = xla_sampler( + self.token_probability_fn, inputs, max_length=5 + ) + + xla_sampler = GreedySampler(jit_compile=False) + outputs_no_xla = xla_sampler( + self.token_probability_fn, inputs, max_length=5 + ) + + self.assertAllEqual(outputs_xla, outputs_no_xla) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 90e2aae086..3f997c3792 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -14,18 +14,49 @@ """Base sampler class.""" import tensorflow as tf +from tensorflow import keras class Sampler: """Base sampler class. - This class must be implemented by child class for instantiation. - Args: {{base_optimizer_keyword_args}} Call Args: - + {{call_keyword_docstring}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + # Print the generated sequence (token ids). + print(sampler(token_probability_fn, prompt, max_length=10)) + ``` """ def __init__( @@ -39,7 +70,7 @@ def __init__( self.jit_compile = jit_compile def _validate_prompt(self, prompt): - """Helper function to validate input to text_generation utils.""" + """Helper method to validate input prompt.""" if not isinstance(prompt, (tf.Tensor, tf.RaggedTensor)): prompt = tf.convert_to_tensor(prompt) return prompt @@ -47,7 +78,7 @@ def _validate_prompt(self, prompt): def _validate_token_probability_fn( self, token_probability_fn, prompt, mask ): - """Helper function to validate `token_probability_fn` output.""" + """Helper method to validate `token_probability_fn` output.""" test_pred = token_probability_fn(prompt, mask=mask) if len(test_pred.shape) != 3: raise ValueError( @@ -61,7 +92,7 @@ def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): longest_prompt_len = tf.reduce_max(prompt.row_lengths()) pad_length = longest_prompt_len - prompt.row_lengths() - prompt = tf.keras.utils.pad_sequences( + prompt = keras.utils.pad_sequences( prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id ) @@ -97,12 +128,10 @@ def _mask_tokens_after_end_token( ) # Build a mask including end_token and replace tokens after end_token # with `pad_token_id`. - valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) - return tf.where(valid_indices, prompt, pad_token_id) + mask_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) + return tf.where(mask_indices, prompt, pad_token_id) def __call__(self, token_probability_fn, prompt, max_length): - """Sampling method to be called by users.""" - prompt = self._validate_prompt(prompt) input_is_1d = prompt.shape.rank == 1 @@ -113,16 +142,23 @@ def __call__(self, token_probability_fn, prompt, max_length): prompt, padding=self.pad_token_id ) longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. + # This step is required for XLA compatibility because XLA requires a + # static shape, which means we cannot concatenate generated token to + # current prompt. prompt, mask = self._align_and_pad_prompt( prompt, max_length, self.pad_token_id ) self._validate_token_probability_fn(token_probability_fn, prompt, mask) + # Convert `sample` method to a `tf.function`, and turn on + # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( token_probability_fn, prompt, mask, max_length - longest_prompt_len ) + # Mask out tokens after `end_token_id`. if self.end_token_id is not None: prompt = self._mask_tokens_after_end_token( prompt, max_length, self.end_token_id, self.pad_token_id @@ -131,7 +167,7 @@ def __call__(self, token_probability_fn, prompt, max_length): return tf.squeeze(prompt) if input_is_1d else prompt def sample(self, token_probability_fn, prompt, mask, num_steps): - """Sampler's logic implementation. + """Sampling logic implementation. Args: {{sample_keyword_docstring}} @@ -149,9 +185,9 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): sequence. If None, every sequence is generated up to `max_length`. If set, all tokens after encountering `end_token_id` will be replaced with `pad_token_id`. - pad_token_id: int, defaults to 0. The pad token after `end_token_id` - is received. - jit_compile: bool, defaults to True. If using XLA compilation.""" + pad_token_id: int, defaults to 0. The padding token. + jit_compile: bool, defaults to True. If True, XLA compilation will be used. + """ call_keyword_docstring = """ token_probability_fn: a function that generates the probability of @@ -163,9 +199,9 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): sample_keyword_docstring = """ token_probability_fn: a function that generates the probability of the next token over the whole vocabulary for each input token. - prompt: a list of integers or an integer Tensor, can be 1D or 2D. The - initial tokens to append generated tokens. - num_steps: int. The number of tokens to generate.""" + prompt: a dense int Tensor of shape [batch_size, max_length]. The + placeholder for generated sequence. + num_steps: int. The remaining number of tokens to generate.""" Sampler.__doc__ = Sampler.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args From e6483a42c32a33ddb86ae4d88baebb1fdfbdf5f0 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 12 Dec 2022 15:50:16 -0800 Subject: [PATCH 03/28] Change padding to left to right --- keras_nlp/samplers/sampler.py | 40 +++++++++-------------------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 3f997c3792..4e6aca3bb8 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -14,7 +14,6 @@ """Base sampler class.""" import tensorflow as tf -from tensorflow import keras class Sampler: @@ -87,30 +86,13 @@ def _validate_token_probability_fn( "[batch_size, sequence_length, vocab_size]." ) - def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): - """Align prompt to the right side, and pad to `max_length`.""" - longest_prompt_len = tf.reduce_max(prompt.row_lengths()) - pad_length = longest_prompt_len - prompt.row_lengths() - - prompt = keras.utils.pad_sequences( - prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id - ) - - mask = tf.RaggedTensor.from_row_lengths( - tf.zeros(shape=[tf.reduce_sum(pad_length)], dtype=tf.int32), - pad_length, - ) - mask = mask.to_tensor(shape=(None, longest_prompt_len), default_value=1) - - shape = prompt.shape - extra_space = tf.math.maximum(0, max_length - shape[1]) - pad_shape = [shape[0], extra_space] - - mask = tf.concat((mask, tf.zeros(pad_shape, tf.int32)), axis=1) - prompt = tf.concat( - (prompt, tf.zeros(pad_shape, prompt.dtype) + pad_token_id), axis=1 + def _pad_prompt(self, prompt, max_length, pad_token_id): + """Pad prompt to `max_length`.""" + mask = tf.ones_like(prompt, dtype=tf.bool) + mask = mask.to_tensor(shape=(None, max_length)) + prompt = prompt.to_tensor( + shape=(None, max_length), default_value=pad_token_id ) - mask = tf.cast(mask, dtype=tf.bool) return prompt, mask def _mask_tokens_after_end_token( @@ -141,21 +123,19 @@ def __call__(self, token_probability_fn, prompt, max_length): prompt = tf.RaggedTensor.from_tensor( prompt, padding=self.pad_token_id ) - longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + shortest_prompt_len = tf.reduce_min(prompt.row_lengths()) # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. # This step is required for XLA compatibility because XLA requires a # static shape, which means we cannot concatenate generated token to # current prompt. - prompt, mask = self._align_and_pad_prompt( - prompt, max_length, self.pad_token_id - ) + prompt, mask = self._pad_prompt(prompt, max_length, self.pad_token_id) self._validate_token_probability_fn(token_probability_fn, prompt, mask) # Convert `sample` method to a `tf.function`, and turn on # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( - token_probability_fn, prompt, mask, max_length - longest_prompt_len + token_probability_fn, prompt, mask, max_length - shortest_prompt_len ) # Mask out tokens after `end_token_id`. @@ -206,7 +186,7 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): Sampler.__doc__ = Sampler.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args ) -Sampler.__doc__ = Sampler.__call__.__doc__.replace( +Sampler.__doc__ = Sampler.__doc__.replace( "{{call_keyword_docstring}}", call_keyword_docstring ) Sampler.sample.__doc__ = Sampler.sample.__doc__.replace( From 513121e24a9ef16fc687a671334d1d2b2c1bf53e Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Tue, 20 Dec 2022 15:11:10 -0800 Subject: [PATCH 04/28] more samplers --- keras_nlp/samplers/beam_sampler.py | 156 ++++++++++++++++++++++++++++ keras_nlp/samplers/top_k_sampler | 130 +++++++++++++++++++++++ keras_nlp/samplers/top_p_sampler.py | 128 +++++++++++++++++++++++ 3 files changed, 414 insertions(+) create mode 100644 keras_nlp/samplers/beam_sampler.py create mode 100644 keras_nlp/samplers/top_k_sampler create mode 100644 keras_nlp/samplers/top_p_sampler.py diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py new file mode 100644 index 0000000000..91d94330af --- /dev/null +++ b/keras_nlp/samplers/beam_sampler.py @@ -0,0 +1,156 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Greedy Sampler.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_keyword_args +from keras_nlp.samplers.sampler import call_keyword_docstring +from keras_nlp.samplers.sampler import sample_keyword_docstring + + +class BeamSampler(Sampler): + """Beam Sampler class. + + This sampler implements beam search algorithm. + + Args: + {{base_sampler_keyword_args}} + + Call Args: + {{call_keyword_args}} + """ + + def __init__( + self, + num_beams, + seed=None, + from_logits=False, + end_token_id=None, + pad_token_id=0, + jit_compile=True, + ): + self.num_beams = num_beams + self.seed = seed + self.from_logits = from_logits + super().__init__(end_token_id, pad_token_id, jit_compile) + + def sample(self, token_probability_fn, prompt, mask, num_steps): + """Sampler's logic implementation. + + Args: + {{call_keyword_docstring}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + length = max_length - num_steps + dummy_preds = self._validate_token_probability_fn( + token_probability_fn, prompt, mask + ) + vocab_size = dummy_preds.shape[-1] + pred_dtype = dummy_preds.dtype + + num_beams = self.num_beams + + # Initialize beam with shape `(batch_size, num_beams, length)`. + beams = tf.repeat(tf.expand_dims(prompt, axis=1), num_beams, axis=1) + # Initialize `beams_prob` with shape `(batch_size, num_beams)`. + beams_prob = tf.zeros([batch_size, 1], dtype=pred_dtype) + beams_prob = tf.concat( + [beams_prob, tf.fill((batch_size, num_beams - 1), pred_dtype.min)], + axis=-1, + ) + + def one_step(beams, beams_prob, length): + truncated_beams = beams[..., :length] + + flattened_beams = tf.reshape( + truncated_beams, shape=[batch_size * num_beams, -1] + ) + preds = token_probability_fn(flattened_beams) + if self.from_logits: + preds = keras.activations.softmax(preds, axis=-1) + # Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`. + preds = tf.reshape(preds, shape=[batch_size, -1]) + + probs = tf.math.log(preds) + tf.repeat( + beams_prob, repeats=vocab_size, axis=1 + ) + + candidate_prob, candidate_indexes = tf.math.top_k( + probs, k=num_beams, sorted=False + ) + candidate_beam_indexes = candidate_indexes // vocab_size + next_token = candidate_indexes % vocab_size + + beams = tf.gather( + beams, candidate_beam_indexes, axis=1, batch_dims=1 + ) + + # Build a new column of updates to scatter into the beam tensor. + next_token = tf.where( + condition=mask[..., length, tf.newaxis], + x=beams[..., length], + y=next_token, + ) + next_token = tf.reshape(next_token, shape=[-1]) + + # Generate `(batch_index, beam_index)` tuples for each beam. + beam_indices = tf.where(tf.ones((batch_size, num_beams), tf.bool)) + beam_indices = tf.cast(beam_indices, dtype=length.dtype) + # Build a tensor of repeated `length` values. + length_indices = tf.fill((batch_size * num_beams, 1), length) + # Concatenate to a triplet of `(batch_index, beam_index, length)`. + indices = tf.concat([beam_indices, length_indices], axis=-1) + + # Update `beams[:, :, length]` with `next_token`. + beams = tf.tensor_scatter_nd_update( + tensor=beams, + indices=indices, + updates=next_token, + ) + + beams_prob = candidate_prob + length = tf.add(length, 1) + + return beams, beams_prob, length + + # Run a while loop till text of length `max_length` has been generated. + beams, beams_prob, length = tf.while_loop( + cond=lambda beams, beams_prob, length: tf.less(length, max_length), + body=one_step, + loop_vars=(beams, beams_prob, length), + ) + + # Get the beam with the maximum probability. + max_indexes = tf.math.argmax(beams_prob, axis=-1) + max_beams = tf.gather( + beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1 + ) + prompt = tf.squeeze(max_beams) + + return prompt + + +BeamSampler.__doc__ = BeamSampler.__doc__.replace( + "{{base_sampler_keyword_args}}", base_sampler_keyword_args +) +BeamSampler.__doc__ = BeamSampler.__doc__.replace( + "{{call_keyword_docstring}}", call_keyword_docstring +) +BeamSampler.sample.__doc__ = BeamSampler.sample.__doc__.replace( + "{{sample_keyword_docstring}}", sample_keyword_docstring +) diff --git a/keras_nlp/samplers/top_k_sampler b/keras_nlp/samplers/top_k_sampler new file mode 100644 index 0000000000..a6f816d4ae --- /dev/null +++ b/keras_nlp/samplers/top_k_sampler @@ -0,0 +1,130 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Greedy Sampler.""" + +import tensorflow as tf + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_keyword_args +from keras_nlp.samplers.sampler import call_keyword_docstring +from keras_nlp.samplers.sampler import sample_keyword_docstring + + +class TopKSampler(Sampler): + """Top-K Sampler class. + + This sampler implements top-k search algorithm. + + Args: + {{base_sampler_keyword_args}} + + Call Args: + {{call_keyword_args}} + """ + + def __init__( + self, + k, + seed=None, + from_logits=False, + end_token_id=None, + pad_token_id=0, + jit_compile=True, + ): + self.k = k + self.seed = seed + self.from_logits = from_logits + super().__init__(end_token_id, pad_token_id, jit_compile) + + def sample(self, token_probability_fn, prompt, mask, num_steps): + """Sampler's logic implementation. + + Args: + {{call_keyword_docstring}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + length = max_length - num_steps + + def one_step(length, prompt, mask): + probs = token_probability_fn(prompt, mask) + pred = tf.gather( + probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + ) + if self.from_logits: + pred = keras.activations.softmax(pred, axis=-1) + + # Filter out top-k tokens. + top_k_pred, top_k_indices = tf.math.top_k( + pred, k=self.k, sorted=False + ) + # Sample the next token from the probability distribution. + next_token = tf.random.categorical( + tf.math.log(top_k_pred), 1, seed=self.seed + ) + + # Rearrange to get the next token idx from the original order. + next_token = tf.gather_nd(top_k_indices, next_token, batch_dims=1) + next_token = tf.cast(next_token, dtype=prompt.dtype) + next_token = tf.where( + mask[:, length], prompt[:, length], next_token + ) + + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + + # Append the next token to current sequence. + prompt = tf.tensor_scatter_nd_update( + tensor=prompt, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=next_token, + ) + + length = tf.add(length, 1) + return (length, prompt, mask) + + # Run a while loop till text of length `max_length` has been generated. + length, prompt, mask = tf.while_loop( + cond=lambda length, prompt, mask: tf.less(length, max_length), + body=one_step, + loop_vars=(length, prompt, mask), + ) + + return prompt + + +TopKSampler.__doc__ = TopKSampler.__doc__.replace( + "{{base_sampler_keyword_args}}", base_sampler_keyword_args +) +TopKSampler.__doc__ = TopKSampler.__doc__.replace( + "{{call_keyword_docstring}}", call_keyword_docstring +) +TopKSampler.sample.__doc__ = TopKSampler.sample.__doc__.replace( + "{{sample_keyword_docstring}}", sample_keyword_docstring +) diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py new file mode 100644 index 0000000000..e4caaa11c4 --- /dev/null +++ b/keras_nlp/samplers/top_p_sampler.py @@ -0,0 +1,128 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Top-p Sampler.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_keyword_args +from keras_nlp.samplers.sampler import call_keyword_docstring +from keras_nlp.samplers.sampler import sample_keyword_docstring + + +class TopPSampler(Sampler): + """Top-P Sampler class. + + This sampler implements top-p search algorithm. + + Args: + {{base_sampler_keyword_args}} + + Call Args: + {{call_keyword_args}} + """ + + def __init__( + self, + p, + seed=None, + from_logits=False, + end_token_id=None, + pad_token_id=0, + jit_compile=True, + ): + self.p = p + self.seed = seed + self.from_logits = from_logits + super().__init__(end_token_id, pad_token_id, jit_compile) + + def sample(self, token_probability_fn, prompt, mask, num_steps): + """Sampler's logic implementation. + + Args: + {{call_keyword_docstring}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + length = max_length - num_steps + + def one_step(length, prompt): + pred = token_probability_fn(prompt[:, :length]) + if self.from_logits: + pred = keras.activations.softmax(pred, axis=-1) + # Sort preds in descending order. + sorted_preds, sorted_indices = tf.math.top_k( + pred, k=pred.shape[1], sorted=True + ) + # Calculate cumulative probability distribution. + cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1) + # Create a mask for the tokens to keep. + keep_mask = cumulative_probs <= self.p + # Shift to include the last token that exceed p. + shifted_keep_mask = tf.concat( + [tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1 + ) + # Filter out unmasked tokens and sample from filtered distribution. + probs = tf.where( + shifted_keep_mask, + sorted_preds, + tf.zeros(tf.shape(pred), dtype=sorted_preds.dtype), + ) + sorted_next_token = tf.random.categorical( + tf.math.log(probs), 1, seed=self.seed + ) + next_token = tf.gather_nd( + sorted_indices, sorted_next_token, batch_dims=1 + ) + next_token = tf.cast(next_token, dtype=prompt.dtype) + next_token = tf.where( + mask[:, length], prompt[:, length], next_token + ) + + # Append the next token to current sequence. + prompt = tf.tensor_scatter_nd_update( + tensor=prompt, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=next_token, + ) + + length = tf.add(length, 1) + return (length, prompt) + + # Run a while loop till text of length `max_length` has been generated. + length, prompt = tf.while_loop( + cond=lambda length, _: tf.less(length, max_length), + body=one_step, + loop_vars=(length, prompt), + ) + + return prompt + + +TopPSampler.__doc__ = TopPSampler.__doc__.replace( + "{{base_sampler_keyword_args}}", base_sampler_keyword_args +) +TopPSampler.__doc__ = TopPSampler.__doc__.replace( + "{{call_keyword_docstring}}", call_keyword_docstring +) +TopPSampler.sample.__doc__ = TopPSampler.sample.__doc__.replace( + "{{sample_keyword_docstring}}", sample_keyword_docstring +) From 0eb68f6e95cbea58ee969985175f5244bed0f31a Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Tue, 20 Dec 2022 18:13:03 -0800 Subject: [PATCH 05/28] Add GPT2 text generation stuff --- keras_nlp/models/__init__.py | 5 + keras_nlp/models/gpt2/gpt2_causal_lm.py | 110 ++++++++++++++++++ keras_nlp/models/gpt2/gpt2_preprocessor.py | 102 ++++++++++++++++ keras_nlp/samplers/__init__.py | 3 + .../{top_k_sampler => top_k_sampler.py} | 1 + keras_nlp/samplers/top_p_sampler.py | 24 +++- 6 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 keras_nlp/models/gpt2/gpt2_causal_lm.py create mode 100644 keras_nlp/models/gpt2/gpt2_preprocessor.py rename keras_nlp/samplers/{top_k_sampler => top_k_sampler.py} (99%) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index bdd6042538..a74f4de811 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -26,6 +26,11 @@ from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( DistilBertTokenizer, ) +from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM +from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2CausalLMPreprocessor +from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py new file mode 100644 index 0000000000..7f8bd7f5d8 --- /dev/null +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -0,0 +1,110 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT task specific models and heads.""" + +import copy + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2CausalLMPreprocessor +from keras_nlp.models.gpt2.gpt2_presets import backbone_presets +from keras_nlp.samplers.beam_sampler import BeamSampler +from keras_nlp.samplers.greedy_sampler import GreedySampler +from keras_nlp.samplers.top_k_sampler import TopKSampler +from keras_nlp.samplers.top_p_sampler import TopPSampler +from keras_nlp.utils.pipeline_model import PipelineModel +from keras_nlp.utils.python_utils import classproperty + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class GPT2CausalLM(PipelineModel): + def __init__(self, backbone, preprocessor=None, **kwargs): + + inputs = backbone.input + x = backbone(inputs) + x = tf.matmul( + x, + backbone.get_layer("token_embedding").embeddings, + transpose_b=True, + ) + outputs = tf.keras.layers.Softmax()(x) + # Instantiate using Functional API Model constructor + super().__init__( + inputs=inputs, + outputs=outputs, + include_preprocessing=preprocessor is not None, + **kwargs, + ) + + self.preprocessor = preprocessor + self.backbone = backbone + + def preprocess_samples(self, x, y=None, sample_weight=None): + return self.preprocessor(x, y=y, sample_weight=sample_weight) + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classmethod + def from_preset( + cls, + preset, + load_weights=True, + **kwargs, + ): + if "preprocessor" not in kwargs: + kwargs["preprocessor"] = GPT2CausalLMPreprocessor.from_preset( + preset + ) + + # Check if preset is backbone-only model. + if preset in GPT2Backbone.presets: + backbone = GPT2Backbone.from_preset(preset, load_weights) + return cls(backbone, **kwargs) + + # Otherwise must be one of class presets. + # Currently no classifier-level presets, so we raise ValueError. + if preset not in cls.presets: + raise ValueError( + "`preset` must be one of " + f"""{", ".join(cls.presets)}. Received: {preset}.""" + ) + + def _get_generator(self, identifier): + maps = { + "greedy": GreedySampler(), + "top_k": TopKSampler(k=5, from_logits=False), + "top_p": TopPSampler(p=0.1, from_logits=False), + "beam": BeamSampler(num_beams=5), + } + return maps[identifier] + + def _get_token_probability(self, prompt, mask): + model_inputs = { + "token_ids": prompt, + "padding_mask": mask, + } + probs = self(model_inputs) + return probs + + def generate(self, prompt, max_length, generator="top_k"): + """Pick one method as the default generation algo.""" + if isinstance(generator, str): + generator = self._get_generator(generator) + prompt = self.preprocessor.tokenizer(prompt) + generated = generator(self._get_token_probability, prompt, max_length) + return self.preprocessor.tokenizer.detokenize(generated) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py new file mode 100644 index 0000000000..e14ef10386 --- /dev/null +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -0,0 +1,102 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT2 preprocessor layer.""" + +import copy + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.models.gpt2.gpt2_presets import backbone_presets +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +class GPT2Preprocessor(keras.layers.Layer): + def __init__(self, tokenizer, sequence_length, **kwargs): + + super().__init__(**kwargs) + + self.tokenizer = tokenizer + self.sequence_length = sequence_length + + def call(self, x, y=None, sample_weight=None): + token_ids = self.tokenizer(x) + mask = tf.ones_like(token_ids, dtype=tf.bool) + mask = mask.to_tensor(shape=(None, self.sequence_length)) + token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) + x = { + "token_ids": token_ids, + "padding_mask": mask, + } + + return pack_x_y_sample_weight(x, y, sample_weight) + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classmethod + def from_preset( + cls, + preset, + sequence_length=None, + **kwargs, + ): + if preset not in cls.presets: + raise ValueError( + "`preset` must be one of " + f"""{", ".join(cls.presets)}. Received: {preset}.""" + ) + + tokenizer = GPT2Tokenizer.from_preset(preset) + + # Use model's `max_sequence_length` if `sequence_length` unspecified; + # otherwise check that `sequence_length` not too long. + metadata = cls.presets[preset] + max_sequence_length = metadata["config"]["max_sequence_length"] + if sequence_length is not None: + if sequence_length > max_sequence_length: + raise ValueError( + f"`sequence_length` cannot be longer than `{preset}` " + f"preset's `max_sequence_length` of {max_sequence_length}. " + f"Received: {sequence_length}." + ) + else: + sequence_length = max_sequence_length + + return cls( + tokenizer=tokenizer, + sequence_length=sequence_length, + **kwargs, + ) + + +class GPT2CausalLMPreprocessor(GPT2Preprocessor): + def call(self, x, y=None, sample_weight=None): + token_ids = self.tokenizer(x) + mask = tf.ones_like(token_ids, dtype=tf.bool) + mask = mask.to_tensor(shape=(None, self.sequence_length)) + token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) + x = { + "token_ids": token_ids[:, :-1], + "padding_mask": mask[:, 1:], + } + + y = token_ids[:, 1:] + sample_weight = mask[:, 1:] + + return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 11908fb4dc..e2b500c2e5 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.samplers.beam_sampler import BeamSampler from keras_nlp.samplers.greedy_sampler import GreedySampler +from keras_nlp.samplers.top_k_sampler import TopKSampler +from keras_nlp.samplers.top_p_sampler import TopPSampler diff --git a/keras_nlp/samplers/top_k_sampler b/keras_nlp/samplers/top_k_sampler.py similarity index 99% rename from keras_nlp/samplers/top_k_sampler rename to keras_nlp/samplers/top_k_sampler.py index a6f816d4ae..ea8806c46e 100644 --- a/keras_nlp/samplers/top_k_sampler +++ b/keras_nlp/samplers/top_k_sampler.py @@ -14,6 +14,7 @@ """Greedy Sampler.""" import tensorflow as tf +from tensorflow import keras from keras_nlp.samplers.sampler import Sampler from keras_nlp.samplers.sampler import base_sampler_keyword_args diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index e4caaa11c4..9826bc0403 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -58,8 +58,8 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): max_length = tf.cast(max_length, num_steps.dtype) length = max_length - num_steps - def one_step(length, prompt): - pred = token_probability_fn(prompt[:, :length]) + def one_step(length, prompt, mask): + pred = token_probability_fn(prompt[:, :length], mask) if self.from_logits: pred = keras.activations.softmax(pred, axis=-1) # Sort preds in descending order. @@ -91,6 +91,18 @@ def one_step(length, prompt): mask[:, length], prompt[:, length], next_token ) + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + # Append the next token to current sequence. prompt = tf.tensor_scatter_nd_update( tensor=prompt, @@ -105,13 +117,13 @@ def one_step(length, prompt): ) length = tf.add(length, 1) - return (length, prompt) + return (length, prompt, mask) # Run a while loop till text of length `max_length` has been generated. - length, prompt = tf.while_loop( - cond=lambda length, _: tf.less(length, max_length), + length, prompt, mask = tf.while_loop( + cond=lambda length, prompt, mask: tf.less(length, max_length), body=one_step, - loop_vars=(length, prompt), + loop_vars=(length, prompt, mask), ) return prompt From fa41d23edb2119aa533f02cc5fe9bf298948ffc5 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Tue, 3 Jan 2023 20:30:37 -0800 Subject: [PATCH 06/28] correct top-p and beam sampler --- keras_nlp/samplers/beam_sampler.py | 46 ++++++++++++++++++++--------- keras_nlp/samplers/top_k_sampler.py | 37 +++++++++++++---------- keras_nlp/samplers/top_p_sampler.py | 7 +++-- 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 91d94330af..4a1eb62e6f 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -57,10 +57,8 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] max_length = tf.cast(max_length, num_steps.dtype) length = max_length - num_steps - dummy_preds = self._validate_token_probability_fn( - token_probability_fn, prompt, mask - ) - vocab_size = dummy_preds.shape[-1] + dummy_preds = token_probability_fn(prompt, mask=mask) + vocab_size = tf.shape(dummy_preds)[-1] pred_dtype = dummy_preds.dtype num_beams = self.num_beams @@ -74,24 +72,30 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): axis=-1, ) - def one_step(beams, beams_prob, length): - truncated_beams = beams[..., :length] + def one_step(beams, beams_prob, length, mask): flattened_beams = tf.reshape( - truncated_beams, shape=[batch_size * num_beams, -1] + beams, shape=[batch_size * num_beams, -1] + ) + repeated_mask = tf.tile(mask, [num_beams, 1]) + probs = token_probability_fn(flattened_beams, repeated_mask) + preds = tf.gather( + probs, + tf.repeat(length - 1, batch_size * num_beams), + axis=1, + batch_dims=1, ) - preds = token_probability_fn(flattened_beams) if self.from_logits: preds = keras.activations.softmax(preds, axis=-1) # Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`. preds = tf.reshape(preds, shape=[batch_size, -1]) - probs = tf.math.log(preds) + tf.repeat( + cum_probs = tf.math.log(preds) + tf.repeat( beams_prob, repeats=vocab_size, axis=1 ) candidate_prob, candidate_indexes = tf.math.top_k( - probs, k=num_beams, sorted=False + cum_probs, k=num_beams, sorted=False ) candidate_beam_indexes = candidate_indexes // vocab_size next_token = candidate_indexes % vocab_size @@ -108,6 +112,18 @@ def one_step(beams, beams_prob, length): ) next_token = tf.reshape(next_token, shape=[-1]) + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + # Generate `(batch_index, beam_index)` tuples for each beam. beam_indices = tf.where(tf.ones((batch_size, num_beams), tf.bool)) beam_indices = tf.cast(beam_indices, dtype=length.dtype) @@ -126,13 +142,15 @@ def one_step(beams, beams_prob, length): beams_prob = candidate_prob length = tf.add(length, 1) - return beams, beams_prob, length + return beams, beams_prob, length, mask # Run a while loop till text of length `max_length` has been generated. - beams, beams_prob, length = tf.while_loop( - cond=lambda beams, beams_prob, length: tf.less(length, max_length), + beams, beams_prob, length, mask = tf.while_loop( + cond=lambda beams, beams_prob, length, mask: tf.less( + length, max_length + ), body=one_step, - loop_vars=(beams, beams_prob, length), + loop_vars=(beams, beams_prob, length, mask), ) # Get the beam with the maximum probability. diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index ea8806c46e..fb2a37f51c 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -49,11 +49,6 @@ def __init__( super().__init__(end_token_id, pad_token_id, jit_compile) def sample(self, token_probability_fn, prompt, mask, num_steps): - """Sampler's logic implementation. - - Args: - {{call_keyword_docstring}} - """ batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] max_length = tf.cast(max_length, num_steps.dtype) length = max_length - num_steps @@ -65,18 +60,30 @@ def one_step(length, prompt, mask): ) if self.from_logits: pred = keras.activations.softmax(pred, axis=-1) - - # Filter out top-k tokens. - top_k_pred, top_k_indices = tf.math.top_k( - pred, k=self.k, sorted=False + # Sort preds in descending order. + sorted_preds, sorted_indices = tf.math.top_k( + pred, k=tf.shape(pred)[1], sorted=True ) - # Sample the next token from the probability distribution. - next_token = tf.random.categorical( - tf.math.log(top_k_pred), 1, seed=self.seed + # Calculate cumulative probability distribution. + cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1) + # Create a mask for the tokens to keep. + keep_mask = cumulative_probs <= self.p + # Shift to include the last token that exceed p. + shifted_keep_mask = tf.concat( + [tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1 + ) + # Filter out unmasked tokens and sample from filtered distribution. + probs = tf.where( + shifted_keep_mask, + sorted_preds, + tf.zeros(tf.shape(pred), dtype=sorted_preds.dtype), + ) + sorted_next_token = tf.random.categorical( + tf.math.log(probs), 1, seed=self.seed + ) + next_token = tf.gather_nd( + sorted_indices, sorted_next_token, batch_dims=1 ) - - # Rearrange to get the next token idx from the original order. - next_token = tf.gather_nd(top_k_indices, next_token, batch_dims=1) next_token = tf.cast(next_token, dtype=prompt.dtype) next_token = tf.where( mask[:, length], prompt[:, length], next_token diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index 9826bc0403..366ed7798d 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -59,12 +59,15 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): length = max_length - num_steps def one_step(length, prompt, mask): - pred = token_probability_fn(prompt[:, :length], mask) + probs = token_probability_fn(prompt, mask) + pred = tf.gather( + probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + ) if self.from_logits: pred = keras.activations.softmax(pred, axis=-1) # Sort preds in descending order. sorted_preds, sorted_indices = tf.math.top_k( - pred, k=pred.shape[1], sorted=True + pred, k=tf.shape(pred)[1], sorted=True ) # Calculate cumulative probability distribution. cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1) From 26fd5099fbadbbde7f56cec17d15b5d876e9a3ea Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 9 Dec 2022 14:41:21 -0800 Subject: [PATCH 07/28] initial commit --- keras_nlp/samplers/__init__.py | 13 ++ keras_nlp/samplers/greedy_sampler.py | 112 ++++++++++++++ keras_nlp/samplers/greedy_sampler_test.py | 114 ++++++++++++++ keras_nlp/samplers/sampler.py | 178 ++++++++++++++++++++++ 4 files changed, 417 insertions(+) create mode 100644 keras_nlp/samplers/__init__.py create mode 100644 keras_nlp/samplers/greedy_sampler.py create mode 100644 keras_nlp/samplers/greedy_sampler_test.py create mode 100644 keras_nlp/samplers/sampler.py diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py new file mode 100644 index 0000000000..6e4df4e727 --- /dev/null +++ b/keras_nlp/samplers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py new file mode 100644 index 0000000000..816193ebf3 --- /dev/null +++ b/keras_nlp/samplers/greedy_sampler.py @@ -0,0 +1,112 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Greedy Sampler.""" + +import tensorflow as tf + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_keyword_args +from keras_nlp.samplers.sampler import call_keyword_docstring +from keras_nlp.samplers.sampler import sample_keyword_docstring + + +class GreedySampler(Sampler): + """Greedy Sampler class. + + This sampler is implemented on greedy search, i.e., always picking up the + token of the largest probability as the next token. + + Args: + {{base_sampler_keyword_args}} + + Call Args: + {{call_keyword_args}} + """ + + def __init__( + self, + end_token_id=None, + pad_token_id=0, + jit_compile=True, + ): + super().__init__(end_token_id, pad_token_id, jit_compile) + + def sample(self, token_probability_fn, prompt, mask, num_steps): + """Sampler's logic implementation. + + Args: + {{call_keyword_docstring}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + length = max_length - num_steps + + def one_step(length, prompt, mask): + + probs = token_probability_fn(prompt, mask) + next_token_prob = tf.gather( + probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + ) + next_token = tf.cast( + tf.argmax(next_token_prob, axis=-1), dtype=prompt.dtype + ) + next_token = tf.where( + mask[:, length], prompt[:, length], next_token + ) + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + + # Append the next token to current sequence. + prompt = tf.tensor_scatter_nd_update( + tensor=prompt, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=next_token, + ) + + length = tf.add(length, 1) + return (length, prompt, mask) + + # Run a while loop till text of length `max_length` has been generated. + length, prompt, mask = tf.while_loop( + cond=lambda length, prompt, mask: tf.less(length, max_length), + body=one_step, + loop_vars=(length, prompt, mask), + ) + return prompt + + +GreedySampler.__doc__ = GreedySampler.__doc__.replace( + "{{base_sampler_keyword_args}}", base_sampler_keyword_args +) +GreedySampler.__doc__ = GreedySampler.__doc__.replace( + "{{call_keyword_docstring}}", call_keyword_docstring +) +GreedySampler.sample.__doc__ = GreedySampler.sample.__doc__.replace( + "{{sample_keyword_docstring}}", sample_keyword_docstring +) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py new file mode 100644 index 0000000000..6e787d0afa --- /dev/null +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -0,0 +1,114 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Text Generation Utils.""" + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.samplers.greedy_sampler import GreedySampler + + +class GreedySamplerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.vocab_size = 10 + self.feature_size = 16 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=self.vocab_size, + output_dim=self.feature_size, + ), + keras.layers.Dense(self.vocab_size), + keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs, mask): + return model(inputs) + + self.token_probability_fn = token_probability_fn + + self.sampler = GreedySampler() + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + self.assertEqual(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + max_length = 5 + + def token_probability_fn(inputs, mask): + # Assert that user function is passed only dense tensors. + self.assertIsInstance(inputs, tf.Tensor) + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1) + + inputs = tf.ragged.constant([[1], [2, 1, 2]]) + outputs = self.sampler(token_probability_fn, inputs, max_length) + self.assertEqual(outputs.shape, [2, 5]) + + def test_assert_generation_is_correct(self): + batch_size = 10 + max_length = 3 + + def token_probability_fn(inputs, mask): + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat( + tf.repeat(prob, batch_size, axis=0), max_length, axis=1 + ) + + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + outputs = self.sampler( + token_probability_fn, inputs, max_length=max_length + ) + self.assertAllEqual( + outputs, 3 * tf.ones(shape=[batch_size, max_length]) + ) + + def test_end_token_id(self): + max_length = 5 + + def token_probability_fn(inputs, mask): + batch_size = inputs.shape[0] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.repeat( + tf.repeat(prob, batch_size, axis=0), max_length, axis=1 + ) + + sampler = GreedySampler(end_token_id=2) + inputs = tf.constant([[0, 1], [1, 2]]) + outputs = sampler( + token_probability_fn, + inputs, + max_length=max_length, + ) + expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) + expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + self.assertAllEqual(outputs, expected_outputs) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py new file mode 100644 index 0000000000..90e2aae086 --- /dev/null +++ b/keras_nlp/samplers/sampler.py @@ -0,0 +1,178 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base sampler class.""" + +import tensorflow as tf + + +class Sampler: + """Base sampler class. + + This class must be implemented by child class for instantiation. + + Args: + {{base_optimizer_keyword_args}} + + Call Args: + + """ + + def __init__( + self, + end_token_id=None, + pad_token_id=0, + jit_compile=True, + ): + self.end_token_id = end_token_id + self.pad_token_id = pad_token_id + self.jit_compile = jit_compile + + def _validate_prompt(self, prompt): + """Helper function to validate input to text_generation utils.""" + if not isinstance(prompt, (tf.Tensor, tf.RaggedTensor)): + prompt = tf.convert_to_tensor(prompt) + return prompt + + def _validate_token_probability_fn( + self, token_probability_fn, prompt, mask + ): + """Helper function to validate `token_probability_fn` output.""" + test_pred = token_probability_fn(prompt, mask=mask) + if len(test_pred.shape) != 3: + raise ValueError( + "Output of `token_probability_fn` is not a 3D tensor, " + "please provide a function with the output shape " + "[batch_size, sequence_length, vocab_size]." + ) + + def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): + """Align prompt to the right side, and pad to `max_length`.""" + longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + pad_length = longest_prompt_len - prompt.row_lengths() + + prompt = tf.keras.utils.pad_sequences( + prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id + ) + + mask = tf.RaggedTensor.from_row_lengths( + tf.zeros(shape=[tf.reduce_sum(pad_length)], dtype=tf.int32), + pad_length, + ) + mask = mask.to_tensor(shape=(None, longest_prompt_len), default_value=1) + + shape = prompt.shape + extra_space = tf.math.maximum(0, max_length - shape[1]) + pad_shape = [shape[0], extra_space] + + mask = tf.concat((mask, tf.zeros(pad_shape, tf.int32)), axis=1) + prompt = tf.concat( + (prompt, tf.zeros(pad_shape, prompt.dtype) + pad_token_id), axis=1 + ) + mask = tf.cast(mask, dtype=tf.bool) + return prompt, mask + + def _mask_tokens_after_end_token( + self, prompt, max_length, end_token_id, pad_token_id + ): + """Helper function to mask the tokens after the end token.""" + # Mask out tokens after `end_token_id` is encountered. + # Find index of first end_token_id. + end_indices = tf.math.argmax(prompt == end_token_id, -1) + # Use max_length if no `end_token_id` is found. + end_indices = tf.where( + end_indices == 0, + tf.cast(max_length, dtype=end_indices.dtype), + end_indices, + ) + # Build a mask including end_token and replace tokens after end_token + # with `pad_token_id`. + valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) + return tf.where(valid_indices, prompt, pad_token_id) + + def __call__(self, token_probability_fn, prompt, max_length): + """Sampling method to be called by users.""" + + prompt = self._validate_prompt(prompt) + + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + if isinstance(prompt, tf.Tensor): + prompt = tf.RaggedTensor.from_tensor( + prompt, padding=self.pad_token_id + ) + longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + prompt, mask = self._align_and_pad_prompt( + prompt, max_length, self.pad_token_id + ) + self._validate_token_probability_fn(token_probability_fn, prompt, mask) + + sample = tf.function(self.sample, jit_compile=self.jit_compile) + prompt = sample( + token_probability_fn, prompt, mask, max_length - longest_prompt_len + ) + + if self.end_token_id is not None: + prompt = self._mask_tokens_after_end_token( + prompt, max_length, self.end_token_id, self.pad_token_id + ) + + return tf.squeeze(prompt) if input_is_1d else prompt + + def sample(self, token_probability_fn, prompt, mask, num_steps): + """Sampler's logic implementation. + + Args: + {{sample_keyword_docstring}} + + Returns: + A dense int Tensor, representing the generated text in token id + space. + """ + raise NotImplementedError + + +base_sampler_keyword_args = """ + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + pad_token_id: int, defaults to 0. The pad token after `end_token_id` + is received. + jit_compile: bool, defaults to True. If using XLA compilation.""" + +call_keyword_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a list of integers or an integer Tensor, can be 1D or 2D. The + initial tokens to append generated tokens. + max_length: int. The max length of generated sequence.""" + +sample_keyword_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a list of integers or an integer Tensor, can be 1D or 2D. The + initial tokens to append generated tokens. + num_steps: int. The number of tokens to generate.""" + +Sampler.__doc__ = Sampler.__doc__.replace( + "{{base_sampler_keyword_args}}", base_sampler_keyword_args +) +Sampler.__doc__ = Sampler.__call__.__doc__.replace( + "{{call_keyword_docstring}}", call_keyword_docstring +) +Sampler.sample.__doc__ = Sampler.sample.__doc__.replace( + "{{sample_keyword_docstring}}", sample_keyword_docstring +) From 7e4c651df9c58f1005ea0bfb85dfb514a6750596 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 9 Dec 2022 17:32:00 -0800 Subject: [PATCH 08/28] Add keras_nlp.samplers --- keras_nlp/samplers/__init__.py | 2 + keras_nlp/samplers/greedy_sampler.py | 75 ++++++++++++++++++----- keras_nlp/samplers/greedy_sampler_test.py | 16 ++++- keras_nlp/samplers/sampler.py | 70 ++++++++++++++++----- 4 files changed, 129 insertions(+), 34 deletions(-) diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 6e4df4e727..11908fb4dc 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.samplers.greedy_sampler import GreedySampler diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 816193ebf3..b57362179c 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -32,6 +32,38 @@ class GreedySampler(Sampler): Call Args: {{call_keyword_args}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + # Print the generated sequence (token ids). + print(sampler(token_probability_fn, prompt, max_length=10)) + ``` """ def __init__( @@ -43,33 +75,40 @@ def __init__( super().__init__(end_token_id, pad_token_id, jit_compile) def sample(self, token_probability_fn, prompt, mask, num_steps): - """Sampler's logic implementation. + """Sampling logic implementation. Args: - {{call_keyword_docstring}} + {{sample_keyword_docstring}} """ batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] max_length = tf.cast(max_length, num_steps.dtype) - length = max_length - num_steps + # The index of the last non-padding token in prompt. Since all sequences + # are aligned to the right side, the index is the same for all. + current_index = max_length - num_steps - def one_step(length, prompt, mask): + def one_step(current_index, prompt, mask): probs = token_probability_fn(prompt, mask) next_token_prob = tf.gather( - probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + probs, + tf.repeat(current_index - 1, batch_size), + axis=1, + batch_dims=1, ) next_token = tf.cast( tf.argmax(next_token_prob, axis=-1), dtype=prompt.dtype ) next_token = tf.where( - mask[:, length], prompt[:, length], next_token + mask[:, current_index], prompt[:, current_index], next_token ) mask = tf.tensor_scatter_nd_update( tensor=mask, indices=tf.stack( ( - tf.cast(tf.range(batch_size), dtype=length.dtype), - tf.repeat(length, batch_size), + tf.cast( + tf.range(batch_size), dtype=current_index.dtype + ), + tf.repeat(current_index, batch_size), ), axis=1, ), @@ -81,22 +120,26 @@ def one_step(length, prompt, mask): tensor=prompt, indices=tf.stack( ( - tf.cast(tf.range(batch_size), dtype=length.dtype), - tf.repeat(length, batch_size), + tf.cast( + tf.range(batch_size), dtype=current_index.dtype + ), + tf.repeat(current_index, batch_size), ), axis=1, ), updates=next_token, ) - length = tf.add(length, 1) - return (length, prompt, mask) + current_index = tf.add(current_index, 1) + return (current_index, prompt, mask) - # Run a while loop till text of length `max_length` has been generated. - length, prompt, mask = tf.while_loop( - cond=lambda length, prompt, mask: tf.less(length, max_length), + # Run a while loop till `max_length` of tokens has been generated. + current_index, prompt, mask = tf.while_loop( + cond=lambda current_index, prompt, mask: tf.less( + current_index, max_length + ), body=one_step, - loop_vars=(length, prompt, mask), + loop_vars=(current_index, prompt, mask), ) return prompt diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index 6e787d0afa..8eac81fd71 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Text Generation Utils.""" +"""Tests for GreedySampler.""" import tensorflow as tf from absl.testing import parameterized @@ -112,3 +112,17 @@ def token_probability_fn(inputs, mask): expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) expected_outputs = tf.concat([inputs, expected_outputs], axis=1) self.assertAllEqual(outputs, expected_outputs) + + def test_compare_xla_noxla_results(self): + inputs = [[1], [1]] + xla_sampler = GreedySampler(jit_compile=True) + outputs_xla = xla_sampler( + self.token_probability_fn, inputs, max_length=5 + ) + + xla_sampler = GreedySampler(jit_compile=False) + outputs_no_xla = xla_sampler( + self.token_probability_fn, inputs, max_length=5 + ) + + self.assertAllEqual(outputs_xla, outputs_no_xla) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 90e2aae086..3f997c3792 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -14,18 +14,49 @@ """Base sampler class.""" import tensorflow as tf +from tensorflow import keras class Sampler: """Base sampler class. - This class must be implemented by child class for instantiation. - Args: {{base_optimizer_keyword_args}} Call Args: - + {{call_keyword_docstring}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + # Print the generated sequence (token ids). + print(sampler(token_probability_fn, prompt, max_length=10)) + ``` """ def __init__( @@ -39,7 +70,7 @@ def __init__( self.jit_compile = jit_compile def _validate_prompt(self, prompt): - """Helper function to validate input to text_generation utils.""" + """Helper method to validate input prompt.""" if not isinstance(prompt, (tf.Tensor, tf.RaggedTensor)): prompt = tf.convert_to_tensor(prompt) return prompt @@ -47,7 +78,7 @@ def _validate_prompt(self, prompt): def _validate_token_probability_fn( self, token_probability_fn, prompt, mask ): - """Helper function to validate `token_probability_fn` output.""" + """Helper method to validate `token_probability_fn` output.""" test_pred = token_probability_fn(prompt, mask=mask) if len(test_pred.shape) != 3: raise ValueError( @@ -61,7 +92,7 @@ def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): longest_prompt_len = tf.reduce_max(prompt.row_lengths()) pad_length = longest_prompt_len - prompt.row_lengths() - prompt = tf.keras.utils.pad_sequences( + prompt = keras.utils.pad_sequences( prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id ) @@ -97,12 +128,10 @@ def _mask_tokens_after_end_token( ) # Build a mask including end_token and replace tokens after end_token # with `pad_token_id`. - valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) - return tf.where(valid_indices, prompt, pad_token_id) + mask_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) + return tf.where(mask_indices, prompt, pad_token_id) def __call__(self, token_probability_fn, prompt, max_length): - """Sampling method to be called by users.""" - prompt = self._validate_prompt(prompt) input_is_1d = prompt.shape.rank == 1 @@ -113,16 +142,23 @@ def __call__(self, token_probability_fn, prompt, max_length): prompt, padding=self.pad_token_id ) longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. + # This step is required for XLA compatibility because XLA requires a + # static shape, which means we cannot concatenate generated token to + # current prompt. prompt, mask = self._align_and_pad_prompt( prompt, max_length, self.pad_token_id ) self._validate_token_probability_fn(token_probability_fn, prompt, mask) + # Convert `sample` method to a `tf.function`, and turn on + # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( token_probability_fn, prompt, mask, max_length - longest_prompt_len ) + # Mask out tokens after `end_token_id`. if self.end_token_id is not None: prompt = self._mask_tokens_after_end_token( prompt, max_length, self.end_token_id, self.pad_token_id @@ -131,7 +167,7 @@ def __call__(self, token_probability_fn, prompt, max_length): return tf.squeeze(prompt) if input_is_1d else prompt def sample(self, token_probability_fn, prompt, mask, num_steps): - """Sampler's logic implementation. + """Sampling logic implementation. Args: {{sample_keyword_docstring}} @@ -149,9 +185,9 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): sequence. If None, every sequence is generated up to `max_length`. If set, all tokens after encountering `end_token_id` will be replaced with `pad_token_id`. - pad_token_id: int, defaults to 0. The pad token after `end_token_id` - is received. - jit_compile: bool, defaults to True. If using XLA compilation.""" + pad_token_id: int, defaults to 0. The padding token. + jit_compile: bool, defaults to True. If True, XLA compilation will be used. + """ call_keyword_docstring = """ token_probability_fn: a function that generates the probability of @@ -163,9 +199,9 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): sample_keyword_docstring = """ token_probability_fn: a function that generates the probability of the next token over the whole vocabulary for each input token. - prompt: a list of integers or an integer Tensor, can be 1D or 2D. The - initial tokens to append generated tokens. - num_steps: int. The number of tokens to generate.""" + prompt: a dense int Tensor of shape [batch_size, max_length]. The + placeholder for generated sequence. + num_steps: int. The remaining number of tokens to generate.""" Sampler.__doc__ = Sampler.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args From 28bcfe1b8d3875e30f2bf5b23fa68781be0c63b8 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 12 Dec 2022 15:50:16 -0800 Subject: [PATCH 09/28] Change padding to left to right --- keras_nlp/samplers/sampler.py | 40 +++++++++-------------------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 3f997c3792..4e6aca3bb8 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -14,7 +14,6 @@ """Base sampler class.""" import tensorflow as tf -from tensorflow import keras class Sampler: @@ -87,30 +86,13 @@ def _validate_token_probability_fn( "[batch_size, sequence_length, vocab_size]." ) - def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): - """Align prompt to the right side, and pad to `max_length`.""" - longest_prompt_len = tf.reduce_max(prompt.row_lengths()) - pad_length = longest_prompt_len - prompt.row_lengths() - - prompt = keras.utils.pad_sequences( - prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id - ) - - mask = tf.RaggedTensor.from_row_lengths( - tf.zeros(shape=[tf.reduce_sum(pad_length)], dtype=tf.int32), - pad_length, - ) - mask = mask.to_tensor(shape=(None, longest_prompt_len), default_value=1) - - shape = prompt.shape - extra_space = tf.math.maximum(0, max_length - shape[1]) - pad_shape = [shape[0], extra_space] - - mask = tf.concat((mask, tf.zeros(pad_shape, tf.int32)), axis=1) - prompt = tf.concat( - (prompt, tf.zeros(pad_shape, prompt.dtype) + pad_token_id), axis=1 + def _pad_prompt(self, prompt, max_length, pad_token_id): + """Pad prompt to `max_length`.""" + mask = tf.ones_like(prompt, dtype=tf.bool) + mask = mask.to_tensor(shape=(None, max_length)) + prompt = prompt.to_tensor( + shape=(None, max_length), default_value=pad_token_id ) - mask = tf.cast(mask, dtype=tf.bool) return prompt, mask def _mask_tokens_after_end_token( @@ -141,21 +123,19 @@ def __call__(self, token_probability_fn, prompt, max_length): prompt = tf.RaggedTensor.from_tensor( prompt, padding=self.pad_token_id ) - longest_prompt_len = tf.reduce_max(prompt.row_lengths()) + shortest_prompt_len = tf.reduce_min(prompt.row_lengths()) # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. # This step is required for XLA compatibility because XLA requires a # static shape, which means we cannot concatenate generated token to # current prompt. - prompt, mask = self._align_and_pad_prompt( - prompt, max_length, self.pad_token_id - ) + prompt, mask = self._pad_prompt(prompt, max_length, self.pad_token_id) self._validate_token_probability_fn(token_probability_fn, prompt, mask) # Convert `sample` method to a `tf.function`, and turn on # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( - token_probability_fn, prompt, mask, max_length - longest_prompt_len + token_probability_fn, prompt, mask, max_length - shortest_prompt_len ) # Mask out tokens after `end_token_id`. @@ -206,7 +186,7 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): Sampler.__doc__ = Sampler.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args ) -Sampler.__doc__ = Sampler.__call__.__doc__.replace( +Sampler.__doc__ = Sampler.__doc__.replace( "{{call_keyword_docstring}}", call_keyword_docstring ) Sampler.sample.__doc__ = Sampler.sample.__doc__.replace( From 9757f4d78f79da4b548892c78f38beea4fd9df09 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 5 Jan 2023 12:16:14 -0800 Subject: [PATCH 10/28] Add serialization support, and move some args from constructor to call --- keras_nlp/__init__.py | 1 + keras_nlp/samplers/__init__.py | 68 +++++++++- .../samplers/{greedy_sampler.py => greedy.py} | 24 ++-- ...{greedy_sampler_test.py => greedy_test.py} | 18 +-- keras_nlp/samplers/sampler.py | 123 ++++++++++++------ keras_nlp/samplers/sampler_test.py | 57 ++++++++ 6 files changed, 229 insertions(+), 62 deletions(-) rename keras_nlp/samplers/{greedy_sampler.py => greedy.py} (88%) rename keras_nlp/samplers/{greedy_sampler_test.py => greedy_test.py} (89%) create mode 100644 keras_nlp/samplers/sampler_test.py diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 73e7e2c593..050a189d6f 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -15,6 +15,7 @@ from keras_nlp import layers from keras_nlp import metrics from keras_nlp import models +from keras_nlp import samplers from keras_nlp import tokenizers from keras_nlp import utils diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 11908fb4dc..f4841dd85a 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -12,4 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.samplers.greedy_sampler import GreedySampler +from tensorflow import keras + +from keras_nlp.samplers.greedy import Greedy + + +def serialize(sampler): + return keras.utils.serialize_keras_object(sampler) + + +def deserialize(config, custom_objects=None): + """Return a `Sampler` object from its config.""" + all_classes = { + "greedy": Greedy, + } + if config["class_name"].lower() in all_classes: + config["class_name"] = config["class_name"].lower() + return keras.utils.deserialize_keras_object( + config, + module_objects=all_classes, + custom_objects=custom_objects, + printable_module_name="samplers", + ) + + +def get(identifier): + """Retrieve a KerasNLP sampler by the identifier. + + The `identifier` may be the string name of a sampler class or class. + + >>> identifier = 'greedy' + >>> sampler = keras_nlp.samplers.get(identifier) + + You can also specify `config` of the sampler to this function by passing + dict containing `class_name` and `config` as an identifier. Also note that + the `class_name` must map to a `Sampler` class. + + >>> cfg = {'class_name': 'Greedy', 'config': {}} + >>> sampler = keras_nlp.samplers.get(cfg) + + In the case that the `identifier` is a class, this method will return a new + instance of the class by its constructor. + + Args: + identifier: String or dict that contains the sampler name or + configurations. + + Returns: + Sampler instance base on the input identifier. + + Raises: + ValueError: If the input identifier is not a supported type or in a bad + format. + """ + + if identifier is None: + return None + if isinstance(identifier, dict): + return deserialize(identifier) + elif isinstance(identifier, str): + identifier = {"class_name": str(identifier), "config": {}} + return deserialize(identifier) + elif callable(identifier): + return identifier + else: + raise ValueError( + "Could not interpret sampler identifier: " + str(identifier) + ) diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy.py similarity index 88% rename from keras_nlp/samplers/greedy_sampler.py rename to keras_nlp/samplers/greedy.py index b57362179c..f486a7cca0 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy.py @@ -21,8 +21,8 @@ from keras_nlp.samplers.sampler import sample_keyword_docstring -class GreedySampler(Sampler): - """Greedy Sampler class. +class Greedy(Sampler): + """Greedy sampler class. This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. @@ -55,26 +55,26 @@ class GreedySampler(Sampler): # Define a function that outputs the next token's probability for each token # in the input sequence. - def token_probability_fn(inputs): + def token_probability_fn(inputs, mask): return model(inputs) prompt = tf.fill((BATCH_SIZE, 1), START_ID) - sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, max_length=10)) + print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) ``` """ def __init__( self, - end_token_id=None, - pad_token_id=0, jit_compile=True, ): - super().__init__(end_token_id, pad_token_id, jit_compile) + super().__init__(jit_compile) - def sample(self, token_probability_fn, prompt, mask, num_steps): + def sample( + self, token_probability_fn, prompt, mask, num_steps, from_logits=True + ): """Sampling logic implementation. Args: @@ -144,12 +144,12 @@ def one_step(current_index, prompt, mask): return prompt -GreedySampler.__doc__ = GreedySampler.__doc__.replace( +Greedy.__doc__ = Greedy.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args ) -GreedySampler.__doc__ = GreedySampler.__doc__.replace( +Greedy.__doc__ = Greedy.__doc__.replace( "{{call_keyword_docstring}}", call_keyword_docstring ) -GreedySampler.sample.__doc__ = GreedySampler.sample.__doc__.replace( +Greedy.sample.__doc__ = Greedy.sample.__doc__.replace( "{{sample_keyword_docstring}}", sample_keyword_docstring ) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_test.py similarity index 89% rename from keras_nlp/samplers/greedy_sampler_test.py rename to keras_nlp/samplers/greedy_test.py index 8eac81fd71..0472bdcbc9 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_test.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for GreedySampler.""" +"""Tests for Greedy sampler.""" import tensorflow as tf from absl.testing import parameterized from tensorflow import keras -from keras_nlp.samplers.greedy_sampler import GreedySampler +from keras_nlp.samplers.greedy import Greedy -class GreedySamplerTest(tf.test.TestCase, parameterized.TestCase): +class GreedyTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() self.vocab_size = 10 @@ -44,7 +44,7 @@ def token_probability_fn(inputs, mask): self.token_probability_fn = token_probability_fn - self.sampler = GreedySampler() + self.sampler = Greedy() def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) @@ -102,25 +102,25 @@ def token_probability_fn(inputs, mask): tf.repeat(prob, batch_size, axis=0), max_length, axis=1 ) - sampler = GreedySampler(end_token_id=2) + sampler = Greedy() inputs = tf.constant([[0, 1], [1, 2]]) outputs = sampler( token_probability_fn, inputs, max_length=max_length, + end_token_id=2, ) - expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) - expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]]) self.assertAllEqual(outputs, expected_outputs) def test_compare_xla_noxla_results(self): inputs = [[1], [1]] - xla_sampler = GreedySampler(jit_compile=True) + xla_sampler = Greedy(jit_compile=True) outputs_xla = xla_sampler( self.token_probability_fn, inputs, max_length=5 ) - xla_sampler = GreedySampler(jit_compile=False) + xla_sampler = Greedy(jit_compile=False) outputs_no_xla = xla_sampler( self.token_probability_fn, inputs, max_length=5 ) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 4e6aca3bb8..64ff4bba0c 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -47,32 +47,45 @@ class Sampler: # Define a function that outputs the next token's probability for each token # in the input sequence. - def token_probability_fn(inputs): + def token_probability_fn(inputs, mask): return model(inputs) prompt = tf.fill((BATCH_SIZE, 1), START_ID) - sampler = keras_nlp.samplers.GreedySearch(end_token_id=END_ID) + sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, max_length=10)) + print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) ``` """ def __init__( self, - end_token_id=None, - pad_token_id=0, jit_compile=True, ): - self.end_token_id = end_token_id - self.pad_token_id = pad_token_id self.jit_compile = jit_compile - def _validate_prompt(self, prompt): + def _validate_prompt_and_mask(self, prompt, mask): """Helper method to validate input prompt.""" - if not isinstance(prompt, (tf.Tensor, tf.RaggedTensor)): + if not isinstance(prompt, (list, tf.RaggedTensor, tf.Tensor)): + raise ValueError( + "`prompt` must be one of `list`, `tf.RaggedTensor` or " + f"`tf.Tensor`, but received: prompt={type(prompt)}." + ) + + if isinstance(prompt, tf.RaggedTensor): + if mask: + raise ValueError( + "`mask` is only valid when `prompt` is a list or dense " + f"tensor, but received type(prompt)={type(prompt)}." + ) + return prompt, mask + + if isinstance(prompt, list): prompt = tf.convert_to_tensor(prompt) - return prompt + if not mask: + mask = tf.cast(tf.ones_like(prompt), dtype=tf.bool) + prompt = tf.ragged.boolean_mask(prompt, mask) + return prompt, mask def _validate_token_probability_fn( self, token_probability_fn, prompt, mask @@ -86,19 +99,20 @@ def _validate_token_probability_fn( "[batch_size, sequence_length, vocab_size]." ) - def _pad_prompt(self, prompt, max_length, pad_token_id): + def _pad_prompt(self, prompt, max_length): """Pad prompt to `max_length`.""" mask = tf.ones_like(prompt, dtype=tf.bool) mask = mask.to_tensor(shape=(None, max_length)) - prompt = prompt.to_tensor( - shape=(None, max_length), default_value=pad_token_id - ) + prompt = prompt.to_tensor(shape=(None, max_length)) return prompt, mask def _mask_tokens_after_end_token( - self, prompt, max_length, end_token_id, pad_token_id + self, + prompt, + max_length, + end_token_id, ): - """Helper function to mask the tokens after the end token.""" + """Helper function to truncate the tokens after the end token.""" # Mask out tokens after `end_token_id` is encountered. # Find index of first end_token_id. end_indices = tf.math.argmax(prompt == end_token_id, -1) @@ -108,45 +122,59 @@ def _mask_tokens_after_end_token( tf.cast(max_length, dtype=end_indices.dtype), end_indices, ) - # Build a mask including end_token and replace tokens after end_token - # with `pad_token_id`. - mask_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) - return tf.where(mask_indices, prompt, pad_token_id) + # Truncate out tokens after (including) the end token. + mask_indices = tf.sequence_mask(end_indices, maxlen=max_length) + return tf.ragged.boolean_mask(prompt, mask_indices) - def __call__(self, token_probability_fn, prompt, max_length): - prompt = self._validate_prompt(prompt) + def __call__( + self, + token_probability_fn, + prompt, + max_length, + padding_mask=None, + end_token_id=None, + from_logits=True, + ): + prompt, padding_mask = self._validate_prompt_and_mask( + prompt, padding_mask + ) input_is_1d = prompt.shape.rank == 1 if input_is_1d: - prompt = prompt[tf.newaxis, :] - if isinstance(prompt, tf.Tensor): - prompt = tf.RaggedTensor.from_tensor( - prompt, padding=self.pad_token_id - ) + prompt = tf.RaggedTensor.from_tensor(prompt[tf.newaxis, :]) + shortest_prompt_len = tf.reduce_min(prompt.row_lengths()) # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. # This step is required for XLA compatibility because XLA requires a # static shape, which means we cannot concatenate generated token to # current prompt. - prompt, mask = self._pad_prompt(prompt, max_length, self.pad_token_id) + prompt, mask = self._pad_prompt(prompt, max_length) self._validate_token_probability_fn(token_probability_fn, prompt, mask) # Convert `sample` method to a `tf.function`, and turn on # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( - token_probability_fn, prompt, mask, max_length - shortest_prompt_len + token_probability_fn, + prompt, + mask, + max_length - shortest_prompt_len, + from_logits, ) # Mask out tokens after `end_token_id`. - if self.end_token_id is not None: + if end_token_id is not None: prompt = self._mask_tokens_after_end_token( - prompt, max_length, self.end_token_id, self.pad_token_id + prompt, + max_length, + end_token_id, ) return tf.squeeze(prompt) if input_is_1d else prompt - def sample(self, token_probability_fn, prompt, mask, num_steps): + def sample( + self, token_probability_fn, prompt, mask, num_steps, from_logits=True + ): """Sampling logic implementation. Args: @@ -158,14 +186,13 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): """ raise NotImplementedError + def get_config(self): + return { + "jit_compile": self.jit_compile, + } + base_sampler_keyword_args = """ - end_token_id: int, defaults to None. The token marking the end of the - sequence, once encountered the generation is finished for the exact - sequence. If None, every sequence is generated up to `max_length`. - If set, all tokens after encountering `end_token_id` will be - replaced with `pad_token_id`. - pad_token_id: int, defaults to 0. The padding token. jit_compile: bool, defaults to True. If True, XLA compilation will be used. """ @@ -174,14 +201,30 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): the next token over the whole vocabulary for each input token. prompt: a list of integers or an integer Tensor, can be 1D or 2D. The initial tokens to append generated tokens. - max_length: int. The max length of generated sequence.""" + max_length: int. The max length of generated sequence. + padding_mask: a tensor, defaults to None. The padding mask of the prompt. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ sample_keyword_docstring = """ token_probability_fn: a function that generates the probability of the next token over the whole vocabulary for each input token. prompt: a dense int Tensor of shape [batch_size, max_length]. The placeholder for generated sequence. - num_steps: int. The remaining number of tokens to generate.""" + mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of + prompt. + num_steps: int. The remaining number of tokens to generate. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ Sampler.__doc__ = Sampler.__doc__.replace( "{{base_sampler_keyword_args}}", base_sampler_keyword_args diff --git a/keras_nlp/samplers/sampler_test.py b/keras_nlp/samplers/sampler_test.py new file mode 100644 index 0000000000..2e4ed4a1a7 --- /dev/null +++ b/keras_nlp/samplers/sampler_test.py @@ -0,0 +1,57 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Sampler classes.""" + +import tensorflow as tf + +import keras_nlp +from keras_nlp.samplers.greedy import Greedy + + +class SamplerTest(tf.test.TestCase): + def test_serialization(self): + sampler = keras_nlp.samplers.Greedy() + config = keras_nlp.samplers.serialize(sampler) + expected_config = { + "class_name": "Greedy", + "config": { + "jit_compile": True, + }, + } + self.assertDictEqual(expected_config, config) + + def test_deserialization(self): + # Test get from string. + identifier = "greedy" + sampler = keras_nlp.samplers.get(identifier) + self.assertIsInstance(sampler, Greedy) + + # Test string is not case-sensitive. + identifier = "Greedy" + sampler = keras_nlp.samplers.get(identifier) + self.assertIsInstance(sampler, Greedy) + + # Test dict identifier. + original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) + config = keras_nlp.samplers.serialize(original_sampler) + restored_sampler = keras_nlp.samplers.get(config) + self.assertDictEqual( + keras_nlp.samplers.serialize(restored_sampler), + keras_nlp.samplers.serialize(original_sampler), + ) + + # Test identifier is already a sampler instance. + original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) + restored_sampler = keras_nlp.samplers.get(original_sampler) + self.assertEqual(original_sampler, restored_sampler) From f7508cb054c2f83db8d1b864075396b980dd0eb4 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 5 Jan 2023 20:28:40 -0800 Subject: [PATCH 11/28] Add string example --- keras_nlp/samplers/sampler.py | 42 ++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 64ff4bba0c..e7216ab23d 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -25,7 +25,11 @@ class Sampler: Call Args: {{call_keyword_docstring}} + The inputs and outputs of Sampler class are both token ids. + Examples: + + Basic usage: ```python BATCH_SIZE = 8 VOCAB_SIZE = 10 @@ -56,6 +60,42 @@ def token_probability_fn(inputs, mask): # Print the generated sequence (token ids). print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) ``` + + Use with string inputs: + ```python + vocab = ["[UNK]", "[PAD]", "[END]", "the", "quick", "brown", "fox"] + tokenizer = keras_nlp.tokenizers.WordPieceTokenizer( + vocabulary=vocab, + lowercase=True, + ) + FEATURE_SIZE = 16 + VOCAB_SIZE = len(vocab) + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs, mask): + return model(inputs) + + prompt = tokenizer("the quick brown fox") + sampler = keras_nlp.samplers.Greedy() + generated = sampler( + token_probability_fn, + prompt, + 10, + end_token_id=tokenizer.token_to_id("[END]") + ) + print(tokenizer.detokenize(generated)) + ``` """ def __init__( @@ -170,7 +210,7 @@ def __call__( end_token_id, ) - return tf.squeeze(prompt) if input_is_1d else prompt + return tf.squeeze(prompt, axis=0) if input_is_1d else prompt def sample( self, token_probability_fn, prompt, mask, num_steps, from_logits=True From b658b61c0ee757d8ae707a08a36f0fcba0d094d6 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 6 Jan 2023 13:45:51 -0800 Subject: [PATCH 12/28] small changes --- keras_nlp/samplers/greedy.py | 5 +++-- keras_nlp/samplers/sampler.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/keras_nlp/samplers/greedy.py b/keras_nlp/samplers/greedy.py index f486a7cca0..a5e3626f14 100644 --- a/keras_nlp/samplers/greedy.py +++ b/keras_nlp/samplers/greedy.py @@ -14,6 +14,7 @@ """Greedy Sampler.""" import tensorflow as tf +from tensorflow import keras from keras_nlp.samplers.sampler import Sampler from keras_nlp.samplers.sampler import base_sampler_keyword_args @@ -21,6 +22,7 @@ from keras_nlp.samplers.sampler import sample_keyword_docstring +@keras.utils.register_keras_serializable(package="keras_nlp") class Greedy(Sampler): """Greedy sampler class. @@ -39,7 +41,6 @@ class Greedy(Sampler): VOCAB_SIZE = 10 FEATURE_SIZE = 16 START_ID = 1 - END_ID = 2 # Create a dummy model to predict the next token. model = keras.Sequential( @@ -62,7 +63,7 @@ def token_probability_fn(inputs, mask): sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) + print(sampler(token_probability_fn, prompt, 10)) ``` """ diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index e7216ab23d..aa7c64bdc7 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -14,8 +14,10 @@ """Base sampler class.""" import tensorflow as tf +from tensorflow import keras +@keras.utils.register_keras_serializable(package="keras_nlp") class Sampler: """Base sampler class. From 76c430c755d8f68bf9ccbedab4fdc5fb171758d0 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 9 Jan 2023 12:08:18 -0800 Subject: [PATCH 13/28] Address comments: fix docstring, remove multicase support --- keras_nlp/samplers/__init__.py | 18 +++--- keras_nlp/samplers/greedy.py | 29 ++++----- keras_nlp/samplers/sampler.py | 97 +++++++++++++++--------------- keras_nlp/samplers/sampler_test.py | 7 +-- 4 files changed, 69 insertions(+), 82 deletions(-) diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index f4841dd85a..5125a571b0 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -26,8 +26,6 @@ def deserialize(config, custom_objects=None): all_classes = { "greedy": Greedy, } - if config["class_name"].lower() in all_classes: - config["class_name"] = config["class_name"].lower() return keras.utils.deserialize_keras_object( config, module_objects=all_classes, @@ -55,15 +53,15 @@ def get(identifier): instance of the class by its constructor. Args: - identifier: String or dict that contains the sampler name or - configurations. + identifier: String or dict that contains the sampler name or + configurations. Returns: - Sampler instance base on the input identifier. + Sampler instance base on the input identifier. Raises: - ValueError: If the input identifier is not a supported type or in a bad - format. + ValueError: If the input identifier is not a supported type or in a bad + format. """ if identifier is None: @@ -71,7 +69,11 @@ def get(identifier): if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, str): - identifier = {"class_name": str(identifier), "config": {}} + if not identifier.islower(): + raise KeyError( + "`keras_nlp.samplers.get()` must take a lowercase string " + f"identifier, but received: {identifier}." + ) return deserialize(identifier) elif callable(identifier): return identifier diff --git a/keras_nlp/samplers/greedy.py b/keras_nlp/samplers/greedy.py index a5e3626f14..60fe732126 100644 --- a/keras_nlp/samplers/greedy.py +++ b/keras_nlp/samplers/greedy.py @@ -17,11 +17,15 @@ from tensorflow import keras from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import base_sampler_keyword_args -from keras_nlp.samplers.sampler import call_keyword_docstring -from keras_nlp.samplers.sampler import sample_keyword_docstring +from keras_nlp.samplers.sampler import base_sampler_args_docstring +from keras_nlp.samplers.sampler import call_args_docstring +from keras_nlp.samplers.sampler import sample_args_docstring +from keras_nlp.utils.python_utils import format_docstring +@format_docstring( + base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring +) @keras.utils.register_keras_serializable(package="keras_nlp") class Greedy(Sampler): """Greedy sampler class. @@ -30,10 +34,10 @@ class Greedy(Sampler): token of the largest probability as the next token. Args: - {{base_sampler_keyword_args}} + {{base_sampler_args}} Call Args: - {{call_keyword_args}} + {{call_args}} Examples: ```python @@ -73,13 +77,14 @@ def __init__( ): super().__init__(jit_compile) + @format_docstring(sample_args=sample_args_docstring) def sample( self, token_probability_fn, prompt, mask, num_steps, from_logits=True ): """Sampling logic implementation. Args: - {{sample_keyword_docstring}} + {{sample_args}} """ batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] max_length = tf.cast(max_length, num_steps.dtype) @@ -88,7 +93,6 @@ def sample( current_index = max_length - num_steps def one_step(current_index, prompt, mask): - probs = token_probability_fn(prompt, mask) next_token_prob = tf.gather( probs, @@ -143,14 +147,3 @@ def one_step(current_index, prompt, mask): loop_vars=(current_index, prompt, mask), ) return prompt - - -Greedy.__doc__ = Greedy.__doc__.replace( - "{{base_sampler_keyword_args}}", base_sampler_keyword_args -) -Greedy.__doc__ = Greedy.__doc__.replace( - "{{call_keyword_docstring}}", call_keyword_docstring -) -Greedy.sample.__doc__ = Greedy.sample.__doc__.replace( - "{{sample_keyword_docstring}}", sample_keyword_docstring -) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index aa7c64bdc7..8f208ff10a 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -16,16 +16,55 @@ import tensorflow as tf from tensorflow import keras +from keras_nlp.utils.python_utils import format_docstring +base_sampler_args_docstring = """ + jit_compile: bool, defaults to True. If True, XLA compilation will be used. + """ + +call_args_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a list of integers or an integer Tensor, can be 1D or 2D. The + initial tokens to append generated tokens. + max_length: int. The max length of generated sequence. + padding_mask: a tensor, defaults to None. The padding mask of the prompt. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ + +sample_args_docstring = """ + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. + prompt: a dense int Tensor of shape [batch_size, max_length]. The + placeholder for generated sequence. + mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of + prompt. + num_steps: int. The remaining number of tokens to generate. + from_logits: bool, defaults to True. Indicate if the `token_probability_fn` + returns logits. If False, `token_probability_fn` returns probability + distributions. + """ + + +@format_docstring( + base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring +) @keras.utils.register_keras_serializable(package="keras_nlp") class Sampler: """Base sampler class. Args: - {{base_optimizer_keyword_args}} + {{base_sampler_args}} Call Args: - {{call_keyword_docstring}} + {{call_args}} The inputs and outputs of Sampler class are both token ids. @@ -39,7 +78,8 @@ class Sampler: START_ID = 1 END_ID = 2 - # Create a dummy model to predict the next token. + # Create a dummy model to predict the next token. Note that the output is + # random without training, here we jsut demo how `samplers` works. model = keras.Sequential( [ keras.Input(shape=[None]), @@ -178,7 +218,8 @@ def __call__( from_logits=True, ): prompt, padding_mask = self._validate_prompt_and_mask( - prompt, padding_mask + prompt, + padding_mask, ) input_is_1d = prompt.shape.rank == 1 @@ -214,13 +255,14 @@ def __call__( return tf.squeeze(prompt, axis=0) if input_is_1d else prompt + @format_docstring(sample_args=sample_args_docstring) def sample( self, token_probability_fn, prompt, mask, num_steps, from_logits=True ): """Sampling logic implementation. Args: - {{sample_keyword_docstring}} + {{sample_args}} Returns: A dense int Tensor, representing the generated text in token id @@ -232,48 +274,3 @@ def get_config(self): return { "jit_compile": self.jit_compile, } - - -base_sampler_keyword_args = """ - jit_compile: bool, defaults to True. If True, XLA compilation will be used. - """ - -call_keyword_docstring = """ - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. - prompt: a list of integers or an integer Tensor, can be 1D or 2D. The - initial tokens to append generated tokens. - max_length: int. The max length of generated sequence. - padding_mask: a tensor, defaults to None. The padding mask of the prompt. - end_token_id: int, defaults to None. The token marking the end of the - sequence, once encountered the generation is finished for the exact - sequence. If None, every sequence is generated up to `max_length`. - If set, all tokens after encountering `end_token_id` will be - replaced with `pad_token_id`. - from_logits: bool, defaults to True. Indicate if the `token_probability_fn` - returns logits. If False, `token_probability_fn` returns probability - distributions. - """ - -sample_keyword_docstring = """ - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. - prompt: a dense int Tensor of shape [batch_size, max_length]. The - placeholder for generated sequence. - mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of - prompt. - num_steps: int. The remaining number of tokens to generate. - from_logits: bool, defaults to True. Indicate if the `token_probability_fn` - returns logits. If False, `token_probability_fn` returns probability - distributions. - """ - -Sampler.__doc__ = Sampler.__doc__.replace( - "{{base_sampler_keyword_args}}", base_sampler_keyword_args -) -Sampler.__doc__ = Sampler.__doc__.replace( - "{{call_keyword_docstring}}", call_keyword_docstring -) -Sampler.sample.__doc__ = Sampler.sample.__doc__.replace( - "{{sample_keyword_docstring}}", sample_keyword_docstring -) diff --git a/keras_nlp/samplers/sampler_test.py b/keras_nlp/samplers/sampler_test.py index 2e4ed4a1a7..f88c6ad0ab 100644 --- a/keras_nlp/samplers/sampler_test.py +++ b/keras_nlp/samplers/sampler_test.py @@ -24,7 +24,7 @@ def test_serialization(self): sampler = keras_nlp.samplers.Greedy() config = keras_nlp.samplers.serialize(sampler) expected_config = { - "class_name": "Greedy", + "class_name": "keras_nlp>Greedy", "config": { "jit_compile": True, }, @@ -37,11 +37,6 @@ def test_deserialization(self): sampler = keras_nlp.samplers.get(identifier) self.assertIsInstance(sampler, Greedy) - # Test string is not case-sensitive. - identifier = "Greedy" - sampler = keras_nlp.samplers.get(identifier) - self.assertIsInstance(sampler, Greedy) - # Test dict identifier. original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) config = keras_nlp.samplers.serialize(original_sampler) From bb430dd905b6fcb67b80d2d0b50f55a4ad6b600d Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 9 Jan 2023 12:36:29 -0800 Subject: [PATCH 14/28] Address comments: move token_probability_fn to the second place --- keras_nlp/samplers/__init__.py | 2 +- keras_nlp/samplers/greedy.py | 4 ++-- keras_nlp/samplers/greedy_test.py | 16 ++++++++-------- keras_nlp/samplers/sampler.py | 27 ++++++++++++--------------- 4 files changed, 23 insertions(+), 26 deletions(-) diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 5125a571b0..3f9e6cd8c2 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -46,7 +46,7 @@ def get(identifier): dict containing `class_name` and `config` as an identifier. Also note that the `class_name` must map to a `Sampler` class. - >>> cfg = {'class_name': 'Greedy', 'config': {}} + >>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}} >>> sampler = keras_nlp.samplers.get(cfg) In the case that the `identifier` is a class, this method will return a new diff --git a/keras_nlp/samplers/greedy.py b/keras_nlp/samplers/greedy.py index 60fe732126..398f719b6b 100644 --- a/keras_nlp/samplers/greedy.py +++ b/keras_nlp/samplers/greedy.py @@ -67,7 +67,7 @@ def token_probability_fn(inputs, mask): sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, 10)) + print(sampler(prompt, token_probability_fn, 10)) ``` """ @@ -79,7 +79,7 @@ def __init__( @format_docstring(sample_args=sample_args_docstring) def sample( - self, token_probability_fn, prompt, mask, num_steps, from_logits=True + self, prompt, token_probability_fn, mask, num_steps, from_logits=True ): """Sampling logic implementation. diff --git a/keras_nlp/samplers/greedy_test.py b/keras_nlp/samplers/greedy_test.py index 0472bdcbc9..bd77b2490f 100644 --- a/keras_nlp/samplers/greedy_test.py +++ b/keras_nlp/samplers/greedy_test.py @@ -48,17 +48,17 @@ def token_probability_fn(inputs, mask): def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) - outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) self.assertEqual(outputs.shape, [5]) def test_generate_with_2d_prompt(self): inputs = tf.constant([[1], [1]]) - outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) self.assertEqual(outputs.shape, [2, 5]) def test_generate_with_list_prompt(self): inputs = [[1], [1]] - outputs = self.sampler(self.token_probability_fn, inputs, max_length=5) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) self.assertEqual(outputs.shape, [2, 5]) def test_generate_with_ragged_prompt(self): @@ -71,7 +71,7 @@ def token_probability_fn(inputs, mask): return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1) inputs = tf.ragged.constant([[1], [2, 1, 2]]) - outputs = self.sampler(token_probability_fn, inputs, max_length) + outputs = self.sampler(inputs, token_probability_fn, max_length) self.assertEqual(outputs.shape, [2, 5]) def test_assert_generation_is_correct(self): @@ -86,7 +86,7 @@ def token_probability_fn(inputs, mask): inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) outputs = self.sampler( - token_probability_fn, inputs, max_length=max_length + inputs, token_probability_fn, max_length=max_length ) self.assertAllEqual( outputs, 3 * tf.ones(shape=[batch_size, max_length]) @@ -105,8 +105,8 @@ def token_probability_fn(inputs, mask): sampler = Greedy() inputs = tf.constant([[0, 1], [1, 2]]) outputs = sampler( - token_probability_fn, inputs, + token_probability_fn, max_length=max_length, end_token_id=2, ) @@ -117,12 +117,12 @@ def test_compare_xla_noxla_results(self): inputs = [[1], [1]] xla_sampler = Greedy(jit_compile=True) outputs_xla = xla_sampler( - self.token_probability_fn, inputs, max_length=5 + inputs, self.token_probability_fn, max_length=5 ) xla_sampler = Greedy(jit_compile=False) outputs_no_xla = xla_sampler( - self.token_probability_fn, inputs, max_length=5 + inputs, self.token_probability_fn, max_length=5 ) self.assertAllEqual(outputs_xla, outputs_no_xla) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 8f208ff10a..c265bf9172 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -23,12 +23,12 @@ """ call_args_docstring = """ - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. prompt: a list of integers or an integer Tensor, can be 1D or 2D. The initial tokens to append generated tokens. + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. max_length: int. The max length of generated sequence. - padding_mask: a tensor, defaults to None. The padding mask of the prompt. + mask: a tensor, defaults to None. The padding mask of the prompt. end_token_id: int, defaults to None. The token marking the end of the sequence, once encountered the generation is finished for the exact sequence. If None, every sequence is generated up to `max_length`. @@ -40,10 +40,10 @@ """ sample_args_docstring = """ - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. prompt: a dense int Tensor of shape [batch_size, max_length]. The placeholder for generated sequence. + token_probability_fn: a function that generates the probability of + the next token over the whole vocabulary for each input token. mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of prompt. num_steps: int. The remaining number of tokens to generate. @@ -100,7 +100,7 @@ def token_probability_fn(inputs, mask): sampler = keras_nlp.samplers.Greedy() # Print the generated sequence (token ids). - print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID)) + print(sampler(prompt, token_probability_fn, 10, end_token_id=END_ID)) ``` Use with string inputs: @@ -131,8 +131,8 @@ def token_probability_fn(inputs, mask): prompt = tokenizer("the quick brown fox") sampler = keras_nlp.samplers.Greedy() generated = sampler( - token_probability_fn, prompt, + token_probability_fn, 10, end_token_id=tokenizer.token_to_id("[END]") ) @@ -210,17 +210,14 @@ def _mask_tokens_after_end_token( def __call__( self, - token_probability_fn, prompt, + token_probability_fn, max_length, - padding_mask=None, + mask=None, end_token_id=None, from_logits=True, ): - prompt, padding_mask = self._validate_prompt_and_mask( - prompt, - padding_mask, - ) + prompt, mask = self._validate_prompt_and_mask(prompt, mask) input_is_1d = prompt.shape.rank == 1 if input_is_1d: @@ -238,8 +235,8 @@ def __call__( # `jit_compile` accordingly. sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( - token_probability_fn, prompt, + token_probability_fn, mask, max_length - shortest_prompt_len, from_logits, @@ -257,7 +254,7 @@ def __call__( @format_docstring(sample_args=sample_args_docstring) def sample( - self, token_probability_fn, prompt, mask, num_steps, from_logits=True + self, prompt, token_probability_fn, mask, num_steps, from_logits=True ): """Sampling logic implementation. From afd3082be0e3bdd1e48ef5a4e49c3fa25cf30376 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Tue, 10 Jan 2023 13:41:35 -0800 Subject: [PATCH 15/28] some initials --- keras_nlp/samplers/top_k.py | 137 ++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 keras_nlp/samplers/top_k.py diff --git a/keras_nlp/samplers/top_k.py b/keras_nlp/samplers/top_k.py new file mode 100644 index 0000000000..2f424baed7 --- /dev/null +++ b/keras_nlp/samplers/top_k.py @@ -0,0 +1,137 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Top-k Sampler.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_args_docstring +from keras_nlp.samplers.sampler import call_args_docstring +from keras_nlp.samplers.sampler import sample_args_docstring +from keras_nlp.utils.python_utils import format_docstring + + +@format_docstring( + base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring +) +@keras.utils.register_keras_serializable(package="keras_nlp") +class TopK(Sampler): + """Top-K Sampler class. + + This sampler implements top-k search algorithm. + + Args: + k: int, the `k` value in top-k. + seed: int, defaults to None. The random seed. + {{base_sampler_args}} + Call Args: + {{call_args}} + """ + + def __init__( + self, + k, + seed=None, + jit_compile=True, + ): + self.k = k + self.seed = seed + super().__init__(jit_compile) + + @format_docstring(sample_args=sample_args_docstring) + def sample( + self, prompt, token_probability_fn, mask, num_steps, from_logits=True + ): + """Sampling logic implementation. + + Args: + {{sample_args}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + length = max_length - num_steps + + def one_step(length, prompt, mask): + probs = token_probability_fn(prompt, mask) + pred = tf.gather( + probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + ) + if self.from_logits: + pred = keras.activations.softmax(pred, axis=-1) + # Sort preds in descending order. + sorted_preds, sorted_indices = tf.math.top_k( + pred, k=tf.shape(pred)[1], sorted=True + ) + # Calculate cumulative probability distribution. + cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1) + # Create a mask for the tokens to keep. + keep_mask = cumulative_probs <= self.p + # Shift to include the last token that exceed p. + shifted_keep_mask = tf.concat( + [tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1 + ) + # Filter out unmasked tokens and sample from filtered distribution. + probs = tf.where( + shifted_keep_mask, + sorted_preds, + tf.zeros(tf.shape(pred), dtype=sorted_preds.dtype), + ) + sorted_next_token = tf.random.categorical( + tf.math.log(probs), 1, seed=self.seed + ) + next_token = tf.gather_nd( + sorted_indices, sorted_next_token, batch_dims=1 + ) + next_token = tf.cast(next_token, dtype=prompt.dtype) + next_token = tf.where( + mask[:, length], prompt[:, length], next_token + ) + + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + + # Append the next token to current sequence. + prompt = tf.tensor_scatter_nd_update( + tensor=prompt, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=next_token, + ) + + length = tf.add(length, 1) + return (length, prompt, mask) + + # Run a while loop till text of length `max_length` has been generated. + length, prompt, mask = tf.while_loop( + cond=lambda length, prompt, mask: tf.less(length, max_length), + body=one_step, + loop_vars=(length, prompt, mask), + ) + + return prompt From 31ad970e86c59945ba1b3c18fd480d7d24b6b63b Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 12 Jan 2023 17:11:44 -0800 Subject: [PATCH 16/28] add more sampler class, and a few changes on the base sampler class --- keras_nlp/samplers/__init__.py | 13 +- keras_nlp/samplers/beam_sampler.py | 215 ++++++++++++++++++ keras_nlp/samplers/beam_sampler_test.py | 137 +++++++++++ .../samplers/{greedy.py => greedy_sampler.py} | 7 +- ...{greedy_test.py => greedy_sampler_test.py} | 12 +- keras_nlp/samplers/sampler.py | 18 +- keras_nlp/samplers/sampler_test.py | 12 +- keras_nlp/samplers/top_k_sampler.py | 159 +++++++++++++ keras_nlp/samplers/top_k_sampler_test.py | 160 +++++++++++++ .../samplers/{top_k.py => top_p_sampler.py} | 58 ++++- keras_nlp/samplers/top_p_sampler_test.py | 160 +++++++++++++ 11 files changed, 919 insertions(+), 32 deletions(-) create mode 100644 keras_nlp/samplers/beam_sampler.py create mode 100644 keras_nlp/samplers/beam_sampler_test.py rename keras_nlp/samplers/{greedy.py => greedy_sampler.py} (96%) rename keras_nlp/samplers/{greedy_test.py => greedy_sampler_test.py} (93%) create mode 100644 keras_nlp/samplers/top_k_sampler.py create mode 100644 keras_nlp/samplers/top_k_sampler_test.py rename keras_nlp/samplers/{top_k.py => top_p_sampler.py} (74%) create mode 100644 keras_nlp/samplers/top_p_sampler_test.py diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 3f9e6cd8c2..89911a78f4 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -14,7 +14,11 @@ from tensorflow import keras -from keras_nlp.samplers.greedy import Greedy +from keras_nlp.samplers.beam_sampler import BeamSampler +from keras_nlp.samplers.greedy_sampler import GreedySampler +from keras_nlp.samplers.greedy_sampler import Sampler +from keras_nlp.samplers.top_k_sampler import TopKSampler +from keras_nlp.samplers.top_p_sampler import TopPSampler def serialize(sampler): @@ -24,7 +28,10 @@ def serialize(sampler): def deserialize(config, custom_objects=None): """Return a `Sampler` object from its config.""" all_classes = { - "greedy": Greedy, + "beam": BeamSampler, + "greedy": GreedySampler, + "top_k": TopKSampler, + "top_p": TopPSampler, } return keras.utils.deserialize_keras_object( config, @@ -46,7 +53,7 @@ def get(identifier): dict containing `class_name` and `config` as an identifier. Also note that the `class_name` must map to a `Sampler` class. - >>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}} + >>> cfg = {'class_name': 'keras_nlp>GreedySampler', 'config': {}} >>> sampler = keras_nlp.samplers.get(cfg) In the case that the `identifier` is a class, this method will return a new diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py new file mode 100644 index 0000000000..4533186633 --- /dev/null +++ b/keras_nlp/samplers/beam_sampler.py @@ -0,0 +1,215 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Beam Sampler.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_args_docstring +from keras_nlp.samplers.sampler import call_args_docstring +from keras_nlp.samplers.sampler import sample_args_docstring +from keras_nlp.utils.python_utils import format_docstring + + +@format_docstring( + base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring +) +@keras.utils.register_keras_serializable(package="keras_nlp") +class BeamSampler(Sampler): + """Beam Sampler class. + + This sampler implements beam search algorithm. At each time-step, beam + search keeps the beams (sequences) of the top `num_beams` highest + accumulated probabilities, and uses each one of the beams to predict + candidate next tokens. + + Args: + num_beams: int. The number of beams that should be kept at each + time-step. `num_beams` should be strictly positive. + {{base_sampler_args}} + + Call Args: + {{call_args}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs, mask): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.BeamSampler(num_beams=3) + # Print the generated sequence (token ids). + print(sampler(prompt, token_probability_fn, 10)) + ``` + """ + + def __init__( + self, + num_beams, + jit_compile=True, + run_eagerly=False, + ): + self.num_beams = num_beams + super().__init__(jit_compile, run_eagerly) + + @format_docstring(sample_args=sample_args_docstring) + def sample( + self, prompt, token_probability_fn, mask, num_steps, from_logits=True + ): + """Sampling logic implementation. + + Args: + {{sample_args}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + length = max_length - num_steps + dummy_preds = token_probability_fn(prompt, mask=mask) + vocab_size = tf.shape(dummy_preds)[-1] + pred_dtype = dummy_preds.dtype + + num_beams = self.num_beams + + # Initialize beam with shape `(batch_size, num_beams, length)`. + beams = tf.repeat(tf.expand_dims(prompt, axis=1), num_beams, axis=1) + # Initialize `beams_prob` with shape `(batch_size, num_beams)`. + beams_prob = tf.zeros([batch_size, 1], dtype=pred_dtype) + beams_prob = tf.concat( + [beams_prob, tf.fill((batch_size, num_beams - 1), pred_dtype.min)], + axis=-1, + ) + + def one_step(beams, beams_prob, length, mask): + + flattened_beams = tf.reshape( + beams, shape=[batch_size * num_beams, -1] + ) + repeated_mask = tf.tile(mask, [num_beams, 1]) + probs = token_probability_fn(flattened_beams, repeated_mask) + preds = tf.gather( + probs, + tf.repeat(length - 1, batch_size * num_beams), + axis=1, + batch_dims=1, + ) + if from_logits: + preds = keras.activations.softmax(preds, axis=-1) + # Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`. + + preds = tf.reshape(preds, shape=[batch_size, -1]) + + cum_probs = tf.math.log(preds) + tf.repeat( + beams_prob, repeats=vocab_size, axis=1 + ) + + candidate_prob, candidate_indexes = tf.math.top_k( + cum_probs, k=num_beams, sorted=False + ) + + candidate_beam_indexes = candidate_indexes // vocab_size + next_token = candidate_indexes % vocab_size + + beams = tf.gather( + beams, candidate_beam_indexes, axis=1, batch_dims=1 + ) + + # Build a new column of updates to scatter into the beam tensor. + next_token = tf.where( + condition=mask[..., length, tf.newaxis], + x=beams[..., length], + y=next_token, + ) + next_token = tf.reshape(next_token, shape=[-1]) + + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + + # Generate `(batch_index, beam_index)` tuples for each beam. + beam_indices = tf.where(tf.ones((batch_size, num_beams), tf.bool)) + beam_indices = tf.cast(beam_indices, dtype=length.dtype) + # Build a tensor of repeated `length` values. + length_indices = tf.fill((batch_size * num_beams, 1), length) + # Concatenate to a triplet of `(batch_index, beam_index, length)`. + indices = tf.concat([beam_indices, length_indices], axis=-1) + + # Update `beams[:, :, length]` with `next_token`. + beams = tf.tensor_scatter_nd_update( + tensor=beams, + indices=indices, + updates=next_token, + ) + + beams_prob = candidate_prob + + length = tf.add(length, 1) + return beams, beams_prob, length, mask + + # Run a while loop till text of length `max_length` has been generated. + beams, beams_prob, length, mask = tf.while_loop( + cond=lambda beams, beams_prob, length, mask: tf.less( + length, max_length + ), + body=one_step, + loop_vars=[beams, beams_prob, length, mask], + # There is a strange issue that when `batch_size=1`, the first loop + # iteration changes `beams_prob`'s shape from [1, None] to + # [None, None], which does not happen for `batch_size>1`. + # As a workaround, we set shape invariants. + shape_invariants=[ + beams.get_shape(), + tf.TensorShape([None, None]), + length.get_shape(), + mask.get_shape(), + ], + ) + + # Get the beam with the maximum probability. + max_indexes = tf.math.argmax(beams_prob, axis=-1) + max_beams = tf.gather( + beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1 + ) + + prompt = tf.squeeze(max_beams, axis=1) + + return prompt diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py new file mode 100644 index 0000000000..82abc889e3 --- /dev/null +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -0,0 +1,137 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Beam sampler.""" + +import random + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.samplers.beam_sampler import BeamSampler +from keras_nlp.samplers.greedy_sampler import GreedySampler + + +class BeamSamplerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.vocab_size = 10 + self.feature_size = 16 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=self.vocab_size, + output_dim=self.feature_size, + ), + keras.layers.Dense(self.vocab_size), + keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs, mask): + return model(inputs) + + self.token_probability_fn = token_probability_fn + self.sampler = BeamSampler(num_beams=2) + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + inputs = tf.ragged.constant([[1], [2, 1, 2]]) + outputs = self.sampler(inputs, token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_one_beam_generation(self): + for _ in range(5): + inputs = tf.constant([random.randint(0, 9)]) + beam_sampler = BeamSampler(num_beams=1) + greedy_sampler = GreedySampler() + beam_output = beam_sampler( + inputs, + self.token_probability_fn, + max_length=5, + ) + greedy_output = greedy_sampler( + inputs, + self.token_probability_fn, + max_length=5, + ) + self.assertAllEqual(beam_output, greedy_output) + + @parameterized.named_parameters( + ("xla_graph", True, False), + ("non_xla_graph", False, False), + ("eager", False, True), + ) + def test_assert_generation_is_correct(self, jit_compile, run_eagerly): + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + for i in range(1, 5): + sampler = BeamSampler( + num_beams=i, + jit_compile=jit_compile, + run_eagerly=run_eagerly, + ) + outputs = sampler( + inputs, + token_probability_fn, + max_length=max_length, + ) + self.assertAllEqual( + outputs, 3 * tf.ones(shape=[batch_size, max_length]) + ) + + def test_end_token_id(self): + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + outputs = self.sampler( + inputs, + token_probability_fn, + max_length=max_length, + end_token_id=2, + ) + expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]]) + self.assertAllEqual(outputs, expected_outputs) diff --git a/keras_nlp/samplers/greedy.py b/keras_nlp/samplers/greedy_sampler.py similarity index 96% rename from keras_nlp/samplers/greedy.py rename to keras_nlp/samplers/greedy_sampler.py index 398f719b6b..d83fdc56d8 100644 --- a/keras_nlp/samplers/greedy.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -27,7 +27,7 @@ base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring ) @keras.utils.register_keras_serializable(package="keras_nlp") -class Greedy(Sampler): +class GreedySampler(Sampler): """Greedy sampler class. This sampler is implemented on greedy search, i.e., always picking up the @@ -65,7 +65,7 @@ def token_probability_fn(inputs, mask): prompt = tf.fill((BATCH_SIZE, 1), START_ID) - sampler = keras_nlp.samplers.Greedy() + sampler = keras_nlp.samplers.GreedySampler() # Print the generated sequence (token ids). print(sampler(prompt, token_probability_fn, 10)) ``` @@ -74,8 +74,9 @@ def token_probability_fn(inputs, mask): def __init__( self, jit_compile=True, + run_eagerly=False, ): - super().__init__(jit_compile) + super().__init__(jit_compile, run_eagerly) @format_docstring(sample_args=sample_args_docstring) def sample( diff --git a/keras_nlp/samplers/greedy_test.py b/keras_nlp/samplers/greedy_sampler_test.py similarity index 93% rename from keras_nlp/samplers/greedy_test.py rename to keras_nlp/samplers/greedy_sampler_test.py index bd77b2490f..b31734d75b 100644 --- a/keras_nlp/samplers/greedy_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -17,10 +17,10 @@ from absl.testing import parameterized from tensorflow import keras -from keras_nlp.samplers.greedy import Greedy +from keras_nlp.samplers.greedy_sampler import GreedySampler -class GreedyTest(tf.test.TestCase, parameterized.TestCase): +class GreedySamplerTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() self.vocab_size = 10 @@ -44,7 +44,7 @@ def token_probability_fn(inputs, mask): self.token_probability_fn = token_probability_fn - self.sampler = Greedy() + self.sampler = GreedySampler() def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) @@ -102,7 +102,7 @@ def token_probability_fn(inputs, mask): tf.repeat(prob, batch_size, axis=0), max_length, axis=1 ) - sampler = Greedy() + sampler = GreedySampler() inputs = tf.constant([[0, 1], [1, 2]]) outputs = sampler( inputs, @@ -115,12 +115,12 @@ def token_probability_fn(inputs, mask): def test_compare_xla_noxla_results(self): inputs = [[1], [1]] - xla_sampler = Greedy(jit_compile=True) + xla_sampler = GreedySampler(jit_compile=True) outputs_xla = xla_sampler( inputs, self.token_probability_fn, max_length=5 ) - xla_sampler = Greedy(jit_compile=False) + xla_sampler = GreedySampler(jit_compile=False) outputs_no_xla = xla_sampler( inputs, self.token_probability_fn, max_length=5 ) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index c265bf9172..17137d6684 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -20,6 +20,8 @@ base_sampler_args_docstring = """ jit_compile: bool, defaults to True. If True, XLA compilation will be used. + run_eagerly: bool, defaults to False. If True, the sampler will run in + the eager mode. """ call_args_docstring = """ @@ -143,8 +145,16 @@ def token_probability_fn(inputs, mask): def __init__( self, jit_compile=True, + run_eagerly=False, ): + if run_eagerly and jit_compile: + raise ValueError( + "XLA cannot be turned on under eager mode, received " + "`jit_compile=True` and `run_eagerly=True`. Please either set " + "`jit_compile=False` or set `run_eagerly=False`." + ) self.jit_compile = jit_compile + self.run_eagerly = run_eagerly def _validate_prompt_and_mask(self, prompt, mask): """Helper method to validate input prompt.""" @@ -231,9 +241,11 @@ def __call__( prompt, mask = self._pad_prompt(prompt, max_length) self._validate_token_probability_fn(token_probability_fn, prompt, mask) - # Convert `sample` method to a `tf.function`, and turn on - # `jit_compile` accordingly. - sample = tf.function(self.sample, jit_compile=self.jit_compile) + # Convert `sample` method to a `tf.function` if `self.run_eagerly=False` + # , and turn on `jit_compile` accordingly. + sample = self.sample + if not self.run_eagerly: + sample = tf.function(self.sample, jit_compile=self.jit_compile) prompt = sample( prompt, token_probability_fn, diff --git a/keras_nlp/samplers/sampler_test.py b/keras_nlp/samplers/sampler_test.py index f88c6ad0ab..860bcbad3e 100644 --- a/keras_nlp/samplers/sampler_test.py +++ b/keras_nlp/samplers/sampler_test.py @@ -16,15 +16,15 @@ import tensorflow as tf import keras_nlp -from keras_nlp.samplers.greedy import Greedy +from keras_nlp.samplers.greedy_sampler import GreedySampler class SamplerTest(tf.test.TestCase): def test_serialization(self): - sampler = keras_nlp.samplers.Greedy() + sampler = keras_nlp.samplers.GreedySampler() config = keras_nlp.samplers.serialize(sampler) expected_config = { - "class_name": "keras_nlp>Greedy", + "class_name": "keras_nlp>GreedySampler", "config": { "jit_compile": True, }, @@ -35,10 +35,10 @@ def test_deserialization(self): # Test get from string. identifier = "greedy" sampler = keras_nlp.samplers.get(identifier) - self.assertIsInstance(sampler, Greedy) + self.assertIsInstance(sampler, GreedySampler) # Test dict identifier. - original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) + original_sampler = keras_nlp.samplers.GreedySampler(jit_compile=False) config = keras_nlp.samplers.serialize(original_sampler) restored_sampler = keras_nlp.samplers.get(config) self.assertDictEqual( @@ -47,6 +47,6 @@ def test_deserialization(self): ) # Test identifier is already a sampler instance. - original_sampler = keras_nlp.samplers.Greedy(jit_compile=False) + original_sampler = keras_nlp.samplers.GreedySampler(jit_compile=False) restored_sampler = keras_nlp.samplers.get(original_sampler) self.assertEqual(original_sampler, restored_sampler) diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py new file mode 100644 index 0000000000..8224dd2d46 --- /dev/null +++ b/keras_nlp/samplers/top_k_sampler.py @@ -0,0 +1,159 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Top-k Sampler.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.samplers.sampler import Sampler +from keras_nlp.samplers.sampler import base_sampler_args_docstring +from keras_nlp.samplers.sampler import call_args_docstring +from keras_nlp.samplers.sampler import sample_args_docstring +from keras_nlp.utils.python_utils import format_docstring + + +@format_docstring( + base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring +) +@keras.utils.register_keras_serializable(package="keras_nlp") +class TopKSampler(Sampler): + """Top-K Sampler class. + + This sampler implements top-k search algorithm. Briefly top-k algorithm + randomly selects a token from the tokens of top K probability, with + selection chance determined by the probability. + + Args: + k: int, the `k` value of top-k. + seed: int, defaults to None. The random seed. + {{base_sampler_args}} + + Call Args: + {{call_args}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs, mask): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.TopKSampler(k=5) + # Print the generated sequence (token ids). + print(sampler(prompt, token_probability_fn, 10)) + ``` + """ + + def __init__( + self, + k, + seed=None, + jit_compile=True, + run_eagerly=False, + ): + self.k = k + self.seed = seed + super().__init__(jit_compile, run_eagerly) + + @format_docstring(sample_args=sample_args_docstring) + def sample( + self, prompt, token_probability_fn, mask, num_steps, from_logits=True + ): + """Sampling logic implementation. + + Args: + {{sample_args}} + """ + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] + max_length = tf.cast(max_length, num_steps.dtype) + length = max_length - num_steps + + def one_step(length, prompt, mask): + probs = token_probability_fn(prompt, mask) + pred = tf.gather( + probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 + ) + if from_logits: + pred = keras.activations.softmax(pred, axis=-1) + # Filter out top-k tokens. + top_k_pred, top_k_indices = tf.math.top_k( + pred, k=self.k, sorted=False + ) + # Sample the next token from the probability distribution. + next_token = tf.random.categorical( + tf.math.log(top_k_pred), 1, seed=self.seed + ) + + # Rearrange to get the next token idx from the original order. + next_token = tf.gather_nd(top_k_indices, next_token, batch_dims=1) + next_token = tf.cast(next_token, dtype=prompt.dtype) + next_token = tf.where( + mask[:, length], prompt[:, length], next_token + ) + + mask = tf.tensor_scatter_nd_update( + tensor=mask, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=tf.repeat(True, batch_size), + ) + + # Append the next token to current sequence. + prompt = tf.tensor_scatter_nd_update( + tensor=prompt, + indices=tf.stack( + ( + tf.cast(tf.range(batch_size), dtype=length.dtype), + tf.repeat(length, batch_size), + ), + axis=1, + ), + updates=next_token, + ) + + length = tf.add(length, 1) + return (length, prompt, mask) + + # Run a while loop till text of length `max_length` has been generated. + length, prompt, mask = tf.while_loop( + cond=lambda length, prompt, mask: tf.less(length, max_length), + body=one_step, + loop_vars=(length, prompt, mask), + ) + + return prompt diff --git a/keras_nlp/samplers/top_k_sampler_test.py b/keras_nlp/samplers/top_k_sampler_test.py new file mode 100644 index 0000000000..4cca77f602 --- /dev/null +++ b/keras_nlp/samplers/top_k_sampler_test.py @@ -0,0 +1,160 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Top-K sampler.""" + +import numpy as np +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.samplers.top_k_sampler import TopKSampler + + +class TopKSamplerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.vocab_size = 10 + self.feature_size = 16 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=self.vocab_size, + output_dim=self.feature_size, + ), + keras.layers.Dense(self.vocab_size), + keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs, mask): + return model(inputs) + + self.token_probability_fn = token_probability_fn + self.sampler = TopKSampler(k=2) + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + inputs = tf.ragged.constant([[1], [2, 1, 2]]) + outputs = self.sampler( + inputs, + token_probability_fn, + max_length=5, + from_logits=False, + ) + self.assertEqual(outputs.shape, [2, 5]) + + @parameterized.named_parameters( + ("xla_graph", True, False), + ("non_xla_graph", False, False), + ("eager", False, True), + ) + def test_assert_probability_distribution_generation_is_correct( + self, jit_compile, run_eagerly + ): + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + + outputs_count = np.array([0, 0, 0, 0]) + tf.random.set_seed(42) + sampler = TopKSampler( + k=2, + seed=42, + run_eagerly=jit_compile, + jit_compile=run_eagerly, + ) + for _ in range(8): + outputs = sampler( + inputs, + token_probability_fn, + max_length=max_length, + from_logits=False, + ) + flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) + for pred in flatten_predictions: + outputs_count[pred] += 1 + self.assertAllClose( + outputs_count / np.sum(outputs_count), + [0.0, 0.0, 0.0, 1.0], + rtol=0.2, + ) + + def test_only_choose_from_top_k_tokens(self): + # Test that there are only the top-k tokens in the output. + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.4, 0.3, 0.2, 0.1]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + # Test that it only samples from top-k tokens. + for k in [1, 2, 3]: + inputs = tf.constant([[0, 0], [0, 0]]) + sampler = TopKSampler(k=k) + for _ in range(10): + outputs = sampler( + inputs, + token_probability_fn, + max_length=5, + from_logits=False, + ) + self.assertAllEqual(outputs < k, tf.ones_like(outputs)) + + def test_end_token_id(self): + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + tf.random.set_seed(42) + sampler = TopKSampler(k=4, seed=42) + outputs = sampler( + inputs, + token_probability_fn, + max_length=max_length, + end_token_id=2, + from_logits=False, + ) + # Top-k sampling result with seed 42. + expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]]) + self.assertAllEqual(outputs, expected_outputs) diff --git a/keras_nlp/samplers/top_k.py b/keras_nlp/samplers/top_p_sampler.py similarity index 74% rename from keras_nlp/samplers/top_k.py rename to keras_nlp/samplers/top_p_sampler.py index 2f424baed7..babfd0085e 100644 --- a/keras_nlp/samplers/top_k.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -1,4 +1,4 @@ -# Copyright 2022 The KerasNLP Authors +# Copyright 2023 The KerasNLP Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Top-k Sampler.""" +"""Top-p Sampler.""" import tensorflow as tf from tensorflow import keras @@ -27,28 +27,64 @@ base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring ) @keras.utils.register_keras_serializable(package="keras_nlp") -class TopK(Sampler): - """Top-K Sampler class. - - This sampler implements top-k search algorithm. +class TopPSampler(Sampler): + """Top-P Sampler class. + This sampler implements top-p search algorithm. Top-p search selects tokens + from the smallest subset of output probabilities that sum to greater than + `p`. Put in another way, top-p will first order token predictions by + likelihood, and ignore all tokens after the cumulative probability of + selected tokens exceeds `p`, then select a token from the remaining tokens. Args: - k: int, the `k` value in top-k. + p: float, the `p` value of top-p. seed: int, defaults to None. The random seed. {{base_sampler_args}} + Call Args: {{call_args}} + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability for each token + # in the input sequence. + def token_probability_fn(inputs, mask): + return model(inputs) + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + sampler = keras_nlp.samplers.TopPSampler(p=0.1) + # Print the generated sequence (token ids). + print(sampler(prompt, token_probability_fn, 10)) + ``` """ def __init__( self, - k, + p, seed=None, jit_compile=True, + run_eagerly=False, ): - self.k = k + self.p = p self.seed = seed - super().__init__(jit_compile) + super().__init__(jit_compile, run_eagerly) @format_docstring(sample_args=sample_args_docstring) def sample( @@ -68,7 +104,7 @@ def one_step(length, prompt, mask): pred = tf.gather( probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1 ) - if self.from_logits: + if from_logits: pred = keras.activations.softmax(pred, axis=-1) # Sort preds in descending order. sorted_preds, sorted_indices = tf.math.top_k( diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py new file mode 100644 index 0000000000..1291de1305 --- /dev/null +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -0,0 +1,160 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Top-P sampler.""" + +import numpy as np +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.samplers.top_p_sampler import TopPSampler + + +class TopPSamplerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.vocab_size = 10 + self.feature_size = 16 + + # Create a dummy model to predict the next token. + model = keras.Sequential( + [ + keras.Input(shape=[None]), + keras.layers.Embedding( + input_dim=self.vocab_size, + output_dim=self.feature_size, + ), + keras.layers.Dense(self.vocab_size), + keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs, mask): + return model(inputs) + + self.token_probability_fn = token_probability_fn + self.sampler = TopPSampler(p=0.1) + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) + self.assertEqual(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + inputs = tf.ragged.constant([[1], [2, 1, 2]]) + outputs = self.sampler( + inputs, + token_probability_fn, + max_length=5, + from_logits=False, + ) + self.assertEqual(outputs.shape, [2, 5]) + + @parameterized.named_parameters( + ("xla_graph", True, False), + ("non_xla_graph", False, False), + ("eager", False, True), + ) + def test_assert_probability_distribution_generation_is_correct( + self, jit_compile, run_eagerly + ): + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + + outputs_count = np.array([0, 0, 0, 0]) + tf.random.set_seed(42) + sampler = TopPSampler( + p=0.1, + seed=42, + run_eagerly=jit_compile, + jit_compile=run_eagerly, + ) + for _ in range(8): + outputs = sampler( + inputs, + token_probability_fn, + max_length=max_length, + from_logits=False, + ) + flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) + for pred in flatten_predictions: + outputs_count[pred] += 1 + self.assertAllClose( + outputs_count / np.sum(outputs_count), + [0.0, 0.0, 0.0, 1.0], + rtol=0.2, + ) + + def test_only_choose_from_top_p_tokens(self): + # Test that there are only the top-p tokens in the output. + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.4, 0.3, 0.2, 0.1]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + # Test that it only samples from top-p tokens. + for i, p in enumerate([0.399, 0.699, 0.899]): + inputs = tf.constant([[0, 0], [0, 0]]) + sampler = TopPSampler(p=p) + for _ in range(10): + outputs = sampler( + inputs, + token_probability_fn, + max_length=5, + from_logits=False, + ) + self.assertAllEqual(outputs <= i, tf.ones_like(outputs)) + + def test_end_token_id(self): + def token_probability_fn(inputs, mask): + batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] + prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) + return tf.tile(prob, [batch_size, seq_length, 1]) + + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + tf.random.set_seed(42) + sampler = TopPSampler(p=0.1, seed=42) + outputs = sampler( + inputs, + token_probability_fn, + max_length=max_length, + end_token_id=2, + from_logits=False, + ) + # Top-p sampling result with seed 42. + expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]]) + self.assertAllEqual(outputs, expected_outputs) From 53008005a0d0bb1b7317f2616c68ff5ba29f5b06 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 12 Jan 2023 17:43:03 -0800 Subject: [PATCH 17/28] dummy --- .../gpt2/gpt2_causal_lm_preprocessor.py | 34 +++++++++++++++++++ keras_nlp/models/gpt2/gpt2_preprocessor.py | 17 ---------- 2 files changed, 34 insertions(+), 17 deletions(-) create mode 100644 keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py new file mode 100644 index 0000000000..a61eaa232f --- /dev/null +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -0,0 +1,34 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf + +from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +class GPT2CausalLMPreprocessor(GPT2Preprocessor): + def call(self, x, y=None, sample_weight=None): + token_ids = self.tokenizer(x) + mask = tf.ones_like(token_ids, dtype=tf.bool) + mask = mask.to_tensor(shape=(None, self.sequence_length)) + token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) + x = { + "token_ids": token_ids[:, :-1], + "padding_mask": mask[:, 1:], + } + + y = token_ids[:, 1:] + sample_weight = mask[:, 1:] + + return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index e14ef10386..fa0aaa85b0 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -83,20 +83,3 @@ def from_preset( sequence_length=sequence_length, **kwargs, ) - - -class GPT2CausalLMPreprocessor(GPT2Preprocessor): - def call(self, x, y=None, sample_weight=None): - token_ids = self.tokenizer(x) - mask = tf.ones_like(token_ids, dtype=tf.bool) - mask = mask.to_tensor(shape=(None, self.sequence_length)) - token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) - x = { - "token_ids": token_ids[:, :-1], - "padding_mask": mask[:, 1:], - } - - y = token_ids[:, 1:] - sample_weight = mask[:, 1:] - - return pack_x_y_sample_weight(x, y, sample_weight) From de2ac9cfbb22751638ca97f072f3cd2edf408d62 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 12 Jan 2023 17:45:04 -0800 Subject: [PATCH 18/28] add some arg defaults --- keras_nlp/samplers/beam_sampler.py | 2 +- keras_nlp/samplers/top_k_sampler.py | 2 +- keras_nlp/samplers/top_p_sampler.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 4533186633..24ac8b8f83 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -77,7 +77,7 @@ def token_probability_fn(inputs, mask): def __init__( self, - num_beams, + num_beams=5, jit_compile=True, run_eagerly=False, ): diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 8224dd2d46..8c7d25f617 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -76,7 +76,7 @@ def token_probability_fn(inputs, mask): def __init__( self, - k, + k=5, seed=None, jit_compile=True, run_eagerly=False, diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index babfd0085e..ccb2817035 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -77,7 +77,7 @@ def token_probability_fn(inputs, mask): def __init__( self, - p, + p=0.1, seed=None, jit_compile=True, run_eagerly=False, From 08f3c1eabbd2e058d9d2c13399e9f247d3553313 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 12 Jan 2023 17:52:54 -0800 Subject: [PATCH 19/28] small fix --- keras_nlp/models/__init__.py | 4 +++- keras_nlp/models/gpt2/gpt2_causal_lm.py | 22 ++++++---------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 9c769b9ba4..ead2311014 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -31,7 +31,9 @@ ) from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM -from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2CausalLMPreprocessor +from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( + GPT2CausalLMPreprocessor, +) from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 7f8bd7f5d8..72eb3f4170 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -18,13 +18,12 @@ import tensorflow as tf from tensorflow import keras +import keras_nlp from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone -from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2CausalLMPreprocessor +from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( + GPT2CausalLMPreprocessor, +) from keras_nlp.models.gpt2.gpt2_presets import backbone_presets -from keras_nlp.samplers.beam_sampler import BeamSampler -from keras_nlp.samplers.greedy_sampler import GreedySampler -from keras_nlp.samplers.top_k_sampler import TopKSampler -from keras_nlp.samplers.top_p_sampler import TopPSampler from keras_nlp.utils.pipeline_model import PipelineModel from keras_nlp.utils.python_utils import classproperty @@ -84,15 +83,6 @@ def from_preset( f"""{", ".join(cls.presets)}. Received: {preset}.""" ) - def _get_generator(self, identifier): - maps = { - "greedy": GreedySampler(), - "top_k": TopKSampler(k=5, from_logits=False), - "top_p": TopPSampler(p=0.1, from_logits=False), - "beam": BeamSampler(num_beams=5), - } - return maps[identifier] - def _get_token_probability(self, prompt, mask): model_inputs = { "token_ids": prompt, @@ -104,7 +94,7 @@ def _get_token_probability(self, prompt, mask): def generate(self, prompt, max_length, generator="top_k"): """Pick one method as the default generation algo.""" if isinstance(generator, str): - generator = self._get_generator(generator) + generator = keras_nlp.samplers.get(generator) prompt = self.preprocessor.tokenizer(prompt) - generated = generator(self._get_token_probability, prompt, max_length) + generated = generator(prompt, self._get_token_probability, max_length) return self.preprocessor.tokenizer.detokenize(generated) From 2b93ad80c6bce77a3bf43aff0fced3fec30a1b80 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Tue, 17 Jan 2023 11:33:37 -0800 Subject: [PATCH 20/28] fix docstring --- keras_nlp/samplers/__init__.py | 2 +- keras_nlp/samplers/beam_sampler.py | 2 +- keras_nlp/samplers/sampler.py | 4 ++++ keras_nlp/samplers/top_k_sampler.py | 2 +- keras_nlp/samplers/top_p_sampler.py | 7 ++++--- 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/keras_nlp/samplers/__init__.py b/keras_nlp/samplers/__init__.py index 89911a78f4..5cc12f2cf6 100644 --- a/keras_nlp/samplers/__init__.py +++ b/keras_nlp/samplers/__init__.py @@ -16,7 +16,7 @@ from keras_nlp.samplers.beam_sampler import BeamSampler from keras_nlp.samplers.greedy_sampler import GreedySampler -from keras_nlp.samplers.greedy_sampler import Sampler +from keras_nlp.samplers.sampler import Sampler from keras_nlp.samplers.top_k_sampler import TopKSampler from keras_nlp.samplers.top_p_sampler import TopPSampler diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 24ac8b8f83..c0c53b4a80 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -30,7 +30,7 @@ class BeamSampler(Sampler): """Beam Sampler class. - This sampler implements beam search algorithm. At each time-step, beam + This sampler implements the beam search algorithm. At each time-step, beam search keeps the beams (sequences) of the top `num_beams` highest accumulated probabilities, and uses each one of the beams to predict candidate next tokens. diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 17137d6684..fdd4332c7e 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -70,6 +70,10 @@ class Sampler: The inputs and outputs of Sampler class are both token ids. + Subclassers should always implement the `sample()` method, which implements + the sampling algorithm body. Please check available subclass samplers for + example. + Examples: Basic usage: diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 8c7d25f617..eabb2b6de8 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -30,7 +30,7 @@ class TopKSampler(Sampler): """Top-K Sampler class. - This sampler implements top-k search algorithm. Briefly top-k algorithm + This sampler implements the top-k search algorithm. Briefly top-k algorithm randomly selects a token from the tokens of top K probability, with selection chance determined by the probability. diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index ccb2817035..4477778181 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -29,9 +29,10 @@ @keras.utils.register_keras_serializable(package="keras_nlp") class TopPSampler(Sampler): """Top-P Sampler class. - This sampler implements top-p search algorithm. Top-p search selects tokens - from the smallest subset of output probabilities that sum to greater than - `p`. Put in another way, top-p will first order token predictions by + + This sampler implements the top-p search algorithm. Top-p search selects + tokens from the smallest subset of output probabilities that sum to greater + than `p`. Put in another way, top-p will first order token predictions by likelihood, and ignore all tokens after the cumulative probability of selected tokens exceeds `p`, then select a token from the remaining tokens. From 820610374ab96929496497b626014373c5adf1f5 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 19 Jan 2023 12:00:27 -0800 Subject: [PATCH 21/28] some changes --- keras_nlp/models/__init__.py | 7 + keras_nlp/models/gpt2/gpt2_causal_lm.py | 71 ++++++-- .../gpt2/gpt2_causal_lm_preprocessor.py | 67 +++++++- keras_nlp/models/gpt2/gpt2_preprocessor.py | 155 +++++++++++++----- keras_nlp/tokenizers/byte_pair_tokenizer.py | 1 + 5 files changed, 244 insertions(+), 57 deletions(-) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 0f73bca401..9173802638 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -32,6 +32,13 @@ from keras_nlp.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM +from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( + GPT2CausalLMPreprocessor, +) +from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 72eb3f4170..4a30418b4e 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""BERT task specific models and heads.""" +"""GPT2 Causal LM (Language Model).""" import copy @@ -30,16 +30,40 @@ @keras.utils.register_keras_serializable(package="keras_nlp") class GPT2CausalLM(PipelineModel): - def __init__(self, backbone, preprocessor=None, **kwargs): + """GPT2 Causal LM task model. + + Causal LM is predicting the next token based on previous tokens, which is + the way GPT2 gets pretrained. Users can finetune `GPT2CausalLM` to generate + text similar to the custom dataset. `GPT2CausalLM` also has a public method + `generate()`, which generates text based on given prompt. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to raw inputs during + `fit()`, `predict()`, and `evaluate()`. This is done by default when + creating the model with `from_preset()`. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. The underlying model is provided by a + third party and subject to a separate license, available + [here](https://github.com/openai/gpt-2). + + Args: + backbone: A `keras_nlp.models.GPT2Backbone` instance. + preprocessor: A `keras_nlp.models.GPT2CausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + + + """ + def __init__(self, backbone, preprocessor=None, **kwargs): inputs = backbone.input x = backbone(inputs) - x = tf.matmul( + outputs = tf.matmul( x, backbone.get_layer("token_embedding").embeddings, transpose_b=True, ) - outputs = tf.keras.layers.Softmax()(x) # Instantiate using Functional API Model constructor super().__init__( inputs=inputs, @@ -88,13 +112,38 @@ def _get_token_probability(self, prompt, mask): "token_ids": prompt, "padding_mask": mask, } - probs = self(model_inputs) - return probs + return self(model_inputs) + + def generate( + self, + prompt, + max_length, + end_token="<|endoftext|>", + sampler="top_k", + ): + """Generate text. - def generate(self, prompt, max_length, generator="top_k"): - """Pick one method as the default generation algo.""" - if isinstance(generator, str): - generator = keras_nlp.samplers.get(generator) + This method generates text based on given `prompt`. Generation will + continue until `max_length` is met, and all generated tokens after + `end_token` will be truncated. + + Args: + prompt: a string, string Tensor or string RaggedTensor. The prompt + text for generation. + max_length: int. The max length of generated sequence. + end_token: string, defaults to "<|endoftext|>", which is the default + end token of GPT2. The token marking the end of the sequence, + tokens generated after the end token will be truncated. + """ + end_token_id = self.preprocessor.tokenizer.token_to_id(end_token) + + if isinstance(sampler, str): + sampler = keras_nlp.samplers.get(sampler) prompt = self.preprocessor.tokenizer(prompt) - generated = generator(prompt, self._get_token_probability, max_length) + generated = sampler( + prompt, + self._get_token_probability, + max_length=max_length, + end_token_id=end_token_id, + ) return self.preprocessor.tokenizer.detokenize(generated) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index a61eaa232f..49b5f2b4e0 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -1,4 +1,4 @@ -# Copyright 2022 The KerasNLP Authors +# Copyright 2023 The KerasNLP Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,24 +11,75 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf + +"""GPT2 Causal LM preprocessor layer.""" + from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor from keras_nlp.utils.keras_utils import pack_x_y_sample_weight class GPT2CausalLMPreprocessor(GPT2Preprocessor): + """GPT2 Causal LM preprocessor. + + This class subclasses `keras_nlp.models.GPT2Preprocessor` and keeps most of + its functionality. The only change is `GPT2CausalLMPreprocessor` sets + `y` (label) and `sample_weights` field by shifting the input sequence one + step towards left, and drop the last token as it does not have a successor, + e.g., if the tokenized input is `[1, 2, 3, 0, 0]` with + `padding_mask=[1, 1, 1, 0, 0]`, then after preprocessing, we + will have `x=[1, 2, 3, 0]` and `y=[2, 3, 0, 0]`, with + `padding_mask=[1, 1, 1, 0]` and `sample_weights=[1, 1, 0, 0]`. + + Args: + tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance. + sequence_length: The length of the packed inputs. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( + "gpt2_base_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("league of legends") + preprocessor(sentence) + # Same output. + preprocessor("league of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["taco tuesday", "gi gi gi gi"]) + preprocessor(sentences) + # Same output. + preprocessor(["taco tuesday", "gi gi gi gi"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + """ + def call(self, x, y=None, sample_weight=None): - token_ids = self.tokenizer(x) - mask = tf.ones_like(token_ids, dtype=tf.bool) - mask = mask.to_tensor(shape=(None, self.sequence_length)) - token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) + + x = super().call(x) + token_ids, padding_mask = x["token_ids"], x["padding_mask"] x = { "token_ids": token_ids[:, :-1], - "padding_mask": mask[:, 1:], + "padding_mask": padding_mask[:, :-1], } y = token_ids[:, 1:] - sample_weight = mask[:, 1:] + sample_weight = padding_mask[:, 1:] return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index fa0aaa85b0..4bfde949fb 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -17,24 +17,134 @@ import copy import tensorflow as tf -from tensorflow import keras from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight from keras_nlp.utils.python_utils import classproperty -class GPT2Preprocessor(keras.layers.Layer): +class GPT2Preprocessor(Preprocessor): + """GPT2 preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + - Tokenize the input using the `tokenizer`. + - Add the id of '<|endoftext|>' to the start and end of the tokenized input. + - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can + be passed directly to a `keras_nlp.models.GPT2Backbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + The call method of this layer accepts three arguments, `x`, `y`, and + `sample_weight`. `x` can be a python string or tensor representing a single + segment, a list of python strings representing a batch of single segments, + or a list of tensors representing multiple segments to be packed together. + `y` and `sample_weight` are both optional, can have any format, and will be + passed through unaltered. + + `GPT2Preprocessor` forces the input to have only one segment, as GPT2 is + mainly used for generation tasks.for tasks having multi-segment inputs + like "glue/mnli", please use a model designed for classification purposes + such as BERT or RoBERTa. + + Args: + tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance. + sequence_length: The length of the packed inputs. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.GPT2Preprocessor.from_preset("gpt2_base_en") + + # Tokenize and pack a single sentence. + sentence = tf.constant("league of legends") + preprocessor(sentence) + # Same output. + preprocessor("league of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["taco tuesday", "gi gi gi gi"]) + preprocessor(sentences) + # Same output. + preprocessor(["taco tuesday", "gi gi gi gi"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Alternatively, you can create a preprocessor from your own vocabulary. + # The usage is exactly the same as above. + vocab = { + "": 0, + "": 1, + "": 2, + "Ġafter": 5, + "noon": 6, + "Ġsun": 7, + } + merges = ["Ġ a", "Ġ s", "Ġ n", "e r", "n o", "o n", "Ġs u", "Ġa f", "no on"] + merges += ["Ġsu n", "Ġaf t", "Ġaft er"] + + tokenizer = keras_nlp.models.GPT2Tokenizer( + vocabulary=vocab, + merges=merges, + ) + preprocessor = keras_nlp.models.GPT2Preprocessor( + tokenizer=tokenizer, + sequence_length=20, + ) + ``` + """ + def __init__(self, tokenizer, sequence_length, **kwargs): super().__init__(**kwargs) - self.tokenizer = tokenizer + self._tokenizer = tokenizer self.sequence_length = sequence_length + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.packer.sequence_length, + } + ) + return config + def call(self, x, y=None, sample_weight=None): - token_ids = self.tokenizer(x) + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) > 1: + raise ValueError( + "GPT2 requires each input feature to contain only " + f"one segment, but received: {len(x)}. If you are using GPT2 " + "for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + token_ids = self._tokenizer(x[0]) + # batch_size = token_ids.nrows() + # start_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id) + # end_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id) + # token_ids = tf.concat([start_column, token_ids, end_column], axis=1) + mask = tf.ones_like(token_ids, dtype=tf.bool) mask = mask.to_tensor(shape=(None, self.sequence_length)) token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) @@ -49,37 +159,6 @@ def call(self, x, y=None, sample_weight=None): def presets(cls): return copy.deepcopy(backbone_presets) - @classmethod - def from_preset( - cls, - preset, - sequence_length=None, - **kwargs, - ): - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - - tokenizer = GPT2Tokenizer.from_preset(preset) - - # Use model's `max_sequence_length` if `sequence_length` unspecified; - # otherwise check that `sequence_length` not too long. - metadata = cls.presets[preset] - max_sequence_length = metadata["config"]["max_sequence_length"] - if sequence_length is not None: - if sequence_length > max_sequence_length: - raise ValueError( - f"`sequence_length` cannot be longer than `{preset}` " - f"preset's `max_sequence_length` of {max_sequence_length}. " - f"Received: {sequence_length}." - ) - else: - sequence_length = max_sequence_length - - return cls( - tokenizer=tokenizer, - sequence_length=sequence_length, - **kwargs, - ) + @classproperty + def tokenizer_cls(cls): + return GPT2Tokenizer diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 85639c0098..93a0ba3ce6 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -77,6 +77,7 @@ def remove_strings_from_inputs(tensor, string_to_remove): flatten_indexes = tf.where(non_empty_mask) flatten_result = tf.gather_nd(tensor, flatten_indexes) row_lengths = tf.reduce_sum(tf.cast(non_empty_mask, tf.int64), axis=1) + result = tf.RaggedTensor.from_row_lengths( values=flatten_result, row_lengths=row_lengths, From 9945c1372175266c1f6971357f56bea7cd913bfa Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 19 Jan 2023 18:06:49 -0800 Subject: [PATCH 22/28] add classes --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 191 +++++++++++++++++- .../gpt2/gpt2_causal_lm_preprocessor.py | 22 +- .../gpt2/gpt2_causal_lm_preprocessor_test.py | 127 ++++++++++++ keras_nlp/models/gpt2/gpt2_causal_lm_test.py | 160 +++++++++++++++ keras_nlp/models/gpt2/gpt2_preprocessor.py | 10 +- .../models/gpt2/gpt2_preprocessor_test.py | 122 +++++++++++ .../roberta/roberta_preprocessor_test.py | 3 + 7 files changed, 618 insertions(+), 17 deletions(-) create mode 100644 keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py create mode 100644 keras_nlp/models/gpt2/gpt2_causal_lm_test.py create mode 100644 keras_nlp/models/gpt2/gpt2_preprocessor_test.py diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 4a30418b4e..ebc5efb72f 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -28,6 +28,36 @@ from keras_nlp.utils.python_utils import classproperty +# @keras.utils.register_keras_serializable(package="keras_nlp") +class EmbeddingMapping(keras.layers.Layer): + def __init__(self, embedding_layer, name="embedding_mapping", **kwargs): + super().__init__(name=name, **kwargs) + self.embedding_layer = embedding_layer + + def call(self, inputs): + return tf.matmul( + inputs, + self.embedding_layer.embeddings, + transpose_b=True, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "embedding_layer": keras.layers.serialize(self.embedding_layer), + } + ) + return config + + @classmethod + def from_config(cls, config): + config["embedding_layer"] = keras.layers.deserialize( + config["embedding_layer"], + ) + return cls(**config) + + @keras.utils.register_keras_serializable(package="keras_nlp") class GPT2CausalLM(PipelineModel): """GPT2 Causal LM task model. @@ -53,17 +83,126 @@ class GPT2CausalLM(PipelineModel): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. + Examples: + + Example usage. + ```python + features = { + "token_ids": tf.constant( + [[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6) + ), + "padding_mask": tf.constant( + [[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6) + ), + } + labels = tf.constant( + [[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6) + ) + sample_weights = tf.constant( + [[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6) + ) + + # Randomly initialize a GPT2 backbone. + backbone = keras_nlp.models.GPT2Backbone( + vocabulary_size=50257, + num_layers=2, + num_heads=2, + hidden_dim=128, + intermediate_dim=256, + max_sequence_length=128, + ) + # Create a `GPT2CausalLM` and fit the data. + gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None) + gpt2_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + gpt2_lm.fit( + x=features, + y=labels, + sample_weights=sample_weights, + batch_size=2, + ) + ``` + + Raw string inputs. + ```python + # Create a dataset with raw string features in an `(x, y)` format. + features = [ + "I don't listen to music while coding.", + "But I watch youtube while coding!", + ] + + # Create a `GPT2CausalLM` and fit your data. + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + gpt2_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + gpt2_lm.fit(x=features, batch_size=2) + ``` + + Raw string inputs with customized preprocessing. + ```python + # Create a dataset with raw string features in an `(x, y)` format. + features = [ + "I don't listen to music while coding.", + "But I watch youtube while coding!", + ] + + # Use a shorter sequence length. + preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( + "gpt2_base_en", + sequence_length=128, + ) + + # Create a `GPT2CausalLM` and fit your data. + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( + "gpt2_base_en", + preprocessor=preprocessor, + ) + gpt2_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + gpt2_lm.fit(x=features, batch_size=2) + ``` + + # Use tf dataset. + ```python + features = [ + "I don't listen to music while coding.", + "But I watch youtube while coding!", + ] + ds = tf.data.Dataset.from_tensor_slices(features) + + # Create a `GPT2CausalLM` and fit your data. + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( + "gpt2_base_en", + preprocessor=preprocessor, + ) + gpt2_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + gpt2_lm.fit(x=features, batch_size=2) + ``` + + # Use `generate()` method to generate text. + ```python + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + gpt2_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + gpt2_lm.generate(["This is a", "Where are you"], max_length=30) + ``` """ def __init__(self, backbone, preprocessor=None, **kwargs): inputs = backbone.input x = backbone(inputs) - outputs = tf.matmul( - x, - backbone.get_layer("token_embedding").embeddings, - transpose_b=True, - ) + # embedding_layer = backbone.get_layer("token_embedding") + # embedding_map_layer = EmbeddingMapping(embedding_layer) + # outputs = embedding_map_layer(x) + outputs = x + # Instantiate using Functional API Model constructor super().__init__( inputs=inputs, @@ -72,12 +211,22 @@ def __init__(self, backbone, preprocessor=None, **kwargs): **kwargs, ) - self.preprocessor = preprocessor - self.backbone = backbone + self._backbone = backbone + self._preprocessor = preprocessor def preprocess_samples(self, x, y=None, sample_weight=None): return self.preprocessor(x, y=y, sample_weight=sample_weight) + @property + def backbone(self): + """The associated `keras_nlp.models.RobertaBackbone`.""" + return self._backbone + + @property + def preprocessor(self): + """A `keras_nlp.models.RobertaMaskedLMPreprocessor` for preprocessing inputs.""" + return self._preprocessor + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) @@ -124,7 +273,7 @@ def generate( """Generate text. This method generates text based on given `prompt`. Generation will - continue until `max_length` is met, and all generated tokens after + continue until `max_length` is met, and all tokens generated after `end_token` will be truncated. Args: @@ -134,11 +283,17 @@ def generate( end_token: string, defaults to "<|endoftext|>", which is the default end token of GPT2. The token marking the end of the sequence, tokens generated after the end token will be truncated. + sampler: a string or `keras_nlp.samplers.Sampler` instance. The + sampler to be used for text generation. """ end_token_id = self.preprocessor.tokenizer.token_to_id(end_token) if isinstance(sampler, str): sampler = keras_nlp.samplers.get(sampler) + if hasattr(self, "jit_compile"): + sampler.jit_compile = self.jit_compile + if hasattr(self, "run_eagerly"): + sampler.run_eagerly = self.run_eagerly prompt = self.preprocessor.tokenizer(prompt) generated = sampler( prompt, @@ -147,3 +302,23 @@ def generate( end_token_id=end_token_id, ) return self.preprocessor.tokenizer.detokenize(generated) + + def get_config(self): + return { + "backbone": keras.layers.serialize(self.backbone), + "preprocessor": keras.layers.serialize(self.preprocessor), + "name": self.name, + "trainable": self.trainable, + } + + @classmethod + def from_config(cls, config): + if "backbone" in config and isinstance(config["backbone"], dict): + config["backbone"] = keras.layers.deserialize(config["backbone"]) + if "preprocessor" in config and isinstance( + config["preprocessor"], dict + ): + config["preprocessor"] = keras.layers.deserialize( + config["preprocessor"] + ) + return cls(**config) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index 49b5f2b4e0..cb16962ae0 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -74,12 +74,20 @@ def call(self, x, y=None, sample_weight=None): x = super().call(x) token_ids, padding_mask = x["token_ids"], x["padding_mask"] - x = { - "token_ids": token_ids[:, :-1], - "padding_mask": padding_mask[:, :-1], - } - - y = token_ids[:, 1:] - sample_weight = padding_mask[:, 1:] + if len(token_ids.shape) == 1: + x = { + "token_ids": token_ids[:-1], + "padding_mask": padding_mask[:-1], + } + y = token_ids[1:] + sample_weight = padding_mask[1:] + else: + x = { + "token_ids": token_ids[:, :-1], + "padding_mask": padding_mask[:, :-1], + } + + y = token_ids[:, 1:] + sample_weight = padding_mask[:, 1:] return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..7ad2cce988 --- /dev/null +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py @@ -0,0 +1,127 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for GPT2 causal LM preprocessor layer.""" + +import os + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( + GPT2CausalLMPreprocessor, +) +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer + + +class GPT2CausalLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + vocab = { + "<|endoftext|>": 0, + "!": 1, + "air": 2, + "Ġair": 3, + "plane": 4, + "Ġat": 5, + "port": 6, + "Ġkoh": 7, + "li": 8, + "Ġis": 9, + "Ġthe": 10, + "Ġbest": 11, + } + + merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "a i", "p l", "n e"] + merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] + merges += ["Ġt h", "ai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] + merges += ["Ġai r", "Ġa i", "pla ne"] + + self.preprocessor = GPT2CausalLMPreprocessor( + tokenizer=GPT2Tokenizer( + vocabulary=vocab, + merges=merges, + ), + sequence_length=8, + ) + + def test_tokenize_strings(self): + input_data = "airplane at airport" + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [2, 4, 5, 3, 6, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0]) + self.assertAllEqual(y, [4, 5, 3, 6, 0, 0, 0]) + self.assertAllEqual(sw, [1, 1, 1, 1, 0, 0, 0]) + + def test_tokenize_list_of_strings(self): + input_data = ["airplane at airport"] * 4 + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], + [[2, 4, 5, 3, 6, 0, 0]] * 4, + ) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0]] * 4) + self.assertAllEqual(y, [[4, 5, 3, 6, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_tokenize_labeled_batch(self): + x = tf.constant(["airplane at airport"] * 4) + y = tf.constant([1] * 4) + sw = tf.constant([1.0] * 4) + x, y, sw = self.preprocessor(x, y, sw) + self.assertAllEqual( + x["token_ids"], + [[2, 4, 5, 3, 6, 0, 0]] * 4, + ) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0]] * 4) + self.assertAllEqual(y, [[4, 5, 3, 6, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_tokenize_labeled_dataset(self): + x = tf.constant(["airplane at airport"] * 4) + y = tf.constant([1] * 4) + sw = tf.constant([1.0] * 4) + ds = tf.data.Dataset.from_tensor_slices((x, y, sw)) + ds = ds.map(self.preprocessor) + x, y, sw = ds.batch(4).take(1).get_single_element() + + self.assertAllEqual( + x["token_ids"], + [[2, 4, 5, 3, 6, 0, 0]] * 4, + ) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0]] * 4) + self.assertAllEqual(y, [[4, 5, 3, 6, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0]] * 4) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + input_data = tf.constant(["airplane at airport"]) + + inputs = keras.Input(dtype="string", shape=()) + outputs, y, sw = self.preprocessor(inputs) + model = keras.Model(inputs, outputs) + + path = os.path.join(self.get_temp_dir(), filename) + model.save(path, save_format=save_format) + + restored_model = keras.models.load_model(path) + self.assertAllEqual( + model(input_data)["token_ids"], + restored_model(input_data)["token_ids"], + ) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py new file mode 100644 index 0000000000..8baf2ebbca --- /dev/null +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py @@ -0,0 +1,160 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for GPT2 causal LM model.""" + +import os + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM +from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( + GPT2CausalLMPreprocessor, +) +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer + + +class GPT2CausalLMTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + backbone = GPT2Backbone( + vocabulary_size=1000, + num_layers=2, + num_heads=2, + hidden_dim=64, + intermediate_dim=128, + max_sequence_length=128, + ) + vocab = { + "<|endoftext|>": 0, + "!": 1, + "air": 2, + "Ġair": 3, + "plane": 4, + "Ġat": 5, + "port": 6, + "Ġkoh": 7, + "li": 8, + "Ġis": 9, + "Ġthe": 10, + "Ġbest": 11, + } + + merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "a i", "p l", "n e"] + merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] + merges += ["Ġt h", "ai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] + merges += ["Ġai r", "Ġa i", "pla ne"] + + self.preprocessor = GPT2CausalLMPreprocessor( + GPT2Tokenizer(vocabulary=vocab, merges=merges), + sequence_length=8, + ) + self.causal_lm = GPT2CausalLM( + backbone, + preprocessor=self.preprocessor, + ) + self.causal_lm_no_preprocessing = GPT2CausalLM( + backbone, + preprocessor=None, + ) + + self.raw_batch = tf.constant( + [ + " airplane at airport", + " the airplane is the best", + " the best airport", + " kohli is the best", + ] + ) + self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] + self.raw_dataset = tf.data.Dataset.from_tensor_slices( + self.raw_batch + ).batch(2) + self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) + + def test_valid_call_causal_lm(self): + self.causal_lm(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_gpt2_causal_lm_predict(self, jit_compile): + self.causal_lm.compile(jit_compile=jit_compile) + self.causal_lm.predict(self.raw_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_gpt2_causal_lm_predict_no_preprocessing(self, jit_compile): + self.causal_lm_no_preprocessing.compile(jit_compile=jit_compile) + self.causal_lm_no_preprocessing.predict(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_gpt2_causal_lm_fit(self, jit_compile): + self.causal_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.causal_lm.fit(self.raw_dataset) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_gpt2_causal_lm_fit_no_preprocessing(self, jit_compile): + self.causal_lm_no_preprocessing.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.causal_lm_no_preprocessing.fit(self.preprocessed_dataset) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_gpt2_causal_lm_generate(self, jit_compile): + self.causal_lm_no_preprocessing.compile(jit_compile=jit_compile) + self.causal_lm.generate( + self.raw_batch, + max_length=10, + ) + + # String input + prompt = " airplane" + generated = self.causal_lm.generate( + prompt, + max_length=10, + ) + generated = generated.numpy().decode("utf-8") + self.assertTrue(prompt in generated) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saving_model(self, save_format, filename): + keras.utils.set_random_seed(42) + model_output = self.causal_lm.predict(self.raw_batch) + save_path = os.path.join(self.get_temp_dir(), filename) + self.causal_lm.save(save_path, save_format) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, GPT2CausalLM) + + # Check that output matches. + keras.utils.set_random_seed(42) + restored_output = restored_model.predict(self.raw_batch) + self.assertAllClose(model_output, restored_output) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 4bfde949fb..6ea65c29cd 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -125,7 +125,7 @@ def get_config(self): config = super().get_config() config.update( { - "sequence_length": self.packer.sequence_length, + "sequence_length": self.sequence_length, } ) return config @@ -144,10 +144,16 @@ def call(self, x, y=None, sample_weight=None): # start_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id) # end_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id) # token_ids = tf.concat([start_column, token_ids, end_column], axis=1) - + input_is_1d = False + if len(token_ids.shape) == 1: + input_is_1d = True + token_ids = tf.RaggedTensor.from_tensor([token_ids]) mask = tf.ones_like(token_ids, dtype=tf.bool) mask = mask.to_tensor(shape=(None, self.sequence_length)) token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) + if input_is_1d: + token_ids = tf.squeeze(token_ids, axis=0) + mask = tf.squeeze(mask, axis=0) x = { "token_ids": token_ids, "padding_mask": mask, diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py new file mode 100644 index 0000000000..d06f8f9f45 --- /dev/null +++ b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py @@ -0,0 +1,122 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for GPT2 preprocessor layer.""" + +import os + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer + + +class GPT2PreprocessorTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + vocab = { + "<|endoftext|>": 0, + "!": 1, + "air": 2, + "Ġair": 3, + "plane": 4, + "Ġat": 5, + "port": 6, + "Ġkoh": 7, + "li": 8, + "Ġis": 9, + "Ġthe": 10, + "Ġbest": 11, + } + + merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "a i", "p l", "n e"] + merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] + merges += ["Ġt h", "ai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] + merges += ["Ġai r", "Ġa i", "pla ne"] + + self.preprocessor = GPT2Preprocessor( + tokenizer=GPT2Tokenizer( + vocabulary=vocab, + merges=merges, + ), + sequence_length=8, + ) + + def test_tokenize_strings(self): + input_data = "airplane at airport" + + output = self.preprocessor(input_data) + self.assertAllEqual(output["token_ids"], [2, 4, 5, 3, 6, 0, 0, 0]) + self.assertAllEqual(output["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_tokenize_list_of_strings(self): + input_data = ["airplane at airport"] * 4 + + output = self.preprocessor(input_data) + self.assertAllEqual( + output["token_ids"], + [[2, 4, 5, 3, 6, 0, 0, 0]] * 4, + ) + + self.assertAllEqual( + output["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4 + ) + + def test_tokenize_labeled_batch(self): + x = tf.constant(["airplane at airport"] * 4) + y = tf.constant([1] * 4) + sw = tf.constant([1.0] * 4) + x_out, y_out, sw_out = self.preprocessor(x, y, sw) + self.assertAllEqual(x_out["token_ids"], [[2, 4, 5, 3, 6, 0, 0, 0]] * 4) + self.assertAllEqual( + x_out["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4 + ) + self.assertAllEqual(y_out, y) + self.assertAllEqual(sw_out, sw) + + def test_tokenize_labeled_dataset(self): + x = tf.constant(["airplane at airport"] * 4) + y = tf.constant([1] * 4) + sw = tf.constant([1.0] * 4) + ds = tf.data.Dataset.from_tensor_slices((x, y, sw)) + ds = ds.map(self.preprocessor) + x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element() + + self.assertAllEqual(x_out["token_ids"], [[2, 4, 5, 3, 6, 0, 0, 0]] * 4) + self.assertAllEqual( + x_out["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4 + ) + self.assertAllEqual(y_out, y) + self.assertAllEqual(sw_out, sw) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + input_data = tf.constant(["airplane at airport"]) + + inputs = keras.Input(dtype="string", shape=()) + outputs = self.preprocessor(inputs) + model = keras.Model(inputs, outputs) + + path = os.path.join(self.get_temp_dir(), filename) + model.save(path, save_format=save_format) + + restored_model = keras.models.load_model(path) + self.assertAllEqual( + model(input_data)["token_ids"], + restored_model(input_data)["token_ids"], + ) diff --git a/keras_nlp/models/roberta/roberta_preprocessor_test.py b/keras_nlp/models/roberta/roberta_preprocessor_test.py index 2045408547..481356f340 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor_test.py +++ b/keras_nlp/models/roberta/roberta_preprocessor_test.py @@ -98,6 +98,9 @@ def test_tokenize_labeled_dataset(self): sw = tf.constant([1.0] * 4) ds = tf.data.Dataset.from_tensor_slices((x, y, sw)) ds = ds.map(self.preprocessor) + import pdb + + pdb.set_trace() x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element() self.assertAllEqual( x_out["token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1, 1, 1]] * 4 From 4fa8fc5d78b4d2c634d3dbf9809f252ca43d43e5 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 19 Jan 2023 18:36:30 -0800 Subject: [PATCH 23/28] fix serialization --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 23 +++++++++++++------ .../gpt2/gpt2_causal_lm_preprocessor.py | 2 ++ keras_nlp/models/gpt2/gpt2_preprocessor.py | 17 ++++++-------- .../roberta/roberta_preprocessor_test.py | 3 --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 1 - 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index ebc5efb72f..a1afe8cb24 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -28,8 +28,18 @@ from keras_nlp.utils.python_utils import classproperty -# @keras.utils.register_keras_serializable(package="keras_nlp") +@keras.utils.register_keras_serializable(package="keras_nlp") class EmbeddingMapping(keras.layers.Layer): + """A layer multiplying model outputs by the token embedding. + + This layer is used to map model outputs to logits over all vocab tokens. + It's used in `GPT2CausalLM` to calculate next token's probability. + + Args: + embedding_layer: a `tf.keras.layers.Embedding` instance, the token + embedding layer. + """ + def __init__(self, embedding_layer, name="embedding_mapping", **kwargs): super().__init__(name=name, **kwargs) self.embedding_layer = embedding_layer @@ -198,10 +208,9 @@ class GPT2CausalLM(PipelineModel): def __init__(self, backbone, preprocessor=None, **kwargs): inputs = backbone.input x = backbone(inputs) - # embedding_layer = backbone.get_layer("token_embedding") - # embedding_map_layer = EmbeddingMapping(embedding_layer) - # outputs = embedding_map_layer(x) - outputs = x + embedding_layer = backbone.get_layer("token_embedding") + embedding_map_layer = EmbeddingMapping(embedding_layer) + outputs = embedding_map_layer(x) # Instantiate using Functional API Model constructor super().__init__( @@ -219,12 +228,12 @@ def preprocess_samples(self, x, y=None, sample_weight=None): @property def backbone(self): - """The associated `keras_nlp.models.RobertaBackbone`.""" + """The associated `keras_nlp.models.GPT2Backbone`.""" return self._backbone @property def preprocessor(self): - """A `keras_nlp.models.RobertaMaskedLMPreprocessor` for preprocessing inputs.""" + """A `keras_nlp.models.GPT2CausalLMPreprocessor` for preprocessing.""" return self._preprocessor @classproperty diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index cb16962ae0..7fbf8b4f63 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -14,11 +14,13 @@ """GPT2 Causal LM preprocessor layer.""" +from tensorflow import keras from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +@keras.utils.register_keras_serializable(package="keras_nlp") class GPT2CausalLMPreprocessor(GPT2Preprocessor): """GPT2 Causal LM preprocessor. diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 6ea65c29cd..8638685f88 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -17,6 +17,7 @@ import copy import tensorflow as tf +from tensorflow import keras from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer @@ -28,13 +29,13 @@ from keras_nlp.utils.python_utils import classproperty +@keras.utils.register_keras_serializable(package="keras_nlp") class GPT2Preprocessor(Preprocessor): """GPT2 preprocessing layer which tokenizes and packs inputs. - This preprocessing layer will do three things: + This preprocessing layer will do 2 things: - Tokenize the input using the `tokenizer`. - - Add the id of '<|endoftext|>' to the start and end of the tokenized input. - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can be passed directly to a `keras_nlp.models.GPT2Backbone`. @@ -135,23 +136,19 @@ def call(self, x, y=None, sample_weight=None): if len(x) > 1: raise ValueError( "GPT2 requires each input feature to contain only " - f"one segment, but received: {len(x)}. If you are using GPT2 " + f"one segment, but received {len(x)}. If you are using GPT2 " "for a multi-segment classification task, please refer to " "classification models like BERT or RoBERTa." ) token_ids = self._tokenizer(x[0]) - # batch_size = token_ids.nrows() - # start_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id) - # end_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id) - # token_ids = tf.concat([start_column, token_ids, end_column], axis=1) - input_is_1d = False - if len(token_ids.shape) == 1: - input_is_1d = True + input_is_1d = len(token_ids.shape) == 1 + if input_is_1d: token_ids = tf.RaggedTensor.from_tensor([token_ids]) mask = tf.ones_like(token_ids, dtype=tf.bool) mask = mask.to_tensor(shape=(None, self.sequence_length)) token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) if input_is_1d: + # If the input is a single string, we let the output be a 1D tensor. token_ids = tf.squeeze(token_ids, axis=0) mask = tf.squeeze(mask, axis=0) x = { diff --git a/keras_nlp/models/roberta/roberta_preprocessor_test.py b/keras_nlp/models/roberta/roberta_preprocessor_test.py index 481356f340..2045408547 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor_test.py +++ b/keras_nlp/models/roberta/roberta_preprocessor_test.py @@ -98,9 +98,6 @@ def test_tokenize_labeled_dataset(self): sw = tf.constant([1.0] * 4) ds = tf.data.Dataset.from_tensor_slices((x, y, sw)) ds = ds.map(self.preprocessor) - import pdb - - pdb.set_trace() x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element() self.assertAllEqual( x_out["token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1, 1, 1]] * 4 diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 93a0ba3ce6..85639c0098 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -77,7 +77,6 @@ def remove_strings_from_inputs(tensor, string_to_remove): flatten_indexes = tf.where(non_empty_mask) flatten_result = tf.gather_nd(tensor, flatten_indexes) row_lengths = tf.reduce_sum(tf.cast(non_empty_mask, tf.int64), axis=1) - result = tf.RaggedTensor.from_row_lengths( values=flatten_result, row_lengths=row_lengths, From cb12604e040397308b21dd2c1f9fc397ed485839 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 23 Jan 2023 15:20:58 -0800 Subject: [PATCH 24/28] fix docstring --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 211 +++++++----------- .../gpt2/gpt2_causal_lm_preprocessor.py | 26 +-- keras_nlp/models/gpt2/gpt2_preprocessor.py | 7 +- keras_nlp/samplers/sampler.py | 4 +- 4 files changed, 95 insertions(+), 153 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index a1afe8cb24..f010c5e288 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -24,12 +24,12 @@ GPT2CausalLMPreprocessor, ) from keras_nlp.models.gpt2.gpt2_presets import backbone_presets -from keras_nlp.utils.pipeline_model import PipelineModel +from keras_nlp.models.task import Task from keras_nlp.utils.python_utils import classproperty @keras.utils.register_keras_serializable(package="keras_nlp") -class EmbeddingMapping(keras.layers.Layer): +class ReverseEmbedding(keras.layers.Layer): """A layer multiplying model outputs by the token embedding. This layer is used to map model outputs to logits over all vocab tokens. @@ -69,7 +69,7 @@ def from_config(cls, config): @keras.utils.register_keras_serializable(package="keras_nlp") -class GPT2CausalLM(PipelineModel): +class GPT2CausalLM(Task): """GPT2 Causal LM task model. Causal LM is predicting the next token based on previous tokens, which is @@ -95,68 +95,56 @@ class GPT2CausalLM(PipelineModel): Examples: - Example usage. + Use `generate()` method to do text generation. ```python - features = { - "token_ids": tf.constant( - [[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6) - ), - "padding_mask": tf.constant( - [[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6) - ), - } - labels = tf.constant( - [[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6) - ) - sample_weights = tf.constant( - [[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6) - ) + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + gpt2_lm.generate("I want to say", max_length=30) - # Randomly initialize a GPT2 backbone. - backbone = keras_nlp.models.GPT2Backbone( - vocabulary_size=50257, - num_layers=2, - num_heads=2, - hidden_dim=128, - intermediate_dim=256, - max_sequence_length=128, - ) - # Create a `GPT2CausalLM` and fit the data. - gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None) - gpt2_lm.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - ) - gpt2_lm.fit( - x=features, - y=labels, - sample_weights=sample_weights, - batch_size=2, - ) + # Generate with batched prompts. + gpt2_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Use a custom sampler for text generation. + ```python + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Use string identifier to set sampler. + gpt2_lm.generate("I want to say", max_length=30, sampler="top_p") + + # Construct a sampler instance. + sampler = keras_nlp.samplers.BeamSampler(num_beams=2) + gpt2_lm.generate("I want to say", max_length=30, sampler=sampler) + ``` + + Load a pretrained `GPT2CausalLM` and get outputs on raw string inputs. + ```python + str_inputs = "You know this is just a test string" + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + gpt2_lm.predict([str_inputs]) ``` - Raw string inputs. + Load a pretrained GPT2 and fit on a string dataset. ```python - # Create a dataset with raw string features in an `(x, y)` format. features = [ "I don't listen to music while coding.", "But I watch youtube while coding!", ] + ds = tf.data.Dataset.from_tensor_slices(features) # Create a `GPT2CausalLM` and fit your data. - gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( + "gpt2_base_en", + ) gpt2_lm.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), ) gpt2_lm.fit(x=features, batch_size=2) ``` - Raw string inputs with customized preprocessing. + Load a pretrain `GPT2CausalLM` with custom preprocessor, and predict on + string inputs. ```python - # Create a dataset with raw string features in an `(x, y)` format. - features = [ - "I don't listen to music while coding.", - "But I watch youtube while coding!", - ] + str_inputs = "You know this is still a test string" # Use a shorter sequence length. preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( @@ -164,43 +152,53 @@ class GPT2CausalLM(PipelineModel): sequence_length=128, ) - # Create a `GPT2CausalLM` and fit your data. + # Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor. gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( "gpt2_base_en", preprocessor=preprocessor, ) - gpt2_lm.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - ) - gpt2_lm.fit(x=features, batch_size=2) + gpt2_lm.predict([str_inputs]) ``` - # Use tf dataset. + Fit your preprocessed data with randomly initialized GPT2. This is useful + when you want to do data preprocessing inside `tf.data` pipeline. ```python - features = [ - "I don't listen to music while coding.", - "But I watch youtube while coding!", - ] - ds = tf.data.Dataset.from_tensor_slices(features) + # Define preprocessed input. + features = { + "token_ids": tf.constant( + [[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6) + ), + "padding_mask": tf.constant( + [[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6) + ), + } + labels = tf.constant( + [[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6) + ) + sample_weight = tf.constant( + [[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6) + ) - # Create a `GPT2CausalLM` and fit your data. - gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( - "gpt2_base_en", - preprocessor=preprocessor, + # Randomly initialize a GPT2 backbone. + backbone = keras_nlp.models.GPT2Backbone( + vocabulary_size=50257, + num_layers=2, + num_heads=2, + hidden_dim=128, + intermediate_dim=256, + max_sequence_length=128, ) + # Create a `GPT2CausalLM` without preprocessor and fit the data. + gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None) gpt2_lm.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), ) - gpt2_lm.fit(x=features, batch_size=2) - ``` - - # Use `generate()` method to generate text. - ```python - gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - gpt2_lm.generate("I want to say", max_length=30) - - # Generate with batched prompts. - gpt2_lm.generate(["This is a", "Where are you"], max_length=30) + gpt2_lm.fit( + x=features, + y=labels, + sample_weight=sample_weight, + batch_size=2, + ) ``` """ @@ -209,7 +207,7 @@ def __init__(self, backbone, preprocessor=None, **kwargs): inputs = backbone.input x = backbone(inputs) embedding_layer = backbone.get_layer("token_embedding") - embedding_map_layer = EmbeddingMapping(embedding_layer) + embedding_map_layer = ReverseEmbedding(embedding_layer) outputs = embedding_map_layer(x) # Instantiate using Functional API Model constructor @@ -223,47 +221,17 @@ def __init__(self, backbone, preprocessor=None, **kwargs): self._backbone = backbone self._preprocessor = preprocessor - def preprocess_samples(self, x, y=None, sample_weight=None): - return self.preprocessor(x, y=y, sample_weight=sample_weight) - - @property - def backbone(self): - """The associated `keras_nlp.models.GPT2Backbone`.""" - return self._backbone - - @property - def preprocessor(self): - """A `keras_nlp.models.GPT2CausalLMPreprocessor` for preprocessing.""" - return self._preprocessor - @classproperty def presets(cls): return copy.deepcopy(backbone_presets) - @classmethod - def from_preset( - cls, - preset, - load_weights=True, - **kwargs, - ): - if "preprocessor" not in kwargs: - kwargs["preprocessor"] = GPT2CausalLMPreprocessor.from_preset( - preset - ) - - # Check if preset is backbone-only model. - if preset in GPT2Backbone.presets: - backbone = GPT2Backbone.from_preset(preset, load_weights) - return cls(backbone, **kwargs) - - # Otherwise must be one of class presets. - # Currently no classifier-level presets, so we raise ValueError. - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) + @classproperty + def backbone_cls(cls): + return GPT2Backbone + + @classproperty + def preprocessor_cls(cls): + return GPT2CausalLMPreprocessor def _get_token_probability(self, prompt, mask): model_inputs = { @@ -301,8 +269,7 @@ def generate( sampler = keras_nlp.samplers.get(sampler) if hasattr(self, "jit_compile"): sampler.jit_compile = self.jit_compile - if hasattr(self, "run_eagerly"): - sampler.run_eagerly = self.run_eagerly + sampler.run_eagerly = self.run_eagerly prompt = self.preprocessor.tokenizer(prompt) generated = sampler( prompt, @@ -311,23 +278,3 @@ def generate( end_token_id=end_token_id, ) return self.preprocessor.tokenizer.detokenize(generated) - - def get_config(self): - return { - "backbone": keras.layers.serialize(self.backbone), - "preprocessor": keras.layers.serialize(self.preprocessor), - "name": self.name, - "trainable": self.trainable, - } - - @classmethod - def from_config(cls, config): - if "backbone" in config and isinstance(config["backbone"], dict): - config["backbone"] = keras.layers.deserialize(config["backbone"]) - if "preprocessor" in config and isinstance( - config["preprocessor"], dict - ): - config["preprocessor"] = keras.layers.deserialize( - config["preprocessor"] - ) - return cls(**config) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index 7fbf8b4f63..915fcef21c 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -51,10 +51,10 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor): preprocessor("league of legends") # Tokenize a batch of sentences. - sentences = tf.constant(["taco tuesday", "gi gi gi gi"]) + sentences = tf.constant(["taco tuesday", "fish taco please!"]) preprocessor(sentences) # Same output. - preprocessor(["taco tuesday", "gi gi gi gi"]) + preprocessor(["taco tuesday", "fish taco please!"]) # Map a dataset to preprocess a single sentence. features = tf.constant( @@ -76,20 +76,10 @@ def call(self, x, y=None, sample_weight=None): x = super().call(x) token_ids, padding_mask = x["token_ids"], x["padding_mask"] - if len(token_ids.shape) == 1: - x = { - "token_ids": token_ids[:-1], - "padding_mask": padding_mask[:-1], - } - y = token_ids[1:] - sample_weight = padding_mask[1:] - else: - x = { - "token_ids": token_ids[:, :-1], - "padding_mask": padding_mask[:, :-1], - } - - y = token_ids[:, 1:] - sample_weight = padding_mask[:, 1:] - + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + y = token_ids[..., 1:] + sample_weight = padding_mask[..., 1:] return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 8638685f88..92b1f5d56a 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -115,7 +115,12 @@ class GPT2Preprocessor(Preprocessor): ``` """ - def __init__(self, tokenizer, sequence_length, **kwargs): + def __init__( + self, + tokenizer, + sequence_length, + **kwargs, + ): super().__init__(**kwargs) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index a646982eb5..5016b9c99e 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -87,7 +87,7 @@ def token_probability_fn(inputs, mask): prompt = tf.fill((8, 1), 1) - sampler = keras_nlp.samplers.Greedy() + sampler = keras_nlp.samplers.GreedySampler() # Print the generated sequence (token ids). print(sampler(prompt, token_probability_fn, max_length=10, end_token_id=2)) ``` @@ -118,7 +118,7 @@ def token_probability_fn(inputs, mask): return model(inputs) prompt = tokenizer("the quick brown fox") - sampler = keras_nlp.samplers.Greedy() + sampler = keras_nlp.samplers.GreedySampler() generated = sampler( prompt, token_probability_fn, From f7685ca6b5a2d33ab385570661dca23f033c0bca Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Tue, 24 Jan 2023 13:55:45 -0800 Subject: [PATCH 25/28] address comments --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 15 +++++---------- .../models/gpt2/gpt2_causal_lm_preprocessor.py | 3 ++- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index f010c5e288..58ada27882 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -118,9 +118,8 @@ class GPT2CausalLM(Task): Load a pretrained `GPT2CausalLM` and get outputs on raw string inputs. ```python - str_inputs = "You know this is just a test string" gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - gpt2_lm.predict([str_inputs]) + gpt2_lm.predict(["You know this is just a test string"]) ``` Load a pretrained GPT2 and fit on a string dataset. @@ -141,11 +140,9 @@ class GPT2CausalLM(Task): gpt2_lm.fit(x=features, batch_size=2) ``` - Load a pretrain `GPT2CausalLM` with custom preprocessor, and predict on + Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on string inputs. ```python - str_inputs = "You know this is still a test string" - # Use a shorter sequence length. preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( "gpt2_base_en", @@ -157,7 +154,7 @@ class GPT2CausalLM(Task): "gpt2_base_en", preprocessor=preprocessor, ) - gpt2_lm.predict([str_inputs]) + gpt2_lm.predict(["You know this is still a test string"]) ``` Fit your preprocessed data with randomly initialized GPT2. This is useful @@ -244,7 +241,6 @@ def generate( self, prompt, max_length, - end_token="<|endoftext|>", sampler="top_k", ): """Generate text. @@ -263,16 +259,15 @@ def generate( sampler: a string or `keras_nlp.samplers.Sampler` instance. The sampler to be used for text generation. """ - end_token_id = self.preprocessor.tokenizer.token_to_id(end_token) + end_token_id = self.preprocessor.tokenizer.end_token_id if isinstance(sampler, str): sampler = keras_nlp.samplers.get(sampler) if hasattr(self, "jit_compile"): sampler.jit_compile = self.jit_compile sampler.run_eagerly = self.run_eagerly - prompt = self.preprocessor.tokenizer(prompt) generated = sampler( - prompt, + self.preprocessor.tokenizer(prompt), self._get_token_probability, max_length=max_length, end_token_id=end_token_id, diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index 915fcef21c..f4af06e41c 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -73,13 +73,14 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor): """ def call(self, x, y=None, sample_weight=None): - x = super().call(x) token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # The last token does not have a next token, so we truncate it out. x = { "token_ids": token_ids[..., :-1], "padding_mask": padding_mask[..., :-1], } + # Target `y` will be the next token. y = token_ids[..., 1:] sample_weight = padding_mask[..., 1:] return pack_x_y_sample_weight(x, y, sample_weight) From 2ed9adb312975b79ce100f63afd62ddc976f7d3a Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Wed, 25 Jan 2023 14:12:42 -0800 Subject: [PATCH 26/28] one more --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 57 +++---------------- .../gpt2/gpt2_causal_lm_preprocessor.py | 9 +++ keras_nlp/models/gpt2/gpt2_preprocessor.py | 2 +- keras_nlp/samplers/sampler.py | 8 ++- 4 files changed, 23 insertions(+), 53 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 58ada27882..482eedeb7a 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -28,46 +28,6 @@ from keras_nlp.utils.python_utils import classproperty -@keras.utils.register_keras_serializable(package="keras_nlp") -class ReverseEmbedding(keras.layers.Layer): - """A layer multiplying model outputs by the token embedding. - - This layer is used to map model outputs to logits over all vocab tokens. - It's used in `GPT2CausalLM` to calculate next token's probability. - - Args: - embedding_layer: a `tf.keras.layers.Embedding` instance, the token - embedding layer. - """ - - def __init__(self, embedding_layer, name="embedding_mapping", **kwargs): - super().__init__(name=name, **kwargs) - self.embedding_layer = embedding_layer - - def call(self, inputs): - return tf.matmul( - inputs, - self.embedding_layer.embeddings, - transpose_b=True, - ) - - def get_config(self): - config = super().get_config() - config.update( - { - "embedding_layer": keras.layers.serialize(self.embedding_layer), - } - ) - return config - - @classmethod - def from_config(cls, config): - config["embedding_layer"] = keras.layers.deserialize( - config["embedding_layer"], - ) - return cls(**config) - - @keras.utils.register_keras_serializable(package="keras_nlp") class GPT2CausalLM(Task): """GPT2 Causal LM task model. @@ -116,7 +76,7 @@ class GPT2CausalLM(Task): gpt2_lm.generate("I want to say", max_length=30, sampler=sampler) ``` - Load a pretrained `GPT2CausalLM` and get outputs on raw string inputs. + Map raw string to languages model logit predictions. ```python gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") gpt2_lm.predict(["You know this is just a test string"]) @@ -137,7 +97,7 @@ class GPT2CausalLM(Task): gpt2_lm.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), ) - gpt2_lm.fit(x=features, batch_size=2) + gpt2_lm.fit(ds, batch_size=2) ``` Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on @@ -203,11 +163,13 @@ class GPT2CausalLM(Task): def __init__(self, backbone, preprocessor=None, **kwargs): inputs = backbone.input x = backbone(inputs) - embedding_layer = backbone.get_layer("token_embedding") - embedding_map_layer = ReverseEmbedding(embedding_layer) - outputs = embedding_map_layer(x) + outputs = tf.matmul( + x, + backbone.token_embedding.embeddings, + transpose_b=True, + ) - # Instantiate using Functional API Model constructor + # Instantiate using Functional API Model constructor. super().__init__( inputs=inputs, outputs=outputs, @@ -253,9 +215,6 @@ def generate( prompt: a string, string Tensor or string RaggedTensor. The prompt text for generation. max_length: int. The max length of generated sequence. - end_token: string, defaults to "<|endoftext|>", which is the default - end token of GPT2. The token marking the end of the sequence, - tokens generated after the end token will be truncated. sampler: a string or `keras_nlp.samplers.Sampler` instance. The sampler to be used for text generation. """ diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index f4af06e41c..cac75f718b 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -14,6 +14,7 @@ """GPT2 Causal LM preprocessor layer.""" +from absl import logging from tensorflow import keras from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor @@ -73,6 +74,14 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor): """ def call(self, x, y=None, sample_weight=None): + if y is not None or sample_weight is not None: + logging.warning( + "`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` " + "based on your input data, but your data already contain `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "overrided." + ) + x = super().call(x) token_ids, padding_mask = x["token_ids"], x["padding_mask"] # The last token does not have a next token, so we truncate it out. diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 92b1f5d56a..b11bac30f8 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -145,7 +145,7 @@ def call(self, x, y=None, sample_weight=None): "for a multi-segment classification task, please refer to " "classification models like BERT or RoBERTa." ) - token_ids = self._tokenizer(x[0]) + token_ids = self.tokenizer(x[0]) input_is_1d = len(token_ids.shape) == 1 if input_is_1d: token_ids = tf.RaggedTensor.from_tensor([token_ids]) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 1d1fed9e12..974c46b627 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -57,9 +57,11 @@ class Sampler: The inputs and outputs of Sampler class are both token ids. - Subclassers should always implement the `sample()` method, which implements - the sampling algorithm body. Please check available subclass samplers for - example. + Subclassers should always implement the `get_next_token()` method, which + gets the next token based on probability distribution over vocab tokens. + Please check available subclass samplers for examples. If you need more + control over the sampling process, please implement `sample()` method + instead, see `BeamSampler` for example. Examples: From f2821b5585620d15beb4453ea62cb55c14970885 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 26 Jan 2023 14:52:23 -0800 Subject: [PATCH 27/28] fix docstring --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 21 ++++++++++++------- .../gpt2/gpt2_causal_lm_preprocessor.py | 19 +++++++++-------- keras_nlp/models/gpt2/gpt2_preprocessor.py | 2 +- keras_nlp/samplers/sampler.py | 2 +- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 482eedeb7a..b1078503e0 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -30,12 +30,13 @@ @keras.utils.register_keras_serializable(package="keras_nlp") class GPT2CausalLM(Task): - """GPT2 Causal LM task model. + """An end-to-end GPT2 model for causal langauge modeling. - Causal LM is predicting the next token based on previous tokens, which is - the way GPT2 gets pretrained. Users can finetune `GPT2CausalLM` to generate - text similar to the custom dataset. `GPT2CausalLM` also has a public method - `generate()`, which generates text based on given prompt. + A causal language model (LM) predicts the next token based on previous + tokens the next token based on previous tokens, which is the way GPT2 gets + pretrained. You can finetune `GPT2CausalLM` to generate text similar to + the custom dataset. `GPT2CausalLM` also has a method `generate()`, which + generates text based on given prompt. This model can optionally be configured with a `preprocessor` layer, in which case it will automatically apply preprocessing to raw inputs during @@ -163,6 +164,8 @@ class GPT2CausalLM(Task): def __init__(self, backbone, preprocessor=None, **kwargs): inputs = backbone.input x = backbone(inputs) + # Use token embedding weights to project from the token representation + # to vocabulary logits. outputs = tf.matmul( x, backbone.token_embedding.embeddings, @@ -209,7 +212,8 @@ def generate( This method generates text based on given `prompt`. Generation will continue until `max_length` is met, and all tokens generated after - `end_token` will be truncated. + `end_token` will be truncated. The sampling approach used can be + controlled via the sampler argument. Args: prompt: a string, string Tensor or string RaggedTensor. The prompt @@ -220,9 +224,10 @@ def generate( """ end_token_id = self.preprocessor.tokenizer.end_token_id - if isinstance(sampler, str): - sampler = keras_nlp.samplers.get(sampler) + sampler = keras_nlp.samplers.get(sampler) if hasattr(self, "jit_compile"): + # `jit_compile` is a public property as of tf 2.12. hasattr is for + # backward compat. sampler.jit_compile = self.jit_compile sampler.run_eagerly = self.run_eagerly generated = sampler( diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index cac75f718b..803d111c45 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -25,14 +25,15 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor): """GPT2 Causal LM preprocessor. + This preprocessor is majorly used as the preprocesor for `GPT2CausalLM`. This class subclasses `keras_nlp.models.GPT2Preprocessor` and keeps most of its functionality. The only change is `GPT2CausalLMPreprocessor` sets `y` (label) and `sample_weights` field by shifting the input sequence one step towards left, and drop the last token as it does not have a successor, e.g., if the tokenized input is `[1, 2, 3, 0, 0]` with - `padding_mask=[1, 1, 1, 0, 0]`, then after preprocessing, we - will have `x=[1, 2, 3, 0]` and `y=[2, 3, 0, 0]`, with - `padding_mask=[1, 1, 1, 0]` and `sample_weights=[1, 1, 0, 0]`. + `padding_mask = [1, 1, 1, 0, 0]`, then after preprocessing, we + will have `x = [1, 2, 3, 0]` and `y = [2, 3, 0, 0]`, with + `padding_mask = [1, 1, 1, 0]` and `sample_weights = [1, 1, 0, 0]`. Args: tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance. @@ -46,16 +47,16 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor): ) # Tokenize and pack a single sentence. - sentence = tf.constant("league of legends") + sentence = tf.constant("League of legends") preprocessor(sentence) # Same output. - preprocessor("league of legends") + preprocessor("League of legends") # Tokenize a batch of sentences. - sentences = tf.constant(["taco tuesday", "fish taco please!"]) + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) preprocessor(sentences) # Same output. - preprocessor(["taco tuesday", "fish taco please!"]) + preprocessor(["Taco tuesday", "Fish taco please!"]) # Map a dataset to preprocess a single sentence. features = tf.constant( @@ -77,9 +78,9 @@ def call(self, x, y=None, sample_weight=None): if y is not None or sample_weight is not None: logging.warning( "`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` " - "based on your input data, but your data already contain `y` " + "based on your input data, but your data already contains `y` " "or `sample_weight`. Your `y` and `sample_weight` will be " - "overrided." + "ignored." ) x = super().call(x) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index b11bac30f8..5f6090a2d7 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -51,7 +51,7 @@ class GPT2Preprocessor(Preprocessor): passed through unaltered. `GPT2Preprocessor` forces the input to have only one segment, as GPT2 is - mainly used for generation tasks.for tasks having multi-segment inputs + mainly used for generation tasks. For tasks having multi-segment inputs like "glue/mnli", please use a model designed for classification purposes such as BERT or RoBERTa. diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 974c46b627..5a849b7b49 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -61,7 +61,7 @@ class Sampler: gets the next token based on probability distribution over vocab tokens. Please check available subclass samplers for examples. If you need more control over the sampling process, please implement `sample()` method - instead, see `BeamSampler` for example. + instead, see `keras_nlp.samplers.BeamSampler` for example. Examples: From 728a47158b31da618e3a2f1400524afa2fadc261 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 27 Jan 2023 13:50:40 -0800 Subject: [PATCH 28/28] minor fix --- keras_nlp/samplers/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 5a849b7b49..dde87a6613 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -61,7 +61,7 @@ class Sampler: gets the next token based on probability distribution over vocab tokens. Please check available subclass samplers for examples. If you need more control over the sampling process, please implement `sample()` method - instead, see `keras_nlp.samplers.BeamSampler` for example. + instead, see `keras_nlp.samplers.BeamSampler` for examples. Examples: