diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index e29c13ddcf2..20bdc98e3bc 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -65,6 +65,7 @@ def __init__(self, vllm_config): False) # Whether to enable DeepSeek models' prefill optimizations self.enable_cpu_binding = additional_config.get( # Whether to enable the cpu binding "enable_cpu_binding", False) + self.lmhead_tp_size = additional_config.get("lmhead_tp_size", -1) class TorchairGraphConfig: diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 181231ed238..14fbdf08006 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -6,6 +6,7 @@ # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None +_LMHEAD: Optional[GroupCoordinator] = None def get_mc2_group() -> GroupCoordinator: @@ -13,12 +14,18 @@ def get_mc2_group() -> GroupCoordinator: return _MC2 +def get_lmhead_group() -> GroupCoordinator: + assert _LMHEAD is not None, ("lmhead group is not initialized") + return _LMHEAD + + def model_parallel_initialized(): return (_MC2 is not None) def init_ascend_model_parallel( expert_parallel_size: int = 1, + lm_head_tp_size: int = -1, backend: Optional[str] = None, ): if model_parallel_initialized(): @@ -41,9 +48,24 @@ def init_ascend_model_parallel( backend, group_name="mc2") + if lm_head_tp_size > 0: + all_ranks = torch.arange(world_size).reshape(-1, lm_head_tp_size) + global _LMHEAD + group_ranks = all_ranks.unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + _LMHEAD = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="lmhead") + def destroy_ascend_model_parallel(): global _MC2 if _MC2: _MC2.destroy() _MC2 = None + global _LMHEAD + if _LMHEAD: + _LMHEAD.destroy() + _LMHEAD = None diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py new file mode 100644 index 00000000000..6101482feac --- /dev/null +++ b/vllm_ascend/distributed/utils.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from vllm.distributed.parallel_state import ( + get_dp_group, get_tensor_model_parallel_world_size) + +from vllm_ascend.distributed.parallel_state import get_lmhead_group + + +def is_lmhead_tp(): + # We only activate optimization of lmhead communication + # when tp_size == 1, dp_size > 1 and lmhead_tp_size > 1. + + try: + get_lmhead_group() + except AssertionError: + return False + + tp_size = get_tensor_model_parallel_world_size() + dp_size = get_dp_group().world_size + lmhead_tp_size = get_lmhead_group().world_size + + return tp_size == 1 and dp_size > 1 and lmhead_tp_size > 1 \ No newline at end of file diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 8b6f8ccbe3f..6cba6822c3e 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -26,11 +26,10 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding from vllm.model_executor.models.deepseek_mtp import ( DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead) @@ -38,6 +37,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm_ascend.ops.lmhead import CustomParallelLMHead +from vllm_ascend.ops.logits_processor import CustomLogitsProcessor + from .deepseek_v2 import CustomDeepseekV2DecoderLayer @@ -49,10 +51,10 @@ def __init__(self, prefix: str = "") -> None: nn.Module.__init__(self) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "head")) + self.head = CustomParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): @@ -145,7 +147,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) ] - self.logits_processor = LogitsProcessor(config.vocab_size) + self.logits_processor = CustomLogitsProcessor(config.vocab_size) def forward( self, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 48cb0ee02c4..a390f5821fa 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -49,12 +49,11 @@ ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding from vllm.model_executor.models.deepseek_v2 import \ DeepseekV2ForCausalLM # noqa: E501 from vllm.model_executor.models.deepseek_v2 import \ @@ -69,6 +68,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.ops.lmhead import CustomParallelLMHead +from vllm_ascend.ops.logits_processor import CustomLogitsProcessor from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor, npu_prefetch @@ -844,14 +845,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix( prefix, "model")) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = CustomParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) + self.logits_processor = CustomLogitsProcessor(config.vocab_size) self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm_ascend/ops/lmhead.py b/vllm_ascend/ops/lmhead.py new file mode 100644 index 00000000000..a98aea2b99e --- /dev/null +++ b/vllm_ascend/ops/lmhead.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from vllm/model_executor/layers/lmhead.py +# This file is a part of the vllm-ascend project. + +from typing import Optional + +import torch +from torch.nn.parameter import Parameter +from vllm.distributed import divide +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod, VocabParallelEmbedding, pad_vocab_size) +from vllm.model_executor.utils import set_weight_attrs + +from vllm_ascend.distributed.parallel_state import get_lmhead_group + +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +class CustomParallelLMHead(VocabParallelEmbedding): + """Parallelized LM head. + + Output logits weight matrices used in the Sampler. The weight and bias + tensors are padded to make sure they are divisible by the number of + model parallel GPUs. + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + bias: whether to use bias. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings, padding_size, quant_config, + prefix) + # Keep the input dimensions. + tp_rank = get_lmhead_group().rank_in_group + self.tp_size = get_lmhead_group().world_size + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + self.embedding_dim = embedding_dim + + quant_method = None + if quant_config is not None: + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self) is VocabParallelEmbedding + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) + if is_embedding_layer and not quant_method_implements_embedding: + raise NotImplementedError( + f"The class {type(quant_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod.") + + self.quant_method: QuantizeMethodBase = quant_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) + + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) + + self.quant_config = quant_config + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def tie_weights(self, embed_tokens: VocabParallelEmbedding): + """Tie the weights with word embeddings.""" + # GGUF quantized embed_tokens. + if self.quant_config and self.quant_config.get_name() == "gguf": + return embed_tokens + else: + self.weight = embed_tokens.weight + return self + + def forward(self, input_): + del input_ + raise RuntimeError("LMHead's weights should be used in the sampler.") diff --git a/vllm_ascend/ops/logits_processor.py b/vllm_ascend/ops/logits_processor.py new file mode 100644 index 00000000000..6702f273489 --- /dev/null +++ b/vllm_ascend/ops/logits_processor.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/model_executor/layers/vocab_parallel_embedding.py +# + +from typing import Optional + +import torch +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding + +from vllm_ascend.distributed.parallel_state import get_lmhead_group +from vllm_ascend.distributed.utils import is_lmhead_tp + + +class CustomLogitsProcessor(LogitsProcessor): + """Process logits and apply logits processors from sampling metadata. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + + if is_lmhead_tp(): + hidden_states = get_lmhead_group().all_gather(hidden_states, 0) + + logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + + if is_lmhead_tp(): + logits = get_lmhead_group().all_to_all(logits) + + # Gather logits for TP + logits = self._gather_logits(logits) + + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index ad7b9838f0b..60b673362be 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -135,6 +135,28 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.world_size_across_dp // parallel_config.expert_tensor_parallel_size) + # Check and configure lmhead_tp + if model_config is not None and "deepseek" not in model_config.hf_config.model_type: + if ascend_config.lmhead_tp_size > 0: + logger.warning( + "LMHead TP only supports deepseek now. For other models, parallelism of lmhead" + " is designed by tensor_parallel_size.") + parallel_config.lmhead_tp_size = -1 + elif ascend_config.lmhead_tp_size <= 0: + # For deepseek series, although we don't need lmhead_tp in default cases, we still + # need to initialize the communication group. + parallel_config.lmhead_tp_size = parallel_config.tensor_parallel_size + elif parallel_config.tensor_parallel_size > 1: + logger.warning( + "LMHead tp_size is separately set while tensor_parallel_size is larger than 1." + "For better performance, we recommend to set LMHead tp_size only when " + "tensor_parallel_size equals 1. LMHead tp_size will be automatically changed " + "to tensor_parallel_size when tensor_parallel_size is larger than 1." + ) + parallel_config.lmhead_tp_size = parallel_config.tensor_parallel_size + else: + parallel_config.lmhead_tp_size = ascend_config.lmhead_tp_size + if model_config is None: logger.warning("Model config is missing. This may indicate " "that we are running a test case") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0284a016c08..e8b8eecb570 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -80,6 +80,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import \ AscendCommonAttentionMetadata as CommonAttentionMetadata +from vllm_ascend.distributed.utils import is_lmhead_tp from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.multistream.ms_split import compute_split_seq_index @@ -1190,6 +1191,15 @@ def _process_reqs( num_draft_tokens, cu_num_tokens) sample_indices = spec_decode_metadata.logits_indices + if is_lmhead_tp(): + if not with_prefill: + padded_num_indices = padded_num_tokens_across_dp + else: + padded_num_indices = self.max_num_reqs + sample_indices = nn.functional.pad( + sample_indices, + (0, padded_num_indices - sample_indices.shape[0])) + return (attn_metadata, hidden_states, spec_decode_metadata, positions, total_num_scheduled_tokens, sample_indices, finished_sending, finished_recving) @@ -1393,11 +1403,17 @@ def execute_model( # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: + if is_lmhead_tp(): + logits = logits[:self.input_batch.num_reqs] + sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) else: + if is_lmhead_tp(): + logits = logits[:len(spec_decode_metadata.logits_indices)] + # When indexing with a tensor (bonus_logits_indices), PyTorch # creates a new tensor with separate storage from the original # logits tensor. This means any in-place operations on bonus_logits @@ -1769,6 +1785,20 @@ def _dummy_run( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, **model_kwargs) + + if not self.in_profile_run and is_lmhead_tp(): + # lmhead_tp introduces additional communication across + # dp when computing logits. Hence we need to add it + # in profile_run. + if not with_prefill: + padded_num_indices = num_tokens + else: + padded_num_indices = max_num_reqs + dummy_indices = torch.zeros(padded_num_indices, + device=hidden_states.device, + dtype=torch.int32) + model.compute_logits(hidden_states[dummy_indices], None) + if self.speculative_config and self.speculative_config.method == "deepseek_mtp": assert isinstance(self.drafter, MtpProposer) self.drafter.dummy_run( @@ -1778,6 +1808,16 @@ def _dummy_run( num_reqs=num_reqs, num_tokens_across_dp=num_tokens_across_dp) + if not self.in_profile_run and is_lmhead_tp(): + if not with_prefill: + padded_num_indices = num_reqs + else: + padded_num_indices = max_num_reqs + dummy_indices = torch.zeros(padded_num_indices, + device=hidden_states.device, + dtype=torch.int32) + model.compute_logits(hidden_states[dummy_indices], None) + if self.in_profile_run and self.dynamic_eplb: self.model.clear_all_moe_loads() if not is_torchair_compile and not self.in_profile_run and self.dynamic_eplb: diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index ee8d7c5e308..4b3d00a6da0 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -18,6 +18,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.utils import \ AscendCommonAttentionMetadata as CommonAttentionMetadata +from vllm_ascend.distributed.utils import is_lmhead_tp from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.utils import ProfileExecuteDuration @@ -263,8 +264,22 @@ def propose( previous_hidden_states=self. hidden_states[:num_input_tokens], kv_caches=self.runner.kv_caches[-1:]) + + num_indices = last_token_indices.shape[0] + if is_lmhead_tp(): + if not self.runner.with_prefill: + padded_num_indices = num_input_tokens // self.runner.decode_token_per_req + else: + padded_num_indices = self.vllm_config.scheduler_config.max_num_seqs + last_token_indices = nn.functional.pad( + last_token_indices, (0, padded_num_indices - num_indices)) + sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) + + if is_lmhead_tp() and num_indices < logits.shape[0]: + logits = logits[:num_indices] + draft_token_ids = logits.argmax(dim=-1) # [batch_size, 1] diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index eb410777023..7f014185808 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -302,7 +302,8 @@ def _init_worker_distributed_environment(self) -> None: ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) - init_ascend_model_parallel(self.parallel_config.expert_parallel_size) + init_ascend_model_parallel(self.parallel_config.expert_parallel_size, + self.parallel_config.lmhead_tp_size) ensure_kv_transfer_initialized(self.vllm_config) def _init_profiler(self):