Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
<li>Qwen2-MoE (57BA14B)</li>
<li>Qwen2.5 (0.5B - 32B)</li>
<li>Qwen3, Qwen3-MoE</li>
<li>Qwen3-Next(80B)</li>
<li>Baichuan (7B)</li>
<li>Baichuan2 (7B-13B)</li>
<li>Code Llama (7B - 34B)</li>
Expand Down
1 change: 1 addition & 0 deletions README_ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
<li>Qwen2-MoE (57BA14B)</li>
<li>Qwen2.5 (0.5B - 32B)</li>
<li>Qwen3, Qwen3-MoE</li>
<li>Qwen3-Next(80B)</li>
<li>Baichuan (7B)</li>
<li>Baichuan2 (7B-13B)</li>
<li>Code Llama (7B - 34B)</li>
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
<li>Qwen2-MoE (57BA14B)</li>
<li>Qwen2.5 (0.5B - 32B)</li>
<li>Qwen3, Qwen3-MoE</li>
<li>Qwen3-Next(80B)</li>
<li>Baichuan (7B)</li>
<li>Baichuan2 (7B-13B)</li>
<li>Code Llama (7B - 34B)</li>
Expand Down
1 change: 1 addition & 0 deletions docs/en/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes\* |
| QWen3-Next | 80B | LLM | Yes | No | No | No | No |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes |
| QWen3-Next | 80B | LLM | Yes | No | No | No | No |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
Expand Down
8 changes: 7 additions & 1 deletion lmdeploy/pytorch/check_env/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def check_dtype(self, config):
if not is_bf16_supported(device_type):
logger.warning('Device does not support bfloat16.')
except Exception as e:
message = (f'Checking failed with error {e}', 'Please send issue to LMDeploy with error logs.')
message = (f'Checking failed with error {e}. Please send issue to LMDeploy with error logs.')
self.log_and_exit(e, 'Model', message=message)

try:
model_config.check_env_func(device_type)
except Exception as e:
message = (f'Checking failed with error {e}.')
self.log_and_exit(e, 'Model', message=message)

def check(self):
Expand Down
17 changes: 15 additions & 2 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import enum
from dataclasses import dataclass
from typing import Any, Dict, List, Literal
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Literal, Tuple

import torch

Expand Down Expand Up @@ -86,6 +86,8 @@ class CacheConfig:
enable_prefix_caching: bool = False
quant_policy: Literal[0, 4, 8] = 0
device_type: str = 'cuda'
num_state_caches: int = None
states_shapes: List[Tuple] = field(default_factory=list)

# For PD Disaggregation
role: EngineRole = EngineRole.Hybrid
Expand Down Expand Up @@ -183,6 +185,10 @@ def override_hf_config(hf_config: Any, hf_overrides: Dict[str, Any]):
_override_hf_config(hf_config, k, v)


def _default_check_env(device: str):
pass


@dataclass
class ModelConfig:
"""Config of model."""
Expand All @@ -208,6 +214,13 @@ class ModelConfig:
dllm_mask_token: int = 0
dllm_block_length: int = None

# added for qwen3_next
# could used for any SSM model.
states_shapes: List[Tuple[Tuple[int], torch.dtype]] = field(default_factory=list)

# check env for model-device combination
check_env_func: Callable = _default_check_env

def get_head_size(self):
"""Get head size."""
return self.head_dim
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/configurations/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
eos_token_id=hf_config.eos_token_id,
sliding_window=sliding_window,
head_dim=head_dim,
k_head_dim=head_dim,
v_head_dim=head_dim,
vocab_size=hf_config.vocab_size,
llm_config=hf_config,
)
58 changes: 58 additions & 0 deletions lmdeploy/pytorch/configurations/qwen3_next.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .builder import AutoModelConfigBuilder
from .default import DefaultModelConfigBuilder


def _check_env_qwen3_next(device: str):
"""Check env for qwen3 next."""
if device != 'cuda':
return

# check cuda
try:
import causal_conv1d # noqa: F401
except ImportError:
raise ImportError('Qwen3-Next cuda support requires https://github.com/Dao-AILab/causal-conv1d.')

try:
import fla # noqa: F401
except ImportError:
raise ImportError('Qwen3-Next cuda support requires https://github.com/fla-org/flash-linear-attention.')


class Qwen3NextModelConfigBuilder(AutoModelConfigBuilder):

@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.model_type == 'qwen3_next'

@classmethod
def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):
"""build."""
cfg = DefaultModelConfigBuilder.build(hf_config, model_path, tp=tp, **kwargs)

# update num layers
num_layers = cfg.num_layers
num_full_layers = num_layers // hf_config.full_attention_interval
num_delta_layers = num_full_layers * (hf_config.full_attention_interval - 1)
cfg.num_layers = num_full_layers

# set state shapes
head_k_dim = hf_config.linear_key_head_dim
head_v_dim = hf_config.linear_value_head_dim
num_v_heads = hf_config.linear_num_value_heads // tp
num_k_heads = hf_config.linear_num_key_heads // tp
key_dim = head_k_dim * num_k_heads
value_dim = head_v_dim * num_v_heads
conv_dim = key_dim * 2 + value_dim
conv_kernel_size = hf_config.linear_conv_kernel_dim

conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size)
recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)
dtype = torch.bfloat16
cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]
cfg.check_env_func = _check_env_qwen3_next
return cfg
28 changes: 28 additions & 0 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,31 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote
))

""" Metheds for PD Disaggregation End. """


class StateCacheEngine:
"""Cache engine for state cache."""

def __init__(self, cache_config: CacheConfig):
self.cache_config = cache_config
self._state_caches = self._allocate_caches(num_caches=cache_config.num_state_caches,
state_shapes=cache_config.states_shapes,
device='cuda')

@staticmethod
def _allocate_caches(num_caches: int, state_shapes: List[Tuple[Tuple[int], torch.dtype]], device: torch.device):
"""Allocate cache implement."""
caches = []
for shape, dtype in state_shapes:
cache = torch.zeros(
size=(num_caches, *shape),
dtype=dtype,
device=device,
)
caches.append(cache)
return caches

@property
def state_caches(self):
"""State caches."""
return self._state_caches
5 changes: 5 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,11 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs)
model_inputs.vision_inputs = vision_model_inputs

# ssm
if len(self.cache_config.states_shapes) > 0:
state_offsets = torch.tensor([msg.logical_state for msg in messages])
model_inputs.state_offsets = state_offsets

return model_inputs

def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor,
Expand Down
36 changes: 35 additions & 1 deletion lmdeploy/pytorch/engine/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,50 @@ def _adjust_block_size(self):
f'Update `block_size={self.cache_config.block_size}` for large `head_dim={self.model_config.k_head_dim}`.' # noqa
)

def _get_state_cache_mem(self):
"""Get state cache mem usage."""
cache_config = self.cache_config
if len(cache_config.states_shapes) == 0:
return 0

import math

num_state_caches = cache_config.num_state_caches
if num_state_caches is None:
# add more caches for eviction
# TODO: Share memory between state cache and pageable cache
num_state_caches = int(cache_config.max_batches + 8)
cache_config.num_state_caches = num_state_caches

mems = 0
for shape, dtype in cache_config.states_shapes:
dtype_size = dtype.itemsize
mems += math.prod(shape) * num_state_caches * dtype_size

if cache_config.enable_prefix_caching:
cache_config.enable_prefix_caching = False
logger.warning('Prefix caching has not been support for state space model.')

return mems

def update_configs(self):
"""Update cache config."""
self._adjust_block_size()
cache_config = self.cache_config
model_config = self.model_config
cache_config.states_shapes = model_config.states_shapes

# get free mems
free_mems = self.gather_free_mem()
free_mem = min(free_mems)
logger.debug(f'minimal free gpu memory: {free_mem >> 20} mb')
vocal_size = self.model_config.vocab_size

# get state cache size
state_cache_mem = self._get_state_cache_mem()
free_mem = free_mem - state_cache_mem
assert free_mem > 0, 'No enough gpu memory for state cache. Please reduce max_batch_size.'

vocal_size = self.model_config.vocab_size
tp = self.dist_config.attn_config.tp
cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, tp,
cache_config.quant_policy)
Expand Down
8 changes: 6 additions & 2 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria
from ..utils import get_gpu_memory
from ..weight_loader.model_weight_loader import load_model_weights
from .cache_engine import CacheEngine
from .cache_engine import CacheEngine, StateCacheEngine
from .logits_process import FusedLogitsProcessor, SamplingInputs

logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -221,6 +221,7 @@ def model_forward(
model: torch.nn.Module,
inputs: ModelInputs,
cache_engine: CacheEngine,
state_cache_engine: StateCacheEngine,
stream: torch.cuda.Stream = None,
):
"""Perform model forward."""
Expand All @@ -232,6 +233,7 @@ def model_forward(
inputs=inputs,
model_config=cache_engine.model_config,
kv_caches=cache_engine.gpu_cache,
state_caches=state_cache_engine.state_caches,
kv_quant_policy=cache_engine.cache_config.quant_policy,
)
with ctx_mgr.context(context):
Expand Down Expand Up @@ -351,6 +353,7 @@ def __init__(self,

self.patched_model = None
self.cache_engine = None
self.state_cache_engine = None
self.profiler: AgentProfiler = None

# microbatch
Expand Down Expand Up @@ -391,7 +394,6 @@ def get_free_mem(self):

def warmup(self):
"""warmup."""
# TODO: disable for now, do not remove the comments.
with self.all_context():
max_batches = self.cache_config.max_batches
num_tokens = max_batches
Expand Down Expand Up @@ -958,12 +960,14 @@ def build_cache_engine(self):
tp_rank=self.tp_rank,
world_size=tp,
cache_stream=self.cache_stream)
self.state_cache_engine = StateCacheEngine(self.cache_config)

def _forward_impl(self, inputs: ModelInputs):
output = model_forward(
self.patched_model,
inputs,
self.cache_engine,
state_cache_engine=self.state_cache_engine,
stream=self.stream,
)
return output
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ class SchedulerSequence:
num_new_tokens: int = 0
sampling_param: SamplingParam = field(default_factory=SamplingParam)
logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks)
logical_state: int = -1
adapter_name: str = None
arrive_time: float = 0.0
output_start_pos: int = 0
Expand Down
11 changes: 11 additions & 0 deletions lmdeploy/pytorch/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class ModelInputs:
model_metas: List[Dict[str, Any]] = None
dp_meta: 'DPMeta' = None
enable_microbatch: bool = False
state_offsets: torch.LongTensor = None

def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None):
"""Update input ids."""
Expand Down Expand Up @@ -256,6 +257,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
model_metas=self.model_metas,
cross_length=cross_length,
history_cross_length=history_cross_length,
state_offsets=self.state_offsets,
)
ret.append(inp)
history_cross_length = cross_length
Expand Down Expand Up @@ -322,6 +324,10 @@ class StepContext:
dp_meta: DPMeta = None
enable_microbatch: bool = False

# states for ssm
state_caches: List = None
state_offsets: torch.LongTensor = None

_outputs: Dict = field(default_factory=dict)

@classmethod
Expand All @@ -330,6 +336,7 @@ def new(
inputs: ModelInputs,
model_config: ModelConfig,
kv_caches: List = None,
state_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
):
"""Build step context.
Expand Down Expand Up @@ -389,6 +396,8 @@ def new(
cross_kv_seqlens=cross_kv_seqlens,
dp_meta=inputs.dp_meta,
enable_microbatch=inputs.enable_microbatch,
state_caches=state_caches,
state_offsets=inputs.state_offsets,
)

ret = get_backend().update_step_context(ret)
Expand Down Expand Up @@ -454,13 +463,15 @@ def build_context(
inputs: ModelInputs,
model_config: ModelConfig,
kv_caches: List = None,
state_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
):
"""Build context."""
return StepContext.new(
inputs,
model_config,
kv_caches,
state_caches,
kv_quant_policy,
)

Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/pytorch/models/module_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,11 @@
'GptOssForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gpt_oss.GptOssForCausalLM',
})

# qwen3 next model
MODULE_MAP.update({
'Qwen3NextForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_next.Qwen3NextForCausalLM',
})

# SDAR
MODULE_MAP.update({
'SDARForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar.SDARForCausalLM',
Expand Down
Loading