Skip to content

Commit 10d451e

Browse files
Fix the sampler truncation strategy (#713)
* Fix the truncation strategy * address comments * fix naming
1 parent 6aba751 commit 10d451e

File tree

5 files changed

+67
-18
lines changed

5 files changed

+67
-18
lines changed

keras_nlp/samplers/beam_sampler_test.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,25 @@ def token_probability_fn(inputs, mask):
125125
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
126126
return tf.tile(prob, [batch_size, seq_length, 1])
127127

128-
max_length = 5
128+
max_length = 4
129129
inputs = tf.constant([[0, 1], [1, 2]])
130130
outputs = self.sampler(
131131
inputs,
132132
token_probability_fn,
133133
max_length=max_length,
134134
end_token_id=2,
135135
)
136-
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
136+
# end_token in prompt does not trigger truncation.
137+
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
138+
self.assertAllEqual(outputs, expected_outputs)
139+
140+
max_length = 4
141+
inputs = tf.constant([[0, 1], [1, 3]])
142+
outputs = self.sampler(
143+
inputs,
144+
token_probability_fn,
145+
max_length=max_length,
146+
end_token_id=3,
147+
)
148+
expected_outputs = tf.ragged.constant([[0, 1], [1, 3]])
137149
self.assertAllEqual(outputs, expected_outputs)

keras_nlp/samplers/greedy_sampler_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,14 @@ def token_probability_fn(inputs, mask):
9393
)
9494

9595
def test_end_token_id(self):
96-
max_length = 5
97-
9896
def token_probability_fn(inputs, mask):
9997
batch_size = inputs.shape[0]
10098
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
10199
return tf.repeat(
102100
tf.repeat(prob, batch_size, axis=0), max_length, axis=1
103101
)
104102

103+
max_length = 4
105104
sampler = GreedySampler()
106105
inputs = tf.constant([[0, 1], [1, 2]])
107106
outputs = sampler(
@@ -110,7 +109,18 @@ def token_probability_fn(inputs, mask):
110109
max_length=max_length,
111110
end_token_id=2,
112111
)
113-
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
112+
# end_token in prompt does not trigger truncation.
113+
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
114+
self.assertAllEqual(outputs, expected_outputs)
115+
116+
outputs = sampler(
117+
inputs,
118+
token_probability_fn,
119+
max_length=max_length,
120+
end_token_id=3,
121+
)
122+
# Generated end_token will be truncated.
123+
expected_outputs = tf.ragged.constant([[0, 1], [1, 2]])
114124
self.assertAllEqual(outputs, expected_outputs)
115125

116126
def test_compare_xla_noxla_results(self):

keras_nlp/samplers/sampler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,18 @@ def _pad_prompt(self, prompt, max_length):
189189

190190
def _mask_tokens_after_end_token(
191191
self,
192-
prompt,
192+
generated_result,
193+
original_padding_mask,
193194
max_length,
194195
end_token_id,
195196
):
196197
"""Helper function to truncate the tokens after the end token."""
197-
# Mask out tokens after `end_token_id` is encountered.
198+
# Create a tensor with True for each end_token_id.
199+
end_tokens = generated_result == end_token_id
200+
# Remove all end_token_ids in the original input.
201+
end_tokens = end_tokens & (original_padding_mask == tf.constant(False))
198202
# Find index of first end_token_id.
199-
end_indices = tf.math.argmax(prompt == end_token_id, -1)
203+
end_indices = tf.math.argmax(end_tokens, -1)
200204
# Use max_length if no `end_token_id` is found.
201205
end_indices = tf.where(
202206
end_indices == 0,
@@ -205,7 +209,7 @@ def _mask_tokens_after_end_token(
205209
)
206210
# Truncate out tokens after (including) the end token.
207211
mask_indices = tf.sequence_mask(end_indices, maxlen=max_length)
208-
return tf.ragged.boolean_mask(prompt, mask_indices)
212+
return tf.ragged.boolean_mask(generated_result, mask_indices)
209213

210214
def __call__(
211215
self,
@@ -217,7 +221,6 @@ def __call__(
217221
from_logits=True,
218222
):
219223
prompt, mask = self._validate_prompt_and_mask(prompt, mask)
220-
221224
input_is_1d = prompt.shape.rank == 1
222225
if input_is_1d:
223226
prompt = tf.RaggedTensor.from_tensor(prompt[tf.newaxis, :])
@@ -228,6 +231,7 @@ def __call__(
228231
# static shape, which means we cannot concatenate generated token to
229232
# current prompt.
230233
prompt, mask = self._pad_prompt(prompt, max_length)
234+
original_padding_mask = tf.identity(mask)
231235
self._validate_token_probability_fn(token_probability_fn, prompt, mask)
232236

233237
# Convert `sample` method to a `tf.function` if `self.run_eagerly=False`
@@ -247,6 +251,7 @@ def __call__(
247251
if end_token_id is not None:
248252
prompt = self._mask_tokens_after_end_token(
249253
prompt,
254+
original_padding_mask,
250255
max_length,
251256
end_token_id,
252257
)

keras_nlp/samplers/top_k_sampler_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,28 @@ def token_probability_fn(inputs, mask):
144144
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
145145
return tf.tile(prob, [batch_size, seq_length, 1])
146146

147-
max_length = 5
148-
inputs = tf.constant([[0, 1], [1, 2]])
149147
tf.random.set_seed(42)
150148
sampler = TopKSampler(k=4, seed=42)
149+
max_length = 4
150+
inputs = tf.constant([[0, 1], [1, 2]])
151151
outputs = sampler(
152152
inputs,
153153
token_probability_fn,
154154
max_length=max_length,
155155
end_token_id=2,
156156
from_logits=False,
157157
)
158-
# Top-k sampling result with seed 42.
159-
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
158+
# end_token in prompt does not trigger truncation.
159+
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
160+
self.assertAllEqual(outputs, expected_outputs)
161+
162+
outputs = sampler(
163+
inputs,
164+
token_probability_fn,
165+
max_length=max_length,
166+
end_token_id=3,
167+
from_logits=False,
168+
)
169+
# Generated end_token will be truncated.
170+
expected_outputs = tf.ragged.constant([[0, 1], [1, 2]])
160171
self.assertAllEqual(outputs, expected_outputs)

keras_nlp/samplers/top_p_sampler_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,28 @@ def token_probability_fn(inputs, mask):
144144
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
145145
return tf.tile(prob, [batch_size, seq_length, 1])
146146

147-
max_length = 5
148-
inputs = tf.constant([[0, 1], [1, 2]])
149147
tf.random.set_seed(42)
150148
sampler = TopPSampler(p=0.1, seed=42)
149+
max_length = 4
150+
inputs = tf.constant([[0, 1], [1, 2]])
151151
outputs = sampler(
152152
inputs,
153153
token_probability_fn,
154154
max_length=max_length,
155155
end_token_id=2,
156156
from_logits=False,
157157
)
158-
# Top-p sampling result with seed 42.
159-
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
158+
# end_token in prompt does not trigger truncation.
159+
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
160+
self.assertAllEqual(outputs, expected_outputs)
161+
162+
outputs = sampler(
163+
inputs,
164+
token_probability_fn,
165+
max_length=max_length,
166+
end_token_id=3,
167+
from_logits=False,
168+
)
169+
# Generated end_token will be truncated.
170+
expected_outputs = tf.ragged.constant([[0, 1], [1, 2]])
160171
self.assertAllEqual(outputs, expected_outputs)

0 commit comments

Comments
 (0)