From 10792d6b4097bf00abf21ee224d0494e1459b2a5 Mon Sep 17 00:00:00 2001 From: Ronald1995 Date: Wed, 4 Feb 2026 11:06:41 +0800 Subject: [PATCH 1/7] implment batch_invariant with ascendc operator Signed-off-by: Ronald1995 --- vllm_ascend/batch_invariant.py | 47 +++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py index fa3e10a99a3..4491b0776c7 100644 --- a/vllm_ascend/batch_invariant.py +++ b/vllm_ascend/batch_invariant.py @@ -17,6 +17,7 @@ # This file is a part of the vllm-ascend project. # import os +import importlib import torch from vllm.logger import init_logger @@ -34,6 +35,8 @@ mm_batch_invariant, ) +HAS_ASCENDC_BATCH_INVARIANT = importlib.util.find_spec("custom_ops") is not None + def override_envs_for_invariance(): # TODO(Ronald) set attntion backend to deterministic mode @@ -43,23 +46,43 @@ def override_envs_for_invariance(): os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0" # communication determinism settings - os.environ["HCCL_DETERMINISTIC"] = "true" + os.environ["HCCL_DETERMINISTIC"] = "strict" os.environ["LCCL_DETERMINISTIC"] = "1" -_batch_invariant_LIB = None +_aten_batch_invariant_LIB = None +_npu_batch_invariant_LIB = None def enable_batch_invariant_mode(): - global _batch_invariant_LIB - - _batch_invariant_LIB = torch.library.Library("aten", "IMPL") - - _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU") + global _aten_batch_invariant_LIB + global _npu_batch_invariant_LIB + _aten_batch_invariant_LIB = torch.library.Library("aten", "IMPL") + _npu_batch_invariant_LIB = torch.library.Library("npu", "IMPL") + + # Register operators only implemented in triton. + if HAS_TRITON: + _aten_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU") + _aten_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU") + + # Register operators implemented in Ascend batch-invariant ops in priority. + if HAS_ASCENDC_BATCH_INVARIANT: + _aten_batch_invariant_LIB.impl("aten::mm", torch.ops.my_ops.npu_mm_batch_invariant, "NPU") + _aten_batch_invariant_LIB.impl("aten::matmul", torch.ops.my_ops.npu_matmul_batch_invariant, "NPU") + _npu_batch_invariant_LIB.impl( + "npu::npu_fused_infer_attention_score", + torch.ops.my_ops.npu_fused_infer_attention_score_batch_invariant, + "NPU", + ) + + # register triton implementations if ascendc is not available. + elif HAS_TRITON: + _aten_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU") + _aten_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU") + + # linear call matmul internally, so register linear only when ascendc + # is not available. it will get better performance with ascendc. + _aten_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU") def init_batch_invariance(): @@ -75,7 +98,7 @@ def init_batch_invariance(): environment variable to enable automatically. """ if vllm_is_batch_invariant(): - if HAS_TRITON: + if HAS_TRITON or HAS_ASCENDC_BATCH_INVARIANT: logger.info( "Enabling batch-invariant mode for vLLM on Ascend NPU.", ) From 7815013468081be3f281510de397c352a73d0b36 Mon Sep 17 00:00:00 2001 From: Ronald1995 Date: Fri, 6 Feb 2026 11:40:25 +0800 Subject: [PATCH 2/7] ruff format Signed-off-by: Ronald1995 --- vllm_ascend/batch_invariant.py | 39 ++++++++++++++++------------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py index 4491b0776c7..522598caea5 100644 --- a/vllm_ascend/batch_invariant.py +++ b/vllm_ascend/batch_invariant.py @@ -16,10 +16,11 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -import os import importlib +import os import torch +import torch_npu from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.triton_utils import HAS_TRITON @@ -39,8 +40,6 @@ def override_envs_for_invariance(): - # TODO(Ronald) set attntion backend to deterministic mode - # enabling NZ mode introduces NZ format input to the triton operator, # resulting in accuracy anomalies. os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0" @@ -50,39 +49,37 @@ def override_envs_for_invariance(): os.environ["LCCL_DETERMINISTIC"] = "1" -_aten_batch_invariant_LIB = None -_npu_batch_invariant_LIB = None +_batch_invariant_LIB = None def enable_batch_invariant_mode(): - global _aten_batch_invariant_LIB - global _npu_batch_invariant_LIB - _aten_batch_invariant_LIB = torch.library.Library("aten", "IMPL") - _npu_batch_invariant_LIB = torch.library.Library("npu", "IMPL") + global _batch_invariant_LIB + _batch_invariant_LIB = torch.library.Library("aten", "IMPL") # Register operators only implemented in triton. if HAS_TRITON: - _aten_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU") - _aten_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU") # Register operators implemented in Ascend batch-invariant ops in priority. if HAS_ASCENDC_BATCH_INVARIANT: - _aten_batch_invariant_LIB.impl("aten::mm", torch.ops.my_ops.npu_mm_batch_invariant, "NPU") - _aten_batch_invariant_LIB.impl("aten::matmul", torch.ops.my_ops.npu_matmul_batch_invariant, "NPU") - _npu_batch_invariant_LIB.impl( - "npu::npu_fused_infer_attention_score", - torch.ops.my_ops.npu_fused_infer_attention_score_batch_invariant, - "NPU", - ) + import custom_ops # noqa + + _batch_invariant_LIB.impl("aten::mm", torch.ops.myops.npu_mm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::matmul", torch.ops.myops.npu_matmul_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::sum", torch.ops.myops.npu_reduce_sum_batch_invariant, "NPU") + # torch_npu.npu_fused_infer_attention_score is a function of torch_npu, not a torch.ops.Operator, + # so we need to patch it directly. + torch_npu.npu_fused_infer_attention_score = torch.ops.myops.npu_fused_infer_attention_score_batch_invariant # register triton implementations if ascendc is not available. elif HAS_TRITON: - _aten_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU") - _aten_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU") # linear call matmul internally, so register linear only when ascendc # is not available. it will get better performance with ascendc. - _aten_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU") def init_batch_invariance(): From 6902b3ad91d21a2f6d1e3521edbc95ae5f084d09 Mon Sep 17 00:00:00 2001 From: Ronald1995 Date: Mon, 9 Feb 2026 10:20:05 +0800 Subject: [PATCH 3/7] add ut Signed-off-by: Ronald1995 --- .../batch_invariant/test_batch_invariant.py | 228 ++++++++++++++++++ vllm_ascend/batch_invariant.py | 10 +- 2 files changed, 235 insertions(+), 3 deletions(-) create mode 100644 tests/ut/batch_invariant/test_batch_invariant.py diff --git a/tests/ut/batch_invariant/test_batch_invariant.py b/tests/ut/batch_invariant/test_batch_invariant.py new file mode 100644 index 00000000000..244c207e880 --- /dev/null +++ b/tests/ut/batch_invariant/test_batch_invariant.py @@ -0,0 +1,228 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import os +import sys +import importlib +import pytest +import torch +from unittest.mock import patch, MagicMock, mock_open + +# Mock the vllm and torch_npu modules to isolate the tests +class MockVLLM: + class Logger: + @staticmethod + def init_logger(name): + logger = MagicMock() + logger.info = MagicMock() + logger.warning = MagicMock() + return logger + + class ModelExecutor: + class Layers: + class BatchInvariant: + @staticmethod + def vllm_is_batch_invariant(): + return False + + class TritonUtils: + HAS_TRITON = False + +class MockTorchNPU: + npu_fused_infer_attention_score = MagicMock() + +# Set up mock modules +@pytest.fixture(autouse=True) +def setup_mocks(): + # Save original modules + original_modules = { + 'vllm': sys.modules.get('vllm'), + 'torch': sys.modules.get('torch'), + 'torch_npu': sys.modules.get('torch_npu'), + 'custom_ops': sys.modules.get('custom_ops') + } + + # Create mock modules + sys.modules['vllm'] = MockVLLM() + sys.modules['vllm.logger'] = MockVLLM.Logger() + sys.modules['vllm.model_executor'] = MockVLLM.ModelExecutor() + sys.modules['vllm.model_executor.layers'] = MockVLLM.ModelExecutor.Layers() + sys.modules['vllm.model_executor.layers.batch_invariant'] = MockVLLM.ModelExecutor.Layers.BatchInvariant() + sys.modules['vllm.triton_utils'] = MockVLLM.TritonUtils() + sys.modules['torch'] = MagicMock() + sys.modules['torch_npu'] = MockTorchNPU() + + # Import the module after mocking + global batch_invariant + import vllm_ascend.batch_invariant as batch_invariant + + yield + + # Restore original modules + for name, module in original_modules.items(): + if module is not None: + sys.modules[name] = module + else: + sys.modules.pop(name, None) + + +class TestBatchInvariant: + """Complete test suite for batch_invariant.py""" + + def test_override_envs_for_invariance(self): + """Test environment variable override""" + # Clear environment variables + env_vars = ['VLLM_ASCEND_ENABLE_NZ', 'HCCL_DETERMINISTIC', 'LCCL_DETERMINISTIC'] + for var in env_vars: + if var in os.environ: + del os.environ[var] + + # Call function + batch_invariant.override_envs_for_invariance() + + # Verify environment variables + assert os.environ['VLLM_ASCEND_ENABLE_NZ'] == '0' + assert os.environ['HCCL_DETERMINISTIC'] == 'strict' + assert os.environ['LCCL_DETERMINISTIC'] == '1' + + @pytest.mark.parametrize("custom_ops_available, expected_value", [ + (True, True), + (False, False) + ]) + def test_has_ascendc_batch_invariant(self, custom_ops_available): + """Test HAS_ASCENDC_BATCH_INVARIANT detection""" + # Control custom_ops availability + if custom_ops_available: + sys.modules['custom_ops'] = MagicMock() + else: + sys.modules.pop('custom_ops', None) + + # Reload module to re-evaluate the flag + importlib.reload(batch_invariant) + + # Verify result + assert batch_invariant.HAS_ASCENDC_BATCH_INVARIANT == expected_value + + @patch('vllm_ascend.batch_invariant.HAS_TRITON', False) + @patch('vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT', True) + def test_enable_batch_invariant_mode_ascendc_path(self): + """Test enable_batch_invariant_mode with AscendC ops available""" + # Mock dependencies + mock_library = MagicMock() + batch_invariant.torch.library.Library = MagicMock(return_value=mock_library) + batch_invariant.torch.ops.myops = MagicMock() + + # Call function + batch_invariant.enable_batch_invariant_mode() + + # Verify library created + batch_invariant.torch.library.Library.assert_called_once_with("aten", "IMPL") + + # Verify operator registrations + assert mock_library.impl.call_count == 3 + mock_library.impl.assert_any_call( + "aten::mm", batch_invariant.torch.ops.myops.npu_mm_batch_invariant, "NPU" + ) + mock_library.impl.assert_any_call( + "aten::matmul", batch_invariant.torch.ops.myops.npu_matmul_batch_invariant, "NPU" + ) + mock_library.impl.assert_any_call( + "aten::sum", batch_invariant.torch.ops.myops.npu_reduce_sum_batch_invariant, "NPU" + ) + + # Verify torch_npu function patching + assert batch_invariant.torch_npu.npu_fused_infer_attention_score == \ + batch_invariant.torch.ops.myops.npu_fused_infer_attention_score_batch_invariant + + @patch('vllm_ascend.batch_invariant.HAS_TRITON', True) + @patch('vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT', False) + def test_enable_batch_invariant_mode_triton_path(self): + """Test enable_batch_invariant_mode with only Triton available""" + # Mock dependencies + mock_library = MagicMock() + batch_invariant.torch.library.Library = MagicMock(return_value=mock_library) + + # Mock triton imports + batch_invariant.addmm_batch_invariant = MagicMock() + batch_invariant.bmm_batch_invariant = MagicMock() + batch_invariant.mm_batch_invariant = MagicMock() + batch_invariant.matmul_batch_invariant = MagicMock() + batch_invariant.linear_batch_invariant = MagicMock() + + # Call function + batch_invariant.enable_batch_invariant_mode() + + # Verify operator registrations + assert mock_library.impl.call_count == 4 + mock_library.impl.assert_any_call("aten::addmm", batch_invariant.addmm_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::bmm", batch_invariant.bmm_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::mm", batch_invariant.mm_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::matmul", batch_invariant.matmul_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::linear", batch_invariant.linear_batch_invariant, "NPU") + + @patch('vllm_ascend.batch_invariant.HAS_TRITON', False) + @patch('vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT', False) + def test_enable_batch_invariant_mode_no_backend(self): + """Test enable_batch_invariant_mode with no backends available""" + # Mock library + mock_library = MagicMock() + batch_invariant.torch.library.Library = MagicMock(return_value=mock_library) + + # Call function + batch_invariant.enable_batch_invariant_mode() + + # Verify no operators registered + mock_library.impl.assert_not_called() + + @pytest.mark.parametrize("batch_invariant_enabled, has_backend, expected_logger_call", [ + (True, True, 'info'), + (True, False, 'warning'), + (False, True, None), + (False, False, None) + ]) + def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expected_logger_call): + """Test init_batch_invariance under different conditions""" + # Mock dependencies + batch_invariant.vllm_is_batch_invariant = MagicMock(return_value=batch_invariant_enabled) + batch_invariant.HAS_TRITON = has_backend + batch_invariant.HAS_ASCENDC_BATCH_INVARIANT = has_backend + batch_invariant.override_envs_for_invariance = MagicMock() + batch_invariant.enable_batch_invariant_mode = MagicMock() + logger = batch_invariant.logger + + # Call function + batch_invariant.init_batch_invariance() + + # Verify function calls based on conditions + if batch_invariant_enabled and has_backend: + batch_invariant.override_envs_for_invariance.assert_called_once() + batch_invariant.enable_batch_invariant_mode.assert_called_once() + logger.info.assert_called_once_with("Enabling batch-invariant mode for vLLM on Ascend NPU.") + elif batch_invariant_enabled and not has_backend: + batch_invariant.override_envs_for_invariance.assert_not_called() + batch_invariant.enable_batch_invariant_mode.assert_not_called() + logger.warning.assert_called_once_with( + "Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization." + ) + else: + batch_invariant.override_envs_for_invariance.assert_not_called() + batch_invariant.enable_batch_invariant_mode.assert_not_called() + logger.info.assert_not_called() + logger.warning.assert_not_called() + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py index 522598caea5..6f4d5c51364 100644 --- a/vllm_ascend/batch_invariant.py +++ b/vllm_ascend/batch_invariant.py @@ -36,7 +36,13 @@ mm_batch_invariant, ) -HAS_ASCENDC_BATCH_INVARIANT = importlib.util.find_spec("custom_ops") is not None + +try: + import custom_ops # noqa + + HAS_ASCENDC_BATCH_INVARIANT = True +except ImportError: + HAS_ASCENDC_BATCH_INVARIANT = False def override_envs_for_invariance(): @@ -63,8 +69,6 @@ def enable_batch_invariant_mode(): # Register operators implemented in Ascend batch-invariant ops in priority. if HAS_ASCENDC_BATCH_INVARIANT: - import custom_ops # noqa - _batch_invariant_LIB.impl("aten::mm", torch.ops.myops.npu_mm_batch_invariant, "NPU") _batch_invariant_LIB.impl("aten::matmul", torch.ops.myops.npu_matmul_batch_invariant, "NPU") _batch_invariant_LIB.impl("aten::sum", torch.ops.myops.npu_reduce_sum_batch_invariant, "NPU") From 629435457af9a18d86eafdc6466f73a134a9665b Mon Sep 17 00:00:00 2001 From: Ronald1995 Date: Mon, 9 Feb 2026 10:42:03 +0800 Subject: [PATCH 4/7] fix ut Signed-off-by: Ronald1995 --- tests/ut/batch_invariant/test_batch_invariant.py | 13 ++++++++----- vllm_ascend/batch_invariant.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/ut/batch_invariant/test_batch_invariant.py b/tests/ut/batch_invariant/test_batch_invariant.py index 244c207e880..ea8721d7ade 100644 --- a/tests/ut/batch_invariant/test_batch_invariant.py +++ b/tests/ut/batch_invariant/test_batch_invariant.py @@ -102,7 +102,7 @@ def test_override_envs_for_invariance(self): (True, True), (False, False) ]) - def test_has_ascendc_batch_invariant(self, custom_ops_available): + def test_has_ascendc_batch_invariant(self, custom_ops_available, expected_value): """Test HAS_ASCENDC_BATCH_INVARIANT detection""" # Control custom_ops availability if custom_ops_available: @@ -166,7 +166,7 @@ def test_enable_batch_invariant_mode_triton_path(self): batch_invariant.enable_batch_invariant_mode() # Verify operator registrations - assert mock_library.impl.call_count == 4 + assert mock_library.impl.call_count == 5 mock_library.impl.assert_any_call("aten::addmm", batch_invariant.addmm_batch_invariant, "NPU") mock_library.impl.assert_any_call("aten::bmm", batch_invariant.bmm_batch_invariant, "NPU") mock_library.impl.assert_any_call("aten::mm", batch_invariant.mm_batch_invariant, "NPU") @@ -215,13 +215,16 @@ def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expec batch_invariant.override_envs_for_invariance.assert_not_called() batch_invariant.enable_batch_invariant_mode.assert_not_called() logger.warning.assert_called_once_with( - "Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization." + "Batch-invariant mode requested but Triton or AscendC batch-invariant " + "ops is not available.skipping batch-invariant initialization." ) else: batch_invariant.override_envs_for_invariance.assert_not_called() batch_invariant.enable_batch_invariant_mode.assert_not_called() - logger.info.assert_not_called() - logger.warning.assert_not_called() + logger.warning.assert_called_once_with( + "Batch-invariant mode requested but Triton or AscendC batch-invariant " + "ops is not available.skipping batch-invariant initialization." + ) if __name__ == "__main__": diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py index 6f4d5c51364..712d9835d74 100644 --- a/vllm_ascend/batch_invariant.py +++ b/vllm_ascend/batch_invariant.py @@ -107,5 +107,6 @@ def init_batch_invariance(): enable_batch_invariant_mode() else: logger.warning( - "Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization.", + "Batch-invariant mode requested but Triton or AscendC batch-invariant " + "ops is not available.skipping batch-invariant initialization." ) From 834711088cefa1080d790be801dc3cfc828d9fb6 Mon Sep 17 00:00:00 2001 From: Ronald1995 Date: Mon, 9 Feb 2026 11:01:57 +0800 Subject: [PATCH 5/7] fix import error Signed-off-by: Ronald1995 --- vllm_ascend/batch_invariant.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py index 712d9835d74..db2d6f92bfd 100644 --- a/vllm_ascend/batch_invariant.py +++ b/vllm_ascend/batch_invariant.py @@ -16,7 +16,6 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -import importlib import os import torch From 82ab2fee0e55396f235b5ef5b50f5b0fa04ed46a Mon Sep 17 00:00:00 2001 From: Ronald1995 Date: Mon, 9 Feb 2026 11:17:13 +0800 Subject: [PATCH 6/7] fix mypy error Signed-off-by: Ronald1995 --- .../batch_invariant/test_batch_invariant.py | 122 +++++++++--------- vllm_ascend/batch_invariant.py | 12 +- 2 files changed, 68 insertions(+), 66 deletions(-) diff --git a/tests/ut/batch_invariant/test_batch_invariant.py b/tests/ut/batch_invariant/test_batch_invariant.py index ea8721d7ade..db6af966cbe 100644 --- a/tests/ut/batch_invariant/test_batch_invariant.py +++ b/tests/ut/batch_invariant/test_batch_invariant.py @@ -13,13 +13,14 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # - +# type: ignore +import importlib import os import sys -import importlib +from unittest.mock import MagicMock, patch + import pytest -import torch -from unittest.mock import patch, MagicMock, mock_open + # Mock the vllm and torch_npu modules to isolate the tests class MockVLLM: @@ -30,40 +31,42 @@ def init_logger(name): logger.info = MagicMock() logger.warning = MagicMock() return logger - + class ModelExecutor: class Layers: class BatchInvariant: @staticmethod def vllm_is_batch_invariant(): return False - + class TritonUtils: HAS_TRITON = False + class MockTorchNPU: npu_fused_infer_attention_score = MagicMock() + # Set up mock modules @pytest.fixture(autouse=True) def setup_mocks(): # Save original modules original_modules = { - 'vllm': sys.modules.get('vllm'), - 'torch': sys.modules.get('torch'), - 'torch_npu': sys.modules.get('torch_npu'), - 'custom_ops': sys.modules.get('custom_ops') + "vllm": sys.modules.get("vllm"), + "torch": sys.modules.get("torch"), + "torch_npu": sys.modules.get("torch_npu"), + "custom_ops": sys.modules.get("custom_ops"), } # Create mock modules - sys.modules['vllm'] = MockVLLM() - sys.modules['vllm.logger'] = MockVLLM.Logger() - sys.modules['vllm.model_executor'] = MockVLLM.ModelExecutor() - sys.modules['vllm.model_executor.layers'] = MockVLLM.ModelExecutor.Layers() - sys.modules['vllm.model_executor.layers.batch_invariant'] = MockVLLM.ModelExecutor.Layers.BatchInvariant() - sys.modules['vllm.triton_utils'] = MockVLLM.TritonUtils() - sys.modules['torch'] = MagicMock() - sys.modules['torch_npu'] = MockTorchNPU() + sys.modules["vllm"] = MockVLLM() + sys.modules["vllm.logger"] = MockVLLM.Logger() + sys.modules["vllm.model_executor"] = MockVLLM.ModelExecutor() + sys.modules["vllm.model_executor.layers"] = MockVLLM.ModelExecutor.Layers() + sys.modules["vllm.model_executor.layers.batch_invariant"] = MockVLLM.ModelExecutor.Layers.BatchInvariant() + sys.modules["vllm.triton_utils"] = MockVLLM.TritonUtils() + sys.modules["torch"] = MagicMock() + sys.modules["torch_npu"] = MockTorchNPU() # Import the module after mocking global batch_invariant @@ -85,7 +88,7 @@ class TestBatchInvariant: def test_override_envs_for_invariance(self): """Test environment variable override""" # Clear environment variables - env_vars = ['VLLM_ASCEND_ENABLE_NZ', 'HCCL_DETERMINISTIC', 'LCCL_DETERMINISTIC'] + env_vars = ["VLLM_ASCEND_ENABLE_NZ", "HCCL_DETERMINISTIC", "LCCL_DETERMINISTIC"] for var in env_vars: if var in os.environ: del os.environ[var] @@ -94,77 +97,76 @@ def test_override_envs_for_invariance(self): batch_invariant.override_envs_for_invariance() # Verify environment variables - assert os.environ['VLLM_ASCEND_ENABLE_NZ'] == '0' - assert os.environ['HCCL_DETERMINISTIC'] == 'strict' - assert os.environ['LCCL_DETERMINISTIC'] == '1' - - @pytest.mark.parametrize("custom_ops_available, expected_value", [ - (True, True), - (False, False) - ]) + assert os.environ["VLLM_ASCEND_ENABLE_NZ"] == "0" + assert os.environ["HCCL_DETERMINISTIC"] == "strict" + assert os.environ["LCCL_DETERMINISTIC"] == "1" + + @pytest.mark.parametrize("custom_ops_available, expected_value", [(True, True), (False, False)]) def test_has_ascendc_batch_invariant(self, custom_ops_available, expected_value): """Test HAS_ASCENDC_BATCH_INVARIANT detection""" # Control custom_ops availability if custom_ops_available: - sys.modules['custom_ops'] = MagicMock() + sys.modules["batch_invariant_ops"] = MagicMock() else: - sys.modules.pop('custom_ops', None) - + sys.modules.pop("batch_invariant_ops", None) + # Reload module to re-evaluate the flag importlib.reload(batch_invariant) - + # Verify result assert batch_invariant.HAS_ASCENDC_BATCH_INVARIANT == expected_value - @patch('vllm_ascend.batch_invariant.HAS_TRITON', False) - @patch('vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT', True) + @patch("vllm_ascend.batch_invariant.HAS_TRITON", False) + @patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", True) def test_enable_batch_invariant_mode_ascendc_path(self): """Test enable_batch_invariant_mode with AscendC ops available""" # Mock dependencies mock_library = MagicMock() batch_invariant.torch.library.Library = MagicMock(return_value=mock_library) - batch_invariant.torch.ops.myops = MagicMock() - + batch_invariant.torch.ops.batch_invariant_ops = MagicMock() + # Call function batch_invariant.enable_batch_invariant_mode() - + # Verify library created batch_invariant.torch.library.Library.assert_called_once_with("aten", "IMPL") - + # Verify operator registrations assert mock_library.impl.call_count == 3 mock_library.impl.assert_any_call( - "aten::mm", batch_invariant.torch.ops.myops.npu_mm_batch_invariant, "NPU" + "aten::mm", batch_invariant.torch.ops.batch_invariant_ops.npu_mm_batch_invariant, "NPU" ) mock_library.impl.assert_any_call( - "aten::matmul", batch_invariant.torch.ops.myops.npu_matmul_batch_invariant, "NPU" + "aten::matmul", batch_invariant.torch.ops.batch_invariant_ops.npu_matmul_batch_invariant, "NPU" ) mock_library.impl.assert_any_call( - "aten::sum", batch_invariant.torch.ops.myops.npu_reduce_sum_batch_invariant, "NPU" + "aten::sum", batch_invariant.torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant, "NPU" ) - + # Verify torch_npu function patching - assert batch_invariant.torch_npu.npu_fused_infer_attention_score == \ - batch_invariant.torch.ops.myops.npu_fused_infer_attention_score_batch_invariant + assert ( + batch_invariant.torch_npu.npu_fused_infer_attention_score + == batch_invariant.torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant + ) - @patch('vllm_ascend.batch_invariant.HAS_TRITON', True) - @patch('vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT', False) + @patch("vllm_ascend.batch_invariant.HAS_TRITON", True) + @patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False) def test_enable_batch_invariant_mode_triton_path(self): """Test enable_batch_invariant_mode with only Triton available""" # Mock dependencies mock_library = MagicMock() batch_invariant.torch.library.Library = MagicMock(return_value=mock_library) - + # Mock triton imports batch_invariant.addmm_batch_invariant = MagicMock() batch_invariant.bmm_batch_invariant = MagicMock() batch_invariant.mm_batch_invariant = MagicMock() batch_invariant.matmul_batch_invariant = MagicMock() batch_invariant.linear_batch_invariant = MagicMock() - + # Call function batch_invariant.enable_batch_invariant_mode() - + # Verify operator registrations assert mock_library.impl.call_count == 5 mock_library.impl.assert_any_call("aten::addmm", batch_invariant.addmm_batch_invariant, "NPU") @@ -173,26 +175,24 @@ def test_enable_batch_invariant_mode_triton_path(self): mock_library.impl.assert_any_call("aten::matmul", batch_invariant.matmul_batch_invariant, "NPU") mock_library.impl.assert_any_call("aten::linear", batch_invariant.linear_batch_invariant, "NPU") - @patch('vllm_ascend.batch_invariant.HAS_TRITON', False) - @patch('vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT', False) + @patch("vllm_ascend.batch_invariant.HAS_TRITON", False) + @patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False) def test_enable_batch_invariant_mode_no_backend(self): """Test enable_batch_invariant_mode with no backends available""" # Mock library mock_library = MagicMock() batch_invariant.torch.library.Library = MagicMock(return_value=mock_library) - + # Call function batch_invariant.enable_batch_invariant_mode() - + # Verify no operators registered mock_library.impl.assert_not_called() - @pytest.mark.parametrize("batch_invariant_enabled, has_backend, expected_logger_call", [ - (True, True, 'info'), - (True, False, 'warning'), - (False, True, None), - (False, False, None) - ]) + @pytest.mark.parametrize( + "batch_invariant_enabled, has_backend, expected_logger_call", + [(True, True, "info"), (True, False, "warning"), (False, True, None), (False, False, None)], + ) def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expected_logger_call): """Test init_batch_invariance under different conditions""" # Mock dependencies @@ -202,10 +202,10 @@ def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expec batch_invariant.override_envs_for_invariance = MagicMock() batch_invariant.enable_batch_invariant_mode = MagicMock() logger = batch_invariant.logger - + # Call function batch_invariant.init_batch_invariance() - + # Verify function calls based on conditions if batch_invariant_enabled and has_backend: batch_invariant.override_envs_for_invariance.assert_called_once() @@ -228,4 +228,4 @@ def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expec if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py index db2d6f92bfd..38b6ad7d757 100644 --- a/vllm_ascend/batch_invariant.py +++ b/vllm_ascend/batch_invariant.py @@ -37,7 +37,7 @@ try: - import custom_ops # noqa + import batch_invariant_ops # type: ignore[import-not-found] # noqa HAS_ASCENDC_BATCH_INVARIANT = True except ImportError: @@ -68,12 +68,14 @@ def enable_batch_invariant_mode(): # Register operators implemented in Ascend batch-invariant ops in priority. if HAS_ASCENDC_BATCH_INVARIANT: - _batch_invariant_LIB.impl("aten::mm", torch.ops.myops.npu_mm_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::matmul", torch.ops.myops.npu_matmul_batch_invariant, "NPU") - _batch_invariant_LIB.impl("aten::sum", torch.ops.myops.npu_reduce_sum_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::mm", torch.ops.batch_invariant_ops.npu_mm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::matmul", torch.ops.batch_invariant_ops.npu_matmul_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::sum", torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant, "NPU") # torch_npu.npu_fused_infer_attention_score is a function of torch_npu, not a torch.ops.Operator, # so we need to patch it directly. - torch_npu.npu_fused_infer_attention_score = torch.ops.myops.npu_fused_infer_attention_score_batch_invariant + torch_npu.npu_fused_infer_attention_score = ( + torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant + ) # register triton implementations if ascendc is not available. elif HAS_TRITON: From 165ab49864ad0662f1f7f6515f770b7e3c738d92 Mon Sep 17 00:00:00 2001 From: Ronald1995 Date: Tue, 10 Feb 2026 10:15:19 +0800 Subject: [PATCH 7/7] fix vllm.logger import error Signed-off-by: Ronald1995 --- .../batch_invariant/test_batch_invariant.py | 70 +------------------ 1 file changed, 1 insertion(+), 69 deletions(-) diff --git a/tests/ut/batch_invariant/test_batch_invariant.py b/tests/ut/batch_invariant/test_batch_invariant.py index db6af966cbe..55fe02e9fae 100644 --- a/tests/ut/batch_invariant/test_batch_invariant.py +++ b/tests/ut/batch_invariant/test_batch_invariant.py @@ -21,65 +21,7 @@ import pytest - -# Mock the vllm and torch_npu modules to isolate the tests -class MockVLLM: - class Logger: - @staticmethod - def init_logger(name): - logger = MagicMock() - logger.info = MagicMock() - logger.warning = MagicMock() - return logger - - class ModelExecutor: - class Layers: - class BatchInvariant: - @staticmethod - def vllm_is_batch_invariant(): - return False - - class TritonUtils: - HAS_TRITON = False - - -class MockTorchNPU: - npu_fused_infer_attention_score = MagicMock() - - -# Set up mock modules -@pytest.fixture(autouse=True) -def setup_mocks(): - # Save original modules - original_modules = { - "vllm": sys.modules.get("vllm"), - "torch": sys.modules.get("torch"), - "torch_npu": sys.modules.get("torch_npu"), - "custom_ops": sys.modules.get("custom_ops"), - } - - # Create mock modules - sys.modules["vllm"] = MockVLLM() - sys.modules["vllm.logger"] = MockVLLM.Logger() - sys.modules["vllm.model_executor"] = MockVLLM.ModelExecutor() - sys.modules["vllm.model_executor.layers"] = MockVLLM.ModelExecutor.Layers() - sys.modules["vllm.model_executor.layers.batch_invariant"] = MockVLLM.ModelExecutor.Layers.BatchInvariant() - sys.modules["vllm.triton_utils"] = MockVLLM.TritonUtils() - sys.modules["torch"] = MagicMock() - sys.modules["torch_npu"] = MockTorchNPU() - - # Import the module after mocking - global batch_invariant - import vllm_ascend.batch_invariant as batch_invariant - - yield - - # Restore original modules - for name, module in original_modules.items(): - if module is not None: - sys.modules[name] = module - else: - sys.modules.pop(name, None) +import vllm_ascend.batch_invariant as batch_invariant class TestBatchInvariant: @@ -201,7 +143,6 @@ def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expec batch_invariant.HAS_ASCENDC_BATCH_INVARIANT = has_backend batch_invariant.override_envs_for_invariance = MagicMock() batch_invariant.enable_batch_invariant_mode = MagicMock() - logger = batch_invariant.logger # Call function batch_invariant.init_batch_invariance() @@ -210,21 +151,12 @@ def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expec if batch_invariant_enabled and has_backend: batch_invariant.override_envs_for_invariance.assert_called_once() batch_invariant.enable_batch_invariant_mode.assert_called_once() - logger.info.assert_called_once_with("Enabling batch-invariant mode for vLLM on Ascend NPU.") elif batch_invariant_enabled and not has_backend: batch_invariant.override_envs_for_invariance.assert_not_called() batch_invariant.enable_batch_invariant_mode.assert_not_called() - logger.warning.assert_called_once_with( - "Batch-invariant mode requested but Triton or AscendC batch-invariant " - "ops is not available.skipping batch-invariant initialization." - ) else: batch_invariant.override_envs_for_invariance.assert_not_called() batch_invariant.enable_batch_invariant_mode.assert_not_called() - logger.warning.assert_called_once_with( - "Batch-invariant mode requested but Triton or AscendC batch-invariant " - "ops is not available.skipping batch-invariant initialization." - ) if __name__ == "__main__":