From 0e22c960d9b1fdbb22cb0e865baa2962b0a9fb9a Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 5 Mar 2026 13:08:49 -0600 Subject: [PATCH 1/7] add deterministic test Signed-off-by: ran --- tests/test_paged_deterministic.py | 114 ++++++++++++++++++++++++++++++ tools/gen_golden.py | 48 +++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 tests/test_paged_deterministic.py create mode 100644 tools/gen_golden.py diff --git a/tests/test_paged_deterministic.py b/tests/test_paged_deterministic.py new file mode 100644 index 00000000..477d2304 --- /dev/null +++ b/tests/test_paged_deterministic.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Deterministic smoke test: vLLM offline inference with golden token comparison. + +Golden token IDs were generated on the main branch using vLLM offline inference +with temperature=0 (greedy decoding) on Qwen/Qwen3-0.6B. + +Findings from golden generation (main branch, HF paged-attention kernel): +- The HF kernel paged KV path produces correct, coherent output. +- 4/5 prompts are identical to the MLX inline cache path. +- 1/5 ("The capital of France is") diverges at token 5 — both continuations + are valid English ("France is also the capital" vs "Italy is Rome. The"). + Likely caused by floating-point non-determinism in the attention kernel + where top-2 logits are very close. + +The assert accepts EITHER golden set (mlx-cache or paged-cache) and prints +which path matched. + +Run (paged KV path, the default): + python -m pytest tests/test_paged_deterministic.py -v -s + +To test the MLX inline cache path instead, change the env vars below. +""" + +from __future__ import annotations + +import os + +import pytest + +# Default: test the paged KV cache path through vLLM offline inference. +# Change to "0" to test the MLX inline cache path. +os.environ.setdefault("VLLM_METAL_USE_PAGED_ATTENTION", "1") +os.environ.setdefault("VLLM_METAL_MEMORY_FRACTION", "0.2") +os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + +MODEL_NAME = "Qwen/Qwen3-0.6B" +MAX_TOKENS = 10 + +PROMPTS = [ + "The capital of France is", + "The weather today is not", + "One plus one equals", + "The largest planet in our solar system is", + "Water boils at a temperature of", +] + +# fmt: off +# Golden token IDs from MLX inline cache (default path), greedy decoding. +# Generated on main branch via: VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden.py +GOLDEN_MLX = { + "The capital of France is": [12095, 13, 576, 6722, 315, 9625, 374, 1083, 279, 6722], + "The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13], + "One plus one equals": [825, 11, 825, 5519, 825, 16819, 1378, 13, 2055, 11], + "The largest planet in our solar system is": [1112, 30, 362, 13, 43562, 425, 13, 48976, 356, 13], + "Water boils at a temperature of": [220, 16, 15, 15, 30937, 13, 3555, 374, 279, 9315], +} + +# Golden token IDs from paged KV cache (HF kernel on main branch), greedy decoding. +# Generated on main branch via: VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \ +# VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden.py +GOLDEN_PAGED = { + "The capital of France is": [12095, 13, 576, 6722, 315, 15344, 374, 21718, 13, 576], + "The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13], + "One plus one equals": [825, 11, 825, 5519, 825, 16819, 1378, 13, 2055, 11], + "The largest planet in our solar system is": [1112, 30, 362, 13, 43562, 425, 13, 48976, 356, 13], + "Water boils at a temperature of": [220, 16, 15, 15, 30937, 13, 3555, 374, 279, 9315], +} +# fmt: on + + +from vllm import LLM, SamplingParams + + +@pytest.fixture(scope="module") +def vllm_outputs(): + """Run vLLM offline inference once for all prompts.""" + llm = LLM(model=MODEL_NAME, max_model_len=512) + sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS) + outputs = llm.generate(PROMPTS, sp) + return {o.prompt: o for o in outputs} + + +class TestPagedDeterministic: + @pytest.mark.slow + @pytest.mark.parametrize("prompt", PROMPTS) + def test_generate_matches_golden(self, vllm_outputs, prompt): + output = vllm_outputs[prompt] + token_ids = list(output.outputs[0].token_ids) + text = output.outputs[0].text + + mlx_expected = GOLDEN_MLX[prompt] + paged_expected = GOLDEN_PAGED[prompt] + + mlx_match = token_ids == mlx_expected + paged_match = token_ids == paged_expected + + print(f"\n prompt: {prompt!r}") + print(f" output: {text!r}") + print(f" ids: {token_ids}") + if mlx_match: + print(" result: MATCHED mlx-cache golden") + elif paged_match: + print(" result: MATCHED paged-cache golden") + else: + print(f" result: NO MATCH") + print(f" expected (mlx): {mlx_expected}") + print(f" expected (paged): {paged_expected}") + + assert mlx_match or paged_match, ( + f"Output for {prompt!r} matched neither golden set.\n" + f"Got: {token_ids}\n" + f"Expected (mlx): {mlx_expected}\n" + f"Expected (pgd): {paged_expected}" + ) diff --git a/tools/gen_golden.py b/tools/gen_golden.py new file mode 100644 index 00000000..9ab62c78 --- /dev/null +++ b/tools/gen_golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +"""Generate golden token IDs for the e2e smoke test via vLLM offline inference. + +Usage: + # MLX inline cache (default): + VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden.py + + # Paged KV cache: + VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 \ + python tools/gen_golden.py +""" + +import os + +os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + +from vllm import LLM, SamplingParams + +MODEL = "Qwen/Qwen3-0.6B" +MAX_TOKENS = 10 + +PROMPTS = [ + "The capital of France is", + "The weather today is not", + "One plus one equals", + "The largest planet in our solar system is", + "Water boils at a temperature of", +] + +if __name__ == "__main__": + paged = os.environ.get("VLLM_METAL_USE_PAGED_ATTENTION", "0") == "1" + label = "PAGED" if paged else "MLX" + print(f"\n--- Generating golden values for {label} path ---\n") + + llm = LLM(model=MODEL, max_model_len=512) + sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS) + outputs = llm.generate(PROMPTS, sp) + + print(f"\nGOLDEN_{label} = {{") + for o in outputs: + prompt = o.prompt + ids = list(o.outputs[0].token_ids) + text = o.outputs[0].text + pad = 45 - len(prompt) + print(f" {prompt!r}:{' ' * pad}{ids},") + print(f" # → {text!r}") + print("}") From 482e6b2748578ca5cc13e0679d95a7661c2094ba Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 5 Mar 2026 13:22:52 -0600 Subject: [PATCH 2/7] fix linter Signed-off-by: ran --- tests/test_paged_deterministic.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_paged_deterministic.py b/tests/test_paged_deterministic.py index 477d2304..7c85c94a 100644 --- a/tests/test_paged_deterministic.py +++ b/tests/test_paged_deterministic.py @@ -25,14 +25,15 @@ import os -import pytest - # Default: test the paged KV cache path through vLLM offline inference. # Change to "0" to test the MLX inline cache path. os.environ.setdefault("VLLM_METAL_USE_PAGED_ATTENTION", "1") os.environ.setdefault("VLLM_METAL_MEMORY_FRACTION", "0.2") os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0") +import pytest +from vllm import LLM, SamplingParams + MODEL_NAME = "Qwen/Qwen3-0.6B" MAX_TOKENS = 10 @@ -68,9 +69,6 @@ # fmt: on -from vllm import LLM, SamplingParams - - @pytest.fixture(scope="module") def vllm_outputs(): """Run vLLM offline inference once for all prompts.""" @@ -102,7 +100,7 @@ def test_generate_matches_golden(self, vllm_outputs, prompt): elif paged_match: print(" result: MATCHED paged-cache golden") else: - print(f" result: NO MATCH") + print(" result: NO MATCH") print(f" expected (mlx): {mlx_expected}") print(f" expected (paged): {paged_expected}") From c01378e828914156ed4f148ad95af94787e51232 Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 5 Mar 2026 22:01:56 -0600 Subject: [PATCH 3/7] set seq 1 to ensure deterministic Signed-off-by: ran --- tests/test_paged_deterministic.py | 51 ++++++++++++++----- ...{gen_golden.py => gen_golden_token_ids.py} | 16 ++++-- 2 files changed, 48 insertions(+), 19 deletions(-) rename tools/{gen_golden.py => gen_golden_token_ids.py} (63%) diff --git a/tests/test_paged_deterministic.py b/tests/test_paged_deterministic.py index 7c85c94a..ed43e8bc 100644 --- a/tests/test_paged_deterministic.py +++ b/tests/test_paged_deterministic.py @@ -2,7 +2,8 @@ """Deterministic smoke test: vLLM offline inference with golden token comparison. Golden token IDs were generated on the main branch using vLLM offline inference -with temperature=0 (greedy decoding) on Qwen/Qwen3-0.6B. +with temperature=0 (greedy decoding) on Qwen/Qwen3-0.6B, running one sequence +at a time (max_num_seqs=1) to avoid batch-invariance issues on Metal. Findings from golden generation (main branch, HF paged-attention kernel): - The HF kernel paged KV path produces correct, coherent output. @@ -18,19 +19,16 @@ Run (paged KV path, the default): python -m pytest tests/test_paged_deterministic.py -v -s -To test the MLX inline cache path instead, change the env vars below. +To test the MLX inline cache path instead, pass env vars explicitly: + VLLM_METAL_USE_PAGED_ATTENTION=0 VLLM_METAL_MEMORY_FRACTION=auto \ + python -m pytest tests/test_paged_deterministic.py -v -s + +Note: MLX requires VLLM_METAL_MEMORY_FRACTION=auto (numeric fractions are +only valid for the paged attention path). """ from __future__ import annotations -import os - -# Default: test the paged KV cache path through vLLM offline inference. -# Change to "0" to test the MLX inline cache path. -os.environ.setdefault("VLLM_METAL_USE_PAGED_ATTENTION", "1") -os.environ.setdefault("VLLM_METAL_MEMORY_FRACTION", "0.2") -os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - import pytest from vllm import LLM, SamplingParams @@ -47,7 +45,7 @@ # fmt: off # Golden token IDs from MLX inline cache (default path), greedy decoding. -# Generated on main branch via: VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden.py +# Generated on main branch via: VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids.py GOLDEN_MLX = { "The capital of France is": [12095, 13, 576, 6722, 315, 9625, 374, 1083, 279, 6722], "The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13], @@ -58,7 +56,7 @@ # Golden token IDs from paged KV cache (HF kernel on main branch), greedy decoding. # Generated on main branch via: VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \ -# VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden.py +# VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids.py GOLDEN_PAGED = { "The capital of France is": [12095, 13, 576, 6722, 315, 15344, 374, 21718, 13, 576], "The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13], @@ -69,10 +67,35 @@ # fmt: on +@pytest.fixture(scope="session") +def _monkeypatch_session(): + """Session-scoped monkeypatch (pytest only provides function-scoped).""" + from _pytest.monkeypatch import MonkeyPatch + + mp = MonkeyPatch() + yield mp + mp.undo() + + +@pytest.fixture(autouse=True, scope="session") +def _set_env(_monkeypatch_session): + """Set env vars for the paged KV cache path. + + Uses monkeypatch so env changes are automatically reverted after the + session, avoiding side effects on other tests. + """ + _monkeypatch_session.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1") + _monkeypatch_session.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2") + _monkeypatch_session.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + @pytest.fixture(scope="module") def vllm_outputs(): - """Run vLLM offline inference once for all prompts.""" - llm = LLM(model=MODEL_NAME, max_model_len=512) + """Run vLLM offline inference once for all prompts. + + Uses max_num_seqs=1 to avoid batch-invariance non-determinism on Metal. + """ + llm = LLM(model=MODEL_NAME, max_model_len=512, max_num_seqs=1) sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS) outputs = llm.generate(PROMPTS, sp) return {o.prompt: o for o in outputs} diff --git a/tools/gen_golden.py b/tools/gen_golden_token_ids.py similarity index 63% rename from tools/gen_golden.py rename to tools/gen_golden_token_ids.py index 9ab62c78..33f9ffe7 100644 --- a/tools/gen_golden.py +++ b/tools/gen_golden_token_ids.py @@ -1,14 +1,20 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 -"""Generate golden token IDs for the e2e smoke test via vLLM offline inference. +"""Generate golden token IDs for the deterministic smoke test. + +Runs vLLM offline inference (greedy, max_num_seqs=1) and prints golden +token-ID dicts to paste into test_paged_deterministic.py. Usage: # MLX inline cache (default): - VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden.py + VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids.py # Paged KV cache: - VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 \ - python tools/gen_golden.py + VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids.py + +Note: MLX path requires VLLM_METAL_MEMORY_FRACTION=auto (the default). + Numeric fractions are only valid for the paged attention path. """ import os @@ -33,7 +39,7 @@ label = "PAGED" if paged else "MLX" print(f"\n--- Generating golden values for {label} path ---\n") - llm = LLM(model=MODEL, max_model_len=512) + llm = LLM(model=MODEL, max_model_len=512, max_num_seqs=1) sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS) outputs = llm.generate(PROMPTS, sp) From e56dde45ad36e28fdeabdf931fce927d61364e58 Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 5 Mar 2026 22:11:31 -0600 Subject: [PATCH 4/7] more descriptive name Signed-off-by: ran --- ...en_token_ids.py => gen_golden_token_ids_for_deterministics.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tools/{gen_golden_token_ids.py => gen_golden_token_ids_for_deterministics.py} (100%) diff --git a/tools/gen_golden_token_ids.py b/tools/gen_golden_token_ids_for_deterministics.py similarity index 100% rename from tools/gen_golden_token_ids.py rename to tools/gen_golden_token_ids_for_deterministics.py From 76929f8f3c7d458f976c215c8f22a340a805f003 Mon Sep 17 00:00:00 2001 From: ran Date: Thu, 5 Mar 2026 22:23:32 -0600 Subject: [PATCH 5/7] use a more elegant way to handle env var Signed-off-by: ran --- tests/test_paged_deterministic.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/tests/test_paged_deterministic.py b/tests/test_paged_deterministic.py index ed43e8bc..85b6d558 100644 --- a/tests/test_paged_deterministic.py +++ b/tests/test_paged_deterministic.py @@ -67,26 +67,18 @@ # fmt: on -@pytest.fixture(scope="session") -def _monkeypatch_session(): - """Session-scoped monkeypatch (pytest only provides function-scoped).""" - from _pytest.monkeypatch import MonkeyPatch - - mp = MonkeyPatch() - yield mp - mp.undo() - - -@pytest.fixture(autouse=True, scope="session") -def _set_env(_monkeypatch_session): +@pytest.fixture(autouse=True, scope="module") +def _set_env(): """Set env vars for the paged KV cache path. - Uses monkeypatch so env changes are automatically reverted after the - session, avoiding side effects on other tests. + Uses MonkeyPatch.context() so env changes are automatically reverted + after the module, avoiding side effects on other tests. """ - _monkeypatch_session.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1") - _monkeypatch_session.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2") - _monkeypatch_session.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + with pytest.MonkeyPatch.context() as mp: + mp.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1") + mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2") + mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + yield @pytest.fixture(scope="module") From 1ddc50d2af882e87ed8f67381ce0c1ccd6e91c1c Mon Sep 17 00:00:00 2001 From: Yuan Lik Xun Date: Fri, 6 Mar 2026 13:29:00 +0800 Subject: [PATCH 6/7] Fix deterministic test docs and env defaults Signed-off-by: Yuan Lik Xun --- tests/test_paged_deterministic.py | 28 +++++++++++++++---- ...gen_golden_token_ids_for_deterministics.py | 4 +-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/tests/test_paged_deterministic.py b/tests/test_paged_deterministic.py index 85b6d558..e356249a 100644 --- a/tests/test_paged_deterministic.py +++ b/tests/test_paged_deterministic.py @@ -29,6 +29,8 @@ from __future__ import annotations +import os + import pytest from vllm import LLM, SamplingParams @@ -45,7 +47,7 @@ # fmt: off # Golden token IDs from MLX inline cache (default path), greedy decoding. -# Generated on main branch via: VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids.py +# Generated on main branch via: VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py GOLDEN_MLX = { "The capital of France is": [12095, 13, 576, 6722, 315, 9625, 374, 1083, 279, 6722], "The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13], @@ -56,7 +58,7 @@ # Golden token IDs from paged KV cache (HF kernel on main branch), greedy decoding. # Generated on main branch via: VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \ -# VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids.py +# VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py GOLDEN_PAGED = { "The capital of France is": [12095, 13, 576, 6722, 315, 15344, 374, 21718, 13, 576], "The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13], @@ -69,15 +71,31 @@ @pytest.fixture(autouse=True, scope="module") def _set_env(): - """Set env vars for the paged KV cache path. + """Set default env vars for this test. Uses MonkeyPatch.context() so env changes are automatically reverted after the module, avoiding side effects on other tests. + + Defaults to the paged KV cache path to ensure the test actually exercises + the paged attention kernel, but respects any env vars already set by the + user (e.g. to run the MLX path). """ with pytest.MonkeyPatch.context() as mp: - mp.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1") - mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2") mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + # Default to paged attention, but allow the caller to override. + use_paged = os.environ.get("VLLM_METAL_USE_PAGED_ATTENTION") + if use_paged is None: + mp.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1") + use_paged = "1" + + # Set a sensible default memory setting for the selected path, unless + # the caller has already specified one. + if os.environ.get("VLLM_METAL_MEMORY_FRACTION") is None: + if use_paged == "1": + mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2") + else: + mp.setenv("VLLM_METAL_MEMORY_FRACTION", "auto") yield diff --git a/tools/gen_golden_token_ids_for_deterministics.py b/tools/gen_golden_token_ids_for_deterministics.py index 33f9ffe7..1043ffe5 100644 --- a/tools/gen_golden_token_ids_for_deterministics.py +++ b/tools/gen_golden_token_ids_for_deterministics.py @@ -7,11 +7,11 @@ Usage: # MLX inline cache (default): - VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids.py + VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py # Paged KV cache: VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \ - VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids.py + VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py Note: MLX path requires VLLM_METAL_MEMORY_FRACTION=auto (the default). Numeric fractions are only valid for the paged attention path. From 8c357bc4734c1df89d0d4ac87c85704f7fc94a91 Mon Sep 17 00:00:00 2001 From: Yuan Lik Xun Date: Fri, 6 Mar 2026 13:35:58 +0800 Subject: [PATCH 7/7] Refactor env defaults Signed-off-by: Yuan Lik Xun --- tests/test_paged_deterministic.py | 39 ++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/tests/test_paged_deterministic.py b/tests/test_paged_deterministic.py index e356249a..0b1d3420 100644 --- a/tests/test_paged_deterministic.py +++ b/tests/test_paged_deterministic.py @@ -36,6 +36,9 @@ MODEL_NAME = "Qwen/Qwen3-0.6B" MAX_TOKENS = 10 +DEFAULT_USE_PAGED_ATTENTION = "1" +DEFAULT_PAGED_MEMORY_FRACTION = "0.2" +DEFAULT_MLX_MEMORY_FRACTION = "auto" PROMPTS = [ "The capital of France is", @@ -69,6 +72,15 @@ # fmt: on +def _setenv_default(mp: pytest.MonkeyPatch, key: str, default: str) -> str: + """Set an env var only when absent and return the effective value.""" + value = os.environ.get(key) + if value is None: + mp.setenv(key, default) + return default + return value + + @pytest.fixture(autouse=True, scope="module") def _set_env(): """Set default env vars for this test. @@ -83,19 +95,20 @@ def _set_env(): with pytest.MonkeyPatch.context() as mp: mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - # Default to paged attention, but allow the caller to override. - use_paged = os.environ.get("VLLM_METAL_USE_PAGED_ATTENTION") - if use_paged is None: - mp.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1") - use_paged = "1" - - # Set a sensible default memory setting for the selected path, unless - # the caller has already specified one. - if os.environ.get("VLLM_METAL_MEMORY_FRACTION") is None: - if use_paged == "1": - mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2") - else: - mp.setenv("VLLM_METAL_MEMORY_FRACTION", "auto") + # Default to paged attention, but allow explicit caller override. + use_paged = _setenv_default( + mp, + "VLLM_METAL_USE_PAGED_ATTENTION", + DEFAULT_USE_PAGED_ATTENTION, + ) + + # Choose a path-specific memory default, while preserving caller override. + memory_default = ( + DEFAULT_PAGED_MEMORY_FRACTION + if use_paged == "1" + else DEFAULT_MLX_MEMORY_FRACTION + ) + _setenv_default(mp, "VLLM_METAL_MEMORY_FRACTION", memory_default) yield