From 5dbefd6a648d25d94156e03737d1bda2cf959e76 Mon Sep 17 00:00:00 2001 From: Bartosz Myrcha Date: Tue, 8 Apr 2025 09:51:52 +0200 Subject: [PATCH 01/20] [SW-224648] Redirect test logs to file (#1017) Switched execution of versioned branches to _next and added logs redirection to file. --- .github/workflows/trigger_jenkins.yml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/trigger_jenkins.yml b/.github/workflows/trigger_jenkins.yml index ae4f8d32167d..c9f199ce9549 100644 --- a/.github/workflows/trigger_jenkins.yml +++ b/.github/workflows/trigger_jenkins.yml @@ -220,11 +220,14 @@ jobs: RELEASED_SYNAPSE_VERSION: ${{ vars.RELEASED_SYNAPSE_VERSION }} BASE_BRANCH: ${{ needs.read_codeowners.outputs.pr_branch }} run: | - version_regex='^v([0-9]+)\.([0-9]+)\.([0-9]+)$' + LOG_REDIRECTION=">" + version_regex='^v([0-9]+)\.([0-9]+)\.([0-9]+)_next$' if [[ $TARGET_BRANCH =~ $version_regex ]]; then synapse_version=${TARGET_BRANCH#v} + synapse_version=${synapse_version%_*} else synapse_version=${RELEASED_SYNAPSE_VERSION#v} + LOG_REDIRECTION="\| tee" fi echo "Using SynapseAI version ${synapse_version}" synapse_build=$(curl "https://dms.habana-labs.com/api/v1.1/branch/info/v$synapse_version" | jq -r ".release_id") @@ -239,17 +242,20 @@ jobs: sed -i "s/##PYTORCH_VERSION##/${pt_version}/g" pod.yml sed -i "s|##GIT_BRANCH##|$BASE_BRANCH|g" pod.yml sed -i "s|##CMD##|$safe_cmd|g" pod.yml + sed -i "s|##LOG_REDIRECTION##|$LOG_REDIRECTION|g" pod.yml echo "Pod Template Created" - name: Run Test run: | random_string=$(tr -dc 'a-z0-9' Date: Tue, 8 Apr 2025 17:02:02 +0000 Subject: [PATCH 02/20] apply deepseek change Signed-off-by: Chendi.Xue --- vllm/attention/backends/hpu_attn.py | 251 +++++++++++++++++- vllm/attention/backends/mla/utils.py | 103 ++++--- vllm/attention/ops/hpu_paged_attn.py | 2 + vllm/envs.py | 5 + .../layers/fused_moe/fused_moe.py | 19 +- vllm/model_executor/layers/fused_moe/layer.py | 128 ++++++--- vllm/model_executor/layers/linear.py | 10 +- .../model_executor/layers/quantization/fp8.py | 40 ++- .../model_loader/weight_utils.py | 2 +- vllm/model_executor/models/deepseek_v2.py | 36 +-- vllm/platforms/hpu.py | 3 + vllm/worker/hpu_model_runner.py | 42 ++- vllm/worker/hpu_worker.py | 17 +- vllm/worker/model_runner_base.py | 6 +- 14 files changed, 545 insertions(+), 119 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index fd005dded2e9..5034dba7fb05 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -17,6 +17,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionType) +from vllm.attention.backends.mla.utils import MLACommonImpl from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, HPUPagedAttentionMetadata) @@ -49,7 +50,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + ) -> List[Tuple[int, ...]]: return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size) @@ -69,6 +70,49 @@ def copy_blocks( HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts) +class HPUMLAAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "HPU_MLA" + + @staticmethod + def get_impl_cls() -> Type["HPUMLAImpl"]: + return HPUMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return HPUMLAMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> List[Tuple[int, ...]]: + return [(num_blocks, block_size, head_size), None] + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + HPUPagedAttention.copy_blocks(kv_caches, src_to_dists) + + @dataclass class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): """Metadata for HPUAttentionbackend.""" @@ -78,6 +122,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor] + input_positions: torch.Tensor seq_lens: Optional[List[int]] = None encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None @@ -91,6 +136,204 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): cross_attn_bias: Optional[torch.Tensor] = None +@dataclass +class HPUMLAMetadata(HPUAttentionMetadata, AttentionMetadata): + pass + + +class HPUMLAImpl(MLACommonImpl[HPUAttentionMetadata], torch.nn.Module): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **kwargs) -> None: + torch.nn.Module.__init__(self) + MLACommonImpl.__init__(self, num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **kwargs) + + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.batch2block_matmul = Matmul() + self.block2batch_matmul = Matmul() + self.latent_cache_k = VLLMKVCache() + self.fused_scaled_dot_product_attention = kernels.fsdpa() + + if "fsdpa" in enabled_flags(): + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + self.prefill_impl = 'fsdpa' + else: + self.prefill_impl = 'naive' + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "TritonMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl") + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: HPUAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if output is not None: + raise NotImplementedError( + "output is not yet supported for MLAImplBase") + + batch_size = hidden_states_or_q_c.shape[0] + + is_prefill = attn_metadata.is_prompt + + k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) + + # Restore head dim (for rotary embedding) + # k_pe = k_pe.unsqueeze(1) + assert hasattr(attn_metadata, + "input_positions"), f"attn meta: {attn_metadata}" + + if not is_prefill: + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) + q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ + .view(-1, self.num_heads, self.qk_rope_head_dim) + else: + q_nope, q_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c) + input_positions = attn_metadata.input_positions.view(-1) + q_pe, k_pe = \ + self.rotary_emb(input_positions, q_pe, k_pe) + else: + q = self.q_proj(hidden_states_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + + q_pe = q[..., self.qk_nope_head_dim:] + + input_positions = attn_metadata.input_positions.view(-1) + # TODO(lucas): there must be a nicer way to write this line + q[..., self.qk_nope_head_dim:], k_pe = \ + self.rotary_emb(input_positions, q_pe, k_pe) + + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + + latent_vec_k = torch.concat( + (k_c_normed, k_pe.view(batch_size, -1, self.qk_rope_head_dim)), + dim=-1) + latent_vec_k = latent_vec_k.view( + -1, self.qk_rope_head_dim + self.kv_lora_rank) + if is_prefill: + latent_vec_k = latent_vec_k.unflatten(0, + (block_indices.size(0), -1)) + + # write the latent and rope to kv cache + if kv_cache is not None and len(kv_cache) == 2: + self.latent_cache_k(latent_vec_k, kv_cache[0], block_indices, + block_offsets) + k_cache = kv_cache[0] + v_cache = None + + if is_prefill: + return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata, + batch_size) + else: + return self._forward_decode(q_nope, q_pe, (k_cache, v_cache), + attn_metadata, batch_size) + + def _forward_prefill(self, q: torch.Tensor, k_c_normed: torch.Tensor, + k_pe: torch.Tensor, + attn_metadata: HPUAttentionMetadata, + batch_size: int) -> torch.Tensor: + kv_nope = self.kv_b_proj(k_c_normed)[0]\ + .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + q = q.view(batch_size, -1, self.num_heads, self.qk_head_dim) + k = k.view(batch_size, -1, self.num_heads, self.qk_head_dim) + v_padded = v_padded.view(batch_size, -1, self.num_heads, + self.qk_head_dim) + out = ops.prompt_attention( + impl=self.prefill_impl, + query=q, + key=k, + value=v_padded, + is_causal=True, + attn_bias=attn_metadata.attn_bias, + valid_seq_lengths=attn_metadata.seq_lens_tensor, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + fsdpa_op=self.fused_scaled_dot_product_attention.apply \ + if self.fused_scaled_dot_product_attention is not None else None) + attn_output = out.view(batch_size, -1, self.num_heads, q.shape[-1]) + attn_output = attn_output[..., :v.shape[-1]]\ + .reshape(batch_size, -1, self.num_heads * v.shape[-1]) + + return self.o_proj(attn_output)[0] + + def _forward_decode(self, q_nope: torch.Tensor, q_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: HPUAttentionMetadata, + batch_size: int) -> torch.Tensor: + query = torch.cat([q_nope, q_pe], dim=-1) + + key_cache = kv_cache[0].unsqueeze(2) + value_cache = kv_cache[1] # value_cache is None + output = HPUPagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + block_groups=attn_metadata.block_groups, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + batch2block_matmul_op=self.batch2block_matmul, + block2batch_matmul_op=self.block2batch_matmul, + keys_fetch_func=self.latent_cache_k.fetch_from_cache, + values_fetch_func=None, + kv_lora_rank=self.kv_lora_rank) + output = output.view(batch_size, 1, -1) + result = self._v_up_proj_and_o_proj(output) + result = result.view(batch_size, 1, -1) + return result + + class HPUAttentionImpl(AttentionImpl, torch.nn.Module): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -153,12 +396,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") - self.attn_type = attn_type if (self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_DECODER diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index e9b4dff74f42..86e2378051bd 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -29,11 +29,16 @@ scaled_quantize) from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm.platforms import current_platform -try: - from vllm.vllm_flash_attn import flash_attn_varlen_func -except ImportError: - from flash_attn import flash_attn_varlen_func +if current_platform.is_hpu(): + from vllm_hpu_extension.ops import is_hpu_gaudi2 + +if current_platform.is_cuda_alike(): + try: + from vllm.vllm_flash_attn import flash_attn_varlen_func + except ImportError: + from flash_attn import flash_attn_varlen_func @dataclass @@ -199,9 +204,14 @@ def _v_up_proj_and_o_proj(self, x): output = output_parallel return output else: - x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) - return self.o_proj(x.reshape(-1, - self.num_heads * self.v_head_dim))[0] + # chendi: this is a cherry-pick of missing commit from upstream + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return self.o_proj(x)[0] def _q_proj_and_k_up_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: @@ -214,10 +224,17 @@ def _q_proj_and_k_up_proj(self, x): return torch.matmul(x, self.W_Q_UK)\ .view(-1, self.num_heads, self.kv_lora_rank) else: - x = torch.matmul(x, self.W_Q)\ - .view(-1, self.num_heads, self.qk_nope_head_dim) - return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ - .view(-1, self.num_heads, self.kv_lora_rank) + # chendi: this is a cherry-pick of missing commit from upstream + q_nope, q_pe = self.q_proj(x)[0]\ + .view(-1, self.num_heads, self.qk_head_dim)\ + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe def process_weights_after_loading(self, act_dtype: torch.dtype): # TODO(lucas) This is very gross, we need a more wide scale refactor of @@ -302,19 +319,21 @@ def get_and_maybe_dequant_weights(layer: LinearBase): W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ - .view(-1, self.num_heads, self.qk_head_dim) - - # can be W_Q or W_UQ depending q_lora_rank, the former if - # q_lora_rank is None, the latter otherwise. From the Attention backend - # perspective though we call these both W_Q and rely on the layer - # to pass in the correct matrix - W_Q = q_proj_weight[..., :self.qk_nope_head_dim] - self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ - .flatten(start_dim=1).contiguous() + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + # Chendi: This is a cherry-pick of missing commit from upstream + q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ + .view(-1, self.num_heads, self.qk_head_dim) + + # can be W_Q or W_UQ depending q_lora_rank, the former if + # q_lora_rank is None, the latter otherwise. From the Attention + # backend perspective though we call these both W_Q and rely on + # the layer to pass in the correct matrix + W_Q = q_proj_weight[..., :self.qk_nope_head_dim] + self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ + .flatten(start_dim=1).contiguous() - # W_QR is small so for simplicity we dont bother requantizing it - self.W_QR = self.W_QR.to(act_dtype) + # W_QR is small so for simplicity we dont bother requantizing it + self.W_QR = self.W_QR.to(act_dtype) if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION @@ -337,15 +356,20 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # for decode, as a result we end up with absorbed weights for decode # and another copy of raw weights for prefill. # - self.W_UK, self.W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK # depending q_lora_rank, the former if q_lora_rank is None, the # latter otherwise # basically if q_lora_rank is none we are absorbing into q_proj # instead of UQ - W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ - .flatten(start_dim=1).contiguous() + + # chendi: Important fix for Gaudi2 + if current_platform.is_hpu() and is_hpu_gaudi2(): + W_Q_UK = torch.einsum( + "qnd,lnd -> qnl", W_Q.bfloat16(), + W_UK.bfloat16()).flatten(start_dim=1).contiguous().float() + else: + W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, + W_UK).flatten(start_dim=1).contiguous() if is_fp8(weight_dtype) and requantization_enabled: W_Q_UK, W_Q_UK_scales = scaled_quantize( @@ -361,8 +385,17 @@ def get_and_maybe_dequant_weights(layer: LinearBase): W_O = get_and_maybe_dequant_weights(self.o_proj)\ .view(-1, self.num_heads, self.v_head_dim) - W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ - .flatten(start_dim=0, end_dim=1).contiguous() + + # chendi: Important fix for Gaudi2 + if current_platform.is_hpu() and is_hpu_gaudi2(): + W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV.bfloat16(), + W_O.bfloat16()).flatten( + start_dim=0, + end_dim=1).contiguous().float() + else: + W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, + W_O).flatten(start_dim=0, + end_dim=1).contiguous() if is_fp8(weight_dtype) and requantization_enabled: W_UV_O, W_UV_O_scales = scaled_quantize( @@ -378,13 +411,11 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.tp_size = get_tensor_model_parallel_world_size() else: - if is_fp8(weight_dtype): - raise NotImplementedError( - "Currently fp8 requires matrix absorption") - - self.W_UV = W_UV - self.W_UK = W_UK - self.W_Q = W_Q.flatten(start_dim=1) + # chendi: this is a cherry-pick of missing commit from upstream + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) @abstractmethod def _forward_prefill( diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 5c59917500a7..088244dd3d24 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -61,6 +61,8 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, @staticmethod def forward_decode(**kwargs) -> torch.Tensor: + if kwargs.get("kv_lora_rank"): + return ops.flat_pa_mla(**kwargs) return ops.flat_pa(**kwargs) @staticmethod diff --git a/vllm/envs.py b/vllm/envs.py index f8a18cc662ab..30d8993dd18d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -89,6 +89,7 @@ VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None + VLLM_ENABLE_EXPERT_PARALLEL: bool = True def get_default_cache_root(): @@ -585,6 +586,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # specify the path through environment variable VLLM_CUDART_SO_PATH. "VLLM_CUDART_SO_PATH": lambda: os.getenv("VLLM_CUDART_SO_PATH", None), + + # Temporary add for enable expert parallel, should remove after rebase + "VLLM_ENABLE_EXPERT_PARALLEL": + lambda: bool(int(os.getenv("VLLM_ENABLE_EXPERT_PARALLEL", "1"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4b4423561121..c3f7dab8f551 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -18,6 +18,7 @@ from vllm.utils import direct_register_custom_op if current_platform.is_hpu(): + import habana_frameworks.torch as htorch from vllm_hpu_extension.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant @@ -943,15 +944,22 @@ def grouped_topk(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported scoring function: {scoring_func}") + # Chendi: Important fix, avoid fusedKernel due to acc issue + if current_platform.is_hpu(): + htorch.core.mark_step() + + num_token = scores.shape[0] if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] - num_token = scores.shape[0] - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] @@ -959,7 +967,10 @@ def grouped_topk(hidden_states: torch.Tensor, score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + + # Chendi: pick up upstream fix + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8cfac29cdf2a..9e88d649fcda 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -6,6 +6,7 @@ import torch +import vllm.envs as envs from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -105,20 +106,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: raise NotImplementedError("CPU MOE only supports x86 arch.") - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ep_rank: Optional[int] = None) -> torch.Tensor: return self.forward(x=x, layer=layer, router_logits=router_logits, @@ -129,7 +129,8 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + ep_rank=ep_rank) def forward_cuda( self, @@ -175,14 +176,38 @@ def forward_hpu( topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, **kwargs, ): - assert not use_grouped_topk, 'use_grouped_topk must be False on HPU' - assert num_expert_group is None, ('num_expert_group is ' - 'not supported on HPU') - assert topk_group is None, 'topk_group is not supported on HPU' - if layer is not None: - return layer.hpu_fused_moe(x, router_logits, top_k) + if use_grouped_topk or custom_routing_function is not None: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + else: + import torch.nn.functional as F + topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(x.dtype) + + topk_ids = topk_ids.view(*x.shape[:-1], -1) + topk_weights = topk_weights.view(*x.shape[:-1], -1) + return layer.moe_op( + x, + topk_ids.to(torch.int64), + topk_weights.to(x.dtype), + permuted_weights=True, + activation="silu", + ) def forward_cpu( self, @@ -289,8 +314,16 @@ def __init__( self.tp_size = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) + # Chendi: Temporarily add ep_support using parameter + # It should be gone once version updated to upstream + self.ep_size = self.tp_size \ + if envs.VLLM_ENABLE_EXPERT_PARALLEL else 1 + self.tp_size = self.tp_size // self.ep_size + assert num_experts % self.ep_size == 0 + self.ep_rank = get_tensor_model_parallel_rank() // self.tp_size + self.top_k = top_k - self.num_experts = num_experts + self.num_experts = num_experts // self.ep_size assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results @@ -302,8 +335,23 @@ def __init__( self.topk_group = topk_group self.custom_routing_function = custom_routing_function if is_hpu: - from vllm_hpu_extension.ops import DynamicFusedMOE - self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) + ep_shift = self.ep_rank * self.num_experts + from vllm_hpu_extension.ops import (VllmMixtureOfExpertsOp, + VllmMixtureOfExpertsOpFP8) + experts_min, experts_max = ep_shift, self.num_experts + ep_shift - 1 + if quant_config is not None: + moe_op = VllmMixtureOfExpertsOpFP8( + self.num_experts, + experts_min, + experts_max, + ) + else: + moe_op = VllmMixtureOfExpertsOp( + self.num_experts, + experts_min, + experts_max, + ) + self.moe_op = moe_op self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias @@ -319,7 +367,7 @@ def __init__( assert self.quant_method is not None moe_quant_params = { - "num_experts": num_experts, + "num_experts": self.num_experts, "hidden_size": hidden_size, "intermediate_size_per_partition": self.intermediate_size_per_partition, @@ -385,7 +433,7 @@ def _load_model_weight_or_group_weight_scale(self, def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.Tensor, - tp_rank: int): + tp_rank: int, expert_id: int): # for per channel weight quantization if shard_id == "w2": expert_data.copy_(loaded_weight) @@ -394,7 +442,8 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_rank=tp_rank, + expert_id=expert_id) def _load_w13(self, expert_data: torch.Tensor, @@ -420,9 +469,8 @@ def _load_w13(self, expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) - if is_hpu: - self.hpu_fused_moe.MoeOp.w13_list[expert_id].set_weight( - orig_exp_data) + if is_hpu and isinstance(self.quant_method, UnquantizedFusedMoEMethod): + self.moe_op.w13_list[expert_id].set_weight(orig_exp_data) def _load_w2(self, expert_data: torch.Tensor, @@ -442,8 +490,8 @@ def _load_w2(self, shard_size) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) - if is_hpu: - self.hpu_fused_moe.MoeOp.w2_list[expert_id].set_weight(expert_data) + if is_hpu and isinstance(self.quant_method, UnquantizedFusedMoEMethod): + self.moe_op.w2_list[expert_id].set_weight(expert_data) def _load_single_value(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int): @@ -468,6 +516,16 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + # Chendi: temporarily add ep_support using parameter + # expect it will be gone once rebase to latest community version + tp_rank = get_tensor_model_parallel_rank() + if self.ep_size > 1: + tp_rank = tp_rank // self.ep_size + # now we want to only load weights for current expert group + expert_id = expert_id - self.ep_rank * self.num_experts + if expert_id < 0 or expert_id >= self.num_experts: + return + # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -488,7 +546,6 @@ def weight_loader(self, param: torch.nn.Parameter, SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] - tp_rank = get_tensor_model_parallel_rank() # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -538,7 +595,8 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_rank=tp_rank, + expert_id=expert_id) elif quant_method in [ FusedMoeWeightScaleSupported.GROUP.value, FusedMoeWeightScaleSupported.BLOCK.value, @@ -639,7 +697,7 @@ def forward(self, hidden_states: torch.Tensor, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias) - if self.reduce_results and self.tp_size > 1: + if self.reduce_results and self.tp_size > 1 or self.ep_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 7fa461d9e287..add961291260 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,7 +2,7 @@ import itertools from abc import abstractmethod -from typing import Optional +from typing import Callable, Optional import torch import torch.nn.functional as F @@ -184,6 +184,14 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: raise NotImplementedError + # Chendi: Necessary base func added by INC team + def get_dequant_weights_func( + self, ) -> Optional[Callable[[torch.nn.Module], torch.Tensor]]: + if self.quant_method is not None: + quant_method = self.quant_method + if hasattr(quant_method, "dequant_block_fp8_weight"): + return quant_method.dequant_block_fp8_weight + class ReplicatedLinear(LinearBase): """Replicated linear layer. diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e0534e154832..c485a86150c4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -33,6 +33,7 @@ from vllm.platforms import current_platform if current_platform.is_hpu(): + import vllm_hpu_extension.ops as hpu_ops from vllm_hpu_extension.ops import scaled_fp8_quant ops.scaled_fp8_quant = scaled_fp8_quant @@ -170,6 +171,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + layer.quant_config = self.quant_config output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") @@ -257,10 +259,26 @@ def create_weights( else: layer.register_parameter("input_scale", None) + def dequant_block_fp8_weight(self, layer) -> torch.Tensor: + if hasattr(layer, "updated_fp8_weight") and layer.updated_fp8_weight: + return layer.weight + dequant_weight = hpu_ops.dequant_block_fp8_weight_naive( + layer.weight, + layer.weight_scale_inv.data, + self.quant_config.weight_block_size, + original_M=layer.orig_M, + original_N=layer.orig_N, + do_unpad=True, + ) + return dequant_weight + def process_weights_after_loading(self, layer: Module) -> None: # TODO(rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" + if current_platform.is_hpu(): + layer = hpu_ops.fp8_block_linear_postprocess_weights(layer) + return if current_platform.is_rocm(): weight, weight_scale_inv, _ = \ normalize_e4m3fn_to_e4m3fnuz( @@ -365,6 +383,18 @@ def apply(self, apply_w8a8_block_fp8_linear) if self.block_quant: assert self.quant_config.weight_block_size is not None + if current_platform.is_hpu(): + return hpu_ops.apply_block_fp8_linear_hpu_dequant( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale, + bias=bias, + original_M=layer.orig_M, + original_N=layer.orig_N, + do_unpad=True, + ) return apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, @@ -406,7 +436,7 @@ def __init__(self, quant_config: Fp8Config): def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): - + layer.quant_config = self.quant_config if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn if self.block_quant: @@ -528,6 +558,9 @@ def process_weights_after_loading(self, layer: Module) -> None: # TODO (rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" + if current_platform.is_hpu(): + layer = hpu_ops.fp8_block_moe_prepare_weights(layer) + return if current_platform.is_rocm(): w13_weight, w13_weight_scale_inv, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( @@ -680,6 +713,11 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) + if current_platform.is_hpu(): + return layer.moe_op(x, + topk_weights=topk_weights.to(torch.int64), + topk_ids=topk_ids.to(x.dtype)) + return fused_experts( x, layer.w13_weight, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 4f67ace2d602..fdf53559ea88 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -607,7 +607,7 @@ def initialize_dummy_weights( """ for param in model.state_dict().values(): if torch.is_floating_point(param): - if current_platform.is_tpu(): + if current_platform.is_tpu() or current_platform.is_hpu(): # XLA device does not support torch.Generator() param.uniform_(low, high) continue diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 054f3e46ed02..53aaa0a0376c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -117,30 +117,6 @@ def __init__( if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") - if is_hpu: - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=False, - prefix=f"{prefix}.experts") - else: - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts") self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, @@ -165,6 +141,7 @@ def __init__( num_expert_group=config.n_group, topk_group=config.topk_group, prefix=f"{prefix}.experts", + tp_size=self.tp_size, scoring_func=config.scoring_func, e_score_correction_bias=self.gate.e_score_correction_bias) @@ -180,12 +157,19 @@ def __init__( ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape + if hidden_states.dim() == 3: + batch_size, seq_len, hidden_dim = hidden_states.shape + num_tokens = batch_size * seq_len + else: + batch_size, seq_len = None, None + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) + if batch_size is not None and seq_len is not None: + hidden_states = hidden_states.view(batch_size, seq_len, hidden_dim) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor @@ -195,6 +179,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) + if batch_size is not None: + return final_hidden_states.view(batch_size, seq_len, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim) diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index b6c6ee19ebe0..2fb586180f0e 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -35,6 +35,9 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, if use_v1: logger.info("Using HPUAttentionV1 backend.") return "vllm.v1.attention.backends.hpu_attn.HPUAttentionBackendV1" + if use_mla: + logger.info("Using HPUAttentionMLA backend.") + return "vllm.attention.backends.hpu_attn.HPUMLAAttentionBackend" logger.info("Using HPUAttention backend.") return "vllm.attention.backends.hpu_attn.HPUAttentionBackend" diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7f4b3c25b75d..d65a51ae497a 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -665,6 +665,7 @@ def __init__( self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, ) if needs_attn_backend else None # Multi-modal data support @@ -750,6 +751,28 @@ def model_is_mrope(self) -> bool: config = self.model_config.hf_config return uses_mrope(config) + # Chendi: Requested to be added by INC team + def _remove_duplicate_submodules_(self, model, inc_config): + # FIXME: (Yi) for deepseek v3 only + self_attn = model.model.layers[0].self_attn + for layer in model.model.layers: + self_attn = layer.self_attn + # delete attrs: q_b_proj, kv_b_proj, o_proj in self_attn, + # as they have been transferred to the MLAImpl. + if hasattr(self_attn, "q_b_proj"): + delattr(self_attn, "q_b_proj") + if hasattr(self_attn, "kv_b_proj"): + delattr(self_attn, "kv_b_proj") + if hasattr(self_attn, "o_proj"): + delattr(self_attn, "o_proj") + + def _inc_preprocess_(self, model: torch.nn.Module, inc_config): + self._remove_duplicate_submodules_(model, inc_config) + + def _is_quant_with_inc(self): + quant_config = os.getenv("QUANT_CONFIG", None) is not None + return (self.model_config.quantization == "inc" or quant_config) + def load_model(self) -> None: import habana_frameworks.torch.core as htcore if self.model_config.quantization == 'inc' or \ @@ -801,19 +824,21 @@ def load_model(self) -> None: ) self.model = self.lora_manager.create_lora_manager(self.model) - if self.model_config.quantization == 'inc': + if self._is_quant_with_inc(): logger.info("Preparing model with INC..") with HabanaMemoryProfiler() as m_inc: from neural_compressor.torch.quantization import ( FP8Config, convert, prepare) config = FP8Config.from_json_file( os.getenv("QUANT_CONFIG", "")) + self._inc_preprocess_(self.model, config) if config.measure: self.model = prepare(self.model, config) elif config.quantize: self.model = convert(self.model, config) htcore.hpu_initialize(self.model, mark_only_scales_as_const=True) + torch.distributed.barrier() self.inc_initialized_successfully = True logger.info("Preparing model with INC took %s", m_inc.get_summary_string()) @@ -1316,6 +1341,7 @@ def _prepare_prompt( slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=False, + input_positions=input_positions, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) for t in multi_modal_kwargs: @@ -1603,7 +1629,7 @@ def _prepare_decode( slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, - ) + input_positions=input_positions) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, @@ -1909,6 +1935,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'block_indices', 'block_offsets', 'block_groups', + 'input_positions', ]) return attention_metadata @@ -2025,8 +2052,15 @@ def warmup_scenario(self, if is_pt_profiler_run and self.is_driver_worker: profiler = setup_profiler() profiler.start() - for _ in range(times): + for time_index in range(times): inputs = self.prepare_model_input(seqs) + # Chendi: Necessary fix for warmup with TP>1 + if time_index == 0: + if self.is_driver_worker: + broadcast_tensor_dict( + {"input_tokens": inputs.input_tokens}, src=0) + else: + broadcast_tensor_dict(src=0) is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 if is_prompt or is_single_step: @@ -2339,7 +2373,7 @@ def finish_measurements(self): finalize_calibration(self.model.model) def shutdown_inc(self): - can_finalize_inc = (self.model_config.quantization == 'inc') and \ + can_finalize_inc = self._is_quant_with_inc() and \ (self.model.model is not None) and \ self.inc_initialized_successfully and \ not getattr(self, "_is_inc_finalized", False) diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index ddd336e3c5be..68847883e18c 100755 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -581,16 +581,25 @@ def _allocate_kv_cache( """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) + if len(kv_cache_shape) == 2: + k_cache_shape = kv_cache_shape[0] + v_cache_shape = kv_cache_shape[1] + else: + k_cache_shape = kv_cache_shape + v_cache_shape = kv_cache_shape kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] dtype = self.dtype if device != 'hpu' and not is_fake_hpu() \ and self.dtype == torch.float8_e4m3fn: dtype = torch.uint8 for _ in range(self.num_attention_layers): - key_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=device) - value_cache = torch.zeros(kv_cache_shape, - dtype=dtype, - device=device) + key_cache = torch.zeros(k_cache_shape, dtype=dtype, device=device) + if v_cache_shape is not None: + value_cache = torch.zeros(v_cache_shape, + dtype=dtype, + device=device) + else: + value_cache = None kv_layer = (key_cache, value_cache) kv_cache.append(kv_layer) return kv_cache diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 38d2b712eff5..16b43e92cc32 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -46,7 +46,11 @@ def _init_attn_metadata_from_tensor_dict( valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): if field.name in tensor_dict: - valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) + # Chendi: Cherry-pick from upstream + if field.name == "input_positions": + valid_attn_kwargs[field.name] = tensor_dict[field.name] + else: + valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) tensor_dict["attn_metadata"] = attn_metadata From d273848499f329cc0e15743f8e1c0f948bc90b8d Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Tue, 8 Apr 2025 17:39:32 +0000 Subject: [PATCH 03/20] update for mypy Signed-off-by: Chendi.Xue --- vllm/attention/backends/hpu_attn.py | 22 +++++++++++----------- vllm/worker/hpu_worker.py | 6 +++--- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 5034dba7fb05..3cd61d6ca076 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -50,7 +50,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> List[Tuple[int, ...]]: + ) -> Tuple[int, ...]: return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size) @@ -94,8 +94,8 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> List[Tuple[int, ...]]: - return [(num_blocks, block_size, head_size), None] + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) @staticmethod def swap_blocks( @@ -265,10 +265,10 @@ def forward( return self._forward_decode(q_nope, q_pe, (k_cache, v_cache), attn_metadata, batch_size) - def _forward_prefill(self, q: torch.Tensor, k_c_normed: torch.Tensor, - k_pe: torch.Tensor, - attn_metadata: HPUAttentionMetadata, - batch_size: int) -> torch.Tensor: + def _forward_prefill( # type: ignore + self, q: torch.Tensor, k_c_normed: torch.Tensor, + k_pe: torch.Tensor, attn_metadata: HPUAttentionMetadata, + batch_size: int) -> torch.Tensor: kv_nope = self.kv_b_proj(k_c_normed)[0]\ .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ @@ -304,10 +304,10 @@ def _forward_prefill(self, q: torch.Tensor, k_c_normed: torch.Tensor, return self.o_proj(attn_output)[0] - def _forward_decode(self, q_nope: torch.Tensor, q_pe: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: HPUAttentionMetadata, - batch_size: int) -> torch.Tensor: + def _forward_decode( # type: ignore + self, q_nope: torch.Tensor, q_pe: torch.Tensor, + kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata, + batch_size: int) -> torch.Tensor: query = torch.cat([q_nope, q_pe], dim=-1) key_cache = kv_cache[0].unsqueeze(2) diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 68847883e18c..2ce84415e2d5 100755 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -581,9 +581,9 @@ def _allocate_kv_cache( """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) - if len(kv_cache_shape) == 2: - k_cache_shape = kv_cache_shape[0] - v_cache_shape = kv_cache_shape[1] + if self.model_config.use_mla: + k_cache_shape = kv_cache_shape + v_cache_shape = None else: k_cache_shape = kv_cache_shape v_cache_shape = kv_cache_shape From 7cf1dcfd41a86cf802b1fff52f709bb8834d5f9f Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Tue, 8 Apr 2025 18:44:03 +0000 Subject: [PATCH 04/20] fix acc issue Signed-off-by: Chendi.Xue --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 6 ++++-- vllm/model_executor/models/deepseek_v2.py | 18 +++--------------- 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9e88d649fcda..26721b4db5e9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -697,7 +697,7 @@ def forward(self, hidden_states: torch.Tensor, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias) - if self.reduce_results and self.tp_size > 1 or self.ep_size > 1: + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c485a86150c4..91bba0cad8de 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -714,9 +714,11 @@ def apply( ) if current_platform.is_hpu(): + topk_ids = topk_ids.view(*x.shape[:-1], -1) + topk_weights = topk_weights.view(*x.shape[:-1], -1) return layer.moe_op(x, - topk_weights=topk_weights.to(torch.int64), - topk_ids=topk_ids.to(x.dtype)) + topk_ids=topk_ids.to(torch.int64), + topk_weights=topk_weights.to(x.dtype)) return fused_experts( x, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 53aaa0a0376c..2e414522339f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -157,7 +157,7 @@ def __init__( ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if hidden_states.dim() == 3: + if is_hpu: batch_size, seq_len, hidden_dim = hidden_states.shape num_tokens = batch_size * seq_len else: @@ -168,8 +168,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if batch_size is not None and seq_len is not None: - hidden_states = hidden_states.view(batch_size, seq_len, hidden_dim) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor @@ -179,7 +177,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - if batch_size is not None: + if is_hpu: return final_hidden_states.view(batch_size, seq_len, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim) @@ -575,6 +573,7 @@ def __init__( eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.prefix = prefix def forward( self, @@ -584,8 +583,6 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: - if is_hpu: - _batch_size = positions.shape[0] # Self Attention if residual is None: residual = hidden_states @@ -603,16 +600,7 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - if is_hpu: - # need reshape from tensor(x0, y0) to tensor(x1) for hpu - hidden_states = hidden_states.reshape( - hidden_states.shape[0] * hidden_states.shape[1], - hidden_states.shape[2]) hidden_states = self.mlp(hidden_states) - if is_hpu: - hidden_states = hidden_states.reshape( - _batch_size, hidden_states.shape[0] // _batch_size, - hidden_states.shape[1]) return hidden_states, residual From 4560a09fd63b39bdcd32c47bea8f25ca2b7bc1e7 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Tue, 8 Apr 2025 22:01:46 +0000 Subject: [PATCH 05/20] fix mypy Signed-off-by: Chendi.Xue --- vllm/model_executor/layers/linear.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index add961291260..4be7b64af131 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -191,6 +191,7 @@ def get_dequant_weights_func( quant_method = self.quant_method if hasattr(quant_method, "dequant_block_fp8_weight"): return quant_method.dequant_block_fp8_weight + return None class ReplicatedLinear(LinearBase): From 85f06932eeb7afd2c10d7333d2aeb76c830f6d85 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Tue, 8 Apr 2025 17:03:04 +0000 Subject: [PATCH 06/20] update vllm-hpu-extension comit id for test Signed-off-by: Chendi.Xue --- requirements-hpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 5aeecab4a1db..f7fd372bb982 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@3e0fb39 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@35c8288 From ff61f890d5fdd4dd149304e56173b8e14f0fb3a9 Mon Sep 17 00:00:00 2001 From: Bartosz Myrcha Date: Wed, 9 Apr 2025 11:21:02 +0200 Subject: [PATCH 07/20] [SW-224648] Fix test logs redirection (#1027) Fixed test logs redirection --- .github/workflows/trigger_jenkins.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/trigger_jenkins.yml b/.github/workflows/trigger_jenkins.yml index c9f199ce9549..e86a8e6a333a 100644 --- a/.github/workflows/trigger_jenkins.yml +++ b/.github/workflows/trigger_jenkins.yml @@ -220,14 +220,14 @@ jobs: RELEASED_SYNAPSE_VERSION: ${{ vars.RELEASED_SYNAPSE_VERSION }} BASE_BRANCH: ${{ needs.read_codeowners.outputs.pr_branch }} run: | - LOG_REDIRECTION=">" + LOG_REDIRECTION="\&>" version_regex='^v([0-9]+)\.([0-9]+)\.([0-9]+)_next$' if [[ $TARGET_BRANCH =~ $version_regex ]]; then synapse_version=${TARGET_BRANCH#v} synapse_version=${synapse_version%_*} else synapse_version=${RELEASED_SYNAPSE_VERSION#v} - LOG_REDIRECTION="\| tee" + LOG_REDIRECTION="2>\&1 \| tee" fi echo "Using SynapseAI version ${synapse_version}" synapse_build=$(curl "https://dms.habana-labs.com/api/v1.1/branch/info/v$synapse_version" | jq -r ".release_id") @@ -248,6 +248,7 @@ jobs: run: | random_string=$(tr -dc 'a-z0-9' Date: Wed, 9 Apr 2025 16:00:57 +0200 Subject: [PATCH 08/20] [SW-225233] Adjust method of getting synapse_build (#1045) Adjusted method of extracting synapse build id for release branches --- .github/workflows/trigger_jenkins.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/trigger_jenkins.yml b/.github/workflows/trigger_jenkins.yml index e86a8e6a333a..ed5429e24517 100644 --- a/.github/workflows/trigger_jenkins.yml +++ b/.github/workflows/trigger_jenkins.yml @@ -225,12 +225,14 @@ jobs: if [[ $TARGET_BRANCH =~ $version_regex ]]; then synapse_version=${TARGET_BRANCH#v} synapse_version=${synapse_version%_*} + synapse_build_endpoint="https://dms.habana-labs.com/api/v1.1/guide/info/${synapse_version}/latest?type=docker-pt" else synapse_version=${RELEASED_SYNAPSE_VERSION#v} LOG_REDIRECTION="2>\&1 \| tee" + synapse_build_endpoint="https://dms.habana-labs.com/api/v1.1/branch/info/v${synapse_version}" fi - echo "Using SynapseAI version ${synapse_version}" - synapse_build=$(curl "https://dms.habana-labs.com/api/v1.1/branch/info/v$synapse_version" | jq -r ".release_id") + echo "Using SynapseAI version ${synapse_version}" + synapse_build=$(curl "${synapse_build_endpoint}" | jq -r ".release_id") pt_version=${{ vars.PT_VERSION }} BUILD_TAG="Github-vLLM-Fork-${{ github.event.number }}-${{github.run_number}}" safe_cmd=${TEST_COMMAND//&/\\&} From 5a9ddfdaaa144ab80b4f1b067d30f29848a8ffec Mon Sep 17 00:00:00 2001 From: Jakub Maksymczuk Date: Thu, 10 Apr 2025 14:13:02 +0200 Subject: [PATCH 09/20] Implement Pipeline Parallelism support for HPU. (#1000) (#1040) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements HPU support for pipeline parallelism. Tested accuracy and it's the same as TP accuracy on: - Llama3.1-70b-Instruct - Llama3.2-3b-Instruct - Mixtral-8x7b To serve with PP: `VLLM_DECODE_BS_BUCKET_MIN=384 VLLM_DECODE_BLOCK_BUCKET_MAX=896 vllm serve /mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-70B-Instruct/ --tensor-parallel-size 1 --pipeline-parallel-size 4 --max-num-seqs 384 --disable-log-requests --dtype bfloat16 --gpu-memory-util 0.9 --disable-log-stats --num_scheduler_steps 1 --max-num-batched-tokens 2048 --max-model-len 256 --block-size 128` Known issues: * since for Pipeline Parallelism max_num_seqs acts as a microbatch for a single virtual_engine - for bigger batch_size we fall into a very specific corner case and get flat_pa error -> set batch_size to approximately batch size that you would use in TP but divided by pp_size * delayed sampling is not yet compatible with pipeline parallelism * virtaul_engine ID is passed to HPUGraph which results in pp_size * amount of graphs Signed-off-by: jmaksymczuk Co-authored-by: Rafal Litka Co-authored-by: Michał Kuligowski --- vllm/model_executor/models/utils.py | 5 ++++- vllm/sequence.py | 3 +++ vllm/worker/hpu_model_runner.py | 27 ++++++++++++++++++++++++--- vllm/worker/hpu_worker.py | 11 ++++++++--- 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 9925fe16d39c..806d92975c18 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -606,12 +606,15 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( batch_size: int, + context_size: int, dtype: torch.dtype, device: torch.device, ) -> IntermediateTensors: return IntermediateTensors({ key: - torch.zeros((batch_size, hidden_size), dtype=dtype, device=device) + torch.zeros((batch_size, context_size, hidden_size), + dtype=dtype, + device=device) for key in keys }) diff --git a/vllm/sequence.py b/vllm/sequence.py index c4c5a131b0c2..aabf8c9a9309 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1146,6 +1146,9 @@ def __eq__(self, other: object): def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" + def __iter__(self): + return iter(self.tensors) + class PoolerOutput( msgspec.Struct, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7f4b3c25b75d..7fcb9a8571e5 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -32,7 +32,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.hpu_attn import HPUAttentionImpl from vllm.config import DeviceConfig, VllmConfig -from vllm.distributed import broadcast_tensor_dict +from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.distributed.parallel_state import get_world_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -421,6 +421,8 @@ def forward(self, *args, **kwargs): with set_forward_context(kwargs['attn_metadata'], self.vllm_config, virtual_engine): hidden_states = self.model(*args, **kwargs) + if not get_pp_group().is_last_rank: + return hidden_states hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) if selected_token_indices is not None: hidden_states = hidden_states.index_select( @@ -433,6 +435,9 @@ def compute_logits(self, *args, **kwargs): def sample(self, *args, **kwargs): return self.model.sample(*args, **kwargs) + def make_empty_intermediate_tensors(self, *args, **kwargs): + return self.model.make_empty_intermediate_tensors(*args, **kwargs) + def generate_proposals(self, *args, **kwargs): return self.model.generate_proposals(*args, **kwargs) @@ -1949,7 +1954,7 @@ def profile_run(self) -> None: kv_caches = [None] * num_layers bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, - [kv_caches]) + [kv_caches] * self.parallel_config.pipeline_parallel_size) _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() max_batch_size = min(self.max_num_seqs, self.max_num_batched_tokens // max_seq_len) @@ -2030,7 +2035,18 @@ def warmup_scenario(self, is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 if is_prompt or is_single_step: - self.execute_model(inputs, kv_caches, warmup_mode=True) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = \ + self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + context_size=seq_len if is_prompt else 1, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(inputs, + kv_caches, + intermediate_tensors=intermediate_tensors, + warmup_mode=True) else: # decode with multi-step inputs = dataclasses.replace(inputs, is_first_multi_step=True, @@ -2528,6 +2544,9 @@ def execute_model( use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode assert not (use_delayed_sampling and num_steps != 1), \ 'Delayed sampling is not compatible with MSS!' + assert not (use_delayed_sampling and + self.parallel_config.pipeline_parallel_size != 1), \ + 'Delayed sampling is not compatible with Pipeline Parallelism!' assert model_input.input_tokens is not None if use_delayed_sampling and not model_input.is_prompt and \ self.is_driver_worker: @@ -2684,6 +2703,8 @@ def try_revert_dummy_output_tokens(): LoraMask.setLoraMask( lora_logits_mask.index_select( 0, sampling_metadata.selected_token_indices)) + if not get_pp_group().is_last_rank: + return hidden_states # Compute the logits. with self.profiler.record_event( diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index ddd336e3c5be..f20c5adb258c 100755 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -20,7 +20,7 @@ import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, get_pp_group, init_distributed_environment) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -63,8 +63,9 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker - if self.is_driver_worker: - assert self.rank == 0, "The driver worker must have rank 0." + if self.parallel_config and self.is_driver_worker: + assert self.rank % self.parallel_config.tensor_parallel_size == 0, \ + "The driver worker must have rank 0." if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -526,6 +527,10 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + if parallel_config.pipeline_parallel_size > 1: + # torch-ccl xpu need a collective API warm up + # before calling send/recv API + get_pp_group().all_reduce(torch.zeros(1).to('hpu')) if torch.distributed.is_initialized(): torch_world_size = torch.distributed.get_world_size() if torch_world_size != parallel_config.world_size: From ed47e1efba0e1e21d19bea5d3a4762139417622c Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Thu, 10 Apr 2025 14:16:50 +0200 Subject: [PATCH 10/20] [1.21 cherry-pick] Fix async callback ordering (#1023) (#1028) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cherry-pick of #1023 Co-authored-by: Michał Kuligowski --- vllm/worker/hpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7fcb9a8571e5..2496929b2d6d 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2725,6 +2725,8 @@ def try_revert_dummy_output_tokens(): if use_delayed_sampling: fake_output = self._delayed_sampler_outputs(model_input) + elif model_input.async_callback is not None: + model_input.async_callback() with self.profiler.record_event( 'internal', ('sample_' @@ -2746,7 +2748,8 @@ def try_revert_dummy_output_tokens(): self.cached_step_outputs.append(output) self.cached_step_inputs.append(model_input) htorch.core.mark_step() - if model_input.async_callback is not None: + if use_delayed_sampling \ + and model_input.async_callback is not None: model_input.async_callback() if i < num_steps - 1: if i == 0: From 9a06a89bd71607ab3568c8b1e0e7ad764392042d Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Thu, 10 Apr 2025 14:17:25 +0200 Subject: [PATCH 11/20] [1.21 cherry-pick] Make lazy mode autodetection more robust (#1038) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cherry-pick of #921 Co-authored-by: Konrad Zawora Co-authored-by: Michał Kuligowski --- vllm/plugins/__init__.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 389cb8728103..057fbb3528a7 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -64,18 +64,10 @@ def load_general_plugins(): # see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158 torch._dynamo.config.disable = True elif current_platform.is_hpu(): - # NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1) - # does not support torch.compile - # Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for - # torch.compile support - is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1' - if is_lazy: + os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' + import habana_frameworks.torch as htorch + if htorch.utils.internal.is_lazy(): torch._dynamo.config.disable = True - # NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only) - # requires enabling lazy collectives - # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 - os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' - plugins = load_plugins_by_group(group='vllm.general_plugins') # general plugins, we only need to execute the loaded functions for func in plugins.values(): From a93c26abeb0bfa7751c67621886f4961d04798f1 Mon Sep 17 00:00:00 2001 From: kwisniewski98 Date: Fri, 11 Apr 2025 18:54:58 +0300 Subject: [PATCH 12/20] Add temporary workaround for V1 Signed-off-by: kwisniewski98 --- vllm/v1/attention/backends/hpu_attn.py | 13 ++++++++++--- vllm/v1/worker/hpu_model_runner.py | 7 ++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/hpu_attn.py b/vllm/v1/attention/backends/hpu_attn.py index 353a53b4a436..cfbff5e8e14b 100644 --- a/vllm/v1/attention/backends/hpu_attn.py +++ b/vllm/v1/attention/backends/hpu_attn.py @@ -30,6 +30,8 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: @dataclass class HPUAttentionMetadataV1(HPUAttentionMetadata): + # TODO(kwisniewski98): for now, in V1 input positions are not provided + # which needs to be fixed in the future, as we need to support MLA """Metadata for HPUAttentionbackend.""" is_prompt: bool attn_bias: Optional[torch.Tensor] @@ -39,7 +41,8 @@ class HPUAttentionMetadataV1(HPUAttentionMetadata): @classmethod def make_prefill_metadata(cls, seq_lens_tensor, num_prefills, - num_prefill_tokens, slot_mapping): + input_positions, num_prefill_tokens, + slot_mapping): return cls(is_prompt=True, block_list=None, block_mapping=None, @@ -53,6 +56,7 @@ def make_prefill_metadata(cls, seq_lens_tensor, num_prefills, multi_modal_placeholder_index_maps=None, seq_lens_tensor=seq_lens_tensor, num_prefills=num_prefills, + input_positions=input_positions, num_prefill_tokens=num_prefill_tokens, slot_mapping=slot_mapping, enable_kv_scales_calculation=False) @@ -60,7 +64,8 @@ def make_prefill_metadata(cls, seq_lens_tensor, num_prefills, @classmethod def make_cached_prefill_metadata(cls, seq_lens_tensor, context_lens_tensor, num_prefills, num_prefill_tokens, - slot_mapping, block_list): + input_positions, slot_mapping, + block_list): return cls(is_prompt=True, block_list=block_list, block_mapping=None, @@ -75,12 +80,13 @@ def make_cached_prefill_metadata(cls, seq_lens_tensor, context_lens_tensor, seq_lens_tensor=seq_lens_tensor, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, + input_positions=input_positions, slot_mapping=slot_mapping, enable_kv_scales_calculation=False) @classmethod def make_decode_metadata(cls, block_list, block_usage, block_groups, - num_decode_tokens, slot_mapping): + input_positions, num_decode_tokens, slot_mapping): return cls(is_prompt=False, block_mapping=None, block_indices=None, @@ -94,6 +100,7 @@ def make_decode_metadata(cls, block_list, block_usage, block_groups, block_list=block_list, block_usage=block_usage, block_groups=block_groups, + input_positions=input_positions, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, enable_kv_scales_calculation=False) diff --git a/vllm/v1/worker/hpu_model_runner.py b/vllm/v1/worker/hpu_model_runner.py index 246bcf99ecae..ad82d22d1251 100644 --- a/vllm/v1/worker/hpu_model_runner.py +++ b/vllm/v1/worker/hpu_model_runner.py @@ -1212,12 +1212,12 @@ def _prepare_prefill_inputs(self, slot_mapping, self.device) logits_indices_device = _async_h2d_tensor_copy( logits_indices, self.device) - prefill_request_ids.append(batch_req_ids) prefill_prompt_lens.append(batch_num_scheduled_tokens) prefill_token_ids.append(token_ids_device) prefill_position_ids.append(positions_device) prefill_logits_indices.append(logits_indices_device) + attn_metadata = None if use_prefix_prefill: # Prefix caching @@ -1247,6 +1247,7 @@ def _prepare_prefill_inputs(self, context_lens_tensor=context_lens_tensor_device, num_prefills=num_prefills, num_prefill_tokens=sum(batch_num_scheduled_tokens), + input_positions=None, slot_mapping=slot_mapping_device, block_list=block_list_device) else: @@ -1254,6 +1255,7 @@ def _prepare_prefill_inputs(self, seq_lens_tensor=seq_lens_tensor_device, num_prefills=num_prefills, num_prefill_tokens=sum(batch_num_scheduled_tokens), + input_positions=None, slot_mapping=slot_mapping_device, ) # ATTN_METADATA. @@ -1379,6 +1381,7 @@ def _prepare_decode_inputs(self, block_list=block_list_device, block_usage=block_usage_device, block_groups=block_groups_device, + input_positions=None, num_decode_tokens=num_decode_tokens_device, slot_mapping=slot_mapping_device, )) @@ -1747,6 +1750,7 @@ def warmup_scenario(self, batch_size, seq_or_block, is_prompt, seq_lens_tensor=seq_lens_device, num_prefills=batch_size, num_prefill_tokens=batch_size * seq_or_block, + input_positions=None, slot_mapping=slot_mapping_device) else: block_tables = [ @@ -1768,6 +1772,7 @@ def warmup_scenario(self, batch_size, seq_or_block, is_prompt, block_usage=block_usage_device, block_groups=block_groups_device, num_decode_tokens=batch_size, + input_positions=None, slot_mapping=slot_mapping_device) logits_indices = torch.arange(0, batch_size, device='cpu') From 035db32a278fa17eb864411b3a5660cb284371e8 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Date: Fri, 11 Apr 2025 20:08:16 +0200 Subject: [PATCH 13/20] APC - Remove prompt attn with context and use existing implementation (#1059) Same PR as [1020](https://github.com/HabanaAI/vllm-fork/pull/1020) but for 1.21 --- requirements-hpu.txt | 2 +- vllm/attention/backends/hpu_attn.py | 56 +++++++++++++--------------- vllm/attention/ops/hpu_paged_attn.py | 4 -- 3 files changed, 27 insertions(+), 35 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 5aeecab4a1db..f36de76312ff 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@3e0fb39 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@7208458 diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index fd005dded2e9..956c5efd7f8d 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -241,45 +241,40 @@ def forward( attn_bias.shape[-1]) attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) - if attn_metadata is None or attn_metadata.block_list is None: - out = ops.prompt_attention( - impl=self.prefill_impl, - query=query.view(query_shape), - key=key.view(kv_shape), - value=value.view(kv_shape), - is_causal=True, - attn_bias=attn_bias, - valid_seq_lengths=attn_metadata.seq_lens_tensor, - **self.common_attention_args()) - else: - # TODO: enable FusedSDPA - out = HPUPagedAttention.forward_prefix( - query=query.view(query_shape), - key=key.view(kv_shape), - value=value.view(kv_shape), - key_cache=key_cache, - value_cache=value_cache, - block_list=attn_metadata.block_list, - attn_bias=attn_metadata.attn_bias, - **self.common_attention_args()) + + block_list = attn_metadata.block_list if attn_metadata \ + and attn_metadata.block_list is not None else None + + out = ops.prompt_attention( + impl=self.prefill_impl, + query=query.view(query_shape), + key=key.view(kv_shape), + value=value.view(kv_shape), + is_causal=True, + attn_bias=attn_bias, + valid_seq_lengths=attn_metadata.seq_lens_tensor, + **self.common_attention_args(block_list, key_cache, + value_cache)) output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. output = HPUPagedAttention.forward_decode( query=query, - key_cache=key_cache, - value_cache=value_cache, - block_list=attn_metadata.block_list, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_groups=attn_metadata.block_groups, - **self.common_attention_args()) + **self.common_attention_args(attn_metadata.block_list, + key_cache, value_cache)) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) - def common_attention_args(self): + def common_attention_args(self, + block_list=None, + key_cache=None, + value_cache=None): fsdpa_op = self.fused_scaled_dot_product_attention.apply \ if self.fused_scaled_dot_product_attention is not None else None + return { 'scale': self.scale, 'matmul_qk_op': self.matmul_qk, @@ -290,6 +285,9 @@ def common_attention_args(self): 'keys_fetch_func': self.k_cache.fetch_from_cache, 'values_fetch_func': self.v_cache.fetch_from_cache, 'softmax_op': self.softmax, + 'block_list': block_list, + 'key_cache': key_cache, + 'value_cache': value_cache, } def forward_encoder_decoder( @@ -371,13 +369,11 @@ def forward_encoder_decoder( # Decoding run. output = HPUPagedAttention.forward_decode( query=query, - key_cache=key_cache, - value_cache=value_cache, - block_list=block_list, block_mapping=block_mapping, block_bias=attn_bias, block_groups=block_groups, - **self.common_attention_args()) + **self.common_attention_args(block_list, key_cache, + value_cache)) # Reshape the output tensor. return output.view(batch_size, -1, hidden_size) diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 5c59917500a7..d81aee5e4b4e 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -63,10 +63,6 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, def forward_decode(**kwargs) -> torch.Tensor: return ops.flat_pa(**kwargs) - @staticmethod - def forward_prefix(**kwargs) -> torch.Tensor: - return ops.prompt_attention_with_context(**kwargs) - @staticmethod def swap_blocks( src_kv_cache: Tuple[torch.Tensor, torch.Tensor], From 496938dfe435e1bb846366cc10182dc3a2dd1b66 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Sat, 12 Apr 2025 08:37:24 +0300 Subject: [PATCH 14/20] Resolve review comments Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 5 +---- vllm/model_executor/models/deepseek_v2.py | 12 +++--------- vllm/worker/hpu_worker.py | 8 ++------ 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 3cd61d6ca076..7d6932d8a784 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -183,7 +183,7 @@ def __init__( ] if any(unsupported_features): raise NotImplementedError( - "TritonMLAImpl does not support one of the following: " + "HPUMLAImpl does not support one of the following: " "alibi_slopes, sliding_window, blocksparse_params, " "logits_soft_cap") @@ -213,8 +213,6 @@ def forward( k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) - # Restore head dim (for rotary embedding) - # k_pe = k_pe.unsqueeze(1) assert hasattr(attn_metadata, "input_positions"), f"attn meta: {attn_metadata}" @@ -235,7 +233,6 @@ def forward( q_pe = q[..., self.qk_nope_head_dim:] input_positions = attn_metadata.input_positions.view(-1) - # TODO(lucas): there must be a nicer way to write this line q[..., self.qk_nope_head_dim:], k_pe = \ self.rotary_emb(input_positions, q_pe, k_pe) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2e414522339f..70fca4ff18c1 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -157,12 +157,8 @@ def __init__( ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if is_hpu: - batch_size, seq_len, hidden_dim = hidden_states.shape - num_tokens = batch_size * seq_len - else: - batch_size, seq_len = None, None - num_tokens, hidden_dim = hidden_states.shape + input_shape = hidden_states.shape + hidden_dim = input_shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) @@ -177,9 +173,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - if is_hpu: - return final_hidden_states.view(batch_size, seq_len, hidden_dim) - return final_hidden_states.view(num_tokens, hidden_dim) + return final_hidden_states.view(*input_shape) def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 2ce84415e2d5..669fa43dfef6 100755 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -581,12 +581,8 @@ def _allocate_kv_cache( """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) - if self.model_config.use_mla: - k_cache_shape = kv_cache_shape - v_cache_shape = None - else: - k_cache_shape = kv_cache_shape - v_cache_shape = kv_cache_shape + k_cache_shape = kv_cache_shape + v_cache_shape = None if self.model_config.use_mla else kv_cache_shape kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] dtype = self.dtype if device != 'hpu' and not is_fake_hpu() \ From a6358a5a7da5ab562901fb3f37a12c53046d07b0 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Sat, 12 Apr 2025 08:50:36 +0300 Subject: [PATCH 15/20] update dependent vllm-hpu-extension Signed-off-by: Chendi Xue --- requirements-hpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index f7fd372bb982..941fa5d6bddc 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@35c8288 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@8ced45a From b576015b80e7a4f857b84ce6b80a8894e8ad74b3 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Sat, 12 Apr 2025 08:37:21 +0200 Subject: [PATCH 16/20] Cherry pick exponential bucketing integration from #642 (#1067) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Iryna Boiko Co-authored-by: Michał Kuligowski --- .github/workflows/add_label_automerge.yml | 21 ------------------- README_GAUDI.md | 3 ++- requirements-hpu.txt | 2 +- vllm/core/scheduler.py | 25 ++++++----------------- vllm/v1/worker/hpu_model_runner.py | 4 ++-- vllm/worker/hpu_model_runner.py | 16 +++++++-------- 6 files changed, 19 insertions(+), 52 deletions(-) delete mode 100644 .github/workflows/add_label_automerge.yml diff --git a/.github/workflows/add_label_automerge.yml b/.github/workflows/add_label_automerge.yml deleted file mode 100644 index c9d6d4259df9..000000000000 --- a/.github/workflows/add_label_automerge.yml +++ /dev/null @@ -1,21 +0,0 @@ -name: Add label on auto-merge enabled -on: - pull_request_target: - types: - - auto_merge_enabled -jobs: - add-label-on-auto-merge: - runs-on: ubuntu-latest - steps: - - name: Add label - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 - with: - script: | - github.rest.issues.addLabels({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - labels: ['ready'] - }) - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/README_GAUDI.md b/README_GAUDI.md index ed635b19796e..5980357f2ea1 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -343,7 +343,8 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi - `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory dedicated for prompt graphs, `0.3` by default. - `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt graph capture, `min_tokens` or `max_bs`, `min_tokens` by default. - `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode graph capture, `min_tokens` or `max_bs`, `max_bs` by default. -- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism. +- `VLLM_EXPONENTIAL_BUCKETING`, if `true`, enables exponential bucket spacing instead of linear (experimental). +- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism (linear bucketing only). - `{phase}` is either `PROMPT` or `DECODE` - `{dim}` is either `BS`, `SEQ` or `BLOCK` - `{param}` is either `MIN`, `STEP` or `MAX` diff --git a/requirements-hpu.txt b/requirements-hpu.txt index f36de76312ff..00a8f56d2d1f 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@7208458 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f855191 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3c2fb779d5d4..f773c49ff488 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -132,25 +132,12 @@ def _generic_padding_fn(self, batch_size, max_seq_len) -> int: return batch_size * max_seq_len def _hpu_padding_fn(self, batch_size, max_seq_len): - from vllm_hpu_extension.bucketing import (HPUBucketingGlobalState, - find_bucket) - padded_bs = batch_size - padded_seq = max_seq_len - - hpu_bucketing_global_state = HPUBucketingGlobalState() - - bs_cfg = hpu_bucketing_global_state.prompt_bs_bucket_cfg - if bs_cfg is not None: - padded_bs = find_bucket(batch_size, bs_cfg) - else: - logger.warning( - "prompt_bs_bucket_cfg was not set! Using unpadded batch size.") - seq_cfg = hpu_bucketing_global_state.prompt_seq_bucket_cfg - if seq_cfg is not None: - padded_seq = find_bucket(max_seq_len, seq_cfg) - else: - logger.warning("prompt_seq_bucket_cfg was not set! " - "Using unpadded sequence length.") + from vllm_hpu_extension.bucketing.common import get_bucketing_context + hpu_bucketing_context = get_bucketing_context().get_instance() + padded_bs = hpu_bucketing_context.get_padded_prompt_batch_size( + batch_size) + padded_seq = hpu_bucketing_context.get_padded_prompt_seq_len( + max_seq_len) return padded_bs * padded_seq def _padding_fn_selector(self): diff --git a/vllm/v1/worker/hpu_model_runner.py b/vllm/v1/worker/hpu_model_runner.py index 246bcf99ecae..80c8e221dee7 100644 --- a/vllm/v1/worker/hpu_model_runner.py +++ b/vllm/v1/worker/hpu_model_runner.py @@ -17,7 +17,6 @@ import torch import torch.distributed import vllm_hpu_extension.environment as environment -from vllm_hpu_extension.bucketing import HPUBucketingContext from vllm_hpu_extension.flags import enabled_flags from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes @@ -45,6 +44,7 @@ if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput +from vllm_hpu_extension.bucketing.common import get_bucketing_context logger = init_logger(__name__) @@ -705,6 +705,7 @@ def __init__( self.seen_configs: set = set() if self.enable_bucketing: logger.info("Bucketing is ON.") + HPUBucketingContext = get_bucketing_context() self.bucketing_ctx = HPUBucketingContext( self.max_num_seqs, self.max_prefill_batch_size, self.block_size, self.scheduler_config.max_num_batched_tokens, @@ -1917,7 +1918,6 @@ def warmup_model(self) -> None: logger.info("Skipping warmup...") return max_blocks = kv_caches[0][0].size(0) - self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_decode_buckets(max_blocks) if not htorch.utils.internal.is_lazy( diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 2496929b2d6d..80bda7407f49 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -22,7 +22,7 @@ import habana_frameworks.torch.internal.bridge_config as bc import torch import vllm_hpu_extension.environment as environment -from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.bucketing.common import get_bucketing_context from vllm_hpu_extension.flags import enabled_flags from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, @@ -690,11 +690,13 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None + HPUBucketingContext = get_bucketing_context() self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, self.max_num_batched_tokens, - self.use_merged_prefill) + self.use_merged_prefill, + self.max_model_len) self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() @@ -1958,7 +1960,6 @@ def profile_run(self) -> None: _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() max_batch_size = min(self.max_num_seqs, self.max_num_batched_tokens // max_seq_len) - self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) return @@ -2188,6 +2189,10 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + if not self.is_pooler: + max_blocks = kv_caches[0][0].size(0) + self.bucketing_ctx.generate_decode_buckets(max_blocks) + if profile := os.environ.get('VLLM_PT_PROFILE', None): phase, bs, seq_len, graph = profile.split('_') is_prompt = phase == 'prompt' @@ -2197,11 +2202,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, True) raise AssertionError("Finished profiling") - if not self.is_pooler: - max_blocks = kv_caches[0][0].size(0) - self.bucketing_ctx.generate_prompt_buckets() - if not self.is_pooler: - self.bucketing_ctx.generate_decode_buckets(max_blocks) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: multiplier = 3 if os.getenv('VLLM_REGIONAL_COMPILATION', 'true').lower() == 'true' else 1 From b49caca9890008577417a4c0ad4d5bede5fe97f6 Mon Sep 17 00:00:00 2001 From: kwisniewski98 Date: Mon, 14 Apr 2025 18:32:50 +0300 Subject: [PATCH 17/20] Remove o_proj only for deepseek Signed-off-by: kwisniewski98 --- vllm/worker/hpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index c38d9560bf15..68de39df2c55 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -772,7 +772,8 @@ def _remove_duplicate_submodules_(self, model, inc_config): delattr(self_attn, "o_proj") def _inc_preprocess_(self, model: torch.nn.Module, inc_config): - self._remove_duplicate_submodules_(model, inc_config) + if "DeepseekV3ForCausalLM" in self.model.config.architectures: + self._remove_duplicate_submodules_(model, inc_config) def _is_quant_with_inc(self): quant_config = os.getenv("QUANT_CONFIG", None) is not None From 214bcaeb968a9724e233d81efe16f1dbff68aa7e Mon Sep 17 00:00:00 2001 From: kwisniewski98 Date: Mon, 14 Apr 2025 19:05:47 +0300 Subject: [PATCH 18/20] Change vllm-hpu-extension version Signed-off-by: kwisniewski98 --- requirements-hpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 12e406be887a..09107c2c551a 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@7208458 \ No newline at end of file +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@aee363c \ No newline at end of file From 9b85748f91bdb45b7c69917fbdf209254b0180e3 Mon Sep 17 00:00:00 2001 From: kwisniewski98 Date: Tue, 15 Apr 2025 14:14:28 +0300 Subject: [PATCH 19/20] Explicitly disable t.compile for deepseek Signed-off-by: kwisniewski98 --- vllm/model_executor/models/deepseek_v2.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 70fca4ff18c1..79716e767e08 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" +import os from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch @@ -685,6 +686,13 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + assert os.environ.get("PT_HPU_LAZY_MODE", "0") == "1", \ + ( + "Deepseek currently supports only lazy mode on HPU, " + "please set PT_HPU_LAZY_MODE=1" + ) + config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config From 1d7fb512d38f3910ba7f3102b264972df5e52b5d Mon Sep 17 00:00:00 2001 From: kwisniewski98 Date: Tue, 15 Apr 2025 14:30:16 +0300 Subject: [PATCH 20/20] Change method of checking lazy mode Signed-off-by: kwisniewski98 --- vllm/model_executor/models/deepseek_v2.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 79716e767e08..8d21ab0137d9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -22,7 +22,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" -import os from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch @@ -686,12 +685,14 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + if is_hpu: + import habana_frameworks.torch as htorch - assert os.environ.get("PT_HPU_LAZY_MODE", "0") == "1", \ - ( - "Deepseek currently supports only lazy mode on HPU, " - "please set PT_HPU_LAZY_MODE=1" - ) + assert htorch.utils.internal.is_lazy(), \ + ( + "Deepseek currently supports only lazy mode on HPU, " + "please set PT_HPU_LAZY_MODE=1" + ) config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config