11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
14
+ import itertools
15
15
from collections import defaultdict
16
16
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Sequence , Tuple
17
17
18
18
from ...extras .constants import IGNORE_INDEX
19
19
from ...extras .logging import get_logger
20
20
from .processor_utils import greedy_knapsack , infer_seqlen
21
21
22
-
23
22
if TYPE_CHECKING :
24
23
from transformers import PreTrainedTokenizer , ProcessorMixin
25
24
26
25
from ...hparams import DataArguments
27
26
from ..mm_plugin import ImageInput , VideoInput
28
27
from ..template import Template
29
28
30
-
31
29
logger = get_logger (__name__ )
32
30
33
31
@@ -53,13 +51,16 @@ def _encode_supervised_example(
53
51
encoded_pairs = encoded_pairs [::- 1 ] # high priority for last turns
54
52
55
53
for turn_idx , (source_ids , target_ids ) in enumerate (encoded_pairs ):
56
- if total_length >= cutoff_len :
54
+ if total_length >= cutoff_len and cutoff_len > 0 :
57
55
break
58
56
59
- source_len , target_len = infer_seqlen (len (source_ids ), len (target_ids ), cutoff_len - total_length )
60
- source_ids = source_ids [:source_len ]
61
- target_ids = target_ids [:target_len ]
62
- total_length += source_len + target_len
57
+ if cutoff_len > 0 :
58
+ source_len , target_len = infer_seqlen (len (source_ids ), len (target_ids ), cutoff_len - total_length )
59
+ source_ids = source_ids [:source_len ]
60
+ target_ids = target_ids [:target_len ]
61
+ total_length += source_len + target_len
62
+ else :
63
+ source_len , target_len = len (source_ids ), len (target_ids )
63
64
64
65
if train_on_prompt :
65
66
source_label = source_ids
@@ -112,7 +113,7 @@ def preprocess_supervised_dataset(
112
113
template = template ,
113
114
tokenizer = tokenizer ,
114
115
processor = processor ,
115
- cutoff_len = data_args .cutoff_len ,
116
+ cutoff_len = data_args .cutoff_len if data_args . allow_truncation else 0 ,
116
117
train_on_prompt = data_args .train_on_prompt ,
117
118
mask_history = data_args .mask_history ,
118
119
)
@@ -132,13 +133,16 @@ def preprocess_packed_supervised_dataset(
132
133
processor : Optional ["ProcessorMixin" ],
133
134
data_args : "DataArguments" ,
134
135
) -> Dict [str , List [Any ]]:
135
- # TODO: use `position_ids` to achieve packing
136
136
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
137
137
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
138
138
valid_num = 0
139
+ invalid_num = 0
139
140
batch_input_ids , batch_labels , batch_images , batch_videos = [], [], [], []
140
141
lengths = []
141
142
length2indexes = defaultdict (list )
143
+
144
+ # reserved for the padding token / flat_packing don't need
145
+ num_reserved = 0 if data_args .flat_packing else 1
142
146
for i in range (len (examples ["_prompt" ])):
143
147
if len (examples ["_prompt" ][i ]) % 2 != 1 or len (examples ["_response" ][i ]) != 1 :
144
148
logger .warning ("Dropped invalid example: {}" .format (examples ["_prompt" ][i ] + examples ["_response" ][i ]))
@@ -154,13 +158,13 @@ def preprocess_packed_supervised_dataset(
154
158
template = template ,
155
159
tokenizer = tokenizer ,
156
160
processor = processor ,
157
- cutoff_len = data_args .cutoff_len - 1 , # reserved for the padding token
161
+ cutoff_len = data_args .cutoff_len - num_reserved if data_args . allow_truncation else 0 ,
158
162
train_on_prompt = data_args .train_on_prompt ,
159
163
mask_history = data_args .mask_history ,
160
164
)
161
165
length = len (input_ids )
162
- if length > data_args .cutoff_len :
163
- logger . warning ( "Dropped lengthy example with length {} > {}." . format ( length , data_args . cutoff_len ))
166
+ if length > data_args .cutoff_len - num_reserved :
167
+ invalid_num += 1
164
168
else :
165
169
lengths .append (length )
166
170
length2indexes [length ].append (valid_num )
@@ -170,36 +174,52 @@ def preprocess_packed_supervised_dataset(
170
174
batch_videos .append (examples ["_videos" ][i ] or [])
171
175
valid_num += 1
172
176
177
+ if invalid_num > 0 :
178
+ logger .warning (
179
+ "Dropped lengthy {} example with length > {}." .format (invalid_num , data_args .cutoff_len - num_reserved )
180
+ )
181
+
173
182
model_inputs = defaultdict (list )
174
- knapsacks = greedy_knapsack (lengths , data_args .cutoff_len - 1 ) # reserved for the padding token
183
+ knapsacks = greedy_knapsack (lengths , data_args .cutoff_len - num_reserved ) # reserved for the padding token
175
184
for knapsack in knapsacks :
176
185
packed_input_ids , packed_attention_masks , packed_labels = [], [], []
177
186
packed_images , packed_videos = [], []
178
- for i , length in enumerate (knapsack ):
179
- index = length2indexes [length ].pop ()
180
- packed_input_ids += batch_input_ids [index ]
181
- packed_labels += batch_labels [index ]
182
- packed_images += batch_images [index ]
183
- packed_videos += batch_videos [index ]
184
- if data_args .neat_packing :
185
- packed_attention_masks += [i + 1 ] * len (batch_input_ids [index ]) # start from 1
186
- else :
187
- packed_attention_masks += [1 ] * len (batch_input_ids [index ])
188
-
189
- if len (packed_input_ids ) < data_args .cutoff_len :
190
- pad_length = data_args .cutoff_len - len (packed_input_ids )
191
- packed_input_ids += [tokenizer .pad_token_id ] * pad_length
192
- packed_labels += [IGNORE_INDEX ] * pad_length
193
- if data_args .neat_packing :
194
- packed_attention_masks += [0 ] * pad_length
195
- else :
196
- packed_attention_masks += [1 ] * pad_length # more efficient flash_attn
197
-
198
- if len (packed_input_ids ) != data_args .cutoff_len :
199
- raise ValueError ("The length of packed example should be identical to the cutoff length." )
187
+
188
+ if data_args .flat_packing :
189
+ for i , length in enumerate (knapsack ):
190
+ index = length2indexes [length ].pop ()
191
+ packed_input_ids .append (batch_input_ids [index ])
192
+ packed_labels .append (batch_labels [index ])
193
+ packed_images .append (batch_images [index ])
194
+ packed_videos .append (batch_videos [index ])
195
+ else :
196
+ for i , length in enumerate (knapsack ):
197
+ index = length2indexes [length ].pop ()
198
+ packed_input_ids += batch_input_ids [index ]
199
+ packed_labels += batch_labels [index ]
200
+ packed_images += batch_images [index ]
201
+ packed_videos += batch_videos [index ]
202
+ if data_args .neat_packing :
203
+ packed_attention_masks += [i + 1 ] * len (batch_input_ids [index ]) # start from 1
204
+ else :
205
+ packed_attention_masks += [1 ] * len (batch_input_ids [index ])
206
+
207
+ # flat_packing don't need attention masks
208
+ if len (packed_input_ids ) < data_args .cutoff_len :
209
+ pad_length = data_args .cutoff_len - len (packed_input_ids )
210
+ packed_input_ids += [tokenizer .pad_token_id ] * pad_length
211
+ packed_labels += [IGNORE_INDEX ] * pad_length
212
+ if data_args .neat_packing :
213
+ packed_attention_masks += [0 ] * pad_length
214
+ else :
215
+ packed_attention_masks += [1 ] * pad_length # more efficient flash_attn
216
+
217
+ # flatting packing don't need pad
218
+ if len (packed_input_ids ) != data_args .cutoff_len :
219
+ raise ValueError ("The length of packed example should be identical to the cutoff length." )
220
+ model_inputs ["attention_mask" ].append (packed_attention_masks )
200
221
201
222
model_inputs ["input_ids" ].append (packed_input_ids )
202
- model_inputs ["attention_mask" ].append (packed_attention_masks )
203
223
model_inputs ["labels" ].append (packed_labels )
204
224
model_inputs ["images" ].append (packed_images or None )
205
225
model_inputs ["videos" ].append (packed_videos or None )
@@ -213,3 +233,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
213
233
print ("inputs:\n {}" .format (tokenizer .decode (example ["input_ids" ], skip_special_tokens = False )))
214
234
print ("label_ids:\n {}" .format (example ["labels" ]))
215
235
print ("labels:\n {}" .format (tokenizer .decode (valid_labels , skip_special_tokens = False )))
236
+
237
+
238
+ def print_flatting_supervised_dataset_example (example : Dict [str , List [int ]], tokenizer : "PreTrainedTokenizer" ) -> None :
239
+ valid_labels = list (filter (lambda x : x != IGNORE_INDEX , itertools .chain (* example ["labels" ])))
240
+ input_ids = list (itertools .chain (* example ["input_ids" ]))
241
+ print ("input_ids:\n {}" .format (input_ids ))
242
+ print ("inputs:\n {}" .format (tokenizer .decode (input_ids , skip_special_tokens = False )))
243
+ print ("label_ids:\n {}" .format (list (itertools .chain (* example ["labels" ]))))
244
+ print ("labels:\n {}" .format (tokenizer .decode (valid_labels ), skip_special_tokens = False ))
0 commit comments