Skip to content

Commit 582bb30

Browse files
committed
Add a settable property for sequence_length
For all preprocessors, we add a settable sequence length property. This removes a very annoying gotcha with preprocessing, where if you set the preprocessing length after build, it would not be reflected in the actual output length.
1 parent 0de100b commit 582bb30

File tree

44 files changed

+308
-86
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+308
-86
lines changed

keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818

19+
from keras_nlp.backend import ops
1920
from keras_nlp.models.albert.albert_masked_lm_preprocessor import (
2021
AlbertMaskedLMPreprocessor,
2122
)
@@ -43,7 +44,7 @@ def setUp(self):
4344
self.input_data = ["the quick brown fox"]
4445

4546
def test_preprocessor_basics(self):
46-
self.run_preprocessing_layer_test(
47+
self.run_preprocessor_test(
4748
cls=AlbertMaskedLMPreprocessor,
4849
init_kwargs=self.init_kwargs,
4950
input_data=self.input_data,

keras_nlp/models/albert/albert_preprocessor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ def __init__(
158158
):
159159
super().__init__(**kwargs)
160160
self.tokenizer = tokenizer
161+
self.packer = None
161162
self.truncate = truncate
162163
self.sequence_length = sequence_length
163-
self.packer = None
164164

165165
def build(self, input_shape):
166166
# Defer packer creation to `build()` so that we can be sure tokenizer
@@ -195,6 +195,17 @@ def call(self, x, y=None, sample_weight=None):
195195
}
196196
return pack_x_y_sample_weight(x, y, sample_weight)
197197

198+
@property
199+
def sequence_length(self):
200+
"""The padded length of model input sequences."""
201+
return self._sequence_length
202+
203+
@sequence_length.setter
204+
def sequence_length(self, value):
205+
self._sequence_length = value
206+
if self.packer is not None:
207+
self.packer.sequence_length = value
208+
198209
@classproperty
199210
def tokenizer_cls(cls):
200211
return AlbertTokenizer

keras_nlp/models/albert/albert_preprocessor_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818

19+
from keras_nlp.backend import ops
1920
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
2021
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
2122
from keras_nlp.tests.test_case import TestCase
@@ -40,7 +41,7 @@ def setUp(self):
4041
)
4142

4243
def test_preprocessor_basics(self):
43-
self.run_preprocessing_layer_test(
44+
self.run_preprocessor_test(
4445
cls=AlbertPreprocessor,
4546
init_kwargs=self.init_kwargs,
4647
input_data=self.input_data,

keras_nlp/models/bart/bart_preprocessor.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,10 @@ def __init__(
140140
):
141141
super().__init__(**kwargs)
142142
self.tokenizer = tokenizer
143-
self.encoder_sequence_length = encoder_sequence_length
144-
self.decoder_sequence_length = decoder_sequence_length
145143
self.encoder_packer = None
146144
self.decoder_packer = None
145+
self.encoder_sequence_length = encoder_sequence_length
146+
self.decoder_sequence_length = decoder_sequence_length
147147

148148
def build(self, input_shape):
149149
# Defer packer creation to `build()` so that we can be sure tokenizer
@@ -174,7 +174,14 @@ def build(self, input_shape):
174174
)
175175
self.built = True
176176

177-
def call(self, x, y=None, sample_weight=None):
177+
def call(
178+
self,
179+
x,
180+
y=None,
181+
sample_weight=None,
182+
encoder_sequence_length=None,
183+
decoder_sequence_length=None,
184+
):
178185
if not (
179186
isinstance(x, dict)
180187
and all(k in x for k in ("encoder_text", "decoder_text"))
@@ -184,6 +191,11 @@ def call(self, x, y=None, sample_weight=None):
184191
f' and `"decoder_text"`. Received x={x}.'
185192
)
186193

194+
if encoder_sequence_length is None:
195+
encoder_sequence_length = self.encoder_sequence_length
196+
if decoder_sequence_length is None:
197+
decoder_sequence_length = self.decoder_sequence_length
198+
187199
encoder_text = x["encoder_text"]
188200
decoder_text = x["decoder_text"]
189201

@@ -199,12 +211,14 @@ def call(self, x, y=None, sample_weight=None):
199211

200212
encoder_inputs = self.tokenizer(encoder_text[0])
201213
encoder_token_ids, encoder_padding_mask = self.encoder_packer(
202-
encoder_inputs
214+
encoder_inputs,
215+
sequence_length=encoder_sequence_length,
203216
)
204217

205218
decoder_inputs = self.tokenizer(decoder_text[0])
206219
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
207-
decoder_inputs
220+
decoder_inputs,
221+
sequence_length=decoder_sequence_length,
208222
)
209223

210224
x = {
@@ -226,6 +240,37 @@ def get_config(self):
226240
)
227241
return config
228242

243+
@property
244+
def encoder_sequence_length(self):
245+
"""The padded length of encoder input sequences."""
246+
return self._encoder_sequence_length
247+
248+
@encoder_sequence_length.setter
249+
def encoder_sequence_length(self, value):
250+
self._encoder_sequence_length = value
251+
if self.encoder_packer is not None:
252+
self.encoder_packer.sequence_length = value
253+
254+
@property
255+
def decoder_sequence_length(self):
256+
"""The padded length of decoder input sequences."""
257+
return self._decoder_sequence_length
258+
259+
@decoder_sequence_length.setter
260+
def decoder_sequence_length(self, value):
261+
self._decoder_sequence_length = value
262+
if self.decoder_packer is not None:
263+
self.decoder_packer.sequence_length = value
264+
265+
@property
266+
def sequence_length(self):
267+
"""Alias for `decoder_sequence_length`."""
268+
return self.decoder_sequence_length
269+
270+
@sequence_length.setter
271+
def sequence_length(self, value):
272+
self.decoder_sequence_length = value
273+
229274
@classproperty
230275
def tokenizer_cls(cls):
231276
return BartTokenizer

keras_nlp/models/bart/bart_preprocessor_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616
import tensorflow as tf
1717

18+
from keras_nlp.backend import ops
1819
from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor
1920
from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
2021
from keras_nlp.tests.test_case import TestCase
@@ -46,7 +47,7 @@ def setUp(self):
4647
)
4748

4849
def test_preprocessor_basics(self):
49-
self.run_preprocessing_layer_test(
50+
self.run_preprocessor_test(
5051
cls=BartPreprocessor,
5152
init_kwargs=self.init_kwargs,
5253
input_data=self.input_data,
@@ -60,6 +61,7 @@ def test_preprocessor_basics(self):
6061
[1], # Pass through labels.
6162
[1.0], # Pass through sample_weights.
6263
),
64+
token_id_key="decoder_token_ids",
6365
)
6466

6567
def test_error_multi_segment_input(self):

keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -124,28 +124,14 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor):
124124
```
125125
"""
126126

127-
def __init__(
127+
def call(
128128
self,
129-
tokenizer,
130-
encoder_sequence_length=1024,
131-
decoder_sequence_length=1024,
132-
**kwargs
129+
x,
130+
y=None,
131+
sample_weight=None,
132+
encoder_sequence_length=None,
133+
decoder_sequence_length=None,
133134
):
134-
# Since we truncate the last token from `decoder_token_ids`, we need to
135-
# forcefully set the `decoder_sequence_length` to one greater than the
136-
# value passed.
137-
super().__init__(
138-
tokenizer=tokenizer,
139-
encoder_sequence_length=encoder_sequence_length,
140-
decoder_sequence_length=decoder_sequence_length + 1,
141-
**kwargs
142-
)
143-
144-
# Maintain a private copy of the sequence lengths for config purposes.
145-
self._encoder_sequence_length = encoder_sequence_length
146-
self._decoder_sequence_length = decoder_sequence_length
147-
148-
def call(self, x, y=None, sample_weight=None):
149135
if y is not None or sample_weight is not None:
150136
logging.warning(
151137
"`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` "
@@ -154,7 +140,16 @@ def call(self, x, y=None, sample_weight=None):
154140
"These values will be ignored."
155141
)
156142

157-
x = super().call(x)
143+
if encoder_sequence_length is None:
144+
encoder_sequence_length = self.encoder_sequence_length
145+
if decoder_sequence_length is None:
146+
decoder_sequence_length = self.decoder_sequence_length
147+
148+
x = super().call(
149+
x,
150+
encoder_sequence_length=encoder_sequence_length,
151+
decoder_sequence_length=decoder_sequence_length + 1,
152+
)
158153
decoder_token_ids = x.pop("decoder_token_ids")
159154
decoder_padding_mask = x.pop("decoder_padding_mask")
160155

@@ -173,7 +168,8 @@ def call(self, x, y=None, sample_weight=None):
173168
def generate_preprocess(
174169
self,
175170
x,
176-
sequence_length=None,
171+
encoder_sequence_length=None,
172+
decoder_sequence_length=None,
177173
):
178174
"""Convert encoder and decoder input strings to integer token inputs for generation.
179175
@@ -190,10 +186,6 @@ def generate_preprocess(
190186
if not self.built:
191187
self.build(None)
192188

193-
# If `sequence_length` is not provided, we use the default value.
194-
if sequence_length is None:
195-
sequence_length = self._decoder_sequence_length
196-
197189
if isinstance(x, dict):
198190
encoder_text = x["encoder_text"]
199191
decoder_text = x["decoder_text"]
@@ -209,7 +201,8 @@ def generate_preprocess(
209201
]
210202
encoder_token_ids = self.tokenizer(encoder_text)
211203
encoder_token_ids, encoder_padding_mask = self.encoder_packer(
212-
encoder_token_ids
204+
encoder_token_ids,
205+
sequence_length=encoder_sequence_length,
213206
)
214207

215208
# Tokenize and pack the decoder inputs.
@@ -219,7 +212,7 @@ def generate_preprocess(
219212
decoder_token_ids = self.tokenizer(decoder_text)
220213
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
221214
decoder_token_ids,
222-
sequence_length=sequence_length,
215+
sequence_length=decoder_sequence_length,
223216
add_end_value=False,
224217
)
225218

@@ -261,16 +254,6 @@ def generate_postprocess(
261254
)
262255
return self.tokenizer.detokenize(decoder_token_ids)
263256

264-
def get_config(self):
265-
config = super().get_config()
266-
config.update(
267-
{
268-
"encoder_sequence_length": self._encoder_sequence_length,
269-
"decoder_sequence_length": self._decoder_sequence_length,
270-
}
271-
)
272-
return config
273-
274257
@classproperty
275258
def presets(cls):
276259
return copy.deepcopy(backbone_presets)

keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def setUp(self):
4545
)
4646

4747
def test_preprocessor_basics(self):
48-
self.run_preprocessing_layer_test(
48+
self.run_preprocessor_test(
4949
cls=BartSeq2SeqLMPreprocessor,
5050
init_kwargs=self.init_kwargs,
5151
input_data=self.input_data,
@@ -59,6 +59,7 @@ def test_preprocessor_basics(self):
5959
[[0, 4, 5, 4, 7, 2, 1, 1]],
6060
[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]],
6161
),
62+
token_id_key="decoder_token_ids",
6263
)
6364

6465
def test_generate_preprocess(self):

keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def setUp(self):
3939
self.input_data = ["the quick brown fox"]
4040

4141
def test_preprocessor_basics(self):
42-
self.run_preprocessing_layer_test(
42+
self.run_preprocessor_test(
4343
cls=BertMaskedLMPreprocessor,
4444
init_kwargs=self.init_kwargs,
4545
input_data=self.input_data,

keras_nlp/models/bert/bert_preprocessor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ def __init__(
139139
):
140140
super().__init__(**kwargs)
141141
self.tokenizer = tokenizer
142+
self.packer = None
142143
self.sequence_length = sequence_length
143144
self.truncate = truncate
144-
self.packer = None
145145

146146
def build(self, input_shape):
147147
# Defer packer creation to `build()` so that we can be sure tokenizer
@@ -176,6 +176,17 @@ def get_config(self):
176176
)
177177
return config
178178

179+
@property
180+
def sequence_length(self):
181+
"""The padded length of model input sequences."""
182+
return self._sequence_length
183+
184+
@sequence_length.setter
185+
def sequence_length(self, value):
186+
self._sequence_length = value
187+
if self.packer is not None:
188+
self.packer.sequence_length = value
189+
179190
@classproperty
180191
def tokenizer_cls(cls):
181192
return BertTokenizer

keras_nlp/models/bert/bert_preprocessor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def setUp(self):
3636
)
3737

3838
def test_preprocessor_basics(self):
39-
self.run_preprocessing_layer_test(
39+
self.run_preprocessor_test(
4040
cls=BertPreprocessor,
4141
init_kwargs=self.init_kwargs,
4242
input_data=self.input_data,

0 commit comments

Comments
 (0)