diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 555b230f2..9dbbd2e05 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -453,6 +453,27 @@ def _process_batch_responses( if request is None: continue + # Handle error responses from failed preprocessing + if response.finish_reason == "error": + output = RequestOutput( + request_id=request_id, + new_token_ids=[], + new_text="", + output_token_ids=[], + prompt_tokens=0, + completion_tokens=0, + finished=True, + finish_reason="error", + ) + request.status = RequestStatus.FINISHED_ABORTED + request.output_text = "" + request.finish_reason = "error" + finished_ids.add(request_id) + self.num_requests_processed += 1 + logger.warning(f"Request {request_id} failed during preprocessing") + outputs.append(output) + continue + # Append token to request request.output_tokens.append(response.token) request.num_output_tokens = len(request.output_tokens) @@ -463,11 +484,10 @@ def _process_batch_responses( new_text = "" else: if request_id not in self._detokenizer_pool: - if hasattr(tokenizer, "detokenizer"): - detok = tokenizer.detokenizer - else: - detok = NaiveStreamingDetokenizer(tokenizer) - detok.reset() + # Always create a new instance per request to avoid shared + # state between concurrent requests (tokenizer.detokenizer + # returns the same object) + detok = NaiveStreamingDetokenizer(tokenizer) self._detokenizer_pool[request_id] = detok detok = self._detokenizer_pool[request_id] detok.add_token(response.token) @@ -495,7 +515,7 @@ def _process_batch_responses( finished_ids.add(request_id) # Finalize streaming detokenizer and get full output - detok = self._detokenizer_pool.get(request_id) + detok = self._detokenizer_pool.pop(request_id, None) if detok is not None: detok.finalize() output.output_text = detok.text @@ -503,7 +523,6 @@ def _process_batch_responses( output.output_text = tokenizer.decode(request.output_tokens) request.output_text = output.output_text request.finish_reason = response.finish_reason - self._detokenizer_pool.pop(request_id, None) self.total_completion_tokens += request.num_output_tokens self.num_requests_processed += 1 @@ -524,6 +543,9 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: if request_id in self.running: del self.running[request_id] + # Clean up detokenizer to prevent memory leak on abort/timeout + self._detokenizer_pool.pop(request_id, None) + # Remove UID mappings if request_id in self.request_id_to_uid: uid = self.request_id_to_uid[request_id] diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index ec4684049..7d29400ab 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -1094,11 +1094,10 @@ def _decode_tokens(self, token_ids: List[int]) -> str: def _get_detokenizer(self, request_id: str) -> Any: """Get or create a streaming detokenizer for a request.""" if request_id not in self._detokenizer_pool: - if hasattr(self.tokenizer, "detokenizer"): - detok = self.tokenizer.detokenizer - else: - detok = NaiveStreamingDetokenizer(self._actual_tokenizer) - detok.reset() + # Always create a new instance per request to avoid shared + # state between concurrent requests (tokenizer.detokenizer + # returns the same object). + detok = NaiveStreamingDetokenizer(self._actual_tokenizer) self._detokenizer_pool[request_id] = detok return self._detokenizer_pool[request_id] diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index a50883951..eed87b0e9 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -28,6 +28,56 @@ def _needs_tokenizer_fallback(model_name: str) -> bool: return any(pattern.lower() in model_lower for pattern in FALLBACK_MODELS) +def _needs_strict_false(model_name: str) -> bool: + """Check if model needs strict=False loading (VLM models with extra weights). + + VLM models (e.g., Qwen3.5) have vision_tower weights that don't match + the text-only model class. Loading with strict=True fails and wastes + memory by loading all weights (~100 GB) before raising ValueError. + Detect these models up-front to avoid the double-load penalty. + """ + from mlx_lm.utils import _download, load_config + + try: + model_path = _download(model_name) + config = load_config(model_path) + except Exception: + return False + if "vision_config" in config and "text_config" in config: + return True + return False + + +def _load_strict_false(model_name: str, tokenizer_config: dict = None): + """Load model with strict=False to discard extra weights. + + Handles models with extra parameters that the text-only model class + doesn't define (e.g., vision tower weights in VLM models, MTP layers). + """ + import mlx.core as mx + from mlx_lm.utils import _download, load_model, load_tokenizer + + model_path = _download(model_name) + model, config = load_model(model_path, strict=False) + + from mlx.utils import tree_flatten + + params = tree_flatten(model.parameters()) + total_params = len(params) + zero_params = sum(1 for _, v in params if mx.all(v == 0).item()) + logger.info( + f"[strict=False] Loaded {total_params} parameters, " + f"{zero_params} all-zero tensors" + ) + + tokenizer = load_tokenizer( + model_path, + tokenizer_config or {}, + eos_token_ids=config.get("eos_token_id", None), + ) + return model, tokenizer + + def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): """ Load model and tokenizer with fallback for non-standard tokenizers. @@ -50,6 +100,11 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): ) return _load_with_tokenizer_fallback(model_name) + # VLM models: skip strict=True attempt to avoid double-loading ~100GB weights + if _needs_strict_false(model_name): + logger.info(f"Model {model_name} detected as VLM, loading with strict=False") + return _load_strict_false(model_name, tokenizer_config) + try: model, tokenizer = load(model_name, tokenizer_config=tokenizer_config) except ValueError as e: @@ -59,13 +114,25 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): return _load_with_tokenizer_fallback(model_name) # Fallback for models with extra weights (e.g., vision tower, MTP layers). # Retry with strict=False to discard extra weights. - if "parameters not in model" in str(e): + elif "parameters not in model" in str(e): logger.warning( f"Extra parameters found (e.g., vision tower / MTP weights), " f"retrying with strict=False: {e}" ) + # Clear traceback references to free memory from the failed load. + # Without this, large models (200GB+) cause OOM during retry. + e.__traceback__ = None + del e + import gc + + gc.collect() return _load_strict_false(model_name, tokenizer_config) - raise + else: + raise + + # After successful load, check if MTP weights exist but were stripped + _try_inject_mtp_post_load(model, model_name) + return model, tokenizer def _load_strict_false(model_name: str, tokenizer_config: dict = None): @@ -111,7 +178,6 @@ def _try_inject_mtp_post_load(model, model_name): return with open(config_path) as f: config = json.load(f) - # Also check text_config for nested configs num_mtp = config.get("num_nextn_predict_layers", 0) if num_mtp == 0: text_config = config.get("text_config", {})