From 0fc7f14b7d6141ee8ef32419779012ddbda56b20 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 28 May 2026 22:18:36 +0800 Subject: [PATCH 1/4] fix(mlx): filter fully masked response samples --- unsloth_zoo/mlx/trainer.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index b202e15c9..f7dfdac03 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -2115,12 +2115,33 @@ def _process_text(text): labels = labels.tolist() return (encoded, labels[0]) + # Filter out samples where all labels are -100 (no valid training signal). + # This can happen when truncation cuts off the response_part entirely, + # e.g. long reasoning/analysis channels in GPT-OSS that exceed max_seq_length. + # Such samples cause NaN loss since cross_entropy(mean) computes 0/0. + def _has_valid_labels(labels): + return any(label != -100 for label in labels) + max_workers = min(4, os.cpu_count() or 1) all_items = [] + n_before_filter = 0 + n_removed = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: for result in executor.map(_process_text, all_texts): if result is not None: - all_items.append(result) + n_before_filter += 1 + if _has_valid_labels(result[1]): + all_items.append(result) + else: + n_removed += 1 + + if n_removed > 0: + print( + f"Unsloth: Removed {n_removed} out of {n_before_filter} samples " + f"from train_dataset where all labels were -100 " + f"(no response found after truncation). " + f"This prevents NaN loss during training." + ) if not all_items: raise ValueError( From ae977baba8712a65751e593522b86f3a2404eff3 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 29 May 2026 18:07:03 +0800 Subject: [PATCH 2/4] fix(mlx): preserve hf tokenizers for response masking --- unsloth_zoo/mlx/trainer.py | 50 ++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index f7dfdac03..3a3af76d6 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -84,6 +84,48 @@ ) +def _is_hf_tokenizer(tokenizer): + try: + from transformers import PreTrainedTokenizerBase + except Exception: + return False + return isinstance(tokenizer, PreTrainedTokenizerBase) + + +def _resolve_response_mask_tokenizer(tokenizer): + """Return a callable HF tokenizer for the CUDA response-mask helper.""" + for _ in range(3): + if _is_hf_tokenizer(tokenizer): + return tokenizer + + processor_tokenizer = getattr(tokenizer, "tokenizer", None) + if processor_tokenizer is not None and processor_tokenizer is not tokenizer: + tokenizer = processor_tokenizer + continue + + # mlx-lm TokenizerWrapper stores the HF tokenizer under _tokenizer. + # HF fast tokenizers also expose _tokenizer, but that is the low-level + # Rust tokenizer and is not callable like PreTrainedTokenizerBase. + if ( + hasattr(tokenizer, "_detokenizer_class") + and hasattr(tokenizer, "_tokenizer") + ): + wrapped = getattr(tokenizer, "_tokenizer", None) + if wrapped is not None and wrapped is not tokenizer: + tokenizer = wrapped + continue + + break + + if not callable(tokenizer): + raise TypeError( + "Unsloth MLX: train_on_responses_only requires a callable " + "Hugging Face tokenizer or a processor/tokenizer wrapper that " + "contains one." + ) + return tokenizer + + def _normalize_mlx_optimizer_name(name): opt_name = str(name or "adamw").strip().lower() if opt_name not in SUPPORTED_MLX_OPTIMIZERS: @@ -2359,13 +2401,7 @@ def train_on_responses_only( "kwarg or via trainer.tokenizer." ) - # Unwrap to get a callable HF tokenizer. - # mlx-lm: TokenizerWrapper._tokenizer -> HF tokenizer - # VLM processors: processor.tokenizer -> HF tokenizer - if hasattr(_tokenizer, "_tokenizer"): - _tokenizer = _tokenizer._tokenizer - elif hasattr(_tokenizer, "tokenizer"): - _tokenizer = _tokenizer.tokenizer + _tokenizer = _resolve_response_mask_tokenizer(_tokenizer) # Get masking closure from the HF/CUDA implementation mask_fn = _hf_train_on_responses_only( From 5dd5c0017a8549f71a44829effa5fafc0a08d1f0 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 29 May 2026 18:08:42 +0800 Subject: [PATCH 3/4] fix(mlx): avoid double bos in text batching --- unsloth_zoo/mlx/trainer.py | 3 ++- unsloth_zoo/mlx/utils.py | 18 +++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 3a3af76d6..aeb026ede 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -64,6 +64,7 @@ iterate_vlm_training_batches, normalize_mlx_chat_template, normalize_vlm_processor_chat_template, + encode_mlx_text, _get_vlm_ignore_token_ids, collect_mlx_texts, save_lora_adapters, @@ -2143,7 +2144,7 @@ def _create_labeled_batches(dataset, tokenizer, mask_fn, batch_size, # 2. Tokenize + mask in parallel (HF fast tokenizers are thread-safe; # slow tokenizers degrade gracefully via the GIL) def _process_text(text): - encoded = tokenizer.encode(text) + encoded = encode_mlx_text(tokenizer, text) # Mirror `_prepare_dataset`'s EOS contract; mismatch desyncs labeled vs unlabeled. if append_eos and eos_id is not None and (not encoded or encoded[-1] != eos_id): encoded.append(eos_id) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 3b38ac330..3cf8d173a 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -2085,6 +2085,22 @@ def normalize_vlm_processor_chat_template( ) +def encode_mlx_text(tokenizer, text): + """Tokenize text while mirroring Unsloth's double-BOS guard.""" + add_special_tokens = True + bos_token = getattr(tokenizer, "bos_token", None) + chat_template = getattr(tokenizer, "chat_template", "") or "" + if bos_token is not None and ( + text.startswith(bos_token) or bos_token in chat_template + ): + add_special_tokens = False + + try: + return tokenizer.encode(text, add_special_tokens=add_special_tokens) + except TypeError: + return tokenizer.encode(text) + + def _raise_mlx_chat_template_error(target, *, is_vlm=False): if is_vlm: _raise_vlm_chat_template_error(target) @@ -3113,7 +3129,7 @@ def __init__(self, data, tokenizer, text_key="text", eos_id=None): self._eos_id = eos_id def process(self, item): - encoded = self.tokenizer.encode(item[self.text_key]) + encoded = encode_mlx_text(self.tokenizer, item[self.text_key]) if ( self._eos_id is not None and (not encoded or encoded[-1] != self._eos_id) From 4f70b5f1d7e0a0e323cdcc889439e8addf303fd2 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 29 May 2026 18:11:16 +0800 Subject: [PATCH 4/4] fix(mlx): handle dict eval datasets --- unsloth_zoo/mlx/trainer.py | 163 +++++++++++++++++++++++-------------- 1 file changed, 104 insertions(+), 59 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index aeb026ede..a9d4a721b 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -830,13 +830,8 @@ def _adafactor_unsupported_parameters(model): unsupported.append((name, tuple(getattr(value, "shape", ())))) return unsupported - def _evaluate(self, eval_batches, loss_fn, is_vlm=False): - """Run evaluation loop. - - Returns: - (avg_loss, perplexity) tuple. - """ - self.model.eval() + def _evaluate_batch_totals(self, eval_batches, loss_fn, is_vlm=False): + """Accumulate weighted loss totals for one flat eval batch stream.""" all_losses = mx.array(0.0) ntokens = mx.array(0) @@ -852,6 +847,32 @@ def _evaluate(self, eval_batches, loss_fn, is_vlm=False): ntokens += ntoks mx.eval(all_losses, ntokens) + return all_losses, ntokens + + def _evaluate(self, eval_batches, loss_fn, is_vlm=False): + """Run evaluation loop. + + Returns: + (avg_loss, perplexity) tuple. + """ + self.model.eval() + if isinstance(eval_batches, dict): + all_losses = mx.array(0.0) + ntokens = mx.array(0) + for split_batches in eval_batches.values(): + split_losses, split_tokens = self._evaluate_batch_totals( + split_batches, loss_fn, is_vlm=is_vlm, + ) + all_losses += split_losses + ntokens += split_tokens + mx.eval(all_losses, ntokens) + if self.stop_requested: + break + else: + all_losses, ntokens = self._evaluate_batch_totals( + eval_batches, loss_fn, is_vlm=is_vlm, + ) + self.model.train() avg_loss = (all_losses / ntokens).item() if ntokens.item() > 0 else 0.0 perplexity = math.exp(min(avg_loss, 100)) @@ -1480,41 +1501,55 @@ def step_fn(batch_data, prev_state, do_update): _labeled_eval = getattr(self, '_eval_batches_labeled', None) if _labeled_eval is not None: eval_batches = _labeled_eval - elif is_vlm: - processor = self._resolve_vlm_processor() - config = getattr(self.model, "_config", {}) - _vlm_mask_fn = getattr(self, '_vlm_response_mask_fn', None) - eval_batches = create_vlm_batches( - dataset=self.eval_dataset, - processor=processor, - config=config, - batch_size=args.per_device_train_batch_size, - max_seq_length=args.max_seq_length, - seed=args.seed, - response_mask_fn=_vlm_mask_fn, - formatting_func=self.formatting_func, - ) else: - eval_batches = create_batches( - dataset=self.eval_dataset, - tokenizer=self.tokenizer, - batch_size=args.per_device_train_batch_size, - max_seq_length=args.max_seq_length, - seed=args.seed, - dataset_text_field=args.dataset_text_field, - formatting_func=self.formatting_func, - chat_template=getattr(args, "chat_template", None), - model_name=getattr(self.model, "_hf_repo", None), - model_type=( - getattr(self.model, "_config", {}).get("model_type") - if isinstance(getattr(self.model, "_config", {}), dict) - else None - ), - append_eos=bool(getattr(args, "append_eos", True)), - ) + def _create_eval_batches(eval_dataset): + """Materialize eval batches for one dataset split.""" + if is_vlm: + processor = self._resolve_vlm_processor() + config = getattr(self.model, "_config", {}) + _vlm_mask_fn = getattr(self, '_vlm_response_mask_fn', None) + return create_vlm_batches( + dataset=eval_dataset, + processor=processor, + config=config, + batch_size=args.per_device_train_batch_size, + max_seq_length=args.max_seq_length, + seed=args.seed, + response_mask_fn=_vlm_mask_fn, + formatting_func=self.formatting_func, + ) + return create_batches( + dataset=eval_dataset, + tokenizer=self.tokenizer, + batch_size=args.per_device_train_batch_size, + max_seq_length=args.max_seq_length, + seed=args.seed, + dataset_text_field=args.dataset_text_field, + formatting_func=self.formatting_func, + chat_template=getattr(args, "chat_template", None), + model_name=getattr(self.model, "_hf_repo", None), + model_type=( + getattr(self.model, "_config", {}).get("model_type") + if isinstance(getattr(self.model, "_config", {}), dict) + else None + ), + append_eos=bool(getattr(args, "append_eos", True)), + ) + + if isinstance(self.eval_dataset, dict): + eval_batches = { + key: _create_eval_batches(value) + for key, value in self.eval_dataset.items() + } + else: + eval_batches = _create_eval_batches(self.eval_dataset) if eval_batches: + eval_batch_count = ( + sum(len(value) for value in eval_batches.values()) + if isinstance(eval_batches, dict) else len(eval_batches) + ) print(f"Unsloth: Eval enabled every {args.eval_steps} steps " - f"({len(eval_batches)} eval batches).") + f"({eval_batch_count} eval batches).") features = [] if is_vlm: @@ -2471,26 +2506,36 @@ def train_on_responses_only( # Process eval dataset too if trainer.eval_dataset is not None: - eval_batches = _create_labeled_batches( - dataset=trainer.eval_dataset, - tokenizer=_tokenizer, - mask_fn=mask_fn, - batch_size=args.per_device_train_batch_size, - max_seq_length=args.max_seq_length, - formatting_func=trainer.formatting_func, - dataset_text_field=args.dataset_text_field, - seed=args.seed, - chat_template=getattr(args, "chat_template", None), - model_name=getattr(trainer.model, "_hf_repo", None), - model_type=( - getattr(trainer.model, "_config", {}).get("model_type") - if isinstance(getattr(trainer.model, "_config", {}), dict) - else None - ), - append_eos=bool(getattr(args, "append_eos", True)), - dataset_order=getattr(args, "dataset_order", "default"), - preserve_dataset_order=bool(getattr(args, "preserve_dataset_order", False)), - ) + def _create_labeled_eval_batches(eval_dataset): + """Build response-masked eval batches for one dataset split.""" + return _create_labeled_batches( + dataset=eval_dataset, + tokenizer=_tokenizer, + mask_fn=mask_fn, + batch_size=args.per_device_train_batch_size, + max_seq_length=args.max_seq_length, + formatting_func=trainer.formatting_func, + dataset_text_field=args.dataset_text_field, + seed=args.seed, + chat_template=getattr(args, "chat_template", None), + model_name=getattr(trainer.model, "_hf_repo", None), + model_type=( + getattr(trainer.model, "_config", {}).get("model_type") + if isinstance(getattr(trainer.model, "_config", {}), dict) + else None + ), + append_eos=bool(getattr(args, "append_eos", True)), + dataset_order=getattr(args, "dataset_order", "default"), + preserve_dataset_order=bool(getattr(args, "preserve_dataset_order", False)), + ) + + if isinstance(trainer.eval_dataset, dict): + eval_batches = { + key: _create_labeled_eval_batches(value) + for key, value in trainer.eval_dataset.items() + } + else: + eval_batches = _create_labeled_eval_batches(trainer.eval_dataset) trainer._eval_batches_labeled = eval_batches print(f"Unsloth: train_on_responses_only enabled "