From b8dff810c7c80d4c792309e19828fcb8808108eb Mon Sep 17 00:00:00 2001 From: zhangzihang Date: Fri, 8 Aug 2025 14:37:05 +0800 Subject: [PATCH 01/10] [Feat]: Add custom lmhead tensor model parallel in pure DP and graph-mode scenarios. Signed-off-by: zzhx1 --- vllm_ascend/ascend_config.py | 8 +- vllm_ascend/distributed/parallel_state.py | 28 ++ vllm_ascend/models/deepseek_v2.py | 15 +- vllm_ascend/ops/vocab_parallel_embedding.py | 268 +++++++++++++++++- vllm_ascend/platform.py | 6 + vllm_ascend/torchair/torchair_model_runner.py | 12 + vllm_ascend/worker/model_runner_v1.py | 19 ++ 7 files changed, 345 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 81cf177c4be..3d8ba1a1d05 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -51,7 +51,13 @@ def __init__(self, vllm_config): "enable_shared_expert_dp", False ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel self.enable_prefetch = additional_config.get("enable_prefetch", False) - + self.lmhead_tensor_parallel_size = additional_config.get( + "lmhead_tensor_parallel_size", None) + if self.lmhead_tensor_parallel_size is not None: + logger.info(f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario") + assert( + vllm_config.parallel_config.tensor_parallel_size == 1 + ),"lmhead_tensor_parallel_size is only supported in the pure DP scenario" class TorchairGraphConfig: """ diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index db1c5a88ae5..ef7894aab4e 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -11,11 +11,16 @@ _MC2: Optional[GroupCoordinator] = None _MLP_TP: Optional[GroupCoordinator] = None +_LMTP: Optional[GroupCoordinator] = None def get_mc2_group() -> GroupCoordinator: assert _MC2 is not None, ("mc2 group is not initialized") return _MC2 +def get_lmheadtp_group() -> GroupCoordinator: + assert _LMTP is not None, ( + "lm head tensor parallel group is not initialized") + return _LMTP def get_mlp_tp_group() -> GroupCoordinator: assert _MLP_TP is not None, ("mlp group is not initialized") @@ -64,6 +69,22 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): get_world_group().local_rank, backend, group_name="mlp_tp") + + lmhead_tensor_parallel_size = parallel_config.lmhead_tensor_parallel_size + if lmhead_tensor_parallel_size is not None: + group_ranks = [] + global _LMTP + num_lmhead_tensor_parallel_groups: int = (world_size // + lmhead_tensor_parallel_size) + for i in range(num_lmhead_tensor_parallel_groups): + ranks = list( + range(i * lmhead_tensor_parallel_size, + (i + 1) * lmhead_tensor_parallel_size)) + group_ranks.append(ranks) + _LMTP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="lmheadtp") def get_mlp_tensor_model_parallel_world_size(): @@ -75,6 +96,8 @@ def get_mlp_tensor_model_parallel_rank(): """Return world size for the tensor model parallel group.""" return get_mlp_tp_group().rank_in_group + + def destroy_ascend_model_parallel(): global _MC2 @@ -86,3 +109,8 @@ def destroy_ascend_model_parallel(): if _MLP_TP: _MLP_TP.destroy() _MLP_TP = None + + global _LMTP + if _LMTP: + _LMTP.destroy() + _LMTP = None diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 6d0913c18c9..b51130df405 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -67,6 +67,7 @@ PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.fused_moe import AscendFusedMoE @@ -74,6 +75,8 @@ from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor +from vllm_ascend.utils import dispose_tensor, npu_prefetch +from vllm_ascend.ops.vocab_parallel_embedding import CustomParallelLMHead, CustomLogitsProcessor class CustomDeepseekV2SiluAndMul(SiluAndMul): @@ -874,14 +877,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix( prefix, "model")) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = CustomParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) + self.logits_processor = CustomLogitsProcessor(config.vocab_size) self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 05b08a4d12c..64a6a9c0b88 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -15,12 +15,42 @@ # limitations under the License. # -from typing import Tuple +from typing import Optional, Tuple import torch -from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.model_executor.layers.vocab_parallel_embedding import \ - VocabParallelEmbedding +from torch.nn import Module +import torch.distributed as dist +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce +) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, + DEFAULT_VOCAB_PADDING_SIZE, + pad_vocab_size, + UnquantizedEmbeddingMethod, + ParallelLMHead +) +from vllm.model_executor.layers.logits_processor import ( + LogitsProcessor, + _apply_logits_processors, + _prune_hidden_states +) +from vllm.model_executor.parameter import BasevLLMParameter +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, + method_has_implemented_embedding +) + +from vllm_ascend.distributed.parallel_state import get_lmheadtp_group +from vllm_ascend.ascend_config import get_ascend_config class AscendVocabParallelEmbedding(VocabParallelEmbedding): @@ -71,3 +101,233 @@ def forward(self, input_): # Reduce across all the model parallel GPUs. output = tensor_model_parallel_all_reduce(output_parallel) return output + + +class CustomParallelLMHead(ParallelLMHead): + + """Costom Parallelized LM head, added the feature of lmheadTP in pure dp scenario + + Output logits weight matrices used in the Sampler. The weight and bias + tensors are padded to make sure they are divisible by the number of + model parallel GPUs. + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + bias: whether to use bias. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + """ + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + Module.__init__(self) + self._enable_lmhead_tp = False + if get_ascend_config().lmhead_tensor_parallel_size is not None: + tp_rank = get_lmheadtp_group().rank_in_group + self.tp_size = get_lmheadtp_group().world_size + self._enable_lmhead_tp = True + else: + tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + self.embedding_dim = embedding_dim + + quant_method = None + if quant_config is not None: + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self) is VocabParallelEmbedding + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) + if is_embedding_layer and not quant_method_implements_embedding: + raise NotImplementedError( + f"The class {type(quant_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod.") + + self.quant_method: QuantizeMethodBase = quant_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) + + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) + + self.quant_config = quant_config + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + +class CustomLogitsProcessor(LogitsProcessor): + """Custom logits processor extending base LogitsProcessor functionality. + Added the feature of lmheadTP in pure dp scenario + """ + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None) -> None: + super().__init__( + vocab_size=vocab_size, + org_vocab_size=org_vocab_size, + scale=scale, + logits_as_input=logits_as_input, + soft_cap=soft_cap + ) + + def forward( + self, + lm_head: CustomParallelLMHead, + hidden_states: torch.Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + if self.logits_as_input: + logits = hidden_states + else: + if sampling_metadata is not None: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, lm_head, embedding_bias) + if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = torch.tanh(logits) + logits = logits * self.soft_cap + + if self.scale != 1.0: + logits *= self.scale + + # Apply logits processors (if any). + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: CustomParallelLMHead, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + """ + Compute logits for next token prediction using parallel processing. + + Args: + hidden_states: Current hidden states from the model with shape [batch_size, hidden_size] + lm_head: Parallel embedding layer for vocabulary predictions + embedding_bias: Optional bias tensor to add to logits with shape [vocab_size] + + Returns: + Logits tensor for next token prediction with shape [batch_size, vocab_size] or None + """ + + if lm_head._enable_lmhead_tp: + # _enable_lmhead_tp used in the graph mode,and batch_size is the same across different dp. + lmhead_tp_size = lm_head.tp_size + local_batch_size = hidden_states.size(0) + vocab_size_per_partition = lm_head.num_embeddings_per_partition + + # Gather hidden states from all devices in tensor parallel group + gathered_hidden_states = get_lmheadtp_group().all_gather(hidden_states, dim=0) + else: + gathered_hidden_states = hidden_states + + # Compute logits using quantized matrix multiplication + local_logits = lm_head.quant_method.apply( + lm_head, + gathered_hidden_states, + bias=embedding_bias + ) + + if lm_head._enable_lmhead_tp: + # Prepare for all-to-all communication to redistribute logits + input_split_sizes = [local_batch_size] * get_lmheadtp_group().world_size + output_split_sizes = [local_batch_size * vocab_size_per_partition] * lmhead_tp_size + + all_to_all_output = torch.empty( + local_batch_size * (vocab_size_per_partition * lmhead_tp_size), + dtype=local_logits.dtype, + device='npu' + ) + + # Perform all-to-all communication to get correct logit partitions + dist.all_to_all_single( + all_to_all_output, + local_logits, + output_split_sizes, + input_split_sizes, + group=get_lmheadtp_group().device_group + ) + + # Reshape and combine logits from all partitions + reshaped_logits = all_to_all_output.view( + lmhead_tp_size, local_batch_size, vocab_size_per_partition + ) + combined_logits = reshaped_logits.permute(1, 0, 2).reshape(local_batch_size, -1) + else: + # Gather logits for tensor parallel + combined_logits = self._gather_logits(local_logits) + + # Remove paddings in vocab (if any) + if combined_logits is not None: + combined_logits = combined_logits[..., :self.org_vocab_size] + + return combined_logits diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index afc0e6b5458..9383a6afc98 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -131,6 +131,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if kv_cache_dtype is not None: vllm_config.cache_config.cache_dtype = kv_cache_dtype + if parallel_config: + # assign lmhead tensor parallel size + parallel_config.lmhead_tensor_parallel_size = ( + ascend_config.lmhead_tensor_parallel_size + ) + if model_config is None: logger.warning("Model config is missing. This may indicate " "that we are running a test case") diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 24fd33a1ea4..b37d6823d88 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -168,6 +168,18 @@ def _generate_dummy_run_hidden_states(self, with_prefill, hidden_states = super()._generate_dummy_run_hidden_states( with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) + if not self.in_profile_run and self._enable_lmhead_tp(): + # lmhead_tp introduces additional communication across + # dp when computing logits. Hence we need to add it + # in profile_run. + if not with_prefill: + max_num_reqs_across_dp = num_tokens + else: + max_num_reqs_across_dp = self.scheduler_config.max_num_seqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, + device=hidden_states.device, + dtype=torch.int32) + self.model.compute_logits(hidden_states[dummy_indices], None) return hidden_states def _convert_torch_format(self, kv_cache): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6930a1ce789..7da504d4862 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2080,6 +2080,19 @@ def _dummy_run( assert aclgraph_runtime_mode == _cg_mode, ( f"Aclgraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.") + + need_dummy_logits = (not self.in_profile_run and self._enable_lmhead_tp()) + + if need_dummy_logits: + # lmhead_tp introduces additional communication across + # dp when computing logits. Hence we need to add it + # in profile_run. + max_num_reqs_across_dp = num_tokens if not with_prefill else self.scheduler_config.max_num_seqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, + device=hidden_states.device, + dtype=torch.int32) + def dummy_compute_logits(hidden_states): + return self.model.compute_logits(hidden_states[dummy_indices], None) with set_ascend_forward_context( attn_metadata, @@ -2097,6 +2110,9 @@ def _dummy_run( with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) + if need_dummy_logits: + dummy_compute_logits(hidden_states) + if self.speculative_config and self.speculative_config.method == "deepseek_mtp": assert isinstance(self.drafter, MtpProposer) self.drafter.dummy_run( @@ -2105,6 +2121,9 @@ def _dummy_run( skip_attn=True, num_reqs=num_reqs, num_tokens_across_dp=num_tokens_across_dp) + if need_dummy_logits: + dummy_compute_logits(hidden_states) + return hidden_states From c0993bcfbc185420f18914ca15fcbc6422164960 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Sun, 10 Aug 2025 17:28:43 +0800 Subject: [PATCH 02/10] pad tensor in compute_logits and adapting to the mixed deployment mode Signed-off-by: zzhx1 --- vllm_ascend/ops/vocab_parallel_embedding.py | 50 +++++---------------- vllm_ascend/utils.py | 6 +++ vllm_ascend/worker/model_runner_v1.py | 17 ++++++- 3 files changed, 32 insertions(+), 41 deletions(-) diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 64a6a9c0b88..7d458f8e310 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -41,7 +41,7 @@ _prune_hidden_states ) from vllm.model_executor.parameter import BasevLLMParameter -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.utils import set_weight_attrs, _enable_lmhead_tp from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, @@ -129,11 +129,10 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): Module.__init__(self) - self._enable_lmhead_tp = False - if get_ascend_config().lmhead_tensor_parallel_size is not None: + + if _enable_lmhead_tp(): tp_rank = get_lmheadtp_group().rank_in_group self.tp_size = get_lmheadtp_group().world_size - self._enable_lmhead_tp = True else: tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -279,12 +278,7 @@ def _get_logits( Logits tensor for next token prediction with shape [batch_size, vocab_size] or None """ - if lm_head._enable_lmhead_tp: - # _enable_lmhead_tp used in the graph mode,and batch_size is the same across different dp. - lmhead_tp_size = lm_head.tp_size - local_batch_size = hidden_states.size(0) - vocab_size_per_partition = lm_head.num_embeddings_per_partition - + if _enable_lmhead_tp(): # Gather hidden states from all devices in tensor parallel group gathered_hidden_states = get_lmheadtp_group().all_gather(hidden_states, dim=0) else: @@ -297,37 +291,15 @@ def _get_logits( bias=embedding_bias ) - if lm_head._enable_lmhead_tp: - # Prepare for all-to-all communication to redistribute logits - input_split_sizes = [local_batch_size] * get_lmheadtp_group().world_size - output_split_sizes = [local_batch_size * vocab_size_per_partition] * lmhead_tp_size - - all_to_all_output = torch.empty( - local_batch_size * (vocab_size_per_partition * lmhead_tp_size), - dtype=local_logits.dtype, - device='npu' - ) - - # Perform all-to-all communication to get correct logit partitions - dist.all_to_all_single( - all_to_all_output, - local_logits, - output_split_sizes, - input_split_sizes, - group=get_lmheadtp_group().device_group - ) - - # Reshape and combine logits from all partitions - reshaped_logits = all_to_all_output.view( - lmhead_tp_size, local_batch_size, vocab_size_per_partition - ) - combined_logits = reshaped_logits.permute(1, 0, 2).reshape(local_batch_size, -1) + if _enable_lmhead_tp(): + logits = get_lmheadtp_group().all_to_all(local_logits) else: # Gather logits for tensor parallel - combined_logits = self._gather_logits(local_logits) + logits = self._gather_logits(local_logits) # Remove paddings in vocab (if any) - if combined_logits is not None: - combined_logits = combined_logits[..., :self.org_vocab_size] + if logits is not None: + logits = logits[..., :self.org_vocab_size] - return combined_logits + return logits + diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index a99a491f393..438dd7cd3d7 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -547,3 +547,9 @@ def get_ascend_soc_version(): global _ascend_soc_version assert _ascend_soc_version is not None return _ascend_soc_version + + +def _enable_lmhead_tp() -> bool: + if get_ascend_config().lmhead_tensor_parallel_size is not None: + return True + return False diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7da504d4862..828e54092a2 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -58,8 +58,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LazyLoader, cdiv, is_pin_memory_available) -from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + LazyLoader, cdiv, is_pin_memory_available, _enable_lmhead_tp) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, @@ -1276,6 +1276,15 @@ def _prepare_inputs( spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices + + if _enable_lmhead_tp(): # + if not with_prefill: + max_num_reqs_across_dp = padded_num_tokens_across_dp + else: + max_num_reqs_across_dp = self.max_num_reqs + logits_indices = nn.functional.pad( + logits_indices, + (0, max_num_reqs_across_dp - logits_indices.shape[0])) return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, @@ -1734,11 +1743,15 @@ def execute_model( # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: + if _enable_lmhead_tp(): + logits = logits[:self.input_batch.num_reqs] sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) else: + if _enable_lmhead_tp(): + logits = logits[:len(spec_decode_metadata.logits_indices)] # When indexing with a tensor (bonus_logits_indices), PyTorch # creates a new tensor with separate storage from the original # logits tensor. This means any in-place operations on bonus_logits From 684173392908e805730d0dad78547183eb9b72d9 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Sun, 10 Aug 2025 20:08:01 +0800 Subject: [PATCH 03/10] Fix bugs and CI and optimize codes Signed-off-by: zzhx1 --- tests/ut/distributed/test_parallel_state.py | 42 +++++ tests/ut/models/test_deepseek_v2.py | 2 +- tests/ut/test_ascend_config.py | 13 +- vllm_ascend/ascend_config.py | 12 +- vllm_ascend/distributed/parallel_state.py | 21 +-- vllm_ascend/models/deepseek_mtp.py | 18 ++- vllm_ascend/models/deepseek_v2.py | 12 +- vllm_ascend/ops/vocab_parallel_embedding.py | 145 ++++-------------- vllm_ascend/platform.py | 3 +- vllm_ascend/torchair/torchair_model_runner.py | 12 -- vllm_ascend/utils.py | 6 +- vllm_ascend/worker/model_runner_v1.py | 48 +++--- vllm_ascend/worker/mtp_proposer_v1.py | 14 +- 13 files changed, 160 insertions(+), 188 deletions(-) create mode 100644 tests/ut/distributed/test_parallel_state.py diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py new file mode 100644 index 00000000000..58e025ede46 --- /dev/null +++ b/tests/ut/distributed/test_parallel_state.py @@ -0,0 +1,42 @@ +from unittest.mock import MagicMock, patch + +import pytest +from vllm.config import ParallelConfig + +from vllm_ascend.distributed.parallel_state import ( + _LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group, + get_mc2_group, init_ascend_model_parallel) + + +@pytest.fixture +def parallel_config(): + return ParallelConfig(data_parallel_size=2, + tensor_parallel_size=2, + pipeline_parallel_size=2) + + +@pytest.fixture +def mock_distributed(): + with patch('torch.distributed.is_initialized', return_value=True), \ + patch('torch.distributed.get_world_size', return_value=8), \ + patch('torch.distributed.get_backend', return_value='nccl'), \ + patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group: + mock_group.return_value.local_rank = 0 + mock_group.return_value.device_group = MagicMock() + yield + + +def test_init_ascend_model_parallel(mock_distributed, parallel_config): + with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ + patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'): + parallel_config.lmhead_tensor_parallel_size = 2 + init_ascend_model_parallel(parallel_config) + + mc2_group = get_mc2_group() + assert mc2_group is not None + lmheadtp_group = get_lmhead_tp_group() + assert lmheadtp_group is not None + + destroy_ascend_model_parallel() + assert _MC2 is None + assert _LMTP is None diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index e0c50e81758..0b3fe9c8e5f 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -26,7 +26,7 @@ CustomDeepseekV2MLP, CustomDeepseekV2MoE, CustomDeepseekV2RowParallelLinear, CustomDeepseekV2RowParallelLinearReplaceAllreduce, - CustomDeepseekV2SiluAndMul) + CustomDeepseekV2SiluAndMul, CustomLogitsProcessor, CustomParallelLMHead) @pytest.fixture diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 49b9abeabcf..622b751d31b 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -16,7 +16,7 @@ import os from transformers import PretrainedConfig -from vllm.config import ModelConfig, VllmConfig +from vllm.config import ModelConfig, ParallelConfig, VllmConfig from tests.ut.base import TestBase from vllm_ascend.ascend_config import (_check_torchair_supported, @@ -75,7 +75,7 @@ def test_init_ascend_config_with_additional_config(self): "enabled": True }, "expert_map_path": "test_expert_map_path", - "refresh": True + "refresh": True, } ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") @@ -304,3 +304,12 @@ def test_ascend_config_load_error(self): "refresh": True } init_ascend_config(test_vllm_config) + + with self.assertRaises(AssertionError): + test_vllm_config.additional_config = { + "lmhead_tensor_parallel_size": 2, + "refresh": True + } + test_vllm_config.parallel_config = ParallelConfig( + data_parallel_size=4, tensor_parallel_size=2) + init_ascend_config(test_vllm_config) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 3d8ba1a1d05..2a2ac7bf642 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -54,10 +54,14 @@ def __init__(self, vllm_config): self.lmhead_tensor_parallel_size = additional_config.get( "lmhead_tensor_parallel_size", None) if self.lmhead_tensor_parallel_size is not None: - logger.info(f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario") - assert( - vllm_config.parallel_config.tensor_parallel_size == 1 - ),"lmhead_tensor_parallel_size is only supported in the pure DP scenario" + logger.info( + f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario" + ) + if vllm_config.parallel_config.tensor_parallel_size != 1: + raise AssertionError( + "lmhead_tensor_parallel_size is only supported in the pure DP scenario" + ) + class TorchairGraphConfig: """ diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index ef7894aab4e..f562ea88b21 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -13,15 +13,18 @@ _LMTP: Optional[GroupCoordinator] = None + def get_mc2_group() -> GroupCoordinator: assert _MC2 is not None, ("mc2 group is not initialized") return _MC2 -def get_lmheadtp_group() -> GroupCoordinator: + +def get_lmhead_tp_group() -> GroupCoordinator: assert _LMTP is not None, ( "lm head tensor parallel group is not initialized") return _LMTP + def get_mlp_tp_group() -> GroupCoordinator: assert _MLP_TP is not None, ("mlp group is not initialized") return _MLP_TP @@ -69,22 +72,22 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): get_world_group().local_rank, backend, group_name="mlp_tp") - + lmhead_tensor_parallel_size = parallel_config.lmhead_tensor_parallel_size if lmhead_tensor_parallel_size is not None: group_ranks = [] global _LMTP num_lmhead_tensor_parallel_groups: int = (world_size // - lmhead_tensor_parallel_size) + lmhead_tensor_parallel_size) for i in range(num_lmhead_tensor_parallel_groups): ranks = list( range(i * lmhead_tensor_parallel_size, - (i + 1) * lmhead_tensor_parallel_size)) + (i + 1) * lmhead_tensor_parallel_size)) group_ranks.append(ranks) _LMTP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="lmheadtp") + get_world_group().local_rank, + backend, + group_name="lmheadtp") def get_mlp_tensor_model_parallel_world_size(): @@ -96,8 +99,6 @@ def get_mlp_tensor_model_parallel_rank(): """Return world size for the tensor model parallel group.""" return get_mlp_tp_group().rank_in_group - - def destroy_ascend_model_parallel(): global _MC2 @@ -109,7 +110,7 @@ def destroy_ascend_model_parallel(): if _MLP_TP: _MLP_TP.destroy() _MLP_TP = None - + global _LMTP if _LMTP: _LMTP.destroy() diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 8bcc4fbed30..f2784aa74c2 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -25,11 +25,10 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding from vllm.model_executor.models.deepseek_mtp import ( DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead) @@ -37,6 +36,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm_ascend.ops.vocab_parallel_embedding import (CustomLogitsProcessor, + CustomParallelLMHead) + from .deepseek_v2 import CustomDeepseekV2DecoderLayer @@ -48,10 +50,10 @@ def __init__(self, prefix: str = "") -> None: nn.Module.__init__(self) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "head")) + self.head = CustomParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): @@ -141,7 +143,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) ] - self.logits_processor = LogitsProcessor(config.vocab_size) + self.logits_processor = CustomLogitsProcessor(config.vocab_size) def forward( self, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index b51130df405..3dbb8fea686 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -48,12 +48,11 @@ ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import \ @@ -67,16 +66,17 @@ PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.ops.vocab_parallel_embedding import (CustomLogitsProcessor, + CustomParallelLMHead) from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor from vllm_ascend.utils import dispose_tensor, npu_prefetch -from vllm_ascend.ops.vocab_parallel_embedding import CustomParallelLMHead, CustomLogitsProcessor + class CustomDeepseekV2SiluAndMul(SiluAndMul): @@ -881,7 +881,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size, quant_config=quant_config, prefix=maybe_prefix( - prefix, "lm_head")) + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = CustomLogitsProcessor(config.vocab_size) diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 7d458f8e310..8dd0718d405 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -19,38 +19,20 @@ import torch from torch.nn import Module -import torch.distributed as dist -from torch.nn.parameter import Parameter, UninitializedParameter - -from vllm.distributed import ( - divide, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce -) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, - DEFAULT_VOCAB_PADDING_SIZE, - pad_vocab_size, - UnquantizedEmbeddingMethod, - ParallelLMHead -) -from vllm.model_executor.layers.logits_processor import ( - LogitsProcessor, - _apply_logits_processors, - _prune_hidden_states -) -from vllm.model_executor.parameter import BasevLLMParameter -from vllm.model_executor.utils import set_weight_attrs, _enable_lmhead_tp -from vllm.model_executor.sampling_metadata import SamplingMetadata +from torch.nn.parameter import Parameter +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, - method_has_implemented_embedding -) + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod, + VocabParallelEmbedding, pad_vocab_size) +from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.distributed.parallel_state import get_lmheadtp_group -from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group +from vllm_ascend.utils import lmhead_tp_enable class AscendVocabParallelEmbedding(VocabParallelEmbedding): @@ -104,8 +86,7 @@ def forward(self, input_): class CustomParallelLMHead(ParallelLMHead): - - """Costom Parallelized LM head, added the feature of lmheadTP in pure dp scenario + """Custom Parallelized LM head, added the feature of lmheadTP in pure dp scenario Output logits weight matrices used in the Sampler. The weight and bias tensors are padded to make sure they are divisible by the number of @@ -119,6 +100,7 @@ class CustomParallelLMHead(ParallelLMHead): org_num_embeddings: original vocabulary size (without LoRA). padding_size: padding size for the vocabulary. """ + def __init__(self, num_embeddings: int, embedding_dim: int, @@ -127,16 +109,16 @@ def __init__(self, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = ""): Module.__init__(self) - if _enable_lmhead_tp(): - tp_rank = get_lmheadtp_group().rank_in_group - self.tp_size = get_lmheadtp_group().world_size + if lmhead_tp_enable(): + tp_rank = get_lmhead_tp_group().rank_in_group + self.tp_size = get_lmhead_tp_group().world_size else: tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - + self.num_embeddings = num_embeddings self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings @@ -196,7 +178,7 @@ def __init__(self, self.num_embeddings_padded, params_dtype=params_dtype, weight_loader=self.weight_loader) - + self.quant_config = quant_config if bias: self.bias = Parameter( @@ -208,91 +190,33 @@ def __init__(self, }) else: self.register_parameter("bias", None) - + + class CustomLogitsProcessor(LogitsProcessor): """Custom logits processor extending base LogitsProcessor functionality. Added the feature of lmheadTP in pure dp scenario """ - - def __init__(self, - vocab_size: int, - org_vocab_size: Optional[int] = None, - scale: float = 1.0, - logits_as_input: bool = False, - soft_cap: Optional[float] = None) -> None: - super().__init__( - vocab_size=vocab_size, - org_vocab_size=org_vocab_size, - scale=scale, - logits_as_input=logits_as_input, - soft_cap=soft_cap - ) - def forward( + def _get_logits( self, - lm_head: CustomParallelLMHead, hidden_states: torch.Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, - embedding_bias: Optional[torch.Tensor] = None, + lm_head: CustomParallelLMHead, + embedding_bias: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: - if self.logits_as_input: - logits = hidden_states - else: - if sampling_metadata is not None: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - - # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, lm_head, embedding_bias) - if logits is not None: - if self.soft_cap is not None: - logits = logits / self.soft_cap - logits = torch.tanh(logits) - logits = logits * self.soft_cap - - if self.scale != 1.0: - logits *= self.scale - # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) - - return logits - - def _get_logits( - self, - hidden_states: torch.Tensor, - lm_head: CustomParallelLMHead, - embedding_bias: Optional[torch.Tensor], - ) -> Optional[torch.Tensor]: - """ - Compute logits for next token prediction using parallel processing. - - Args: - hidden_states: Current hidden states from the model with shape [batch_size, hidden_size] - lm_head: Parallel embedding layer for vocabulary predictions - embedding_bias: Optional bias tensor to add to logits with shape [vocab_size] - - Returns: - Logits tensor for next token prediction with shape [batch_size, vocab_size] or None - """ - - if _enable_lmhead_tp(): + if lmhead_tp_enable(): # Gather hidden states from all devices in tensor parallel group - gathered_hidden_states = get_lmheadtp_group().all_gather(hidden_states, dim=0) + gathered_hidden_states = get_lmhead_tp_group().all_gather( + hidden_states, dim=0) else: gathered_hidden_states = hidden_states - # Compute logits using quantized matrix multiplication - local_logits = lm_head.quant_method.apply( - lm_head, - gathered_hidden_states, - bias=embedding_bias - ) + local_logits = lm_head.quant_method.apply(lm_head, + gathered_hidden_states, + bias=embedding_bias) - if _enable_lmhead_tp(): - logits = get_lmheadtp_group().all_to_all(local_logits) + if lmhead_tp_enable(): + logits = get_lmhead_tp_group().all_to_all(local_logits) else: # Gather logits for tensor parallel logits = self._gather_logits(local_logits) @@ -300,6 +224,5 @@ def _get_logits( # Remove paddings in vocab (if any) if logits is not None: logits = logits[..., :self.org_vocab_size] - + return logits - diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 9383a6afc98..29ac635b9bd 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -134,8 +134,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if parallel_config: # assign lmhead tensor parallel size parallel_config.lmhead_tensor_parallel_size = ( - ascend_config.lmhead_tensor_parallel_size - ) + ascend_config.lmhead_tensor_parallel_size) if model_config is None: logger.warning("Model config is missing. This may indicate " diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index b37d6823d88..24fd33a1ea4 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -168,18 +168,6 @@ def _generate_dummy_run_hidden_states(self, with_prefill, hidden_states = super()._generate_dummy_run_hidden_states( with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) - if not self.in_profile_run and self._enable_lmhead_tp(): - # lmhead_tp introduces additional communication across - # dp when computing logits. Hence we need to add it - # in profile_run. - if not with_prefill: - max_num_reqs_across_dp = num_tokens - else: - max_num_reqs_across_dp = self.scheduler_config.max_num_seqs - dummy_indices = torch.zeros(max_num_reqs_across_dp, - device=hidden_states.device, - dtype=torch.int32) - self.model.compute_logits(hidden_states[dummy_indices], None) return hidden_states def _convert_torch_format(self, kv_cache): diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 438dd7cd3d7..f191a0968a1 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -549,7 +549,5 @@ def get_ascend_soc_version(): return _ascend_soc_version -def _enable_lmhead_tp() -> bool: - if get_ascend_config().lmhead_tensor_parallel_size is not None: - return True - return False +def lmhead_tp_enable() -> bool: + return get_ascend_config().lmhead_tensor_parallel_size is not None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 828e54092a2..1ce98463f9a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -58,8 +58,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LazyLoader, cdiv, is_pin_memory_available, _enable_lmhead_tp) -from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + LazyLoader, cdiv, is_pin_memory_available) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, @@ -89,7 +89,7 @@ from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - ProfileExecuteDuration, is_310p, + ProfileExecuteDuration, is_310p,lmhead_tp_enable, vllm_version_is) from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer @@ -1276,15 +1276,12 @@ def _prepare_inputs( spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices - - if _enable_lmhead_tp(): # - if not with_prefill: - max_num_reqs_across_dp = padded_num_tokens_across_dp - else: - max_num_reqs_across_dp = self.max_num_reqs - logits_indices = nn.functional.pad( - logits_indices, - (0, max_num_reqs_across_dp - logits_indices.shape[0])) + + if lmhead_tp_enable(): + max_num_reqs_across_dp = padded_num_tokens_across_dp if not with_prefill else self.max_num_reqs + logits_indices = nn.functional.pad( + logits_indices, + (0, max_num_reqs_across_dp - logits_indices.shape[0])) return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, @@ -1743,14 +1740,14 @@ def execute_model( # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: - if _enable_lmhead_tp(): + if lmhead_tp_enable() and logits is not None: logits = logits[:self.input_batch.num_reqs] sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) else: - if _enable_lmhead_tp(): + if lmhead_tp_enable() and logits is not None: logits = logits[:len(spec_decode_metadata.logits_indices)] # When indexing with a tensor (bonus_logits_indices), PyTorch # creates a new tensor with separate storage from the original @@ -2093,19 +2090,18 @@ def _dummy_run( assert aclgraph_runtime_mode == _cg_mode, ( f"Aclgraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.") - - need_dummy_logits = (not self.in_profile_run and self._enable_lmhead_tp()) - + + need_dummy_logits = (not self.in_profile_run + and lmhead_tp_enable()) + if need_dummy_logits: - # lmhead_tp introduces additional communication across - # dp when computing logits. Hence we need to add it - # in profile_run. - max_num_reqs_across_dp = num_tokens if not with_prefill else self.scheduler_config.max_num_seqs + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs dummy_indices = torch.zeros(max_num_reqs_across_dp, - device=hidden_states.device, dtype=torch.int32) + def dummy_compute_logits(hidden_states): - return self.model.compute_logits(hidden_states[dummy_indices], None) + return self.model.compute_logits( + hidden_states[dummy_indices], None) with set_ascend_forward_context( attn_metadata, @@ -2124,8 +2120,8 @@ def dummy_compute_logits(hidden_states): attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) if need_dummy_logits: - dummy_compute_logits(hidden_states) - + dummy_compute_logits(hidden_states) + if self.speculative_config and self.speculative_config.method == "deepseek_mtp": assert isinstance(self.drafter, MtpProposer) self.drafter.dummy_run( @@ -2136,8 +2132,6 @@ def dummy_compute_logits(hidden_states): num_tokens_across_dp=num_tokens_across_dp) if need_dummy_logits: dummy_compute_logits(hidden_states) - - return hidden_states @contextmanager diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 120b17a652c..848da933a11 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -19,7 +19,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata -from vllm_ascend.utils import ProfileExecuteDuration +from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable class MtpProposer: @@ -235,8 +235,20 @@ def propose( previous_hidden_states=self. hidden_states[:num_input_tokens], kv_caches=self.runner.kv_caches[-1:]) + + num_indices = last_token_indices.shape[0] + if lmhead_tp_enable(): + if not self.runner.with_prefill: + max_num_reqs_across_dp = num_input_tokens + else: + max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs + last_token_indices = nn.functional.pad( + last_token_indices, (0, max_num_reqs_across_dp - num_indices)) + sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) + if lmhead_tp_enable() and num_indices < logits.shape[0]: + logits = logits[:num_indices] draft_token_ids = logits.argmax(dim=-1) # [batch_size, 1] From e4d6bec5d165676d06e4a86a7dc04aeb33e7b411 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Mon, 25 Aug 2025 19:36:37 +0800 Subject: [PATCH 04/10] Register VocabParallelEmbedding as a custom op for Ascend Signed-off-by: zzhx1 --- tests/ut/models/test_deepseek_v2.py | 3 +- vllm_ascend/models/deepseek_mtp.py | 18 ++-- vllm_ascend/models/deepseek_v2.py | 19 ++-- vllm_ascend/ops/vocab_parallel_embedding.py | 111 ++++++++++++-------- vllm_ascend/utils.py | 15 ++- 5 files changed, 100 insertions(+), 66 deletions(-) diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index 0b3fe9c8e5f..435527a1299 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -20,13 +20,14 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig from vllm.distributed.parallel_state import GroupCoordinator +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm_ascend.models.deepseek_v2 import ( CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention, CustomDeepseekV2MLP, CustomDeepseekV2MoE, CustomDeepseekV2RowParallelLinear, CustomDeepseekV2RowParallelLinearReplaceAllreduce, - CustomDeepseekV2SiluAndMul, CustomLogitsProcessor, CustomParallelLMHead) + CustomDeepseekV2SiluAndMul, CustomParallelLMHead) @pytest.fixture diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index f2784aa74c2..8bcc4fbed30 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -25,10 +25,11 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import \ - VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_mtp import ( DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead) @@ -36,9 +37,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm_ascend.ops.vocab_parallel_embedding import (CustomLogitsProcessor, - CustomParallelLMHead) - from .deepseek_v2 import CustomDeepseekV2DecoderLayer @@ -50,10 +48,10 @@ def __init__(self, prefix: str = "") -> None: nn.Module.__init__(self) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = CustomParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "head")) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): @@ -143,7 +141,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) ] - self.logits_processor = CustomLogitsProcessor(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) def forward( self, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 3dbb8fea686..a9ceaed80e9 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -48,11 +48,12 @@ ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import \ - VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import \ @@ -69,8 +70,6 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.ops.vocab_parallel_embedding import (CustomLogitsProcessor, - CustomParallelLMHead) from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor @@ -877,14 +876,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix( prefix, "model")) if get_pp_group().is_last_rank: - self.lm_head = CustomParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() - self.logits_processor = CustomLogitsProcessor(config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 8dd0718d405..3b00b407876 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -18,11 +18,10 @@ from typing import Optional, Tuple import torch -from torch.nn import Module +from torch import nn from torch.nn.parameter import Parameter -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import divide, tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import get_tp_group from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) @@ -85,39 +84,30 @@ def forward(self, input_): return output -class CustomParallelLMHead(ParallelLMHead): - """Custom Parallelized LM head, added the feature of lmheadTP in pure dp scenario - - Output logits weight matrices used in the Sampler. The weight and bias - tensors are padded to make sure they are divisible by the number of - model parallel GPUs. - - Args: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - bias: whether to use bias. - params_dtype: type of the parameters. - org_num_embeddings: original vocabulary size (without LoRA). - padding_size: padding size for the vocabulary. +class AscendVocabParallelEmbedding(VocabParallelEmbedding): + """ + Register VocabParallelEmbedding as a custom op for Ascend. + AscendVocabParallelEmbedding support different communication parallel groups + Added the feature of lmheadTP in pure dp scenario """ def __init__(self, num_embeddings: int, embedding_dim: int, - bias: bool = False, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): - Module.__init__(self) + nn.Module.__init__(self) - if lmhead_tp_enable(): - tp_rank = get_lmhead_tp_group().rank_in_group - self.tp_size = get_lmhead_tp_group().world_size + if lmhead_tp_enable() and prefix.find("lm_head") != -1: + self.comm_group = get_lmhead_tp_group() else: - tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + self.comm_group = get_tp_group() + + self.tp_size = self.comm_group.world_size + self.tp_rank = self.comm_group.rank_in_group self.num_embeddings = num_embeddings self.padding_size = padding_size @@ -133,10 +123,9 @@ def __init__(self, self.shard_indices = self._get_indices(self.num_embeddings_padded, self.org_vocab_size_padded, self.num_embeddings, - self.org_vocab_size, tp_rank, - self.tp_size) + self.org_vocab_size, + self.tp_rank, self.tp_size) self.embedding_dim = embedding_dim - quant_method = None if quant_config is not None: quant_method = quant_config.get_quant_method(self, prefix=prefix) @@ -179,6 +168,25 @@ def __init__(self, params_dtype=params_dtype, weight_loader=self.weight_loader) + +class AscendParallelLMHead(ParallelLMHead): + """ + Register ParallelLMHead as a custom op for Ascend.""" + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + AscendVocabParallelEmbedding.__init__(self, num_embeddings, + embedding_dim, params_dtype, + org_num_embeddings, padding_size, + quant_config, prefix) + self.quant_config = quant_config if bias: self.bias = Parameter( @@ -192,34 +200,55 @@ def __init__(self, self.register_parameter("bias", None) -class CustomLogitsProcessor(LogitsProcessor): - """Custom logits processor extending base LogitsProcessor functionality. +class AscendLogitsProcessor(LogitsProcessor): + """ + Register LogitsProcessor as a custom op for Ascend. Added the feature of lmheadTP in pure dp scenario """ def _get_logits( self, hidden_states: torch.Tensor, - lm_head: CustomParallelLMHead, + lm_head: AscendParallelLMHead, embedding_bias: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: - if lmhead_tp_enable(): - # Gather hidden states from all devices in tensor parallel group - gathered_hidden_states = get_lmhead_tp_group().all_gather( - hidden_states, dim=0) + return self._get_logits_lmheadtp(hidden_states, lm_head, + embedding_bias) else: - gathered_hidden_states = hidden_states + return self._get_logits_normal(hidden_states, lm_head, + embedding_bias) + def _get_logits_lmheadtp( + self, + hidden_states: torch.Tensor, + lm_head: AscendParallelLMHead, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + # Gather hidden states from all devices in tensor parallel group + gathered_hidden_states = get_lmhead_tp_group().all_gather( + hidden_states, dim=0) local_logits = lm_head.quant_method.apply(lm_head, gathered_hidden_states, bias=embedding_bias) + # Gather logits for tensor parallel + logits = get_lmhead_tp_group().all_to_all(local_logits) + # Remove paddings in vocab (if any) + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits - if lmhead_tp_enable(): - logits = get_lmhead_tp_group().all_to_all(local_logits) - else: - # Gather logits for tensor parallel - logits = self._gather_logits(local_logits) + def _get_logits_normal( + self, + hidden_states: torch.Tensor, + lm_head: AscendParallelLMHead, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + local_logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + # Gather logits for tensor parallel + logits = self._gather_logits(local_logits) # Remove paddings in vocab (if any) if logits is not None: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index f191a0968a1..4bf6f97c24c 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -485,10 +485,12 @@ def register_ascend_customop(): from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear, - AscendMlpMergedColumnParallelLinear, - AscendMlpRowParallelLinear) + AscendMlpMergedColumnParallelLinear) from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) + from vllm_ascend.ops.vocab_parallel_embedding import ( + AscendLogitsProcessor, AscendParallelLMHead, + AscendVocabParallelEmbedding) CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU") CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul, name="SiluAndMul") @@ -497,11 +499,16 @@ def register_ascend_customop(): CustomOp.register_oot( _decorated_op_cls=AscendDeepseekScalingRotaryEmbedding, name="DeepseekScalingRotaryEmbedding") + CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding, + name="VocabParallelEmbedding") + CustomOp.register_oot(_decorated_op_cls=AscendParallelLMHead, + name="ParallelLMHead") + CustomOp.register_oot(_decorated_op_cls=AscendLogitsProcessor, + name="LogitsProcessor") if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE: CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear, name="ColumnParallelLinear") - CustomOp.register_oot(_decorated_op_cls=AscendMlpRowParallelLinear, - name="RowParallelLinear") + CustomOp.register_oot( _decorated_op_cls=AscendMlpMergedColumnParallelLinear, name="MergedColumnParallelLinear") From f7efbc61443d53c950d564b7eb364c8899ac4896 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Tue, 26 Aug 2025 12:29:26 +0800 Subject: [PATCH 05/10] [CI] fix Signed-off-by: zzhx1 --- tests/ut/models/test_deepseek_v2.py | 3 ++- vllm_ascend/utils.py | 1 + vllm_ascend/worker/model_runner_v1.py | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index 435527a1299..2bb3bb2b1db 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -21,13 +21,14 @@ from vllm.config import CacheConfig from vllm.distributed.parallel_state import GroupCoordinator from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm_ascend.models.deepseek_v2 import ( CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention, CustomDeepseekV2MLP, CustomDeepseekV2MoE, CustomDeepseekV2RowParallelLinear, CustomDeepseekV2RowParallelLinearReplaceAllreduce, - CustomDeepseekV2SiluAndMul, CustomParallelLMHead) + CustomDeepseekV2SiluAndMul) @pytest.fixture diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 4bf6f97c24c..f0a9dabfaf9 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -33,6 +33,7 @@ from vllm.logger import logger import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config if TYPE_CHECKING: from vllm.config import VllmConfig diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1ce98463f9a..949759b6e98 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -89,8 +89,8 @@ from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - ProfileExecuteDuration, is_310p,lmhead_tp_enable, - vllm_version_is) + ProfileExecuteDuration, is_310p, + lmhead_tp_enable, vllm_version_is) from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch From 9af76017e0ae7dc4c92b7871a223761a37e129ab Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Tue, 26 Aug 2025 14:14:00 +0800 Subject: [PATCH 06/10] [CI] fix UT about vocab_parallel_embedding Signed-off-by: zzhx1 --- tests/ut/models/test_deepseek_mtp.py | 17 ++++++++++++++++- tests/ut/test_utils.py | 4 ++-- .../models/test_torchair_deepseek_mtp.py | 17 ++++++++++++++++- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/tests/ut/models/test_deepseek_mtp.py b/tests/ut/models/test_deepseek_mtp.py index 45b6ed5d198..61fdf9829fc 100644 --- a/tests/ut/models/test_deepseek_mtp.py +++ b/tests/ut/models/test_deepseek_mtp.py @@ -31,6 +31,11 @@ def setup_mtp_layer(self, mocker: MockerFixture): mocker_deepseek_v2_decode_layer = mocker.patch( "vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None) mocker_deepseek_v2_decode_layer.assert_called_once() @@ -83,6 +88,11 @@ def setup_predictor(self, mocker: MockerFixture): mocker.patch( "vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) predictor = CustomDeepSeekMultiTokenPredictor( vllm_config=mock_vllm_config) @@ -157,6 +167,11 @@ def setup_mtp(self, mocker: MockerFixture): return_value=None) mocker.patch("vllm.model_executor.layers.sampler.get_sampler", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) mtp = CustomDeepSeekMTP(vllm_config=vllm_config) return mtp @@ -177,4 +192,4 @@ def test_forward(self, mocker: MockerFixture, setup_mtp): output = setup_mtp.forward(input_ids, positions, kv_caches, None, previous_hidden_states, inputs_embeds, spec_step_idx) - assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) \ No newline at end of file + assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 396f457cf81..16cc455e22d 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -289,13 +289,13 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm, # ascend custom op is not registered utils.register_ascend_customop() # should call register_oot three - self.assertEqual(mock_customop.register_oot.call_count, 10) + self.assertEqual(mock_customop.register_oot.call_count, 11) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() # should not register_oot again, thus only called three in this ut - self.assertEqual(mock_customop.register_oot.call_count, 10) + self.assertEqual(mock_customop.register_oot.call_count, 11) class TestProfileExecuteDuration(TestBase): diff --git a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py index 1c1e6c711ad..7aafdfc4f45 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py @@ -31,6 +31,11 @@ def setup_mtp_layer(self, mocker: MockerFixture): mocker_deepseek_v2_decode_layer = mocker.patch( "vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None) mocker_deepseek_v2_decode_layer.assert_called_once() @@ -83,6 +88,11 @@ def setup_predictor(self, mocker: MockerFixture): mocker.patch( "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) predictor = TorchairDeepSeekMultiTokenPredictor( vllm_config=mock_vllm_config) @@ -157,6 +167,11 @@ def setup_mtp(self, mocker: MockerFixture): return_value=None) mocker.patch("vllm.model_executor.layers.sampler.get_sampler", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) mtp = TorchairDeepSeekMTP(vllm_config=vllm_config) return mtp @@ -177,4 +192,4 @@ def test_forward(self, mocker: MockerFixture, setup_mtp): output = setup_mtp.forward(input_ids, positions, kv_caches, None, previous_hidden_states, inputs_embeds, spec_step_idx) - assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) \ No newline at end of file + assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) From f8b99dc33d46f898bdfbf4e1e176186013da5a58 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Wed, 27 Aug 2025 19:09:47 +0800 Subject: [PATCH 07/10] fix code from reviewer Signed-off-by: zzhx1 --- docs/source/user_guide/configuration/additional_config.md | 1 + tests/ut/distributed/test_parallel_state.py | 6 ++++-- vllm_ascend/distributed/parallel_state.py | 4 +++- vllm_ascend/platform.py | 5 ----- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index cf4d1dcf697..cdb090848e6 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -34,6 +34,7 @@ The following table lists the additional configuration options available in vLLM | `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. | | `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | | `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | +| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. | The details of each config option are as follows: diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index 58e025ede46..afc22c86586 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -27,9 +27,11 @@ def mock_distributed(): def test_init_ascend_model_parallel(mock_distributed, parallel_config): + mock_ascend_config = MagicMock() + mock_ascend_config.lmhead_tensor_parallel_size = 2 with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ - patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'): - parallel_config.lmhead_tensor_parallel_size = 2 + patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \ + patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config): init_ascend_model_parallel(parallel_config) mc2_group = get_mc2_group() diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index f562ea88b21..f81d501d782 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -6,6 +6,7 @@ init_model_parallel_group) import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None @@ -73,7 +74,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="mlp_tp") - lmhead_tensor_parallel_size = parallel_config.lmhead_tensor_parallel_size + lmhead_tensor_parallel_size = get_ascend_config( + ).lmhead_tensor_parallel_size if lmhead_tensor_parallel_size is not None: group_ranks = [] global _LMTP diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 29ac635b9bd..afc0e6b5458 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -131,11 +131,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if kv_cache_dtype is not None: vllm_config.cache_config.cache_dtype = kv_cache_dtype - if parallel_config: - # assign lmhead tensor parallel size - parallel_config.lmhead_tensor_parallel_size = ( - ascend_config.lmhead_tensor_parallel_size) - if model_config is None: logger.warning("Model config is missing. This may indicate " "that we are running a test case") From cb4090b775e391be87f2ea560e1b65b9b2122f94 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Thu, 28 Aug 2025 09:14:09 +0800 Subject: [PATCH 08/10] rebase custom AscendVocabParallelEmbedding Signed-off-by: zzhx1 --- vllm_ascend/ops/vocab_parallel_embedding.py | 97 ++++++++++----------- vllm_ascend/utils.py | 4 - 2 files changed, 47 insertions(+), 54 deletions(-) diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 3b00b407876..f43387a73db 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -34,56 +34,6 @@ from vllm_ascend.utils import lmhead_tp_enable -class AscendVocabParallelEmbedding(VocabParallelEmbedding): - - def _get_masked_input_and_mask( - self, input_: torch.Tensor, org_vocab_start_index: int, - org_vocab_end_index: int, num_org_vocab_padding: int, - added_vocab_start_index: int, - added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: - # torch.compile will fuse all of the pointwise ops below - # into a single kernel, making it very fast - org_vocab_mask = (input_ >= org_vocab_start_index) & ( - input_ < org_vocab_end_index) - # Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index. - if added_vocab_start_index == added_vocab_end_index: - valid_offset = (org_vocab_start_index * org_vocab_mask) - vocab_mask = org_vocab_mask - else: - added_vocab_mask = (input_ >= added_vocab_start_index) & ( - input_ < added_vocab_end_index) - added_offset = added_vocab_start_index - ( - org_vocab_end_index - - org_vocab_start_index) - num_org_vocab_padding - valid_offset = (org_vocab_start_index * - org_vocab_mask) + (added_offset * added_vocab_mask) - vocab_mask = org_vocab_mask | added_vocab_mask - # Adapt end. - input_ = vocab_mask * (input_ - valid_offset) - return input_, ~vocab_mask - - def forward(self, input_): - if self.tp_size > 1: - # Build the mask. - masked_input, input_mask = self._get_masked_input_and_mask( - input_, self.shard_indices.org_vocab_start_index, - self.shard_indices.org_vocab_end_index, - self.shard_indices.num_org_vocab_padding, - self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index) - else: - masked_input = input_ - # Get the embeddings. - output_parallel = self.quant_method.embedding(self, - masked_input.long()) - # Mask the output embedding. - if self.tp_size > 1: - output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) - # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) - return output - - class AscendVocabParallelEmbedding(VocabParallelEmbedding): """ Register VocabParallelEmbedding as a custom op for Ascend. @@ -168,6 +118,53 @@ def __init__(self, params_dtype=params_dtype, weight_loader=self.weight_loader) + def _get_masked_input_and_mask( + self, input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + # torch.compile will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & ( + input_ < org_vocab_end_index) + # Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index. + if added_vocab_start_index == added_vocab_end_index: + valid_offset = (org_vocab_start_index * org_vocab_mask) + vocab_mask = org_vocab_mask + else: + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - + org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + # Adapt end. + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = self._get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, + masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output + class AscendParallelLMHead(ParallelLMHead): """ diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index f0a9dabfaf9..f4ab5d9a8b8 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -520,10 +520,6 @@ def register_ascend_customop(): from vllm_ascend.ops.common_fused_moe import AscendFusedMoE CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE") - from vllm_ascend.ops.vocab_parallel_embedding import \ - AscendVocabParallelEmbedding - CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding, - name="VocabParallelEmbedding") # NOTE: Keep this at last to ensure all custom actions are registered _ASCEND_CUSTOMOP_IS_REIGISTERED = True From ca30d377b14c405783cd4d58f494384816be2234 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Thu, 28 Aug 2025 12:55:46 +0800 Subject: [PATCH 09/10] rebase custom main Signed-off-by: zzhx1 --- tests/ut/models/test_deepseek_v2.py | 39 ++++++++++++++++--- tests/ut/ops/test_vocab_parallel_embedding.py | 6 ++- tests/ut/test_utils.py | 4 +- vllm_ascend/models/deepseek_v2.py | 2 - vllm_ascend/utils.py | 7 ++-- 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index 2bb3bb2b1db..cdb2ad05c68 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -20,15 +20,13 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig from vllm.distributed.parallel_state import GroupCoordinator -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm_ascend.models.deepseek_v2 import ( - CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention, - CustomDeepseekV2MLP, CustomDeepseekV2MoE, + CustomDeepseekV2ForCausalLM, CustomDeepseekV2MergedReplicatedLinear, + CustomDeepseekV2MLAAttention, CustomDeepseekV2MLP, CustomDeepseekV2MoE, CustomDeepseekV2RowParallelLinear, CustomDeepseekV2RowParallelLinearReplaceAllreduce, - CustomDeepseekV2SiluAndMul) + CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead) @pytest.fixture @@ -268,3 +266,34 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, kv_lora_rank=16, prefix="layers.1.self_attn") assert hasattr(attn, "q_proj") + +def test_deepseek_v2_lmhead(mock_distributed, vllm_config): + # 创建一个简单的配置对象 + class SimpleConfig: + def __init__(self): + self.vocab_size = 10000 + self.hidden_size = 128 + + config = SimpleConfig() + + # 直接创建lmhead和logits_processor + lmhead = ParallelLMHead(config.vocab_size, config.hidden_size) + logits_processor = LogitsProcessor(config.vocab_size) + + # 创建测试输入 + input_ids = torch.randint(0, config.vocab_size, (2, 4)) + positions = torch.arange(4).repeat(2, 1) + + # 创建模拟输出 + mock_output = torch.randn(2, 4, config.hidden_size) + mock_logits = torch.randn(2, 4, config.vocab_size) + + # 直接测试logits_processor + with patch.object(lmhead.quant_method, + "apply", + return_value=mock_logits): + with patch.object(logits_processor, + "_gather_logits", + return_value=mock_logits): + logits = logits_processor(lmhead, mock_output) + assert logits.shape == (2, 4, config.vocab_size) \ No newline at end of file diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py index 13ede67085e..33cbaf69340 100644 --- a/tests/ut/ops/test_vocab_parallel_embedding.py +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -34,7 +34,11 @@ def setUp(self): def _create_layer(self): # Patch methods and dependencies for VocabParallelEmbedding - with patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \ + mock_group = MagicMock() + mock_group.world_size = 2 + mock_group.rank_in_group = 0 + with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \ + patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \ patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \ patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \ patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y): diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 16cc455e22d..0d264c7fe90 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -289,13 +289,13 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm, # ascend custom op is not registered utils.register_ascend_customop() # should call register_oot three - self.assertEqual(mock_customop.register_oot.call_count, 11) + self.assertEqual(mock_customop.register_oot.call_count, 12) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() # should not register_oot again, thus only called three in this ut - self.assertEqual(mock_customop.register_oot.call_count, 11) + self.assertEqual(mock_customop.register_oot.call_count, 12) class TestProfileExecuteDuration(TestBase): diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index a9ceaed80e9..6d0913c18c9 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -74,8 +74,6 @@ from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor -from vllm_ascend.utils import dispose_tensor, npu_prefetch - class CustomDeepseekV2SiluAndMul(SiluAndMul): diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index f4ab5d9a8b8..adab4906e03 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -486,7 +486,8 @@ def register_ascend_customop(): from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear, - AscendMlpMergedColumnParallelLinear) + AscendMlpMergedColumnParallelLinear, + AscendMlpRowParallelLinear) from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) from vllm_ascend.ops.vocab_parallel_embedding import ( @@ -509,7 +510,8 @@ def register_ascend_customop(): if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE: CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear, name="ColumnParallelLinear") - + CustomOp.register_oot(_decorated_op_cls=AscendMlpRowParallelLinear, + name="RowParallelLinear") CustomOp.register_oot( _decorated_op_cls=AscendMlpMergedColumnParallelLinear, name="MergedColumnParallelLinear") @@ -520,7 +522,6 @@ def register_ascend_customop(): from vllm_ascend.ops.common_fused_moe import AscendFusedMoE CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE") - # NOTE: Keep this at last to ensure all custom actions are registered _ASCEND_CUSTOMOP_IS_REIGISTERED = True From 8c80738e141b1025d83455d29bfea8e235bf0144 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Thu, 28 Aug 2025 19:21:32 +0800 Subject: [PATCH 10/10] [CI] fix Signed-off-by: zzhx1 --- tests/ut/models/test_deepseek_v2.py | 18 +++--- tests/ut/ops/test_vocab_parallel_embedding.py | 56 ++++++++++++++++++- vllm_ascend/ops/vocab_parallel_embedding.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 2 +- 4 files changed, 63 insertions(+), 15 deletions(-) diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index cdb2ad05c68..df14a2a68e5 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -22,8 +22,8 @@ from vllm.distributed.parallel_state import GroupCoordinator from vllm_ascend.models.deepseek_v2 import ( - CustomDeepseekV2ForCausalLM, CustomDeepseekV2MergedReplicatedLinear, - CustomDeepseekV2MLAAttention, CustomDeepseekV2MLP, CustomDeepseekV2MoE, + CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention, + CustomDeepseekV2MLP, CustomDeepseekV2MoE, CustomDeepseekV2RowParallelLinear, CustomDeepseekV2RowParallelLinearReplaceAllreduce, CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead) @@ -267,33 +267,29 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, prefix="layers.1.self_attn") assert hasattr(attn, "q_proj") + def test_deepseek_v2_lmhead(mock_distributed, vllm_config): # 创建一个简单的配置对象 class SimpleConfig: + def __init__(self): self.vocab_size = 10000 self.hidden_size = 128 config = SimpleConfig() - + # 直接创建lmhead和logits_processor lmhead = ParallelLMHead(config.vocab_size, config.hidden_size) logits_processor = LogitsProcessor(config.vocab_size) - # 创建测试输入 - input_ids = torch.randint(0, config.vocab_size, (2, 4)) - positions = torch.arange(4).repeat(2, 1) - # 创建模拟输出 mock_output = torch.randn(2, 4, config.hidden_size) mock_logits = torch.randn(2, 4, config.vocab_size) # 直接测试logits_processor - with patch.object(lmhead.quant_method, - "apply", - return_value=mock_logits): + with patch.object(lmhead.quant_method, "apply", return_value=mock_logits): with patch.object(logits_processor, "_gather_logits", return_value=mock_logits): logits = logits_processor(lmhead, mock_output) - assert logits.shape == (2, 4, config.vocab_size) \ No newline at end of file + assert logits.shape == (2, 4, config.vocab_size) diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py index 33cbaf69340..5378b1905fd 100644 --- a/tests/ut/ops/test_vocab_parallel_embedding.py +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -18,8 +18,8 @@ import torch -from vllm_ascend.ops.vocab_parallel_embedding import \ - AscendVocabParallelEmbedding +from vllm_ascend.ops.vocab_parallel_embedding import ( + AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding) VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 @@ -178,3 +178,55 @@ def test_output_shape(self): # Call the forward method output = layer.forward(input_) self.assertEqual(output.shape, expected_shape) + + +class TestAscendLogitsProcessor(unittest.TestCase): + + def setUp(self): + self.vocab_size = 50 + self.num_embeddings = 50 + self.embedding_dim = 10 + self.org_num_embeddings = 40 + self.padding_size = 8 + + self.mock_group = MagicMock() + self.mock_group.world_size = 2 + self.mock_group.rank_in_group = 0 + self.mock_ascend_config = MagicMock() + self.mock_quant_method = MagicMock() + self.mock_quant_method.apply = MagicMock( + return_value=torch.randn(1, self.vocab_size)) + self.patches = [ + patch("vllm_ascend.ascend_config.get_ascend_config", + return_value=self.mock_ascend_config), + patch( + "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group", + return_value=self.mock_group), + patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable", + return_value=True), + patch( + "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all", + return_value=torch.randn(1, self.vocab_size)) + ] + + for p in self.patches: + p.start() + + def tearDown(self): + for p in self.patches: + p.stop() + + def test_create_processor(self): + processor = AscendLogitsProcessor(vocab_size=self.vocab_size) + self.assertEqual(processor.vocab_size, self.vocab_size) + + def test_get_logits(self): + processor = AscendLogitsProcessor(vocab_size=self.vocab_size) + lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings, + embedding_dim=self.embedding_dim, + prefix="lm_head") + lmhead.quant_method = self.mock_quant_method + lmhead.quant_method.apply = self.mock_quant_method.apply + hidden_state = torch.randn(1, self.org_num_embeddings) + processor._get_logits(hidden_state, lmhead) + self.mock_quant_method.apply.assert_called_once() diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index f43387a73db..7ad35dc288e 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -207,7 +207,7 @@ def _get_logits( self, hidden_states: torch.Tensor, lm_head: AscendParallelLMHead, - embedding_bias: Optional[torch.Tensor], + embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: if lmhead_tp_enable(): return self._get_logits_lmheadtp(hidden_states, lm_head, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 949759b6e98..15effb7ce6a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1278,7 +1278,7 @@ def _prepare_inputs( logits_indices = spec_decode_metadata.logits_indices if lmhead_tp_enable(): - max_num_reqs_across_dp = padded_num_tokens_across_dp if not with_prefill else self.max_num_reqs + max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs logits_indices = nn.functional.pad( logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0]))