From 76a098cfe91ea2598569c086c72ce598c59fcf9c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 11 Feb 2025 13:56:06 +0800 Subject: [PATCH] update --- swift/llm/argument/base_args/model_args.py | 18 ---- swift/llm/argument/infer_args.py | 82 ----------------- swift/llm/argument/rlhf_args.py | 31 ++++--- swift/llm/argument/train_args.py | 6 +- swift/trainers/rlhf_arguments.py | 4 +- swift/utils/__init__.py | 3 +- swift/utils/argumens.py | 101 +++++++++++++++++++++ 7 files changed, 125 insertions(+), 120 deletions(-) create mode 100644 swift/utils/argumens.py diff --git a/swift/llm/argument/base_args/model_args.py b/swift/llm/argument/base_args/model_args.py index a39c0e769d..2646f321b3 100644 --- a/swift/llm/argument/base_args/model_args.py +++ b/swift/llm/argument/base_args/model_args.py @@ -48,24 +48,6 @@ class ModelArguments: # this parameter specifies the path to the locally downloaded repository. local_repo_path: Optional[str] = None - @staticmethod - def parse_to_dict(value: Union[str, Dict, None], strict: bool = True) -> Union[str, Dict]: - """Convert a JSON string or JSON file into a dict""" - # If the value could potentially be a string, it is generally advisable to set strict to False. - if value is None: - value = {} - elif isinstance(value, str): - if os.path.exists(value): # local path - with open(value, 'r', encoding='utf-8') as f: - value = json.load(f) - else: # json str - try: - value = json.loads(value) - except json.JSONDecodeError: - if strict: - logger.error(f"Unable to parse string: '{value}'") - raise - return value def _init_device_map(self): """Prepare device map args""" diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index e131341b16..3c2fc89654 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -16,88 +16,6 @@ logger = get_logger() -@dataclass -class LmdeployArguments: - """ - LmdeployArguments is a dataclass that holds the configuration for lmdeploy. - - Args: - tp (int): Tensor parallelism size. Default is 1. - session_len(Optional[int]): The session length, default None. - cache_max_entry_count (float): Maximum entry count for cache. Default is 0.8. - quant_policy (int): Quantization policy, e.g., 4, 8. Default is 0. - vision_batch_size (int): Maximum batch size in VisionConfig. Default is 1. - """ - - # lmdeploy - tp: int = 1 - session_len: Optional[int] = None - cache_max_entry_count: float = 0.8 - quant_policy: int = 0 # e.g. 4, 8 - vision_batch_size: int = 1 # max_batch_size in VisionConfig - - def get_lmdeploy_engine_kwargs(self): - return { - 'tp': self.tp, - 'session_len': self.session_len, - 'cache_max_entry_count': self.cache_max_entry_count, - 'quant_policy': self.quant_policy, - 'vision_batch_size': self.vision_batch_size - } - - -@dataclass -class VllmArguments: - """ - VllmArguments is a dataclass that holds the configuration for vllm. - - Args: - gpu_memory_utilization (float): GPU memory utilization. Default is 0.9. - tensor_parallel_size (int): Tensor parallelism size. Default is 1. - pipeline_parallel_size(int): Pipeline parallelism size. Default is 1. - max_num_seqs (int): Maximum number of sequences. Default is 256. - max_model_len (Optional[int]): Maximum model length. Default is None. - disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is False. - enforce_eager (bool): Flag to enforce eager execution. Default is False. - limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None. - vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16. - enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False. - """ - # vllm - gpu_memory_utilization: float = 0.9 - tensor_parallel_size: int = 1 - pipeline_parallel_size: int = 1 - max_num_seqs: int = 256 - max_model_len: Optional[int] = None - disable_custom_all_reduce: bool = False - enforce_eager: bool = False - limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 10, "video": 5}' - vllm_max_lora_rank: int = 16 - enable_prefix_caching: bool = False - - def __post_init__(self): - self.limit_mm_per_prompt = ModelArguments.parse_to_dict(self.limit_mm_per_prompt) - - def get_vllm_engine_kwargs(self): - adapters = self.adapters - if hasattr(self, 'adapter_mapping'): - adapters = adapters + list(self.adapter_mapping.values()) - return { - 'gpu_memory_utilization': self.gpu_memory_utilization, - 'tensor_parallel_size': self.tensor_parallel_size, - 'pipeline_parallel_size': self.pipeline_parallel_size, - 'max_num_seqs': self.max_num_seqs, - 'max_model_len': self.max_model_len, - 'disable_custom_all_reduce': self.disable_custom_all_reduce, - 'enforce_eager': self.enforce_eager, - 'limit_mm_per_prompt': self.limit_mm_per_prompt, - 'max_lora_rank': self.vllm_max_lora_rank, - 'enable_lora': len(adapters) > 0, - 'max_loras': max(len(adapters), 1), - 'enable_prefix_caching': self.enable_prefix_caching, - } - - @dataclass class InferArguments(MergeArguments, VllmArguments, LmdeployArguments, BaseArguments): """ diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 4dd513b232..e08501df80 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -4,20 +4,22 @@ from typing import List, Literal, Optional from swift.llm import MODEL_MAPPING -from swift.utils import get_logger +from swift.utils import get_logger, VllmArguments from .train_args import TrainArguments logger = get_logger() @dataclass -class PPOArguments: +class RewardModelArguments: reward_model: Optional[str] = None reward_adapters: List[str] = field(default_factory=list) reward_model_type: Optional[str] = field( default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) reward_model_revision: Optional[str] = None +@dataclass +class PPOArguments: num_ppo_epochs: int = 4 whiten_rewards: bool = False kl_coef: float = 0.05 @@ -31,12 +33,20 @@ class PPOArguments: local_rollout_forward_batch_size: int = 64 num_sample_generations: int = 10 response_length: int = 512 - temperature: float = 0.7 missing_eos_penalty: Optional[float] = None @dataclass -class RLHFArguments(PPOArguments, TrainArguments): +class GRPOArguments(VllmArguments): + num_generations: int = 8 # G in the GRPO paper + max_completion_length: int = 512 + reward_funcs: List[str] = field(default_factory=list) + # vLLM in GRPO + use_vllm: bool = False + vllm_device: Optional[str] = 'auto' # 'cuda:1' + +@dataclass +class RLHFArguments(PPOArguments, RewardModelArguments, TrainArguments): """ RLHFArguments is a dataclass that holds arguments specific to the Reinforcement Learning with Human Feedback (RLHF) training backend. @@ -62,6 +72,7 @@ class RLHFArguments(PPOArguments, TrainArguments): beta: Optional[float] = None label_smoothing: float = 0 + loss_scale: Optional[str] = None # 'last_round' # DPO rpo_alpha: float = 1. # CPO @@ -71,16 +82,8 @@ class RLHFArguments(PPOArguments, TrainArguments): # KTO desirable_weight: float = 1.0 undesirable_weight: float = 1.0 - # GRPO - num_generations: int = 8 # G in the GRPO paper - max_completion_length: int = 512 - reward_funcs: List[str] = field(default_factory=list) - # vLLM in GRPO - use_vllm: bool = False - vllm_device: Optional[str] = 'auto' # 'cuda:1' - vllm_gpu_memory_utilization: float = 0.9 - vllm_max_model_len: Optional[int] = None - loss_scale: Optional[str] = None + # PPO/GRPO + temperature: float = 0.7 def __post_init__(self): self._init_grpo() diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 732b4b4c84..e0c8909022 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -10,7 +10,7 @@ from swift.plugin import LOSS_MAPPING from swift.trainers import TrainerFactory from swift.utils import (add_version_to_work_dir, get_logger, get_pai_tensorboard_dir, is_liger_available, - is_local_master, is_mp, is_pai_training_job, use_torchacc) + is_local_master, is_mp, is_pai_training_job, use_torchacc, parse_to_dict) from .base_args import BaseArguments, to_abspath from .tuner_args import TunerArguments @@ -65,9 +65,9 @@ def __post_init__(self): else: self.learning_rate = 1e-4 if self.lr_scheduler_kwargs: - self.lr_scheduler_kwargs = self.parse_to_dict(self.lr_scheduler_kwargs) + self.lr_scheduler_kwargs = parse_to_dict(self.lr_scheduler_kwargs) if getattr(self, 'gradient_checkpointing_kwargs', None): - self.gradient_checkpointing_kwargs = self.parse_to_dict(self.gradient_checkpointing_kwargs) + self.gradient_checkpointing_kwargs = parse_to_dict(self.gradient_checkpointing_kwargs) self._init_eval_strategy() diff --git a/swift/trainers/rlhf_arguments.py b/swift/trainers/rlhf_arguments.py index 350b9d44a0..5136a6393e 100644 --- a/swift/trainers/rlhf_arguments.py +++ b/swift/trainers/rlhf_arguments.py @@ -8,7 +8,7 @@ from trl import PPOConfig as HfPPOConfig from trl import RewardConfig as HfRewardConfig -from .arguments import SwiftArgumentsMixin +from .arguments import SwiftArgumentsMixin, VllmArguments @dataclass @@ -42,5 +42,5 @@ class PPOConfig(SwiftArgumentsMixin, HfPPOConfig): @dataclass -class GRPOConfig(SwiftArgumentsMixin, HfGRPOConfig): +class GRPOConfig(VllmArguments, SwiftArgumentsMixin, HfGRPOConfig): pass diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index b9c49c7503..7e9dc7c39e 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -12,5 +12,6 @@ from .torch_utils import (Serializer, activate_parameters, find_all_linears, find_embedding, find_norm, freeze_parameters, get_model_parameter_info, safe_ddp_context, show_layers, time_synchronize) from .utils import (add_version_to_work_dir, check_json_format, deep_getattr, find_free_port, get_env_args, lower_bound, - parse_args, patch_getattr, read_multi_line, seed_everything, split_list, subprocess_run, test_time, + patch_getattr, read_multi_line, seed_everything, split_list, subprocess_run, test_time, upper_bound) +from .argumens import parse_to_dict, parse_args diff --git a/swift/utils/argumens.py b/swift/utils/argumens.py new file mode 100644 index 0000000000..b0830f7bf2 --- /dev/null +++ b/swift/utils/argumens.py @@ -0,0 +1,101 @@ + + +@dataclass +class VllmArguments: + """ + VllmArguments is a dataclass that holds the configuration for vllm. + + Args: + gpu_memory_utilization (float): GPU memory utilization. Default is 0.9. + tensor_parallel_size (int): Tensor parallelism size. Default is 1. + pipeline_parallel_size(int): Pipeline parallelism size. Default is 1. + max_num_seqs (int): Maximum number of sequences. Default is 256. + max_model_len (Optional[int]): Maximum model length. Default is None. + disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is False. + enforce_eager (bool): Flag to enforce eager execution. Default is False. + limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None. + vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16. + enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False. + """ + # vllm + gpu_memory_utilization: float = 0.9 + tensor_parallel_size: int = 1 + pipeline_parallel_size: int = 1 + max_num_seqs: int = 256 + max_model_len: Optional[int] = None + disable_custom_all_reduce: bool = False + enforce_eager: bool = False + limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 10, "video": 5}' + vllm_max_lora_rank: int = 16 + enable_prefix_caching: bool = False + + def __post_init__(self): + self.limit_mm_per_prompt = ModelArguments.parse_to_dict(self.limit_mm_per_prompt) + + def get_vllm_engine_kwargs(self): + adapters = self.adapters + if hasattr(self, 'adapter_mapping'): + adapters = adapters + list(self.adapter_mapping.values()) + return { + 'gpu_memory_utilization': self.gpu_memory_utilization, + 'tensor_parallel_size': self.tensor_parallel_size, + 'pipeline_parallel_size': self.pipeline_parallel_size, + 'max_num_seqs': self.max_num_seqs, + 'max_model_len': self.max_model_len, + 'disable_custom_all_reduce': self.disable_custom_all_reduce, + 'enforce_eager': self.enforce_eager, + 'limit_mm_per_prompt': self.limit_mm_per_prompt, + 'max_lora_rank': self.vllm_max_lora_rank, + 'enable_lora': len(adapters) > 0, + 'max_loras': max(len(adapters), 1), + 'enable_prefix_caching': self.enable_prefix_caching, + } + + +@dataclass +class LmdeployArguments: + """ + LmdeployArguments is a dataclass that holds the configuration for lmdeploy. + + Args: + tp (int): Tensor parallelism size. Default is 1. + session_len(Optional[int]): The session length, default None. + cache_max_entry_count (float): Maximum entry count for cache. Default is 0.8. + quant_policy (int): Quantization policy, e.g., 4, 8. Default is 0. + vision_batch_size (int): Maximum batch size in VisionConfig. Default is 1. + """ + + # lmdeploy + tp: int = 1 + session_len: Optional[int] = None + cache_max_entry_count: float = 0.8 + quant_policy: int = 0 # e.g. 4, 8 + vision_batch_size: int = 1 # max_batch_size in VisionConfig + + def get_lmdeploy_engine_kwargs(self): + return { + 'tp': self.tp, + 'session_len': self.session_len, + 'cache_max_entry_count': self.cache_max_entry_count, + 'quant_policy': self.quant_policy, + 'vision_batch_size': self.vision_batch_size + } + + +def parse_to_dict(value: Union[str, Dict, None], strict: bool = True) -> Union[str, Dict]: + """Convert a JSON string or JSON file into a dict""" + # If the value could potentially be a string, it is generally advisable to set strict to False. + if value is None: + value = {} + elif isinstance(value, str): + if os.path.exists(value): # local path + with open(value, 'r', encoding='utf-8') as f: + value = json.load(f) + else: # json str + try: + value = json.loads(value) + except json.JSONDecodeError: + if strict: + logger.error(f"Unable to parse string: '{value}'") + raise + return value \ No newline at end of file