Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras_nlp import layers
from keras_nlp import metrics
from keras_nlp import models
from keras_nlp import samplers
from keras_nlp import tokenizers
from keras_nlp import utils

Expand Down
83 changes: 83 additions & 0 deletions keras_nlp/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow import keras

from keras_nlp.samplers.greedy import Greedy


def serialize(sampler):
return keras.utils.serialize_keras_object(sampler)


def deserialize(config, custom_objects=None):
"""Return a `Sampler` object from its config."""
all_classes = {
"greedy": Greedy,
}
return keras.utils.deserialize_keras_object(
config,
module_objects=all_classes,
custom_objects=custom_objects,
printable_module_name="samplers",
)


def get(identifier):
"""Retrieve a KerasNLP sampler by the identifier.

The `identifier` may be the string name of a sampler class or class.

>>> identifier = 'greedy'
>>> sampler = keras_nlp.samplers.get(identifier)

You can also specify `config` of the sampler to this function by passing
dict containing `class_name` and `config` as an identifier. Also note that
the `class_name` must map to a `Sampler` class.

>>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}}
>>> sampler = keras_nlp.samplers.get(cfg)

In the case that the `identifier` is a class, this method will return a new
instance of the class by its constructor.

Args:
identifier: String or dict that contains the sampler name or
configurations.

Returns:
Sampler instance base on the input identifier.

Raises:
ValueError: If the input identifier is not a supported type or in a bad
format.
"""

if identifier is None:
return None
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, str):
if not identifier.islower():
raise KeyError(
"`keras_nlp.samplers.get()` must take a lowercase string "
f"identifier, but received: {identifier}."
)
return deserialize(identifier)
elif callable(identifier):
return identifier
else:
raise ValueError(
"Could not interpret sampler identifier: " + str(identifier)
)
149 changes: 149 additions & 0 deletions keras_nlp/samplers/greedy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Greedy Sampler."""

import tensorflow as tf
from tensorflow import keras

from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import base_sampler_args_docstring
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.samplers.sampler import sample_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(
base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring
)
@keras.utils.register_keras_serializable(package="keras_nlp")
class Greedy(Sampler):
"""Greedy sampler class.

This sampler is implemented on greedy search, i.e., always picking up the
token of the largest probability as the next token.

Args:
{{base_sampler_args}}

Call Args:
{{call_args}}

Examples:
```python
BATCH_SIZE = 8
VOCAB_SIZE = 10
FEATURE_SIZE = 16
START_ID = 1

# Create a dummy model to predict the next token.
model = keras.Sequential(
[
keras.Input(shape=[None]),
keras.layers.Embedding(
input_dim=VOCAB_SIZE,
output_dim=FEATURE_SIZE,
),
keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
]
)

# Define a function that outputs the next token's probability for each token
# in the input sequence.
def token_probability_fn(inputs, mask):
return model(inputs)

prompt = tf.fill((BATCH_SIZE, 1), START_ID)

sampler = keras_nlp.samplers.Greedy()
# Print the generated sequence (token ids).
print(sampler(prompt, token_probability_fn, 10))
```
"""

def __init__(
self,
jit_compile=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give that we talked about moving compilation to model.generate, I think we can remove all the jit_compile stuff here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we cannot unfortunately, not all of the sampler's __call__ is XLA-compatible. The prompt preprocessing part, because it changes the shape, is not XLA-compatible. The XLA-compatible part is the sample method, which is the part actually benefits from XLA.

):
super().__init__(jit_compile)

@format_docstring(sample_args=sample_args_docstring)
def sample(
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
):
"""Sampling logic implementation.

Args:
{{sample_args}}
"""
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
max_length = tf.cast(max_length, num_steps.dtype)
# The index of the last non-padding token in prompt. Since all sequences
# are aligned to the right side, the index is the same for all.
current_index = max_length - num_steps

def one_step(current_index, prompt, mask):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we factor out one_step like train_step in keras.Model? Might improve encapsulation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually a good idea, the issue is the customization of child samplers classes happen at sample() level, because before the while_loop, there are some custom code required, e.g., beam needs to constructs the beam before the loop.

We can expose an abstract method sample_step() in base class as well, in which case we will have two abstract methods sample() and sample_step() to override, which looks a bit redundant.

probs = token_probability_fn(prompt, mask)
next_token_prob = tf.gather(
probs,
tf.repeat(current_index - 1, batch_size),
axis=1,
batch_dims=1,
)
next_token = tf.cast(
tf.argmax(next_token_prob, axis=-1), dtype=prompt.dtype
)
next_token = tf.where(
mask[:, current_index], prompt[:, current_index], next_token
)
mask = tf.tensor_scatter_nd_update(
tensor=mask,
indices=tf.stack(
(
tf.cast(
tf.range(batch_size), dtype=current_index.dtype
),
tf.repeat(current_index, batch_size),
),
axis=1,
),
updates=tf.repeat(True, batch_size),
)

# Append the next token to current sequence.
prompt = tf.tensor_scatter_nd_update(
tensor=prompt,
indices=tf.stack(
(
tf.cast(
tf.range(batch_size), dtype=current_index.dtype
),
tf.repeat(current_index, batch_size),
),
axis=1,
),
updates=next_token,
)

current_index = tf.add(current_index, 1)
return (current_index, prompt, mask)

# Run a while loop till `max_length` of tokens has been generated.
current_index, prompt, mask = tf.while_loop(
cond=lambda current_index, prompt, mask: tf.less(
current_index, max_length
),
body=one_step,
loop_vars=(current_index, prompt, mask),
)
return prompt
128 changes: 128 additions & 0 deletions keras_nlp/samplers/greedy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Greedy sampler."""

import tensorflow as tf
from absl.testing import parameterized
from tensorflow import keras

from keras_nlp.samplers.greedy import Greedy


class GreedyTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.vocab_size = 10
self.feature_size = 16

# Create a dummy model to predict the next token.
model = keras.Sequential(
[
keras.Input(shape=[None]),
keras.layers.Embedding(
input_dim=self.vocab_size,
output_dim=self.feature_size,
),
keras.layers.Dense(self.vocab_size),
keras.layers.Softmax(),
]
)

def token_probability_fn(inputs, mask):
return model(inputs)

self.token_probability_fn = token_probability_fn

self.sampler = Greedy()

def test_generate_with_1d_prompt(self):
inputs = tf.constant([1])
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [5])

def test_generate_with_2d_prompt(self):
inputs = tf.constant([[1], [1]])
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_generate_with_list_prompt(self):
inputs = [[1], [1]]
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_generate_with_ragged_prompt(self):
max_length = 5

def token_probability_fn(inputs, mask):
# Assert that user function is passed only dense tensors.
self.assertIsInstance(inputs, tf.Tensor)
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1)

inputs = tf.ragged.constant([[1], [2, 1, 2]])
outputs = self.sampler(inputs, token_probability_fn, max_length)
self.assertEqual(outputs.shape, [2, 5])

def test_assert_generation_is_correct(self):
batch_size = 10
max_length = 3

def token_probability_fn(inputs, mask):
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.repeat(
tf.repeat(prob, batch_size, axis=0), max_length, axis=1
)

inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
outputs = self.sampler(
inputs, token_probability_fn, max_length=max_length
)
self.assertAllEqual(
outputs, 3 * tf.ones(shape=[batch_size, max_length])
)

def test_end_token_id(self):
max_length = 5

def token_probability_fn(inputs, mask):
batch_size = inputs.shape[0]
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.repeat(
tf.repeat(prob, batch_size, axis=0), max_length, axis=1
)

sampler = Greedy()
inputs = tf.constant([[0, 1], [1, 2]])
outputs = sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=2,
)
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
self.assertAllEqual(outputs, expected_outputs)

def test_compare_xla_noxla_results(self):
inputs = [[1], [1]]
xla_sampler = Greedy(jit_compile=True)
outputs_xla = xla_sampler(
inputs, self.token_probability_fn, max_length=5
)

xla_sampler = Greedy(jit_compile=False)
outputs_no_xla = xla_sampler(
inputs, self.token_probability_fn, max_length=5
)

self.assertAllEqual(outputs_xla, outputs_no_xla)
Loading