fix(mlx): handle zero-token and invalid labels in CCE#24
fix(mlx): handle zero-token and invalid labels in CCE#24danielhanchen wants to merge 13 commits into
Conversation
…nslothai#682 Three follow-ups from review feedback: 1. _poison_invalid_targets used mx.full(values.shape, NaN, ...) which allocates an O(n) tensor on every forward call, even when the invalid mask is all False. Replaced with a scalar NaN broadcast through mx.where so the normal path costs nothing extra. 2. The zero-token forward early-return now raises ValueError when hidden has zero rows but targets is non-empty. The previous version silently returned empty loss/LSE which masked an upstream shape mismatch instead of surfacing it. 3. The compiled invalid-label regression test now parametrizes over bad_target in [-1, vocab_size]. Negative and out-of-range labels take different lookup paths under mx.compile, so single-sided coverage was insufficient. Added a matching test for the new ValueError on hidden=0 with non-empty targets.
…ai#682 The scalar mx.array(float('nan'), ...) broadcast got baked into the Metal kernel as the literal token 'nan', which the Metal C++ tokenizer rejects ('use of undeclared identifier nan'). Restoring the original mx.full(values.shape, ...) form keeps the allocation but preserves Metal-kernel compilation. The two-sided compile-test parametrize and the hidden=0/targets!=0 ValueError defensive check remain.
Extend the existing n==0 shape-mismatch guard in _forward_chunked_fused_finalize so any targets.shape[0] != n is surfaced as a ValueError before chunk planning, fallback indexing, or Metal kernel launch. Short non-empty targets previously broadcast to the n-row loss and lse tensors, silently returning wrong-but-finite losses; longer targets reached kernels that index targets[row] for row in [0, n) and produced cryptic backend errors. The new check follows the PR intent of surfacing upstream shape bugs early.
…#682 The Metal dlogits kernel uses fast::exp(capped - lse[row]). When _poison_invalid_targets sets lse[row] to NaN for invalid labels, fast:: math is not IEEE 754 strict (MSL spec 6.5.1) and may return a finite value, leaving the forward loss NaN but the gradient silently wrong. Add an isnan(lse[row]) guard that emits NaN via 0.0f/0.0f (Metal C++ rejects the literal 'nan' token). Also tighten the invalid-labels gradient test: previously asserted only that _stable_norm(grads) is NaN, which is trivially satisfied because grad_weight is all-NaN whenever any invalid row exists. Now explicitly checks finite grad_hidden on the valid row, NaN on the invalid row, and zero on the ignore_index row.
Rank-2 (n, 1) or scalar targets slipped past the length check and crashed deep inside the kernels with backend-specific errors. Add a 1D guard at the top of _runtime_cce_loss_and_aux so callers get a clear ValueError.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request enhances the MLX chunked cross-entropy loss implementation by adding input validation for target shapes and robust handling for zero-token inputs and invalid labels. It introduces target validity masks and NaN poisoning to ensure correct error propagation in both CPU and Metal paths, preventing silent failures or crashes. The review feedback suggests centralizing the mask generation logic within the forward pass to improve code maintainability and cleanliness.
|
|
||
| lse = running_max + mx.log(running_sum_exp + 1e-9) | ||
| valid = targets != ignore_index | ||
| valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index) |
There was a problem hiding this comment.
The call to _target_validity_masks is duplicated in both the fallback and Metal paths within _forward_chunked_fused_finalize. While the current implementation is correct and efficient (as it is called outside the main loops), moving this calculation to the beginning of the function (e.g., after the input validation at line 518) would improve code cleanliness and maintainability by centralizing the mask generation logic.
_fallback_dlogits relies on lse arriving pre-poisoned with NaN for invalid-label rows (targets outside [0, vocab_size) and not equal to ignore_index). It does not re-check vocab bounds; NaN propagation through exp(capped - NaN) = NaN is what produces NaN gradients for those rows. Add a comment so future callers cannot accidentally pass unpoisoned lse and silently get finite garbage gradients.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request improves the robustness of the MLX Chunked Cross Entropy (CCE) implementation by adding validation for target shapes and token counts, and handling zero-token inputs in both forward and backward passes. It also introduces NaN poisoning for out-of-vocabulary labels to ensure correct gradient behavior, including specific handling within the Metal kernel. A comprehensive set of tests was added to cover these edge cases for standard, compiled, and quantized versions. I have no feedback to provide as no review comments were submitted.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request enhances the MLX chunked cross-entropy implementation by adding robust input validation and handling for edge cases. Specifically, it introduces checks for target tensor rank and length, handles zero-token inputs to prevent runtime errors, and implements a mechanism to inject NaNs for invalid labels (out-of-vocab) in both the forward and backward passes. Comprehensive tests were added to verify these changes for both standard and quantized linear layers. Review feedback focuses on adopting more idiomatic MLX/NumPy patterns for array property checks and refactoring mask generation to avoid logic duplication.
| vocab_size = weight_compute.shape[0] | ||
| # targets must be a flat 1D vector; rank-2 inputs like (n, 1) would slip | ||
| # past the length check and explode later inside kernels. | ||
| if len(targets.shape) != 1: |
| ) | ||
| if n == 0: | ||
| # surface upstream shape mismatch instead of silently dropping labels. | ||
| if targets.shape[0] != 0: |
| ) | ||
| empty = mx.zeros((0,), dtype=mx.float32) | ||
| return empty, empty | ||
| if targets.shape[0] != n: |
|
|
||
| lse = running_max + mx.log(running_sum_exp + 1e-9) | ||
| valid = targets != ignore_index | ||
| valid, invalid = _target_validity_masks(targets, vocab_size, ignore_index) |
The Metal forward-finalize kernel writes finite lse_out and loss_out for every row, including out-of-vocab targets it cannot classify because vocab_size is not a kernel input. The dlogits backward kernel relies on the Python caller having applied _poison_invalid_targets to both lse and loss before backward is invoked. Make that invariant explicit at the kernel definition site so future fused or multi-GPU paths that pass raw lse_out to the dlogits kernel will not silently produce finite wrong gradients for invalid labels.
|
Fixes pushed to unslothai#682. |
Staging mirror of unslothai#682
Original PR: unslothai#682
Author: Lyxot
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Summary
Hardens the MLX fused CCE runtime against two edge cases:
[0, vocab_size)exceptignore_indexBehavior
Before this change:
-1ortarget == vocab_sizecould produce finite-looking losses instead of surfacing bad input.Fix
NaNloss/LSE so bad labels cannot silently train with plausible loss values.mx.compile.Validation
This PR tracks the moving review branch (pr-682-head). Iteration fix commits land here directly. Review-added tests are in a separate PR.
Changed files:
.github/workflows/consolidated-tests-ci.yml.github/workflows/lint-ci.yml.github/workflows/mlx-ci.yml.github/workflows/security-audit.yml.github/workflows/stale.yml.github/workflows/studio-export-fix-ci.yml.github/workflows/wheel-smoke.ymlunsloth_zoo/mlx/cce/runtime_cce.pytests/test_mlx_runtime_cce_compile.py