Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 171 additions & 68 deletions unsloth_zoo/mlx/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
iterate_vlm_training_batches,
normalize_mlx_chat_template,
normalize_vlm_processor_chat_template,
encode_mlx_text,
_get_vlm_ignore_token_ids,
collect_mlx_texts,
save_lora_adapters,
Expand All @@ -84,6 +85,48 @@
)


def _is_hf_tokenizer(tokenizer):
try:
from transformers import PreTrainedTokenizerBase
except Exception:
return False
return isinstance(tokenizer, PreTrainedTokenizerBase)


def _resolve_response_mask_tokenizer(tokenizer):
"""Return a callable HF tokenizer for the CUDA response-mask helper."""
for _ in range(3):
if _is_hf_tokenizer(tokenizer):
return tokenizer

processor_tokenizer = getattr(tokenizer, "tokenizer", None)
if processor_tokenizer is not None and processor_tokenizer is not tokenizer:
tokenizer = processor_tokenizer
continue

# mlx-lm TokenizerWrapper stores the HF tokenizer under _tokenizer.
# HF fast tokenizers also expose _tokenizer, but that is the low-level
# Rust tokenizer and is not callable like PreTrainedTokenizerBase.
if (
hasattr(tokenizer, "_detokenizer_class")
and hasattr(tokenizer, "_tokenizer")
):
wrapped = getattr(tokenizer, "_tokenizer", None)
if wrapped is not None and wrapped is not tokenizer:
tokenizer = wrapped
continue

break

if not callable(tokenizer):
raise TypeError(
"Unsloth MLX: train_on_responses_only requires a callable "
"Hugging Face tokenizer or a processor/tokenizer wrapper that "
"contains one."
)
return tokenizer


def _normalize_mlx_optimizer_name(name):
opt_name = str(name or "adamw").strip().lower()
if opt_name not in SUPPORTED_MLX_OPTIMIZERS:
Expand Down Expand Up @@ -787,13 +830,8 @@ def _adafactor_unsupported_parameters(model):
unsupported.append((name, tuple(getattr(value, "shape", ()))))
return unsupported

def _evaluate(self, eval_batches, loss_fn, is_vlm=False):
"""Run evaluation loop.

Returns:
(avg_loss, perplexity) tuple.
"""
self.model.eval()
def _evaluate_batch_totals(self, eval_batches, loss_fn, is_vlm=False):
"""Accumulate weighted loss totals for one flat eval batch stream."""
all_losses = mx.array(0.0)
ntokens = mx.array(0)

Expand All @@ -809,6 +847,32 @@ def _evaluate(self, eval_batches, loss_fn, is_vlm=False):
ntokens += ntoks
mx.eval(all_losses, ntokens)

return all_losses, ntokens

def _evaluate(self, eval_batches, loss_fn, is_vlm=False):
"""Run evaluation loop.

Returns:
(avg_loss, perplexity) tuple.
"""
self.model.eval()
if isinstance(eval_batches, dict):
all_losses = mx.array(0.0)
ntokens = mx.array(0)
for split_batches in eval_batches.values():
split_losses, split_tokens = self._evaluate_batch_totals(
split_batches, loss_fn, is_vlm=is_vlm,
)
all_losses += split_losses
ntokens += split_tokens
mx.eval(all_losses, ntokens)
if self.stop_requested:
break
else:
all_losses, ntokens = self._evaluate_batch_totals(
eval_batches, loss_fn, is_vlm=is_vlm,
)

self.model.train()
avg_loss = (all_losses / ntokens).item() if ntokens.item() > 0 else 0.0
perplexity = math.exp(min(avg_loss, 100))
Expand Down Expand Up @@ -1437,41 +1501,55 @@ def step_fn(batch_data, prev_state, do_update):
_labeled_eval = getattr(self, '_eval_batches_labeled', None)
if _labeled_eval is not None:
eval_batches = _labeled_eval
elif is_vlm:
processor = self._resolve_vlm_processor()
config = getattr(self.model, "_config", {})
_vlm_mask_fn = getattr(self, '_vlm_response_mask_fn', None)
eval_batches = create_vlm_batches(
dataset=self.eval_dataset,
processor=processor,
config=config,
batch_size=args.per_device_train_batch_size,
max_seq_length=args.max_seq_length,
seed=args.seed,
response_mask_fn=_vlm_mask_fn,
formatting_func=self.formatting_func,
)
else:
eval_batches = create_batches(
dataset=self.eval_dataset,
tokenizer=self.tokenizer,
batch_size=args.per_device_train_batch_size,
max_seq_length=args.max_seq_length,
seed=args.seed,
dataset_text_field=args.dataset_text_field,
formatting_func=self.formatting_func,
chat_template=getattr(args, "chat_template", None),
model_name=getattr(self.model, "_hf_repo", None),
model_type=(
getattr(self.model, "_config", {}).get("model_type")
if isinstance(getattr(self.model, "_config", {}), dict)
else None
),
append_eos=bool(getattr(args, "append_eos", True)),
)
def _create_eval_batches(eval_dataset):
"""Materialize eval batches for one dataset split."""
if is_vlm:
processor = self._resolve_vlm_processor()
config = getattr(self.model, "_config", {})
_vlm_mask_fn = getattr(self, '_vlm_response_mask_fn', None)
return create_vlm_batches(
dataset=eval_dataset,
processor=processor,
config=config,
batch_size=args.per_device_train_batch_size,
max_seq_length=args.max_seq_length,
seed=args.seed,
response_mask_fn=_vlm_mask_fn,
formatting_func=self.formatting_func,
)
return create_batches(
dataset=eval_dataset,
tokenizer=self.tokenizer,
batch_size=args.per_device_train_batch_size,
max_seq_length=args.max_seq_length,
seed=args.seed,
dataset_text_field=args.dataset_text_field,
formatting_func=self.formatting_func,
chat_template=getattr(args, "chat_template", None),
model_name=getattr(self.model, "_hf_repo", None),
model_type=(
getattr(self.model, "_config", {}).get("model_type")
if isinstance(getattr(self.model, "_config", {}), dict)
else None
),
append_eos=bool(getattr(args, "append_eos", True)),
)

if isinstance(self.eval_dataset, dict):
eval_batches = {
key: _create_eval_batches(value)
for key, value in self.eval_dataset.items()
}
else:
eval_batches = _create_eval_batches(self.eval_dataset)
if eval_batches:
eval_batch_count = (
sum(len(value) for value in eval_batches.values())
if isinstance(eval_batches, dict) else len(eval_batches)
)
print(f"Unsloth: Eval enabled every {args.eval_steps} steps "
f"({len(eval_batches)} eval batches).")
f"({eval_batch_count} eval batches).")

features = []
if is_vlm:
Expand Down Expand Up @@ -2101,7 +2179,7 @@ def _create_labeled_batches(dataset, tokenizer, mask_fn, batch_size,
# 2. Tokenize + mask in parallel (HF fast tokenizers are thread-safe;
# slow tokenizers degrade gracefully via the GIL)
def _process_text(text):
encoded = tokenizer.encode(text)
encoded = encode_mlx_text(tokenizer, text)
# Mirror `_prepare_dataset`'s EOS contract; mismatch desyncs labeled vs unlabeled.
if append_eos and eos_id is not None and (not encoded or encoded[-1] != eos_id):
encoded.append(eos_id)
Expand All @@ -2115,12 +2193,33 @@ def _process_text(text):
labels = labels.tolist()
return (encoded, labels[0])

# Filter out samples where all labels are -100 (no valid training signal).
# This can happen when truncation cuts off the response_part entirely,
# e.g. long reasoning/analysis channels in GPT-OSS that exceed max_seq_length.
# Such samples cause NaN loss since cross_entropy(mean) computes 0/0.
def _has_valid_labels(labels):
return any(label != -100 for label in labels)

max_workers = min(4, os.cpu_count() or 1)
all_items = []
n_before_filter = 0
n_removed = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for result in executor.map(_process_text, all_texts):
if result is not None:
all_items.append(result)
n_before_filter += 1
if _has_valid_labels(result[1]):
all_items.append(result)
else:
n_removed += 1

if n_removed > 0:
print(
f"Unsloth: Removed {n_removed} out of {n_before_filter} samples "
f"from train_dataset where all labels were -100 "
f"(no response found after truncation). "
f"This prevents NaN loss during training."
)

if not all_items:
raise ValueError(
Expand Down Expand Up @@ -2338,13 +2437,7 @@ def train_on_responses_only(
"kwarg or via trainer.tokenizer."
)

# Unwrap to get a callable HF tokenizer.
# mlx-lm: TokenizerWrapper._tokenizer -> HF tokenizer
# VLM processors: processor.tokenizer -> HF tokenizer
if hasattr(_tokenizer, "_tokenizer"):
_tokenizer = _tokenizer._tokenizer
elif hasattr(_tokenizer, "tokenizer"):
_tokenizer = _tokenizer.tokenizer
_tokenizer = _resolve_response_mask_tokenizer(_tokenizer)

# Get masking closure from the HF/CUDA implementation
mask_fn = _hf_train_on_responses_only(
Expand Down Expand Up @@ -2413,26 +2506,36 @@ def train_on_responses_only(

# Process eval dataset too
if trainer.eval_dataset is not None:
eval_batches = _create_labeled_batches(
dataset=trainer.eval_dataset,
tokenizer=_tokenizer,
mask_fn=mask_fn,
batch_size=args.per_device_train_batch_size,
max_seq_length=args.max_seq_length,
formatting_func=trainer.formatting_func,
dataset_text_field=args.dataset_text_field,
seed=args.seed,
chat_template=getattr(args, "chat_template", None),
model_name=getattr(trainer.model, "_hf_repo", None),
model_type=(
getattr(trainer.model, "_config", {}).get("model_type")
if isinstance(getattr(trainer.model, "_config", {}), dict)
else None
),
append_eos=bool(getattr(args, "append_eos", True)),
dataset_order=getattr(args, "dataset_order", "default"),
preserve_dataset_order=bool(getattr(args, "preserve_dataset_order", False)),
)
def _create_labeled_eval_batches(eval_dataset):
"""Build response-masked eval batches for one dataset split."""
return _create_labeled_batches(
dataset=eval_dataset,
tokenizer=_tokenizer,
mask_fn=mask_fn,
batch_size=args.per_device_train_batch_size,
max_seq_length=args.max_seq_length,
formatting_func=trainer.formatting_func,
dataset_text_field=args.dataset_text_field,
seed=args.seed,
chat_template=getattr(args, "chat_template", None),
model_name=getattr(trainer.model, "_hf_repo", None),
model_type=(
getattr(trainer.model, "_config", {}).get("model_type")
if isinstance(getattr(trainer.model, "_config", {}), dict)
else None
),
append_eos=bool(getattr(args, "append_eos", True)),
dataset_order=getattr(args, "dataset_order", "default"),
preserve_dataset_order=bool(getattr(args, "preserve_dataset_order", False)),
)

if isinstance(trainer.eval_dataset, dict):
eval_batches = {
key: _create_labeled_eval_batches(value)
for key, value in trainer.eval_dataset.items()
}
else:
eval_batches = _create_labeled_eval_batches(trainer.eval_dataset)
trainer._eval_batches_labeled = eval_batches

print(f"Unsloth: train_on_responses_only enabled "
Expand Down
18 changes: 17 additions & 1 deletion unsloth_zoo/mlx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2085,6 +2085,22 @@ def normalize_vlm_processor_chat_template(
)


def encode_mlx_text(tokenizer, text):
"""Tokenize text while mirroring Unsloth's double-BOS guard."""
add_special_tokens = True
bos_token = getattr(tokenizer, "bos_token", None)
chat_template = getattr(tokenizer, "chat_template", "") or ""
if bos_token is not None and (
text.startswith(bos_token) or bos_token in chat_template
):
add_special_tokens = False

try:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
except TypeError:
return tokenizer.encode(text)


def _raise_mlx_chat_template_error(target, *, is_vlm=False):
if is_vlm:
_raise_vlm_chat_template_error(target)
Expand Down Expand Up @@ -3113,7 +3129,7 @@ def __init__(self, data, tokenizer, text_key="text", eos_id=None):
self._eos_id = eos_id

def process(self, item):
encoded = self.tokenizer.encode(item[self.text_key])
encoded = encode_mlx_text(self.tokenizer, item[self.text_key])
if (
self._eos_id is not None
and (not encoded or encoded[-1] != self._eos_id)
Expand Down
Loading