Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/e2e_ppo_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,15 @@ jobs:
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving
run: |
ray stop --force
VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/ppo_trainer/run_function_reward.sh
VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True bash tests/e2e/ppo_trainer/run_function_reward.sh
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm after resuming
run: |
ray stop --force
RESUME_MODE=auto bash tests/e2e/ppo_trainer/run_function_reward.sh
- name: Test FSDP checkpoints merging function (Qwen Actor)
run: |
exp_name="qwen2.5-0.5b-function-reward-minimal"
python scripts/model_merger.py --backend fsdp --hf_model_path ~/models/Qwen/Qwen2.5-0.5B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface
- name: Running GSM8K E2E without rmpad using function rm
run: |
ray stop --force
Expand Down
43 changes: 40 additions & 3 deletions scripts/model_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,40 @@ def upload_model_to_huggingface(hf_path):
api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")


def test_fsdp_state_dict(
auto_model_class,
original_hf_model_path: str,
collected_state_dict: Dict[str, torch.Tensor],
) -> bool:
# load original model using bf16 since we collected state_dict with bf16
original_model = auto_model_class.from_pretrained(original_hf_model_path, torch_dtype=torch.bfloat16)
original_state_dict = original_model.state_dict()
del original_model # Free memory

original_keys = set(original_state_dict.keys())
collected_keys = set(collected_state_dict.keys())

missing_keys = original_keys - collected_keys
assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}"

extra_keys = collected_keys - original_keys
assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}"

for key in original_keys:
original_shape = original_state_dict[key].shape
collected_shape = collected_state_dict[key].shape
assert original_shape == collected_shape, f"Shape mismatch for key '{key}': original {original_shape} vs collected {collected_shape}"

original_dtype = original_state_dict[key].dtype
collected_dtype = collected_state_dict[key].dtype
assert original_dtype == collected_dtype, f"Dtype mismatch for key '{key}': original {original_dtype} vs collected {collected_dtype}"

torch.testing.assert_close(original_state_dict[key], collected_state_dict[key], atol=1e-4, rtol=1e-4)

print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.")
return True


def patch_model_generation_config(model, hf_model_path):
"""
The generation_config created from model config may be different to the pretrained model,
Expand All @@ -94,9 +128,9 @@ def patch_model_generation_config(model, hf_model_path):
"""
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(args.hf_model_path)
model.generation_config = GenerationConfig.from_pretrained(hf_model_path)
except OSError:
print(f"Warning: Generation config file not found in {args.hf_model_path}, using a generation config created from the model config.")
print(f"Warning: Generation config file not found in {hf_model_path}, using a generation config created from the model config.")
pass
return model

Expand Down Expand Up @@ -200,7 +234,6 @@ def process_one_shard(rank, model_state_dict_lst):
else:
state_dict[key] = torch.cat(state_dict[key], dim=0)

print("Writing to local disk")
hf_path = os.path.join(local_dir, "huggingface") if args.target_dir is None else args.target_dir
config = AutoConfig.from_pretrained(args.hf_model_path)

Expand All @@ -213,6 +246,10 @@ def process_one_shard(rank, model_state_dict_lst):
else:
raise NotImplementedError(f"Unknown architecture {config['architectures']}")

if args.test:
print("Running compatibility test")
test_fsdp_state_dict(auto_model, args.test_hf_dir, state_dict)

with torch.device("meta"):
model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
model.to_empty(device="cpu")
Expand Down
10 changes: 10 additions & 0 deletions tests/e2e/ppo_trainer/run_function_reward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ RESUME_MODE=${RESUME_MODE:-disable}
SAVE_FREQ=${SAVE_FREQ:--1}
TOT_TRAIN_STEPS=${TOT_TRAIN_STEPS:-1}

# whether to save hf_model
SAVE_HF_MODEL=${SAVE_HF_MODEL:-False}

if [ "${SAVE_HF_MODEL}" = "True" ]; then
CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']"
else
CHECKPOINT_CONTENTS="['model','optimizer','extra']"
fi

train_traj_micro_bsz_per_gpu=2 # b
n_resp_per_prompt=4 # g

Expand Down Expand Up @@ -70,6 +79,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \
actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \
actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
Expand Down
72 changes: 61 additions & 11 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import torch
import torch.distributed
from torch.distributed.fsdp import FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin

from verl.utils.fs import copy_to_local, is_non_local
from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx
Expand Down Expand Up @@ -150,19 +150,69 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None
torch.save(extra_state_dict, extra_path)

if "hf_model" in self.checkpoint_contents:
# wait for everyone to dump to local
torch.distributed.barrier()

if self.rank == 0:
hf_local_path = os.path.join(local_path, "huggingface")
os.makedirs(hf_local_path, exist_ok=True)
if fsdp_version(self.model) == 1:
self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path)
unwrap_model = self.model._fsdp_wrapped_module
else:
unwrap_model = self.model

model_config = unwrap_model.config
if unwrap_model.can_generate() and hasattr(model_config, "name_or_path") and model_config.name_or_path:
# Some model's name_or_path is empty if not initialized from pretrained,
# in this cases, we don't save generation config.
generation_config = GenerationConfig.from_pretrained(model_config.name_or_path)
generation_config.save_pretrained(local_path)
else:
self.model.config.save_pretrained(hf_local_path)
self.processing_class.save_pretrained(hf_local_path)
generation_config = None

model_config.save_pretrained(local_path)
self.processing_class.save_pretrained(local_path)

# wait for everyone to dump to local
torch.distributed.barrier()

if "hf_model" in self.checkpoint_contents:
hf_local_path = os.path.join(local_path, "huggingface")
os.makedirs(hf_local_path, exist_ok=True)

# Only rank 0 will save hf model and,
# offload to cpu to save LLMs which may be too large to fit in one GPU
state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None):
state_dict = self.model.state_dict()

if self.rank == 0:
if "ForTokenClassification" in model_config.architectures[0]:
from transformers import AutoModelForTokenClassification

auto_model_cls = AutoModelForTokenClassification
elif "ForCausalLM" in model_config.architectures[0]:
from transformers import AutoModelForCausalLM

auto_model_cls = AutoModelForCausalLM
elif "ForConditionalGeneration" in model_config.architectures[0]:
from transformers import AutoModelForVision2Seq

auto_model_cls = AutoModelForVision2Seq
else:
raise NotImplementedError(f"Unknown architecture {model_config['architectures']}")

with torch.device("meta"):
save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)
save_model.to_empty(device="cpu")

if save_model.can_generate():
if generation_config is not None:
save_model.generation_config = generation_config
else:
print(f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found in, using a generation config created from the model config when saving hf_model.")

save_model.save_pretrained(hf_local_path, state_dict=state_dict)
self.processing_class.save_pretrained(hf_local_path)
del state_dict
del save_model

# wait for rank0 to dump hf_model to local
torch.distributed.barrier()

self.previous_saved_paths.append(local_path)