Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import importlib
import os
from unittest.mock import MagicMock, patch

import torch
from torch.distributed import ProcessGroup, ReduceOp
from vllm.config import ParallelConfig

from tests.ut.base import TestBase


class TestPatchPlatformDistributed(TestBase):

@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_HCCL_ALLREDUCE": "1"})
@patch(
"vllm.distributed.utils.stateless_init_torch_distributed_process_group"
)
def test_ascend_stateless_init_dp_group_called_when_optimized(
self, mock_init_process_group):
# We have to patch and reload because the patch will take effect
# only after VLLM_ASCEND_ENABLE_HCCL_ALLREDUCE is set.
import vllm_ascend.patch.platform.patch_common.patch_distributed
importlib.reload(
vllm_ascend.patch.platform.patch_common.patch_distributed)

test_parallel_config = ParallelConfig()

test_parallel_config.data_parallel_master_ip = "127.0.0.1"
test_parallel_config.data_parallel_rank = 0
test_parallel_config.data_parallel_size = 2

mock_port = 12345
test_parallel_config.get_next_dp_init_port = MagicMock(
return_value=mock_port)

mock_pg_instance = MagicMock(spec=ProcessGroup)
mock_init_process_group.return_value = mock_pg_instance

result = test_parallel_config.stateless_init_dp_group()

self.assertIs(result, mock_pg_instance)

mock_init_process_group.assert_called_once_with("127.0.0.1",
mock_port,
0,
2,
backend="hccl")

test_parallel_config.get_next_dp_init_port.assert_called_once()

@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_HCCL_ALLREDUCE": "1"})
@patch("torch.distributed.all_reduce")
@patch("torch.tensor")
def test_ascend_has_unfinished_dp_when_optimized2(self, mock_tensor,
mock_all_reduce):
# We have to patch and reload because the patch will take effect
# only after VLLM_ASCEND_ENABLE_HCCL_ALLREDUCE is set.
import vllm_ascend.patch.platform.patch_common.patch_distributed
importlib.reload(
vllm_ascend.patch.platform.patch_common.patch_distributed)

mock_tensor_instance = MagicMock()
mock_tensor_instance.dtype = torch.int32
mock_tensor_instance.device.type = "npu"
mock_tensor_instance.item.return_value = 1
mock_tensor.return_value = mock_tensor_instance

test_parallel_config = ParallelConfig()
mock_pg_instance = MagicMock(spec=ProcessGroup)
mock_all_reduce.return_value = None

result = test_parallel_config.has_unfinished_dp(mock_pg_instance,
has_unfinished=True)

self.assertTrue(result)

mock_tensor.assert_called_once_with([True],
dtype=torch.int32,
device="npu")

mock_all_reduce.assert_called_once()
args, kwargs = mock_all_reduce.call_args
self.assertEqual(kwargs["op"], ReduceOp.MAX)
self.assertEqual(kwargs["group"], mock_pg_instance)
4 changes: 4 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@
# 1: enable moe all2all seq.
"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ":
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))),
# Whether to enable hccl allreduce in has_unfinished_dp function.
# this feature is supported in A3, and will get better performance.
"VLLM_ASCEND_ENABLE_HCCL_ALLREDUCE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_HCCL_ALLREDUCE", '0'))),
}

# end-env-vars-definition
Expand Down
22 changes: 22 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,28 @@
# Need a PR to vllm to support get port from environment.
# Future Plan:
# Remove those patch when vllm merged them
# 2. `vllm.config.ParallelConfig.ParallelConfig.stateless_init_dp_group`
# Why:
# vLLM use gloo backend by default to initialize stateless dp process gourp, but we want to use hccl here to
# get better performance
# How:
# adopt hccl backend to init process group.(Now use VLLM_ASCEND_ENABLE_HCCL_ALLREDUCE
# environment variable to enable this patch)
# Related PR (if no, explain why):
# Need a PR to vllm to support more backend.
# Future Plan:
# Remove those patch when vllm merged them
# 3. `vllm.config.ParallelConfig.ParallelConfig.has_unfinished_dp`
# Why:
# vLLM use gloo backend by default to initialize stateless dp process gourp, but we want to use hccl here to
# get better performance in has_unfinished_dp function
# How:
# adopt hccl backend to init process group.(Now use VLLM_ASCEND_ENABLE_HCCL_ALLREDUCE
# environment variable to enable this patch)
# Related PR (if no, explain why):
# Need a PR to vllm to support more backend.
# Future Plan:
# Remove those patch when vllm merged them
#
# * Worker Patch:
# ===============
Expand Down
30 changes: 30 additions & 0 deletions vllm_ascend/patch/platform/patch_common/patch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@

import torch
import vllm.envs as envs
from torch.distributed import ProcessGroup, ReduceOp
from vllm.config import ParallelConfig
from vllm.distributed.utils import \
stateless_init_torch_distributed_process_group

from vllm_ascend import envs as ascend_envs
from vllm_ascend.utils import is_310p


Expand All @@ -41,7 +45,33 @@ def parallel_config_get_dp_port(self) -> int:
return port


def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
dp_group = stateless_init_torch_distributed_process_group(
self.data_parallel_master_ip,
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend="hccl")

return dp_group


def ascend_has_unfinished_dp(dp_group: "ProcessGroup",
has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="npu")
# dp rank 0: has_unfinished_seqs=True
# dp rank 1: has_unfinished_seqs=False
# aggregated: has_unfinished_seqs=True
# so this is an OR operation, i.e. MAX in integers
torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
aggregated_has_unfinished = bool(tensor.item())
return aggregated_has_unfinished


ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
if ascend_envs.VLLM_ASCEND_ENABLE_HCCL_ALLREDUCE:
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group
ParallelConfig.has_unfinished_dp = staticmethod(ascend_has_unfinished_dp)


class NullHandle:
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,8 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int):
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
# Detailed introduction is as follows:
# https://www.hiascend.com/document/detail/zh/Pytorch/710/modthirdparty/torchairuseguide/torchair_00016.html
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
Expand Down
Loading