Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from typing import Optional

import torch
import torch_npu
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context

from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
maybe_converting_weight_acl_format)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner

Expand Down Expand Up @@ -113,3 +114,7 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
return hidden_states

def _convert_torch_format(self, kv_cache):
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
return kv_cache
24 changes: 12 additions & 12 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@

if is_310p():
torch_npu.npu.set_compile_mode(jit_compile=False)
ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
else:
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND


@dataclass
Expand Down Expand Up @@ -2047,8 +2050,8 @@ def load_model(self) -> None:
if isinstance(module,
(MergedColumnParallelLinear,
QKVParallelLinear, RowParallelLinear)):
module.weight.data = torch_npu.npu_format_cast(
module.weight.data, ACL_FORMAT_FRACTAL_NZ)
module.weight.data = self._convert_torch_format(
module.weight.data)
if self.drafter:
logger.info("Loading drafter model...")
if isinstance(self.drafter, EagleProposer):
Expand Down Expand Up @@ -2133,6 +2136,10 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int):
ge_cache=False)
return self.torchair_compiled_models[batch_size]

def _convert_torch_format(self, tensor):
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor

def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Expand All @@ -2141,9 +2148,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
cache size of each layer
"""
self.kv_cache_config = kv_cache_config
import torch_npu
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
) and not self.torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
kv_caches: Dict[str, torch.Tensor] = {}

def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
Expand Down Expand Up @@ -2202,7 +2206,6 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
if self.model_config.is_deepseek_mla:

num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = head_size - rope_dim
Expand All @@ -2218,10 +2221,8 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
nope_cache = torch.zeros(nope_cache_shape,
dtype=dtype,
device=self.device)
rope_cache = torch_npu.npu_format_cast(
rope_cache, acl_format)
nope_cache = torch_npu.npu_format_cast(
nope_cache, acl_format)
rope_cache = self._convert_torch_format(rope_cache)
nope_cache = self._convert_torch_format(nope_cache)
else:

# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
Expand Down Expand Up @@ -2259,8 +2260,7 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
kv_cache = torch.zeros(cache_shape,
dtype=dtype,
device=self.device)
kv_cache = torch_npu.npu_format_cast(
kv_cache, acl_format)
kv_cache = self._convert_torch_format(kv_cache)
else:
cache_size = math.prod(cache_shape)
cache_size_aligned = cache_size + alignment
Expand Down
Loading