Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, vllm_config):
False) # Whether to enable DeepSeek models' prefill optimizations
self.enable_cpu_binding = additional_config.get( # Whether to enable the cpu binding
"enable_cpu_binding", False)
self.lmhead_tp_size = additional_config.get("lmhead_tp_size", -1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better that the default value is 1



class TorchairGraphConfig:
Expand Down
22 changes: 22 additions & 0 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,26 @@

# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
_LMHEAD: Optional[GroupCoordinator] = None


def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2


def get_lmhead_group() -> GroupCoordinator:
assert _LMHEAD is not None, ("lmhead group is not initialized")
return _LMHEAD


def model_parallel_initialized():
return (_MC2 is not None)


def init_ascend_model_parallel(
expert_parallel_size: int = 1,
lm_head_tp_size: int = -1,
backend: Optional[str] = None,
):
if model_parallel_initialized():
Expand All @@ -41,9 +48,24 @@ def init_ascend_model_parallel(
backend,
group_name="mc2")

if lm_head_tp_size > 0:
all_ranks = torch.arange(world_size).reshape(-1, lm_head_tp_size)
global _LMHEAD
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]

_LMHEAD = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="lmhead")


def destroy_ascend_model_parallel():
global _MC2
if _MC2:
_MC2.destroy()
_MC2 = None
global _LMHEAD
if _LMHEAD:
_LMHEAD.destroy()
_LMHEAD = None
37 changes: 37 additions & 0 deletions vllm_ascend/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#

from vllm.distributed.parallel_state import (
get_dp_group, get_tensor_model_parallel_world_size)

from vllm_ascend.distributed.parallel_state import get_lmhead_group


def is_lmhead_tp():
# We only activate optimization of lmhead communication
# when tp_size == 1, dp_size > 1 and lmhead_tp_size > 1.

try:
get_lmhead_group()
except AssertionError:
return False

tp_size = get_tensor_model_parallel_world_size()
dp_size = get_dp_group().world_size
lmhead_tp_size = get_lmhead_group().world_size

return tp_size == 1 and dp_size > 1 and lmhead_tp_size > 1
18 changes: 10 additions & 8 deletions vllm_ascend/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from vllm_ascend.ops.lmhead import CustomParallelLMHead
from vllm_ascend.ops.logits_processor import CustomLogitsProcessor

from .deepseek_v2 import CustomDeepseekV2DecoderLayer


Expand All @@ -49,10 +51,10 @@ def __init__(self,
prefix: str = "") -> None:
nn.Module.__init__(self)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "head"))
self.head = CustomParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "head"))


class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
Expand Down Expand Up @@ -145,7 +147,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
]
self.logits_processor = LogitsProcessor(config.vocab_size)
self.logits_processor = CustomLogitsProcessor(config.vocab_size)

def forward(
self,
Expand Down
19 changes: 10 additions & 9 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@
ReplicatedLinear,
RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.models.deepseek_v2 import \
DeepseekV2ForCausalLM # noqa: E501
from vllm.model_executor.models.deepseek_v2 import \
Expand All @@ -69,6 +68,8 @@

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.lmhead import CustomParallelLMHead
from vllm_ascend.ops.logits_processor import CustomLogitsProcessor
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor, npu_prefetch
Expand Down Expand Up @@ -844,14 +845,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=maybe_prefix(
prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
self.lm_head = CustomParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.logits_processor = CustomLogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
Expand Down
150 changes: 150 additions & 0 deletions vllm_ascend/ops/lmhead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from vllm/model_executor/layers/lmhead.py
# This file is a part of the vllm-ascend project.

from typing import Optional

import torch
from torch.nn.parameter import Parameter
from vllm.distributed import divide
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding, pad_vocab_size)
from vllm.model_executor.utils import set_weight_attrs

from vllm_ascend.distributed.parallel_state import get_lmhead_group

DEFAULT_VOCAB_PADDING_SIZE = 64


class CustomParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.

Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.

Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""

def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
# Keep the input dimensions.
tp_rank = get_lmhead_group().rank_in_group
self.tp_size = get_lmhead_group().world_size
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded

self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim

quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()

# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")

self.quant_method: QuantizeMethodBase = quant_method

if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)

self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)

self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)

def tie_weights(self, embed_tokens: VocabParallelEmbedding):
"""Tie the weights with word embeddings."""
# GGUF quantized embed_tokens.
if self.quant_config and self.quant_config.get_name() == "gguf":
return embed_tokens
else:
self.weight = embed_tokens.weight
return self

def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")
Loading