diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index 616fc7a86059..f5a7b9cc276b 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc import random from typing import Optional, Union @@ -10,6 +9,7 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel +from vllm.distributed import cleanup_dist_env_and_memory from vllm.forward_context import get_forward_context from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration from vllm.model_executor.models.registry import ModelRegistry @@ -18,6 +18,9 @@ from ...utils import fork_new_process_for_each_test +# global seed +SEED = 42 + class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): @@ -95,8 +98,25 @@ def test_prompts(): return prompts +def cleanup(llm: LLM, compilation_config: CompilationConfig): + # hacky: below lines are required to free up memory for the next test + # when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient + # TODO(sarckk): when enforce_eager=False, memory is not freed: + # find out why and re-enable test for enforce_eager=False case + llm_engine = llm.llm_engine.engine_core.engine_core + model_runner = llm_engine.model_executor.driver_worker.worker.model_runner + del model_runner.model + del model_runner.kv_caches + del compilation_config.static_forward_context + compilation_config.static_forward_context = {} + + del llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + @fork_new_process_for_each_test -@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("enforce_eager", [True]) def test_kv_sharing_fast_prefill( monkeypatch: pytest.MonkeyPatch, enforce_eager: bool, @@ -115,23 +135,28 @@ def test_kv_sharing_fast_prefill( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + # Make scheduling deterministic for reproducibility + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + llm = LLM( model="google/gemma-3n-E2B-it", enforce_eager=enforce_eager, compilation_config=compilation_config, + seed=SEED, ) ref_responses = llm.generate(test_prompts, sampling_params) - del llm - gc.collect() - torch.cuda.empty_cache() + cleanup(llm, compilation_config) llm = LLM(model="google/gemma-3n-E2B-it", enforce_eager=enforce_eager, compilation_config=compilation_config, + seed=SEED, kv_sharing_fast_prefill=True) optimized_responses = llm.generate(test_prompts, sampling_params) + cleanup(llm, compilation_config) + misses = 0 for ref_response, optimized_response in zip(ref_responses,