Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions keras_nlp/samplers/beam_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,25 @@ def token_probability_fn(inputs, mask):
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.tile(prob, [batch_size, seq_length, 1])

max_length = 5
max_length = 4
inputs = tf.constant([[0, 1], [1, 2]])
outputs = self.sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=2,
)
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
# end_token in prompt does not trigger truncation.
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
self.assertAllEqual(outputs, expected_outputs)

max_length = 4
inputs = tf.constant([[0, 1], [1, 3]])
outputs = self.sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=3,
)
expected_outputs = tf.ragged.constant([[0, 1], [1, 3]])
self.assertAllEqual(outputs, expected_outputs)
16 changes: 13 additions & 3 deletions keras_nlp/samplers/greedy_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,14 @@ def token_probability_fn(inputs, mask):
)

def test_end_token_id(self):
max_length = 5

def token_probability_fn(inputs, mask):
batch_size = inputs.shape[0]
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.repeat(
tf.repeat(prob, batch_size, axis=0), max_length, axis=1
)

max_length = 4
sampler = GreedySampler()
inputs = tf.constant([[0, 1], [1, 2]])
outputs = sampler(
Expand All @@ -110,7 +109,18 @@ def token_probability_fn(inputs, mask):
max_length=max_length,
end_token_id=2,
)
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
# end_token in prompt does not trigger truncation.
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
self.assertAllEqual(outputs, expected_outputs)

outputs = sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=3,
)
# Generated end_token will be truncated.
expected_outputs = tf.ragged.constant([[0, 1], [1, 2]])
self.assertAllEqual(outputs, expected_outputs)

def test_compare_xla_noxla_results(self):
Expand Down
15 changes: 10 additions & 5 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,18 @@ def _pad_prompt(self, prompt, max_length):

def _mask_tokens_after_end_token(
self,
prompt,
generated_result,
original_padding_mask,
max_length,
end_token_id,
):
"""Helper function to truncate the tokens after the end token."""
# Mask out tokens after `end_token_id` is encountered.
# Create a tensor with True for each end_token_id.
end_tokens = generated_result == end_token_id
# Remove all end_token_ids in the original input.
end_tokens = end_tokens & (original_padding_mask == tf.constant(False))
# Find index of first end_token_id.
end_indices = tf.math.argmax(prompt == end_token_id, -1)
end_indices = tf.math.argmax(end_tokens, -1)
# Use max_length if no `end_token_id` is found.
end_indices = tf.where(
end_indices == 0,
Expand All @@ -205,7 +209,7 @@ def _mask_tokens_after_end_token(
)
# Truncate out tokens after (including) the end token.
mask_indices = tf.sequence_mask(end_indices, maxlen=max_length)
return tf.ragged.boolean_mask(prompt, mask_indices)
return tf.ragged.boolean_mask(generated_result, mask_indices)

def __call__(
self,
Expand All @@ -217,7 +221,6 @@ def __call__(
from_logits=True,
):
prompt, mask = self._validate_prompt_and_mask(prompt, mask)

input_is_1d = prompt.shape.rank == 1
if input_is_1d:
prompt = tf.RaggedTensor.from_tensor(prompt[tf.newaxis, :])
Expand All @@ -228,6 +231,7 @@ def __call__(
# static shape, which means we cannot concatenate generated token to
# current prompt.
prompt, mask = self._pad_prompt(prompt, max_length)
original_padding_mask = tf.identity(mask)
self._validate_token_probability_fn(token_probability_fn, prompt, mask)

# Convert `sample` method to a `tf.function` if `self.run_eagerly=False`
Expand All @@ -247,6 +251,7 @@ def __call__(
if end_token_id is not None:
prompt = self._mask_tokens_after_end_token(
prompt,
original_padding_mask,
max_length,
end_token_id,
)
Expand Down
19 changes: 15 additions & 4 deletions keras_nlp/samplers/top_k_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,28 @@ def token_probability_fn(inputs, mask):
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.tile(prob, [batch_size, seq_length, 1])

max_length = 5
inputs = tf.constant([[0, 1], [1, 2]])
tf.random.set_seed(42)
sampler = TopKSampler(k=4, seed=42)
max_length = 4
inputs = tf.constant([[0, 1], [1, 2]])
outputs = sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=2,
from_logits=False,
)
# Top-k sampling result with seed 42.
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
# end_token in prompt does not trigger truncation.
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
self.assertAllEqual(outputs, expected_outputs)

outputs = sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=3,
from_logits=False,
)
# Generated end_token will be truncated.
expected_outputs = tf.ragged.constant([[0, 1], [1, 2]])
self.assertAllEqual(outputs, expected_outputs)
19 changes: 15 additions & 4 deletions keras_nlp/samplers/top_p_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,28 @@ def token_probability_fn(inputs, mask):
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.tile(prob, [batch_size, seq_length, 1])

max_length = 5
inputs = tf.constant([[0, 1], [1, 2]])
tf.random.set_seed(42)
sampler = TopPSampler(p=0.1, seed=42)
max_length = 4
inputs = tf.constant([[0, 1], [1, 2]])
outputs = sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=2,
from_logits=False,
)
# Top-p sampling result with seed 42.
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
# end_token in prompt does not trigger truncation.
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
self.assertAllEqual(outputs, expected_outputs)

outputs = sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=3,
from_logits=False,
)
# Generated end_token will be truncated.
expected_outputs = tf.ragged.constant([[0, 1], [1, 2]])
self.assertAllEqual(outputs, expected_outputs)