Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
cc2df6b
add xdg ulysses
Jun 9, 2025
95f3466
add grpo scripts
Jun 9, 2025
86a66a6
适配redmoe+mcore by光速
Jun 26, 2025
75bd461
Bump from guangsu
Jul 12, 2025
28a0dd9
[feat] Add async-rl with param-sync and async-pipeline
ziqi-wlb Jul 14, 2025
c1f94a5
Update README
ziqi-wlb Aug 19, 2025
6a3e533
Refine code
ziqi-wlb Aug 19, 2025
ae10015
rebase to main
ziqi-wlb Aug 20, 2025
15e7718
add offload-grad for megatron-worker
ziqi-wlb Aug 20, 2025
ad39348
Refine code
ziqi-wlb Aug 20, 2025
c7e0216
Refine code
ziqi-wlb Aug 20, 2025
d1914e5
Refine code
ziqi-wlb Aug 20, 2025
6e42f66
Fix save checkpoint
ziqi-wlb Aug 21, 2025
f319332
Merge from feat/async-ref-logp
ziqi-wlb Aug 29, 2025
e4619d7
Fix pp param-sync
ziqi-wlb Aug 29, 2025
56a34c1
Fallback to per-tensor-generator and fix load-checkpoint
ziqi-wlb Sep 2, 2025
c3ecc80
Support valid skip first-val-step
ziqi-wlb Sep 2, 2025
68fdedc
Fix ref-model path for resume
ziqi-wlb Sep 2, 2025
06a9493
Fix cpu-oom for large model
ziqi-wlb Sep 4, 2025
c2f4988
Add memory_efficient_mode to fallback to single buffer for param-update
ziqi-wlb Sep 4, 2025
d7ae3ca
Add clear buffer for param-update
ziqi-wlb Sep 4, 2025
ce683e3
Add overlap logp and recv
ziqi-wlb Sep 4, 2025
9bce5a0
Add nccl-sync for param-update
ziqi-wlb Sep 4, 2025
b77c224
WIP: debug for nccl-sync
ziqi-wlb Sep 4, 2025
ecf14f8
WIP: hang at step2 param-update
ziqi-wlb Sep 5, 2025
ee4db3a
Fix hang for param-update when nccl-sync
ziqi-wlb Sep 5, 2025
dd9b5dc
Porting: support dots model register for engine
ziqi-wlb Sep 5, 2025
54716bd
Fix hang for infer_tp>1
ziqi-wlb Sep 8, 2025
94ad7da
Refine code for async-param
ziqi-wlb Sep 8, 2025
473669d
optimize for param-update nccl
ziqi-wlb Sep 9, 2025
3306bf4
optimize for param-update nccl: 3.5s->2s
ziqi-wlb Sep 9, 2025
247e30a
enable mem clear and refine log
ziqi-wlb Sep 10, 2025
e4f945b
refine mem clear
ziqi-wlb Sep 10, 2025
dbdd2c8
Merge conflict for ci-check
ziqi-wlb Sep 10, 2025
8799aac
Add hilab license and refine for ci-check
ziqi-wlb Sep 10, 2025
e04a74c
Refine for pre-commit ci
ziqi-wlb Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipe/dapo/dapo_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
This trainer supports model-agonistic model initialization with huggingface
"""

import os
import uuid
from collections import defaultdict
from copy import deepcopy
Expand All @@ -24,7 +25,6 @@
import numpy as np
import torch
from tqdm import tqdm
import os

from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss
Expand Down
66 changes: 50 additions & 16 deletions recipe/grpo/grpo_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import ray
import time

import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint

import numpy as np
import ray
import torch
from tqdm import tqdm

Expand All @@ -34,7 +33,14 @@
compute_timing_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, compute_response_mask
from verl.trainer.ppo.ray_trainer import (
AdvantageEstimator,
RayPPOTrainer,
_timer,
apply_kl_penalty,
compute_advantage,
compute_response_mask,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async


Expand Down Expand Up @@ -127,7 +133,9 @@ def fit(self):

del gen_baseline_batch, gen_baseline_output

batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
Expand Down Expand Up @@ -161,7 +169,9 @@ def fit(self):
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
entropy_loss = agg_loss(
loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode
)
old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
Expand Down Expand Up @@ -216,21 +226,39 @@ def fit(self):
print(f"{list(reward_extra_infos_dict.keys())=}")
if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
metrics.update({
**{f"critic/rewards/{k}/mean": np.mean(v) for k, v in reward_extra_infos_dict.items() if '_sub' in k},
**{f"critic/rewards/{k}/max": np.max(v) for k, v in reward_extra_infos_dict.items() if '_sub' in k},
**{f"critic/rewards/{k}/min": np.min(v) for k, v in reward_extra_infos_dict.items() if '_sub' in k},
})
metrics.update(
{
**{
f"critic/rewards/{k}/mean": np.mean(v)
for k, v in reward_extra_infos_dict.items()
if "_sub" in k
},
**{
f"critic/rewards/{k}/max": np.max(v)
for k, v in reward_extra_infos_dict.items()
if "_sub" in k
},
**{
f"critic/rewards/{k}/min": np.min(v)
for k, v in reward_extra_infos_dict.items()
if "_sub" in k
},
}
)
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

# compute advantages, executed on the driver process

norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) # GRPO adv normalization factor
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor

batch = compute_advantage(
batch,
Expand Down Expand Up @@ -278,14 +306,20 @@ def fit(self):
)

# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with _timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)

if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()

Expand All @@ -311,4 +345,4 @@ def fit(self):
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
return
62 changes: 45 additions & 17 deletions recipe/grpo/main_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
"""

import hydra
import pandas as pd
import ray
from torch.utils.data import Dataset

import pandas as pd
from .grpo_ray_trainer import RayGRPOTrainer
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.reward import load_reward_manager
from torch.utils.data import Dataset
from verl.utils.dataset.rl_dataset import RLHFDataset as OriginalRLHFDataset


Expand All @@ -30,11 +30,11 @@ def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.data_files:
# read parquet files and cache
if parquet_file.endswith('parquet'):
if parquet_file.endswith("parquet"):
dataframe = pd.read_parquet(parquet_file)
elif parquet_file.endswith('json'):
elif parquet_file.endswith("json"):
dataframe = pd.read_json(parquet_file)
elif parquet_file.endswith('jsonl'):
elif parquet_file.endswith("jsonl"):
chunks = []
for chunk in pd.read_json(
parquet_file,
Expand All @@ -51,23 +51,25 @@ def _read_files_and_tokenize(self):

print(f"dataset len: {len(self.dataframe)}")

if self.config.data.get('system_prompt', None) is not None:
if self.config.data.get("system_prompt", None) is not None:
system_prompt = self.config.data.system_prompt
self.dataframe[self.prompt_key] = self.dataframe[self.prompt_key].apply(
lambda x: [{'role': 'system', 'content': system_prompt}]+x
lambda x: [{"role": "system", "content": system_prompt}] + x
)
# filter out too long prompts
if self.filter_overlong_prompts:
tokenizer = self.tokenizer
prompt_key = self.prompt_key
self.dataframe = self.dataframe.filter(
lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True))
<= self.max_prompt_length,
num_proc=self.num_workers,
desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
)

print(f"filter dataset len: {len(self.dataframe)}")


@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
run_grpo(config)
Expand All @@ -77,7 +79,14 @@ def run_grpo(config) -> None:
if not ray.is_initialized():
# this is for local ray cluster
ray.init(
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN", "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true"}},
runtime_env={
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "WARN",
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true",
}
},
num_cpus=config.ray_init.num_cpus,
)

Expand All @@ -103,14 +112,18 @@ def run(self, config):
OmegaConf.resolve(config)

# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False))
local_path = copy_to_local(
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
)

# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer

trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) # used for multimodal LLM, could be none
processor = hf_processor(
local_path, trust_remote_code=trust_remote_code, use_fast=True
) # used for multimodal LLM, could be none

# vllm early verify
if config.actor_rollout_ref.rollout.name in ["vllm"]:
Expand All @@ -126,15 +139,23 @@ def run(self, config):
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker

actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
actor_rollout_cls = (
AsyncActorRolloutRefWorker
if config.actor_rollout_ref.rollout.mode == "async"
else ActorRolloutRefWorker
)
ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker

actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
actor_rollout_cls = (
AsyncActorRolloutRefWorker
if config.actor_rollout_ref.rollout.mode == "async"
else ActorRolloutRefWorker
)
ray_worker_group_cls = NVMegatronRayWorkerGroup

else:
Expand Down Expand Up @@ -177,8 +198,12 @@ def run(self, config):
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id

reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}))
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}))
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
)
val_reward_fn = load_reward_manager(
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

from verl.utils.dataset.rl_dataset import collate_fn
Expand Down Expand Up @@ -222,7 +247,10 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor):

dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
if not issubclass(dataset_cls, Dataset):
raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset")
raise TypeError(
f"The custom dataset class '{data_config.custom_cls.name}' from "
f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset"
)
else:
dataset_cls = RLHFDataset
print(f"Using dataset class: {dataset_cls.__name__}")
Expand Down
7 changes: 4 additions & 3 deletions recipe/langgraph_agent/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple
tuple[str, list[int], list[int]]: Request id, prompt ids, response mask.
"""
# messages: [system], human, ai, human|tool, ai, human|tool, ...
assert messages[-1].type in ["human", "tool"], (
f"Last message must be human or tool, but got {messages[-1].type}"
)
assert messages[-1].type in [
"human",
"tool",
], f"Last message must be human or tool, but got {messages[-1].type}"
loop = asyncio.get_running_loop()

# Case 1: initial chat completion: [system], human
Expand Down
Loading