Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions verl/single_controller/base/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def init_predefined_dispatch_mode():
Dispatch.register("DP_COMPUTE_PROTO")
Dispatch.register("DP_COMPUTE_PROTO_WITH_FUNC")
Dispatch.register("DP_COMPUTE_METRIC")
Dispatch.register("DP_DISPATCH")
# This is a special dispatch mode for vllm ExternalRayDistributedExecutor
Dispatch.register("DIRECT_ROLLOUT_METHOD")

Expand Down Expand Up @@ -135,6 +136,12 @@ def dispatch_all_to_all(worker_group, *args, **kwargs):
def collect_all_to_all(worker_group, output):
return output

def dispatch_dp(worker_group, *args, **kwargs):
return args, kwargs

def collect_dp(worker_group, output):
return output


def dispatch_megatron_compute(worker_group, *args, **kwargs):
"""
Expand Down Expand Up @@ -415,6 +422,7 @@ def collect_dp_compute_data_proto(worker_group, output):
"dispatch_fn": dummy_direct_rollout_call,
"collect_fn": dummy_direct_rollout_call,
},
Dispatch.DP_DISPATCH: {"dispatch_fn": dispatch_dp, "collect_fn": collect_dp},
}


Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ top_k: -1
# Top-p sampling parameter. Default 1.0.
top_p: 1

# over sampling batch size
over_sampling_batch_size: 0
rollout_batch_size: 0
partial_rollout: false

# typically the same as data max prompt length
# same as data.max_prompt_length if it exists
prompt_length: ${oc.select:data.max_prompt_length,512}
Expand Down
95 changes: 32 additions & 63 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
from verl.workers.rollout.rollout_manager import RolloutManager

WorkerType = type[Worker]

Expand Down Expand Up @@ -732,22 +733,8 @@ def _validate(self):
}
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

# pad to be divisible by dp_size
size_divisor = (
self.actor_rollout_wg.world_size
if not self.async_rollout_mode
else self.config.actor_rollout_ref.rollout.agent.num_workers
)
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
if not self.async_rollout_mode:
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
else:
test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)

# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
test_output_gen_batch = self.generate_sequences(test_gen_batch)

print("validation generation end")

# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
Expand Down Expand Up @@ -830,6 +817,10 @@ def init_workers(self):
self.resource_pool_manager.create_resource_pool()

self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
self.rollout_manager = RolloutManager.options(num_cpus=1, num_gpus=0).remote(config=self.config)
self.sglang_router_ip, self.sglang_router_port = ray.get(
self.rollout_manager.get_sglang_router_ip_and_port.remote()
)

# create actor and rollout
if self.hybrid_engine:
Expand All @@ -839,6 +830,8 @@ def init_workers(self):
config=self.config.actor_rollout_ref,
role="actor_rollout",
profile_option=self.config.trainer.npu_profile.options,
sglang_router_ip=self.sglang_router_ip,
sglang_router_port=self.sglang_router_port,
)
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
else:
Expand Down Expand Up @@ -1078,6 +1071,19 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle
)
metrics.update(global_balance_stats)

def rollout(self) -> DataProto:
self.actor_rollout_wg.prepare_for_generate()
batch = ray.get(self.rollout_manager.rollout.remote())
self.actor_rollout_wg.finish_generate()
return batch

def generate_sequences(self, batch: DataProto) -> DataProto:
# For compatibility with verl's original generate_sequences
self.actor_rollout_wg.prepare_for_generate()
batch = ray.get(self.rollout_manager.generate_sequences.remote(batch))
self.actor_rollout_wg.finish_generate()
return batch

def fit(self):
"""
The training loop of PPO.
Expand All @@ -1103,13 +1109,13 @@ def fit(self):

# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
# val_metrics = self._validate()
# assert val_metrics, f"{val_metrics=}"
# pprint(f"Initial validation metrics: {val_metrics}")
# logger.log(data=val_metrics, step=self.global_steps)
# if self.config.trainer.get("val_only", False):
# return

# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
Expand All @@ -1118,9 +1124,10 @@ def fit(self):
self.global_steps += 1
last_val_metrics = None
self.max_steps_duration = 0
per_epoch_iters = ray.get(self.rollout_manager.get_num_rollout_per_epoch.remote())

for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
for batch_dict in range(per_epoch_iters):
metrics = {}
timing_raw = {}

Expand All @@ -1132,45 +1139,14 @@ def fit(self):
with marked_timer("start_profile", timing_raw):
self._start_profiling(do_profile)

batch: DataProto = DataProto.from_single_dict(batch_dict)

# pop those keys for generation
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_data" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("interaction_kwargs")
if "index" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("index")
if "agent_name" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("agent_name")

gen_batch = batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)

# pass global_steps to trace
gen_batch.meta_info["global_steps"] = self.global_steps
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

is_last_step = self.global_steps >= self.total_training_steps

with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)

batch = self.rollout()
gen_batch = batch
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
Expand All @@ -1189,13 +1165,6 @@ def fit(self):

del gen_baseline_batch, gen_baseline_output

batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

if "response_mask" not in batch.batch.keys():
batch.batch["response_mask"] = compute_response_mask(batch)
# Balance the number of valid tokens across DP ranks.
Expand Down
22 changes: 21 additions & 1 deletion verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):

def __init__(self, config: DictConfig, role: str, **kwargs):
MegatronWorker.__init__(self)
from loguru import logger as log

self.config = config
self.sglang_router_ip = kwargs["sglang_router_ip"]
self.sglang_router_port = kwargs["sglang_router_port"]
Comment on lines +93 to +94
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing kwargs with [] is unsafe as it will raise a KeyError if the keys sglang_router_ip or sglang_router_port are not provided by the caller. This can lead to runtime crashes that are hard to debug. For better robustness, you should use kwargs.get() and handle the case where the keys might be missing, or explicitly check for their existence and raise a more informative ValueError.

Suggested change
self.sglang_router_ip = kwargs["sglang_router_ip"]
self.sglang_router_port = kwargs["sglang_router_port"]
self.sglang_router_ip = kwargs.get("sglang_router_ip")
self.sglang_router_port = kwargs.get("sglang_router_port")
if self.sglang_router_ip is None or self.sglang_router_port is None:
raise ValueError("sglang_router_ip and sglang_router_port must be provided for ActorRolloutRefWorker")

log.info(f"sglang_router_host: {self.sglang_router_ip}, sglang_router_port: {self.sglang_router_port}")

# NOTE(sgm): We utilize colocate WorkerGroup by default.
# As a result, Workers for different model share the same process.
Expand All @@ -110,7 +115,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size,
pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=None,
# pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=self.config.actor.megatron.context_parallel_size,
expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size,
Expand Down Expand Up @@ -358,6 +363,8 @@ def _build_rollout(self, trust_remote_code=False):
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
device_mesh=rollout_device_mesh,
sglang_router_ip=self.sglang_router_ip,
sglang_router_port=self.sglang_router_port,
)
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None)

Expand Down Expand Up @@ -576,6 +583,19 @@ def generate_sequences(self, prompts: DataProto):
get_torch_device().empty_cache()
return output

@register(dispatch_mode=Dispatch.DP_DISPATCH)
@GPUMemoryLogger(role="prepare_for_generate", logger=logger)
@DistProfiler.annotate(color="olive")
def prepare_for_generate(self):
self.sharding_manager.prepare_for_generate()

@register(dispatch_mode=Dispatch.DP_DISPATCH)
@GPUMemoryLogger(role="finish_generate", logger=logger)
@DistProfiler.annotate(color="olive")
def finish_generate(self):
self.sharding_manager.finish_generate()
get_torch_device().empty_cache()

@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
@GPUMemoryLogger(role="compute_ref_log_prob", logger=logger)
@DistProfiler.annotate(color="olive")
Expand Down
134 changes: 134 additions & 0 deletions verl/workers/rollout/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import copy
import os
import uuid
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4

import ray
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer

from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset


class Status(Enum):
PENDING = "pending"
COMPLETED = "completed"
TRUNCATED = "truncated"
ABORTED = "aborted"


class Buffer:
def __init__(self, config):
# init_wandb_secondary(args, wandb_run_id)
self.config = config

# 数据源相关属性
self.epoch_id = 0
self.sample_index = 0
self.sample_offset = 0
# TODO remove this
self.metadata = {}

# 初始化tokenizer和processor
local_path = self.config.actor_rollout_ref.model.path
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)
# Used for multimodal LLM, could be None
self.processor = hf_processor(local_path, trust_remote_code=True, use_fast=True)

# 加载RLHF数据集
rldataset = RLHFDataset(
data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
config=self.config.data,
)
self.dataset = []
for item in tqdm(rldataset, desc="Loading RLHF dataset", total=len(rldataset)):
self.dataset.append(item)

self.n_samples_per_prompt = self.config.actor_rollout_ref.rollout.n

self.buffer: List[List[Dict]] = []

def get_num_rollout_per_epoch(self):
return len(self.dataset) //self.config.actor_rollout_ref.rollout.rollout_batch_size

def _get_samples_from_data_source(self, num_samples: int) -> List[List[Dict]]:
"""从数据源获取样本,整合了原RolloutDataSource.get_samples的逻辑"""
samples = []
# TODO unify the two branches
if self.sample_offset + num_samples <= len(self.dataset):
prompt_samples = self.dataset[self.sample_offset : self.sample_offset + num_samples]
self.sample_offset += num_samples
else:
prompt_samples = self.dataset[self.sample_offset :]
num_samples -= len(prompt_samples)
# self.epoch_id += 1
# if self.args.rollout_shuffle:
# self.dataset.shuffle(self.epoch_id)
prompt_samples += self.dataset[:num_samples]
self.sample_offset = num_samples
# self.sample_offset = 0


for prompt_sample in prompt_samples:
group = []
prompt_sample["status"] = Status.PENDING
prompt_sample["response_length"] = 0
prompt_sample["response"] = []
prompt_sample["uid"] = str(uuid.uuid4())

for _ in range(self.n_samples_per_prompt):
sample = copy.deepcopy(prompt_sample)
group.append(sample)
samples.append(group)

return samples

# TODO simplify remaining logic
def get_samples(self, num_samples: int) -> List[List[Dict]]:
"""
Return num_samples samples
"""

samples = self._get_samples_from_buffer(num_samples)
num_samples -= len(samples)

if num_samples == 0:
return samples

samples += self._get_samples_from_data_source(num_samples=num_samples)
return samples

def _get_samples_from_buffer(self, num_samples: int) -> List[List[Dict]]:
if len(self.buffer) == 0 or num_samples == 0:
return []
num_to_pop = min(len(self.buffer), num_samples)
samples = self.buffer[:num_to_pop]
del self.buffer[:num_to_pop]
return samples

def add_samples(self, samples: List[List[Dict]]):
"""
Add a sample group to buffer.
"""
if not samples:
return
assert isinstance(samples, list), f"samples must be a list, got {type(samples)}"
assert isinstance(samples[0], list), f"the elements of samples must be list, got {type(samples[0])}"
for i in range(0, len(samples)):
assert len(samples[i]) == self.n_samples_per_prompt, (
f"the length of the elements of samples must be equal to n_samples_per_prompt, got {len(samples[i])} != {self.n_samples_per_prompt}"
)
group = samples[i] # type: ignore
self.buffer.append(group)

def get_buffer_length(self):
return len(self.buffer)
Loading