Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions tests/test_repetition_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
"""
Tests for the repetition detector in scheduler.py.
"""

from vllm_mlx.scheduler import _detect_repetition


class TestDetectRepetition:
"""Tests for _detect_repetition function."""

def test_no_repetition_normal_tokens(self):
tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
assert _detect_repetition(tokens) is False

def test_single_token_repetition(self):
"""8 identical tokens should trigger detection."""
tokens = [0, 0, 0, 0, 0, 0, 0, 0]
assert _detect_repetition(tokens) is True

def test_single_token_not_enough(self):
"""7 identical tokens should NOT trigger (min_repeat=8)."""
tokens = [0, 0, 0, 0, 0, 0, 0]
assert _detect_repetition(tokens) is False

def test_single_token_at_tail(self):
"""Repetition at the tail of a longer sequence."""
tokens = [1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0]
assert _detect_repetition(tokens) is True

def test_two_token_pattern(self):
"""Pattern of length 2 repeated 6 times = 12 tokens."""
tokens = [1, 2] * 6 # [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2]
assert _detect_repetition(tokens) is True

def test_two_token_not_enough_repeats(self):
"""Pattern of length 2 repeated only 5 times should NOT trigger."""
tokens = [1, 2] * 5
assert _detect_repetition(tokens) is False

def test_three_token_pattern(self):
"""Pattern of length 3 repeated 6 times = 18 tokens."""
tokens = [10, 20, 30] * 6
assert _detect_repetition(tokens) is True

def test_four_token_pattern(self):
"""Pattern of length 4 repeated 6 times = 24 tokens."""
tokens = [1, 2, 3, 4] * 6
assert _detect_repetition(tokens) is True

def test_pattern_at_tail(self):
"""Pattern repetition at the tail after normal tokens."""
prefix = [100, 200, 300, 400, 500]
repeating = [7, 8] * 6
tokens = prefix + repeating
assert _detect_repetition(tokens) is True

def test_empty_tokens(self):
assert _detect_repetition([]) is False

def test_short_tokens(self):
assert _detect_repetition([1, 2, 3]) is False

def test_almost_repetition(self):
"""Almost repeating but one token is different."""
tokens = [0, 0, 0, 0, 0, 0, 0, 1]
assert _detect_repetition(tokens) is False

def test_custom_min_repeat(self):
"""Custom min_repeat parameter."""
tokens = [5, 5, 5, 5]
assert _detect_repetition(tokens, min_repeat=4) is True
assert _detect_repetition(tokens, min_repeat=5) is False

def test_mixed_no_false_positive(self):
"""Varied tokens should not trigger."""
tokens = list(range(32))
assert _detect_repetition(tokens) is False

def test_realistic_degenerate_output(self):
"""Simulate realistic degenerate model output (token ID 0 = padding)."""
# Model starts generating, then degenerates
normal = [15234, 8821, 3309, 44, 2847]
degenerate = [0] * 10
tokens = normal + degenerate
assert _detect_repetition(tokens) is True
54 changes: 54 additions & 0 deletions vllm_mlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,35 @@
]


def _detect_repetition(recent_tokens: list[int], min_repeat: int = 8) -> bool:
"""Detect degenerate repetition in recent token history.

Args:
recent_tokens: Ring buffer of recently generated token IDs.
min_repeat: Minimum number of identical tokens to trigger detection.

Returns:
True if degenerate repetition is detected.
"""
if len(recent_tokens) < min_repeat:
return False
# Check single-token repetition (e.g., "0 0 0 0 0 0 0 0")
tail = recent_tokens[-min_repeat:]
if len(set(tail)) == 1:
return True
# Check short sequence repetition (e.g., "ab ab ab ab ab ab")
for seq_len in (2, 3, 4):
repeats_needed = 6
check_len = seq_len * repeats_needed
if len(recent_tokens) < check_len:
continue
tail = recent_tokens[-check_len:]
pattern = tail[:seq_len]
if all(tail[i] == pattern[i % seq_len] for i in range(check_len)):
return True
return False


class SchedulingPolicy(Enum):
"""Scheduling policy for request ordering."""

Expand Down Expand Up @@ -150,6 +179,9 @@ def _install_chunked_prefill(
# Partial prefill state (None when no prefill in progress)
batch_gen._partial = None

# Repetition detection: track recent tokens per UID (ring buffer of last 32)
_repetition_buffers: Dict[int, list] = {}

# Monkey-patch _process_prompts to capture prompt-only cache state.
# At the point where _process_prompts returns, the Batch cache contains
# the exact prompt-only state: all prompt tokens have been processed
Expand Down Expand Up @@ -201,17 +233,36 @@ def _generation_step(self=batch_gen):
cache_out = None
num_tok += 1
batch.num_tokens[e] = num_tok

# Track recent tokens for repetition detection
buf = _repetition_buffers.get(uid)
if buf is None:
buf = []
_repetition_buffers[uid] = buf
buf.append(t)
# Keep only last 32 tokens (ring buffer)
if len(buf) > 32:
del buf[: len(buf) - 32]

if t in self.stop_tokens:
finish_reason = "stop"
end_idx.append(e)
elif num_tok >= max_tok:
finish_reason = "length"
end_idx.append(e)
elif _detect_repetition(buf):
finish_reason = "stop"
end_idx.append(e)
logger.info(
f"[repetition_detector] uid={uid} stopped after {num_tok} tokens"
)
else:
finish_reason = None
keep_idx.append(e)
if finish_reason is not None:
cache_out = batch.extract_cache(e)
# Clean up repetition buffer on finish
_repetition_buffers.pop(uid, None)
responses.append(
self.Response(uid, t, logprobs[e], finish_reason, cache_out)
)
Expand Down Expand Up @@ -471,6 +522,9 @@ def _patched_remove(uids_to_remove, _self=batch_gen):
)
_self._partial = None
mx.clear_cache() # flush Metal encoders after dropping partial state
# Clean up repetition buffers for removed UIDs
for uid in uids_to_remove:
_repetition_buffers.pop(uid, None)
_orig_remove(uids_to_remove)

batch_gen._next = _chunked_next
Expand Down
Loading