Skip to content

fix(mlx): handle zero-token and invalid labels in CCE#24

Closed
danielhanchen wants to merge 13 commits into
mainfrom
pr-682-head
Closed

fix(mlx): handle zero-token and invalid labels in CCE#24
danielhanchen wants to merge 13 commits into
mainfrom
pr-682-head

Conversation

@danielhanchen

Copy link
Copy Markdown
Owner

Staging mirror of unslothai#682

Original PR: unslothai#682
Author: Lyxot

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Summary

Hardens the MLX fused CCE runtime against two edge cases:

  • zero-token inputs
  • invalid target labels outside [0, vocab_size) except ignore_index

Behavior

Before this change:

  • zero-token inputs could reach the Metal kernel with an empty grid and crash the process.
  • invalid labels such as -1 or target == vocab_size could produce finite-looking losses instead of surfacing bad input.

Fix

  • Return empty loss/LSE tensors for zero-token forward passes before launching Metal kernels.
  • Return zero gradients for zero-token dense and quantized backward paths.
  • Mark invalid-label rows with NaN loss/LSE so bad labels cannot silently train with plausible loss values.
  • Keep invalid-label handling compile-safe under mx.compile.

Validation

python -m pytest tests/test_mlx_runtime_cce_compile.py tests/test_mlx_baseline_loss_parity.py -q

This PR tracks the moving review branch (pr-682-head). Iteration fix commits land here directly. Review-added tests are in a separate PR.

Changed files:

  • .github/workflows/consolidated-tests-ci.yml
  • .github/workflows/lint-ci.yml
  • .github/workflows/mlx-ci.yml
  • .github/workflows/security-audit.yml
  • .github/workflows/stale.yml
  • .github/workflows/studio-export-fix-ci.yml
  • .github/workflows/wheel-smoke.yml
  • unsloth_zoo/mlx/cce/runtime_cce.py
  • tests/test_mlx_runtime_cce_compile.py

Lyxot and others added 11 commits May 21, 2026 00:30
…nslothai#682

Three follow-ups from review feedback:

1. _poison_invalid_targets used mx.full(values.shape, NaN, ...) which
   allocates an O(n) tensor on every forward call, even when the invalid
   mask is all False. Replaced with a scalar NaN broadcast through
   mx.where so the normal path costs nothing extra.

2. The zero-token forward early-return now raises ValueError when
   hidden has zero rows but targets is non-empty. The previous version
   silently returned empty loss/LSE which masked an upstream shape
   mismatch instead of surfacing it.

3. The compiled invalid-label regression test now parametrizes over
   bad_target in [-1, vocab_size]. Negative and out-of-range labels
   take different lookup paths under mx.compile, so single-sided
   coverage was insufficient. Added a matching test for the new
   ValueError on hidden=0 with non-empty targets.
…ai#682

The scalar mx.array(float('nan'), ...) broadcast got baked into the
Metal kernel as the literal token 'nan', which the Metal C++
tokenizer rejects ('use of undeclared identifier nan'). Restoring
the original mx.full(values.shape, ...) form keeps the allocation
but preserves Metal-kernel compilation. The two-sided compile-test
parametrize and the hidden=0/targets!=0 ValueError defensive check
remain.
Extend the existing n==0 shape-mismatch guard in
_forward_chunked_fused_finalize so any targets.shape[0] != n is surfaced
as a ValueError before chunk planning, fallback indexing, or Metal
kernel launch. Short non-empty targets previously broadcast to the
n-row loss and lse tensors, silently returning wrong-but-finite losses;
longer targets reached kernels that index targets[row] for row in
[0, n) and produced cryptic backend errors. The new check follows the
PR intent of surfacing upstream shape bugs early.
…#682

The Metal dlogits kernel uses fast::exp(capped - lse[row]). When
_poison_invalid_targets sets lse[row] to NaN for invalid labels,
fast:: math is not IEEE 754 strict (MSL spec 6.5.1) and may return
a finite value, leaving the forward loss NaN but the gradient
silently wrong. Add an isnan(lse[row]) guard that emits NaN via
0.0f/0.0f (Metal C++ rejects the literal 'nan' token).

Also tighten the invalid-labels gradient test: previously asserted
only that _stable_norm(grads) is NaN, which is trivially satisfied
because grad_weight is all-NaN whenever any invalid row exists.
Now explicitly checks finite grad_hidden on the valid row, NaN on
the invalid row, and zero on the ignore_index row.
Rank-2 (n, 1) or scalar targets slipped past the length check and crashed
deep inside the kernels with backend-specific errors. Add a 1D guard at
the top of _runtime_cce_loss_and_aux so callers get a clear ValueError.
@danielhanchen

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the MLX chunked cross-entropy loss implementation by adding input validation for target shapes and robust handling for zero-token inputs and invalid labels. It introduces target validity masks and NaN poisoning to ensure correct error propagation in both CPU and Metal paths, preventing silent failures or crashes. The review feedback suggests centralizing the mask generation logic within the forward pass to improve code maintainability and cleanliness.


lse = running_max + mx.log(running_sum_exp + 1e-9)
valid = targets != ignore_index
valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The call to _target_validity_masks is duplicated in both the fallback and Metal paths within _forward_chunked_fused_finalize. While the current implementation is correct and efficient (as it is called outside the main loops), moving this calculation to the beginning of the function (e.g., after the input validation at line 518) would improve code cleanliness and maintainability by centralizing the mask generation logic.

_fallback_dlogits relies on lse arriving pre-poisoned with NaN for
invalid-label rows (targets outside [0, vocab_size) and not equal to
ignore_index). It does not re-check vocab bounds; NaN propagation
through exp(capped - NaN) = NaN is what produces NaN gradients for
those rows. Add a comment so future callers cannot accidentally pass
unpoisoned lse and silently get finite garbage gradients.
@danielhanchen

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request improves the robustness of the MLX Chunked Cross Entropy (CCE) implementation by adding validation for target shapes and token counts, and handling zero-token inputs in both forward and backward passes. It also introduces NaN poisoning for out-of-vocabulary labels to ensure correct gradient behavior, including specific handling within the Metal kernel. A comprehensive set of tests was added to cover these edge cases for standard, compiled, and quantized versions. I have no feedback to provide as no review comments were submitted.

@danielhanchen

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the MLX chunked cross-entropy implementation by adding robust input validation and handling for edge cases. Specifically, it introduces checks for target tensor rank and length, handles zero-token inputs to prevent runtime errors, and implements a mechanism to inject NaNs for invalid labels (out-of-vocab) in both the forward and backward passes. Comprehensive tests were added to verify these changes for both standard and quantized linear layers. Review feedback focuses on adopting more idiomatic MLX/NumPy patterns for array property checks and refactoring mask generation to avoid logic duplication.

vocab_size = weight_compute.shape[0]
# targets must be a flat 1D vector; rank-2 inputs like (n, 1) would slip
# past the length check and explode later inside kernels.
if len(targets.shape) != 1:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

For consistency with MLX and NumPy idioms, it's generally preferred to use .ndim instead of len(.shape) to check the rank of an array.

Suggested change
if len(targets.shape) != 1:
if targets.ndim != 1:

)
if n == 0:
# surface upstream shape mismatch instead of silently dropping labels.
if targets.shape[0] != 0:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Using len(targets) is slightly more idiomatic than targets.shape[0] for 1D arrays in Python/MLX.

Suggested change
if targets.shape[0] != 0:
if len(targets) != 0:

)
empty = mx.zeros((0,), dtype=mx.float32)
return empty, empty
if targets.shape[0] != n:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Using len(targets) is slightly more idiomatic than targets.shape[0] for 1D arrays.

Suggested change
if targets.shape[0] != n:
if len(targets) != n:


lse = running_max + mx.log(running_sum_exp + 1e-9)
valid = targets != ignore_index
valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Consider moving the _target_validity_masks call up before the fallback/fused branches (e.g., after line 518). This would allow you to compute the masks once and reuse them in both paths, making the code slightly cleaner and avoiding duplication of the logic.

The Metal forward-finalize kernel writes finite lse_out and loss_out
for every row, including out-of-vocab targets it cannot classify
because vocab_size is not a kernel input. The dlogits backward kernel
relies on the Python caller having applied _poison_invalid_targets to
both lse and loss before backward is invoked. Make that invariant
explicit at the kernel definition site so future fused or multi-GPU
paths that pass raw lse_out to the dlogits kernel will not silently
produce finite wrong gradients for invalid labels.
@danielhanchen

Copy link
Copy Markdown
Owner Author

Fixes pushed to unslothai#682.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants