diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index 65e507d2ade..9fc1821b006 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -84,11 +84,15 @@ jobs: - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving run: | ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/ppo_trainer/run_function_reward.sh + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True bash tests/e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm after resuming run: | ray stop --force RESUME_MODE=auto bash tests/e2e/ppo_trainer/run_function_reward.sh + - name: Test FSDP checkpoints merging function (Qwen Actor) + run: | + exp_name="qwen2.5-0.5b-function-reward-minimal" + python scripts/model_merger.py --backend fsdp --hf_model_path ~/models/Qwen/Qwen2.5-0.5B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - name: Running GSM8K E2E without rmpad using function rm run: | ray stop --force diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 02b246bb997..590f4508c04 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -85,6 +85,40 @@ def upload_model_to_huggingface(hf_path): api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model") +def test_fsdp_state_dict( + auto_model_class, + original_hf_model_path: str, + collected_state_dict: Dict[str, torch.Tensor], +) -> bool: + # load original model using bf16 since we collected state_dict with bf16 + original_model = auto_model_class.from_pretrained(original_hf_model_path, torch_dtype=torch.bfloat16) + original_state_dict = original_model.state_dict() + del original_model # Free memory + + original_keys = set(original_state_dict.keys()) + collected_keys = set(collected_state_dict.keys()) + + missing_keys = original_keys - collected_keys + assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" + + extra_keys = collected_keys - original_keys + assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" + + for key in original_keys: + original_shape = original_state_dict[key].shape + collected_shape = collected_state_dict[key].shape + assert original_shape == collected_shape, f"Shape mismatch for key '{key}': original {original_shape} vs collected {collected_shape}" + + original_dtype = original_state_dict[key].dtype + collected_dtype = collected_state_dict[key].dtype + assert original_dtype == collected_dtype, f"Dtype mismatch for key '{key}': original {original_dtype} vs collected {collected_dtype}" + + torch.testing.assert_close(original_state_dict[key], collected_state_dict[key], atol=1e-4, rtol=1e-4) + + print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") + return True + + def patch_model_generation_config(model, hf_model_path): """ The generation_config created from model config may be different to the pretrained model, @@ -94,9 +128,9 @@ def patch_model_generation_config(model, hf_model_path): """ if model.can_generate(): try: - model.generation_config = GenerationConfig.from_pretrained(args.hf_model_path) + model.generation_config = GenerationConfig.from_pretrained(hf_model_path) except OSError: - print(f"Warning: Generation config file not found in {args.hf_model_path}, using a generation config created from the model config.") + print(f"Warning: Generation config file not found in {hf_model_path}, using a generation config created from the model config.") pass return model @@ -200,7 +234,6 @@ def process_one_shard(rank, model_state_dict_lst): else: state_dict[key] = torch.cat(state_dict[key], dim=0) - print("Writing to local disk") hf_path = os.path.join(local_dir, "huggingface") if args.target_dir is None else args.target_dir config = AutoConfig.from_pretrained(args.hf_model_path) @@ -213,6 +246,10 @@ def process_one_shard(rank, model_state_dict_lst): else: raise NotImplementedError(f"Unknown architecture {config['architectures']}") + if args.test: + print("Running compatibility test") + test_fsdp_state_dict(auto_model, args.test_hf_dir, state_dict) + with torch.device("meta"): model = auto_model.from_config(config, torch_dtype=torch.bfloat16) model.to_empty(device="cpu") diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index 89bae9eee11..0c842e86ffe 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -30,6 +30,15 @@ RESUME_MODE=${RESUME_MODE:-disable} SAVE_FREQ=${SAVE_FREQ:--1} TOT_TRAIN_STEPS=${TOT_TRAIN_STEPS:-1} +# whether to save hf_model +SAVE_HF_MODEL=${SAVE_HF_MODEL:-False} + +if [ "${SAVE_HF_MODEL}" = "True" ]; then + CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']" +else + CHECKPOINT_CONTENTS="['model','optimizer','extra']" +fi + train_traj_micro_bsz_per_gpu=2 # b n_resp_per_prompt=4 # g @@ -70,6 +79,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \ + actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \ actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index a3e4303e6c7..7df4b39a60d 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -18,9 +18,9 @@ import torch import torch.distributed +from torch.distributed.fsdp import FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType -from transformers import PreTrainedTokenizer, ProcessorMixin +from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin from verl.utils.fs import copy_to_local, is_non_local from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx @@ -150,19 +150,69 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None torch.save(extra_state_dict, extra_path) - if "hf_model" in self.checkpoint_contents: - # wait for everyone to dump to local - torch.distributed.barrier() - if self.rank == 0: - hf_local_path = os.path.join(local_path, "huggingface") - os.makedirs(hf_local_path, exist_ok=True) if fsdp_version(self.model) == 1: - self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path) + unwrap_model = self.model._fsdp_wrapped_module + else: + unwrap_model = self.model + + model_config = unwrap_model.config + if unwrap_model.can_generate() and hasattr(model_config, "name_or_path") and model_config.name_or_path: + # Some model's name_or_path is empty if not initialized from pretrained, + # in this cases, we don't save generation config. + generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + generation_config.save_pretrained(local_path) else: - self.model.config.save_pretrained(hf_local_path) - self.processing_class.save_pretrained(hf_local_path) + generation_config = None + model_config.save_pretrained(local_path) + self.processing_class.save_pretrained(local_path) + + # wait for everyone to dump to local torch.distributed.barrier() + if "hf_model" in self.checkpoint_contents: + hf_local_path = os.path.join(local_path, "huggingface") + os.makedirs(hf_local_path, exist_ok=True) + + # Only rank 0 will save hf model and, + # offload to cpu to save LLMs which may be too large to fit in one GPU + state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None): + state_dict = self.model.state_dict() + + if self.rank == 0: + if "ForTokenClassification" in model_config.architectures[0]: + from transformers import AutoModelForTokenClassification + + auto_model_cls = AutoModelForTokenClassification + elif "ForCausalLM" in model_config.architectures[0]: + from transformers import AutoModelForCausalLM + + auto_model_cls = AutoModelForCausalLM + elif "ForConditionalGeneration" in model_config.architectures[0]: + from transformers import AutoModelForVision2Seq + + auto_model_cls = AutoModelForVision2Seq + else: + raise NotImplementedError(f"Unknown architecture {model_config['architectures']}") + + with torch.device("meta"): + save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16) + save_model.to_empty(device="cpu") + + if save_model.can_generate(): + if generation_config is not None: + save_model.generation_config = generation_config + else: + print(f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found in, using a generation config created from the model config when saving hf_model.") + + save_model.save_pretrained(hf_local_path, state_dict=state_dict) + self.processing_class.save_pretrained(hf_local_path) + del state_dict + del save_model + + # wait for rank0 to dump hf_model to local + torch.distributed.barrier() + self.previous_saved_paths.append(local_path)