Skip to content

[tests] fix(mlx): handle zero-token and invalid labels in CCE#27

Closed
danielhanchen wants to merge 14 commits into
mainfrom
pr-682-tests
Closed

[tests] fix(mlx): handle zero-token and invalid labels in CCE#27
danielhanchen wants to merge 14 commits into
mainfrom
pr-682-tests

Conversation

@danielhanchen

Copy link
Copy Markdown
Owner

Staging mirror of unslothai#682

Original PR: unslothai#682
Author: Lyxot

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Summary

Hardens the MLX fused CCE runtime against two edge cases:

  • zero-token inputs
  • invalid target labels outside [0, vocab_size) except ignore_index

Behavior

Before this change:

  • zero-token inputs could reach the Metal kernel with an empty grid and crash the process.
  • invalid labels such as -1 or target == vocab_size could produce finite-looking losses instead of surfacing bad input.

Fix

  • Return empty loss/LSE tensors for zero-token forward passes before launching Metal kernels.
  • Return zero gradients for zero-token dense and quantized backward paths.
  • Mark invalid-label rows with NaN loss/LSE so bad labels cannot silently train with plausible loss values.
  • Keep invalid-label handling compile-safe under mx.compile.

Validation

python -m pytest tests/test_mlx_runtime_cce_compile.py tests/test_mlx_baseline_loss_parity.py -q

This PR contains test changes only (1 files). Code changes are in the head PR.

Test files:

  • tests/test_mlx_runtime_cce_compile.py

Lyxot and others added 12 commits May 21, 2026 00:30
…nslothai#682

Three follow-ups from review feedback:

1. _poison_invalid_targets used mx.full(values.shape, NaN, ...) which
   allocates an O(n) tensor on every forward call, even when the invalid
   mask is all False. Replaced with a scalar NaN broadcast through
   mx.where so the normal path costs nothing extra.

2. The zero-token forward early-return now raises ValueError when
   hidden has zero rows but targets is non-empty. The previous version
   silently returned empty loss/LSE which masked an upstream shape
   mismatch instead of surfacing it.

3. The compiled invalid-label regression test now parametrizes over
   bad_target in [-1, vocab_size]. Negative and out-of-range labels
   take different lookup paths under mx.compile, so single-sided
   coverage was insufficient. Added a matching test for the new
   ValueError on hidden=0 with non-empty targets.
…ai#682

The scalar mx.array(float('nan'), ...) broadcast got baked into the
Metal kernel as the literal token 'nan', which the Metal C++
tokenizer rejects ('use of undeclared identifier nan'). Restoring
the original mx.full(values.shape, ...) form keeps the allocation
but preserves Metal-kernel compilation. The two-sided compile-test
parametrize and the hidden=0/targets!=0 ValueError defensive check
remain.
Extend the existing n==0 shape-mismatch guard in
_forward_chunked_fused_finalize so any targets.shape[0] != n is surfaced
as a ValueError before chunk planning, fallback indexing, or Metal
kernel launch. Short non-empty targets previously broadcast to the
n-row loss and lse tensors, silently returning wrong-but-finite losses;
longer targets reached kernels that index targets[row] for row in
[0, n) and produced cryptic backend errors. The new check follows the
PR intent of surfacing upstream shape bugs early.
…#682

The Metal dlogits kernel uses fast::exp(capped - lse[row]). When
_poison_invalid_targets sets lse[row] to NaN for invalid labels,
fast:: math is not IEEE 754 strict (MSL spec 6.5.1) and may return
a finite value, leaving the forward loss NaN but the gradient
silently wrong. Add an isnan(lse[row]) guard that emits NaN via
0.0f/0.0f (Metal C++ rejects the literal 'nan' token).

Also tighten the invalid-labels gradient test: previously asserted
only that _stable_norm(grads) is NaN, which is trivially satisfied
because grad_weight is all-NaN whenever any invalid row exists.
Now explicitly checks finite grad_hidden on the valid row, NaN on
the invalid row, and zero on the ignore_index row.
Rank-2 (n, 1) or scalar targets slipped past the length check and crashed
deep inside the kernels with backend-specific errors. Add a 1D guard at
the top of _runtime_cce_loss_and_aux so callers get a clear ValueError.
_fallback_dlogits relies on lse arriving pre-poisoned with NaN for
invalid-label rows (targets outside [0, vocab_size) and not equal to
ignore_index). It does not re-check vocab bounds; NaN propagation
through exp(capped - NaN) = NaN is what produces NaN gradients for
those rows. Add a comment so future callers cannot accidentally pass
unpoisoned lse and silently get finite garbage gradients.
The Metal forward-finalize kernel writes finite lse_out and loss_out
for every row, including out-of-vocab targets it cannot classify
because vocab_size is not a kernel input. The dlogits backward kernel
relies on the Python caller having applied _poison_invalid_targets to
both lse and loss before backward is invoked. Make that invariant
explicit at the kernel definition site so future fused or multi-GPU
paths that pass raw lse_out to the dlogits kernel will not silently
produce finite wrong gradients for invalid labels.
Add backward NaN-propagation gradient tests for the compiled and
quantized invalid-label paths, plus PEP 8 blank lines, to
test_mlx_runtime_cce_compile.py (gated on real Apple Silicon).

Add tests/test_mlx_cce_target_classification.py covering the
pure-Python validation paths, in-vocab ignore_index precedence, and
logit_softcap interaction with invalid labels under the simulation
shim so non-Apple-Silicon CI exercises these branches too.
@danielhanchen

Copy link
Copy Markdown
Owner Author

Fixes pushed to unslothai#682.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants