Skip to content

Commit 7cab73b

Browse files
committed
1. support flat_packing
2. fix knapsack, may cause #5443 3. avoid supervised examples wrongly truncation
1 parent 1a3e654 commit 7cab73b

File tree

7 files changed

+155
-54
lines changed

7 files changed

+155
-54
lines changed

src/llamafactory/data/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@
1717
MultiModalDataCollatorForSeq2Seq,
1818
PairwiseDataCollatorWithPadding,
1919
SFTDataCollatorWith4DAttentionMask,
20+
SFTDataCollatorWithFlattingPacking,
2021
)
2122
from .data_utils import Role, split_dataset
2223
from .loader import get_dataset
2324
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
2425

25-
2626
__all__ = [
2727
"KTODataCollatorWithPadding",
2828
"MultiModalDataCollatorForSeq2Seq",
2929
"PairwiseDataCollatorWithPadding",
3030
"SFTDataCollatorWith4DAttentionMask",
31+
"SFTDataCollatorWithFlattingPacking",
3132
"Role",
3233
"split_dataset",
3334
"get_dataset",

src/llamafactory/data/collator.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
2020

2121
import torch
22-
from transformers import DataCollatorForSeq2Seq
23-
22+
from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, default_data_collator, PreTrainedTokenizerBase
2423

2524
if TYPE_CHECKING:
2625
from transformers import ProcessorMixin
@@ -120,6 +119,42 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
120119
return features
121120

122121

122+
@dataclass
123+
class SFTDataCollatorWithFlattingPacking(DefaultDataCollator):
124+
r"""
125+
Data collator for flatting packing.
126+
"""
127+
128+
tokenizer: PreTrainedTokenizerBase = None
129+
label_pad_token_id: int = -100
130+
template: Optional["Template"] = None
131+
processor: Optional["ProcessorMixin"] = None
132+
return_position_ids: bool = True
133+
134+
def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, "torch.Tensor"]:
135+
# todo: not support multi-model
136+
if return_tensors is None:
137+
return_tensors = self.return_tensors
138+
is_labels_provided = "labels" in features[0]
139+
ret = {"input_ids": [], "labels": []}
140+
if self.return_position_ids:
141+
ret.update({"position_ids": []})
142+
for instances in features:
143+
for input_ids, labels in zip(instances["input_ids"], instances["labels"]):
144+
ret["input_ids"] += input_ids
145+
if is_labels_provided:
146+
ret["labels"] += [self.label_pad_token_id] + labels[1:]
147+
else:
148+
ret["labels"] += [self.label_pad_token_id] + input_ids[1:]
149+
if self.return_position_ids:
150+
ret["position_ids"] += list(range(len(input_ids)))
151+
152+
assert len(ret["input_ids"]) == len(ret["labels"])
153+
154+
features: Dict[str, "torch.Tensor"] = default_data_collator([ret], return_tensors)
155+
return features
156+
157+
123158
@dataclass
124159
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
125160
r"""

src/llamafactory/data/preprocess.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
preprocess_packed_supervised_dataset,
2323
preprocess_supervised_dataset,
2424
print_supervised_dataset_example,
25+
print_flatting_supervised_dataset_example,
2526
)
2627
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
2728

28-
2929
if TYPE_CHECKING:
3030
from transformers import PreTrainedTokenizer, ProcessorMixin
3131

@@ -78,8 +78,10 @@ def __init__(self, data, **kwargs):
7878
processor=processor,
7979
data_args=data_args,
8080
)
81-
82-
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
81+
if data_args.packing and data_args.flat_packing:
82+
print_function = partial(print_flatting_supervised_dataset_example, tokenizer=tokenizer)
83+
else:
84+
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
8385
elif stage == "rm":
8486
preprocess_func = partial(
8587
preprocess_pairwise_dataset,

src/llamafactory/data/processors/processor_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
2828
r"""
2929
An efficient greedy algorithm with binary search for the knapsack problem.
3030
"""
31+
# filter out numbers that are larger than the capacity
32+
numbers = [number for number in numbers if number <= capacity]
3133
numbers.sort() # sort numbers in ascending order for binary search
3234
knapsacks = []
3335

@@ -43,6 +45,10 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
4345
remaining_capacity -= numbers[index] # update the remaining capacity
4446
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
4547

48+
# avoid endless loop
49+
if remaining_capacity == capacity:
50+
break
51+
4652
knapsacks.append(current_knapsack)
4753

4854
return knapsacks

src/llamafactory/data/processors/supervised.py

+66-37
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,21 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import itertools
1515
from collections import defaultdict
1616
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
1717

1818
from ...extras.constants import IGNORE_INDEX
1919
from ...extras.logging import get_logger
2020
from .processor_utils import greedy_knapsack, infer_seqlen
2121

22-
2322
if TYPE_CHECKING:
2423
from transformers import PreTrainedTokenizer, ProcessorMixin
2524

2625
from ...hparams import DataArguments
2726
from ..mm_plugin import ImageInput, VideoInput
2827
from ..template import Template
2928

30-
3129
logger = get_logger(__name__)
3230

3331

@@ -53,13 +51,16 @@ def _encode_supervised_example(
5351
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
5452

5553
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
56-
if total_length >= cutoff_len:
54+
if total_length >= cutoff_len and cutoff_len > 0:
5755
break
5856

59-
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
60-
source_ids = source_ids[:source_len]
61-
target_ids = target_ids[:target_len]
62-
total_length += source_len + target_len
57+
if cutoff_len > 0:
58+
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
59+
source_ids = source_ids[:source_len]
60+
target_ids = target_ids[:target_len]
61+
total_length += source_len + target_len
62+
else:
63+
source_len, target_len = len(source_ids), len(target_ids)
6364

6465
if train_on_prompt:
6566
source_label = source_ids
@@ -112,7 +113,7 @@ def preprocess_supervised_dataset(
112113
template=template,
113114
tokenizer=tokenizer,
114115
processor=processor,
115-
cutoff_len=data_args.cutoff_len,
116+
cutoff_len=data_args.cutoff_len if data_args.allow_truncation else 0,
116117
train_on_prompt=data_args.train_on_prompt,
117118
mask_history=data_args.mask_history,
118119
)
@@ -132,13 +133,16 @@ def preprocess_packed_supervised_dataset(
132133
processor: Optional["ProcessorMixin"],
133134
data_args: "DataArguments",
134135
) -> Dict[str, List[Any]]:
135-
# TODO: use `position_ids` to achieve packing
136136
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
137137
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
138138
valid_num = 0
139+
invalid_num = 0
139140
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
140141
lengths = []
141142
length2indexes = defaultdict(list)
143+
144+
# reserved for the padding token / flat_packing don't need
145+
num_reserved = 0 if data_args.flat_packing else 1
142146
for i in range(len(examples["_prompt"])):
143147
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
144148
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
@@ -154,13 +158,13 @@ def preprocess_packed_supervised_dataset(
154158
template=template,
155159
tokenizer=tokenizer,
156160
processor=processor,
157-
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
161+
cutoff_len=data_args.cutoff_len - num_reserved if data_args.allow_truncation else 0,
158162
train_on_prompt=data_args.train_on_prompt,
159163
mask_history=data_args.mask_history,
160164
)
161165
length = len(input_ids)
162-
if length > data_args.cutoff_len:
163-
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
166+
if length > data_args.cutoff_len - num_reserved:
167+
invalid_num += 1
164168
else:
165169
lengths.append(length)
166170
length2indexes[length].append(valid_num)
@@ -170,36 +174,52 @@ def preprocess_packed_supervised_dataset(
170174
batch_videos.append(examples["_videos"][i] or [])
171175
valid_num += 1
172176

177+
if invalid_num > 0:
178+
logger.warning(
179+
"Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved)
180+
)
181+
173182
model_inputs = defaultdict(list)
174-
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
183+
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token
175184
for knapsack in knapsacks:
176185
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
177186
packed_images, packed_videos = [], []
178-
for i, length in enumerate(knapsack):
179-
index = length2indexes[length].pop()
180-
packed_input_ids += batch_input_ids[index]
181-
packed_labels += batch_labels[index]
182-
packed_images += batch_images[index]
183-
packed_videos += batch_videos[index]
184-
if data_args.neat_packing:
185-
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
186-
else:
187-
packed_attention_masks += [1] * len(batch_input_ids[index])
188-
189-
if len(packed_input_ids) < data_args.cutoff_len:
190-
pad_length = data_args.cutoff_len - len(packed_input_ids)
191-
packed_input_ids += [tokenizer.pad_token_id] * pad_length
192-
packed_labels += [IGNORE_INDEX] * pad_length
193-
if data_args.neat_packing:
194-
packed_attention_masks += [0] * pad_length
195-
else:
196-
packed_attention_masks += [1] * pad_length # more efficient flash_attn
197-
198-
if len(packed_input_ids) != data_args.cutoff_len:
199-
raise ValueError("The length of packed example should be identical to the cutoff length.")
187+
188+
if data_args.flat_packing:
189+
for i, length in enumerate(knapsack):
190+
index = length2indexes[length].pop()
191+
packed_input_ids.append(batch_input_ids[index])
192+
packed_labels.append(batch_labels[index])
193+
packed_images.append(batch_images[index])
194+
packed_videos.append(batch_videos[index])
195+
else:
196+
for i, length in enumerate(knapsack):
197+
index = length2indexes[length].pop()
198+
packed_input_ids += batch_input_ids[index]
199+
packed_labels += batch_labels[index]
200+
packed_images += batch_images[index]
201+
packed_videos += batch_videos[index]
202+
if data_args.neat_packing:
203+
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
204+
else:
205+
packed_attention_masks += [1] * len(batch_input_ids[index])
206+
207+
# flat_packing don't need attention masks
208+
if len(packed_input_ids) < data_args.cutoff_len:
209+
pad_length = data_args.cutoff_len - len(packed_input_ids)
210+
packed_input_ids += [tokenizer.pad_token_id] * pad_length
211+
packed_labels += [IGNORE_INDEX] * pad_length
212+
if data_args.neat_packing:
213+
packed_attention_masks += [0] * pad_length
214+
else:
215+
packed_attention_masks += [1] * pad_length # more efficient flash_attn
216+
217+
# flatting packing don't need pad
218+
if len(packed_input_ids) != data_args.cutoff_len:
219+
raise ValueError("The length of packed example should be identical to the cutoff length.")
220+
model_inputs["attention_mask"].append(packed_attention_masks)
200221

201222
model_inputs["input_ids"].append(packed_input_ids)
202-
model_inputs["attention_mask"].append(packed_attention_masks)
203223
model_inputs["labels"].append(packed_labels)
204224
model_inputs["images"].append(packed_images or None)
205225
model_inputs["videos"].append(packed_videos or None)
@@ -213,3 +233,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
213233
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
214234
print("label_ids:\n{}".format(example["labels"]))
215235
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
236+
237+
238+
def print_flatting_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
239+
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, itertools.chain(*example["labels"])))
240+
input_ids = list(itertools.chain(*example["input_ids"]))
241+
print("input_ids:\n{}".format(input_ids))
242+
print("inputs:\n{}".format(tokenizer.decode(input_ids, skip_special_tokens=False)))
243+
print("label_ids:\n{}".format(list(itertools.chain(*example["labels"]))))
244+
print("labels:\n{}".format(tokenizer.decode(valid_labels), skip_special_tokens=False))

src/llamafactory/hparams/data_args.py

+11
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ class DataArguments:
105105
default=False,
106106
metadata={"help": "Enable sequence packing without cross-attention."},
107107
)
108+
flat_packing: bool = field(
109+
default=False,
110+
metadata={"help": "Enable sequence packing with flattening, need flash atten."}
111+
)
112+
allow_truncation: bool = field(
113+
default=False,
114+
metadata={"help": "Allow truncation when processing supervised examples."}
115+
)
108116
tool_format: Optional[str] = field(
109117
default=None,
110118
metadata={"help": "Tool format to use for constructing function calling examples."},
@@ -148,3 +156,6 @@ def split_arg(arg):
148156

149157
if self.mask_history and self.train_on_prompt:
150158
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
159+
160+
if self.neat_packing and self.flat_packing:
161+
raise ValueError("`neat_packing` is incompatible with `flat_packing`.")

src/llamafactory/train/sft/workflow.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,24 @@
1717

1818
from typing import TYPE_CHECKING, List, Optional
1919

20-
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
20+
from ...data import SFTDataCollatorWith4DAttentionMask, SFTDataCollatorWithFlattingPacking, get_dataset, \
21+
get_template_and_fix_tokenizer
2122
from ...extras.constants import IGNORE_INDEX
2223
from ...extras.misc import get_logits_processor
2324
from ...extras.ploting import plot_loss
25+
from ...extras.logging import get_logger
2426
from ...model import load_model, load_tokenizer
2527
from ..trainer_utils import create_modelcard_and_push
2628
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
2729
from .trainer import CustomSeq2SeqTrainer
2830

29-
3031
if TYPE_CHECKING:
3132
from transformers import Seq2SeqTrainingArguments, TrainerCallback
3233

3334
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
3435

36+
logger = get_logger(__name__)
37+
3538

3639
def run_sft(
3740
model_args: "ModelArguments",
@@ -50,15 +53,29 @@ def run_sft(
5053
if getattr(model, "is_quantized", False) and not training_args.do_train:
5154
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
5255

53-
data_collator = SFTDataCollatorWith4DAttentionMask(
54-
template=template,
55-
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
56-
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
57-
block_diag_attn=model_args.block_diag_attn,
58-
attn_implementation=getattr(model.config, "_attn_implementation", None),
59-
compute_dtype=model_args.compute_dtype,
60-
**tokenizer_module,
61-
)
56+
if (
57+
data_args.packing and
58+
data_args.flat_packing and
59+
(getattr(model.config, "_attn_implementation", None) != "flash_attention_2")
60+
):
61+
logger.warning("The `flat_packing` only support `flash_attention_2`! Maybe cause Out of memory!")
62+
63+
if (data_args.packing and data_args.flat_packing):
64+
data_collator = SFTDataCollatorWithFlattingPacking(
65+
template=template,
66+
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
67+
**tokenizer_module,
68+
)
69+
else:
70+
data_collator = SFTDataCollatorWith4DAttentionMask(
71+
template=template,
72+
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
73+
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
74+
block_diag_attn=model_args.block_diag_attn,
75+
attn_implementation=getattr(model.config, "_attn_implementation", None),
76+
compute_dtype=model_args.compute_dtype,
77+
**tokenizer_module,
78+
)
6279

6380
# Override the decoding parameters of Seq2SeqTrainer
6481
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len

0 commit comments

Comments
 (0)