diff --git a/tests/test_repetition_detector.py b/tests/test_repetition_detector.py new file mode 100644 index 00000000..c59f0e7a --- /dev/null +++ b/tests/test_repetition_detector.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for the repetition detector in scheduler.py. +""" + +from vllm_mlx.scheduler import _detect_repetition + + +class TestDetectRepetition: + """Tests for _detect_repetition function.""" + + def test_no_repetition_normal_tokens(self): + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert _detect_repetition(tokens) is False + + def test_single_token_repetition(self): + """8 identical tokens should trigger detection.""" + tokens = [0, 0, 0, 0, 0, 0, 0, 0] + assert _detect_repetition(tokens) is True + + def test_single_token_not_enough(self): + """7 identical tokens should NOT trigger (min_repeat=8).""" + tokens = [0, 0, 0, 0, 0, 0, 0] + assert _detect_repetition(tokens) is False + + def test_single_token_at_tail(self): + """Repetition at the tail of a longer sequence.""" + tokens = [1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0] + assert _detect_repetition(tokens) is True + + def test_two_token_pattern(self): + """Pattern of length 2 repeated 6 times = 12 tokens.""" + tokens = [1, 2] * 6 # [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2] + assert _detect_repetition(tokens) is True + + def test_two_token_not_enough_repeats(self): + """Pattern of length 2 repeated only 5 times should NOT trigger.""" + tokens = [1, 2] * 5 + assert _detect_repetition(tokens) is False + + def test_three_token_pattern(self): + """Pattern of length 3 repeated 6 times = 18 tokens.""" + tokens = [10, 20, 30] * 6 + assert _detect_repetition(tokens) is True + + def test_four_token_pattern(self): + """Pattern of length 4 repeated 6 times = 24 tokens.""" + tokens = [1, 2, 3, 4] * 6 + assert _detect_repetition(tokens) is True + + def test_pattern_at_tail(self): + """Pattern repetition at the tail after normal tokens.""" + prefix = [100, 200, 300, 400, 500] + repeating = [7, 8] * 6 + tokens = prefix + repeating + assert _detect_repetition(tokens) is True + + def test_empty_tokens(self): + assert _detect_repetition([]) is False + + def test_short_tokens(self): + assert _detect_repetition([1, 2, 3]) is False + + def test_almost_repetition(self): + """Almost repeating but one token is different.""" + tokens = [0, 0, 0, 0, 0, 0, 0, 1] + assert _detect_repetition(tokens) is False + + def test_custom_min_repeat(self): + """Custom min_repeat parameter.""" + tokens = [5, 5, 5, 5] + assert _detect_repetition(tokens, min_repeat=4) is True + assert _detect_repetition(tokens, min_repeat=5) is False + + def test_mixed_no_false_positive(self): + """Varied tokens should not trigger.""" + tokens = list(range(32)) + assert _detect_repetition(tokens) is False + + def test_realistic_degenerate_output(self): + """Simulate realistic degenerate model output (token ID 0 = padding).""" + # Model starts generating, then degenerates + normal = [15234, 8821, 3309, 44, 2847] + degenerate = [0] * 10 + tokens = normal + degenerate + assert _detect_repetition(tokens) is True diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 26ef5315..2ecf04e5 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -40,6 +40,35 @@ ] +def _detect_repetition(recent_tokens: list[int], min_repeat: int = 8) -> bool: + """Detect degenerate repetition in recent token history. + + Args: + recent_tokens: Ring buffer of recently generated token IDs. + min_repeat: Minimum number of identical tokens to trigger detection. + + Returns: + True if degenerate repetition is detected. + """ + if len(recent_tokens) < min_repeat: + return False + # Check single-token repetition (e.g., "0 0 0 0 0 0 0 0") + tail = recent_tokens[-min_repeat:] + if len(set(tail)) == 1: + return True + # Check short sequence repetition (e.g., "ab ab ab ab ab ab") + for seq_len in (2, 3, 4): + repeats_needed = 6 + check_len = seq_len * repeats_needed + if len(recent_tokens) < check_len: + continue + tail = recent_tokens[-check_len:] + pattern = tail[:seq_len] + if all(tail[i] == pattern[i % seq_len] for i in range(check_len)): + return True + return False + + class SchedulingPolicy(Enum): """Scheduling policy for request ordering.""" @@ -150,6 +179,9 @@ def _install_chunked_prefill( # Partial prefill state (None when no prefill in progress) batch_gen._partial = None + # Repetition detection: track recent tokens per UID (ring buffer of last 32) + _repetition_buffers: Dict[int, list] = {} + # Monkey-patch _process_prompts to capture prompt-only cache state. # At the point where _process_prompts returns, the Batch cache contains # the exact prompt-only state: all prompt tokens have been processed @@ -201,17 +233,36 @@ def _generation_step(self=batch_gen): cache_out = None num_tok += 1 batch.num_tokens[e] = num_tok + + # Track recent tokens for repetition detection + buf = _repetition_buffers.get(uid) + if buf is None: + buf = [] + _repetition_buffers[uid] = buf + buf.append(t) + # Keep only last 32 tokens (ring buffer) + if len(buf) > 32: + del buf[: len(buf) - 32] + if t in self.stop_tokens: finish_reason = "stop" end_idx.append(e) elif num_tok >= max_tok: finish_reason = "length" end_idx.append(e) + elif _detect_repetition(buf): + finish_reason = "stop" + end_idx.append(e) + logger.info( + f"[repetition_detector] uid={uid} stopped after {num_tok} tokens" + ) else: finish_reason = None keep_idx.append(e) if finish_reason is not None: cache_out = batch.extract_cache(e) + # Clean up repetition buffer on finish + _repetition_buffers.pop(uid, None) responses.append( self.Response(uid, t, logprobs[e], finish_reason, cache_out) ) @@ -471,6 +522,9 @@ def _patched_remove(uids_to_remove, _self=batch_gen): ) _self._partial = None mx.clear_cache() # flush Metal encoders after dropping partial state + # Clean up repetition buffers for removed UIDs + for uid in uids_to_remove: + _repetition_buffers.pop(uid, None) _orig_remove(uids_to_remove) batch_gen._next = _chunked_next