11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
14
+ import itertools
15
15
from collections import defaultdict
16
16
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Sequence , Tuple
17
17
18
18
from ...extras .constants import IGNORE_INDEX
19
19
from ...extras .logging import get_logger
20
20
from .processor_utils import greedy_knapsack , infer_seqlen
21
21
22
-
23
22
if TYPE_CHECKING :
24
23
from transformers import PreTrainedTokenizer , ProcessorMixin
25
24
26
25
from ...hparams import DataArguments
27
26
from ..mm_plugin import ImageInput , VideoInput
28
27
from ..template import Template
29
28
30
-
31
29
logger = get_logger (__name__ )
32
30
33
31
@@ -48,18 +46,12 @@ def _encode_supervised_example(
48
46
messages = template .mm_plugin .process_messages (prompt + response , images , videos , processor )
49
47
input_ids , labels = template .mm_plugin .process_token_ids ([], [], images , videos , tokenizer , processor )
50
48
encoded_pairs = template .encode_multiturn (tokenizer , messages , system , tools )
51
- total_length = len (input_ids ) + (1 if template .efficient_eos else 0 )
52
49
if mask_history :
53
50
encoded_pairs = encoded_pairs [::- 1 ] # high priority for last turns
54
51
55
52
for turn_idx , (source_ids , target_ids ) in enumerate (encoded_pairs ):
56
- if total_length >= cutoff_len :
57
- break
58
-
59
- source_len , target_len = infer_seqlen (len (source_ids ), len (target_ids ), cutoff_len - total_length )
60
- source_ids = source_ids [:source_len ]
61
- target_ids = target_ids [:target_len ]
62
- total_length += source_len + target_len
53
+ source_len = len (source_ids )
54
+ target_len = len (target_ids )
63
55
64
56
if train_on_prompt :
65
57
source_label = source_ids
@@ -132,13 +124,16 @@ def preprocess_packed_supervised_dataset(
132
124
processor : Optional ["ProcessorMixin" ],
133
125
data_args : "DataArguments" ,
134
126
) -> Dict [str , List [Any ]]:
135
- # TODO: use `position_ids` to achieve packing
136
127
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
137
128
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
138
129
valid_num = 0
130
+ invalid_num = 0
139
131
batch_input_ids , batch_labels , batch_images , batch_videos = [], [], [], []
140
132
lengths = []
141
133
length2indexes = defaultdict (list )
134
+
135
+ # reserved for the padding token / flatting_packing don't need
136
+ num_reserved = 0 if data_args .flatting_packing else 1
142
137
for i in range (len (examples ["_prompt" ])):
143
138
if len (examples ["_prompt" ][i ]) % 2 != 1 or len (examples ["_response" ][i ]) != 1 :
144
139
logger .warning ("Dropped invalid example: {}" .format (examples ["_prompt" ][i ] + examples ["_response" ][i ]))
@@ -154,13 +149,13 @@ def preprocess_packed_supervised_dataset(
154
149
template = template ,
155
150
tokenizer = tokenizer ,
156
151
processor = processor ,
157
- cutoff_len = data_args .cutoff_len - 1 , # reserved for the padding token
152
+ cutoff_len = data_args .cutoff_len - num_reserved ,
158
153
train_on_prompt = data_args .train_on_prompt ,
159
154
mask_history = data_args .mask_history ,
160
155
)
161
156
length = len (input_ids )
162
- if length > data_args .cutoff_len :
163
- logger . warning ( "Dropped lengthy example with length {} > {}." . format ( length , data_args . cutoff_len ))
157
+ if length > data_args .cutoff_len - num_reserved :
158
+ invalid_num += 1
164
159
else :
165
160
lengths .append (length )
166
161
length2indexes [length ].append (valid_num )
@@ -170,36 +165,52 @@ def preprocess_packed_supervised_dataset(
170
165
batch_videos .append (examples ["_videos" ][i ] or [])
171
166
valid_num += 1
172
167
168
+ if invalid_num > 0 :
169
+ logger .warning (
170
+ "Dropped lengthy {} example with length > {}." .format (invalid_num , data_args .cutoff_len - num_reserved )
171
+ )
172
+
173
173
model_inputs = defaultdict (list )
174
- knapsacks = greedy_knapsack (lengths , data_args .cutoff_len - 1 ) # reserved for the padding token
174
+ knapsacks = greedy_knapsack (lengths , data_args .cutoff_len - num_reserved ) # reserved for the padding token
175
175
for knapsack in knapsacks :
176
176
packed_input_ids , packed_attention_masks , packed_labels = [], [], []
177
177
packed_images , packed_videos = [], []
178
- for i , length in enumerate (knapsack ):
179
- index = length2indexes [length ].pop ()
180
- packed_input_ids += batch_input_ids [index ]
181
- packed_labels += batch_labels [index ]
182
- packed_images += batch_images [index ]
183
- packed_videos += batch_videos [index ]
184
- if data_args .neat_packing :
185
- packed_attention_masks += [i + 1 ] * len (batch_input_ids [index ]) # start from 1
186
- else :
187
- packed_attention_masks += [1 ] * len (batch_input_ids [index ])
188
-
189
- if len (packed_input_ids ) < data_args .cutoff_len :
190
- pad_length = data_args .cutoff_len - len (packed_input_ids )
191
- packed_input_ids += [tokenizer .pad_token_id ] * pad_length
192
- packed_labels += [IGNORE_INDEX ] * pad_length
193
- if data_args .neat_packing :
194
- packed_attention_masks += [0 ] * pad_length
195
- else :
196
- packed_attention_masks += [1 ] * pad_length # more efficient flash_attn
197
-
198
- if len (packed_input_ids ) != data_args .cutoff_len :
199
- raise ValueError ("The length of packed example should be identical to the cutoff length." )
178
+
179
+ if data_args .flatting_packing :
180
+ for i , length in enumerate (knapsack ):
181
+ index = length2indexes [length ].pop ()
182
+ packed_input_ids .append (batch_input_ids [index ])
183
+ packed_labels .append (batch_labels [index ])
184
+ packed_images .append (batch_images [index ])
185
+ packed_videos .append (batch_videos [index ])
186
+ else :
187
+ for i , length in enumerate (knapsack ):
188
+ index = length2indexes [length ].pop ()
189
+ packed_input_ids += batch_input_ids [index ]
190
+ packed_labels += batch_labels [index ]
191
+ packed_images += batch_images [index ]
192
+ packed_videos += batch_videos [index ]
193
+ if data_args .neat_packing :
194
+ packed_attention_masks += [i + 1 ] * len (batch_input_ids [index ]) # start from 1
195
+ else :
196
+ packed_attention_masks += [1 ] * len (batch_input_ids [index ])
197
+
198
+ # flatting_packing don't need attention masks
199
+ if len (packed_input_ids ) < data_args .cutoff_len :
200
+ pad_length = data_args .cutoff_len - len (packed_input_ids )
201
+ packed_input_ids += [tokenizer .pad_token_id ] * pad_length
202
+ packed_labels += [IGNORE_INDEX ] * pad_length
203
+ if data_args .neat_packing :
204
+ packed_attention_masks += [0 ] * pad_length
205
+ else :
206
+ packed_attention_masks += [1 ] * pad_length # more efficient flash_attn
207
+
208
+ # flatting packing don't need pad
209
+ if len (packed_input_ids ) != data_args .cutoff_len :
210
+ raise ValueError ("The length of packed example should be identical to the cutoff length." )
211
+ model_inputs ["attention_mask" ].append (packed_attention_masks )
200
212
201
213
model_inputs ["input_ids" ].append (packed_input_ids )
202
- model_inputs ["attention_mask" ].append (packed_attention_masks )
203
214
model_inputs ["labels" ].append (packed_labels )
204
215
model_inputs ["images" ].append (packed_images or None )
205
216
model_inputs ["videos" ].append (packed_videos or None )
@@ -213,3 +224,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
213
224
print ("inputs:\n {}" .format (tokenizer .decode (example ["input_ids" ], skip_special_tokens = False )))
214
225
print ("label_ids:\n {}" .format (example ["labels" ]))
215
226
print ("labels:\n {}" .format (tokenizer .decode (valid_labels , skip_special_tokens = False )))
227
+
228
+
229
+ def print_flatting_supervised_dataset_example (example : Dict [str , List [int ]], tokenizer : "PreTrainedTokenizer" ) -> None :
230
+ valid_labels = list (filter (lambda x : x != IGNORE_INDEX , itertools .chain (* example ["labels" ])))
231
+ input_ids = list (itertools .chain (* example ["input_ids" ]))
232
+ print ("input_ids:\n {}" .format (input_ids ))
233
+ print ("inputs:\n {}" .format (tokenizer .decode (input_ids , skip_special_tokens = False )))
234
+ print ("label_ids:\n {}" .format (list (itertools .chain (* example ["labels" ]))))
235
+ print ("labels:\n {}" .format (tokenizer .decode (valid_labels ), skip_special_tokens = False ))
0 commit comments