Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The following table lists the additional configuration options available in vLLM
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |

The details of each config option are as follows:

Expand Down
44 changes: 44 additions & 0 deletions tests/ut/distributed/test_parallel_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from unittest.mock import MagicMock, patch

import pytest
from vllm.config import ParallelConfig

from vllm_ascend.distributed.parallel_state import (
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
get_mc2_group, init_ascend_model_parallel)


@pytest.fixture
def parallel_config():
return ParallelConfig(data_parallel_size=2,
tensor_parallel_size=2,
pipeline_parallel_size=2)


@pytest.fixture
def mock_distributed():
with patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.get_world_size', return_value=8), \
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
yield


def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config = MagicMock()
mock_ascend_config.lmhead_tensor_parallel_size = 2
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
init_ascend_model_parallel(parallel_config)

mc2_group = get_mc2_group()
assert mc2_group is not None
lmheadtp_group = get_lmhead_tp_group()
assert lmheadtp_group is not None

destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None
17 changes: 16 additions & 1 deletion tests/ut/models/test_deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def setup_mtp_layer(self, mocker: MockerFixture):
mocker_deepseek_v2_decode_layer = mocker.patch(
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())

mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
Expand Down Expand Up @@ -83,6 +88,11 @@ def setup_predictor(self, mocker: MockerFixture):
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())

predictor = CustomDeepSeekMultiTokenPredictor(
vllm_config=mock_vllm_config)
Expand Down Expand Up @@ -157,6 +167,11 @@ def setup_mtp(self, mocker: MockerFixture):
return_value=None)
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())

mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
return mtp
Expand All @@ -177,4 +192,4 @@ def test_forward(self, mocker: MockerFixture, setup_mtp):
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
previous_hidden_states, inputs_embeds,
spec_step_idx)
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
29 changes: 28 additions & 1 deletion tests/ut/models/test_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
CustomDeepseekV2SiluAndMul)
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)


@pytest.fixture
Expand Down Expand Up @@ -266,3 +266,30 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")


def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
# 创建一个简单的配置对象
class SimpleConfig:

def __init__(self):
self.vocab_size = 10000
self.hidden_size = 128

config = SimpleConfig()

# 直接创建lmhead和logits_processor
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
logits_processor = LogitsProcessor(config.vocab_size)

# 创建模拟输出
mock_output = torch.randn(2, 4, config.hidden_size)
mock_logits = torch.randn(2, 4, config.vocab_size)

# 直接测试logits_processor
with patch.object(lmhead.quant_method, "apply", return_value=mock_logits):
with patch.object(logits_processor,
"_gather_logits",
return_value=mock_logits):
logits = logits_processor(lmhead, mock_output)
assert logits.shape == (2, 4, config.vocab_size)
62 changes: 59 additions & 3 deletions tests/ut/ops/test_vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import torch

from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)

VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128

Expand All @@ -34,7 +34,11 @@ def setUp(self):

def _create_layer(self):
# Patch methods and dependencies for VocabParallelEmbedding
with patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
mock_group = MagicMock()
mock_group.world_size = 2
mock_group.rank_in_group = 0
with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
Expand Down Expand Up @@ -174,3 +178,55 @@ def test_output_shape(self):
# Call the forward method
output = layer.forward(input_)
self.assertEqual(output.shape, expected_shape)


class TestAscendLogitsProcessor(unittest.TestCase):

def setUp(self):
self.vocab_size = 50
self.num_embeddings = 50
self.embedding_dim = 10
self.org_num_embeddings = 40
self.padding_size = 8

self.mock_group = MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
self.mock_ascend_config = MagicMock()
self.mock_quant_method = MagicMock()
self.mock_quant_method.apply = MagicMock(
return_value=torch.randn(1, self.vocab_size))
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
return_value=True),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
return_value=torch.randn(1, self.vocab_size))
]

for p in self.patches:
p.start()

def tearDown(self):
for p in self.patches:
p.stop()

def test_create_processor(self):
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
self.assertEqual(processor.vocab_size, self.vocab_size)

def test_get_logits(self):
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
embedding_dim=self.embedding_dim,
prefix="lm_head")
lmhead.quant_method = self.mock_quant_method
lmhead.quant_method.apply = self.mock_quant_method.apply
hidden_state = torch.randn(1, self.org_num_embeddings)
processor._get_logits(hidden_state, lmhead)
self.mock_quant_method.apply.assert_called_once()
13 changes: 11 additions & 2 deletions tests/ut/test_ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os

from transformers import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.config import ModelConfig, ParallelConfig, VllmConfig

from tests.ut.base import TestBase
from vllm_ascend.ascend_config import (_check_torchair_supported,
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_init_ascend_config_with_additional_config(self):
"enabled": True
},
"expert_map_path": "test_expert_map_path",
"refresh": True
"refresh": True,
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
Expand Down Expand Up @@ -304,3 +304,12 @@ def test_ascend_config_load_error(self):
"refresh": True
}
init_ascend_config(test_vllm_config)

with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"lmhead_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=2)
init_ascend_config(test_vllm_config)
4 changes: 2 additions & 2 deletions tests/ut/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,13 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm,
# ascend custom op is not registered
utils.register_ascend_customop()
# should call register_oot three
self.assertEqual(mock_customop.register_oot.call_count, 10)
self.assertEqual(mock_customop.register_oot.call_count, 12)
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)

# ascend custom op is already registered
utils.register_ascend_customop()
# should not register_oot again, thus only called three in this ut
self.assertEqual(mock_customop.register_oot.call_count, 10)
self.assertEqual(mock_customop.register_oot.call_count, 12)


class TestProfileExecuteDuration(TestBase):
Expand Down
17 changes: 16 additions & 1 deletion tests/ut/torchair/models/test_torchair_deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def setup_mtp_layer(self, mocker: MockerFixture):
mocker_deepseek_v2_decode_layer = mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())

mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
Expand Down Expand Up @@ -83,6 +88,11 @@ def setup_predictor(self, mocker: MockerFixture):
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())

predictor = TorchairDeepSeekMultiTokenPredictor(
vllm_config=mock_vllm_config)
Expand Down Expand Up @@ -157,6 +167,11 @@ def setup_mtp(self, mocker: MockerFixture):
return_value=None)
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())

mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
return mtp
Expand All @@ -177,4 +192,4 @@ def test_forward(self, mocker: MockerFixture, setup_mtp):
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
previous_hidden_states, inputs_embeds,
spec_step_idx)
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
10 changes: 10 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def __init__(self, vllm_config):
"enable_shared_expert_dp", False
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
self.enable_prefetch = additional_config.get("enable_prefetch", False)
self.lmhead_tensor_parallel_size = additional_config.get(
"lmhead_tensor_parallel_size", None)
if self.lmhead_tensor_parallel_size is not None:
logger.info(
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
)
if vllm_config.parallel_config.tensor_parallel_size != 1:
raise AssertionError(
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
)


class TorchairGraphConfig:
Expand Down
31 changes: 31 additions & 0 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@
init_model_parallel_group)

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config

# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None

_LMTP: Optional[GroupCoordinator] = None


def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2


def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, (
"lm head tensor parallel group is not initialized")
return _LMTP


def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
return _MLP_TP
Expand Down Expand Up @@ -65,6 +74,23 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="mlp_tp")

lmhead_tensor_parallel_size = get_ascend_config(
).lmhead_tensor_parallel_size
if lmhead_tensor_parallel_size is not None:
group_ranks = []
global _LMTP
num_lmhead_tensor_parallel_groups: int = (world_size //
lmhead_tensor_parallel_size)
for i in range(num_lmhead_tensor_parallel_groups):
ranks = list(
range(i * lmhead_tensor_parallel_size,
(i + 1) * lmhead_tensor_parallel_size))
group_ranks.append(ranks)
_LMTP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="lmheadtp")


def get_mlp_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
Expand All @@ -86,3 +112,8 @@ def destroy_ascend_model_parallel():
if _MLP_TP:
_MLP_TP.destroy()
_MLP_TP = None

global _LMTP
if _LMTP:
_LMTP.destroy()
_LMTP = None
Loading
Loading