-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3df3cf6
commit 76a098c
Showing
7 changed files
with
125 additions
and
120 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
|
||
|
||
@dataclass | ||
class VllmArguments: | ||
""" | ||
VllmArguments is a dataclass that holds the configuration for vllm. | ||
Args: | ||
gpu_memory_utilization (float): GPU memory utilization. Default is 0.9. | ||
tensor_parallel_size (int): Tensor parallelism size. Default is 1. | ||
pipeline_parallel_size(int): Pipeline parallelism size. Default is 1. | ||
max_num_seqs (int): Maximum number of sequences. Default is 256. | ||
max_model_len (Optional[int]): Maximum model length. Default is None. | ||
disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is False. | ||
enforce_eager (bool): Flag to enforce eager execution. Default is False. | ||
limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None. | ||
vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16. | ||
enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False. | ||
""" | ||
# vllm | ||
gpu_memory_utilization: float = 0.9 | ||
tensor_parallel_size: int = 1 | ||
pipeline_parallel_size: int = 1 | ||
max_num_seqs: int = 256 | ||
max_model_len: Optional[int] = None | ||
disable_custom_all_reduce: bool = False | ||
enforce_eager: bool = False | ||
limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 10, "video": 5}' | ||
vllm_max_lora_rank: int = 16 | ||
enable_prefix_caching: bool = False | ||
|
||
def __post_init__(self): | ||
self.limit_mm_per_prompt = ModelArguments.parse_to_dict(self.limit_mm_per_prompt) | ||
|
||
def get_vllm_engine_kwargs(self): | ||
adapters = self.adapters | ||
if hasattr(self, 'adapter_mapping'): | ||
adapters = adapters + list(self.adapter_mapping.values()) | ||
return { | ||
'gpu_memory_utilization': self.gpu_memory_utilization, | ||
'tensor_parallel_size': self.tensor_parallel_size, | ||
'pipeline_parallel_size': self.pipeline_parallel_size, | ||
'max_num_seqs': self.max_num_seqs, | ||
'max_model_len': self.max_model_len, | ||
'disable_custom_all_reduce': self.disable_custom_all_reduce, | ||
'enforce_eager': self.enforce_eager, | ||
'limit_mm_per_prompt': self.limit_mm_per_prompt, | ||
'max_lora_rank': self.vllm_max_lora_rank, | ||
'enable_lora': len(adapters) > 0, | ||
'max_loras': max(len(adapters), 1), | ||
'enable_prefix_caching': self.enable_prefix_caching, | ||
} | ||
|
||
|
||
@dataclass | ||
class LmdeployArguments: | ||
""" | ||
LmdeployArguments is a dataclass that holds the configuration for lmdeploy. | ||
Args: | ||
tp (int): Tensor parallelism size. Default is 1. | ||
session_len(Optional[int]): The session length, default None. | ||
cache_max_entry_count (float): Maximum entry count for cache. Default is 0.8. | ||
quant_policy (int): Quantization policy, e.g., 4, 8. Default is 0. | ||
vision_batch_size (int): Maximum batch size in VisionConfig. Default is 1. | ||
""" | ||
|
||
# lmdeploy | ||
tp: int = 1 | ||
session_len: Optional[int] = None | ||
cache_max_entry_count: float = 0.8 | ||
quant_policy: int = 0 # e.g. 4, 8 | ||
vision_batch_size: int = 1 # max_batch_size in VisionConfig | ||
|
||
def get_lmdeploy_engine_kwargs(self): | ||
return { | ||
'tp': self.tp, | ||
'session_len': self.session_len, | ||
'cache_max_entry_count': self.cache_max_entry_count, | ||
'quant_policy': self.quant_policy, | ||
'vision_batch_size': self.vision_batch_size | ||
} | ||
|
||
|
||
def parse_to_dict(value: Union[str, Dict, None], strict: bool = True) -> Union[str, Dict]: | ||
"""Convert a JSON string or JSON file into a dict""" | ||
# If the value could potentially be a string, it is generally advisable to set strict to False. | ||
if value is None: | ||
value = {} | ||
elif isinstance(value, str): | ||
if os.path.exists(value): # local path | ||
with open(value, 'r', encoding='utf-8') as f: | ||
value = json.load(f) | ||
else: # json str | ||
try: | ||
value = json.loads(value) | ||
except json.JSONDecodeError: | ||
if strict: | ||
logger.error(f"Unable to parse string: '{value}'") | ||
raise | ||
return value |