|
14 | 14 | """Base sampler class.""" |
15 | 15 |
|
16 | 16 | import tensorflow as tf |
17 | | -from tensorflow import keras |
18 | 17 |
|
19 | 18 |
|
20 | 19 | class Sampler: |
@@ -87,30 +86,13 @@ def _validate_token_probability_fn( |
87 | 86 | "[batch_size, sequence_length, vocab_size]." |
88 | 87 | ) |
89 | 88 |
|
90 | | - def _align_and_pad_prompt(self, prompt, max_length, pad_token_id): |
91 | | - """Align prompt to the right side, and pad to `max_length`.""" |
92 | | - longest_prompt_len = tf.reduce_max(prompt.row_lengths()) |
93 | | - pad_length = longest_prompt_len - prompt.row_lengths() |
94 | | - |
95 | | - prompt = keras.utils.pad_sequences( |
96 | | - prompt.to_list(), maxlen=longest_prompt_len, value=pad_token_id |
97 | | - ) |
98 | | - |
99 | | - mask = tf.RaggedTensor.from_row_lengths( |
100 | | - tf.zeros(shape=[tf.reduce_sum(pad_length)], dtype=tf.int32), |
101 | | - pad_length, |
102 | | - ) |
103 | | - mask = mask.to_tensor(shape=(None, longest_prompt_len), default_value=1) |
104 | | - |
105 | | - shape = prompt.shape |
106 | | - extra_space = tf.math.maximum(0, max_length - shape[1]) |
107 | | - pad_shape = [shape[0], extra_space] |
108 | | - |
109 | | - mask = tf.concat((mask, tf.zeros(pad_shape, tf.int32)), axis=1) |
110 | | - prompt = tf.concat( |
111 | | - (prompt, tf.zeros(pad_shape, prompt.dtype) + pad_token_id), axis=1 |
| 89 | + def _pad_prompt(self, prompt, max_length, pad_token_id): |
| 90 | + """Pad prompt to `max_length`.""" |
| 91 | + mask = tf.ones_like(prompt, dtype=tf.bool) |
| 92 | + mask = mask.to_tensor(shape=(None, max_length)) |
| 93 | + prompt = prompt.to_tensor( |
| 94 | + shape=(None, max_length), default_value=pad_token_id |
112 | 95 | ) |
113 | | - mask = tf.cast(mask, dtype=tf.bool) |
114 | 96 | return prompt, mask |
115 | 97 |
|
116 | 98 | def _mask_tokens_after_end_token( |
@@ -141,21 +123,19 @@ def __call__(self, token_probability_fn, prompt, max_length): |
141 | 123 | prompt = tf.RaggedTensor.from_tensor( |
142 | 124 | prompt, padding=self.pad_token_id |
143 | 125 | ) |
144 | | - longest_prompt_len = tf.reduce_max(prompt.row_lengths()) |
| 126 | + shortest_prompt_len = tf.reduce_min(prompt.row_lengths()) |
145 | 127 | # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. |
146 | 128 | # This step is required for XLA compatibility because XLA requires a |
147 | 129 | # static shape, which means we cannot concatenate generated token to |
148 | 130 | # current prompt. |
149 | | - prompt, mask = self._align_and_pad_prompt( |
150 | | - prompt, max_length, self.pad_token_id |
151 | | - ) |
| 131 | + prompt, mask = self._pad_prompt(prompt, max_length, self.pad_token_id) |
152 | 132 | self._validate_token_probability_fn(token_probability_fn, prompt, mask) |
153 | 133 |
|
154 | 134 | # Convert `sample` method to a `tf.function`, and turn on |
155 | 135 | # `jit_compile` accordingly. |
156 | 136 | sample = tf.function(self.sample, jit_compile=self.jit_compile) |
157 | 137 | prompt = sample( |
158 | | - token_probability_fn, prompt, mask, max_length - longest_prompt_len |
| 138 | + token_probability_fn, prompt, mask, max_length - shortest_prompt_len |
159 | 139 | ) |
160 | 140 |
|
161 | 141 | # Mask out tokens after `end_token_id`. |
@@ -206,7 +186,7 @@ def sample(self, token_probability_fn, prompt, mask, num_steps): |
206 | 186 | Sampler.__doc__ = Sampler.__doc__.replace( |
207 | 187 | "{{base_sampler_keyword_args}}", base_sampler_keyword_args |
208 | 188 | ) |
209 | | -Sampler.__doc__ = Sampler.__call__.__doc__.replace( |
| 189 | +Sampler.__doc__ = Sampler.__doc__.replace( |
210 | 190 | "{{call_keyword_docstring}}", call_keyword_docstring |
211 | 191 | ) |
212 | 192 | Sampler.sample.__doc__ = Sampler.sample.__doc__.replace( |
|
0 commit comments