Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions tests/test_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for resumable model download with retry/timeout support."""

import os
from pathlib import Path
from unittest.mock import patch

import pytest

from vllm_mlx.utils.download import (
LLM_ALLOW_PATTERNS,
MLLM_ALLOW_PATTERNS,
DownloadConfig,
ensure_model_downloaded,
)


class TestLocalPath:
"""Tests for local path handling."""

def test_local_path_skips_download(self, tmp_path):
"""Existing local directory is returned without downloading."""
with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
result = ensure_model_downloaded(str(tmp_path))
mock_download.assert_not_called()
assert result == tmp_path


class TestRetryLogic:
"""Tests for download retry behavior."""

def test_retry_on_failure(self):
"""Failed downloads are retried up to max_retries times."""
config = DownloadConfig(max_retries=3, retry_backoff_base=0.01)
fake_path = "/fake/cache/path"

with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
mock_download.side_effect = [
ConnectionError("timeout"),
ConnectionError("timeout"),
fake_path,
]
result = ensure_model_downloaded("org/model", config=config)
assert result == Path(fake_path)
assert mock_download.call_count == 3

def test_retry_exhaustion(self):
"""RuntimeError is raised after all retries are exhausted."""
config = DownloadConfig(max_retries=2, retry_backoff_base=0.01)

with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
mock_download.side_effect = ConnectionError("timeout")
with pytest.raises(RuntimeError, match="Failed to download"):
ensure_model_downloaded("org/model", config=config)
assert mock_download.call_count == 2

def test_keyboard_interrupt_not_retried(self):
"""KeyboardInterrupt propagates immediately without retry."""
config = DownloadConfig(max_retries=3, retry_backoff_base=0.01)

with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
mock_download.side_effect = KeyboardInterrupt()
with pytest.raises(KeyboardInterrupt):
ensure_model_downloaded("org/model", config=config)
assert mock_download.call_count == 1


class TestOfflineMode:
"""Tests for offline mode behavior."""

def test_offline_mode_cached(self):
"""Offline mode finds cached model successfully."""
config = DownloadConfig(offline=True)
fake_path = "/fake/cache/path"

with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
mock_download.return_value = fake_path
result = ensure_model_downloaded("org/model", config=config)
assert result == Path(fake_path)
mock_download.assert_called_once_with("org/model", local_files_only=True)

def test_offline_mode_missing(self):
"""Offline mode raises clear error when model is not cached."""
config = DownloadConfig(offline=True)

with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
mock_download.side_effect = Exception("not found locally")
with pytest.raises(RuntimeError, match="not found in local cache"):
ensure_model_downloaded("org/model", config=config)


class TestTimeout:
"""Tests for download timeout configuration."""

def test_hf_timeout_env_set(self):
"""HF_HUB_DOWNLOAD_TIMEOUT env var is set during download."""
config = DownloadConfig(download_timeout=600, max_retries=1)
fake_path = "/fake/cache/path"
captured_timeout = {}

original_env = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")

def capture_env(*args, **kwargs):
captured_timeout["value"] = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")
return fake_path

with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
mock_download.side_effect = capture_env
ensure_model_downloaded("org/model", config=config)

assert captured_timeout["value"] == "600"
# Env var should be restored after download
assert os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") == original_env

def test_hf_timeout_env_restored_on_failure(self):
"""HF_HUB_DOWNLOAD_TIMEOUT is restored even after failure."""
config = DownloadConfig(
download_timeout=999, max_retries=1, retry_backoff_base=0.01
)
original_env = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")

with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
mock_download.side_effect = ConnectionError("fail")
with pytest.raises(RuntimeError):
ensure_model_downloaded("org/model", config=config)

assert os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") == original_env


class TestAllowPatterns:
"""Tests for LLM vs MLLM download patterns."""

def test_llm_patterns_used_by_default(self):
"""LLM allow patterns are used when is_mllm=False."""
config = DownloadConfig(max_retries=1)
fake_path = "/fake/cache/path"

with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
mock_download.return_value = fake_path
ensure_model_downloaded("org/model", config=config, is_mllm=False)
mock_download.assert_called_once_with(
"org/model", allow_patterns=LLM_ALLOW_PATTERNS
)

def test_mllm_patterns_used(self):
"""MLLM allow patterns are used when is_mllm=True."""
config = DownloadConfig(max_retries=1)
fake_path = "/fake/cache/path"

with patch("vllm_mlx.utils.download.snapshot_download") as mock_download:
mock_download.return_value = fake_path
ensure_model_downloaded("org/model", config=config, is_mllm=True)
mock_download.assert_called_once_with(
"org/model", allow_patterns=MLLM_ALLOW_PATTERNS
)


class TestCLIDownloadCommand:
"""Tests for CLI download subcommand argument parsing."""

def test_cli_download_command(self):
"""Download subcommand parses arguments correctly."""
import argparse

# We test argparse by calling parse_args directly
# (main() would try to actually run the command)
with patch("sys.argv", ["vllm-mlx", "download", "org/model"]):
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command")
download_parser = subparsers.add_parser("download")
download_parser.add_argument("model")
download_parser.add_argument("--timeout", type=int, default=300)
download_parser.add_argument("--retries", type=int, default=3)
download_parser.add_argument("--mllm", action="store_true")

args = parser.parse_args(["download", "org/model", "--timeout", "600"])
assert args.command == "download"
assert args.model == "org/model"
assert args.timeout == 600
assert args.retries == 3
assert args.mllm is False

def test_cli_download_mllm_flag(self):
"""Download subcommand parses --mllm flag."""
import argparse

parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command")
download_parser = subparsers.add_parser("download")
download_parser.add_argument("model")
download_parser.add_argument("--timeout", type=int, default=300)
download_parser.add_argument("--retries", type=int, default=3)
download_parser.add_argument("--mllm", action="store_true")

args = parser.parse_args(["download", "org/vl-model", "--mllm"])
assert args.mllm is True
74 changes: 74 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ def serve_command(args):
print(" Reasoning: Use --reasoning-parser to enable")
print("=" * 60)

# Pre-download model with retry/timeout
from .api.utils import is_mllm_model
from .utils.download import DownloadConfig, ensure_model_downloaded

download_config = DownloadConfig(
download_timeout=args.download_timeout,
max_retries=args.download_retries,
offline=getattr(args, "offline", False),
)
ensure_model_downloaded(
args.model,
config=download_config,
is_mllm=is_mllm_model(args.model),
)

print(f"Loading model: {args.model}")
print(f"Default max tokens: {args.max_tokens}")

Expand Down Expand Up @@ -194,6 +209,23 @@ def serve_command(args):
uvicorn.run(app, host=args.host, port=args.port, log_level="info")


def download_command(args):
"""Download a model to local cache without starting a server."""
from .utils.download import DownloadConfig, ensure_model_downloaded

config = DownloadConfig(
download_timeout=args.timeout,
max_retries=args.retries,
)
print(f"Downloading model: {args.model}")
path = ensure_model_downloaded(
args.model,
config=config,
is_mllm=args.mllm,
)
print(f"Model ready at: {path}")


def bench_command(args):
"""Run benchmark."""
import asyncio
Expand Down Expand Up @@ -827,6 +859,24 @@ def main():
default=None,
help="Pre-load an embedding model at startup (e.g. mlx-community/embeddinggemma-300m-6bit)",
)
# Download options
serve_parser.add_argument(
"--download-timeout",
type=int,
default=300,
help="Per-file download timeout in seconds (default: 300)",
)
serve_parser.add_argument(
"--download-retries",
type=int,
default=3,
help="Number of download retry attempts (default: 3)",
)
serve_parser.add_argument(
"--offline",
action="store_true",
help="Offline mode — only use locally cached models",
)
# Bench command
bench_parser = subparsers.add_parser("bench", help="Run benchmark")
bench_parser.add_argument("model", type=str, help="Model to benchmark")
Expand Down Expand Up @@ -962,6 +1012,28 @@ def main():
help="Quantization group size (default: 64)",
)

# Download command
download_parser = subparsers.add_parser(
"download", help="Download a model to local cache without starting a server"
)
download_parser.add_argument("model", type=str, help="Model to download")
download_parser.add_argument(
"--timeout",
type=int,
default=300,
help="Per-file download timeout in seconds (default: 300)",
)
download_parser.add_argument(
"--retries",
type=int,
default=3,
help="Number of retry attempts (default: 3)",
)
download_parser.add_argument(
"--mllm",
action="store_true",
help="Download as multimodal model (broader file patterns)",
)
args = parser.parse_args()

if args.command == "serve":
Expand All @@ -972,6 +1044,8 @@ def main():
bench_detok_command(args)
elif args.command == "bench-kv-cache":
bench_kv_cache_command(args)
elif args.command == "download":
download_command(args)
else:
parser.print_help()
sys.exit(1)
Expand Down
3 changes: 2 additions & 1 deletion vllm_mlx/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""Utility modules for vllm-mlx."""

from .download import DownloadConfig, ensure_model_downloaded
from .tokenizer import load_model_with_fallback

__all__ = ["load_model_with_fallback"]
__all__ = ["DownloadConfig", "ensure_model_downloaded", "load_model_with_fallback"]
Loading
Loading