Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def serve_command(args):
stream_interval=args.stream_interval if args.continuous_batching else 1,
max_tokens=args.max_tokens,
force_mllm=args.mllm,
repetition_detector=getattr(args, "repetition_detector", False),
)

# Start server
Expand Down Expand Up @@ -807,6 +808,11 @@ def main():
action="store_true",
help="Force load model as multimodal (vision) even if name doesn't match auto-detection patterns",
)
serve_parser.add_argument(
"--repetition-detector",
action="store_true",
help="Detect and stop degenerate repeating token loops during generation",
)
# Generation defaults
serve_parser.add_argument(
"--default-temperature",
Expand Down
27 changes: 24 additions & 3 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
trust_remote_code: bool = True,
enable_cache: bool = True,
force_mllm: bool = False,
repetition_detector: bool = False,
):
"""
Initialize the simple engine.
Expand All @@ -41,11 +42,13 @@ def __init__(
trust_remote_code: Whether to trust remote code
enable_cache: Enable VLM cache for multimodal models
force_mllm: Force loading as MLLM even if not auto-detected
repetition_detector: Enable detection of degenerate repeat loops
"""
self._model_name = model_name
self._trust_remote_code = trust_remote_code
self._enable_cache = enable_cache
self._is_mllm = force_mllm or is_mllm_model(model_name)
self._use_repetition_detector = repetition_detector

self._model = None
self._loaded = False
Expand Down Expand Up @@ -186,6 +189,12 @@ async def stream_generate(
completion_tokens = 0
finished = False

rep_det = None
if self._use_repetition_detector:
from ..repetition_detector import RepetitionDetector

rep_det = RepetitionDetector()

for chunk in self._model.stream_generate(
prompt=prompt,
max_tokens=max_tokens,
Expand All @@ -203,9 +212,21 @@ async def stream_generate(
new_text = chunk.text if hasattr(chunk, "text") else str(chunk)
accumulated_text += new_text

finished = (
getattr(chunk, "finished", False) or completion_tokens >= max_tokens
)
# Check for degenerate repetition loops
if rep_det is not None:
token_id = getattr(chunk, "token", hash(new_text) & 0xFFFFFFFF)
if rep_det.check(token_id):
logger.warning(
"Repetition loop detected at token %d, stopping",
completion_tokens,
)
finished = True

if not finished:
finished = (
getattr(chunk, "finished", False)
or completion_tokens >= max_tokens
)
finish_reason = None
if finished:
finish_reason = getattr(chunk, "finish_reason", "stop")
Expand Down
73 changes: 73 additions & 0 deletions vllm_mlx/repetition_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Detect degenerate repeating token patterns during generation."""


class RepetitionDetector:
"""Sliding-window detector for repeating token sequences.

Checks periodically (every ``check_interval`` tokens) whether the
last ``window`` tokens contain a pattern of length 2-``max_pattern``
repeated at least ``min_repeats`` times consecutively.

Usage::

det = RepetitionDetector()
for token_id in generate():
if det.check(token_id):
break # degenerate loop detected
"""

def __init__(
self,
window: int = 200,
max_pattern: int = 50,
min_repeats: int = 3,
check_interval: int = 20,
):
self.window = window
self.max_pattern = max_pattern
self.min_repeats = min_repeats
self.check_interval = check_interval
self._tokens: list[int] = []
self._count = 0

def check(self, token_id: int) -> bool:
"""Record a token and return True if a repetition loop is detected."""
self._tokens.append(token_id)
self._count += 1

# Only keep the sliding window
if len(self._tokens) > self.window:
self._tokens = self._tokens[-self.window :]

# Check periodically to stay lightweight
if self._count % self.check_interval != 0:
return False

return self._is_repeating()

def _is_repeating(self) -> bool:
tokens = self._tokens
n = len(tokens)
# Need at least min_repeats * 2 tokens for shortest pattern (len 2)
if n < self.min_repeats * 2:
return False

for pat_len in range(2, min(self.max_pattern + 1, n // self.min_repeats + 1)):
pattern = tokens[-pat_len:]
repeats = 1
pos = n - 2 * pat_len
while pos >= 0:
if tokens[pos : pos + pat_len] == pattern:
repeats += 1
if repeats >= self.min_repeats:
return True
pos -= pat_len
else:
break

return False

def reset(self):
"""Clear state for a new generation."""
self._tokens.clear()
self._count = 0
7 changes: 6 additions & 1 deletion vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def load_model(
stream_interval: int = 1,
max_tokens: int = 32768,
force_mllm: bool = False,
repetition_detector: bool = False,
):
"""
Load a model (auto-detects MLLM vs LLM).
Expand Down Expand Up @@ -502,7 +503,11 @@ def load_model(
logger.info(f"Model loaded (batched mode): {model_name}")
else:
logger.info(f"Loading model with SimpleEngine: {model_name}")
_engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm)
_engine = SimpleEngine(
model_name=model_name,
force_mllm=force_mllm,
repetition_detector=repetition_detector,
)
# Start SimpleEngine synchronously (no background loop)
# Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated)
loop = asyncio.new_event_loop()
Expand Down
Loading