From 31e9d397617320300d54ca86a5f14e24fc14b23a Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 14 Feb 2024 00:32:57 -0800 Subject: [PATCH] Add a settable property for sequence_length For all preprocessors, we add a settable sequence length property. This removes a very annoying gotcha with preprocessing, where if you set the preprocessing length after build, it would not be reflected in the actual output length. --- .../albert_masked_lm_preprocessor_test.py | 2 +- .../models/albert/albert_preprocessor.py | 13 +++- .../models/albert/albert_preprocessor_test.py | 2 +- keras_nlp/models/bart/bart_preprocessor.py | 59 ++++++++++++++-- .../models/bart/bart_preprocessor_test.py | 3 +- .../bart/bart_seq_2_seq_lm_preprocessor.py | 70 +++++++++---------- .../bart_seq_2_seq_lm_preprocessor_test.py | 3 +- .../bert/bert_masked_lm_preprocessor_test.py | 2 +- keras_nlp/models/bert/bert_preprocessor.py | 13 +++- .../models/bert/bert_preprocessor_test.py | 2 +- .../bloom_causal_lm_preprocessor_test.py | 2 +- keras_nlp/models/bloom/bloom_preprocessor.py | 13 +++- .../models/bloom/bloom_preprocessor_test.py | 2 +- .../deberta_v3_masked_lm_preprocessor_test.py | 2 +- .../deberta_v3/deberta_v3_preprocessor.py | 13 +++- .../deberta_v3_preprocessor_test.py | 2 +- ...distil_bert_masked_lm_preprocessor_test.py | 2 +- .../distil_bert/distil_bert_preprocessor.py | 12 ++++ .../distil_bert_preprocessor_test.py | 2 +- .../f_net_masked_lm_preprocessor_test.py | 2 +- keras_nlp/models/f_net/f_net_preprocessor.py | 13 +++- .../models/f_net/f_net_preprocessor_test.py | 2 +- .../gpt2/gpt2_causal_lm_preprocessor_test.py | 2 +- keras_nlp/models/gpt2/gpt2_preprocessor.py | 13 +++- .../models/gpt2/gpt2_preprocessor_test.py | 2 +- .../gpt_neo_x_causal_lm_preprocessor_test.py | 2 +- .../gpt_neo_x/gpt_neo_x_preprocessor.py | 14 +++- .../gpt_neo_x/gpt_neo_x_preprocessor_test.py | 2 +- .../mistral_causal_lm_preprocessor_test.py | 2 +- .../models/mistral/mistral_preprocessor.py | 17 ++++- .../mistral/mistral_preprocessor_test.py | 2 +- .../opt/opt_causal_lm_preprocessor_test.py | 2 +- keras_nlp/models/opt/opt_preprocessor.py | 13 +++- keras_nlp/models/opt/opt_preprocessor_test.py | 2 +- .../roberta_masked_lm_preprocessor_test.py | 2 +- .../models/roberta/roberta_preprocessor.py | 13 +++- .../roberta/roberta_preprocessor_test.py | 2 +- .../models/whisper/whisper_preprocessor.py | 22 +++++- .../whisper/whisper_preprocessor_test.py | 3 +- ...xlm_roberta_masked_lm_preprocessor_test.py | 2 +- .../xlm_roberta/xlm_roberta_preprocessor.py | 13 +++- .../xlm_roberta_preprocessor_test.py | 2 +- keras_nlp/tests/test_case.py | 36 ++++++++++ 43 files changed, 320 insertions(+), 84 deletions(-) diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py index 79d3a36bbb..b9bf693c17 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -43,7 +43,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=AlbertMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/albert/albert_preprocessor.py b/keras_nlp/models/albert/albert_preprocessor.py index 5d5628a729..19f4bd9a7b 100644 --- a/keras_nlp/models/albert/albert_preprocessor.py +++ b/keras_nlp/models/albert/albert_preprocessor.py @@ -158,9 +158,9 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -195,6 +195,17 @@ def call(self, x, y=None, sample_weight=None): } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return AlbertTokenizer diff --git a/keras_nlp/models/albert/albert_preprocessor_test.py b/keras_nlp/models/albert/albert_preprocessor_test.py index 7d6fb4cfd4..ad5da8a47b 100644 --- a/keras_nlp/models/albert/albert_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_preprocessor_test.py @@ -40,7 +40,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=AlbertPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/bart/bart_preprocessor.py b/keras_nlp/models/bart/bart_preprocessor.py index ffe2148839..3310b1e532 100644 --- a/keras_nlp/models/bart/bart_preprocessor.py +++ b/keras_nlp/models/bart/bart_preprocessor.py @@ -140,10 +140,10 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer - self.encoder_sequence_length = encoder_sequence_length - self.decoder_sequence_length = decoder_sequence_length self.encoder_packer = None self.decoder_packer = None + self.encoder_sequence_length = encoder_sequence_length + self.decoder_sequence_length = decoder_sequence_length def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -174,7 +174,17 @@ def build(self, input_shape): ) self.built = True - def call(self, x, y=None, sample_weight=None): + def call( + self, + x, + y=None, + sample_weight=None, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + # `sequence_length` is an alias for `decoder_sequence_length` + sequence_length=None, + ): if not ( isinstance(x, dict) and all(k in x for k in ("encoder_text", "decoder_text")) @@ -184,6 +194,12 @@ def call(self, x, y=None, sample_weight=None): f' and `"decoder_text"`. Received x={x}.' ) + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + encoder_text = x["encoder_text"] decoder_text = x["decoder_text"] @@ -199,12 +215,14 @@ def call(self, x, y=None, sample_weight=None): encoder_inputs = self.tokenizer(encoder_text[0]) encoder_token_ids, encoder_padding_mask = self.encoder_packer( - encoder_inputs + encoder_inputs, + sequence_length=encoder_sequence_length, ) decoder_inputs = self.tokenizer(decoder_text[0]) decoder_token_ids, decoder_padding_mask = self.decoder_packer( - decoder_inputs + decoder_inputs, + sequence_length=decoder_sequence_length, ) x = { @@ -226,6 +244,37 @@ def get_config(self): ) return config + @property + def encoder_sequence_length(self): + """The padded length of encoder input sequences.""" + return self._encoder_sequence_length + + @encoder_sequence_length.setter + def encoder_sequence_length(self, value): + self._encoder_sequence_length = value + if self.encoder_packer is not None: + self.encoder_packer.sequence_length = value + + @property + def decoder_sequence_length(self): + """The padded length of decoder input sequences.""" + return self._decoder_sequence_length + + @decoder_sequence_length.setter + def decoder_sequence_length(self, value): + self._decoder_sequence_length = value + if self.decoder_packer is not None: + self.decoder_packer.sequence_length = value + + @property + def sequence_length(self): + """Alias for `decoder_sequence_length`.""" + return self.decoder_sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self.decoder_sequence_length = value + @classproperty def tokenizer_cls(cls): return BartTokenizer diff --git a/keras_nlp/models/bart/bart_preprocessor_test.py b/keras_nlp/models/bart/bart_preprocessor_test.py index 23cb7cae79..7872e35efa 100644 --- a/keras_nlp/models/bart/bart_preprocessor_test.py +++ b/keras_nlp/models/bart/bart_preprocessor_test.py @@ -46,7 +46,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BartPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, @@ -60,6 +60,7 @@ def test_preprocessor_basics(self): [1], # Pass through labels. [1.0], # Pass through sample_weights. ), + token_id_key="decoder_token_ids", ) def test_error_multi_segment_input(self): diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py index 3d398d29d1..1c72e6e935 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -124,28 +124,17 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor): ``` """ - def __init__( + def call( self, - tokenizer, - encoder_sequence_length=1024, - decoder_sequence_length=1024, - **kwargs + x, + y=None, + sample_weight=None, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + # `sequence_length` is an alias for `decoder_sequence_length` + sequence_length=None, ): - # Since we truncate the last token from `decoder_token_ids`, we need to - # forcefully set the `decoder_sequence_length` to one greater than the - # value passed. - super().__init__( - tokenizer=tokenizer, - encoder_sequence_length=encoder_sequence_length, - decoder_sequence_length=decoder_sequence_length + 1, - **kwargs - ) - - # Maintain a private copy of the sequence lengths for config purposes. - self._encoder_sequence_length = encoder_sequence_length - self._decoder_sequence_length = decoder_sequence_length - - def call(self, x, y=None, sample_weight=None): if y is not None or sample_weight is not None: logging.warning( "`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` " @@ -154,7 +143,17 @@ def call(self, x, y=None, sample_weight=None): "These values will be ignored." ) - x = super().call(x) + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + x = super().call( + x, + encoder_sequence_length=encoder_sequence_length, + decoder_sequence_length=decoder_sequence_length + 1, + ) decoder_token_ids = x.pop("decoder_token_ids") decoder_padding_mask = x.pop("decoder_padding_mask") @@ -173,6 +172,10 @@ def call(self, x, y=None, sample_weight=None): def generate_preprocess( self, x, + *, + encoder_sequence_length=None, + # `sequence_length` is an alias for `decoder_sequence_length` + decoder_sequence_length=None, sequence_length=None, ): """Convert encoder and decoder input strings to integer token inputs for generation. @@ -190,10 +193,6 @@ def generate_preprocess( if not self.built: self.build(None) - # If `sequence_length` is not provided, we use the default value. - if sequence_length is None: - sequence_length = self._decoder_sequence_length - if isinstance(x, dict): encoder_text = x["encoder_text"] decoder_text = x["decoder_text"] @@ -202,6 +201,12 @@ def generate_preprocess( # Initialize empty prompt for the decoder. decoder_text = tf.fill((tf.shape(encoder_text)[0],), "") + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + # Tokenize and pack the encoder inputs. # TODO: Remove `[0]` once we have shifted to `MultiSegmentPacker`. encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text)[ @@ -209,7 +214,8 @@ def generate_preprocess( ] encoder_token_ids = self.tokenizer(encoder_text) encoder_token_ids, encoder_padding_mask = self.encoder_packer( - encoder_token_ids + encoder_token_ids, + sequence_length=encoder_sequence_length, ) # Tokenize and pack the decoder inputs. @@ -219,7 +225,7 @@ def generate_preprocess( decoder_token_ids = self.tokenizer(decoder_text) decoder_token_ids, decoder_padding_mask = self.decoder_packer( decoder_token_ids, - sequence_length=sequence_length, + sequence_length=decoder_sequence_length, add_end_value=False, ) @@ -261,16 +267,6 @@ def generate_postprocess( ) return self.tokenizer.detokenize(decoder_token_ids) - def get_config(self): - config = super().get_config() - config.update( - { - "encoder_sequence_length": self._encoder_sequence_length, - "decoder_sequence_length": self._decoder_sequence_length, - } - ) - return config - @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py index 33fbd5fc3a..2f40e69722 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py @@ -45,7 +45,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BartSeq2SeqLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, @@ -59,6 +59,7 @@ def test_preprocessor_basics(self): [[0, 4, 5, 4, 7, 2, 1, 1]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]], ), + token_id_key="decoder_token_ids", ) def test_generate_preprocess(self): diff --git a/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py index ff58962215..479d9e879b 100644 --- a/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py @@ -39,7 +39,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BertMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/bert/bert_preprocessor.py b/keras_nlp/models/bert/bert_preprocessor.py index bad38f22a5..02f5a45985 100644 --- a/keras_nlp/models/bert/bert_preprocessor.py +++ b/keras_nlp/models/bert/bert_preprocessor.py @@ -139,9 +139,9 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.truncate = truncate - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -176,6 +176,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return BertTokenizer diff --git a/keras_nlp/models/bert/bert_preprocessor_test.py b/keras_nlp/models/bert/bert_preprocessor_test.py index 6d1e5fee57..c109d1006d 100644 --- a/keras_nlp/models/bert/bert_preprocessor_test.py +++ b/keras_nlp/models/bert/bert_preprocessor_test.py @@ -36,7 +36,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BertPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor_test.py b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor_test.py index 6caf8fddcf..a281519340 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor_test.py @@ -40,7 +40,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BloomCausalLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/bloom/bloom_preprocessor.py b/keras_nlp/models/bloom/bloom_preprocessor.py index d57ccd2414..734c9f4bf8 100644 --- a/keras_nlp/models/bloom/bloom_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_preprocessor.py @@ -116,8 +116,8 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.add_start_token = add_start_token self.add_end_token = add_end_token @@ -173,6 +173,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bloom/bloom_preprocessor_test.py b/keras_nlp/models/bloom/bloom_preprocessor_test.py index bb80006396..938113ef4b 100644 --- a/keras_nlp/models/bloom/bloom_preprocessor_test.py +++ b/keras_nlp/models/bloom/bloom_preprocessor_test.py @@ -38,7 +38,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BloomPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py index 217980ea59..f041a6f7ff 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py @@ -43,7 +43,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=DebertaV3MaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py index 93f4fbbd22..88fa08fd70 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py @@ -156,9 +156,9 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -192,6 +192,17 @@ def call(self, x, y=None, sample_weight=None): } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return DebertaV3Tokenizer diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py index a50022f3c7..a9e2a59c29 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py @@ -42,7 +42,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=DebertaV3Preprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py index b01b1da8ac..85ee5bdd43 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py @@ -41,7 +41,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=DistilBertMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py index 107275f80a..63f4e3637b 100644 --- a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py +++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py @@ -127,6 +127,7 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.truncate = truncate @@ -162,6 +163,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return DistilBertTokenizer diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py index 22d69c88dc..f58b42cd39 100644 --- a/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py @@ -40,7 +40,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=DistilBertPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py index 5f72081a0d..7d2ecc0f17 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py @@ -41,7 +41,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=FNetMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/f_net/f_net_preprocessor.py b/keras_nlp/models/f_net/f_net_preprocessor.py index 296493c930..b4cb5836bb 100644 --- a/keras_nlp/models/f_net/f_net_preprocessor.py +++ b/keras_nlp/models/f_net/f_net_preprocessor.py @@ -129,9 +129,9 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -165,6 +165,17 @@ def call(self, x, y=None, sample_weight=None): } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return FNetTokenizer diff --git a/keras_nlp/models/f_net/f_net_preprocessor_test.py b/keras_nlp/models/f_net/f_net_preprocessor_test.py index f67737c828..c9096ac59f 100644 --- a/keras_nlp/models/f_net/f_net_preprocessor_test.py +++ b/keras_nlp/models/f_net/f_net_preprocessor_test.py @@ -38,7 +38,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=FNetPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py index 400273b792..0623d983a9 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py @@ -40,7 +40,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=GPT2CausalLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 29182f77b6..82be34776f 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -118,8 +118,8 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.add_start_token = add_start_token self.add_end_token = add_end_token @@ -175,6 +175,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py index d7dcd261ed..35129c200d 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py @@ -38,7 +38,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=GPT2Preprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py index f5a7c57421..e873c38c79 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py @@ -40,7 +40,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=GPTNeoXCausalLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py index 1db4fe4c9b..8dc374332b 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py @@ -74,12 +74,11 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.add_start_token = add_start_token self.add_end_token = add_end_token - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -132,6 +131,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return GPTNeoXTokenizer diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py index c87329af4a..92ea191596 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py @@ -38,7 +38,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=GPTNeoXPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py index 420995016b..80b55c02e5 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py @@ -36,7 +36,7 @@ def setUp(self): self.input_data = (["the quick brown fox"],) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=MistralCausalLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/mistral/mistral_preprocessor.py b/keras_nlp/models/mistral/mistral_preprocessor.py index d5d838303e..b96afb5ed8 100644 --- a/keras_nlp/models/mistral/mistral_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_preprocessor.py @@ -121,15 +121,15 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer - self.add_start_token = add_start_token - self.add_end_token = add_end_token - self.sequence_length = sequence_length self.packer = StartEndPacker( start_value=self.tokenizer.start_token_id, end_value=self.tokenizer.end_token_id, sequence_length=sequence_length, return_padding_mask=True, ) + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.sequence_length = sequence_length def get_config(self): config = super().get_config() @@ -170,6 +170,17 @@ def call( } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return MistralTokenizer diff --git a/keras_nlp/models/mistral/mistral_preprocessor_test.py b/keras_nlp/models/mistral/mistral_preprocessor_test.py index 40528fd4e8..47c0bc542c 100644 --- a/keras_nlp/models/mistral/mistral_preprocessor_test.py +++ b/keras_nlp/models/mistral/mistral_preprocessor_test.py @@ -38,7 +38,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=MistralPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py index 9ba6851d4b..e04436f092 100644 --- a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py @@ -39,7 +39,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=OPTCausalLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/opt/opt_preprocessor.py b/keras_nlp/models/opt/opt_preprocessor.py index cdca904870..8f52bb67e6 100644 --- a/keras_nlp/models/opt/opt_preprocessor.py +++ b/keras_nlp/models/opt/opt_preprocessor.py @@ -120,10 +120,10 @@ def __init__( super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.add_start_token = add_start_token self.add_end_token = add_end_token - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -176,6 +176,17 @@ def call( } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/opt/opt_preprocessor_test.py b/keras_nlp/models/opt/opt_preprocessor_test.py index b80c409b92..901efc7bee 100644 --- a/keras_nlp/models/opt/opt_preprocessor_test.py +++ b/keras_nlp/models/opt/opt_preprocessor_test.py @@ -37,7 +37,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=OPTPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py index ae762079e2..a842e99f5d 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py @@ -44,7 +44,7 @@ def setUp(self): self.input_data = [" airplane airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=RobertaMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/roberta/roberta_preprocessor.py b/keras_nlp/models/roberta/roberta_preprocessor.py index 556561d17c..57a421590f 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor.py +++ b/keras_nlp/models/roberta/roberta_preprocessor.py @@ -143,9 +143,9 @@ def __init__( super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -180,6 +180,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return RobertaTokenizer diff --git a/keras_nlp/models/roberta/roberta_preprocessor_test.py b/keras_nlp/models/roberta/roberta_preprocessor_test.py index 5e7ad77514..699742ea08 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor_test.py +++ b/keras_nlp/models/roberta/roberta_preprocessor_test.py @@ -41,7 +41,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=RobertaPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/whisper/whisper_preprocessor.py b/keras_nlp/models/whisper/whisper_preprocessor.py index abcff0d770..c21705a481 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/models/whisper/whisper_preprocessor.py @@ -169,11 +169,11 @@ def __init__( audio_feature_extractor = WhisperAudioFeatureExtractor() self.audio_feature_extractor = audio_feature_extractor self.tokenizer = tokenizer + self.decoder_packer = None self.decoder_sequence_length = decoder_sequence_length self.language = language self.task = task self.no_timestamps = no_timestamps - self.decoder_packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -307,6 +307,26 @@ def from_config(cls, config): return cls(**config) + @property + def decoder_sequence_length(self): + """The padded length of decoder input sequences.""" + return self._decoder_sequence_length + + @decoder_sequence_length.setter + def decoder_sequence_length(self, value): + self._decoder_sequence_length = value + if self.decoder_packer is not None: + self.decoder_packer.sequence_length = value + + @property + def sequence_length(self): + """Alias for `decoder_sequence_length`.""" + return self.decoder_sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self.decoder_sequence_length = value + @classproperty def audio_feature_extractor_cls(cls): return WhisperAudioFeatureExtractor diff --git a/keras_nlp/models/whisper/whisper_preprocessor_test.py b/keras_nlp/models/whisper/whisper_preprocessor_test.py index 6837dc8bfa..8517a6c102 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor_test.py +++ b/keras_nlp/models/whisper/whisper_preprocessor_test.py @@ -66,10 +66,11 @@ def setUp(self): } def test_feature_extractor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=WhisperPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, + token_id_key="decoder_token_ids", ) def test_sequence_length_override(self): diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py index c1bfc7242a..6d77e71319 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py @@ -45,7 +45,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=XLMRobertaMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py index 23b48073f7..c94f5f2421 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py @@ -156,9 +156,9 @@ def __init__( super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -193,6 +193,17 @@ def call(self, x, y=None, sample_weight=None): } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return XLMRobertaTokenizer diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py index 3c3bbf2612..85c76fa282 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py @@ -44,7 +44,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=XLMRobertaPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/tests/test_case.py b/keras_nlp/tests/test_case.py index 50e7733513..0541ae6451 100644 --- a/keras_nlp/tests/test_case.py +++ b/keras_nlp/tests/test_case.py @@ -230,6 +230,42 @@ def run_preprocessing_layer_test( if expected_output: self.assertAllClose(output, expected_output) + def run_preprocessor_test( + self, + cls, + init_kwargs, + input_data, + expected_output=None, + expected_detokenize_output=None, + token_id_key="token_ids", + ): + """Run basic tests for a Model Preprocessor layer.""" + self.run_preprocessing_layer_test( + cls, + init_kwargs, + input_data, + expected_output=expected_output, + expected_detokenize_output=expected_detokenize_output, + ) + + layer = cls(**self.init_kwargs) + if isinstance(input_data, tuple): + output = layer(*input_data) + else: + output = layer(input_data) + output, _, _ = keras.utils.unpack_x_y_sample_weight(output) + shape = ops.shape(output[token_id_key]) + self.assertEqual(shape[-1], layer.sequence_length) + # Update the sequence length. + layer.sequence_length = 17 + if isinstance(input_data, tuple): + output = layer(*input_data) + else: + output = layer(input_data) + output, _, _ = keras.utils.unpack_x_y_sample_weight(output) + shape = ops.shape(output[token_id_key]) + self.assertEqual(shape[-1], 17) + def run_serialization_test(self, instance): """Check idempotency of serialize/deserialize.