Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUp(self):
self.input_data = ["the quick brown fox"]

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
self.run_preprocessor_test(
cls=AlbertMaskedLMPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
Expand Down
13 changes: 12 additions & 1 deletion keras_nlp/models/albert/albert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def __init__(
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
self.packer = None
self.truncate = truncate
self.sequence_length = sequence_length
self.packer = None

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
Expand Down Expand Up @@ -195,6 +195,17 @@ def call(self, x, y=None, sample_weight=None):
}
return pack_x_y_sample_weight(x, y, sample_weight)

@property
def sequence_length(self):
"""The padded length of model input sequences."""
return self._sequence_length

@sequence_length.setter
def sequence_length(self, value):
self._sequence_length = value
if self.packer is not None:
self.packer.sequence_length = value

@classproperty
def tokenizer_cls(cls):
return AlbertTokenizer
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/albert/albert_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setUp(self):
)

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
self.run_preprocessor_test(
cls=AlbertPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
Expand Down
59 changes: 54 additions & 5 deletions keras_nlp/models/bart/bart_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def __init__(
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
self.encoder_sequence_length = encoder_sequence_length
self.decoder_sequence_length = decoder_sequence_length
self.encoder_packer = None
self.decoder_packer = None
self.encoder_sequence_length = encoder_sequence_length
self.decoder_sequence_length = decoder_sequence_length

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
Expand Down Expand Up @@ -174,7 +174,17 @@ def build(self, input_shape):
)
self.built = True

def call(self, x, y=None, sample_weight=None):
def call(
self,
x,
y=None,
sample_weight=None,
*,
encoder_sequence_length=None,
decoder_sequence_length=None,
# `sequence_length` is an alias for `decoder_sequence_length`
sequence_length=None,
):
if not (
isinstance(x, dict)
and all(k in x for k in ("encoder_text", "decoder_text"))
Expand All @@ -184,6 +194,12 @@ def call(self, x, y=None, sample_weight=None):
f' and `"decoder_text"`. Received x={x}.'
)

if encoder_sequence_length is None:
encoder_sequence_length = self.encoder_sequence_length
decoder_sequence_length = decoder_sequence_length or sequence_length
if decoder_sequence_length is None:
decoder_sequence_length = self.decoder_sequence_length

encoder_text = x["encoder_text"]
decoder_text = x["decoder_text"]

Expand All @@ -199,12 +215,14 @@ def call(self, x, y=None, sample_weight=None):

encoder_inputs = self.tokenizer(encoder_text[0])
encoder_token_ids, encoder_padding_mask = self.encoder_packer(
encoder_inputs
encoder_inputs,
sequence_length=encoder_sequence_length,
)

decoder_inputs = self.tokenizer(decoder_text[0])
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
decoder_inputs
decoder_inputs,
sequence_length=decoder_sequence_length,
)

x = {
Expand All @@ -226,6 +244,37 @@ def get_config(self):
)
return config

@property
def encoder_sequence_length(self):
"""The padded length of encoder input sequences."""
return self._encoder_sequence_length

@encoder_sequence_length.setter
def encoder_sequence_length(self, value):
self._encoder_sequence_length = value
if self.encoder_packer is not None:
self.encoder_packer.sequence_length = value

@property
def decoder_sequence_length(self):
"""The padded length of decoder input sequences."""
return self._decoder_sequence_length

@decoder_sequence_length.setter
def decoder_sequence_length(self, value):
self._decoder_sequence_length = value
if self.decoder_packer is not None:
self.decoder_packer.sequence_length = value

@property
def sequence_length(self):
"""Alias for `decoder_sequence_length`."""
return self.decoder_sequence_length

@sequence_length.setter
def sequence_length(self, value):
self.decoder_sequence_length = value

@classproperty
def tokenizer_cls(cls):
return BartTokenizer
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/models/bart/bart_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setUp(self):
)

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
self.run_preprocessor_test(
cls=BartPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
Expand All @@ -60,6 +60,7 @@ def test_preprocessor_basics(self):
[1], # Pass through labels.
[1.0], # Pass through sample_weights.
),
token_id_key="decoder_token_ids",
)

def test_error_multi_segment_input(self):
Expand Down
70 changes: 33 additions & 37 deletions keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,28 +124,17 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor):
```
"""

def __init__(
def call(
self,
tokenizer,
encoder_sequence_length=1024,
decoder_sequence_length=1024,
**kwargs
x,
y=None,
sample_weight=None,
*,
encoder_sequence_length=None,
decoder_sequence_length=None,
# `sequence_length` is an alias for `decoder_sequence_length`
sequence_length=None,
):
# Since we truncate the last token from `decoder_token_ids`, we need to
# forcefully set the `decoder_sequence_length` to one greater than the
# value passed.
super().__init__(
tokenizer=tokenizer,
encoder_sequence_length=encoder_sequence_length,
decoder_sequence_length=decoder_sequence_length + 1,
**kwargs
)

# Maintain a private copy of the sequence lengths for config purposes.
self._encoder_sequence_length = encoder_sequence_length
self._decoder_sequence_length = decoder_sequence_length

def call(self, x, y=None, sample_weight=None):
if y is not None or sample_weight is not None:
logging.warning(
"`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` "
Expand All @@ -154,7 +143,17 @@ def call(self, x, y=None, sample_weight=None):
"These values will be ignored."
)

x = super().call(x)
if encoder_sequence_length is None:
encoder_sequence_length = self.encoder_sequence_length
decoder_sequence_length = decoder_sequence_length or sequence_length
if decoder_sequence_length is None:
decoder_sequence_length = self.decoder_sequence_length

x = super().call(
x,
encoder_sequence_length=encoder_sequence_length,
decoder_sequence_length=decoder_sequence_length + 1,
)
decoder_token_ids = x.pop("decoder_token_ids")
decoder_padding_mask = x.pop("decoder_padding_mask")

Expand All @@ -173,6 +172,10 @@ def call(self, x, y=None, sample_weight=None):
def generate_preprocess(
self,
x,
*,
encoder_sequence_length=None,
# `sequence_length` is an alias for `decoder_sequence_length`
decoder_sequence_length=None,
sequence_length=None,
):
"""Convert encoder and decoder input strings to integer token inputs for generation.
Expand All @@ -190,10 +193,6 @@ def generate_preprocess(
if not self.built:
self.build(None)

# If `sequence_length` is not provided, we use the default value.
if sequence_length is None:
sequence_length = self._decoder_sequence_length

if isinstance(x, dict):
encoder_text = x["encoder_text"]
decoder_text = x["decoder_text"]
Expand All @@ -202,14 +201,21 @@ def generate_preprocess(
# Initialize empty prompt for the decoder.
decoder_text = tf.fill((tf.shape(encoder_text)[0],), "")

if encoder_sequence_length is None:
encoder_sequence_length = self.encoder_sequence_length
decoder_sequence_length = decoder_sequence_length or sequence_length
if decoder_sequence_length is None:
decoder_sequence_length = self.decoder_sequence_length

# Tokenize and pack the encoder inputs.
# TODO: Remove `[0]` once we have shifted to `MultiSegmentPacker`.
encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text)[
0
]
encoder_token_ids = self.tokenizer(encoder_text)
encoder_token_ids, encoder_padding_mask = self.encoder_packer(
encoder_token_ids
encoder_token_ids,
sequence_length=encoder_sequence_length,
)

# Tokenize and pack the decoder inputs.
Expand All @@ -219,7 +225,7 @@ def generate_preprocess(
decoder_token_ids = self.tokenizer(decoder_text)
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
decoder_token_ids,
sequence_length=sequence_length,
sequence_length=decoder_sequence_length,
add_end_value=False,
)

Expand Down Expand Up @@ -261,16 +267,6 @@ def generate_postprocess(
)
return self.tokenizer.detokenize(decoder_token_ids)

def get_config(self):
config = super().get_config()
config.update(
{
"encoder_sequence_length": self._encoder_sequence_length,
"decoder_sequence_length": self._decoder_sequence_length,
}
)
return config

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
3 changes: 2 additions & 1 deletion keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def setUp(self):
)

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
self.run_preprocessor_test(
cls=BartSeq2SeqLMPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
Expand All @@ -59,6 +59,7 @@ def test_preprocessor_basics(self):
[[0, 4, 5, 4, 7, 2, 1, 1]],
[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]],
),
token_id_key="decoder_token_ids",
)

def test_generate_preprocess(self):
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setUp(self):
self.input_data = ["the quick brown fox"]

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
self.run_preprocessor_test(
cls=BertMaskedLMPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
Expand Down
13 changes: 12 additions & 1 deletion keras_nlp/models/bert/bert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ def __init__(
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
self.packer = None
self.sequence_length = sequence_length
self.truncate = truncate
self.packer = None

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
Expand Down Expand Up @@ -176,6 +176,17 @@ def get_config(self):
)
return config

@property
def sequence_length(self):
"""The padded length of model input sequences."""
return self._sequence_length

@sequence_length.setter
def sequence_length(self, value):
self._sequence_length = value
if self.packer is not None:
self.packer.sequence_length = value

@classproperty
def tokenizer_cls(cls):
return BertTokenizer
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/bert/bert_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def setUp(self):
)

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
self.run_preprocessor_test(
cls=BertPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setUp(self):
self.input_data = ["airplane at airport"]

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
self.run_preprocessor_test(
cls=BloomCausalLMPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
Expand Down
13 changes: 12 additions & 1 deletion keras_nlp/models/bloom/bloom_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)

self.tokenizer = tokenizer
self.packer = None
self.sequence_length = sequence_length
self.add_start_token = add_start_token
self.add_end_token = add_end_token
Expand Down Expand Up @@ -173,6 +173,17 @@ def get_config(self):
)
return config

@property
def sequence_length(self):
"""The padded length of model input sequences."""
return self._sequence_length

@sequence_length.setter
def sequence_length(self, value):
self._sequence_length = value
if self.packer is not None:
self.packer.sequence_length = value

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
Expand Down
Loading