diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 4edf23d684..25fcdd4620 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -377,7 +377,7 @@ def calib_search_scale(parser): @staticmethod def device(parser, default: str = 'cuda', - choices: List[str] = ['cuda', 'ascend', 'maca']): + choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']): """Add argument device to parser.""" return parser.add_argument('--device', diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 11626f44a2..cfc146f86d 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -7,6 +7,9 @@ from pydantic.dataclasses import dataclass as pydantic_dataclass from .tokenizer import Tokenizer +from .utils import get_logger + +logger = get_logger('lmdeploy') LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] """LogitsProcessor is a function that takes a tensor of input_ids, the logits @@ -297,13 +300,18 @@ def __post_init__(self): assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks' assert self.quant_policy in (0, 4, 8), 'invalid quant_policy' assert self.device_type in [ - 'cuda', 'ascend', 'maca' + 'cuda', 'ascend', 'maca', 'camb' ], (f'invalid device_type: {self.device_type}') if self.quant_policy > 0 and self.device_type not in [ 'cuda', 'ascend' ]: assert False, \ 'kv cache quantization only works for CUDA and ASCEND.' + if self.device_type == 'camb' and self.block_size != 16: + self.block_size = 16 + logger.warning( + 'Currently, camb device requires block size to be 16, \ + setting block size to 16') class ResponseType(enum.Enum): diff --git a/lmdeploy/pytorch/backends/dlinfer/__init__.py b/lmdeploy/pytorch/backends/dlinfer/__init__.py index af3ccff085..1cf6eea440 100644 --- a/lmdeploy/pytorch/backends/dlinfer/__init__.py +++ b/lmdeploy/pytorch/backends/dlinfer/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ascend import AscendOpsBackend # noqa: F401 +from .camb import CambOpsBackend # noqa: F401 from .maca import MacaOpsBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py b/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py new file mode 100644 index 0000000000..897495c209 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .op_backend import CambOpsBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py new file mode 100644 index 0000000000..89c71f46fb --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch + +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.utils import get_logger + +from ..op_backend import DlinferOpsBackend + +logger = get_logger('lmdeploy') + + +class CambOpsBackend(DlinferOpsBackend): + """camb layer backend.""" + total_slots = None + + @staticmethod + def get_name() -> str: + """backend name.""" + return 'camb' + + @staticmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + num_heads, + block_size, + head_size, + ) + + @staticmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + num_heads, + block_size, + head_size, + ) + + @classmethod + def update_step_context(cls, step_context): + """update step context.""" + + def get_total_slots(): + if cls.total_slots is None: + cls.total_slots = torch.arange( + block_num * block_size, + dtype=torch.int32, + device=step_context.block_offsets.device) + cls.total_slots = cls.total_slots.view(block_num, block_size) + return cls.total_slots + + kv_start_indices = [] + block_num, _, block_size, _ = step_context.kv_caches[0][0].shape + + is_unpaged_prefill = False + q_start_loc = step_context.q_start_loc + q_seqlens = step_context.q_seqlens + kv_seqlens = step_context.kv_seqlens.to(torch.int32) + block_offsets = step_context.block_offsets.to(torch.int32) + max_q_seq_len = torch.max(q_seqlens).cpu().item() + max_kv_seq_len = torch.max(kv_seqlens).cpu().item() + + cu_seqlens = torch.cat( + (q_start_loc, q_seqlens.sum().unsqueeze(0))).int() + cu_seq_lens_kv = None + + q_seqlens_list = step_context.q_seqlens.tolist() + kv_seqlens_list = step_context.kv_seqlens.tolist() + if not step_context.is_decoding: + is_unpaged_prefill = q_seqlens_list == kv_seqlens_list + # get kv_indices + for i in range(q_start_loc.size(0)): + q_seq_len = q_seqlens_list[i] + kv_seq_len = kv_seqlens_list[i] + # collect kv start indices. + history_length = kv_seq_len - q_seq_len + total_slots = get_total_slots() + slot_tables = total_slots[block_offsets[i]].view(-1) + slots = slot_tables[history_length:kv_seq_len] + kv_start_indices.append(slots) + kv_start_indices = torch.cat(kv_start_indices) + if not is_unpaged_prefill: + cu_seq_lens_kv = torch.cat( + (torch.tensor([0], device=kv_seqlens.device), + kv_seqlens.cumsum(0))).int() + else: + # collect kv_start_indices without using a for-loop, + # (fill kv-cache for just ONE token during the decoding phase) + idx = (step_context.kv_seqlens - 1) % block_size + block_num = (step_context.kv_seqlens - 1) // block_size + last_block = block_offsets.gather( # dtype of gather must be int64 + 1, block_num.view(-1, 1)).view(-1) + kv_start_indices = (last_block * block_size + idx).to(torch.int32) + + attn_meta_cls = cls.get_attention_metadata_cls() + attn_metadata = attn_meta_cls( + step_context.is_decoding, + block_offsets, + q_start_loc=cu_seqlens, + cu_seq_lens_kv=cu_seq_lens_kv, + q_seqlens=q_seqlens, + kv_seqlens=kv_seqlens, + kv_start_indices=kv_start_indices, + block_size=block_size, + attention_mask=None, + is_unpaged_prefill=is_unpaged_prefill, + max_q_seq_len=max_q_seq_len, + max_kv_seq_len=max_kv_seq_len, + ) + + step_context.attn_metadata = attn_metadata + return step_context + + @staticmethod + def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + device: torch.device): + """build graph runner.""" + from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner + return CUDAGraphRunner(model, model_config, cache_config, + backend_config, device) diff --git a/lmdeploy/pytorch/backends/selector.py b/lmdeploy/pytorch/backends/selector.py index 987730a981..4db73fa370 100644 --- a/lmdeploy/pytorch/backends/selector.py +++ b/lmdeploy/pytorch/backends/selector.py @@ -18,5 +18,8 @@ def get_backend(): if device_type == 'maca': from .dlinfer import MacaOpsBackend return MacaOpsBackend + if device_type == 'camb': + from .dlinfer import CambOpsBackend + return CambOpsBackend else: raise RuntimeError(f'Unsupported device type: {device_type}') diff --git a/lmdeploy/pytorch/check_env/deeplink.py b/lmdeploy/pytorch/check_env/deeplink.py index 74ab5a7b87..00bcfdf77c 100644 --- a/lmdeploy/pytorch/check_env/deeplink.py +++ b/lmdeploy/pytorch/check_env/deeplink.py @@ -5,6 +5,7 @@ 'ascend', 'npu', 'maca', + 'camb', ] diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index c1b62736f7..c01a166b94 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -6,9 +6,11 @@ MODULE_MAP = dict() ASCEND_MODULE_MAP = dict() MACA_MODULE_MAP = dict() +CAMB_MODULE_MAP = dict() DEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP, - maca=MACA_MODULE_MAP) + maca=MACA_MODULE_MAP, + camb=CAMB_MODULE_MAP) # llama MODULE_MAP.update({ diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index fbdd374f80..e9ef0ba2bb 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -332,7 +332,7 @@ def get_max_batch_size(device_type: str): Args: device_type (str): the type of device """ - assert device_type in ['cuda', 'ascend', 'maca'] + assert device_type in ['cuda', 'ascend', 'maca', 'camb'] if device_type == 'cuda': max_batch_size_map = { 'a100': 256, @@ -352,6 +352,8 @@ def get_max_batch_size(device_type: str): return 16 elif device_type == 'maca': return 128 + elif device_type == 'camb': + return 128 def is_bf16_supported(device_type: str = 'cuda'): @@ -387,5 +389,7 @@ def is_bf16_supported(device_type: str = 'cuda'): # return False elif device_type == 'maca': return True + elif device_type == 'camb': + return True else: return False diff --git a/requirements/runtime_camb.txt b/requirements/runtime_camb.txt new file mode 100644 index 0000000000..e56d0cb494 --- /dev/null +++ b/requirements/runtime_camb.txt @@ -0,0 +1,21 @@ +accelerate==1.2.0 +einops +fastapi +fire +mmengine-lite +numpy<2.0.0 +openai +outlines<0.1.0 +peft<=0.11.1 +pillow +protobuf +pydantic>2.0.0 +pynvml +safetensors +sentencepiece +shortuuid +tiktoken +torch==2.4.0 +torchvision<=0.19.0,>=0.15.0 +transformers +uvicorn diff --git a/requirements_camb.txt b/requirements_camb.txt new file mode 100644 index 0000000000..24b1f3e796 --- /dev/null +++ b/requirements_camb.txt @@ -0,0 +1,4 @@ +-r requirements/build.txt +-r requirements/runtime_camb.txt +-r requirements/lite.txt +-r requirements/serve.txt