diff --git a/.gitignore b/.gitignore index d77a5b43ffc..b54bf5f02f1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +/ckpts_hdd +cx_run.sh +cx_run_noRoutingReplay.sh **/*.pt **/checkpoints diff --git a/recipe/moe_routing_replay/config/dapo_trainer.yaml b/recipe/moe_routing_replay/config/dapo_trainer.yaml new file mode 100644 index 00000000000..c901f3882cd --- /dev/null +++ b/recipe/moe_routing_replay/config/dapo_trainer.yaml @@ -0,0 +1,32 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +actor_rollout_ref: + rollout: + enable_routing_replay: True + +data: + gen_batch_size: ${data.train_batch_size} + +reward_model: + reward_manager: dapo + overlong_buffer: + enable: False # We try to avoid forgetting to set enable + len: 0 + penalty_factor: 0.0 + log: False + +algorithm: + filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig + enable: False # We try to avoid forgetting to set enable + metric: null # acc / score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 0 # Non-positive values mean no upper limit + +trainer: + project_name: verl-dapo diff --git a/recipe/moe_routing_replay/dapo_ray_trainer.py b/recipe/moe_routing_replay/dapo_ray_trainer.py new file mode 100644 index 00000000000..9123dc243a1 --- /dev/null +++ b/recipe/moe_routing_replay/dapo_ray_trainer.py @@ -0,0 +1,429 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint + +import numpy as np +import torch +from tqdm import tqdm + +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.trainer.ppo.reward import compute_reward +from verl.utils.metric import reduce_metrics +from verl.utils.profiler import marked_timer +from verl.utils.rollout_skip import RolloutSkip + + +class RayDAPOTrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, timing_raw: dict): + batch.batch["response_mask"] = compute_response_mask(batch) + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, "blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + metrics.update(calculate_debug_metrics(batch)) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, "olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + return batch + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + self.gen_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + self.gen_steps += 1 + last_val_metrics = None + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + timing_raw = defaultdict(float) + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + + new_batch: DataProto = DataProto.from_single_dict(batch_dict) + num_gen_batches += 1 + # pop those keys for generation + if "multi_modal_data" in new_batch.non_tensor_batch.keys(): + gen_batch = new_batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"], + ) + else: + gen_batch = new_batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids"], + ) + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, "red"): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with marked_timer("gen_max", timing_raw, "red"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + new_batch = new_batch.union(gen_baseline_output) + # compute reward model score on new_batch + rm_scores = None + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + rm_scores = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(rm_scores) + reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + new_batch.pop(batch_keys=list(keys_to_pop)) + + new_batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + + new_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + new_batch = new_batch.union(gen_batch_output) + + if self.config.algorithm.use_kl_in_reward: + # We need these metrics for apply_kl_penalty if using kl in reward + new_batch = self.compute_kl_related_metrics(new_batch, metrics, timing_raw) + # otherwise, we will compute those after dynamic sampling + + with marked_timer("reward", timing_raw, "yellow"): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn) + + new_batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + new_batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + new_batch, kl_metrics = apply_kl_penalty( + new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update( + kl_metrics + ) # TODO: This will be cleared if we use multiple genenration batches + else: + new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] + + if not self.config.algorithm.filter_groups.enable: + batch = new_batch + else: # NOTE: When prompts after filtering is less than train batch size, + # we skip to the next generation batch + metric_name = self.config.algorithm.filter_groups.metric + if metric_name == "seq_final_reward": + # Turn to numpy for easier filtering + new_batch.non_tensor_batch["seq_final_reward"] = ( + new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + ) + elif metric_name == "seq_reward": + new_batch.non_tensor_batch["seq_reward"] = ( + new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + ) + + # Collect the sequence reward for each trajectory + prompt_uid2metric_vals = defaultdict(list) + for uid, metric_val in zip( + new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True + ): + prompt_uid2metric_vals[uid].append(metric_val) + + prompt_uid2metric_std = {} + for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): + prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) + + kept_prompt_uids = [ + uid + for uid, std in prompt_uid2metric_std.items() + if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 + ] + num_prompt_in_batch += len(kept_prompt_uids) + + kept_traj_idxs = [] + for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): + if traj_from_prompt_uid in kept_prompt_uids: + kept_traj_idxs.append(idx) + + new_batch = new_batch[kept_traj_idxs] + batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) + + prompt_bsz = self.config.data.train_batch_size + if num_prompt_in_batch < prompt_bsz: + print(f"{num_prompt_in_batch=} < {prompt_bsz=}") + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f"{num_gen_batches=}. Keep generating...") + self.gen_steps += 1 + is_last_step = self.global_steps >= self.total_training_steps + continue + else: + raise ValueError( + f"{num_gen_batches=} >= {max_num_gen_batches=}." + + " Generated too many. Please check if your data are too difficult." + + " You could also try set max_num_gen_batches=0 to enable endless trials." + ) + else: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] + + # === Updating === + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + if not self.config.algorithm.use_kl_in_reward: + batch = self.compute_kl_related_metrics(batch, metrics, timing_raw) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, "cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + # Compute rollout correction weights and off-policy metrics (inherited from RayPPOTrainer) + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + with marked_timer("adv", timing_raw, "brown"): + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, "pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, "red"): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, "green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with marked_timer("save_checkpoint", timing_raw, "green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + timing_raw = defaultdict(float) # clear timing + + metrics["train/num_gen_batches"] = num_gen_batches + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 + self.gen_steps += 1 + # check if last step checkpint exists + checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + if not os.path.exists(checkpoint_dir): + # save last step checkpoint + timing_raw = defaultdict(float) + with marked_timer("save_checkpoint", timing_raw, "green"): + self._save_checkpoint() + metrics = {f"timing/{k}": v for k, v in timing_raw.items()} + logger.log(data=metrics, step=self.global_steps) diff --git a/recipe/moe_routing_replay/main_dapo.py b/recipe/moe_routing_replay/main_dapo.py new file mode 100644 index 00000000000..3fdfeb36fb2 --- /dev/null +++ b/recipe/moe_routing_replay/main_dapo.py @@ -0,0 +1,183 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.device import is_cuda_available + +from .dapo_ray_trainer import RayDAPOTrainer + + +@hydra.main(config_path="config", config_name="dapo_trainer", version_base=None) +def main(config): + run_ppo(config) + + +def run_ppo(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + default_runtime_env = { + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + } + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + try: + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and OmegaConf.select(config.global_profiler, "steps") is not None + and len(OmegaConf.select(config.global_profiler, "steps")) > 0 + ): + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + finally: + if ray.is_initialized(): + ray.shutdown() + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # used for multimodal LLM, could be none + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + from verl.single_controller.ray import RayWorkerGroup + + # define worker classes + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_fn = load_reward_manager( + config, + tokenizer, + 0, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + + # Note that we always use function-based RM for validation + val_reward_fn = load_reward_manager( + config, + tokenizer, + 1, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayDAPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) + trainer.init_workers() + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/recipe/moe_routing_replay/qwen3_routing_model.py b/recipe/moe_routing_replay/qwen3_routing_model.py new file mode 100644 index 00000000000..a27c323627a --- /dev/null +++ b/recipe/moe_routing_replay/qwen3_routing_model.py @@ -0,0 +1,586 @@ +import torch +import torch.nn.functional as F +from typing import Optional, Union, List, Dict, Any, Tuple +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeSparseMoeBlock, + Qwen3MoeDecoderLayer, + Qwen3MoeModel, + Qwen3MoeForCausalLM, + Qwen3MoePreTrainedModel +) + +# 尝试导入需要的类,如果失败则跳过 +try: + from transformers.cache_utils import Cache, DynamicCache +except ImportError: + from transformers.modeling_utils import Cache + DynamicCache = Cache + +try: + from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast +except ImportError: + # 如果导入失败,使用基础输出类型 + from transformers.modeling_outputs import BaseModelOutputWithPast as MoeModelOutputWithPast + from transformers.modeling_outputs import CausalLMOutputWithPast as MoeCausalLMOutputWithPast + +try: + from transformers.generation.utils import GenerationMixin +except ImportError: + from transformers import GenerationMixin + +try: + from transformers.modeling_attn_mask_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, + ) +except ImportError: + # 如果导入失败,定义简单的替代函数 + def create_causal_mask(*args, **kwargs): + return None + def create_sliding_window_causal_mask(*args, **kwargs): + return None + + + + +from typing import Any + +def print_input_shapes( + input_ids: Any = None, + attention_mask: Any = None, + position_ids: Any = None, + routing_ids: Any = None, + tag='', +) -> None: + """ + Prints shapes for the given arguments. Handles: + - torch.Tensor + - numpy.ndarray + - list/tuple/dict containing tensors/arrays + - None + """ + + def shape_of(x): + # Lazy imports so the function is standalone + try: + import torch + except Exception: + torch = None + try: + import numpy as np + except Exception: + np = None + + if x is None: + return "None" + + if torch is not None and isinstance(x, torch.Tensor): + return f"{tuple(x.shape)}" + + if np is not None and isinstance(x, np.ndarray): + return f"{tuple(x.shape)} (numpy)" + + if isinstance(x, (list, tuple)): + if not x: + return "[]" + parts = [] + for i, xi in enumerate(x): + if torch is not None and isinstance(xi, torch.Tensor): + parts.append(f"{i}:{tuple(xi.shape)}") + elif np is not None and isinstance(xi, np.ndarray): + parts.append(f"{i}:{tuple(xi.shape)}(np)") + else: + parts.append(f"{i}:{type(xi).__name__}") + return f"[{', '.join(parts)}]" + + if isinstance(x, dict): + parts = [] + for k, v in x.items(): + if torch is not None and isinstance(v, torch.Tensor): + parts.append(f"{k}:{tuple(v.shape)}") + elif np is not None and isinstance(v, np.ndarray): + parts.append(f"{k}:{tuple(v.shape)}(np)") + else: + parts.append(f"{k}:{type(v).__name__}") + return "{" + ", ".join(parts) + "}" + + return f"{type(x).__name__}" + + print(f"[{tag}] input_ids :", shape_of(input_ids)) + print(f"[{tag}] attention_mask:", shape_of(attention_mask)) + print(f"[{tag}] position_ids :", shape_of(position_ids)) + print(f"[{tag}] routing_ids :", shape_of(routing_ids)) + + + + + +class CustomQwen3MoeSparseMoeBlock: + @staticmethod + def forward(self, hidden_states: torch.Tensor, routing_map: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + + if routing_map is not None: + # 检查一下routing map的尺寸是否相符 + rp_batch_size, rp_sequence, rp_expert_num = routing_map.shape + + assert rp_batch_size == batch_size, f"[qwen3_routing_model][CustomQwen3MoeSparseMoeBlock.forward] Shape mismatch: routing_map.shape={routing_map.shape} but hidden_states.shape={hidden_states.shape}" + assert rp_sequence == sequence_length, f"[qwen3_routing_model][CustomQwen3MoeSparseMoeBlock.forward] Shape mismatch: routing_map.shape={routing_map.shape} but hidden_states.shape={hidden_states.shape}" + assert rp_expert_num == self.top_k, f"[qwen3_routing_model][CustomQwen3MoeSparseMoeBlock.forward] Expert number mismatch: rp_expert_num={rp_expert_num} but self.top_k={self.top_k}" + + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + + if routing_map is not None: + # 复用routing map,直接取相应位置的值 + try: + selected_experts = routing_map.view(-1, self.top_k)#.long() # TODO cx note: review required + routing_weights = routing_weights.gather(1, selected_experts) + except RuntimeError as re: + raise re + else: + # TODO CRITICAL DEBUGGING ONLY + # routing_weights: value of [batch * sequence_length, self.top_k] + # selected_experts: indice of [batch * sequence_length, self.top_k] + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero().squeeze(-1) + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + + return final_hidden_states, router_logits + +class CustomQwen3MoeDecoderLayer: + @staticmethod + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + routing_map: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + **kwargs, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention - 使用更兼容的方式调用 + # import pdb + # pdb.set_trace() + print(f"type(hidden_states)={type(hidden_states)}") + attn_output = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask.to(dtype=torch.float16) if attention_mask is not None else attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # 处理不同版本的返回值格式 + if isinstance(attn_output, tuple): + hidden_states = attn_output[0] + self_attn_weights = attn_output[1] if output_attentions else None + present_key_value = attn_output[-1] if use_cache else None + else: + hidden_states = attn_output + self_attn_weights = None + present_key_value = None + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # MLP forward with optional routing_map + if hasattr(self.mlp, 'forward') and 'routing_map' in self.mlp.forward.__code__.co_varnames: + mlp_output = self.mlp(hidden_states, routing_map=routing_map) + else: + mlp_output = self.mlp(hidden_states) + + # For the MoE layers, we need to unpack + if isinstance(mlp_output, tuple): + hidden_states, router_logits = mlp_output + else: + hidden_states = mlp_output + router_logits = None + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if router_logits is not None: + outputs += (router_logits,) + + return outputs + +class CustomQwen3MoeModel: + @staticmethod + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + routing_maps = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + if routing_maps is not None: + routing_maps = routing_maps.permute(2, 0, 1, 3) + # transfer to [layer, batch, seq_length, expert] + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + # import pdb + # pdb.set_trace() + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # 创建因果掩码(简化版本) + causal_mask = None + if attention_mask is not None: + causal_mask = attention_mask + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + if routing_maps is None: + routing_maps = [None] * self.config.num_hidden_layers + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer, routing_map in zip(self.layers, routing_maps): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + routing_map=routing_map, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and len(layer_outputs) > (2 if output_attentions else 1): + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k: int = 2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + """ + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + +class CustomMoeForCausalLM: + + @staticmethod + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + # routing_maps: Optional[List[torch.Tensor]] = None, + # routing_ids: Optional[List[torch.Tensor]] = None, # TODO NOTE cx modified + routing_ids = None, + + return_dict: Optional[bool] = None, + **kwargs, + ): + """ + Forward pass for the Custom MoE CausalLM model. + """ + + # print_input_shapes(input_ids, attention_mask, position_ids, routing_ids, tag='[cx_debug][qwen3_routing_model][CustomMoeForCausalLM]') + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + # routing_maps=routing_maps, + routing_maps=routing_ids, # TODO NOTE cx modified + return_dict=return_dict, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state if return_dict else outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + router_logits = outputs.router_logits if return_dict else (outputs[-1] if len(outputs) > 1 else None) + if router_logits: + aux_loss = load_balancing_loss_func( + router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None and aux_loss != 0: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) + + if not return_dict: + output = (logits,) + (outputs[1:] if isinstance(outputs, tuple) else ()) + if aux_loss is not None: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values if return_dict else None, + hidden_states=outputs.hidden_states if return_dict else None, + attentions=outputs.attentions if return_dict else None, + router_logits=outputs.router_logits if return_dict else None, + ) + +# 应用monkey patching - 替换原始类的forward方法 +def apply_patches(): + """应用所有的monkey patches""" + # Qwen3MoeSparseMoeBlock.forward = CustomQwen3MoeSparseMoeBlock.forward + + # # 为DecoderLayer创建一个wrapper + # original_decoder_init = Qwen3MoeDecoderLayer.__init__ + # original_decoder_forward = Qwen3MoeDecoderLayer.forward + + # def new_decoder_forward(self, *args, **kwargs): + # custom_layer = CustomQwen3MoeDecoderLayer() + # custom_layer.__dict__.update(self.__dict__) + # return custom_layer.forward(*args, **kwargs) + + # Qwen3MoeDecoderLayer.forward = new_decoder_forward + + # # 应用其他patches + # Qwen3MoeModel.forward = CustomQwen3MoeModel.forward + # Qwen3MoeForCausalLM.forward = CustomMoeForCausalLM.forward + + Qwen3MoeSparseMoeBlock.forward = CustomQwen3MoeSparseMoeBlock.forward + Qwen3MoeDecoderLayer.forward = CustomQwen3MoeDecoderLayer.forward + Qwen3MoeModel.forward = CustomQwen3MoeModel.forward + Qwen3MoeForCausalLM.forward = CustomMoeForCausalLM.forward + + print("Monkey patched Qwen3Moe Model for FSDP routing replay.") + +# 自动应用patches +apply_patches() \ No newline at end of file diff --git a/recipe/moe_routing_replay/run.sh b/recipe/moe_routing_replay/run.sh new file mode 100644 index 00000000000..0c28f93e0e8 --- /dev/null +++ b/recipe/moe_routing_replay/run.sh @@ -0,0 +1,188 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# set -xuo pipefail +# while true; do + +export http_proxy=http://oversea-squid2.ko.txyun:11080 https_proxy=http://oversea-squid2.ko.txyun:11080 no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com + +timestamp=$(date +"%Y-%m-%d-%H:%M:%S")"" + +loss_mode=vanilla +adv_estimator=grpo +loss_agg_mode="token-mean" + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 32)) + +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Recommended setup: (train_prompt_bsz * n_resp_per_prompt) / n_machines <= 32, in respect of SGLang expert_recorder's unstable recording when inference bsz is bigger than 32 +train_prompt_bsz=16 +train_prompt_mini_bsz=16 +n_resp_per_prompt=8 + + + +# for TIS +imp_ratio_cap=-1 + + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/recipe/moe/runtime_env.yaml"} +# NNODES=${NNODES:-4} +NNODES=${NNODES:-$(sort -u /etc/mpi/hostfile | wc -l)} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +project_name='DAPO-Qwen3-MOE-30B-FSDP-RoutingReplay' +info_tag="" +exp_name='qwen3moe-'${loss_mode}-${train_prompt_bsz}_${train_prompt_mini_bsz}'-n-'${n_resp_per_prompt}'-len-'${max_response_length}-${info_tag} + + +# Paths +MODEL_PATH=path/to/Qwen3-30B-A3B +CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/ckpts_hdd/${project_name}/${exp_name}"} +OUTPUTS_ROLLOUT_DIR=${OUTPUTS_ROLLOUT_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/rollout/"} +OUTPUTS_VALIDATION_DIR=${OUTPUTS_VALIDATION_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/validate/"} +mkdir -p $OUTPUTS_ROLLOUT_DIR +mkdir -p $OUTPUTS_VALIDATION_DIR +mkdir -p "${WORKING_DIR}/logs" + +TRAIN_FILE="your_train_file.parquet" +TEST_FILE="your_test_file.parquet" + +echo $OUTPUTS_ROLLOUT_DIR + +# rollout +enable_routing_replay=True +rollout_mode="sync" +return_raw_chat="True" +rollout_name="sglang" +if [ "$rollout_mode" = "async" ]; then + # NOTE async rollout mode is not supported yet. + export VLLM_USE_V1=1 +fi + + + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +val_temperature=1.0 +val_top_p=0.7 +val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + + + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +offload=True +gen_tp=8 +fsdp_size=-1 # default -1 + +# Trade compute for memory +entropy_checkpointing=True +entropy_from_logits_with_chunking=True + +# export RAY_DEDUP_LOGS=0 +PYTHONUNBUFFERED=1 python3 -m recipe.moe_routing_replay.main_dapo --config-path=config \ + --config-name='dapo_trainer.yaml'\ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.return_raw_chat=${return_raw_chat} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \ + actor_rollout_ref.actor.entropy_from_logits_with_chunking=${entropy_from_logits_with_chunking} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_activation_offload=${offload} \ + critic.model.enable_activation_offload=${offload} \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=$rollout_name \ + actor_rollout_ref.rollout.mode=$rollout_mode \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.rollout.enable_routing_replay=${enable_routing_replay} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.rollout_data_dir="${OUTPUTS_ROLLOUT_DIR}" \ + trainer.validation_data_dir="${OUTPUTS_VALIDATION_DIR}" \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=25 \ + trainer.save_freq=50 \ + trainer.total_epochs=50 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=5 \ + actor_rollout_ref.nccl_timeout=60000 2>&1 | tee logs/${project_name}_${exp_name}_$timestamp.log diff --git a/recipe/moe_routing_replay/runtime_env.yaml b/recipe/moe_routing_replay/runtime_env.yaml new file mode 100644 index 00000000000..13f4b2ba230 --- /dev/null +++ b/recipe/moe_routing_replay/runtime_env.yaml @@ -0,0 +1,5 @@ +working_dir: ./ +excludes: ["/.git/"] +env_vars: + TORCH_NCCL_AVOID_RECORD_STREAMS: "1" + VLLM_USE_V1: "1" diff --git a/verl/protocol.py b/verl/protocol.py index e0b1affe8f1..fd564d91d9c 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -590,6 +590,72 @@ def from_tensordict( meta_info=meta_info, ) + + def _tensor_debug_info(self, t: torch.Tensor) -> str: + base = getattr(t, "_base", None) + try: + stride = tuple(t.stride()) + except Exception: + stride = "" + try: + off = t.storage_offset() + except Exception: + off = "" + return ( + f"shape={tuple(t.shape)}, stride={stride}, dtype={t.dtype}, device={t.device}, " + f"is_contiguous={t.is_contiguous()}, layout={t.layout}, storage_offset={off}, " + f"is_view={'yes' if base is not None else 'no'}" + ) + + def _safe_move_tensor(self, t: torch.Tensor, device: str | torch.device) -> torch.Tensor: + # 1) try normal move + try: + return t.to(device) + except Exception as e: + print(f"[_safe_move_tensor][diagnose] normal move failed: {e}") + pass + # 2) detach + contiguous + try: + return t.detach().contiguous().to(device) + except Exception as e: + print(f"[_safe_move_tensor][diagnose] no cloning move failed: {e}. Falling back to clone") + pass + # 3) materialize fresh storage then move + return t.detach().contiguous().clone().to(device) + + def _safely_move_tensordict(self, td, device: str | torch.device): + """Move a TensorDict to device with graceful fallback and diagnostics.""" + # Fast path (prefer copy=True when available to avoid storage rebind bugs) + try: + return td.to(device, copy=True) # tensordict>=0.4 + except TypeError: + # older versions don't support copy=... + pass + except Exception as e: + print(f"[diagnose] td.to(copy=True, device={device!s}) failed: {e}") + + try: + return td.to(device) # try fast zero-copy path + except Exception as e: + print(f"[diagnose] td.to({device!s}) failed: {e}\n[diagnose] per-key fallback...") + + # Per-key move with detailed logging for the failing entry + # from tensordict import TensorDict + moved = {} + for k, v in td.items(): + if torch.is_tensor(v): + try: + moved[k] = self._safe_move_tensor(v, device) + except Exception as ke: + print(f"[diagnose] key {k!r} move failed: {ke}") + print(f"[diagnose] key {k!r} info: {self._tensor_debug_info(v)}") + # last-resort attempt: clone then move (already tried inside _safe_move_tensor) + raise + else: + # Non-tensors are carried as-is + moved[k] = v + return TensorDict(moved, batch_size=td.batch_size, device=device) + def to(self, device) -> "DataProto": """move the batch to device @@ -601,7 +667,15 @@ def to(self, device) -> "DataProto": """ if self.batch is not None: - self.batch = self.batch.to(device) + try: + self.batch = self.batch.to(device) + except RuntimeError as e: + try: + # TODO cx_note add condition to only consolidate if is to CPU, or it is said to be time-consuming on GPU + # self.batch = self.batch.contiguous().consolidate() # TODO cx_note will this cause memory leak of the old batch? check out "/nlp_group/sunchenxi/qwen_moe/logs/qwen3_moe_2025-10-28-23:08:51.log" for log of enabling it (searching is_view=no and you'll find all no views) + self.batch = self.batch.to(device) + except RuntimeError as re: + self.batch = self._safely_move_tensordict(self.batch, device) return self def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto": diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 3ec59760a05..53ea5316761 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1204,6 +1204,7 @@ def fit(self): with marked_timer("update_actor", timing_raw, color="red"): batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable actor_output = self.actor_rollout_wg.update_actor(batch) + assert self.actor_rollout_wg._routing_cache == {} and self.actor_rollout_wg._routing_refs == {} and self.actor_rollout_wg._routing_prepared_batches == set(), f"self.actor_rollout_wg._routing_cache of len {len(self.actor_rollout_wg._routing_cache)} ={self.actor_rollout_wg._routing_cache}, self.actor_rollout_wg._routing_refs of len {len(self.actor_rollout_wg._routing_refs)} = {self.actor_rollout_wg._routing_refs == {}}, self.actor_rollout_wg._routing_prepared_batches of len {len(self.actor_rollout_wg._routing_prepared_batches)} = {self.actor_rollout_wg._routing_prepared_batches}" actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index b65d94ec14d..35ba550bcea 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -56,7 +56,7 @@ def default_compute_score( # from . import math_verify # res = math_verify.compute_score(solution_str, ground_truth) - elif data_source in ["math_dapo", "math", "math_dapo_reasoning"] or data_source.startswith("aime"): + elif data_source in ["math_dapo", "math", "math_dapo_reasoning"] or data_source.startswith("aime") or data_source in ['math', 'aime2024', 'math500', 'aime2025', 'aime2023', 'amc23', 'dapo-math']: from . import math_dapo res = math_dapo.compute_score(solution_str, ground_truth) diff --git a/verl/utils/reward_score/math_dapo.py b/verl/utils/reward_score/math_dapo.py index 940500fd59e..c7e03b465d4 100644 --- a/verl/utils/reward_score/math_dapo.py +++ b/verl/utils/reward_score/math_dapo.py @@ -1,3 +1,277 @@ +# # Copyright 2024 Bytedance Ltd. and/or its affiliates +# # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + +# import re +# from typing import Optional + + +# def last_boxed_only_string(string: str) -> Optional[str]: +# """Extract the last LaTeX boxed expression from a string. + +# Args: +# string: Input string containing LaTeX code + +# Returns: +# The last boxed expression or None if not found +# """ +# idx = string.rfind("\\boxed{") +# if idx < 0: +# return None + +# i = idx +# right_brace_idx = None +# num_left_braces_open = 0 + +# while i < len(string): +# if string[i] == "{": +# num_left_braces_open += 1 +# if string[i] == "}": +# num_left_braces_open -= 1 +# if num_left_braces_open == 0: +# right_brace_idx = i +# break +# i += 1 + +# return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None + + +# def remove_boxed(s: str) -> str: +# """Remove the LaTeX boxed command from a string. + +# Args: +# s: String with format "\\boxed{content}" + +# Returns: +# The content inside the boxed command +# """ +# left = "\\boxed{" +# assert s[: len(left)] == left, f"box error: {s}" +# assert s[-1] == "}", f"box error: {s}" +# return s[len(left) : -1] + + +# # Constants for normalization +# SUBSTITUTIONS = [ +# ("an ", ""), +# ("a ", ""), +# (".$", "$"), +# ("\\$", ""), +# (r"\ ", ""), +# (" ", ""), +# ("mbox", "text"), +# (",\\text{and}", ","), +# ("\\text{and}", ","), +# ("\\text{m}", "\\text{}"), +# ] + +# REMOVED_EXPRESSIONS = [ +# "square", +# "ways", +# "integers", +# "dollars", +# "mph", +# "inches", +# "hours", +# "km", +# "units", +# "\\ldots", +# "sue", +# "points", +# "feet", +# "minutes", +# "digits", +# "cents", +# "degrees", +# "cm", +# "gm", +# "pounds", +# "meters", +# "meals", +# "edges", +# "students", +# "childrentickets", +# "multiples", +# "\\text{s}", +# "\\text{.}", +# "\\text{\ns}", +# "\\text{}^2", +# "\\text{}^3", +# "\\text{\n}", +# "\\text{}", +# r"\mathrm{th}", +# r"^\circ", +# r"^{\circ}", +# r"\;", +# r",\!", +# "{,}", +# '"', +# "\\dots", +# ] + + +# def normalize_final_answer(final_answer: str) -> str: +# """Normalize a final answer to a quantitative reasoning question. + +# Args: +# final_answer: The answer string to normalize + +# Returns: +# Normalized answer string +# """ +# final_answer = final_answer.split("=")[-1] + +# # Apply substitutions and removals +# for before, after in SUBSTITUTIONS: +# final_answer = final_answer.replace(before, after) +# for expr in REMOVED_EXPRESSIONS: +# final_answer = final_answer.replace(expr, "") + +# # Extract and normalize LaTeX math +# final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) +# final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) +# final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) +# final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) +# final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + +# # Normalize shorthand TeX: +# # \fracab -> \frac{a}{b} +# # \frac{abc}{bef} -> \frac{abc}{bef} +# # \fracabc -> \frac{a}{b}c +# # \sqrta -> \sqrt{a} +# # \sqrtab -> sqrt{a}b +# final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) +# final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) +# final_answer = final_answer.replace("$", "") + +# # Normalize numbers +# if final_answer.replace(",", "").isdigit(): +# final_answer = final_answer.replace(",", "") + +# return final_answer.strip() + + +# def is_correct_minerva( +# solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" +# ) -> tuple[bool, str]: +# """Check if the solution is correct according to Minerva criteria. + +# Args: +# solution_str: The solution string to check +# gt: The ground truth answer +# gt_need_extract: Whether the ground truth needs extraction +# answer_pattern: Regex pattern to extract the answer + +# Returns: +# Tuple of (is_correct, normalized_prediction) +# """ +# # Extract answer from solution +# match = re.findall(answer_pattern, solution_str) +# extracted_answer = match[-1] if match else "[INVALID]" +# pred = normalize_final_answer(extracted_answer) + +# # Process ground truth +# if gt_need_extract: +# gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt))) +# else: +# gt = normalize_final_answer(gt) + +# return (pred == gt), pred + + +# def is_correct_strict_box( +# pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None +# ) -> tuple[int, Optional[str]]: +# """Check if the prediction is correct using strict boxed answer criteria. + +# Args: +# pred: The prediction string +# gt: The ground truth answer +# pause_tokens_index: Indices of pause tokens + +# Returns: +# Tuple of (score, extracted_prediction) +# """ +# # Extract the relevant part of the prediction +# if pause_tokens_index is not None: +# assert len(pause_tokens_index) == 4 +# pred = pred[pause_tokens_index[-1] - 100 :] +# else: +# pred = pred[-100:] + +# # Extract and check the boxed answer +# boxed_pred = last_boxed_only_string(pred) +# extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None + +# return 1 if (extracted_pred == gt) else -1, extracted_pred + + +# def verify( +# solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None +# ) -> bool: +# """Verify if the solution is correct. + +# Args: +# solution_str: The solution string to verify +# answer: The ground truth answer +# strict_box_verify: Whether to use strict box verification +# pause_tokens_index: Indices of pause tokens + +# Returns: +# True if the solution is correct, False otherwise +# """ +# if strict_box_verify: +# correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index) +# return correct == 1, pred + +# correct, pred = is_correct_minerva(solution_str, answer) +# return correct, pred + + +# def compute_score( +# solution_str: str, +# ground_truth: str, +# strict_box_verify: bool = False, +# pause_tokens_index: Optional[list[int]] = None, +# ) -> float: +# """Compute the reward score for a solution. + +# Args: +# solution_str: The solution string +# ground_truth: The ground truth answer +# strict_box_verify: Whether to use strict box verification +# pause_tokens_index: Indices of pause tokens + +# Returns: +# Reward score (1.0 for correct, -1.0 for incorrect) +# """ +# # Limit solution length for efficiency +# solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + +# # Verify the solution +# correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + +# reward = 1.0 if correct else -1.0 +# acc = correct + +# return { +# "score": reward, +# "acc": acc, +# "pred": pred, +# } + + # Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,212 +288,13 @@ # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py import re -from typing import Optional - - -def last_boxed_only_string(string: str) -> Optional[str]: - """Extract the last LaTeX boxed expression from a string. - - Args: - string: Input string containing LaTeX code - - Returns: - The last boxed expression or None if not found - """ - idx = string.rfind("\\boxed{") - if idx < 0: - return None - - i = idx - right_brace_idx = None - num_left_braces_open = 0 - - while i < len(string): - if string[i] == "{": - num_left_braces_open += 1 - if string[i] == "}": - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - i += 1 - - return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None - - -def remove_boxed(s: str) -> str: - """Remove the LaTeX boxed command from a string. - - Args: - s: String with format "\\boxed{content}" - - Returns: - The content inside the boxed command - """ - left = "\\boxed{" - assert s[: len(left)] == left, f"box error: {s}" - assert s[-1] == "}", f"box error: {s}" - return s[len(left) : -1] - - -# Constants for normalization -SUBSTITUTIONS = [ - ("an ", ""), - ("a ", ""), - (".$", "$"), - ("\\$", ""), - (r"\ ", ""), - (" ", ""), - ("mbox", "text"), - (",\\text{and}", ","), - ("\\text{and}", ","), - ("\\text{m}", "\\text{}"), -] - -REMOVED_EXPRESSIONS = [ - "square", - "ways", - "integers", - "dollars", - "mph", - "inches", - "hours", - "km", - "units", - "\\ldots", - "sue", - "points", - "feet", - "minutes", - "digits", - "cents", - "degrees", - "cm", - "gm", - "pounds", - "meters", - "meals", - "edges", - "students", - "childrentickets", - "multiples", - "\\text{s}", - "\\text{.}", - "\\text{\ns}", - "\\text{}^2", - "\\text{}^3", - "\\text{\n}", - "\\text{}", - r"\mathrm{th}", - r"^\circ", - r"^{\circ}", - r"\;", - r",\!", - "{,}", - '"', - "\\dots", -] - - -def normalize_final_answer(final_answer: str) -> str: - """Normalize a final answer to a quantitative reasoning question. - - Args: - final_answer: The answer string to normalize - - Returns: - Normalized answer string - """ - final_answer = final_answer.split("=")[-1] - - # Apply substitutions and removals - for before, after in SUBSTITUTIONS: - final_answer = final_answer.replace(before, after) - for expr in REMOVED_EXPRESSIONS: - final_answer = final_answer.replace(expr, "") - - # Extract and normalize LaTeX math - final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) - final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) - - # Normalize shorthand TeX: - # \fracab -> \frac{a}{b} - # \frac{abc}{bef} -> \frac{abc}{bef} - # \fracabc -> \frac{a}{b}c - # \sqrta -> \sqrt{a} - # \sqrtab -> sqrt{a}b - final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) - final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) - final_answer = final_answer.replace("$", "") - - # Normalize numbers - if final_answer.replace(",", "").isdigit(): - final_answer = final_answer.replace(",", "") - - return final_answer.strip() - - -def is_correct_minerva( - solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" -) -> tuple[bool, str]: - """Check if the solution is correct according to Minerva criteria. - - Args: - solution_str: The solution string to check - gt: The ground truth answer - gt_need_extract: Whether the ground truth needs extraction - answer_pattern: Regex pattern to extract the answer - - Returns: - Tuple of (is_correct, normalized_prediction) - """ - # Extract answer from solution - match = re.findall(answer_pattern, solution_str) - extracted_answer = match[-1] if match else "[INVALID]" - pred = normalize_final_answer(extracted_answer) - - # Process ground truth - if gt_need_extract: - gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt))) - else: - gt = normalize_final_answer(gt) - - return (pred == gt), pred - - -def is_correct_strict_box( - pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None -) -> tuple[int, Optional[str]]: - """Check if the prediction is correct using strict boxed answer criteria. - - Args: - pred: The prediction string - gt: The ground truth answer - pause_tokens_index: Indices of pause tokens - - Returns: - Tuple of (score, extracted_prediction) - """ - # Extract the relevant part of the prediction - if pause_tokens_index is not None: - assert len(pause_tokens_index) == 4 - pred = pred[pause_tokens_index[-1] - 100 :] - else: - pred = pred[-100:] - - # Extract and check the boxed answer - boxed_pred = last_boxed_only_string(pred) - extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None - - return 1 if (extracted_pred == gt) else -1, extracted_pred - +import signal +from typing import Optional, Tuple, Dict, Any, Union +from math_verify import parse, verify as m_verify def verify( solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None -) -> bool: +) -> Tuple[bool, str]: """Verify if the solution is correct. Args: @@ -231,12 +306,23 @@ def verify( Returns: True if the solution is correct, False otherwise """ - if strict_box_verify: - correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index) - return correct == 1, pred - - correct, pred = is_correct_minerva(solution_str, answer) - return correct, pred + if 'boxed' not in solution_str[-300:]: + return False, "" + answer = str(answer) + + try: + solution_val = parse(solution_str[-300:]) + if "boxed" in answer: + gt_val = parse(answer) + else: + boxed_answer = "\\boxed{" + answer + "}" + gt_val = parse(boxed_answer) + if m_verify(solution_val, gt_val): + return True, "" + else: + return False, "" + except Exception as e: + return False, "" def compute_score( @@ -250,7 +336,7 @@ def compute_score( Args: solution_str: The solution string ground_truth: The ground truth answer - strict_box_verify: Whether to use strict box verification + config: Configuration object containing reward model settings pause_tokens_index: Indices of pause tokens Returns: @@ -264,9 +350,11 @@ def compute_score( reward = 1.0 if correct else -1.0 acc = correct + if pred is None: + pred = "" return { "score": reward, "acc": acc, "pred": pred, - } + } \ No newline at end of file diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 1204fa01768..6ceae9cf6d8 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -113,6 +113,12 @@ def _forward_micro_batch( ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + if micro_batch.get("routing_ids", None) is not None: + routing_ids = micro_batch["routing_ids"] # (bsz, seqlen, n_layer, k_expert) + routing_ids_flat = rearrange(routing_ids, "b s l k -> (b s) l k") # (b*s, l, k) + routing_ids_rmpad = index_first_axis(routing_ids_flat, indices) # (total_nnz, l, k) + routing_ids_rmpad = routing_ids_rmpad.unsqueeze(0) + # unpad the position_ids to align the rotary if position_ids.dim() == 3: position_ids_rmpad = ( @@ -147,12 +153,32 @@ def _forward_micro_batch( position_ids_rmpad=position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size, ) + + assert micro_batch.get("routing_ids", None) is not None, f"[ERROR] routing replay not implemented for vlm models." else: input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( input_ids_rmpad, position_ids_rmpad=position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size, ) + + if micro_batch.get("routing_ids", None) is not None: + B, T, L, K = routing_ids_rmpad.shape # B should be 1 + r = routing_ids_rmpad.permute(0, 2, 3, 1).contiguous() # (1, L, K, T) + r = r.view(L * K, T) # (B' = L*K, T) + + # Use the existing helpers unchanged (they expect 2D) + r, _, pad_size_r = ulysses_pad_and_slice_inputs( + r, position_ids_rmpad=None, sp_size=self.ulysses_sequence_parallel_size + ) # (L*K, T_local) + + T_local = r.size(1) + r = r.view(1, L, K, T_local).permute(0, 3, 1, 2).contiguous() # (1, T_local, L, K) + routing_ids_rmpad = r + + # Sanity: token axis must match other token-aligned inputs + assert routing_ids_rmpad.size(1) == input_ids_rmpad.size(1) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( input_ids_rmpad_rolled, position_ids_rmpad=None, @@ -171,6 +197,7 @@ def _forward_micro_batch( input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, + routing_ids=routing_ids_rmpad if micro_batch.get("routing_ids", None) is not None else None, **multi_modal_inputs, use_cache=False, **extra_args, @@ -249,6 +276,7 @@ def _forward_micro_batch( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + routing_ids=micro_batch.get("routing_ids", None), **multi_modal_inputs, use_cache=False, **extra_args, @@ -322,6 +350,9 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + if "routing_ids" in data.batch.keys(): + select_keys.append("routing_ids") + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) if use_dynamic_bsz: @@ -377,6 +408,9 @@ def update_policy(self, data: DataProto): # Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True if "rollout_is_weights" in data.batch.keys(): select_keys.append("rollout_is_weights") + # routing_ids will be added to batch when SGLang returns the routing matrix + if "routing_ids" in data.batch.keys(): + select_keys.append("routing_ids") has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 9a2514f4305..0779a1166cb 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -178,6 +178,8 @@ class RolloutConfig(BaseConfig): skip_tokenizer_init: bool = False + enable_routing_replay: bool = False + def __post_init__(self): """Validate the rollout config""" if self.expert_parallel_size > 1: diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index a457925fc73..f9be54d9b87 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -36,6 +36,9 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType +import ray +from typing import List, Dict, Tuple + try: # for torch 2.5+ from torch.distributed.tensor import DTensor @@ -90,6 +93,8 @@ from verl.workers.rollout import get_rollout_class from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager +from recipe.moe_routing_replay.qwen3_routing_model import CustomMoeForCausalLM + logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -266,6 +271,607 @@ def __init__(self, config: DictConfig, role: str, **kwargs): self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + + # --- routing caches / book-keeping (training-time only) --- + self._routing_refs: dict[str, "ray.ObjectRef"] = {} # rid -> ObjectRef (pins plasma lifetime) + self._routing_cache: dict[str, torch.Tensor] = {} # rid -> CUDA tensor [T, L, K] + self._routing_prepared_batches: set[str] = set() # batch ids we've prepared + + # ---------------------------------------------------------------------- + # Routing: helpers, prepare, fetch (NCCL full sequence), guard, cleanup + # ---------------------------------------------------------------------- + def _get_my_node_id(self) -> str: + """Ray node_id if running under Ray, else hostname.""" + try: + return ray.get_runtime_context().get_node_id() + except Exception: + import socket + return socket.gethostname() + + def _routing_batch_id(self, meta: dict) -> str: + """Stable batch id; prefer explicit uuid provided by rollout.""" + bid = meta.get("routing_batch_uuid") + if isinstance(bid, str) and bid: + return bid + rids = meta.get("routing_rid", []) + + try: + rids = rids.tolist() + except Exception as e: + raise e + # pass + bid = "RID|" + "|".join(map(str, rids)) + return bid + + def _gather_node_ids(self, group=None) -> List[str]: + """Return global_rank -> node_id (Ray node_id) for current group.""" + group = group or torch.distributed.group.WORLD + me = self._get_my_node_id() + world = torch.distributed.get_world_size(group) + all_nodes = [None] * world + torch.distributed.all_gather_object(all_nodes, me, group=group) + return all_nodes + + def _pick_owner_rank_for_node( + self, owner_node_id: str, ranks_in_group: List[int], node_id_per_rank: List[str] + ) -> Optional[int]: + cands = [r for r in ranks_in_group if node_id_per_rank[r] == owner_node_id] + return min(cands) if cands else None + + def _owner_register_refs_from_batch(self, meta: dict) -> None: + """Pin plasma lifetime by copying any present ObjectRef into a long-lived dict, keyed by rid.""" + rid_arr = meta.get("routing_rid", []) + ref_arr = meta.get("routing_ref", []) + for rid, ref in zip(rid_arr, ref_arr if len(ref_arr) == len(rid_arr) else [None] * len(rid_arr), strict=True): + if ref is not None and str(rid) not in self._routing_refs: + self._routing_refs[str(rid)] = ref + + def _owner_materialize_blocks_to_cuda(self, rid_list: List[str]) -> None: + """ + Owner rank only: for each rid not yet cached, ray.get the per-sample [T,L,K] from plasma + and keep it on CUDA for fast NCCL serving. + """ + if not rid_list: + return # non-owner: no-op + need = [rid for rid in rid_list if rid not in self._routing_cache] + if not need: + return + refs = [self._routing_refs[rid] for rid in need] + + np_tuples = ray.get(refs) # each is (ids_TLK, pos_arr, p2l_arr) + + for rid, tup in zip(need, np_tuples, strict=True): + ids_TLK_np = tup[0] # NumPy [T, L, K], token-major + ids_TLK = torch.from_numpy(ids_TLK_np).to("cuda", dtype=torch.long, non_blocking=True) # keep original storage dtype # TODO cx note check dtype + self._routing_cache[rid] = ids_TLK + + def _fetch_routing_block_nccl( + self, + owner_rank: int, + rid: str, + need_block: bool, + group=None, + ) -> torch.Tensor: + """ + NCCL exchange for a SINGLE rid (one block per rid): + non-owner -> send flag (int32: 1 if need, 0 if not) + if need==1 then recv T,L,K header and then [T,L,K] payload (torch.long) + owner -> for each peer, recv flag; if flag==1 send T,L,K then payload + Returns [T, L, K] on CUDA (dtype torch.long) for non-owner requester; empty tensor otherwise. + """ + group = group or torch.distributed.group.WORLD + rank = torch.distributed.get_rank(group) + dev = torch.device("cuda") + + need_flag = torch.tensor([1 if need_block else 0], device=dev, dtype=torch.long) + + if rank != owner_rank: + # send flag + torch.distributed.isend(need_flag, dst=owner_rank, group=group).wait() + if not need_block: + return torch.empty((0, 0, 0), device=dev, dtype=torch.long) + + # recv header T,L,K then payload + T = torch.empty(1, device=dev, dtype=torch.long) + L = torch.empty(1, device=dev, dtype=torch.long) + K = torch.empty(1, device=dev, dtype=torch.long) + torch.distributed.irecv(T, src=owner_rank, group=group).wait() + torch.distributed.irecv(L, src=owner_rank, group=group).wait() + torch.distributed.irecv(K, src=owner_rank, group=group).wait() + + out = torch.empty((int(T.item()), int(L.item()), int(K.item())), device=dev, dtype=torch.long) + torch.distributed.irecv(out, src=owner_rank, group=group).wait() + if out.numel() > 0: + self._routing_cache[rid] = out + return out + + # --- owner path --- + world = torch.distributed.get_world_size(group) + # tensor to send (header + payload) is constant for all requesters + rblk = self._routing_cache[rid] # [T,L,K], on CUDA, dtype may be uint8/16/32 + + # In case offloaded - JIT warm to cuda + if not rblk.is_cuda: + rblk = rblk.to("cuda", non_blocking=True) # JIT warm for NCCL + self._routing_cache[rid] = rblk # keep consistent + + T_dim, L_dim, K_dim = map(int, rblk.shape) + header_T = torch.tensor([T_dim], device=dev, dtype=torch.long) + header_L = torch.tensor([L_dim], device=dev, dtype=torch.long) + header_K = torch.tensor([K_dim], device=dev, dtype=torch.long) + payload = rblk.to(torch.long, non_blocking=True).contiguous() + + for peer in range(world): + if peer == owner_rank: + continue + flag = torch.empty(1, device=dev, dtype=torch.long) + torch.distributed.irecv(flag, src=peer, group=group).wait() + if int(flag.item()) == 0: + continue + torch.distributed.isend(header_T, dst=peer, group=group).wait() + torch.distributed.isend(header_L, dst=peer, group=group).wait() + torch.distributed.isend(header_K, dst=peer, group=group).wait() + torch.distributed.isend(payload, dst=peer, group=group).wait() + + # if need_block: + # # if need_block is True, this tensor is meant to be transforred for other ranks and don't need to be kept on the owner rank after syncing is done. + # del self._routing_cache[rid] + + return torch.empty(0, device=dev, dtype=torch.long) + + + def prepare_routing_for_actor_step(self, data: DataProto) -> DataProto: + """ + One-time per batch (idempotent): + 1) Pin any local ObjectRefs (rid->ref). + 2) All-gather (rid, ref) pairs and pick the first non-None ref per rid. + 3) Choose an owner rank per rid via 'routing_owner_node' colocation. + 4) On the owner rank: ray.get ONCE per rid; keep CUDA copy in self._routing_cache. + 5) Barrier. + """ + + meta = data.non_tensor_batch + if "routing_rid" not in meta or "routing_owner_node" not in meta: + return DataProto(meta_info={"routing_prepared": False}) + + batch_id = self._routing_batch_id(meta) + if batch_id in self._routing_prepared_batches: + return DataProto(meta_info={"routing_prepared": True}) + + # 1) pin lifetime for any refs visible on THIS rank (local shard) + self._owner_register_refs_from_batch(meta) + + # local (rid, ref) pairs + local_pairs = [] + rid_arr = meta["routing_rid"] + ref_arr = meta.get("routing_ref", []) + for rid, ref in zip(rid_arr, ref_arr if len(ref_arr) == len(rid_arr) else [None] * len(rid_arr), strict=True): + local_pairs.append((str(rid), ref)) + + # 2) global refs_by_rid + world = torch.distributed.get_world_size() + gathered: List[List[Tuple[str, "ray.ObjectRef"]]] = [None] * world # type: ignore + torch.distributed.all_gather_object(gathered, local_pairs) + + from itertools import chain + flat = list(chain.from_iterable(gathered)) + by_rid = {} + for rid, ref in flat: + if ref is not None: + by_rid.setdefault(rid, []).append(ref) + + for rid, refs in by_rid.items(): + # Ray ObjectRef supports equality by id; this is safe + if any(r != refs[0] for r in refs[1:]): + raise RuntimeError(f"[routing] Multiple distinct ObjectRefs found for rid={rid}") + + + refs_by_rid: dict[str, "ray.ObjectRef"] = {} + for pairs in gathered: + for rid, ref in pairs: + if ref is not None and rid not in refs_by_rid: + refs_by_rid[rid] = ref + + # 3) choose owner per rid using owner_node colocation + ranks = list(range(world)) + node_id_per_rank = self._gather_node_ids() + owner_nodes = meta["routing_owner_node"] # per-local-sample node ids + # build local list then global-union of (rid, owner_rank) + loc_owner_pairs: List[Tuple[str, int]] = [] + for node_id, rid in zip(owner_nodes, rid_arr, strict=True): + owner = self._pick_owner_rank_for_node(node_id, ranks, node_id_per_rank) + if owner is None: + raise RuntimeError( + f"[routing] No actor rank colocated with Ray node {node_id} for rid={rid}. " + "Cross-node ray.get is disallowed. Please schedule at least one actor rank on that node." + ) + loc_owner_pairs.append((str(rid), owner)) + + gathered_pairs: List[List[Tuple[str, int]]] = [None] * world # type: ignore + torch.distributed.all_gather_object(gathered_pairs, loc_owner_pairs) + rid_owner_map: dict[str, int] = {} + for pairs in gathered_pairs: + for rid, own in pairs: + rid_owner_map[rid] = own # consistent across ranks; last write ok + + # 4) owners ray.get once per rid + my_rank = torch.distributed.get_rank() + + # TODO DEBUGGING + #### Enforce in-node ray.get to avoid cross code ray.get in `_owner_materialize_blocks_to_cuda` #### + # after rid_owner_map is built + for rid, own in rid_owner_map.items(): + if own != my_rank: + self._routing_refs.pop(rid, None) # keep only owner's handle + #################################################################################################### + + my_rids_to_materialize = [rid for rid, own in rid_owner_map.items() if own == my_rank] + # sanity: must have ObjectRef + missing = [rid for rid in my_rids_to_materialize if rid not in refs_by_rid] + if missing: + raise RuntimeError( + f"[routing] Owner rank {my_rank} lacks ObjectRef(s) for rid(s)={missing}; " + f"no rank provided a non-None routing_ref for those rids." + ) + # ensure owner rank actually has the ObjectRef locally + for rid in my_rids_to_materialize: + if rid not in self._routing_refs: + self._routing_refs[rid] = refs_by_rid[rid] + self._owner_materialize_blocks_to_cuda(my_rids_to_materialize) + + # 5) mesh-wide sync: owners are ready to serve + torch.distributed.barrier() + self._routing_prepared_batches.add(batch_id) + return DataProto(meta_info={"routing_prepared": True}) + + + def sync_routing_rows_to_local_cache(self, data: DataProto) -> DataProto: + """ + NCCL-sync stage (deadlock-safe, rid-only): + - Build global set of (owner_rank, rid) via all_gather_object so EVERY rank participates. + - For each (owner, rid): + * owner calls _fetch_routing_block_nccl(owner, rid, need_block=False) to serve. + * non-owners call it with need_block=(rid in my local shard AND rid not cached). + - Cache the received block as self._routing_cache[rid]. + Assumes prepare_routing_for_actor_step has already materialized owners' CUDA caches. + """ + meta = data.non_tensor_batch + if "routing_rid" not in meta or "routing_owner_node" not in meta: + return DataProto(meta_info={"routing_synced": False}) + + ranks = list(range(torch.distributed.get_world_size())) + node_id_per_rank = self._gather_node_ids() + owner_nodes = meta["routing_owner_node"] + rids = [str(r) for r in meta["routing_rid"]] + + my_rank = torch.distributed.get_rank() + + # local (owner, rid) pairs + local_pairs: List[Tuple[int, str]] = [] + for node_id, rid in zip(owner_nodes, rids, strict=True): + owner = self._pick_owner_rank_for_node(node_id, ranks, node_id_per_rank) + if owner is None: + raise RuntimeError( + f"[routing] No actor rank colocated with Ray node {node_id} for rid={rid}. " + "Cross-node ray.get is disallowed. Please schedule at least one actor rank on that node." + ) + local_pairs.append((owner, rid)) + + # global union of (owner, rid) + world = torch.distributed.get_world_size() + gathered: List[List[Tuple[int, str]]] = [None] * world # type: ignore + torch.distributed.all_gather_object(gathered, local_pairs) + global_pairs = sorted({p for sub in gathered for p in sub}) + + # Make sure owners finished materialization (should already be true) + torch.distributed.barrier() + print(f"[DEBUG][sync_routing_rows_to_local_cache] Start fetching inference routing maps via NCCL") + + # participate for EVERY pair to avoid deadlocks + local_rid_set = set(rids) + need_list = [] + for (owner_rank, rid) in global_pairs: + if owner_rank == my_rank: + _ = self._fetch_routing_block_nccl(owner_rank, rid, need_block=False) + else: + need = (rid in local_rid_set) and (rid not in self._routing_cache) + need_list.append(need) + # assert need==True + block = self._fetch_routing_block_nccl(owner_rank, rid, need_block=need) + # if need and block.numel() > 0: + # self._routing_cache[rid] = block # [T,L,K], int32 CUDA + + print(f"[DEBUG][sync_routing_rows_to_local_cache]: Fetching complete and wait at the barrier...") + + + torch.distributed.barrier() + return DataProto(meta_info={"routing_synced": True}) + + + def _as_list(self, x): + try: return x.tolist() + except Exception: return list(x) + + def _rids_from_data(self, data): + meta = data.non_tensor_batch + if "routing_rid" not in meta: return [] + return [str(r) for r in self._as_list(meta["routing_rid"])] + + def _move_cached_rids(self, rids, to: str, *, pin_cpu: bool = True): + for rid in rids: + t = self._routing_cache.get(rid) + if t is None: + continue + if to == "cpu" and t.is_cuda: + cpu = t.to("cpu", non_blocking=True) + self._routing_cache[rid] = cpu.pin_memory() if pin_cpu else cpu + elif to == "cuda" and not t.is_cuda: + self._routing_cache[rid] = t.to("cuda", non_blocking=True) + elif to not in ("cpu", "cuda"): + raise ValueError(f"bad target {to}") + + def _offload_cached_rids_for_data(self, data, *, drop_instead: bool = False): + rids = self._rids_from_data(data) + if not rids: + return + if drop_instead: + for rid in rids: + self._routing_cache.pop(rid, None) + else: + rids = [str(r) for r in list(self._routing_cache.keys())] + self._move_cached_rids(rids, "cpu", pin_cpu=True) + try: torch.cuda.empty_cache() + except: pass + + + + def _attach_full_routing_to_data(self, data: DataProto) -> None: + """ + Build and attach a dense routing tensor shaped [B_local, T_max, L, K] to data.batch["routing_ids"], + and lengths [B_local] to data.batch["routing_T_lens"]. + Assumes prepare_routing_for_actor_step + sync_routing_rows_to_local_cache completed. + """ + meta = data.non_tensor_batch + if "routing_rid" not in meta: + return + + # lists + def _as_list(x): + try: + return x.tolist() + except Exception: + return list(x) + + rid_arr = [str(r) for r in _as_list(meta["routing_rid"])] + shape_arr = _as_list(meta["routing_shape"]) # each item like (T, L, K) + + B_local = len(rid_arr) + if B_local == 0: + return + + # T may differ per rid; L,K must be consistent + shapes = [tuple(s) for s in shape_arr] + T_list = [int(tlk[0]) for tlk in shapes] + L_set = {int(tlk[1]) for tlk in shapes} + K_set = {int(tlk[2]) for tlk in shapes} + assert len(L_set) == 1 and len(K_set) == 1, "Routing L,K must be consistent across the batch." + L = L_set.pop(); K = K_set.pop() + # T_max = max(T_list) + T_max = data.batch['input_ids'].shape[1] # directly left padding to the same length of input_ids + + target_device = data.batch["input_ids"].device + out_ids = torch.full((B_local, T_max, L, K), 0, device=target_device, dtype=torch.long).contiguous() # TODO cx note: CRITICAL! default to -1 or 0? + out_len = torch.tensor(T_list, device=target_device, dtype=torch.long).contiguous() + + for i, rid in enumerate(rid_arr): + rblk = self._routing_cache.get(rid, None) + if rblk is None: + raise RuntimeError( + f"[routing] Missing cached routing block for rid={rid} on rank {torch.distributed.get_rank()}. " + f"Call prepare_routing_for_actor_step + sync_routing_rows_to_local_cache first." + ) + if rblk.device != target_device: + rblk = rblk.to(target_device, non_blocking=True) + + # assert int(data.batch['attention_mask'][i].sum()) == + assert data.batch['attention_mask'].shape[1] == data.batch['input_ids'].shape[1] == rblk.shape[0] + out_ids[i, :, :, :] = rblk.to(torch.long, non_blocking=True) + + # Ti = T_list[i] + + # # out_ids[i, :Ti, :, :] = rblk[:Ti, :, :].to(torch.long, non_blocking=True) # RIGHT padding + # start = T_max - Ti # number of padded steps on the left + # assert Ti == rblk.shape[0], f"n tokens {Ti} and routing matrix seq len {rblk.shape[0]} doesn't match" + # assert start >=0 and start+Ti<=out_ids.shape[1], f"start={start}, start+Ti={start+Ti}, out_ids.shape[1]={out_ids.shape[1]}" + # out_ids[i, start:start+Ti, :, :] = rblk[:Ti, :, :].to(torch.long, non_blocking=True) # TODO CRITICAL LEFT padding + + + # data.batch["routing_ids"] = out_ids.detach()#.contiguous()#.clone() # [B_local, T_max, L, K] + # data.batch["routing_T_lens"] = out_len.contiguous().detach()#.contiguous()#.clone() # [B_local] + + # cx NOTE make sure not attached a view, due to current cache offloading strategy + data.batch["routing_ids"] = out_ids.detach().clone() # [B_local, T_max, L, K] + data.batch["routing_T_lens"] = out_len.contiguous().detach().clone() # [B_local] + data.meta_info["routing_attached"] = True + + + def _ensure_routing_prepared(self, data: DataProto, tag='') -> DataProto: + """ + Guard that runs the 3-stage pipeline on demand: + 1) prepare_routing_for_actor_step + 2) sync_routing_rows_to_local_cache + 3) _attach_full_routing_to_data + """ + meta = data.non_tensor_batch + if "routing_rid" not in meta: + print(f"routing_rid not found. routing matrix syncing skipped.") + return data + + # TODO DEBUGGING ONLY + # snapshot rids before we mutate anything + def _to_list_str(x): + try: + return [str(y) for y in x.tolist()] + except Exception: + return [str(y) for y in list(x)] + + rid_before = None + if "routing_rid" in meta: + rid_before = _to_list_str(meta["routing_rid"]) + ################################################ + + bid = self._routing_batch_id(meta) + + if bid not in self._routing_prepared_batches: + print(f"[DEBUG][fsdp_workers]: Preparing routing matrices (owner materialization) for batch with id: {bid}") + _ = self.prepare_routing_for_actor_step(data) + else: + print(f"[DEBUG][fsdp_workers]: prepare stage already done. Skipped...") + + # TODO CRITICAL no syncing except for the first time + # TODO CRITICAL empty self.routing_cache etc once dataproto constructed? + _ = self.sync_routing_rows_to_local_cache(data) + + if not data.meta_info.get("routing_attached", False): + self._attach_full_routing_to_data(data) + + data.batch = data.batch.contiguous() + self._offload_cached_rids_for_data(data, drop_instead=False) # NOTE setting `drop_instead` can cause re-sync local caches + + ########################## TODO DEBUGGING ONLY ########################## + ########################## ASSERTIONS: (1)token num matches (2) rid doesn't change ########################## + # # After self._attach_full_routing_to_data(data) + # if "routing_T_lens" in data.batch and "attention_mask" in data.batch: + # # Prefer counting tokens in the prompt region to avoid response spillover + # if "prompts" in data.batch: + # prompt_len = data.batch["prompts"].size(-1) # left-padded width + # attn_prompt = data.batch["attention_mask"][..., :] + # else: + # # Fallback: assume routing covers exactly the left segment of input_ids + # # (only use if 'prompts' is absent) + # attn_prompt = data.batch["attention_mask"] + + # # Sum to per-sample lengths (int32 for fair compare) + # attn_lens = attn_prompt.sum(dim=-1).to(dtype=torch.int32, device=data.batch["routing_T_lens"].device) + # t_lens = data.batch["routing_T_lens"] + + # # Strict equality: routing tokens must match prompt token count + # if not torch.equal(attn_lens, t_lens): + # # Helpful debug: show first few mismatches + # mismatch = (attn_lens != t_lens).nonzero(as_tuple=False).flatten() + # idx = mismatch[:8].tolist() + # raise AssertionError( + # "[routing] routing_T_lens mismatch vs prompt attention sum: " + # f"indices={idx}, routing_T_lens={t_lens[idx].tolist()}, attn_sum={attn_lens[idx].tolist()}" + # ) + + # ---------------- DEBUG: routing_ids vs attention_mask alignment ---------------- + if "routing_ids" in data.batch and "attention_mask" in data.batch: + mask = data.batch["attention_mask"] # [B,T] + rid = data.batch["routing_ids"] # [B,T,L,K] + # reduce [L,K] -> whether any element is non-zero at each (b,t) + any_nonpad = (rid != 0).any(dim=(-1, -2)) # [B,T], bool + + # (1) padded tokens must be all zeros + bad_pad = (mask == 0) & any_nonpad # non-zero routing under mask==0 + if bad_pad.any(): + idx = bad_pad.nonzero(as_tuple=False)[:8].tolist() + raise AssertionError( + f"[routing] non-zero routing under mask==0 at {idx} (showing up to 8)" + ) + + # (2) active count must match: sum over T of mask==1 equals sum over T of any_nonpad==1 + attn_lens = mask.sum(dim=-1).to(dtype=torch.long) # [B] + routed_lens = any_nonpad.sum(dim=-1).to(dtype=torch.long) # [B] + if not torch.equal(attn_lens, routed_lens): + mismatch = (attn_lens != routed_lens).nonzero(as_tuple=False).flatten() + idx = mismatch[:8].tolist() + raise AssertionError( + f"[routing][rank {torch.distributed.get_rank()}] active length mismatch vs non-zero routing length: " + f"indices={idx}, attn_lens={attn_lens[idx].tolist()}, " + f"routed_lens={routed_lens[idx].tolist()}" + ) + + # Optional: also assert full-length T matches input_ids T for sanity + T_full_ids = data.batch["input_ids"].shape[1] + T_full_mask = mask.shape[1] + T_full_route = rid.shape[1] + if not (T_full_ids == T_full_mask == T_full_route): + raise AssertionError( + f"[routing] T_full mismatch: input_ids={T_full_ids}, " + f"mask={T_full_mask}, routing_ids={T_full_route}" + ) + # ------------------------------------------------------------------------------- + + # TODO CRITICAL add routing_ids assertion. 8 ids should be different. This should be enough to detect under/overflow if any + + # --- integrity checks: rids unchanged across _ensure pipeline --- + if rid_before is not None: + rid_after = _to_list_str(data.non_tensor_batch.get("routing_rid", [])) + + # length must be identical + if len(rid_before) != len(rid_after): + raise AssertionError( + "[routing] routing_rid length changed across _ensure_routing_prepared: " + f"before={len(rid_before)} after={len(rid_after)}" + ) + + # content & order must be identical + if rid_before != rid_after: + # help debug: show first few mismatches + bad = [i for i, (a, b) in enumerate(zip(rid_before, rid_after)) if a != b] + bad = bad[:8] + raise AssertionError( + "[routing] routing_rid content/order changed across _ensure_routing_prepared: " + f"indices={bad}, before={[rid_before[i] for i in bad]}, after={[rid_after[i] for i in bad]}" + ) + + # optional: uniqueness & companion-field length sanity + if len(set(rid_after)) != len(rid_after): + raise AssertionError("[routing] routing_rid contains duplicates") + + # keep companion fields in lockstep with rid count (helpful for sneaky shape bugs) + for k in ("routing_owner_node", "routing_shape", "routing_layout", "routing_dtype_ids", "routing_ref", "routing_b_idx"): + if k in data.non_tensor_batch and len(data.non_tensor_batch[k]) != len(rid_after): + raise AssertionError( + f"[routing] {k} length ({len(data.non_tensor_batch[k])}) != routing_rid length ({len(rid_after)})" + ) + + ######################################################################################################## + + # print(f"""[cx_debug][fsdp_workers] len(data)={len(data)}. data.meta_info["routing_attached"]={data.meta_info["routing_attached"]}. data.batch["routing_T_lens"].shape={data.batch["routing_T_lens"].shape} and data.batch['input_ids'].shape=""") + + # if 'routing_attached' in data.batch: + # data.batch.pop('routing_attached') + if 'routing_T_lens' in data.batch: # TODO + data.pop(batch_keys=['routing_T_lens']) + + # # TODO + # if 'meta_info' in data: + # print(f"data.meta_info popped as: {data.meta_info}") + # data.batch.pop('meta_info') + + return data + + + + def _cleanup_routing_for_batch(self, meta: dict) -> None: + """Drop CUDA cache and ObjectRefs for this batch to avoid (CUDA) OOM.""" + if "routing_rid" not in meta: + return + rid_arr = meta.get("routing_rid", []) + for rid in rid_arr: + self._routing_cache.pop(str(rid), None) # free CUDA + self._routing_refs.pop(str(rid), None) # allow plasma to reclaim + self._routing_prepared_batches.discard(self._routing_batch_id(meta)) + # assert self._routing_cache == {} and self._routing_refs == {} and self._routing_prepared_batches == set(), f"self._routing_cache of len {len(self._routing_cache)} ={self._routing_cache}, self._routing_refs of len {len(self._routing_refs)} = {self._routing_refs == {}}, self._routing_prepared_batches of len {len(self._routing_prepared_batches)} = {self._routing_prepared_batches}" + try: + torch.cuda.empty_cache() + except Exception as e: + pass + + + def _build_model_optimizer( self, model_path, @@ -860,9 +1466,35 @@ def init_model(self): checkpoint_config=checkpoint_contents, ) + + # TODO cx_note debugging only + def print_tensordict_info(self, td, tag=""): + for k, v in td.items(): + if isinstance(v, torch.Tensor): + base = getattr(v, "_base", None) + try: + stride = tuple(v.stride()) + except Exception: + stride = "" + try: + off = v.storage_offset() + except Exception: + off = "" + print( + f"[cx_debug][fsdp_workers][update_actor][print_tensordict_info][{tag}] " + f"{k:>24} | dtype={str(v.dtype):>8} | device={str(v.device):>8} | " + f"shape={tuple(v.shape)} | stride={stride} | contiguous={v.is_contiguous()} | " + f"layout={v.layout} | storage_offset={off} | is_view={'yes' if base is not None else 'no'}" + ) + else: + print(f"{k:>24} | non-tensor type={type(v).__name__}") + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): + # ensure routing prepared once (idempotent) + self._ensure_routing_prepared(data, tag='[update_actor]') # TODO debugging only + assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) @@ -901,6 +1533,9 @@ def update_actor(self, data: DataProto): offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) + # NOTE Cleaning up is coupled with the current training order, assuming that update_actor is the last function that requires the rouing_ids for routing replay + self._cleanup_routing_for_batch(data.non_tensor_batch) + return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) @@ -927,7 +1562,8 @@ def generate_sequences(self, prompts: DataProto): log_gpu_memory_usage("After switch to rollout mode", logger=logger) with simple_timer("generate_sequences", timing_generate): - output = self.rollout.generate_sequences(prompts=prompts) + # output = self.rollout.generate_sequences(prompts=prompts) + output = self.rollout.generate_sequences(prompts=prompts, device_id=get_device_id(), device=get_torch_device()) if self._is_actor: loop.run_until_complete(self.trainer_mode()) @@ -956,6 +1592,9 @@ def generate_sequences(self, prompts: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") def compute_log_prob(self, data: DataProto): + owner_node_id = ray.get_runtime_context().get_node_id() # stable per Ray node + self._ensure_routing_prepared(data, tag='[compute_log_prob]') + # when is_lora is True, we use the actor without lora applied to calculate the log_prob # which is mostly used for ref log_prob calculation assert self._is_actor @@ -997,6 +1636,9 @@ def compute_log_prob(self, data: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: DataProto): + # ensure routing prepared once (idempotent) + self._ensure_routing_prepared(data, 'compute_ref_log_prob') + if self._is_lora: # if _is_lora, actor without lora applied is the ref data.meta_info["is_lora"] = True diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index ecf622406f9..a2564efe476 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -591,7 +591,42 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: """ if self.config.multi_turn.enable: return self._req_level_generate_sequences(prompts, **kwargs) - return self._batch_level_generate_sequences(prompts, **kwargs) + + device_id = kwargs.pop('device_id', "no device_id passed in") + device = kwargs.pop('device', "no device passed in") + + # validation without routing matrix + if prompts.meta_info.get("validate", False): + res = self._batch_level_generate_sequences(prompts, **kwargs) + # dist.barrier() + return res + + if not self.config.enable_routing_replay: + print(f"Routing replay not enabled.") + return self._batch_level_generate_sequences(prompts, **kwargs) + + step_bsz = 32 # NOTE Curretnly, SGLang expert_distribution_recorder is not functioning as expected when per entry bsz > 32 + if len(prompts) <= step_bsz: + res = self._batch_level_generate_sequences_train(prompts, **kwargs) + # dist.barrier() + return res + + prompts = prompts.to('cpu') + data_sliced = [prompts[start:start+step_bsz] for start in range(0, len(prompts), step_bsz)] + generations_list = [] + + # for idx in range(0, len(data_sliced)): + while (len(data_sliced) > 0): + data_cpu = data_sliced.pop(0) + data_gpu = data_cpu.to(device_id) + res = self._batch_level_generate_sequences_train(data_gpu, **kwargs).to('cpu') + generations_list.append(res) + del data_gpu + del data_cpu + # device.empty_cache() + + # dist.barrier() + return DataProto.concat(generations_list)#.to(device_id) # NOTE we don't need to move back to GPU according to later exeuction logic in fsdp_workers.py `generate_sequences`` function @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() @@ -804,6 +839,420 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + + def _pad_and_put_routing_from_mask( + self, + ids_per_row, # list of tuples: (ids_TLK, pos_arr, p2l_arr), ids_TLK shape (Ti, L, K) + attention_mask: torch.Tensor,# [B, T_full], final mask after concat + *, + logger=None, + ): + """ + For each row b: + - Scatter ids_TLK (Ti, L, K) into a zero-initialized (T_full, L, K) array + at positions where attention_mask[b] == 1 (in order). + - Because Ti == sum(mask==1) (by your observation), this is a 1:1 fill. + - Returns (refs_per_row, shape_list) where each ref is + ray.put((ids_padded_TLK, pos_arr, p2l_arr)) and shape_list[b]=(T_full,L,K). + """ + + B, T_full = attention_mask.shape + refs_per_row = [None] * B + shape_list = [None] * B + + mask_cpu = attention_mask.detach().to("cpu") + + for b in range(B): + ids_TLK, pos_arr, p2l_arr = ids_per_row[b] # ids_TLK: (Ti, L, K) + if ids_TLK is None: + refs_per_row[b] = None + continue + + if not isinstance(ids_TLK, np.ndarray) or ids_TLK.ndim != 3: + raise ValueError(f"[routing] ids_TLK[{b}] must be np.ndarray with shape (Ti,L,K), got {type(ids_TLK)} {getattr(ids_TLK,'shape',None)}") + + Ti, L, K = ids_TLK.shape + active_idx = (mask_cpu[b] != 0).nonzero(as_tuple=False).squeeze(1).numpy() # (Ti,) + n_active = int(active_idx.size) + + if n_active != Ti: + raise ValueError(f"[error][routing][rank {torch.distributed.get_rank()}] row {b}: Ti={Ti} != active_mask_tokens={n_active}") + + # Allocate final frame and scatter + out = np.zeros((int(T_full), int(L), int(K)), dtype=ids_TLK.dtype) + n_fill = min(Ti, n_active) + if n_fill > 0: + out[active_idx[:n_fill], :, :] = ids_TLK[:n_fill] + + out = np.ascontiguousarray(out) + # Put tuple to plasma and remember final shape + refs_per_row[b] = ray.put((out, pos_arr, p2l_arr)) + shape_list[b] = (out.shape[0], L, K) + + return refs_per_row, shape_list + + + + @GPUMemoryLogger(role="sglang rollout", logger=logger) + @torch.no_grad() + def _batch_level_generate_sequences_train(self, prompts: DataProto, **kwargs) -> DataProto: + """Generates single-turn sequences for a batch of prompts. + For single-turn generation, all prompts are processed in one request. + `_batch_level_generate_sequences` involves: + 1. Extracting and pre-processing prompt token IDs from the input + `prompts`. This includes handling padding and preparing raw + token ID lists. + 2. Preparing inputs for the SGLang engine, including multi-modal + data if present. + 3. Invoking the SGLang engine (`self._engine.async_generate`, + an async coroutine) with the batch of processed inputs and + specified sampling parameters on the master TP rank. + 4. Broadcasting the results from the master TP rank to all + other TP ranks. + 5. Post-processing the engine's output to format the generated + token IDs and (if applicable) log probabilities. + 6. Constructing the final sequences by concatenating original + prompts with the generated responses. + 7. Updating attention masks and position IDs to reflect the full + concatenated sequences. + 8. If `self.config.free_cache_engine` is true, the SGLang engine's + KV cache is flushed after generation on the master TP rank. + Args: + prompts: A `DataProto` object containing the batch of + input prompts, including tensor data (like `input_ids`, + `attention_mask`) and meta-information (like `eos_token_id`, + `do_sample`). + **kwargs: Additional keyword arguments that can override the + default sampling parameters (e.g., `temperature`, `top_p`, + `max_new_tokens`). These are temporarily applied using + `update_sampling_params`. + Returns: + DataProto: A `DataProto` object containing the batch of + generated sequences. This includes tensors for `prompts` + (original input IDs), `responses` (generated token IDs), + `input_ids` (concatenated prompt and response), + `attention_mask`, and `position_ids` for the full + sequences. + Note that in GRPO, if the prompts are validated, we repeat the prompts for rollout.n times in ray_trainer. + Thus we do not need to repeat the prompts here and set the sampling parameter n to 1. + """ + # input ids: (bs, prompt_length), left-padded + idx = prompts.batch["input_ids"] + # attention_mask: (bs, seq_length), left-padded + attention_mask = prompts.batch["attention_mask"] + position_ids = prompts.batch["position_ids"] + + # used to generate attention mask for the + # response based on EOS token position + eos_token_id = prompts.meta_info["eos_token_id"] + + batch_size = idx.size(0) + + # Extract non-tensor data + non_tensor_batch = prompts.non_tensor_batch + if "raw_prompt_ids" not in non_tensor_batch: + non_tensor_batch["raw_prompt_ids"] = np.array( + [_pre_process_inputs(self.pad_token_id, idx[i]).tolist() for i in range(batch_size)], + dtype=object, + ) + + if "multi_modal_data" in non_tensor_batch: + sglang_inputs = [] + for raw_prompt_ids, multi_modal_data in zip( + non_tensor_batch.pop("raw_prompt_ids"), + non_tensor_batch.pop("multi_modal_data"), + strict=True, + ): + sglang_inputs.append( + { + "prompt_token_ids": raw_prompt_ids, + "multi_modal_data": multi_modal_data, + "image_data": ( + multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None + ), + } + ) + else: + sglang_inputs = [ + {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + ] + + for input_data in sglang_inputs: + # Ensure token IDs are lists or numpy arrays + if not isinstance(input_data["prompt_token_ids"], list | np.ndarray): + raise TypeError( + f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}" + ) + + input_data["prompt_token_ids"] = list(input_data["prompt_token_ids"]) + + # Extract token IDs and image data for SGLang Engine + idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs] + image_list = [input_data.get("image_data", None) for input_data in sglang_inputs] + + do_sample = prompts.meta_info.get("do_sample", True) + is_validate = prompts.meta_info.get("validate", False) + + + + # --- initialize locals for all ranks --- + routing_meta = None + routing_locals = None + + # Create request-level sampling parameters + request_sampling_params = self.sampling_params.copy() + if not do_sample: + request_sampling_params.update( + { + "n": 1, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + "temperature": 0, + "top_p": 1, + "top_k": -1, + "ignore_eos": False, + "min_new_tokens": 0, + "max_new_tokens": self.config.response_length, + "skip_special_tokens": True, + "spaces_between_special_tokens": True, + } + ) + elif is_validate: + request_sampling_params.update( + { + "top_k": self.config.val_kwargs.top_k, + "top_p": self.config.val_kwargs.top_p, + "temperature": self.config.val_kwargs.temperature, + "n": 1, # if validate, already repeat in ray_trainer + } + ) + + # Update with any additional kwargs + request_sampling_params.update(kwargs) + # TODO cx NOTE to be deleted + # request_sampling_params.update({'stop_token_ids': [151645]}) + + if self._tp_rank == 0: + loop = asyncio.get_event_loop() + output = loop.run_until_complete( + self._engine.async_generate( + prompt=None, # because we have already convert it to prompt token id + sampling_params=request_sampling_params, + return_logprob=True, + input_ids=idx_list, + image_data=image_list, + # return_expert_routing=(not is_validate), + return_expert_routing=True, + ) + ) + + owner_node = ray.get_runtime_context().get_node_id() # stable per Ray node + + B = batch_size + # refs_per_row = [None] * B + ids_per_row = [None] * B + rid_list = [None] * B + shape_list = [None] * B + # dtype_ids_list = [None] * B # store dtype per row (usually 'uint16') + # TODO CRITICAL move to tensor bfloat16? + dtype_ids = np.int64 # TODO cx note CRITICAL used to be uint8. modify to int64 but still witnessed over/underflow in qwen3_routing_model.py + layout = "T_L_K" # token-major for fast row fetch (ids[t] -> [L,K]) + + for b in range(B): + # TODO please review the np types + r = output[b]["meta_info"]["moe_routing"] + del output[b]["meta_info"]["moe_routing"] + + # Required meta + rid = r["rid"] # canonical UUID string from SGLang (e.g., "0e3b1f...-...") + L = int(r["shape"]["num_layers"]) + T = int(r["shape"]["num_tokens"]) + K = int(r["shape"]["top_k"]) + rid_list[b] = rid + shape_list[b] = (T, L, K) + + # 1) Expert IDs: list [L, T, K] -> ndarray [T, L, K] (token-major), compact dtype (seq_len, n_layers, n_experts) + ids = np.asarray(r["topk_ids_of_layer"], dtype=dtype_ids) # safe upcast first + if ids.shape != (L, T, K): + raise ValueError(f"topk_ids_of_layer shape mismatch, expected {(L,T,K)}, got {ids.shape}") + + ids_TLK = np.ascontiguousarray(np.transpose(ids, (1, 0, 2))) # [T, L, K], aka [seq_len, n_layers, k_experts_selected] + + # 2) positions (optional) -> int32 contiguous + pos = r.get("positions", None) + if pos is None: + pos_arr = np.empty((0,), dtype=np.int32) + else: + pos_arr = np.ascontiguousarray(np.asarray(pos, dtype=np.int32)) + + # 3) physical_to_logical_map [L, E] (tiny) -> int16 contiguous + p2l = r.get("physical_to_logical_map", None) + if p2l is None: + p2l_arr = np.empty((0, 0), dtype=np.int64) + else: + p2l_arr = np.ascontiguousarray(np.asarray(p2l, dtype=np.int32)).astype(np.int64, copy=False) + + # Put one plasma object per row (tuple of arrays). Arrays are zero-copy in plasma. + ids_per_row[b] = (ids_TLK, pos_arr, p2l_arr) + + + # Tiny meta to broadcast (NO ObjectRefs) + routing_meta = { + "owner_node": owner_node, # <-- add this + # "owner_rank": owner_rank, # (optional/legacy; not used for training) + "rid": rid_list, # list[str], len B + "shape": shape_list, # list[tuple(T,L,K)], len B + "layout": layout, # 'T_L_K' + "dtype_ids": np.dtype(dtype_ids).name, # list[str], e.g., 'uint16' + } + + else: + output = None + + # Most naive implementation, can extract tensor and send via gloo if too slow + dist.barrier() + # dist.barrier(self._device_mesh_cpu["tp"].get_group()) + [output] = broadcast_pyobj( + data=[output], + rank=self._rank, + dist_group=self._device_mesh_cpu["tp"].get_group(), + src=self._device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, + ) + out = _post_process_outputs(self.processing_class, output) + + response = out[0].to(idx.device) + rollout_log_probs = None + if self.config.calculate_log_probs: + rollout_log_probs = out[1].to(idx.device) + + if response.shape[1] < self.config.response_length: + response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) + if self.config.calculate_log_probs: + rollout_log_probs = pad_sequence_to_length( + rollout_log_probs, self.config.response_length, self.pad_token_id + ) + + seq = torch.cat([idx, response], dim=-1) + + response_length = response.size(1) + delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) + if position_ids.dim() == 3: # qwen2vl mrope + delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) + + # TODO(sgm): fix position_ids on right_pad + # prompt: left pad + response: right pad + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + response_position_ids = position_ids[..., -1:] + delta_position_id + position_ids = torch.cat([position_ids, response_position_ids], dim=-1) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) + attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) + + + + if self._tp_rank == 0: + # TODO pad the routing ids in ids_per_row using the attention_mask, then ray put each tup + # Pad routing to match final attention_mask, then ray.put each row + refs_per_row, padded_shapes = self._pad_and_put_routing_from_mask( + ids_per_row=ids_per_row, + attention_mask=attention_mask, # final mask (prompt + response) + logger=logger, + ) + # Ensure meta shape reflects the padded T_full so training sees consistent lengths + routing_meta["shape"] = padded_shapes + + # --------------------------- + # 3) TP sync and broadcast ONLY tiny routing_meta (NO ObjectRef) + # --------------------------- + dist.barrier() + # dist.barrier(self._device_mesh_cpu["tp"].get_group()) + [routing_meta] = broadcast_pyobj( + data=[routing_meta], + rank=self._rank, + dist_group=self._device_mesh_cpu["tp"].get_group(), + src=self._device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, + ) + + + + # all the tp ranks should contain the same data here. data in all ranks are valid + batch = TensorDict( + { + "prompts": idx, + "responses": response, + "input_ids": seq, # here input_ids become the whole sentences + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=batch_size, + ) + if self.config.calculate_log_probs: + # we will recompute old log prob with actor + batch["rollout_log_probs"] = rollout_log_probs + + # free cache engine + if self._engine is not None and self._tp_rank == 0: + loop = asyncio.get_event_loop() + loop.run_until_complete(self._engine.flush_cache()) + + + + + # --------------------------- + # 5) Build per-sample arrays for DataProto.non_tensor_batch + # --------------------------- + B = batch_size + + # Make sure refs_per_row exists on non-TP0 ranks + if self._tp_rank != 0 or 'refs_per_row' not in locals() or refs_per_row is None: + refs_per_row = [None] * batch_size + + # Normalize dtype string (optional; keeps it readable like "uint8") + dtype_ids_str = routing_meta.get("dtype_ids", "int64") + # If someone passed a class repr like "", keep it as-is; it's fine for logging. + + # owner_arr = np.full((batch_size,), routing_meta["owner_rank"], dtype=np.int32) # TODO notice type volume + owner_node_arr = np.empty((B,), dtype=object) + owner_node_arr.fill(routing_meta["owner_node"]) + + rid_arr = np.asarray(routing_meta["rid"], dtype=object) + + shape_arr = np.empty((batch_size,), dtype=object) + shape_arr[:] = [tuple(s) for s in routing_meta["shape"]] + + layout_arr = np.empty((batch_size,), dtype=object) + layout_arr.fill(routing_meta.get("layout", "T_L_K")) + + dtype_ids_arr = np.empty((batch_size,), dtype=object) + dtype_ids_arr.fill(dtype_ids_str) + + ref_arr = np.empty((batch_size,), dtype=object) + if self._tp_rank == 0: + for i in range(batch_size): + ref_arr[i] = refs_per_row[i] + else: + ref_arr.fill(None) + + b_idx_arr = np.arange(batch_size, dtype=np.int64) # TODO no use if storing as individual obj and rid is unique; but future proof for switching to as a single large tensor + + # non_tensor_batch["routing_owner_rank"] = owner_arr + non_tensor_batch["routing_owner_node"] = owner_node_arr + non_tensor_batch["routing_rid"] = rid_arr + non_tensor_batch["routing_shape"] = shape_arr # (T, L, K) per row + non_tensor_batch["routing_layout"] = layout_arr # 'T_L_K' + non_tensor_batch["routing_dtype_ids"] = dtype_ids_arr # 'uint8'/'uint16'/'uint32' + non_tensor_batch["routing_ref"] = ref_arr # ObjectRef on TP-0; None elsewhere + non_tensor_batch["routing_b_idx"] = b_idx_arr + + payload = DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + return payload + async def _async_rollout_a_request( self, req: AsyncRolloutRequest,