Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ascend]feat: support kv int8 #2736

Merged
merged 9 commits into from
Dec 6, 2024
6 changes: 6 additions & 0 deletions docs/en/get_started/ascend/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```

Please check [supported_models](../../supported_models/supported_models.md) before use this feature.

### int8 KV-cache Quantization

Ascend backend has supported offline int8 KV-cache Quantization on eager mode.

Please refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details.
6 changes: 6 additions & 0 deletions docs/zh_cn/get_started/ascend/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```

支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。

### int8 KV-cache 量化

昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。

详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)。
7 changes: 5 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,11 @@ def __post_init__(self):
assert self.device_type in [
'cuda', 'ascend', 'maca'
], (f'invalid device_type: {self.device_type}')
if self.quant_policy > 0 and self.device_type != 'cuda':
assert False, 'kv cache quantization only works for CUDA.'
if self.quant_policy > 0 and self.device_type not in [
'cuda', 'ascend'
]:
assert False, \
'kv cache quantization only works for CUDA and ASCEND.'


class ResponseType(enum.Enum):
Expand Down
88 changes: 87 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import itertools
import os
import re
from pathlib import Path
from typing import Dict, Tuple

import torch

Expand All @@ -11,6 +15,71 @@
logger = get_logger('lmdeploy')


class AscendKVQuantMeta:
has_set_value: bool = False
quant_meta: Dict = {}

@classmethod
def set_value(cls, device: str, dtype: torch.dtype, record_file: str,
total_layers: int):
with open(record_file, 'r') as file:
data = file.read()
scale_offset_pairs = re.findall(
r'scale:\s*([\d\.\-]+)\s*offset:\s*(-?\d+)', data)
scale_offset_pairs = [(float(scale), float(offset))
for scale, offset in scale_offset_pairs]
k_scales, v_scales, kv_scales = [], [], []
k_zeros, v_zeros, kv_zeros = [], [], []
if len(scale_offset_pairs) == total_layers:
for scale, offset in scale_offset_pairs:
k_scales.append(
torch.tensor([scale], device=device, dtype=dtype))
v_scales.append(
torch.tensor([scale], device=device, dtype=dtype))
kv_scales.append(
torch.tensor([scale, scale], device=device, dtype=dtype))
k_zeros.append(
torch.tensor([offset], device=device, dtype=dtype))
v_zeros.append(
torch.tensor([offset], device=device, dtype=dtype))
kv_zeros.append(
torch.tensor([offset, offset], device=device, dtype=dtype))
elif len(scale_offset_pairs) == total_layers * 2:
for i in range(total_layers):
scale_k, offset_k = scale_offset_pairs[2 * i]
scale_v, offset_v = scale_offset_pairs[2 * i + 1]
k_scales.append(
torch.tensor([scale_k], device=device, dtype=dtype))
v_scales.append(
torch.tensor([scale_v], device=device, dtype=dtype))
kv_scales.append(
torch.tensor([scale_k, scale_v],
device=device,
dtype=dtype))
k_zeros.append(
torch.tensor([offset_k], device=device, dtype=dtype))
v_zeros.append(
torch.tensor([offset_v], device=device, dtype=dtype))
kv_zeros.append(
torch.tensor([offset_k, offset_v],
device=device,
dtype=dtype))
else:
raise ValueError(
f'num of scale_offset_pairs({len(scale_offset_pairs)}) '
f'must match num of total_layers({total_layers})')

cls.quant_meta.update({
'k_scales': itertools.cycle(k_scales),
'k_zeros': itertools.cycle(k_zeros),
'v_scales': itertools.cycle(v_scales),
'v_zeros': itertools.cycle(v_zeros),
'kv_scales': itertools.cycle(kv_scales),
'kv_zeros': itertools.cycle(kv_zeros)
})
cls.has_set_value = True


class AscendOpsBackend(DlinferOpsBackend):
"""ascend layer backend."""
enable_graph = False
Expand Down Expand Up @@ -164,6 +233,21 @@ def get_total_slots():
.repeat_interleave(step_context.q_seqlens, 0)
kv_seqlens = kv_seqlens_cpu

if not cls.enable_graph and step_context.kv_quant_policy == 8:
record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE'
path = Path(record_file)
is_path = path.is_absolute() or path.is_relative_to('/')
exists = path.exists()
if not (is_path and exists):
raise ValueError(
'please specify valid ASCEND_QUANT_RECORD_FILE')
if not AscendKVQuantMeta.has_set_value:
total_layers = len(step_context.kv_caches)
AscendKVQuantMeta.set_value(step_context.block_offsets.device,
step_context.model_config.dtype,
record_file, total_layers)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
Expand All @@ -177,6 +261,8 @@ def get_total_slots():
is_unpaged_prefill=is_unpaged_prefill,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
quant_policy=step_context.kv_quant_policy,
quant_meta=AscendKVQuantMeta.quant_meta,
)

step_context.attn_metadata = attn_metadata
Expand Down
37 changes: 34 additions & 3 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Dict, Optional, Sequence

from torch import Tensor

Expand All @@ -15,6 +15,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
is_unpaged_prefill: Optional[bool] = None
max_q_seq_len: int = 1
max_kv_seq_len: int = 1
quant_meta: Dict = None


class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
Expand Down Expand Up @@ -74,10 +75,37 @@ def forward(
is_unpaged_prefill = attn_metadata.is_unpaged_prefill
max_q_seq_len = attn_metadata.max_q_seq_len
max_kv_seq_len = attn_metadata.max_kv_seq_len
quant_bits = attn_metadata.quant_policy
if attn_metadata.quant_meta is not None:
k_scales_zeros = [
next(attn_metadata.quant_meta['k_scales']),
next(attn_metadata.quant_meta['k_zeros'])
] if 'k_scales' in attn_metadata.quant_meta else []
v_scales_zeros = [
next(attn_metadata.quant_meta['v_scales']),
next(attn_metadata.quant_meta['v_zeros'])
] if 'v_scales' in attn_metadata.quant_meta else []
kv_scales = next(
attn_metadata.quant_meta['kv_scales']
) if 'kv_scales' in attn_metadata.quant_meta else None
kv_zeros = next(
attn_metadata.quant_meta['kv_zeros']
) if 'kv_zeros' in attn_metadata.quant_meta else None
else:
k_scales_zeros = []
v_scales_zeros = []
kv_scales = None
kv_zeros = None

# fill kv cache
k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache,
kv_start_indices)
k_cache, v_cache = self.fill_kv_cache(key,
value,
k_cache,
v_cache,
kv_start_indices,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_bits=quant_bits)

if inplace:
attn_output = query[..., :self.v_head_size]
Expand All @@ -103,6 +131,9 @@ def forward(
block_size=block_size,
attn_mask=attn_mask,
is_unpaged_prefill=is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)

return attn_output
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class CacheConfig:
max_prefill_token_num: int = 4096
enable_prefix_caching: bool = False
quant_policy: Literal[0, 4, 8] = 0
device_type: str = 'cuda'

def __post_init__(self):
"""post init."""
Expand Down
8 changes: 7 additions & 1 deletion lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def __init__(
self.num_layers = model_config.num_layers
self.kv_cache_dtype = model_config.dtype
if cache_config.quant_policy > 0:
self.kv_cache_dtype = torch.uint8
if self.cache_config.device_type in ['cuda']:
self.kv_cache_dtype = torch.uint8
elif self.cache_config.device_type in ['ascend', 'npu']:
self.kv_cache_dtype = torch.int8
else:
raise ValueError(
f'unsupported device_type {self.cache_config.device_type}')

# Initialize the cache.
self.local_gpu_cache = self.allocate_gpu_cache()
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(self,
max_prefill_token_num=engine_config.max_prefill_token_num,
enable_prefix_caching=engine_config.enable_prefix_caching,
quant_policy=engine_config.quant_policy,
device_type=engine_config.device_type,
)

if not os.path.exists(model_path):
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def model_forward(
ctx_mgr = model.ctx_mgr
context = ctx_mgr.build_context(
inputs=inputs,
model_config=cache_engine.model_config,
world_size=world_size,
kv_caches=cache_engine.gpu_cache,
kv_quant_policy=cache_engine.cache_config.quant_policy,
Expand Down
15 changes: 13 additions & 2 deletions lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

import dlinfer.ops as ext_ops
from torch import Tensor

Expand All @@ -9,7 +11,16 @@ def fill_kv_cache(
key_caches: Tensor,
value_caches: Tensor,
kv_start_indices: Tensor,
k_scales_zeros: Sequence[Optional[Tensor]],
v_scales_zeros: Sequence[Optional[Tensor]],
quant_bits: int = 0,
):
"""fill key/value state to cache for paged attention."""
return ext_ops.fill_kv_cache(key_states, value_states, key_caches,
value_caches, kv_start_indices)
return ext_ops.fill_kv_cache(key_states,
value_states,
key_caches,
value_caches,
kv_start_indices,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_bits=quant_bits)
33 changes: 31 additions & 2 deletions lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def prefill_attention(
block_size: int,
attn_mask: Sequence[Optional[Tensor]],
is_unpaged_prefill: Optional[bool],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
) -> Tensor:
num_q_heads = query_states.shape[1]
num_kv_heads = value_states.shape[1]
Expand Down Expand Up @@ -53,11 +56,25 @@ def prefill_attention(
num_kv_heads,
attn_mask,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)


def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
max_kv_seq_len, block_offsets, block_size):
def paged_token_attention(
q,
k_cache,
v_cache,
attn_output,
kv_seq_len,
max_kv_seq_len,
block_offsets,
block_size,
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
):
num_q_heads, q_head_dim = q.shape[1:3]
num_kv_heads = k_cache.shape[-1] // q_head_dim
return ext_ops.paged_decode_attention(
Expand All @@ -71,6 +88,9 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
num_q_heads,
num_kv_heads,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)


Expand All @@ -91,6 +111,9 @@ def paged_attention_fwd(
block_size: int,
attn_mask: Sequence[Optional[Tensor]] = (),
is_unpaged_prefill: Optional[bool] = None,
kv_scales: Optional[Tensor] = None,
kv_zeros: Optional[Tensor] = None,
quant_bits: Optional[int] = 0,
):
if not is_decoding:
return prefill_attention(
Expand All @@ -108,6 +131,9 @@ def paged_attention_fwd(
block_size,
attn_mask,
is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)
else:
return paged_token_attention(
Expand All @@ -119,4 +145,7 @@ def paged_attention_fwd(
max_kv_seq_len,
block_offsets,
block_size,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
)
Loading