@@ -124,28 +124,17 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor):
124124 ```
125125 """
126126
127- def __init__ (
127+ def call (
128128 self ,
129- tokenizer ,
130- encoder_sequence_length = 1024 ,
131- decoder_sequence_length = 1024 ,
132- ** kwargs
129+ x ,
130+ y = None ,
131+ sample_weight = None ,
132+ * ,
133+ encoder_sequence_length = None ,
134+ decoder_sequence_length = None ,
135+ # `sequence_length` is an alias for `decoder_sequence_length`
136+ sequence_length = None ,
133137 ):
134- # Since we truncate the last token from `decoder_token_ids`, we need to
135- # forcefully set the `decoder_sequence_length` to one greater than the
136- # value passed.
137- super ().__init__ (
138- tokenizer = tokenizer ,
139- encoder_sequence_length = encoder_sequence_length ,
140- decoder_sequence_length = decoder_sequence_length + 1 ,
141- ** kwargs
142- )
143-
144- # Maintain a private copy of the sequence lengths for config purposes.
145- self ._encoder_sequence_length = encoder_sequence_length
146- self ._decoder_sequence_length = decoder_sequence_length
147-
148- def call (self , x , y = None , sample_weight = None ):
149138 if y is not None or sample_weight is not None :
150139 logging .warning (
151140 "`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` "
@@ -154,7 +143,17 @@ def call(self, x, y=None, sample_weight=None):
154143 "These values will be ignored."
155144 )
156145
157- x = super ().call (x )
146+ if encoder_sequence_length is None :
147+ encoder_sequence_length = self .encoder_sequence_length
148+ decoder_sequence_length = decoder_sequence_length or sequence_length
149+ if decoder_sequence_length is None :
150+ decoder_sequence_length = self .decoder_sequence_length
151+
152+ x = super ().call (
153+ x ,
154+ encoder_sequence_length = encoder_sequence_length ,
155+ decoder_sequence_length = decoder_sequence_length + 1 ,
156+ )
158157 decoder_token_ids = x .pop ("decoder_token_ids" )
159158 decoder_padding_mask = x .pop ("decoder_padding_mask" )
160159
@@ -173,6 +172,10 @@ def call(self, x, y=None, sample_weight=None):
173172 def generate_preprocess (
174173 self ,
175174 x ,
175+ * ,
176+ encoder_sequence_length = None ,
177+ # `sequence_length` is an alias for `decoder_sequence_length`
178+ decoder_sequence_length = None ,
176179 sequence_length = None ,
177180 ):
178181 """Convert encoder and decoder input strings to integer token inputs for generation.
@@ -190,10 +193,6 @@ def generate_preprocess(
190193 if not self .built :
191194 self .build (None )
192195
193- # If `sequence_length` is not provided, we use the default value.
194- if sequence_length is None :
195- sequence_length = self ._decoder_sequence_length
196-
197196 if isinstance (x , dict ):
198197 encoder_text = x ["encoder_text" ]
199198 decoder_text = x ["decoder_text" ]
@@ -202,14 +201,21 @@ def generate_preprocess(
202201 # Initialize empty prompt for the decoder.
203202 decoder_text = tf .fill ((tf .shape (encoder_text )[0 ],), "" )
204203
204+ if encoder_sequence_length is None :
205+ encoder_sequence_length = self .encoder_sequence_length
206+ decoder_sequence_length = decoder_sequence_length or sequence_length
207+ if decoder_sequence_length is None :
208+ decoder_sequence_length = self .decoder_sequence_length
209+
205210 # Tokenize and pack the encoder inputs.
206211 # TODO: Remove `[0]` once we have shifted to `MultiSegmentPacker`.
207212 encoder_text = convert_inputs_to_list_of_tensor_segments (encoder_text )[
208213 0
209214 ]
210215 encoder_token_ids = self .tokenizer (encoder_text )
211216 encoder_token_ids , encoder_padding_mask = self .encoder_packer (
212- encoder_token_ids
217+ encoder_token_ids ,
218+ sequence_length = encoder_sequence_length ,
213219 )
214220
215221 # Tokenize and pack the decoder inputs.
@@ -219,7 +225,7 @@ def generate_preprocess(
219225 decoder_token_ids = self .tokenizer (decoder_text )
220226 decoder_token_ids , decoder_padding_mask = self .decoder_packer (
221227 decoder_token_ids ,
222- sequence_length = sequence_length ,
228+ sequence_length = decoder_sequence_length ,
223229 add_end_value = False ,
224230 )
225231
@@ -261,16 +267,6 @@ def generate_postprocess(
261267 )
262268 return self .tokenizer .detokenize (decoder_token_ids )
263269
264- def get_config (self ):
265- config = super ().get_config ()
266- config .update (
267- {
268- "encoder_sequence_length" : self ._encoder_sequence_length ,
269- "decoder_sequence_length" : self ._decoder_sequence_length ,
270- }
271- )
272- return config
273-
274270 @classproperty
275271 def presets (cls ):
276272 return copy .deepcopy (backbone_presets )
0 commit comments