Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 50 additions & 50 deletions data/reason_tool_use_demo_50.jsonl

Large diffs are not rendered by default.

13 changes: 5 additions & 8 deletions src/llamafactory/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict

import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand All @@ -25,14 +26,11 @@
AutoProcessor,
AutoTokenizer,
)
from packaging import version
from torch import nn
from trl import AutoModelForCausalLMWithValueHead
import warnings

from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import _get_package_version
from ..extras.packages import is_torch_version_greater_than
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
Expand Down Expand Up @@ -206,11 +204,10 @@ def load_model(
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")

# Conv3D is not recommended when using torch 2.9.x
torch_version = _get_package_version("torch")
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
raise ValueError(
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
"This combination is known to cause severe performance regression. "
Expand Down
2 changes: 0 additions & 2 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,13 @@ def __init__(

self.compute_loss_func = dft_loss_func


elif finetuning_args.use_eaft_loss:
from ..trainer_utils import eaft_loss_func

self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)


if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)

Expand Down
32 changes: 20 additions & 12 deletions src/llamafactory/train/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,9 @@ def get_batch_logps(
return logps, valid_length


def dft_loss_func(outputs, labels, num_items_in_batch=None):
def dft_loss_func(
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
):
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
Expand All @@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):


def _dft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
ignore_index: int = -100,
) -> torch.Tensor:
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
Expand All @@ -679,7 +681,12 @@ def _dft_cross_entropy(
return loss


def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
def eaft_loss_func(
outputs: "torch.Tensor",
labels: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
) -> "torch.Tensor":
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
Expand All @@ -697,12 +704,12 @@ def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):


def _eaft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
ignore_index: int = -100,
) -> torch.Tensor:
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
Expand All @@ -712,13 +719,13 @@ def _eaft_cross_entropy(

with torch.no_grad():
source_detached = source[valid_mask].detach()

topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
log_probs_topk = topk_val - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)

entropy_term = entropy_approx / 3.0
adaptive_weight = torch.pow(entropy_term, alpha)

Expand All @@ -731,6 +738,7 @@ def _eaft_cross_entropy(
loss = total_loss / num_items_in_batch
else:
loss = weighted_losses.mean()

return loss


Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/v1/config/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ModelArguments:
metadata={"help": "Path to the model or model identifier from Hugging Face."},
)
template: str = field(
default="chatml",
default="qwen3_nothink",
metadata={"help": "Template for the model."},
)
trust_remote_code: bool = field(
Expand Down
180 changes: 20 additions & 160 deletions src/llamafactory/v1/core/utils/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re

from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor


def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""

tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))

return ""


def render_chatml_messages(
processor: Processor,
messages: list[Message],
Expand All @@ -52,123 +26,38 @@ def render_chatml_messages(
) -> ModelInput:
"""Apply chatml template to messages and convert them to model input.

See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
"""
tokenizer = get_tokenizer(processor)
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")

temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)

temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")

if not isinstance(tools, list):
tools = [tools]

for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)

temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:

for message in messages:
temp_str = "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")

temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)

temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)

for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")

temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"

try:
tool_call = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)

else:
raise ValueError(f"Unsupported content type: {content['type']}")

temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"

temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")

temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"

temp_weight = message.get("loss_weight", 0.0)

temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))

if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0

temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([0.0] * len(temp_ids))
labels.extend([IGNORE_INDEX] * len(temp_ids))

attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
attention_mask=[1] * len(input_ids),
labels=labels,
loss_weights=loss_weights,
)
Expand All @@ -183,36 +72,7 @@ def parse_chatml_message(generated_text: str) -> Message:
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})

tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")

content.append({"type": "tool_call", "value": tag_value.strip()})

last_end = end

if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})

return Message(role="assistant", content=content)
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])


class Renderer:
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/v1/plugins/data_plugins/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
)

if tools:
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
return {"messages": messages, "tools": json.dumps(tools)}
else:
return {"messages": messages}

Expand Down
Loading