Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/e2e_ppo_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,15 @@ jobs:
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving
run: |
ray stop --force
VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/ppo_trainer/run_function_reward.sh
VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True bash tests/e2e/ppo_trainer/run_function_reward.sh
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm after resuming
run: |
ray stop --force
RESUME_MODE=auto bash tests/e2e/ppo_trainer/run_function_reward.sh
- name: Test FSDP checkpoints merging function (Qwen Actor)
run: |
exp_name="qwen2.5-0.5b-function-reward-minimal"
python scripts/model_merger.py --backend fsdp --hf_model_path Qwen/Qwen2.5-0.5B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface
- name: Running GSM8K E2E without rmpad using function rm
run: |
ray stop --force
Expand Down
49 changes: 40 additions & 9 deletions scripts/model_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@
"--local_dir",
type=str,
required=True,
help=(
"The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, "
"commonly be `config.default_local_dir/global_step_\{global_step\}`."
),
help=("The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, commonly be `config.default_local_dir/global_step_\{global_step\}`."),
)
parser.add_argument("--target_dir", required=False, default="tmp", type=str, help="The path for the target model")
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
Expand Down Expand Up @@ -85,6 +82,40 @@ def upload_model_to_huggingface(hf_path):
api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")


def test_fsdp_state_dict(
auto_model_class,
original_hf_model_path: str,
collected_state_dict: Dict[str, torch.Tensor],
) -> bool:
# load original model using bf16 since we collected state_dict with bf16
original_model = auto_model_class.from_pretrained(original_hf_model_path, torch_dtype=torch.bfloat16)
original_state_dict = original_model.state_dict()
del original_model # Free memory

original_keys = set(original_state_dict.keys())
collected_keys = set(collected_state_dict.keys())

missing_keys = original_keys - collected_keys
assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}"

extra_keys = collected_keys - original_keys
assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}"

for key in original_keys:
original_shape = original_state_dict[key].shape
collected_shape = collected_state_dict[key].shape
assert original_shape == collected_shape, f"Shape mismatch for key '{key}': original {original_shape} vs collected {collected_shape}"

original_dtype = original_state_dict[key].dtype
collected_dtype = collected_state_dict[key].dtype
assert original_dtype == collected_dtype, f"Dtype mismatch for key '{key}': original {original_dtype} vs collected {collected_dtype}"

torch.testing.assert_close(original_state_dict[key], collected_state_dict[key], atol=1e-4, rtol=1e-4)

print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.")
return True


def patch_model_generation_config(model, hf_model_path):
"""
The generation_config created from model config may be different to the pretrained model,
Expand All @@ -96,10 +127,7 @@ def patch_model_generation_config(model, hf_model_path):
try:
model.generation_config = GenerationConfig.from_pretrained(args.hf_model_path)
except OSError:
print(
f"Warning: Generation config file not found in {args.hf_model_path}, "
"using a generation config created from the model config."
)
print(f"Warning: Generation config file not found in {args.hf_model_path}, using a generation config created from the model config.")
pass
return model

Expand Down Expand Up @@ -203,7 +231,6 @@ def process_one_shard(rank, model_state_dict_lst):
else:
state_dict[key] = torch.cat(state_dict[key], dim=0)

print("Writing to local disk")
hf_path = os.path.join(local_dir, "huggingface") if args.target_dir is None else args.target_dir
config = AutoConfig.from_pretrained(args.hf_model_path)

Expand All @@ -216,6 +243,10 @@ def process_one_shard(rank, model_state_dict_lst):
else:
raise NotImplementedError(f"Unknown architecture {config['architectures']}")

if args.test:
print("Running compatibility test")
test_fsdp_state_dict(auto_model, args.hf_model_path, state_dict)

with torch.device("meta"):
model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
model.to_empty(device="cpu")
Expand Down
10 changes: 10 additions & 0 deletions tests/e2e/ppo_trainer/run_function_reward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ RESUME_MODE=${RESUME_MODE:-disable}
SAVE_FREQ=${SAVE_FREQ:--1}
TOT_TRAIN_STEPS=${TOT_TRAIN_STEPS:-1}

# whether to save hf_model
SAVE_HF_MODEL=${SAVE_HF_MODEL:-False}

if [ "${SAVE_HF_MODEL}" = "True" ]; then
CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']"
else
CHECKPOINT_CONTENTS="['model','optimizer','extra']"
fi

train_traj_micro_bsz_per_gpu=2 # b
n_resp_per_prompt=4 # g

Expand Down Expand Up @@ -70,6 +79,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \
actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \
actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
Expand Down