Skip to content

Commit

Permalink
support new backend cambricon (#3002)
Browse files Browse the repository at this point in the history
* [dlinfer]add camb support

* [camb] fix multiple of 8, exp raise core dump

* [camb] fix multiple of 8, exp raise core dump

* [camb] format

* [camb]pow of 2 better

* [camb]rm local_adapterids

* [camb]modify graph runner

* [camb]mock graph runner

* [camb]add requirements.txt

* [camb]post init set block_size to 16

* lint
  • Loading branch information
JackWeiw authored Jan 13, 2025
1 parent 39af9c8 commit 5820107
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 4 deletions.
2 changes: 1 addition & 1 deletion lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def calib_search_scale(parser):
@staticmethod
def device(parser,
default: str = 'cuda',
choices: List[str] = ['cuda', 'ascend', 'maca']):
choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']):
"""Add argument device to parser."""

return parser.add_argument('--device',
Expand Down
10 changes: 9 additions & 1 deletion lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from pydantic.dataclasses import dataclass as pydantic_dataclass

from .tokenizer import Tokenizer
from .utils import get_logger

logger = get_logger('lmdeploy')

LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a tensor of input_ids, the logits
Expand Down Expand Up @@ -297,13 +300,18 @@ def __post_init__(self):
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
assert self.device_type in [
'cuda', 'ascend', 'maca'
'cuda', 'ascend', 'maca', 'camb'
], (f'invalid device_type: {self.device_type}')
if self.quant_policy > 0 and self.device_type not in [
'cuda', 'ascend'
]:
assert False, \
'kv cache quantization only works for CUDA and ASCEND.'
if self.device_type == 'camb' and self.block_size != 16:
self.block_size = 16
logger.warning(
'Currently, camb device requires block size to be 16, \
setting block size to 16')


class ResponseType(enum.Enum):
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/dlinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ascend import AscendOpsBackend # noqa: F401
from .camb import CambOpsBackend # noqa: F401
from .maca import MacaOpsBackend # noqa: F401
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/camb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .op_backend import CambOpsBackend # noqa: F401
132 changes: 132 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.utils import get_logger

from ..op_backend import DlinferOpsBackend

logger = get_logger('lmdeploy')


class CambOpsBackend(DlinferOpsBackend):
"""camb layer backend."""
total_slots = None

@staticmethod
def get_name() -> str:
"""backend name."""
return 'camb'

@staticmethod
def get_k_block_shape(
block_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> Tuple[int, ...]:
return (
num_heads,
block_size,
head_size,
)

@staticmethod
def get_v_block_shape(
block_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> Tuple[int, ...]:
return (
num_heads,
block_size,
head_size,
)

@classmethod
def update_step_context(cls, step_context):
"""update step context."""

def get_total_slots():
if cls.total_slots is None:
cls.total_slots = torch.arange(
block_num * block_size,
dtype=torch.int32,
device=step_context.block_offsets.device)
cls.total_slots = cls.total_slots.view(block_num, block_size)
return cls.total_slots

kv_start_indices = []
block_num, _, block_size, _ = step_context.kv_caches[0][0].shape

is_unpaged_prefill = False
q_start_loc = step_context.q_start_loc
q_seqlens = step_context.q_seqlens
kv_seqlens = step_context.kv_seqlens.to(torch.int32)
block_offsets = step_context.block_offsets.to(torch.int32)
max_q_seq_len = torch.max(q_seqlens).cpu().item()
max_kv_seq_len = torch.max(kv_seqlens).cpu().item()

cu_seqlens = torch.cat(
(q_start_loc, q_seqlens.sum().unsqueeze(0))).int()
cu_seq_lens_kv = None

q_seqlens_list = step_context.q_seqlens.tolist()
kv_seqlens_list = step_context.kv_seqlens.tolist()
if not step_context.is_decoding:
is_unpaged_prefill = q_seqlens_list == kv_seqlens_list
# get kv_indices
for i in range(q_start_loc.size(0)):
q_seq_len = q_seqlens_list[i]
kv_seq_len = kv_seqlens_list[i]
# collect kv start indices.
history_length = kv_seq_len - q_seq_len
total_slots = get_total_slots()
slot_tables = total_slots[block_offsets[i]].view(-1)
slots = slot_tables[history_length:kv_seq_len]
kv_start_indices.append(slots)
kv_start_indices = torch.cat(kv_start_indices)
if not is_unpaged_prefill:
cu_seq_lens_kv = torch.cat(
(torch.tensor([0], device=kv_seqlens.device),
kv_seqlens.cumsum(0))).int()
else:
# collect kv_start_indices without using a for-loop,
# (fill kv-cache for just ONE token during the decoding phase)
idx = (step_context.kv_seqlens - 1) % block_size
block_num = (step_context.kv_seqlens - 1) // block_size
last_block = block_offsets.gather( # dtype of gather must be int64
1, block_num.view(-1, 1)).view(-1)
kv_start_indices = (last_block * block_size + idx).to(torch.int32)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
block_offsets,
q_start_loc=cu_seqlens,
cu_seq_lens_kv=cu_seq_lens_kv,
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
kv_start_indices=kv_start_indices,
block_size=block_size,
attention_mask=None,
is_unpaged_prefill=is_unpaged_prefill,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
)

step_context.attn_metadata = attn_metadata
return step_context

@staticmethod
def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig,
cache_config: CacheConfig,
backend_config: BackendConfig,
device: torch.device):
"""build graph runner."""
from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
return CUDAGraphRunner(model, model_config, cache_config,
backend_config, device)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,8 @@ def get_backend():
if device_type == 'maca':
from .dlinfer import MacaOpsBackend
return MacaOpsBackend
if device_type == 'camb':
from .dlinfer import CambOpsBackend
return CambOpsBackend
else:
raise RuntimeError(f'Unsupported device type: {device_type}')
1 change: 1 addition & 0 deletions lmdeploy/pytorch/check_env/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
'ascend',
'npu',
'maca',
'camb',
]


Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/models/module_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
MODULE_MAP = dict()
ASCEND_MODULE_MAP = dict()
MACA_MODULE_MAP = dict()
CAMB_MODULE_MAP = dict()

DEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP,
maca=MACA_MODULE_MAP)
maca=MACA_MODULE_MAP,
camb=CAMB_MODULE_MAP)

# llama
MODULE_MAP.update({
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def get_max_batch_size(device_type: str):
Args:
device_type (str): the type of device
"""
assert device_type in ['cuda', 'ascend', 'maca']
assert device_type in ['cuda', 'ascend', 'maca', 'camb']
if device_type == 'cuda':
max_batch_size_map = {
'a100': 256,
Expand All @@ -352,6 +352,8 @@ def get_max_batch_size(device_type: str):
return 16
elif device_type == 'maca':
return 128
elif device_type == 'camb':
return 128


def is_bf16_supported(device_type: str = 'cuda'):
Expand Down Expand Up @@ -387,5 +389,7 @@ def is_bf16_supported(device_type: str = 'cuda'):
# return False
elif device_type == 'maca':
return True
elif device_type == 'camb':
return True
else:
return False
21 changes: 21 additions & 0 deletions requirements/runtime_camb.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
accelerate==1.2.0
einops
fastapi
fire
mmengine-lite
numpy<2.0.0
openai
outlines<0.1.0
peft<=0.11.1
pillow
protobuf
pydantic>2.0.0
pynvml
safetensors
sentencepiece
shortuuid
tiktoken
torch==2.4.0
torchvision<=0.19.0,>=0.15.0
transformers
uvicorn
4 changes: 4 additions & 0 deletions requirements_camb.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-r requirements/build.txt
-r requirements/runtime_camb.txt
-r requirements/lite.txt
-r requirements/serve.txt

0 comments on commit 5820107

Please sign in to comment.