Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions vllm_mlx/mllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,27 @@ def _process_batch_responses(
if request is None:
continue

# Handle error responses from failed preprocessing
if response.finish_reason == "error":
output = RequestOutput(
request_id=request_id,
new_token_ids=[],
new_text="",
output_token_ids=[],
prompt_tokens=0,
completion_tokens=0,
finished=True,
finish_reason="error",
)
request.status = RequestStatus.FINISHED_ABORTED
request.output_text = ""
request.finish_reason = "error"
finished_ids.add(request_id)
self.num_requests_processed += 1
logger.warning(f"Request {request_id} failed during preprocessing")
outputs.append(output)
continue

# Append token to request
request.output_tokens.append(response.token)
request.num_output_tokens = len(request.output_tokens)
Expand All @@ -463,11 +484,10 @@ def _process_batch_responses(
new_text = ""
else:
if request_id not in self._detokenizer_pool:
if hasattr(tokenizer, "detokenizer"):
detok = tokenizer.detokenizer
else:
detok = NaiveStreamingDetokenizer(tokenizer)
detok.reset()
# Always create a new instance per request to avoid shared
# state between concurrent requests (tokenizer.detokenizer
# returns the same object)
detok = NaiveStreamingDetokenizer(tokenizer)
self._detokenizer_pool[request_id] = detok
detok = self._detokenizer_pool[request_id]
detok.add_token(response.token)
Expand Down Expand Up @@ -495,15 +515,14 @@ def _process_batch_responses(
finished_ids.add(request_id)

# Finalize streaming detokenizer and get full output
detok = self._detokenizer_pool.get(request_id)
detok = self._detokenizer_pool.pop(request_id, None)
if detok is not None:
detok.finalize()
output.output_text = detok.text
else:
output.output_text = tokenizer.decode(request.output_tokens)
request.output_text = output.output_text
request.finish_reason = response.finish_reason
self._detokenizer_pool.pop(request_id, None)

self.total_completion_tokens += request.num_output_tokens
self.num_requests_processed += 1
Expand All @@ -524,6 +543,9 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
if request_id in self.running:
del self.running[request_id]

# Clean up detokenizer to prevent memory leak on abort/timeout
self._detokenizer_pool.pop(request_id, None)

# Remove UID mappings
if request_id in self.request_id_to_uid:
uid = self.request_id_to_uid[request_id]
Expand Down
9 changes: 4 additions & 5 deletions vllm_mlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,11 +1094,10 @@ def _decode_tokens(self, token_ids: List[int]) -> str:
def _get_detokenizer(self, request_id: str) -> Any:
"""Get or create a streaming detokenizer for a request."""
if request_id not in self._detokenizer_pool:
if hasattr(self.tokenizer, "detokenizer"):
detok = self.tokenizer.detokenizer
else:
detok = NaiveStreamingDetokenizer(self._actual_tokenizer)
detok.reset()
# Always create a new instance per request to avoid shared
# state between concurrent requests (tokenizer.detokenizer
# returns the same object).
detok = NaiveStreamingDetokenizer(self._actual_tokenizer)
self._detokenizer_pool[request_id] = detok
return self._detokenizer_pool[request_id]

Expand Down
72 changes: 69 additions & 3 deletions vllm_mlx/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,56 @@ def _needs_tokenizer_fallback(model_name: str) -> bool:
return any(pattern.lower() in model_lower for pattern in FALLBACK_MODELS)


def _needs_strict_false(model_name: str) -> bool:
"""Check if model needs strict=False loading (VLM models with extra weights).

VLM models (e.g., Qwen3.5) have vision_tower weights that don't match
the text-only model class. Loading with strict=True fails and wastes
memory by loading all weights (~100 GB) before raising ValueError.
Detect these models up-front to avoid the double-load penalty.
"""
from mlx_lm.utils import _download, load_config

try:
model_path = _download(model_name)
config = load_config(model_path)
except Exception:
return False
if "vision_config" in config and "text_config" in config:
return True
return False


def _load_strict_false(model_name: str, tokenizer_config: dict = None):
"""Load model with strict=False to discard extra weights.

Handles models with extra parameters that the text-only model class
doesn't define (e.g., vision tower weights in VLM models, MTP layers).
"""
import mlx.core as mx
from mlx_lm.utils import _download, load_model, load_tokenizer

model_path = _download(model_name)
model, config = load_model(model_path, strict=False)

from mlx.utils import tree_flatten

params = tree_flatten(model.parameters())
total_params = len(params)
zero_params = sum(1 for _, v in params if mx.all(v == 0).item())
logger.info(
f"[strict=False] Loaded {total_params} parameters, "
f"{zero_params} all-zero tensors"
)

tokenizer = load_tokenizer(
model_path,
tokenizer_config or {},
eos_token_ids=config.get("eos_token_id", None),
)
return model, tokenizer


def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
"""
Load model and tokenizer with fallback for non-standard tokenizers.
Expand All @@ -50,6 +100,11 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
)
return _load_with_tokenizer_fallback(model_name)

# VLM models: skip strict=True attempt to avoid double-loading ~100GB weights
if _needs_strict_false(model_name):
logger.info(f"Model {model_name} detected as VLM, loading with strict=False")
return _load_strict_false(model_name, tokenizer_config)

try:
model, tokenizer = load(model_name, tokenizer_config=tokenizer_config)
except ValueError as e:
Expand All @@ -59,13 +114,25 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
return _load_with_tokenizer_fallback(model_name)
# Fallback for models with extra weights (e.g., vision tower, MTP layers).
# Retry with strict=False to discard extra weights.
if "parameters not in model" in str(e):
elif "parameters not in model" in str(e):
logger.warning(
f"Extra parameters found (e.g., vision tower / MTP weights), "
f"retrying with strict=False: {e}"
)
# Clear traceback references to free memory from the failed load.
# Without this, large models (200GB+) cause OOM during retry.
e.__traceback__ = None
del e
import gc

gc.collect()
return _load_strict_false(model_name, tokenizer_config)
raise
else:
raise

# After successful load, check if MTP weights exist but were stripped
_try_inject_mtp_post_load(model, model_name)
return model, tokenizer


def _load_strict_false(model_name: str, tokenizer_config: dict = None):
Expand Down Expand Up @@ -111,7 +178,6 @@ def _try_inject_mtp_post_load(model, model_name):
return
with open(config_path) as f:
config = json.load(f)
# Also check text_config for nested configs
num_mtp = config.get("num_nextn_predict_layers", 0)
if num_mtp == 0:
text_config = config.get("text_config", {})
Expand Down
Loading