diff --git a/.buildkite/scripts/run_in_docker.sh b/.buildkite/scripts/run_in_docker.sh index 85f2027215..6eb3ca9092 100755 --- a/.buildkite/scripts/run_in_docker.sh +++ b/.buildkite/scripts/run_in_docker.sh @@ -103,7 +103,6 @@ exec docker run \ -e MODEL_IMPL_TYPE="$MODEL_IMPL_TYPE" \ -e HF_TOKEN="$HF_TOKEN" \ -e VLLM_XLA_CACHE_PATH="$DOCKER_HF_HOME/.cache/jax_cache" \ - -e VLLM_USE_V1=1 \ -e VLLM_XLA_CHECK_RECOMPILATION=1 \ ${QUANTIZATION:+-e QUANTIZATION="$QUANTIZATION"} \ ${NEW_MODEL_DESIGN:+-e NEW_MODEL_DESIGN="$NEW_MODEL_DESIGN"} \ diff --git a/scripts/vllm/integration/test_accuracy.py b/scripts/vllm/integration/test_accuracy.py index 2c4e666f8c..421196a0e2 100644 --- a/scripts/vllm/integration/test_accuracy.py +++ b/scripts/vllm/integration/test_accuracy.py @@ -70,9 +70,7 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch, elif tp_size < 1 or tp_size > 8: raise ValueError - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - + with monkeypatch.context() as _: more_args = None if current_platform.is_tpu(): more_args = "max_model_len=2048,max_num_seqs=64" @@ -104,9 +102,7 @@ def test_lm_eval_accuracy_v1_engine_fp8_kv_cache( elif tp_size < 1 or tp_size > 8: raise ValueError - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - + with monkeypatch.context() as _: more_args = None if current_platform.is_tpu(): more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8" diff --git a/tests/executors/test_ray_distributed_executor.py b/tests/executors/test_ray_distributed_executor.py index a0bfbbaaf6..b50d3f537a 100644 --- a/tests/executors/test_ray_distributed_executor.py +++ b/tests/executors/test_ray_distributed_executor.py @@ -61,7 +61,6 @@ def test_init_executor_basic_flow(self, mock_wait_until_pg_ready, # --- Setup mocks --- mock_envs.VLLM_USE_RAY_COMPILED_DAG = True mock_envs.VLLM_USE_RAY_SPMD_WORKER = True - mock_envs.VLLM_USE_V1 = True mock_envs.VLLM_RAY_BUNDLE_INDICES = "" mock_platform.ray_device_key = "TPU" diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index f1f055eece..423ffa23e7 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -97,7 +97,6 @@ def _init_executor(self) -> None: self.input_encoder = msgspec.msgpack.Encoder(enc_hook=_encode_hook) self.output_decoder = msgspec.msgpack.Decoder( Optional[List[SamplerOutput]]) - self.use_v1 = envs.VLLM_USE_V1 self.pp_locks: Optional[List[asyncio.Lock]] = None diff --git a/tpu_inference/mock/vllm_envs.py b/tpu_inference/mock/vllm_envs.py index 82084d1fc5..1a938002a6 100644 --- a/tpu_inference/mock/vllm_envs.py +++ b/tpu_inference/mock/vllm_envs.py @@ -91,7 +91,6 @@ VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] - VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True @@ -725,10 +724,6 @@ def get_vllm_port() -> Optional[int]: lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ "VLLM_DISABLED_KERNELS"].split(","), - # If set, use the V1 code path. - "VLLM_USE_V1": - lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), - # Disable aiter ops unless specifically enabled. # Acts as a parent switch to enable the rest of the other operations. "VLLM_ROCM_USE_AITER": @@ -1180,15 +1175,6 @@ def is_set(name: str): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -def set_vllm_use_v1(use_v1: bool): - if is_set("VLLM_USE_V1"): - raise ValueError( - "Should not call set_vllm_use_v1() if VLLM_USE_V1 is set " - "explicitly by the user. Please raise this as a Github " - "Issue and explicitly set VLLM_USE_V1=0 or 1.") - os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" - - def compute_hash() -> str: """ WARNING: Whenever a new key is added to this environment diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 0f47ed60fb..c19d673526 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -88,7 +88,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return not vllm_envs.VLLM_USE_V1 + return False @classmethod def get_punica_wrapper(cls) -> str: @@ -119,8 +119,6 @@ def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None: @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - if not vllm_envs.VLLM_USE_V1: - raise RuntimeError("VLLM_USE_V1=1 must be set for JAX backend.") if vllm_envs.VLLM_TPU_USING_PATHWAYS: assert not vllm_envs.VLLM_ENABLE_V1_MULTIPROCESSING, ( @@ -165,22 +163,19 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: vllm_config.model_config.dtype = j2t_dtype( vllm_config.model_config.dtype.dtype) - if vllm_envs.VLLM_USE_V1: - # TODO(cuiq): remove this dependency. - from vllm.v1.attention.backends.pallas import \ - PallasAttentionBackend - cache_config.block_size = PallasAttentionBackend.get_page_size( - vllm_config) # type: ignore[assignment] - min_page_size = PallasAttentionBackend.get_min_page_size( - vllm_config) - if min_page_size > cache_config.block_size: - logger.warning( - "Increase the page size from %s to %s to make sure there's" - "no SMEM OOM", - cache_config.block_size, - min_page_size, - ) - cache_config.block_size = min_page_size # type: ignore[assignment] + # TODO(cuiq): remove this dependency. + from vllm.v1.attention.backends.pallas import PallasAttentionBackend + cache_config.block_size = PallasAttentionBackend.get_page_size( + vllm_config) # type: ignore[assignment] + min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config) + if min_page_size > cache_config.block_size: + logger.warning( + "Increase the page size from %s to %s to make sure there's" + "no SMEM OOM", + cache_config.block_size, + min_page_size, + ) + cache_config.block_size = min_page_size # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config @@ -251,9 +246,6 @@ def validate_request( """Raises if this request is unsupported on this platform""" if isinstance(params, SamplingParams): - if params.structured_outputs is not None and not vllm_envs.VLLM_USE_V1: - raise ValueError("Structured output is not supported on " - f"{cls.device_name} V0.") if params.sampling_type == SamplingType.RANDOM_SEED: raise ValueError("JAX does not support per-request seed.")