Skip to content

[tests] fix(mlx): handle zero-token and invalid labels in CCE#21

Closed
danielhanchen wants to merge 7 commits into
mainfrom
pr-682-tests
Closed

[tests] fix(mlx): handle zero-token and invalid labels in CCE#21
danielhanchen wants to merge 7 commits into
mainfrom
pr-682-tests

Conversation

@danielhanchen

Copy link
Copy Markdown
Owner

Staging mirror of unslothai#682

Original PR: unslothai#682
Author: Lyxot

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Summary

Hardens the MLX fused CCE runtime against two edge cases:

  • zero-token inputs
  • invalid target labels outside [0, vocab_size) except ignore_index

Behavior

Before this change:

  • zero-token inputs could reach the Metal kernel with an empty grid and crash the process.
  • invalid labels such as -1 or target == vocab_size could produce finite-looking losses instead of surfacing bad input.

Fix

  • Return empty loss/LSE tensors for zero-token forward passes before launching Metal kernels.
  • Return zero gradients for zero-token dense and quantized backward paths.
  • Mark invalid-label rows with NaN loss/LSE so bad labels cannot silently train with plausible loss values.
  • Keep invalid-label handling compile-safe under mx.compile.

Validation

python -m pytest tests/test_mlx_runtime_cce_compile.py tests/test_mlx_baseline_loss_parity.py -q

This PR contains test changes only (1 files). Code changes are in the head PR.

Test files:

  • tests/test_mlx_runtime_cce_compile.py

Lyxot and others added 7 commits May 21, 2026 00:30
…nslothai#682

Three follow-ups from review feedback:

1. _poison_invalid_targets used mx.full(values.shape, NaN, ...) which
   allocates an O(n) tensor on every forward call, even when the invalid
   mask is all False. Replaced with a scalar NaN broadcast through
   mx.where so the normal path costs nothing extra.

2. The zero-token forward early-return now raises ValueError when
   hidden has zero rows but targets is non-empty. The previous version
   silently returned empty loss/LSE which masked an upstream shape
   mismatch instead of surfacing it.

3. The compiled invalid-label regression test now parametrizes over
   bad_target in [-1, vocab_size]. Negative and out-of-range labels
   take different lookup paths under mx.compile, so single-sided
   coverage was insufficient. Added a matching test for the new
   ValueError on hidden=0 with non-empty targets.
…ai#682

The scalar mx.array(float('nan'), ...) broadcast got baked into the
Metal kernel as the literal token 'nan', which the Metal C++
tokenizer rejects ('use of undeclared identifier nan'). Restoring
the original mx.full(values.shape, ...) form keeps the allocation
but preserves Metal-kernel compilation. The two-sided compile-test
parametrize and the hidden=0/targets!=0 ValueError defensive check
remain.
Extend the existing n==0 shape-mismatch guard in
_forward_chunked_fused_finalize so any targets.shape[0] != n is surfaced
as a ValueError before chunk planning, fallback indexing, or Metal
kernel launch. Short non-empty targets previously broadcast to the
n-row loss and lse tensors, silently returning wrong-but-finite losses;
longer targets reached kernels that index targets[row] for row in
[0, n) and produced cryptic backend errors. The new check follows the
PR intent of surfacing upstream shape bugs early.
@danielhanchen

Copy link
Copy Markdown
Owner Author

Fixes pushed to unslothai#682.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants