Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
91535bc
fix: handle zero-token MLX CCE inputs
Lyxot May 20, 2026
41d5d41
fix: poison invalid MLX CCE labels
Lyxot May 20, 2026
433b794
Merge remote-tracking branch 'origin/main' into fix/mlx-cce-edge-cases
danielhanchen May 24, 2026
51fc49e
Tighten CCE invalid-target guards and broaden compile coverage for PR…
danielhanchen May 24, 2026
037e63b
Revert scalar NaN broadcast in _poison_invalid_targets for PR #682
danielhanchen May 24, 2026
c72840a
Scrub .github/workflows for staging push (matches staging base)
danielhanchen May 24, 2026
5d8a382
Validate target length against hidden token count in MLX CCE
danielhanchen May 24, 2026
5f700d0
Sync .github/workflows with upstream author branch
danielhanchen May 24, 2026
5f1b305
Guard Metal backward kernel against NaN-poisoned lse for PR #682
danielhanchen May 24, 2026
8e32c6c
Reject non-flat targets up front for PR #682
danielhanchen May 24, 2026
69115d8
Scrub .github/workflows for staging push (matches staging base)
danielhanchen May 24, 2026
0c4aef9
Document NaN-lse precondition for MLX CCE fallback dlogits
danielhanchen May 24, 2026
2b3a2cf
Document poison precondition for MLX CCE forward-finalize kernel
danielhanchen May 24, 2026
fb47ee1
Extend MLX CCE edge-case test coverage
danielhanchen May 24, 2026
94fade7
Sync .github/workflows with upstream author branch
danielhanchen May 24, 2026
40cca34
Validate target dtype before narrowing for PR #682
danielhanchen May 24, 2026
2576726
Backward NaN-lse before ignore_index + drop wrapper pre-cast for PR #682
danielhanchen May 25, 2026
8ead4bd
Allocate distinct empty arrays in zero-token CCE return for PR #682
danielhanchen May 25, 2026
e37b0b0
Drop upstream label int32 narrow in VLM batch collators for PR #682
danielhanchen May 25, 2026
ab7065f
Widen VLM label buffers to int64 before runtime CCE validation for PR…
danielhanchen May 25, 2026
6c021dc
Widen prompt/completion VLM labels to int64 + cast unsigned targets b…
danielhanchen May 25, 2026
9c47913
Widen uint64 + normalize unsigned VLM labels before masking for PR #682
danielhanchen May 25, 2026
483ffd3
Preserve uint64 invalidity through normalize + apply normalization in…
danielhanchen May 25, 2026
135da64
Avoid raw uint64 compare in _normalize_cce_label_dtype + cover VLM no…
danielhanchen May 25, 2026
cb3967b
Preserve encoded uint64 ignore + normalize baseline CE + response-mas…
danielhanchen May 25, 2026
bf9529a
Treat uint64 wraparound as invalid + preserve raw input_ids for label…
danielhanchen May 25, 2026
7db36d5
Drop stale _RAW_INPUT_IDS_FOR_LABELS after VLM token expansion for PR…
danielhanchen May 25, 2026
efb032a
Expand raw label carrier alongside input_ids in VLM expand branches f…
danielhanchen May 26, 2026
76ea813
Merge origin/main into fix/mlx-cce-edge-cases
danielhanchen May 27, 2026
09fdfd9
Tighten verbose comments in mlx CCE edge-case code
danielhanchen May 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions tests/test_mlx_cce_target_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Unsloth Zoo - Utilities for Unsloth
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# SPDX-License-Identifier: AGPL-3.0-or-later

# MLX CCE target-classification coverage that runs on non-Apple-Silicon
# hosts via the simulation shim. The companion file
# tests/test_mlx_runtime_cce_compile.py gates on real Metal and skips
# under the shim, leaving the pure-Python validation, in-vocab
# ignore_index precedence, and logit_softcap fallback paths without
# Linux CI coverage. This file fills those gaps.

from __future__ import annotations

import math

import pytest


@pytest.fixture(autouse=True, scope="module")
def _install_mlx_shim():
try:
from mlx_simulation import simulate_mlx_on_torch
except ImportError:
pytest.skip("mlx_simulation suite not on sys.path", allow_module_level=False)
simulate_mlx_on_torch()


def _expected_valid_loss(vocab_size: int) -> float:
# hidden=ones, weight=ones, vocab=V: each logit is dim, lse=log(V)+dim,
# target_logit=dim, so loss = log(V).
return math.log(float(vocab_size))


# ----------------------------------------------------------------------
# Pure-Python validation: shape, length, zero-token mismatch
# ----------------------------------------------------------------------

def test_runtime_cce_zero_tokens_with_non_empty_targets_raises():
import mlx.core as mx

from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss

runtime_cce, _ = make_chunked_cross_entropy_loss(
ignore_index=-100,
chunk_size=16,
)
hidden = mx.zeros((0, 16), dtype=mx.float32)
weight = mx.zeros((32, 16), dtype=mx.float32)
targets = mx.array([0, 1, 2], dtype=mx.int32)

with pytest.raises(ValueError, match="hidden has 0 tokens"):
runtime_cce(hidden, weight, targets)


def test_runtime_cce_rejects_non_flat_targets():
import mlx.core as mx

from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss

runtime_cce, _ = make_chunked_cross_entropy_loss(
ignore_index=-100,
chunk_size=16,
)
hidden = mx.zeros((4, 16), dtype=mx.float32)
weight = mx.zeros((32, 16), dtype=mx.float32)
targets_2d = mx.zeros((4, 1), dtype=mx.int32)
targets_scalar = mx.array(0, dtype=mx.int32)

with pytest.raises(ValueError, match="flat 1D vector"):
runtime_cce(hidden, weight, targets_2d)
with pytest.raises(ValueError, match="flat 1D vector"):
runtime_cce(hidden, weight, targets_scalar)


def test_runtime_cce_rejects_target_length_mismatch():
import mlx.core as mx

from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss

runtime_cce, _ = make_chunked_cross_entropy_loss(
ignore_index=-100,
chunk_size=16,
)
hidden = mx.zeros((4, 16), dtype=mx.float32)
weight = mx.zeros((32, 16), dtype=mx.float32)
targets_wrong_len = mx.zeros((3,), dtype=mx.int32)

with pytest.raises(ValueError, match="targets length does not match"):
runtime_cce(hidden, weight, targets_wrong_len)


# ----------------------------------------------------------------------
# In-vocab ignore_index must take precedence over invalid classification
# ----------------------------------------------------------------------

@pytest.mark.parametrize("ignore_index", [0, 5, 31])
def test_in_vocab_ignore_index_zero_loss_not_nan(ignore_index):
import mlx.core as mx

from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss

runtime_cce, _ = make_chunked_cross_entropy_loss(
ignore_index=ignore_index,
chunk_size=16,
)
vocab_size = 32
hidden = mx.ones((3, 16), dtype=mx.float32)
weight = mx.ones((vocab_size, 16), dtype=mx.float32)
valid_other = (ignore_index + 1) % vocab_size
targets = mx.array(
[0 if ignore_index != 0 else 1, ignore_index, valid_other],
dtype=mx.int32,
)

losses = runtime_cce(hidden, weight, targets)
mx.eval(losses)

assert losses[1].item() == pytest.approx(0.0)
assert not math.isnan(losses[1].item())
assert losses[0].item() == pytest.approx(_expected_valid_loss(vocab_size), rel=1e-5)
assert losses[2].item() == pytest.approx(_expected_valid_loss(vocab_size), rel=1e-5)


def test_in_vocab_ignore_index_does_not_poison_other_rows():
import mlx.core as mx

from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss

runtime_cce, _ = make_chunked_cross_entropy_loss(
ignore_index=10,
chunk_size=16,
)
vocab_size = 32
hidden = mx.ones((4, 16), dtype=mx.float32)
weight = mx.ones((vocab_size, 16), dtype=mx.float32)
targets = mx.array([3, 10, 33, 7], dtype=mx.int32)

losses = runtime_cce(hidden, weight, targets)
mx.eval(losses)

assert losses[0].item() == pytest.approx(_expected_valid_loss(vocab_size), rel=1e-5)
assert losses[1].item() == pytest.approx(0.0)
assert math.isnan(losses[2].item())
assert losses[3].item() == pytest.approx(_expected_valid_loss(vocab_size), rel=1e-5)


# ----------------------------------------------------------------------
# logit_softcap > 0 must preserve NaN poisoning for invalid labels
# ----------------------------------------------------------------------

@pytest.mark.parametrize("bad_target", [-1, 32])
def test_softcap_invalid_label_poisons_loss(bad_target):
import mlx.core as mx

from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss

runtime_cce, _ = make_chunked_cross_entropy_loss(
ignore_index=-100,
chunk_size=16,
logit_softcap=30.0,
)
hidden = mx.ones((3, 16), dtype=mx.float32)
weight = mx.ones((32, 16), dtype=mx.float32)
targets = mx.array([0, bad_target, -100], dtype=mx.int32)

losses = runtime_cce(hidden, weight, targets)
mx.eval(losses)

assert math.isfinite(losses[0].item())
assert math.isnan(losses[1].item())
assert losses[2].item() == pytest.approx(0.0)


def test_softcap_valid_labels_remain_finite():
import mlx.core as mx

from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss

runtime_cce, _ = make_chunked_cross_entropy_loss(
ignore_index=-100,
chunk_size=16,
logit_softcap=30.0,
)
hidden = mx.ones((2, 16), dtype=mx.float32)
weight = mx.ones((32, 16), dtype=mx.float32)
targets = mx.array([0, 1], dtype=mx.int32)

losses = runtime_cce(hidden, weight, targets)
mx.eval(losses)

assert math.isfinite(losses[0].item())
assert math.isfinite(losses[1].item())
Loading