diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 195dfe3..e368ffb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.13" - name: Install dependencies run: | diff --git a/pyproject.toml b/pyproject.toml index 193a22c..53434ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "rapid-mlx" -version = "0.3.12" +version = "0.2.7" description = "Rapid-MLX — AI inference for Apple Silicon. Drop-in OpenAI API, 2-4x faster than Ollama." readme = "README.md" license = {text = "Apache-2.0"} @@ -30,8 +30,9 @@ classifiers = [ dependencies = [ # Core — these are all you need for `rapid-mlx serve ` "mlx>=0.29.0", - "mlx-lm>=0.30.5", - "transformers>=5.0.0", + "mlx-lm>=0.31.0", # 0.31+ required for ArraysCache native batching (hybrid models) + "mlx-vlm>=0.1.0", # VLM support + "transformers>=5.0.0", # mlx-lm 0.30.5+ requires transformers 5.0 (rc3 bug fixed in stable) "tokenizers>=0.19.0", "huggingface-hub>=0.23.0", "numpy>=1.24.0", diff --git a/reports/benchmarks/devstral-24b.json b/reports/benchmarks/devstral-24b.json new file mode 100644 index 0000000..e174cf3 --- /dev/null +++ b/reports/benchmarks/devstral-24b.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/LMStudio-Models/mlx-community/Devstral-Small-2-24B-Instruct-2512-4bit", + "short_decode_tps": { + "mean": 29.546224945131634, + "median": 29.5528121992651, + "min": 29.523008740000495, + "max": 29.562853896129308 + }, + "short_prefill_tps": { + "median": 78.50722284202406 + }, + "long_decode_tps": { + "mean": 29.190642870240584, + "median": 29.196905393551695, + "min": 29.17665464544716, + "max": 29.198368571722895 + }, + "long_prefill_tps": { + "median": 555.7740024656049 + }, + "ttft_cold_s": 0.3981518749933457, + "ttft_cached_s": 0.17480293699918548, + "multi_turn_ttft_cold_s": 0.5294090829993365, + "multi_turn_ttft_cached_s": 0.1781606875010766, + "peak_ram_mb": 13286.703125, + "tool_call_rate": 0.0, + "recovery_rate": 0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/glm45-air.json b/reports/benchmarks/glm45-air.json new file mode 100644 index 0000000..ff8885d --- /dev/null +++ b/reports/benchmarks/glm45-air.json @@ -0,0 +1,31 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/mlx-models/GLM-4.5-Air-MLX-4bit", + "short_decode_tps": { + "mean": 0.27992552863199427, + "median": 0.27974669915489814, + "min": 0.27946756029371045, + "max": 0.2805623264473743 + }, + "long_decode_tps": { + "mean": 16.547349065413208, + "median": 0.11041879751955017, + "min": 0.10995964577211123, + "max": 49.421668752947966 + }, + "long_prefill_tps": { + "median": 720.5478369401173 + }, + "ttft_cold_s": 0.689568749992759, + "ttft_cached_s": 0.13348579200101085, + "multi_turn_ttft_cold_s": 0.5098579159966903, + "multi_turn_ttft_cached_s": 0.13851270850136643, + "peak_ram_mb": 58026.484375, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/gpt-oss-20b.json b/reports/benchmarks/gpt-oss-20b.json index 86b38c7..34f59e3 100644 --- a/reports/benchmarks/gpt-oss-20b.json +++ b/reports/benchmarks/gpt-oss-20b.json @@ -1,34 +1,34 @@ [ { "engine": "Rapid-MLX", - "model": "default", + "model": "/Volumes/Extreme SSD/LMStudio-Models/mlx-community/gpt-oss-20b-MXFP4-Q8", "short_decode_tps": { - "mean": 122.94637623722063, - "median": 122.80723229657188, - "min": 121.5261824468642, - "max": 124.50571396822582 + "mean": 59.23904874483382, + "median": 58.460670099788786, + "min": 58.3865845503003, + "max": 60.86989158441239 }, "short_prefill_tps": { - "median": 658.6315031611078 + "median": 180.25974764376866 }, "long_decode_tps": { - "mean": 123.56692922034931, - "median": 123.55530128557712, - "min": 123.47945313209522, - "max": 123.6660332433756 + "mean": 59.21073424542031, + "median": 59.209885676983895, + "min": 59.014062596876016, + "max": 59.40825446240103 }, "long_prefill_tps": { - "median": 1413.1066416100443 + "median": 452.4523802707095 }, - "ttft_cold_s": 0.3050392910372466, - "ttft_cached_s": 0.112084332969971, - "multi_turn_ttft_cold_s": 0.3241620830958709, - "multi_turn_ttft_cached_s": 0.11514252098277211, - "peak_ram_mb": 12061.125, + "ttft_cold_s": 0.43270891599240713, + "ttft_cached_s": 1.639991458003351, + "multi_turn_ttft_cold_s": 0.46583316699252464, + "multi_turn_ttft_cached_s": 0.2966201045055641, + "peak_ram_mb": 12314.140625, "tool_call_rate": 0.0, "recovery_rate": 0, "leak_rate": 0.0, - "vision": true, + "vision": false, "audio": false } ] \ No newline at end of file diff --git a/reports/benchmarks/hermes3-llama31-8b.json b/reports/benchmarks/hermes3-llama31-8b.json index 06ce5ee..d9edcfd 100644 --- a/reports/benchmarks/hermes3-llama31-8b.json +++ b/reports/benchmarks/hermes3-llama31-8b.json @@ -1,30 +1,30 @@ [ { "engine": "Rapid-MLX", - "model": "default", + "model": "/Volumes/Extreme SSD/LMStudio-Models/mlx-community/Hermes-3-Llama-3.1-8B-4bit", "short_decode_tps": { - "mean": 124.08034651519463, - "median": 124.36739745875063, - "min": 123.39537136101775, - "max": 124.4782707258155 + "mean": 123.8564234267494, + "median": 123.42606066686089, + "min": 123.31491031465332, + "max": 124.82829929873398 }, "short_prefill_tps": { - "median": 247.71035750769664 + "median": 190.87991603464098 }, "long_decode_tps": { - "mean": 122.80223356011224, - "median": 122.77578368076514, - "min": 122.77411550738096, - "max": 122.85680149219063 + "mean": 123.22420766220122, + "median": 122.97184960224764, + "min": 122.68944338901483, + "max": 124.0113299953412 }, "long_prefill_tps": { - "median": 1351.9491427231967 + "median": 980.9118874888566 }, - "ttft_cold_s": 0.3690061660017818, - "ttft_cached_s": 0.08006683352869004, - "multi_turn_ttft_cold_s": 0.23345162498299032, - "multi_turn_ttft_cached_s": 0.0775428750202991, - "peak_ram_mb": 4711.265625, + "ttft_cold_s": 0.14250754201202653, + "ttft_cached_s": 0.10385847950237803, + "multi_turn_ttft_cold_s": 0.19274274999042973, + "multi_turn_ttft_cached_s": 0.10551808349555358, + "peak_ram_mb": 4940.90625, "tool_call_rate": 0.0, "recovery_rate": 0.0, "leak_rate": 0.0, diff --git a/reports/benchmarks/llama32-3b.json b/reports/benchmarks/llama32-3b.json index 0b1ed49..0aa1b47 100644 --- a/reports/benchmarks/llama32-3b.json +++ b/reports/benchmarks/llama32-3b.json @@ -1,30 +1,30 @@ [ { "engine": "Rapid-MLX", - "model": "default", + "model": "/Volumes/Extreme SSD/LMStudio-Models/mlx-community/Llama-3.2-3B-Instruct-4bit", "short_decode_tps": { - "mean": 225.29973710147297, - "median": 225.3064908515301, - "min": 224.56576334387896, - "max": 226.02695710900986 + "mean": 226.52675688833313, + "median": 226.5153487509708, + "min": 226.30568834910332, + "max": 226.75923356492524 }, "short_prefill_tps": { - "median": 684.1153251386287 + "median": 475.41882123651436 }, "long_decode_tps": { - "mean": 219.96123540145632, - "median": 219.99973600342884, - "min": 219.82171612576647, - "max": 220.06225407517366 + "mean": 220.57800958459168, + "median": 220.6666855788961, + "min": 220.287971451054, + "max": 220.7793717238249 }, "long_prefill_tps": { - "median": 1912.1065873841285 + "median": 1328.9680471540132 }, - "ttft_cold_s": 0.12779508298262954, - "ttft_cached_s": 0.06659756251610816, - "multi_turn_ttft_cold_s": 0.1073445410002023, - "multi_turn_ttft_cached_s": 0.06532658298965544, - "peak_ram_mb": 2120.53125, + "ttft_cold_s": 0.12346550000074785, + "ttft_cached_s": 0.0960028545014211, + "multi_turn_ttft_cold_s": 0.15037445900088642, + "multi_turn_ttft_cached_s": 0.09493745800136821, + "peak_ram_mb": 2348.03125, "tool_call_rate": 0.0, "recovery_rate": 0.0, "leak_rate": 0.0, diff --git a/reports/benchmarks/minimax-m25.json b/reports/benchmarks/minimax-m25.json index e89a765..442268b 100644 --- a/reports/benchmarks/minimax-m25.json +++ b/reports/benchmarks/minimax-m25.json @@ -1,33 +1,33 @@ [ { "engine": "Rapid-MLX", - "model": "default", + "model": "/Volumes/Extreme SSD/mlx-models/MiniMax-M2.5-MLX-4bit", "short_decode_tps": { - "mean": 51.84681788276677, - "median": 51.86400138987982, - "min": 51.76456896445916, - "max": 51.91188329396134 + "mean": 51.67176982233886, + "median": 51.65256027149127, + "min": 51.61833640541185, + "max": 51.74441279011345 }, "short_prefill_tps": { - "median": 373.7141658642875 + "median": 137.99610273516524 }, "long_decode_tps": { - "mean": 50.95070780816328, - "median": 50.97072297393578, - "min": 50.88414682013725, - "max": 50.9972536304168 + "mean": 51.14445303382068, + "median": 51.20566577490106, + "min": 51.00525329304959, + "max": 51.222440033511404 }, "long_prefill_tps": { - "median": 993.181355104207 + "median": 347.7178246515437 }, - "ttft_cold_s": 1.1762665420000076, - "ttft_cached_s": 0.13059816650002176, - "multi_turn_ttft_cold_s": 0.49031412499994076, - "multi_turn_ttft_cached_s": 0.13266239600000063, - "peak_ram_mb": 123113.28125, + "ttft_cold_s": 1.5327279580087634, + "ttft_cached_s": 0.47744062499987194, + "multi_turn_ttft_cold_s": 1.049200916007976, + "multi_turn_ttft_cached_s": 0.4468816875014454, + "peak_ram_mb": 123325.296875, "tool_call_rate": 1.0, "recovery_rate": 1.0, - "leak_rate": 0.8, + "leak_rate": 0.0, "vision": true, "audio": false } diff --git a/reports/benchmarks/mistral-small-24b.json b/reports/benchmarks/mistral-small-24b.json index f1da033..ea273b5 100644 --- a/reports/benchmarks/mistral-small-24b.json +++ b/reports/benchmarks/mistral-small-24b.json @@ -1,30 +1,30 @@ [ { "engine": "Rapid-MLX", - "model": "default", + "model": "/Volumes/Extreme SSD/LMStudio-Models/lmstudio-community/Mistral-Small-3.2-24B-Instruct-2506-MLX-4bit", "short_decode_tps": { - "mean": 48.45506038907022, - "median": 48.444075696160446, - "min": 48.43399621727914, - "max": 48.487109253771074 + "mean": 48.376265707848326, + "median": 48.39510538749555, + "min": 48.28727936639898, + "max": 48.44641236965046 }, "short_prefill_tps": { - "median": 3075.293887102689 + "median": 2439.4643799749797 }, "long_decode_tps": { - "mean": 47.902364420568546, - "median": 47.9061503187498, - "min": 47.85985295159887, - "max": 47.941089991356975 + "mean": 41.24710965899371, + "median": 47.605138559651195, + "min": 28.290110300102153, + "max": 47.84608011722777 }, "long_prefill_tps": { - "median": 4034.624900255464 + "median": 3025.09303781437 }, - "ttft_cold_s": 1.1419446660438552, - "ttft_cached_s": 0.1070051045389846, - "multi_turn_ttft_cold_s": 0.392231083009392, - "multi_turn_ttft_cached_s": 0.0995690205018036, - "peak_ram_mb": 13007.984375, + "ttft_cold_s": 1.164492041003541, + "ttft_cached_s": 0.13583050000306685, + "multi_turn_ttft_cold_s": 0.5366168750042561, + "multi_turn_ttft_cached_s": 0.18406710399722215, + "peak_ram_mb": 13272.046875, "tool_call_rate": 0.0, "recovery_rate": 0.0, "leak_rate": 0.0, diff --git a/reports/benchmarks/phi4-mini.json b/reports/benchmarks/phi4-mini.json new file mode 100644 index 0000000..3ab90a8 --- /dev/null +++ b/reports/benchmarks/phi4-mini.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/LMStudio-Models/lmstudio-community/Phi-4-mini-reasoning-MLX-4bit", + "short_decode_tps": { + "mean": 174.0167093209013, + "median": 174.0297507031804, + "min": 173.96538436660595, + "max": 174.0549928929175 + }, + "short_prefill_tps": { + "median": 212.18942802486131 + }, + "long_decode_tps": { + "mean": 169.99027159980744, + "median": 169.88826976220724, + "min": 169.87693495786576, + "max": 170.20561007934936 + }, + "long_prefill_tps": { + "median": 840.6429382159101 + }, + "ttft_cold_s": 0.1561205420002807, + "ttft_cached_s": 0.13479174999520183, + "multi_turn_ttft_cold_s": 0.19009308399108704, + "multi_turn_ttft_cached_s": 0.13267095800256357, + "peak_ram_mb": 2651.3125, + "tool_call_rate": 0.0, + "recovery_rate": 0, + "leak_rate": 1.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/qwen3-coder-next.json b/reports/benchmarks/qwen3-coder-next.json new file mode 100644 index 0000000..b748417 --- /dev/null +++ b/reports/benchmarks/qwen3-coder-next.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/LMStudio-Models/lmstudio-community/Qwen3-Coder-Next-MLX-4bit", + "short_decode_tps": { + "mean": 74.49753610288838, + "median": 74.53875302231162, + "min": 74.37381199570675, + "max": 74.58004329064677 + }, + "short_prefill_tps": { + "median": 109.45499116352116 + }, + "long_decode_tps": { + "mean": 72.94398697030766, + "median": 72.88246741411913, + "min": 72.71115685641371, + "max": 73.23833664039014 + }, + "long_prefill_tps": { + "median": 395.6298569406463 + }, + "ttft_cold_s": 0.5353582079987973, + "ttft_cached_s": 0.15223635399888735, + "multi_turn_ttft_cold_s": 0.2726411250041565, + "multi_turn_ttft_cached_s": 0.19461137499456527, + "peak_ram_mb": 43416.75, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/qwen35-27b.json b/reports/benchmarks/qwen35-27b.json new file mode 100644 index 0000000..c7275e7 --- /dev/null +++ b/reports/benchmarks/qwen35-27b.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/LMStudio-Models/mlx-community/Qwen3.5-27B-4bit", + "short_decode_tps": { + "mean": 38.85941608644048, + "median": 38.969508540666844, + "min": 38.59073363686412, + "max": 39.01800608179049 + }, + "short_prefill_tps": { + "median": 59.8790443320733 + }, + "long_decode_tps": { + "mean": 38.64150579366772, + "median": 38.648835551974784, + "min": 38.615557996057866, + "max": 38.6601238329705 + }, + "long_prefill_tps": { + "median": 183.93893179093496 + }, + "ttft_cold_s": 0.4288106249878183, + "ttft_cached_s": 0.2768138959945645, + "multi_turn_ttft_cold_s": 0.6875584160006838, + "multi_turn_ttft_cached_s": 0.4295624790029251, + "peak_ram_mb": 15171.09375, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/qwen35-35b-a3b.json b/reports/benchmarks/qwen35-35b-a3b.json new file mode 100644 index 0000000..6754b6e --- /dev/null +++ b/reports/benchmarks/qwen35-35b-a3b.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/LMStudio-Models/Qwen3.5-35B-A3B-8bit", + "short_decode_tps": { + "mean": 83.05217030898503, + "median": 83.11205723211323, + "min": 82.92439265366863, + "max": 83.1200610411732 + }, + "short_prefill_tps": { + "median": 97.83125210812838 + }, + "long_decode_tps": { + "mean": 82.07174184821544, + "median": 82.09727904869547, + "min": 81.83012948039176, + "max": 82.2878170155591 + }, + "long_prefill_tps": { + "median": 365.7515127062668 + }, + "ttft_cold_s": 0.4227738329936983, + "ttft_cached_s": 0.19225418750284007, + "multi_turn_ttft_cold_s": 0.29732754100405145, + "multi_turn_ttft_cached_s": 0.22780197899555787, + "peak_ram_mb": 35859.5, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/qwen35-4b.json b/reports/benchmarks/qwen35-4b.json index 994850a..9c79a36 100644 --- a/reports/benchmarks/qwen35-4b.json +++ b/reports/benchmarks/qwen35-4b.json @@ -1,30 +1,30 @@ [ { "engine": "Rapid-MLX", - "model": "default", + "model": "mlx-community/Qwen3.5-4B-MLX-4bit", "short_decode_tps": { - "mean": 12534.648129409241, - "median": 12532.585835006303, - "min": 11085.359976254695, - "max": 13985.998576966726 + "mean": 161.51262202923854, + "median": 161.51928309765657, + "min": 161.469009120634, + "max": 161.54957386942502 }, "short_prefill_tps": { - "median": 15.198847520323636 + "median": 108.55340507429347 }, "long_decode_tps": { - "mean": 9616.363718586874, - "median": 14336.916074158638, - "min": 158.17741769529468, - "max": 14353.99766390669 + "mean": 159.87546201437704, + "median": 160.0963752307641, + "min": 159.04055963059778, + "max": 160.48945118176925 }, "long_prefill_tps": { - "median": 31.442693820655546 + "median": 470.6212943769949 }, - "ttft_cold_s": 1.4518451249459758, - "ttft_cached_s": 1.3812861454789527, - "multi_turn_ttft_cold_s": 0.79640395892784, - "multi_turn_ttft_cached_s": 0.7635571040445939, - "peak_ram_mb": 2694.625, + "ttft_cold_s": 0.20709370900294743, + "ttft_cached_s": 0.1758369999952265, + "multi_turn_ttft_cold_s": 0.23808012499648612, + "multi_turn_ttft_cached_s": 0.19367549999878975, + "peak_ram_mb": 2981.625, "tool_call_rate": 1.0, "recovery_rate": 1.0, "leak_rate": 0.0, diff --git a/reports/benchmarks/qwen35-9b.json b/reports/benchmarks/qwen35-9b.json new file mode 100644 index 0000000..fb38792 --- /dev/null +++ b/reports/benchmarks/qwen35-9b.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/LMStudio-Models/mlx-community/Qwen3.5-9B-4bit", + "short_decode_tps": { + "mean": 97.89498105147021, + "median": 99.84277038711238, + "min": 93.94512015866323, + "max": 99.89705260863504 + }, + "short_prefill_tps": { + "median": 10.78341279848712 + }, + "long_decode_tps": { + "mean": 103.15986603661459, + "median": 103.9518078591003, + "min": 101.21186818236882, + "max": 104.31592206837466 + }, + "long_prefill_tps": { + "median": 21.688568216231342 + }, + "ttft_cold_s": 0.4995118750084657, + "ttft_cached_s": 1.947697354502452, + "multi_turn_ttft_cold_s": 1.1698493330040947, + "multi_turn_ttft_cached_s": 1.0341245004965458, + "peak_ram_mb": 5540.109375, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/tests/regression_suite.py b/tests/regression_suite.py new file mode 100644 index 0000000..1989520 --- /dev/null +++ b/tests/regression_suite.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3.12 +"""Comprehensive regression and edge case test suite for Rapid-MLX.""" + +import json +import urllib.request +import urllib.error +import sys + +BASE = "http://localhost:8777" + +def api_call(path, body=None, method="GET"): + """Make an API call, return (status_code, parsed_json_or_None).""" + url = BASE + path + data = json.dumps(body).encode() if body else None + req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"}) + if method != "GET" and data is None: + req.method = method + try: + resp = urllib.request.urlopen(req) + return resp.status, json.loads(resp.read()) + except urllib.error.HTTPError as e: + try: + body_text = e.read().decode()[:500] + except: + body_text = "" + return e.code, body_text + +def stream_call(path, body): + """Make a streaming API call, return collected text and all SSE lines.""" + url = BASE + path + data = json.dumps(body).encode() + req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"}) + text = "" + lines = [] + with urllib.request.urlopen(req) as resp: + for line in resp: + line = line.decode().strip() + if line.startswith("data:"): + lines.append(line) + if "[DONE]" not in line: + d = json.loads(line[5:].strip()) + delta = d["choices"][0].get("delta", {}) + if "content" in delta: + text += delta["content"] + return text, lines + +def test_1(): + """Stop at newline.""" + print("=" * 60) + print("TEST 1: Stop sequence - newline") + _, r = api_call("/v1/chat/completions", { + "model": "default", + "messages": [{"role": "user", "content": "Say hello then explain python"}], + "stop": ["\n"], + "max_tokens": 100, + "stream": False + }) + content = r["choices"][0]["message"]["content"] + finish = r["choices"][0]["finish_reason"] + has_newline = "\n" in content + print(f" Content: {content!r}") + print(f" Has newline: {has_newline}") + print(f" finish_reason: {finish}") + passed = not has_newline and finish == "stop" + print(f" RESULT: {'PASS' if passed else 'FAIL'}") + return passed + +def test_2(): + """Multiple stop sequences (first match wins).""" + print("=" * 60) + print("TEST 2: Multiple stop sequences") + _, r = api_call("/v1/chat/completions", { + "model": "default", + "messages": [{"role": "user", "content": "Write: Hello World! Goodbye World!"}], + "stop": ["World", "!"], + "max_tokens": 100, + "stream": False + }) + content = r["choices"][0]["message"]["content"] + finish = r["choices"][0]["finish_reason"] + has_world = "World" in content + has_bang = "!" in content + print(f" Content: {content!r}") + print(f" Contains 'World': {has_world}") + print(f" Contains '!': {has_bang}") + print(f" finish_reason: {finish}") + passed = not has_world and not has_bang and finish == "stop" + print(f" RESULT: {'PASS' if passed else 'FAIL'}") + return passed + +def test_3(): + """Empty stop sequence array.""" + print("=" * 60) + print("TEST 3: Empty stop sequence array") + code, r = api_call("/v1/chat/completions", { + "model": "default", + "messages": [{"role": "user", "content": "hi"}], + "stop": [], + "max_tokens": 10, + "stream": False + }) + if code == 200: + content = r["choices"][0]["message"]["content"] + print(f" OK: {content[:50]!r}") + passed = len(content) > 0 + else: + print(f" HTTP {code}: {r}") + passed = False + print(f" RESULT: {'PASS' if passed else 'FAIL'}") + return passed + +def test_4(): + """Unicode stop sequences.""" + print("=" * 60) + print("TEST 4: Unicode stop sequences") + _, r = api_call("/v1/chat/completions", { + "model": "default", + "messages": [{"role": "user", "content": "Say 你好世界 then say goodbye"}], + "stop": ["世界"], + "max_tokens": 100, + "stream": False + }) + content = r["choices"][0]["message"]["content"] + has_stop = "世界" in content + print(f" Content: {content!r}") + print(f" Contains '世界': {has_stop}") + print(f" finish_reason: {r['choices'][0]['finish_reason']}") + passed = not has_stop + print(f" RESULT: {'PASS' if passed else 'FAIL'}") + return passed + +def test_5(): + """Streaming stop sequence truncation.""" + print("=" * 60) + print("TEST 5: Streaming stop sequence truncation") + text, lines = stream_call("/v1/chat/completions", { + "model": "default", + "messages": [{"role": "user", "content": "Count: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10"}], + "stop": [", 5"], + "max_tokens": 100, + "stream": True + }) + has_stop = ", 5" in text + print(f" Text: {text!r}") + print(f" Contains ', 5': {has_stop}") + passed = not has_stop + print(f" RESULT: {'PASS' if passed else 'FAIL'}") + return passed + +def test_6(): + """Completions endpoint (/v1/completions).""" + print("=" * 60) + print("TEST 6: Completions endpoint") + code, r = api_call("/v1/completions", { + "model": "default", + "prompt": "def fibonacci(n):\n ", + "max_tokens": 50, + "stop": ["\n\n"], + "temperature": 0 + }) + print(f" HTTP {code}") + if code == 200: + if isinstance(r, dict): + print(f" Response: {json.dumps(r, indent=2)[:300]}") + has_choices = "choices" in r and len(r["choices"]) > 0 + has_text = has_choices and "text" in r["choices"][0] + passed = has_choices and has_text + else: + print(f" Response: {r[:200]}") + passed = False + elif code == 404: + print(f" Endpoint not implemented (404)") + passed = False + else: + print(f" Response: {r[:200] if isinstance(r, str) else r}") + passed = False + print(f" RESULT: {'PASS' if passed else 'FAIL (endpoint may not be implemented)'}") + return passed + +def test_7(): + """Validation rules - all should return 400.""" + print("=" * 60) + print("TEST 7: Validation rules") + cases = [ + ("max_tokens=0", {"model": "default", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 0}), + ("temp=-0.1", {"model": "default", "messages": [{"role": "user", "content": "hi"}], "temperature": -0.1}), + ("temp=2.1", {"model": "default", "messages": [{"role": "user", "content": "hi"}], "temperature": 2.1}), + ("n=2", {"model": "default", "messages": [{"role": "user", "content": "hi"}], "n": 2}), + ("empty messages", {"model": "default", "messages": []}), + ("invalid role", {"model": "default", "messages": [{"role": "foo", "content": "hi"}]}), + ] + all_pass = True + for name, body in cases: + code, _ = api_call("/v1/chat/completions", body) + ok = code == 400 + if not ok: + all_pass = False + print(f" {name}: HTTP {code} ({'PASS' if ok else 'FAIL - expected 400'})") + print(f" RESULT: {'PASS' if all_pass else 'FAIL'}") + return all_pass + +def test_8(): + """Health endpoint.""" + print("=" * 60) + print("TEST 8: Health endpoint") + code, r = api_call("/health") + print(f" HTTP {code}") + if code == 200 and isinstance(r, dict): + print(f" {json.dumps(r, indent=2)}") + passed = True + else: + print(f" Response: {r}") + passed = False + print(f" RESULT: {'PASS' if passed else 'FAIL'}") + return passed + +def test_9(): + """Model endpoint format validation.""" + print("=" * 60) + print("TEST 9: Models endpoint format validation") + code, r = api_call("/v1/models") + if code != 200: + print(f" HTTP {code}: {r}") + print(" RESULT: FAIL") + return False + checks = [] + checks.append(("object == 'list'", r.get("object") == "list")) + checks.append(("has data", len(r.get("data", [])) > 0)) + if r.get("data"): + m = r["data"][0] + checks.append(("has id", "id" in m)) + checks.append(("object == 'model'", m.get("object") == "model")) + checks.append(("has created", "created" in m)) + checks.append(("has owned_by", "owned_by" in m)) + print(f" Model: {json.dumps(m, indent=2)}") + all_pass = True + for name, ok in checks: + if not ok: + all_pass = False + print(f" {name}: {'PASS' if ok else 'FAIL'}") + print(f" RESULT: {'PASS' if all_pass else 'FAIL'}") + return all_pass + +def test_10(): + """Streaming usage stats (stream_options).""" + print("=" * 60) + print("TEST 10: Streaming usage stats") + text, lines = stream_call("/v1/chat/completions", { + "model": "default", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 10, + "stream": True, + "stream_options": {"include_usage": True} + }) + print(f" Total SSE data lines: {len(lines)}") + print(f" Last 3 lines:") + for l in lines[-3:]: + print(f" {l[:200]}") + + found_usage = False + for l in reversed(lines): + if "[DONE]" in l: + continue + chunk = json.loads(l[5:].strip()) + if "usage" in chunk and chunk["usage"] is not None: + found_usage = True + print(f" Usage: {chunk['usage']}") + break + print(f" Has usage in final chunk: {found_usage}") + print(f" RESULT: {'PASS' if found_usage else 'FAIL'}") + return found_usage + +if __name__ == "__main__": + results = {} + for i, test_fn in enumerate([test_1, test_2, test_3, test_4, test_5, test_6, test_7, test_8, test_9, test_10], 1): + try: + results[i] = test_fn() + except Exception as e: + print(f" EXCEPTION: {e}") + results[i] = False + print() + + print("=" * 60) + print("SUMMARY") + print("=" * 60) + for i in range(1, 11): + status = "PASS" if results.get(i) else "FAIL" + print(f" Test {i:2d}: {status}") + passed = sum(1 for v in results.values() if v) + total = len(results) + print(f"\n {passed}/{total} tests passed") diff --git a/tests/test_api_models.py b/tests/test_api_models.py index 2b18df3..7cef58b 100644 --- a/tests/test_api_models.py +++ b/tests/test_api_models.py @@ -395,7 +395,7 @@ def test_model_info(self): info = ModelInfo(id="mlx-community/Llama-3.2-3B-Instruct-4bit") assert info.id == "mlx-community/Llama-3.2-3B-Instruct-4bit" assert info.object == "model" - assert info.owned_by == "vllm-mlx" + assert info.owned_by == "rapid-mlx" def test_models_response(self): resp = ModelsResponse( diff --git a/tests/test_mllm_mtp_routing.py b/tests/test_mllm_mtp_routing.py new file mode 100644 index 0000000..e2394cf --- /dev/null +++ b/tests/test_mllm_mtp_routing.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for MLLM + MTP per-request routing.""" + + +def test_has_media_content_text_only(): + from vllm_mlx.engine.simple import _has_media_content + + assert _has_media_content([{"role": "user", "content": "Hello"}]) is False + + +def test_has_media_content_with_image(): + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's this?"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,..."}, + }, + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_with_video(): + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": "file:///tmp/v.mp4"}} + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_empty(): + from vllm_mlx.engine.simple import _has_media_content + + assert _has_media_content([]) is False + + +def test_has_media_content_string_content(): + """String content (not list) should return False.""" + from vllm_mlx.engine.simple import _has_media_content + + assert _has_media_content([{"role": "user", "content": "Just text"}]) is False + + +def test_has_media_content_audio(): + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,..."}} + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_multi_turn(): + """Media in earlier turns should still be detected.""" + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Look at this"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,..."}, + }, + ], + }, + {"role": "assistant", "content": "I see an image."}, + {"role": "user", "content": "Tell me more about it."}, + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_text_list(): + """List content with only text parts should return False.""" + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "World"}, + ], + } + ] + assert _has_media_content(messages) is False + + +# --- MLXMultimodalLM extraction method tests --- + +from unittest.mock import MagicMock + + +def test_get_language_model(): + from vllm_mlx.models.mllm import MLXMultimodalLM + + mllm = MagicMock(spec=MLXMultimodalLM) + inner_lm = MagicMock() + mllm.model = MagicMock() + mllm.model.language_model = inner_lm + assert MLXMultimodalLM.get_language_model(mllm) is inner_lm + + +def test_get_tokenizer(): + from vllm_mlx.models.mllm import MLXMultimodalLM + + mllm = MagicMock(spec=MLXMultimodalLM) + inner_tok = MagicMock() + mllm.processor = MagicMock() + mllm.processor.tokenizer = inner_tok + assert MLXMultimodalLM.get_tokenizer(mllm) is inner_tok diff --git a/tests/test_platform.py b/tests/test_platform.py index caeeeb2..61ce434 100644 --- a/tests/test_platform.py +++ b/tests/test_platform.py @@ -27,7 +27,7 @@ def test_is_apple_silicon(): def test_mlx_platform_properties(): """Test MLXPlatform class properties.""" - from vllm_mlx.platform import MLXPlatform + from vllm_mlx.vllm_platform import MLXPlatform platform_obj = MLXPlatform() @@ -42,7 +42,7 @@ def test_mlx_platform_properties(): def test_get_device_name(): """Test getting device name.""" - from vllm_mlx.platform import MLXPlatform + from vllm_mlx.vllm_platform import MLXPlatform name = MLXPlatform.get_device_name() assert isinstance(name, str) @@ -51,7 +51,7 @@ def test_get_device_name(): def test_get_device_memory(): """Test getting device memory.""" - from vllm_mlx.platform import MLXPlatform + from vllm_mlx.vllm_platform import MLXPlatform memory = MLXPlatform.get_device_total_memory() assert isinstance(memory, int) @@ -62,7 +62,7 @@ def test_supported_dtypes(): """Test supported dtypes.""" import torch - from vllm_mlx.platform import MLXPlatform + from vllm_mlx.vllm_platform import MLXPlatform platform_obj = MLXPlatform() dtypes = platform_obj.supported_dtypes @@ -84,7 +84,7 @@ def test_plugin_entry_point(): pytest.skip("MLX not installed") result = mlx_platform_plugin() - assert result == "vllm_mlx.platform.MLXPlatform" + assert result == "vllm_mlx.vllm_platform.MLXPlatform" def test_device_info(): diff --git a/tests/test_streaming_pipeline_integration.py b/tests/test_streaming_pipeline_integration.py new file mode 100644 index 0000000..7ddcf39 --- /dev/null +++ b/tests/test_streaming_pipeline_integration.py @@ -0,0 +1,240 @@ +"""Integration test for the Anthropic streaming pipeline. + +Tests the full flow: raw model output → StreamingToolCallFilter → StreamingThinkRouter +→ Anthropic SSE events, verifying block transitions, tool call extraction, and +prompt_tokens tracking work together correctly. +""" + +import json +import unittest + +from vllm_mlx.api.utils import StreamingToolCallFilter, StreamingThinkRouter +from vllm_mlx.server import _emit_content_pieces + + +class TestEmitContentPieces(unittest.TestCase): + """Test the refactored _emit_content_pieces helper.""" + + def test_single_text_block(self): + events, block_type, index = _emit_content_pieces([("text", "hello")], None, 0) + assert len(events) == 2 # block_start + delta + assert block_type == "text" + assert index == 0 + # Verify block_start + start_data = json.loads(events[0].split("data: ")[1]) + assert start_data["type"] == "content_block_start" + assert start_data["content_block"]["type"] == "text" + # Verify delta + delta_data = json.loads(events[1].split("data: ")[1]) + assert delta_data["delta"]["text"] == "hello" + + def test_single_thinking_block(self): + events, block_type, index = _emit_content_pieces( + [("thinking", "reasoning")], None, 0 + ) + assert block_type == "thinking" + delta_data = json.loads(events[1].split("data: ")[1]) + assert delta_data["delta"]["thinking"] == "reasoning" + + def test_transition_thinking_to_text(self): + events, block_type, index = _emit_content_pieces( + [("thinking", "reason"), ("text", "answer")], None, 0 + ) + assert block_type == "text" + assert index == 1 # incremented on block transition + # Should have: start_thinking, delta_thinking, stop_thinking, start_text, delta_text + assert len(events) == 5 + stop_data = json.loads(events[2].split("data: ")[1]) + assert stop_data["type"] == "content_block_stop" + + def test_continues_existing_block(self): + """If current_block_type matches, no start/stop emitted.""" + events, block_type, index = _emit_content_pieces([("text", "more")], "text", 0) + assert len(events) == 1 # just delta, no start + assert block_type == "text" + + def test_empty_pieces(self): + events, block_type, index = _emit_content_pieces([], None, 0) + assert events == [] + assert block_type is None + assert index == 0 + + +class TestStreamingPipelineIntegration(unittest.TestCase): + """Integration test for the full streaming pipeline.""" + + def _run_pipeline(self, deltas, start_in_thinking=False): + """Run deltas through tool_filter → think_router → emit, return events.""" + tool_filter = StreamingToolCallFilter() + think_router = StreamingThinkRouter(start_in_thinking=start_in_thinking) + current_block_type = None + block_index = 0 + all_events = [] + accumulated_text = "" + + for delta in deltas: + accumulated_text += delta + filtered = tool_filter.process(delta) + if not filtered: + continue + pieces = think_router.process(filtered) + events, current_block_type, block_index = _emit_content_pieces( + pieces, current_block_type, block_index + ) + all_events.extend(events) + + # Flush + remaining = tool_filter.flush() + if remaining: + pieces = think_router.process(remaining) + events, current_block_type, block_index = _emit_content_pieces( + pieces, current_block_type, block_index + ) + all_events.extend(events) + + flush_pieces = think_router.flush() + if flush_pieces: + events, current_block_type, block_index = _emit_content_pieces( + flush_pieces, current_block_type, block_index + ) + all_events.extend(events) + + # Close final block + if current_block_type is not None: + all_events.append( + f"event: content_block_stop\ndata: " + f"{json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" + ) + block_index += 1 + + return all_events, accumulated_text, block_index + + def _parse_events(self, events): + """Parse SSE events into structured data.""" + parsed = [] + for event in events: + data_line = event.split("data: ", 1)[1].split("\n")[0] + parsed.append(json.loads(data_line)) + return parsed + + def test_pure_text_response(self): + """Simple text response - one text block.""" + events, _, block_index = self._run_pipeline(["Hello ", "world!"]) + parsed = self._parse_events(events) + + # block_start, 2 deltas, block_stop + types = [p["type"] for p in parsed] + assert types[0] == "content_block_start" + assert parsed[0]["content_block"]["type"] == "text" + assert types[-1] == "content_block_stop" + assert block_index == 1 + + def test_thinking_then_text(self): + """Model thinks then responds.""" + events, _, block_index = self._run_pipeline( + ["Let me think", " about this", "The answer is 42"] + ) + parsed = self._parse_events(events) + + block_starts = [p for p in parsed if p["type"] == "content_block_start"] + assert len(block_starts) == 2 + assert block_starts[0]["content_block"]["type"] == "thinking" + assert block_starts[1]["content_block"]["type"] == "text" + assert block_index == 2 + + def test_start_in_thinking_then_text(self): + """Model starts in thinking mode (template injects ).""" + events, _, _ = self._run_pipeline( + ["reasoning here", "", "The answer"], + start_in_thinking=True, + ) + parsed = self._parse_events(events) + + block_starts = [p for p in parsed if p["type"] == "content_block_start"] + assert len(block_starts) == 2 + assert block_starts[0]["content_block"]["type"] == "thinking" + assert block_starts[1]["content_block"]["type"] == "text" + + def test_text_then_tool_call(self): + """Text followed by tool call - tool markup suppressed from text.""" + events, accumulated, _ = self._run_pipeline( + [ + "I'll search for that. ", + "", + '', + 'ls /tmp', + "", + "", + ] + ) + parsed = self._parse_events(events) + + # Only text block should appear (tool call is suppressed from streaming) + text_deltas = [ + p + for p in parsed + if p["type"] == "content_block_delta" + and p["delta"].get("type") == "text_delta" + ] + text_content = "".join(d["delta"]["text"] for d in text_deltas) + assert "I'll search for that." in text_content + assert "" not in text_content + + # But accumulated text has the full tool call for parsing + assert "" in accumulated + + def test_thinking_then_tool_call(self): + """Thinking followed by tool call - both properly routed.""" + events, accumulated, _ = self._run_pipeline( + [ + "I need to search", + "", + '', + 'test', + "", + "", + ] + ) + parsed = self._parse_events(events) + + block_starts = [p for p in parsed if p["type"] == "content_block_start"] + # Only thinking block (tool call is suppressed) + assert len(block_starts) == 1 + assert block_starts[0]["content_block"]["type"] == "thinking" + + def test_mixed_thinking_text_and_tool_call(self): + """Full scenario: thinking → text → tool call.""" + events, accumulated, block_index = self._run_pipeline( + [ + "analyzing request", + "Let me help. ", + "", + 'echo hi', + "", + ] + ) + parsed = self._parse_events(events) + + block_starts = [p for p in parsed if p["type"] == "content_block_start"] + # thinking block + text block (tool call suppressed) + assert len(block_starts) == 2 + assert block_starts[0]["content_block"]["type"] == "thinking" + assert block_starts[1]["content_block"]["type"] == "text" + + # Accumulated has everything for post-stream tool parsing + assert "" in accumulated + + def test_block_index_increments_correctly(self): + """Block indices should increment on each transition.""" + events, _, final_index = self._run_pipeline( + ["t1textt2end"] + ) + parsed = self._parse_events(events) + + starts = [p for p in parsed if p["type"] == "content_block_start"] + assert [s["index"] for s in starts] == [0, 1, 2, 3] + assert final_index == 4 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_streaming_think_router.py b/tests/test_streaming_think_router.py new file mode 100644 index 0000000..d5272cd --- /dev/null +++ b/tests/test_streaming_think_router.py @@ -0,0 +1,168 @@ +"""Tests for StreamingThinkRouter - routes blocks to Anthropic thinking content blocks.""" + +import unittest + +from vllm_mlx.api.utils import StreamingThinkRouter + + +class TestStreamingThinkRouter(unittest.TestCase): + """Unit tests for StreamingThinkRouter.""" + + # --- Basic routing --- + + def test_plain_text_routes_as_text(self): + r = StreamingThinkRouter() + assert r.process("Hello world") == [("text", "Hello world")] + + def test_think_block_routes_as_thinking(self): + r = StreamingThinkRouter() + assert r.process("reasoning") == [("thinking", "reasoning")] + + def test_text_then_think_then_text(self): + r = StreamingThinkRouter() + result = r.process("beforemiddleafter") + assert result == [("text", "before"), ("thinking", "middle"), ("text", "after")] + + # --- start_in_thinking mode --- + + def test_start_in_thinking_mode(self): + """When model injects into prompt, output starts in thinking mode.""" + r = StreamingThinkRouter(start_in_thinking=True) + result = r.process("reasoning here") + assert result == [("thinking", "reasoning here")] + + def test_start_in_thinking_then_close(self): + """Thinking closes with , then text follows.""" + r = StreamingThinkRouter(start_in_thinking=True) + result = r.process("reasoninganswer") + assert result == [("thinking", "reasoning"), ("text", "answer")] + + def test_start_in_thinking_close_across_deltas(self): + """ split across multiple deltas.""" + r = StreamingThinkRouter(start_in_thinking=True) + p1 = r.process("thinking stuffnow text") + # First delta should hold back partial + assert ("thinking", "thinking stuff") in p1 + # Second delta should transition + all_pieces = p1 + p2 + types = [t for t, _ in all_pieces] + assert "text" in types + + # --- Partial tag handling --- + + def test_partial_open_tag_held_back(self): + """Partial reasoning") + # p1 should emit "Hello " but hold back "answer") + # p1 should emit thinking content but hold back partial + assert ("thinking", "deep thought") in p1 + # p2 should transition to text + assert ("text", "answer") in p2 + + def test_partial_tag_false_alarm(self): + """Partial match that turns out not to be a tag.""" + r = StreamingThinkRouter() + p1 = r.process("Hello ") + # After p2, the held-back "" should emit as text + all_text = "".join(t for bt, t in p1 + p2 if bt == "text") + assert "Hello " == all_text + + # --- Multiple think blocks --- + + def test_multiple_think_blocks(self): + r = StreamingThinkRouter() + result = r.process("firstmiddlesecondend") + assert result == [ + ("thinking", "first"), + ("text", "middle"), + ("thinking", "second"), + ("text", "end"), + ] + + # --- Streaming across deltas --- + + def test_streaming_token_by_token(self): + """Simulate character-by-character streaming.""" + r = StreamingThinkRouter() + text = "abcxyz" + all_pieces = [] + for ch in text: + all_pieces.extend(r.process(ch)) + all_pieces.extend(r.flush()) + thinking = "".join(t for bt, t in all_pieces if bt == "thinking") + text_out = "".join(t for bt, t in all_pieces if bt == "text") + assert thinking == "abc" + assert text_out == "xyz" + + def test_streaming_with_start_in_thinking(self): + """Token-by-token with start_in_thinking.""" + r = StreamingThinkRouter(start_in_thinking=True) + text = "reasoningthe answer" + all_pieces = [] + for ch in text: + all_pieces.extend(r.process(ch)) + all_pieces.extend(r.flush()) + thinking = "".join(t for bt, t in all_pieces if bt == "thinking") + text_out = "".join(t for bt, t in all_pieces if bt == "text") + assert thinking == "reasoning" + assert text_out == "the answer" + + # --- Flush behavior --- + + def test_flush_emits_remaining_text(self): + """Text without partial tags is emitted by process(), flush() is empty.""" + r = StreamingThinkRouter() + pieces = r.process("partial text") + assert pieces == [("text", "partial text")] + assert r.flush() == [] + + def test_flush_emits_remaining_thinking(self): + """Thinking without partial tags is emitted by process(), flush() is empty.""" + r = StreamingThinkRouter(start_in_thinking=True) + pieces = r.process("unfinished thought") + assert pieces == [("thinking", "unfinished thought")] + assert r.flush() == [] + + def test_flush_with_held_back_partial(self): + """Flush should emit held-back partial tag as content.""" + r = StreamingThinkRouter() + r.process("text") + f.process('') + f.process('/tmp/test.txt') + f.process("") + result = f.process("") + assert result == "" + + def test_text_after_tool_call_emits(self): + f = StreamingToolCallFilter() + f.process("content") + assert f.process("After") == "After" + + def test_text_before_and_after_same_delta(self): + f = StreamingToolCallFilter() + result = f.process("Before insideAfter") + assert result == "Before After" + + def test_split_across_deltas(self): + f = StreamingToolCallFilter() + r1 = f.process("Before insideAfter") + assert r1 + r2 == "Before After" + + def test_qwen_format_suppressed(self): + f = StreamingToolCallFilter() + result = f.process('Text {"name":"fn"} more') + assert result == "Text more" + + def test_multiple_tool_calls(self): + f = StreamingToolCallFilter() + result = f.process( + "A x" + " B y C" + ) + assert result == "A B C" + + def test_flush_partial_tag_emits(self): + f = StreamingToolCallFilter() + r = f.process("text partial content") + assert f.flush() == "" + + def test_large_tool_call_content(self): + """Simulates a Read tool returning a large file.""" + f = StreamingToolCallFilter() + big = "x" * 10000 + result = f.process(f"Before {big}After") + assert result == "Before After" + + def test_think_tags_not_filtered(self): + f = StreamingToolCallFilter() + result = f.process("reasoning hereanswer") + assert "" in result + assert "reasoning here" in result + + def test_mixed_think_and_tool_call(self): + f = StreamingToolCallFilter() + result = f.process( + "thinking" + "tool stuff" + "final answer" + ) + assert "thinking" in result + assert "tool stuff" not in result + assert "final answer" in result + + def test_gradual_token_by_token(self): + """Simulate token-by-token streaming.""" + f = StreamingToolCallFilter() + parts = [ + "Hello ", + "<", + "mini", + "max:", + "tool_call", + ">", + '', + "", + "", + " world", + ] + result = "" + for part in parts: + result += f.process(part) + result += f.flush() + assert result == "Hello world", f"Got: {result!r}" + + def test_empty_deltas(self): + f = StreamingToolCallFilter() + assert f.process("") == "" + assert f.process("text") == "text" + assert f.process("") == "" + + def test_calling_tool_bracket_suppressed(self): + """Qwen3 bracket-style: [Calling tool: func({...})]\n""" + f = StreamingToolCallFilter() + result = f.process('[Calling tool: search({"q": "test"})]\n') + assert result == "" + + def test_calling_tool_multiline_json(self): + """Multi-line JSON args in bracket-style tool call.""" + f = StreamingToolCallFilter() + r1 = f.process('[Calling tool: search({"q": "test",') + r2 = f.process(' "limit": 5})]\n') + r3 = f.process("After") + assert r1 + r2 + r3 == "After" + + def test_buffer_cap_on_unclosed_block(self): + """Buffer should be capped if tool call block never closes.""" + from vllm_mlx.api.utils import _MAX_TOOL_BUFFER_BYTES + + f = StreamingToolCallFilter() + f.process("") + # Feed data exceeding the cap + chunk = "x" * 10000 + for _ in range(_MAX_TOOL_BUFFER_BYTES // 10000 + 2): + f.process(chunk) + # After cap, filter should have exited the block + assert not f._in_block + # New text should pass through + assert f.process("after") == "after" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_text_model_from_vlm.py b/tests/test_text_model_from_vlm.py new file mode 100644 index 0000000..037ff81 --- /dev/null +++ b/tests/test_text_model_from_vlm.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for building mlx_lm TextModel from mlx_vlm-loaded weights.""" + +import json +from pathlib import Path + +import pytest + +from vllm_mlx.text_model_from_vlm import build_text_model + +# VLM+MTP model (created by merging mlx-community VLM + our MTP weights) +VLM_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-VLM-MTP-8bit" + +# Text-only MTP model (no vision tower — can't test VLM loading) +TEXT_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-8bit" + + +def test_build_text_model_no_config(): + """Returns None when model path has no config.json.""" + result = build_text_model(None, "/nonexistent/path") + assert result is None + + +def test_build_text_model_none_vlm(): + """Returns None when vlm_model is None.""" + result = build_text_model(None, TEXT_MTP_MODEL) + assert result is None + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_build_text_model_moe(): + """build_text_model creates a TextModel with shared weights and MTP (MoE).""" + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, processor = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + assert text_model is not None, "build_text_model returned None" + + # TextModel should have MTP (config has mtp_num_hidden_layers=1) + assert hasattr(text_model, "mtp"), "TextModel missing .mtp attribute" + assert text_model.mtp is not None, "TextModel.mtp is None" + assert hasattr(text_model, "mtp_forward"), "TextModel missing mtp_forward method" + assert hasattr( + text_model, "make_mtp_cache" + ), "TextModel missing make_mtp_cache method" + + # Verify MoE layer exists in MTP + mtp_layer = text_model.mtp.layers[0] + assert hasattr(mtp_layer, "mlp"), "MTP layer missing mlp" + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_text_model_mtp_forward(): + """TextModel.mtp_forward returns logits of correct vocab_size shape.""" + import mlx.core as mx + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, _ = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + config = json.loads((VLM_MTP_MODEL / "config.json").read_text()) + text_config = config.get("text_config", config) + + mtp_cache = text_model.make_mtp_cache() + assert len(mtp_cache) > 0 + + hidden = mx.zeros((1, 1, text_config["hidden_size"])) + next_ids = mx.array([[0]]) + logits = text_model.mtp_forward(hidden, next_ids, mtp_cache) + + assert ( + logits.shape[-1] == text_config["vocab_size"] + ), f"Expected vocab_size={text_config['vocab_size']}, got {logits.shape[-1]}" + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_text_model_return_hidden(): + """TextModel supports return_hidden=True (required by mtp_generate_step).""" + import mlx.core as mx + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, _ = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + config = json.loads((VLM_MTP_MODEL / "config.json").read_text()) + text_config = config.get("text_config", config) + + cache = text_model.make_cache() + tokens = mx.array([[1, 2, 3]]) # Dummy token IDs + + # return_hidden=True should return (logits, hidden_states) + result = text_model(tokens, cache=cache, return_hidden=True) + + # Should be a tuple of (logits, hidden) + assert isinstance(result, tuple), f"Expected tuple, got {type(result)}" + logits, hidden = result + assert logits.shape[-1] == text_config["vocab_size"] + assert hidden.shape[-1] == text_config["hidden_size"] + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_weight_sharing(): + """Backbone weights are shared (zero-copy) between vlm and TextModel.""" + import mlx.core as mx + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, _ = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + # Compare a backbone weight reference. + # Layer 0 may be linear_attn (GatedDeltaNet) on MoE models, so find a layer + # with self_attn (full attention layers are at indices 11, 15, 19, 23, 27). + for i in range(len(vlm_model.language_model.model.layers)): + layer = vlm_model.language_model.model.layers[i] + if hasattr(layer, "self_attn"): + vlm_weight = layer.self_attn.q_proj.weight + tm_weight = text_model.model.layers[i].self_attn.q_proj.weight + assert mx.array_equal( + vlm_weight, tm_weight + ), f"Weights at layer {i} should be identical" + break + else: + pytest.fail("No layer with self_attn found") diff --git a/tests/test_video.py b/tests/test_video.py new file mode 100644 index 0000000..b66efbd --- /dev/null +++ b/tests/test_video.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for video support in MLLM chat/stream_chat.""" + +from vllm_mlx.models.mllm import ( + FRAME_FACTOR, + MIN_FRAMES, + MLXMultimodalLM, + is_base64_video, + smart_nframes, +) + + +class TestSmartNframes: + """Verify frame count alignment and clamping.""" + + def test_basic_calculation(self): + # 300 frames at 30fps = 10s video, at 2fps target = 20 frames + result = smart_nframes(300, 30.0, target_fps=2.0) + assert result == 20 + assert result % FRAME_FACTOR == 0 + + def test_clamps_to_min(self): + # Very short video: 6 frames at 30fps + result = smart_nframes(6, 30.0, target_fps=2.0) + assert result >= MIN_FRAMES + assert result % FRAME_FACTOR == 0 + + def test_clamps_to_max(self): + # Very long video: 100000 frames + result = smart_nframes(100000, 30.0, target_fps=2.0, max_frames=64) + assert result <= 64 + assert result % FRAME_FACTOR == 0 + + def test_result_always_even(self): + for total in [5, 7, 11, 13, 100, 999]: + result = smart_nframes(total, 30.0) + assert ( + result % FRAME_FACTOR == 0 + ), f"Odd frame count {result} for total={total}" + + +class TestVideoUrlParsing: + """Verify video_url content type extraction from OpenAI messages.""" + + def _make_model(self): + """Create an unloaded model instance for testing.""" + model = MLXMultimodalLM.__new__(MLXMultimodalLM) + model._loaded = False + model._video_native = False + return model + + def _extract_video_inputs(self, messages): + """Use the actual _collect_video_inputs helper.""" + model = self._make_model() + return model._collect_video_inputs(messages) + + def test_video_url_dict_format(self): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": {"url": "https://example.com/video.mp4"}, + }, + {"type": "text", "text": "Describe this video"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert 0 in result + assert result[0] == ["https://example.com/video.mp4"] + + def test_video_url_string_format(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": "https://example.com/video.mp4"}, + {"type": "text", "text": "Describe"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert result[0] == ["https://example.com/video.mp4"] + + def test_video_type(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": "/path/to/video.mp4"}, + {"type": "text", "text": "Describe"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert result[0] == ["/path/to/video.mp4"] + + def test_no_video(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert len(result) == 0 + + def test_mixed_media(self): + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "https://example.com/img.jpg"}, + }, + { + "type": "video_url", + "video_url": {"url": "https://example.com/vid.mp4"}, + }, + {"type": "text", "text": "Compare"}, + ], + } + ] + result = self._extract_video_inputs(messages) + # Only video extracted, not image + assert result[0] == ["https://example.com/vid.mp4"] + + def test_multi_message_videos(self): + """Videos in different messages should be keyed by message index.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": "/path/first.mp4"}, + {"type": "text", "text": "First"}, + ], + }, + {"role": "assistant", "content": "OK"}, + { + "role": "user", + "content": [ + {"type": "video", "video": "/path/second.mp4"}, + {"type": "text", "text": "Second"}, + ], + }, + ] + result = self._extract_video_inputs(messages) + assert result[0] == ["/path/first.mp4"] + assert result[2] == ["/path/second.mp4"] + assert 1 not in result + + def test_multiple_videos_single_message(self): + """Multiple videos in one message should produce a list at that index.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": "/path/a.mp4"}, + {"type": "video_url", "video_url": {"url": "/path/b.mp4"}}, + {"type": "text", "text": "Compare these"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert result[0] == ["/path/a.mp4", "/path/b.mp4"] + + +class TestTranslateMessages: + """Verify OpenAI format to process_vision_info format translation.""" + + def _make_model(self): + """Create an unloaded model instance for testing translation.""" + model = MLXMultimodalLM.__new__(MLXMultimodalLM) + model._loaded = False + model._video_native = True + return model + + def test_text_only_passthrough(self): + model = self._make_model() + messages = [{"role": "user", "content": "Hello"}] + result = model._translate_messages_for_native_video(messages, 2.0, 128) + assert result[0]["content"] == "Hello" + + def test_video_url_translated(self): + import os + import tempfile + + # Create a temp file to act as a "video" + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(b"\x00" * 100) + video_path = f.name + + try: + model = self._make_model() + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": video_path}, + {"type": "text", "text": "Describe"}, + ], + } + ] + result = model._translate_messages_for_native_video(messages, 2.0, 128) + content = result[0]["content"] + + # Should have video and text items + types = [item["type"] for item in content] + assert "video" in types + assert "text" in types + + # Video item should have fps and max_frames + video_item = next(i for i in content if i["type"] == "video") + assert video_item["fps"] == 2.0 + assert video_item["max_frames"] == 128 + finally: + os.unlink(video_path) + + def test_video_url_type_translated(self): + import os + import tempfile + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(b"\x00" * 100) + video_path = f.name + + try: + model = self._make_model() + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": {"url": video_path}, + }, + {"type": "text", "text": "Describe"}, + ], + } + ] + result = model._translate_messages_for_native_video(messages, 1.0, 64) + content = result[0]["content"] + + types = [item["type"] for item in content] + assert "video" in types + assert "text" in types + + video_item = next(i for i in content if i["type"] == "video") + assert video_item["fps"] == 1.0 + assert video_item["max_frames"] == 64 + finally: + os.unlink(video_path) + + +class TestCollectVideoInputsPydantic: + """Verify _collect_video_inputs handles Pydantic models correctly.""" + + def _make_model(self): + model = MLXMultimodalLM.__new__(MLXMultimodalLM) + model._loaded = False + model._video_native = False + return model + + def test_pydantic_model_dump(self): + """Pydantic v2 objects with model_dump() are handled.""" + + class FakeContent: + def model_dump(self, exclude_none=False): + return {"type": "video", "video": "/path/to/video.mp4"} + + messages = [{"role": "user", "content": [FakeContent()]}] + result = self._make_model()._collect_video_inputs(messages) + assert result[0] == ["/path/to/video.mp4"] + + def test_pydantic_v1_dict(self): + """Pydantic v1 objects with dict() are handled.""" + + class FakeContent: + def dict(self): + return { + "type": "video_url", + "video_url": {"url": "https://example.com/v.mp4"}, + "image_url": None, + } + + messages = [{"role": "user", "content": [FakeContent()]}] + result = self._make_model()._collect_video_inputs(messages) + assert result[0] == ["https://example.com/v.mp4"] + + def test_empty_video_url_skipped(self): + """Empty video URL dicts are skipped.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": ""}}, + ], + } + ] + result = self._make_model()._collect_video_inputs(messages) + assert len(result) == 0 + + +class TestToolForwarding: + """Verify tools are popped from kwargs before native video path.""" + + def test_tools_not_in_kwargs_after_pop(self): + """Ensure tools don't leak into **kwargs for mlx_vlm.generate().""" + model = MLXMultimodalLM.__new__(MLXMultimodalLM) + model._loaded = False + model._video_native = True + + tools = [{"type": "function", "function": {"name": "test"}}] + kwargs = {"tools": tools, "video_fps": 2.0, "video_max_frames": 64} + + # Simulate what chat() does: pop tools before native video branch + video_fps = kwargs.pop("video_fps", 2.0) + video_max_frames = kwargs.pop("video_max_frames", 128) + popped_tools = kwargs.pop("tools", None) + + assert popped_tools == tools + assert "tools" not in kwargs + + def test_generate_native_video_accepts_tools_param(self): + """Verify _generate_native_video signature accepts tools kwarg.""" + import inspect + + sig = inspect.signature(MLXMultimodalLM._generate_native_video) + params = list(sig.parameters.keys()) + assert "tools" in params + + def test_prepare_native_video_inputs_accepts_tools(self): + """Verify preprocessing helper also accepts tools.""" + import inspect + + sig = inspect.signature(MLXMultimodalLM._prepare_native_video_inputs) + params = list(sig.parameters.keys()) + assert "tools" in params + + def test_generate_imports_from_video_generate(self): + """Verify _generate_native_video uses mlx_vlm.video_generate.generate.""" + import inspect + + source = inspect.getsource(MLXMultimodalLM._generate_native_video) + assert "from mlx_vlm.video_generate import generate" in source + + +class TestIsBase64Video: + def test_detects_base64_video(self): + assert is_base64_video("data:video/mp4;base64,AAAA") is True + + def test_rejects_non_video(self): + assert is_base64_video("data:image/jpeg;base64,AAAA") is False + assert is_base64_video("https://example.com/video.mp4") is False diff --git a/vllm_mlx/__init__.py b/vllm_mlx/__init__.py index fd0f9c9..7dc3dc9 100644 --- a/vllm_mlx/__init__.py +++ b/vllm_mlx/__init__.py @@ -71,7 +71,7 @@ def __getattr__(name): # vLLM integration components (require torch) if name == "MLXPlatform": - from vllm_mlx.platform import MLXPlatform + from vllm_mlx.vllm_platform import MLXPlatform return MLXPlatform if name == "MLXWorker": diff --git a/vllm_mlx/api/__init__.py b/vllm_mlx/api/__init__.py index 89555c8..666962a 100644 --- a/vllm_mlx/api/__init__.py +++ b/vllm_mlx/api/__init__.py @@ -65,6 +65,8 @@ from .utils import ( MLLM_PATTERNS, SPECIAL_TOKENS_PATTERN, + StreamingThinkRouter, + StreamingToolCallFilter, clean_output_text, extract_multimodal_content, is_mllm_model, @@ -118,6 +120,8 @@ "MLLM_PATTERNS", "SPECIAL_TOKENS_PATTERN", "strip_special_tokens", + "StreamingToolCallFilter", + "StreamingThinkRouter", # Tool calling "parse_tool_calls", "convert_tools_for_template", diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index dff61df..8b9ee08 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -209,6 +209,10 @@ class ChatCompletionRequest(BaseModel): enable_thinking: bool | None = None # Number of completions (only n=1 supported) n: int | None = None + # SpecPrefill: per-request enable/disable (None = server decides) + specprefill: bool | None = None + # SpecPrefill: per-request keep percentage (0.0-1.0, None = use server default) + specprefill_keep_pct: float | None = None class AssistantMessage(BaseModel): diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index e33c348..069e5ea 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -3,10 +3,13 @@ Utility functions for text processing and model detection. """ +import logging import re from .models import Message +logger = logging.getLogger(__name__) + # ============================================================================= # Special Token Patterns # ============================================================================= @@ -200,6 +203,224 @@ def extract_json_from_response(text: str) -> str: return text +# ============================================================================= +# Streaming Tool Call Filter +# ============================================================================= + +# Safety cap for tool call buffer (bytes). If a tool call block never closes, +# the buffer is capped to prevent unbounded memory growth. In practice, the +# buffer is bounded by max_tokens (~100KB at 32768 tokens), but this cap +# protects against pathological cases. +_MAX_TOOL_BUFFER_BYTES = 1_048_576 # 1 MB + +# Tags that delimit tool call blocks in streaming output. +# Content inside these tags should be suppressed during streaming because +# it will be re-emitted as structured tool_use blocks after parsing. +_TOOL_CALL_TAGS = [ + ("", ""), + ("", ""), + (""), + ("[TOOL_CALL]", "[/TOOL_CALL]"), + ("[Calling tool", "]\n"), # Qwen3 bracket-style: [Calling tool: func({...})]\n +] + + +class StreamingToolCallFilter: + """Buffer streaming text to suppress tool call markup. + + Tool call XML (e.g. ...) arrives + split across multiple streaming deltas. This filter detects entry into a + tool call block, suppresses all output until the block closes, and emits + only non-tool-call text. + + The full unfiltered text is still accumulated separately for tool call + parsing at stream end. + """ + + def __init__(self): + self._buffer = "" + self._in_block = False + self._close_tag = "" + # Longest open tag - used to determine how much buffer to hold back + self._max_open_len = max(len(t[0]) for t in _TOOL_CALL_TAGS) + + def process(self, delta: str) -> str: + """Process a streaming delta. Returns text to emit (may be empty).""" + self._buffer += delta + + if self._in_block: + return self._consume_block() + else: + return self._scan_for_open() + + def _scan_for_open(self) -> str: + """Scan buffer for tool call open tags. Emit safe text.""" + # Check for complete open tags + for open_tag, close_tag in _TOOL_CALL_TAGS: + idx = self._buffer.find(open_tag) + if idx >= 0: + # Found an open tag - emit text before it, enter block mode + emit = self._buffer[:idx] + self._buffer = self._buffer[idx + len(open_tag) :] + self._in_block = True + self._close_tag = close_tag + # Process remainder in case close tag is already in buffer + after = self._consume_block() + return emit + after + + # No complete open tag found. Check if buffer ends with a partial + # match of any open tag - hold that back to avoid emitting a fragment. + hold_back = 0 + for open_tag, _ in _TOOL_CALL_TAGS: + for prefix_len in range(min(len(open_tag), len(self._buffer)), 0, -1): + if self._buffer.endswith(open_tag[:prefix_len]): + hold_back = max(hold_back, prefix_len) + break + + if hold_back > 0: + emit = self._buffer[:-hold_back] + self._buffer = self._buffer[-hold_back:] + return emit + + # No partial match - safe to emit everything + emit = self._buffer + self._buffer = "" + return emit + + def _consume_block(self) -> str: + """Consume content inside a tool call block. Returns empty string + unless the block closes and there's text after it.""" + idx = self._buffer.find(self._close_tag) + if idx >= 0: + # Block closed - discard content up to and including close tag + self._buffer = self._buffer[idx + len(self._close_tag) :] + self._in_block = False + self._close_tag = "" + # Process remainder - might have more text or another tool call + if self._buffer: + return self._scan_for_open() + return "" + # Still inside block - suppress everything but cap buffer size + if len(self._buffer) > _MAX_TOOL_BUFFER_BYTES: + logger.warning( + f"Tool call buffer exceeded {_MAX_TOOL_BUFFER_BYTES} bytes, " + f"discarding and exiting block" + ) + self._buffer = "" + self._in_block = False + self._close_tag = "" + return "" + + def flush(self) -> str: + """Flush remaining buffer at end of stream.""" + if self._in_block: + # Unterminated tool call block - discard + self._buffer = "" + self._in_block = False + return "" + emit = self._buffer + self._buffer = "" + return emit + + +# ============================================================================= +# Streaming Think Block Router +# ============================================================================= + + +class StreamingThinkRouter: + """Route ... content to separate Anthropic thinking blocks. + + Instead of emitting thinking content as plain text (where it's + indistinguishable from the response), this router yields tagged + pieces that the streaming handler can emit as proper Anthropic + content block types. + + Each call to process() returns a list of (block_type, text) tuples: + - ("thinking", text) for content inside ... + - ("text", text) for content outside think blocks + + Args: + start_in_thinking: If True, assume the model starts in thinking + mode (e.g. MiniMax adds to the generation prompt, + so the tag never appears in the output stream). + """ + + def __init__(self, start_in_thinking: bool = False): + self._buffer = "" + self._in_think = start_in_thinking + + def process(self, delta: str) -> list[tuple[str, str]]: + """Process a delta. Returns list of (block_type, text) pieces.""" + self._buffer += delta + pieces = [] + self._extract_pieces(pieces) + return pieces + + def _extract_pieces(self, pieces: list[tuple[str, str]]) -> None: + """Extract all complete pieces from the buffer.""" + while True: + if self._in_think: + idx = self._buffer.find("") + if idx >= 0: + # Emit thinking content, exit think mode + thinking = self._buffer[:idx] + self._buffer = self._buffer[idx + len("") :] + self._in_think = False + if thinking: + pieces.append(("thinking", thinking)) + continue # Process remainder + else: + # Check for partial close tag at end + for plen in range(min(len(""), len(self._buffer)), 0, -1): + if self._buffer.endswith(""[:plen]): + # Hold back partial match + emit = self._buffer[:-plen] + self._buffer = self._buffer[-plen:] + if emit: + pieces.append(("thinking", emit)) + return + # No partial match - emit all as thinking + if self._buffer: + pieces.append(("thinking", self._buffer)) + self._buffer = "" + return + else: + idx = self._buffer.find("") + if idx >= 0: + # Emit text before tag, enter think mode + before = self._buffer[:idx] + self._buffer = self._buffer[idx + len("") :] + self._in_think = True + if before: + pieces.append(("text", before)) + continue # Process remainder + else: + # Check for partial open tag at end + for plen in range(min(len(""), len(self._buffer)), 0, -1): + if self._buffer.endswith(""[:plen]): + emit = self._buffer[:-plen] + self._buffer = self._buffer[-plen:] + if emit: + pieces.append(("text", emit)) + return + # No partial match - emit all as text + if self._buffer: + pieces.append(("text", self._buffer)) + self._buffer = "" + return + + def flush(self) -> list[tuple[str, str]]: + """Flush remaining buffer at end of stream.""" + pieces = [] + if self._buffer: + block_type = "thinking" if self._in_think else "text" + pieces.append((block_type, self._buffer)) + self._buffer = "" + self._in_think = False + return pieces + + # ============================================================================= # Model Detection # ============================================================================= @@ -267,9 +488,9 @@ def _content_to_text(content) -> str: parts = [] for item in content: if hasattr(item, "model_dump"): - item = item.model_dump() + item = item.model_dump(exclude_none=True) elif hasattr(item, "dict"): - item = item.dict() + item = {k: v for k, v in item.dict().items() if v is not None} if isinstance(item, dict) and item.get("type") == "text": parts.append(item.get("text", "")) return "\n".join(parts) @@ -412,9 +633,9 @@ def extract_multimodal_content( for item in content: # Handle both Pydantic models and dicts if hasattr(item, "model_dump"): - item = item.model_dump() + item = item.model_dump(exclude_none=True) elif hasattr(item, "dict"): - item = item.dict() + item = {k: v for k, v in item.dict().items() if v is not None} item_type = item.get("type", "") diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 5cba7dc..74531e1 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -302,6 +302,16 @@ def serve_command(args): print(f"Prefix cache: max_entries={args.prefix_cache_size}") else: print("Mode: Simple (maximum throughput)") + if args.enable_mtp: + print("MTP: enabled (native speculative decoding)") + if args.enable_mtp and getattr(args, "mllm", False): + print("MTP + MLLM: per-request routing (text-only → MTP, media → MLLM)") + if args.specprefill and args.specprefill_draft_model: + print( + f"SpecPrefill: enabled (draft={args.specprefill_draft_model}, " + f"threshold={args.specprefill_threshold}, " + f"keep={args.specprefill_keep_pct*100:.0f}%)" + ) # Check port availability before loading model (avoid wasting RAM on conflict) import socket @@ -337,6 +347,12 @@ def serve_command(args): cloud_threshold=args.cloud_threshold, cloud_api_base=args.cloud_api_base, cloud_api_key=args.cloud_api_key, + served_model_name=args.served_model_name, + mtp=args.enable_mtp, + specprefill_enabled=args.specprefill, + specprefill_threshold=args.specprefill_threshold, + specprefill_keep_pct=args.specprefill_keep_pct, + specprefill_draft_model=args.specprefill_draft_model, ) except Exception as e: # Show clean error instead of raw traceback @@ -790,6 +806,12 @@ def main(): # Serve command serve_parser = subparsers.add_parser("serve", help="Start OpenAI-compatible server") serve_parser.add_argument("model", type=str, help="Model to serve") + serve_parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. If not specified, the model argument is used.", + ) serve_parser.add_argument( "--host", type=str, default="0.0.0.0", help="Host to bind" ) @@ -935,6 +957,44 @@ def main(): help="Skip MTP acceptance check for maximum speed. " "~5-10%% wrong tokens. Best for chat, not for code.", ) + # Prefill step size + serve_parser.add_argument( + "--prefill-step-size", + type=int, + default=2048, + help="Chunk size for prompt prefill processing. Larger values use more memory " + "but can improve prefill throughput. (default: 2048)", + ) + # SpecPrefill (attention-based sparse prefill using draft model) + serve_parser.add_argument( + "--specprefill", + action="store_true", + default=False, + help="Enable SpecPrefill: use a small draft model to score token importance, " + "then sparse-prefill only the important tokens on the target model. " + "Reduces TTFT on long prompts. Requires --specprefill-draft-model.", + ) + serve_parser.add_argument( + "--specprefill-threshold", + type=int, + default=8192, + help="Minimum suffix tokens to trigger SpecPrefill (default: 8192). " + "Shorter prompts use full prefill (scoring overhead > savings).", + ) + serve_parser.add_argument( + "--specprefill-keep-pct", + type=float, + default=0.3, + help="Fraction of tokens to keep during sparse prefill (default: 0.3). " + "Lower = faster prefill but more quality loss.", + ) + serve_parser.add_argument( + "--specprefill-draft-model", + type=str, + default=None, + help="Path to small draft model for SpecPrefill importance scoring. " + "Must share the same tokenizer as the target model.", + ) # MCP options serve_parser.add_argument( "--mcp-config", diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 90e29ae..d7fd283 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -41,9 +41,9 @@ def _extract_media_from_messages(messages: list[dict[str, Any]]) -> tuple: for item in content: # Handle Pydantic models if hasattr(item, "model_dump"): - item = item.model_dump() + item = item.model_dump(exclude_none=True) elif hasattr(item, "dict"): - item = item.dict() + item = {k: v for k, v in item.dict().items() if v is not None} if not isinstance(item, dict): continue diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index f1b54ea..5d10e94 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -28,6 +28,29 @@ GuidedGenerator = None +_MEDIA_TYPES = frozenset( + { + "image_url", + "video_url", + "audio_url", + "image", + "video", + "audio", + } +) + + +def _has_media_content(messages: list) -> bool: + """Check if any message contains media content (images, video, audio).""" + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") in _MEDIA_TYPES: + return True + return False + + class SimpleEngine(BaseEngine): """ Simple engine for direct model calls. @@ -44,9 +67,14 @@ def __init__( force_mllm: bool = False, draft_model: str | None = None, num_draft_tokens: int = 4, + mtp: bool = False, prefill_step_size: int = 2048, kv_bits: int | None = None, kv_group_size: int = 64, + specprefill_enabled: bool = False, + specprefill_threshold: int = 8192, + specprefill_keep_pct: float = 0.3, + specprefill_draft_model: str | None = None, ): """ Initialize the simple engine. @@ -58,9 +86,14 @@ def __init__( force_mllm: Force loading as MLLM even if not auto-detected draft_model: Optional draft model path for speculative decoding num_draft_tokens: Number of tokens to generate speculatively per step + mtp: Enable native MTP speculative decoding (model must have MTP head) prefill_step_size: Tokens to process per prefill chunk (default: 2048) kv_bits: KV cache quantization bits (None=no quantization, 4 or 8) kv_group_size: Group size for KV cache quantization (default: 64) + specprefill_enabled: Enable SpecPrefill (attention-based sparse prefill) + specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill + specprefill_keep_pct: Fraction of tokens to keep (default: 0.3) + specprefill_draft_model: Path to small draft model for importance scoring """ self._model_name = model_name self._trust_remote_code = trust_remote_code @@ -68,15 +101,35 @@ def __init__( self._is_mllm = force_mllm or is_mllm_model(model_name) self._draft_model_name = draft_model self._num_draft_tokens = num_draft_tokens + self._mtp = mtp self._prefill_step_size = prefill_step_size self._kv_bits = kv_bits self._kv_group_size = kv_group_size + + # SpecPrefill config + self._specprefill_enabled = specprefill_enabled + self._specprefill_threshold = specprefill_threshold + self._specprefill_keep_pct = specprefill_keep_pct + self._specprefill_draft_model_path = specprefill_draft_model + self._model = None self._loaded = False + # Per-request routing state (MLLM+MTP mode) + self._text_model = None + self._text_tokenizer = None + + # SpecPrefill draft model (loaded at start if enabled) + self._draft_model = None + # Lock to serialize MLX operations (prevents Metal command buffer conflicts) self._generation_lock = asyncio.Lock() + # System prompt KV cache (reduces repeated prefill across requests) + self._system_kv_snapshot = None # List of (keys, values) per backbone layer + self._system_kv_hash = None # Hash of system prefix text + self._system_kv_token_count = 0 # Tokens in cached prefix + @property def model(self): """Get the underlying MLXLanguageModel instance.""" @@ -142,6 +195,7 @@ async def start(self) -> None: trust_remote_code=self._trust_remote_code, draft_model=self._draft_model_name, num_draft_tokens=self._num_draft_tokens, + mtp=self._mtp, prefill_step_size=self._prefill_step_size, kv_bits=self._kv_bits, kv_group_size=self._kv_group_size, @@ -150,17 +204,81 @@ async def start(self) -> None: self._model.load() self._loaded = True + # Build parallel mlx_lm TextModel for text-only MTP routing + if self._is_mllm and self._mtp: + try: + from ..text_model_from_vlm import build_text_model + + self._text_model = build_text_model(self._model.model, self._model_name) + + if ( + self._text_model is not None + and hasattr(self._text_model, "mtp") + and self._text_model.mtp is not None + ): + self._text_tokenizer = self._model.get_tokenizer() + + # Apply Qwen3.5 eos_token fix (matches MLXLanguageModel.load) + if "qwen3" in self._model_name.lower(): + self._text_tokenizer.eos_token = "<|im_end|>" + self._text_tokenizer.eos_token_id = ( + self._text_tokenizer.convert_tokens_to_ids("<|im_end|>") + ) + + logger.info( + "MLLM+MTP routing: text-only → mlx_lm TextModel (MTP=True), " + "media → mlx_vlm" + ) + else: + logger.warning( + "TextModel built but no MTP — text-only requests won't use MTP" + ) + self._text_model = None + + except Exception as e: + logger.error("MLLM+MTP routing setup failed: %s", e) + self._text_model = None + self._text_tokenizer = None + + # Load SpecPrefill draft model (small model for importance scoring) + if self._specprefill_enabled and self._specprefill_draft_model_path: + try: + from mlx_lm import load as mlx_lm_load + + self._draft_model, _ = mlx_lm_load(self._specprefill_draft_model_path) + logger.info( + "SpecPrefill: draft model loaded (%s), threshold=%d, keep=%.0f%%", + self._specprefill_draft_model_path, + self._specprefill_threshold, + self._specprefill_keep_pct * 100, + ) + except Exception as e: + logger.error("SpecPrefill: draft model load failed: %s", e) + self._draft_model = None + spec_info = "" if self._draft_model_name and not self._is_mllm: spec_info = f", speculative={self._draft_model_name}" + mtp_info = f", MTP={self._mtp}" if self._mtp else "" + routing = ", routing=per-request" if self._text_model is not None else "" + specprefill_info = ( + ", SpecPrefill=active" if self._draft_model is not None else "" + ) logger.info( - f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm}{spec_info})" + f"SimpleEngine loaded: {self._model_name} " + f"(MLLM={self._is_mllm}{spec_info}{mtp_info}{routing}{specprefill_info})" ) async def stop(self) -> None: """Stop the engine and cleanup resources.""" self._model = None + self._text_model = None + self._text_tokenizer = None + self._draft_model = None self._loaded = False + self._system_kv_snapshot = None + self._system_kv_hash = None + self._system_kv_token_count = 0 logger.info("SimpleEngine stopped") async def generate( @@ -239,6 +357,56 @@ async def stream_generate( if not self._loaded: await self.start() + # Per-request specprefill overrides (from extra_body) + specprefill_override = kwargs.pop("specprefill", None) + specprefill_keep_pct_override = kwargs.pop("specprefill_keep_pct", None) + + # SpecPrefill for non-MLLM models (MLLM+MTP handles it in _stream_generate_text) + if not self._is_mllm and self._draft_model is not None: + use_specprefill = True + if specprefill_override is False: + use_specprefill = False + + if use_specprefill: + tokenizer = self._model.tokenizer + add_special = tokenizer.bos_token is None or not prompt.startswith( + tokenizer.bos_token + ) + tokens_list = tokenizer.encode(prompt, add_special_tokens=add_special) + n_tokens = len(tokens_list) + + # Threshold check (skip when force-enabled via per-request override) + if ( + specprefill_override is not True + and n_tokens <= self._specprefill_threshold + ): + use_specprefill = False + + # Upper bound: cap to avoid draft model OOM + _SPECPREFILL_MAX_TOKENS = 65536 + if use_specprefill and n_tokens > _SPECPREFILL_MAX_TOKENS: + logger.warning( + "SpecPrefill: prompt %d tokens exceeds max %d, " + "falling back to normal path", + n_tokens, + _SPECPREFILL_MAX_TOKENS, + ) + use_specprefill = False + + if use_specprefill: + async for output in self._stream_generate_specprefill( + prompt, + tokens_list, + max_tokens, + temperature, + top_p, + stop=stop, + specprefill_keep_pct=specprefill_keep_pct_override, + **kwargs, + ): + yield output + return + async with self._generation_lock: accumulated_text = "" prompt_tokens = 0 @@ -350,9 +518,42 @@ async def chat( # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None + # Text-only MTP routing — BEFORE the lock because + # _stream_generate_text() acquires _generation_lock internally. + if ( + self._is_mllm + and self._text_model is not None + and not _has_media_content(messages) + ): + logger.info("Text-only request → LLM path (MTP=True) [non-streaming]") + last_chunk = None + async for chunk in self._stream_generate_text( + messages, + max_tokens, + temperature, + top_p, + stop=stop, + tools=template_tools, + **kwargs, + ): + last_chunk = chunk + if last_chunk is not None: + # _stream_generate_text yields accumulated text, not deltas + return GenerationOutput( + text=last_chunk.text, + tokens=[], + prompt_tokens=last_chunk.prompt_tokens, + completion_tokens=last_chunk.completion_tokens, + finish_reason=last_chunk.finish_reason or "stop", + ) + return GenerationOutput( + text="", tokens=[], prompt_tokens=0, + completion_tokens=0, finish_reason="stop", + ) + async with self._generation_lock: if self._is_mllm: - # For MLLM, use the chat method which handles images/videos + # For MLLM with media, use the chat method which handles images/videos # Run in thread pool to allow asyncio timeout to work output = await asyncio.to_thread( self._model.chat, @@ -496,8 +697,29 @@ async def stream_chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + # Per-request routing: text-only through mlx_lm with MTP + if ( + self._is_mllm + and self._text_model is not None + and not _has_media_content(messages) + ): + logger.info("Text-only request → LLM path (MTP=True)") + async for chunk in self._stream_generate_text( + messages, + max_tokens, + temperature, + top_p, + stop=stop, + tools=template_tools, + **kwargs, + ): + yield chunk + return + # Build prompt using tokenizer if self._is_mllm: + if self._text_model is not None: + logger.info("Media request → MLLM path") # For MLLM, use stream_chat which yields tokens incrementally. # Must hold the generation lock to prevent concurrent Metal # command buffer conflicts with other generation methods. @@ -563,6 +785,633 @@ async def stream_chat( ): yield output + async def _stream_generate_specprefill( + self, + prompt: str, + tokens: list[int], + max_tokens: int, + temperature: float, + top_p: float, + stop: list[str] | None = None, + specprefill_keep_pct: float | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """SpecPrefill path for non-MTP models (Nemotron, GPT-OSS, etc). + + Scores token importance with the draft model, sparse-prefills the target + model, then generates autoregressively. Falls back to normal generation + on any error. + """ + import mlx.core as mx + from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_sampler + + model = self._model.model + tokenizer = self._model.tokenizer + n_tokens = len(tokens) + + async with self._generation_lock: + + def _run_all(): + try: + return _run_specprefill() + except Exception as e: + logger.error( + "SpecPrefill failed, falling back to normal path: %s", e + ) + return _run_normal() + + def _run_specprefill(): + """Score tokens, sparse prefill, generate autoregressively.""" + import time + from types import SimpleNamespace + + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) + + cache = make_prompt_cache(model) + + try: + # Phase 1: Score with draft model + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + tokens, + prefill_step_size=self._prefill_step_size, + ) + t_score = time.monotonic() - t0 + + # Phase 2: Select important chunks + effective_keep = specprefill_keep_pct or self._specprefill_keep_pct + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] + + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + model, + tokens, + selected, + cache, + step_size=self._prefill_step_size, + ) + t_prefill = time.monotonic() - t0 + + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs", + n_tokens, + t_score, + n_selected, + n_tokens, + n_selected / n_tokens * 100, + t_prefill, + ) + + # Phase 4: Generate (simple autoregressive, no MTP) + sampler = make_sampler(temp=temperature, top_p=top_p) + eos_id = tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) + + results = [] + generated_ids = [] + prev_decoded = "" + + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) + + decoded = tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded + + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, + ) + ) + + if is_eos: + break + + logits = model(y.reshape(1, -1), cache=cache) + y = sampler(logits[:, -1, :]) + mx.eval(y) + + return results + + finally: + cleanup_rope(model) + + def _run_normal(): + """Fallback: normal generation without specprefill.""" + from types import SimpleNamespace + + results = [] + for chunk in self._model.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ): + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + results.append( + SimpleNamespace( + text=new_text, + finish_reason=getattr(chunk, "finish_reason", None), + ) + ) + return results + + all_resps = await asyncio.to_thread(_run_all) + + # Yield results as GenerationOutput + accumulated_text = "" + token_count = 0 + finished = False + for i, resp in enumerate(all_resps): + token_count += 1 + new_text = resp.text + accumulated_text += new_text + + is_last = i == len(all_resps) - 1 + finished = is_last or token_count >= max_tokens + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=n_tokens, + completion_tokens=token_count, + finished=finished, + finish_reason=resp.finish_reason or ("stop" if finished else None), + ) + + if finished: + break + + if not finished: + yield GenerationOutput( + text=accumulated_text, + new_text="", + prompt_tokens=n_tokens, + completion_tokens=token_count, + finished=True, + finish_reason="length", + ) + + async def _stream_generate_text( + self, + messages: list[dict[str, Any]], + max_tokens: int, + temperature: float, + top_p: float, + stop: list[str] | None = None, + tools: list | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Text-only generation via mlx_lm TextModel with MTP. + + Used when MLLM+MTP routing is active and the request has no media. + Runs the full generation in a single thread to maintain Metal safety. + + System prompt KV caching: on the first request, prefills system tokens + and snapshots backbone KV state. Subsequent requests with the same + system prompt restore the snapshot and only prefill the suffix tokens. + """ + import hashlib + import os + + import mlx.core as mx + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_sampler + + # Per-request specprefill overrides (from extra_body) + specprefill_override = kwargs.pop("specprefill", None) + specprefill_keep_pct = kwargs.pop("specprefill_keep_pct", None) + + # Read enable_thinking from env (set by runtime_patches, consistent with MLLM path) + enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") + enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") + + # Apply chat template for full prompt + template_kwargs = { + "tokenize": False, + "add_generation_prompt": True, + "enable_thinking": enable_thinking, + } + if tools: + template_kwargs["tools"] = tools + + try: + full_prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + except TypeError: + # Template doesn't accept tools= or enable_thinking= + template_kwargs.pop("tools", None) + template_kwargs.pop("enable_thinking", None) + full_prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + + # Build sampler + sampler = make_sampler(temp=temperature, top_p=top_p) + max_tokens = max_tokens or 4096 + + # --- System prompt KV caching --- + backbone_cache = None # Backbone-only cache (no MTP), used by both paths + prompt_to_send = full_prompt # Default: send full prompt text + cache_hit = False + system_token_count = 0 + full_token_count = 0 + system_hash = None + system_tokens = None + suffix_tokens = None + full_tokens_list = None + + # Extract system messages for caching + has_system = any(m.get("role") == "system" for m in messages) + + if has_system and self._text_model is not None: + # Find system prefix boundary in full prompt text. + # ChatML format: system section ends where first non-system message begins. + # Works with tools (rendered inside system section by Qwen templates). + system_prefix_end = -1 + for marker in ("<|im_start|>user\n", "<|im_start|>assistant\n"): + idx = full_prompt.find(marker) + if idx > 0: + system_prefix_end = idx + break + + if system_prefix_end > 0: + system_prefix_text = full_prompt[:system_prefix_end] + system_hash = hashlib.sha256(system_prefix_text.encode()).hexdigest()[ + :16 + ] + + # Tokenize both (matching stream_generate's tokenization logic) + tokenizer = self._text_tokenizer + add_special = tokenizer.bos_token is None or not full_prompt.startswith( + tokenizer.bos_token + ) + full_tokens_list = tokenizer.encode( + full_prompt, add_special_tokens=add_special + ) + full_token_count = len(full_tokens_list) + + system_tokens_list = tokenizer.encode( + system_prefix_text, add_special_tokens=add_special + ) + system_token_count = len(system_tokens_list) + + # Verify system tokens are a proper prefix of full tokens + prefix_valid = ( + len(full_tokens_list) > system_token_count + and full_tokens_list[:system_token_count] == system_tokens_list + ) + + if prefix_valid: + system_tokens = system_tokens_list + suffix_tokens = full_tokens_list[system_token_count:] + + if ( + system_hash == self._system_kv_hash + and self._system_kv_snapshot is not None + and system_token_count == self._system_kv_token_count + ): + # Cache HIT — restore KV state into fresh backbone cache + backbone_cache = make_prompt_cache(self._text_model) + for i, saved_state in enumerate(self._system_kv_snapshot): + backbone_cache[i].state = saved_state + + prompt_to_send = mx.array(suffix_tokens) + cache_hit = True + logger.info( + "System KV cache HIT: reusing %d cached tokens, " + "prefilling %d new tokens (hash=%s)", + system_token_count, + len(suffix_tokens), + system_hash, + ) + else: + # Cache MISS — will prefill system tokens and snapshot + logger.info( + "System KV cache MISS: will prefill %d system tokens, " + "%d suffix tokens (hash=%s)", + system_token_count, + len(suffix_tokens), + system_hash, + ) + else: + logger.debug( + "System KV cache: prefix token validation failed, " + "using full prompt (%d tokens)", + len(full_tokens_list), + ) + system_token_count = 0 + + # Determine if SpecPrefill should be used + # Per-request boolean override: True = force enable, False = force disable + if specprefill_override is False: + use_specprefill = False + elif specprefill_override is True and self._draft_model is not None: + use_specprefill = True # Force enable, skip threshold check + else: + use_specprefill = self._draft_model is not None + + # For specprefill, ensure we have token IDs (not just prompt text) + if use_specprefill and suffix_tokens is None and full_tokens_list is None: + tokenizer = self._text_tokenizer + add_special = tokenizer.bos_token is None or not full_prompt.startswith( + tokenizer.bos_token + ) + full_tokens_list = tokenizer.encode( + full_prompt, add_special_tokens=add_special + ) + full_token_count = len(full_tokens_list) + + # Tokens for specprefill: suffix (if system KV) or full prompt + specprefill_tokens = ( + suffix_tokens if suffix_tokens is not None else full_tokens_list + ) + specprefill_offset = system_token_count if suffix_tokens is not None else 0 + + # Threshold check: only use specprefill on long prompts + # (skipped when per-request boolean forces enable) + if ( + use_specprefill + and specprefill_override is not True + and ( + specprefill_tokens is None + or len(specprefill_tokens) <= self._specprefill_threshold + ) + ): + use_specprefill = False + + # Upper bound: cap specprefill to avoid draft model OOM on very long prompts + # 65536 tokens ~ 2GB draft KV cache on Qwen3.5-4B (32KB/token x 8 attn layers) + _SPECPREFILL_MAX_TOKENS = 65536 + if ( + use_specprefill + and specprefill_tokens is not None + and len(specprefill_tokens) > _SPECPREFILL_MAX_TOKENS + ): + logger.warning( + "SpecPrefill: prompt %d tokens exceeds max %d, " + "falling back to normal path", + len(specprefill_tokens), + _SPECPREFILL_MAX_TOKENS, + ) + use_specprefill = False + + # Run under generation lock, all Metal ops in single thread + async with self._generation_lock: + + def _run_all(): + nonlocal backbone_cache, prompt_to_send + + model = self._text_model + + # Cache MISS with valid prefix: prefill system tokens and snapshot + if ( + not cache_hit + and system_token_count > 0 + and system_tokens is not None + and suffix_tokens is not None + ): + mc = make_prompt_cache(model) + sys_arr = mx.array(system_tokens) + + # Prefill system tokens in chunks (matching generate_step) + step = self._prefill_step_size + while sys_arr.size > step: + model(sys_arr[:step][None], cache=mc) + mx.eval([c.state for c in mc]) + sys_arr = sys_arr[step:] + mx.clear_cache() + if sys_arr.size > 0: + model(sys_arr[None], cache=mc) + mx.eval([c.state for c in mc]) + + # Snapshot backbone cache (immutable mx.arrays, safe to reuse) + snapshot = [c.state for c in mc] + mx.eval([s for pair in snapshot for s in pair]) + + self._system_kv_snapshot = snapshot + self._system_kv_hash = system_hash + self._system_kv_token_count = system_token_count + + backbone_cache = mc + prompt_to_send = mx.array(suffix_tokens) + logger.info( + "System KV cache: stored %d-token snapshot (%.1f MB), " + "prefilling %d remaining", + system_token_count, + sum(c.nbytes for c in mc) / 1e6, + len(suffix_tokens), + ) + + # --- SpecPrefill path (with fallback to normal on failure) --- + if use_specprefill: + try: + return _run_specprefill(model, backbone_cache) + except Exception as e: + logger.error( + "SpecPrefill failed, falling back to normal MTP path: %s", + e, + ) + # Discard potentially corrupted cache + backbone_cache = None + prompt_to_send = full_prompt + + # --- Normal path (MTP via mlx_lm stream_generate) --- + prompt_cache = None + if backbone_cache is not None: + # Add MTP cache on top of backbone + if hasattr(model, "make_mtp_cache"): + mtp_cache = model.make_mtp_cache() + prompt_cache = backbone_cache + mtp_cache + else: + prompt_cache = backbone_cache + + results = [] + gen_kwargs = dict( + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + prefill_step_size=self._prefill_step_size, + ) + if prompt_cache is not None: + gen_kwargs["prompt_cache"] = prompt_cache + + for resp in mlx_stream_generate( + model, + self._text_tokenizer, + prompt=prompt_to_send, + **gen_kwargs, + ): + results.append(resp) + return results + + def _run_specprefill(model, bc): + """Score tokens, sparse prefill, generate without MTP.""" + from types import SimpleNamespace + + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) + + # Create backbone cache if not already from system KV + if bc is None: + bc = make_prompt_cache(model) + + try: + # Phase 1: Score with draft model + import time + + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + specprefill_tokens, + prefill_step_size=self._prefill_step_size, + ) + t_score = time.monotonic() - t0 + + # Phase 2: Select important chunks + effective_keep = specprefill_keep_pct or self._specprefill_keep_pct + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] + n_total = len(specprefill_tokens) + + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + model, + specprefill_tokens, + selected, + bc, + step_size=self._prefill_step_size, + position_offset=specprefill_offset, + ) + t_prefill = time.monotonic() - t0 + + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs " + "(offset=%d, effective_keep=%.2f)", + n_total, + t_score, + n_selected, + n_total, + n_selected / n_total * 100, + t_prefill, + specprefill_offset, + effective_keep, + ) + + # Phase 4: Generate (simple autoregressive, no MTP) + eos_id = self._text_tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) + + results = [] + generated_ids = [] + prev_decoded = "" + + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) + + # Incremental text decode + decoded = self._text_tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded + + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, + ) + ) + + if is_eos: + break + + # Next token + logits = model(y.reshape(1, -1), cache=bc) + y = sampler(logits[:, -1, :]) + mx.eval(y) + + return results + + finally: + cleanup_rope(model) + + all_resps = await asyncio.to_thread(_run_all) + + # Yield results as GenerationOutput + accumulated_text = "" + token_count = 0 + finished = False + for i, resp in enumerate(all_resps): + token_count += 1 + new_text = resp.text if hasattr(resp, "text") else str(resp) + accumulated_text += new_text + + # Check stop sequences (mlx_lm doesn't handle these natively) + stop_hit = False + if stop: + for stop_seq in stop: + idx = accumulated_text.find(stop_seq) + if idx != -1: + # Trim both accumulated and new_text so SSE streams + # never emit the stop sequence or anything after it. + overshoot = len(accumulated_text) - idx + accumulated_text = accumulated_text[:idx] + new_text = new_text[: max(0, len(new_text) - overshoot)] + stop_hit = True + break + + is_last = i == len(all_resps) - 1 + finished = stop_hit or is_last or token_count >= max_tokens + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=full_token_count or 0, + completion_tokens=token_count, + finished=finished, + finish_reason=getattr(resp, "finish_reason", None) + or ("stop" if finished else None), + ) + + if finished: + break + + if not finished: + yield GenerationOutput( + text=accumulated_text, + new_text="", + prompt_tokens=full_token_count or 0, + completion_tokens=token_count, + finished=True, + finish_reason="length", + ) + def get_stats(self) -> dict[str, Any]: """Get engine statistics.""" stats = { @@ -572,6 +1421,29 @@ def get_stats(self) -> dict[str, Any]: "loaded": self._loaded, } + # SpecPrefill stats + if self._draft_model is not None: + stats["specprefill"] = { + "enabled": True, + "draft_model": self._specprefill_draft_model_path, + "threshold": self._specprefill_threshold, + "keep_pct": self._specprefill_keep_pct, + } + + # System KV cache stats + if self._system_kv_snapshot is not None: + cache_bytes = 0 + for entry in self._system_kv_snapshot: + if isinstance(entry, tuple) and len(entry) == 2: + cache_bytes += entry[0].nbytes + entry[1].nbytes + elif isinstance(entry, list): + cache_bytes += sum(a.nbytes for a in entry if a is not None) + stats["system_kv_cache"] = { + "tokens": self._system_kv_token_count, + "hash": self._system_kv_hash, + "memory_mb": round(cache_bytes / 1e6, 1), + } + # Include Metal memory stats try: import mlx.core as mx diff --git a/vllm_mlx/engine_core.py b/vllm_mlx/engine_core.py index f6089dd..92e4d28 100644 --- a/vllm_mlx/engine_core.py +++ b/vllm_mlx/engine_core.py @@ -154,11 +154,17 @@ async def _engine_loop(self) -> None: stream_interval = self.config.stream_interval use_simple_streaming = stream_interval == 1 - # Emergency memory pressure threshold — dynamic based on gpu_memory_utilization - # Uses a 5% gap above the soft limit (capped at 99%) to allow temporary spikes. + # Emergency memory pressure threshold — dynamic based on gpu_memory_utilization. + # Uses Metal's max recommended working set when available, falling back to + # device memory. Applies a 5% gap above the soft limit (capped at 99%). _gpu_mem_util = self.config.gpu_memory_utilization try: - _device_mem = mx.device_info().get("memory_size", 200 * 1024 * 1024 * 1024) + _device_info = mx.device_info() + _max_recommended = _device_info.get( + "max_recommended_working_set_size", + _device_info.get("memory_size", 0), + ) + _device_mem = _max_recommended if _max_recommended > 0 else 200 * 1024 * 1024 * 1024 _memory_pressure_threshold = int( _device_mem * min(_gpu_mem_util + 0.05, 0.99) ) @@ -250,7 +256,9 @@ async def _engine_loop(self) -> None: except asyncio.CancelledError: break except Exception as e: - logger.error(f"Engine loop error: {e}") + import traceback + + logger.error(f"Engine loop error: {e}\n{traceback.format_exc()}") await asyncio.sleep(0.1) async def add_request( diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 2fec268..dd272aa 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -28,6 +28,7 @@ from typing import Any import mlx.core as mx +from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer from .mllm_batch_generator import ( MLLMBatchGenerator, @@ -201,6 +202,9 @@ def __init__( self.request_id_to_uid: dict[str, int] = {} self.uid_to_request_id: dict[int, str] = {} + # Per-request streaming detokenizers for UTF-8-safe incremental decode + self._detokenizer_pool: dict[str, Any] = {} + # Output queues for async streaming self.output_queues: dict[str, asyncio.Queue] = {} @@ -215,6 +219,10 @@ def __init__( self._processing_task: asyncio.Task | None = None self._step_executor = None # ThreadPoolExecutor, created in _process_loop + # Memory management: periodic mx.clear_cache() to free Metal buffer pool + self._step_count = 0 + self._clear_cache_interval = 32 + # Statistics self.num_requests_processed = 0 self.total_prompt_tokens = 0 @@ -383,6 +391,8 @@ def _do_abort_request(self, request_id: str) -> None: self.finished_req_ids.add(request_id) self.requests.pop(request_id, None) + self._detokenizer_pool.pop(request_id, None) + # Do NOT write to output_queues here — this may run on the # executor thread where asyncio.Queue is not safe. Mark for # signaling on the event loop thread via _distribute_outputs. @@ -484,8 +494,21 @@ def _process_batch_responses( request.output_tokens.append(response.token) request.num_output_tokens = len(request.output_tokens) - # Decode the new token - new_text = tokenizer.decode([response.token]) + # Decode the new token using streaming detokenizer (UTF-8 safe). + # Skip stop tokens — they are not content. + if response.finish_reason == "stop": + new_text = "" + else: + if request_id not in self._detokenizer_pool: + if hasattr(tokenizer, "detokenizer"): + detok = tokenizer.detokenizer + else: + detok = NaiveStreamingDetokenizer(tokenizer) + detok.reset() + self._detokenizer_pool[request_id] = detok + detok = self._detokenizer_pool[request_id] + detok.add_token(response.token) + new_text = detok.last_segment # Create output output = RequestOutput( @@ -531,15 +554,22 @@ def _process_batch_responses( output.finish_reason = finish_reason finished_ids.add(request_id) - # Use trimmed output if set by stop-string check, else decode. + # Use trimmed output if set by stop-string check, else + # finalize streaming detokenizer for full output. # Use explicit flag instead of string truthiness — empty string # is a valid trimmed result (stop at position 0). if stop_trimmed: output.output_text = request.output_text else: - output.output_text = tokenizer.decode(request.output_tokens) + detok = self._detokenizer_pool.get(request_id) + if detok is not None: + detok.finalize() + output.output_text = detok.text + else: + output.output_text = tokenizer.decode(request.output_tokens) request.output_text = output.output_text request.finish_reason = finish_reason + self._detokenizer_pool.pop(request_id, None) self.total_completion_tokens += request.num_output_tokens self.num_requests_processed += 1 @@ -642,6 +672,18 @@ def _step_no_queue(self) -> MLLMSchedulerOutput: if finished_ids: mx.clear_cache() + # Adaptive periodic cache clear: scale inversely with concurrency + # to prevent Metal buffer pool growth during long generations + active_seqs = len(self.running) + min_interval = max(4, self._clear_cache_interval // 4) + effective_interval = max( + min_interval, self._clear_cache_interval // max(1, active_seqs // 8) + ) + + self._step_count += 1 + if self._step_count % effective_interval == 0: + mx.clear_cache() + # Clear finished tracking for next step self.finished_req_ids = set() @@ -972,6 +1014,7 @@ def reset(self) -> None: self.finished_req_ids.clear() self.request_id_to_uid.clear() self.uid_to_request_id.clear() + self._detokenizer_pool.clear() if self.batch_generator is not None: self.batch_generator.close() diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index add0981..632d57d 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -62,6 +62,7 @@ def __init__( prefill_step_size: int = 2048, kv_bits: int | None = None, kv_group_size: int = 64, + mtp: bool = False, ): """ Initialize the MLX language model. @@ -75,6 +76,7 @@ def __init__( prefill_step_size: Tokens to process per prefill chunk (default: 2048) kv_bits: KV cache quantization bits (None=no quantization, 4 or 8) kv_group_size: Group size for KV cache quantization (default: 64) + mtp: Enable native MTP speculative decoding (model must have MTP head) """ self.model_name = model_name self.tokenizer_name = tokenizer_name or model_name @@ -84,6 +86,8 @@ def __init__( self.prefill_step_size = prefill_step_size self.kv_bits = kv_bits self.kv_group_size = kv_group_size + self._mtp = mtp + self.model = None self.tokenizer = None self.draft_model = None @@ -585,6 +589,9 @@ def stream_generate( # Create sampler with parameters sampler = self._create_sampler(temperature, top_p) + # Count prompt tokens once upfront + num_prompt_tokens = len(self.tokenizer.encode(prompt)) + token_count = 0 accumulated_text = "" # Use IncrementalDecoder with skip_special_tokens=False to preserve @@ -600,6 +607,10 @@ def stream_generate( "prefill_step_size": self.prefill_step_size, } + # Native MTP speculative decoding + if self._mtp: + gen_kwargs["mtp"] = True + # KV cache quantization reduces memory pressure for long prompts if self.kv_bits is not None: gen_kwargs["kv_bits"] = self.kv_bits diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index 596245d..1bec767 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -466,7 +466,7 @@ def save_base64_image(base64_string: str) -> str: import hashlib # Hash the base64 string to check cache - image_hash = hashlib.md5(base64_string[:1000].encode()).hexdigest() + image_hash = hashlib.md5(base64_string.encode()).hexdigest() # Return cached path if available and file still exists if image_hash in _base64_image_cache: @@ -708,6 +708,7 @@ def __init__( self.processor = None self.config = None self._loaded = False + self._video_native = False # Initialize MLLM prefix cache manager (with vision embedding caching) self._cache_manager: MLLMPrefixCacheManager | None = None @@ -729,7 +730,12 @@ def load(self) -> None: self.config = load_config(self.model_name) self._loaded = True + self._video_native = hasattr( + self.model.config, "video_token_id" + ) or hasattr(self.model.config, "video_token_index") logger.info(f"MLLM loaded successfully: {self.model_name}") + if self._video_native: + logger.info("Native video pipeline enabled (temporal 3D conv + M-RoPE)") except ImportError: raise ImportError( @@ -740,6 +746,14 @@ def load(self) -> None: logger.error(f"Failed to load MLLM: {e}") raise + def get_language_model(self): + """Extract the underlying language model for mlx_lm TextModel construction.""" + return self.model.language_model + + def get_tokenizer(self): + """Get the text tokenizer (not the multimodal processor).""" + return self.processor.tokenizer + def _prepare_images(self, images: list) -> list[str]: """Process image inputs and return local file paths.""" processed = [] @@ -785,6 +799,259 @@ def _prepare_video( ) return save_frames_to_temp(frames) + def _collect_video_inputs(self, messages: list[dict]) -> dict[int, list]: + """Collect video inputs from messages, keyed by message index. + + Handles both 'video' and 'video_url' content types, including + Pydantic model conversion. + """ + video_inputs: dict[int, list] = {} + for msg_idx, msg in enumerate(messages): + content = msg.get("content", "") + if not isinstance(content, list): + continue + for item in content: + if hasattr(item, "model_dump"): + item = item.model_dump(exclude_none=True) + elif hasattr(item, "dict"): + item = {k: v for k, v in item.dict().items() if v is not None} + + if not isinstance(item, dict): + continue + item_type = item.get("type", "") + if item_type == "video": + video_inputs.setdefault(msg_idx, []).append( + item.get("video", item.get("url", "")) + ) + elif item_type == "video_url": + vid_url = item.get("video_url", {}) + if isinstance(vid_url, str): + video_inputs.setdefault(msg_idx, []).append(vid_url) + elif isinstance(vid_url, dict): + url = vid_url.get("url", "") + if url: + video_inputs.setdefault(msg_idx, []).append(url) + return video_inputs + + def _prepare_native_video_inputs( + self, + messages: list[dict], + video_fps: float = DEFAULT_FPS, + video_max_frames: int = MAX_FRAMES, + tools: list | None = None, + ) -> tuple[str, dict]: + """Preprocess messages into prompt + generation kwargs for native video. + + Mirrors the preprocessing in mlx_vlm.video_generate.main() so that + upstream improvements are easy to adopt. Returns the formatted prompt + text and a dict of kwargs ready to pass to video_generate.generate(). + + Currently Qwen-family-specific (video_token_id / video_token_index). + """ + import mlx.core as mx + + try: + from mlx_vlm.video_generate import process_vision_info + except ImportError: + raise ImportError( + "mlx_vlm.video_generate is required for native video support. " + "Upgrade with: pip install --upgrade mlx-vlm" + ) + + # Translate OpenAI API messages into process_vision_info format + native_messages = self._translate_messages_for_native_video( + messages, video_fps, video_max_frames + ) + + # Use HF processor's chat template (handles timestamp interleaving) + template_kwargs: dict = {} + if tools: + template_kwargs["tools"] = tools + + text = self.processor.apply_chat_template( + native_messages, + tokenize=False, + add_generation_prompt=True, + **template_kwargs, + ) + + # Extract vision inputs via mlx-vlm's process_vision_info + image_inputs, video_inputs, fps_info = process_vision_info( + native_messages, return_video_kwargs=True + ) + + # Process through HF processor to get input_ids, pixel_values, grid_thw + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + input_ids = mx.array(inputs["input_ids"]) + pixel_values = inputs.get( + "pixel_values_videos", inputs.get("pixel_values", None) + ) + if pixel_values is not None: + pixel_values = mx.array(pixel_values) + mask = mx.array(inputs["attention_mask"]) + + gen_kwargs: dict = {} + if inputs.get("video_grid_thw", None) is not None: + gen_kwargs["video_grid_thw"] = mx.array(inputs["video_grid_thw"]) + if inputs.get("image_grid_thw", None) is not None: + gen_kwargs["image_grid_thw"] = mx.array(inputs["image_grid_thw"]) + + gen_kwargs["input_ids"] = input_ids + gen_kwargs["pixel_values"] = pixel_values + gen_kwargs["mask"] = mask + + grid_thw_info = gen_kwargs.get("video_grid_thw") + logger.info( + f"Native video: {input_ids.size} input tokens, " + f"video_grid_thw={grid_thw_info.tolist() if grid_thw_info is not None else None}" + ) + + return text, gen_kwargs + + def _generate_native_video( + self, + messages: list[dict], + max_tokens: int = 256, + temperature: float = 0.7, + video_fps: float = DEFAULT_FPS, + video_max_frames: int = MAX_FRAMES, + tools: list | None = None, + **kwargs, + ) -> MLLMOutput: + """Generate using native video pipeline (Qwen-family models). + + Delegates preprocessing to _prepare_native_video_inputs and generation + to mlx_vlm.video_generate.generate(), keeping our code aligned with + upstream's video pipeline so improvements are easy to adopt. + """ + try: + from mlx_vlm.video_generate import generate + except ImportError: + raise ImportError( + "mlx_vlm.video_generate is required for native video support. " + "Upgrade with: pip install --upgrade mlx-vlm" + ) + + text, gen_kwargs = self._prepare_native_video_inputs( + messages, video_fps, video_max_frames, tools + ) + gen_kwargs["temperature"] = temperature + + result = generate( + self.model, + self.processor, + prompt=text, + max_tokens=max_tokens, + verbose=False, + **gen_kwargs, + ) + + if hasattr(result, "text"): + return MLLMOutput( + text=result.text, + finish_reason="stop", + prompt_tokens=getattr(result, "prompt_tokens", 0), + completion_tokens=getattr(result, "generation_tokens", 0), + ) + return MLLMOutput(text=str(result), finish_reason="stop") + + def _translate_messages_for_native_video( + self, + messages: list[dict], + video_fps: float, + video_max_frames: int, + ) -> list[dict]: + """Translate OpenAI API format messages to process_vision_info format. + + Converts video_url/video types and resolves URLs/base64 to local paths. + Images are preserved as-is (process_vision_info handles them). + """ + translated = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if isinstance(content, str): + translated.append({"role": role, "content": content}) + continue + + if not isinstance(content, list): + translated.append({"role": role, "content": str(content)}) + continue + + new_content = [] + for item in content: + if hasattr(item, "model_dump"): + item = item.model_dump(exclude_none=True) + elif hasattr(item, "dict"): + item = {k: v for k, v in item.dict().items() if v is not None} + + if not isinstance(item, dict): + new_content.append({"type": "text", "text": str(item)}) + continue + + item_type = item.get("type", "") + + if item_type == "text": + new_content.append(item) + + elif item_type == "image_url": + img_url = item.get("image_url", {}) + url = ( + img_url.get("url", img_url) + if isinstance(img_url, dict) + else img_url + ) + # Resolve to local path for process_vision_info + local_path = process_image_input(url) + new_content.append({"type": "image", "image": local_path}) + + elif item_type == "image": + img = item.get("image", item.get("url", "")) + local_path = process_image_input(img) + new_content.append({"type": "image", "image": local_path}) + + elif item_type in ("video", "video_url"): + # Extract video path/URL from various formats + if item_type == "video_url": + vid_url = item.get("video_url", {}) + if isinstance(vid_url, str): + video_source = vid_url + elif isinstance(vid_url, dict): + video_source = vid_url.get("url", "") + else: + continue + else: + video_source = item.get("video", item.get("url", "")) + + if not video_source: + continue + + # Resolve to local path + video_path = process_video_input(video_source) + new_content.append( + { + "type": "video", + "video": video_path, + "fps": video_fps, + "max_frames": video_max_frames, + } + ) + + else: + new_content.append(item) + + translated.append({"role": role, "content": new_content}) + + return translated + def generate( self, prompt: str, @@ -1052,12 +1319,47 @@ def chat( # Extract text and images from messages # Build chat_messages for multi-turn support WITH proper image tokens per message all_image_urls = [] # Raw URLs/paths to process later - videos = [] chat_messages = [] # List of properly formatted messages for chat template logger.info(f"MLLM.chat() called with {len(messages)} messages") - for msg in messages: + # Pop params early so they don't leak into mlx_vlm.generate() + video_fps = kwargs.pop("video_fps", DEFAULT_FPS) + video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) + tools = kwargs.pop("tools", None) + use_cache = kwargs.pop("use_cache", True) + + # Collect video inputs from messages + _msg_video_inputs = self._collect_video_inputs(messages) + + # Use native video pipeline for supported models + if self._video_native and _msg_video_inputs: + return self._generate_native_video( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + video_fps=video_fps, + video_max_frames=video_max_frames, + tools=tools, + **kwargs, + ) + + # Fallback: extract frames and treat as individual images + _msg_video_frame_counts: dict[int, int] = {} + all_video_frames: list[str] = [] + for msg_idx, vid_inputs in _msg_video_inputs.items(): + total_frames = 0 + for vid_input in vid_inputs: + frames = self._prepare_video( + vid_input, fps=video_fps, max_frames=video_max_frames + ) + all_video_frames.extend(frames) + total_frames += len(frames) + logger.info(f"Added {len(frames)} frames from video: {vid_input}") + _msg_video_frame_counts[msg_idx] = total_frames + + # Second pass: build chat messages with image counts that include video frames + for msg_idx, msg in enumerate(messages): role = msg.get("role", "user") content = msg.get("content", "") msg_text = "" # Text content for this message @@ -1099,8 +1401,8 @@ def chat( ) msg_image_count += 1 - elif item_type == "video": - videos.append(item.get("video", item.get("url", ""))) + # Add video frame count to image count for this message + msg_image_count += _msg_video_frame_counts.get(msg_idx, 0) # Build properly structured message for Qwen3-VL-MoE # Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "text", "text": "..."}]} @@ -1132,16 +1434,8 @@ def chat( all_images = [] if all_image_urls: all_images.extend(self._prepare_images(all_image_urls)) - - # Process videos - video_fps = kwargs.pop("video_fps", DEFAULT_FPS) - video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) - for video_path in videos: - frames = self._prepare_video( - video_path, fps=video_fps, max_frames=video_max_frames - ) - all_images.extend(frames) - logger.info(f"Added {len(frames)} frames from video: {video_path}") + # Append pre-processed video frames + all_images.extend(all_video_frames) # Apply chat template directly - messages are already properly structured logger.info( @@ -1153,8 +1447,7 @@ def chat( f" Chat msg {i}: role={cm['role']}, content={content_preview}..." ) - # Pop tools so they don't leak into mlx_vlm.generate()/stream_generate() - tools = kwargs.pop("tools", None) + # Build template kwargs for tool definitions (tools already popped above) template_extra_kwargs = {} if tools: template_extra_kwargs["tools"] = tools @@ -1192,7 +1485,8 @@ def chat( from mlx_vlm.models import cache as vlm_cache - use_cache = kwargs.pop("use_cache", True) + # use_cache was already popped near the top of chat() — don't re-pop + # with a default of True, as that would overwrite a caller's False. cache_entry = None prefix_match_len = 0 vision_embeddings = None @@ -1429,26 +1723,64 @@ def stream_chat( # Extract text and images from messages # Build chat_messages for multi-turn support WITH proper image tokens per message all_image_urls = [] # Raw URLs/paths to process later - videos = [] chat_messages = [] # List of properly formatted messages for chat template - for msg in messages: + # Pop params early so they don't leak into mlx_vlm.generate() + video_fps = kwargs.pop("video_fps", DEFAULT_FPS) + video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) + tools = kwargs.pop("tools", None) + use_cache = kwargs.pop("use_cache", True) + + # Collect video inputs from messages + _msg_video_inputs = self._collect_video_inputs(messages) + + # Use native video pipeline for supported models. + # NOTE: Native video yields a single chunk (not incremental streaming) + # because mlx_vlm.video_generate has no streaming API. The event loop + # is NOT blocked at the server level — SimpleEngine wraps this in + # asyncio.to_thread(). True token-level streaming requires upstream + # mlx-vlm support for video stream_generate. + if self._video_native and _msg_video_inputs: + output = self._generate_native_video( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + video_fps=video_fps, + video_max_frames=video_max_frames, + tools=tools, + **kwargs, + ) + yield output + return + + # Fallback: frames as images + _msg_video_frame_counts: dict[int, int] = {} + all_video_frames: list[str] = [] + for msg_idx, vid_inputs in _msg_video_inputs.items(): + total_frames = 0 + for vid_input in vid_inputs: + frames = self._prepare_video( + vid_input, fps=video_fps, max_frames=video_max_frames + ) + all_video_frames.extend(frames) + total_frames += len(frames) + logger.info(f"Added {len(frames)} frames from video: {vid_input}") + _msg_video_frame_counts[msg_idx] = total_frames + + for msg_idx, msg in enumerate(messages): role = msg.get("role", "user") content = msg.get("content", "") - msg_text = "" # Text content for this message - msg_image_count = 0 # Number of images in THIS message + msg_text = "" + msg_image_count = 0 if isinstance(content, str): msg_text = content elif isinstance(content, list): - # OpenAI multimodal format - extract text and count images for THIS message for item in content: if isinstance(item, str): msg_text += item continue - # Convert Pydantic models to dicts, excluding None fields - # to avoid null keys like image_url: null on text parts if hasattr(item, "model_dump"): item = item.model_dump(exclude_none=True) elif hasattr(item, "dict"): @@ -1474,14 +1806,10 @@ def stream_chat( ) msg_image_count += 1 - elif item_type == "video": - videos.append(item.get("video", item.get("url", ""))) + msg_image_count += _msg_video_frame_counts.get(msg_idx, 0) - # Build properly structured message for Qwen3-VL-MoE - # Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "text", "text": "..."}]} if msg_text or msg_image_count > 0: if role == "user" and msg_image_count > 0: - # User message WITH images - build content array with image tokens FIRST content_list = [] for _ in range(msg_image_count): content_list.append({"type": "image"}) @@ -1490,10 +1818,8 @@ def stream_chat( ) chat_messages.append({"role": role, "content": content_list}) elif role == "assistant": - # Assistant messages - just text content (not array) chat_messages.append({"role": role, "content": msg_text}) else: - # User/system message WITHOUT images - still use content array format chat_messages.append( { "role": role, @@ -1503,29 +1829,17 @@ def stream_chat( } ) - # Process images all_images = [] if all_image_urls: all_images.extend(self._prepare_images(all_image_urls)) + all_images.extend(all_video_frames) - # Process videos - video_fps = kwargs.pop("video_fps", DEFAULT_FPS) - video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) - for video_path in videos: - frames = self._prepare_video( - video_path, fps=video_fps, max_frames=video_max_frames - ) - all_images.extend(frames) - - # Apply chat template directly - messages are already properly structured - # Pop tools so they don't leak into mlx_vlm.generate()/stream_generate() - tools = kwargs.pop("tools", None) + # Build template kwargs for tool definitions (tools already popped above) template_extra_kwargs = {} if tools: template_extra_kwargs["tools"] = tools try: - # Use get_chat_template directly since messages are already properly formatted formatted_prompt = get_chat_template( self.processor, chat_messages, @@ -1556,7 +1870,6 @@ def stream_chat( prompt_cache = None cache_hit = False - use_cache = kwargs.pop("use_cache", True) if use_cache and self._cache_manager is not None and all_images: prompt_cache, cache_hit = self._cache_manager.fetch_cache( diff --git a/vllm_mlx/patches/qwen3_next_mtp.py b/vllm_mlx/patches/qwen3_next_mtp.py index 427c7b4..6d07b36 100644 --- a/vllm_mlx/patches/qwen3_next_mtp.py +++ b/vllm_mlx/patches/qwen3_next_mtp.py @@ -1,26 +1,186 @@ # SPDX-License-Identifier: Apache-2.0 """ -Runtime MTP (Multi-Token Prediction) validation for Qwen3-Next models. +Runtime MTP (Multi-Token Prediction) support for Qwen3-Next models. Qwen3-Next models may include a built-in MTP head that predicts token n+2 -from hidden states + token n+1. When MTP weights have been added to the -quantized MLX model (via scripts/add_mtp_weights.py), mlx_lm.load() -automatically instantiates the MTP module (model.mtp). - -This module provides a lightweight validation function that checks whether -a loaded model has a working MTP head and logs diagnostic information. -The actual MTP logic lives in: - - mlx_lm/models/qwen3_next.py (Model.__call__ with return_hidden, - Model.mtp_forward, Model.make_mtp_cache) +from hidden states + token n+1. MTP weights are added to the quantized +MLX model via scripts/add_mtp_weights.py. + +Since mlx_lm's qwen3_next.py does NOT define MTP module/methods, this +module provides: + - inject_mtp_support(): dynamically creates MTP module, loads weights, + and monkey-patches the model class with return_hidden, mtp_forward, + and make_mtp_cache + - validate_mtp_support(): checks whether a loaded model has working MTP + +The actual MTP scheduling logic lives in: - vllm_mlx/scheduler.py (_install_mtp, _mtp_step, _mtp_next) """ import logging +from pathlib import Path from typing import Any logger = logging.getLogger(__name__) +def inject_mtp_support(model: Any, model_path, config: dict) -> bool: + """Inject MTP module into a loaded Qwen3-Next model. + + mlx_lm's qwen3_next.py does not define MTP layers, so we: + 1. Create MTP module matching the weight structure + 2. Quantize it to match the base model + 3. Load MTP weights from model-mtp.safetensors + 4. Monkey-patch Model with return_hidden, mtp_forward, make_mtp_cache + + Args: + model: A model loaded via mlx_lm (strict=False, MTP weights ignored) + model_path: Path to model directory (contains model-mtp.safetensors) + config: Parsed config.json dict + + Returns: + True if MTP was successfully injected, False otherwise. + """ + import mlx.core as mx + import mlx.nn as nn + + num_mtp_layers = config.get("num_nextn_predict_layers", 0) + if num_mtp_layers == 0: + logger.info("[MTP inject] num_nextn_predict_layers=0, skipping") + return False + + model_path = Path(model_path) + mtp_file = model_path / "model-mtp.safetensors" + if not mtp_file.exists(): + logger.warning(f"[MTP inject] model-mtp.safetensors not found in {model_path}") + return False + + args = model.args + + # Import model components + from mlx_lm.models.base import create_attention_mask, create_ssm_mask + from mlx_lm.models.cache import KVCache + from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer + + # --- Step 1: Create MTP module --- + logger.info(f"[MTP inject] Creating MTP module ({num_mtp_layers} layers)") + + class _MTPModule(nn.Module): + def __init__(self, args, n_layers): + super().__init__() + self.pre_fc_norm_hidden = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.pre_fc_norm_embedding = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) + # MTP decoder uses full attention (not linear/delta-net) + fa_idx = args.full_attention_interval - 1 + self.layers = [ + Qwen3NextDecoderLayer(args, layer_idx=fa_idx) for _ in range(n_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + mtp = _MTPModule(args, num_mtp_layers) + + # --- Step 2: Quantize MTP module to match base model --- + quant_config = config.get("quantization", {}) + if quant_config: + bits = quant_config.get("bits", 6) + group_size = quant_config.get("group_size", 64) + + def _mtp_quant_pred(path, module): + # Only quantize Linear layers + if not isinstance(module, nn.Linear): + return False + # fc kept as FP (concat projection) + if path == "fc": + return False + # Gate layers at 8-bit (matching base model) + if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): + return {"group_size": 64, "bits": 8} + return True + + nn.quantize( + mtp, group_size=group_size, bits=bits, class_predicate=_mtp_quant_pred + ) + logger.info(f"[MTP inject] Quantized MTP: {bits}-bit, group_size={group_size}") + + # --- Step 3: Load MTP weights --- + logger.info(f"[MTP inject] Loading weights from {mtp_file.name}") + raw = mx.load(str(mtp_file)) + mtp_weights = { + k.removeprefix("mtp."): v for k, v in raw.items() if k.startswith("mtp.") + } + mtp.load_weights(list(mtp_weights.items()), strict=False) + mx.eval(mtp.parameters()) + logger.info(f"[MTP inject] Loaded {len(mtp_weights)} MTP weight tensors") + + # --- Step 4: Attach MTP and monkey-patch model class --- + model.mtp = mtp + + original_class = model.__class__ + + class _Qwen3NextMTP(original_class): + """Qwen3-Next with MTP support (injected at runtime).""" + + def __call__( + self, + inputs, + cache=None, + return_hidden: bool = False, + ): + inner = self.model + hidden_states = inner.embed_tokens(inputs) + if cache is None: + cache = [None] * len(inner.layers) + fa_mask = create_attention_mask(hidden_states, cache[inner.fa_idx]) + ssm_mask = create_ssm_mask(hidden_states, cache[inner.ssm_idx]) + for layer, c in zip(inner.layers, cache): + mask = ssm_mask if layer.is_linear else fa_mask + hidden_states = layer(hidden_states, mask=mask, cache=c) + normed = inner.norm(hidden_states) + if self.args.tie_word_embeddings: + out = inner.embed_tokens.as_linear(normed) + else: + out = self.lm_head(normed) + if return_hidden: + return out, hidden_states # pre-norm hidden states + return out + + def mtp_forward( + self, + hidden_states, + next_token_ids, + cache=None, + mtp_cache=None, + ): + """Run MTP head: predict token n+2 from hidden states + token n+1.""" + input_embeds = self.model.embed_tokens(next_token_ids) + h = self.mtp.pre_fc_norm_hidden(hidden_states) + e = self.mtp.pre_fc_norm_embedding(input_embeds) + x = self.mtp.fc(mx.concatenate([h, e], axis=-1)) + layer = self.mtp.layers[0] + c = mtp_cache[0] if mtp_cache else None + mask = create_attention_mask(x, c) + x = layer(x, mask=mask, cache=c) + x = self.mtp.norm(x) + if self.args.tie_word_embeddings: + return self.model.embed_tokens.as_linear(x) + return self.lm_head(x) + + def make_mtp_cache(self): + """Create KV cache for MTP layers.""" + if self.mtp is None: + return None + return [KVCache() for _ in self.mtp.layers] + + model.__class__ = _Qwen3NextMTP + logger.info("[MTP inject] Model class patched with MTP support") + return True + + def validate_mtp_support(model: Any) -> bool: """Validate that a loaded model has working MTP support. @@ -41,7 +201,12 @@ def validate_mtp_support(model: Any) -> bool: mtp = getattr(model, "mtp", None) if mtp is None: num_mtp = 0 + # Try model.args (Qwen3-Next) and model.language_model.args (Qwen3.5) args = getattr(model, "args", None) + if args is None: + lm = getattr(model, "language_model", None) + if lm is not None: + args = getattr(lm, "args", None) if args is not None: num_mtp = getattr(args, "num_nextn_predict_layers", 0) if num_mtp > 0: @@ -89,7 +254,7 @@ def validate_mtp_support(model: Any) -> bool: num_layers = getattr(args, "num_nextn_predict_layers", 0) if args else 0 logger.info( "[MTP] Model has working MTP support: " - "%d MTP layer(s), %d predictor decoder layer(s)", + "%s MTP layer(s), %d predictor decoder layer(s)", num_layers, len(mtp_layers), ) diff --git a/vllm_mlx/plugin.py b/vllm_mlx/plugin.py index b3beca4..04d152c 100644 --- a/vllm_mlx/plugin.py +++ b/vllm_mlx/plugin.py @@ -67,7 +67,7 @@ def mlx_platform_plugin() -> str | None: pass logger.info("MLX platform is available on Apple Silicon") - return "vllm_mlx.platform.MLXPlatform" + return "vllm_mlx.vllm_platform.MLXPlatform" def is_mlx_available() -> bool: diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 7de1cea..1c3fea2 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -20,6 +20,7 @@ import mlx.core as mx from mlx_lm.generate import BatchGenerator from mlx_lm.sample_utils import make_sampler +from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig from .paged_cache import PagedCacheManager @@ -399,6 +400,7 @@ def _chunked_next(self=batch_gen): # noqa: C901 caches, samplers, logits_processors, + _prompt_checkpoints, ) = zip(*batch_prompts) lengths = [len(p) for p in inputs_raw] max_length = max(lengths) @@ -410,7 +412,9 @@ def _chunked_next(self=batch_gen): # noqa: C901 if not is_cached: padded = _left_pad_prompts(inputs_raw, max_length=max_length) - prompt_cache = _make_cache(self.model, padding) + prompt_cache = _make_cache( + self.model, padding, self.max_kv_size + ) else: last_inputs = mx.array([p[-1:] for p in inputs_raw]) padded = _right_pad_prompts(inputs_raw, max_length=max_length) @@ -783,7 +787,11 @@ def _mtp_step( # (both P and D) for all cache types, then # re-advance with just P for a consistent state. for c in prompt_cache: - if hasattr(c, "is_trimmable") and c.is_trimmable(): + if ( + hasattr(c, "is_trimmable") + and c.is_trimmable() + and hasattr(c, "trim") + ): c.trim(2) for _ci, _snap in _rnn_snapshots.items(): prompt_cache[_ci].state = _snap @@ -813,7 +821,11 @@ def _mtp_step( else: # Pure attention model: simple trim(1) is enough. for c in prompt_cache: - if hasattr(c, "is_trimmable") and c.is_trimmable(): + if ( + hasattr(c, "is_trimmable") + and c.is_trimmable() + and hasattr(c, "trim") + ): c.trim(1) if verify_hidden is not None: _skip_state[0] = { @@ -999,6 +1011,9 @@ def __init__( # Detect if tokenizer is a processor (MLLM) and get the actual tokenizer self._actual_tokenizer = self._get_actual_tokenizer(tokenizer) + # Per-request streaming detokenizers for UTF-8-safe incremental decode + self._detokenizer_pool: dict[str, Any] = {} + # Request management - following vLLM's design self.waiting: deque[Request] = deque() # Waiting queue (FCFS) self.running: dict[str, Request] = {} # Running requests by ID @@ -1099,6 +1114,21 @@ def _decode_tokens(self, token_ids: list[int]) -> str: """ return self._actual_tokenizer.decode(token_ids) + def _get_detokenizer(self, request_id: str) -> Any: + """Get or create a streaming detokenizer for a request.""" + if request_id not in self._detokenizer_pool: + if hasattr(self.tokenizer, "detokenizer"): + detok = self.tokenizer.detokenizer + else: + detok = NaiveStreamingDetokenizer(self._actual_tokenizer) + detok.reset() + self._detokenizer_pool[request_id] = detok + return self._detokenizer_pool[request_id] + + def _cleanup_detokenizer(self, request_id: str) -> None: + """Remove the streaming detokenizer for a finished request.""" + self._detokenizer_pool.pop(request_id, None) + def _get_stop_tokens(self) -> set[int]: """Get stop token IDs from tokenizer or processor.""" stop_tokens = set() @@ -1727,6 +1757,7 @@ def _do_abort_request(self, request_id: str) -> bool: if request is not None: request.set_finished(RequestStatus.FINISHED_ABORTED) self.finished_req_ids.add(request_id) + self._cleanup_detokenizer(request_id) # Flush Metal encoders after removing arrays from batch mx.clear_cache() @@ -1938,6 +1969,7 @@ def _process_batch_responses( else: output.output_text = self._decode_tokens(request.output_token_ids) request.output_text = output.output_text + self._cleanup_detokenizer(request_id) # Extract cache for future reuse (critical for agentic multi-turn) if hasattr(response, "prompt_cache"): @@ -2172,6 +2204,7 @@ def _recover_from_generation_error(self) -> set[str]: aborted_ids.add(request_id) self.finished_req_ids.add(request_id) self.running.clear() + self._detokenizer_pool.clear() # Clear UID mappings (batch generator is gone) self.request_id_to_uid.clear() @@ -2459,6 +2492,7 @@ def reset(self) -> None: self.finished_req_ids.clear() self.request_id_to_uid.clear() self.uid_to_request_id.clear() + self._detokenizer_pool.clear() self._close_batch_generator() self._current_sampler_params = None diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 5c58ec2..3722490 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -104,6 +104,8 @@ ) from .api.utils import ( SPECIAL_TOKENS_PATTERN, + StreamingThinkRouter, + StreamingToolCallFilter, clean_output_text, extract_json_from_response, extract_multimodal_content, @@ -127,6 +129,9 @@ _engine: BaseEngine | None = None _model_name: str | None = None _model_alias: str | None = None # Short alias used to start the model (if any) +_model_path: str | None = ( + None # Actual model path (for cache dir, not affected by --served-model-name) +) _default_max_tokens: int = 32768 _default_timeout: float = 300.0 # Default request timeout in seconds (5 minutes) _default_temperature: float | None = None # Set via --default-temperature @@ -300,11 +305,14 @@ def _save_prefix_cache_to_disk() -> None: def _get_cache_dir() -> str: - """Get cache persistence directory based on model name.""" - # Use global _model_name which is always a string, set during load_model() - model_name = _model_name if _model_name else "default" + """Get cache persistence directory based on actual model path.""" + # Use _model_path (actual model path) not _model_name (which may be overridden + # by --served-model-name). This ensures cache is shared regardless of served name. + model_name = ( + _model_path if _model_path else (_model_name if _model_name else "default") + ) logger.info( - f"[_get_cache_dir] _model_name={_model_name!r} type={type(_model_name)}" + f"[_get_cache_dir] _model_path={_model_path!r} type={type(_model_path)}" ) # Sanitize model name for filesystem safe_name = str(model_name).replace("/", "--").replace("\\", "--") @@ -478,6 +486,24 @@ def get_engine() -> BaseEngine: return _engine +def _validate_model_name(request_model: str) -> None: + """Validate that the request model name matches the served model or its alias.""" + if not _model_name: + return + # Accept the canonical name, the alias, or the model path + accepted = {_model_name} + if _model_alias: + accepted.add(_model_alias) + if _model_path: + accepted.add(_model_path) + if request_model not in accepted: + raise HTTPException( + status_code=404, + detail=f"The model `{request_model}` does not exist. " + f"Available model: `{_model_name}`", + ) + + def _parse_tool_calls_with_parser( output_text: str, request: ChatCompletionRequest | None = None ) -> tuple[str, list | None]: @@ -688,6 +714,12 @@ def load_model( cloud_threshold: int = 20000, cloud_api_base: str | None = None, cloud_api_key: str | None = None, + served_model_name: str | None = None, + mtp: bool = False, + specprefill_enabled: bool = False, + specprefill_threshold: int = 8192, + specprefill_keep_pct: float = 0.3, + specprefill_draft_model: str = None, ): """ Load a model (auto-detects MLLM vs LLM). @@ -706,16 +738,23 @@ def load_model( prefill_step_size: Tokens to process per prefill chunk (default: 2048) kv_bits: KV cache quantization bits (None=no quantization, 4 or 8) kv_group_size: Group size for KV cache quantization (default: 64) + mtp: Enable native MTP speculative decoding (SimpleEngine only) + specprefill_enabled: Enable SpecPrefill (SimpleEngine only) + specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill (default: 8192) + specprefill_keep_pct: Fraction of tokens to keep (default: 0.3) + specprefill_draft_model: Path to small draft model for SpecPrefill scoring """ global \ _engine, \ _model_name, \ + _model_path, \ _default_max_tokens, \ _tool_parser_instance, \ _cloud_router _default_max_tokens = max_tokens - _model_name = model_name + _model_path = model_name + _model_name = served_model_name or model_name # Reset tool parser instance when model is reloaded (tokenizer may change) _tool_parser_instance = None @@ -781,6 +820,11 @@ def load_model( prefill_step_size=prefill_step_size, kv_bits=kv_bits, kv_group_size=kv_group_size, + mtp=mtp, + specprefill_enabled=specprefill_enabled, + specprefill_threshold=specprefill_threshold, + specprefill_keep_pct=specprefill_keep_pct, + specprefill_draft_model=specprefill_draft_model, ) # Start SimpleEngine synchronously (no background loop) # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) @@ -1580,6 +1624,7 @@ async def _wait_disconnect(): ) async def create_completion(request: CompletionRequest, raw_request: Request): """Create a text completion.""" + _validate_model_name(request.model) engine = get_engine() # Validate model name matches loaded model @@ -1711,6 +1756,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re } ``` """ + _validate_model_name(request.model) engine = get_engine() # Validate messages is non-empty @@ -1826,6 +1872,23 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) has_media = bool(images or videos) + if engine.is_mllm and not has_media: + # MLLM extracts media from messages directly, so images/videos are + # always empty. Check message content for video/image types instead. + for msg in request.messages: + content = msg.content if hasattr(msg, "content") else msg.get("content", "") + if isinstance(content, list): + for item in content: + item_type = ( + item.type + if hasattr(item, "type") + else (item.get("type", "") if isinstance(item, dict) else "") + ) + if item_type in ("image_url", "image", "video", "video_url"): + has_media = True + break + if has_media: + break # Normalize "developer" role to "system" for models that don't support it. # OpenAI's API uses "developer" as a newer alias for "system", but most @@ -1908,6 +1971,12 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re if request.video_max_frames: chat_kwargs["video_max_frames"] = request.video_max_frames + # SpecPrefill: per-request overrides + if request.specprefill is not None: + chat_kwargs["specprefill"] = request.specprefill + if request.specprefill_keep_pct is not None: + chat_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + # Add tools if provided if request.tools: chat_kwargs["tools"] = convert_tools_for_template(request.tools) @@ -2249,6 +2318,8 @@ async def create_anthropic_message( body = await request.json() anthropic_request = AnthropicRequest(**body) + _validate_model_name(anthropic_request.model) + # --- Detailed request logging --- n_msgs = len(anthropic_request.messages) total_chars = 0 @@ -2453,6 +2524,59 @@ async def count_anthropic_tokens(request: Request): return {"input_tokens": total_tokens} +def _emit_content_pieces( + pieces: list[tuple[str, str]], + current_block_type: str | None, + block_index: int, +) -> tuple[list[str], str | None, int]: + """Emit Anthropic SSE events for content pieces from the think router. + + Handles block type transitions (thinking <-> text), emitting + content_block_start/stop/delta events as needed. + + Args: + pieces: List of (block_type, text) from StreamingThinkRouter + current_block_type: Current open block type, or None + block_index: Current block index + + Returns: + Tuple of (events, updated_block_type, updated_block_index) + """ + events = [] + for block_type, text in pieces: + if block_type != current_block_type: + # Close previous block if open + if current_block_type is not None: + events.append( + f"event: content_block_stop\ndata: " + f"{json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" + ) + block_index += 1 + # Start new block + current_block_type = block_type + content_block = ( + {"type": block_type, "text": ""} + if block_type == "text" + else {"type": block_type, "thinking": ""} + ) + events.append( + f"event: content_block_start\ndata: " + f"{json.dumps({'type': 'content_block_start', 'index': block_index, 'content_block': content_block})}\n\n" + ) + # Emit delta + delta_key = "thinking" if block_type == "thinking" else "text" + delta_type = "thinking_delta" if block_type == "thinking" else "text_delta" + delta_event = { + "type": "content_block_delta", + "index": block_index, + "delta": {"type": delta_type, delta_key: text}, + } + events.append( + f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" + ) + return events, current_block_type, block_index + + async def _stream_anthropic_messages( engine: BaseEngine, openai_request: ChatCompletionRequest, @@ -2506,19 +2630,30 @@ async def _stream_anthropic_messages( } yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" - # Emit content_block_start for text - content_block_start = { - "type": "content_block_start", - "index": 0, - "content_block": {"type": "text", "text": ""}, - } - yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n" - - # Stream content deltas — use reasoning parser to strip think tags + # Stream pipeline: raw text → tool call filter → think router → emit + # - Tool call filter strips tool call markup (emitted as structured blocks later) + # - Think router separates content into Anthropic thinking blocks accumulated_text = "" accumulated_raw = "" + tool_filter = StreamingToolCallFilter() + # Detect if the model's chat template injects into the + # generation prompt. If so, the model starts in thinking mode and + # the opening tag never appears in the output stream. + _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None + _chat_template = "" + if _tokenizer and hasattr(_tokenizer, "chat_template"): + _chat_template = _tokenizer.chat_template or "" + _starts_thinking = ( + "" in _chat_template and "add_generation_prompt" in _chat_template + ) + think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking) + prompt_tokens = 0 completion_tokens = 0 + # Track which content blocks we've started + current_block_type = None # "thinking" or "text" + block_index = 0 + # Reset reasoning parser state for this stream if _reasoning_parser: _reasoning_parser.reset_state() @@ -2527,10 +2662,14 @@ async def _stream_anthropic_messages( delta_text = output.new_text # Track token counts + if hasattr(output, "prompt_tokens") and output.prompt_tokens: + prompt_tokens = output.prompt_tokens if hasattr(output, "completion_tokens") and output.completion_tokens: completion_tokens = output.completion_tokens if delta_text: + # Accumulate raw text BEFORE special token cleaning for tool parsing + accumulated_text += delta_text content = None # Use reasoning parser to separate reasoning from content @@ -2552,13 +2691,39 @@ async def _stream_anthropic_messages( content = strip_special_tokens(content) if content: - accumulated_text += content - delta_event = { - "type": "content_block_delta", - "index": 0, - "delta": {"type": "text_delta", "text": content}, - } - yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" + # Stage 1: strip tool call markup + filtered = tool_filter.process(content) + if not filtered: + continue + # Stage 2: route thinking vs text + pieces = think_router.process(filtered) + events, current_block_type, block_index = _emit_content_pieces( + pieces, current_block_type, block_index + ) + for event in events: + yield event + + # Flush remaining from both filters + remaining = tool_filter.flush() + if remaining: + events, current_block_type, block_index = _emit_content_pieces( + think_router.process(remaining), current_block_type, block_index + ) + for event in events: + yield event + + flush_pieces = think_router.flush() + if flush_pieces: + events, current_block_type, block_index = _emit_content_pieces( + flush_pieces, current_block_type, block_index + ) + for event in events: + yield event + + # Close final content block + if current_block_type is not None: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" + block_index += 1 # Handle reasoning parser finalization (e.g. no-tag correction) if _reasoning_parser and accumulated_raw: @@ -2584,13 +2749,10 @@ async def _stream_anthropic_messages( # Check for tool calls in accumulated text _, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request) - # Emit content_block_stop for text block - yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" - # If there are tool calls, emit tool_use blocks if tool_calls: for i, tc in enumerate(tool_calls): - tool_index = i + 1 + tool_index = block_index + i try: tool_input = json.loads(tc.function.arguments) except (json.JSONDecodeError, AttributeError): @@ -2628,7 +2790,7 @@ async def _stream_anthropic_messages( message_delta = { "type": "message_delta", "delta": {"stop_reason": stop_reason, "stop_sequence": None}, - "usage": {"output_tokens": completion_tokens}, + "usage": {"input_tokens": prompt_tokens, "output_tokens": completion_tokens}, } yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n" @@ -2636,7 +2798,7 @@ async def _stream_anthropic_messages( elapsed = time.perf_counter() - start_time tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 logger.info( - f"Anthropic messages (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + f"Anthropic messages (stream): prompt={prompt_tokens} + completion={completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" ) # Emit message_stop diff --git a/vllm_mlx/specprefill.py b/vllm_mlx/specprefill.py new file mode 100644 index 0000000..5ea985f --- /dev/null +++ b/vllm_mlx/specprefill.py @@ -0,0 +1,742 @@ +# SPDX-License-Identifier: Apache-2.0 +"""SpecPrefill: Attention-based sparse prefill for MLX. + +Full pipeline for reducing TTFT on long prompts: + Step 1 (score_tokens): Use a small draft model to identify important tokens + Step 2 (sparse_prefill): Prefill target model with only selected tokens, + preserving original positional encoding via manual RoPE + +Usage: + from specprefill import score_tokens, select_chunks, sparse_prefill, cleanup_rope + + # 1. Score with draft model + importance = score_tokens(draft_model, tokens) + + # 2. Select important token chunks + selected = select_chunks(importance, keep_pct=0.3) + + # 3. Sparse prefill on target model + target_cache = make_prompt_cache(target_model) + logits = sparse_prefill(target_model, tokens, selected, target_cache) + + # 4. Generate normally using target_cache... + + # 5. Cleanup + cleanup_rope(target_model) + +Design notes: + - RoPE is relative: Q_m @ K_p^T depends only on (m - p). Selected keys stored + contiguously in the cache buffer with correct RoPE angles produce correct + attention during decode. + - After sparse prefill of N tokens from a total prompt of M, cache.offset = N + but decode RoPE needs position M. The _OffsetAdjustedRoPE adds (M - N) to + each RoPE offset call, so decode position = N + i + (M - N) = M + i. + - GatedDeltaNet (linear attention) layers process sparse tokens through their + conv/SSM state normally. This is lossy but acceptable per the SpecPrefill + paper — attention layers are the primary long-range mechanism. + +Reference: arxiv.org/abs/2502.02789 (SpecPrefill: Speculative Prefilling) +""" + +import math + +import mlx.core as mx + +from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.sample_utils import make_sampler + +# =========================================================================== +# Step 1: Token importance scoring (draft model) +# =========================================================================== + + +class _AttentionCapture: + """Wrapper that captures post-RoPE query vectors and delegates to original. + + Installed on attention layers during lookahead decode to capture query + vectors for importance scoring. Supports multiple architectures via + query_extractor callback. + """ + + def __init__(self, original, buf_idx, query_buffer, query_extractor=None): + self._original = original + self._buf_idx = buf_idx + self._query_buffer = query_buffer + self._query_extractor = query_extractor or _qwen35_extract_queries + + def __call__(self, x, mask=None, cache=None): + queries = self._query_extractor(self._original, x, cache) + self._query_buffer[self._buf_idx].append(queries) + return self._original(x, mask=mask, cache=cache) + + def __getattr__(self, name): + return getattr(self._original, name) + + +def _qwen35_extract_queries(attn, x, cache=None): + """Extract post-RoPE queries from Qwen3.5 attention (gate split + q_norm). + + Qwen3.5 q_proj output is 2x wider: [queries, gate]. We split, normalize, + then apply RoPE. + """ + B, L, D = x.shape + q_out = attn.q_proj(x) + queries, _gate = mx.split( + q_out.reshape(B, L, attn.num_attention_heads, -1), 2, axis=-1 + ) + queries = attn.q_norm(queries).transpose(0, 2, 1, 3) + if cache is not None: + queries = attn.rope(queries, offset=cache.offset) + else: + queries = attn.rope(queries) + return queries + + +def _llama_extract_queries(attn, x, cache=None): + """Extract post-RoPE queries from standard transformer attention. + + Standard architecture: q_proj → reshape → RoPE. No gate, no q_norm. + Works for Llama 3.x, Mistral, Gemma, GPT-OSS, and other GQA models. + """ + B, L, D = x.shape + n_heads = getattr( + attn, + "num_attention_heads", + getattr(attn, "n_heads", getattr(attn, "num_heads", None)), + ) + queries = attn.q_proj(x) + queries = queries.reshape(B, L, n_heads, -1).transpose(0, 2, 1, 3) + if cache is not None: + queries = attn.rope(queries, offset=cache.offset) + else: + queries = attn.rope(queries) + return queries + + +def _nemotron_h_extract_queries(attn, x, cache=None): + """Extract queries from Nemotron-H attention (no RoPE, no gate, no q_norm). + + Nemotron-H attention layers have NO positional encoding — RoPE is absent. + Positional modeling comes from Mamba2 layers. Attention is content-based only. + """ + B, L, D = x.shape + queries = attn.q_proj(x).reshape(B, L, attn.num_heads, -1).transpose(0, 2, 1, 3) + # No RoPE to apply — queries are used as-is for content-based scoring + return queries + + +def _patch_attention_for_capture(model, query_buffer, query_extractor=None): + """Replace attention modules on full-attention layers with capture wrappers. + + Supports both `self_attn` (Qwen3.5/Llama/GPT-OSS) and `mixer` + (Nemotron-H block_type="*") attribute conventions. + + Returns (originals, attn_layer_indices) for cleanup. + """ + originals = [] + attn_indices = [] + for layer_idx, layer in _find_attention_layers(model): + buf_idx = len(attn_indices) + attn_indices.append(layer_idx) + orig = _get_attn_module(layer) + _set_attn_module( + layer, + _AttentionCapture( + orig, buf_idx, query_buffer, query_extractor=query_extractor + ), + ) + originals.append((layer_idx, orig)) + return originals, attn_indices + + +def _unpatch_attention_capture(model, originals): + """Restore original attention modules after capture.""" + for layer_idx, orig in originals: + _set_attn_module(model.layers[layer_idx], orig) + + +def _prefill_draft(model, tokens, cache, step_size=2048): + """Prefill prompt tokens into cache. Returns logits from last token.""" + prompt = mx.array(tokens) if not isinstance(tokens, mx.array) else tokens + n = len(tokens) + processed = 0 + while n - processed > 1: + chunk = min(step_size, n - processed - 1) + model(prompt[processed : processed + chunk][None], cache=cache) + mx.eval([c.state for c in cache]) + processed += chunk + mx.clear_cache() + logits = model(prompt[processed:][None], cache=cache) + mx.eval(logits) + return logits + + +def _lookahead_decode(model, first_logits, cache, n_steps, temp=0.6, top_p=0.95): + """Run n_steps autoregressive decode, returning generated token ids. + + Query vectors are captured by the monkey-patched attention layers. + """ + sampler = make_sampler(temp=temp, top_p=top_p) + y = sampler(first_logits[:, -1, :]) + mx.eval(y) + generated = [y.item()] + for _ in range(n_steps): + logits = model(y.reshape(1, -1), cache=cache) + y = sampler(logits[:, -1, :]) + mx.eval(y) + generated.append(y.item()) + return generated + + +def _avg_pool1d(x, kernel_size): + """1D average pooling along last axis via prefix-sum. + + Args: + x: (..., M) input + kernel_size: window size (odd for centered) + + Returns: + (..., M) pooled (same size, zero-padded at edges) + """ + if kernel_size <= 1: + return x + pad = kernel_size // 2 + padded = mx.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad, pad)]) + zeros = mx.zeros(x.shape[:-1] + (1,), dtype=x.dtype) + prefix = mx.concatenate([zeros, mx.cumsum(padded, axis=-1)], axis=-1) + return (prefix[..., kernel_size:] - prefix[..., :-kernel_size]) / kernel_size + + +def _compute_importance( + query_buffer, attn_caches, n_prompt, n_attn_heads, n_kv_heads, pool_kernel=13 +): + """Compute per-token importance from captured queries and cached keys. + + Aggregation (SpecPrefill paper): + 1. softmax(Q @ K^T / sqrt(d)) per head, per layer, per lookahead token + 2. avg_pool1d smoothing + 3. max across (layers × heads) + 4. mean across lookahead tokens + + Returns: (n_prompt,) importance scores. + """ + heads_per_group = n_attn_heads // n_kv_heads + all_scores = [] + + for layer_i, captures in enumerate(query_buffer): + if not captures: + continue + cache = attn_caches[layer_i] + prompt_keys = cache.keys[..., :n_prompt, :] + # Skip layers with windowed/rotating caches that don't span + # the full prompt (e.g., GPT-OSS sliding_attention with 128-token window). + # These lack global context and would produce mismatched score shapes. + if prompt_keys.shape[-2] < n_prompt: + continue + head_dim = prompt_keys.shape[-1] + q_stack = mx.concatenate(captures, axis=2) + if heads_per_group > 1: + expanded_keys = mx.repeat(prompt_keys, heads_per_group, axis=1) + else: + expanded_keys = prompt_keys + scale = head_dim**-0.5 + scores = (q_stack @ expanded_keys.transpose(0, 1, 3, 2)) * scale + weights = mx.softmax(scores.astype(mx.float32), axis=-1) + all_scores.append(weights.squeeze(0)) + + if not all_scores: + raise RuntimeError("No attention scores captured — check model/patching") + + combined = mx.concatenate(all_scores, axis=0) + if pool_kernel and pool_kernel > 1: + combined = _avg_pool1d(combined, pool_kernel) + max_scores = mx.max(combined, axis=0) + importance = mx.mean(max_scores, axis=0) + return importance + + +def score_tokens( + model, + tokens, + n_lookahead=8, + pool_kernel=13, + temp=0.6, + top_p=0.95, + prefill_step_size=2048, + query_extractor=None, +): + """Score token importance using attention-based analysis on a draft model. + + Runs the full scoring pipeline: + 1. Prefill the draft model with all tokens + 2. N lookahead decode steps, capturing query vectors from attention layers + 3. Compute importance: Q_lookahead @ K_prompt^T, aggregated across heads/layers + + The draft model's cache is created internally and discarded after scoring. + + Args: + model: Draft model (small, fast — e.g. 4B) + tokens: list or mx.array of token IDs + n_lookahead: decode steps for query capture (default 8) + pool_kernel: smoothing kernel for avg_pool1d (default 13, 0=disable) + temp: sampling temperature for lookahead (default 0.6) + top_p: top-p for lookahead (default 0.95) + prefill_step_size: chunk size for draft prefill (default 2048) + query_extractor: function(attn, x, cache) → queries tensor. + Default: _qwen35_extract_queries. Use _llama_extract_queries for + standard Llama/Mistral/Gemma models. + + Returns: + importance: (M,) mx.array of per-token importance scores + """ + if isinstance(tokens, mx.array): + tokens = tokens.tolist() + n_prompt = len(tokens) + + # Model topology — detect attribute names across architectures + attn_layers = _find_attention_layers(model) + n_attn_layers = len(attn_layers) + attn_obj = _get_attn_module(attn_layers[0][1]) + # Attribute names vary: num_attention_heads (Qwen3.5), n_heads (Llama), + # num_heads (Nemotron-H) + n_attn_heads = getattr( + attn_obj, + "num_attention_heads", + getattr(attn_obj, "n_heads", getattr(attn_obj, "num_heads", None)), + ) + n_kv_heads = getattr( + attn_obj, "num_key_value_heads", getattr(attn_obj, "n_kv_heads", None) + ) + + # Auto-detect query extractor if not specified + if query_extractor is None: + if hasattr(attn_obj, "q_norm"): + query_extractor = _qwen35_extract_queries + elif not hasattr(attn_obj, "rope"): + # No RoPE attribute → Nemotron-H style (content-based attention) + query_extractor = _nemotron_h_extract_queries + else: + query_extractor = _llama_extract_queries + + # Phase 1: Prefill + cache = make_prompt_cache(model) + logits = _prefill_draft(model, tokens, cache, step_size=prefill_step_size) + + # Phase 2: Lookahead decode with query capture + query_buffer = [[] for _ in range(n_attn_layers)] + patches, attn_indices = _patch_attention_for_capture( + model, query_buffer, query_extractor=query_extractor + ) + try: + _lookahead_decode(model, logits, cache, n_lookahead, temp=temp, top_p=top_p) + mx.eval(query_buffer) + finally: + _unpatch_attention_capture(model, patches) + + # Phase 3: Compute importance + # Map layer indices to cache indices (identity for standard models, + # compacted for Nemotron-H where only M/* layers have cache entries) + layer_to_cache = _build_layer_to_cache_map(model) + attn_caches = [cache[layer_to_cache[i]] for i in attn_indices] + importance = _compute_importance( + query_buffer, + attn_caches, + n_prompt, + n_attn_heads, + n_kv_heads, + pool_kernel=pool_kernel if pool_kernel > 0 else None, + ) + mx.eval(importance) + + # Draft cache is no longer needed — let GC reclaim it + del cache, logits, query_buffer, attn_caches + mx.clear_cache() + + return importance + + +def select_chunks(importance, keep_pct=0.3, chunk_size=32): + """Select top-k% token chunks by average importance. + + Args: + importance: (M,) per-token importance scores + keep_pct: fraction of chunks to keep (default 0.3) + chunk_size: tokens per chunk (default 32) + + Returns: + sorted mx.array of kept token indices + """ + M = importance.shape[0] + if keep_pct >= 1.0: + return mx.arange(M) + + n_chunks = math.ceil(M / chunk_size) + keep_n = max(1, math.ceil(n_chunks * keep_pct)) + + chunk_scores = [] + for i in range(n_chunks): + start = i * chunk_size + end = min(start + chunk_size, M) + chunk_scores.append(mx.mean(importance[start:end]).item()) + + top_chunks = sorted(range(n_chunks), key=lambda i: chunk_scores[i], reverse=True)[ + :keep_n + ] + top_chunks.sort() + + indices = [] + for ci in top_chunks: + start = ci * chunk_size + end = min(start + chunk_size, M) + indices.extend(range(start, end)) + + return mx.array(indices) + + +# =========================================================================== +# Step 2: Sparse prefill with non-contiguous position IDs (target model) +# =========================================================================== + + +# --------------------------------------------------------------------------- +# Manual RoPE at arbitrary positions +# --------------------------------------------------------------------------- + + +def manual_rope(x, positions, dims, base=10000.0, scale=1.0): + """Apply RoPE at arbitrary (non-contiguous) positions. + + Uses non-traditional (interleaved) layout matching Qwen3.5: + rotates first `dims` dimensions as pairs [0,half), [half,dims), + passes through [dims:] unchanged. + + Args: + x: (B, n_heads, L, head_dim) input tensor + positions: (L,) position indices (can be non-contiguous) + dims: number of dimensions to rotate (head_dim * partial_rotary_factor) + base: RoPE base frequency (default 10000.0) + scale: position scale divisor (default 1.0, higher = compressed positions) + + Returns: + (B, n_heads, L, head_dim) with RoPE applied + """ + half = dims // 2 + inv_freq = 1.0 / (base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)) + scaled_pos = positions.astype(mx.float32) / scale + angles = scaled_pos[:, None] * inv_freq[None, :] # (L, half) + cos_a = mx.cos(angles)[None, None, :, :] # (1, 1, L, half) + sin_a = mx.sin(angles)[None, None, :, :] + x_rot, x_pass = x[..., :dims], x[..., dims:] + x1, x2 = x_rot[..., :half], x_rot[..., half:] + rotated = mx.concatenate( + [x1 * cos_a - x2 * sin_a, x1 * sin_a + x2 * cos_a], axis=-1 + ) + return mx.concatenate([rotated, x_pass], axis=-1) + + +def manual_rope_with_freqs(x, positions, dims, freqs, pre_scale=1.0): + """Apply RoPE at arbitrary positions using pre-computed frequencies. + + For custom RoPE variants (Llama3, Yarn, SuScaled) that store _freqs. + """ + half = dims // 2 + inv_freq = (1.0 / freqs).astype(mx.float32) + angles = positions[:, None].astype(mx.float32) * inv_freq[None, :] + cos_a = mx.cos(angles)[None, None, :, :] + sin_a = mx.sin(angles)[None, None, :, :] + x_rot, x_pass = x[..., :dims], x[..., dims:] + if pre_scale != 1.0: + x_rot = pre_scale * x_rot + x1, x2 = x_rot[..., :half], x_rot[..., half:] + rotated = mx.concatenate( + [x1 * cos_a - x2 * sin_a, x1 * sin_a + x2 * cos_a], axis=-1 + ) + return mx.concatenate([rotated, x_pass], axis=-1) + + +# --------------------------------------------------------------------------- +# RoPE wrappers +# --------------------------------------------------------------------------- + + +class _PositionMappedRoPE: + """Wraps a RoPE module to apply rotation at non-contiguous positions. + + Used during sparse prefill. The `offset` parameter from the cache tells us + which slice of the position array to use for the current chunk: + positions = all_positions[(offset - cache_start) : (offset - cache_start) + L] + + When composing with a pre-populated cache (e.g., system KV cache), cache_start + is the initial cache offset so indexing into the position array is correct. + """ + + def __init__(self, original_rope, all_positions, cache_start=0): + self._original = original_rope + self._all_positions = all_positions + self._cache_start = cache_start + self._has_custom_freqs = hasattr(original_rope, "_freqs") + + if self._has_custom_freqs: + self._freqs = original_rope._freqs + self._dims = _get_dims(original_rope) + self._pre_scale = _get_pre_scale(original_rope) + else: + # Standard nn.RoPE: attributes are dims, base, scale (no underscore) + self._dims = original_rope.dims + self._base = original_rope.base + self._scale = original_rope.scale + + def __call__(self, x, offset=0): + L = x.shape[2] + idx = offset - self._cache_start + positions = self._all_positions[idx : idx + L] + if self._has_custom_freqs: + return manual_rope_with_freqs( + x, positions, self._dims, self._freqs, pre_scale=self._pre_scale + ) + return manual_rope(x, positions, self._dims, base=self._base, scale=self._scale) + + +class _OffsetAdjustedRoPE: + """Wraps a RoPE module to add a constant offset for decode after sparse prefill. + + After sparse prefill of N tokens from a prompt of M total tokens: + cache.offset = N + i (i = decode step) + desired RoPE position = M + i + adjustment = M - N + + So: RoPE(x, offset = cache.offset + adjustment) = RoPE(x, M + i) + """ + + def __init__(self, original_rope, adjustment): + self._original = original_rope + self._adjustment = adjustment + + def __call__(self, x, offset=0): + return self._original(x, offset=offset + self._adjustment) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_dims(rope_module): + """Extract rotary dimensions from any RoPE variant.""" + for attr in ("_dims", "dim", "dims"): + if hasattr(rope_module, attr): + return getattr(rope_module, attr) + raise ValueError(f"Cannot determine dims from {type(rope_module)}") + + +def _get_pre_scale(rope_module): + """Extract pre-scale factor from custom RoPE variants (SuScaled, Yarn).""" + if hasattr(rope_module, "mscale"): + return rope_module.mscale + if hasattr(rope_module, "_scale") and hasattr(rope_module, "dim"): + return rope_module._scale + return 1.0 + + +def _find_attention_layers(model): + """Find all full-attention layers across architectures. + + Supports: + - Qwen3.5 / Llama / GPT-OSS: layers with `self_attn` attribute + - Nemotron-H: layers with `block_type == "*"` (attention blocks use `mixer`) + + Returns list of (layer_idx, layer) tuples. + """ + results = [] + for idx, layer in enumerate(model.layers): + if hasattr(layer, "self_attn"): + results.append((idx, layer)) + elif getattr(layer, "block_type", None) == "*": + results.append((idx, layer)) + return results + + +def _get_attn_module(layer): + """Get the attention module from a layer (self_attn or mixer).""" + if hasattr(layer, "self_attn"): + return layer.self_attn + if getattr(layer, "block_type", None) == "*": + return layer.mixer + return None + + +def _set_attn_module(layer, module): + """Set the attention module on a layer (self_attn or mixer).""" + if hasattr(layer, "self_attn"): + layer.self_attn = module + elif getattr(layer, "block_type", None) == "*": + layer.mixer = module + + +def _build_layer_to_cache_map(model): + """Build mapping from model layer index to cache index. + + Standard models (Qwen3.5, Llama, GPT-OSS): one cache entry per layer, + so the mapping is identity (layer_idx → layer_idx). + + Nemotron-H: only M (Mamba2) and * (attention) layers have cache entries. + MLP (-) and MoE (E) layers get no cache. The mapping is compacted. + + Returns dict {layer_idx: cache_idx}. + """ + has_block_type = any(hasattr(layer, "block_type") for layer in model.layers) + if not has_block_type: + # Standard model: identity mapping + return {i: i for i in range(len(model.layers))} + + # Nemotron-H style: count cache entries for M/* layers + layer_to_cache = {} + cache_idx = 0 + for layer_idx, layer in enumerate(model.layers): + bt = getattr(layer, "block_type", None) + if bt in ("M", "*"): + layer_to_cache[layer_idx] = cache_idx + cache_idx += 1 + return layer_to_cache + + +# --------------------------------------------------------------------------- +# Core API — sparse prefill +# --------------------------------------------------------------------------- + + +def sparse_prefill( + model, tokens, selected_indices, cache, step_size=2048, position_offset=0 +): + """Prefill the model cache with selected tokens at their original positions. + + Runs the model forward on only the selected tokens while preserving their + original positional encoding via manual RoPE. After this call, the cache + contains KV entries with correct RoPE positions, and attention layers have + _OffsetAdjustedRoPE installed for correct decode positioning. + + Args: + model: Language model with .layers property (TextModel or VLM Model) + tokens: (M,) all prompt token IDs (mx.array or list) + selected_indices: (N,) sorted indices into tokens to keep (mx.array or list) + cache: list of KVCache/ArraysCache from make_prompt_cache() + step_size: chunk size for processing (default 2048) + position_offset: added to selected_indices for RoPE positions (default 0). + Use when the cache already has tokens from a prior prefill (e.g., + system prompt KV cache with S tokens → position_offset=S). + + Returns: + logits: (1, 1, vocab_size) from the last selected token + + Side effects: + - Populates cache with KV for selected tokens + - Installs _OffsetAdjustedRoPE on attention layers for decode + - Call cleanup_rope(model) after generation to restore original RoPE + """ + if not isinstance(tokens, mx.array): + tokens = mx.array(tokens) + if not isinstance(selected_indices, mx.array): + selected_indices = mx.array(selected_indices) + + M = tokens.shape[0] + + # Detect RotatingKVCache and ensure tail tokens are included. + # Models with sliding window attention (e.g., GPT-OSS) use RotatingKVCache + # which evicts old entries. We must include the last `max_size` positions + # so sliding window layers have valid recent context for decode. + max_rotating_size = 0 + for c in cache: + if type(c).__name__ == "RotatingKVCache": + max_rotating_size = max(max_rotating_size, getattr(c, "max_size", 0)) + if max_rotating_size > 0: + tail_start = max(0, M - max_rotating_size) + tail_indices = set(range(tail_start, M)) + existing = set(selected_indices.tolist()) + merged = sorted(existing | tail_indices) + selected_indices = mx.array(merged) + + # RoPE positions: absolute positions accounting for any prefix + selected_positions = selected_indices.astype(mx.int32) + position_offset + selected_tokens = tokens[selected_indices] + N = selected_tokens.shape[0] + + # Determine initial cache offset (non-zero when system KV cache is restored) + attn_layers = _find_attention_layers(model) + layer_to_cache = _build_layer_to_cache_map(model) + first_attn_layer_idx = attn_layers[0][0] + first_attn_cache_idx = layer_to_cache[first_attn_layer_idx] + cache_start = ( + cache[first_attn_cache_idx].offset + if hasattr(cache[first_attn_cache_idx], "offset") + else 0 + ) + + # Check if attention layers use RoPE (Nemotron-H has none) + first_attn = _get_attn_module(attn_layers[0][1]) + has_rope = hasattr(first_attn, "rope") + + # Patch RoPE on attention layers for position-mapped prefill + # (skipped for architectures without RoPE, e.g. Nemotron-H) + original_ropes = {} + if has_rope: + for layer_idx, layer in attn_layers: + attn = _get_attn_module(layer) + original_ropes[layer_idx] = attn.rope + attn.rope = _PositionMappedRoPE( + attn.rope, selected_positions, cache_start=cache_start + ) + + try: + prompt = selected_tokens + n = int(N) + processed = 0 + + while n - processed > 1: + chunk = min(step_size, n - processed - 1) + model(prompt[processed : processed + chunk][None], cache=cache) + mx.eval([c.state for c in cache]) + processed += chunk + mx.clear_cache() + + # Last token → logits + logits = model(prompt[processed:][None], cache=cache) + mx.eval(logits) + + finally: + # Replace position-mapped RoPE with offset-adjusted RoPE for decode. + # Skipped for architectures without RoPE (e.g. Nemotron-H). + # + # Total prompt length = position_offset + M (prefix + current tokens). + # After prefill, cache offset = cache_start + N. + # Decode needs RoPE position = total_len + i, cache gives offset = cache_start + N + i. + # Adjustment = total_len - (cache_start + N) = position_offset + M - cache_start - N. + # When cache_start == position_offset (normal case): adjustment = M - N. + if has_rope: + total_prompt_len = position_offset + M + final_cache_offset = cache_start + N + adjustment = int(total_prompt_len) - int(final_cache_offset) + for layer_idx, layer in attn_layers: + attn = _get_attn_module(layer) + original = original_ropes[layer_idx] + if adjustment > 0: + attn.rope = _OffsetAdjustedRoPE(original, adjustment) + else: + attn.rope = original + + return logits + + +def cleanup_rope(model): + """Restore original RoPE on all attention layers. + + Call this after generation is complete to remove _OffsetAdjustedRoPE + wrappers installed by sparse_prefill(). No-op for architectures + without RoPE (e.g. Nemotron-H). + """ + for _, layer in _find_attention_layers(model): + attn = _get_attn_module(layer) + if attn is None or not hasattr(attn, "rope"): + continue + rope = attn.rope + if isinstance(rope, (_OffsetAdjustedRoPE, _PositionMappedRoPE)): + attn.rope = rope._original diff --git a/vllm_mlx/text_model_from_vlm.py b/vllm_mlx/text_model_from_vlm.py new file mode 100644 index 0000000..3de833b --- /dev/null +++ b/vllm_mlx/text_model_from_vlm.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Construct an mlx_lm TextModel from mlx_vlm-loaded model weights. + +When mlx_vlm loads a model, it strips MTP weights in sanitize(). +This module builds a parallel mlx_lm TextModel that: +1. Shares backbone + lm_head weights with the vlm model (zero-copy) +2. Loads MTP weights from safetensors on disk +3. Provides full mlx_lm API: return_hidden, n_confirmed, mtp_forward, make_mtp_cache +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + +import mlx.core as mx +import mlx.nn as nn +import mlx.utils + +logger = logging.getLogger(__name__) + + +def build_text_model(vlm_model: Any, model_path: str | Path) -> Any | None: + """Build an mlx_lm TextModel from a vlm-loaded model's weights. + + Args: + vlm_model: The mlx_vlm-loaded model (has .language_model attribute) + model_path: Path to the model directory (contains config.json + safetensors) + + Returns: + mlx_lm TextModel with MTP support, or None on failure. + """ + if vlm_model is None: + return None + + model_path = Path(model_path) if model_path else None + if model_path is None or not (model_path / "config.json").exists(): + # model_path may be a Hub repo ID — resolve to local cache + try: + from huggingface_hub import snapshot_download + + model_path = Path(snapshot_download(str(model_path))) + except Exception: + pass + if model_path is None or not (model_path / "config.json").exists(): + return None + + try: + config = json.loads((model_path / "config.json").read_text()) + text_config = config.get("text_config", config) + + # Always import from qwen3_5 — TextModel and TextModelArgs handle both + # dense and MoE natively (MTPDecoderLayer auto-selects SparseMoeBlock + # when args.num_experts > 0). qwen3_5_moe.py does NOT export these. + from mlx_lm.models.qwen3_5 import TextModel, TextModelArgs + + # Build args with proper __post_init__ (handles partial_rotary_factor, + # rope_scaling, head_dim derivation) + args = TextModelArgs.from_dict(text_config) + text_model = TextModel(args) + + # Collect all weights first: backbone from vlm + MTP from safetensors + vlm_lm = vlm_model.language_model + vlm_weights = mlx.utils.tree_flatten(vlm_lm.parameters()) + mtp_weights = _load_mtp_weights(model_path) + + all_weight_names = set(name for name, _ in vlm_weights) + all_weight_names.update(name for name, _ in mtp_weights) + + # Quantize the TextModel skeleton to match source weights. + # Use a predicate that only quantizes layers that have .scales in source. + # This prevents quantizing layers like mtp.fc which are BF16. + quantization = text_config.get("quantization", config.get("quantization", None)) + if quantization is not None: + + def _class_predicate(path, module): + if not hasattr(module, "to_quantized"): + return False + return f"{path}.scales" in all_weight_names + + nn.quantize( + text_model, + group_size=quantization.get("group_size", 64), + bits=quantization.get("bits", 8), + class_predicate=_class_predicate, + ) + + # Transfer backbone + lm_head weights from vlm language_model (zero-copy). + # strict=False because TextModel has MTP params that vlm doesn't have yet. + text_model.load_weights(vlm_weights, strict=False) + + logger.info( + "Transferred %d weight arrays from vlm language_model", len(vlm_weights) + ) + + # Load MTP weights from safetensors + if mtp_weights: + text_model.load_weights(mtp_weights, strict=False) + logger.info("Loaded %d MTP weights from safetensors", len(mtp_weights)) + else: + logger.warning("No MTP weights found in %s", model_path.name) + + # Verify MTP is functional + if hasattr(text_model, "mtp") and text_model.mtp is not None: + mx.eval(text_model.mtp.parameters()) + logger.info( + "TextModel built with MTP support (%d layers)", + args.mtp_num_hidden_layers, + ) + else: + logger.info("TextModel built without MTP (mtp_num_hidden_layers=0)") + + return text_model + + except ImportError as e: + logger.error("Cannot import mlx_lm TextModel (need PR #990): %s", e) + return None + except Exception as e: + logger.error("Failed to build TextModel from vlm: %s", e) + return None + + +def _load_mtp_weights(model_path: Path) -> list[tuple[str, mx.array]]: + """Load MTP weights from safetensors, stripping the language_model. prefix. + + mlx_vlm's sanitize() strips mtp.* keys during model loading, + but the weights are still on disk in the safetensors files. + """ + index_file = model_path / "model.safetensors.index.json" + if not index_file.exists(): + return [] + + index = json.loads(index_file.read_text()) + weight_map = index.get("weight_map", {}) + + # Find MTP keys and their shard files + mtp_keys: dict[str, tuple[str, str]] = {} + for key, shard in weight_map.items(): + if ".mtp." in key: + # Strip "language_model." prefix to match mlx_lm namespace + clean = ( + key.replace("language_model.", "", 1) + if key.startswith("language_model.") + else key + ) + mtp_keys[key] = (clean, shard) + + if not mtp_keys: + return [] + + # Group by shard to minimize I/O + shards: dict[str, list[tuple[str, str]]] = {} + for orig, (clean, shard) in mtp_keys.items(): + shards.setdefault(shard, []).append((orig, clean)) + + weights = [] + for shard_file, key_pairs in shards.items(): + shard_path = model_path / shard_file + if not shard_path.exists(): + logger.warning("MTP shard not found: %s", shard_file) + continue + shard_data = mx.load(str(shard_path)) + for orig, clean in key_pairs: + if orig in shard_data: + weights.append((clean, shard_data[orig])) + + return weights diff --git a/vllm_mlx/utils/mamba_cache.py b/vllm_mlx/utils/mamba_cache.py index cac35c8..2d9d5f1 100644 --- a/vllm_mlx/utils/mamba_cache.py +++ b/vllm_mlx/utils/mamba_cache.py @@ -11,17 +11,12 @@ import mlx.core as mx -# MambaCache was removed in mlx-lm 0.30.6 - make import conditional +# MambaCache was removed in mlx-lm 0.30.6, fall back to ArraysCache try: from mlx_lm.models.cache import MambaCache - - HAS_MAMBA_CACHE = True except ImportError: - # Fallback for mlx-lm >= 0.30.6 where MambaCache was removed from mlx_lm.models.cache import ArraysCache as MambaCache - HAS_MAMBA_CACHE = False - logger = logging.getLogger(__name__) @@ -41,10 +36,9 @@ def __init__(self, left_padding: list[int] | None = None, size: int = 2): left_padding: Amount of left padding for each sequence in batch size: Number of state arrays (default 2 for Mamba models) """ - if HAS_MAMBA_CACHE: - super().__init__(left_padding=left_padding) - else: - super().__init__(size=size, left_padding=left_padding) + # Always pass size - ArraysCache requires it, and MambaCache + # (if it exists) inherits from ArraysCache + super().__init__(size=size, left_padding=left_padding) self._batch_size = len(left_padding) if left_padding else 0 def extract(self, idx: int) -> MambaCache: @@ -58,10 +52,7 @@ def extract(self, idx: int) -> MambaCache: A new MambaCache with the extracted state """ size = len(self.cache) - if HAS_MAMBA_CACHE: - cache = MambaCache() - else: - cache = MambaCache(size=size) + cache = MambaCache(size=size) # Extract the state arrays for this index cache.cache = [ mx.contiguous(c[idx : idx + 1]) if c is not None else None @@ -207,8 +198,17 @@ def _patched_merge_caches(caches): def ensure_mamba_support(): - """Ensure MambaCache batching support is enabled.""" + """Ensure MambaCache batching support is enabled. + + NOTE: Disabled for mlx-lm >= 0.30.6 where ArraysCache natively supports + all batch operations (extract, merge, filter, prepare). The old patch + replaced ArraysCache with BatchMambaCache, which broke hybrid models + (Qwen3.5) that mix ArraysCache + KVCache layers. + """ global _patched if not _patched: - patch_mlx_lm_for_mamba() + logger.info( + "[MambaCache] Skipping _make_cache patch — " + "mlx-lm ArraysCache has native batching support" + ) _patched = True diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index be18a50..4cf433a 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -51,24 +51,89 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): return _load_with_tokenizer_fallback(model_name) try: - return load(model_name, tokenizer_config=tokenizer_config) + model, tokenizer = load(model_name, tokenizer_config=tokenizer_config) + return model, tokenizer except ValueError as e: # Fallback for models with non-standard tokenizers if "TokenizersBackend" in str(e) or "Tokenizer class" in str(e): logger.warning(f"Standard tokenizer loading failed, using fallback: {e}") return _load_with_tokenizer_fallback(model_name) - # Fallback for multimodal models loaded as text-only (skip vision weights) + # Fallback for models with extra/missing weights (e.g., vision tower, MTP layers). + # Retry with strict=False to discard extra weights. elif "parameters not in model" in str(e) or ( "Missing" in str(e) and "parameters" in str(e) ): logger.warning( - f"Model has extra/missing parameters (likely VLM weights), retrying with strict=False: {e}" + f"Model has extra/missing parameters (likely VLM / MTP weights), " + f"retrying with strict=False: {e}" ) - return _load_non_strict(model_name, tokenizer_config) + return _load_strict_false(model_name, tokenizer_config) else: raise +def _load_strict_false(model_name: str, tokenizer_config: dict = None): + """Load model with strict=False to discard extra weights (e.g., vision tower, MTP).""" + from mlx_lm.utils import load_model, load_tokenizer + + local_path = Path(model_name) + if local_path.is_dir(): + model_path = local_path + else: + from huggingface_hub import snapshot_download + + model_path = Path(snapshot_download(model_name)) + + model, config = load_model(model_path, strict=False) + tokenizer = load_tokenizer( + model_path, + tokenizer_config or {}, + eos_token_ids=config.get("eos_token_id", None), + ) + # Inject MTP support if model has MTP config + weights + _try_inject_mtp(model, model_path, config) + return model, tokenizer + + +def _try_inject_mtp(model, model_path, config): + """Inject MTP support if model has MTP config + weights.""" + if config.get("num_nextn_predict_layers", 0) > 0: + from ..patches.qwen3_next_mtp import inject_mtp_support + + inject_mtp_support(model, model_path, config) + + +def _try_inject_mtp_post_load(model, model_name): + """Check if MTP weights exist but were stripped by sanitize(), and inject.""" + import json + + from mlx_lm.utils import _download + + model_path = _download(model_name) + config_path = Path(model_path) / "config.json" + if not config_path.exists(): + return + with open(config_path) as f: + config = json.load(f) + # Also check text_config for nested configs + num_mtp = config.get("num_nextn_predict_layers", 0) + if num_mtp == 0: + text_config = config.get("text_config", {}) + num_mtp = text_config.get("num_nextn_predict_layers", 0) + if num_mtp > 0 and getattr(model, "mtp", None) is None: + mtp_file = Path(model_path) / "model-mtp.safetensors" + if mtp_file.exists(): + logger.info( + f"[MTP] Found MTP config (layers={num_mtp}) and weights, injecting..." + ) + _try_inject_mtp(model, model_path, config) + else: + logger.info( + f"[MTP] Config has num_nextn_predict_layers={num_mtp} " + "but model-mtp.safetensors not found, skipping MTP." + ) + + def _load_non_strict(model_name: str, tokenizer_config: dict = None): """Load model with strict=False to skip extra weights (e.g., vision tower).""" from mlx_lm.utils import load_model, load_tokenizer diff --git a/vllm_mlx/platform.py b/vllm_mlx/vllm_platform.py similarity index 100% rename from vllm_mlx/platform.py rename to vllm_mlx/vllm_platform.py