Skip to content

Commit e6483a4

Browse files
Change padding to left to right
1 parent c53b4a9 commit e6483a4

File tree

1 file changed

+10
-30
lines changed

1 file changed

+10
-30
lines changed

keras_nlp/samplers/sampler.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""Base sampler class."""
1515

1616
import tensorflow as tf
17-
from tensorflow import keras
1817

1918

2019
class Sampler:
@@ -87,30 +86,13 @@ def _validate_token_probability_fn(
8786
"[batch_size, sequence_length, vocab_size]."
8887
)
8988

90-
def _align_and_pad_prompt(self, prompt, max_length, pad_token_id):
91-
"""Align prompt to the right side, and pad to `max_length`."""
92-
longest_prompt_len = tf.reduce_max(prompt.row_lengths())
93-
pad_length = longest_prompt_len - prompt.row_lengths()
94-
95-
prompt = keras.utils.pad_sequences(
96-
prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id
97-
)
98-
99-
mask = tf.RaggedTensor.from_row_lengths(
100-
tf.zeros(shape=[tf.reduce_sum(pad_length)], dtype=tf.int32),
101-
pad_length,
102-
)
103-
mask = mask.to_tensor(shape=(None, longest_prompt_len), default_value=1)
104-
105-
shape = prompt.shape
106-
extra_space = tf.math.maximum(0, max_length - shape[1])
107-
pad_shape = [shape[0], extra_space]
108-
109-
mask = tf.concat((mask, tf.zeros(pad_shape, tf.int32)), axis=1)
110-
prompt = tf.concat(
111-
(prompt, tf.zeros(pad_shape, prompt.dtype) + pad_token_id), axis=1
89+
def _pad_prompt(self, prompt, max_length, pad_token_id):
90+
"""Pad prompt to `max_length`."""
91+
mask = tf.ones_like(prompt, dtype=tf.bool)
92+
mask = mask.to_tensor(shape=(None, max_length))
93+
prompt = prompt.to_tensor(
94+
shape=(None, max_length), default_value=pad_token_id
11295
)
113-
mask = tf.cast(mask, dtype=tf.bool)
11496
return prompt, mask
11597

11698
def _mask_tokens_after_end_token(
@@ -141,21 +123,19 @@ def __call__(self, token_probability_fn, prompt, max_length):
141123
prompt = tf.RaggedTensor.from_tensor(
142124
prompt, padding=self.pad_token_id
143125
)
144-
longest_prompt_len = tf.reduce_max(prompt.row_lengths())
126+
shortest_prompt_len = tf.reduce_min(prompt.row_lengths())
145127
# Pad prompt to be a dense Tensor of shape [batch_size, max_length].
146128
# This step is required for XLA compatibility because XLA requires a
147129
# static shape, which means we cannot concatenate generated token to
148130
# current prompt.
149-
prompt, mask = self._align_and_pad_prompt(
150-
prompt, max_length, self.pad_token_id
151-
)
131+
prompt, mask = self._pad_prompt(prompt, max_length, self.pad_token_id)
152132
self._validate_token_probability_fn(token_probability_fn, prompt, mask)
153133

154134
# Convert `sample` method to a `tf.function`, and turn on
155135
# `jit_compile` accordingly.
156136
sample = tf.function(self.sample, jit_compile=self.jit_compile)
157137
prompt = sample(
158-
token_probability_fn, prompt, mask, max_length - longest_prompt_len
138+
token_probability_fn, prompt, mask, max_length - shortest_prompt_len
159139
)
160140

161141
# Mask out tokens after `end_token_id`.
@@ -206,7 +186,7 @@ def sample(self, token_probability_fn, prompt, mask, num_steps):
206186
Sampler.__doc__ = Sampler.__doc__.replace(
207187
"{{base_sampler_keyword_args}}", base_sampler_keyword_args
208188
)
209-
Sampler.__doc__ = Sampler.__call__.__doc__.replace(
189+
Sampler.__doc__ = Sampler.__doc__.replace(
210190
"{{call_keyword_docstring}}", call_keyword_docstring
211191
)
212192
Sampler.sample.__doc__ = Sampler.sample.__doc__.replace(

0 commit comments

Comments
 (0)