Skip to content

Commit 558b983

Browse files
committed
1. support flatting_packing
2. update mistral format function call 3. fix knapsack, may cause hiyouga#5443 4. avoid supervised examples wrongly truncation hiyouga#5426
1 parent 1a3e654 commit 558b983

File tree

11 files changed

+224
-115
lines changed

11 files changed

+224
-115
lines changed

src/llamafactory/cli.py

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from .train.tuner import export_model, run_exp
2929
from .webui.interface import run_web_demo, run_web_ui
3030

31-
3231
USAGE = (
3332
"-" * 70
3433
+ "\n"

src/llamafactory/data/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@
1717
MultiModalDataCollatorForSeq2Seq,
1818
PairwiseDataCollatorWithPadding,
1919
SFTDataCollatorWith4DAttentionMask,
20+
SFTDataCollatorWithFlattingPacking,
2021
)
2122
from .data_utils import Role, split_dataset
2223
from .loader import get_dataset
2324
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
2425

25-
2626
__all__ = [
2727
"KTODataCollatorWithPadding",
2828
"MultiModalDataCollatorForSeq2Seq",
2929
"PairwiseDataCollatorWithPadding",
3030
"SFTDataCollatorWith4DAttentionMask",
31+
"SFTDataCollatorWithFlattingPacking",
3132
"Role",
3233
"split_dataset",
3334
"get_dataset",

src/llamafactory/data/collator.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
2020

2121
import torch
22-
from transformers import DataCollatorForSeq2Seq
23-
22+
from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, default_data_collator, PreTrainedTokenizerBase
2423

2524
if TYPE_CHECKING:
2625
from transformers import ProcessorMixin
@@ -120,6 +119,42 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
120119
return features
121120

122121

122+
@dataclass
123+
class SFTDataCollatorWithFlattingPacking(DefaultDataCollator):
124+
r"""
125+
Data collator for flatting packing.
126+
"""
127+
128+
tokenizer: PreTrainedTokenizerBase = None
129+
label_pad_token_id: int = -100
130+
template: Optional["Template"] = None
131+
processor: Optional["ProcessorMixin"] = None
132+
return_position_ids: bool = True
133+
134+
def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, "torch.Tensor"]:
135+
# todo: not support multi-model
136+
if return_tensors is None:
137+
return_tensors = self.return_tensors
138+
is_labels_provided = "labels" in features[0]
139+
ret = {"input_ids": [], "labels": []}
140+
if self.return_position_ids:
141+
ret.update({"position_ids": []})
142+
for instances in features:
143+
for input_ids, labels in zip(instances["input_ids"], instances["labels"]):
144+
ret["input_ids"] += input_ids
145+
if is_labels_provided:
146+
ret["labels"] += [self.label_pad_token_id] + labels[1:]
147+
else:
148+
ret["labels"] += [self.label_pad_token_id] + input_ids[1:]
149+
if self.return_position_ids:
150+
ret["position_ids"] += list(range(len(input_ids)))
151+
152+
assert len(ret["input_ids"]) == len(ret["labels"])
153+
154+
features: Dict[str, "torch.Tensor"] = default_data_collator([ret], return_tensors)
155+
return features
156+
157+
123158
@dataclass
124159
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
125160
r"""

src/llamafactory/data/formatter.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from .data_utils import SLOTS
2424
from .tool_utils import get_tool_utils
2525

26-
2726
if TYPE_CHECKING:
2827
from .tool_utils import FunctionCall
2928

@@ -129,6 +128,51 @@ def apply(self, **kwargs) -> SLOTS:
129128
return elements
130129

131130

131+
@dataclass
132+
class MistralFunctionFormatter(Formatter):
133+
@override
134+
def apply(self, **kwargs) -> SLOTS:
135+
content = kwargs.pop("content")
136+
functions: List[Tuple[str, str]] = []
137+
try:
138+
tool_calls = json.loads(content)
139+
if not isinstance(tool_calls, list): # parallel function call
140+
tool_calls = [tool_calls]
141+
142+
for tool_call in tool_calls:
143+
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
144+
145+
except json.JSONDecodeError:
146+
functions = []
147+
148+
elements = []
149+
for name, arguments in functions:
150+
elements.append(f""""{{"name":"{name}","arguments":{arguments}}}""")
151+
elements = ["[TOOL_CALLS] [" + ", ".join(elements) + "]"]
152+
153+
return elements
154+
155+
156+
@dataclass
157+
class MistralObservationFormatter(Formatter):
158+
def __post_init__(self):
159+
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
160+
161+
@override
162+
def apply(self, **kwargs) -> SLOTS:
163+
content = kwargs.pop("content")
164+
tool_results: List[Tuple[str, str]]
165+
try:
166+
tool_results = [json.dumps(result) for result in json.loads(content)]
167+
except json.JSONDecodeError:
168+
tool_results = []
169+
170+
elements = []
171+
for content in tool_results:
172+
elements.append(f"[TOOL_RESULTS] {{\"content\":{content}}}[/TOOL_RESULTS]")
173+
return ["".join(elements)]
174+
175+
132176
@dataclass
133177
class ToolFormatter(Formatter):
134178
def __post_init__(self):

src/llamafactory/data/preprocess.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
preprocess_packed_supervised_dataset,
2323
preprocess_supervised_dataset,
2424
print_supervised_dataset_example,
25+
print_flatting_supervised_dataset_example,
2526
)
2627
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
2728

28-
2929
if TYPE_CHECKING:
3030
from transformers import PreTrainedTokenizer, ProcessorMixin
3131

@@ -78,8 +78,10 @@ def __init__(self, data, **kwargs):
7878
processor=processor,
7979
data_args=data_args,
8080
)
81-
82-
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
81+
if data_args.packing and data_args.flatting_packing:
82+
print_function = partial(print_flatting_supervised_dataset_example, tokenizer=tokenizer)
83+
else:
84+
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
8385
elif stage == "rm":
8486
preprocess_func = partial(
8587
preprocess_pairwise_dataset,

src/llamafactory/data/processors/processor_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
2828
r"""
2929
An efficient greedy algorithm with binary search for the knapsack problem.
3030
"""
31+
# filter out numbers that are larger than the capacity
32+
numbers = [number for number in numbers if number <= capacity]
3133
numbers.sort() # sort numbers in ascending order for binary search
3234
knapsacks = []
3335

@@ -43,6 +45,10 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
4345
remaining_capacity -= numbers[index] # update the remaining capacity
4446
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
4547

48+
# avoid endless loop
49+
if remaining_capacity == capacity:
50+
break
51+
4652
knapsacks.append(current_knapsack)
4753

4854
return knapsacks

src/llamafactory/data/processors/supervised.py

+59-39
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,21 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import itertools
1515
from collections import defaultdict
1616
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
1717

1818
from ...extras.constants import IGNORE_INDEX
1919
from ...extras.logging import get_logger
2020
from .processor_utils import greedy_knapsack, infer_seqlen
2121

22-
2322
if TYPE_CHECKING:
2423
from transformers import PreTrainedTokenizer, ProcessorMixin
2524

2625
from ...hparams import DataArguments
2726
from ..mm_plugin import ImageInput, VideoInput
2827
from ..template import Template
2928

30-
3129
logger = get_logger(__name__)
3230

3331

@@ -48,18 +46,12 @@ def _encode_supervised_example(
4846
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
4947
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
5048
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
51-
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
5249
if mask_history:
5350
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
5451

5552
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
56-
if total_length >= cutoff_len:
57-
break
58-
59-
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
60-
source_ids = source_ids[:source_len]
61-
target_ids = target_ids[:target_len]
62-
total_length += source_len + target_len
53+
source_len = len(source_ids)
54+
target_len = len(target_ids)
6355

6456
if train_on_prompt:
6557
source_label = source_ids
@@ -132,13 +124,16 @@ def preprocess_packed_supervised_dataset(
132124
processor: Optional["ProcessorMixin"],
133125
data_args: "DataArguments",
134126
) -> Dict[str, List[Any]]:
135-
# TODO: use `position_ids` to achieve packing
136127
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
137128
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
138129
valid_num = 0
130+
invalid_num = 0
139131
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
140132
lengths = []
141133
length2indexes = defaultdict(list)
134+
135+
# reserved for the padding token / flatting_packing don't need
136+
num_reserved = 0 if data_args.flatting_packing else 1
142137
for i in range(len(examples["_prompt"])):
143138
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
144139
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
@@ -154,13 +149,13 @@ def preprocess_packed_supervised_dataset(
154149
template=template,
155150
tokenizer=tokenizer,
156151
processor=processor,
157-
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
152+
cutoff_len=data_args.cutoff_len - num_reserved,
158153
train_on_prompt=data_args.train_on_prompt,
159154
mask_history=data_args.mask_history,
160155
)
161156
length = len(input_ids)
162-
if length > data_args.cutoff_len:
163-
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
157+
if length > data_args.cutoff_len - num_reserved:
158+
invalid_num += 1
164159
else:
165160
lengths.append(length)
166161
length2indexes[length].append(valid_num)
@@ -170,36 +165,52 @@ def preprocess_packed_supervised_dataset(
170165
batch_videos.append(examples["_videos"][i] or [])
171166
valid_num += 1
172167

168+
if invalid_num > 0:
169+
logger.warning(
170+
"Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved)
171+
)
172+
173173
model_inputs = defaultdict(list)
174-
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
174+
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token
175175
for knapsack in knapsacks:
176176
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
177177
packed_images, packed_videos = [], []
178-
for i, length in enumerate(knapsack):
179-
index = length2indexes[length].pop()
180-
packed_input_ids += batch_input_ids[index]
181-
packed_labels += batch_labels[index]
182-
packed_images += batch_images[index]
183-
packed_videos += batch_videos[index]
184-
if data_args.neat_packing:
185-
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
186-
else:
187-
packed_attention_masks += [1] * len(batch_input_ids[index])
188-
189-
if len(packed_input_ids) < data_args.cutoff_len:
190-
pad_length = data_args.cutoff_len - len(packed_input_ids)
191-
packed_input_ids += [tokenizer.pad_token_id] * pad_length
192-
packed_labels += [IGNORE_INDEX] * pad_length
193-
if data_args.neat_packing:
194-
packed_attention_masks += [0] * pad_length
195-
else:
196-
packed_attention_masks += [1] * pad_length # more efficient flash_attn
197-
198-
if len(packed_input_ids) != data_args.cutoff_len:
199-
raise ValueError("The length of packed example should be identical to the cutoff length.")
178+
179+
if data_args.flatting_packing:
180+
for i, length in enumerate(knapsack):
181+
index = length2indexes[length].pop()
182+
packed_input_ids.append(batch_input_ids[index])
183+
packed_labels.append(batch_labels[index])
184+
packed_images.append(batch_images[index])
185+
packed_videos.append(batch_videos[index])
186+
else:
187+
for i, length in enumerate(knapsack):
188+
index = length2indexes[length].pop()
189+
packed_input_ids += batch_input_ids[index]
190+
packed_labels += batch_labels[index]
191+
packed_images += batch_images[index]
192+
packed_videos += batch_videos[index]
193+
if data_args.neat_packing:
194+
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
195+
else:
196+
packed_attention_masks += [1] * len(batch_input_ids[index])
197+
198+
# flatting_packing don't need attention masks
199+
if len(packed_input_ids) < data_args.cutoff_len:
200+
pad_length = data_args.cutoff_len - len(packed_input_ids)
201+
packed_input_ids += [tokenizer.pad_token_id] * pad_length
202+
packed_labels += [IGNORE_INDEX] * pad_length
203+
if data_args.neat_packing:
204+
packed_attention_masks += [0] * pad_length
205+
else:
206+
packed_attention_masks += [1] * pad_length # more efficient flash_attn
207+
208+
# flatting packing don't need pad
209+
if len(packed_input_ids) != data_args.cutoff_len:
210+
raise ValueError("The length of packed example should be identical to the cutoff length.")
211+
model_inputs["attention_mask"].append(packed_attention_masks)
200212

201213
model_inputs["input_ids"].append(packed_input_ids)
202-
model_inputs["attention_mask"].append(packed_attention_masks)
203214
model_inputs["labels"].append(packed_labels)
204215
model_inputs["images"].append(packed_images or None)
205216
model_inputs["videos"].append(packed_videos or None)
@@ -213,3 +224,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
213224
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
214225
print("label_ids:\n{}".format(example["labels"]))
215226
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
227+
228+
229+
def print_flatting_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
230+
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, itertools.chain(*example["labels"])))
231+
input_ids = list(itertools.chain(*example["input_ids"]))
232+
print("input_ids:\n{}".format(input_ids))
233+
print("inputs:\n{}".format(tokenizer.decode(input_ids, skip_special_tokens=False)))
234+
print("label_ids:\n{}".format(list(itertools.chain(*example["labels"]))))
235+
print("labels:\n{}".format(tokenizer.decode(valid_labels), skip_special_tokens=False))

0 commit comments

Comments
 (0)