Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3d3c556
fix batch invariant warning
WindChimeRan Feb 21, 2026
0a0ea10
add pre-commit
WindChimeRan Feb 21, 2026
f42bf0e
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Feb 21, 2026
aa879f6
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Feb 22, 2026
6e3d53f
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Feb 22, 2026
899d1e8
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Feb 24, 2026
98d2996
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Feb 27, 2026
bb3e4b2
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 1, 2026
48eeee6
update all env var consumers
WindChimeRan Mar 1, 2026
4801f60
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 1, 2026
0a26ebc
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 2, 2026
09c4dc9
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 2, 2026
aff083e
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 2, 2026
e0f34c8
fix tests for env var
WindChimeRan Mar 3, 2026
4da8190
Merge remote-tracking branch 'fork/fix/register-batch-invariant-env' …
WindChimeRan Mar 3, 2026
f86edc3
Merge remote-tracking branch 'origin/main' into fix/register-batch-in…
WindChimeRan Mar 3, 2026
ddb04a1
Merge remote-tracking branch 'origin/main' into fix/register-batch-in…
WindChimeRan Mar 4, 2026
b7454e9
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 4, 2026
aead361
fix legacy kernel test
WindChimeRan Mar 4, 2026
b688ae3
Merge remote-tracking branch 'fork/fix/register-batch-invariant-env' …
WindChimeRan Mar 4, 2026
463a60f
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 5, 2026
8b59cb2
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 6, 2026
c4ac7c3
Merge branch 'main' into fix/register-batch-invariant-env
yewentao256 Mar 6, 2026
152853d
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 7, 2026
b6b9efa
fix fa_utils
WindChimeRan Mar 7, 2026
a331e79
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 7, 2026
51aa06a
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 8, 2026
95670b3
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 10, 2026
b4b8d85
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 21, 2026
0fdc21e
Merge branch 'main' into fix/register-batch-invariant-env
WindChimeRan Mar 22, 2026
e9da71a
fix fp8 utils merging conflict
WindChimeRan Mar 22, 2026
0843198
Merge remote-tracking branch 'origin/main' into fix/register-batch-in…
WindChimeRan Mar 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tests/kernels/attention/test_use_trtllm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,37 +55,37 @@ def _clear_supports_cache():
# supports_trtllm_attention


@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=True)
def test_supports_batch_invariant_disables(_mock):
@patch("vllm.envs.VLLM_BATCH_INVARIANT", True)
def test_supports_batch_invariant_disables():
assert supports_trtllm_attention() is False


@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False)
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True,
)
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True)
def test_supports_sm100_with_artifactory(_art, _cap, _bi):
def test_supports_sm100_with_artifactory(_art, _cap):
assert supports_trtllm_attention() is True


@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False)
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=False,
)
def test_supports_non_sm100_platform(_cap, _bi):
def test_supports_non_sm100_platform(_cap):
assert supports_trtllm_attention() is False


@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False)
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True,
)
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False)
def test_supports_sm100_without_artifactory(_art, _cap, _bi):
def test_supports_sm100_without_artifactory(_art, _cap):
assert supports_trtllm_attention() is False


Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/moe/test_grouped_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch

import vllm.model_executor.layers.batch_invariant as batch_invariant
import vllm.envs as envs
from vllm.config import (
CompilationConfig,
VllmConfig,
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_grouped_topk(

with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
m.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
m.setattr(envs, "VLLM_BATCH_INVARIANT", True)
grouped_topk = GroupedTopk(
topk=topk,
renormalize=renormalize,
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/determinism/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

import vllm.model_executor.layers.batch_invariant as batch_invariant
import vllm.envs as envs


@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
monkeypatch.setattr(envs, "VLLM_BATCH_INVARIANT", True)
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
16 changes: 6 additions & 10 deletions tests/v1/determinism/test_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
skip_unsupported,
)

import vllm.model_executor.layers.batch_invariant as batch_invariant
import vllm.envs as envs
from vllm import LLM, SamplingParams

IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
Expand Down Expand Up @@ -173,11 +173,9 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(

# For batch invariance, disable custom all-reduce to ensure deterministic
# all-reduce operations (custom all-reduce may not be deterministic)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
import vllm.envs as envs

disable_custom_ar = vllm_is_batch_invariant()
disable_custom_ar = envs.VLLM_BATCH_INVARIANT

if disable_custom_ar:
print(f"\n{'=' * 80}")
Expand Down Expand Up @@ -454,7 +452,7 @@ def test_logprobs_without_batch_invariance_should_fail(
"""
# CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
monkeypatch.setattr(envs, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
Expand Down Expand Up @@ -674,11 +672,9 @@ def test_decode_logprobs_match_prefill_logprobs(
random.seed(seed)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))

from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
import vllm.envs as envs

disable_custom_ar = vllm_is_batch_invariant()
disable_custom_ar = envs.VLLM_BATCH_INVARIANT

if disable_custom_ar:
print(f"\n{'=' * 80}")
Expand Down
5 changes: 1 addition & 4 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_ports_list
from vllm.utils.torch_utils import cuda_device_count_stateless
Expand Down Expand Up @@ -786,7 +783,7 @@ def _verify_args(self) -> Self:
from vllm.v1.executor import Executor

# Enable batch invariance settings if requested
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
self.disable_custom_all_reduce = True

if (
Expand Down
4 changes: 1 addition & 3 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,11 +1112,9 @@ def has_blocked_weights():
"when cudagraph_mode piecewise cudagraphs is used, "
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
)
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant

if (
self.model_config
and vllm_is_batch_invariant()
and envs.VLLM_BATCH_INVARIANT
and not self.model_config.disable_cascade_attn
):
self.model_config.disable_cascade_attn = True
Expand Down
5 changes: 1 addition & 4 deletions vllm/distributed/device_communicators/all_reduce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless

Expand Down Expand Up @@ -115,7 +112,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
is_symmetric_memory_enabled,
)

if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return False

if not is_symmetric_memory_enabled():
Expand Down
6 changes: 2 additions & 4 deletions vllm/distributed/device_communicators/symm_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
import torch.distributed as dist
from torch.distributed import ProcessGroup

import vllm.envs as envs
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform

try:
Expand Down Expand Up @@ -112,7 +110,7 @@ def __init__(
return
self.force_multimem = force_multimem
self.disabled = False
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
self.disabled = True

def should_use_symm_mem(self, inp: torch.Tensor):
Expand Down
9 changes: 5 additions & 4 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
VLLM_BATCH_INVARIANT: bool = False
MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False
Expand Down Expand Up @@ -280,9 +281,6 @@ def disable_compile_cache() -> bool:


def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils.torch_utils import is_torch_equal_or_newer

default_value = (
Expand All @@ -292,7 +290,7 @@ def use_aot_compile() -> bool:
)

return (
not vllm_is_batch_invariant()
not bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0")))
and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
)

Expand Down Expand Up @@ -498,6 +496,9 @@ def _get_or_set_default() -> str:
["highest", "high", "medium"],
case_sensitive=False,
),
# Enable batch-invariant mode: deterministic results regardless of
# batch composition. Requires NVIDIA GPU with compute capability >= 9.0.
"VLLM_BATCH_INVARIANT": lambda: bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0"))),
# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
Expand Down
3 changes: 1 addition & 2 deletions vllm/lora/ops/triton_ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@

from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2

logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant()
is_batch_invariant = envs.VLLM_BATCH_INVARIANT

_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/kernels/linear/scaled_mm/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

import vllm.envs as envs
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_block_strategy,
)
Expand Down Expand Up @@ -42,7 +41,7 @@ def is_supported(
# Check if platform supports FP8 Marlin
if not is_fp8_marlin_supported():
return False, "FP8 Marlin requires compute capability 7.5 or higher"
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return False, "FP8 Marlin not supported for batch invariant execution."
if (
compute_capability is not None
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
)
Expand Down Expand Up @@ -296,7 +295,7 @@ def __init__(
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and envs.VLLM_BATCH_INVARIANT
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
)
Expand Down Expand Up @@ -372,7 +371,7 @@ def __init__(
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and envs.VLLM_BATCH_INVARIANT
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
Expand Down Expand Up @@ -2188,7 +2187,7 @@ def _flash_attn_varlen_diff_headdims(
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
kwargs["num_splits"] = 1

attn_out = self.flash_attn_varlen_func(
Expand Down
18 changes: 2 additions & 16 deletions vllm/model_executor/layers/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
Expand Down Expand Up @@ -986,21 +987,6 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt")


def _read_vllm_batch_invariant() -> bool:
val = os.getenv("VLLM_BATCH_INVARIANT", "0")
try:
return int(val) != 0
except ValueError:
return False


VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant()


def vllm_is_batch_invariant() -> bool:
return VLLM_BATCH_INVARIANT


def override_envs_for_invariance(
attention_backend: AttentionBackendEnum | None,
):
Expand Down Expand Up @@ -1059,7 +1045,7 @@ def init_batch_invariance(
attention_backend: AttentionBackendEnum | None,
):
# this will hit all the csrc overrides as well
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
override_envs_for_invariance(attention_backend)
enable_batch_invariant_mode()

Expand Down
7 changes: 2 additions & 5 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
Expand Down Expand Up @@ -1051,7 +1048,7 @@ def get_moe_configs(
"""

# Avoid optimizing for the batch invariant case. Use default config
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return None

# First look up if an optimized configuration is available in the configs
Expand Down Expand Up @@ -1232,7 +1229,7 @@ def get_default_config(
dtype: str | None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
import torch

import vllm._custom_ops as ops
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
Expand Down Expand Up @@ -160,7 +158,7 @@ def fused_topk_bias(
) + e_score_correction_bias.unsqueeze(0)

# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
use_sorted = envs.VLLM_BATCH_INVARIANT
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
Expand Down
Loading
Loading