[tests] fix(mlx): handle zero-token and invalid labels in CCE#27
Closed
danielhanchen wants to merge 14 commits into
Closed
[tests] fix(mlx): handle zero-token and invalid labels in CCE#27danielhanchen wants to merge 14 commits into
danielhanchen wants to merge 14 commits into
Conversation
…nslothai#682 Three follow-ups from review feedback: 1. _poison_invalid_targets used mx.full(values.shape, NaN, ...) which allocates an O(n) tensor on every forward call, even when the invalid mask is all False. Replaced with a scalar NaN broadcast through mx.where so the normal path costs nothing extra. 2. The zero-token forward early-return now raises ValueError when hidden has zero rows but targets is non-empty. The previous version silently returned empty loss/LSE which masked an upstream shape mismatch instead of surfacing it. 3. The compiled invalid-label regression test now parametrizes over bad_target in [-1, vocab_size]. Negative and out-of-range labels take different lookup paths under mx.compile, so single-sided coverage was insufficient. Added a matching test for the new ValueError on hidden=0 with non-empty targets.
…ai#682 The scalar mx.array(float('nan'), ...) broadcast got baked into the Metal kernel as the literal token 'nan', which the Metal C++ tokenizer rejects ('use of undeclared identifier nan'). Restoring the original mx.full(values.shape, ...) form keeps the allocation but preserves Metal-kernel compilation. The two-sided compile-test parametrize and the hidden=0/targets!=0 ValueError defensive check remain.
Extend the existing n==0 shape-mismatch guard in _forward_chunked_fused_finalize so any targets.shape[0] != n is surfaced as a ValueError before chunk planning, fallback indexing, or Metal kernel launch. Short non-empty targets previously broadcast to the n-row loss and lse tensors, silently returning wrong-but-finite losses; longer targets reached kernels that index targets[row] for row in [0, n) and produced cryptic backend errors. The new check follows the PR intent of surfacing upstream shape bugs early.
…#682 The Metal dlogits kernel uses fast::exp(capped - lse[row]). When _poison_invalid_targets sets lse[row] to NaN for invalid labels, fast:: math is not IEEE 754 strict (MSL spec 6.5.1) and may return a finite value, leaving the forward loss NaN but the gradient silently wrong. Add an isnan(lse[row]) guard that emits NaN via 0.0f/0.0f (Metal C++ rejects the literal 'nan' token). Also tighten the invalid-labels gradient test: previously asserted only that _stable_norm(grads) is NaN, which is trivially satisfied because grad_weight is all-NaN whenever any invalid row exists. Now explicitly checks finite grad_hidden on the valid row, NaN on the invalid row, and zero on the ignore_index row.
Rank-2 (n, 1) or scalar targets slipped past the length check and crashed deep inside the kernels with backend-specific errors. Add a 1D guard at the top of _runtime_cce_loss_and_aux so callers get a clear ValueError.
_fallback_dlogits relies on lse arriving pre-poisoned with NaN for invalid-label rows (targets outside [0, vocab_size) and not equal to ignore_index). It does not re-check vocab bounds; NaN propagation through exp(capped - NaN) = NaN is what produces NaN gradients for those rows. Add a comment so future callers cannot accidentally pass unpoisoned lse and silently get finite garbage gradients.
bd4f833 to
0c4aef9
Compare
The Metal forward-finalize kernel writes finite lse_out and loss_out for every row, including out-of-vocab targets it cannot classify because vocab_size is not a kernel input. The dlogits backward kernel relies on the Python caller having applied _poison_invalid_targets to both lse and loss before backward is invoked. Make that invariant explicit at the kernel definition site so future fused or multi-GPU paths that pass raw lse_out to the dlogits kernel will not silently produce finite wrong gradients for invalid labels.
Add backward NaN-propagation gradient tests for the compiled and quantized invalid-label paths, plus PEP 8 blank lines, to test_mlx_runtime_cce_compile.py (gated on real Apple Silicon). Add tests/test_mlx_cce_target_classification.py covering the pure-Python validation paths, in-vocab ignore_index precedence, and logit_softcap interaction with invalid labels under the simulation shim so non-Apple-Silicon CI exercises these branches too.
Owner
Author
|
Fixes pushed to unslothai#682. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Staging mirror of unslothai#682
Original PR: unslothai#682
Author: Lyxot
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Summary
Hardens the MLX fused CCE runtime against two edge cases:
[0, vocab_size)exceptignore_indexBehavior
Before this change:
-1ortarget == vocab_sizecould produce finite-looking losses instead of surfacing bad input.Fix
NaNloss/LSE so bad labels cannot silently train with plausible loss values.mx.compile.Validation
This PR contains test changes only (1 files). Code changes are in the head PR.
Test files:
tests/test_mlx_runtime_cce_compile.py