Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions tests/ut/batch_invariant/test_batch_invariant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# type: ignore
import importlib
import os
import sys
from unittest.mock import MagicMock, patch

import pytest

import vllm_ascend.batch_invariant as batch_invariant


class TestBatchInvariant:
"""Complete test suite for batch_invariant.py"""

def test_override_envs_for_invariance(self):
"""Test environment variable override"""
# Clear environment variables
env_vars = ["VLLM_ASCEND_ENABLE_NZ", "HCCL_DETERMINISTIC", "LCCL_DETERMINISTIC"]
for var in env_vars:
if var in os.environ:
del os.environ[var]

# Call function
batch_invariant.override_envs_for_invariance()

# Verify environment variables
assert os.environ["VLLM_ASCEND_ENABLE_NZ"] == "0"
assert os.environ["HCCL_DETERMINISTIC"] == "strict"
assert os.environ["LCCL_DETERMINISTIC"] == "1"

@pytest.mark.parametrize("custom_ops_available, expected_value", [(True, True), (False, False)])
def test_has_ascendc_batch_invariant(self, custom_ops_available, expected_value):
"""Test HAS_ASCENDC_BATCH_INVARIANT detection"""
# Control custom_ops availability
if custom_ops_available:
sys.modules["batch_invariant_ops"] = MagicMock()
else:
sys.modules.pop("batch_invariant_ops", None)

# Reload module to re-evaluate the flag
importlib.reload(batch_invariant)

# Verify result
assert batch_invariant.HAS_ASCENDC_BATCH_INVARIANT == expected_value

@patch("vllm_ascend.batch_invariant.HAS_TRITON", False)
@patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", True)
def test_enable_batch_invariant_mode_ascendc_path(self):
"""Test enable_batch_invariant_mode with AscendC ops available"""
# Mock dependencies
mock_library = MagicMock()
batch_invariant.torch.library.Library = MagicMock(return_value=mock_library)
batch_invariant.torch.ops.batch_invariant_ops = MagicMock()

# Call function
batch_invariant.enable_batch_invariant_mode()

# Verify library created
batch_invariant.torch.library.Library.assert_called_once_with("aten", "IMPL")

# Verify operator registrations
assert mock_library.impl.call_count == 3
mock_library.impl.assert_any_call(
"aten::mm", batch_invariant.torch.ops.batch_invariant_ops.npu_mm_batch_invariant, "NPU"
)
mock_library.impl.assert_any_call(
"aten::matmul", batch_invariant.torch.ops.batch_invariant_ops.npu_matmul_batch_invariant, "NPU"
)
mock_library.impl.assert_any_call(
"aten::sum", batch_invariant.torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant, "NPU"
)

# Verify torch_npu function patching
assert (
batch_invariant.torch_npu.npu_fused_infer_attention_score
== batch_invariant.torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant
)

@patch("vllm_ascend.batch_invariant.HAS_TRITON", True)
@patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False)
def test_enable_batch_invariant_mode_triton_path(self):
"""Test enable_batch_invariant_mode with only Triton available"""
# Mock dependencies
mock_library = MagicMock()
batch_invariant.torch.library.Library = MagicMock(return_value=mock_library)

# Mock triton imports
batch_invariant.addmm_batch_invariant = MagicMock()
batch_invariant.bmm_batch_invariant = MagicMock()
batch_invariant.mm_batch_invariant = MagicMock()
batch_invariant.matmul_batch_invariant = MagicMock()
batch_invariant.linear_batch_invariant = MagicMock()

# Call function
batch_invariant.enable_batch_invariant_mode()

# Verify operator registrations
assert mock_library.impl.call_count == 5
mock_library.impl.assert_any_call("aten::addmm", batch_invariant.addmm_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::bmm", batch_invariant.bmm_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::mm", batch_invariant.mm_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::matmul", batch_invariant.matmul_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::linear", batch_invariant.linear_batch_invariant, "NPU")

@patch("vllm_ascend.batch_invariant.HAS_TRITON", False)
@patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False)
def test_enable_batch_invariant_mode_no_backend(self):
"""Test enable_batch_invariant_mode with no backends available"""
# Mock library
mock_library = MagicMock()
batch_invariant.torch.library.Library = MagicMock(return_value=mock_library)

# Call function
batch_invariant.enable_batch_invariant_mode()

# Verify no operators registered
mock_library.impl.assert_not_called()

@pytest.mark.parametrize(
"batch_invariant_enabled, has_backend, expected_logger_call",
[(True, True, "info"), (True, False, "warning"), (False, True, None), (False, False, None)],
)
def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expected_logger_call):
"""Test init_batch_invariance under different conditions"""
# Mock dependencies
batch_invariant.vllm_is_batch_invariant = MagicMock(return_value=batch_invariant_enabled)
batch_invariant.HAS_TRITON = has_backend
batch_invariant.HAS_ASCENDC_BATCH_INVARIANT = has_backend
batch_invariant.override_envs_for_invariance = MagicMock()
batch_invariant.enable_batch_invariant_mode = MagicMock()

# Call function
batch_invariant.init_batch_invariance()

# Verify function calls based on conditions
if batch_invariant_enabled and has_backend:
batch_invariant.override_envs_for_invariance.assert_called_once()
batch_invariant.enable_batch_invariant_mode.assert_called_once()
elif batch_invariant_enabled and not has_backend:
batch_invariant.override_envs_for_invariance.assert_not_called()
batch_invariant.enable_batch_invariant_mode.assert_not_called()
else:
batch_invariant.override_envs_for_invariance.assert_not_called()
batch_invariant.enable_batch_invariant_mode.assert_not_called()


if __name__ == "__main__":
pytest.main([__file__])
48 changes: 37 additions & 11 deletions vllm_ascend/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os

import torch
import torch_npu
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.triton_utils import HAS_TRITON
Expand All @@ -35,15 +36,21 @@
)


def override_envs_for_invariance():
# TODO(Ronald) set attntion backend to deterministic mode
try:
import batch_invariant_ops # type: ignore[import-not-found] # noqa

HAS_ASCENDC_BATCH_INVARIANT = True
except ImportError:
HAS_ASCENDC_BATCH_INVARIANT = False


def override_envs_for_invariance():
# enabling NZ mode introduces NZ format input to the triton operator,
# resulting in accuracy anomalies.
os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0"

# communication determinism settings
os.environ["HCCL_DETERMINISTIC"] = "true"
os.environ["HCCL_DETERMINISTIC"] = "strict"
os.environ["LCCL_DETERMINISTIC"] = "1"


Expand All @@ -52,14 +59,32 @@ def override_envs_for_invariance():

def enable_batch_invariant_mode():
global _batch_invariant_LIB

_batch_invariant_LIB = torch.library.Library("aten", "IMPL")

_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU")
# Register operators only implemented in triton.
if HAS_TRITON:
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU")

# Register operators implemented in Ascend batch-invariant ops in priority.
if HAS_ASCENDC_BATCH_INVARIANT:
_batch_invariant_LIB.impl("aten::mm", torch.ops.batch_invariant_ops.npu_mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", torch.ops.batch_invariant_ops.npu_matmul_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::sum", torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant, "NPU")
# torch_npu.npu_fused_infer_attention_score is a function of torch_npu, not a torch.ops.Operator,
# so we need to patch it directly.
torch_npu.npu_fused_infer_attention_score = (
torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant
)

# register triton implementations if ascendc is not available.
elif HAS_TRITON:
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU")

# linear call matmul internally, so register linear only when ascendc
# is not available. it will get better performance with ascendc.
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU")


def init_batch_invariance():
Expand All @@ -75,13 +100,14 @@ def init_batch_invariance():
environment variable to enable automatically.
"""
if vllm_is_batch_invariant():
if HAS_TRITON:
if HAS_TRITON or HAS_ASCENDC_BATCH_INVARIANT:
logger.info(
"Enabling batch-invariant mode for vLLM on Ascend NPU.",
)
override_envs_for_invariance()
enable_batch_invariant_mode()
else:
logger.warning(
"Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization.",
"Batch-invariant mode requested but Triton or AscendC batch-invariant "
"ops is not available.skipping batch-invariant initialization."
)