From 1ae7c40a1067af3099eb3cb338b699b45ccea79b Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Sun, 28 Sep 2025 16:32:13 +0800 Subject: [PATCH 1/8] Origin recipe --- .../config/transfer_queue_ppo_trainer.yaml | 7 + recipe/transfer_queue/main_ppo.py | 412 ++++++ recipe/transfer_queue/ray_trainer.py | 1236 +++++++++++++++++ .../run_qwen3-8b_transferqueue_npu.sh | 59 + 4 files changed, 1714 insertions(+) create mode 100644 recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml create mode 100644 recipe/transfer_queue/main_ppo.py create mode 100644 recipe/transfer_queue/ray_trainer.py create mode 100644 recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh diff --git a/recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml b/recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml new file mode 100644 index 00000000000..3ca7a21b4bd --- /dev/null +++ b/recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml @@ -0,0 +1,7 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ diff --git a/recipe/transfer_queue/main_ppo.py b/recipe/transfer_queue/main_ppo.py new file mode 100644 index 00000000000..1320cdaa145 --- /dev/null +++ b/recipe/transfer_queue/main_ppo.py @@ -0,0 +1,412 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other mpain. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.experimental.dataset.sampler import AbstractSampler +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import is_cuda_available +from verl.utils.import_utils import load_extern_type + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 + ): + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + + Attributes: + role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes + mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation + """ + + def __init__(self): + self.role_worker_mapping = {} + self.mapping = {} + + def add_actor_rollout_worker(self, config): + """Add actor rollout worker based on the actor strategy.""" + from verl.single_controller.ray import RayWorkerGroup + + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import Role + + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + + return actor_rollout_cls, ray_worker_group_cls + + def add_critic_worker(self, config): + """Add critic worker to role mapping.""" + if config.critic.strategy in {"fsdp", "fsdp2"}: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + from verl.workers.fsdp_workers import CriticWorker + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import CriticWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + elif config.critic.strategy == "megatron": + from verl.workers.megatron_workers import CriticWorker + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import Role + + self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) + + def init_resource_pool_mgr(self, config): + """Initialize resource pool manager.""" + from verl.trainer.ppo.ray_trainer import Role + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + # TODO Here you can use the new registration method to support dynamic registration of roles + if config.reward_model.enable_resource_pool: + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + + self.mapping[Role.ActorRollout] = global_pool_id + self.mapping[Role.Critic] = global_pool_id + from verl.trainer.ppo.ray_trainer import ResourcePoolManager + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) + return resource_pool_manager + + def add_reward_model_worker(self, config): + """Add reward model worker if enabled.""" + from verl.trainer.ppo.ray_trainer import Role + + if config.reward_model.enable: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import RewardModelWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + if config.reward_model.enable_resource_pool: + self.mapping[Role.RewardModel] = "reward_pool" + else: + self.mapping[Role.RewardModel] = "global_pool" + + def add_ref_policy_worker(self, config, ref_policy_cls): + """Add reference policy worker if KL loss or KL reward is used.""" + from verl.trainer.ppo.ray_trainer import Role + + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) + self.mapping[Role.RefPolicy] = "global_pool" + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + self.add_reward_model_worker(config) + + # Add a reference policy worker if KL loss or KL reward is used. + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(self.role_worker_mapping), + use_critic=need_critic(config), + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Load the reward manager for training and validation. + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + + resource_pool_manager = self.init_resource_pool_mgr(config) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + + # Start the training process. + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True): + """Create a dataset. + + Arguments: + data_paths: List of paths to data files. + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + from torch.utils.data import Dataset + + from verl.utils.dataset.rl_dataset import RLHFDataset + + # Check if a custom dataset class is specified in the data configuration + # and if the path to the custom class is provided + if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + # Dynamically load the custom dataset class + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Verify that the custom dataset class inherits from torch.utils.data.Dataset + if not issubclass(dataset_cls, Dataset): + raise TypeError( + f"The custom dataset class '{data_config.custom_cls.name}' from " + f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset" + ) + elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train: + # If a data generation strategy is specified, use the DynamicGenDataset class + from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset + + dataset_cls = DynamicGenDataset + print("Using DynamicGenDataset for data generation.") + else: + # Use the default RLHFDataset class if no custom class is specified + dataset_cls = RLHFDataset + print(f"Using dataset class: {dataset_cls.__name__}") + + # Instantiate the dataset using the determined dataset class + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import RandomSampler, SequentialSampler + + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: + curriculum_class = load_extern_type( + data_config.sampler.class_path, + data_config.sampler.class_name, + ) + sampler = curriculum_class( + data_source=dataset, + data_config=data_config, + ) + assert isinstance(sampler, AbstractSampler) + assert data_config.get("dataloader_num_workers", 8) == 0, ( + "If using curriculum, num_workers must be 0 to prevent data caching. " + "If the dataloader caches data before the batch is done the " + "curriculum sampler won't have the opportunity to reorder it. " + ) + + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. + elif data_config.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(data_config.get("seed", 1)) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py new file mode 100644 index 00000000000..2752f08feb2 --- /dev/null +++ b/recipe/transfer_queue/ray_trainer.py @@ -0,0 +1,1236 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import json +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass, field +from pprint import pprint +from typing import Optional + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf, open_dict +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.config import AlgoConfig +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + process_validation_metrics, +) +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.metric import reduce_metrics +from verl.utils.rollout_skip import RolloutSkip +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.torch_functional import masked_mean +from verl.utils.tracking import ValidationGenerationsLogger + + +@dataclass +class ResourcePoolManager: + """ + Define a resource pool specification. Resource pool will be initialized first. + """ + + resource_pool_spec: dict[str, list[int]] + mapping: dict[Role, str] + resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) + + def create_resource_pool(self): + """Create Ray resource pools for distributed training. + + Initializes resource pools based on the resource pool specification, + with each pool managing GPU resources across multiple nodes. + For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups. + For Megatron backend, uses max_colocate_count>1 for different models. + """ + for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): + # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool + # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. + # For Megatron backend, we recommend using max_colocate_count>1 + # that can utilize different WorkerGroup for differnt models + resource_pool = RayResourcePool( + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + ) + self.resource_pool_dict[resource_pool_name] = resource_pool + + self._check_resource_available() + + def get_resource_pool(self, role: Role) -> RayResourcePool: + """Get the resource pool of the worker_cls""" + return self.resource_pool_dict[self.mapping[role]] + + def get_n_gpus(self) -> int: + """Get the number of gpus in this cluster.""" + return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + + def _check_resource_available(self): + """Check if the resource pool can be satisfied in this ray cluster.""" + node_available_resources = ray._private.state.available_resources_per_node() + node_available_gpus = { + node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) + for node, node_info in node_available_resources.items() + } + + # check total required gpus can be satisfied + total_available_gpus = sum(node_available_gpus.values()) + total_required_gpus = sum( + [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + ) + if total_available_gpus < total_required_gpus: + raise ValueError( + f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" + ) + + +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): + """Apply KL penalty to the token-level rewards. + + This function computes the KL divergence between the reference policy and current policy, + then applies a penalty to the token-level rewards based on this divergence. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. + kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". + + Returns: + tuple: A tuple containing: + - The updated data with token-level rewards adjusted by KL penalty + - A dictionary of metrics related to the KL penalty + """ + response_mask = data.batch["response_mask"] + token_level_scores = data.batch["token_level_scores"] + batch_size = data.batch.batch_size[0] + + # compute kl between ref_policy and current policy + # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. + kld = core_algos.kl_penalty( + data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty + ) # (batch_size, response_length) + kld = kld * response_mask + beta = kl_ctrl.value + + token_level_rewards = token_level_scores - beta * kld + + current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence + current_kl = torch.mean(current_kl, dim=0).item() + + # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 + kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) + data.batch["token_level_rewards"] = token_level_rewards + + metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} + + return data, metrics + + +def compute_response_mask(data: DataProto): + """Compute the attention mask for the response part of the sequence. + + This function extracts the portion of the attention mask that corresponds to the model's response, + which is used for masking computations that should only apply to response tokens. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + + Returns: + torch.Tensor: The attention mask for the response tokens. + """ + responses = data.batch["responses"] + response_length = responses.size(1) + attention_mask = data.batch["attention_mask"] + return attention_mask[:, -response_length:] + + +def compute_advantage( + data: DataProto, + adv_estimator: AdvantageEstimator, + gamma: float = 1.0, + lam: float = 1.0, + num_repeat: int = 1, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> DataProto: + """Compute advantage estimates for policy optimization. + + This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. + The advantage estimates are used to guide policy optimization in RL algorithms. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. + lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. + num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in + GRPO. Defaults to True. + config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. + + Returns: + DataProto: The updated data with computed advantages and returns. + """ + # Back-compatible with trainers that do not compute response mask in fit + if "response_mask" not in data.batch.keys(): + data.batch["response_mask"] = compute_response_mask(data) + # prepare response group + if adv_estimator == AdvantageEstimator.GAE: + # Compute advantages and returns using Generalized Advantage Estimation (GAE) + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=data.batch["token_level_rewards"], + values=data.batch["values"], + response_mask=data.batch["response_mask"], + gamma=gamma, + lam=lam, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + if config.get("use_pf_ppo", False): + data = core_algos.compute_pf_ppo_reweight_data( + data, + config.pf_ppo.get("reweight_method"), + config.pf_ppo.get("weight_pow"), + ) + elif adv_estimator == AdvantageEstimator.GRPO: + # Initialize the mask for GRPO calculation + grpo_calculation_mask = data.batch["response_mask"] + + # Call compute_grpo_outcome_advantage with parameters matching its definition + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=grpo_calculation_mask, + index=data.non_tensor_batch["uid"], + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + else: + # handle all other adv estimator type other than GAE and GRPO + adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) + adv_kwargs = { + "token_level_rewards": data.batch["token_level_rewards"], + "response_mask": data.batch["response_mask"], + "config": config, + } + if "uid" in data.non_tensor_batch: # optional + adv_kwargs["index"] = data.non_tensor_batch["uid"] + if "reward_baselines" in data.batch: # optional + adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] + + # calculate advantage estimator + advantages, returns = adv_estimator_fn(**adv_kwargs) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + return data + + +class RayPPOTrainer: + """Distributed PPO trainer using Ray for scalable reinforcement learning. + + This trainer orchestrates distributed PPO training across multiple nodes and GPUs, + managing actor rollouts, critic training, and reward computation with Ray backend. + Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name=None, + ): + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + reward_fn: Function for computing rewards during training. + val_reward_fn: Function for computing rewards during validation. + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. + """ + + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = need_reference_policy(self.role_worker_mapping) + self.use_rm = need_reward_model(self.role_worker_mapping) + self.use_critic = need_critic(self.config) + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + self.validation_generations_logger = ValidationGenerationsLogger( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + ) + + # if ref_in_actor is True, the reference policy will be actor without lora applied + self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 + + # define in-reward KL control + # kl loss control currently not suppoorted + if self.config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset( + self.config.data.train_files, self.config.data, self.tokenizer, self.processor + ) + if val_dataset is None: + val_dataset = create_rl_dataset( + self.config.data.val_files, self.config.data, self.tokenizer, self.processor + ) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + num_workers = self.config.data["dataloader_num_workers"] + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): + """Dump rollout/validation samples as JSONL.""" + os.makedirs(dump_path, exist_ok=True) + filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") + + n = len(inputs) + base_data = { + "input": inputs, + "output": outputs, + "gts": gts, + "score": scores, + "step": [self.global_steps] * n, + } + + for k, v in reward_extra_infos_dict.items(): + if len(v) == n: + base_data[k] = v + + lines = [] + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + lines.append(json.dumps(entry, ensure_ascii=False)) + + with open(filename, "w") as f: + f.write("\n".join(lines) + "\n") + + print(f"Dumped generations to {filename}") + + def _log_rollout_data( + self, batch: DataProto, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str + ): + """Log rollout data to disk. + Args: + batch (DataProto): The batch containing rollout data + reward_extra_infos_dict (dict): Additional reward information to log + timing_raw (dict): Timing information for profiling + rollout_data_dir (str): Directory path to save the rollout data + """ + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] + + reward_extra_infos_to_dump = reward_extra_infos_dict.copy() + if "request_id" in batch.non_tensor_batch: + reward_extra_infos_dict.setdefault( + "request_id", + batch.non_tensor_batch["request_id"].tolist(), + ) + + self._dump_generations( + inputs=inputs, + outputs=outputs, + gts=sample_gts, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_to_dump, + dump_path=rollout_data_dir, + ) + + def _maybe_log_val_generations(self, inputs, outputs, scores): + """Log a table of validation samples to the configured logger (wandb or swanlab)""" + + generations_to_log = self.config.trainer.log_val_generations + + if generations_to_log == 0: + return + + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + samples = list(zip(inputs, outputs, scores, strict=True)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Log to each configured logger + self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + + def _get_gen_batch(self, batch: DataProto) -> DataProto: + reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() + + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), + ) + + # For agent loop, we need reward model keys to compute score. + if self.async_rollout_mode: + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + + return gen_batch + + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_gts = [] + sample_scores = [] + sample_turns = [] + sample_uids = [] + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + if "uid" not in test_batch.non_tensor_batch: + test_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object + ) + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + # Store original inputs + input_ids = test_batch.batch["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + sample_uids.extend(test_batch.non_tensor_batch["uid"]) + + ground_truths = [ + item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch + ] + sample_gts.extend(ground_truths) + + test_gen_batch = self._get_gen_batch(test_batch) + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + "global_steps": self.global_steps, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + size_divisor = ( + self.actor_rollout_wg.world_size + if not self.async_rollout_mode + else self.config.actor_rollout_ref.rollout.agent.num_workers + ) + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + + print("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + test_batch.meta_info["validate"] = True + + # evaluate using reward_function + if self.val_reward_fn is None: + raise ValueError("val_reward_fn must be provided for validation.") + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") + + # collect num_turns of each prompt + if "__num_turns__" in test_batch.non_tensor_batch: + sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + data_sources = np.concatenate(data_source_lst, axis=0) + + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + if len(sample_turns) > 0: + sample_turns = np.concatenate(sample_turns) + metric_dict["val-aux/num_turns/min"] = sample_turns.min() + metric_dict["val-aux/num_turns/max"] = sample_turns.max() + metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() + + return metric_dict + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor_rollout", + ) + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cfg = omega_conf_to_dataclass(self.config.critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role="ref", + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + self.rm_wg = None + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg.init_model() + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async": + from verl.experimental.agent_loop import AgentLoopManager + + self.async_rollout_mode = True + self.async_rollout_manager = AgentLoopManager( + config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg + ) + + def _save_checkpoint(self): + from verl.utils.fs import local_mkdir_safe + + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print( + "Warning: remove_previous_ckpt_in_save is deprecated," + + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" + ) + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + ) + self.critic_wg.save_checkpoint( + critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + ) + + # save dataloader + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + + # latest checkpointed iteration tracker (for atomic usage) + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, "critic") + # load actor + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + else: + print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + + def _start_profiling(self, do_profile: bool) -> None: + """Start profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile(profile_step=self.global_steps) + if self.use_critic: + self.critic_wg.start_profile(profile_step=self.global_steps) + if self.use_rm: + self.rm_wg.start_profile(profile_step=self.global_steps) + + def _stop_profiling(self, do_profile: bool) -> None: + """Stop profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + if self.use_rm: + self.rm_wg.stop_profile() + + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): + """Reorder the data on single controller such that each dp rank gets similar total tokens""" + attention_mask = batch.batch["attention_mask"] + batch_size = attention_mask.shape[0] + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + world_size = self.actor_rollout_wg.world_size + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) + # reorder based on index. The data will be automatically equally partitioned by dispatch function + global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + batch.reorder(global_idx) + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + ) + metrics.update(global_balance_stats) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + + is_last_step = self.global_steps >= self.total_training_steps + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + if self.reward_fn is None: + raise ValueError("A reward_fn is required for REMAX advantage estimation.") + + with marked_timer("gen_max", timing_raw, color="purple"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm and "rm_scores" not in batch.batch.keys(): + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh b/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh new file mode 100644 index 00000000000..1d3a190b3f2 --- /dev/null +++ b/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh @@ -0,0 +1,59 @@ +set -x + +project_name='GRPO-Qwen3' +exp_name='GRPO-Qwen3-8B-npu' +gen_tp=2 +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-8B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.default_local_dir=${CKPTS_DIR} \ + trainer.device=npu \ + trainer.resume_mode=auto \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + ++actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ + ++actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + trainer.val_before_train=True \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 \ No newline at end of file From db24ce4d6b37d0a95d46ccb9a59239e608fbcfda Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Sun, 28 Sep 2025 18:07:13 +0800 Subject: [PATCH 2/8] Integrate TransferQueue with Ray Trainer --- recipe/transfer_queue/main_ppo.py | 7 +- recipe/transfer_queue/ray_trainer.py | 453 ++++++++++++++---- .../run_qwen3-8b_transferqueue_npu.sh | 10 +- verl/single_controller/base/decorator.py | 12 +- 4 files changed, 384 insertions(+), 98 deletions(-) diff --git a/recipe/transfer_queue/main_ppo.py b/recipe/transfer_queue/main_ppo.py index 1320cdaa145..0e692ceaf35 100644 --- a/recipe/transfer_queue/main_ppo.py +++ b/recipe/transfer_queue/main_ppo.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Note that we don't combine the main with ray_trainer as ray_trainer is used by other mpain. +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ import os @@ -24,13 +24,14 @@ from verl.experimental.dataset.sampler import AbstractSampler from verl.trainer.constants_ppo import get_ppo_ray_runtime_env -from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.trainer.ppo.reward import load_reward_manager from verl.trainer.ppo.utils import need_critic, need_reference_policy from verl.utils.config import validate_config from verl.utils.device import is_cuda_available from verl.utils.import_utils import load_extern_type +from .ray_trainer import RayPPOTrainer + @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) def main(config): @@ -312,7 +313,6 @@ def run(self, config): ) # Initialize the workers of the trainer. trainer.init_workers() - # Start the training process. trainer.fit() @@ -350,6 +350,7 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr dataset_cls = DynamicGenDataset print("Using DynamicGenDataset for data generation.") + else: # Use the default RLHFDataset class if no custom class is specified dataset_cls = RLHFDataset diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 2752f08feb2..bbf6db748f1 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -18,12 +18,16 @@ This trainer supports model-agonistic model initialization with huggingface """ +import asyncio import json +import math +import logging import os import uuid from collections import defaultdict from copy import deepcopy from dataclasses import dataclass, field +from packaging.version import parse as parse_version from pprint import pprint from typing import Optional @@ -31,12 +35,19 @@ import ray import torch from omegaconf import OmegaConf, open_dict +import tensordict +from tensordict import TensorDict from torch.utils.data import Dataset, Sampler from torchdata.stateful_dataloader import StatefulDataLoader from tqdm import tqdm from verl import DataProto from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.experimental.transfer_queue.client import AsyncTransferQueueClient, process_zmq_server_info +from verl.experimental.transfer_queue.controller import TransferQueueController +from verl.experimental.transfer_queue.metadata import BatchMeta +from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit +from verl.experimental.transfer_queue.utils.utils import get_placement_group from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls @@ -160,22 +171,30 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, return data, metrics -def compute_response_mask(data: DataProto): +def compute_response_mask(batch_meta: BatchMeta, data_system_client): """Compute the attention mask for the response part of the sequence. This function extracts the portion of the attention mask that corresponds to the model's response, which is used for masking computations that should only apply to response tokens. Args: - data (DataProto): The data containing batched model outputs and inputs. + batch_meta (BatchMeta): The data containing batched model outputs and inputs. Returns: - torch.Tensor: The attention mask for the response tokens. + BatchMeta: The BatchMeta of attention mask for the response tokens. """ - responses = data.batch["responses"] + data = asyncio.run(data_system_client.async_get_data(batch_meta)) + + responses = data["responses"] response_length = responses.size(1) - attention_mask = data.batch["attention_mask"] - return attention_mask[:, -response_length:] + attention_mask = data["attention_mask"] + response_mask = attention_mask[:, -response_length:] + output = TensorDict({"response_mask": response_mask}, batch_size=response_mask.size(0)) + + asyncio.run(data_system_client.async_put(data=output, metadata=batch_meta)) + batch_meta.add_fields(output) + + return batch_meta def compute_advantage( @@ -229,7 +248,6 @@ def compute_advantage( elif adv_estimator == AdvantageEstimator.GRPO: # Initialize the mask for GRPO calculation grpo_calculation_mask = data.batch["response_mask"] - # Call compute_grpo_outcome_advantage with parameters matching its definition advantages, returns = core_algos.compute_grpo_outcome_advantage( token_level_rewards=data.batch["token_level_rewards"], @@ -340,6 +358,59 @@ def __init__( self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + self._initialize_data_system() + + def _initialize_data_system(self): + num_n_samples = self.config.actor_rollout_ref.rollout.n + # 1. 初始化TransferQueueStorage + total_storage_size = self.config.data.train_batch_size * self.config.trainer.num_global_batch * num_n_samples + self.data_system_storage_units = {} + storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) + for storage_unit_rank in range(self.config.trainer.num_data_storage_units): + # TransferQueueStorage通过Ray拉起,是一个ray.remote修饰的类 + storage_node = TransferQueueStorageSimpleUnit.options( + placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank + ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) + self.data_system_storage_units[storage_unit_rank] = storage_node + logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") + + # 2. 初始化TransferQueueController + # 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务 + self.data_system_controllers = {} + controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) + for controller_rank in range(self.config.trainer.num_data_controllers): + self.data_system_controllers[controller_rank] = TransferQueueController.options( + placement_group=controller_placement_group, placement_group_bundle_index=controller_rank + ).remote( + num_storage_units=self.config.trainer.num_data_storage_units, + global_batch_size=self.config.data.train_batch_size, + num_global_batch=self.config.trainer.num_global_batch, + num_n_samples=num_n_samples, + ) + logging.info(f"TransferQueueController #{controller_rank} has been created.") + + # 3. 将Controller注册至各个Storage + # 每个Storage Unit拿到所有Controller的handler,通过Ray拿到对应的IP+端口,之后建立ZMQ Socket进行消息传输 + self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers) + self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) + + ray.get( + [ + storage_unit.register_controller_info.remote(self.data_system_controller_infos) + for storage_unit in self.data_system_storage_units.values() + ] + ) + + # 4. 创建Client + self.data_system_client = AsyncTransferQueueClient( + client_id="Trainer", + controller_infos=self.data_system_controller_infos[0], + # TODO: 主控Client感知所有controller,WorkerGroup和Worker的Client感知一个controller + storage_infos=self.data_system_storage_unit_infos, + ) + + return self.data_system_client + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): """ Creates the train and validation dataloaders. @@ -443,26 +514,29 @@ def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dic print(f"Dumped generations to {filename}") def _log_rollout_data( - self, batch: DataProto, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str + self, log_rollout_meta: BatchMeta, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str ): """Log rollout data to disk. Args: - batch (DataProto): The batch containing rollout data + log_rollout_meta (BatchMeta): The batch_meta of rollout data reward_extra_infos_dict (dict): Additional reward information to log timing_raw (dict): Timing information for profiling rollout_data_dir (str): Directory path to save the rollout data """ with marked_timer("dump_rollout_generations", timing_raw, color="green"): - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() - sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] + + data = asyncio.run(self.data_system_client.async_get_data(log_rollout_meta)) + + inputs = self.tokenizer.batch_decode(data["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(data["responses"], skip_special_tokens=True) + scores = data["token_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.get("ground_truth", None) for item in data.get("reward_model", {})] reward_extra_infos_to_dump = reward_extra_infos_dict.copy() - if "request_id" in batch.non_tensor_batch: + if "request_id" in log_rollout_meta.field_names: reward_extra_infos_dict.setdefault( "request_id", - batch.non_tensor_batch["request_id"].tolist(), + data["request_id"].tolist(), ) self._dump_generations( @@ -719,8 +793,8 @@ def init_workers(self): # Only require nsight worker options when tool is nsys if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": assert ( - OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") - is not None + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None ), "worker_nsight_options must be set when using nsys with profile_steps" wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") @@ -901,22 +975,93 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm: self.rm_wg.stop_profile() - def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): - """Reorder the data on single controller such that each dp rank gets similar total tokens""" - attention_mask = batch.batch["attention_mask"] + def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"): + """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens""" + data = asyncio.run(data_system_client.async_get_data(batch)) + + attention_mask = data["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) world_size = self.actor_rollout_wg.world_size global_partition_lst = get_seqlen_balanced_partitions( global_seqlen_lst, k_partitions=world_size, equal_size=True ) # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) - batch.reorder(global_idx) + global_idx = [j for partition in global_partition_lst for j in partition] global_balance_stats = log_seqlen_unbalance( seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix ) metrics.update(global_balance_stats) + return global_idx + + @classmethod + def repeat_dict(cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_times=2, interleave=True) \ + -> dict[str, torch.Tensor | np.ndarray]: + """ + Repeat the batch dict a specified number of times. + + Args: + repeat_times (int): Number of times to repeat the data. + interleave (bool): Whether to interleave the repeated data. + + Returns: + dict: A new dict with repeated data. + """ + if repeat_times == 1: + return batch_dict + + repeated_batch_dict = {} + if batch_dict: + if interleave: + # Interleave the data + for key, val in batch_dict.items(): + if isinstance(val, torch.Tensor): + repeated_batch_dict[key] = val.repeat_interleave(repeat_times, dim=0) + elif isinstance(val, np.ndarray): + repeated_batch_dict[key] = np.repeat(val, repeat_times, axis=0) + else: + raise ValueError(f"Unsupported type in data {type(val)}") + else: + # Stack the data + for key, val in batch_dict.items(): + if isinstance(val, torch.Tensor): + repeated_batch_dict[key] = val.unsqueeze(0).expand(repeat_times, *val.shape).reshape(-1, + *val.shape[ + 1:]) + elif isinstance(val, np.ndarray): + repeated_batch_dict[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) + else: + raise ValueError(f"Unsupported type in data {type(val)}") + return repeated_batch_dict + + @classmethod + def dict_to_tensordict(cls, data: dict[str, torch.Tensor | np.ndarray]) -> TensorDict: + """Create a TensorDict from a dict of tensors and non_tensors. + Note that this requires tensordict version at least 0.10 + """ + assert parse_version(tensordict.__version__) >= parse_version("0.10"), ( + "Storing non-tensor data in TensorDict at least requires tensordict version 0.10" + ) + tensors_batch = {} + batch_size = None + + for key, val in data.items(): + if isinstance(val, (torch.Tensor, np.ndarray)): + tensors_batch[key] = val + else: + raise ValueError(f"Unsupported type in data {type(val)}") + + if batch_size is None: + batch_size = len(val) + else: + assert len(val) == batch_size + + if batch_size is None: + batch_size = [] + else: + batch_size = [batch_size] + + return TensorDict(tensors_batch, batch_size=batch_size) def fit(self): """ @@ -982,74 +1127,110 @@ def fit(self): if self.config.global_profiler.profile_continuous_steps else curr_step_profile ) - batch: DataProto = DataProto.from_single_dict(batch_dict) # add uid to batch - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + batch_dict["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch_dict["input_ids"]))], dtype=object + ) + # When n > 1, repeat input data before putting to data system, simulating DataProto repeat. + repeated_batch_dict = self.repeat_dict( + batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) + asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) + gen_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["input_ids", "attention_mask", "position_ids", "index", + "tools_kwargs", "interaction_kwargs", "ability", "raw_prompt_ids"], + batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, # self.global_steps start from 1 + get_n_samples=False, + task_name="generate_sequences", + ) ) - - gen_batch = self._get_gen_batch(batch) - # pass global_steps to trace - gen_batch.meta_info["global_steps"] = self.global_steps - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + gen_meta.set_extra_info("global_steps", self.global_steps) is_last_step = self.global_steps >= self.total_training_steps - with marked_timer("step", timing_raw): + + with ((marked_timer("step", timing_raw))): # generate a batch with marked_timer("gen", timing_raw, color="red"): if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_output_meta = self.actor_rollout_wg.generate_sequences(gen_meta) else: - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) - - timing_raw.update(gen_batch_output.meta_info["timing"]) - gen_batch_output.meta_info.pop("timing", None) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - if self.reward_fn is None: - raise ValueError("A reward_fn is required for REMAX advantage estimation.") - - with marked_timer("gen_max", timing_raw, color="purple"): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - if not self.async_rollout_mode: - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - else: - gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) - batch = batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - batch = batch.union(gen_batch_output) + gen_output_meta = self.async_rollout_manager.generate_sequences(gen_meta) + timing_raw.update(gen_output_meta.extra_info["timing"]) + gen_output_meta.extra_info.pop("timing", None) + + # TODO: (transferqueue) + # if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + # if self.reward_fn is None: + # raise ValueError("A reward_fn is required for REMAX advantage estimation.") + # + # with marked_timer("gen_max", timing_raw, color="purple"): + # gen_baseline_meta = deepcopy(gen_meta) + # gen_baseline_meta.extra_info["do_sample"] = False + # if not self.async_rollout_mode: + # gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_meta) + # else: + # gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_meta) + # batch = batch.union(gen_baseline_output) + # reward_baseline_tensor = self.reward_fn(batch) + # reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + # + # batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + # + # batch.batch["reward_baselines"] = reward_baseline_tensor + # + # del gen_baseline_batch, gen_baseline_output + + batch_meta = gen_meta.union(gen_meta_output) + + if "response_mask" not in batch_meta.field_names: + response_mask_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["responses", "attention_mask"], + batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="compute_response_mask", + ) + ) + response_mask_output_meta = compute_response_mask(response_mask_meta, self.data_system_client) + batch_meta = batch_meta.union(response_mask_output_meta) - if "response_mask" not in batch.batch.keys(): - batch.batch["response_mask"] = compute_response_mask(batch) # Balance the number of valid tokens across DP ranks. # NOTE: This usually changes the order of data in the `batch`, # which won't affect the advantage calculation (since it's based on uid), # but might affect the loss calculation (due to the change of mini-batching). # TODO: Decouple the DP balancing and mini-batching. + balanced_idx = None if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) + attention_mask_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["attention_mask"], + batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="balance_batch", + ) + ) + + balanced_idx = self._balance_batch(attention_mask_meta, self.data_system_client, + metrics=metrics) + batch_meta.reorder(balanced_idx) # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + data = asyncio.run(self.data_system_client.async_get_data(attention_mask_meta)) + batch_meta.extra_info["global_token_num"] = torch.sum(data["attention_mask"], dim=-1).tolist() with marked_timer("reward", timing_raw, color="yellow"): # compute reward model score - if self.use_rm and "rm_scores" not in batch.batch.keys(): - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) + if self.use_rm and "rm_scores" not in batch_meta.field_names: + reward_meta = self.rm_wg.compute_rm_score(batch_meta) + batch_meta = batch_meta.union(reward_meta) + # TODO: (transferqueue) (zjj) if self.config.reward_model.launch_reward_fn_async: future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn) else: @@ -1057,61 +1238,120 @@ def fit(self): # recompute old_log_probs with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] - response_masks = batch.batch["response_mask"] + old_log_prob_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["input_ids", "attention_mask", "position_ids", "response", + "response_mask", "data_source", "reward_model", "extra_info", + "uid", "index", "tools_kwargs", "interaction_kwargs", "ability"], + batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="compute_log_prob", + ) + ) + old_log_prob_meta.reorder(balanced_idx) + + old_log_prob_output_meta = self.actor_rollout_wg.compute_log_prob(old_log_prob_meta) + data = asyncio.run(self.data_system_client.async_get_data(old_log_prob_output_meta)) + entropys = data["entropys"] + response_masks = data["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) - if "rollout_log_probs" in batch.batch.keys(): + batch_meta = batch_meta.union(old_log_prob_output_meta) + + if "rollout_log_probs" in batch_meta.field_names: # TODO: we may want to add diff of probs too. from verl.utils.debug.metrics import calculate_debug_metrics - + # # TODO: (transferqueue) batch_meta(BatchMeta) -> batch(DataProto) metrics.update(calculate_debug_metrics(batch)) if self.use_reference_policy: # compute reference log_prob + ref_log_prob_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["input_ids", "attention_mask", "position_ids", "response", "response_mask", + "old_log_probs", "data_source", "reward_model", "extra_info", + "uid", "index", "tools_kwargs", "interaction_kwargs", "ability"], + batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="compute_ref_log_prob", + ) + ) + ref_log_prob_meta.reorder(balanced_idx) with marked_timer("ref", timing_raw, color="olive"): if not self.ref_in_actor: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + ref_log_prob_output_meta = self.ref_policy_wg.compute_ref_log_prob(ref_log_prob_meta) else: - ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) + ref_log_prob_output_meta = self.actor_rollout_wg.compute_ref_log_prob(ref_log_prob_meta) + batch_meta = batch_meta.union(ref_log_prob_output_meta) # compute values if self.use_critic: with marked_timer("values", timing_raw, color="cyan"): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) + values_meta = self.critic_wg.compute_values(batch_meta) + batch_meta = batch_meta.union(values_meta) with marked_timer("adv", timing_raw, color="brown"): # we combine with rule-based rm reward_extra_infos_dict: dict[str, list] if self.config.reward_model.launch_reward_fn_async: reward_tensor, reward_extra_infos_dict = ray.get(future_reward) - batch.batch["token_level_scores"] = reward_tensor + reward_td = TensorDict( + {"token_level_scores": reward_tensor}, batch_size=reward_tensor.size(0) + ) + asyncio.run(self.data_system_client.async_put(data=reward_td, metadata=batch_meta)) + batch_meta.add_fields(reward_td) if reward_extra_infos_dict: - batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + reward_extra_infos_dict_new = {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + reward_extra_infos_td = self.dict_to_tensordict(reward_extra_infos_dict_new) + asyncio.run( + self.data_system_client.async_put(data=reward_extra_infos_td, metadata=batch_meta) + ) + batch_meta.add_fields(reward_extra_infos_td) # compute rewards. apply_kl_penalty if available + # TODO: (transferqueue) (zjj) if self.config.algorithm.use_kl_in_reward: batch, kl_metrics = apply_kl_penalty( batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty ) metrics.update(kl_metrics) else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + token_level_scores_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["token_level_scores"], + batch_size=self.config.data.train_batch_size * + self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="token_level_scores", + ) + ) + token_level_scores_meta.reorder(balanced_idx) + data = asyncio.run(self.data_system_client.async_get_data(token_level_scores_meta)) + token_level_rewards_td = TensorDict( + {"token_level_rewards": data["token_level_scores"]}, + batch_size=data["token_level_scores"].size(0) + ) + asyncio.run( + self.data_system_client.async_put( + data=token_level_rewards_td, metadata=token_level_scores_meta + ) + ) + batch_meta.add_fields(token_level_rewards_td) # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( "norm_adv_by_std_in_grpo", True ) # GRPO adv normalization factor + # TODO: (transferqueue) (zjj) batch = compute_advantage( batch, adv_estimator=self.config.algorithm.adv_estimator, @@ -1125,23 +1365,56 @@ def fit(self): # update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + critic_output_meta = self.critic_wg.update_critic(batch_meta) + batch_meta.union(critic_output_meta) + critic_output_metrics = reduce_metrics(critic_output_meta.extra_info["metrics"]) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor with marked_timer("update_actor", timing_raw, color="red"): - batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + batch_meta.extra_info[ + "multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + + update_actor_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["input_ids", "attention_mask", "position_ids", "response", + "response_mask", "old_log_probs", "ref_log_prob", "advantages", + "returns", "token_level_rewards", "token_level_scores", "data_source", + "reward_model", "extra_info", "uid", "index", "tools_kwargs", + "interaction_kwargs", "ability"], + batch_size=self.config.data.train_batch_size * + self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="update_actor", + ) + ) + update_actor_meta.reorder(balanced_idx) + + actor_output_meta = self.actor_rollout_wg.update_actor(update_actor_meta) + batch_meta = batch_meta.union(actor_output_meta) + actor_output_metrics = reduce_metrics(actor_output_meta.extra_info["metrics"]) metrics.update(actor_output_metrics) # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + data_fields = ["prompts", "responses", "token_level_scores", "reward_model"] + if "request_id" in batch_meta.field_names: data_fields.append("request_id") + log_rollout_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=data_fields, + batch_size=self.config.data.train_batch_size * + self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="log_rollout", + ) + ) + log_rollout_meta.reorder(balanced_idx) + self._log_rollout_data(log_rollout_meta, reward_extra_infos_dict, timing_raw, rollout_data_dir) # validate if ( @@ -1150,6 +1423,7 @@ def fit(self): and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) ): with marked_timer("testing", timing_raw, color="green"): + # TODO: (transferqueue) _validate(DataProto->BatchMeta) val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics @@ -1200,6 +1474,7 @@ def fit(self): } ) # collect metrics + # TODO: (transferqueue) batch_meta(BatchMeta) - > batch(DataProto) metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: implement actual tflpo and theoretical tflpo diff --git a/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh b/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh index 1d3a190b3f2..70b7e23976d 100644 --- a/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh +++ b/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh @@ -9,7 +9,8 @@ CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} -python3 -m verl.trainer.main_ppo \ +python3 -m recipe.transfer_queue.main_ppo \ + --config-name='transfer_queue_ppo_trainer' \ algorithm.adv_estimator=grpo \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \ @@ -53,7 +54,10 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ ++actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ ++actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ - trainer.val_before_train=True \ + trainer.val_before_train=False \ trainer.save_freq=5 \ trainer.test_freq=5 \ - trainer.total_epochs=15 \ No newline at end of file + trainer.total_epochs=15 \ + +trainer.num_global_batch=1 \ + +trainer.num_data_storage_units=2 \ + +trainer.num_data_controllers=1 \ No newline at end of file diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index b246aaf7eec..e87547cd142 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -70,15 +70,16 @@ def init_predefined_execute_mode(): def _split_args_kwargs_data_proto(chunks, *args, **kwargs): from verl.protocol import DataProto, DataProtoFuture + from verl.experimental.transfer_queue.metadata import BatchMeta splitted_args = [] for arg in args: - assert isinstance(arg, DataProto | DataProtoFuture) + assert isinstance(arg, DataProto | DataProtoFuture | BatchMeta) splitted_args.append(arg.chunk(chunks=chunks)) splitted_kwargs = {} for key, val in kwargs.items(): - assert isinstance(val, DataProto | DataProtoFuture) + assert isinstance(val, DataProto | DataProtoFuture | BatchMeta) splitted_kwargs[key] = val.chunk(chunks=chunks) return splitted_args, splitted_kwargs @@ -135,6 +136,7 @@ def _concat_data_proto_or_future(output: list): import ray from verl.protocol import DataProto, DataProtoFuture + from verl.experimental.transfer_queue.metadata import BatchMeta # make sure all the elements in output has the same type for o in output: @@ -146,6 +148,8 @@ def _concat_data_proto_or_future(output: list): return DataProto.concat(output) elif isinstance(o, ray.ObjectRef): return DataProtoFuture.concat(output) + elif isinstance(o, BatchMeta): + return BatchMeta.concat(output) else: raise NotImplementedError @@ -263,9 +267,11 @@ def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output) import ray from verl.protocol import DataProto + from verl.experimental.transfer_queue.metadata import BatchMeta for o in output: - assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" + assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta), \ + f"expecting {o} to be DataProto or BatchMeta, but got {type(o)}" return _concat_data_proto_or_future(output) From b35596f105925d3c7518a4d01f41ff5e34888c53 Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Sun, 28 Sep 2025 18:50:10 +0800 Subject: [PATCH 3/8] Fix codecheck --- recipe/transfer_queue/ray_trainer.py | 133 +++++++++++++++-------- verl/single_controller/base/decorator.py | 9 +- 2 files changed, 94 insertions(+), 48 deletions(-) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index bbf6db748f1..16edb5c3846 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -20,22 +20,21 @@ import asyncio import json -import math import logging +import math import os import uuid from collections import defaultdict -from copy import deepcopy from dataclasses import dataclass, field -from packaging.version import parse as parse_version from pprint import pprint from typing import Optional import numpy as np import ray +import tensordict import torch from omegaconf import OmegaConf, open_dict -import tensordict +from packaging.version import parse as parse_version from tensordict import TensorDict from torch.utils.data import Dataset, Sampler from torchdata.stateful_dataloader import StatefulDataLoader @@ -793,8 +792,8 @@ def init_workers(self): # Only require nsight worker options when tool is nsys if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": assert ( - OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") - is not None + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None ), "worker_nsight_options must be set when using nsys with profile_steps" wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") @@ -995,8 +994,9 @@ def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_ return global_idx @classmethod - def repeat_dict(cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_times=2, interleave=True) \ - -> dict[str, torch.Tensor | np.ndarray]: + def repeat_dict( + cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_times=2, interleave=True + ) -> dict[str, torch.Tensor | np.ndarray]: """ Repeat the batch dict a specified number of times. @@ -1025,9 +1025,9 @@ def repeat_dict(cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_ti # Stack the data for key, val in batch_dict.items(): if isinstance(val, torch.Tensor): - repeated_batch_dict[key] = val.unsqueeze(0).expand(repeat_times, *val.shape).reshape(-1, - *val.shape[ - 1:]) + repeated_batch_dict[key] = ( + val.unsqueeze(0).expand(repeat_times, *val.shape).reshape(-1, *val.shape[1:]) + ) elif isinstance(val, np.ndarray): repeated_batch_dict[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) else: @@ -1036,8 +1036,9 @@ def repeat_dict(cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_ti @classmethod def dict_to_tensordict(cls, data: dict[str, torch.Tensor | np.ndarray]) -> TensorDict: - """Create a TensorDict from a dict of tensors and non_tensors. - Note that this requires tensordict version at least 0.10 + """ + Create a TensorDict from a dict of tensors and non_tensors. + Note that this requires tensordict version at least 0.10 """ assert parse_version(tensordict.__version__) >= parse_version("0.10"), ( "Storing non-tensor data in TensorDict at least requires tensordict version 0.10" @@ -1046,7 +1047,7 @@ def dict_to_tensordict(cls, data: dict[str, torch.Tensor | np.ndarray]) -> Tenso batch_size = None for key, val in data.items(): - if isinstance(val, (torch.Tensor, np.ndarray)): + if isinstance(val, torch.Tensor | np.ndarray): tensors_batch[key] = val else: raise ValueError(f"Unsupported type in data {type(val)}") @@ -1134,13 +1135,20 @@ def fit(self): ) # When n > 1, repeat input data before putting to data system, simulating DataProto repeat. repeated_batch_dict = self.repeat_dict( - batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) gen_meta = asyncio.run( self.data_system_client.async_get_meta( - data_fields=["input_ids", "attention_mask", "position_ids", "index", - "tools_kwargs", "interaction_kwargs", "ability", "raw_prompt_ids"], + data_fields=["input_ids", + "attention_mask", + "position_ids", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability", + "raw_prompt_ids"], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, # self.global_steps start from 1 get_n_samples=False, @@ -1152,7 +1160,7 @@ def fit(self): is_last_step = self.global_steps >= self.total_training_steps - with ((marked_timer("step", timing_raw))): + with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw, color="red"): if not self.async_rollout_mode: @@ -1184,7 +1192,7 @@ def fit(self): # # del gen_baseline_batch, gen_baseline_output - batch_meta = gen_meta.union(gen_meta_output) + batch_meta = gen_meta.union(gen_output_meta) if "response_mask" not in batch_meta.field_names: response_mask_meta = asyncio.run( @@ -1216,8 +1224,9 @@ def fit(self): ) ) - balanced_idx = self._balance_batch(attention_mask_meta, self.data_system_client, - metrics=metrics) + balanced_idx = self._balance_batch( + attention_mask_meta, self.data_system_client,bmetrics=metrics + ) batch_meta.reorder(balanced_idx) # compute global_valid tokens @@ -1240,9 +1249,19 @@ def fit(self): with marked_timer("old_log_prob", timing_raw, color="blue"): old_log_prob_meta = asyncio.run( self.data_system_client.async_get_meta( - data_fields=["input_ids", "attention_mask", "position_ids", "response", - "response_mask", "data_source", "reward_model", "extra_info", - "uid", "index", "tools_kwargs", "interaction_kwargs", "ability"], + data_fields=["input_ids", + "attention_mask", + "position_ids", + "response", + "response_mask", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability"], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, @@ -1272,9 +1291,20 @@ def fit(self): # compute reference log_prob ref_log_prob_meta = asyncio.run( self.data_system_client.async_get_meta( - data_fields=["input_ids", "attention_mask", "position_ids", "response", "response_mask", - "old_log_probs", "data_source", "reward_model", "extra_info", - "uid", "index", "tools_kwargs", "interaction_kwargs", "ability"], + data_fields=["input_ids", + "attention_mask", + "position_ids", + "response", + "response_mask", + "old_log_probs", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability"], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, @@ -1300,9 +1330,8 @@ def fit(self): reward_extra_infos_dict: dict[str, list] if self.config.reward_model.launch_reward_fn_async: reward_tensor, reward_extra_infos_dict = ray.get(future_reward) - reward_td = TensorDict( - {"token_level_scores": reward_tensor}, batch_size=reward_tensor.size(0) - ) + reward_td = TensorDict({"token_level_scores": reward_tensor}, + batch_size=reward_tensor.size(0)) asyncio.run(self.data_system_client.async_put(data=reward_td, metadata=batch_meta)) batch_meta.add_fields(reward_td) @@ -1325,8 +1354,8 @@ def fit(self): token_level_scores_meta = asyncio.run( self.data_system_client.async_get_meta( data_fields=["token_level_scores"], - batch_size=self.config.data.train_batch_size * - self.config.actor_rollout_ref.rollout.n, + batch_size=self.config.data.train_batch_size + * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, task_name="token_level_scores", @@ -1336,7 +1365,7 @@ def fit(self): data = asyncio.run(self.data_system_client.async_get_data(token_level_scores_meta)) token_level_rewards_td = TensorDict( {"token_level_rewards": data["token_level_scores"]}, - batch_size=data["token_level_scores"].size(0) + batch_size=data["token_level_scores"].size(0), ) asyncio.run( self.data_system_client.async_put( @@ -1375,17 +1404,33 @@ def fit(self): # update actor with marked_timer("update_actor", timing_raw, color="red"): batch_meta.extra_info[ - "multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + "multi_turn"] = ( + self.config.actor_rollout_ref.rollout.multi_turn.enable + ) update_actor_meta = asyncio.run( self.data_system_client.async_get_meta( - data_fields=["input_ids", "attention_mask", "position_ids", "response", - "response_mask", "old_log_probs", "ref_log_prob", "advantages", - "returns", "token_level_rewards", "token_level_scores", "data_source", - "reward_model", "extra_info", "uid", "index", "tools_kwargs", - "interaction_kwargs", "ability"], - batch_size=self.config.data.train_batch_size * - self.config.actor_rollout_ref.rollout.n, + data_fields=["input_ids", + "attention_mask", + "position_ids", + "response", + "response_mask", + "old_log_probs", + "ref_log_prob", + "advantages", + "returns", + "token_level_rewards", + "token_level_scores", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability"], + batch_size=self.config.data.train_batch_size + * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, task_name="update_actor", @@ -1402,12 +1447,12 @@ def fit(self): rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: data_fields = ["prompts", "responses", "token_level_scores", "reward_model"] - if "request_id" in batch_meta.field_names: data_fields.append("request_id") + if "request_id" in batch_meta.field_names: + data_fields.append("request_id") log_rollout_meta = asyncio.run( self.data_system_client.async_get_meta( data_fields=data_fields, - batch_size=self.config.data.train_batch_size * - self.config.actor_rollout_ref.rollout.n, + batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, task_name="log_rollout", diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index e87547cd142..8d195f92209 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -69,8 +69,8 @@ def init_predefined_execute_mode(): def _split_args_kwargs_data_proto(chunks, *args, **kwargs): - from verl.protocol import DataProto, DataProtoFuture from verl.experimental.transfer_queue.metadata import BatchMeta + from verl.protocol import DataProto, DataProtoFuture splitted_args = [] for arg in args: @@ -135,8 +135,8 @@ def collect_all_to_all(worker_group, output): def _concat_data_proto_or_future(output: list): import ray - from verl.protocol import DataProto, DataProtoFuture from verl.experimental.transfer_queue.metadata import BatchMeta + from verl.protocol import DataProto, DataProtoFuture # make sure all the elements in output has the same type for o in output: @@ -266,12 +266,13 @@ def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output) output = collect_nd_compute(collect_mask, worker_group, output) import ray - from verl.protocol import DataProto from verl.experimental.transfer_queue.metadata import BatchMeta + from verl.protocol import DataProto for o in output: - assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta), \ + assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta), ( f"expecting {o} to be DataProto or BatchMeta, but got {type(o)}" + ) return _concat_data_proto_or_future(output) From 305d8dd229bff47414728deb35991809794d0d26 Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Sun, 28 Sep 2025 19:05:20 +0800 Subject: [PATCH 4/8] Fix codecheck --- recipe/transfer_queue/ray_trainer.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 16edb5c3846..863f2ced88e 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -995,7 +995,7 @@ def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_ @classmethod def repeat_dict( - cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_times=2, interleave=True + cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_times=2, interleave=True ) -> dict[str, torch.Tensor | np.ndarray]: """ Repeat the batch dict a specified number of times. @@ -1148,7 +1148,8 @@ def fit(self): "tools_kwargs", "interaction_kwargs", "ability", - "raw_prompt_ids"], + "raw_prompt_ids" + ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, # self.global_steps start from 1 get_n_samples=False, @@ -1225,7 +1226,7 @@ def fit(self): ) balanced_idx = self._balance_batch( - attention_mask_meta, self.data_system_client,bmetrics=metrics + attention_mask_meta, self.data_system_client, metrics=metrics ) batch_meta.reorder(balanced_idx) @@ -1261,7 +1262,8 @@ def fit(self): "index", "tools_kwargs", "interaction_kwargs", - "ability"], + "ability" + ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, @@ -1304,7 +1306,8 @@ def fit(self): "index", "tools_kwargs", "interaction_kwargs", - "ability"], + "ability" + ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, @@ -1355,7 +1358,7 @@ def fit(self): self.data_system_client.async_get_meta( data_fields=["token_level_scores"], batch_size=self.config.data.train_batch_size - * self.config.actor_rollout_ref.rollout.n, + * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, task_name="token_level_scores", @@ -1403,8 +1406,7 @@ def fit(self): if self.config.trainer.critic_warmup <= self.global_steps: # update actor with marked_timer("update_actor", timing_raw, color="red"): - batch_meta.extra_info[ - "multi_turn"] = ( + batch_meta.extra_info["multi_turn"] = ( self.config.actor_rollout_ref.rollout.multi_turn.enable ) @@ -1428,9 +1430,10 @@ def fit(self): "index", "tools_kwargs", "interaction_kwargs", - "ability"], + "ability" + ], batch_size=self.config.data.train_batch_size - * self.config.actor_rollout_ref.rollout.n, + * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, task_name="update_actor", From 9e6bd5058376e226ab492d78bff66c5951736525 Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Sun, 28 Sep 2025 19:15:31 +0800 Subject: [PATCH 5/8] Fix codecheck --- recipe/transfer_queue/ray_trainer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 863f2ced88e..11b3909480d 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -1149,7 +1149,7 @@ def fit(self): "interaction_kwargs", "ability", "raw_prompt_ids" - ], + ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, # self.global_steps start from 1 get_n_samples=False, @@ -1263,7 +1263,7 @@ def fit(self): "tools_kwargs", "interaction_kwargs", "ability" - ], + ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, @@ -1307,7 +1307,7 @@ def fit(self): "tools_kwargs", "interaction_kwargs", "ability" - ], + ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, @@ -1333,8 +1333,9 @@ def fit(self): reward_extra_infos_dict: dict[str, list] if self.config.reward_model.launch_reward_fn_async: reward_tensor, reward_extra_infos_dict = ray.get(future_reward) - reward_td = TensorDict({"token_level_scores": reward_tensor}, - batch_size=reward_tensor.size(0)) + reward_td = TensorDict( + {"token_level_scores": reward_tensor}, batch_size=reward_tensor.size(0) + ) asyncio.run(self.data_system_client.async_put(data=reward_td, metadata=batch_meta)) batch_meta.add_fields(reward_td) @@ -1431,7 +1432,7 @@ def fit(self): "tools_kwargs", "interaction_kwargs", "ability" - ], + ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, From 6017818f5a2aecf87e8cf1c4f5b2b2d22052b39c Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Sun, 28 Sep 2025 19:26:17 +0800 Subject: [PATCH 6/8] Fix codecheck --- recipe/transfer_queue/ray_trainer.py | 118 ++++++++++++++------------- 1 file changed, 60 insertions(+), 58 deletions(-) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 11b3909480d..aa12f3b5c22 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -1141,14 +1141,15 @@ def fit(self): asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) gen_meta = asyncio.run( self.data_system_client.async_get_meta( - data_fields=["input_ids", - "attention_mask", - "position_ids", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - "raw_prompt_ids" + data_fields=[ + "input_ids", + "attention_mask", + "position_ids", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability", + "raw_prompt_ids" ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, # self.global_steps start from 1 @@ -1250,19 +1251,20 @@ def fit(self): with marked_timer("old_log_prob", timing_raw, color="blue"): old_log_prob_meta = asyncio.run( self.data_system_client.async_get_meta( - data_fields=["input_ids", - "attention_mask", - "position_ids", - "response", - "response_mask", - "data_source", - "reward_model", - "extra_info", - "uid", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability" + data_fields=[ + "input_ids", + "attention_mask", + "position_ids", + "response", + "response_mask", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability" ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, @@ -1293,20 +1295,21 @@ def fit(self): # compute reference log_prob ref_log_prob_meta = asyncio.run( self.data_system_client.async_get_meta( - data_fields=["input_ids", - "attention_mask", - "position_ids", - "response", - "response_mask", - "old_log_probs", - "data_source", - "reward_model", - "extra_info", - "uid", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability" + data_fields=[ + "input_ids", + "attention_mask", + "position_ids", + "response", + "response_mask", + "old_log_probs", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability" ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, @@ -1333,9 +1336,7 @@ def fit(self): reward_extra_infos_dict: dict[str, list] if self.config.reward_model.launch_reward_fn_async: reward_tensor, reward_extra_infos_dict = ray.get(future_reward) - reward_td = TensorDict( - {"token_level_scores": reward_tensor}, batch_size=reward_tensor.size(0) - ) + reward_td = TensorDict({"token_level_scores": reward_tensor}, batch_size=reward_tensor.size(0)) asyncio.run(self.data_system_client.async_put(data=reward_td, metadata=batch_meta)) batch_meta.add_fields(reward_td) @@ -1413,28 +1414,29 @@ def fit(self): update_actor_meta = asyncio.run( self.data_system_client.async_get_meta( - data_fields=["input_ids", - "attention_mask", - "position_ids", - "response", - "response_mask", - "old_log_probs", - "ref_log_prob", - "advantages", - "returns", - "token_level_rewards", - "token_level_scores", - "data_source", - "reward_model", - "extra_info", - "uid", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability" + data_fields=[ + "input_ids", + "attention_mask", + "position_ids", + "response", + "response_mask", + "old_log_probs", + "ref_log_prob", + "advantages", + "returns", + "token_level_rewards", + "token_level_scores", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability" ], batch_size=self.config.data.train_batch_size - * self.config.actor_rollout_ref.rollout.n, + * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, task_name="update_actor", From 8818c4d08bba21b93aba6caf2c91ccd56b765111 Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Sun, 28 Sep 2025 20:21:55 +0800 Subject: [PATCH 7/8] Fix --- recipe/transfer_queue/ray_trainer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index aa12f3b5c22..efba4c7b98d 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -1149,7 +1149,7 @@ def fit(self): "tools_kwargs", "interaction_kwargs", "ability", - "raw_prompt_ids" + "raw_prompt_ids", ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, # self.global_steps start from 1 @@ -1255,7 +1255,8 @@ def fit(self): "input_ids", "attention_mask", "position_ids", - "response", + "prompts", + "responses", "response_mask", "data_source", "reward_model", @@ -1264,7 +1265,7 @@ def fit(self): "index", "tools_kwargs", "interaction_kwargs", - "ability" + "ability", ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, @@ -1299,7 +1300,8 @@ def fit(self): "input_ids", "attention_mask", "position_ids", - "response", + "prompts", + "responses", "response_mask", "old_log_probs", "data_source", @@ -1309,7 +1311,7 @@ def fit(self): "index", "tools_kwargs", "interaction_kwargs", - "ability" + "ability", ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, @@ -1418,7 +1420,8 @@ def fit(self): "input_ids", "attention_mask", "position_ids", - "response", + "prompts", + "responses", "response_mask", "old_log_probs", "ref_log_prob", @@ -1433,10 +1436,10 @@ def fit(self): "index", "tools_kwargs", "interaction_kwargs", - "ability" + "ability", ], batch_size=self.config.data.train_batch_size - * self.config.actor_rollout_ref.rollout.n, + * self.config.actor_rollout_ref.rollout.n, global_step=self.global_steps - 1, get_n_samples=False, task_name="update_actor", From 5bb745d7734bf1efd4b0a9530ad8184b1e147240 Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Sun, 28 Sep 2025 20:31:11 +0800 Subject: [PATCH 8/8] Fix codecheck --- recipe/transfer_queue/ray_trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index efba4c7b98d..d19b3d5b3d1 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -515,7 +515,9 @@ def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dic def _log_rollout_data( self, log_rollout_meta: BatchMeta, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str ): - """Log rollout data to disk. + """ + Log rollout data to disk. + Args: log_rollout_meta (BatchMeta): The batch_meta of rollout data reward_extra_infos_dict (dict): Additional reward information to log @@ -523,7 +525,6 @@ def _log_rollout_data( rollout_data_dir (str): Directory path to save the rollout data """ with marked_timer("dump_rollout_generations", timing_raw, color="green"): - data = asyncio.run(self.data_system_client.async_get_data(log_rollout_meta)) inputs = self.tokenizer.batch_decode(data["prompts"], skip_special_tokens=True) @@ -1289,7 +1290,8 @@ def fit(self): if "rollout_log_probs" in batch_meta.field_names: # TODO: we may want to add diff of probs too. from verl.utils.debug.metrics import calculate_debug_metrics - # # TODO: (transferqueue) batch_meta(BatchMeta) -> batch(DataProto) + + # TODO: (transferqueue) batch_meta(BatchMeta) -> batch(DataProto) metrics.update(calculate_debug_metrics(batch)) if self.use_reference_policy: