Skip to content

Commit f7d51f7

Browse files
Fix the truncation strategy
1 parent c35325a commit f7d51f7

File tree

6 files changed

+77
-22
lines changed

6 files changed

+77
-22
lines changed

keras_nlp/layers/masked_lm_mask_generator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,7 @@ def call(self, inputs):
147147
# convert dense to ragged.
148148
inputs = tf.RaggedTensor.from_tensor(inputs)
149149

150-
(
151-
token_ids,
152-
mask_positions,
153-
mask_ids,
154-
) = tf_text.mask_language_model(
150+
(token_ids, mask_positions, mask_ids,) = tf_text.mask_language_model(
155151
inputs,
156152
item_selector=self._random_selector,
157153
mask_values_chooser=self._mask_values_chooser,

keras_nlp/samplers/beam_sampler_test.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,25 @@ def token_probability_fn(inputs, mask):
125125
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
126126
return tf.tile(prob, [batch_size, seq_length, 1])
127127

128-
max_length = 5
128+
max_length = 4
129129
inputs = tf.constant([[0, 1], [1, 2]])
130130
outputs = self.sampler(
131131
inputs,
132132
token_probability_fn,
133133
max_length=max_length,
134134
end_token_id=2,
135135
)
136-
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
136+
# end_token in prompt does not trigger truncation.
137+
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
138+
self.assertAllEqual(outputs, expected_outputs)
139+
140+
max_length = 4
141+
inputs = tf.constant([[0, 1], [1, 3]])
142+
outputs = self.sampler(
143+
inputs,
144+
token_probability_fn,
145+
max_length=max_length,
146+
end_token_id=3,
147+
)
148+
expected_outputs = tf.ragged.constant([[0, 1], [1, 3]])
137149
self.assertAllEqual(outputs, expected_outputs)

keras_nlp/samplers/greedy_sampler_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,14 @@ def token_probability_fn(inputs, mask):
9393
)
9494

9595
def test_end_token_id(self):
96-
max_length = 5
97-
9896
def token_probability_fn(inputs, mask):
9997
batch_size = inputs.shape[0]
10098
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
10199
return tf.repeat(
102100
tf.repeat(prob, batch_size, axis=0), max_length, axis=1
103101
)
104102

103+
max_length = 4
105104
sampler = GreedySampler()
106105
inputs = tf.constant([[0, 1], [1, 2]])
107106
outputs = sampler(
@@ -110,7 +109,18 @@ def token_probability_fn(inputs, mask):
110109
max_length=max_length,
111110
end_token_id=2,
112111
)
113-
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
112+
# end_token in prompt does not trigger truncation.
113+
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
114+
self.assertAllEqual(outputs, expected_outputs)
115+
116+
outputs = sampler(
117+
inputs,
118+
token_probability_fn,
119+
max_length=max_length,
120+
end_token_id=3,
121+
)
122+
# Generated end_token will be truncated.
123+
expected_outputs = tf.ragged.constant([[0, 1], [1, 2]])
114124
self.assertAllEqual(outputs, expected_outputs)
115125

116126
def test_compare_xla_noxla_results(self):

keras_nlp/samplers/sampler.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,27 @@ def _pad_prompt(self, prompt, max_length):
189189

190190
def _mask_tokens_after_end_token(
191191
self,
192-
prompt,
192+
generated_result,
193+
original_prompt,
193194
max_length,
194195
end_token_id,
195196
):
196197
"""Helper function to truncate the tokens after the end token."""
198+
# Exclude original prompts from being truncated.
199+
# Add a big int to tokens in original prompt to move token ids over
200+
# the vocab size. Vocab size should not be large as 1e9.
201+
increment_value = int(1e9)
202+
increment = tf.cast(
203+
tf.zeros_like(original_prompt) + increment_value,
204+
dtype=tf.int32,
205+
)
206+
if not isinstance(increment, tf.RaggedTensor):
207+
increment = tf.RaggedTensor.from_tensor(increment)
208+
increment = increment.to_tensor(shape=tf.shape(generated_result))
209+
generated_result += increment
197210
# Mask out tokens after `end_token_id` is encountered.
198211
# Find index of first end_token_id.
199-
end_indices = tf.math.argmax(prompt == end_token_id, -1)
212+
end_indices = tf.math.argmax(generated_result == end_token_id, -1)
200213
# Use max_length if no `end_token_id` is found.
201214
end_indices = tf.where(
202215
end_indices == 0,
@@ -205,7 +218,9 @@ def _mask_tokens_after_end_token(
205218
)
206219
# Truncate out tokens after (including) the end token.
207220
mask_indices = tf.sequence_mask(end_indices, maxlen=max_length)
208-
return tf.ragged.boolean_mask(prompt, mask_indices)
221+
# Revert the increment added earlier.
222+
generated_result -= increment
223+
return tf.ragged.boolean_mask(generated_result, mask_indices)
209224

210225
def __call__(
211226
self,
@@ -217,10 +232,10 @@ def __call__(
217232
from_logits=True,
218233
):
219234
prompt, mask = self._validate_prompt_and_mask(prompt, mask)
220-
221235
input_is_1d = prompt.shape.rank == 1
222236
if input_is_1d:
223237
prompt = tf.RaggedTensor.from_tensor(prompt[tf.newaxis, :])
238+
original_prompt = tf.identity(prompt)
224239

225240
shortest_prompt_len = tf.reduce_min(prompt.row_lengths())
226241
# Pad prompt to be a dense Tensor of shape [batch_size, max_length].
@@ -247,6 +262,7 @@ def __call__(
247262
if end_token_id is not None:
248263
prompt = self._mask_tokens_after_end_token(
249264
prompt,
265+
original_prompt,
250266
max_length,
251267
end_token_id,
252268
)

keras_nlp/samplers/top_k_sampler_test.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,27 @@ def token_probability_fn(inputs, mask):
144144
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
145145
return tf.tile(prob, [batch_size, seq_length, 1])
146146

147-
max_length = 5
148-
inputs = tf.constant([[0, 1], [1, 2]])
149147
tf.random.set_seed(42)
150148
sampler = TopKSampler(k=4, seed=42)
149+
max_length = 4
150+
inputs = tf.constant([[0, 1], [1, 2]])
151151
outputs = sampler(
152152
inputs,
153153
token_probability_fn,
154154
max_length=max_length,
155155
end_token_id=2,
156156
from_logits=False,
157157
)
158-
# Top-k sampling result with seed 42.
159-
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
158+
# end_token in prompt does not trigger truncation.
159+
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
160+
self.assertAllEqual(outputs, expected_outputs)
161+
162+
outputs = sampler(
163+
inputs,
164+
token_probability_fn,
165+
max_length=max_length,
166+
end_token_id=3,
167+
)
168+
# Generated end_token will be truncated.
169+
expected_outputs = tf.ragged.constant([[0, 1], [1, 2]])
160170
self.assertAllEqual(outputs, expected_outputs)

keras_nlp/samplers/top_p_sampler_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,28 @@ def token_probability_fn(inputs, mask):
144144
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
145145
return tf.tile(prob, [batch_size, seq_length, 1])
146146

147-
max_length = 5
148-
inputs = tf.constant([[0, 1], [1, 2]])
149147
tf.random.set_seed(42)
150148
sampler = TopPSampler(p=0.1, seed=42)
149+
max_length = 4
150+
inputs = tf.constant([[0, 1], [1, 2]])
151151
outputs = sampler(
152152
inputs,
153153
token_probability_fn,
154154
max_length=max_length,
155155
end_token_id=2,
156156
from_logits=False,
157157
)
158-
# Top-p sampling result with seed 42.
159-
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
158+
# end_token in prompt does not trigger truncation.
159+
expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]])
160+
self.assertAllEqual(outputs, expected_outputs)
161+
162+
outputs = sampler(
163+
inputs,
164+
token_probability_fn,
165+
max_length=max_length,
166+
end_token_id=3,
167+
from_logits=False,
168+
)
169+
# Generated end_token will be truncated.
170+
expected_outputs = tf.ragged.constant([[0, 1], [1, 2]])
160171
self.assertAllEqual(outputs, expected_outputs)

0 commit comments

Comments
 (0)