Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions docs/source/user_guide/feature_guide/batch_invariance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Batch Invariance

```{note}
Batch invariance is currently in beta. Some features are still under active development.
Track progress and planned improvements at <https://github.com/vllm-project/vllm-ascend/issues/5487>
```

This document shows how to enable batch invariance in vLLM-Ascend. Batch invariance ensures that the output of a model is deterministic and independent of the batch size or the order of requests in a batch.

## Motivation

Batch invariance is crucial for several use cases:

- **Framework debugging**: Deterministic outputs make it easier to debug issues in the inference framework, as the same input will always produce the same output regardless of batching.
- **Model debugging**: Helps identify issues in model implementations by ensuring consistent behavior across different batch configurations.
- **Reinforcement Learning (RL)**: RL training often requires deterministic rollouts for reproducibility and stable training.
- **Large-scale inference systems**: Systems that use vLLM as a component benefit from deterministic behavior for testing, validation, and consistency guarantees.

## Hardware Requirements

Batch invariance currently requires Ascend NPUs for 910B,
because only 910B supports batch invariance with HCCL communication for now,
we will support other NPUs in the future.

## Software Requirements

Batch invariance requires a customed operator library for 910B.
We will release the customed operator library in future versions.

## Enabling Batch Invariance

Batch invariance can be enabled by setting the `VLLM_BATCH_INVARIANT` environment variable to `1`:

```bash
export VLLM_BATCH_INVARIANT=1
```

### Online Inference (Server Mode)

To start a vLLM server with batch invariance enabled:

```bash
VLLM_BATCH_INVARIANT=1 vllm serve Qwen/Qwen3-8B
```

Then use the OpenAI-compatible client:

```python
from openai import OpenAI

client = OpenAI(
api_key="EMPTY",
base_url="http://localhost:8000/v1",
)

# These requests will produce deterministic outputs
# regardless of batch size or order
response = client.completions.create(
model="Qwen/Qwen3-8B",
prompt="The future of AI is",
max_tokens=100,
temperature=0.7,
seed=42,
)

print(response.choices[0].text)
```

### Offline Inference

For offline batch inference with batch invariance:

```python
import os
os.environ["VLLM_BATCH_INVARIANT"] = "1"

from vllm import LLM, SamplingParams

prompts = [
"The future of AI is",
"Machine learning enables",
"Deep learning models can",
]

sampling_params = SamplingParams(
temperature=0.7,
max_tokens=100,
seed=42,
)

llm = LLM(
model="Qwen/Qwen3-8B",
tensor_parallel_size=1,
)

# Outputs will be deterministic regardless of batch size
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Generated: {generated_text!r}\n")
```

## Tested Models

Batch invariance has been tested and verified on the following models:

- **Qwen3 (Dense)**: `Qwen/Qwen3-1.7B`, `Qwen/Qwen3-8B`
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`

Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm-ascend/issues/new/choose).

## Implementation Details

When batch invariance is enabled, vLLM:

1. Uses deterministic kernel implementations for attention and other operations
2. Ensures consistent numerical behavior across different batch sizes
3. Disables certain optimizations that may introduce non-determinism

```{note}
Enabling batch invariance may impact performance compared to the default non-deterministic mode. This trade-off is intentional to guarantee reproducibility.
```

## Future Improvements

The batch invariance feature is under active development. Planned improvements include:

- Support for additional NPUs series
- Expanded model coverage
- Performance optimizations
- Additional testing and validation

For the latest status and to contribute ideas, see the [tracking issue](https://github.com/vllm-project/vllm-ascend/issues/5487).
1 change: 1 addition & 0 deletions docs/source/user_guide/feature_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ context_parallel
npugraph_ex
weight_prefetch
sequence_parallelism
batch_invariance
:::
97 changes: 79 additions & 18 deletions tests/ut/batch_invariant/test_batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
# This file is a part of the vllm-ascend project.
#
# type: ignore
import importlib
import os
import sys
from unittest.mock import MagicMock, patch

import pytest
import torch

# Now import the module under test
import vllm_ascend.batch_invariant as batch_invariant


Expand All @@ -43,21 +43,6 @@ def test_override_envs_for_invariance(self):
assert os.environ["HCCL_DETERMINISTIC"] == "strict"
assert os.environ["LCCL_DETERMINISTIC"] == "1"

@pytest.mark.parametrize("custom_ops_available, expected_value", [(True, True), (False, False)])
def test_has_ascendc_batch_invariant(self, custom_ops_available, expected_value):
"""Test HAS_ASCENDC_BATCH_INVARIANT detection"""
# Control custom_ops availability
if custom_ops_available:
sys.modules["batch_invariant_ops"] = MagicMock()
else:
sys.modules.pop("batch_invariant_ops", None)

# Reload module to re-evaluate the flag
importlib.reload(batch_invariant)

# Verify result
assert batch_invariant.HAS_ASCENDC_BATCH_INVARIANT == expected_value

@patch("vllm_ascend.batch_invariant.HAS_TRITON", False)
@patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", True)
def test_enable_batch_invariant_mode_ascendc_path(self):
Expand Down Expand Up @@ -105,17 +90,20 @@ def test_enable_batch_invariant_mode_triton_path(self):
batch_invariant.mm_batch_invariant = MagicMock()
batch_invariant.matmul_batch_invariant = MagicMock()
batch_invariant.linear_batch_invariant = MagicMock()
batch_invariant.softmax_batch_invariant = MagicMock()

# Call function
batch_invariant.enable_batch_invariant_mode()

# Verify operator registrations
assert mock_library.impl.call_count == 5
assert mock_library.impl.call_count == 7
mock_library.impl.assert_any_call("aten::addmm", batch_invariant.addmm_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::bmm", batch_invariant.bmm_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::mm", batch_invariant.mm_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::matmul", batch_invariant.matmul_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::linear", batch_invariant.linear_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::softmax", batch_invariant.softmax_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::_softmax", batch_invariant.softmax_batch_invariant, "NPU")

@patch("vllm_ascend.batch_invariant.HAS_TRITON", False)
@patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False)
Expand Down Expand Up @@ -158,6 +146,79 @@ def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expec
batch_invariant.override_envs_for_invariance.assert_not_called()
batch_invariant.enable_batch_invariant_mode.assert_not_called()

@patch("vllm_ascend.batch_invariant.torch_npu")
def test_add_rms_norm(self, mock_torch_npu):
"""Test add_rms_norm function"""
# Mock dependencies
mock_torch = batch_invariant.torch

# Create mock tensors
batch_size = 2
hidden_size = 4
x = MagicMock(spec=torch.Tensor)
residual = MagicMock(spec=torch.Tensor)
weight = MagicMock(spec=torch.Tensor)
eps = 1e-6

# Set up mock return value for addition
x_plus_residual = MagicMock(spec=torch.Tensor)
x.__add__.return_value = x_plus_residual

# Set up expected outputs from npu_rms_norm
expected_output = MagicMock(spec=torch.Tensor)
expected_residual = MagicMock(spec=torch.Tensor)
mock_torch_npu.npu_rms_norm.return_value = (expected_output, expected_residual)

# Call the function
result_x, result_placeholder, result_residual = batch_invariant.add_rms_norm(x, residual, weight, eps)

# Verify the addition was called
x.__add__.assert_called_once_with(residual)

# Verify the npu_rms_norm was called with the correct parameters
mock_torch_npu.npu_rms_norm.assert_called_once_with(x_plus_residual, weight, eps)

# Verify the results
assert result_x is expected_output
assert result_placeholder is None

@patch("vllm_ascend.batch_invariant.torch_npu")
def test_add_rms_norm_consistency(self, mock_torch_npu):
"""Test that add_rms_norm produces the same output as torch_npu.npu_add_rms_norm"""
# Create mock tensors
batch_size = 2
hidden_size = 4
x = MagicMock(spec=torch.Tensor)
residual = MagicMock(spec=torch.Tensor)
weight = MagicMock(spec=torch.Tensor)
eps = 1e-6

# Set up mock values
x_plus_residual = MagicMock(spec=torch.Tensor)
x.__add__.return_value = x_plus_residual

# Define consistent mock results
expected_output = MagicMock(spec=torch.Tensor)
expected_residual = MagicMock(spec=torch.Tensor)

# Set up mock_npu_rms_norm to return the same results as if it were npu_add_rms_norm
mock_torch_npu.npu_rms_norm.return_value = (expected_output, expected_residual)
mock_torch_npu.npu_add_rms_norm.return_value = (expected_output, None, expected_residual)

# Call add_rms_norm
add_rms_norm_result = batch_invariant.add_rms_norm(x, residual, weight, eps)

# Call npu_add_rms_norm directly
npu_add_rms_norm_result = mock_torch_npu.npu_add_rms_norm(x, residual, weight, eps)

# Verify both functions return the same results
assert add_rms_norm_result[0] == npu_add_rms_norm_result[0]

# Verify the function composition is correct
x.__add__.assert_called_once_with(residual)
mock_torch_npu.npu_rms_norm.assert_called_once_with(x_plus_residual, weight, eps)
mock_torch_npu.npu_add_rms_norm.assert_called_once_with(x, residual, weight, eps)


if __name__ == "__main__":
pytest.main([__file__])
9 changes: 8 additions & 1 deletion vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,14 @@ def __init__(self, vllm_config: "VllmConfig"):
# npu_fused_infer_attention_score performs better on all scenarios.
self.pa_shape_list = additional_config.get("pa_shape_list", [])

self.enable_async_exponential = bool(additional_config.get("enable_async_exponential", False))
# when enable_async_exponential is True, AscendSampler will be different from vllm Sampler,
# which make batch_invariant mode not working.
# so we disable async exponential when batch_invariant mode is enabled.
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant

self.enable_async_exponential = (
bool(additional_config.get("enable_async_exponential", False)) and not vllm_is_batch_invariant()
)

self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
if self.enable_kv_nz:
Expand Down
38 changes: 38 additions & 0 deletions vllm_ascend/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.triton_utils import HAS_TRITON

# in case recursive call in reduce_sum.
torch_sum = torch.sum

logger = init_logger(__name__)

if HAS_TRITON:
Expand All @@ -34,6 +37,7 @@
matmul_batch_invariant,
mm_batch_invariant,
)
from vllm_ascend.ops.triton.batch_invariant.softmax import softmax_batch_invariant


try:
Expand All @@ -44,10 +48,38 @@
HAS_ASCENDC_BATCH_INVARIANT = False


def add_rms_norm(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
):
"""AclnnAddRmsNorm can't ensure batch invariant,
so we need to split it into add and rms_norm.
"""
x_ = x + residual
residual_ = x_
x_, _ = torch_npu.npu_rms_norm(x_, weight, eps)
return x_, None, residual_


def reduce_sum(x: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
"""npu_reduce_sum_batch_invariant requires dim to be specified, but torch.sum
doesn't require it, so we set dim to -1 by default if dim is None and x.dim()==1.
"""
dim = -1 if dim is None and x.dim() == 1 else dim
if x.device.type == "npu" and dim is not None:
return torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant(x, dim, keepdim)
# cpu tensor can't use npu_reduce_sum_batch_invariant, so we use torch.sum instead.
return torch_sum(x, dim, keepdim)


def override_envs_for_invariance():
# enabling NZ mode introduces NZ format input to the triton operator,
# resulting in accuracy anomalies.
os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0"
# fused operator can't ensure batch invariant, so we disable it.
os.environ["VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE"] = "0"

# communication determinism settings
os.environ["HCCL_DETERMINISTIC"] = "strict"
Expand All @@ -65,6 +97,8 @@ def enable_batch_invariant_mode():
if HAS_TRITON:
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "NPU")

# Register operators implemented in Ascend batch-invariant ops in priority.
if HAS_ASCENDC_BATCH_INVARIANT:
Expand All @@ -76,6 +110,10 @@ def enable_batch_invariant_mode():
torch_npu.npu_fused_infer_attention_score = (
torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant
)
# patch npu_add_rms_norm to ensure batch invariant.
torch_npu.npu_add_rms_norm = add_rms_norm
# torch.sum can't be replaced by dispatch logic, so we patch it directly.
torch.sum = reduce_sum

# register triton implementations if ascendc is not available.
elif HAS_TRITON:
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/sample/sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.sampler import Sampler

Expand Down Expand Up @@ -73,6 +74,10 @@ def set_q_event(self, q, event):

def forward_native(self, logits, generators, k, p):
"""Override pytorch native implementation to torch_npu"""
# when batch_invariant mode is enabled, we should use vllm's implementation.
# or it will make batch_invariant mode not working.
if vllm_is_batch_invariant():
return super().forward_native(logits, generators, k, p)
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
Expand Down
Loading