diff --git a/tests/ut/batch_invariant/test_batch_invariant.py b/tests/ut/batch_invariant/test_batch_invariant.py new file mode 100644 index 00000000000..55fe02e9fae --- /dev/null +++ b/tests/ut/batch_invariant/test_batch_invariant.py @@ -0,0 +1,163 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# type: ignore +import importlib +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +import vllm_ascend.batch_invariant as batch_invariant + + +class TestBatchInvariant: + """Complete test suite for batch_invariant.py""" + + def test_override_envs_for_invariance(self): + """Test environment variable override""" + # Clear environment variables + env_vars = ["VLLM_ASCEND_ENABLE_NZ", "HCCL_DETERMINISTIC", "LCCL_DETERMINISTIC"] + for var in env_vars: + if var in os.environ: + del os.environ[var] + + # Call function + batch_invariant.override_envs_for_invariance() + + # Verify environment variables + assert os.environ["VLLM_ASCEND_ENABLE_NZ"] == "0" + assert os.environ["HCCL_DETERMINISTIC"] == "strict" + assert os.environ["LCCL_DETERMINISTIC"] == "1" + + @pytest.mark.parametrize("custom_ops_available, expected_value", [(True, True), (False, False)]) + def test_has_ascendc_batch_invariant(self, custom_ops_available, expected_value): + """Test HAS_ASCENDC_BATCH_INVARIANT detection""" + # Control custom_ops availability + if custom_ops_available: + sys.modules["batch_invariant_ops"] = MagicMock() + else: + sys.modules.pop("batch_invariant_ops", None) + + # Reload module to re-evaluate the flag + importlib.reload(batch_invariant) + + # Verify result + assert batch_invariant.HAS_ASCENDC_BATCH_INVARIANT == expected_value + + @patch("vllm_ascend.batch_invariant.HAS_TRITON", False) + @patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", True) + def test_enable_batch_invariant_mode_ascendc_path(self): + """Test enable_batch_invariant_mode with AscendC ops available""" + # Mock dependencies + mock_library = MagicMock() + batch_invariant.torch.library.Library = MagicMock(return_value=mock_library) + batch_invariant.torch.ops.batch_invariant_ops = MagicMock() + + # Call function + batch_invariant.enable_batch_invariant_mode() + + # Verify library created + batch_invariant.torch.library.Library.assert_called_once_with("aten", "IMPL") + + # Verify operator registrations + assert mock_library.impl.call_count == 3 + mock_library.impl.assert_any_call( + "aten::mm", batch_invariant.torch.ops.batch_invariant_ops.npu_mm_batch_invariant, "NPU" + ) + mock_library.impl.assert_any_call( + "aten::matmul", batch_invariant.torch.ops.batch_invariant_ops.npu_matmul_batch_invariant, "NPU" + ) + mock_library.impl.assert_any_call( + "aten::sum", batch_invariant.torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant, "NPU" + ) + + # Verify torch_npu function patching + assert ( + batch_invariant.torch_npu.npu_fused_infer_attention_score + == batch_invariant.torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant + ) + + @patch("vllm_ascend.batch_invariant.HAS_TRITON", True) + @patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False) + def test_enable_batch_invariant_mode_triton_path(self): + """Test enable_batch_invariant_mode with only Triton available""" + # Mock dependencies + mock_library = MagicMock() + batch_invariant.torch.library.Library = MagicMock(return_value=mock_library) + + # Mock triton imports + batch_invariant.addmm_batch_invariant = MagicMock() + batch_invariant.bmm_batch_invariant = MagicMock() + batch_invariant.mm_batch_invariant = MagicMock() + batch_invariant.matmul_batch_invariant = MagicMock() + batch_invariant.linear_batch_invariant = MagicMock() + + # Call function + batch_invariant.enable_batch_invariant_mode() + + # Verify operator registrations + assert mock_library.impl.call_count == 5 + mock_library.impl.assert_any_call("aten::addmm", batch_invariant.addmm_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::bmm", batch_invariant.bmm_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::mm", batch_invariant.mm_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::matmul", batch_invariant.matmul_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::linear", batch_invariant.linear_batch_invariant, "NPU") + + @patch("vllm_ascend.batch_invariant.HAS_TRITON", False) + @patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False) + def test_enable_batch_invariant_mode_no_backend(self): + """Test enable_batch_invariant_mode with no backends available""" + # Mock library + mock_library = MagicMock() + batch_invariant.torch.library.Library = MagicMock(return_value=mock_library) + + # Call function + batch_invariant.enable_batch_invariant_mode() + + # Verify no operators registered + mock_library.impl.assert_not_called() + + @pytest.mark.parametrize( + "batch_invariant_enabled, has_backend, expected_logger_call", + [(True, True, "info"), (True, False, "warning"), (False, True, None), (False, False, None)], + ) + def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expected_logger_call): + """Test init_batch_invariance under different conditions""" + # Mock dependencies + batch_invariant.vllm_is_batch_invariant = MagicMock(return_value=batch_invariant_enabled) + batch_invariant.HAS_TRITON = has_backend + batch_invariant.HAS_ASCENDC_BATCH_INVARIANT = has_backend + batch_invariant.override_envs_for_invariance = MagicMock() + batch_invariant.enable_batch_invariant_mode = MagicMock() + + # Call function + batch_invariant.init_batch_invariance() + + # Verify function calls based on conditions + if batch_invariant_enabled and has_backend: + batch_invariant.override_envs_for_invariance.assert_called_once() + batch_invariant.enable_batch_invariant_mode.assert_called_once() + elif batch_invariant_enabled and not has_backend: + batch_invariant.override_envs_for_invariance.assert_not_called() + batch_invariant.enable_batch_invariant_mode.assert_not_called() + else: + batch_invariant.override_envs_for_invariance.assert_not_called() + batch_invariant.enable_batch_invariant_mode.assert_not_called() + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py index fa3e10a99a3..38b6ad7d757 100644 --- a/vllm_ascend/batch_invariant.py +++ b/vllm_ascend/batch_invariant.py @@ -19,6 +19,7 @@ import os import torch +import torch_npu from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.triton_utils import HAS_TRITON @@ -35,15 +36,21 @@ ) -def override_envs_for_invariance(): - # TODO(Ronald) set attntion backend to deterministic mode +try: + import batch_invariant_ops # type: ignore[import-not-found] # noqa + + HAS_ASCENDC_BATCH_INVARIANT = True +except ImportError: + HAS_ASCENDC_BATCH_INVARIANT = False + +def override_envs_for_invariance(): # enabling NZ mode introduces NZ format input to the triton operator, # resulting in accuracy anomalies. os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0" # communication determinism settings - os.environ["HCCL_DETERMINISTIC"] = "true" + os.environ["HCCL_DETERMINISTIC"] = "strict" os.environ["LCCL_DETERMINISTIC"] = "1" @@ -52,14 +59,32 @@ def override_envs_for_invariance(): def enable_batch_invariant_mode(): global _batch_invariant_LIB - _batch_invariant_LIB = torch.library.Library("aten", "IMPL") - _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU") + # Register operators only implemented in triton. + if HAS_TRITON: + _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU") + + # Register operators implemented in Ascend batch-invariant ops in priority. + if HAS_ASCENDC_BATCH_INVARIANT: + _batch_invariant_LIB.impl("aten::mm", torch.ops.batch_invariant_ops.npu_mm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::matmul", torch.ops.batch_invariant_ops.npu_matmul_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::sum", torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant, "NPU") + # torch_npu.npu_fused_infer_attention_score is a function of torch_npu, not a torch.ops.Operator, + # so we need to patch it directly. + torch_npu.npu_fused_infer_attention_score = ( + torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant + ) + + # register triton implementations if ascendc is not available. + elif HAS_TRITON: + _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU") + + # linear call matmul internally, so register linear only when ascendc + # is not available. it will get better performance with ascendc. + _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU") def init_batch_invariance(): @@ -75,7 +100,7 @@ def init_batch_invariance(): environment variable to enable automatically. """ if vllm_is_batch_invariant(): - if HAS_TRITON: + if HAS_TRITON or HAS_ASCENDC_BATCH_INVARIANT: logger.info( "Enabling batch-invariant mode for vLLM on Ascend NPU.", ) @@ -83,5 +108,6 @@ def init_batch_invariance(): enable_batch_invariant_mode() else: logger.warning( - "Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization.", + "Batch-invariant mode requested but Triton or AscendC batch-invariant " + "ops is not available.skipping batch-invariant initialization." )