[tests] fix(mlx): handle zero-token and invalid labels in CCE#21
Closed
danielhanchen wants to merge 7 commits into
Closed
[tests] fix(mlx): handle zero-token and invalid labels in CCE#21danielhanchen wants to merge 7 commits into
danielhanchen wants to merge 7 commits into
Conversation
…nslothai#682 Three follow-ups from review feedback: 1. _poison_invalid_targets used mx.full(values.shape, NaN, ...) which allocates an O(n) tensor on every forward call, even when the invalid mask is all False. Replaced with a scalar NaN broadcast through mx.where so the normal path costs nothing extra. 2. The zero-token forward early-return now raises ValueError when hidden has zero rows but targets is non-empty. The previous version silently returned empty loss/LSE which masked an upstream shape mismatch instead of surfacing it. 3. The compiled invalid-label regression test now parametrizes over bad_target in [-1, vocab_size]. Negative and out-of-range labels take different lookup paths under mx.compile, so single-sided coverage was insufficient. Added a matching test for the new ValueError on hidden=0 with non-empty targets.
…ai#682 The scalar mx.array(float('nan'), ...) broadcast got baked into the Metal kernel as the literal token 'nan', which the Metal C++ tokenizer rejects ('use of undeclared identifier nan'). Restoring the original mx.full(values.shape, ...) form keeps the allocation but preserves Metal-kernel compilation. The two-sided compile-test parametrize and the hidden=0/targets!=0 ValueError defensive check remain.
Extend the existing n==0 shape-mismatch guard in _forward_chunked_fused_finalize so any targets.shape[0] != n is surfaced as a ValueError before chunk planning, fallback indexing, or Metal kernel launch. Short non-empty targets previously broadcast to the n-row loss and lse tensors, silently returning wrong-but-finite losses; longer targets reached kernels that index targets[row] for row in [0, n) and produced cryptic backend errors. The new check follows the PR intent of surfacing upstream shape bugs early.
fad1b06 to
5d8a382
Compare
Owner
Author
|
Fixes pushed to unslothai#682. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Staging mirror of unslothai#682
Original PR: unslothai#682
Author: Lyxot
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Summary
Hardens the MLX fused CCE runtime against two edge cases:
[0, vocab_size)exceptignore_indexBehavior
Before this change:
-1ortarget == vocab_sizecould produce finite-looking losses instead of surfacing bad input.Fix
NaNloss/LSE so bad labels cannot silently train with plausible loss values.mx.compile.Validation
This PR contains test changes only (1 files). Code changes are in the head PR.
Test files:
tests/test_mlx_runtime_cce_compile.py