fix(mlx): handle zero-token and invalid labels in CCE#20
Conversation
…nslothai#682 Three follow-ups from review feedback: 1. _poison_invalid_targets used mx.full(values.shape, NaN, ...) which allocates an O(n) tensor on every forward call, even when the invalid mask is all False. Replaced with a scalar NaN broadcast through mx.where so the normal path costs nothing extra. 2. The zero-token forward early-return now raises ValueError when hidden has zero rows but targets is non-empty. The previous version silently returned empty loss/LSE which masked an upstream shape mismatch instead of surfacing it. 3. The compiled invalid-label regression test now parametrizes over bad_target in [-1, vocab_size]. Negative and out-of-range labels take different lookup paths under mx.compile, so single-sided coverage was insufficient. Added a matching test for the new ValueError on hidden=0 with non-empty targets.
…ai#682 The scalar mx.array(float('nan'), ...) broadcast got baked into the Metal kernel as the literal token 'nan', which the Metal C++ tokenizer rejects ('use of undeclared identifier nan'). Restoring the original mx.full(values.shape, ...) form keeps the allocation but preserves Metal-kernel compilation. The two-sided compile-test parametrize and the hidden=0/targets!=0 ValueError defensive check remain.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request improves the robustness of the MLX chunked cross-entropy loss implementation by introducing comprehensive handling for edge cases. It adds logic to detect and raise errors for shape mismatches, specifically when hidden tokens are missing but targets are present, and ensures that empty inputs return zeroed gradients. Additionally, it implements a mechanism to identify and 'poison' invalid target labels with NaN values to prevent silent errors. The reviewer feedback suggests a performance optimization for NaN generation in eager mode and recommends adding explicit dimension checks to the input validation logic.
| def _poison_invalid_targets(values: mx.array, invalid: mx.array) -> mx.array: | ||
| # mx.full produces a real tensor; a 0-d scalar gets baked into the | ||
| # Metal kernel as the literal token `nan` which the Metal C++ | ||
| # tokenizer rejects (use of undeclared identifier 'nan'). | ||
| return mx.where( | ||
| invalid, | ||
| mx.full(values.shape, float("nan"), dtype=values.dtype), | ||
| values, | ||
| ) |
There was a problem hiding this comment.
The use of mx.full here is a clever workaround for the Metal compilation issue with literal nan tokens. However, for better performance in the eager path (non-compiled), you might consider using mx.array(float("nan"), dtype=values.dtype) which allows MLX to handle the broadcasting more efficiently without allocating a full tensor of NaNs upfront. If mx.compile still correctly handles the 0-d tensor as a kernel argument rather than inlining it as a literal, this would be a slight optimization.
| if targets.shape[0] != 0: | ||
| raise ValueError( | ||
| "MLX CCE: hidden has 0 tokens but targets is non-empty " | ||
| f"(targets.shape={targets.shape})." | ||
| ) |
There was a problem hiding this comment.
This check is excellent for catching shape mismatches early. However, it assumes that hidden and targets are already flattened or have matching leading dimensions. While this is the standard expectation for this internal runtime, adding a check for hidden.ndim == 2 and targets.ndim == 1 (or ensuring they match) could further harden the API against unexpected input shapes.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request enhances the MLX chunked cross-entropy implementation by adding robust handling for zero-token inputs and invalid target labels. Key changes include the introduction of target validity masking and a mechanism to poison loss and log-sum-exp values with NaN for out-of-vocab targets, ensuring that bad data is not silently ignored. Additionally, the code now explicitly handles zero-token scenarios by raising a ValueError on shape mismatches or returning zero gradients during the backward pass. New tests have been added to cover these edge cases for both standard and quantized versions. I have no feedback to provide.
Extend the existing n==0 shape-mismatch guard in _forward_chunked_fused_finalize so any targets.shape[0] != n is surfaced as a ValueError before chunk planning, fallback indexing, or Metal kernel launch. Short non-empty targets previously broadcast to the n-row loss and lse tensors, silently returning wrong-but-finite losses; longer targets reached kernels that index targets[row] for row in [0, n) and produced cryptic backend errors. The new check follows the PR intent of surfacing upstream shape bugs early.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request enhances the MLX Chunked Cross Entropy (CCE) implementation by introducing validation for input shapes and handling invalid target labels. Key changes include raising ValueErrors for shape mismatches between hidden states and targets, returning zero gradients for empty inputs, and poisoning loss and LSE values with NaN when target indices are negative or exceed the vocabulary size. These improvements are applied to both standard and quantized paths, with comprehensive new tests verifying the behavior across various execution modes, including mx.compile. I have no feedback to provide as there were no review comments.
|
Fixes pushed to unslothai#682. |
Staging mirror of unslothai#682
Original PR: unslothai#682
Author: Lyxot
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Summary
Hardens the MLX fused CCE runtime against two edge cases:
[0, vocab_size)exceptignore_indexBehavior
Before this change:
-1ortarget == vocab_sizecould produce finite-looking losses instead of surfacing bad input.Fix
NaNloss/LSE so bad labels cannot silently train with plausible loss values.mx.compile.Validation
This PR tracks the moving review branch (pr-682-head). Iteration fix commits land here directly. Review-added tests are in a separate PR.
Changed files:
.github/workflows/consolidated-tests-ci.yml.github/workflows/lint-ci.yml.github/workflows/mlx-ci.yml.github/workflows/security-audit.yml.github/workflows/stale.yml.github/workflows/studio-export-fix-ci.yml.github/workflows/wheel-smoke.ymlunsloth_zoo/mlx/cce/runtime_cce.pytests/test_mlx_runtime_cce_compile.py