Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.

import argparse
import os

from ..version import __version__
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters
from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args,
get_chat_template, get_lora_adapters)


class CLI(object):
_desc = 'The CLI provides a unified API for converting, ' \
'compressing and deploying large language models.'
parser = argparse.ArgumentParser(prog='lmdeploy', description=_desc, add_help=True)
parser = FlexibleArgumentParser(prog='lmdeploy', description=_desc, add_help=True)
parser.add_argument('-v', '--version', action='version', version=__version__)
subparsers = parser.add_subparsers(title='Commands', description='lmdeploy has following commands:', dest='command')

Expand Down
48 changes: 27 additions & 21 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def add_parser_api_server():
max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group)
quant_policy = ArgumentHelper.quant_policy(pt_group)
model_format = ArgumentHelper.model_format(pt_group)
hf_overrides = ArgumentHelper.hf_overrides(pt_group)
ArgumentHelper.dp(pt_group)
ArgumentHelper.ep(pt_group)
ArgumentHelper.enable_microbatch(pt_group)
Expand All @@ -189,6 +190,7 @@ def add_parser_api_server():
tb_group._group_actions.append(max_prefill_token_num_act)
tb_group._group_actions.append(quant_policy)
tb_group._group_actions.append(model_format)
tb_group._group_actions.append(hf_overrides)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
Expand Down Expand Up @@ -318,26 +320,29 @@ def api_server(args):
if backend == 'pytorch':
from lmdeploy.messages import PytorchEngineConfig
adapters = get_lora_adapters(args.adapters)
backend_config = PytorchEngineConfig(dtype=args.dtype,
tp=args.tp,
dp=args.dp,
ep=args.ep,
max_batch_size=max_batch_size,
cache_max_entry_count=args.cache_max_entry_count,
block_size=args.cache_block_seq_len,
session_len=args.session_len,
adapters=adapters,
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device,
quant_policy=args.quant_policy,
eager_mode=args.eager_mode,
max_prefill_token_num=args.max_prefill_token_num,
enable_microbatch=args.enable_microbatch,
enable_eplb=args.enable_eplb,
enable_metrics=args.enable_metrics,
role=EngineRole[args.role],
migration_backend=MigrationBackend[args.migration_backend],
model_format=args.model_format)
backend_config = PytorchEngineConfig(
dtype=args.dtype,
tp=args.tp,
dp=args.dp,
ep=args.ep,
max_batch_size=max_batch_size,
cache_max_entry_count=args.cache_max_entry_count,
block_size=args.cache_block_seq_len,
session_len=args.session_len,
adapters=adapters,
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device,
quant_policy=args.quant_policy,
eager_mode=args.eager_mode,
max_prefill_token_num=args.max_prefill_token_num,
enable_microbatch=args.enable_microbatch,
enable_eplb=args.enable_eplb,
enable_metrics=args.enable_metrics,
role=EngineRole[args.role],
migration_backend=MigrationBackend[args.migration_backend],
model_format=args.model_format,
hf_overrides=args.hf_overrides,
)
else:
from lmdeploy.messages import TurbomindEngineConfig
backend_config = TurbomindEngineConfig(dtype=args.dtype,
Expand All @@ -351,7 +356,8 @@ def api_server(args):
cache_block_seq_len=args.cache_block_seq_len,
enable_prefix_caching=args.enable_prefix_caching,
max_prefill_token_num=args.max_prefill_token_num,
communicator=args.communicator)
communicator=args.communicator,
hf_overrides=args.hf_overrides)
chat_template_config = get_chat_template(args.chat_template)

from lmdeploy.messages import VisionConfig
Expand Down
104 changes: 103 additions & 1 deletion lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.

import argparse
from typing import List
import json
import re
import sys
from collections import defaultdict
from typing import Any, List


class DefaultsAndTypesHelpFormatter(argparse.HelpFormatter):
Expand Down Expand Up @@ -231,6 +235,14 @@ def rope_scaling_factor(parser):

return parser.add_argument('--rope-scaling-factor', type=float, default=0.0, help='Rope scaling factor')

@staticmethod
def hf_overrides(parser):
"""Add argument hf_overrides to parser."""
return parser.add_argument('--hf-overrides',
type=json.loads,
default=None,
help='Extra arguments to be forwarded to the HuggingFace config.')

@staticmethod
def use_logn_attn(parser):
"""Add argument use_logn_attn to parser."""
Expand Down Expand Up @@ -580,3 +592,93 @@ def migration_backend(parser):
default='DLSlime',
choices=['DLSlime', 'Mooncake'],
help='kvcache migration management backend when PD disaggregation')


# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
class FlexibleArgumentParser(argparse.ArgumentParser):
""""More flexible argument parser."""

def parse_args(self, args=None, namespace=None):
# If args is not provided, use arguments from the command line
if args is None:
args = sys.argv[1:]

def repl(match: re.Match) -> str:
"""Replaces underscores with dashes in the matched string."""
return match.group(0).replace('_', '-')

# Everything between the first -- and the first .
pattern = re.compile(r'(?<=--)[^\.]*')

# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
if '=' in arg:
key, value = arg.split('=', 1)
key = pattern.sub(repl, key, count=1)
processed_args.append(f'{key}={value}')
else:
key = pattern.sub(repl, arg, count=1)
processed_args.append(key)
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
# allow -O flag to be used without space, e.g. -O3
processed_args.append('-O')
processed_args.append(arg[2:])
else:
processed_args.append(arg)

def _try_convert(value: str):
"""Try to convert string to float or int."""
if not isinstance(value, str):
return value
# try loads from json
try:
return json.loads(value)
except json.JSONDecodeError:
pass
return value

def create_nested_dict(keys: list[str], value: str):
"""Creates a nested dictionary from a list of keys and a value.

For example, `keys = ["a", "b", "c"]` and `value = 1` will create: `{"a": {"b": {"c": 1}}}`
"""
nested_dict: Any = _try_convert(value)
for key in reversed(keys):
nested_dict = {key: nested_dict}
return nested_dict

def recursive_dict_update(original: dict, update: dict):
"""Recursively updates a dictionary with another dictionary."""
for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict):
recursive_dict_update(original[k], v)
else:
original[k] = v

delete = set()
dict_args: dict[str, dict] = defaultdict(dict)
for i, processed_arg in enumerate(processed_args):
if processed_arg.startswith('--') and '.' in processed_arg:
if '=' in processed_arg:
processed_arg, value = processed_arg.split('=', 1)
if '.' not in processed_arg:
# False positive, . was only in the value
continue
else:
value = processed_args[i + 1]
delete.add(i + 1)
key, *keys = processed_arg.split('.')
# Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value)
recursive_dict_update(dict_args[key], arg_dict)
delete.add(i)
# Filter out the dict args we set to None
processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
# Add the dict args back as if they were originally passed as JSON
for dict_arg, dict_value in dict_args.items():
processed_args.append(dict_arg)
processed_args.append(json.dumps(dict_value))

return super().parse_args(processed_args, namespace)
8 changes: 7 additions & 1 deletion lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
import time
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional

import torch
from pydantic.dataclasses import dataclass as pydantic_dataclass
Expand Down Expand Up @@ -223,6 +223,8 @@ class TurbomindEngineConfig:
devices(List[int]): the used devices
empty_init (bool): Whether to load the model weights, you should set
it to True if you want to update weights after create the pipeline
hf_overrides (Dict[str, Any]): Huggingface overrides for the model.
It can be used to override the default config of the model,
"""

dtype: str = 'auto'
Expand Down Expand Up @@ -252,6 +254,7 @@ class TurbomindEngineConfig:
devices: Optional[List[int]] = None
empty_init: bool = False
communicator: str = 'nccl'
hf_overrides: Optional[Dict[str, Any]] = None

def __post_init__(self):
"""Check input validation."""
Expand Down Expand Up @@ -322,6 +325,8 @@ class PytorchEngineConfig:
Default to `MigrationBackend.DLSlime`.
enable_mp_engine (bool): run engine in multi-process mode.
model_format (str): weight quantization policy, options: ['fp8'].
hf_overrides (Dict[str, Any]): Huggingface overrides for the model.
It can be used to override the default config of the model,
"""
dtype: str = 'auto'
tp: int = 1
Expand Down Expand Up @@ -352,6 +357,7 @@ class PytorchEngineConfig:
enable_mp_engine: bool = False
model_format: str = None
enable_metrics: bool = False
hf_overrides: Optional[Dict[str, Any]] = None

role: EngineRole = EngineRole.Hybrid
migration_backend: MigrationBackend = MigrationBackend.DLSlime
Expand Down
32 changes: 24 additions & 8 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def from_pretrained(cls,
pretrained_model_name_or_path: str,
trust_remote_code: bool = True,
dtype: str = 'auto',
dist_config: DistConfig = None):
dist_config: DistConfig = None,
hf_overrides: Dict[str, Any] = None):
"""Instantiate one of the configuration classes of the library from a
pretrained model configuration.

Expand All @@ -168,13 +169,28 @@ def from_pretrained(cls,
models defined on the Hub in their own modeling files.
dtype (str): user specified data type for model weights and
activations. Refer to `PyTorchEngineConfig` for details
hf_overrides (Dict[str, Any]): overrides for the HF config.
"""
from transformers import AutoConfig

from lmdeploy.utils import get_logger

hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
if getattr(hf_config, 'model_type', None) in ['phi3']:
# phi3 + trust_remote_code leads to error when tp.
hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
return cls.from_hf_config(hf_config, pretrained_model_name_or_path, dtype=dtype, dist_config=dist_config)

model_config = cls.from_hf_config(hf_config,
pretrained_model_name_or_path,
dtype=dtype,
dist_config=dist_config)

if hf_overrides is not None:
logger = get_logger('lmdeploy')
logger.warning(f'Overriding HF config with {hf_overrides}')
model_config.hf_config.update(hf_overrides)

return model_config

@classmethod
def from_hf_config(cls,
Expand Down Expand Up @@ -223,14 +239,14 @@ class MiscConfig:
custom_module_map: str = None
empty_init: bool = False
model_format: str = None
hf_overrides: Dict[str, Any] = None

@classmethod
def from_engine_config(cls, engine_config: PytorchEngineConfig):
"""From engine config."""
misc_config = cls(
custom_module_map=engine_config.custom_module_map,
empty_init=engine_config.empty_init,
prefill_interval=engine_config.prefill_interval,
model_format=engine_config.model_format,
)
misc_config = cls(custom_module_map=engine_config.custom_module_map,
empty_init=engine_config.empty_init,
prefill_interval=engine_config.prefill_interval,
model_format=engine_config.model_format,
hf_overrides=engine_config.hf_overrides)
return misc_config
6 changes: 5 additions & 1 deletion lmdeploy/pytorch/engine/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def build_executor(model_path: str,
dp = dist_config.dp
world_size = dist_config.world_size

model_config = ModelConfig.from_pretrained(model_path, trust_remote_code=True, dtype=dtype, dist_config=dist_config)
model_config = ModelConfig.from_pretrained(model_path,
trust_remote_code=True,
dtype=dtype,
hf_overrides=misc_config.hf_overrides,
dist_config=dist_config)

if distributed_executor_backend is None:
distributed_executor_backend = get_distributed_executor_backend(world_size, dp, device_type, logger)
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,11 @@ def __init__(

from lmdeploy.tokenizer import Tokenizer
tokenizer = Tokenizer(model_path).model.model
model_config = ModelConfig.from_pretrained(model_path, dtype=dtype, dist_config=dist_config)
model_config = ModelConfig.from_pretrained(model_path,
dtype=dtype,
hf_overrides=misc_config.hf_overrides,
dist_config=dist_config)

super().__init__(
model_path=model_path,
cache_config=cache_config,
Expand Down
34 changes: 27 additions & 7 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from pydantic.dataclasses import dataclass

from lmdeploy.messages import TurbomindEngineConfig
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def config_from_dict(cls, env):
Expand Down Expand Up @@ -150,15 +153,32 @@ def update_from_engine_config(self, config: TurbomindEngineConfig):
if hasattr(self.attention_config, key):
setattr(self.attention_config, key, value)

# update from hf_overrides
if hasattr(config, 'hf_overrides') and config.hf_overrides:
hf_overrides = config.hf_overrides

if hf_overrides.get('rope_scaling'):
override_params = hf_overrides.get('rope_scaling')

rope_param = self.attention_config.rope_param or RopeParam(type='', base=0, dim=0)
rope_param.type = override_params.get('rope_type', '')
rope_param.factor = override_params.get('factor', 1.0)
rope_param.max_position_embeddings = override_params.get('original_max_position_embeddings', None)

self.attention_config.rope_param = rope_param
logger.warning(f'Overriding HF config with {hf_overrides}')

# use dynamic ntk
if config.rope_scaling_factor:
if self.attention_config.rope_param is None:
# some ut will create empty RopeParam, will check base/dim in src code
self.attention_config.rope_param = RopeParam(type='', base=0, dim=0)
self.attention_config.rope_param.__dict__.update(
type='dynamic',
factor=config.rope_scaling_factor,
max_position_embeddings=self.attention_config.max_position_embeddings)
# some ut will create empty RopeParam, will check base/dim in src code
rope_param = self.attention_config.rope_param or RopeParam(type='', base=0, dim=0)
rope_param.type = 'dynamic'
rope_param.factor = config.rope_scaling_factor
rope_param.max_position_embeddings = self.attention_config.max_position_embeddings

self.attention_config.rope_param = rope_param
logger.warning(
'`--rope-scaling-factor` will be removed in a future release. Please instead use `--hf-overrides`.')

@classmethod
def from_dict(cls, config: dict = {}):
Expand Down
Loading