fix(mlx): handle zero-token and invalid labels in CCE#682
Conversation
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR hardens the MLX runtime chunked cross-entropy implementation for edge cases by adding explicit handling for zero-token inputs and by poisoning invalid (out-of-vocab) targets with NaNs to surface data issues quickly.
Changes:
- Add helpers to compute target validity masks and poison invalid positions with NaNs.
- Add early returns for zero-token inputs in forward finalize and VJP paths.
- Add tests covering zero-token behavior and invalid-label poisoning (including compiled + quantized variants).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| unsloth_zoo/mlx/cce/runtime_cce.py | Adds invalid-target detection/NaN poisoning and zero-token early returns in forward/VJP. |
| tests/test_mlx_runtime_cce_compile.py | Adds regression tests for zero-token outputs/gradients and NaN poisoning for invalid labels. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _poison_invalid_targets(values: mx.array, invalid: mx.array) -> mx.array: | ||
| return mx.where( | ||
| invalid, | ||
| mx.full(values.shape, float("nan"), dtype=values.dtype), |
There was a problem hiding this comment.
Code Review
This pull request introduces robust handling for zero-token inputs and out-of-vocabulary target labels within the MLX runtime chunked cross-entropy implementation. Specifically, it adds logic to return empty losses and zero gradients for zero-token scenarios and ensures that invalid targets correctly result in NaN values for both loss and log-sum-exp. Comprehensive tests have been added to verify these behaviors in both standard and quantized paths. Feedback focuses on optimizing memory efficiency by using scalar broadcasting in mx.where instead of mx.full and consolidating conditional logic in the fallback path for better performance and readability.
| def _poison_invalid_targets(values: mx.array, invalid: mx.array) -> mx.array: | ||
| return mx.where( | ||
| invalid, | ||
| mx.full(values.shape, float("nan"), dtype=values.dtype), | ||
| values, | ||
| ) |
There was a problem hiding this comment.
Using mx.full to create a NaN array of the same shape as values is inefficient as it triggers a full allocation. MLX's mx.where supports scalar broadcasting, which is more memory-efficient and faster.
| def _poison_invalid_targets(values: mx.array, invalid: mx.array) -> mx.array: | |
| return mx.where( | |
| invalid, | |
| mx.full(values.shape, float("nan"), dtype=values.dtype), | |
| values, | |
| ) | |
| def _poison_invalid_targets(values: mx.array, invalid: mx.array) -> mx.array: | |
| return mx.where( | |
| invalid, | |
| mx.array(float("nan"), dtype=values.dtype), | |
| values, | |
| ) |
| valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index) | ||
| loss = mx.where(valid, lse - target_logit, mx.zeros_like(lse)) | ||
| loss = _poison_invalid_targets(loss, invalid) | ||
| lse = _poison_invalid_targets(lse, invalid) |
There was a problem hiding this comment.
In the fallback path, the logic for setting the loss to zero for ignored tokens and NaN for invalid tokens can be consolidated into a single nested mx.where call. This avoids multiple passes over the data and is more idiomatic. Additionally, ensure that logit scaling and softcapping transformations are applied to the logits before they are passed to this loss calculation to maintain consistency with standard execution paths.
| valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index) | |
| loss = mx.where(valid, lse - target_logit, mx.zeros_like(lse)) | |
| loss = _poison_invalid_targets(loss, invalid) | |
| lse = _poison_invalid_targets(lse, invalid) | |
| valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index) | |
| loss = mx.where(valid, lse - target_logit, mx.where(invalid, float("nan"), 0.0)) | |
| lse = _poison_invalid_targets(lse, invalid) |
References
- When implementing code paths that return materialized logits (e.g., using UNSLOTH_RETURN_LOGITS=1), ensure that logit scaling (multiply/divide) and softcapping transformations are applied to the logits before they are passed to the loss function to maintain consistency with standard execution paths.
The seven upstream workflows (consolidated-tests-ci, lint-ci, mlx-ci, security-audit, stale, studio-export-fix-ci, wheel-smoke) would fire on every push and PR-event to this throwaway staging branch and burn runner minutes that have nothing to do with validating MLX PRs unslothai#679, unslothai#682, unslothai#692. Keep only the three mlx-pr-* workflows on this branch. They stay in upstream main / origin/main untouched -- this deletion is scoped to the staging branch only.
…nslothai#682 Three follow-ups from review feedback: 1. _poison_invalid_targets used mx.full(values.shape, NaN, ...) which allocates an O(n) tensor on every forward call, even when the invalid mask is all False. Replaced with a scalar NaN broadcast through mx.where so the normal path costs nothing extra. 2. The zero-token forward early-return now raises ValueError when hidden has zero rows but targets is non-empty. The previous version silently returned empty loss/LSE which masked an upstream shape mismatch instead of surfacing it. 3. The compiled invalid-label regression test now parametrizes over bad_target in [-1, vocab_size]. Negative and out-of-range labels take different lookup paths under mx.compile, so single-sided coverage was insufficient. Added a matching test for the new ValueError on hidden=0 with non-empty targets.
…ai#682 The scalar mx.array(float('nan'), ...) broadcast got baked into the Metal kernel as the literal token 'nan', which the Metal C++ tokenizer rejects ('use of undeclared identifier nan'). Restoring the original mx.full(values.shape, ...) form keeps the allocation but preserves Metal-kernel compilation. The two-sided compile-test parametrize and the hidden=0/targets!=0 ValueError defensive check remain.
Extend the existing n==0 shape-mismatch guard in _forward_chunked_fused_finalize so any targets.shape[0] != n is surfaced as a ValueError before chunk planning, fallback indexing, or Metal kernel launch. Short non-empty targets previously broadcast to the n-row loss and lse tensors, silently returning wrong-but-finite losses; longer targets reached kernels that index targets[row] for row in [0, n) and produced cryptic backend errors. The new check follows the PR intent of surfacing upstream shape bugs early.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Auto-review verdict: Changes requested Reason: Deterministic gate: accumulated test(s) regressed after post_sim_commit at iteration 2 (1 regression event(s) total) |
…#682 The Metal dlogits kernel uses fast::exp(capped - lse[row]). When _poison_invalid_targets sets lse[row] to NaN for invalid labels, fast:: math is not IEEE 754 strict (MSL spec 6.5.1) and may return a finite value, leaving the forward loss NaN but the gradient silently wrong. Add an isnan(lse[row]) guard that emits NaN via 0.0f/0.0f (Metal C++ rejects the literal 'nan' token). Also tighten the invalid-labels gradient test: previously asserted only that _stable_norm(grads) is NaN, which is trivially satisfied because grad_weight is all-NaN whenever any invalid row exists. Now explicitly checks finite grad_hidden on the valid row, NaN on the invalid row, and zero on the ignore_index row.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Rank-2 (n, 1) or scalar targets slipped past the length check and crashed deep inside the kernels with backend-specific errors. Add a 1D guard at the top of _runtime_cce_loss_and_aux so callers get a clear ValueError.
_fallback_dlogits relies on lse arriving pre-poisoned with NaN for invalid-label rows (targets outside [0, vocab_size) and not equal to ignore_index). It does not re-check vocab bounds; NaN propagation through exp(capped - NaN) = NaN is what produces NaN gradients for those rows. Add a comment so future callers cannot accidentally pass unpoisoned lse and silently get finite garbage gradients.
The Metal forward-finalize kernel writes finite lse_out and loss_out for every row, including out-of-vocab targets it cannot classify because vocab_size is not a kernel input. The dlogits backward kernel relies on the Python caller having applied _poison_invalid_targets to both lse and loss before backward is invoked. Make that invariant explicit at the kernel definition site so future fused or multi-GPU paths that pass raw lse_out to the dlogits kernel will not silently produce finite wrong gradients for invalid labels.
|
Auto-review verdict: Approved Hardens unsloth_zoo/mlx/cce/runtime_cce.py against (a) zero-token inputs that crash the Metal kernel via an empty grid and (b) out-of-vocab labels that previously produced finite-garbage losses and gradients; adds upfront shape validation, NaN poisoning of loss/lse for invalid labels in both fallback and Metal-finalize paths, an isnan guard in the dlogits backward kernel, and 7+ tests covering eager, compiled, and quantized paths. Reason: All real review findings fixed; no remaining bugs; PR correctly addresses two concrete bugs in MLX CCE. |
Wider integer labels (int64 / uint32) outside the int32 range could wrap through targets.astype(mx.int32) into a valid class id or the ignore_index, sneaking past the new invalid-label NaN poisoning. Compute the validity masks from the original dtype first, then narrow for the kernels and indexing. Add a parametrized regression covering 2**32 plus/minus offsets and their negatives.
…nslothai#682 A wide invalid label that narrows to ignore_index after astype(mx.int32) (e.g. 2**32 - 100 -> -100 with ignore_index=-100) used to produce a NaN forward loss but a zero backward gradient. Move the NaN-lse guard before the ignore_index check in both the Metal dlogits kernel and the Python fallback so the gradient also poisons. Refuse the implicit zeroing path on rows whose lse is NaN. Also drop the .astype(mx.int32) pre-cast from the four CCE wrapper call sites (text dense / quantized; VLM dense / quantized). Casting before the runtime narrowed wide labels to valid class ids or ignore_index, slipping past the new original-dtype validation in _forward_chunked_fused_finalize. The runtime now performs that narrow internally only after the validity masks are built from the original dtype. New regression: test_runtime_cce_int64_wrap_to_ignore_index_poisons_gradients.
…hai#682 The zero-token guard previously returned `return empty, empty`, aliasing the same mx.array as both loss and lse. Today's downstream wrapper only does `losses + lse * 0.0` so aliasing is harmless, but any future caller that mutates one slot or relies on independent identity (e.g. VJP cotangent indexing) would surprise. Return two distinct mx.zeros((0,)) allocations instead. One-line change, no behavior change for current callers.
…hai#682 The runtime CCE now validates label dtype/range before its own int32 narrow so wide invalid labels (e.g. 2**32 - 100 wrapping to -100 after narrow) get NaN-poisoned instead of silently treated as ignore_index. The direct call-site .astype(mx.int32) pre-casts in utils.py CCE wrappers were removed earlier for the same reason. Two upstream sites still pre-narrowed: _to_mx_vlm_batch (utils.py:2168) and _collate_vlm_prompt_completion_batch's final labels cast (utils.py:2237). With those still narrowing, wide invalid VLM labels silently became -100 before reaching the validity check, regressing the PR's fix on the VLM path. Drop the narrows; the runtime CCE owns narrowing now.
…nslothai#682 _expand_image_token_sequences (utils.py:928) and _expand_token_runs (utils.py:1000) allocated padded_labels as np.int32. When a response- masked VLM batch contains a wide invalid label (e.g. 2**32 - 100 from a tokenizer that emits int64), the int32 allocation OverflowErrors before the runtime CCE validity check has a chance to reject the label. Allocate padded_labels as np.int64 so wide values reach the CCE runtime intact. The runtime narrows after validating, so we no longer lose the dtype/range information that triggers NaN-poisoning.
…efore validation for PR unslothai#682 Round-10 review caught two asymmetric-fix paths the previous int64 widening missed: * `_collate_vlm_prompt_completion_batch` still built `labels_np` with `dtype=np.int32` after `_to_mx_vlm_batch()` narrowed the model inputs. A wide-int label (e.g. `2**32 - 100` under a uint32 tokenizer) was discarded before reaching runtime CCE's validity classifier. Now build labels from `combined_inputs["input_ids"]` as `np.int64` before the batch narrow, mirroring the rest of the PR's "validate before int32 narrow" contract. * `_target_validity_masks` compared `targets >= 0` directly, which the torch-backed MLX simulation rejects for `uint16`/`uint32`/`uint64` ("ge_cpu" not implemented for UInt32). Cast unsigned dtypes to int64 for the comparison so wide-int wrap labels classify as out-of-vocab instead of crashing the validity step. Defensive no-op on real Apple-Metal MLX.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
…slothai#682 Round-11 review caught three remaining asymmetric-fix surfaces left open by the R10 unsigned-label fixes: * `_target_validity_masks` now widens `mx.uint64` via float64 (or float32 if the build lacks float64) before the validity comparisons. int64 would wraparound large unsigned values like `2**64 - 100` into negative ints that could collide with `ignore_index=-100`; float losslessly covers the meaningful 0..2**53 range and saturates above that into the out-of-vocab bucket. Bypasses the `ignore_index != target` check entirely when ignore_index is negative since unsigned values cannot reach it anyway. * `_to_mx_vlm_batch` now widens unsigned `labels` (uint8/16/32/64) to signed int64 via the shared `_normalize_cce_label_dtype` helper. The PR previously stopped narrowing labels to mx.int32 so runtime CCE could classify wide invalids, but VLM masking helpers downstream inject `-100` via `mx.where`, which crashes on unsigned tensors under the torch-backed MLX simulation ("where_cpu" not implemented for UInt32). int64 keeps wide-range visibility AND lets masking inject the signed sentinel cleanly. * `_mask_image_tokens` and `_mask_prompt_tokens` apply the same normalization defensively and use `mx.array(-100, dtype=targets.dtype)` so any caller-shaped target dtype (signed int8/16, float, etc.) gets a matched-dtype sentinel instead of relying on implicit Python-int broadcasting that some MLX shims reject.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
… text CCE branches for PR unslothai#682 Round-15 review caught a regression I introduced in R11 plus two asymmetric-fix gaps the prior rounds missed: * `_normalize_cce_label_dtype` cast `mx.uint64` straight through `astype(mx.int64)`. Values >= 2**63 wrap to negatives; specifically `2**64 - 100` lands on exactly `-100` and silently collides with `ignore_index`, defeating the runtime CCE validity check the rest of the PR is built around. Detect uint64 explicitly and route any value above `(1 << 63) - 1` to a known out-of-vocab sentinel via `mx.where`; uint8/16/32 still fit losslessly so a direct cast is fine. * Both quantized and dense text CCE branches in `make_chunked_cross_entropy_loss` still injected the `-100` sentinel via raw `mx.where(mask, targets, -100)`. The VLM masking helpers were R11-normalized to widen unsigned labels first; mirror the same contract in the text path so unsigned tokenizer labels don't crash the masking step on the torch-backed shim and don't bypass runtime CCE's validity classifier for wide invalids. * `_collate_vlm_prompt_completion_batch` built `labels_np` via `np.asarray(..., dtype=np.int64)`. The same uint64 wrap issue applies at the numpy layer: `np.uint64(2**64 - 100)` casts to int64 -100, hitting ignore_index before the runtime ever sees the value. Apply the same saturating-to-sentinel logic in numpy when the source dtype is uint64; non-uint64 paths still use the lossless int64 cast.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 483ffd3af6
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| _validation_dtype = getattr(mx, "float64", mx.float32) | ||
| targets_for_validation = targets.astype(_validation_dtype) | ||
| in_vocab = (targets_for_validation >= 0.0) & ( |
There was a problem hiding this comment.
Avoid float64 casts in uint64 label validation path
When targets.dtype == mx.uint64, this branch promotes labels to mx.float64 before computing in_vocab. On MLX, float64 arrays are CPU-only, so with the default GPU stream this can raise at runtime before the CCE kernels run. That means uint64-labeled batches (the exact edge case this patch tries to support) can fail outright on real MLX devices instead of being classified and NaN-poisoned.
Useful? React with 👍 / 👎.
…-label / np-expand / runtime float64 paths for PR unslothai#682 Fixes the four asymmetric-fix gaps flagged in round 16 review: 1. unsloth_zoo/mlx/utils.py: _normalize_cce_label_dtype previously did `mx.where(labels > max_i63, sentinel, labels_i64)` on raw uint64, which crashed the torch-backed MLX shim with `NotImplementedError: "gt_cpu" not implemented for 'UInt64'`. Cast to int64 first and detect overflow with `labels_i64 < 0`; any uint64 value above 2**63-1 wraps negative after the cast and is routed to a positive out-of-vocab sentinel so runtime CCE NaN- poisons the row instead of colliding with ignore_index. 2. unsloth_zoo/mlx/utils.py: _vlm_cce_forward(labels=None) used to apply `mx.where(length_mask, targets, -100)` directly on unsigned input_ids. Normalize through _normalize_cce_label_dtype first and inject the ignore sentinel via `mx.array(-100, dtype=targets.dtype)` so the path matches every other masking site. 3. unsloth_zoo/mlx/utils.py: add _normalize_numpy_cce_labels helper and apply it inside _expand_image_token_sequences and _expand_token_runs so wide np.uint64 labels (e.g. 2**64-100) do not OverflowError when `.tolist()` values are packed into the int64 padded label buffer. 4. unsloth_zoo/mlx/cce/runtime_cce.py: _target_validity_masks no longer requests `getattr(mx, "float64", mx.float32)` for uint64 validation. MLX float64 arrays are CPU-only per the public docs, so the prior branch would have broken on real Apple-Silicon/Metal even though it sidestepped the simulation crash. float32 is sufficient to separate normal vocabulary labels from uint64 overflow sentinels for the validity check.
…k labels for PR unslothai#682 Fixes four asymmetric-fix gaps from round 17 review: 1. unsloth_zoo/mlx/utils.py + unsloth_zoo/mlx/cce/runtime_cce.py: collators that pre-encode ignore_index as uint64(2**64-100) wrap to exactly int64(-100) after the cast. Round 16's overflow routing then sent every wrapped-negative value (including encoded ignore) to the invalid out-of-vocab sentinel and runtime CCE NaN-poisoned the row. Detect `labels_i64 == -100` (and the numpy analogue `labels_np == np.uint64(2**64-100)`) BEFORE routing overflow to the sentinel so legitimate ignore rows continue to be skipped downstream. Also drop the float32-based uint64 validation in `_target_validity_masks` (float32 cannot exactly represent labels above 2**24, so labels near vocab boundaries above that were misclassified). Validate through signed int64 with explicit overflow detection instead. Comment in the helper updated to document the actual sentinel contract (1 << 62, not iinfo max). 2. unsloth_zoo/mlx/utils.py: `_apply_response_mask_to_vlm_batch` built `mx.array(labels_list)` and injected an untyped `mx.array(-100)` directly. Custom `response_mask_fn` results carrying unsigned label dtypes then crashed the torch-backed MLX shim with NotImplementedError "where_cpu" not implemented for UInt*. Route labels through `_normalize_numpy_cce_labels` + `_normalize_cce_label_dtype` and inject a dtype-matched ignore sentinel. 3. unsloth_zoo/mlx/utils.py: text baseline CE path (`use_cce=False`) reached `mx.where(targets == -100, 0, targets)` directly on raw unsigned labels, leaving the asymmetric-fix gap that runtime CCE already covered. Normalize the labels through `_normalize_cce_label_dtype` and use a dtype-matched ignore scalar in the safe-targets `mx.where`. 4. unsloth_zoo/mlx/utils.py: the VLM baseline CE fallback had the same gap on both the labels-aware and labels=None branches. Mirror the same normalization + dtype-matched ignore sentinel.
…s-free + response-mask paths for PR unslothai#682 Fixes three asymmetric-fix gaps from round 18 review: 1. unsloth_zoo/mlx/utils.py + unsloth_zoo/mlx/cce/runtime_cce.py: round 17 preserved `uint64(2**64-100)` (which wraps to int64(-100) after cast) as the canonical ignore_index. That was wrong: a raw uint64 value above 2**63-1 is an out-of-range unsigned label, not an intentional encoded ignore. Treat ALL uint64 overflow values as invalid sentinels (1 << 62) so runtime CCE NaN-poisons them instead of silently dropping them as ignored rows. Applies symmetrically in `_normalize_cce_label_dtype`, `_normalize_numpy_cce_labels`, and `_target_validity_masks`. 2. unsloth_zoo/mlx/utils.py: labels-free VLM CCE / baseline paths used to derive targets from `input_ids[:, 1:]` AFTER `_to_mx_vlm_batch()` narrowed `input_ids` to mx.int32. A wide invalid id such as `np.uint32(2**32-100)` then wrapped to -100 in the narrowed view and was silently treated as ignore_index instead of being NaN-poisoned. Preserve the raw (pre-narrow) input_ids under a private `_unsloth_raw_input_ids_for_labels` key in `_to_mx_vlm_batch` when no explicit `labels` are present, and let both `make_vlm_baseline_loss_fn` and `_vlm_cce_forward` consume that raw copy before falling back to the narrowed view. 3. unsloth_zoo/mlx/utils.py: `_apply_response_mask_to_vlm_batch` built labels by calling `.tolist()` on the narrowed `batch_dict["input_ids"]`, so the masking closure never saw wide invalid ids. Consume the same `_unsloth_raw_input_ids_for_labels` carrier first, falling back to the narrowed view only when no raw copy was preserved. Both `_vlm_cce_forward` and `make_vlm_baseline_loss_fn` also strip the private key from the backbone forward kwargs so the model never receives it.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: bf9529a3d5
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| target_source = batch_dict.get(_RAW_INPUT_IDS_FOR_LABELS, input_ids) | ||
| targets = _normalize_cce_label_dtype(target_source[:, 1:]) |
There was a problem hiding this comment.
Derive VLM targets from expanded input_ids
When labels are absent, this branch always prefers _RAW_INPUT_IDS_FOR_LABELS, but _prepare_vlm_batch_for_compile can expand batch_dict["input_ids"] for multi_modality, phi4-siglip, and phi4mm without updating that raw copy (see the _expand_* rewrites around _prepare_vlm_batch_for_compile). That leaves targets shorter/different than the actual forward sequence, and _align_logits_with_labels then truncates logits to the stale target length, silently dropping supervision for valid trailing tokens. In those model types this can corrupt training loss/gradients rather than just handling invalid-label edge cases.
Useful? React with 👍 / 👎.
…nslothai#682 _to_mx_vlm_batch stores the pre-narrow input_ids under _RAW_INPUT_IDS_FOR_LABELS when no labels are present, so labels-free / response-mask loss paths can derive targets from the original processor output before the int32 narrow. For multi_modality, phi4-siglip/phi4_siglip, and phi4mm, _prepare_vlm_batch_for_compile then expands input_ids/attention_mask to a longer sequence (image- or audio-token runs). With labels=None the carrier was left unchanged and went stale: target_source = batch_dict.get(_RAW_INPUT_IDS_FOR_LABELS, input_ids) handed back the shorter pre-expansion ids while hidden states were computed from the expanded sequence, silently misaligning targets against the longer hidden-state sequence (no NaN poisoning, finite losses). Pop the carrier in each expand branch when labels is None so the loss paths fall back to the freshly expanded input_ids via the get(..., input_ids) default. qwen/glm/phi3_v branches do not expand input_ids and keep the carrier intact.
…nslothai#682 _to_mx_vlm_batch stores the pre-narrow input_ids under _RAW_INPUT_IDS_FOR_LABELS when no labels are present, so labels-free / response-mask loss paths can derive targets from the original processor output before the int32 narrow. For multi_modality, phi4-siglip/phi4_siglip, and phi4mm, _prepare_vlm_batch_for_compile then expands input_ids/attention_mask to a longer sequence (image- or audio-token runs). With labels=None the carrier was left unchanged and went stale: target_source = batch_dict.get(_RAW_INPUT_IDS_FOR_LABELS, input_ids) handed back the shorter pre-expansion ids while hidden states were computed from the expanded sequence, silently misaligning targets against the longer hidden-state sequence (no NaN poisoning, finite losses). Pop the carrier in each expand branch when labels is None so the loss paths fall back to the freshly expanded input_ids via the get(..., input_ids) default. qwen/glm/phi3_v branches do not expand input_ids and keep the carrier intact.
…or PR unslothai#682 R21 popped _RAW_INPUT_IDS_FOR_LABELS in the labels-None expand branches (multi_modality / phi4-siglip / phi4mm) to avoid a shape mismatch against the expanded input_ids. That stopped the silent label/data corruption but threw away the wide-invalid-label detection the carrier was introduced for: a raw np.uint32(2**32 - 100) that narrows to -100 in input_ids.astype(int32) was then quietly treated as ignore_index by the labels-free CCE path instead of being NaN-poisoned. The labels-present branches had the mirror issue: when _collate_vlm_prompt_completion_batch stashes the carrier before adding labels, the carrier was left at pre-expansion length and crashed _apply_response_mask_to_vlm_batch with a shape mismatch the moment a user combined prompt/completion VLM batches with a response_mask_fn. Pass the raw carrier into the existing expand helpers via their labels= slot when labels are absent, then bind the expanded carrier back into the batch under _RAW_INPUT_IDS_FOR_LABELS so the loss path keeps wide-invalid visibility at the original positions. When labels are present the carrier is redundant; pop it so any stale pre-expansion copy cannot reach the response-mask path. Applies symmetrically to multi_modality, phi4-siglip / phi4_siglip, and phi4mm.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
…or PR unslothai#682 R21 popped _RAW_INPUT_IDS_FOR_LABELS in the labels-None expand branches (multi_modality / phi4-siglip / phi4mm) to avoid a shape mismatch against the expanded input_ids. That stopped the silent label/data corruption but threw away the wide-invalid-label detection the carrier was introduced for: a raw np.uint32(2**32 - 100) that narrows to -100 in input_ids.astype(int32) was then quietly treated as ignore_index by the labels-free CCE path instead of being NaN-poisoned. The labels-present branches had the mirror issue: when _collate_vlm_prompt_completion_batch stashes the carrier before adding labels, the carrier was left at pre-expansion length and crashed _apply_response_mask_to_vlm_batch with a shape mismatch the moment a user combined prompt/completion VLM batches with a response_mask_fn. Pass the raw carrier into the existing expand helpers via their labels= slot when labels are absent, then bind the expanded carrier back into the batch under _RAW_INPUT_IDS_FOR_LABELS so the loss path keeps wide-invalid visibility at the original positions. When labels are present the carrier is redundant; pop it so any stale pre-expansion copy cannot reach the response-mask path. Applies symmetrically to multi_modality, phi4-siglip / phi4_siglip, and phi4mm.
Comments-only refactor: drop pure narration and compress multi-line WHY explanations to a single line per intent. No behavior change.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
…-export-adapters # Conflicts: # unsloth_zoo/mlx/loader.py # unsloth_zoo/mlx/trainer.py # unsloth_zoo/mlx/utils.py
Merges 5 main-side mlx fixes (unslothai#673 zero-token CCE, unslothai#679 + unslothai#692 LoRA save metadata, unslothai#682 invalid label NaN-poisoning, unslothai#688 tool mask). All 13 conflict regions in unsloth_zoo/mlx/utils.py resolved to keep PR unslothai#684's behavior where it conflicts on semantics: - half-open `<` length mask (PR unslothai#684 fix) wins over main's inclusive `<=` - `if labels is None` branch preserved (PR unslothai#684 generality) alongside main's `_normalize_cce_label_dtype` dtype widening - `_get_image_token_ids` legacy wrapper kept alongside main's new `_normalize_cce_label_dtype` / `_normalize_numpy_cce_labels` - `_mask_label_token_ids` calls `_normalize_cce_label_dtype` first so image masking honors main's uint-widening contract - HEAD's `_expand_token_replacements` dropped; main's three-function split (`_normalize_numpy_cce_labels` + `_expand_image_token_sequences` + `_expand_token_runs`) is canonical; duplicate HEAD wrappers removed - `_collate_vlm_prompt_completion_batch` reads back the masked labels in int64 so image + attention masking survives without narrowing - prompt-completion VLM collator routes through `_apply_vlm_label_masks` after dtype normalisation so ignore_token_ids and wide invalid ids both reach runtime CCE intact - `_to_mx_vlm_batch` uses main's `_normalize_cce_label_dtype` for labels while keeping PR unslothai#684's token_type_ids / mm_token_type_ids handling - `_unsloth_*` prefix filter preserved so the new collated_position_ids flag and main's raw-input-ids carrier both get stripped 152 MLX tests pass post-merge.
Summary
Hardens the MLX fused CCE runtime against two edge cases:
[0, vocab_size)exceptignore_indexBehavior
Before this change:
-1ortarget == vocab_sizecould produce finite-looking losses instead of surfacing bad input.Fix
NaNloss/LSE so bad labels cannot silently train with plausible loss values.mx.compile.Validation