@@ -124,28 +124,14 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor):
124124 ```
125125 """
126126
127- def __init__ (
127+ def call (
128128 self ,
129- tokenizer ,
130- encoder_sequence_length = 1024 ,
131- decoder_sequence_length = 1024 ,
132- ** kwargs
129+ x ,
130+ y = None ,
131+ sample_weight = None ,
132+ encoder_sequence_length = None ,
133+ decoder_sequence_length = None ,
133134 ):
134- # Since we truncate the last token from `decoder_token_ids`, we need to
135- # forcefully set the `decoder_sequence_length` to one greater than the
136- # value passed.
137- super ().__init__ (
138- tokenizer = tokenizer ,
139- encoder_sequence_length = encoder_sequence_length ,
140- decoder_sequence_length = decoder_sequence_length + 1 ,
141- ** kwargs
142- )
143-
144- # Maintain a private copy of the sequence lengths for config purposes.
145- self ._encoder_sequence_length = encoder_sequence_length
146- self ._decoder_sequence_length = decoder_sequence_length
147-
148- def call (self , x , y = None , sample_weight = None ):
149135 if y is not None or sample_weight is not None :
150136 logging .warning (
151137 "`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` "
@@ -154,7 +140,16 @@ def call(self, x, y=None, sample_weight=None):
154140 "These values will be ignored."
155141 )
156142
157- x = super ().call (x )
143+ if encoder_sequence_length is None :
144+ encoder_sequence_length = self .encoder_sequence_length
145+ if decoder_sequence_length is None :
146+ decoder_sequence_length = self .decoder_sequence_length
147+
148+ x = super ().call (
149+ x ,
150+ encoder_sequence_length = encoder_sequence_length ,
151+ decoder_sequence_length = decoder_sequence_length + 1 ,
152+ )
158153 decoder_token_ids = x .pop ("decoder_token_ids" )
159154 decoder_padding_mask = x .pop ("decoder_padding_mask" )
160155
@@ -173,7 +168,8 @@ def call(self, x, y=None, sample_weight=None):
173168 def generate_preprocess (
174169 self ,
175170 x ,
176- sequence_length = None ,
171+ encoder_sequence_length = None ,
172+ decoder_sequence_length = None ,
177173 ):
178174 """Convert encoder and decoder input strings to integer token inputs for generation.
179175
@@ -190,10 +186,6 @@ def generate_preprocess(
190186 if not self .built :
191187 self .build (None )
192188
193- # If `sequence_length` is not provided, we use the default value.
194- if sequence_length is None :
195- sequence_length = self ._decoder_sequence_length
196-
197189 if isinstance (x , dict ):
198190 encoder_text = x ["encoder_text" ]
199191 decoder_text = x ["decoder_text" ]
@@ -209,7 +201,8 @@ def generate_preprocess(
209201 ]
210202 encoder_token_ids = self .tokenizer (encoder_text )
211203 encoder_token_ids , encoder_padding_mask = self .encoder_packer (
212- encoder_token_ids
204+ encoder_token_ids ,
205+ sequence_length = encoder_sequence_length ,
213206 )
214207
215208 # Tokenize and pack the decoder inputs.
@@ -219,7 +212,7 @@ def generate_preprocess(
219212 decoder_token_ids = self .tokenizer (decoder_text )
220213 decoder_token_ids , decoder_padding_mask = self .decoder_packer (
221214 decoder_token_ids ,
222- sequence_length = sequence_length ,
215+ sequence_length = decoder_sequence_length ,
223216 add_end_value = False ,
224217 )
225218
@@ -261,16 +254,6 @@ def generate_postprocess(
261254 )
262255 return self .tokenizer .detokenize (decoder_token_ids )
263256
264- def get_config (self ):
265- config = super ().get_config ()
266- config .update (
267- {
268- "encoder_sequence_length" : self ._encoder_sequence_length ,
269- "decoder_sequence_length" : self ._decoder_sequence_length ,
270- }
271- )
272- return config
273-
274257 @classproperty
275258 def presets (cls ):
276259 return copy .deepcopy (backbone_presets )
0 commit comments