Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .buildkite/scripts/run_in_docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ exec docker run \
-e MODEL_IMPL_TYPE="$MODEL_IMPL_TYPE" \
-e HF_TOKEN="$HF_TOKEN" \
-e VLLM_XLA_CACHE_PATH="$DOCKER_HF_HOME/.cache/jax_cache" \
-e VLLM_USE_V1=1 \
-e VLLM_XLA_CHECK_RECOMPILATION=1 \
${QUANTIZATION:+-e QUANTIZATION="$QUANTIZATION"} \
${NEW_MODEL_DESIGN:+-e NEW_MODEL_DESIGN="$NEW_MODEL_DESIGN"} \
Expand Down
8 changes: 2 additions & 6 deletions scripts/vllm/integration/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch,
elif tp_size < 1 or tp_size > 8:
raise ValueError

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

with monkeypatch.context() as _:
more_args = None
if current_platform.is_tpu():
more_args = "max_model_len=2048,max_num_seqs=64"
Expand Down Expand Up @@ -104,9 +102,7 @@ def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
elif tp_size < 1 or tp_size > 8:
raise ValueError

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

with monkeypatch.context() as _:
more_args = None
if current_platform.is_tpu():
more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
Expand Down
1 change: 0 additions & 1 deletion tests/executors/test_ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def test_init_executor_basic_flow(self, mock_wait_until_pg_ready,
# --- Setup mocks ---
mock_envs.VLLM_USE_RAY_COMPILED_DAG = True
mock_envs.VLLM_USE_RAY_SPMD_WORKER = True
mock_envs.VLLM_USE_V1 = True
mock_envs.VLLM_RAY_BUNDLE_INDICES = ""

mock_platform.ray_device_key = "TPU"
Expand Down
1 change: 0 additions & 1 deletion tpu_inference/executors/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def _init_executor(self) -> None:
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=_encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
self.use_v1 = envs.VLLM_USE_V1

self.pp_locks: Optional[List[asyncio.Lock]] = None

Expand Down
14 changes: 0 additions & 14 deletions tpu_inference/mock/vllm_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
Expand Down Expand Up @@ -725,10 +724,6 @@ def get_vllm_port() -> Optional[int]:
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),

# If set, use the V1 code path.
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),

# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER":
Expand Down Expand Up @@ -1180,15 +1175,6 @@ def is_set(name: str):
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def set_vllm_use_v1(use_v1: bool):
if is_set("VLLM_USE_V1"):
raise ValueError(
"Should not call set_vllm_use_v1() if VLLM_USE_V1 is set "
"explicitly by the user. Please raise this as a Github "
"Issue and explicitly set VLLM_USE_V1=0 or 1.")
os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0"


def compute_hash() -> str:
"""
WARNING: Whenever a new key is added to this environment
Expand Down
36 changes: 14 additions & 22 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return not vllm_envs.VLLM_USE_V1
return False

@classmethod
def get_punica_wrapper(cls) -> str:
Expand Down Expand Up @@ -119,8 +119,6 @@ def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if not vllm_envs.VLLM_USE_V1:
raise RuntimeError("VLLM_USE_V1=1 must be set for JAX backend.")

if vllm_envs.VLLM_TPU_USING_PATHWAYS:
assert not vllm_envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
Expand Down Expand Up @@ -165,22 +163,19 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
vllm_config.model_config.dtype = j2t_dtype(
vllm_config.model_config.dtype.dtype)

if vllm_envs.VLLM_USE_V1:
# TODO(cuiq): remove this dependency.
from vllm.v1.attention.backends.pallas import \
PallasAttentionBackend
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > cache_config.block_size:
logger.warning(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM",
cache_config.block_size,
min_page_size,
)
cache_config.block_size = min_page_size # type: ignore[assignment]
# TODO(cuiq): remove this dependency.
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config)
if min_page_size > cache_config.block_size:
logger.warning(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM",
cache_config.block_size,
min_page_size,
)
cache_config.block_size = min_page_size # type: ignore[assignment]

parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
Expand Down Expand Up @@ -251,9 +246,6 @@ def validate_request(
"""Raises if this request is unsupported on this platform"""

if isinstance(params, SamplingParams):
if params.structured_outputs is not None and not vllm_envs.VLLM_USE_V1:
raise ValueError("Structured output is not supported on "
f"{cls.device_name} V0.")
if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError("JAX does not support per-request seed.")

Expand Down