Skip to content

fix(mlx): handle zero-token and invalid labels in CCE#682

Merged
danielhanchen merged 30 commits into
unslothai:mainfrom
Lyxot:fix/mlx-cce-edge-cases
May 27, 2026
Merged

fix(mlx): handle zero-token and invalid labels in CCE#682
danielhanchen merged 30 commits into
unslothai:mainfrom
Lyxot:fix/mlx-cce-edge-cases

Conversation

@Lyxot

@Lyxot Lyxot commented May 20, 2026

Copy link
Copy Markdown
Contributor

Summary

Hardens the MLX fused CCE runtime against two edge cases:

  • zero-token inputs
  • invalid target labels outside [0, vocab_size) except ignore_index

Behavior

Before this change:

  • zero-token inputs could reach the Metal kernel with an empty grid and crash the process.
  • invalid labels such as -1 or target == vocab_size could produce finite-looking losses instead of surfacing bad input.

Fix

  • Return empty loss/LSE tensors for zero-token forward passes before launching Metal kernels.
  • Return zero gradients for zero-token dense and quantized backward paths.
  • Mark invalid-label rows with NaN loss/LSE so bad labels cannot silently train with plausible loss values.
  • Keep invalid-label handling compile-safe under mx.compile.

Validation

python -m pytest tests/test_mlx_runtime_cce_compile.py tests/test_mlx_baseline_loss_parity.py -q

Copilot AI review requested due to automatic review settings May 20, 2026 16:52

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR hardens the MLX runtime chunked cross-entropy implementation for edge cases by adding explicit handling for zero-token inputs and by poisoning invalid (out-of-vocab) targets with NaNs to surface data issues quickly.

Changes:

  • Add helpers to compute target validity masks and poison invalid positions with NaNs.
  • Add early returns for zero-token inputs in forward finalize and VJP paths.
  • Add tests covering zero-token behavior and invalid-label poisoning (including compiled + quantized variants).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
unsloth_zoo/mlx/cce/runtime_cce.py Adds invalid-target detection/NaN poisoning and zero-token early returns in forward/VJP.
tests/test_mlx_runtime_cce_compile.py Adds regression tests for zero-token outputs/gradients and NaN poisoning for invalid labels.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

def _poison_invalid_targets(values: mx.array, invalid: mx.array) -> mx.array:
return mx.where(
invalid,
mx.full(values.shape, float("nan"), dtype=values.dtype),
Comment thread unsloth_zoo/mlx/cce/runtime_cce.py

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces robust handling for zero-token inputs and out-of-vocabulary target labels within the MLX runtime chunked cross-entropy implementation. Specifically, it adds logic to return empty losses and zero gradients for zero-token scenarios and ensures that invalid targets correctly result in NaN values for both loss and log-sum-exp. Comprehensive tests have been added to verify these behaviors in both standard and quantized paths. Feedback focuses on optimizing memory efficiency by using scalar broadcasting in mx.where instead of mx.full and consolidating conditional logic in the fallback path for better performance and readability.

Comment on lines +128 to +133
def _poison_invalid_targets(values: mx.array, invalid: mx.array) -> mx.array:
return mx.where(
invalid,
mx.full(values.shape, float("nan"), dtype=values.dtype),
values,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using mx.full to create a NaN array of the same shape as values is inefficient as it triggers a full allocation. MLX's mx.where supports scalar broadcasting, which is more memory-efficient and faster.

Suggested change
def _poison_invalid_targets(values: mx.array, invalid: mx.array) -> mx.array:
return mx.where(
invalid,
mx.full(values.shape, float("nan"), dtype=values.dtype),
values,
)
def _poison_invalid_targets(values: mx.array, invalid: mx.array) -> mx.array:
return mx.where(
invalid,
mx.array(float("nan"), dtype=values.dtype),
values,
)

Comment thread unsloth_zoo/mlx/cce/runtime_cce.py Outdated
Comment on lines +530 to +533
valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index)
loss = mx.where(valid, lse - target_logit, mx.zeros_like(lse))
loss = _poison_invalid_targets(loss, invalid)
lse = _poison_invalid_targets(lse, invalid)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the fallback path, the logic for setting the loss to zero for ignored tokens and NaN for invalid tokens can be consolidated into a single nested mx.where call. This avoids multiple passes over the data and is more idiomatic. Additionally, ensure that logit scaling and softcapping transformations are applied to the logits before they are passed to this loss calculation to maintain consistency with standard execution paths.

Suggested change
valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index)
loss = mx.where(valid, lse - target_logit, mx.zeros_like(lse))
loss = _poison_invalid_targets(loss, invalid)
lse = _poison_invalid_targets(lse, invalid)
valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index)
loss = mx.where(valid, lse - target_logit, mx.where(invalid, float("nan"), 0.0))
lse = _poison_invalid_targets(lse, invalid)
References
  1. When implementing code paths that return materialized logits (e.g., using UNSLOTH_RETURN_LOGITS=1), ensure that logit scaling (multiply/divide) and softcapping transformations are applied to the logits before they are passed to the loss function to maintain consistency with standard execution paths.

danielhanchen added a commit to danielhanchen/unsloth-zoo-staging-1 that referenced this pull request May 24, 2026
The seven upstream workflows (consolidated-tests-ci, lint-ci, mlx-ci,
security-audit, stale, studio-export-fix-ci, wheel-smoke) would fire on
every push and PR-event to this throwaway staging branch and burn runner
minutes that have nothing to do with validating MLX PRs unslothai#679, unslothai#682, unslothai#692.

Keep only the three mlx-pr-* workflows on this branch. They stay in
upstream main / origin/main untouched -- this deletion is scoped to the
staging branch only.
…nslothai#682

Three follow-ups from review feedback:

1. _poison_invalid_targets used mx.full(values.shape, NaN, ...) which
   allocates an O(n) tensor on every forward call, even when the invalid
   mask is all False. Replaced with a scalar NaN broadcast through
   mx.where so the normal path costs nothing extra.

2. The zero-token forward early-return now raises ValueError when
   hidden has zero rows but targets is non-empty. The previous version
   silently returned empty loss/LSE which masked an upstream shape
   mismatch instead of surfacing it.

3. The compiled invalid-label regression test now parametrizes over
   bad_target in [-1, vocab_size]. Negative and out-of-range labels
   take different lookup paths under mx.compile, so single-sided
   coverage was insufficient. Added a matching test for the new
   ValueError on hidden=0 with non-empty targets.
…ai#682

The scalar mx.array(float('nan'), ...) broadcast got baked into the
Metal kernel as the literal token 'nan', which the Metal C++
tokenizer rejects ('use of undeclared identifier nan'). Restoring
the original mx.full(values.shape, ...) form keeps the allocation
but preserves Metal-kernel compilation. The two-sided compile-test
parametrize and the hidden=0/targets!=0 ValueError defensive check
remain.
Extend the existing n==0 shape-mismatch guard in
_forward_chunked_fused_finalize so any targets.shape[0] != n is surfaced
as a ValueError before chunk planning, fallback indexing, or Metal
kernel launch. Short non-empty targets previously broadcast to the
n-row loss and lse tensors, silently returning wrong-but-finite losses;
longer targets reached kernels that index targets[row] for row in
[0, n) and produced cryptic backend errors. The new check follows the
PR intent of surfacing upstream shape bugs early.
@danielhanchen danielhanchen added auto-review-failed Auto-review rejected the PR and removed auto-reviewing Auto-review in progress labels May 24, 2026
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@danielhanchen

Copy link
Copy Markdown
Member

Auto-review verdict: Changes requested

Reason: Deterministic gate: accumulated test(s) regressed after post_sim_commit at iteration 2 (1 regression event(s) total)

…#682

The Metal dlogits kernel uses fast::exp(capped - lse[row]). When
_poison_invalid_targets sets lse[row] to NaN for invalid labels,
fast:: math is not IEEE 754 strict (MSL spec 6.5.1) and may return
a finite value, leaving the forward loss NaN but the gradient
silently wrong. Add an isnan(lse[row]) guard that emits NaN via
0.0f/0.0f (Metal C++ rejects the literal 'nan' token).

Also tighten the invalid-labels gradient test: previously asserted
only that _stable_norm(grads) is NaN, which is trivially satisfied
because grad_weight is all-NaN whenever any invalid row exists.
Now explicitly checks finite grad_hidden on the valid row, NaN on
the invalid row, and zero on the ignore_index row.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@danielhanchen danielhanchen removed the auto-review-failed Auto-review rejected the PR label May 24, 2026
Rank-2 (n, 1) or scalar targets slipped past the length check and crashed
deep inside the kernels with backend-specific errors. Add a 1D guard at
the top of _runtime_cce_loss_and_aux so callers get a clear ValueError.
_fallback_dlogits relies on lse arriving pre-poisoned with NaN for
invalid-label rows (targets outside [0, vocab_size) and not equal to
ignore_index). It does not re-check vocab bounds; NaN propagation
through exp(capped - NaN) = NaN is what produces NaN gradients for
those rows. Add a comment so future callers cannot accidentally pass
unpoisoned lse and silently get finite garbage gradients.
The Metal forward-finalize kernel writes finite lse_out and loss_out
for every row, including out-of-vocab targets it cannot classify
because vocab_size is not a kernel input. The dlogits backward kernel
relies on the Python caller having applied _poison_invalid_targets to
both lse and loss before backward is invoked. Make that invariant
explicit at the kernel definition site so future fused or multi-GPU
paths that pass raw lse_out to the dlogits kernel will not silently
produce finite wrong gradients for invalid labels.
@danielhanchen danielhanchen added auto-approved Auto-review approved the PR and removed auto-reviewing Auto-review in progress labels May 24, 2026
@danielhanchen

Copy link
Copy Markdown
Member

Auto-review verdict: Approved

Hardens unsloth_zoo/mlx/cce/runtime_cce.py against (a) zero-token inputs that crash the Metal kernel via an empty grid and (b) out-of-vocab labels that previously produced finite-garbage losses and gradients; adds upfront shape validation, NaN poisoning of loss/lse for invalid labels in both fallback and Metal-finalize paths, an isnan guard in the dlogits backward kernel, and 7+ tests covering eager, compiled, and quantized paths.

Reason: All real review findings fixed; no remaining bugs; PR correctly addresses two concrete bugs in MLX CCE.

Wider integer labels (int64 / uint32) outside the int32 range could
wrap through targets.astype(mx.int32) into a valid class id or the
ignore_index, sneaking past the new invalid-label NaN poisoning.
Compute the validity masks from the original dtype first, then narrow
for the kernels and indexing. Add a parametrized regression covering
2**32 plus/minus offsets and their negatives.
…nslothai#682

A wide invalid label that narrows to ignore_index after astype(mx.int32)
(e.g. 2**32 - 100 -> -100 with ignore_index=-100) used to produce a NaN
forward loss but a zero backward gradient. Move the NaN-lse guard before
the ignore_index check in both the Metal dlogits kernel and the Python
fallback so the gradient also poisons. Refuse the implicit zeroing path
on rows whose lse is NaN.

Also drop the .astype(mx.int32) pre-cast from the four CCE wrapper call
sites (text dense / quantized; VLM dense / quantized). Casting before the
runtime narrowed wide labels to valid class ids or ignore_index, slipping
past the new original-dtype validation in _forward_chunked_fused_finalize.
The runtime now performs that narrow internally only after the validity
masks are built from the original dtype.

New regression: test_runtime_cce_int64_wrap_to_ignore_index_poisons_gradients.
…hai#682

The zero-token guard previously returned `return empty, empty`, aliasing
the same mx.array as both loss and lse. Today's downstream wrapper only
does `losses + lse * 0.0` so aliasing is harmless, but any future caller
that mutates one slot or relies on independent identity (e.g. VJP
cotangent indexing) would surprise. Return two distinct mx.zeros((0,))
allocations instead. One-line change, no behavior change for current
callers.
…hai#682

The runtime CCE now validates label dtype/range before its own int32
narrow so wide invalid labels (e.g. 2**32 - 100 wrapping to -100 after
narrow) get NaN-poisoned instead of silently treated as ignore_index.
The direct call-site .astype(mx.int32) pre-casts in utils.py CCE wrappers
were removed earlier for the same reason.

Two upstream sites still pre-narrowed: _to_mx_vlm_batch (utils.py:2168)
and _collate_vlm_prompt_completion_batch's final labels cast
(utils.py:2237). With those still narrowing, wide invalid VLM labels
silently became -100 before reaching the validity check, regressing the
PR's fix on the VLM path. Drop the narrows; the runtime CCE owns
narrowing now.
…nslothai#682

_expand_image_token_sequences (utils.py:928) and _expand_token_runs
(utils.py:1000) allocated padded_labels as np.int32. When a response-
masked VLM batch contains a wide invalid label (e.g. 2**32 - 100 from
a tokenizer that emits int64), the int32 allocation OverflowErrors
before the runtime CCE validity check has a chance to reject the label.

Allocate padded_labels as np.int64 so wide values reach the CCE
runtime intact. The runtime narrows after validating, so we no longer
lose the dtype/range information that triggers NaN-poisoning.
…efore validation for PR unslothai#682

Round-10 review caught two asymmetric-fix paths the previous int64 widening
missed:

* `_collate_vlm_prompt_completion_batch` still built `labels_np` with
  `dtype=np.int32` after `_to_mx_vlm_batch()` narrowed the model inputs.
  A wide-int label (e.g. `2**32 - 100` under a uint32 tokenizer) was
  discarded before reaching runtime CCE's validity classifier. Now build
  labels from `combined_inputs["input_ids"]` as `np.int64` before the
  batch narrow, mirroring the rest of the PR's "validate before int32
  narrow" contract.
* `_target_validity_masks` compared `targets >= 0` directly, which the
  torch-backed MLX simulation rejects for `uint16`/`uint32`/`uint64`
  ("ge_cpu" not implemented for UInt32). Cast unsigned dtypes to int64
  for the comparison so wide-int wrap labels classify as out-of-vocab
  instead of crashing the validity step. Defensive no-op on real
  Apple-Metal MLX.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

…slothai#682

Round-11 review caught three remaining asymmetric-fix surfaces left open
by the R10 unsigned-label fixes:

* `_target_validity_masks` now widens `mx.uint64` via float64 (or float32
  if the build lacks float64) before the validity comparisons. int64
  would wraparound large unsigned values like `2**64 - 100` into negative
  ints that could collide with `ignore_index=-100`; float losslessly
  covers the meaningful 0..2**53 range and saturates above that into
  the out-of-vocab bucket. Bypasses the `ignore_index != target` check
  entirely when ignore_index is negative since unsigned values cannot
  reach it anyway.
* `_to_mx_vlm_batch` now widens unsigned `labels` (uint8/16/32/64) to
  signed int64 via the shared `_normalize_cce_label_dtype` helper. The
  PR previously stopped narrowing labels to mx.int32 so runtime CCE
  could classify wide invalids, but VLM masking helpers downstream
  inject `-100` via `mx.where`, which crashes on unsigned tensors under
  the torch-backed MLX simulation ("where_cpu" not implemented for
  UInt32). int64 keeps wide-range visibility AND lets masking inject
  the signed sentinel cleanly.
* `_mask_image_tokens` and `_mask_prompt_tokens` apply the same
  normalization defensively and use `mx.array(-100, dtype=targets.dtype)`
  so any caller-shaped target dtype (signed int8/16, float, etc.) gets
  a matched-dtype sentinel instead of relying on implicit Python-int
  broadcasting that some MLX shims reject.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

… text CCE branches for PR unslothai#682

Round-15 review caught a regression I introduced in R11 plus two
asymmetric-fix gaps the prior rounds missed:

* `_normalize_cce_label_dtype` cast `mx.uint64` straight through
  `astype(mx.int64)`. Values >= 2**63 wrap to negatives; specifically
  `2**64 - 100` lands on exactly `-100` and silently collides with
  `ignore_index`, defeating the runtime CCE validity check the rest of
  the PR is built around. Detect uint64 explicitly and route any value
  above `(1 << 63) - 1` to a known out-of-vocab sentinel via
  `mx.where`; uint8/16/32 still fit losslessly so a direct cast is fine.
* Both quantized and dense text CCE branches in
  `make_chunked_cross_entropy_loss` still injected the `-100` sentinel
  via raw `mx.where(mask, targets, -100)`. The VLM masking helpers were
  R11-normalized to widen unsigned labels first; mirror the same
  contract in the text path so unsigned tokenizer labels don't crash
  the masking step on the torch-backed shim and don't bypass runtime
  CCE's validity classifier for wide invalids.
* `_collate_vlm_prompt_completion_batch` built `labels_np` via
  `np.asarray(..., dtype=np.int64)`. The same uint64 wrap issue applies
  at the numpy layer: `np.uint64(2**64 - 100)` casts to int64 -100,
  hitting ignore_index before the runtime ever sees the value. Apply
  the same saturating-to-sentinel logic in numpy when the source dtype
  is uint64; non-uint64 paths still use the lossless int64 cast.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 483ffd3af6

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/cce/runtime_cce.py Outdated
Comment on lines +147 to +149
_validation_dtype = getattr(mx, "float64", mx.float32)
targets_for_validation = targets.astype(_validation_dtype)
in_vocab = (targets_for_validation >= 0.0) & (

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid float64 casts in uint64 label validation path

When targets.dtype == mx.uint64, this branch promotes labels to mx.float64 before computing in_vocab. On MLX, float64 arrays are CPU-only, so with the default GPU stream this can raise at runtime before the CCE kernels run. That means uint64-labeled batches (the exact edge case this patch tries to support) can fail outright on real MLX devices instead of being classified and NaN-poisoned.

Useful? React with 👍 / 👎.

…-label / np-expand / runtime float64 paths for PR unslothai#682

Fixes the four asymmetric-fix gaps flagged in round 16 review:

1. unsloth_zoo/mlx/utils.py: _normalize_cce_label_dtype previously did
   `mx.where(labels > max_i63, sentinel, labels_i64)` on raw uint64,
   which crashed the torch-backed MLX shim with
   `NotImplementedError: "gt_cpu" not implemented for 'UInt64'`. Cast
   to int64 first and detect overflow with `labels_i64 < 0`; any
   uint64 value above 2**63-1 wraps negative after the cast and is
   routed to a positive out-of-vocab sentinel so runtime CCE NaN-
   poisons the row instead of colliding with ignore_index.

2. unsloth_zoo/mlx/utils.py: _vlm_cce_forward(labels=None) used to
   apply `mx.where(length_mask, targets, -100)` directly on unsigned
   input_ids. Normalize through _normalize_cce_label_dtype first and
   inject the ignore sentinel via `mx.array(-100, dtype=targets.dtype)`
   so the path matches every other masking site.

3. unsloth_zoo/mlx/utils.py: add _normalize_numpy_cce_labels helper and
   apply it inside _expand_image_token_sequences and _expand_token_runs
   so wide np.uint64 labels (e.g. 2**64-100) do not OverflowError when
   `.tolist()` values are packed into the int64 padded label buffer.

4. unsloth_zoo/mlx/cce/runtime_cce.py: _target_validity_masks no
   longer requests `getattr(mx, "float64", mx.float32)` for uint64
   validation. MLX float64 arrays are CPU-only per the public docs,
   so the prior branch would have broken on real Apple-Silicon/Metal
   even though it sidestepped the simulation crash. float32 is
   sufficient to separate normal vocabulary labels from uint64
   overflow sentinels for the validity check.
…k labels for PR unslothai#682

Fixes four asymmetric-fix gaps from round 17 review:

1. unsloth_zoo/mlx/utils.py + unsloth_zoo/mlx/cce/runtime_cce.py:
   collators that pre-encode ignore_index as uint64(2**64-100) wrap to
   exactly int64(-100) after the cast. Round 16's overflow routing
   then sent every wrapped-negative value (including encoded ignore)
   to the invalid out-of-vocab sentinel and runtime CCE NaN-poisoned
   the row. Detect `labels_i64 == -100` (and the numpy analogue
   `labels_np == np.uint64(2**64-100)`) BEFORE routing overflow to
   the sentinel so legitimate ignore rows continue to be skipped
   downstream. Also drop the float32-based uint64 validation in
   `_target_validity_masks` (float32 cannot exactly represent labels
   above 2**24, so labels near vocab boundaries above that were
   misclassified). Validate through signed int64 with explicit
   overflow detection instead. Comment in the helper updated to
   document the actual sentinel contract (1 << 62, not iinfo max).

2. unsloth_zoo/mlx/utils.py: `_apply_response_mask_to_vlm_batch`
   built `mx.array(labels_list)` and injected an untyped `mx.array(-100)`
   directly. Custom `response_mask_fn` results carrying unsigned label
   dtypes then crashed the torch-backed MLX shim with
   NotImplementedError "where_cpu" not implemented for UInt*. Route
   labels through `_normalize_numpy_cce_labels` + `_normalize_cce_label_dtype`
   and inject a dtype-matched ignore sentinel.

3. unsloth_zoo/mlx/utils.py: text baseline CE path (`use_cce=False`)
   reached `mx.where(targets == -100, 0, targets)` directly on raw
   unsigned labels, leaving the asymmetric-fix gap that runtime CCE
   already covered. Normalize the labels through `_normalize_cce_label_dtype`
   and use a dtype-matched ignore scalar in the safe-targets `mx.where`.

4. unsloth_zoo/mlx/utils.py: the VLM baseline CE fallback had the
   same gap on both the labels-aware and labels=None branches.
   Mirror the same normalization + dtype-matched ignore sentinel.
…s-free + response-mask paths for PR unslothai#682

Fixes three asymmetric-fix gaps from round 18 review:

1. unsloth_zoo/mlx/utils.py + unsloth_zoo/mlx/cce/runtime_cce.py:
   round 17 preserved `uint64(2**64-100)` (which wraps to int64(-100)
   after cast) as the canonical ignore_index. That was wrong: a raw
   uint64 value above 2**63-1 is an out-of-range unsigned label, not
   an intentional encoded ignore. Treat ALL uint64 overflow values as
   invalid sentinels (1 << 62) so runtime CCE NaN-poisons them instead
   of silently dropping them as ignored rows. Applies symmetrically in
   `_normalize_cce_label_dtype`, `_normalize_numpy_cce_labels`, and
   `_target_validity_masks`.

2. unsloth_zoo/mlx/utils.py: labels-free VLM CCE / baseline paths
   used to derive targets from `input_ids[:, 1:]` AFTER
   `_to_mx_vlm_batch()` narrowed `input_ids` to mx.int32. A wide
   invalid id such as `np.uint32(2**32-100)` then wrapped to -100 in
   the narrowed view and was silently treated as ignore_index instead
   of being NaN-poisoned. Preserve the raw (pre-narrow) input_ids
   under a private `_unsloth_raw_input_ids_for_labels` key in
   `_to_mx_vlm_batch` when no explicit `labels` are present, and let
   both `make_vlm_baseline_loss_fn` and `_vlm_cce_forward` consume
   that raw copy before falling back to the narrowed view.

3. unsloth_zoo/mlx/utils.py: `_apply_response_mask_to_vlm_batch`
   built labels by calling `.tolist()` on the narrowed
   `batch_dict["input_ids"]`, so the masking closure never saw wide
   invalid ids. Consume the same `_unsloth_raw_input_ids_for_labels`
   carrier first, falling back to the narrowed view only when no raw
   copy was preserved. Both `_vlm_cce_forward` and
   `make_vlm_baseline_loss_fn` also strip the private key from the
   backbone forward kwargs so the model never receives it.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: bf9529a3d5

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +910 to +911
target_source = batch_dict.get(_RAW_INPUT_IDS_FOR_LABELS, input_ids)
targets = _normalize_cce_label_dtype(target_source[:, 1:])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Derive VLM targets from expanded input_ids

When labels are absent, this branch always prefers _RAW_INPUT_IDS_FOR_LABELS, but _prepare_vlm_batch_for_compile can expand batch_dict["input_ids"] for multi_modality, phi4-siglip, and phi4mm without updating that raw copy (see the _expand_* rewrites around _prepare_vlm_batch_for_compile). That leaves targets shorter/different than the actual forward sequence, and _align_logits_with_labels then truncates logits to the stale target length, silently dropping supervision for valid trailing tokens. In those model types this can corrupt training loss/gradients rather than just handling invalid-label edge cases.

Useful? React with 👍 / 👎.

…nslothai#682

_to_mx_vlm_batch stores the pre-narrow input_ids under _RAW_INPUT_IDS_FOR_LABELS
when no labels are present, so labels-free / response-mask loss paths can derive
targets from the original processor output before the int32 narrow. For
multi_modality, phi4-siglip/phi4_siglip, and phi4mm,
_prepare_vlm_batch_for_compile then expands input_ids/attention_mask to a longer
sequence (image- or audio-token runs). With labels=None the carrier was left
unchanged and went stale: target_source = batch_dict.get(_RAW_INPUT_IDS_FOR_LABELS,
input_ids) handed back the shorter pre-expansion ids while hidden states were
computed from the expanded sequence, silently misaligning targets against the
longer hidden-state sequence (no NaN poisoning, finite losses).

Pop the carrier in each expand branch when labels is None so the loss paths
fall back to the freshly expanded input_ids via the get(..., input_ids)
default. qwen/glm/phi3_v branches do not expand input_ids and keep the carrier
intact.
danielhanchen added a commit to danielhanchen/unsloth-zoo-staging-1 that referenced this pull request May 25, 2026
…nslothai#682

_to_mx_vlm_batch stores the pre-narrow input_ids under _RAW_INPUT_IDS_FOR_LABELS
when no labels are present, so labels-free / response-mask loss paths can derive
targets from the original processor output before the int32 narrow. For
multi_modality, phi4-siglip/phi4_siglip, and phi4mm,
_prepare_vlm_batch_for_compile then expands input_ids/attention_mask to a longer
sequence (image- or audio-token runs). With labels=None the carrier was left
unchanged and went stale: target_source = batch_dict.get(_RAW_INPUT_IDS_FOR_LABELS,
input_ids) handed back the shorter pre-expansion ids while hidden states were
computed from the expanded sequence, silently misaligning targets against the
longer hidden-state sequence (no NaN poisoning, finite losses).

Pop the carrier in each expand branch when labels is None so the loss paths
fall back to the freshly expanded input_ids via the get(..., input_ids)
default. qwen/glm/phi3_v branches do not expand input_ids and keep the carrier
intact.
…or PR unslothai#682

R21 popped _RAW_INPUT_IDS_FOR_LABELS in the labels-None expand branches
(multi_modality / phi4-siglip / phi4mm) to avoid a shape mismatch against the
expanded input_ids. That stopped the silent label/data corruption but threw
away the wide-invalid-label detection the carrier was introduced for: a raw
np.uint32(2**32 - 100) that narrows to -100 in input_ids.astype(int32) was
then quietly treated as ignore_index by the labels-free CCE path instead of
being NaN-poisoned. The labels-present branches had the mirror issue: when
_collate_vlm_prompt_completion_batch stashes the carrier before adding labels,
the carrier was left at pre-expansion length and crashed
_apply_response_mask_to_vlm_batch with a shape mismatch the moment a user
combined prompt/completion VLM batches with a response_mask_fn.

Pass the raw carrier into the existing expand helpers via their labels= slot
when labels are absent, then bind the expanded carrier back into the batch
under _RAW_INPUT_IDS_FOR_LABELS so the loss path keeps wide-invalid visibility
at the original positions. When labels are present the carrier is redundant;
pop it so any stale pre-expansion copy cannot reach the response-mask path.
Applies symmetrically to multi_modality, phi4-siglip / phi4_siglip, and phi4mm.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

danielhanchen added a commit to danielhanchen/unsloth-zoo-staging-1 that referenced this pull request May 26, 2026
…or PR unslothai#682

R21 popped _RAW_INPUT_IDS_FOR_LABELS in the labels-None expand branches
(multi_modality / phi4-siglip / phi4mm) to avoid a shape mismatch against the
expanded input_ids. That stopped the silent label/data corruption but threw
away the wide-invalid-label detection the carrier was introduced for: a raw
np.uint32(2**32 - 100) that narrows to -100 in input_ids.astype(int32) was
then quietly treated as ignore_index by the labels-free CCE path instead of
being NaN-poisoned. The labels-present branches had the mirror issue: when
_collate_vlm_prompt_completion_batch stashes the carrier before adding labels,
the carrier was left at pre-expansion length and crashed
_apply_response_mask_to_vlm_batch with a shape mismatch the moment a user
combined prompt/completion VLM batches with a response_mask_fn.

Pass the raw carrier into the existing expand helpers via their labels= slot
when labels are absent, then bind the expanded carrier back into the batch
under _RAW_INPUT_IDS_FOR_LABELS so the loss path keeps wide-invalid visibility
at the original positions. When labels are present the carrier is redundant;
pop it so any stale pre-expansion copy cannot reach the response-mask path.
Applies symmetrically to multi_modality, phi4-siglip / phi4_siglip, and phi4mm.
Comments-only refactor: drop pure narration and compress multi-line WHY
explanations to a single line per intent. No behavior change.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@danielhanchen danielhanchen merged commit eb26d3a into unslothai:main May 27, 2026
danielhanchen added a commit to Lyxot/unsloth-zoo that referenced this pull request May 27, 2026
…-export-adapters

# Conflicts:
#	unsloth_zoo/mlx/loader.py
#	unsloth_zoo/mlx/trainer.py
#	unsloth_zoo/mlx/utils.py
danielhanchen pushed a commit to mmathew23/unsloth-zoo that referenced this pull request May 27, 2026
Merges 5 main-side mlx fixes (unslothai#673 zero-token CCE, unslothai#679 + unslothai#692 LoRA save
metadata, unslothai#682 invalid label NaN-poisoning, unslothai#688 tool mask). All 13
conflict regions in unsloth_zoo/mlx/utils.py resolved to keep PR unslothai#684's
behavior where it conflicts on semantics:

  - half-open `<` length mask (PR unslothai#684 fix) wins over main's inclusive `<=`
  - `if labels is None` branch preserved (PR unslothai#684 generality) alongside
    main's `_normalize_cce_label_dtype` dtype widening
  - `_get_image_token_ids` legacy wrapper kept alongside main's new
    `_normalize_cce_label_dtype` / `_normalize_numpy_cce_labels`
  - `_mask_label_token_ids` calls `_normalize_cce_label_dtype` first so
    image masking honors main's uint-widening contract
  - HEAD's `_expand_token_replacements` dropped; main's three-function
    split (`_normalize_numpy_cce_labels` + `_expand_image_token_sequences`
    + `_expand_token_runs`) is canonical; duplicate HEAD wrappers removed
  - `_collate_vlm_prompt_completion_batch` reads back the masked labels
    in int64 so image + attention masking survives without narrowing
  - prompt-completion VLM collator routes through `_apply_vlm_label_masks`
    after dtype normalisation so ignore_token_ids and wide invalid ids
    both reach runtime CCE intact
  - `_to_mx_vlm_batch` uses main's `_normalize_cce_label_dtype` for labels
    while keeping PR unslothai#684's token_type_ids / mm_token_type_ids handling
  - `_unsloth_*` prefix filter preserved so the new collated_position_ids
    flag and main's raw-input-ids carrier both get stripped

152 MLX tests pass post-merge.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

auto-approved Auto-review approved the PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants