Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support new backend cambricon #3002

Merged
merged 12 commits into from
Jan 13, 2025
2 changes: 1 addition & 1 deletion lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def calib_search_scale(parser):
@staticmethod
def device(parser,
default: str = 'cuda',
choices: List[str] = ['cuda', 'ascend', 'maca']):
choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']):
"""Add argument device to parser."""

return parser.add_argument('--device',
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __post_init__(self):
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
assert self.device_type in [
'cuda', 'ascend', 'maca'
'cuda', 'ascend', 'maca', 'camb'
], (f'invalid device_type: {self.device_type}')
if self.quant_policy > 0 and self.device_type not in [
'cuda', 'ascend'
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/dlinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ascend import AscendOpsBackend # noqa: F401
from .camb import CambOpsBackend # noqa: F401
from .maca import MacaOpsBackend # noqa: F401
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/camb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .op_backend import CambOpsBackend # noqa: F401
171 changes: 171 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Tuple

import torch
from torch import Tensor

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.model_inputs import StepContext
from lmdeploy.pytorch.models.utils.cudagraph import (CudaGraphMeta,
next_power_of_2)
from lmdeploy.utils import get_logger

from ...graph_runner import GraphRunner

logger = get_logger('lmdeploy')

BuffType = Dict[str, Tensor]


def _false(*args, **kwargs):
"""default value of not support cuda graph."""
return False


class CAMBSingleGraphRunner:
jinminxi104 marked this conversation as resolved.
Show resolved Hide resolved
"""camb single graph runner."""

def __init__(
self,
model: torch.nn.Module,
max_batches: int,
max_tokens: int,
num_blocks: int,
pool: Tuple[int, int],
device: torch.device,
):
self.model = model
self.ctx_mgr = model.ctx_mgr
self.meta = CudaGraphMeta(
max_batchs=max_batches,
max_tokens=max_tokens,
num_blocks=num_blocks,
is_decoding=True,
device=device,
input_buffers=dict(),
output_buffers=dict(),
)
self.device = device
self.max_batches = max_batches
self.max_tokens = max_tokens
self.num_blocks = num_blocks
self.pool = pool
self._graph: torch.mlu.CUDAGraph = None

def capture(self, **kwargs):
"""capture graph."""
self.meta.input_buffers = self.model.make_buffers_cudagraph(
self.meta, **kwargs)
padded_kwargs = self.model.fill_buffers_cudagraph(self.meta, **kwargs)

context = self.ctx_mgr.current_context()
self.model.update_context_cudagraph(self.meta, context)
current_stream = torch.mlu.current_stream()

# warmup
self.model(**padded_kwargs)

self._graph = torch.mlu.CUDAGraph()
# unsafe kernel call in other thread might invalid the capture
# so we set thread_safe capture mode here.
with torch.mlu.graph(self._graph,
pool=self.pool,
stream=current_stream,
capture_error_mode='thread_local'):
output = self.model(**padded_kwargs)

output_buffers = dict(logits=output)
self.meta.output_buffers = output_buffers
return output

def forward(self, **kwargs):
"""forward."""
num_tokens = kwargs['input_ids'].size(-1)
assert self._graph is not None
self.model.fill_buffers_cudagraph(self.meta, **kwargs)
context = self.ctx_mgr.current_context()
self.model.update_context_cudagraph(self.meta, context)
self._graph.replay()

output = self.meta.output_buffers['logits'][:, :num_tokens]
return output

def __del__(self):
"""del."""
del self._graph


class CAMBGraphRunner(GraphRunner):
"""CAMB graph runner."""

def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
cache_config: CacheConfig, backend_config: BackendConfig,
device: torch.device):
super().__init__(model, model_config, cache_config, backend_config,
device)
self.max_batches = cache_config.max_batches
self.max_tokens = cache_config.max_prefill_token_num
self.num_blocks = cache_config.num_gpu_blocks

self.enable_graph = self.check_enable_graph()

self.graph_pool_handle = torch.mlu.graph_pool_handle()
self._runner_map: Dict[Any, CAMBSingleGraphRunner] = dict()

def check_enable_graph(self):
"""check enable graph."""
if self.backend_config.eager_mode:
return _false

return getattr(self.model, 'support_cuda_graph', _false)

def get_graph_key(self, input_ids: torch.Tensor,
position_ids: torch.Tensor, past_key_values: List,
attn_metadata: Any, inputs_embeds: torch.Tensor,
**kwargs):
"""get graph key."""
context = self.ctx_mgr.current_context()
is_decoding = context.is_decoding
num_tokens = input_ids.numel()
new_num_tokens = next_power_of_2(num_tokens)
return (new_num_tokens, is_decoding)

def __call__(self, **kwargs):
"""call."""
enable_graph = self.enable_graph(**kwargs)
graph_key = self.get_graph_key(**kwargs)
max_tokens = graph_key[0]
is_decoding = graph_key[1]

# only enable graph when decoding
if (not enable_graph) or (not is_decoding):
return self.model(**kwargs)

if graph_key not in self._runner_map:
max_batches = max_tokens
runner = CAMBSingleGraphRunner(self.model,
max_batches=max_batches,
max_tokens=max_tokens,
num_blocks=self.num_blocks,
pool=self.graph_pool_handle,
device=self.device)
runner.capture(**kwargs)
self._runner_map[graph_key] = runner
else:
runner = self._runner_map[graph_key]

output = runner.forward(**kwargs)
return output

def prepare_inputs_for_generation(
self,
past_key_values: List[List[torch.Tensor]],
inputs_embeds: torch.Tensor = None,
context: StepContext = None,
):
"""prepare inputs."""
return self.model.prepare_inputs_for_generation(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
context=context,
)
132 changes: 132 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.utils import get_logger

from ..op_backend import DlinferOpsBackend

logger = get_logger('lmdeploy')


class CambOpsBackend(DlinferOpsBackend):
"""camb layer backend."""
total_slots = None

@staticmethod
def get_name() -> str:
"""backend name."""
return 'camb'

@staticmethod
def get_k_block_shape(
block_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> Tuple[int, ...]:
return (
num_heads,
block_size,
head_size,
)

@staticmethod
def get_v_block_shape(
block_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> Tuple[int, ...]:
return (
num_heads,
block_size,
head_size,
)

@classmethod
def update_step_context(cls, step_context):
"""update step context."""

def get_total_slots():
if cls.total_slots is None:
cls.total_slots = torch.arange(
block_num * block_size,
dtype=torch.int32,
device=step_context.block_offsets.device)
cls.total_slots = cls.total_slots.view(block_num, block_size)
return cls.total_slots

kv_start_indices = []
block_num, _, block_size, _ = step_context.kv_caches[0][0].shape

is_unpaged_prefill = False
q_start_loc = step_context.q_start_loc
q_seqlens = step_context.q_seqlens
kv_seqlens = step_context.kv_seqlens.to(torch.int32)
block_offsets = step_context.block_offsets.to(torch.int32)
max_q_seq_len = torch.max(q_seqlens).cpu().item()
max_kv_seq_len = torch.max(kv_seqlens).cpu().item()

cu_seqlens = torch.cat(
(q_start_loc, q_seqlens.sum().unsqueeze(0))).int()
cu_seq_lens_kv = None

q_seqlens_list = step_context.q_seqlens.tolist()
kv_seqlens_list = step_context.kv_seqlens.tolist()
if not step_context.is_decoding:
is_unpaged_prefill = q_seqlens_list == kv_seqlens_list
# get kv_indices
for i in range(q_start_loc.size(0)):
q_seq_len = q_seqlens_list[i]
kv_seq_len = kv_seqlens_list[i]
# collect kv start indices.
history_length = kv_seq_len - q_seq_len
total_slots = get_total_slots()
slot_tables = total_slots[block_offsets[i]].view(-1)
slots = slot_tables[history_length:kv_seq_len]
kv_start_indices.append(slots)
kv_start_indices = torch.cat(kv_start_indices)
if not is_unpaged_prefill:
cu_seq_lens_kv = torch.cat(
(torch.tensor([0], device=kv_seqlens.device),
kv_seqlens.cumsum(0))).int()
else:
# collect kv_start_indices without using a for-loop,
# (fill kv-cache for just ONE token during the decoding phase)
idx = (step_context.kv_seqlens - 1) % block_size
block_num = (step_context.kv_seqlens - 1) // block_size
last_block = block_offsets.gather( # dtype of gather must be int64
1, block_num.view(-1, 1)).view(-1)
kv_start_indices = (last_block * block_size + idx).to(torch.int32)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
block_offsets,
q_start_loc=cu_seqlens,
cu_seq_lens_kv=cu_seq_lens_kv,
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
kv_start_indices=kv_start_indices,
block_size=block_size,
attention_mask=None,
is_unpaged_prefill=is_unpaged_prefill,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
)

step_context.attn_metadata = attn_metadata
return step_context

@staticmethod
def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig,
cache_config: CacheConfig,
backend_config: BackendConfig,
device: torch.device):
"""build graph runner."""
from .graph_runner import CAMBGraphRunner
return CAMBGraphRunner(model, model_config, cache_config,
backend_config, device)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,8 @@ def get_backend():
if device_type == 'maca':
from .dlinfer import MacaOpsBackend
return MacaOpsBackend
if device_type == 'camb':
from .dlinfer import CambOpsBackend
return CambOpsBackend
else:
raise RuntimeError(f'Unsupported device type: {device_type}')
1 change: 1 addition & 0 deletions lmdeploy/pytorch/check_env/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
'ascend',
'npu',
'maca',
'camb',
]


Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/models/module_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
MODULE_MAP = dict()
ASCEND_MODULE_MAP = dict()
MACA_MODULE_MAP = dict()
CAMB_MODULE_MAP = dict()

DEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP,
maca=MACA_MODULE_MAP)
maca=MACA_MODULE_MAP,
camb=CAMB_MODULE_MAP)

# llama
MODULE_MAP.update({
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def get_max_batch_size(device_type: str):
Args:
device_type (str): the type of device
"""
assert device_type in ['cuda', 'ascend', 'maca']
assert device_type in ['cuda', 'ascend', 'maca', 'camb']
if device_type == 'cuda':
max_batch_size_map = {
'a100': 256,
Expand All @@ -352,6 +352,8 @@ def get_max_batch_size(device_type: str):
return 16
elif device_type == 'maca':
return 128
elif device_type == 'camb':
return 128


def is_bf16_supported(device_type: str = 'cuda'):
Expand Down Expand Up @@ -387,5 +389,7 @@ def is_bf16_supported(device_type: str = 'cuda'):
# return False
elif device_type == 'maca':
return True
elif device_type == 'camb':
return True
else:
return False
Loading