Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions vllm_ascend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_kv_cache_shape(
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def swap_blocks(
Expand Down Expand Up @@ -512,6 +512,8 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.seq_len_cpu_tensor = None
self.key_cache = None
self.value_cache = None

def forward(
self,
Expand Down Expand Up @@ -555,6 +557,11 @@ def forward(
dtype=query.dtype,
device=query.device)

if kv_cache.numel() > 0:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping

if hasattr(layer, 'quant_method'):
isPrefill = True if attn_metadata.num_prefills > 0 else False
if isPrefill:
Expand All @@ -570,24 +577,16 @@ def forward(
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
# Details of kv_cache arrangement in attention quantization
# are implemented by quant_method.
layer.quant_method.apply(layer, query, key, value, kv_cache,
self.scale, self.seq_lens_tensor_cpu,
block_tables, isPrefill, attn_metadata,
output)
layer.quant_method.apply(layer, query, key, value, self.key_cache,
self.value_cache, self.scale,
self.seq_lens_tensor_cpu, block_tables,
isPrefill, attn_metadata, output)
else:
if kv_cache.numel() > 0:
key_cache, value_cache = kv_cache[0], kv_cache[1]
num_blocks, block_size, _ = key_cache.shape
key_cache = key_cache.view(num_blocks, block_size,
self.num_kv_heads, self.head_size)
value_cache = value_cache.view(num_blocks, block_size,
self.num_kv_heads,
self.head_size)
slots = attn_metadata.slot_mapping
if self.key_cache is not None:
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)

if attn_metadata.num_prefills > 0:
Expand Down Expand Up @@ -617,15 +616,15 @@ def forward(
"Prefix cache and chunked prefill are currently not supported."
)
elif attn_metadata.decode_metadata:
assert kv_cache is not None
assert self.key_cache is not None
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.decode_metadata.seq_lens).astype(
np.int32))
block_tables = attn_metadata.decode_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
Expand Down
8 changes: 4 additions & 4 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.quant_method.process_weights_after_loading(layer)

def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
kv_cache: List[torch.Tensor], scale: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, scale: torch.Tensor,
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,
scale, seq_lens_tensor_cpu,
return self.quant_method.apply(layer, query, key, value, key_cache,
value_cache, scale, seq_lens_tensor_cpu,
block_tables, isPrefill, attn_metadata,
output)
6 changes: 6 additions & 0 deletions vllm_ascend/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ def _init_cache_engine(self):
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
import torch_npu
for ve in range(self.parallel_config.pipeline_parallel_size):
num_layers = len(self.cache_engine[ve].gpu_cache)
for i in range(num_layers):
torch_npu.npu_format_cast(self.cache_engine[ve].gpu_cache[i],
2)
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
Expand Down