Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/v1/determinism/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

import vllm.model_executor.layers.batch_invariant as batch_invariant


@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
yield
35 changes: 9 additions & 26 deletions tests/v1/determinism/test_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,16 @@

import pytest
import torch
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
from utils import (
BACKENDS,
_extract_step_logprobs,
_random_prompt,
resolve_model_name,
skip_unsupported,
)

import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform

BACKENDS: list[str] = [
"FLASH_ATTN",
"FLASHINFER",
]

if current_platform.is_cuda() and current_platform.is_device_capability(90):
BACKENDS.append("FLASH_ATTN_MLA")

DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"


def resolve_model_name(backend: str) -> str:
"""Resolve the model name for the given backend, respecting env overrides."""
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
if backend.endswith("MLA") and model == DEFAULT_MODEL:
return MLA_MODEL
return model


@skip_unsupported
Expand Down Expand Up @@ -454,14 +441,10 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant

vllm_is_batch_invariant.cache_clear()
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)

# CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")

monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)
Expand Down
12 changes: 9 additions & 3 deletions tests/v1/determinism/test_online_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from typing import Any

import openai
from utils import _random_prompt, skip_unsupported
import pytest
from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported

from tests.utils import RemoteOpenAIServer

Expand Down Expand Up @@ -133,9 +134,14 @@ def _compare_bs1_vs_bsn_single_process(


@skip_unsupported
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
@pytest.mark.parametrize("backend", BACKENDS)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend: str, monkeypatch: pytest.MonkeyPatch
) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
# Override backend for this test (and the RemoteOpenAIServer child process).
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)]

sp_kwargs: dict[str, Any] = {
Expand Down
20 changes: 20 additions & 0 deletions tests/v1/determinism/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random

import pytest
Expand All @@ -12,6 +13,25 @@
reason="Requires CUDA and >= Hopper (SM90)",
)

BACKENDS: list[str] = [
"FLASH_ATTN",
"FLASHINFER",
]

if current_platform.is_cuda() and current_platform.is_device_capability(90):
BACKENDS.append("FLASH_ATTN_MLA")

DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"


def resolve_model_name(backend: str) -> str:
"""Resolve the model name for the given backend."""
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
if backend.endswith("MLA") and model == DEFAULT_MODEL:
return MLA_MODEL
return model


def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Generate more realistic prompts that will actually produce varied tokens
Expand Down
20 changes: 11 additions & 9 deletions vllm/model_executor/layers/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Callable
from functools import cache
from typing import Any

import torch
Expand Down Expand Up @@ -785,16 +784,19 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt")


@cache
def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False
val = os.getenv(env_key, "0")
def _read_vllm_batch_invariant() -> bool:
val = os.getenv("VLLM_BATCH_INVARIANT", "0")
try:
is_overridden = int(val) != 0
return int(val) != 0
except ValueError:
is_overridden = False
return is_overridden
return False


VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant()


def vllm_is_batch_invariant() -> bool:
return VLLM_BATCH_INVARIANT


def override_envs_for_invariance():
Expand Down