Skip to content

Commit 31e9d39

Browse files
committed
Add a settable property for sequence_length
For all preprocessors, we add a settable sequence length property. This removes a very annoying gotcha with preprocessing, where if you set the preprocessing length after build, it would not be reflected in the actual output length.
1 parent 0de100b commit 31e9d39

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+320
-84
lines changed

keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def setUp(self):
4343
self.input_data = ["the quick brown fox"]
4444

4545
def test_preprocessor_basics(self):
46-
self.run_preprocessing_layer_test(
46+
self.run_preprocessor_test(
4747
cls=AlbertMaskedLMPreprocessor,
4848
init_kwargs=self.init_kwargs,
4949
input_data=self.input_data,

keras_nlp/models/albert/albert_preprocessor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ def __init__(
158158
):
159159
super().__init__(**kwargs)
160160
self.tokenizer = tokenizer
161+
self.packer = None
161162
self.truncate = truncate
162163
self.sequence_length = sequence_length
163-
self.packer = None
164164

165165
def build(self, input_shape):
166166
# Defer packer creation to `build()` so that we can be sure tokenizer
@@ -195,6 +195,17 @@ def call(self, x, y=None, sample_weight=None):
195195
}
196196
return pack_x_y_sample_weight(x, y, sample_weight)
197197

198+
@property
199+
def sequence_length(self):
200+
"""The padded length of model input sequences."""
201+
return self._sequence_length
202+
203+
@sequence_length.setter
204+
def sequence_length(self, value):
205+
self._sequence_length = value
206+
if self.packer is not None:
207+
self.packer.sequence_length = value
208+
198209
@classproperty
199210
def tokenizer_cls(cls):
200211
return AlbertTokenizer

keras_nlp/models/albert/albert_preprocessor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def setUp(self):
4040
)
4141

4242
def test_preprocessor_basics(self):
43-
self.run_preprocessing_layer_test(
43+
self.run_preprocessor_test(
4444
cls=AlbertPreprocessor,
4545
init_kwargs=self.init_kwargs,
4646
input_data=self.input_data,

keras_nlp/models/bart/bart_preprocessor.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,10 @@ def __init__(
140140
):
141141
super().__init__(**kwargs)
142142
self.tokenizer = tokenizer
143-
self.encoder_sequence_length = encoder_sequence_length
144-
self.decoder_sequence_length = decoder_sequence_length
145143
self.encoder_packer = None
146144
self.decoder_packer = None
145+
self.encoder_sequence_length = encoder_sequence_length
146+
self.decoder_sequence_length = decoder_sequence_length
147147

148148
def build(self, input_shape):
149149
# Defer packer creation to `build()` so that we can be sure tokenizer
@@ -174,7 +174,17 @@ def build(self, input_shape):
174174
)
175175
self.built = True
176176

177-
def call(self, x, y=None, sample_weight=None):
177+
def call(
178+
self,
179+
x,
180+
y=None,
181+
sample_weight=None,
182+
*,
183+
encoder_sequence_length=None,
184+
decoder_sequence_length=None,
185+
# `sequence_length` is an alias for `decoder_sequence_length`
186+
sequence_length=None,
187+
):
178188
if not (
179189
isinstance(x, dict)
180190
and all(k in x for k in ("encoder_text", "decoder_text"))
@@ -184,6 +194,12 @@ def call(self, x, y=None, sample_weight=None):
184194
f' and `"decoder_text"`. Received x={x}.'
185195
)
186196

197+
if encoder_sequence_length is None:
198+
encoder_sequence_length = self.encoder_sequence_length
199+
decoder_sequence_length = decoder_sequence_length or sequence_length
200+
if decoder_sequence_length is None:
201+
decoder_sequence_length = self.decoder_sequence_length
202+
187203
encoder_text = x["encoder_text"]
188204
decoder_text = x["decoder_text"]
189205

@@ -199,12 +215,14 @@ def call(self, x, y=None, sample_weight=None):
199215

200216
encoder_inputs = self.tokenizer(encoder_text[0])
201217
encoder_token_ids, encoder_padding_mask = self.encoder_packer(
202-
encoder_inputs
218+
encoder_inputs,
219+
sequence_length=encoder_sequence_length,
203220
)
204221

205222
decoder_inputs = self.tokenizer(decoder_text[0])
206223
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
207-
decoder_inputs
224+
decoder_inputs,
225+
sequence_length=decoder_sequence_length,
208226
)
209227

210228
x = {
@@ -226,6 +244,37 @@ def get_config(self):
226244
)
227245
return config
228246

247+
@property
248+
def encoder_sequence_length(self):
249+
"""The padded length of encoder input sequences."""
250+
return self._encoder_sequence_length
251+
252+
@encoder_sequence_length.setter
253+
def encoder_sequence_length(self, value):
254+
self._encoder_sequence_length = value
255+
if self.encoder_packer is not None:
256+
self.encoder_packer.sequence_length = value
257+
258+
@property
259+
def decoder_sequence_length(self):
260+
"""The padded length of decoder input sequences."""
261+
return self._decoder_sequence_length
262+
263+
@decoder_sequence_length.setter
264+
def decoder_sequence_length(self, value):
265+
self._decoder_sequence_length = value
266+
if self.decoder_packer is not None:
267+
self.decoder_packer.sequence_length = value
268+
269+
@property
270+
def sequence_length(self):
271+
"""Alias for `decoder_sequence_length`."""
272+
return self.decoder_sequence_length
273+
274+
@sequence_length.setter
275+
def sequence_length(self, value):
276+
self.decoder_sequence_length = value
277+
229278
@classproperty
230279
def tokenizer_cls(cls):
231280
return BartTokenizer

keras_nlp/models/bart/bart_preprocessor_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def setUp(self):
4646
)
4747

4848
def test_preprocessor_basics(self):
49-
self.run_preprocessing_layer_test(
49+
self.run_preprocessor_test(
5050
cls=BartPreprocessor,
5151
init_kwargs=self.init_kwargs,
5252
input_data=self.input_data,
@@ -60,6 +60,7 @@ def test_preprocessor_basics(self):
6060
[1], # Pass through labels.
6161
[1.0], # Pass through sample_weights.
6262
),
63+
token_id_key="decoder_token_ids",
6364
)
6465

6566
def test_error_multi_segment_input(self):

keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -124,28 +124,17 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor):
124124
```
125125
"""
126126

127-
def __init__(
127+
def call(
128128
self,
129-
tokenizer,
130-
encoder_sequence_length=1024,
131-
decoder_sequence_length=1024,
132-
**kwargs
129+
x,
130+
y=None,
131+
sample_weight=None,
132+
*,
133+
encoder_sequence_length=None,
134+
decoder_sequence_length=None,
135+
# `sequence_length` is an alias for `decoder_sequence_length`
136+
sequence_length=None,
133137
):
134-
# Since we truncate the last token from `decoder_token_ids`, we need to
135-
# forcefully set the `decoder_sequence_length` to one greater than the
136-
# value passed.
137-
super().__init__(
138-
tokenizer=tokenizer,
139-
encoder_sequence_length=encoder_sequence_length,
140-
decoder_sequence_length=decoder_sequence_length + 1,
141-
**kwargs
142-
)
143-
144-
# Maintain a private copy of the sequence lengths for config purposes.
145-
self._encoder_sequence_length = encoder_sequence_length
146-
self._decoder_sequence_length = decoder_sequence_length
147-
148-
def call(self, x, y=None, sample_weight=None):
149138
if y is not None or sample_weight is not None:
150139
logging.warning(
151140
"`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` "
@@ -154,7 +143,17 @@ def call(self, x, y=None, sample_weight=None):
154143
"These values will be ignored."
155144
)
156145

157-
x = super().call(x)
146+
if encoder_sequence_length is None:
147+
encoder_sequence_length = self.encoder_sequence_length
148+
decoder_sequence_length = decoder_sequence_length or sequence_length
149+
if decoder_sequence_length is None:
150+
decoder_sequence_length = self.decoder_sequence_length
151+
152+
x = super().call(
153+
x,
154+
encoder_sequence_length=encoder_sequence_length,
155+
decoder_sequence_length=decoder_sequence_length + 1,
156+
)
158157
decoder_token_ids = x.pop("decoder_token_ids")
159158
decoder_padding_mask = x.pop("decoder_padding_mask")
160159

@@ -173,6 +172,10 @@ def call(self, x, y=None, sample_weight=None):
173172
def generate_preprocess(
174173
self,
175174
x,
175+
*,
176+
encoder_sequence_length=None,
177+
# `sequence_length` is an alias for `decoder_sequence_length`
178+
decoder_sequence_length=None,
176179
sequence_length=None,
177180
):
178181
"""Convert encoder and decoder input strings to integer token inputs for generation.
@@ -190,10 +193,6 @@ def generate_preprocess(
190193
if not self.built:
191194
self.build(None)
192195

193-
# If `sequence_length` is not provided, we use the default value.
194-
if sequence_length is None:
195-
sequence_length = self._decoder_sequence_length
196-
197196
if isinstance(x, dict):
198197
encoder_text = x["encoder_text"]
199198
decoder_text = x["decoder_text"]
@@ -202,14 +201,21 @@ def generate_preprocess(
202201
# Initialize empty prompt for the decoder.
203202
decoder_text = tf.fill((tf.shape(encoder_text)[0],), "")
204203

204+
if encoder_sequence_length is None:
205+
encoder_sequence_length = self.encoder_sequence_length
206+
decoder_sequence_length = decoder_sequence_length or sequence_length
207+
if decoder_sequence_length is None:
208+
decoder_sequence_length = self.decoder_sequence_length
209+
205210
# Tokenize and pack the encoder inputs.
206211
# TODO: Remove `[0]` once we have shifted to `MultiSegmentPacker`.
207212
encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text)[
208213
0
209214
]
210215
encoder_token_ids = self.tokenizer(encoder_text)
211216
encoder_token_ids, encoder_padding_mask = self.encoder_packer(
212-
encoder_token_ids
217+
encoder_token_ids,
218+
sequence_length=encoder_sequence_length,
213219
)
214220

215221
# Tokenize and pack the decoder inputs.
@@ -219,7 +225,7 @@ def generate_preprocess(
219225
decoder_token_ids = self.tokenizer(decoder_text)
220226
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
221227
decoder_token_ids,
222-
sequence_length=sequence_length,
228+
sequence_length=decoder_sequence_length,
223229
add_end_value=False,
224230
)
225231

@@ -261,16 +267,6 @@ def generate_postprocess(
261267
)
262268
return self.tokenizer.detokenize(decoder_token_ids)
263269

264-
def get_config(self):
265-
config = super().get_config()
266-
config.update(
267-
{
268-
"encoder_sequence_length": self._encoder_sequence_length,
269-
"decoder_sequence_length": self._decoder_sequence_length,
270-
}
271-
)
272-
return config
273-
274270
@classproperty
275271
def presets(cls):
276272
return copy.deepcopy(backbone_presets)

keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def setUp(self):
4545
)
4646

4747
def test_preprocessor_basics(self):
48-
self.run_preprocessing_layer_test(
48+
self.run_preprocessor_test(
4949
cls=BartSeq2SeqLMPreprocessor,
5050
init_kwargs=self.init_kwargs,
5151
input_data=self.input_data,
@@ -59,6 +59,7 @@ def test_preprocessor_basics(self):
5959
[[0, 4, 5, 4, 7, 2, 1, 1]],
6060
[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]],
6161
),
62+
token_id_key="decoder_token_ids",
6263
)
6364

6465
def test_generate_preprocess(self):

keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def setUp(self):
3939
self.input_data = ["the quick brown fox"]
4040

4141
def test_preprocessor_basics(self):
42-
self.run_preprocessing_layer_test(
42+
self.run_preprocessor_test(
4343
cls=BertMaskedLMPreprocessor,
4444
init_kwargs=self.init_kwargs,
4545
input_data=self.input_data,

keras_nlp/models/bert/bert_preprocessor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ def __init__(
139139
):
140140
super().__init__(**kwargs)
141141
self.tokenizer = tokenizer
142+
self.packer = None
142143
self.sequence_length = sequence_length
143144
self.truncate = truncate
144-
self.packer = None
145145

146146
def build(self, input_shape):
147147
# Defer packer creation to `build()` so that we can be sure tokenizer
@@ -176,6 +176,17 @@ def get_config(self):
176176
)
177177
return config
178178

179+
@property
180+
def sequence_length(self):
181+
"""The padded length of model input sequences."""
182+
return self._sequence_length
183+
184+
@sequence_length.setter
185+
def sequence_length(self, value):
186+
self._sequence_length = value
187+
if self.packer is not None:
188+
self.packer.sequence_length = value
189+
179190
@classproperty
180191
def tokenizer_cls(cls):
181192
return BertTokenizer

keras_nlp/models/bert/bert_preprocessor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def setUp(self):
3636
)
3737

3838
def test_preprocessor_basics(self):
39-
self.run_preprocessing_layer_test(
39+
self.run_preprocessor_test(
4040
cls=BertPreprocessor,
4141
init_kwargs=self.init_kwargs,
4242
input_data=self.input_data,

0 commit comments

Comments
 (0)