From a8027730aa042341803a399685a77c69f2427d16 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 01:58:44 +0000 Subject: [PATCH 01/31] previous code Co-authored-by: Dev Patel --- .../skyrl_train/inference_engines/base.py | 5 ++ .../inference_engines/vllm/vllm_engine.py | 20 +++++- skyrl/backends/skyrl_train/training_batch.py | 1 + .../skyrl_train/utils/replay_utils.py | 62 +++++++++++++++++++ .../workers/megatron/megatron_worker.py | 11 ++++ skyrl/train/config/config.py | 1 + .../train/config/megatron_config/policy.yaml | 1 + skyrl/train/config/ppo_base_config.yaml | 1 + skyrl/train/dataset/preprocess.py | 10 ++- skyrl/train/entrypoints/main_base.py | 1 + skyrl/train/generators/base.py | 1 + skyrl/train/generators/skyrl_gym_generator.py | 5 ++ skyrl/train/trainer.py | 4 ++ 13 files changed, 120 insertions(+), 3 deletions(-) create mode 100644 skyrl/backends/skyrl_train/utils/replay_utils.py diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index 4b073da5a0..c3595ab7a9 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -29,6 +29,7 @@ class InferenceEngineOutput(TypedDict): response_ids: List[List[int]] stop_reasons: List[str] response_logprobs: Optional[List[List[float]]] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] class InferenceEngineInterface(ABC): @@ -63,6 +64,7 @@ async def sample( all_responses = [] all_stop_reasons = [] all_response_logprobs = [] + all_rollout_inference_indices = [] for _ in range(num_samples): input_batch: InferenceEngineInput = { @@ -79,12 +81,15 @@ async def sample( all_stop_reasons.append(output["stop_reasons"][0]) if output.get("response_logprobs") is not None: all_response_logprobs.append(output["response_logprobs"][0]) + if output.get("rollout_inference_indices") is not None: + all_rollout_inference_indices.append(output["rollout_inference_indices"][0]) return { "response_ids": all_response_ids, "responses": all_responses, "stop_reasons": all_stop_reasons, "response_logprobs": all_response_logprobs if all_response_logprobs else None, + "rollout_inference_indices": all_rollout_inference_indices if all_rollout_inference_indices else None, } @abstractmethod diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 1123088cb6..a0fa9ff4c6 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -135,6 +135,7 @@ def _postprocess_outputs(self, outputs): stop_reasons: List[str] = [] response_ids: List[List[int]] = [] response_logprobs: Optional[List[List[float]]] = [] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] = [] for output in outputs: # TODO(tgriggs): Support n>1 sampling. @@ -156,14 +157,25 @@ def _postprocess_outputs(self, outputs): del token_logprobs response_logprobs.append(_logprobs) + if resp.routed_experts is not None: + if hasattr(resp.routed_experts, "tolist"): + routed_experts_list = resp.routed_experts.tolist() + else: + routed_experts_list = resp.routed_experts + rollout_inference_indices.append(routed_experts_list) + if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params + if len(rollout_inference_indices) == 0: + rollout_inference_indices = None + return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs, + rollout_inference_indices=rollout_inference_indices, ) def _get_engine(self): @@ -320,7 +332,13 @@ def _create_engine(self, *args, **kwargs): enable_ray_prometheus_stats = kwargs.pop("enable_ray_prometheus_stats", False) enable_log_requests = kwargs.pop("enable_log_requests", False) max_log_len = kwargs.pop("max_log_len", None) - + + # Log if enable_return_routed_experts is being passed + if "enable_return_routed_experts" in kwargs: + logger.info(f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs") + else: + logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") + if version.parse(vllm.__version__) >= version.parse("0.10.0"): engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) else: diff --git a/skyrl/backends/skyrl_train/training_batch.py b/skyrl/backends/skyrl_train/training_batch.py index 839508aa3c..8b00aeaa92 100644 --- a/skyrl/backends/skyrl_train/training_batch.py +++ b/skyrl/backends/skyrl_train/training_batch.py @@ -369,6 +369,7 @@ class TrainingInput(TypedDict, total=False): kl: Float[torch.Tensor, "batch_size seq_len"] rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] + rollout_inference_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]] class TrainingInputBatch(TensorBatch[TrainingInput]): diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py new file mode 100644 index 0000000000..d315779c07 --- /dev/null +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -0,0 +1,62 @@ +""" +Utility functions for MoE Router Replay. +""" + +import torch +from typing import Optional, List +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch + +def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch.Tensor]: + if rollout_inference_indices is None: + return None + if rollout_inference_indices.dim() != 4: + raise ValueError(f"Expected 4D replay indices, got shape {rollout_inference_indices.shape}") + per_layer = rollout_inference_indices.permute(2, 0, 1, 3).contiguous() + return [per_layer[i] for i in range(per_layer.shape[0])] + +def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: + """ + Set up router replay for forward pass (ref/policy inference). + """ + if not enable_router_replay: + return False + + rollout_inference_indices = data.get("rollout_inference_indices") + if rollout_inference_indices is None: + return False + + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + + RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + return True + + +def setup_router_replay_backward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: + """ + Set up router replay for training forward/backward pass. + """ + if not enable_router_replay: + return False + + rollout_inference_indices = data.get("rollout_inference_indices") + if rollout_inference_indices is None: + return False + + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + + RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) + # Use REPLAY_FORWARD - Megatron handles REPLAY_BACKWARD automatically + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + return True + + +def clear_router_replay(): + """Clear all router replay state.""" + from megatron.core.transformer.moe.router_replay import RouterReplay + + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() + RouterReplay.clear_global_router_replay_instances() diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index e80ad1c85a..850a429f42 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -322,6 +322,7 @@ def init_configs( self.strategy.hf_config = hf_config self.tokenizer = tokenizer + self.enable_router_replay = transformer_config_kwargs.get("moe_enable_routing_replay", False) def configure_lora(self, lora_config, lora_type: Optional[str] = "lora"): if lora_type == "lora": @@ -401,6 +402,10 @@ def forward(self, data: TrainingInputBatch): """ Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. """ + from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + + setup_router_replay_forward(data, self.enable_router_replay) + # Run in micro batches grouped into a single mini-batch micro_bsz = self.cfg.micro_forward_batch_size_per_gpu micro_batches = data.chunk(micro_bsz) @@ -438,6 +443,7 @@ def forward(self, data: TrainingInputBatch): log_probs = log_probs.to("cpu") output = TrainingOutputBatch({"output": log_probs}) output.metadata = data.metadata + clear_router_replay() return output def save_hf_model(self, export_dir: str, tokenizer): @@ -593,6 +599,9 @@ def forward_backward( Returns: Aggregated metrics dict across all micro batches """ + from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + + setup_router_replay_forward(data, self.enable_router_replay) self.model.train() for chunk in self.actor_module: # if use distributed optimizer, zero grad buffer will be handled by optimizer @@ -665,6 +674,8 @@ def forward_backward( # Add loss_fn_outputs back (not reduced, kept as list) if all_loss_fn_outputs: status["loss_fn_outputs"] = all_loss_fn_outputs + + clear_router_replay() return status diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index effaa7c6cc..4c4f749670 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -440,6 +440,7 @@ class InferenceEngineConfig(BaseConfig): """Sets ``VLLM_ENABLE_V1_MULTIPROCESSING=0`` for reproducibility.""" enable_prefix_caching: bool = True enable_chunked_prefill: bool = True + enable_return_routed_experts: bool = False max_num_batched_tokens: int = 8192 enforce_eager: bool = True """Disable CUDA graphs for stability. Set to ``False`` for higher performance, diff --git a/skyrl/train/config/megatron_config/policy.yaml b/skyrl/train/config/megatron_config/policy.yaml index d6ef8382f9..b641d1f706 100644 --- a/skyrl/train/config/megatron_config/policy.yaml +++ b/skyrl/train/config/megatron_config/policy.yaml @@ -57,6 +57,7 @@ transformer_config_kwargs: recompute_modules: ["core_attn"] recompute_method: uniform recompute_num_layers: 1 + moe_enable_routing_replay: ${generator.enable_return_routed_experts} # flag to manually empty torch's cuda cache between the forward/backward pass and the optimizer step # this will free reserved but unallocated memory, and can help avoid OoMs in the optimizer diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index 8f4184b18d..9772eb5e9b 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -286,6 +286,7 @@ generator: vllm_v1_disable_multiproc: true enable_prefix_caching: true enable_chunked_prefill: true + enable_return_routed_experts: false max_num_batched_tokens: 8192 # Disable CUDA graphs by default for stability. Set to false for higher performance, but this may affect convergence for long-running and/or long context training jobs. enforce_eager: true diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index a17cea1b91..7c360e5de0 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -1,7 +1,7 @@ from typing import List, Tuple, Optional import torch from transformers import AutoTokenizer -from jaxtyping import Float +from jaxtyping import Float, Integer def _verify_inputs( @@ -32,6 +32,7 @@ def convert_prompts_responses_to_batch_tensors( rewards: List[List[float]], loss_masks: List[List[int]], logprobs: Optional[List[List[float]]] = None, + rollout_inference_indices: Optional[List[List[List[List[int]]]]] = None, ) -> Tuple[ Float[torch.Tensor, "batch seq_len"], Float[torch.Tensor, "batch seq_len"], @@ -39,6 +40,7 @@ def convert_prompts_responses_to_batch_tensors( Float[torch.Tensor, "batch response_len"], Float[torch.Tensor, "batch response_len"], Optional[Float[torch.Tensor, "batch response_len"]], + Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]], ]: """ Convert prompts and responses to batch tensors for training. @@ -129,4 +131,8 @@ def convert_prompts_responses_to_batch_tensors( ] logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float) - return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor + rollout_inference_indices_tensor = None + if rollout_inference_indices: + rollout_inference_indices_tensor = torch.tensor(rollout_inference_indices, dtype=torch.int32) + + return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor, rollout_inference_indices_tensor diff --git a/skyrl/train/entrypoints/main_base.py b/skyrl/train/entrypoints/main_base.py index 10e78e9de3..e3c2b980f1 100644 --- a/skyrl/train/entrypoints/main_base.py +++ b/skyrl/train/entrypoints/main_base.py @@ -68,6 +68,7 @@ def create_ray_wrapped_inference_engines_from_config( "backend": ie_cfg.backend, "engine_init_kwargs": ie_cfg.engine_init_kwargs, "enable_ray_prometheus_stats": ie_cfg.enable_ray_prometheus_stats, + "enable_return_routed_experts": ie_cfg.enable_return_routed_experts, } # Conditionally add LoRA parameters if LoRA is enabled diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index fabe5524e8..02150a80c4 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -39,6 +39,7 @@ class GeneratorOutput(TypedDict): rollout_metrics: Optional[Dict[str, Any]] rollout_logprobs: Optional[List[List[float]]] trajectory_ids: Optional[List[TrajectoryID]] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] # Applicable only for step-wise training is_last_step: Optional[List[bool]] diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 35c7e189f3..3e2fa6d344 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -66,6 +66,7 @@ class TurnOutput: output_logprobs: Optional[List[float]] new_obs: ConversationType obs_ids: List[int] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] reward: Optional[float] added_eos: bool = False @@ -300,11 +301,14 @@ async def agent_loop( output_ids = engine_output["response_ids"][0] stop_reason = engine_output["stop_reasons"][0] response_logprobs = engine_output.get("response_logprobs", None) + rollout_inference_indices = engine_output.get("rollout_inference_indices", None) if response_logprobs is not None: response_logprobs = response_logprobs[0] if self.custom_chat_template is not None: raise ValueError("Response Logprobs bookkeeping is not supported with custom chat template") + if rollout_inference_indices is not None: + rollout_inference_indices = rollout_inference_indices[0] # Append eos when sampling_params.stop is not None. Does not affect 3.a as chat templates add eos_token. # sampling_params is not None for eval, but None for training (which uses engine.sampling_params which are from cfg) stop_strs = current_sampling_params.get("stop", None) @@ -348,6 +352,7 @@ async def agent_loop( reward=step_reward, obs_ids=obs_ids, added_eos=added_eos, + rollout_inference_indices=rollout_inference_indices, ) if is_step_wise: diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index ada3afbc5c..8c966d8749 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -604,6 +604,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks: List[List[int]] = generator_output["loss_masks"] logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None) + rollout_inference_indices: Optional[List[List[List[List[int]]]]] = generator_output.get("rollout_inference_indices", None) ( sequences_tensor, @@ -612,6 +613,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards_tensor, loss_masks_tensor, rollout_logprobs_tensor, + rollout_inference_indices_tensor, ) = convert_prompts_responses_to_batch_tensors( self.tokenizer, prompt_ids, @@ -619,6 +621,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards, loss_masks, logprobs, + rollout_inference_indices, ) # sanity check for off_policy_correction @@ -639,6 +642,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis "rewards": rewards_tensor, "loss_mask": loss_masks_tensor, "rollout_logprobs": rollout_logprobs_tensor, + "rollout_inference_indices": rollout_inference_indices_tensor, "is_last_step": ( torch.tensor(generator_output["is_last_step"], dtype=torch.bool) if generator_output.get("is_last_step", None) is not None From 21b44dedbc639c8e3437b4529defdd636ddd152b Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 02:05:28 +0000 Subject: [PATCH 02/31] lint --- .../skyrl_train/inference_engines/base.py | 2 +- .../inference_engines/vllm/vllm_engine.py | 10 +++++---- .../skyrl_train/utils/replay_utils.py | 22 ++++++++++--------- .../workers/megatron/megatron_worker.py | 6 ++--- skyrl/train/dataset/preprocess.py | 10 ++++++++- skyrl/train/generators/base.py | 2 +- skyrl/train/generators/skyrl_gym_generator.py | 2 +- skyrl/train/trainer.py | 4 +++- 8 files changed, 36 insertions(+), 22 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index c3595ab7a9..ddecc965fb 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -29,7 +29,7 @@ class InferenceEngineOutput(TypedDict): response_ids: List[List[int]] stop_reasons: List[str] response_logprobs: Optional[List[List[float]]] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] class InferenceEngineInterface(ABC): diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index a0fa9ff4c6..66d4213f1c 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -332,13 +332,15 @@ def _create_engine(self, *args, **kwargs): enable_ray_prometheus_stats = kwargs.pop("enable_ray_prometheus_stats", False) enable_log_requests = kwargs.pop("enable_log_requests", False) max_log_len = kwargs.pop("max_log_len", None) - - # Log if enable_return_routed_experts is being passed + + # Log if enable_return_routed_experts is being passed if "enable_return_routed_experts" in kwargs: - logger.info(f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs") + logger.info( + f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs" + ) else: logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") - + if version.parse(vllm.__version__) >= version.parse("0.10.0"): engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) else: diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index d315779c07..9e3bac18a5 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -3,9 +3,10 @@ """ import torch -from typing import Optional, List +from typing import List from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch + def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch.Tensor]: if rollout_inference_indices is None: return None @@ -14,22 +15,23 @@ def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch per_layer = rollout_inference_indices.permute(2, 0, 1, 3).contiguous() return [per_layer[i] for i in range(per_layer.shape[0])] + def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: """ Set up router replay for forward pass (ref/policy inference). """ if not enable_router_replay: return False - + rollout_inference_indices = data.get("rollout_inference_indices") if rollout_inference_indices is None: return False - + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction - + RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - + return True @@ -39,24 +41,24 @@ def setup_router_replay_backward(data: TrainingInputBatch, enable_router_replay: """ if not enable_router_replay: return False - + rollout_inference_indices = data.get("rollout_inference_indices") if rollout_inference_indices is None: return False - + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction - + RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) # Use REPLAY_FORWARD - Megatron handles REPLAY_BACKWARD automatically RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - + return True def clear_router_replay(): """Clear all router replay state.""" from megatron.core.transformer.moe.router_replay import RouterReplay - + RouterReplay.clear_global_indices() RouterReplay.clear_global_router_replay_action() RouterReplay.clear_global_router_replay_instances() diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 850a429f42..3480f0ef27 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -403,7 +403,7 @@ def forward(self, data: TrainingInputBatch): Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. """ from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay - + setup_router_replay_forward(data, self.enable_router_replay) # Run in micro batches grouped into a single mini-batch @@ -600,7 +600,7 @@ def forward_backward( Aggregated metrics dict across all micro batches """ from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay - + setup_router_replay_forward(data, self.enable_router_replay) self.model.train() for chunk in self.actor_module: @@ -674,7 +674,7 @@ def forward_backward( # Add loss_fn_outputs back (not reduced, kept as list) if all_loss_fn_outputs: status["loss_fn_outputs"] = all_loss_fn_outputs - + clear_router_replay() return status diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index 7c360e5de0..01b332a625 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -135,4 +135,12 @@ def convert_prompts_responses_to_batch_tensors( if rollout_inference_indices: rollout_inference_indices_tensor = torch.tensor(rollout_inference_indices, dtype=torch.int32) - return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor, rollout_inference_indices_tensor + return ( + sequences, + attention_mask, + action_mask, + ret_rewards, + ret_loss_masks, + logprobs_tensor, + rollout_inference_indices_tensor, + ) diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index 02150a80c4..530fbbbe84 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -39,7 +39,7 @@ class GeneratorOutput(TypedDict): rollout_metrics: Optional[Dict[str, Any]] rollout_logprobs: Optional[List[List[float]]] trajectory_ids: Optional[List[TrajectoryID]] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] # Applicable only for step-wise training is_last_step: Optional[List[bool]] diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 3e2fa6d344..8f650418b3 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -66,7 +66,7 @@ class TurnOutput: output_logprobs: Optional[List[float]] new_obs: ConversationType obs_ids: List[int] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] reward: Optional[float] added_eos: bool = False diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 8c966d8749..f842868816 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -604,7 +604,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks: List[List[int]] = generator_output["loss_masks"] logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None) - rollout_inference_indices: Optional[List[List[List[List[int]]]]] = generator_output.get("rollout_inference_indices", None) + rollout_inference_indices: Optional[List[List[List[List[int]]]]] = generator_output.get( + "rollout_inference_indices", None + ) ( sequences_tensor, From 23fdc45972c2a20f6af56021553c2d0d260f9417 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 02:25:35 +0000 Subject: [PATCH 03/31] make opus take a pass at test + plumbing fully thru generator --- skyrl/train/generators/skyrl_gym_generator.py | 77 ++++++++- .../gpu/gpu_ci/test_router_replay.py | 163 ++++++++++++++++++ 2 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 8f650418b3..67bc85a720 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -40,6 +40,7 @@ class TrajectoryOutput: prompt_ids: List[int] rollout_logprobs: Optional[List[float]] env_metrics: Dict[str, Any] + rollout_inference_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -57,6 +58,7 @@ class AgentLoopState: rollout_logprobs: Optional[List[float]] response_end_idx: Optional[int] done: bool + rollout_inference_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -66,10 +68,31 @@ class TurnOutput: output_logprobs: Optional[List[float]] new_obs: ConversationType obs_ids: List[int] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] + rollout_inference_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk] reward: Optional[float] added_eos: bool = False + def get_turn_rollout_inference_indices(self) -> Optional[List[List[List[int]]]]: + """ + Get rollout inference indices for this turn's tokens (output + observation). + + Returns indices for generated output tokens, with padding entries (all -1) + for any manually-added EOS token and observation tokens. + Returns None if rollout_inference_indices is None. + """ + if self.rollout_inference_indices is None: + return None + if not self.rollout_inference_indices: + return self.rollout_inference_indices + layer_num = len(self.rollout_inference_indices[0]) + topk = len(self.rollout_inference_indices[0][0]) if layer_num > 0 else 0 + pad_entry = [[-1] * topk for _ in range(layer_num)] + indices = list(self.rollout_inference_indices) + if self.added_eos: + indices.append(pad_entry) + indices.extend(pad_entry for _ in range(len(self.obs_ids))) + return indices + def get_turn_loss_mask(self) -> List[int]: """ Get loss mask for this turn's tokens. @@ -355,6 +378,9 @@ async def agent_loop( rollout_inference_indices=rollout_inference_indices, ) + if turn_output.rollout_inference_indices is not None and agent_loop_state.rollout_inference_indices is None: + agent_loop_state.rollout_inference_indices = [] + if is_step_wise: # current response + observation ids turn_response_ids = turn_output.output_ids + turn_output.obs_ids @@ -372,6 +398,7 @@ async def agent_loop( rollout_logprobs=turn_response_logprobs, stop_reason=stop_reason, env_metrics=env.get_metrics() if agent_loop_state.done else {}, + rollout_inference_indices=turn_output.get_turn_rollout_inference_indices(), ) agent_loop_output.step_outputs.append(per_step_output) @@ -400,6 +427,7 @@ async def agent_loop( prompt_ids = agent_loop_state.input_ids[:initial_prompt_length] rollout_logprobs = None + rollout_inference_indices_out = None response_ids = None # Prepare the final loss_mask, response_ids and rollout_logprobs . @@ -430,6 +458,10 @@ async def agent_loop( rollout_logprobs = agent_loop_state.rollout_logprobs[ : agent_loop_state.response_end_idx - initial_prompt_length + 1 ] + if agent_loop_state.rollout_inference_indices is not None: + rollout_inference_indices_out = agent_loop_state.rollout_inference_indices[ + : agent_loop_state.response_end_idx - initial_prompt_length + 1 + ] # fix index for per_step_rewards per_step_rewards = [(reward, idx - initial_prompt_length) for reward, idx in per_step_rewards] assert len(loss_mask) == len( @@ -446,6 +478,10 @@ async def agent_loop( loss_mask.append(1) if rollout_logprobs is not None: rollout_logprobs.append(0.0) + if rollout_inference_indices_out is not None and rollout_inference_indices_out: + layer_num = len(rollout_inference_indices_out[0]) + topk = len(rollout_inference_indices_out[0][0]) if layer_num > 0 else 0 + rollout_inference_indices_out.append([[-1] * topk for _ in range(layer_num)]) appended_eos_token = True if self.generator_cfg.step_wise_trajectories: @@ -465,6 +501,7 @@ async def agent_loop( prompt_ids=prompt_ids, rollout_logprobs=rollout_logprobs, env_metrics=env_metrics, + rollout_inference_indices=rollout_inference_indices_out, ) return agent_loop_output @@ -616,12 +653,14 @@ async def generate_batched( responses = engine_output["response_ids"] stop_reasons = engine_output["stop_reasons"] logprobs = engine_output.get("response_logprobs", None) + raw_rollout_inference_indices = engine_output.get("rollout_inference_indices", None) truncated_responses = [] rewards = [] loss_masks = [] env_metrics = [] truncated_logprobs: Optional[List[List[float]]] = [] if logprobs is not None else None + truncated_indices: Optional[List] = [] if raw_rollout_inference_indices is not None else None for i, (output, response, env, env_class) in enumerate(zip(outputs, responses, envs, env_classes)): # step on environment and compute reward @@ -636,6 +675,9 @@ async def generate_batched( if logprobs is not None: sample_logprobs = logprobs[i][: len(response)] truncated_logprobs.append(sample_logprobs) + if raw_rollout_inference_indices is not None: + sample_indices = raw_rollout_inference_indices[i][: len(response)] + truncated_indices.append(sample_indices) # Get environment-specific metrics env_metrics.append(env.get_metrics()) @@ -655,6 +697,7 @@ async def generate_batched( "stop_reasons": stop_reasons, "rollout_metrics": rollout_metrics, "rollout_logprobs": truncated_logprobs, + "rollout_inference_indices": truncated_indices, } return generator_output @@ -755,6 +798,18 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False else: rollout_logprobs = None + if self.generator_cfg.step_wise_trajectories: + all_indices = sum( + [ + [step_output.rollout_inference_indices for step_output in output.step_outputs] + for output in all_outputs + ], + [], + ) + else: + all_indices = [output.rollout_inference_indices for output in all_outputs] + rollout_inference_indices = all_indices if any(idx is not None for idx in all_indices) else None + rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) if self.generator_cfg.zero_reward_on_non_stop: @@ -773,6 +828,7 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False "rollout_metrics": rollout_metrics, "rollout_logprobs": rollout_logprobs, "trajectory_ids": out_trajectory_ids, + "rollout_inference_indices": rollout_inference_indices, "is_last_step": is_last_step, } @@ -840,6 +896,8 @@ def _update_agent_state_by_retokenizing_chat_history( agent_loop_state.response_end_idx = None # `logprobs` are not computed because retokenizing breaks token-in-token-out agent_loop_state.rollout_logprobs = None + # indices are not meaningful when retokenizing + agent_loop_state.rollout_inference_indices = None return agent_loop_state def _update_agent_loop_state_with_multiturn_chat_template( @@ -891,6 +949,8 @@ def _update_agent_loop_state_with_multiturn_chat_template( loss_mask_for_turn = turn_output.get_turn_loss_mask() rollout_logprobs_for_turn = turn_output.get_turn_rollout_logprobs() + rollout_inference_indices_for_turn = turn_output.get_turn_rollout_inference_indices() + if self.generator_cfg.step_wise_trajectories: # cumulative input_ids is not tracked for step wise training agent_loop_state.response_end_idx = len(turn_output.output_ids) - 1 @@ -905,6 +965,11 @@ def _update_agent_loop_state_with_multiturn_chat_template( agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn + if ( + agent_loop_state.rollout_inference_indices is not None + and rollout_inference_indices_for_turn is not None + ): + agent_loop_state.rollout_inference_indices += rollout_inference_indices_for_turn return agent_loop_state @@ -969,11 +1034,21 @@ def _update_agent_loop_state_with_singleturn_chat_template( obs_ids_to_add ) + rollout_inference_indices_for_turn = None + if turn_output.rollout_inference_indices is not None and turn_output.rollout_inference_indices: + layer_num = len(turn_output.rollout_inference_indices[0]) + topk = len(turn_output.rollout_inference_indices[0][0]) if layer_num > 0 else 0 + pad_entry = [[-1] * topk for _ in range(layer_num)] + rollout_inference_indices_for_turn = list(turn_output.rollout_inference_indices[: len(new_resp_tokens)]) + rollout_inference_indices_for_turn.extend(pad_entry for _ in range(len(obs_ids_to_add))) + # Directly append turn output agent_loop_state.response_end_idx = len(agent_loop_state.input_ids) + len(new_resp_tokens) - 1 agent_loop_state.input_ids += turn_ids agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn + if agent_loop_state.rollout_inference_indices is not None and rollout_inference_indices_for_turn is not None: + agent_loop_state.rollout_inference_indices += rollout_inference_indices_for_turn return agent_loop_state diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py new file mode 100644 index 0000000000..eedbc61282 --- /dev/null +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -0,0 +1,163 @@ +""" +Run with: +uv run --isolated --extra dev --extra megatron -- pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +""" + +import ray +import pytest +import asyncio +from transformers import AutoTokenizer +from tests.backends.skyrl_train.gpu.utils import ( + InferenceEngineState, + get_test_generator_input, + Timer, +) +from skyrl.train.utils.utils import validate_cfg +from skyrl.train.config import ( + SkyRLTrainConfig, + SamplingParams, +) +from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator +from skyrl.train.generators.base import GeneratorInput +from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend + +MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" + + +def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: + cfg = SkyRLTrainConfig() + cfg.trainer.policy.model.path = model_name + cfg.trainer.micro_forward_batch_size_per_gpu = 2 + cfg.trainer.micro_train_batch_size_per_gpu = 2 + cfg.trainer.use_sample_packing = False + cfg.trainer.logger = "console" + + validate_cfg(cfg) + + return cfg + + +@pytest.mark.megatron +def test_megatron_router_replay(ray_init_fixture): + """ + Test that SkyRLGymGenerator returns rollout_inference_indices + for MoE models with enable_return_routed_experts=True. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 2 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=64, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 1 + cfg.generator.use_conversation_multi_turn = True + cfg.generator.apply_overlong_filtering = False + cfg.generator.zero_reward_on_non_stop = False + + num_prompts = 2 + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.8, + ) as engines: + client = engines.client + + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=num_prompts, + n_samples_per_prompt=1, + max_prompt_length=512, + env_class="gsm8k", + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=64, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + # --- Basic output checks --- + assert ( + "rollout_inference_indices" in generator_output + ), "rollout_inference_indices missing from GeneratorOutput" + indices = generator_output["rollout_inference_indices"] + assert ( + indices is not None + ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + + responses = generator_output["response_ids"] + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + + # --- Shape & value validation per sample --- + for i, (sample_indices, sample_response) in enumerate(zip(indices, responses)): + response_len = len(sample_response) + assert ( + len(sample_indices) == response_len + ), f"Sample {i}: indices length {len(sample_indices)} != response length {response_len}" + + if response_len == 0: + continue + + # Each token position should have [layer_num, topk] structure + layer_num = len(sample_indices[0]) + assert layer_num > 0, f"Sample {i}: expected > 0 MoE layers, got {layer_num}" + + topk = len(sample_indices[0][0]) + assert topk > 0, f"Sample {i}: expected topk > 0, got {topk}" + + for t, token_indices in enumerate(sample_indices): + assert ( + len(token_indices) == layer_num + ), f"Sample {i}, token {t}: expected {layer_num} layers, got {len(token_indices)}" + for l_idx, layer_indices in enumerate(token_indices): + assert ( + len(layer_indices) == topk + ), f"Sample {i}, token {t}, layer {l_idx}: expected topk={topk}, got {len(layer_indices)}" + for k, expert_id in enumerate(layer_indices): + assert isinstance(expert_id, int), ( + f"Sample {i}, token {t}, layer {l_idx}, k {k}: " + f"expected int expert id, got {type(expert_id)}" + ) + assert expert_id >= 0, ( + f"Sample {i}, token {t}, layer {l_idx}, k {k}: " + f"expected non-negative expert id, got {expert_id}" + ) + + print("Router replay test passed:") + print(f" Batch size: {len(indices)}") + print(f" Response lengths: {[len(r) for r in responses]}") + if indices and indices[0]: + print(f" Layers: {len(indices[0][0])}, TopK: {len(indices[0][0][0])}") + + finally: + ray.shutdown() From 8daac59a53d0a244cf322dc64c45b7798ebbd2a5 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Wed, 4 Mar 2026 10:31:59 +0000 Subject: [PATCH 04/31] updated test utils and file to support rollout replay indices --- .../inference_engine_client.py | 16 +++ .../ray_wrapped_inference_engine.py | 2 + .../workers/megatron/megatron_worker.py | 24 ++++ .../gpu/gpu_ci/test_router_replay.py | 118 +++++++++++++++++- tests/backends/skyrl_train/gpu/utils.py | 1 + 5 files changed, 157 insertions(+), 4 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index 02bad0d7be..19a6724395 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -153,8 +153,10 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu stop_reasons: list[str] = [""] * n response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)] response_ids: List[List[int]] = [[] for _ in range(n)] + rollout_inference_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)] # a bit hacky for now add_resp_logprobs = False + add_rollout_inference_indices = False for indices, result in zip(indices_list, results): for local_idx, original_idx in enumerate(indices): @@ -164,12 +166,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu if result.get("response_logprobs", None): add_resp_logprobs = True response_logprobs[original_idx] = result["response_logprobs"][local_idx] + if result.get("rollout_inference_indices", None): + add_rollout_inference_indices = True + rollout_inference_indices[original_idx] = result["rollout_inference_indices"][local_idx] return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs if add_resp_logprobs else None, + rollout_inference_indices=rollout_inference_indices if add_rollout_inference_indices else None, ) def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int: @@ -265,6 +271,7 @@ async def _generate_single_with_retry( # 2. Initialize fields we want to accumulate or update in each loop iteration accum_response_ids: List[int] = [] accum_response_logprobs: List[float] = [] + accum_rollout_inference_indices: List[List[List[int]]] = [] stop_reason: str = "abort" # We only use it if generation is completed in one turn to maintain original behavior with no retry. @@ -300,6 +307,10 @@ async def _generate_single_with_retry( new_response_logprobs_list: Optional[List[List[float]]] = partial_response.get("response_logprobs", None) if new_response_logprobs_list is not None and len(new_response_logprobs_list) > 0: new_response_logprobs = new_response_logprobs_list[0] + new_rollout_inference_indices: Optional[List[List[List[int]]]] = None + new_rollout_inference_indices_list = partial_response.get("rollout_inference_indices", None) + if new_rollout_inference_indices_list is not None and len(new_rollout_inference_indices_list) > 0: + new_rollout_inference_indices = new_rollout_inference_indices_list[0] # 3.4 Aborted without generating tokens, so partial_response is useless. if stop_reason == "abort" and len(new_response_ids) == 0: @@ -309,6 +320,8 @@ async def _generate_single_with_retry( accum_response_ids.extend(new_response_ids) if new_response_logprobs is not None: accum_response_logprobs.extend(new_response_logprobs) + if new_rollout_inference_indices is not None: + accum_rollout_inference_indices.extend(new_rollout_inference_indices) num_turns += 1 # 4. Build the final response and return. @@ -321,6 +334,9 @@ async def _generate_single_with_retry( stop_reasons=[stop_reason], response_ids=[accum_response_ids], response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None, + rollout_inference_indices=[accum_rollout_inference_indices] + if len(accum_rollout_inference_indices) > 0 + else None, ) async def _chat_completion_with_retry( diff --git a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 6e238f492f..0ec3c29ac3 100644 --- a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -107,6 +107,7 @@ def create_ray_wrapped_inference_engines( rope_scaling: Dict[str, Any] = {}, rope_theta: float | None = None, enable_ray_prometheus_stats: bool = False, + enable_return_routed_experts: bool = False, served_model_name: str | None = None, ) -> List[InferenceEngineInterface]: """ @@ -249,6 +250,7 @@ def create_ray_wrapped_inference_engines( max_num_seqs=max_num_seqs, max_logprobs=1, # only need chosen-token logprobs enable_ray_prometheus_stats=enable_ray_prometheus_stats, + enable_return_routed_experts=enable_return_routed_experts, **dp_kwargs, **engine_init_kwargs, **lora_kwargs, diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 3480f0ef27..9cb8e1a2e4 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -253,6 +253,28 @@ def extract_weights(self, dtype: torch.dtype): class MegatronWorker: + def _read_router_replay_state(self): + from megatron.core.transformer.moe.router_replay import RouterReplay + + # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info + global_indices = getattr(RouterReplay, "global_indices", None) or [] + instances = getattr(RouterReplay, "global_router_replay_instances", None) or [] + + # Track size to check if shapes are valid after replay / we consume only layers we need + replay_backward_total_entries = 0 + for instance in instances: + replay_backward_total_entries += len(getattr(instance, "replay_backward_list", None) or []) + + return { + "action": str(getattr(RouterReplay, "global_router_replay_action", None)), + "global_indices": [x.detach().cpu() for x in global_indices], + "num_instances": len(instances), + "replay_backward_total_entries": replay_backward_total_entries, + } + + def get_last_router_replay_state(self): + return getattr(self, "_last_router_replay_state", None) + def init_configs( self, model_path, @@ -443,6 +465,7 @@ def forward(self, data: TrainingInputBatch): log_probs = log_probs.to("cpu") output = TrainingOutputBatch({"output": log_probs}) output.metadata = data.metadata + self._last_router_replay_state = self._read_router_replay_state() clear_router_replay() return output @@ -675,6 +698,7 @@ def forward_backward( if all_loss_fn_outputs: status["loss_fn_outputs"] = all_loss_fn_outputs + self._last_router_replay_state = self._read_router_replay_state() clear_router_replay() return status diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index eedbc61282..69c25a9d07 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -6,11 +6,13 @@ import ray import pytest import asyncio +import torch from transformers import AutoTokenizer from tests.backends.skyrl_train.gpu.utils import ( InferenceEngineState, get_test_generator_input, Timer, + init_worker_with_type, ) from skyrl.train.utils.utils import validate_cfg from skyrl.train.config import ( @@ -20,6 +22,9 @@ from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator from skyrl.train.generators.base import GeneratorInput from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend +from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.backends.skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" @@ -49,7 +54,7 @@ def test_megatron_router_replay(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 2 cfg.generator.sampling_params = SamplingParams( - max_generate_length=64, + max_generate_length=16, logprobs=1, temperature=1.0, ) @@ -59,7 +64,7 @@ def test_megatron_router_replay(ray_init_fixture): cfg.generator.apply_overlong_filtering = False cfg.generator.zero_reward_on_non_stop = False - num_prompts = 2 + num_prompts = 1 tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) @@ -69,7 +74,8 @@ def test_megatron_router_replay(ray_init_fixture): use_local=True, backend="vllm", sleep_level=1, - gpu_memory_utilization=0.8, + gpu_memory_utilization=0.9, + max_num_seqs=1, ) as engines: client = engines.client @@ -95,7 +101,7 @@ def test_megatron_router_replay(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=64, + max_generate_length=16, min_p=0.0, logprobs=1, ), @@ -152,12 +158,116 @@ def test_megatron_router_replay(ray_init_fixture): f"Sample {i}, token {t}, layer {l_idx}, k {k}: " f"expected non-negative expert id, got {expert_id}" ) + from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices + replay_tensor = torch.tensor(indices, dtype=torch.long) + per_layer_replay = _split_replay_indices(replay_tensor) + reconstructed = torch.stack(per_layer_replay, dim=2) + assert torch.equal( + replay_tensor, reconstructed + ), "Replay index translation changed values between vLLM and Megatron layout" + + prompt_ids = generator_output["prompt_token_ids"] + rollout_logprobs = generator_output.get("rollout_logprobs", None) + loss_masks = generator_output["loss_masks"] + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[reward] * len(response) for reward, response in zip(rewards, responses)] + + ( + sequences_tensor, + attention_masks_tensor, + response_masks_tensor, + rewards_tensor, + loss_masks_tensor, + rollout_logprobs_tensor, + rollout_inference_indices_tensor, + ) = convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=prompt_ids, + responses=responses, + rewards=rewards, + loss_masks=loss_masks, + logprobs=rollout_logprobs, + rollout_inference_indices=indices, + ) + + assert rollout_inference_indices_tensor is not None + assert rollout_inference_indices_tensor.shape[0] == len(responses) + assert rollout_inference_indices_tensor.shape[1] == response_masks_tensor.shape[1] + + training_input = TrainingInputBatch( + { + "sequences": sequences_tensor, + "attention_mask": attention_masks_tensor, + "response_mask": response_masks_tensor, + "rewards": rewards_tensor, + "loss_mask": loss_masks_tensor, + "rollout_logprobs": rollout_logprobs_tensor, + "rollout_inference_indices": rollout_inference_indices_tensor, + } + ) + training_input.metadata = {"response_length": response_masks_tensor.shape[1]} + assert training_input["rollout_inference_indices"] is not None + + cfg.trainer.policy.megatron_config.transformer_config_kwargs = {} + cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_enable_routing_replay"] = True + cfg.trainer.placement.policy_num_gpus_per_node = 2 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + num_actions = response_masks_tensor.shape[1] + batch_size = sequences_tensor.shape[0] + if training_input.get("rollout_logprobs") is None: + training_input["rollout_logprobs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) + training_input["action_log_probs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) + training_input["base_action_log_probs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) + training_input["advantages"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) + training_input["action_mask"] = response_masks_tensor.to(dtype=torch.int64) + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=2, + cfg=cfg, + ) + + forward_refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) + all_rank_forward_outputs = ray.get(forward_refs) + forward_output = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, all_rank_forward_outputs)[ + "output" + ] + expected_per_layer = _split_replay_indices(training_input["rollout_inference_indices"].to(torch.long)) + forward_state = ray.get(actor_group.async_run_ray_method("pass_through", "get_last_router_replay_state"))[0] + + assert forward_state is not None + assert len(forward_state["global_indices"]) == len(expected_per_layer) + for got, expected in zip(forward_state["global_indices"], expected_per_layer): + assert torch.equal(got.to(torch.long), expected.to(torch.long)) + assert forward_state["replay_backward_total_entries"] > 0 + + training_input.metadata["global_step"] = 0 + fb_results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", training_input)) + assert isinstance(fb_results[0], dict) + assert "policy_loss" in fb_results[0] + fb_state = ray.get(actor_group.async_run_ray_method("pass_through", "get_last_router_replay_state"))[0] + assert fb_state is not None + assert len(fb_state["global_indices"]) == len(expected_per_layer) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) print("Router replay test passed:") print(f" Batch size: {len(indices)}") print(f" Response lengths: {[len(r) for r in responses]}") if indices and indices[0]: print(f" Layers: {len(indices[0][0])}, TopK: {len(indices[0][0][0])}") + print(f" Handoff replay tensor shape: {tuple(rollout_inference_indices_tensor.shape)}") + print(f" Megatron forward shape: {tuple(forward_output.shape)}") finally: ray.shutdown() diff --git a/tests/backends/skyrl_train/gpu/utils.py b/tests/backends/skyrl_train/gpu/utils.py index f6d726b902..a7f7545331 100644 --- a/tests/backends/skyrl_train/gpu/utils.py +++ b/tests/backends/skyrl_train/gpu/utils.py @@ -516,6 +516,7 @@ def create( sleep_level=sleep_level, enable_lora=enable_lora, engine_init_kwargs=ie_cfg.engine_init_kwargs, + enable_return_routed_experts=ie_cfg.enable_return_routed_experts, served_model_name=served_model_name, ) client = InferenceEngineClient( From 647426fff9d93c3da286f4da00ea0d2437db287e Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Wed, 4 Mar 2026 12:03:06 +0000 Subject: [PATCH 05/31] add helper functions for router visibility and megatron testing, successful! --- .../skyrl_train/utils/replay_utils.py | 3 +- .../workers/megatron/megatron_worker.py | 31 +-- .../gpu/gpu_ci/test_router_replay.py | 219 ++++++------------ 3 files changed, 94 insertions(+), 159 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 9e3bac18a5..2e5b166e1b 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -13,7 +13,8 @@ def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch if rollout_inference_indices.dim() != 4: raise ValueError(f"Expected 4D replay indices, got shape {rollout_inference_indices.shape}") per_layer = rollout_inference_indices.permute(2, 0, 1, 3).contiguous() - return [per_layer[i] for i in range(per_layer.shape[0])] + # flatten [batch, seq, topk] to [batch * seq, topk] for each layer + return [per_layer[i].reshape(-1, per_layer.shape[-1]) for i in range(per_layer.shape[0])] def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 9cb8e1a2e4..d4cf00b0f0 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -254,26 +254,31 @@ def extract_weights(self, dtype: torch.dtype): class MegatronWorker: def _read_router_replay_state(self): + """Read the current RouterReplay state from all instances.""" from megatron.core.transformer.moe.router_replay import RouterReplay # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info - global_indices = getattr(RouterReplay, "global_indices", None) or [] - instances = getattr(RouterReplay, "global_router_replay_instances", None) or [] + instances = RouterReplay.global_router_replay_instances or [] + action = instances[0].router_replay_action if instances else None - # Track size to check if shapes are valid after replay / we consume only layers we need - replay_backward_total_entries = 0 - for instance in instances: - replay_backward_total_entries += len(getattr(instance, "replay_backward_list", None) or []) + target_indices = [ + inst.target_topk_idx.detach().cpu() + for inst in instances if inst.target_topk_idx is not None + ] return { - "action": str(getattr(RouterReplay, "global_router_replay_action", None)), - "global_indices": [x.detach().cpu() for x in global_indices], + "action": str(action), + "target_indices": target_indices, "num_instances": len(instances), - "replay_backward_total_entries": replay_backward_total_entries, } - def get_last_router_replay_state(self): - return getattr(self, "_last_router_replay_state", None) + def debug_setup_router_replay_state(self, data: TrainingInputBatch): + from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + + setup_router_replay_forward(data, enable_router_replay=True) + state = self._read_router_replay_state() + clear_router_replay() + return state def init_configs( self, @@ -424,7 +429,7 @@ def forward(self, data: TrainingInputBatch): """ Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. """ - from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay setup_router_replay_forward(data, self.enable_router_replay) @@ -622,7 +627,7 @@ def forward_backward( Returns: Aggregated metrics dict across all micro batches """ - from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay setup_router_replay_forward(data, self.enable_router_replay) self.model.train() diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 69c25a9d07..26ed26c7fb 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -15,18 +15,16 @@ init_worker_with_type, ) from skyrl.train.utils.utils import validate_cfg -from skyrl.train.config import ( - SkyRLTrainConfig, - SamplingParams, -) +from skyrl.train.config import SkyRLTrainConfig, SamplingParams from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator from skyrl.train.generators.base import GeneratorInput from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch -from skyrl.backends.skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch +from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +REPLAY_NUM_LAYERS = 2 def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: @@ -36,9 +34,7 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: cfg.trainer.micro_train_batch_size_per_gpu = 2 cfg.trainer.use_sample_packing = False cfg.trainer.logger = "console" - validate_cfg(cfg) - return cfg @@ -124,150 +120,83 @@ def test_megatron_router_replay(ray_init_fixture): responses ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" - # --- Shape & value validation per sample --- - for i, (sample_indices, sample_response) in enumerate(zip(indices, responses)): - response_len = len(sample_response) - assert ( - len(sample_indices) == response_len - ), f"Sample {i}: indices length {len(sample_indices)} != response length {response_len}" - - if response_len == 0: - continue - - # Each token position should have [layer_num, topk] structure - layer_num = len(sample_indices[0]) - assert layer_num > 0, f"Sample {i}: expected > 0 MoE layers, got {layer_num}" - - topk = len(sample_indices[0][0]) - assert topk > 0, f"Sample {i}: expected topk > 0, got {topk}" - - for t, token_indices in enumerate(sample_indices): - assert ( - len(token_indices) == layer_num - ), f"Sample {i}, token {t}: expected {layer_num} layers, got {len(token_indices)}" - for l_idx, layer_indices in enumerate(token_indices): - assert ( - len(layer_indices) == topk - ), f"Sample {i}, token {t}, layer {l_idx}: expected topk={topk}, got {len(layer_indices)}" - for k, expert_id in enumerate(layer_indices): - assert isinstance(expert_id, int), ( - f"Sample {i}, token {t}, layer {l_idx}, k {k}: " - f"expected int expert id, got {type(expert_id)}" - ) - assert expert_id >= 0, ( - f"Sample {i}, token {t}, layer {l_idx}, k {k}: " - f"expected non-negative expert id, got {expert_id}" - ) - from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices - replay_tensor = torch.tensor(indices, dtype=torch.long) - per_layer_replay = _split_replay_indices(replay_tensor) - reconstructed = torch.stack(per_layer_replay, dim=2) - assert torch.equal( - replay_tensor, reconstructed - ), "Replay index translation changed values between vLLM and Megatron layout" - - prompt_ids = generator_output["prompt_token_ids"] - rollout_logprobs = generator_output.get("rollout_logprobs", None) - loss_masks = generator_output["loss_masks"] - rewards = generator_output["rewards"] - if rewards and not isinstance(rewards[0], list): - rewards = [[reward] * len(response) for reward, response in zip(rewards, responses)] + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_inference_indices=indices, + ) - ( - sequences_tensor, - attention_masks_tensor, - response_masks_tensor, - rewards_tensor, - loss_masks_tensor, - rollout_logprobs_tensor, - rollout_inference_indices_tensor, - ) = convert_prompts_responses_to_batch_tensors( - tokenizer=tokenizer, - prompts=prompt_ids, - responses=responses, - rewards=rewards, - loss_masks=loss_masks, - logprobs=rollout_logprobs, - rollout_inference_indices=indices, - ) + assert rii_tensor is not None + rii_tensor = rii_tensor[:, :, :REPLAY_NUM_LAYERS, :] + + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch({ + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": logprobs_t if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32), + "rollout_inference_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + }) + training_input.metadata = {"response_length": num_actions} + + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + "num_layers": REPLAY_NUM_LAYERS, + "moe_enable_routing_replay": True, + } + cfg.trainer.placement.policy_num_gpus_per_node = 2 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + actor_group = init_worker_with_type( + "policy", shared_pg=None, colocate_all=False, + num_gpus_per_node=2, cfg=cfg, + ) - assert rollout_inference_indices_tensor is not None - assert rollout_inference_indices_tensor.shape[0] == len(responses) - assert rollout_inference_indices_tensor.shape[1] == response_masks_tensor.shape[1] + expected_per_layer = _split_replay_indices(rii_tensor.to(torch.long)) - training_input = TrainingInputBatch( - { - "sequences": sequences_tensor, - "attention_mask": attention_masks_tensor, - "response_mask": response_masks_tensor, - "rewards": rewards_tensor, - "loss_mask": loss_masks_tensor, - "rollout_logprobs": rollout_logprobs_tensor, - "rollout_inference_indices": rollout_inference_indices_tensor, - } + state = ray.get( + actor_group.async_run_ray_method( + "pass_through", "debug_setup_router_replay_state", + data=training_input, ) - training_input.metadata = {"response_length": response_masks_tensor.shape[1]} - assert training_input["rollout_inference_indices"] is not None - - cfg.trainer.policy.megatron_config.transformer_config_kwargs = {} - cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_enable_routing_replay"] = True - cfg.trainer.placement.policy_num_gpus_per_node = 2 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 - cfg.trainer.micro_forward_batch_size_per_gpu = 1 - cfg.trainer.micro_train_batch_size_per_gpu = 1 + )[0] - num_actions = response_masks_tensor.shape[1] - batch_size = sequences_tensor.shape[0] - if training_input.get("rollout_logprobs") is None: - training_input["rollout_logprobs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) - training_input["action_log_probs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) - training_input["base_action_log_probs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) - training_input["advantages"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) - training_input["action_mask"] = response_masks_tensor.to(dtype=torch.int64) - - actor_group = init_worker_with_type( - "policy", - shared_pg=None, - colocate_all=False, - num_gpus_per_node=2, - cfg=cfg, + assert state is not None, "Worker returned None state" + assert "REPLAY_FORWARD" in state["action"], ( + f"RouterReplay action should be REPLAY_FORWARD, got: {state['action']}" + ) + assert state["num_instances"] == len(expected_per_layer), ( + f"Expected {len(expected_per_layer)} replay instances (one per layer), " + f"got {state['num_instances']}" + ) + for layer_idx, (got, expected) in enumerate( + zip(state["target_indices"], expected_per_layer) + ): + assert torch.equal(got.to(torch.long), expected.to(torch.long)), ( + f"Layer {layer_idx}: Megatron target indices differ from vLLM indices" ) - - forward_refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) - all_rank_forward_outputs = ray.get(forward_refs) - forward_output = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, all_rank_forward_outputs)[ - "output" - ] - expected_per_layer = _split_replay_indices(training_input["rollout_inference_indices"].to(torch.long)) - forward_state = ray.get(actor_group.async_run_ray_method("pass_through", "get_last_router_replay_state"))[0] - - assert forward_state is not None - assert len(forward_state["global_indices"]) == len(expected_per_layer) - for got, expected in zip(forward_state["global_indices"], expected_per_layer): - assert torch.equal(got.to(torch.long), expected.to(torch.long)) - assert forward_state["replay_backward_total_entries"] > 0 - - training_input.metadata["global_step"] = 0 - fb_results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", training_input)) - assert isinstance(fb_results[0], dict) - assert "policy_loss" in fb_results[0] - fb_state = ray.get(actor_group.async_run_ray_method("pass_through", "get_last_router_replay_state"))[0] - assert fb_state is not None - assert len(fb_state["global_indices"]) == len(expected_per_layer) - ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) - - print("Router replay test passed:") - print(f" Batch size: {len(indices)}") - print(f" Response lengths: {[len(r) for r in responses]}") - if indices and indices[0]: - print(f" Layers: {len(indices[0][0])}, TopK: {len(indices[0][0][0])}") - print(f" Handoff replay tensor shape: {tuple(rollout_inference_indices_tensor.shape)}") - print(f" Megatron forward shape: {tuple(forward_output.shape)}") + print(f"PASSED: vLLM routing indices ({rii_tensor.shape}) correctly " + f"loaded into {state['num_instances']} Megatron RouterReplay instances") finally: ray.shutdown() From d4b753fcef83c8aa37f6235a540af7e11607719a Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Wed, 4 Mar 2026 12:11:49 +0000 Subject: [PATCH 06/31] linter --- .../inference_engine_client.py | 6 +- .../workers/megatron/megatron_worker.py | 7 +- .../gpu/gpu_ci/test_router_replay.py | 86 +++++++++++-------- 3 files changed, 53 insertions(+), 46 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index 19a6724395..264e060cd1 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -334,9 +334,9 @@ async def _generate_single_with_retry( stop_reasons=[stop_reason], response_ids=[accum_response_ids], response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None, - rollout_inference_indices=[accum_rollout_inference_indices] - if len(accum_rollout_inference_indices) > 0 - else None, + rollout_inference_indices=( + [accum_rollout_inference_indices] if len(accum_rollout_inference_indices) > 0 else None + ), ) async def _chat_completion_with_retry( diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index d4cf00b0f0..d01a5ad061 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -257,14 +257,11 @@ def _read_router_replay_state(self): """Read the current RouterReplay state from all instances.""" from megatron.core.transformer.moe.router_replay import RouterReplay - # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info + # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info instances = RouterReplay.global_router_replay_instances or [] action = instances[0].router_replay_action if instances else None - target_indices = [ - inst.target_topk_idx.detach().cpu() - for inst in instances if inst.target_topk_idx is not None - ] + target_indices = [inst.target_topk_idx.detach().cpu() for inst in instances if inst.target_topk_idx is not None] return { "action": str(action), diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 26ed26c7fb..13fb4f51cc 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -123,14 +123,16 @@ def test_megatron_router_replay(ray_init_fixture): rewards = generator_output["rewards"] if rewards and not isinstance(rewards[0], list): rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] - (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = convert_prompts_responses_to_batch_tensors( - tokenizer=tokenizer, - prompts=generator_output["prompt_token_ids"], - responses=responses, - rewards=rewards, - loss_masks=generator_output["loss_masks"], - logprobs=generator_output.get("rollout_logprobs"), - rollout_inference_indices=indices, + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_inference_indices=indices, + ) ) assert rii_tensor is not None @@ -138,20 +140,25 @@ def test_megatron_router_replay(ray_init_fixture): num_actions = response_mask.shape[1] batch_size = sequences.shape[0] - training_input = TrainingInputBatch({ - "sequences": sequences, - "attention_mask": attention_mask, - "response_mask": response_mask, - "rewards": rewards_t, - "loss_mask": loss_mask_t, - "rollout_logprobs": logprobs_t if logprobs_t is not None - else torch.zeros((batch_size, num_actions), dtype=torch.float32), - "rollout_inference_indices": rii_tensor, - "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "action_mask": response_mask.to(dtype=torch.int64), - }) + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": ( + logprobs_t + if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32) + ), + "rollout_inference_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) training_input.metadata = {"response_length": num_actions} cfg.trainer.policy.megatron_config.transformer_config_kwargs = { @@ -168,35 +175,38 @@ def test_megatron_router_replay(ray_init_fixture): cfg.trainer.micro_train_batch_size_per_gpu = 1 actor_group = init_worker_with_type( - "policy", shared_pg=None, colocate_all=False, - num_gpus_per_node=2, cfg=cfg, + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=2, + cfg=cfg, ) expected_per_layer = _split_replay_indices(rii_tensor.to(torch.long)) state = ray.get( actor_group.async_run_ray_method( - "pass_through", "debug_setup_router_replay_state", + "pass_through", + "debug_setup_router_replay_state", data=training_input, ) )[0] assert state is not None, "Worker returned None state" - assert "REPLAY_FORWARD" in state["action"], ( - f"RouterReplay action should be REPLAY_FORWARD, got: {state['action']}" - ) + assert ( + "REPLAY_FORWARD" in state["action"] + ), f"RouterReplay action should be REPLAY_FORWARD, got: {state['action']}" assert state["num_instances"] == len(expected_per_layer), ( - f"Expected {len(expected_per_layer)} replay instances (one per layer), " - f"got {state['num_instances']}" + f"Expected {len(expected_per_layer)} replay instances (one per layer), " f"got {state['num_instances']}" + ) + for layer_idx, (got, expected) in enumerate(zip(state["target_indices"], expected_per_layer)): + assert torch.equal( + got.to(torch.long), expected.to(torch.long) + ), f"Layer {layer_idx}: Megatron target indices differ from vLLM indices" + print( + f"PASSED: vLLM routing indices ({rii_tensor.shape}) correctly " + f"loaded into {state['num_instances']} Megatron RouterReplay instances" ) - for layer_idx, (got, expected) in enumerate( - zip(state["target_indices"], expected_per_layer) - ): - assert torch.equal(got.to(torch.long), expected.to(torch.long)), ( - f"Layer {layer_idx}: Megatron target indices differ from vLLM indices" - ) - print(f"PASSED: vLLM routing indices ({rii_tensor.shape}) correctly " - f"loaded into {state['num_instances']} Megatron RouterReplay instances") finally: ray.shutdown() From f1b9c5333cc0f767b0b48c4fe4ba23b7ee976bfb Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 22:21:52 +0000 Subject: [PATCH 07/31] worked w opus to get forward pass logprob diff lower with replay + running with tp + ep for megatron --- .../inference_engines/vllm/vllm_engine.py | 11 +- .../skyrl_train/utils/replay_utils.py | 98 +++++++++++ .../megatron/megatron_model_wrapper.py | 6 + .../workers/megatron/megatron_worker.py | 22 +-- skyrl/train/dataset/preprocess.py | 13 +- skyrl/train/generators/skyrl_gym_generator.py | 5 +- .../gpu/gpu_ci/test_router_replay.py | 160 +++++++++++++++++- 7 files changed, 291 insertions(+), 24 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 66d4213f1c..048fba7ec8 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -157,18 +157,19 @@ def _postprocess_outputs(self, outputs): del token_logprobs response_logprobs.append(_logprobs) + _routed_experts = None if resp.routed_experts is not None: if hasattr(resp.routed_experts, "tolist"): - routed_experts_list = resp.routed_experts.tolist() + _routed_experts = resp.routed_experts.tolist() else: - routed_experts_list = resp.routed_experts - rollout_inference_indices.append(routed_experts_list) + _routed_experts = resp.routed_experts + rollout_inference_indices.append(_routed_experts) if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params - if len(rollout_inference_indices) == 0: - rollout_inference_indices = None + if len(rollout_inference_indices) == 0 and _routed_experts is None: + rollout_inference_indices = None # hack: assume uniform sampling params return InferenceEngineOutput( responses=responses, diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 2e5b166e1b..0b2ce73e50 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -7,6 +7,39 @@ from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +def _patch_alltoall_dispatcher_for_replay(): + """Monkey-patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay. + + When router replay is enabled, duplicate indices in top_indices can cause + routing_map.sum() < num_tokens * topk, leading to a split size mismatch + in the alltoall collective. We fix this by deriving num_out_tokens from + the routing map instead of the static num_tokens * topk formula. + """ + try: + from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher + except ImportError: + return + + if getattr(MoEAlltoAllTokenDispatcher, "_preprocess_patched", False): + return + + original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + + def patched_preprocess(self, routing_map): + result = original_preprocess(self, routing_map) + if ( + getattr(self.config, "moe_enable_routing_replay", False) + and not self.drop_and_pad + and self.config.moe_expert_capacity_factor is None + and not self.config.moe_router_padding_for_quantization + ): + self.num_out_tokens = int(routing_map.sum().item()) + return result + + MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess + MoEAlltoAllTokenDispatcher._preprocess_patched = True + + def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch.Tensor]: if rollout_inference_indices is None: return None @@ -17,6 +50,71 @@ def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch return [per_layer[i].reshape(-1, per_layer.shape[-1]) for i in range(per_layer.shape[0])] +def _remove_left_padding_from_indices( + rollout_inference_indices: torch.Tensor, + attention_mask: torch.Tensor, +) -> torch.Tensor: + """Apply the same left-padding removal as remove_left_padding to routing indices. + + Args: + rollout_inference_indices: [batch, padded_seq_len, layers, topk] + attention_mask: [batch, padded_seq_len] (int or bool) + + Returns: + [batch, effective_seq_len, layers, topk] with real tokens packed left. + """ + import megatron.core.parallel_state as mpu + + seq_lens = attention_mask.sum(dim=1) + effective_seq_len = seq_lens.max().item() + sp_world_size = mpu.get_tensor_model_parallel_world_size() + if sp_world_size > 1: + pad_size = (sp_world_size - effective_seq_len % sp_world_size) % sp_world_size + effective_seq_len += pad_size + + batch_size = rollout_inference_indices.shape[0] + new_rii = torch.zeros( + batch_size, + effective_seq_len, + rollout_inference_indices.shape[2], + rollout_inference_indices.shape[3], + dtype=rollout_inference_indices.dtype, + device=rollout_inference_indices.device, + ) + for i in range(batch_size): + mask = attention_mask[i].bool() + new_rii[i, : seq_lens[i]] = rollout_inference_indices[i, mask] + return new_rii + + +def _setup_per_microbatch_replay( + rollout_inference_indices: torch.Tensor, + attention_mask: torch.Tensor, +) -> None: + """Set up RouterReplay for a single micro-batch, aligning indices + with the left-padding-removed token layout that the MoE layer sees. + + Handles sequence parallelism: when TP > 1, the sequence is split across + TP ranks, so each rank's MoE router only sees its local chunk of tokens. + """ + import megatron.core.parallel_state as mpu + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + + _patch_alltoall_dispatcher_for_replay() + + aligned = _remove_left_padding_from_indices(rollout_inference_indices, attention_mask) + + tp_size = mpu.get_tensor_model_parallel_world_size() + if tp_size > 1: + tp_rank = mpu.get_tensor_model_parallel_rank() + seq_len = aligned.shape[1] + chunk_size = seq_len // tp_size + aligned = aligned[:, tp_rank * chunk_size : (tp_rank + 1) * chunk_size, :, :] + + RouterReplay.set_replay_data(_split_replay_indices(aligned)) + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: """ Set up router replay for forward pass (ref/policy inference). diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 5ab565f989..aa64058e09 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -24,6 +24,7 @@ remove_left_padding, recover_left_padding, ) +from skyrl.backends.skyrl_train.utils.replay_utils import _setup_per_microbatch_replay class MegatronModelWrapper: @@ -103,6 +104,11 @@ def collection_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) + + rollout_inference_indices = batch.pop("rollout_inference_indices", None) + if rollout_inference_indices is not None: + _setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index d01a5ad061..52f7d24807 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -426,9 +426,7 @@ def forward(self, data: TrainingInputBatch): """ Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. """ - from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay - - setup_router_replay_forward(data, self.enable_router_replay) + from skyrl.backends.skyrl_train.utils.replay_utils import clear_router_replay # Run in micro batches grouped into a single mini-batch micro_bsz = self.cfg.micro_forward_batch_size_per_gpu @@ -444,14 +442,16 @@ def forward(self, data: TrainingInputBatch): num_actions = micro.metadata["response_length"] position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) - micro_dicts.append( - { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": num_actions, - } - ) + micro_dict = { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": num_actions, + } + rii = micro.get("rollout_inference_indices") + if rii is not None and self.enable_router_replay: + micro_dict["rollout_inference_indices"] = rii + micro_dicts.append(micro_dict) self.model.eval() seq_len = micro_dicts[0]["sequences"].shape[1] diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index 01b332a625..be9c8fbcce 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -133,7 +133,18 @@ def convert_prompts_responses_to_batch_tensors( rollout_inference_indices_tensor = None if rollout_inference_indices: - rollout_inference_indices_tensor = torch.tensor(rollout_inference_indices, dtype=torch.int32) + first_non_empty = next((x for x in rollout_inference_indices if x), None) + if first_non_empty: + total_seq_len = max_input_len + max_output_len + num_layers = len(first_non_empty[0]) + topk = len(first_non_empty[0][0]) if num_layers > 0 else 0 + padded = torch.zeros(len(rollout_inference_indices), total_seq_len, num_layers, topk, dtype=torch.int32) + for i, sample_indices in enumerate(rollout_inference_indices): + if sample_indices: + left_pad = max_input_len - prompt_token_lens[i] + n = min(len(sample_indices), total_seq_len - left_pad) + padded[i, left_pad : left_pad + n] = torch.tensor(sample_indices[:n], dtype=torch.int32) + rollout_inference_indices_tensor = padded return ( sequences, diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 67bc85a720..b48249130d 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -460,7 +460,7 @@ async def agent_loop( ] if agent_loop_state.rollout_inference_indices is not None: rollout_inference_indices_out = agent_loop_state.rollout_inference_indices[ - : agent_loop_state.response_end_idx - initial_prompt_length + 1 + : agent_loop_state.response_end_idx + 1 ] # fix index for per_step_rewards per_step_rewards = [(reward, idx - initial_prompt_length) for reward, idx in per_step_rewards] @@ -676,8 +676,7 @@ async def generate_batched( sample_logprobs = logprobs[i][: len(response)] truncated_logprobs.append(sample_logprobs) if raw_rollout_inference_indices is not None: - sample_indices = raw_rollout_inference_indices[i][: len(response)] - truncated_indices.append(sample_indices) + truncated_indices.append(raw_rollout_inference_indices[i]) # Get environment-specific metrics env_metrics.append(env.get_metrics()) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 13fb4f51cc..f1134da5d6 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -14,6 +14,7 @@ Timer, init_worker_with_type, ) +from skyrl.backends.skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch from skyrl.train.utils.utils import validate_cfg from skyrl.train.config import SkyRLTrainConfig, SamplingParams from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator @@ -50,15 +51,12 @@ def test_megatron_router_replay(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 2 cfg.generator.sampling_params = SamplingParams( - max_generate_length=16, + max_generate_length=1024, logprobs=1, temperature=1.0, ) cfg.generator.batched = False cfg.generator.max_turns = 1 - cfg.generator.use_conversation_multi_turn = True - cfg.generator.apply_overlong_filtering = False - cfg.generator.zero_reward_on_non_stop = False num_prompts = 1 @@ -210,3 +208,157 @@ def test_megatron_router_replay(ray_init_fixture): finally: ray.shutdown() + + +@pytest.mark.megatron +def test_megatron_router_replay_logprobs(ray_init_fixture): + """ + Check that logprob diff is lower when using router replay. Requires full 8xH100 setup to do full forward pass. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 8 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=1024, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 1 + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + colocate_all=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.9, + ) as engines: + client, pg = engines.client, engines.pg + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=4, + n_samples_per_prompt=1, + max_prompt_length=512, + env_class="gsm8k", + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=16, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + indices = generator_output["rollout_inference_indices"] + responses = generator_output["response_ids"] + assert ( + indices is not None + ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + asyncio.run(client.sleep()) + + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_inference_indices=indices, + ) + ) + + assert rii_tensor is not None + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": ( + logprobs_t + if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32) + ), + "rollout_inference_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) + training_input.metadata = {"response_length": num_actions} + + cfg.trainer.placement.policy_num_gpus_per_node = 8 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 4 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + def run_megatron_forward(enable_replay: bool) -> torch.Tensor: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + "moe_enable_routing_replay": enable_replay, + } + actor_group = init_worker_with_type( + "policy", + shared_pg=pg, + colocate_all=True, + num_gpus_per_node=8, + cfg=cfg, + ) + refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) + results = ray.get(refs) + outputs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] + + for actor in actor_group._actor_handlers: + ray.kill(actor) + return outputs + + r3_logprobs = run_megatron_forward(enable_replay=True) + no_r3_logprobs = run_megatron_forward(enable_replay=False) + + r3_diff = (logprobs_t - r3_logprobs).abs() + no_r3_diff = (logprobs_t - no_r3_logprobs).abs() + print(f"With replay - logprob diff mean: {r3_diff.mean().item():.6f}, std: {r3_diff.std().item():.6f}") + print(f"Without replay - logprob diff mean: {no_r3_diff.mean().item():.6f}, std: {no_r3_diff.std().item():.6f}") + + assert r3_diff.mean().item() < no_r3_diff.mean().item(), ( + f"Router replay should reduce logprob diff vs rollout, " + f"but with_replay={r3_diff.mean().item():.6f} >= without_replay={no_r3_diff.mean().item():.6f}" + ) + finally: + ray.shutdown() From 8a8fa701f9e8420582c1bc9c80dce287073a5284 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 23:26:19 +0000 Subject: [PATCH 08/31] add test for forward backward and fix behavior --- .../megatron/megatron_model_wrapper.py | 4 + .../workers/megatron/megatron_worker.py | 44 ++--- .../gpu/gpu_ci/test_router_replay.py | 158 ++++++++++++++++++ 3 files changed, 185 insertions(+), 21 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index aa64058e09..540e9328e2 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -361,6 +361,10 @@ def loss_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) + rollout_inference_indices = batch.pop("rollout_inference_indices", None) + if rollout_inference_indices is not None: + _setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 52f7d24807..3a78b8fb91 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -33,7 +33,7 @@ from skyrl.train.utils.utils import update_model_config, str_to_torch_dtype from skyrl.backends.skyrl_train.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch -from skyrl.backends.skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics +from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics, all_reduce_metrics from skyrl.backends.skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -624,9 +624,8 @@ def forward_backward( Returns: Aggregated metrics dict across all micro batches """ - from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + from skyrl.backends.skyrl_train.utils.replay_utils import clear_router_replay - setup_router_replay_forward(data, self.enable_router_replay) self.model.train() for chunk in self.actor_module: # if use distributed optimizer, zero grad buffer will be handled by optimizer @@ -638,28 +637,31 @@ def forward_backward( # Move data to GPU data.to(torch.cuda.current_device()) - # Build micro-batch dicts expected by forward_backward_mini_batch + # Chunk manually so we can propagate rollout_inference_indices for + # per-micro-batch router replay (BatchIterator/Experience don't carry them). micro_buffer = [] - for experience in BatchIterator(data, micro_batch_size, drop_last=False): - sequences = experience.sequences - attention_mask = experience.attention_mask + for micro in data.chunk(micro_batch_size): + sequences = micro["sequences"] + attention_mask = micro["attention_mask"] position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) - micro_buffer.append( - { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": experience.num_actions, - "old_action_log_probs": experience.action_log_probs, - "base_action_log_probs": experience.base_action_log_probs, - "advantages": experience.advantages, - "loss_mask": experience.loss_mask, - "rollout_action_logprobs": experience.rollout_logprobs, - "action_mask": experience.action_mask, - } - ) + micro_dict = { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": micro.metadata["response_length"], + "old_action_log_probs": micro.get("action_log_probs"), + "base_action_log_probs": micro.get("base_action_log_probs"), + "advantages": micro.get("advantages"), + "loss_mask": micro.get("loss_mask"), + "rollout_action_logprobs": micro.get("rollout_logprobs"), + "action_mask": micro.get("action_mask"), + } + rii = micro.get("rollout_inference_indices") + if rii is not None and self.enable_router_replay: + micro_dict["rollout_inference_indices"] = rii + micro_buffer.append(micro_dict) if not micro_buffer: return {} diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index f1134da5d6..f113240443 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -362,3 +362,161 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: ) finally: ray.shutdown() + + +@pytest.mark.megatron +def test_megatron_router_replay_forward_backward(ray_init_fixture): + """ + Check that forward_backward produces similar losses with and without + router replay (same weights, so routing decisions should nearly match). + Requires full 8xH100 setup. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 8 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=1024, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 1 + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + colocate_all=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.9, + ) as engines: + client, pg = engines.client, engines.pg + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=10, + n_samples_per_prompt=5, + max_prompt_length=512, + env_class="gsm8k", + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=16, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + indices = generator_output["rollout_inference_indices"] + responses = generator_output["response_ids"] + assert ( + indices is not None + ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + asyncio.run(client.sleep()) + + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_inference_indices=indices, + ) + ) + + assert rii_tensor is not None + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": ( + logprobs_t + if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32) + ), + "rollout_inference_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) + training_input.metadata = {"response_length": num_actions} + + cfg.trainer.placement.policy_num_gpus_per_node = 8 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 4 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + def run_megatron_forward_backward(enable_replay: bool) -> dict: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + "moe_enable_routing_replay": enable_replay, + } + actor_group = init_worker_with_type( + "policy", + shared_pg=pg, + colocate_all=True, + num_gpus_per_node=8, + cfg=cfg, + ) + results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) + for actor in actor_group._actor_handlers: + ray.kill(actor) + return results[0] + + metrics_replay = run_megatron_forward_backward(enable_replay=True) + metrics_no_replay = run_megatron_forward_backward(enable_replay=False) + + loss_replay = metrics_replay["policy_loss"] + loss_no_replay = metrics_no_replay["policy_loss"] + print(f"With replay - loss: {loss_replay:.6f}") + print(f"Without replay - loss: {loss_no_replay:.6f}") + print(f"With replay metrics: {metrics_replay}") + print(f"Without replay metrics: {metrics_no_replay}") + + diff = abs(loss_replay - loss_no_replay) + threshold = 0.5 + print(f"Loss diff: {diff:.6f} (threshold: {threshold})") + assert diff < threshold, ( + f"Losses with/without replay should be similar (same weights), " + f"but diff={diff:.6f} >= threshold={threshold}" + ) + finally: + ray.shutdown() From 410995a64d644ea2ae654cd44c006dd0146ddccb Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 6 Mar 2026 02:22:30 +0000 Subject: [PATCH 09/31] working for qwen but not moonlight... debugging moonlight --- examples/train/gsm8k/run_gsm8k.sh | 4 +- .../skyrl_train/utils/replay_utils.py | 63 ++++++++++++++++- .../megatron/megatron_model_wrapper.py | 16 +++++ .../workers/megatron/megatron_worker.py | 59 +++++++++++++++- skyrl/train/utils/utils.py | 2 +- skyrl/utils/tok.py | 1 + .../skyrl_train/gpu/gpu_ci/conftest.py | 6 ++ .../gpu/gpu_ci/test_megatron_worker.py | 18 +++-- .../gpu/gpu_ci/test_router_replay.py | 67 ++++++++++++++----- 9 files changed, 207 insertions(+), 29 deletions(-) diff --git a/examples/train/gsm8k/run_gsm8k.sh b/examples/train/gsm8k/run_gsm8k.sh index 2693571dac..e5dc9e20f3 100755 --- a/examples/train/gsm8k/run_gsm8k.sh +++ b/examples/train/gsm8k/run_gsm8k.sh @@ -26,8 +26,8 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \ trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \ trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ - generator.inference_engine.num_engines=$NUM_GPUS \ - generator.inference_engine.tensor_parallel_size=1 \ + generator.inference_engine.num_engines=1 \ + generator.inference_engine.tensor_parallel_size=4 \ trainer.epochs=20 \ trainer.eval_batch_size=1024 \ trainer.eval_before_train=true \ diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 0b2ce73e50..68d266e052 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -7,6 +7,38 @@ from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +def _patch_topk_router_layer_number(): + """Monkey-patch TopKRouter.set_layer_number to propagate the global layer + number to the RouterReplay instance. + + DeepSeek V3 (and similar) architectures have dense FFN layers before the MoE + layers. vLLM reports routing indices for ALL transformer layers (including + dense), but Megatron only creates RouterReplay instances for MoE layers. + Storing the global layer_number on each RouterReplay instance lets us map + vLLM's per-layer data to the correct MoE router even when dense layers are + present. + + Must be called BEFORE model creation (i.e. before make_megatron_module). + """ + try: + from megatron.core.transformer.moe.router import TopKRouter + except ImportError: + return + + if getattr(TopKRouter, "_set_layer_number_patched", False): + return + + original_set_layer_number = TopKRouter.set_layer_number + + def patched_set_layer_number(self, layer_number: int): + original_set_layer_number(self, layer_number) + if self.router_replay is not None: + self.router_replay.layer_number = layer_number + + TopKRouter.set_layer_number = patched_set_layer_number + TopKRouter._set_layer_number_patched = True + + def _patch_alltoall_dispatcher_for_replay(): """Monkey-patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay. @@ -96,6 +128,12 @@ def _setup_per_microbatch_replay( Handles sequence parallelism: when TP > 1, the sequence is split across TP ranks, so each rank's MoE router only sees its local chunk of tokens. + + Handles dense-layer mismatch: DeepSeek V3-style models have dense FFN + layers before the MoE layers. vLLM reports routing indices for ALL + transformer layers, but Megatron only has RouterReplay instances for MoE + layers. We use each instance's global layer_number (set by the patched + TopKRouter.set_layer_number) to index into the correct slice of the data. """ import megatron.core.parallel_state as mpu from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction @@ -111,7 +149,30 @@ def _setup_per_microbatch_replay( chunk_size = seq_len // tp_size aligned = aligned[:, tp_rank * chunk_size : (tp_rank + 1) * chunk_size, :, :] - RouterReplay.set_replay_data(_split_replay_indices(aligned)) + per_layer_data = _split_replay_indices(aligned) + num_layers_in_data = len(per_layer_data) + instances = RouterReplay.global_router_replay_instances + num_instances = len(instances) + + if num_layers_in_data == num_instances: + RouterReplay.set_replay_data(per_layer_data) + else: + # Dense-layer mismatch: map each MoE router to its global layer index. + # Prefer the patched layer_number; fall back to offset-based mapping + # (assumes dense layers precede MoE layers). + for i, router_instance in enumerate(instances): + layer_number = getattr(router_instance, "layer_number", None) + if layer_number is not None: + layer_idx = layer_number - 1 # layer_number is 1-based + else: + layer_idx = i + (num_layers_in_data - num_instances) + if layer_idx < 0 or layer_idx >= num_layers_in_data: + raise ValueError( + f"Router replay layer index {layer_idx} out of range " + f"for data with {num_layers_in_data} layers " + f"({num_instances} router instances)" + ) + router_instance.set_target_indices(per_layer_data[layer_idx]) RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 540e9328e2..ec9717819e 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -87,6 +87,22 @@ def collection_func(logits, data): tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() + if tp_rank == 0 and mpu.get_data_parallel_rank() == 0: + import os + + if os.environ.get("SKYRL_DEBUG_LOGITS"): + print( + f"[DEBUG] logits shape={logits.shape}, " + f"mean={logits.float().mean().item():.4f}, " + f"std={logits.float().std().item():.4f}, " + f"min={logits.float().min().item():.4f}, " + f"max={logits.float().max().item():.4f}, " + f"sequences shape={sequences.shape}, " + f"attention_backend={getattr(get_model_config(self.actor_module[0]), 'attention_backend', 'unknown')}, " + f"multi_latent_attention={getattr(get_model_config(self.actor_module[0]), 'multi_latent_attention', 'unknown')}, " + f"q_lora_rank={getattr(get_model_config(self.actor_module[0]), 'q_lora_rank', 'unknown')}" + ) + if temperature != 1.0: logits.div_(temperature) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 3a78b8fb91..77c4b1327e 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -253,6 +253,39 @@ def extract_weights(self, dtype: torch.dtype): class MegatronWorker: + def debug_model_config(self): + """Return model config diagnostics for debugging logprob mismatch.""" + from skyrl.backends.skyrl_train.distributed.megatron.megatron_utils import get_model_config + + config = get_model_config(self.actor_module[0]) + diag = { + "attention_backend": str(getattr(config, "attention_backend", "unknown")), + "multi_latent_attention": getattr(config, "multi_latent_attention", "unknown"), + "q_lora_rank": getattr(config, "q_lora_rank", "unknown"), + "kv_lora_rank": getattr(config, "kv_lora_rank", "unknown"), + "qk_head_dim": getattr(config, "qk_head_dim", "unknown"), + "qk_pos_emb_head_dim": getattr(config, "qk_pos_emb_head_dim", "unknown"), + "v_head_dim": getattr(config, "v_head_dim", "unknown"), + "num_layers": getattr(config, "num_layers", "unknown"), + "hidden_size": getattr(config, "hidden_size", "unknown"), + "num_attention_heads": getattr(config, "num_attention_heads", "unknown"), + "rope_type": getattr(config, "rope_type", "unknown"), + "layernorm_epsilon": getattr(config, "layernorm_epsilon", "unknown"), + "sequence_parallel": getattr(config, "sequence_parallel", "unknown"), + } + weight_stats = {} + model = self.actor_module[0] + for name, param in model.named_parameters(): + if any(k in name for k in ["layers.0.", "output_layer", "word_embeddings"]): + weight_stats[name] = { + "shape": list(param.shape), + "mean": param.float().mean().item(), + "std": param.float().std().item(), + "norm": param.float().norm().item(), + } + diag["weight_stats"] = weight_stats + return diag + def _read_router_replay_state(self): """Read the current RouterReplay state from all instances.""" from megatron.core.transformer.moe.router_replay import RouterReplay @@ -301,13 +334,11 @@ def init_configs( override_config_kwargs.update(model_config_kwargs.get("model_config", {})) update_model_config(hf_config, override_config_kwargs=override_config_kwargs) - # if flash_attn is enabled, we use flash attention backend, otherwise fall back to fused attention backend transformer_config_kwargs = ( transformer_config_kwargs if isinstance(transformer_config_kwargs, dict) else OmegaConf.to_container(transformer_config_kwargs, resolve=True) ) - transformer_config_kwargs["attention_backend"] = "flash" if flash_attn else "fused" if not self.cfg.gradient_checkpointing: for key in ("recompute_granularity", "recompute_method", "recompute_num_layers"): @@ -315,6 +346,18 @@ def init_configs( bridge = AutoBridge.from_hf_pretrained(model_path, trust_remote_code=True) provider = bridge.to_megatron_provider() + + # Determine attention backend. MLA (Multi-Latent Attention) with TE fused + # attention can produce NaN/incorrect results; fall back to unfused for MLA. + if "attention_backend" not in transformer_config_kwargs: + has_mla = getattr(provider, "multi_latent_attention", False) + if flash_attn: + transformer_config_kwargs["attention_backend"] = "flash" + elif has_mla: + transformer_config_kwargs["attention_backend"] = "unfused" + else: + transformer_config_kwargs["attention_backend"] = "fused" + provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 @@ -322,7 +365,7 @@ def init_configs( provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size provider.expert_tensor_parallel_size = megatron_config.expert_tensor_parallel_size provider.sequence_parallel = megatron_config.tensor_model_parallel_size > 1 - provider.attention_backend = "flash" if flash_attn else "fused" + provider.attention_backend = transformer_config_kwargs["attention_backend"] provider.variable_seq_lengths = True provider.masked_softmax_fusion = True # Apply explicit MoE config fields to the provider. @@ -558,6 +601,11 @@ def init_model(self, model_path, num_training_steps: int = 1e9): flash_attn=self.cfg.flash_attn, ) + if self.enable_router_replay: + from skyrl.backends.skyrl_train.utils.replay_utils import _patch_topk_router_layer_number + + _patch_topk_router_layer_number() + # wrap with DDP for training self.actor_module = self.make_megatron_module( wrap_with_ddp=True, @@ -860,6 +908,11 @@ def init_model(self, model_path, num_training_steps: int = 1e9): flash_attn=self.cfg.flash_attn, ) + if self.enable_router_replay: + from skyrl.backends.skyrl_train.utils.replay_utils import _patch_topk_router_layer_number + + _patch_topk_router_layer_number() + self.actor_module = self.make_megatron_module( wrap_with_ddp=False, ddp_config=None, diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 3008b615c0..d06f16b51f 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -801,7 +801,7 @@ def run_p2p_access_check(): if device_count < 2: return False - # Check P2P access between all GPU pairs + # # Check P2P access between all GPU pairs for i in range(device_count): for j in range(device_count): if i != j: diff --git a/skyrl/utils/tok.py b/skyrl/utils/tok.py index 1006830500..b328ee8334 100644 --- a/skyrl/utils/tok.py +++ b/skyrl/utils/tok.py @@ -7,6 +7,7 @@ def get_tokenizer(model_name_or_path, **tokenizer_kwargs) -> AutoTokenizer: """Gets tokenizer for the given base model with the given parameters Sets the pad token ID to EOS token ID if `None`""" + tokenizer_kwargs.setdefault("trust_remote_code", True) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **tokenizer_kwargs) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 3a0bfe532c..1394e2b43c 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -33,6 +33,8 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env_vars["NVTE_FUSED_ATTN"] = "0" + env_vars["RAY_CGRAPH_get_timeout"] = "600" + env_vars["SKYRL_DEBUG_LOGITS"] = "1" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") @@ -40,6 +42,10 @@ def ray_init_fixture(): raise RuntimeError("SKYRL_PYTHONPATH_EXPORT is set but PYTHONPATH is not defined in environment") env_vars["PYTHONPATH"] = pythonpath + # RAY_CGRAPH_get_timeout must be set in os.environ so that vLLM subprocesses + # (EngineCore) inherit it — runtime_env alone doesn't propagate to subprocesses. + os.environ["RAY_CGRAPH_get_timeout"] = env_vars.pop("RAY_CGRAPH_get_timeout") + logger.info(f"Initializing Ray with environment variables: {env_vars}") ray.init(runtime_env={"env_vars": env_vars}) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 6e31b5815b..55e30ddb6a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -33,7 +33,8 @@ # TODO (erictang000): we would prefer to use this smaller MoE model for testing, but seeing incorrect logprobs when using EP > 1 # this might be a model specific mbridge issue - see if this persists when we transition to Megatron-Bridge # MOE_MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B" -MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +# MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +MOE_MODEL_NAME = "/home/ray/moonlight16b" def get_test_actor_config(model_name=MODEL_NAME) -> SkyRLTrainConfig: @@ -228,10 +229,10 @@ async def test_megatron_forward( cfg.trainer.use_sample_packing = use_sample_packing batch = get_test_training_batch(max(4, gpus_per_node)) - if ep > 1: - if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: - cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() - cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 + # if ep > 1: + # if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: + # cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 if lora: cfg.trainer.policy.model.lora = SkyRLLoraConfig(rank=16, alpha=16) @@ -246,6 +247,9 @@ async def test_megatron_forward( action_log_probs_refs = actor_group.async_run_ray_method("mesh", "forward", data=batch) all_rank_action_log_probs = ray.get(action_log_probs_refs) + print( + f"Megatron logprobs - mean: {all_rank_action_log_probs.mean().item():.6f}, std: {all_rank_action_log_probs.std().item():.6f}" + ) action_log_probs_megatron = concatenate_outputs_after_mesh_dispatch( actor_group.actor_infos, all_rank_action_log_probs )["output"] @@ -258,8 +262,8 @@ async def test_megatron_forward( @ray.remote(num_gpus=1) def run_hf_forward(batch, model_name): config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, dtype=torch.bfloat16) - if ep > 1: - config.num_hidden_layers = 2 + # if ep > 1: + # config.num_hidden_layers = 2 model = AutoModelForCausalLM.from_pretrained(model_name, config=config, dtype=torch.bfloat16) model.eval() model.to("cuda") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index f113240443..6ab38cd31c 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -24,8 +24,11 @@ from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices -MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +MOE_MODEL_NAME = "/home/ray/moonlight16b" +# MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" REPLAY_NUM_LAYERS = 2 +NUM_PROMPTS = 10 +N_SAMPLES_PER_PROMPT = 5 def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: @@ -35,6 +38,20 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: cfg.trainer.micro_train_batch_size_per_gpu = 2 cfg.trainer.use_sample_packing = False cfg.trainer.logger = "console" + if "moonlight" in model_name: + # flash attn not supported for moonlight16b + cfg.trainer.policy.megatron_config.moe_token_dispatcher_type = "alltoall" + cfg.trainer.policy.megatron_config.moe_router_load_balancing_type = "seq_aux_loss" + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_aux_loss_coeff"] = 0 + cfg.trainer.policy.megatron_config.moe_router_score_function = "sigmoid" + cfg.trainer.policy.megatron_config.moe_router_enable_expert_bias = True + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_bias_update_rate"] = 0 + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_dtype"] = "fp32" + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_topk"] = 6 + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_pre_softmax"] = True + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_group_topk"] = 1 + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_num_groups"] = 1 + cfg.trainer.flash_attn = False validate_cfg(cfg) return cfg @@ -51,7 +68,7 @@ def test_megatron_router_replay(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 2 cfg.generator.sampling_params = SamplingParams( - max_generate_length=1024, + max_generate_length=128, logprobs=1, temperature=1.0, ) @@ -95,7 +112,7 @@ def test_megatron_router_replay(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=16, + max_generate_length=128, min_p=0.0, logprobs=1, ), @@ -211,7 +228,7 @@ def test_megatron_router_replay(ray_init_fixture): @pytest.mark.megatron -def test_megatron_router_replay_logprobs(ray_init_fixture): +def test_logprobs(ray_init_fixture): """ Check that logprob diff is lower when using router replay. Requires full 8xH100 setup to do full forward pass. """ @@ -221,7 +238,7 @@ def test_megatron_router_replay_logprobs(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 8 cfg.generator.sampling_params = SamplingParams( - max_generate_length=1024, + max_generate_length=128, logprobs=1, temperature=1.0, ) @@ -251,8 +268,8 @@ def test_megatron_router_replay_logprobs(ray_init_fixture): input_batch: GeneratorInput = get_test_generator_input( model=MOE_MODEL_NAME, - num_prompts=4, - n_samples_per_prompt=1, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, max_prompt_length=512, env_class="gsm8k", ) @@ -262,7 +279,7 @@ def test_megatron_router_replay_logprobs(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=16, + max_generate_length=128, min_p=0.0, logprobs=1, ), @@ -329,7 +346,11 @@ def test_megatron_router_replay_logprobs(ray_init_fixture): cfg.trainer.micro_forward_batch_size_per_gpu = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 - def run_megatron_forward(enable_replay: bool) -> torch.Tensor: + import os + + os.environ["SKYRL_DEBUG_LOGITS"] = "1" + + def run_megatron_forward(enable_replay: bool, debug: bool = False) -> torch.Tensor: cfg.trainer.policy.megatron_config.transformer_config_kwargs = { "moe_enable_routing_replay": enable_replay, } @@ -340,6 +361,19 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: num_gpus_per_node=8, cfg=cfg, ) + + if debug: + diag = ray.get(actor_group.async_run_ray_method("pass_through", "debug_model_config"))[0] + print(f"\n=== Model Config (replay={enable_replay}) ===") + for k, v in diag.items(): + if k != "weight_stats": + print(f" {k}: {v}") + print(f" weight_stats ({len(diag['weight_stats'])} params):") + for name, stats in sorted(diag["weight_stats"].items()): + print( + f" {name}: shape={stats['shape']}, mean={stats['mean']:.6f}, std={stats['std']:.6f}, norm={stats['norm']:.2f}" + ) + refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) results = ray.get(refs) outputs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] @@ -348,11 +382,14 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: ray.kill(actor) return outputs - r3_logprobs = run_megatron_forward(enable_replay=True) + r3_logprobs = run_megatron_forward(enable_replay=True, debug=True) no_r3_logprobs = run_megatron_forward(enable_replay=False) r3_diff = (logprobs_t - r3_logprobs).abs() no_r3_diff = (logprobs_t - no_r3_logprobs).abs() + print(f"vLLM logprobs - mean: {logprobs_t.mean().item():.6f}, std: {logprobs_t.std().item():.6f}") + print(f"Megatron (replay) - mean: {r3_logprobs.mean().item():.6f}, std: {r3_logprobs.std().item():.6f}") + print(f"Megatron (no rep) - mean: {no_r3_logprobs.mean().item():.6f}, std: {no_r3_logprobs.std().item():.6f}") print(f"With replay - logprob diff mean: {r3_diff.mean().item():.6f}, std: {r3_diff.std().item():.6f}") print(f"Without replay - logprob diff mean: {no_r3_diff.mean().item():.6f}, std: {no_r3_diff.std().item():.6f}") @@ -365,7 +402,7 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: @pytest.mark.megatron -def test_megatron_router_replay_forward_backward(ray_init_fixture): +def test_forward_backward(ray_init_fixture): """ Check that forward_backward produces similar losses with and without router replay (same weights, so routing decisions should nearly match). @@ -377,7 +414,7 @@ def test_megatron_router_replay_forward_backward(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 8 cfg.generator.sampling_params = SamplingParams( - max_generate_length=1024, + max_generate_length=128, logprobs=1, temperature=1.0, ) @@ -407,8 +444,8 @@ def test_megatron_router_replay_forward_backward(ray_init_fixture): input_batch: GeneratorInput = get_test_generator_input( model=MOE_MODEL_NAME, - num_prompts=10, - n_samples_per_prompt=5, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, max_prompt_length=512, env_class="gsm8k", ) @@ -418,7 +455,7 @@ def test_megatron_router_replay_forward_backward(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=16, + max_generate_length=128, min_p=0.0, logprobs=1, ), From 93eee6545070b432f1b0ff729420e86f41617a5b Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 6 Mar 2026 02:25:05 +0000 Subject: [PATCH 10/31] x --- .../workers/megatron/megatron_model_wrapper.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index ec9717819e..540e9328e2 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -87,22 +87,6 @@ def collection_func(logits, data): tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() - if tp_rank == 0 and mpu.get_data_parallel_rank() == 0: - import os - - if os.environ.get("SKYRL_DEBUG_LOGITS"): - print( - f"[DEBUG] logits shape={logits.shape}, " - f"mean={logits.float().mean().item():.4f}, " - f"std={logits.float().std().item():.4f}, " - f"min={logits.float().min().item():.4f}, " - f"max={logits.float().max().item():.4f}, " - f"sequences shape={sequences.shape}, " - f"attention_backend={getattr(get_model_config(self.actor_module[0]), 'attention_backend', 'unknown')}, " - f"multi_latent_attention={getattr(get_model_config(self.actor_module[0]), 'multi_latent_attention', 'unknown')}, " - f"q_lora_rank={getattr(get_model_config(self.actor_module[0]), 'q_lora_rank', 'unknown')}" - ) - if temperature != 1.0: logits.div_(temperature) From 9c716a16f251369dc7b2a728570283e00988f826 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Mon, 9 Mar 2026 19:59:35 +0000 Subject: [PATCH 11/31] fixed test for moonlight by enforcing fused attn --- .../workers/megatron/megatron_model_wrapper.py | 16 ---------------- .../workers/megatron/megatron_worker.py | 6 +++--- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index ec9717819e..540e9328e2 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -87,22 +87,6 @@ def collection_func(logits, data): tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() - if tp_rank == 0 and mpu.get_data_parallel_rank() == 0: - import os - - if os.environ.get("SKYRL_DEBUG_LOGITS"): - print( - f"[DEBUG] logits shape={logits.shape}, " - f"mean={logits.float().mean().item():.4f}, " - f"std={logits.float().std().item():.4f}, " - f"min={logits.float().min().item():.4f}, " - f"max={logits.float().max().item():.4f}, " - f"sequences shape={sequences.shape}, " - f"attention_backend={getattr(get_model_config(self.actor_module[0]), 'attention_backend', 'unknown')}, " - f"multi_latent_attention={getattr(get_model_config(self.actor_module[0]), 'multi_latent_attention', 'unknown')}, " - f"q_lora_rank={getattr(get_model_config(self.actor_module[0]), 'q_lora_rank', 'unknown')}" - ) - if temperature != 1.0: logits.div_(temperature) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 77c4b1327e..1643c4a86d 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -350,11 +350,11 @@ def init_configs( # Determine attention backend. MLA (Multi-Latent Attention) with TE fused # attention can produce NaN/incorrect results; fall back to unfused for MLA. if "attention_backend" not in transformer_config_kwargs: - has_mla = getattr(provider, "multi_latent_attention", False) + # has_mla = getattr(provider, "multi_latent_attention", False) if flash_attn: transformer_config_kwargs["attention_backend"] = "flash" - elif has_mla: - transformer_config_kwargs["attention_backend"] = "unfused" + # elif has_mla: + # transformer_config_kwargs["attention_backend"] = "unfused" else: transformer_config_kwargs["attention_backend"] = "fused" From 6de7d5c9e3ff6cbe524e18e59d804238de2ca7b6 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Mon, 9 Mar 2026 20:01:34 +0000 Subject: [PATCH 12/31] x --- .../skyrl_train/gpu/gpu_ci/conftest.py | 1 - .../gpu/gpu_ci/test_router_replay.py | 205 +++++++++++++++--- 2 files changed, 169 insertions(+), 37 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 1394e2b43c..9ea416dcee 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -34,7 +34,6 @@ def ray_init_fixture(): env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env_vars["NVTE_FUSED_ATTN"] = "0" env_vars["RAY_CGRAPH_get_timeout"] = "600" - env_vars["SKYRL_DEBUG_LOGITS"] = "1" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 6ab38cd31c..477af7d64a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -28,7 +28,8 @@ # MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" REPLAY_NUM_LAYERS = 2 NUM_PROMPTS = 10 -N_SAMPLES_PER_PROMPT = 5 +N_SAMPLES_PER_PROMPT = 4 +MAX_GENERATE_LENGTH = 1024 def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: @@ -36,26 +37,84 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: cfg.trainer.policy.model.path = model_name cfg.trainer.micro_forward_batch_size_per_gpu = 2 cfg.trainer.micro_train_batch_size_per_gpu = 2 - cfg.trainer.use_sample_packing = False + cfg.trainer.use_sample_packing = True + # flash attn + mla works without sample packing, logprobs are crazy/wrong + # but flash-attn correctly throws error with sample packing + # we should add an assert that if you set use_sample_packing=False flash attn can accidentally be used cfg.trainer.logger = "console" if "moonlight" in model_name: + if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = {} # flash attn not supported for moonlight16b - cfg.trainer.policy.megatron_config.moe_token_dispatcher_type = "alltoall" - cfg.trainer.policy.megatron_config.moe_router_load_balancing_type = "seq_aux_loss" - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_aux_loss_coeff"] = 0 - cfg.trainer.policy.megatron_config.moe_router_score_function = "sigmoid" - cfg.trainer.policy.megatron_config.moe_router_enable_expert_bias = True - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_bias_update_rate"] = 0 - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_dtype"] = "fp32" - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_topk"] = 6 - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_pre_softmax"] = True - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_group_topk"] = 1 - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_num_groups"] = 1 + # cfg.trainer.policy.megatron_config.moe_token_dispatcher_type = "alltoall" + # cfg.trainer.policy.megatron_config.moe_router_load_balancing_type = "seq_aux_loss" + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_aux_loss_coeff"] = 0 + # cfg.trainer.policy.megatron_config.moe_router_score_function = "sigmoid" + # cfg.trainer.policy.megatron_config.moe_router_enable_expert_bias = True + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_bias_update_rate"] = 0 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_dtype"] = "fp32" + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_topk"] = 6 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_pre_softmax"] = True + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_group_topk"] = 1 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_num_groups"] = 1 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers_in_last_pipeline_stage"] = 13 cfg.trainer.flash_attn = False validate_cfg(cfg) return cfg +def build_training_input_from_text_samples( + tokenizer: AutoTokenizer, prompt_response_pairs: list[tuple[str, str]] +) -> TrainingInputBatch: + prompts = [] + responses = [] + rewards = [] + loss_masks = [] + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + for prompt_text, response_text in prompt_response_pairs: + prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False) + response_ids = tokenizer.encode(response_text, add_special_tokens=False) + if tokenizer.eos_token_id is not None and (not response_ids or response_ids[-1] != tokenizer.eos_token_id): + response_ids.append(tokenizer.eos_token_id) + + prompts.append(prompt_ids) + responses.append(response_ids) + rewards.append([0.0] * len(response_ids)) + loss_masks.append([1] * len(response_ids)) + + sequences, attention_mask, response_mask, rewards_t, loss_mask_t, _, _ = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=prompts, + responses=responses, + rewards=rewards, + loss_masks=loss_masks, + ) + ) + + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) + training_input.metadata = {"response_length": num_actions} + return training_input + + @pytest.mark.megatron def test_megatron_router_replay(ray_init_fixture): """ @@ -227,6 +286,96 @@ def test_megatron_router_replay(ray_init_fixture): ray.shutdown() +@pytest.mark.megatron +def test_moonlight_logprobs(ray_init_fixture): + """ + Check magnitude of moonlight-16b-a3b logprobs without router replay. + """ + actor_group = None + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, + max_prompt_length=512, + env_class="gsm8k", + ) + training_input = build_training_input_from_text_samples( + tokenizer=tokenizer, + prompt_response_pairs=[ + ( + tokenizer.apply_chat_template( + prompt + if any(message["role"] == "system" for message in prompt) + else [{"role": "system", "content": "You are a helpful assistant."}] + prompt, + add_generation_prompt=True, + tokenize=False, + ), + " " + + ( + " ".join( + next( + ( + message["content"] + for message in reversed(prompt) + if message["role"] == "user" + ), + "", + ).split()[:16] + ).strip() + or "I will solve this step by step." + ), + ) + for prompt in input_batch["prompts"] + ], + ) + + cfg.trainer.placement.policy_num_gpus_per_node = 8 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + **(cfg.trainer.policy.megatron_config.transformer_config_kwargs or {}), + "moe_enable_routing_replay": False, + } + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=8, + cfg=cfg, + ) + + with Timer("moonlight_forward"): + refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) + results = ray.get(refs) + + action_log_probs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] + valid_log_probs = action_log_probs[training_input["response_mask"].bool()] + avg_logprob = valid_log_probs.mean().item() + print( + f"Moonlight Megatron logprobs - mean: {avg_logprob:.6f}, std: {valid_log_probs.std().item():.6f}, " + f"num_tokens: {valid_log_probs.numel()}" + ) + + assert valid_log_probs.numel() > 0, "Expected at least one valid response token" + assert torch.isfinite(valid_log_probs).all().item(), "Expected all response logprobs to be finite" + assert -20.0 < avg_logprob < -0.01, f"Unexpected average logprob magnitude: {avg_logprob:.6f}" + finally: + if actor_group is not None: + for actor in actor_group._actor_handlers: + ray.kill(actor) + ray.shutdown() + @pytest.mark.megatron def test_logprobs(ray_init_fixture): """ @@ -238,7 +387,7 @@ def test_logprobs(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 8 cfg.generator.sampling_params = SamplingParams( - max_generate_length=128, + max_generate_length=MAX_GENERATE_LENGTH, logprobs=1, temperature=1.0, ) @@ -279,7 +428,7 @@ def test_logprobs(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=128, + max_generate_length=MAX_GENERATE_LENGTH, min_p=0.0, logprobs=1, ), @@ -338,7 +487,7 @@ def test_logprobs(ray_init_fixture): training_input.metadata = {"response_length": num_actions} cfg.trainer.placement.policy_num_gpus_per_node = 8 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 4 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 cfg.trainer.policy.megatron_config.context_parallel_size = 1 cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 @@ -346,11 +495,7 @@ def test_logprobs(ray_init_fixture): cfg.trainer.micro_forward_batch_size_per_gpu = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 - import os - - os.environ["SKYRL_DEBUG_LOGITS"] = "1" - - def run_megatron_forward(enable_replay: bool, debug: bool = False) -> torch.Tensor: + def run_megatron_forward(enable_replay: bool) -> torch.Tensor: cfg.trainer.policy.megatron_config.transformer_config_kwargs = { "moe_enable_routing_replay": enable_replay, } @@ -362,18 +507,6 @@ def run_megatron_forward(enable_replay: bool, debug: bool = False) -> torch.Tens cfg=cfg, ) - if debug: - diag = ray.get(actor_group.async_run_ray_method("pass_through", "debug_model_config"))[0] - print(f"\n=== Model Config (replay={enable_replay}) ===") - for k, v in diag.items(): - if k != "weight_stats": - print(f" {k}: {v}") - print(f" weight_stats ({len(diag['weight_stats'])} params):") - for name, stats in sorted(diag["weight_stats"].items()): - print( - f" {name}: shape={stats['shape']}, mean={stats['mean']:.6f}, std={stats['std']:.6f}, norm={stats['norm']:.2f}" - ) - refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) results = ray.get(refs) outputs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] @@ -382,7 +515,7 @@ def run_megatron_forward(enable_replay: bool, debug: bool = False) -> torch.Tens ray.kill(actor) return outputs - r3_logprobs = run_megatron_forward(enable_replay=True, debug=True) + r3_logprobs = run_megatron_forward(enable_replay=True) no_r3_logprobs = run_megatron_forward(enable_replay=False) r3_diff = (logprobs_t - r3_logprobs).abs() @@ -414,7 +547,7 @@ def test_forward_backward(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 8 cfg.generator.sampling_params = SamplingParams( - max_generate_length=128, + max_generate_length=MAX_GENERATE_LENGTH, logprobs=1, temperature=1.0, ) @@ -455,7 +588,7 @@ def test_forward_backward(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=128, + max_generate_length=MAX_GENERATE_LENGTH, min_p=0.0, logprobs=1, ), From 591af9beee0450dde9202fda45ef9a234c876df4 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Mon, 9 Mar 2026 21:34:52 +0000 Subject: [PATCH 13/31] x --- .../skyrl_train/gpu/gpu_ci/conftest.py | 1 - .../gpu/gpu_ci/test_router_replay.py | 174 +----------------- 2 files changed, 1 insertion(+), 174 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 9ea416dcee..297672c737 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -33,7 +33,6 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env_vars["NVTE_FUSED_ATTN"] = "0" - env_vars["RAY_CGRAPH_get_timeout"] = "600" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 477af7d64a..1504d04f96 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -29,7 +29,7 @@ REPLAY_NUM_LAYERS = 2 NUM_PROMPTS = 10 N_SAMPLES_PER_PROMPT = 4 -MAX_GENERATE_LENGTH = 1024 +MAX_GENERATE_LENGTH = 128 def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: @@ -114,178 +114,6 @@ def build_training_input_from_text_samples( training_input.metadata = {"response_length": num_actions} return training_input - -@pytest.mark.megatron -def test_megatron_router_replay(ray_init_fixture): - """ - Test that SkyRLGymGenerator returns rollout_inference_indices - for MoE models with enable_return_routed_experts=True. - """ - try: - cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) - cfg.trainer.strategy = "megatron" - cfg.generator.inference_engine.enable_return_routed_experts = True - cfg.generator.inference_engine.tensor_parallel_size = 2 - cfg.generator.sampling_params = SamplingParams( - max_generate_length=128, - logprobs=1, - temperature=1.0, - ) - cfg.generator.batched = False - cfg.generator.max_turns = 1 - - num_prompts = 1 - - tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) - - with InferenceEngineState.create( - cfg=cfg, - model=MOE_MODEL_NAME, - use_local=True, - backend="vllm", - sleep_level=1, - gpu_memory_utilization=0.9, - max_num_seqs=1, - ) as engines: - client = engines.client - - asyncio.run(client.wake_up()) - - generator = SkyRLGymGenerator( - generator_cfg=cfg.generator, - skyrl_gym_cfg=cfg.environment.skyrl_gym, - inference_engine_client=client, - tokenizer=tokenizer, - ) - - input_batch: GeneratorInput = get_test_generator_input( - model=MOE_MODEL_NAME, - num_prompts=num_prompts, - n_samples_per_prompt=1, - max_prompt_length=512, - env_class="gsm8k", - ) - input_batch["sampling_params"] = get_sampling_params_for_backend( - "vllm", - SamplingParams( - temperature=1.0, - top_p=1.0, - top_k=-1, - max_generate_length=128, - min_p=0.0, - logprobs=1, - ), - ) - - with Timer("generate_with_router_replay"): - generator_output = asyncio.run(generator.generate(input_batch)) - - # --- Basic output checks --- - assert ( - "rollout_inference_indices" in generator_output - ), "rollout_inference_indices missing from GeneratorOutput" - indices = generator_output["rollout_inference_indices"] - assert ( - indices is not None - ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" - - responses = generator_output["response_ids"] - assert len(indices) == len( - responses - ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" - - rewards = generator_output["rewards"] - if rewards and not isinstance(rewards[0], list): - rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] - (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( - convert_prompts_responses_to_batch_tensors( - tokenizer=tokenizer, - prompts=generator_output["prompt_token_ids"], - responses=responses, - rewards=rewards, - loss_masks=generator_output["loss_masks"], - logprobs=generator_output.get("rollout_logprobs"), - rollout_inference_indices=indices, - ) - ) - - assert rii_tensor is not None - rii_tensor = rii_tensor[:, :, :REPLAY_NUM_LAYERS, :] - - num_actions = response_mask.shape[1] - batch_size = sequences.shape[0] - training_input = TrainingInputBatch( - { - "sequences": sequences, - "attention_mask": attention_mask, - "response_mask": response_mask, - "rewards": rewards_t, - "loss_mask": loss_mask_t, - "rollout_logprobs": ( - logprobs_t - if logprobs_t is not None - else torch.zeros((batch_size, num_actions), dtype=torch.float32) - ), - "rollout_inference_indices": rii_tensor, - "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "action_mask": response_mask.to(dtype=torch.int64), - } - ) - training_input.metadata = {"response_length": num_actions} - - cfg.trainer.policy.megatron_config.transformer_config_kwargs = { - "num_layers": REPLAY_NUM_LAYERS, - "moe_enable_routing_replay": True, - } - cfg.trainer.placement.policy_num_gpus_per_node = 2 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 - cfg.trainer.micro_forward_batch_size_per_gpu = 1 - cfg.trainer.micro_train_batch_size_per_gpu = 1 - - actor_group = init_worker_with_type( - "policy", - shared_pg=None, - colocate_all=False, - num_gpus_per_node=2, - cfg=cfg, - ) - - expected_per_layer = _split_replay_indices(rii_tensor.to(torch.long)) - - state = ray.get( - actor_group.async_run_ray_method( - "pass_through", - "debug_setup_router_replay_state", - data=training_input, - ) - )[0] - - assert state is not None, "Worker returned None state" - assert ( - "REPLAY_FORWARD" in state["action"] - ), f"RouterReplay action should be REPLAY_FORWARD, got: {state['action']}" - assert state["num_instances"] == len(expected_per_layer), ( - f"Expected {len(expected_per_layer)} replay instances (one per layer), " f"got {state['num_instances']}" - ) - for layer_idx, (got, expected) in enumerate(zip(state["target_indices"], expected_per_layer)): - assert torch.equal( - got.to(torch.long), expected.to(torch.long) - ), f"Layer {layer_idx}: Megatron target indices differ from vLLM indices" - print( - f"PASSED: vLLM routing indices ({rii_tensor.shape}) correctly " - f"loaded into {state['num_instances']} Megatron RouterReplay instances" - ) - - finally: - ray.shutdown() - - @pytest.mark.megatron def test_moonlight_logprobs(ray_init_fixture): """ From acb35ec2692afd2684b079787fb274257a7aab42 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 10 Mar 2026 22:33:37 +0000 Subject: [PATCH 14/31] clean up --- .../inference_engines/vllm/vllm_engine.py | 2 +- .../skyrl_train/utils/replay_utils.py | 44 +------ .../megatron/megatron_model_wrapper.py | 6 +- .../workers/megatron/megatron_worker.py | 119 ++++-------------- .../skyrl_train/workers/worker_utils.py | 1 + skyrl/train/dataset/replay_buffer.py | 5 + skyrl/train/utils/utils.py | 2 +- .../skyrl_train/gpu/gpu_ci/conftest.py | 6 +- .../gpu/gpu_ci/test_router_replay.py | 104 +-------------- 9 files changed, 42 insertions(+), 247 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 048fba7ec8..11541786de 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -168,7 +168,7 @@ def _postprocess_outputs(self, outputs): if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params - if len(rollout_inference_indices) == 0 and _routed_experts is None: + if len(rollout_inference_indices) == 0 and rollout_inference_indices[0] is None: rollout_inference_indices = None # hack: assume uniform sampling params return InferenceEngineOutput( diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 68d266e052..cf016f7e46 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -4,7 +4,6 @@ import torch from typing import List -from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch def _patch_topk_router_layer_number(): @@ -46,6 +45,8 @@ def _patch_alltoall_dispatcher_for_replay(): routing_map.sum() < num_tokens * topk, leading to a split size mismatch in the alltoall collective. We fix this by deriving num_out_tokens from the routing map instead of the static num_tokens * topk formula. + + Reference: https://github.com/verl-project/verl/pull/4986 """ try: from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher @@ -119,7 +120,7 @@ def _remove_left_padding_from_indices( return new_rii -def _setup_per_microbatch_replay( +def setup_per_microbatch_replay( rollout_inference_indices: torch.Tensor, attention_mask: torch.Tensor, ) -> None: @@ -176,45 +177,6 @@ def _setup_per_microbatch_replay( RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) -def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: - """ - Set up router replay for forward pass (ref/policy inference). - """ - if not enable_router_replay: - return False - - rollout_inference_indices = data.get("rollout_inference_indices") - if rollout_inference_indices is None: - return False - - from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction - - RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) - RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - - return True - - -def setup_router_replay_backward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: - """ - Set up router replay for training forward/backward pass. - """ - if not enable_router_replay: - return False - - rollout_inference_indices = data.get("rollout_inference_indices") - if rollout_inference_indices is None: - return False - - from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction - - RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) - # Use REPLAY_FORWARD - Megatron handles REPLAY_BACKWARD automatically - RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - - return True - - def clear_router_replay(): """Clear all router replay state.""" from megatron.core.transformer.moe.router_replay import RouterReplay diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 540e9328e2..a7f6ade581 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -24,7 +24,7 @@ remove_left_padding, recover_left_padding, ) -from skyrl.backends.skyrl_train.utils.replay_utils import _setup_per_microbatch_replay +from skyrl.backends.skyrl_train.utils.replay_utils import setup_per_microbatch_replay class MegatronModelWrapper: @@ -107,7 +107,7 @@ def forward_step(batch_iter, model): rollout_inference_indices = batch.pop("rollout_inference_indices", None) if rollout_inference_indices is not None: - _setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) @@ -363,7 +363,7 @@ def forward_step(batch_iter, model): rollout_inference_indices = batch.pop("rollout_inference_indices", None) if rollout_inference_indices is not None: - _setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 1643c4a86d..cd5343c5bb 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -33,7 +33,7 @@ from skyrl.train.utils.utils import update_model_config, str_to_torch_dtype from skyrl.backends.skyrl_train.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch -from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics, all_reduce_metrics +from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics, all_reduce_metrics, BatchIterator from skyrl.backends.skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -253,63 +253,6 @@ def extract_weights(self, dtype: torch.dtype): class MegatronWorker: - def debug_model_config(self): - """Return model config diagnostics for debugging logprob mismatch.""" - from skyrl.backends.skyrl_train.distributed.megatron.megatron_utils import get_model_config - - config = get_model_config(self.actor_module[0]) - diag = { - "attention_backend": str(getattr(config, "attention_backend", "unknown")), - "multi_latent_attention": getattr(config, "multi_latent_attention", "unknown"), - "q_lora_rank": getattr(config, "q_lora_rank", "unknown"), - "kv_lora_rank": getattr(config, "kv_lora_rank", "unknown"), - "qk_head_dim": getattr(config, "qk_head_dim", "unknown"), - "qk_pos_emb_head_dim": getattr(config, "qk_pos_emb_head_dim", "unknown"), - "v_head_dim": getattr(config, "v_head_dim", "unknown"), - "num_layers": getattr(config, "num_layers", "unknown"), - "hidden_size": getattr(config, "hidden_size", "unknown"), - "num_attention_heads": getattr(config, "num_attention_heads", "unknown"), - "rope_type": getattr(config, "rope_type", "unknown"), - "layernorm_epsilon": getattr(config, "layernorm_epsilon", "unknown"), - "sequence_parallel": getattr(config, "sequence_parallel", "unknown"), - } - weight_stats = {} - model = self.actor_module[0] - for name, param in model.named_parameters(): - if any(k in name for k in ["layers.0.", "output_layer", "word_embeddings"]): - weight_stats[name] = { - "shape": list(param.shape), - "mean": param.float().mean().item(), - "std": param.float().std().item(), - "norm": param.float().norm().item(), - } - diag["weight_stats"] = weight_stats - return diag - - def _read_router_replay_state(self): - """Read the current RouterReplay state from all instances.""" - from megatron.core.transformer.moe.router_replay import RouterReplay - - # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info - instances = RouterReplay.global_router_replay_instances or [] - action = instances[0].router_replay_action if instances else None - - target_indices = [inst.target_topk_idx.detach().cpu() for inst in instances if inst.target_topk_idx is not None] - - return { - "action": str(action), - "target_indices": target_indices, - "num_instances": len(instances), - } - - def debug_setup_router_replay_state(self, data: TrainingInputBatch): - from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay - - setup_router_replay_forward(data, enable_router_replay=True) - state = self._read_router_replay_state() - clear_router_replay() - return state - def init_configs( self, model_path, @@ -347,17 +290,6 @@ def init_configs( bridge = AutoBridge.from_hf_pretrained(model_path, trust_remote_code=True) provider = bridge.to_megatron_provider() - # Determine attention backend. MLA (Multi-Latent Attention) with TE fused - # attention can produce NaN/incorrect results; fall back to unfused for MLA. - if "attention_backend" not in transformer_config_kwargs: - # has_mla = getattr(provider, "multi_latent_attention", False) - if flash_attn: - transformer_config_kwargs["attention_backend"] = "flash" - # elif has_mla: - # transformer_config_kwargs["attention_backend"] = "unfused" - else: - transformer_config_kwargs["attention_backend"] = "fused" - provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 @@ -365,7 +297,7 @@ def init_configs( provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size provider.expert_tensor_parallel_size = megatron_config.expert_tensor_parallel_size provider.sequence_parallel = megatron_config.tensor_model_parallel_size > 1 - provider.attention_backend = transformer_config_kwargs["attention_backend"] + provider.attention_backend = "flash" if flash_attn else "fused" provider.variable_seq_lengths = True provider.masked_softmax_fusion = True # Apply explicit MoE config fields to the provider. @@ -510,7 +442,6 @@ def forward(self, data: TrainingInputBatch): log_probs = log_probs.to("cpu") output = TrainingOutputBatch({"output": log_probs}) output.metadata = data.metadata - self._last_router_replay_state = self._read_router_replay_state() clear_router_replay() return output @@ -685,31 +616,29 @@ def forward_backward( # Move data to GPU data.to(torch.cuda.current_device()) - # Chunk manually so we can propagate rollout_inference_indices for - # per-micro-batch router replay (BatchIterator/Experience don't carry them). + # Build micro-batch dicts expected by forward_backward_mini_batch micro_buffer = [] - for micro in data.chunk(micro_batch_size): - sequences = micro["sequences"] - attention_mask = micro["attention_mask"] + for experience in BatchIterator(data, micro_batch_size, drop_last=False): + sequences = experience.sequences + attention_mask = experience.attention_mask position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) - micro_dict = { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": micro.metadata["response_length"], - "old_action_log_probs": micro.get("action_log_probs"), - "base_action_log_probs": micro.get("base_action_log_probs"), - "advantages": micro.get("advantages"), - "loss_mask": micro.get("loss_mask"), - "rollout_action_logprobs": micro.get("rollout_logprobs"), - "action_mask": micro.get("action_mask"), - } - rii = micro.get("rollout_inference_indices") - if rii is not None and self.enable_router_replay: - micro_dict["rollout_inference_indices"] = rii - micro_buffer.append(micro_dict) + micro_buffer.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": experience.num_actions, + "old_action_log_probs": experience.action_log_probs, + "base_action_log_probs": experience.base_action_log_probs, + "advantages": experience.advantages, + "loss_mask": experience.loss_mask, + "rollout_action_logprobs": experience.rollout_logprobs, + "action_mask": experience.action_mask, + "rollout_inference_indices": experience.rollout_inference_indices, + } + ) if not micro_buffer: return {} @@ -750,7 +679,6 @@ def forward_backward( if all_loss_fn_outputs: status["loss_fn_outputs"] = all_loss_fn_outputs - self._last_router_replay_state = self._read_router_replay_state() clear_router_replay() return status @@ -908,11 +836,6 @@ def init_model(self, model_path, num_training_steps: int = 1e9): flash_attn=self.cfg.flash_attn, ) - if self.enable_router_replay: - from skyrl.backends.skyrl_train.utils.replay_utils import _patch_topk_router_layer_number - - _patch_topk_router_layer_number() - self.actor_module = self.make_megatron_module( wrap_with_ddp=False, ddp_config=None, diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index eb76c5ee7d..312e3ee52c 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -83,6 +83,7 @@ def batch_to_experience(batch: TrainingInputBatch): action_mask=batch.get("response_mask"), num_actions=batch.metadata["response_length"], # int rollout_logprobs=batch.get("rollout_logprobs"), + rollout_inference_indices=batch.get("rollout_inference_indices"), # additional info # can be used to log metrics etc for micro-batches in the worker info={}, diff --git a/skyrl/train/dataset/replay_buffer.py b/skyrl/train/dataset/replay_buffer.py index 07846937a6..109f1957ec 100644 --- a/skyrl/train/dataset/replay_buffer.py +++ b/skyrl/train/dataset/replay_buffer.py @@ -66,6 +66,7 @@ class Experience: loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]] action_mask: Optional[Integer[torch.Tensor, "batch response_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]] + rollout_inference_indices: Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]] num_actions: int info: Optional[dict] kl: Optional[Float[torch.Tensor, "batch response_len"]] = None @@ -92,6 +93,8 @@ def to_device(self, device: torch.device) -> None: self.action_mask = to(self.action_mask, device) if self.rollout_logprobs is not None: self.rollout_logprobs = to(self.rollout_logprobs, device) + if self.rollout_inference_indices is not None: + self.rollout_inference_indices = to(self.rollout_inference_indices, device) def pin_memory(self): self.sequences = pin_memory(self.sequences) @@ -113,6 +116,8 @@ def pin_memory(self): self.action_mask = self.action_mask.pin_memory() if self.rollout_logprobs is not None: self.rollout_logprobs = self.rollout_logprobs.pin_memory() + if self.rollout_inference_indices is not None: + self.rollout_inference_indices = self.rollout_inference_indices.pin_memory() return self diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index d06f16b51f..3008b615c0 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -801,7 +801,7 @@ def run_p2p_access_check(): if device_count < 2: return False - # # Check P2P access between all GPU pairs + # Check P2P access between all GPU pairs for i in range(device_count): for j in range(device_count): if i != j: diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 297672c737..17bafaa98a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -32,7 +32,7 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - env_vars["NVTE_FUSED_ATTN"] = "0" + # env_vars["NVTE_FUSED_ATTN"] = "0" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") @@ -40,10 +40,6 @@ def ray_init_fixture(): raise RuntimeError("SKYRL_PYTHONPATH_EXPORT is set but PYTHONPATH is not defined in environment") env_vars["PYTHONPATH"] = pythonpath - # RAY_CGRAPH_get_timeout must be set in os.environ so that vLLM subprocesses - # (EngineCore) inherit it — runtime_env alone doesn't propagate to subprocesses. - os.environ["RAY_CGRAPH_get_timeout"] = env_vars.pop("RAY_CGRAPH_get_timeout") - logger.info(f"Initializing Ray with environment variables: {env_vars}") ray.init(runtime_env={"env_vars": env_vars}) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 1504d04f96..9581b8f990 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -22,7 +22,6 @@ from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch -from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices MOE_MODEL_NAME = "/home/ray/moonlight16b" # MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" @@ -85,14 +84,12 @@ def build_training_input_from_text_samples( rewards.append([0.0] * len(response_ids)) loss_masks.append([1] * len(response_ids)) - sequences, attention_mask, response_mask, rewards_t, loss_mask_t, _, _ = ( - convert_prompts_responses_to_batch_tensors( - tokenizer=tokenizer, - prompts=prompts, - responses=responses, - rewards=rewards, - loss_masks=loss_masks, - ) + sequences, attention_mask, response_mask, rewards_t, loss_mask_t, _, _ = convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=prompts, + responses=responses, + rewards=rewards, + loss_masks=loss_masks, ) num_actions = response_mask.shape[1] @@ -114,95 +111,6 @@ def build_training_input_from_text_samples( training_input.metadata = {"response_length": num_actions} return training_input -@pytest.mark.megatron -def test_moonlight_logprobs(ray_init_fixture): - """ - Check magnitude of moonlight-16b-a3b logprobs without router replay. - """ - actor_group = None - try: - cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) - cfg.trainer.strategy = "megatron" - - tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) - input_batch: GeneratorInput = get_test_generator_input( - model=MOE_MODEL_NAME, - num_prompts=NUM_PROMPTS, - n_samples_per_prompt=N_SAMPLES_PER_PROMPT, - max_prompt_length=512, - env_class="gsm8k", - ) - training_input = build_training_input_from_text_samples( - tokenizer=tokenizer, - prompt_response_pairs=[ - ( - tokenizer.apply_chat_template( - prompt - if any(message["role"] == "system" for message in prompt) - else [{"role": "system", "content": "You are a helpful assistant."}] + prompt, - add_generation_prompt=True, - tokenize=False, - ), - " " - + ( - " ".join( - next( - ( - message["content"] - for message in reversed(prompt) - if message["role"] == "user" - ), - "", - ).split()[:16] - ).strip() - or "I will solve this step by step." - ), - ) - for prompt in input_batch["prompts"] - ], - ) - - cfg.trainer.placement.policy_num_gpus_per_node = 8 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 - cfg.trainer.micro_forward_batch_size_per_gpu = 1 - cfg.trainer.micro_train_batch_size_per_gpu = 1 - cfg.trainer.policy.megatron_config.transformer_config_kwargs = { - **(cfg.trainer.policy.megatron_config.transformer_config_kwargs or {}), - "moe_enable_routing_replay": False, - } - - actor_group = init_worker_with_type( - "policy", - shared_pg=None, - colocate_all=False, - num_gpus_per_node=8, - cfg=cfg, - ) - - with Timer("moonlight_forward"): - refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) - results = ray.get(refs) - - action_log_probs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] - valid_log_probs = action_log_probs[training_input["response_mask"].bool()] - avg_logprob = valid_log_probs.mean().item() - print( - f"Moonlight Megatron logprobs - mean: {avg_logprob:.6f}, std: {valid_log_probs.std().item():.6f}, " - f"num_tokens: {valid_log_probs.numel()}" - ) - - assert valid_log_probs.numel() > 0, "Expected at least one valid response token" - assert torch.isfinite(valid_log_probs).all().item(), "Expected all response logprobs to be finite" - assert -20.0 < avg_logprob < -0.01, f"Unexpected average logprob magnitude: {avg_logprob:.6f}" - finally: - if actor_group is not None: - for actor in actor_group._actor_handlers: - ray.kill(actor) - ray.shutdown() @pytest.mark.megatron def test_logprobs(ray_init_fixture): From 5ad9426b3ac62d65f053f1bfad78b4eda107f6c1 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 11 Mar 2026 16:54:03 +0000 Subject: [PATCH 15/31] rename var and clean up --- .../skyrl_train/inference_engines/base.py | 10 +- .../inference_engine_client.py | 30 +++-- .../inference_engines/vllm/vllm_engine.py | 10 +- skyrl/backends/skyrl_train/training_batch.py | 2 +- .../skyrl_train/utils/replay_utils.py | 34 +++--- .../megatron/megatron_model_wrapper.py | 14 +-- .../workers/megatron/megatron_worker.py | 26 ++--- .../skyrl_train/workers/worker_utils.py | 2 +- skyrl/train/dataset/preprocess.py | 16 +-- skyrl/train/dataset/replay_buffer.py | 10 +- skyrl/train/generators/base.py | 2 +- skyrl/train/generators/skyrl_gym_generator.py | 107 +++++++++--------- skyrl/train/trainer.py | 10 +- .../gpu/gpu_ci/test_router_replay.py | 94 +++++++++++++-- .../gpu/gpu_ci/test_skyrl_gym_generator.py | 40 ++++++- 15 files changed, 259 insertions(+), 148 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index ddecc965fb..819a071603 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -29,7 +29,7 @@ class InferenceEngineOutput(TypedDict): response_ids: List[List[int]] stop_reasons: List[str] response_logprobs: Optional[List[List[float]]] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] + rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] class InferenceEngineInterface(ABC): @@ -64,7 +64,7 @@ async def sample( all_responses = [] all_stop_reasons = [] all_response_logprobs = [] - all_rollout_inference_indices = [] + all_rollout_expert_indices = [] for _ in range(num_samples): input_batch: InferenceEngineInput = { @@ -81,15 +81,15 @@ async def sample( all_stop_reasons.append(output["stop_reasons"][0]) if output.get("response_logprobs") is not None: all_response_logprobs.append(output["response_logprobs"][0]) - if output.get("rollout_inference_indices") is not None: - all_rollout_inference_indices.append(output["rollout_inference_indices"][0]) + if output.get("rollout_expert_indices") is not None: + all_rollout_expert_indices.append(output["rollout_expert_indices"][0]) return { "response_ids": all_response_ids, "responses": all_responses, "stop_reasons": all_stop_reasons, "response_logprobs": all_response_logprobs if all_response_logprobs else None, - "rollout_inference_indices": all_rollout_inference_indices if all_rollout_inference_indices else None, + "rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None, } @abstractmethod diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index 264e060cd1..d418089a72 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -153,10 +153,10 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu stop_reasons: list[str] = [""] * n response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)] response_ids: List[List[int]] = [[] for _ in range(n)] - rollout_inference_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)] + rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)] # a bit hacky for now add_resp_logprobs = False - add_rollout_inference_indices = False + add_rollout_expert_indices = False for indices, result in zip(indices_list, results): for local_idx, original_idx in enumerate(indices): @@ -166,16 +166,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu if result.get("response_logprobs", None): add_resp_logprobs = True response_logprobs[original_idx] = result["response_logprobs"][local_idx] - if result.get("rollout_inference_indices", None): - add_rollout_inference_indices = True - rollout_inference_indices[original_idx] = result["rollout_inference_indices"][local_idx] + if result.get("rollout_expert_indices", None): + add_rollout_expert_indices = True + rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx] return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs if add_resp_logprobs else None, - rollout_inference_indices=rollout_inference_indices if add_rollout_inference_indices else None, + rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None, ) def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int: @@ -271,7 +271,7 @@ async def _generate_single_with_retry( # 2. Initialize fields we want to accumulate or update in each loop iteration accum_response_ids: List[int] = [] accum_response_logprobs: List[float] = [] - accum_rollout_inference_indices: List[List[List[int]]] = [] + accum_rollout_expert_indices: List[List[List[int]]] = [] stop_reason: str = "abort" # We only use it if generation is completed in one turn to maintain original behavior with no retry. @@ -307,10 +307,10 @@ async def _generate_single_with_retry( new_response_logprobs_list: Optional[List[List[float]]] = partial_response.get("response_logprobs", None) if new_response_logprobs_list is not None and len(new_response_logprobs_list) > 0: new_response_logprobs = new_response_logprobs_list[0] - new_rollout_inference_indices: Optional[List[List[List[int]]]] = None - new_rollout_inference_indices_list = partial_response.get("rollout_inference_indices", None) - if new_rollout_inference_indices_list is not None and len(new_rollout_inference_indices_list) > 0: - new_rollout_inference_indices = new_rollout_inference_indices_list[0] + new_rollout_expert_indices: Optional[List[List[List[int]]]] = None + new_rollout_expert_indices_list = partial_response.get("rollout_expert_indices", None) + if new_rollout_expert_indices_list is not None and len(new_rollout_expert_indices_list) > 0: + new_rollout_expert_indices = new_rollout_expert_indices_list[0] # 3.4 Aborted without generating tokens, so partial_response is useless. if stop_reason == "abort" and len(new_response_ids) == 0: @@ -320,8 +320,8 @@ async def _generate_single_with_retry( accum_response_ids.extend(new_response_ids) if new_response_logprobs is not None: accum_response_logprobs.extend(new_response_logprobs) - if new_rollout_inference_indices is not None: - accum_rollout_inference_indices.extend(new_rollout_inference_indices) + if new_rollout_expert_indices is not None: + accum_rollout_expert_indices.extend(new_rollout_expert_indices) num_turns += 1 # 4. Build the final response and return. @@ -334,9 +334,7 @@ async def _generate_single_with_retry( stop_reasons=[stop_reason], response_ids=[accum_response_ids], response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None, - rollout_inference_indices=( - [accum_rollout_inference_indices] if len(accum_rollout_inference_indices) > 0 else None - ), + rollout_expert_indices=([accum_rollout_expert_indices] if len(accum_rollout_expert_indices) > 0 else None), ) async def _chat_completion_with_retry( diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 11541786de..aa62a74f7c 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -135,7 +135,7 @@ def _postprocess_outputs(self, outputs): stop_reasons: List[str] = [] response_ids: List[List[int]] = [] response_logprobs: Optional[List[List[float]]] = [] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] = [] + rollout_expert_indices: Optional[List[List[List[List[int]]]]] = [] for output in outputs: # TODO(tgriggs): Support n>1 sampling. @@ -163,20 +163,20 @@ def _postprocess_outputs(self, outputs): _routed_experts = resp.routed_experts.tolist() else: _routed_experts = resp.routed_experts - rollout_inference_indices.append(_routed_experts) + rollout_expert_indices.append(_routed_experts) if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params - if len(rollout_inference_indices) == 0 and rollout_inference_indices[0] is None: - rollout_inference_indices = None # hack: assume uniform sampling params + if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None: + rollout_expert_indices = None # hack: assume uniform sampling params return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs, - rollout_inference_indices=rollout_inference_indices, + rollout_expert_indices=rollout_expert_indices, ) def _get_engine(self): diff --git a/skyrl/backends/skyrl_train/training_batch.py b/skyrl/backends/skyrl_train/training_batch.py index 8b00aeaa92..5295c6c4bc 100644 --- a/skyrl/backends/skyrl_train/training_batch.py +++ b/skyrl/backends/skyrl_train/training_batch.py @@ -369,7 +369,7 @@ class TrainingInput(TypedDict, total=False): kl: Float[torch.Tensor, "batch_size seq_len"] rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] - rollout_inference_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]] + rollout_expert_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]] class TrainingInputBatch(TensorBatch[TrainingInput]): diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index cf016f7e46..72d14e67db 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -73,24 +73,24 @@ def patched_preprocess(self, routing_map): MoEAlltoAllTokenDispatcher._preprocess_patched = True -def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch.Tensor]: - if rollout_inference_indices is None: +def _split_replay_indices(rollout_expert_indices: torch.Tensor) -> List[torch.Tensor]: + if rollout_expert_indices is None: return None - if rollout_inference_indices.dim() != 4: - raise ValueError(f"Expected 4D replay indices, got shape {rollout_inference_indices.shape}") - per_layer = rollout_inference_indices.permute(2, 0, 1, 3).contiguous() + if rollout_expert_indices.dim() != 4: + raise ValueError(f"Expected 4D replay indices, got shape {rollout_expert_indices.shape}") + per_layer = rollout_expert_indices.permute(2, 0, 1, 3).contiguous() # flatten [batch, seq, topk] to [batch * seq, topk] for each layer return [per_layer[i].reshape(-1, per_layer.shape[-1]) for i in range(per_layer.shape[0])] def _remove_left_padding_from_indices( - rollout_inference_indices: torch.Tensor, + rollout_expert_indices: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: """Apply the same left-padding removal as remove_left_padding to routing indices. Args: - rollout_inference_indices: [batch, padded_seq_len, layers, topk] + rollout_expert_indices: [batch, padded_seq_len, layers, topk] attention_mask: [batch, padded_seq_len] (int or bool) Returns: @@ -105,23 +105,23 @@ def _remove_left_padding_from_indices( pad_size = (sp_world_size - effective_seq_len % sp_world_size) % sp_world_size effective_seq_len += pad_size - batch_size = rollout_inference_indices.shape[0] + batch_size = rollout_expert_indices.shape[0] new_rii = torch.zeros( batch_size, effective_seq_len, - rollout_inference_indices.shape[2], - rollout_inference_indices.shape[3], - dtype=rollout_inference_indices.dtype, - device=rollout_inference_indices.device, + rollout_expert_indices.shape[2], + rollout_expert_indices.shape[3], + dtype=rollout_expert_indices.dtype, + device=rollout_expert_indices.device, ) for i in range(batch_size): mask = attention_mask[i].bool() - new_rii[i, : seq_lens[i]] = rollout_inference_indices[i, mask] + new_rii[i, : seq_lens[i]] = rollout_expert_indices[i, mask] return new_rii -def setup_per_microbatch_replay( - rollout_inference_indices: torch.Tensor, +def setup_per_microbatch_replay_forward( + rollout_expert_indices: torch.Tensor, attention_mask: torch.Tensor, ) -> None: """Set up RouterReplay for a single micro-batch, aligning indices @@ -141,8 +141,10 @@ def setup_per_microbatch_replay( _patch_alltoall_dispatcher_for_replay() - aligned = _remove_left_padding_from_indices(rollout_inference_indices, attention_mask) + aligned = _remove_left_padding_from_indices(rollout_expert_indices, attention_mask) + # handles megatron sequence parallelism across the tensor model parallel region + # since we automatically enable sequence parallelism when TP > 1 tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: tp_rank = mpu.get_tensor_model_parallel_rank() diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index a7f6ade581..c7d1c640b8 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -24,7 +24,7 @@ remove_left_padding, recover_left_padding, ) -from skyrl.backends.skyrl_train.utils.replay_utils import setup_per_microbatch_replay +from skyrl.backends.skyrl_train.utils.replay_utils import setup_per_microbatch_replay_forward class MegatronModelWrapper: @@ -105,9 +105,9 @@ def collection_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) - rollout_inference_indices = batch.pop("rollout_inference_indices", None) - if rollout_inference_indices is not None: - setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + rollout_expert_indices = batch.pop("rollout_expert_indices", None) + if rollout_expert_indices is not None: + setup_per_microbatch_replay_forward(rollout_expert_indices, batch["attention_mask"]) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) @@ -361,9 +361,9 @@ def loss_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) - rollout_inference_indices = batch.pop("rollout_inference_indices", None) - if rollout_inference_indices is not None: - setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + rollout_expert_indices = batch.pop("rollout_expert_indices", None) + if rollout_expert_indices is not None: + setup_per_microbatch_replay_forward(rollout_expert_indices, batch["attention_mask"]) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index cd5343c5bb..ec67b91897 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -33,7 +33,7 @@ from skyrl.train.utils.utils import update_model_config, str_to_torch_dtype from skyrl.backends.skyrl_train.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch -from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics, all_reduce_metrics, BatchIterator +from skyrl.backends.skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics from skyrl.backends.skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -289,7 +289,6 @@ def init_configs( bridge = AutoBridge.from_hf_pretrained(model_path, trust_remote_code=True) provider = bridge.to_megatron_provider() - provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 @@ -417,16 +416,17 @@ def forward(self, data: TrainingInputBatch): num_actions = micro.metadata["response_length"] position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) - micro_dict = { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": num_actions, - } - rii = micro.get("rollout_inference_indices") - if rii is not None and self.enable_router_replay: - micro_dict["rollout_inference_indices"] = rii - micro_dicts.append(micro_dict) + micro_dicts.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": num_actions, + "rollout_expert_indices": ( + micro.get("rollout_expert_indices") if self.enable_router_replay else None + ), + } + ) self.model.eval() seq_len = micro_dicts[0]["sequences"].shape[1] @@ -636,7 +636,7 @@ def forward_backward( "loss_mask": experience.loss_mask, "rollout_action_logprobs": experience.rollout_logprobs, "action_mask": experience.action_mask, - "rollout_inference_indices": experience.rollout_inference_indices, + "rollout_expert_indices": experience.rollout_expert_indices if self.enable_router_replay else None, } ) diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index 312e3ee52c..b8ad1b09fb 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -83,7 +83,7 @@ def batch_to_experience(batch: TrainingInputBatch): action_mask=batch.get("response_mask"), num_actions=batch.metadata["response_length"], # int rollout_logprobs=batch.get("rollout_logprobs"), - rollout_inference_indices=batch.get("rollout_inference_indices"), + rollout_expert_indices=batch.get("rollout_expert_indices"), # additional info # can be used to log metrics etc for micro-batches in the worker info={}, diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index be9c8fbcce..67907e977f 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -32,7 +32,7 @@ def convert_prompts_responses_to_batch_tensors( rewards: List[List[float]], loss_masks: List[List[int]], logprobs: Optional[List[List[float]]] = None, - rollout_inference_indices: Optional[List[List[List[List[int]]]]] = None, + rollout_expert_indices: Optional[List[List[List[List[int]]]]] = None, ) -> Tuple[ Float[torch.Tensor, "batch seq_len"], Float[torch.Tensor, "batch seq_len"], @@ -131,20 +131,20 @@ def convert_prompts_responses_to_batch_tensors( ] logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float) - rollout_inference_indices_tensor = None - if rollout_inference_indices: - first_non_empty = next((x for x in rollout_inference_indices if x), None) + rollout_expert_indices_tensor = None + if rollout_expert_indices: + first_non_empty = next((x for x in rollout_expert_indices if x), None) if first_non_empty: total_seq_len = max_input_len + max_output_len num_layers = len(first_non_empty[0]) topk = len(first_non_empty[0][0]) if num_layers > 0 else 0 - padded = torch.zeros(len(rollout_inference_indices), total_seq_len, num_layers, topk, dtype=torch.int32) - for i, sample_indices in enumerate(rollout_inference_indices): + padded = torch.zeros(len(rollout_expert_indices), total_seq_len, num_layers, topk, dtype=torch.int32) + for i, sample_indices in enumerate(rollout_expert_indices): if sample_indices: left_pad = max_input_len - prompt_token_lens[i] n = min(len(sample_indices), total_seq_len - left_pad) padded[i, left_pad : left_pad + n] = torch.tensor(sample_indices[:n], dtype=torch.int32) - rollout_inference_indices_tensor = padded + rollout_expert_indices_tensor = padded return ( sequences, @@ -153,5 +153,5 @@ def convert_prompts_responses_to_batch_tensors( ret_rewards, ret_loss_masks, logprobs_tensor, - rollout_inference_indices_tensor, + rollout_expert_indices_tensor, ) diff --git a/skyrl/train/dataset/replay_buffer.py b/skyrl/train/dataset/replay_buffer.py index 109f1957ec..072c65fdf7 100644 --- a/skyrl/train/dataset/replay_buffer.py +++ b/skyrl/train/dataset/replay_buffer.py @@ -66,7 +66,7 @@ class Experience: loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]] action_mask: Optional[Integer[torch.Tensor, "batch response_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]] - rollout_inference_indices: Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]] + rollout_expert_indices: Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]] num_actions: int info: Optional[dict] kl: Optional[Float[torch.Tensor, "batch response_len"]] = None @@ -93,8 +93,8 @@ def to_device(self, device: torch.device) -> None: self.action_mask = to(self.action_mask, device) if self.rollout_logprobs is not None: self.rollout_logprobs = to(self.rollout_logprobs, device) - if self.rollout_inference_indices is not None: - self.rollout_inference_indices = to(self.rollout_inference_indices, device) + if self.rollout_expert_indices is not None: + self.rollout_expert_indices = to(self.rollout_expert_indices, device) def pin_memory(self): self.sequences = pin_memory(self.sequences) @@ -116,8 +116,8 @@ def pin_memory(self): self.action_mask = self.action_mask.pin_memory() if self.rollout_logprobs is not None: self.rollout_logprobs = self.rollout_logprobs.pin_memory() - if self.rollout_inference_indices is not None: - self.rollout_inference_indices = self.rollout_inference_indices.pin_memory() + if self.rollout_expert_indices is not None: + self.rollout_expert_indices = self.rollout_expert_indices.pin_memory() return self diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index 530fbbbe84..49b3ecc1ac 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -39,7 +39,7 @@ class GeneratorOutput(TypedDict): rollout_metrics: Optional[Dict[str, Any]] rollout_logprobs: Optional[List[List[float]]] trajectory_ids: Optional[List[TrajectoryID]] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] + rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] # Applicable only for step-wise training is_last_step: Optional[List[bool]] diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index b48249130d..cbad99eb31 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -40,7 +40,7 @@ class TrajectoryOutput: prompt_ids: List[int] rollout_logprobs: Optional[List[float]] env_metrics: Dict[str, Any] - rollout_inference_indices: Optional[List[List[List[int]]]] = None + rollout_expert_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -58,7 +58,7 @@ class AgentLoopState: rollout_logprobs: Optional[List[float]] response_end_idx: Optional[int] done: bool - rollout_inference_indices: Optional[List[List[List[int]]]] = None + rollout_expert_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -68,26 +68,26 @@ class TurnOutput: output_logprobs: Optional[List[float]] new_obs: ConversationType obs_ids: List[int] - rollout_inference_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk] + rollout_expert_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk] reward: Optional[float] added_eos: bool = False - def get_turn_rollout_inference_indices(self) -> Optional[List[List[List[int]]]]: + def get_turn_rollout_expert_indices(self) -> Optional[List[List[List[int]]]]: """ Get rollout inference indices for this turn's tokens (output + observation). Returns indices for generated output tokens, with padding entries (all -1) for any manually-added EOS token and observation tokens. - Returns None if rollout_inference_indices is None. + Returns None if rollout_expert_indices is None. """ - if self.rollout_inference_indices is None: + if self.rollout_expert_indices is None: return None - if not self.rollout_inference_indices: - return self.rollout_inference_indices - layer_num = len(self.rollout_inference_indices[0]) - topk = len(self.rollout_inference_indices[0][0]) if layer_num > 0 else 0 + if not self.rollout_expert_indices: + return self.rollout_expert_indices + layer_num = len(self.rollout_expert_indices[0]) + topk = len(self.rollout_expert_indices[0][0]) if layer_num > 0 else 0 pad_entry = [[-1] * topk for _ in range(layer_num)] - indices = list(self.rollout_inference_indices) + indices = list(self.rollout_expert_indices) if self.added_eos: indices.append(pad_entry) indices.extend(pad_entry for _ in range(len(self.obs_ids))) @@ -324,14 +324,14 @@ async def agent_loop( output_ids = engine_output["response_ids"][0] stop_reason = engine_output["stop_reasons"][0] response_logprobs = engine_output.get("response_logprobs", None) - rollout_inference_indices = engine_output.get("rollout_inference_indices", None) + rollout_expert_indices = engine_output.get("rollout_expert_indices", None) if response_logprobs is not None: response_logprobs = response_logprobs[0] if self.custom_chat_template is not None: raise ValueError("Response Logprobs bookkeeping is not supported with custom chat template") - if rollout_inference_indices is not None: - rollout_inference_indices = rollout_inference_indices[0] + if rollout_expert_indices is not None: + rollout_expert_indices = rollout_expert_indices[0] # Append eos when sampling_params.stop is not None. Does not affect 3.a as chat templates add eos_token. # sampling_params is not None for eval, but None for training (which uses engine.sampling_params which are from cfg) stop_strs = current_sampling_params.get("stop", None) @@ -375,11 +375,11 @@ async def agent_loop( reward=step_reward, obs_ids=obs_ids, added_eos=added_eos, - rollout_inference_indices=rollout_inference_indices, + rollout_expert_indices=rollout_expert_indices, ) - if turn_output.rollout_inference_indices is not None and agent_loop_state.rollout_inference_indices is None: - agent_loop_state.rollout_inference_indices = [] + if turn_output.rollout_expert_indices is not None and agent_loop_state.rollout_expert_indices is None: + agent_loop_state.rollout_expert_indices = [] if is_step_wise: # current response + observation ids @@ -398,7 +398,7 @@ async def agent_loop( rollout_logprobs=turn_response_logprobs, stop_reason=stop_reason, env_metrics=env.get_metrics() if agent_loop_state.done else {}, - rollout_inference_indices=turn_output.get_turn_rollout_inference_indices(), + rollout_expert_indices=turn_output.get_turn_rollout_expert_indices(), ) agent_loop_output.step_outputs.append(per_step_output) @@ -427,7 +427,7 @@ async def agent_loop( prompt_ids = agent_loop_state.input_ids[:initial_prompt_length] rollout_logprobs = None - rollout_inference_indices_out = None + rollout_expert_indices_out = None response_ids = None # Prepare the final loss_mask, response_ids and rollout_logprobs . @@ -458,8 +458,8 @@ async def agent_loop( rollout_logprobs = agent_loop_state.rollout_logprobs[ : agent_loop_state.response_end_idx - initial_prompt_length + 1 ] - if agent_loop_state.rollout_inference_indices is not None: - rollout_inference_indices_out = agent_loop_state.rollout_inference_indices[ + if agent_loop_state.rollout_expert_indices is not None: + rollout_expert_indices_out = agent_loop_state.rollout_expert_indices[ : agent_loop_state.response_end_idx + 1 ] # fix index for per_step_rewards @@ -478,10 +478,10 @@ async def agent_loop( loss_mask.append(1) if rollout_logprobs is not None: rollout_logprobs.append(0.0) - if rollout_inference_indices_out is not None and rollout_inference_indices_out: - layer_num = len(rollout_inference_indices_out[0]) - topk = len(rollout_inference_indices_out[0][0]) if layer_num > 0 else 0 - rollout_inference_indices_out.append([[-1] * topk for _ in range(layer_num)]) + if rollout_expert_indices_out is not None and rollout_expert_indices_out: + layer_num = len(rollout_expert_indices_out[0]) + topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0 + rollout_expert_indices_out.append([[-1] * topk for _ in range(layer_num)]) appended_eos_token = True if self.generator_cfg.step_wise_trajectories: @@ -501,7 +501,7 @@ async def agent_loop( prompt_ids=prompt_ids, rollout_logprobs=rollout_logprobs, env_metrics=env_metrics, - rollout_inference_indices=rollout_inference_indices_out, + rollout_expert_indices=rollout_expert_indices_out, ) return agent_loop_output @@ -653,14 +653,14 @@ async def generate_batched( responses = engine_output["response_ids"] stop_reasons = engine_output["stop_reasons"] logprobs = engine_output.get("response_logprobs", None) - raw_rollout_inference_indices = engine_output.get("rollout_inference_indices", None) + raw_rollout_expert_indices = engine_output.get("rollout_expert_indices", None) truncated_responses = [] rewards = [] loss_masks = [] env_metrics = [] truncated_logprobs: Optional[List[List[float]]] = [] if logprobs is not None else None - truncated_indices: Optional[List] = [] if raw_rollout_inference_indices is not None else None + truncated_indices: Optional[List] = [] if raw_rollout_expert_indices is not None else None for i, (output, response, env, env_class) in enumerate(zip(outputs, responses, envs, env_classes)): # step on environment and compute reward @@ -675,8 +675,8 @@ async def generate_batched( if logprobs is not None: sample_logprobs = logprobs[i][: len(response)] truncated_logprobs.append(sample_logprobs) - if raw_rollout_inference_indices is not None: - truncated_indices.append(raw_rollout_inference_indices[i]) + if raw_rollout_expert_indices is not None: + truncated_indices.append(raw_rollout_expert_indices[i]) # Get environment-specific metrics env_metrics.append(env.get_metrics()) @@ -696,7 +696,7 @@ async def generate_batched( "stop_reasons": stop_reasons, "rollout_metrics": rollout_metrics, "rollout_logprobs": truncated_logprobs, - "rollout_inference_indices": truncated_indices, + "rollout_expert_indices": truncated_indices, } return generator_output @@ -799,15 +799,12 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False if self.generator_cfg.step_wise_trajectories: all_indices = sum( - [ - [step_output.rollout_inference_indices for step_output in output.step_outputs] - for output in all_outputs - ], + [[step_output.rollout_expert_indices for step_output in output.step_outputs] for output in all_outputs], [], ) else: - all_indices = [output.rollout_inference_indices for output in all_outputs] - rollout_inference_indices = all_indices if any(idx is not None for idx in all_indices) else None + all_indices = [output.rollout_expert_indices for output in all_outputs] + rollout_expert_indices = all_indices if any(idx is not None for idx in all_indices) else None rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) @@ -827,7 +824,7 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False "rollout_metrics": rollout_metrics, "rollout_logprobs": rollout_logprobs, "trajectory_ids": out_trajectory_ids, - "rollout_inference_indices": rollout_inference_indices, + "rollout_expert_indices": rollout_expert_indices, "is_last_step": is_last_step, } @@ -896,7 +893,7 @@ def _update_agent_state_by_retokenizing_chat_history( # `logprobs` are not computed because retokenizing breaks token-in-token-out agent_loop_state.rollout_logprobs = None # indices are not meaningful when retokenizing - agent_loop_state.rollout_inference_indices = None + agent_loop_state.rollout_expert_indices = None return agent_loop_state def _update_agent_loop_state_with_multiturn_chat_template( @@ -948,14 +945,15 @@ def _update_agent_loop_state_with_multiturn_chat_template( loss_mask_for_turn = turn_output.get_turn_loss_mask() rollout_logprobs_for_turn = turn_output.get_turn_rollout_logprobs() - rollout_inference_indices_for_turn = turn_output.get_turn_rollout_inference_indices() + rollout_expert_indices_for_turn = turn_output.get_turn_rollout_expert_indices() if self.generator_cfg.step_wise_trajectories: # cumulative input_ids is not tracked for step wise training agent_loop_state.response_end_idx = len(turn_output.output_ids) - 1 - # no running loss_mask or `rollout_logprobs` are tracked for step-wise training + # no running loss_mask, `rollout_logprobs`, or `rollout_expert_indices` are tracked for step-wise training agent_loop_state.loss_mask = None agent_loop_state.rollout_logprobs = None + agent_loop_state.rollout_expert_indices = None else: # Directly append turn output turn_ids = turn_output.output_ids + turn_output.obs_ids @@ -964,11 +962,10 @@ def _update_agent_loop_state_with_multiturn_chat_template( agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn - if ( - agent_loop_state.rollout_inference_indices is not None - and rollout_inference_indices_for_turn is not None - ): - agent_loop_state.rollout_inference_indices += rollout_inference_indices_for_turn + if agent_loop_state.rollout_expert_indices is not None and rollout_expert_indices_for_turn is not None: + # overwrite the existing rollout inference indices, since the inference engine should + # return the expert indices for the entire sequence including each turn's input + agent_loop_state.rollout_expert_indices = rollout_expert_indices_for_turn return agent_loop_state @@ -1033,13 +1030,13 @@ def _update_agent_loop_state_with_singleturn_chat_template( obs_ids_to_add ) - rollout_inference_indices_for_turn = None - if turn_output.rollout_inference_indices is not None and turn_output.rollout_inference_indices: - layer_num = len(turn_output.rollout_inference_indices[0]) - topk = len(turn_output.rollout_inference_indices[0][0]) if layer_num > 0 else 0 - pad_entry = [[-1] * topk for _ in range(layer_num)] - rollout_inference_indices_for_turn = list(turn_output.rollout_inference_indices[: len(new_resp_tokens)]) - rollout_inference_indices_for_turn.extend(pad_entry for _ in range(len(obs_ids_to_add))) + # rollout_expert_indices_for_turn = None + # if turn_output.rollout_expert_indices is not None and turn_output.rollout_expert_indices: + # layer_num = len(turn_output.rollout_expert_indices[0]) + # topk = len(turn_output.rollout_expert_indices[0][0]) if layer_num > 0 else 0 + # pad_entry = [[-1] * topk for _ in range(layer_num)] + # rollout_expert_indices_for_turn = list(turn_output.rollout_expert_indices[: len(new_resp_tokens)]) + # rollout_expert_indices_for_turn.extend(pad_entry for _ in range(len(obs_ids_to_add))) # Directly append turn output agent_loop_state.response_end_idx = len(agent_loop_state.input_ids) + len(new_resp_tokens) - 1 @@ -1047,7 +1044,7 @@ def _update_agent_loop_state_with_singleturn_chat_template( agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn - if agent_loop_state.rollout_inference_indices is not None and rollout_inference_indices_for_turn is not None: - agent_loop_state.rollout_inference_indices += rollout_inference_indices_for_turn + if self.generator_cfg.enable_return_routed_experts and turn_output.rollout_expert_indices is not None: + agent_loop_state.rollout_expert_indices = turn_output.rollout_expert_indices return agent_loop_state diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index f842868816..ef73ab11b5 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -604,8 +604,8 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks: List[List[int]] = generator_output["loss_masks"] logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None) - rollout_inference_indices: Optional[List[List[List[List[int]]]]] = generator_output.get( - "rollout_inference_indices", None + rollout_expert_indices: Optional[List[List[List[List[int]]]]] = generator_output.get( + "rollout_expert_indices", None ) ( @@ -615,7 +615,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards_tensor, loss_masks_tensor, rollout_logprobs_tensor, - rollout_inference_indices_tensor, + rollout_expert_indices_tensor, ) = convert_prompts_responses_to_batch_tensors( self.tokenizer, prompt_ids, @@ -623,7 +623,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards, loss_masks, logprobs, - rollout_inference_indices, + rollout_expert_indices, ) # sanity check for off_policy_correction @@ -644,7 +644,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis "rewards": rewards_tensor, "loss_mask": loss_masks_tensor, "rollout_logprobs": rollout_logprobs_tensor, - "rollout_inference_indices": rollout_inference_indices_tensor, + "rollout_expert_indices": rollout_expert_indices_tensor, "is_last_step": ( torch.tensor(generator_output["is_last_step"], dtype=torch.bool) if generator_output.get("is_last_step", None) is not None diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 9581b8f990..8e84ee61f2 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -23,7 +23,8 @@ from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch -MOE_MODEL_NAME = "/home/ray/moonlight16b" +MOE_MODEL_NAME = "arcee-ai/Trinity-Nano-Preview" +# MOE_MODEL_NAME = "/home/ray/moonlight16b" # MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" REPLAY_NUM_LAYERS = 2 NUM_PROMPTS = 10 @@ -112,6 +113,81 @@ def build_training_input_from_text_samples( return training_input +def test_generate_with_router_replay(ray_init_fixture): + """ + Check that generate with router replay produces the correct rollout inference indices. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 8 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=MAX_GENERATE_LENGTH, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 2 + cfg.generator.use_conversation_multi_turn = True + env_class = "gsm8k_multi_turn" + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + colocate_all=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.9, + ) as engines: + client = engines.client + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, + max_prompt_length=512, + env_class=env_class, + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=MAX_GENERATE_LENGTH, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + indices = generator_output["rollout_expert_indices"] + responses = generator_output["response_ids"] + assert ( + indices is not None + ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + asyncio.run(client.sleep()) + finally: + ray.shutdown() + + @pytest.mark.megatron def test_logprobs(ray_init_fixture): """ @@ -173,11 +249,11 @@ def test_logprobs(ray_init_fixture): with Timer("generate_with_router_replay"): generator_output = asyncio.run(generator.generate(input_batch)) - indices = generator_output["rollout_inference_indices"] + indices = generator_output["rollout_expert_indices"] responses = generator_output["response_ids"] assert ( indices is not None - ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" assert len(indices) == len( responses ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" @@ -194,7 +270,7 @@ def test_logprobs(ray_init_fixture): rewards=rewards, loss_masks=generator_output["loss_masks"], logprobs=generator_output.get("rollout_logprobs"), - rollout_inference_indices=indices, + rollout_expert_indices=indices, ) ) @@ -213,7 +289,7 @@ def test_logprobs(ray_init_fixture): if logprobs_t is not None else torch.zeros((batch_size, num_actions), dtype=torch.float32) ), - "rollout_inference_indices": rii_tensor, + "rollout_expert_indices": rii_tensor, "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), @@ -333,11 +409,11 @@ def test_forward_backward(ray_init_fixture): with Timer("generate_with_router_replay"): generator_output = asyncio.run(generator.generate(input_batch)) - indices = generator_output["rollout_inference_indices"] + indices = generator_output["rollout_expert_indices"] responses = generator_output["response_ids"] assert ( indices is not None - ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" assert len(indices) == len( responses ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" @@ -354,7 +430,7 @@ def test_forward_backward(ray_init_fixture): rewards=rewards, loss_masks=generator_output["loss_masks"], logprobs=generator_output.get("rollout_logprobs"), - rollout_inference_indices=indices, + rollout_expert_indices=indices, ) ) @@ -373,7 +449,7 @@ def test_forward_backward(ray_init_fixture): if logprobs_t is not None else torch.zeros((batch_size, num_actions), dtype=torch.float32) ), - "rollout_inference_indices": rii_tensor, + "rollout_expert_indices": rii_tensor, "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py index 788bc8ef42..7a0b77bbfa 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py @@ -29,6 +29,7 @@ def get_test_config( is_step_wise, temperature, get_logprobs, + enable_return_routed_experts, ): cfg = SkyRLTrainConfig() cfg.trainer.policy.model.path = model @@ -49,7 +50,7 @@ def get_test_config( cfg.generator.inference_engine.http_endpoint_host = "127.0.0.1" cfg.generator.inference_engine.http_endpoint_port = 8000 cfg.generator.step_wise_trajectories = is_step_wise - + cfg.generator.inference_engine.enable_return_routed_experts = enable_return_routed_experts cfg.environment.skyrl_gym.search.log_requests = True cfg.environment.skyrl_gym.search.search_url = "http://127.0.0.1:8000/retrieve" cfg.environment.skyrl_gym.max_env_workers = max_env_workers @@ -108,6 +109,7 @@ async def run_generator_end_to_end( is_step_wise: bool = False, temperature=1.0, get_logprobs: bool = False, + enable_return_routed_experts: bool = False, ): """ End to end generator test - requires minimum 2 GPUs @@ -125,6 +127,7 @@ async def run_generator_end_to_end( is_step_wise, temperature, get_logprobs, + enable_return_routed_experts, ) # Use InferenceEngineState to support both legacy and new inference backends @@ -186,6 +189,11 @@ async def run_generator_end_to_end( "rollout_logprobs": ( generator_output["rollout_logprobs"][i] if generator_output["rollout_logprobs"] else None ), + "rollout_expert_indices": ( + generator_output["rollout_expert_indices"][i] + if generator_output["rollout_expert_indices"] + else None + ), } for i in range(len(generator_output["response_ids"])) ] @@ -450,3 +458,33 @@ async def test_generator_multi_turn_gsm8k_step_wise(ray_init_fixture): assert isinstance(generator_output["is_last_step"], list) and isinstance(generator_output["is_last_step"][0], bool) # Expect atleast one response with more than one turn assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) + + +async def test_generator_multi_turn_gsm8k_router_replay(ray_init_fixture): + """ + Test the generator with the multi-turn GSM8K environment for router replay + """ + generator_output: GeneratorOutput = await run_generator_end_to_end( + use_async_engine=True, + batched=False, + n_samples_per_prompt=5, + num_inference_engines=2, + tensor_parallel_size=2, + model="arcee-ai/Trinity-Nano-Preview", + max_prompt_length=2048, + max_input_length=4096, + max_generate_length=1000, + data_path=os.path.expanduser("~/data/gsm8k/validation.parquet"), + env_class="gsm8k_multi_turn", + num_prompts=2, + max_turns=2, + use_conversation_multi_turn=True, + max_env_workers=0, + is_step_wise=True, + temperature=0, + enable_return_routed_experts=True, + ) + + assert isinstance(generator_output["is_last_step"], list) and isinstance(generator_output["is_last_step"][0], bool) + # Expect atleast one response with more than one turn + assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) From 4a60d4c7765624689e6dcb4b85bbd8cd40d66da2 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 12 Mar 2026 20:32:51 +0000 Subject: [PATCH 16/31] cleaning up --- .../inference_engines/vllm/vllm_engine.py | 2 +- .../skyrl_train/utils/replay_utils.py | 25 ++- .../megatron/megatron_model_wrapper.py | 8 +- .../workers/megatron/megatron_worker.py | 7 +- skyrl/train/config/config.py | 1 + .../train/config/megatron_config/policy.yaml | 4 +- skyrl/train/generators/skyrl_gym_generator.py | 9 +- skyrl/train/utils/utils.py | 5 + .../skyrl_train/gpu/gpu_ci/conftest.py | 2 +- .../gpu/gpu_ci/test_router_replay.py | 155 +++++------------- .../gpu/gpu_ci/test_skyrl_gym_generator.py | 20 ++- 11 files changed, 108 insertions(+), 130 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 565e4bfffd..cad5bce14b 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -182,7 +182,7 @@ def _postprocess_outputs(self, outputs): if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params - if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None: + if len(rollout_expert_indices) > 0 and rollout_expert_indices[0] is None: rollout_expert_indices = None # hack: assume uniform sampling params return InferenceEngineOutput( diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 72d14e67db..87ec12dedb 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -2,9 +2,10 @@ Utility functions for MoE Router Replay. """ -import torch from typing import List +import torch + def _patch_topk_router_layer_number(): """Monkey-patch TopKRouter.set_layer_number to propagate the global layer @@ -49,7 +50,9 @@ def _patch_alltoall_dispatcher_for_replay(): Reference: https://github.com/verl-project/verl/pull/4986 """ try: - from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher + from megatron.core.transformer.moe.token_dispatcher import ( + MoEAlltoAllTokenDispatcher, + ) except ImportError: return @@ -137,7 +140,10 @@ def setup_per_microbatch_replay_forward( TopKRouter.set_layer_number) to index into the correct slice of the data. """ import megatron.core.parallel_state as mpu - from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + from megatron.core.transformer.moe.router_replay import ( + RouterReplay, + RouterReplayAction, + ) _patch_alltoall_dispatcher_for_replay() @@ -179,6 +185,19 @@ def setup_per_microbatch_replay_forward( RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) +def setup_per_microbatch_replay_backward() -> None: + """Switch RouterReplay to backward mode so that activation-checkpoint + recomputation during the backward pass consumes indices from + ``replay_backward_list`` in FIFO order (populated during the forward pass). + """ + from megatron.core.transformer.moe.router_replay import ( + RouterReplay, + RouterReplayAction, + ) + + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) + + def clear_router_replay(): """Clear all router replay state.""" from megatron.core.transformer.moe.router_replay import RouterReplay diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 7f20ff3d14..17c6b5db19 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -25,7 +25,10 @@ PolicyLossRegistry, compute_approx_kl, ) -from skyrl.backends.skyrl_train.utils.replay_utils import setup_per_microbatch_replay_forward +from skyrl.backends.skyrl_train.utils.replay_utils import ( + setup_per_microbatch_replay_backward, + setup_per_microbatch_replay_forward, +) from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean from skyrl.train.config import TrainerConfig @@ -414,6 +417,9 @@ def forward_step(batch_iter, model): post_process=mpu.is_pipeline_last_stage(ignore_virtual=True), ) + if rollout_expert_indices is not None: + setup_per_microbatch_replay_backward() + return outputs, partial(loss_func, data=batch) # batch should be a list of micro-batches diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 054874b04b..9076106e2f 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -351,6 +351,7 @@ def init_configs( provider.moe_router_score_function = megatron_config.moe_router_score_function if megatron_config.moe_router_enable_expert_bias is not None: provider.moe_router_enable_expert_bias = megatron_config.moe_router_enable_expert_bias + provider.moe_enable_routing_replay = megatron_config.moe_enable_routing_replay # Apply any additional transformer config kwargs (can override the above). for k, v in transformer_config_kwargs.items(): @@ -362,7 +363,7 @@ def init_configs( self.strategy.hf_config = hf_config self.tokenizer = tokenizer - self.enable_router_replay = transformer_config_kwargs.get("moe_enable_routing_replay", False) + self.enable_router_replay = megatron_config.moe_enable_routing_replay def configure_lora(self, lora_config, lora_type: Optional[str] = "lora"): if lora_type == "lora": @@ -577,7 +578,9 @@ def init_model(self, model_path, num_training_steps: int = 1e9): ) if self.enable_router_replay: - from skyrl.backends.skyrl_train.utils.replay_utils import _patch_topk_router_layer_number + from skyrl.backends.skyrl_train.utils.replay_utils import ( + _patch_topk_router_layer_number, + ) _patch_topk_router_layer_number() diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 014e8fbe99..62c696e1d2 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -158,6 +158,7 @@ class MegatronConfig(BaseConfig): moe_grouped_gemm: bool = True moe_router_score_function: Optional[str] = None moe_router_enable_expert_bias: Optional[bool] = None + moe_enable_routing_replay: bool = False ddp_config: MegatronDDPConfig = field(default_factory=MegatronDDPConfig) torch_profiler_config: MegatronTorchProfilerConfig = field(default_factory=MegatronTorchProfilerConfig) lora_config: MegatronLoraConfig = field(default_factory=MegatronLoraConfig) diff --git a/skyrl/train/config/megatron_config/policy.yaml b/skyrl/train/config/megatron_config/policy.yaml index b641d1f706..f44181e55a 100644 --- a/skyrl/train/config/megatron_config/policy.yaml +++ b/skyrl/train/config/megatron_config/policy.yaml @@ -14,6 +14,9 @@ moe_grouped_gemm: true moe_router_score_function: null moe_router_enable_expert_bias: null +# whether to enable router replay (r3) +moe_enable_routing_replay: False + # pass-through config to Megatron's `DistributedDataParallelConfig` object # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.13.0/megatron/core/distributed/distributed_data_parallel_config.py#L8 ddp_config: @@ -57,7 +60,6 @@ transformer_config_kwargs: recompute_modules: ["core_attn"] recompute_method: uniform recompute_num_layers: 1 - moe_enable_routing_replay: ${generator.enable_return_routed_experts} # flag to manually empty torch's cuda cache between the forward/backward pass and the optimizer step # this will free reserved but unallocated memory, and can help avoid OoMs in the optimizer diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 1c5883fe64..acd25af560 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -686,7 +686,9 @@ async def generate_batched( sample_logprobs = logprobs[i][: len(response)] truncated_logprobs.append(sample_logprobs) if raw_rollout_expert_indices is not None: - truncated_indices.append(raw_rollout_expert_indices[i]) + sample_indices = raw_rollout_expert_indices[i] + prompt_len = len(prompt_token_ids[i]) + truncated_indices.append(sample_indices[: prompt_len + len(response)]) # Get environment-specific metrics env_metrics.append(env.get_metrics()) @@ -1054,7 +1056,10 @@ def _update_agent_loop_state_with_singleturn_chat_template( agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn - if self.generator_cfg.enable_return_routed_experts and turn_output.rollout_expert_indices is not None: + if ( + self.generator_cfg.inference_engine.enable_return_routed_experts + and turn_output.rollout_expert_indices is not None + ): agent_loop_state.rollout_expert_indices = turn_output.rollout_expert_indices return agent_loop_state diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 46babd7be3..6c12066c7f 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -201,6 +201,11 @@ def validate_megatron_cfg(cfg: SkyRLTrainConfig): if version > "2.8.1": logger.warning("flash_attn > 2.8.1 is not supported for using the megatron backend with flash_attn") + if cfg.trainer.policy.megatron_config.moe_enable_routing_replay: + assert ( + cfg.generator.inference_engine.enable_return_routed_experts + ), "rollout router replay (r3) is only supported when enable_return_routed_experts is True" + worker_configs = [(cfg.trainer.policy, "policy"), (cfg.trainer.ref, "ref")] for config, worker_type in worker_configs: # context, expert, and expert tensor parallel are not yet supported for megatron diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index f572b11aab..8f906e4e55 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -39,7 +39,7 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - # env_vars["NVTE_FUSED_ATTN"] = "0" + env_vars["NVTE_FUSED_ATTN"] = "1" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 8e84ee61f2..f08e31c490 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -3,28 +3,34 @@ uv run --isolated --extra dev --extra megatron -- pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py """ -import ray -import pytest import asyncio + +import pytest +import ray import torch from transformers import AutoTokenizer + +from skyrl.backends.skyrl_train.distributed.dispatch import ( + concatenate_outputs_after_mesh_dispatch, +) +from skyrl.backends.skyrl_train.inference_engines.utils import ( + get_sampling_params_for_backend, +) +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.train.config import SamplingParams, SkyRLTrainConfig +from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors +from skyrl.train.generators.base import GeneratorInput +from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator +from skyrl.train.utils.utils import validate_cfg from tests.backends.skyrl_train.gpu.utils import ( InferenceEngineState, - get_test_generator_input, Timer, + get_test_generator_input, init_worker_with_type, ) -from skyrl.backends.skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch -from skyrl.train.utils.utils import validate_cfg -from skyrl.train.config import SkyRLTrainConfig, SamplingParams -from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator -from skyrl.train.generators.base import GeneratorInput -from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend -from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors -from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch -MOE_MODEL_NAME = "arcee-ai/Trinity-Nano-Preview" -# MOE_MODEL_NAME = "/home/ray/moonlight16b" +# MOE_MODEL_NAME = "arcee-ai/Trinity-Nano-Preview" +MOE_MODEL_NAME = "/home/ray/moonlight16b" # MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" REPLAY_NUM_LAYERS = 2 NUM_PROMPTS = 10 @@ -38,6 +44,7 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: cfg.trainer.micro_forward_batch_size_per_gpu = 2 cfg.trainer.micro_train_batch_size_per_gpu = 2 cfg.trainer.use_sample_packing = True + cfg.generator.inference_engine.distributed_executor_backend = "mp" # flash attn + mla works without sample packing, logprobs are crazy/wrong # but flash-attn correctly throws error with sample packing # we should add an assert that if you set use_sample_packing=False flash attn can accidentally be used @@ -45,19 +52,6 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: if "moonlight" in model_name: if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: cfg.trainer.policy.megatron_config.transformer_config_kwargs = {} - # flash attn not supported for moonlight16b - # cfg.trainer.policy.megatron_config.moe_token_dispatcher_type = "alltoall" - # cfg.trainer.policy.megatron_config.moe_router_load_balancing_type = "seq_aux_loss" - # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_aux_loss_coeff"] = 0 - # cfg.trainer.policy.megatron_config.moe_router_score_function = "sigmoid" - # cfg.trainer.policy.megatron_config.moe_router_enable_expert_bias = True - # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_bias_update_rate"] = 0 - # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_dtype"] = "fp32" - # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_topk"] = 6 - # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_pre_softmax"] = True - # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_group_topk"] = 1 - # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_num_groups"] = 1 - # cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers_in_last_pipeline_stage"] = 13 cfg.trainer.flash_attn = False validate_cfg(cfg) return cfg @@ -112,82 +106,6 @@ def build_training_input_from_text_samples( training_input.metadata = {"response_length": num_actions} return training_input - -def test_generate_with_router_replay(ray_init_fixture): - """ - Check that generate with router replay produces the correct rollout inference indices. - """ - try: - cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) - cfg.trainer.strategy = "megatron" - cfg.generator.inference_engine.enable_return_routed_experts = True - cfg.generator.inference_engine.tensor_parallel_size = 8 - cfg.generator.sampling_params = SamplingParams( - max_generate_length=MAX_GENERATE_LENGTH, - logprobs=1, - temperature=1.0, - ) - cfg.generator.batched = False - cfg.generator.max_turns = 2 - cfg.generator.use_conversation_multi_turn = True - env_class = "gsm8k_multi_turn" - - tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) - - with InferenceEngineState.create( - cfg=cfg, - model=MOE_MODEL_NAME, - use_local=True, - colocate_all=True, - backend="vllm", - sleep_level=1, - gpu_memory_utilization=0.9, - ) as engines: - client = engines.client - asyncio.run(client.wake_up()) - - generator = SkyRLGymGenerator( - generator_cfg=cfg.generator, - skyrl_gym_cfg=cfg.environment.skyrl_gym, - inference_engine_client=client, - tokenizer=tokenizer, - ) - - input_batch: GeneratorInput = get_test_generator_input( - model=MOE_MODEL_NAME, - num_prompts=NUM_PROMPTS, - n_samples_per_prompt=N_SAMPLES_PER_PROMPT, - max_prompt_length=512, - env_class=env_class, - ) - input_batch["sampling_params"] = get_sampling_params_for_backend( - "vllm", - SamplingParams( - temperature=1.0, - top_p=1.0, - top_k=-1, - max_generate_length=MAX_GENERATE_LENGTH, - min_p=0.0, - logprobs=1, - ), - ) - - with Timer("generate_with_router_replay"): - generator_output = asyncio.run(generator.generate(input_batch)) - - indices = generator_output["rollout_expert_indices"] - responses = generator_output["response_ids"] - assert ( - indices is not None - ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" - assert len(indices) == len( - responses - ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" - asyncio.run(client.sleep()) - finally: - ray.shutdown() - - @pytest.mark.megatron def test_logprobs(ray_init_fixture): """ @@ -203,7 +121,8 @@ def test_logprobs(ray_init_fixture): logprobs=1, temperature=1.0, ) - cfg.generator.batched = False + cfg.generator.batched = True + cfg.generator.async_engine = False cfg.generator.max_turns = 1 tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) @@ -308,9 +227,7 @@ def test_logprobs(ray_init_fixture): cfg.trainer.micro_train_batch_size_per_gpu = 1 def run_megatron_forward(enable_replay: bool) -> torch.Tensor: - cfg.trainer.policy.megatron_config.transformer_config_kwargs = { - "moe_enable_routing_replay": enable_replay, - } + cfg.trainer.policy.megatron_config.moe_enable_routing_replay = enable_replay actor_group = init_worker_with_type( "policy", shared_pg=pg, @@ -330,11 +247,15 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: r3_logprobs = run_megatron_forward(enable_replay=True) no_r3_logprobs = run_megatron_forward(enable_replay=False) - r3_diff = (logprobs_t - r3_logprobs).abs() - no_r3_diff = (logprobs_t - no_r3_logprobs).abs() - print(f"vLLM logprobs - mean: {logprobs_t.mean().item():.6f}, std: {logprobs_t.std().item():.6f}") - print(f"Megatron (replay) - mean: {r3_logprobs.mean().item():.6f}, std: {r3_logprobs.std().item():.6f}") - print(f"Megatron (no rep) - mean: {no_r3_logprobs.mean().item():.6f}, std: {no_r3_logprobs.std().item():.6f}") + mask = response_mask.bool() + vllm_valid = logprobs_t[mask] + r3_valid = r3_logprobs[mask] + no_r3_valid = no_r3_logprobs[mask] + r3_diff = (vllm_valid - r3_valid).abs() + no_r3_diff = (vllm_valid - no_r3_valid).abs() + print(f"vLLM logprobs - mean: {vllm_valid.mean().item():.6f}, std: {vllm_valid.std().item():.6f}") + print(f"Megatron (replay) - mean: {r3_valid.mean().item():.6f}, std: {r3_valid.std().item():.6f}") + print(f"Megatron (no rep) - mean: {no_r3_valid.mean().item():.6f}, std: {no_r3_valid.std().item():.6f}") print(f"With replay - logprob diff mean: {r3_diff.mean().item():.6f}, std: {r3_diff.std().item():.6f}") print(f"Without replay - logprob diff mean: {no_r3_diff.mean().item():.6f}, std: {no_r3_diff.std().item():.6f}") @@ -468,9 +389,7 @@ def test_forward_backward(ray_init_fixture): cfg.trainer.micro_train_batch_size_per_gpu = 1 def run_megatron_forward_backward(enable_replay: bool) -> dict: - cfg.trainer.policy.megatron_config.transformer_config_kwargs = { - "moe_enable_routing_replay": enable_replay, - } + cfg.trainer.policy.megatron_config.moe_enable_routing_replay = enable_replay actor_group = init_worker_with_type( "policy", shared_pg=pg, @@ -478,6 +397,10 @@ def run_megatron_forward_backward(enable_replay: bool) -> dict: num_gpus_per_node=8, cfg=cfg, ) + + ray.get(actor_group.async_run_ray_method("pass_through", "setup_per_microbatch_replay_backward")) + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) for actor in actor_group._actor_handlers: ray.kill(actor) @@ -490,8 +413,8 @@ def run_megatron_forward_backward(enable_replay: bool) -> dict: loss_no_replay = metrics_no_replay["policy_loss"] print(f"With replay - loss: {loss_replay:.6f}") print(f"Without replay - loss: {loss_no_replay:.6f}") - print(f"With replay metrics: {metrics_replay}") - print(f"Without replay metrics: {metrics_no_replay}") + # print(f"With replay metrics: {metrics_replay}") + # print(f"Without replay metrics: {metrics_no_replay}") diff = abs(loss_replay - loss_no_replay) threshold = 0.5 diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py index 0fa7f6c3b5..5f0acd46ac 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py @@ -468,23 +468,27 @@ async def test_generator_multi_turn_gsm8k_step_wise(ray_init_fixture): assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) +@pytest.mark.asyncio async def test_generator_multi_turn_gsm8k_router_replay(ray_init_fixture): """ Test the generator with the multi-turn GSM8K environment for router replay """ + num_prompts = 5 + n_samples_per_prompt = 2 + max_input_length = 4096 generator_output: GeneratorOutput = await run_generator_end_to_end( use_async_engine=True, batched=False, - n_samples_per_prompt=5, + n_samples_per_prompt=n_samples_per_prompt, num_inference_engines=2, tensor_parallel_size=2, model="arcee-ai/Trinity-Nano-Preview", max_prompt_length=2048, - max_input_length=4096, + max_input_length=max_input_length, max_generate_length=1000, data_path=os.path.expanduser("~/data/gsm8k/validation.parquet"), env_class="gsm8k_multi_turn", - num_prompts=2, + num_prompts=num_prompts, max_turns=2, use_conversation_multi_turn=True, max_env_workers=0, @@ -496,3 +500,13 @@ async def test_generator_multi_turn_gsm8k_router_replay(ray_init_fixture): assert isinstance(generator_output["is_last_step"], list) and isinstance(generator_output["is_last_step"][0], bool) # Expect atleast one response with more than one turn assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) + assert generator_output["rollout_expert_indices"] is not None + + # check that the rollout expert indices are non-zero, and that the shape is (bs, seq_len, layer_num, topk) + rollout_expert_indices = generator_output["rollout_expert_indices"] + total_batch_size = num_prompts * n_samples_per_prompt + + assert len(rollout_expert_indices) == total_batch_size + assert len(rollout_expert_indices[0]) < max_input_length + assert len(rollout_expert_indices[0][0]) == 56 # 56 layers in Trinity-Nano-Preview + assert len(rollout_expert_indices[0][0][0]) == 8 # 8 topk for each layer From 205da19ec441a8dce6b3e54e2451a38971829f74 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 12 Mar 2026 20:39:34 +0000 Subject: [PATCH 17/31] lint --- tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index f08e31c490..3ed69a7608 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -106,6 +106,7 @@ def build_training_input_from_text_samples( training_input.metadata = {"response_length": num_actions} return training_input + @pytest.mark.megatron def test_logprobs(ray_init_fixture): """ From f78dc753dc1f9649359a9261a8ace83fea0d52d4 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 12 Mar 2026 22:53:40 +0000 Subject: [PATCH 18/31] cleaning up --- .../skyrl_train/inference_engines/base.py | 2 +- .../skyrl_train/utils/replay_utils.py | 67 +++++++++++++++++-- .../megatron/megatron_model_wrapper.py | 8 ++- .../workers/megatron/megatron_worker.py | 4 +- skyrl/train/config/legacy.py | 2 + skyrl/train/generators/skyrl_gym_generator.py | 39 ++++++----- skyrl/train/utils/utils.py | 5 ++ .../gpu/gpu_ci/test_router_replay.py | 8 +-- .../test_inference_engine_client.py | 5 +- tests/train/dataset/test_preprocess.py | 4 +- tests/train/generators/test_datatypes.py | 1 + .../generators/test_skyrl_gym_generator.py | 1 + 12 files changed, 109 insertions(+), 37 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index 8b21738e77..c72ed7af89 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -31,7 +31,7 @@ class InferenceEngineOutput(TypedDict): response_ids: List[List[int]] stop_reasons: List[str] response_logprobs: Optional[List[List[float]]] - rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] + rollout_expert_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk] class InferenceEngineInterface(ABC): diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 87ec12dedb..4941b4c0f8 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -7,7 +7,7 @@ import torch -def _patch_topk_router_layer_number(): +def patch_topk_router_layer_number(): """Monkey-patch TopKRouter.set_layer_number to propagate the global layer number to the RouterReplay instance. @@ -123,16 +123,72 @@ def _remove_left_padding_from_indices( return new_rii +def _pack_replay_indices( + rollout_expert_indices: torch.Tensor, + attention_mask: torch.Tensor, +) -> torch.Tensor: + """Pack routing indices to match the token layout produced by preprocess_packed_seqs. + + With sample packing, Megatron concatenates all sequences into one packed + sequence with per-sample alignment padding. The MoE router sees tokens in + this packed order, so replay indices must follow the same layout. + + Returns: + [1, total_packed_len, layers, topk] matching the packed model input. + """ + import megatron.core.parallel_state as mpu + + batch_size = rollout_expert_indices.shape[0] + num_layers = rollout_expert_indices.shape[2] + topk = rollout_expert_indices.shape[3] + + seq_lens = attention_mask.sum(dim=-1, dtype=torch.int32) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + + pad_sizes = (align_size - seq_lens % align_size) % align_size + seqlens_padded = seq_lens + pad_sizes + + total_packed_len = int(seqlens_padded.sum().item()) + if cp_size > 1: + total_packed_len = total_packed_len // cp_size + + packed = torch.zeros( + total_packed_len, + num_layers, + topk, + dtype=rollout_expert_indices.dtype, + device=rollout_expert_indices.device, + ) + + seq_lens_cpu = seq_lens.tolist() + seqlens_padded_cpu = seqlens_padded.tolist() + offset = 0 + for i in range(batch_size): + n = seq_lens_cpu[i] + mask = attention_mask[i].bool() + packed[offset : offset + n] = rollout_expert_indices[i, mask] + offset += seqlens_padded_cpu[i] // cp_size if cp_size > 1 else seqlens_padded_cpu[i] + + return packed.unsqueeze(0) # [1, total_packed_len, layers, topk] + + def setup_per_microbatch_replay_forward( rollout_expert_indices: torch.Tensor, attention_mask: torch.Tensor, + use_sample_packing: bool = False, ) -> None: """Set up RouterReplay for a single micro-batch, aligning indices - with the left-padding-removed token layout that the MoE layer sees. + with the token layout that the MoE layer sees. Handles sequence parallelism: when TP > 1, the sequence is split across TP ranks, so each rank's MoE router only sees its local chunk of tokens. + Handles sample packing: when use_sample_packing is True, sequences are + concatenated into one packed sequence with per-sample alignment padding. + The replay indices must follow this same packed layout. + Handles dense-layer mismatch: DeepSeek V3-style models have dense FFN layers before the MoE layers. vLLM reports routing indices for ALL transformer layers, but Megatron only has RouterReplay instances for MoE @@ -147,10 +203,11 @@ def setup_per_microbatch_replay_forward( _patch_alltoall_dispatcher_for_replay() - aligned = _remove_left_padding_from_indices(rollout_expert_indices, attention_mask) + if use_sample_packing: + aligned = _pack_replay_indices(rollout_expert_indices, attention_mask) + else: + aligned = _remove_left_padding_from_indices(rollout_expert_indices, attention_mask) - # handles megatron sequence parallelism across the tensor model parallel region - # since we automatically enable sequence parallelism when TP > 1 tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: tp_rank = mpu.get_tensor_model_parallel_rank() diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 17c6b5db19..df86161c9a 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -113,7 +113,9 @@ def forward_step(batch_iter, model): rollout_expert_indices = batch.pop("rollout_expert_indices", None) if rollout_expert_indices is not None: - setup_per_microbatch_replay_forward(rollout_expert_indices, batch["attention_mask"]) + setup_per_microbatch_replay_forward( + rollout_expert_indices, batch["attention_mask"], use_sample_packing=self.use_sample_packing + ) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) @@ -369,7 +371,9 @@ def forward_step(batch_iter, model): rollout_expert_indices = batch.pop("rollout_expert_indices", None) if rollout_expert_indices is not None: - setup_per_microbatch_replay_forward(rollout_expert_indices, batch["attention_mask"]) + setup_per_microbatch_replay_forward( + rollout_expert_indices, batch["attention_mask"], use_sample_packing=self.use_sample_packing + ) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 9076106e2f..31daa3c36a 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -579,10 +579,10 @@ def init_model(self, model_path, num_training_steps: int = 1e9): if self.enable_router_replay: from skyrl.backends.skyrl_train.utils.replay_utils import ( - _patch_topk_router_layer_number, + patch_topk_router_layer_number, ) - _patch_topk_router_layer_number() + patch_topk_router_layer_number() # wrap with DDP for training self.actor_module = self.make_megatron_module( diff --git a/skyrl/train/config/legacy.py b/skyrl/train/config/legacy.py index 46ac9a7a4b..9b23832e8b 100644 --- a/skyrl/train/config/legacy.py +++ b/skyrl/train/config/legacy.py @@ -41,6 +41,8 @@ "override_existing_update_group": None, "external_proxy_url": None, "external_server_urls": None, + "enable_return_routed_experts": None, + "distributed_executor_backend": None, } # Fields that should be removed (deprecated or derived) diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index acd25af560..e5d09aeffd 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -84,10 +84,10 @@ class TurnOutput: def get_turn_rollout_expert_indices(self) -> Optional[List[List[List[int]]]]: """ - Get rollout inference indices for this turn's tokens (output + observation). + Get rollout inference indices for this turn's tokens (output tokens + observation tokens). - Returns indices for generated output tokens, with padding entries (all -1) - for any manually-added EOS token and observation tokens. + Returns indices for generated output tokens, with padding entries (all 0) + for any manually-added EOS token and observation tokens Returns None if rollout_expert_indices is None. """ if self.rollout_expert_indices is None: @@ -96,7 +96,7 @@ def get_turn_rollout_expert_indices(self) -> Optional[List[List[List[int]]]]: return self.rollout_expert_indices layer_num = len(self.rollout_expert_indices[0]) topk = len(self.rollout_expert_indices[0][0]) if layer_num > 0 else 0 - pad_entry = [[-1] * topk for _ in range(layer_num)] + pad_entry = [[0] * topk for _ in range(layer_num)] indices = list(self.rollout_expert_indices) if self.added_eos: indices.append(pad_entry) @@ -204,6 +204,9 @@ def _validate_cfg(self, generator_cfg: GeneratorConfig): f"`step_wise_trajectories` doesn't support custom chat template, got {generator_cfg.chat_template}" ) + if self.generator_cfg.inference_engine.enable_return_routed_experts: + raise ValueError("`step_wise_trajectories` doesn't support `enable_return_routed_experts=True`") + if not self.use_conversation_multi_turn: raise ValueError("`step_wise_trajectories` doesn't support `use_conversation_multi_turn=False`") @@ -342,6 +345,8 @@ async def agent_loop( if rollout_expert_indices is not None: rollout_expert_indices = rollout_expert_indices[0] + if self.custom_chat_template is not None: + raise ValueError("Rollout expert indices bookkeeping is not supported with custom chat template") # Append eos when sampling_params.stop is not None. Does not affect 3.a as chat templates add eos_token. # sampling_params is not None for eval, but None for training (which uses engine.sampling_params which are from cfg) stop_strs = current_sampling_params.get("stop", None) @@ -809,14 +814,10 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False else: rollout_logprobs = None - if self.generator_cfg.step_wise_trajectories: - all_indices = sum( - [[step_output.rollout_expert_indices for step_output in output.step_outputs] for output in all_outputs], - [], - ) + if self.generator_cfg.inference_engine.enable_return_routed_experts: + rollout_expert_indices = [output.rollout_expert_indices for output in all_outputs] else: - all_indices = [output.rollout_expert_indices for output in all_outputs] - rollout_expert_indices = all_indices if any(idx is not None for idx in all_indices) else None + rollout_expert_indices = None rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) @@ -957,7 +958,9 @@ def _update_agent_loop_state_with_multiturn_chat_template( loss_mask_for_turn = turn_output.get_turn_loss_mask() rollout_logprobs_for_turn = turn_output.get_turn_rollout_logprobs() - rollout_expert_indices_for_turn = turn_output.get_turn_rollout_expert_indices() + # use the raw rollout expert indices without any appending of observation tokens + # this will be overwritten each turn, so we don't need to append observation tokens to it + rollout_expert_indices_for_turn = turn_output.rollout_expert_indices if self.generator_cfg.step_wise_trajectories: # cumulative input_ids is not tracked for step wise training @@ -977,6 +980,7 @@ def _update_agent_loop_state_with_multiturn_chat_template( if agent_loop_state.rollout_expert_indices is not None and rollout_expert_indices_for_turn is not None: # overwrite the existing rollout inference indices, since the inference engine should # return the expert indices for the entire sequence including each turn's input + # and the final response should not have an observation appended to it agent_loop_state.rollout_expert_indices = rollout_expert_indices_for_turn return agent_loop_state @@ -1042,14 +1046,6 @@ def _update_agent_loop_state_with_singleturn_chat_template( obs_ids_to_add ) - # rollout_expert_indices_for_turn = None - # if turn_output.rollout_expert_indices is not None and turn_output.rollout_expert_indices: - # layer_num = len(turn_output.rollout_expert_indices[0]) - # topk = len(turn_output.rollout_expert_indices[0][0]) if layer_num > 0 else 0 - # pad_entry = [[-1] * topk for _ in range(layer_num)] - # rollout_expert_indices_for_turn = list(turn_output.rollout_expert_indices[: len(new_resp_tokens)]) - # rollout_expert_indices_for_turn.extend(pad_entry for _ in range(len(obs_ids_to_add))) - # Directly append turn output agent_loop_state.response_end_idx = len(agent_loop_state.input_ids) + len(new_resp_tokens) - 1 agent_loop_state.input_ids += turn_ids @@ -1060,6 +1056,9 @@ def _update_agent_loop_state_with_singleturn_chat_template( self.generator_cfg.inference_engine.enable_return_routed_experts and turn_output.rollout_expert_indices is not None ): + # overwrite the existing rollout inference indices, since the inference engine should + # return the expert indices for the entire sequence including each turn's input and observation tokens + # and the final response should not have an observation appended to it agent_loop_state.rollout_expert_indices = turn_output.rollout_expert_indices return agent_loop_state diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 6c12066c7f..7aca1b74e6 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -445,6 +445,11 @@ def validate_generator_cfg(cfg: SkyRLTrainConfig): assert ie_cfg.distributed_executor_backend in ("mp", "ray"), "invalid distributed executor backend" + if ie_cfg.enable_return_routed_experts: + assert ie_cfg.distributed_executor_backend == "mp", "rollout router replay (r3) can hang with the ray backend - use the vLLM mp backend instead" + assert cfg.trainer.strategy == "megatron", "rollout router replay (r3) is only supported with Megatron training backend" + assert cfg.trainer.policy.megatron_config.moe_enable_routing_replay, "moe_enable_routing_replay must be True to consume rollout expert indices" + pp_size = ie_cfg.pipeline_parallel_size tp_pp_size = tp_size * pp_size num_gpus_per_node = cfg.trainer.placement.policy_num_gpus_per_node diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 3ed69a7608..ddd9b57e77 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -224,8 +224,8 @@ def test_logprobs(ray_init_fixture): cfg.trainer.policy.megatron_config.context_parallel_size = 1 cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 - cfg.trainer.micro_forward_batch_size_per_gpu = 1 - cfg.trainer.micro_train_batch_size_per_gpu = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 2 + cfg.trainer.micro_train_batch_size_per_gpu = 2 def run_megatron_forward(enable_replay: bool) -> torch.Tensor: cfg.trainer.policy.megatron_config.moe_enable_routing_replay = enable_replay @@ -386,8 +386,8 @@ def test_forward_backward(ray_init_fixture): cfg.trainer.policy.megatron_config.context_parallel_size = 1 cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 - cfg.trainer.micro_forward_batch_size_per_gpu = 1 - cfg.trainer.micro_train_batch_size_per_gpu = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 2 + cfg.trainer.micro_train_batch_size_per_gpu = 2 def run_megatron_forward_backward(enable_replay: bool) -> dict: cfg.trainer.policy.megatron_config.moe_enable_routing_replay = enable_replay diff --git a/tests/backends/skyrl_train/inference_engines/test_inference_engine_client.py b/tests/backends/skyrl_train/inference_engines/test_inference_engine_client.py index fcc6b98034..5ed4c39e2e 100644 --- a/tests/backends/skyrl_train/inference_engines/test_inference_engine_client.py +++ b/tests/backends/skyrl_train/inference_engines/test_inference_engine_client.py @@ -933,4 +933,7 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu # client should return the second response directly (no aggregation) # Besides, since we completed in one turn, we return the text response of the first turn returned by # the underlying engine instead re-tokenizing the accumulated tokens - assert out == engines[0].responses[1] + assert out["responses"] == engines[0].responses[1]["responses"] + assert out["response_ids"] == engines[0].responses[1]["response_ids"] + assert out["stop_reasons"] == engines[0].responses[1]["stop_reasons"] + assert out["response_logprobs"] == engines[0].responses[1]["response_logprobs"] diff --git a/tests/train/dataset/test_preprocess.py b/tests/train/dataset/test_preprocess.py index 69a6e7e543..b6aac3396f 100644 --- a/tests/train/dataset/test_preprocess.py +++ b/tests/train/dataset/test_preprocess.py @@ -65,7 +65,7 @@ def test_convert_prompts_responses_to_batch_tensors_exact(tokenizer): loss_masks = [[1, 1, 0], [1, 1, 1, 0, 0]] rewards = [torch.tensor([0, 1, 0]), torch.tensor([1, 0, 0, 0, 0])] - sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs = ( + sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs, _ = ( convert_prompts_responses_to_batch_tensors( tokenizer, prompts, @@ -93,7 +93,7 @@ def test_convert_prompts_responses_to_batch_tensors_different_lengths(tokenizer) rewards = [torch.tensor([1.0, 0.5, 0.3]), torch.tensor([0.8])] loss_masks = [[1, 1, 1], [1]] - sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs = ( + sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs, _ = ( convert_prompts_responses_to_batch_tensors( tokenizer, prompts, diff --git a/tests/train/generators/test_datatypes.py b/tests/train/generators/test_datatypes.py index 580cbd97f5..12c9efdba2 100644 --- a/tests/train/generators/test_datatypes.py +++ b/tests/train/generators/test_datatypes.py @@ -31,6 +31,7 @@ def test_turn_output(output_ids, observation_ids, output_logprobs, added_eos, ex output_logprobs=output_logprobs, new_obs=[], obs_ids=observation_ids, + rollout_expert_indices=None, added_eos=added_eos, reward=1.0, ) diff --git a/tests/train/generators/test_skyrl_gym_generator.py b/tests/train/generators/test_skyrl_gym_generator.py index dade800fdb..93ae2c7ffb 100644 --- a/tests/train/generators/test_skyrl_gym_generator.py +++ b/tests/train/generators/test_skyrl_gym_generator.py @@ -371,6 +371,7 @@ def test_generator_output_concatenation(): "stop_reasons", "rollout_metrics", "rollout_logprobs", + "rollout_expert_indices", # optional but present in the signature "trajectory_ids", "is_last_step", From 43297f0ea42116f78b34bdf026fd48ca07a4f5a1 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 12 Mar 2026 22:55:58 +0000 Subject: [PATCH 19/31] x --- examples/train/gsm8k/run_gsm8k.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/train/gsm8k/run_gsm8k.sh b/examples/train/gsm8k/run_gsm8k.sh index e5dc9e20f3..2693571dac 100755 --- a/examples/train/gsm8k/run_gsm8k.sh +++ b/examples/train/gsm8k/run_gsm8k.sh @@ -26,8 +26,8 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \ trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \ trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ - generator.inference_engine.num_engines=1 \ - generator.inference_engine.tensor_parallel_size=4 \ + generator.inference_engine.num_engines=$NUM_GPUS \ + generator.inference_engine.tensor_parallel_size=1 \ trainer.epochs=20 \ trainer.eval_batch_size=1024 \ trainer.eval_before_train=true \ From 0468c374842c26ac8993962325578e9dd351eb8d Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 12 Mar 2026 23:06:10 +0000 Subject: [PATCH 20/31] x --- .../gpu/gpu_ci/test_megatron_worker.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index a37acabf7e..626dce4656 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -41,9 +41,7 @@ # TODO (erictang000): we would prefer to use this smaller MoE model for testing, but seeing incorrect logprobs when using EP > 1 # this might be a model specific mbridge issue - see if this persists when we transition to Megatron-Bridge # MOE_MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B" -# MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" -MOE_MODEL_NAME = "/home/ray/moonlight16b" - +MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" def get_test_actor_config(model_name=MODEL_NAME) -> SkyRLTrainConfig: cfg = SkyRLTrainConfig() @@ -245,10 +243,10 @@ async def test_megatron_forward( cfg.trainer.use_sample_packing = use_sample_packing batch = get_test_training_batch(max(4, gpus_per_node)) - # if ep > 1: - # if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: - # cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() - # cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 + if ep > 1: + if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() + cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 if lora: cfg.trainer.policy.model.lora = SkyRLLoraConfig(rank=16, alpha=16) @@ -263,9 +261,6 @@ async def test_megatron_forward( action_log_probs_refs = actor_group.async_run_ray_method("mesh", "forward", data=batch) all_rank_action_log_probs = ray.get(action_log_probs_refs) - print( - f"Megatron logprobs - mean: {all_rank_action_log_probs.mean().item():.6f}, std: {all_rank_action_log_probs.std().item():.6f}" - ) action_log_probs_megatron = concatenate_outputs_after_mesh_dispatch( actor_group.actor_infos, all_rank_action_log_probs )["output"] @@ -278,8 +273,8 @@ async def test_megatron_forward( @ray.remote(num_gpus=1) def run_hf_forward(batch, model_name): config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, dtype=torch.bfloat16) - # if ep > 1: - # config.num_hidden_layers = 2 + if ep > 1: + config.num_hidden_layers = 2 model = AutoModelForCausalLM.from_pretrained(model_name, config=config, dtype=torch.bfloat16) model.eval() model.to("cuda") From 4878ed7117900cde6b901fa5561b667dc3bb6f34 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 01:08:19 +0000 Subject: [PATCH 21/31] x --- .../skyrl_train/utils/replay_utils.py | 17 +- skyrl/train/generators/skyrl_gym_generator.py | 2 +- skyrl/train/utils/utils.py | 12 +- .../gpu/gpu_ci/test_megatron_worker.py | 7 +- .../gpu/gpu_ci/test_router_replay.py | 180 ++++++------------ .../gpu/gpu_ci/test_skyrl_gym_generator.py | 14 +- tests/backends/skyrl_train/gpu/utils.py | 16 ++ 7 files changed, 113 insertions(+), 135 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 4941b4c0f8..47e405769f 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -164,12 +164,24 @@ def _pack_replay_indices( seq_lens_cpu = seq_lens.tolist() seqlens_padded_cpu = seqlens_padded.tolist() + if cp_size > 1: + cp_rank = mpu.get_context_parallel_rank() offset = 0 for i in range(batch_size): n = seq_lens_cpu[i] mask = attention_mask[i].bool() - packed[offset : offset + n] = rollout_expert_indices[i, mask] - offset += seqlens_padded_cpu[i] // cp_size if cp_size > 1 else seqlens_padded_cpu[i] + d = rollout_expert_indices[i, mask] + if cp_size > 1: + chunk_size = seqlens_padded_cpu[i] // cp_size + start = cp_rank * chunk_size + end = min(start + chunk_size, n) + valid_len = max(0, end - start) + if valid_len > 0: + packed[offset : offset + valid_len] = d[start:end] + offset += chunk_size + else: + packed[offset : offset + n] = d + offset += seqlens_padded_cpu[i] return packed.unsqueeze(0) # [1, total_packed_len, layers, topk] @@ -261,4 +273,3 @@ def clear_router_replay(): RouterReplay.clear_global_indices() RouterReplay.clear_global_router_replay_action() - RouterReplay.clear_global_router_replay_instances() diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index dc89f7cb34..30433cd9ac 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -496,7 +496,7 @@ async def agent_loop( if rollout_expert_indices_out is not None and rollout_expert_indices_out: layer_num = len(rollout_expert_indices_out[0]) topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0 - rollout_expert_indices_out.append([[-1] * topk for _ in range(layer_num)]) + rollout_expert_indices_out.append([[0] * topk for _ in range(layer_num)]) appended_eos_token = True if self.generator_cfg.step_wise_trajectories: diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 8fbde0de03..a1d711fa3d 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -446,9 +446,15 @@ def validate_generator_cfg(cfg: SkyRLTrainConfig): assert ie_cfg.distributed_executor_backend in ("mp", "ray"), "invalid distributed executor backend" if ie_cfg.enable_return_routed_experts: - assert ie_cfg.distributed_executor_backend == "mp", "rollout router replay (r3) can hang with the ray backend - use the vLLM mp backend instead" - assert cfg.trainer.strategy == "megatron", "rollout router replay (r3) is only supported with Megatron training backend" - assert cfg.trainer.policy.megatron_config.moe_enable_routing_replay, "moe_enable_routing_replay must be True to consume rollout expert indices" + assert ( + ie_cfg.distributed_executor_backend == "mp" + ), "rollout router replay (r3) can hang with the ray backend - use the vLLM mp backend instead" + assert ( + cfg.trainer.strategy == "megatron" + ), "rollout router replay (r3) is only supported with Megatron training backend" + assert ( + cfg.trainer.policy.megatron_config.moe_enable_routing_replay + ), "moe_enable_routing_replay must be True to consume rollout expert indices" pp_size = ie_cfg.pipeline_parallel_size tp_pp_size = tp_size * pp_size diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 626dce4656..94726a4f4d 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -43,6 +43,7 @@ # MOE_MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B" MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" + def get_test_actor_config(model_name=MODEL_NAME) -> SkyRLTrainConfig: cfg = SkyRLTrainConfig() cfg.trainer.policy.model.path = model_name @@ -244,9 +245,9 @@ async def test_megatron_forward( batch = get_test_training_batch(max(4, gpus_per_node)) if ep > 1: - if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: - cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() - cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 + if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() + cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 if lora: cfg.trainer.policy.model.lora = SkyRLLoraConfig(rank=16, alpha=16) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index ddd9b57e77..f1d2b666b5 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -108,6 +108,7 @@ def build_training_input_from_text_samples( @pytest.mark.megatron +@pytest.mark.skip(reason="Skipping router replay test for now due to size constraints") def test_logprobs(ray_init_fixture): """ Check that logprob diff is lower when using router replay. Requires full 8xH100 setup to do full forward pass. @@ -269,96 +270,60 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: @pytest.mark.megatron +@pytest.mark.skip(reason="Skipping router replay test for now due to size constraints") def test_forward_backward(ray_init_fixture): """ - Check that forward_backward produces similar losses with and without - router replay (same weights, so routing decisions should nearly match). - Requires full 8xH100 setup. + Check that forward_backward with router replay completes without error. + Uses dummy expert routing indices (no vLLM engine needed). + Non-zero advantages / action_log_probs verify the loss is actually computed. """ try: cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) cfg.trainer.strategy = "megatron" - cfg.generator.inference_engine.enable_return_routed_experts = True - cfg.generator.inference_engine.tensor_parallel_size = 8 - cfg.generator.sampling_params = SamplingParams( - max_generate_length=MAX_GENERATE_LENGTH, - logprobs=1, - temperature=1.0, - ) - cfg.generator.batched = False - cfg.generator.max_turns = 1 tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) - - with InferenceEngineState.create( - cfg=cfg, - model=MOE_MODEL_NAME, - use_local=True, - colocate_all=True, - backend="vllm", - sleep_level=1, - gpu_memory_utilization=0.9, - ) as engines: - client, pg = engines.client, engines.pg - asyncio.run(client.wake_up()) - - generator = SkyRLGymGenerator( - generator_cfg=cfg.generator, - skyrl_gym_cfg=cfg.environment.skyrl_gym, - inference_engine_client=client, - tokenizer=tokenizer, - ) - - input_batch: GeneratorInput = get_test_generator_input( - model=MOE_MODEL_NAME, - num_prompts=NUM_PROMPTS, - n_samples_per_prompt=N_SAMPLES_PER_PROMPT, - max_prompt_length=512, - env_class="gsm8k", - ) - input_batch["sampling_params"] = get_sampling_params_for_backend( - "vllm", - SamplingParams( - temperature=1.0, - top_p=1.0, - top_k=-1, - max_generate_length=MAX_GENERATE_LENGTH, - min_p=0.0, - logprobs=1, - ), - ) - - with Timer("generate_with_router_replay"): - generator_output = asyncio.run(generator.generate(input_batch)) - - indices = generator_output["rollout_expert_indices"] - responses = generator_output["response_ids"] - assert ( - indices is not None - ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" - assert len(indices) == len( - responses - ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" - asyncio.run(client.sleep()) - - rewards = generator_output["rewards"] - if rewards and not isinstance(rewards[0], list): - rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] - (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + num_samples = NUM_PROMPTS * N_SAMPLES_PER_PROMPT + prompts = [] + responses = [] + rewards = [] + loss_masks = [] + for i in range(num_samples): + prompt_ids = tokenizer.encode(f"What is {i} + {i}?", add_special_tokens=False) + response_ids = tokenizer.encode(f"The answer is {i + i}.", add_special_tokens=False) + if tokenizer.eos_token_id is not None and (not response_ids or response_ids[-1] != tokenizer.eos_token_id): + response_ids.append(tokenizer.eos_token_id) + prompts.append(prompt_ids) + responses.append(response_ids) + rewards.append([1.0] * len(response_ids)) + loss_masks.append([1] * len(response_ids)) + + sequences, attention_mask, response_mask, rewards_t, loss_mask_t, _, _ = ( convert_prompts_responses_to_batch_tensors( tokenizer=tokenizer, - prompts=generator_output["prompt_token_ids"], + prompts=prompts, responses=responses, rewards=rewards, - loss_masks=generator_output["loss_masks"], - logprobs=generator_output.get("rollout_logprobs"), - rollout_expert_indices=indices, + loss_masks=loss_masks, ) ) - assert rii_tensor is not None - num_actions = response_mask.shape[1] batch_size = sequences.shape[0] + seq_len = sequences.shape[1] + num_actions = response_mask.shape[1] + + # Moonlight 16B: 27 MoE layers, top_k=6, 64 routed experts + MOONLIGHT_NUM_LAYERS = 27 + MOONLIGHT_TOPK = 6 + MOONLIGHT_NUM_EXPERTS = 64 + rollout_expert_indices = torch.randint( + 0, MOONLIGHT_NUM_EXPERTS, (batch_size, seq_len, MOONLIGHT_NUM_LAYERS, MOONLIGHT_TOPK), dtype=torch.int32 + ) + rollout_expert_indices[attention_mask == 0] = 0 + + gen = torch.Generator().manual_seed(42) training_input = TrainingInputBatch( { "sequences": sequences, @@ -366,15 +331,11 @@ def test_forward_backward(ray_init_fixture): "response_mask": response_mask, "rewards": rewards_t, "loss_mask": loss_mask_t, - "rollout_logprobs": ( - logprobs_t - if logprobs_t is not None - else torch.zeros((batch_size, num_actions), dtype=torch.float32) - ), - "rollout_expert_indices": rii_tensor, - "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "rollout_logprobs": -torch.rand((batch_size, num_actions), generator=gen) * 2.0, + "rollout_expert_indices": rollout_expert_indices, + "action_log_probs": -torch.rand((batch_size, num_actions), generator=gen) * 2.0, + "base_action_log_probs": -torch.rand((batch_size, num_actions), generator=gen) * 2.0, + "advantages": torch.randn((batch_size, num_actions), generator=gen), "action_mask": response_mask.to(dtype=torch.int64), } ) @@ -388,41 +349,26 @@ def test_forward_backward(ray_init_fixture): cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 cfg.trainer.micro_forward_batch_size_per_gpu = 2 cfg.trainer.micro_train_batch_size_per_gpu = 2 + cfg.trainer.policy.megatron_config.moe_enable_routing_replay = True - def run_megatron_forward_backward(enable_replay: bool) -> dict: - cfg.trainer.policy.megatron_config.moe_enable_routing_replay = enable_replay - actor_group = init_worker_with_type( - "policy", - shared_pg=pg, - colocate_all=True, - num_gpus_per_node=8, - cfg=cfg, - ) - - ray.get(actor_group.async_run_ray_method("pass_through", "setup_per_microbatch_replay_backward")) - ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) - ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) - results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) - for actor in actor_group._actor_handlers: - ray.kill(actor) - return results[0] - - metrics_replay = run_megatron_forward_backward(enable_replay=True) - metrics_no_replay = run_megatron_forward_backward(enable_replay=False) - - loss_replay = metrics_replay["policy_loss"] - loss_no_replay = metrics_no_replay["policy_loss"] - print(f"With replay - loss: {loss_replay:.6f}") - print(f"Without replay - loss: {loss_no_replay:.6f}") - # print(f"With replay metrics: {metrics_replay}") - # print(f"Without replay metrics: {metrics_no_replay}") - - diff = abs(loss_replay - loss_no_replay) - threshold = 0.5 - print(f"Loss diff: {diff:.6f} (threshold: {threshold})") - assert diff < threshold, ( - f"Losses with/without replay should be similar (same weights), " - f"but diff={diff:.6f} >= threshold={threshold}" + actor_group = init_worker_with_type( + "policy", + num_gpus_per_node=8, + cfg=cfg, ) + + ray.get(actor_group.async_run_ray_method("pass_through", "setup_per_microbatch_replay_backward")) + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) + + metrics = results[0] + loss = metrics["policy_loss"] + print(f"Router replay forward_backward - loss: {loss:.6f}") + assert loss is not None and not torch.isnan(torch.tensor(loss)), "Loss should be valid (not NaN)" + assert loss != 0.0, "Loss should be non-zero with non-zero advantages" + + for actor in actor_group._actor_handlers: + ray.kill(actor) finally: ray.shutdown() diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py index 5f0acd46ac..a8b782f924 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py @@ -20,6 +20,7 @@ from tests.backends.skyrl_train.gpu.utils import ( InferenceEngineState, Timer, + _ensure_chat_template, get_test_generator_input, ) @@ -123,6 +124,7 @@ async def run_generator_end_to_end( End to end generator test - requires minimum 2 GPUs """ tokenizer = AutoTokenizer.from_pretrained(model) + _ensure_chat_template(tokenizer) cfg = get_test_config( max_generate_length, @@ -482,24 +484,20 @@ async def test_generator_multi_turn_gsm8k_router_replay(ray_init_fixture): n_samples_per_prompt=n_samples_per_prompt, num_inference_engines=2, tensor_parallel_size=2, - model="arcee-ai/Trinity-Nano-Preview", + model="allenai/OLMoE-1B-7B-0924", max_prompt_length=2048, max_input_length=max_input_length, max_generate_length=1000, - data_path=os.path.expanduser("~/data/gsm8k/validation.parquet"), + data_path=os.path.expanduser("/mnt/cluster_storage/data/gsm8k/validation.parquet"), env_class="gsm8k_multi_turn", num_prompts=num_prompts, max_turns=2, use_conversation_multi_turn=True, max_env_workers=0, - is_step_wise=True, + is_step_wise=False, temperature=0, enable_return_routed_experts=True, ) - - assert isinstance(generator_output["is_last_step"], list) and isinstance(generator_output["is_last_step"][0], bool) - # Expect atleast one response with more than one turn - assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) assert generator_output["rollout_expert_indices"] is not None # check that the rollout expert indices are non-zero, and that the shape is (bs, seq_len, layer_num, topk) @@ -508,5 +506,5 @@ async def test_generator_multi_turn_gsm8k_router_replay(ray_init_fixture): assert len(rollout_expert_indices) == total_batch_size assert len(rollout_expert_indices[0]) < max_input_length - assert len(rollout_expert_indices[0][0]) == 56 # 56 layers in Trinity-Nano-Preview + assert len(rollout_expert_indices[0][0]) == 16 # 16 layers in OLMoE-1B-7B-0924 assert len(rollout_expert_indices[0][0][0]) == 8 # 8 topk for each layer diff --git a/tests/backends/skyrl_train/gpu/utils.py b/tests/backends/skyrl_train/gpu/utils.py index 07754a9e9f..754feadd07 100644 --- a/tests/backends/skyrl_train/gpu/utils.py +++ b/tests/backends/skyrl_train/gpu/utils.py @@ -273,6 +273,7 @@ def get_test_prompts(model: str, num_samples: int = 20) -> List[ConversationType # Ensure pad_token is set correctly if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + _ensure_chat_template(tokenizer) dataset = PromptDataset( datasets=[TEST_DATA_PATH], @@ -289,6 +290,20 @@ def get_test_prompts(model: str, num_samples: int = 20) -> List[ConversationType return prompts +def _ensure_chat_template(tokenizer): + """Set a minimal chat template if the tokenizer doesn't ship with one.""" + if tokenizer.chat_template is None: + tokenizer.chat_template = ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}{{ message['content'] + '\n' }}" + "{% elif message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n' }}" + "{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" + ) + + def get_test_generator_input( model: str, num_prompts: int = 20, @@ -301,6 +316,7 @@ def get_test_generator_input( # Ensure pad_token is set correctly if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + _ensure_chat_template(tokenizer) dataset = PromptDataset( datasets=[data_path], From a5babb4e7feeb6a83f20bf328de14f8b91fb3980 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 02:30:00 +0000 Subject: [PATCH 22/31] fix bug not propagating router indices to fwd pass --- .../run_moonlight16b_router_replay.sh | 84 +++++++++++++++++++ .../inference_engine_client.py | 6 +- skyrl/train/trainer.py | 5 +- .../gpu/gpu_ci/test_router_replay.py | 11 +-- 4 files changed, 95 insertions(+), 11 deletions(-) create mode 100644 examples/train/router_replay/run_moonlight16b_router_replay.sh diff --git a/examples/train/router_replay/run_moonlight16b_router_replay.sh b/examples/train/router_replay/run_moonlight16b_router_replay.sh new file mode 100644 index 0000000000..a12b264836 --- /dev/null +++ b/examples/train/router_replay/run_moonlight16b_router_replay.sh @@ -0,0 +1,84 @@ +set -x + +# Colocated GRPO training+generation for Moonlight-16B-A3B-Instruct on GSM8K with Megatron with router replay (r3) +# Runs on 1 nodes of 8xH100s + +# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k +# export WANDB_API_KEY= +# bash examples/train/router_replay/run_moonlight16b_router_replay.sh + +DATA_DIR="$HOME/data/gsm8k" +LOGGER="wandb" # change to "console" to print to stdout +MODEL_NAME="moonshotai/Moonlight-16B-A3B-Instruct" + +INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron + +NUM_NODES=1 +NUM_GPUS=8 + +MEGATRON_TP=4 +MEGATRON_PP=1 +MEGATRON_CP=1 +MEGATRON_EP=8 +MEGATRON_ETP=1 + +NUM_INFERENCE_ENGINES=1 +INFERENCE_ENGINE_TP=8 + +# flash attn is not supported for moonlight16b since it is a DeepSeekV3 like model, and uses Multi-Head Latent Attention (MLA) +# https://github.com/NVIDIA/TransformerEngine/blob/483d9594fb070f62966f6a12ed6c90942310b48e/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L483 +FLASH_ATTN=false + +# router replay (r3) +ROUTER_REPLAY=true +DISTRIBUTED_EXECUTION_BACKEND="mp" + +SKYRL_RAY_PG_TIMEOUT_IN_S=300 uv run --isolated --extra megatron --with blobfile -m skyrl.train.entrypoints.main_base \ + data.train_data="['$DATA_DIR/train.parquet']" \ + data.val_data="['$DATA_DIR/validation.parquet']" \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.policy.model.path=$MODEL_NAME \ + trainer.placement.colocate_all=true \ + trainer.strategy=megatron \ + trainer.placement.policy_num_nodes=$NUM_NODES \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ + generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \ + generator.inference_engine.tensor_parallel_size=$INFERENCE_ENGINE_TP \ + trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \ + trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \ + trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ + trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ + trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ + trainer.policy.megatron_config.moe_enable_routing_replay=$ROUTER_REPLAY \ + generator.inference_engine.distributed_executor_backend=$DISTRIBUTED_EXECUTION_BACKEND \ + generator.inference_engine.enable_return_routed_experts=$ROUTER_REPLAY \ + trainer.use_sample_packing=true \ + trainer.flash_attn=$FLASH_ATTN \ + trainer.epochs=20 \ + trainer.eval_batch_size=1024 \ + trainer.eval_before_train=false \ + trainer.eval_interval=5 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=256 \ + trainer.policy_mini_batch_size=32 \ + trainer.micro_forward_batch_size_per_gpu=4 \ + trainer.micro_train_batch_size_per_gpu=4 \ + trainer.ckpt_interval=100 \ + trainer.max_prompt_length=512 \ + generator.sampling_params.max_generate_length=1024 \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=false \ + generator.inference_engine.backend=$INFERENCE_BACKEND \ + generator.inference_engine.run_engines_locally=true \ + generator.inference_engine.weight_sync_backend=nccl \ + generator.inference_engine.async_engine=true \ + generator.batched=true \ + environment.env_class=gsm8k \ + generator.n_samples_per_prompt=5 \ + generator.inference_engine.gpu_memory_utilization=0.6 \ + trainer.logger="$LOGGER" \ + trainer.project_name="gsm8k_router_replay" \ + trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_moonlight16b-a3b_with_router_replay" \ + trainer.resume_mode=null \ + trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \ + $@ \ No newline at end of file diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index da5ab0fef9..d718366ff3 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -273,7 +273,7 @@ async def _generate_single_with_retry( # 2. Initialize fields we want to accumulate or update in each loop iteration accum_response_ids: List[int] = [] accum_response_logprobs: List[float] = [] - accum_rollout_expert_indices: List[List[List[int]]] = [] + rollout_expert_indices: List[List[List[int]]] = None stop_reason: str = "abort" # We only use it if generation is completed in one turn to maintain original behavior with no retry. @@ -323,7 +323,7 @@ async def _generate_single_with_retry( if new_response_logprobs is not None: accum_response_logprobs.extend(new_response_logprobs) if new_rollout_expert_indices is not None: - accum_rollout_expert_indices.extend(new_rollout_expert_indices) + rollout_expert_indices = new_rollout_expert_indices num_turns += 1 # 4. Build the final response and return. @@ -336,7 +336,7 @@ async def _generate_single_with_retry( stop_reasons=[stop_reason], response_ids=[accum_response_ids], response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None, - rollout_expert_indices=([accum_rollout_expert_indices] if len(accum_rollout_expert_indices) > 0 else None), + rollout_expert_indices=([rollout_expert_indices] if rollout_expert_indices is not None else None), ) async def _chat_completion_with_retry( diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 30d6deca93..75e4809ee8 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -939,7 +939,10 @@ def fwd_logprobs_values_reward( - `["action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"] - `["values"]`: Float[torch.Tensor, "batch_size seqlen"] """ - data_fwd_pass = training_input.select(keys=["sequences", "attention_mask"], metadata_keys=["response_length"]) + fwd_keys = ["sequences", "attention_mask"] + if training_input.get("rollout_expert_indices") is not None: + fwd_keys.append("rollout_expert_indices") + data_fwd_pass = training_input.select(keys=fwd_keys, metadata_keys=["response_length"]) values = None base_log_probs = None diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index f1d2b666b5..a82f4d8038 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -29,12 +29,9 @@ init_worker_with_type, ) -# MOE_MODEL_NAME = "arcee-ai/Trinity-Nano-Preview" -MOE_MODEL_NAME = "/home/ray/moonlight16b" -# MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" -REPLAY_NUM_LAYERS = 2 -NUM_PROMPTS = 10 -N_SAMPLES_PER_PROMPT = 4 +MOE_MODEL_NAME = "moonshotai/Moonlight-16B-A3B-Instruct" +NUM_PROMPTS = 5 +N_SAMPLES_PER_PROMPT = 2 MAX_GENERATE_LENGTH = 128 @@ -222,7 +219,7 @@ def test_logprobs(ray_init_fixture): cfg.trainer.placement.policy_num_gpus_per_node = 8 cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 2 cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 cfg.trainer.micro_forward_batch_size_per_gpu = 2 From 7c11d73b586bfad4650b04b7f4d886ea24c516d2 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 02:51:54 +0000 Subject: [PATCH 23/31] x --- tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index a82f4d8038..2e4be5a9a5 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -218,8 +218,8 @@ def test_logprobs(ray_init_fixture): cfg.trainer.placement.policy_num_gpus_per_node = 8 cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 2 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 cfg.trainer.micro_forward_batch_size_per_gpu = 2 From ac1fb79d7228c9a6fb725c76b809360f82a01a7c Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 06:26:17 +0000 Subject: [PATCH 24/31] add supported settings to cfg validation --- skyrl/train/utils/utils.py | 11 +++++++++++ .../skyrl_train/gpu/gpu_ci/test_router_replay.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index a1d711fa3d..a647b0d12c 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -205,6 +205,12 @@ def validate_megatron_cfg(cfg: SkyRLTrainConfig): assert ( cfg.generator.inference_engine.enable_return_routed_experts ), "rollout router replay (r3) is only supported when enable_return_routed_experts is True" + assert ( + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size == 1 + ), "pipeline parallel is not yet supported for router replay (r3) with megatron" + assert ( + cfg.trainer.policy.megatron_config.context_parallel_size == 1 + ), "context parallel is not yet supported for router replay (r3) with megatron" worker_configs = [(cfg.trainer.policy, "policy"), (cfg.trainer.ref, "ref")] for config, worker_type in worker_configs: @@ -513,6 +519,11 @@ def _validate_new_inference_cfg(cfg: SkyRLTrainConfig): "the mp backend for vLLM is not yet fully supported for the new inference backend. See https://github.com/NovaSky-AI/SkyRL/issues/1309. Use the ray backend instead." ) + if cfg.generator.inference_engine.enable_return_routed_experts: + raise ValueError( + "rollout router replay (r3) is not yet fully supported for the new inference backend. See https://github.com/NovaSky-AI/SkyRL/issues/815." + ) + @ray.remote def get_all_env_variables(): diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 2e4be5a9a5..bd114d7b76 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -218,7 +218,7 @@ def test_logprobs(ray_init_fixture): cfg.trainer.placement.policy_num_gpus_per_node = 8 cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 cfg.trainer.policy.megatron_config.context_parallel_size = 1 cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 From 38b15a1291477386f4dbe8a6d1cb02fd100ad9c5 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 07:07:55 +0000 Subject: [PATCH 25/31] add docs' --- .../docs/algorithms/off_policy_correction.mdx | 39 ++++++++++++++++--- docs/content/docs/configuration/config.mdx | 6 +++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/docs/content/docs/algorithms/off_policy_correction.mdx b/docs/content/docs/algorithms/off_policy_correction.mdx index 652bd47114..b07ddaa8ae 100644 --- a/docs/content/docs/algorithms/off_policy_correction.mdx +++ b/docs/content/docs/algorithms/off_policy_correction.mdx @@ -17,20 +17,26 @@ SkyRL provides built-in utilities for correcting off-policy drift from trainer/i We recommend adding the following configs in order to your training runs to help address off-policy drift: ```yaml -# we recommend trying basic TIS correction first +# For dense models, we recommend trying basic TIS correction first trainer.algorithm.off_policy_correction.tis_ratio_type="token" trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=2.0 -# for long context + MoE models, try geometric sequence masking - tune geo_mask_high/geo_mask_low as needed +# for MoE models, enabling router replay (R3) to fix the source of train/infer mismatch is recommended +trainer.policy.megatron_config.moe_enable_routing_replay=True +generator.inference_engine.enable_return_routed_experts=True +generator.inference_engine.distributed_executor_backend="mp" # this is temporarily needed for vLLM, since routed experts cause issues with the ray backend. + +# The following masking strategies can additionally help mitigate off policy drift, especially from sources other than train/infer mismatch +# geometric sequence masking - tune geo_mask_high/geo_mask_low as needed trainer.algorithm.off_policy_correction.sequence_mask_metric="geometric" trainer.algorithm.off_policy_correction.geo_mask_high=1.01 trainer.algorithm.off_policy_correction.geo_mask_low=0.99 -# alternatively, for long context + MoE you can try token masking (icepop) and tune token_mask_is_threshold_low/high +# token masking (icepop): tune token_mask_is_threshold_low/high trainer.algorithm.off_policy_correction.token_mask_is_threshold_low=0.5 trainer.algorithm.off_policy_correction.token_mask_is_threshold_high=2.0 -# for longer context + MoE, you can also try outlier based sequence masking, which stacks on top of geometric sequence masking +# outlier based sequence masking: stacks on top of geometric sequence masking trainer.algorithm.off_policy_correction.outlier_token_is_threshold_low=1e-4 trainer.algorithm.off_policy_correction.outlier_token_is_threshold_high=100 ``` @@ -125,10 +131,31 @@ policies. To mitigate this, the max staleness of trajectories can be tuned to pr Mini batching results in off-policy updates, which can be clamped within an acceptable range in the common dual clip formulation of the PPO loss. Tuning the number of mini batches per training batch can impact convergence of RL runs, and impact whether corrections like routing replay and masking are needed. +# Routing Replay + +SkyRL supports rollout routing replay (R3), first introduced by [Ma et. al](https://arxiv.org/pdf/2510.11370) to help eliminate a source of trainer/inference mismatch for MoE at the source. Rollout routing replay works by recording expert +routing decisions at inference time, and replaying the same expert routing decisions at training time. + +```yaml +generator: + inference_engine: + enable_return_routed_experts: True # pass through argument to vLLM + distributed_executor_backend: "mp" # temporarily needed to work around hanging issues with other backends +... +trainer: + policy: + megatron_config: + moe_enable_routing_replay: True # enables Megatron native RoutingReplay feature +``` + +To enable rollout router replay, set `generator.inference_engine.enable_return_routed_experts=True`, `trainer.policy.megatron_config.moe_enable_routing_replay=True`, and use the `mp` distributed_executor_backend for vLLM. Note that +R3 does induce additional training bias when mini-batching, since routing decisions are fixed for all mini-batches in a training batch. However, it has been shown to be important for stabilizing large scale MoE training, particularly +in models adopting a Deepseek-v3 like architecture (notably the GLM family) due to the use of sigmoid-based affinity scoring instead of softmax for top-k routing. + # Algorithmic Off Policy Correction -In the previous section, we described some reasons why off-policy drift can occur, and some ways to mitigate it (e.g., batch invariant kernels, routing replay). However, -these solutions come with tradeoffs (slower inference for batch invariant kernels, additional bias for routing replay), and are not sufficient to address all sources of drift, like fully async RL. +In the previous sections, we described some reasons why off-policy drift can occur, and some ways to mitigate it (e.g., batch invariant kernels, routing replay). However, +these solutions come with tradeoffs (slower inference for batch invariant kernels), and are not sufficient to address all sources of drift, like fully async RL. Recent works ([Liu et. al 2025](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda), [Yao et. al 2025](https://fengyao.notion.site/off-policy-rl)) have proposed additional techniques for off-policy correction. In this section, we describe these techniques and how to enable them in SkyRL. diff --git a/docs/content/docs/configuration/config.mdx b/docs/content/docs/configuration/config.mdx index c1ae99edbc..85166535e6 100644 --- a/docs/content/docs/configuration/config.mdx +++ b/docs/content/docs/configuration/config.mdx @@ -164,6 +164,8 @@ megatron_config: expert_model_parallel_size: 1 expert_tensor_parallel_size: null + moe_enable_routing_replay: False + ddp_config: # pass-through config to Megatron's `DistributedDataParallelConfig` object # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.13.0/megatron/core/distributed/distributed_data_parallel_config.py#L8 ... @@ -203,6 +205,8 @@ Some rules for configuring these parameters: - `world_size % (pp_size * ep_size * etp_size) == 0` - This means that `ep_size * etp_size` can scale independently of `tp_size * cp_size`, and can go across data parallel ranks. +- `moe_enable_routing_replay`: Whether to enable megatron router replay. Used together with `generator.inference_engine.enable_return_routed_experts` to enable R3. + `optimizer_config_kwargs.use_precision_aware_optimizer=true` can cause checkpointing to fail. See: https://github.com/nvidia/megatron-lm/issues/1820. We recommend leaving this setting to `false`. @@ -631,6 +635,7 @@ generator: max_num_seqs: 1024 vllm_v1_disable_multiproc: true remote_urls: [] + enable_return_routed_experts: false distributed_executor_backend: "ray" # "mp", "ray" engine_init_kwargs: {} override_existing_update_group: "auto" # "auto", "enable", "disable" @@ -709,6 +714,7 @@ For more details on how different placement options work, please refer to the [p - `generator.inference_engine.max_num_batched_tokens`: Continous batching parameter for vLLM. Maximum number of tokens to pack into a batch. - `generator.inference_engine.enforce_eager`: Whether to disable CUDA graphs. Default is `true` for stability. Set to `false` for higher performance, but this may affect convergence for long-running or long-context training jobs. - `generator.inference_engine.enable_ray_prometheus_stats`: Whether to enable Ray Prometheus stats logger for vLLM inference engine metrics (vLLM v1 only). When enabled, uses `vllm.v1.metrics.ray_wrappers.RayPrometheusStatLogger`. +- `generator.inference_engine.enable_return_routed_experts`: Whether to return per-layer expert routing indices to use for rollout router replay (r3) if training an MoE model. Used together with `trainer.policy.megatron_config.enable_return_routed_experts` to enable R3. - `generator.inference_engine.distributed_executor_backend`: The distributed executor backend to use for the vLLM engine. Options are either `mp` or `ray`. - `generator.inference_engine.engine_init_kwargs`: Inference engine arguments passed directly to the vLLM engine. If duplicate kwargs are passed or kwargs clash with existing inference engine arguments (e.g., `tensor_parallel_size`), an error is raised. From 465ec77d394de3c1a1e1ab98755b7d22626b985a Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 07:14:40 +0000 Subject: [PATCH 26/31] docs --- docs/content/docs/algorithms/off_policy_correction.mdx | 8 ++++---- docs/content/docs/configuration/config.mdx | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/content/docs/algorithms/off_policy_correction.mdx b/docs/content/docs/algorithms/off_policy_correction.mdx index b07ddaa8ae..b8ce40a70d 100644 --- a/docs/content/docs/algorithms/off_policy_correction.mdx +++ b/docs/content/docs/algorithms/off_policy_correction.mdx @@ -133,8 +133,8 @@ can impact convergence of RL runs, and impact whether corrections like routing r # Routing Replay -SkyRL supports rollout routing replay (R3), first introduced by [Ma et. al](https://arxiv.org/pdf/2510.11370) to help eliminate a source of trainer/inference mismatch for MoE at the source. Rollout routing replay works by recording expert -routing decisions at inference time, and replaying the same expert routing decisions at training time. +SkyRL supports rollout routing replay (R3), first introduced by [Ma et al.](https://arxiv.org/pdf/2510.11370) to help eliminate trainer/inference mismatch for MoE at the source. Rollout routing replay works by recording expert +routing decisions for MoE layers at inference time, and replaying the same per-layer expert routing decisions at training time, which helps reduce mismatched logprobs. ```yaml generator: @@ -149,8 +149,8 @@ trainer: ``` To enable rollout router replay, set `generator.inference_engine.enable_return_routed_experts=True`, `trainer.policy.megatron_config.moe_enable_routing_replay=True`, and use the `mp` distributed_executor_backend for vLLM. Note that -R3 does induce additional training bias when mini-batching, since routing decisions are fixed for all mini-batches in a training batch. However, it has been shown to be important for stabilizing large scale MoE training, particularly -in models adopting a Deepseek-v3 like architecture (notably the GLM family) due to the use of sigmoid-based affinity scoring instead of softmax for top-k routing. +R3 does induce additional training bias when mini-batching, since routing decisions are fixed for all mini-batches in a training batch. However, it has been shown to be important for stabilizing large-scale MoE training, particularly +in models adopting a DeepSeek-V3 like architecture (notably the GLM family) due to the use of sigmoid-based affinity scoring instead of softmax for top-k routing. # Algorithmic Off Policy Correction diff --git a/docs/content/docs/configuration/config.mdx b/docs/content/docs/configuration/config.mdx index 85166535e6..a9b7514a30 100644 --- a/docs/content/docs/configuration/config.mdx +++ b/docs/content/docs/configuration/config.mdx @@ -205,7 +205,7 @@ Some rules for configuring these parameters: - `world_size % (pp_size * ep_size * etp_size) == 0` - This means that `ep_size * etp_size` can scale independently of `tp_size * cp_size`, and can go across data parallel ranks. -- `moe_enable_routing_replay`: Whether to enable megatron router replay. Used together with `generator.inference_engine.enable_return_routed_experts` to enable R3. +- `moe_enable_routing_replay`: Whether to enable Megatron router replay. Used together with `generator.inference_engine.enable_return_routed_experts` to enable R3. `optimizer_config_kwargs.use_precision_aware_optimizer=true` can cause checkpointing to fail. See: https://github.com/nvidia/megatron-lm/issues/1820. We recommend leaving this setting to `false`. @@ -714,7 +714,7 @@ For more details on how different placement options work, please refer to the [p - `generator.inference_engine.max_num_batched_tokens`: Continous batching parameter for vLLM. Maximum number of tokens to pack into a batch. - `generator.inference_engine.enforce_eager`: Whether to disable CUDA graphs. Default is `true` for stability. Set to `false` for higher performance, but this may affect convergence for long-running or long-context training jobs. - `generator.inference_engine.enable_ray_prometheus_stats`: Whether to enable Ray Prometheus stats logger for vLLM inference engine metrics (vLLM v1 only). When enabled, uses `vllm.v1.metrics.ray_wrappers.RayPrometheusStatLogger`. -- `generator.inference_engine.enable_return_routed_experts`: Whether to return per-layer expert routing indices to use for rollout router replay (r3) if training an MoE model. Used together with `trainer.policy.megatron_config.enable_return_routed_experts` to enable R3. +- `generator.inference_engine.enable_return_routed_experts`: Whether to return per-layer expert routing indices to use for rollout router replay (R3) if training an MoE model. Used together with `trainer.policy.megatron_config.moe_enable_routing_replay` to enable R3. - `generator.inference_engine.distributed_executor_backend`: The distributed executor backend to use for the vLLM engine. Options are either `mp` or `ray`. - `generator.inference_engine.engine_init_kwargs`: Inference engine arguments passed directly to the vLLM engine. If duplicate kwargs are passed or kwargs clash with existing inference engine arguments (e.g., `tensor_parallel_size`), an error is raised. From e6af1a08af4d37b880c175090fe9c38257eb7236 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 07:17:12 +0000 Subject: [PATCH 27/31] remove legacy --- skyrl/train/config/legacy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/skyrl/train/config/legacy.py b/skyrl/train/config/legacy.py index 9b23832e8b..46ac9a7a4b 100644 --- a/skyrl/train/config/legacy.py +++ b/skyrl/train/config/legacy.py @@ -41,8 +41,6 @@ "override_existing_update_group": None, "external_proxy_url": None, "external_server_urls": None, - "enable_return_routed_experts": None, - "distributed_executor_backend": None, } # Fields that should be removed (deprecated or derived) From 951bc2473e59ce9c8112cdee2a8d6e6e043bf933 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 07:20:29 +0000 Subject: [PATCH 28/31] x --- .../skyrl_train/inference_engines/vllm/vllm_engine.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index cad5bce14b..f7f1a83ea8 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -348,14 +348,6 @@ def _create_engine(self, *args, **kwargs): enable_log_requests = kwargs.pop("enable_log_requests", False) max_log_len = kwargs.pop("max_log_len", None) - # Log if enable_return_routed_experts is being passed - if "enable_return_routed_experts" in kwargs: - logger.info( - f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs" - ) - else: - logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") - if version.parse(vllm.__version__) >= version.parse("0.10.0"): engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) else: From 2f6c778e0532b3869f7702c60ab55baeb1a29c5c Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 07:39:23 +0000 Subject: [PATCH 29/31] x --- skyrl/train/config/ppo_base_config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index 2390a2305f..c3f48297a8 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -291,7 +291,6 @@ generator: vllm_v1_disable_multiproc: true enable_prefix_caching: true enable_chunked_prefill: true - enable_return_routed_experts: false max_num_batched_tokens: 8192 # Disable CUDA graphs by default for stability. Set to false for higher performance, but this may affect convergence for long-running and/or long context training jobs. enforce_eager: true From e3c965ca360ef3fda717b4144f4031d7ffb755a6 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 07:43:14 +0000 Subject: [PATCH 30/31] ur right devin --- tests/backends/skyrl_train/gpu/gpu_ci/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 8f906e4e55..f0af226f85 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -39,7 +39,7 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - env_vars["NVTE_FUSED_ATTN"] = "1" + env_vars["NVTE_FUSED_ATTN"] = "0" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") From bd696142ab3d6b2d99edfcfb6c3edfc9e5944bbb Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 13 Mar 2026 21:03:29 +0000 Subject: [PATCH 31/31] add dapo moonlight with r3 --- .../run_dapo_moonlight_16b_a3b.sh | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 examples/train/router_replay/run_dapo_moonlight_16b_a3b.sh diff --git a/examples/train/router_replay/run_dapo_moonlight_16b_a3b.sh b/examples/train/router_replay/run_dapo_moonlight_16b_a3b.sh new file mode 100644 index 0000000000..5a480569bd --- /dev/null +++ b/examples/train/router_replay/run_dapo_moonlight_16b_a3b.sh @@ -0,0 +1,129 @@ +set -x + +# Colocated DAPO training+generation for Moonlight-16B-A3B on DAPO with Megatron with router replay. +# Should run on 2 node of 8xH100s + +# bash examples/train/algorithms/dapo/prepare_dapo_data.sh +# bash examples/train/router_replay/run_dapo_moonlight_16b_a3b.sh + +MODEL_NAME="moonshotai/Moonlight-16B-A3B-Instruct" +DATA_DIR="$HOME/data/dapo" +TRAIN_FILE="$DATA_DIR/dapo-math-17k-cleaned.parquet" +TEST_FILE="$DATA_DIR/aime-2024-cleaned.parquet" +NUM_NODES=2 +NUM_GPUS_PER_NODE=8 +NUM_INFERENCE_ENGINES=2 +INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE=8 +LOGGER="wandb" # change to "console" to print to stdout + +# flash attention off +FLASH_ATTN=false + +CLIP_RATIO_LOW=0.2 +CLIP_RATIO_HIGH=0.28 +# use token mean loss reduction +LOSS_REDUCTION="token_mean" +# applies overlong filtering (but not soft overlong punishment) +APPLY_OVERLONG_FILTERING=true +# apply soft overlong punishment with custom trainer impl in main_dapo.py +OVERLONG_BUFFER_LEN=$((1024 * 4)) +OVERLONG_BUFFER_PENALTY_FACTOR=1.0 + +# other DAPO parameters +USE_KL_LOSS=false +TEMPERATURE=1.0 +TOP_P=1.0 +EVAL_TOP_P=0.7 +CLIP_RATIO_C=10.0 +MAX_PROMPT_LENGTH=$((1024 * 2)) +MAX_RESPONSE_LENGTH=$((1024 * 8)) + +# repro run parameters +TRAIN_BATCH_SIZE=128 +MINI_BATCH_SIZE=32 +N_SAMPLES_PER_PROMPT=16 +EVAL_N_SAMPLES_PER_PROMPT=32 +ENFORCE_EAGER=true # cuda graphs can cause some instability +LR=1e-6 + +# megatron config +MEGATRON_TP=4 +MEGATRON_PP=1 +MEGATRON_CP=1 +MEGATRON_EP=8 +MEGATRON_ETP=1 + + +# Router replay (r3) +ROUTER_REPLAY=true +DISTRIBUTED_EXECUTION_BACKEND="mp" + +SKYRL_RAY_PG_TIMEOUT_IN_S=300 uv run --isolated --extra megatron -m examples.train.algorithms.dapo.main_dapo \ + data.train_data="['$TRAIN_FILE']" \ + data.val_data="['$TEST_FILE']" \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.algorithm.policy_loss_type="dual_clip" \ + trainer.algorithm.overlong_buffer_len=$OVERLONG_BUFFER_LEN \ + trainer.algorithm.overlong_buffer_penalty_factor=$OVERLONG_BUFFER_PENALTY_FACTOR \ + trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ + generator.inference_engine.enforce_eager=$ENFORCE_EAGER \ + generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ + generator.sampling_params.temperature=$TEMPERATURE \ + generator.sampling_params.top_p=$TOP_P \ + generator.eval_sampling_params.top_p=$EVAL_TOP_P \ + generator.eval_sampling_params.temperature=$TEMPERATURE \ + generator.eval_sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \ + trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ + trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ + trainer.policy.model.path="$MODEL_NAME" \ + trainer.placement.colocate_all=true \ + trainer.strategy=megatron \ + trainer.placement.policy_num_nodes=$NUM_NODES \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS_PER_NODE \ + generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \ + generator.inference_engine.tensor_parallel_size=$INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE \ + trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \ + trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \ + trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ + trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ + trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ + trainer.policy.megatron_config.moe_enable_routing_replay=$ROUTER_REPLAY \ + generator.inference_engine.enable_return_routed_experts=$ROUTER_REPLAY \ + generator.inference_engine.distributed_executor_backend=$DISTRIBUTED_EXECUTION_BACKEND \ + trainer.epochs=20 \ + trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ + trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ + trainer.eval_batch_size=1024 \ + trainer.eval_before_train=true \ + trainer.eval_interval=5 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=$TRAIN_BATCH_SIZE \ + trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \ + trainer.micro_forward_batch_size_per_gpu=4 \ + trainer.micro_train_batch_size_per_gpu=2 \ + trainer.ckpt_interval=200 \ + trainer.max_prompt_length=$MAX_PROMPT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \ + trainer.policy.optimizer_config.lr=$LR \ + trainer.policy.optimizer_config.num_warmup_steps=40 \ + trainer.policy.optimizer_config.weight_decay=0.1 \ + trainer.policy.optimizer_config.max_grad_norm=1.0 \ + trainer.flash_attn=$FLASH_ATTN \ + generator.inference_engine.backend=vllm \ + generator.inference_engine.run_engines_locally=true \ + generator.inference_engine.weight_sync_backend=nccl \ + generator.inference_engine.async_engine=false \ + generator.batched=true \ + environment.env_class=aime \ + generator.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \ + generator.eval_n_samples_per_prompt=$EVAL_N_SAMPLES_PER_PROMPT \ + generator.inference_engine.gpu_memory_utilization=0.7 \ + trainer.logger="$LOGGER" \ + trainer.project_name="router_replay" \ + trainer.run_name="dapo_moonlight_16b_a3b_megatron_r3" \ + trainer.export_path="$HOME/exports/dapo_moonlight_16b_a3b_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_r3" \ + trainer.hf_save_interval=300 \ + trainer.resume_mode=latest \ + trainer.max_ckpts_to_keep=3 \ + trainer.ckpt_path="$HOME/ckpts/dapo_moonlight_16b_a3b_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_r3" \ + $@ \ No newline at end of file