From a8027730aa042341803a399685a77c69f2427d16 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 01:58:44 +0000 Subject: [PATCH 01/18] previous code Co-authored-by: Dev Patel --- .../skyrl_train/inference_engines/base.py | 5 ++ .../inference_engines/vllm/vllm_engine.py | 20 +++++- skyrl/backends/skyrl_train/training_batch.py | 1 + .../skyrl_train/utils/replay_utils.py | 62 +++++++++++++++++++ .../workers/megatron/megatron_worker.py | 11 ++++ skyrl/train/config/config.py | 1 + .../train/config/megatron_config/policy.yaml | 1 + skyrl/train/config/ppo_base_config.yaml | 1 + skyrl/train/dataset/preprocess.py | 10 ++- skyrl/train/entrypoints/main_base.py | 1 + skyrl/train/generators/base.py | 1 + skyrl/train/generators/skyrl_gym_generator.py | 5 ++ skyrl/train/trainer.py | 4 ++ 13 files changed, 120 insertions(+), 3 deletions(-) create mode 100644 skyrl/backends/skyrl_train/utils/replay_utils.py diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index 4b073da5a0..c3595ab7a9 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -29,6 +29,7 @@ class InferenceEngineOutput(TypedDict): response_ids: List[List[int]] stop_reasons: List[str] response_logprobs: Optional[List[List[float]]] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] class InferenceEngineInterface(ABC): @@ -63,6 +64,7 @@ async def sample( all_responses = [] all_stop_reasons = [] all_response_logprobs = [] + all_rollout_inference_indices = [] for _ in range(num_samples): input_batch: InferenceEngineInput = { @@ -79,12 +81,15 @@ async def sample( all_stop_reasons.append(output["stop_reasons"][0]) if output.get("response_logprobs") is not None: all_response_logprobs.append(output["response_logprobs"][0]) + if output.get("rollout_inference_indices") is not None: + all_rollout_inference_indices.append(output["rollout_inference_indices"][0]) return { "response_ids": all_response_ids, "responses": all_responses, "stop_reasons": all_stop_reasons, "response_logprobs": all_response_logprobs if all_response_logprobs else None, + "rollout_inference_indices": all_rollout_inference_indices if all_rollout_inference_indices else None, } @abstractmethod diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 1123088cb6..a0fa9ff4c6 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -135,6 +135,7 @@ def _postprocess_outputs(self, outputs): stop_reasons: List[str] = [] response_ids: List[List[int]] = [] response_logprobs: Optional[List[List[float]]] = [] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] = [] for output in outputs: # TODO(tgriggs): Support n>1 sampling. @@ -156,14 +157,25 @@ def _postprocess_outputs(self, outputs): del token_logprobs response_logprobs.append(_logprobs) + if resp.routed_experts is not None: + if hasattr(resp.routed_experts, "tolist"): + routed_experts_list = resp.routed_experts.tolist() + else: + routed_experts_list = resp.routed_experts + rollout_inference_indices.append(routed_experts_list) + if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params + if len(rollout_inference_indices) == 0: + rollout_inference_indices = None + return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs, + rollout_inference_indices=rollout_inference_indices, ) def _get_engine(self): @@ -320,7 +332,13 @@ def _create_engine(self, *args, **kwargs): enable_ray_prometheus_stats = kwargs.pop("enable_ray_prometheus_stats", False) enable_log_requests = kwargs.pop("enable_log_requests", False) max_log_len = kwargs.pop("max_log_len", None) - + + # Log if enable_return_routed_experts is being passed + if "enable_return_routed_experts" in kwargs: + logger.info(f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs") + else: + logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") + if version.parse(vllm.__version__) >= version.parse("0.10.0"): engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) else: diff --git a/skyrl/backends/skyrl_train/training_batch.py b/skyrl/backends/skyrl_train/training_batch.py index 839508aa3c..8b00aeaa92 100644 --- a/skyrl/backends/skyrl_train/training_batch.py +++ b/skyrl/backends/skyrl_train/training_batch.py @@ -369,6 +369,7 @@ class TrainingInput(TypedDict, total=False): kl: Float[torch.Tensor, "batch_size seq_len"] rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] + rollout_inference_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]] class TrainingInputBatch(TensorBatch[TrainingInput]): diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py new file mode 100644 index 0000000000..d315779c07 --- /dev/null +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -0,0 +1,62 @@ +""" +Utility functions for MoE Router Replay. +""" + +import torch +from typing import Optional, List +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch + +def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch.Tensor]: + if rollout_inference_indices is None: + return None + if rollout_inference_indices.dim() != 4: + raise ValueError(f"Expected 4D replay indices, got shape {rollout_inference_indices.shape}") + per_layer = rollout_inference_indices.permute(2, 0, 1, 3).contiguous() + return [per_layer[i] for i in range(per_layer.shape[0])] + +def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: + """ + Set up router replay for forward pass (ref/policy inference). + """ + if not enable_router_replay: + return False + + rollout_inference_indices = data.get("rollout_inference_indices") + if rollout_inference_indices is None: + return False + + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + + RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + return True + + +def setup_router_replay_backward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: + """ + Set up router replay for training forward/backward pass. + """ + if not enable_router_replay: + return False + + rollout_inference_indices = data.get("rollout_inference_indices") + if rollout_inference_indices is None: + return False + + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + + RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) + # Use REPLAY_FORWARD - Megatron handles REPLAY_BACKWARD automatically + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + return True + + +def clear_router_replay(): + """Clear all router replay state.""" + from megatron.core.transformer.moe.router_replay import RouterReplay + + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() + RouterReplay.clear_global_router_replay_instances() diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index e80ad1c85a..850a429f42 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -322,6 +322,7 @@ def init_configs( self.strategy.hf_config = hf_config self.tokenizer = tokenizer + self.enable_router_replay = transformer_config_kwargs.get("moe_enable_routing_replay", False) def configure_lora(self, lora_config, lora_type: Optional[str] = "lora"): if lora_type == "lora": @@ -401,6 +402,10 @@ def forward(self, data: TrainingInputBatch): """ Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. """ + from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + + setup_router_replay_forward(data, self.enable_router_replay) + # Run in micro batches grouped into a single mini-batch micro_bsz = self.cfg.micro_forward_batch_size_per_gpu micro_batches = data.chunk(micro_bsz) @@ -438,6 +443,7 @@ def forward(self, data: TrainingInputBatch): log_probs = log_probs.to("cpu") output = TrainingOutputBatch({"output": log_probs}) output.metadata = data.metadata + clear_router_replay() return output def save_hf_model(self, export_dir: str, tokenizer): @@ -593,6 +599,9 @@ def forward_backward( Returns: Aggregated metrics dict across all micro batches """ + from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + + setup_router_replay_forward(data, self.enable_router_replay) self.model.train() for chunk in self.actor_module: # if use distributed optimizer, zero grad buffer will be handled by optimizer @@ -665,6 +674,8 @@ def forward_backward( # Add loss_fn_outputs back (not reduced, kept as list) if all_loss_fn_outputs: status["loss_fn_outputs"] = all_loss_fn_outputs + + clear_router_replay() return status diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index effaa7c6cc..4c4f749670 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -440,6 +440,7 @@ class InferenceEngineConfig(BaseConfig): """Sets ``VLLM_ENABLE_V1_MULTIPROCESSING=0`` for reproducibility.""" enable_prefix_caching: bool = True enable_chunked_prefill: bool = True + enable_return_routed_experts: bool = False max_num_batched_tokens: int = 8192 enforce_eager: bool = True """Disable CUDA graphs for stability. Set to ``False`` for higher performance, diff --git a/skyrl/train/config/megatron_config/policy.yaml b/skyrl/train/config/megatron_config/policy.yaml index d6ef8382f9..b641d1f706 100644 --- a/skyrl/train/config/megatron_config/policy.yaml +++ b/skyrl/train/config/megatron_config/policy.yaml @@ -57,6 +57,7 @@ transformer_config_kwargs: recompute_modules: ["core_attn"] recompute_method: uniform recompute_num_layers: 1 + moe_enable_routing_replay: ${generator.enable_return_routed_experts} # flag to manually empty torch's cuda cache between the forward/backward pass and the optimizer step # this will free reserved but unallocated memory, and can help avoid OoMs in the optimizer diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index 8f4184b18d..9772eb5e9b 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -286,6 +286,7 @@ generator: vllm_v1_disable_multiproc: true enable_prefix_caching: true enable_chunked_prefill: true + enable_return_routed_experts: false max_num_batched_tokens: 8192 # Disable CUDA graphs by default for stability. Set to false for higher performance, but this may affect convergence for long-running and/or long context training jobs. enforce_eager: true diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index a17cea1b91..7c360e5de0 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -1,7 +1,7 @@ from typing import List, Tuple, Optional import torch from transformers import AutoTokenizer -from jaxtyping import Float +from jaxtyping import Float, Integer def _verify_inputs( @@ -32,6 +32,7 @@ def convert_prompts_responses_to_batch_tensors( rewards: List[List[float]], loss_masks: List[List[int]], logprobs: Optional[List[List[float]]] = None, + rollout_inference_indices: Optional[List[List[List[List[int]]]]] = None, ) -> Tuple[ Float[torch.Tensor, "batch seq_len"], Float[torch.Tensor, "batch seq_len"], @@ -39,6 +40,7 @@ def convert_prompts_responses_to_batch_tensors( Float[torch.Tensor, "batch response_len"], Float[torch.Tensor, "batch response_len"], Optional[Float[torch.Tensor, "batch response_len"]], + Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]], ]: """ Convert prompts and responses to batch tensors for training. @@ -129,4 +131,8 @@ def convert_prompts_responses_to_batch_tensors( ] logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float) - return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor + rollout_inference_indices_tensor = None + if rollout_inference_indices: + rollout_inference_indices_tensor = torch.tensor(rollout_inference_indices, dtype=torch.int32) + + return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor, rollout_inference_indices_tensor diff --git a/skyrl/train/entrypoints/main_base.py b/skyrl/train/entrypoints/main_base.py index 10e78e9de3..e3c2b980f1 100644 --- a/skyrl/train/entrypoints/main_base.py +++ b/skyrl/train/entrypoints/main_base.py @@ -68,6 +68,7 @@ def create_ray_wrapped_inference_engines_from_config( "backend": ie_cfg.backend, "engine_init_kwargs": ie_cfg.engine_init_kwargs, "enable_ray_prometheus_stats": ie_cfg.enable_ray_prometheus_stats, + "enable_return_routed_experts": ie_cfg.enable_return_routed_experts, } # Conditionally add LoRA parameters if LoRA is enabled diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index fabe5524e8..02150a80c4 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -39,6 +39,7 @@ class GeneratorOutput(TypedDict): rollout_metrics: Optional[Dict[str, Any]] rollout_logprobs: Optional[List[List[float]]] trajectory_ids: Optional[List[TrajectoryID]] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] # Applicable only for step-wise training is_last_step: Optional[List[bool]] diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 35c7e189f3..3e2fa6d344 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -66,6 +66,7 @@ class TurnOutput: output_logprobs: Optional[List[float]] new_obs: ConversationType obs_ids: List[int] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] reward: Optional[float] added_eos: bool = False @@ -300,11 +301,14 @@ async def agent_loop( output_ids = engine_output["response_ids"][0] stop_reason = engine_output["stop_reasons"][0] response_logprobs = engine_output.get("response_logprobs", None) + rollout_inference_indices = engine_output.get("rollout_inference_indices", None) if response_logprobs is not None: response_logprobs = response_logprobs[0] if self.custom_chat_template is not None: raise ValueError("Response Logprobs bookkeeping is not supported with custom chat template") + if rollout_inference_indices is not None: + rollout_inference_indices = rollout_inference_indices[0] # Append eos when sampling_params.stop is not None. Does not affect 3.a as chat templates add eos_token. # sampling_params is not None for eval, but None for training (which uses engine.sampling_params which are from cfg) stop_strs = current_sampling_params.get("stop", None) @@ -348,6 +352,7 @@ async def agent_loop( reward=step_reward, obs_ids=obs_ids, added_eos=added_eos, + rollout_inference_indices=rollout_inference_indices, ) if is_step_wise: diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index ada3afbc5c..8c966d8749 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -604,6 +604,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks: List[List[int]] = generator_output["loss_masks"] logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None) + rollout_inference_indices: Optional[List[List[List[List[int]]]]] = generator_output.get("rollout_inference_indices", None) ( sequences_tensor, @@ -612,6 +613,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards_tensor, loss_masks_tensor, rollout_logprobs_tensor, + rollout_inference_indices_tensor, ) = convert_prompts_responses_to_batch_tensors( self.tokenizer, prompt_ids, @@ -619,6 +621,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards, loss_masks, logprobs, + rollout_inference_indices, ) # sanity check for off_policy_correction @@ -639,6 +642,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis "rewards": rewards_tensor, "loss_mask": loss_masks_tensor, "rollout_logprobs": rollout_logprobs_tensor, + "rollout_inference_indices": rollout_inference_indices_tensor, "is_last_step": ( torch.tensor(generator_output["is_last_step"], dtype=torch.bool) if generator_output.get("is_last_step", None) is not None From 21b44dedbc639c8e3437b4529defdd636ddd152b Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 02:05:28 +0000 Subject: [PATCH 02/18] lint --- .../skyrl_train/inference_engines/base.py | 2 +- .../inference_engines/vllm/vllm_engine.py | 10 +++++---- .../skyrl_train/utils/replay_utils.py | 22 ++++++++++--------- .../workers/megatron/megatron_worker.py | 6 ++--- skyrl/train/dataset/preprocess.py | 10 ++++++++- skyrl/train/generators/base.py | 2 +- skyrl/train/generators/skyrl_gym_generator.py | 2 +- skyrl/train/trainer.py | 4 +++- 8 files changed, 36 insertions(+), 22 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index c3595ab7a9..ddecc965fb 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -29,7 +29,7 @@ class InferenceEngineOutput(TypedDict): response_ids: List[List[int]] stop_reasons: List[str] response_logprobs: Optional[List[List[float]]] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] class InferenceEngineInterface(ABC): diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index a0fa9ff4c6..66d4213f1c 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -332,13 +332,15 @@ def _create_engine(self, *args, **kwargs): enable_ray_prometheus_stats = kwargs.pop("enable_ray_prometheus_stats", False) enable_log_requests = kwargs.pop("enable_log_requests", False) max_log_len = kwargs.pop("max_log_len", None) - - # Log if enable_return_routed_experts is being passed + + # Log if enable_return_routed_experts is being passed if "enable_return_routed_experts" in kwargs: - logger.info(f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs") + logger.info( + f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs" + ) else: logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") - + if version.parse(vllm.__version__) >= version.parse("0.10.0"): engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) else: diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index d315779c07..9e3bac18a5 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -3,9 +3,10 @@ """ import torch -from typing import Optional, List +from typing import List from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch + def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch.Tensor]: if rollout_inference_indices is None: return None @@ -14,22 +15,23 @@ def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch per_layer = rollout_inference_indices.permute(2, 0, 1, 3).contiguous() return [per_layer[i] for i in range(per_layer.shape[0])] + def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: """ Set up router replay for forward pass (ref/policy inference). """ if not enable_router_replay: return False - + rollout_inference_indices = data.get("rollout_inference_indices") if rollout_inference_indices is None: return False - + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction - + RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - + return True @@ -39,24 +41,24 @@ def setup_router_replay_backward(data: TrainingInputBatch, enable_router_replay: """ if not enable_router_replay: return False - + rollout_inference_indices = data.get("rollout_inference_indices") if rollout_inference_indices is None: return False - + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction - + RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) # Use REPLAY_FORWARD - Megatron handles REPLAY_BACKWARD automatically RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - + return True def clear_router_replay(): """Clear all router replay state.""" from megatron.core.transformer.moe.router_replay import RouterReplay - + RouterReplay.clear_global_indices() RouterReplay.clear_global_router_replay_action() RouterReplay.clear_global_router_replay_instances() diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 850a429f42..3480f0ef27 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -403,7 +403,7 @@ def forward(self, data: TrainingInputBatch): Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. """ from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay - + setup_router_replay_forward(data, self.enable_router_replay) # Run in micro batches grouped into a single mini-batch @@ -600,7 +600,7 @@ def forward_backward( Aggregated metrics dict across all micro batches """ from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay - + setup_router_replay_forward(data, self.enable_router_replay) self.model.train() for chunk in self.actor_module: @@ -674,7 +674,7 @@ def forward_backward( # Add loss_fn_outputs back (not reduced, kept as list) if all_loss_fn_outputs: status["loss_fn_outputs"] = all_loss_fn_outputs - + clear_router_replay() return status diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index 7c360e5de0..01b332a625 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -135,4 +135,12 @@ def convert_prompts_responses_to_batch_tensors( if rollout_inference_indices: rollout_inference_indices_tensor = torch.tensor(rollout_inference_indices, dtype=torch.int32) - return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor, rollout_inference_indices_tensor + return ( + sequences, + attention_mask, + action_mask, + ret_rewards, + ret_loss_masks, + logprobs_tensor, + rollout_inference_indices_tensor, + ) diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index 02150a80c4..530fbbbe84 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -39,7 +39,7 @@ class GeneratorOutput(TypedDict): rollout_metrics: Optional[Dict[str, Any]] rollout_logprobs: Optional[List[List[float]]] trajectory_ids: Optional[List[TrajectoryID]] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] # Applicable only for step-wise training is_last_step: Optional[List[bool]] diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 3e2fa6d344..8f650418b3 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -66,7 +66,7 @@ class TurnOutput: output_logprobs: Optional[List[float]] new_obs: ConversationType obs_ids: List[int] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] + rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] reward: Optional[float] added_eos: bool = False diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 8c966d8749..f842868816 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -604,7 +604,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks: List[List[int]] = generator_output["loss_masks"] logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None) - rollout_inference_indices: Optional[List[List[List[List[int]]]]] = generator_output.get("rollout_inference_indices", None) + rollout_inference_indices: Optional[List[List[List[List[int]]]]] = generator_output.get( + "rollout_inference_indices", None + ) ( sequences_tensor, From 23fdc45972c2a20f6af56021553c2d0d260f9417 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 02:25:35 +0000 Subject: [PATCH 03/18] make opus take a pass at test + plumbing fully thru generator --- skyrl/train/generators/skyrl_gym_generator.py | 77 ++++++++- .../gpu/gpu_ci/test_router_replay.py | 163 ++++++++++++++++++ 2 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 8f650418b3..67bc85a720 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -40,6 +40,7 @@ class TrajectoryOutput: prompt_ids: List[int] rollout_logprobs: Optional[List[float]] env_metrics: Dict[str, Any] + rollout_inference_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -57,6 +58,7 @@ class AgentLoopState: rollout_logprobs: Optional[List[float]] response_end_idx: Optional[int] done: bool + rollout_inference_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -66,10 +68,31 @@ class TurnOutput: output_logprobs: Optional[List[float]] new_obs: ConversationType obs_ids: List[int] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] + rollout_inference_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk] reward: Optional[float] added_eos: bool = False + def get_turn_rollout_inference_indices(self) -> Optional[List[List[List[int]]]]: + """ + Get rollout inference indices for this turn's tokens (output + observation). + + Returns indices for generated output tokens, with padding entries (all -1) + for any manually-added EOS token and observation tokens. + Returns None if rollout_inference_indices is None. + """ + if self.rollout_inference_indices is None: + return None + if not self.rollout_inference_indices: + return self.rollout_inference_indices + layer_num = len(self.rollout_inference_indices[0]) + topk = len(self.rollout_inference_indices[0][0]) if layer_num > 0 else 0 + pad_entry = [[-1] * topk for _ in range(layer_num)] + indices = list(self.rollout_inference_indices) + if self.added_eos: + indices.append(pad_entry) + indices.extend(pad_entry for _ in range(len(self.obs_ids))) + return indices + def get_turn_loss_mask(self) -> List[int]: """ Get loss mask for this turn's tokens. @@ -355,6 +378,9 @@ async def agent_loop( rollout_inference_indices=rollout_inference_indices, ) + if turn_output.rollout_inference_indices is not None and agent_loop_state.rollout_inference_indices is None: + agent_loop_state.rollout_inference_indices = [] + if is_step_wise: # current response + observation ids turn_response_ids = turn_output.output_ids + turn_output.obs_ids @@ -372,6 +398,7 @@ async def agent_loop( rollout_logprobs=turn_response_logprobs, stop_reason=stop_reason, env_metrics=env.get_metrics() if agent_loop_state.done else {}, + rollout_inference_indices=turn_output.get_turn_rollout_inference_indices(), ) agent_loop_output.step_outputs.append(per_step_output) @@ -400,6 +427,7 @@ async def agent_loop( prompt_ids = agent_loop_state.input_ids[:initial_prompt_length] rollout_logprobs = None + rollout_inference_indices_out = None response_ids = None # Prepare the final loss_mask, response_ids and rollout_logprobs . @@ -430,6 +458,10 @@ async def agent_loop( rollout_logprobs = agent_loop_state.rollout_logprobs[ : agent_loop_state.response_end_idx - initial_prompt_length + 1 ] + if agent_loop_state.rollout_inference_indices is not None: + rollout_inference_indices_out = agent_loop_state.rollout_inference_indices[ + : agent_loop_state.response_end_idx - initial_prompt_length + 1 + ] # fix index for per_step_rewards per_step_rewards = [(reward, idx - initial_prompt_length) for reward, idx in per_step_rewards] assert len(loss_mask) == len( @@ -446,6 +478,10 @@ async def agent_loop( loss_mask.append(1) if rollout_logprobs is not None: rollout_logprobs.append(0.0) + if rollout_inference_indices_out is not None and rollout_inference_indices_out: + layer_num = len(rollout_inference_indices_out[0]) + topk = len(rollout_inference_indices_out[0][0]) if layer_num > 0 else 0 + rollout_inference_indices_out.append([[-1] * topk for _ in range(layer_num)]) appended_eos_token = True if self.generator_cfg.step_wise_trajectories: @@ -465,6 +501,7 @@ async def agent_loop( prompt_ids=prompt_ids, rollout_logprobs=rollout_logprobs, env_metrics=env_metrics, + rollout_inference_indices=rollout_inference_indices_out, ) return agent_loop_output @@ -616,12 +653,14 @@ async def generate_batched( responses = engine_output["response_ids"] stop_reasons = engine_output["stop_reasons"] logprobs = engine_output.get("response_logprobs", None) + raw_rollout_inference_indices = engine_output.get("rollout_inference_indices", None) truncated_responses = [] rewards = [] loss_masks = [] env_metrics = [] truncated_logprobs: Optional[List[List[float]]] = [] if logprobs is not None else None + truncated_indices: Optional[List] = [] if raw_rollout_inference_indices is not None else None for i, (output, response, env, env_class) in enumerate(zip(outputs, responses, envs, env_classes)): # step on environment and compute reward @@ -636,6 +675,9 @@ async def generate_batched( if logprobs is not None: sample_logprobs = logprobs[i][: len(response)] truncated_logprobs.append(sample_logprobs) + if raw_rollout_inference_indices is not None: + sample_indices = raw_rollout_inference_indices[i][: len(response)] + truncated_indices.append(sample_indices) # Get environment-specific metrics env_metrics.append(env.get_metrics()) @@ -655,6 +697,7 @@ async def generate_batched( "stop_reasons": stop_reasons, "rollout_metrics": rollout_metrics, "rollout_logprobs": truncated_logprobs, + "rollout_inference_indices": truncated_indices, } return generator_output @@ -755,6 +798,18 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False else: rollout_logprobs = None + if self.generator_cfg.step_wise_trajectories: + all_indices = sum( + [ + [step_output.rollout_inference_indices for step_output in output.step_outputs] + for output in all_outputs + ], + [], + ) + else: + all_indices = [output.rollout_inference_indices for output in all_outputs] + rollout_inference_indices = all_indices if any(idx is not None for idx in all_indices) else None + rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) if self.generator_cfg.zero_reward_on_non_stop: @@ -773,6 +828,7 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False "rollout_metrics": rollout_metrics, "rollout_logprobs": rollout_logprobs, "trajectory_ids": out_trajectory_ids, + "rollout_inference_indices": rollout_inference_indices, "is_last_step": is_last_step, } @@ -840,6 +896,8 @@ def _update_agent_state_by_retokenizing_chat_history( agent_loop_state.response_end_idx = None # `logprobs` are not computed because retokenizing breaks token-in-token-out agent_loop_state.rollout_logprobs = None + # indices are not meaningful when retokenizing + agent_loop_state.rollout_inference_indices = None return agent_loop_state def _update_agent_loop_state_with_multiturn_chat_template( @@ -891,6 +949,8 @@ def _update_agent_loop_state_with_multiturn_chat_template( loss_mask_for_turn = turn_output.get_turn_loss_mask() rollout_logprobs_for_turn = turn_output.get_turn_rollout_logprobs() + rollout_inference_indices_for_turn = turn_output.get_turn_rollout_inference_indices() + if self.generator_cfg.step_wise_trajectories: # cumulative input_ids is not tracked for step wise training agent_loop_state.response_end_idx = len(turn_output.output_ids) - 1 @@ -905,6 +965,11 @@ def _update_agent_loop_state_with_multiturn_chat_template( agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn + if ( + agent_loop_state.rollout_inference_indices is not None + and rollout_inference_indices_for_turn is not None + ): + agent_loop_state.rollout_inference_indices += rollout_inference_indices_for_turn return agent_loop_state @@ -969,11 +1034,21 @@ def _update_agent_loop_state_with_singleturn_chat_template( obs_ids_to_add ) + rollout_inference_indices_for_turn = None + if turn_output.rollout_inference_indices is not None and turn_output.rollout_inference_indices: + layer_num = len(turn_output.rollout_inference_indices[0]) + topk = len(turn_output.rollout_inference_indices[0][0]) if layer_num > 0 else 0 + pad_entry = [[-1] * topk for _ in range(layer_num)] + rollout_inference_indices_for_turn = list(turn_output.rollout_inference_indices[: len(new_resp_tokens)]) + rollout_inference_indices_for_turn.extend(pad_entry for _ in range(len(obs_ids_to_add))) + # Directly append turn output agent_loop_state.response_end_idx = len(agent_loop_state.input_ids) + len(new_resp_tokens) - 1 agent_loop_state.input_ids += turn_ids agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn + if agent_loop_state.rollout_inference_indices is not None and rollout_inference_indices_for_turn is not None: + agent_loop_state.rollout_inference_indices += rollout_inference_indices_for_turn return agent_loop_state diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py new file mode 100644 index 0000000000..eedbc61282 --- /dev/null +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -0,0 +1,163 @@ +""" +Run with: +uv run --isolated --extra dev --extra megatron -- pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +""" + +import ray +import pytest +import asyncio +from transformers import AutoTokenizer +from tests.backends.skyrl_train.gpu.utils import ( + InferenceEngineState, + get_test_generator_input, + Timer, +) +from skyrl.train.utils.utils import validate_cfg +from skyrl.train.config import ( + SkyRLTrainConfig, + SamplingParams, +) +from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator +from skyrl.train.generators.base import GeneratorInput +from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend + +MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" + + +def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: + cfg = SkyRLTrainConfig() + cfg.trainer.policy.model.path = model_name + cfg.trainer.micro_forward_batch_size_per_gpu = 2 + cfg.trainer.micro_train_batch_size_per_gpu = 2 + cfg.trainer.use_sample_packing = False + cfg.trainer.logger = "console" + + validate_cfg(cfg) + + return cfg + + +@pytest.mark.megatron +def test_megatron_router_replay(ray_init_fixture): + """ + Test that SkyRLGymGenerator returns rollout_inference_indices + for MoE models with enable_return_routed_experts=True. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 2 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=64, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 1 + cfg.generator.use_conversation_multi_turn = True + cfg.generator.apply_overlong_filtering = False + cfg.generator.zero_reward_on_non_stop = False + + num_prompts = 2 + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.8, + ) as engines: + client = engines.client + + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=num_prompts, + n_samples_per_prompt=1, + max_prompt_length=512, + env_class="gsm8k", + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=64, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + # --- Basic output checks --- + assert ( + "rollout_inference_indices" in generator_output + ), "rollout_inference_indices missing from GeneratorOutput" + indices = generator_output["rollout_inference_indices"] + assert ( + indices is not None + ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + + responses = generator_output["response_ids"] + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + + # --- Shape & value validation per sample --- + for i, (sample_indices, sample_response) in enumerate(zip(indices, responses)): + response_len = len(sample_response) + assert ( + len(sample_indices) == response_len + ), f"Sample {i}: indices length {len(sample_indices)} != response length {response_len}" + + if response_len == 0: + continue + + # Each token position should have [layer_num, topk] structure + layer_num = len(sample_indices[0]) + assert layer_num > 0, f"Sample {i}: expected > 0 MoE layers, got {layer_num}" + + topk = len(sample_indices[0][0]) + assert topk > 0, f"Sample {i}: expected topk > 0, got {topk}" + + for t, token_indices in enumerate(sample_indices): + assert ( + len(token_indices) == layer_num + ), f"Sample {i}, token {t}: expected {layer_num} layers, got {len(token_indices)}" + for l_idx, layer_indices in enumerate(token_indices): + assert ( + len(layer_indices) == topk + ), f"Sample {i}, token {t}, layer {l_idx}: expected topk={topk}, got {len(layer_indices)}" + for k, expert_id in enumerate(layer_indices): + assert isinstance(expert_id, int), ( + f"Sample {i}, token {t}, layer {l_idx}, k {k}: " + f"expected int expert id, got {type(expert_id)}" + ) + assert expert_id >= 0, ( + f"Sample {i}, token {t}, layer {l_idx}, k {k}: " + f"expected non-negative expert id, got {expert_id}" + ) + + print("Router replay test passed:") + print(f" Batch size: {len(indices)}") + print(f" Response lengths: {[len(r) for r in responses]}") + if indices and indices[0]: + print(f" Layers: {len(indices[0][0])}, TopK: {len(indices[0][0][0])}") + + finally: + ray.shutdown() From 8daac59a53d0a244cf322dc64c45b7798ebbd2a5 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Wed, 4 Mar 2026 10:31:59 +0000 Subject: [PATCH 04/18] updated test utils and file to support rollout replay indices --- .../inference_engine_client.py | 16 +++ .../ray_wrapped_inference_engine.py | 2 + .../workers/megatron/megatron_worker.py | 24 ++++ .../gpu/gpu_ci/test_router_replay.py | 118 +++++++++++++++++- tests/backends/skyrl_train/gpu/utils.py | 1 + 5 files changed, 157 insertions(+), 4 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index 02bad0d7be..19a6724395 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -153,8 +153,10 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu stop_reasons: list[str] = [""] * n response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)] response_ids: List[List[int]] = [[] for _ in range(n)] + rollout_inference_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)] # a bit hacky for now add_resp_logprobs = False + add_rollout_inference_indices = False for indices, result in zip(indices_list, results): for local_idx, original_idx in enumerate(indices): @@ -164,12 +166,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu if result.get("response_logprobs", None): add_resp_logprobs = True response_logprobs[original_idx] = result["response_logprobs"][local_idx] + if result.get("rollout_inference_indices", None): + add_rollout_inference_indices = True + rollout_inference_indices[original_idx] = result["rollout_inference_indices"][local_idx] return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs if add_resp_logprobs else None, + rollout_inference_indices=rollout_inference_indices if add_rollout_inference_indices else None, ) def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int: @@ -265,6 +271,7 @@ async def _generate_single_with_retry( # 2. Initialize fields we want to accumulate or update in each loop iteration accum_response_ids: List[int] = [] accum_response_logprobs: List[float] = [] + accum_rollout_inference_indices: List[List[List[int]]] = [] stop_reason: str = "abort" # We only use it if generation is completed in one turn to maintain original behavior with no retry. @@ -300,6 +307,10 @@ async def _generate_single_with_retry( new_response_logprobs_list: Optional[List[List[float]]] = partial_response.get("response_logprobs", None) if new_response_logprobs_list is not None and len(new_response_logprobs_list) > 0: new_response_logprobs = new_response_logprobs_list[0] + new_rollout_inference_indices: Optional[List[List[List[int]]]] = None + new_rollout_inference_indices_list = partial_response.get("rollout_inference_indices", None) + if new_rollout_inference_indices_list is not None and len(new_rollout_inference_indices_list) > 0: + new_rollout_inference_indices = new_rollout_inference_indices_list[0] # 3.4 Aborted without generating tokens, so partial_response is useless. if stop_reason == "abort" and len(new_response_ids) == 0: @@ -309,6 +320,8 @@ async def _generate_single_with_retry( accum_response_ids.extend(new_response_ids) if new_response_logprobs is not None: accum_response_logprobs.extend(new_response_logprobs) + if new_rollout_inference_indices is not None: + accum_rollout_inference_indices.extend(new_rollout_inference_indices) num_turns += 1 # 4. Build the final response and return. @@ -321,6 +334,9 @@ async def _generate_single_with_retry( stop_reasons=[stop_reason], response_ids=[accum_response_ids], response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None, + rollout_inference_indices=[accum_rollout_inference_indices] + if len(accum_rollout_inference_indices) > 0 + else None, ) async def _chat_completion_with_retry( diff --git a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 6e238f492f..0ec3c29ac3 100644 --- a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -107,6 +107,7 @@ def create_ray_wrapped_inference_engines( rope_scaling: Dict[str, Any] = {}, rope_theta: float | None = None, enable_ray_prometheus_stats: bool = False, + enable_return_routed_experts: bool = False, served_model_name: str | None = None, ) -> List[InferenceEngineInterface]: """ @@ -249,6 +250,7 @@ def create_ray_wrapped_inference_engines( max_num_seqs=max_num_seqs, max_logprobs=1, # only need chosen-token logprobs enable_ray_prometheus_stats=enable_ray_prometheus_stats, + enable_return_routed_experts=enable_return_routed_experts, **dp_kwargs, **engine_init_kwargs, **lora_kwargs, diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 3480f0ef27..9cb8e1a2e4 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -253,6 +253,28 @@ def extract_weights(self, dtype: torch.dtype): class MegatronWorker: + def _read_router_replay_state(self): + from megatron.core.transformer.moe.router_replay import RouterReplay + + # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info + global_indices = getattr(RouterReplay, "global_indices", None) or [] + instances = getattr(RouterReplay, "global_router_replay_instances", None) or [] + + # Track size to check if shapes are valid after replay / we consume only layers we need + replay_backward_total_entries = 0 + for instance in instances: + replay_backward_total_entries += len(getattr(instance, "replay_backward_list", None) or []) + + return { + "action": str(getattr(RouterReplay, "global_router_replay_action", None)), + "global_indices": [x.detach().cpu() for x in global_indices], + "num_instances": len(instances), + "replay_backward_total_entries": replay_backward_total_entries, + } + + def get_last_router_replay_state(self): + return getattr(self, "_last_router_replay_state", None) + def init_configs( self, model_path, @@ -443,6 +465,7 @@ def forward(self, data: TrainingInputBatch): log_probs = log_probs.to("cpu") output = TrainingOutputBatch({"output": log_probs}) output.metadata = data.metadata + self._last_router_replay_state = self._read_router_replay_state() clear_router_replay() return output @@ -675,6 +698,7 @@ def forward_backward( if all_loss_fn_outputs: status["loss_fn_outputs"] = all_loss_fn_outputs + self._last_router_replay_state = self._read_router_replay_state() clear_router_replay() return status diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index eedbc61282..69c25a9d07 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -6,11 +6,13 @@ import ray import pytest import asyncio +import torch from transformers import AutoTokenizer from tests.backends.skyrl_train.gpu.utils import ( InferenceEngineState, get_test_generator_input, Timer, + init_worker_with_type, ) from skyrl.train.utils.utils import validate_cfg from skyrl.train.config import ( @@ -20,6 +22,9 @@ from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator from skyrl.train.generators.base import GeneratorInput from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend +from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.backends.skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" @@ -49,7 +54,7 @@ def test_megatron_router_replay(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 2 cfg.generator.sampling_params = SamplingParams( - max_generate_length=64, + max_generate_length=16, logprobs=1, temperature=1.0, ) @@ -59,7 +64,7 @@ def test_megatron_router_replay(ray_init_fixture): cfg.generator.apply_overlong_filtering = False cfg.generator.zero_reward_on_non_stop = False - num_prompts = 2 + num_prompts = 1 tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) @@ -69,7 +74,8 @@ def test_megatron_router_replay(ray_init_fixture): use_local=True, backend="vllm", sleep_level=1, - gpu_memory_utilization=0.8, + gpu_memory_utilization=0.9, + max_num_seqs=1, ) as engines: client = engines.client @@ -95,7 +101,7 @@ def test_megatron_router_replay(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=64, + max_generate_length=16, min_p=0.0, logprobs=1, ), @@ -152,12 +158,116 @@ def test_megatron_router_replay(ray_init_fixture): f"Sample {i}, token {t}, layer {l_idx}, k {k}: " f"expected non-negative expert id, got {expert_id}" ) + from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices + replay_tensor = torch.tensor(indices, dtype=torch.long) + per_layer_replay = _split_replay_indices(replay_tensor) + reconstructed = torch.stack(per_layer_replay, dim=2) + assert torch.equal( + replay_tensor, reconstructed + ), "Replay index translation changed values between vLLM and Megatron layout" + + prompt_ids = generator_output["prompt_token_ids"] + rollout_logprobs = generator_output.get("rollout_logprobs", None) + loss_masks = generator_output["loss_masks"] + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[reward] * len(response) for reward, response in zip(rewards, responses)] + + ( + sequences_tensor, + attention_masks_tensor, + response_masks_tensor, + rewards_tensor, + loss_masks_tensor, + rollout_logprobs_tensor, + rollout_inference_indices_tensor, + ) = convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=prompt_ids, + responses=responses, + rewards=rewards, + loss_masks=loss_masks, + logprobs=rollout_logprobs, + rollout_inference_indices=indices, + ) + + assert rollout_inference_indices_tensor is not None + assert rollout_inference_indices_tensor.shape[0] == len(responses) + assert rollout_inference_indices_tensor.shape[1] == response_masks_tensor.shape[1] + + training_input = TrainingInputBatch( + { + "sequences": sequences_tensor, + "attention_mask": attention_masks_tensor, + "response_mask": response_masks_tensor, + "rewards": rewards_tensor, + "loss_mask": loss_masks_tensor, + "rollout_logprobs": rollout_logprobs_tensor, + "rollout_inference_indices": rollout_inference_indices_tensor, + } + ) + training_input.metadata = {"response_length": response_masks_tensor.shape[1]} + assert training_input["rollout_inference_indices"] is not None + + cfg.trainer.policy.megatron_config.transformer_config_kwargs = {} + cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_enable_routing_replay"] = True + cfg.trainer.placement.policy_num_gpus_per_node = 2 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + num_actions = response_masks_tensor.shape[1] + batch_size = sequences_tensor.shape[0] + if training_input.get("rollout_logprobs") is None: + training_input["rollout_logprobs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) + training_input["action_log_probs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) + training_input["base_action_log_probs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) + training_input["advantages"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) + training_input["action_mask"] = response_masks_tensor.to(dtype=torch.int64) + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=2, + cfg=cfg, + ) + + forward_refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) + all_rank_forward_outputs = ray.get(forward_refs) + forward_output = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, all_rank_forward_outputs)[ + "output" + ] + expected_per_layer = _split_replay_indices(training_input["rollout_inference_indices"].to(torch.long)) + forward_state = ray.get(actor_group.async_run_ray_method("pass_through", "get_last_router_replay_state"))[0] + + assert forward_state is not None + assert len(forward_state["global_indices"]) == len(expected_per_layer) + for got, expected in zip(forward_state["global_indices"], expected_per_layer): + assert torch.equal(got.to(torch.long), expected.to(torch.long)) + assert forward_state["replay_backward_total_entries"] > 0 + + training_input.metadata["global_step"] = 0 + fb_results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", training_input)) + assert isinstance(fb_results[0], dict) + assert "policy_loss" in fb_results[0] + fb_state = ray.get(actor_group.async_run_ray_method("pass_through", "get_last_router_replay_state"))[0] + assert fb_state is not None + assert len(fb_state["global_indices"]) == len(expected_per_layer) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) print("Router replay test passed:") print(f" Batch size: {len(indices)}") print(f" Response lengths: {[len(r) for r in responses]}") if indices and indices[0]: print(f" Layers: {len(indices[0][0])}, TopK: {len(indices[0][0][0])}") + print(f" Handoff replay tensor shape: {tuple(rollout_inference_indices_tensor.shape)}") + print(f" Megatron forward shape: {tuple(forward_output.shape)}") finally: ray.shutdown() diff --git a/tests/backends/skyrl_train/gpu/utils.py b/tests/backends/skyrl_train/gpu/utils.py index f6d726b902..a7f7545331 100644 --- a/tests/backends/skyrl_train/gpu/utils.py +++ b/tests/backends/skyrl_train/gpu/utils.py @@ -516,6 +516,7 @@ def create( sleep_level=sleep_level, enable_lora=enable_lora, engine_init_kwargs=ie_cfg.engine_init_kwargs, + enable_return_routed_experts=ie_cfg.enable_return_routed_experts, served_model_name=served_model_name, ) client = InferenceEngineClient( From 647426fff9d93c3da286f4da00ea0d2437db287e Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Wed, 4 Mar 2026 12:03:06 +0000 Subject: [PATCH 05/18] add helper functions for router visibility and megatron testing, successful! --- .../skyrl_train/utils/replay_utils.py | 3 +- .../workers/megatron/megatron_worker.py | 31 +-- .../gpu/gpu_ci/test_router_replay.py | 219 ++++++------------ 3 files changed, 94 insertions(+), 159 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 9e3bac18a5..2e5b166e1b 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -13,7 +13,8 @@ def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch if rollout_inference_indices.dim() != 4: raise ValueError(f"Expected 4D replay indices, got shape {rollout_inference_indices.shape}") per_layer = rollout_inference_indices.permute(2, 0, 1, 3).contiguous() - return [per_layer[i] for i in range(per_layer.shape[0])] + # flatten [batch, seq, topk] to [batch * seq, topk] for each layer + return [per_layer[i].reshape(-1, per_layer.shape[-1]) for i in range(per_layer.shape[0])] def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 9cb8e1a2e4..d4cf00b0f0 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -254,26 +254,31 @@ def extract_weights(self, dtype: torch.dtype): class MegatronWorker: def _read_router_replay_state(self): + """Read the current RouterReplay state from all instances.""" from megatron.core.transformer.moe.router_replay import RouterReplay # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info - global_indices = getattr(RouterReplay, "global_indices", None) or [] - instances = getattr(RouterReplay, "global_router_replay_instances", None) or [] + instances = RouterReplay.global_router_replay_instances or [] + action = instances[0].router_replay_action if instances else None - # Track size to check if shapes are valid after replay / we consume only layers we need - replay_backward_total_entries = 0 - for instance in instances: - replay_backward_total_entries += len(getattr(instance, "replay_backward_list", None) or []) + target_indices = [ + inst.target_topk_idx.detach().cpu() + for inst in instances if inst.target_topk_idx is not None + ] return { - "action": str(getattr(RouterReplay, "global_router_replay_action", None)), - "global_indices": [x.detach().cpu() for x in global_indices], + "action": str(action), + "target_indices": target_indices, "num_instances": len(instances), - "replay_backward_total_entries": replay_backward_total_entries, } - def get_last_router_replay_state(self): - return getattr(self, "_last_router_replay_state", None) + def debug_setup_router_replay_state(self, data: TrainingInputBatch): + from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + + setup_router_replay_forward(data, enable_router_replay=True) + state = self._read_router_replay_state() + clear_router_replay() + return state def init_configs( self, @@ -424,7 +429,7 @@ def forward(self, data: TrainingInputBatch): """ Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. """ - from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay setup_router_replay_forward(data, self.enable_router_replay) @@ -622,7 +627,7 @@ def forward_backward( Returns: Aggregated metrics dict across all micro batches """ - from skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay setup_router_replay_forward(data, self.enable_router_replay) self.model.train() diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 69c25a9d07..26ed26c7fb 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -15,18 +15,16 @@ init_worker_with_type, ) from skyrl.train.utils.utils import validate_cfg -from skyrl.train.config import ( - SkyRLTrainConfig, - SamplingParams, -) +from skyrl.train.config import SkyRLTrainConfig, SamplingParams from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator from skyrl.train.generators.base import GeneratorInput from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch -from skyrl.backends.skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch +from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +REPLAY_NUM_LAYERS = 2 def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: @@ -36,9 +34,7 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: cfg.trainer.micro_train_batch_size_per_gpu = 2 cfg.trainer.use_sample_packing = False cfg.trainer.logger = "console" - validate_cfg(cfg) - return cfg @@ -124,150 +120,83 @@ def test_megatron_router_replay(ray_init_fixture): responses ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" - # --- Shape & value validation per sample --- - for i, (sample_indices, sample_response) in enumerate(zip(indices, responses)): - response_len = len(sample_response) - assert ( - len(sample_indices) == response_len - ), f"Sample {i}: indices length {len(sample_indices)} != response length {response_len}" - - if response_len == 0: - continue - - # Each token position should have [layer_num, topk] structure - layer_num = len(sample_indices[0]) - assert layer_num > 0, f"Sample {i}: expected > 0 MoE layers, got {layer_num}" - - topk = len(sample_indices[0][0]) - assert topk > 0, f"Sample {i}: expected topk > 0, got {topk}" - - for t, token_indices in enumerate(sample_indices): - assert ( - len(token_indices) == layer_num - ), f"Sample {i}, token {t}: expected {layer_num} layers, got {len(token_indices)}" - for l_idx, layer_indices in enumerate(token_indices): - assert ( - len(layer_indices) == topk - ), f"Sample {i}, token {t}, layer {l_idx}: expected topk={topk}, got {len(layer_indices)}" - for k, expert_id in enumerate(layer_indices): - assert isinstance(expert_id, int), ( - f"Sample {i}, token {t}, layer {l_idx}, k {k}: " - f"expected int expert id, got {type(expert_id)}" - ) - assert expert_id >= 0, ( - f"Sample {i}, token {t}, layer {l_idx}, k {k}: " - f"expected non-negative expert id, got {expert_id}" - ) - from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices - replay_tensor = torch.tensor(indices, dtype=torch.long) - per_layer_replay = _split_replay_indices(replay_tensor) - reconstructed = torch.stack(per_layer_replay, dim=2) - assert torch.equal( - replay_tensor, reconstructed - ), "Replay index translation changed values between vLLM and Megatron layout" - - prompt_ids = generator_output["prompt_token_ids"] - rollout_logprobs = generator_output.get("rollout_logprobs", None) - loss_masks = generator_output["loss_masks"] - rewards = generator_output["rewards"] - if rewards and not isinstance(rewards[0], list): - rewards = [[reward] * len(response) for reward, response in zip(rewards, responses)] + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_inference_indices=indices, + ) - ( - sequences_tensor, - attention_masks_tensor, - response_masks_tensor, - rewards_tensor, - loss_masks_tensor, - rollout_logprobs_tensor, - rollout_inference_indices_tensor, - ) = convert_prompts_responses_to_batch_tensors( - tokenizer=tokenizer, - prompts=prompt_ids, - responses=responses, - rewards=rewards, - loss_masks=loss_masks, - logprobs=rollout_logprobs, - rollout_inference_indices=indices, - ) + assert rii_tensor is not None + rii_tensor = rii_tensor[:, :, :REPLAY_NUM_LAYERS, :] + + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch({ + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": logprobs_t if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32), + "rollout_inference_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + }) + training_input.metadata = {"response_length": num_actions} + + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + "num_layers": REPLAY_NUM_LAYERS, + "moe_enable_routing_replay": True, + } + cfg.trainer.placement.policy_num_gpus_per_node = 2 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + actor_group = init_worker_with_type( + "policy", shared_pg=None, colocate_all=False, + num_gpus_per_node=2, cfg=cfg, + ) - assert rollout_inference_indices_tensor is not None - assert rollout_inference_indices_tensor.shape[0] == len(responses) - assert rollout_inference_indices_tensor.shape[1] == response_masks_tensor.shape[1] + expected_per_layer = _split_replay_indices(rii_tensor.to(torch.long)) - training_input = TrainingInputBatch( - { - "sequences": sequences_tensor, - "attention_mask": attention_masks_tensor, - "response_mask": response_masks_tensor, - "rewards": rewards_tensor, - "loss_mask": loss_masks_tensor, - "rollout_logprobs": rollout_logprobs_tensor, - "rollout_inference_indices": rollout_inference_indices_tensor, - } + state = ray.get( + actor_group.async_run_ray_method( + "pass_through", "debug_setup_router_replay_state", + data=training_input, ) - training_input.metadata = {"response_length": response_masks_tensor.shape[1]} - assert training_input["rollout_inference_indices"] is not None - - cfg.trainer.policy.megatron_config.transformer_config_kwargs = {} - cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_enable_routing_replay"] = True - cfg.trainer.placement.policy_num_gpus_per_node = 2 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 - cfg.trainer.micro_forward_batch_size_per_gpu = 1 - cfg.trainer.micro_train_batch_size_per_gpu = 1 + )[0] - num_actions = response_masks_tensor.shape[1] - batch_size = sequences_tensor.shape[0] - if training_input.get("rollout_logprobs") is None: - training_input["rollout_logprobs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) - training_input["action_log_probs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) - training_input["base_action_log_probs"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) - training_input["advantages"] = torch.zeros((batch_size, num_actions), dtype=torch.float32) - training_input["action_mask"] = response_masks_tensor.to(dtype=torch.int64) - - actor_group = init_worker_with_type( - "policy", - shared_pg=None, - colocate_all=False, - num_gpus_per_node=2, - cfg=cfg, + assert state is not None, "Worker returned None state" + assert "REPLAY_FORWARD" in state["action"], ( + f"RouterReplay action should be REPLAY_FORWARD, got: {state['action']}" + ) + assert state["num_instances"] == len(expected_per_layer), ( + f"Expected {len(expected_per_layer)} replay instances (one per layer), " + f"got {state['num_instances']}" + ) + for layer_idx, (got, expected) in enumerate( + zip(state["target_indices"], expected_per_layer) + ): + assert torch.equal(got.to(torch.long), expected.to(torch.long)), ( + f"Layer {layer_idx}: Megatron target indices differ from vLLM indices" ) - - forward_refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) - all_rank_forward_outputs = ray.get(forward_refs) - forward_output = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, all_rank_forward_outputs)[ - "output" - ] - expected_per_layer = _split_replay_indices(training_input["rollout_inference_indices"].to(torch.long)) - forward_state = ray.get(actor_group.async_run_ray_method("pass_through", "get_last_router_replay_state"))[0] - - assert forward_state is not None - assert len(forward_state["global_indices"]) == len(expected_per_layer) - for got, expected in zip(forward_state["global_indices"], expected_per_layer): - assert torch.equal(got.to(torch.long), expected.to(torch.long)) - assert forward_state["replay_backward_total_entries"] > 0 - - training_input.metadata["global_step"] = 0 - fb_results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", training_input)) - assert isinstance(fb_results[0], dict) - assert "policy_loss" in fb_results[0] - fb_state = ray.get(actor_group.async_run_ray_method("pass_through", "get_last_router_replay_state"))[0] - assert fb_state is not None - assert len(fb_state["global_indices"]) == len(expected_per_layer) - ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) - - print("Router replay test passed:") - print(f" Batch size: {len(indices)}") - print(f" Response lengths: {[len(r) for r in responses]}") - if indices and indices[0]: - print(f" Layers: {len(indices[0][0])}, TopK: {len(indices[0][0][0])}") - print(f" Handoff replay tensor shape: {tuple(rollout_inference_indices_tensor.shape)}") - print(f" Megatron forward shape: {tuple(forward_output.shape)}") + print(f"PASSED: vLLM routing indices ({rii_tensor.shape}) correctly " + f"loaded into {state['num_instances']} Megatron RouterReplay instances") finally: ray.shutdown() From d4b753fcef83c8aa37f6235a540af7e11607719a Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Wed, 4 Mar 2026 12:11:49 +0000 Subject: [PATCH 06/18] linter --- .../inference_engine_client.py | 6 +- .../workers/megatron/megatron_worker.py | 7 +- .../gpu/gpu_ci/test_router_replay.py | 86 +++++++++++-------- 3 files changed, 53 insertions(+), 46 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index 19a6724395..264e060cd1 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -334,9 +334,9 @@ async def _generate_single_with_retry( stop_reasons=[stop_reason], response_ids=[accum_response_ids], response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None, - rollout_inference_indices=[accum_rollout_inference_indices] - if len(accum_rollout_inference_indices) > 0 - else None, + rollout_inference_indices=( + [accum_rollout_inference_indices] if len(accum_rollout_inference_indices) > 0 else None + ), ) async def _chat_completion_with_retry( diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index d4cf00b0f0..d01a5ad061 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -257,14 +257,11 @@ def _read_router_replay_state(self): """Read the current RouterReplay state from all instances.""" from megatron.core.transformer.moe.router_replay import RouterReplay - # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info + # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info instances = RouterReplay.global_router_replay_instances or [] action = instances[0].router_replay_action if instances else None - target_indices = [ - inst.target_topk_idx.detach().cpu() - for inst in instances if inst.target_topk_idx is not None - ] + target_indices = [inst.target_topk_idx.detach().cpu() for inst in instances if inst.target_topk_idx is not None] return { "action": str(action), diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 26ed26c7fb..13fb4f51cc 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -123,14 +123,16 @@ def test_megatron_router_replay(ray_init_fixture): rewards = generator_output["rewards"] if rewards and not isinstance(rewards[0], list): rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] - (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = convert_prompts_responses_to_batch_tensors( - tokenizer=tokenizer, - prompts=generator_output["prompt_token_ids"], - responses=responses, - rewards=rewards, - loss_masks=generator_output["loss_masks"], - logprobs=generator_output.get("rollout_logprobs"), - rollout_inference_indices=indices, + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_inference_indices=indices, + ) ) assert rii_tensor is not None @@ -138,20 +140,25 @@ def test_megatron_router_replay(ray_init_fixture): num_actions = response_mask.shape[1] batch_size = sequences.shape[0] - training_input = TrainingInputBatch({ - "sequences": sequences, - "attention_mask": attention_mask, - "response_mask": response_mask, - "rewards": rewards_t, - "loss_mask": loss_mask_t, - "rollout_logprobs": logprobs_t if logprobs_t is not None - else torch.zeros((batch_size, num_actions), dtype=torch.float32), - "rollout_inference_indices": rii_tensor, - "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "action_mask": response_mask.to(dtype=torch.int64), - }) + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": ( + logprobs_t + if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32) + ), + "rollout_inference_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) training_input.metadata = {"response_length": num_actions} cfg.trainer.policy.megatron_config.transformer_config_kwargs = { @@ -168,35 +175,38 @@ def test_megatron_router_replay(ray_init_fixture): cfg.trainer.micro_train_batch_size_per_gpu = 1 actor_group = init_worker_with_type( - "policy", shared_pg=None, colocate_all=False, - num_gpus_per_node=2, cfg=cfg, + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=2, + cfg=cfg, ) expected_per_layer = _split_replay_indices(rii_tensor.to(torch.long)) state = ray.get( actor_group.async_run_ray_method( - "pass_through", "debug_setup_router_replay_state", + "pass_through", + "debug_setup_router_replay_state", data=training_input, ) )[0] assert state is not None, "Worker returned None state" - assert "REPLAY_FORWARD" in state["action"], ( - f"RouterReplay action should be REPLAY_FORWARD, got: {state['action']}" - ) + assert ( + "REPLAY_FORWARD" in state["action"] + ), f"RouterReplay action should be REPLAY_FORWARD, got: {state['action']}" assert state["num_instances"] == len(expected_per_layer), ( - f"Expected {len(expected_per_layer)} replay instances (one per layer), " - f"got {state['num_instances']}" + f"Expected {len(expected_per_layer)} replay instances (one per layer), " f"got {state['num_instances']}" + ) + for layer_idx, (got, expected) in enumerate(zip(state["target_indices"], expected_per_layer)): + assert torch.equal( + got.to(torch.long), expected.to(torch.long) + ), f"Layer {layer_idx}: Megatron target indices differ from vLLM indices" + print( + f"PASSED: vLLM routing indices ({rii_tensor.shape}) correctly " + f"loaded into {state['num_instances']} Megatron RouterReplay instances" ) - for layer_idx, (got, expected) in enumerate( - zip(state["target_indices"], expected_per_layer) - ): - assert torch.equal(got.to(torch.long), expected.to(torch.long)), ( - f"Layer {layer_idx}: Megatron target indices differ from vLLM indices" - ) - print(f"PASSED: vLLM routing indices ({rii_tensor.shape}) correctly " - f"loaded into {state['num_instances']} Megatron RouterReplay instances") finally: ray.shutdown() From f1b9c5333cc0f767b0b48c4fe4ba23b7ee976bfb Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 22:21:52 +0000 Subject: [PATCH 07/18] worked w opus to get forward pass logprob diff lower with replay + running with tp + ep for megatron --- .../inference_engines/vllm/vllm_engine.py | 11 +- .../skyrl_train/utils/replay_utils.py | 98 +++++++++++ .../megatron/megatron_model_wrapper.py | 6 + .../workers/megatron/megatron_worker.py | 22 +-- skyrl/train/dataset/preprocess.py | 13 +- skyrl/train/generators/skyrl_gym_generator.py | 5 +- .../gpu/gpu_ci/test_router_replay.py | 160 +++++++++++++++++- 7 files changed, 291 insertions(+), 24 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 66d4213f1c..048fba7ec8 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -157,18 +157,19 @@ def _postprocess_outputs(self, outputs): del token_logprobs response_logprobs.append(_logprobs) + _routed_experts = None if resp.routed_experts is not None: if hasattr(resp.routed_experts, "tolist"): - routed_experts_list = resp.routed_experts.tolist() + _routed_experts = resp.routed_experts.tolist() else: - routed_experts_list = resp.routed_experts - rollout_inference_indices.append(routed_experts_list) + _routed_experts = resp.routed_experts + rollout_inference_indices.append(_routed_experts) if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params - if len(rollout_inference_indices) == 0: - rollout_inference_indices = None + if len(rollout_inference_indices) == 0 and _routed_experts is None: + rollout_inference_indices = None # hack: assume uniform sampling params return InferenceEngineOutput( responses=responses, diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 2e5b166e1b..0b2ce73e50 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -7,6 +7,39 @@ from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +def _patch_alltoall_dispatcher_for_replay(): + """Monkey-patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay. + + When router replay is enabled, duplicate indices in top_indices can cause + routing_map.sum() < num_tokens * topk, leading to a split size mismatch + in the alltoall collective. We fix this by deriving num_out_tokens from + the routing map instead of the static num_tokens * topk formula. + """ + try: + from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher + except ImportError: + return + + if getattr(MoEAlltoAllTokenDispatcher, "_preprocess_patched", False): + return + + original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + + def patched_preprocess(self, routing_map): + result = original_preprocess(self, routing_map) + if ( + getattr(self.config, "moe_enable_routing_replay", False) + and not self.drop_and_pad + and self.config.moe_expert_capacity_factor is None + and not self.config.moe_router_padding_for_quantization + ): + self.num_out_tokens = int(routing_map.sum().item()) + return result + + MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess + MoEAlltoAllTokenDispatcher._preprocess_patched = True + + def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch.Tensor]: if rollout_inference_indices is None: return None @@ -17,6 +50,71 @@ def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch return [per_layer[i].reshape(-1, per_layer.shape[-1]) for i in range(per_layer.shape[0])] +def _remove_left_padding_from_indices( + rollout_inference_indices: torch.Tensor, + attention_mask: torch.Tensor, +) -> torch.Tensor: + """Apply the same left-padding removal as remove_left_padding to routing indices. + + Args: + rollout_inference_indices: [batch, padded_seq_len, layers, topk] + attention_mask: [batch, padded_seq_len] (int or bool) + + Returns: + [batch, effective_seq_len, layers, topk] with real tokens packed left. + """ + import megatron.core.parallel_state as mpu + + seq_lens = attention_mask.sum(dim=1) + effective_seq_len = seq_lens.max().item() + sp_world_size = mpu.get_tensor_model_parallel_world_size() + if sp_world_size > 1: + pad_size = (sp_world_size - effective_seq_len % sp_world_size) % sp_world_size + effective_seq_len += pad_size + + batch_size = rollout_inference_indices.shape[0] + new_rii = torch.zeros( + batch_size, + effective_seq_len, + rollout_inference_indices.shape[2], + rollout_inference_indices.shape[3], + dtype=rollout_inference_indices.dtype, + device=rollout_inference_indices.device, + ) + for i in range(batch_size): + mask = attention_mask[i].bool() + new_rii[i, : seq_lens[i]] = rollout_inference_indices[i, mask] + return new_rii + + +def _setup_per_microbatch_replay( + rollout_inference_indices: torch.Tensor, + attention_mask: torch.Tensor, +) -> None: + """Set up RouterReplay for a single micro-batch, aligning indices + with the left-padding-removed token layout that the MoE layer sees. + + Handles sequence parallelism: when TP > 1, the sequence is split across + TP ranks, so each rank's MoE router only sees its local chunk of tokens. + """ + import megatron.core.parallel_state as mpu + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + + _patch_alltoall_dispatcher_for_replay() + + aligned = _remove_left_padding_from_indices(rollout_inference_indices, attention_mask) + + tp_size = mpu.get_tensor_model_parallel_world_size() + if tp_size > 1: + tp_rank = mpu.get_tensor_model_parallel_rank() + seq_len = aligned.shape[1] + chunk_size = seq_len // tp_size + aligned = aligned[:, tp_rank * chunk_size : (tp_rank + 1) * chunk_size, :, :] + + RouterReplay.set_replay_data(_split_replay_indices(aligned)) + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: """ Set up router replay for forward pass (ref/policy inference). diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 5ab565f989..aa64058e09 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -24,6 +24,7 @@ remove_left_padding, recover_left_padding, ) +from skyrl.backends.skyrl_train.utils.replay_utils import _setup_per_microbatch_replay class MegatronModelWrapper: @@ -103,6 +104,11 @@ def collection_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) + + rollout_inference_indices = batch.pop("rollout_inference_indices", None) + if rollout_inference_indices is not None: + _setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index d01a5ad061..52f7d24807 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -426,9 +426,7 @@ def forward(self, data: TrainingInputBatch): """ Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. """ - from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay - - setup_router_replay_forward(data, self.enable_router_replay) + from skyrl.backends.skyrl_train.utils.replay_utils import clear_router_replay # Run in micro batches grouped into a single mini-batch micro_bsz = self.cfg.micro_forward_batch_size_per_gpu @@ -444,14 +442,16 @@ def forward(self, data: TrainingInputBatch): num_actions = micro.metadata["response_length"] position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) - micro_dicts.append( - { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": num_actions, - } - ) + micro_dict = { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": num_actions, + } + rii = micro.get("rollout_inference_indices") + if rii is not None and self.enable_router_replay: + micro_dict["rollout_inference_indices"] = rii + micro_dicts.append(micro_dict) self.model.eval() seq_len = micro_dicts[0]["sequences"].shape[1] diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index 01b332a625..be9c8fbcce 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -133,7 +133,18 @@ def convert_prompts_responses_to_batch_tensors( rollout_inference_indices_tensor = None if rollout_inference_indices: - rollout_inference_indices_tensor = torch.tensor(rollout_inference_indices, dtype=torch.int32) + first_non_empty = next((x for x in rollout_inference_indices if x), None) + if first_non_empty: + total_seq_len = max_input_len + max_output_len + num_layers = len(first_non_empty[0]) + topk = len(first_non_empty[0][0]) if num_layers > 0 else 0 + padded = torch.zeros(len(rollout_inference_indices), total_seq_len, num_layers, topk, dtype=torch.int32) + for i, sample_indices in enumerate(rollout_inference_indices): + if sample_indices: + left_pad = max_input_len - prompt_token_lens[i] + n = min(len(sample_indices), total_seq_len - left_pad) + padded[i, left_pad : left_pad + n] = torch.tensor(sample_indices[:n], dtype=torch.int32) + rollout_inference_indices_tensor = padded return ( sequences, diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 67bc85a720..b48249130d 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -460,7 +460,7 @@ async def agent_loop( ] if agent_loop_state.rollout_inference_indices is not None: rollout_inference_indices_out = agent_loop_state.rollout_inference_indices[ - : agent_loop_state.response_end_idx - initial_prompt_length + 1 + : agent_loop_state.response_end_idx + 1 ] # fix index for per_step_rewards per_step_rewards = [(reward, idx - initial_prompt_length) for reward, idx in per_step_rewards] @@ -676,8 +676,7 @@ async def generate_batched( sample_logprobs = logprobs[i][: len(response)] truncated_logprobs.append(sample_logprobs) if raw_rollout_inference_indices is not None: - sample_indices = raw_rollout_inference_indices[i][: len(response)] - truncated_indices.append(sample_indices) + truncated_indices.append(raw_rollout_inference_indices[i]) # Get environment-specific metrics env_metrics.append(env.get_metrics()) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 13fb4f51cc..f1134da5d6 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -14,6 +14,7 @@ Timer, init_worker_with_type, ) +from skyrl.backends.skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch from skyrl.train.utils.utils import validate_cfg from skyrl.train.config import SkyRLTrainConfig, SamplingParams from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator @@ -50,15 +51,12 @@ def test_megatron_router_replay(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 2 cfg.generator.sampling_params = SamplingParams( - max_generate_length=16, + max_generate_length=1024, logprobs=1, temperature=1.0, ) cfg.generator.batched = False cfg.generator.max_turns = 1 - cfg.generator.use_conversation_multi_turn = True - cfg.generator.apply_overlong_filtering = False - cfg.generator.zero_reward_on_non_stop = False num_prompts = 1 @@ -210,3 +208,157 @@ def test_megatron_router_replay(ray_init_fixture): finally: ray.shutdown() + + +@pytest.mark.megatron +def test_megatron_router_replay_logprobs(ray_init_fixture): + """ + Check that logprob diff is lower when using router replay. Requires full 8xH100 setup to do full forward pass. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 8 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=1024, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 1 + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + colocate_all=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.9, + ) as engines: + client, pg = engines.client, engines.pg + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=4, + n_samples_per_prompt=1, + max_prompt_length=512, + env_class="gsm8k", + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=16, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + indices = generator_output["rollout_inference_indices"] + responses = generator_output["response_ids"] + assert ( + indices is not None + ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + asyncio.run(client.sleep()) + + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_inference_indices=indices, + ) + ) + + assert rii_tensor is not None + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": ( + logprobs_t + if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32) + ), + "rollout_inference_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) + training_input.metadata = {"response_length": num_actions} + + cfg.trainer.placement.policy_num_gpus_per_node = 8 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 4 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + def run_megatron_forward(enable_replay: bool) -> torch.Tensor: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + "moe_enable_routing_replay": enable_replay, + } + actor_group = init_worker_with_type( + "policy", + shared_pg=pg, + colocate_all=True, + num_gpus_per_node=8, + cfg=cfg, + ) + refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) + results = ray.get(refs) + outputs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] + + for actor in actor_group._actor_handlers: + ray.kill(actor) + return outputs + + r3_logprobs = run_megatron_forward(enable_replay=True) + no_r3_logprobs = run_megatron_forward(enable_replay=False) + + r3_diff = (logprobs_t - r3_logprobs).abs() + no_r3_diff = (logprobs_t - no_r3_logprobs).abs() + print(f"With replay - logprob diff mean: {r3_diff.mean().item():.6f}, std: {r3_diff.std().item():.6f}") + print(f"Without replay - logprob diff mean: {no_r3_diff.mean().item():.6f}, std: {no_r3_diff.std().item():.6f}") + + assert r3_diff.mean().item() < no_r3_diff.mean().item(), ( + f"Router replay should reduce logprob diff vs rollout, " + f"but with_replay={r3_diff.mean().item():.6f} >= without_replay={no_r3_diff.mean().item():.6f}" + ) + finally: + ray.shutdown() From 8a8fa701f9e8420582c1bc9c80dce287073a5284 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 4 Mar 2026 23:26:19 +0000 Subject: [PATCH 08/18] add test for forward backward and fix behavior --- .../megatron/megatron_model_wrapper.py | 4 + .../workers/megatron/megatron_worker.py | 44 ++--- .../gpu/gpu_ci/test_router_replay.py | 158 ++++++++++++++++++ 3 files changed, 185 insertions(+), 21 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index aa64058e09..540e9328e2 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -361,6 +361,10 @@ def loss_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) + rollout_inference_indices = batch.pop("rollout_inference_indices", None) + if rollout_inference_indices is not None: + _setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 52f7d24807..3a78b8fb91 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -33,7 +33,7 @@ from skyrl.train.utils.utils import update_model_config, str_to_torch_dtype from skyrl.backends.skyrl_train.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch -from skyrl.backends.skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics +from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics, all_reduce_metrics from skyrl.backends.skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -624,9 +624,8 @@ def forward_backward( Returns: Aggregated metrics dict across all micro batches """ - from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay + from skyrl.backends.skyrl_train.utils.replay_utils import clear_router_replay - setup_router_replay_forward(data, self.enable_router_replay) self.model.train() for chunk in self.actor_module: # if use distributed optimizer, zero grad buffer will be handled by optimizer @@ -638,28 +637,31 @@ def forward_backward( # Move data to GPU data.to(torch.cuda.current_device()) - # Build micro-batch dicts expected by forward_backward_mini_batch + # Chunk manually so we can propagate rollout_inference_indices for + # per-micro-batch router replay (BatchIterator/Experience don't carry them). micro_buffer = [] - for experience in BatchIterator(data, micro_batch_size, drop_last=False): - sequences = experience.sequences - attention_mask = experience.attention_mask + for micro in data.chunk(micro_batch_size): + sequences = micro["sequences"] + attention_mask = micro["attention_mask"] position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) - micro_buffer.append( - { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": experience.num_actions, - "old_action_log_probs": experience.action_log_probs, - "base_action_log_probs": experience.base_action_log_probs, - "advantages": experience.advantages, - "loss_mask": experience.loss_mask, - "rollout_action_logprobs": experience.rollout_logprobs, - "action_mask": experience.action_mask, - } - ) + micro_dict = { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": micro.metadata["response_length"], + "old_action_log_probs": micro.get("action_log_probs"), + "base_action_log_probs": micro.get("base_action_log_probs"), + "advantages": micro.get("advantages"), + "loss_mask": micro.get("loss_mask"), + "rollout_action_logprobs": micro.get("rollout_logprobs"), + "action_mask": micro.get("action_mask"), + } + rii = micro.get("rollout_inference_indices") + if rii is not None and self.enable_router_replay: + micro_dict["rollout_inference_indices"] = rii + micro_buffer.append(micro_dict) if not micro_buffer: return {} diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index f1134da5d6..f113240443 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -362,3 +362,161 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: ) finally: ray.shutdown() + + +@pytest.mark.megatron +def test_megatron_router_replay_forward_backward(ray_init_fixture): + """ + Check that forward_backward produces similar losses with and without + router replay (same weights, so routing decisions should nearly match). + Requires full 8xH100 setup. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 8 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=1024, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 1 + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + colocate_all=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.9, + ) as engines: + client, pg = engines.client, engines.pg + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=10, + n_samples_per_prompt=5, + max_prompt_length=512, + env_class="gsm8k", + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=16, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + indices = generator_output["rollout_inference_indices"] + responses = generator_output["response_ids"] + assert ( + indices is not None + ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + asyncio.run(client.sleep()) + + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_inference_indices=indices, + ) + ) + + assert rii_tensor is not None + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": ( + logprobs_t + if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32) + ), + "rollout_inference_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) + training_input.metadata = {"response_length": num_actions} + + cfg.trainer.placement.policy_num_gpus_per_node = 8 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 4 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + def run_megatron_forward_backward(enable_replay: bool) -> dict: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + "moe_enable_routing_replay": enable_replay, + } + actor_group = init_worker_with_type( + "policy", + shared_pg=pg, + colocate_all=True, + num_gpus_per_node=8, + cfg=cfg, + ) + results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) + for actor in actor_group._actor_handlers: + ray.kill(actor) + return results[0] + + metrics_replay = run_megatron_forward_backward(enable_replay=True) + metrics_no_replay = run_megatron_forward_backward(enable_replay=False) + + loss_replay = metrics_replay["policy_loss"] + loss_no_replay = metrics_no_replay["policy_loss"] + print(f"With replay - loss: {loss_replay:.6f}") + print(f"Without replay - loss: {loss_no_replay:.6f}") + print(f"With replay metrics: {metrics_replay}") + print(f"Without replay metrics: {metrics_no_replay}") + + diff = abs(loss_replay - loss_no_replay) + threshold = 0.5 + print(f"Loss diff: {diff:.6f} (threshold: {threshold})") + assert diff < threshold, ( + f"Losses with/without replay should be similar (same weights), " + f"but diff={diff:.6f} >= threshold={threshold}" + ) + finally: + ray.shutdown() From 410995a64d644ea2ae654cd44c006dd0146ddccb Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 6 Mar 2026 02:22:30 +0000 Subject: [PATCH 09/18] working for qwen but not moonlight... debugging moonlight --- examples/train/gsm8k/run_gsm8k.sh | 4 +- .../skyrl_train/utils/replay_utils.py | 63 ++++++++++++++++- .../megatron/megatron_model_wrapper.py | 16 +++++ .../workers/megatron/megatron_worker.py | 59 +++++++++++++++- skyrl/train/utils/utils.py | 2 +- skyrl/utils/tok.py | 1 + .../skyrl_train/gpu/gpu_ci/conftest.py | 6 ++ .../gpu/gpu_ci/test_megatron_worker.py | 18 +++-- .../gpu/gpu_ci/test_router_replay.py | 67 ++++++++++++++----- 9 files changed, 207 insertions(+), 29 deletions(-) diff --git a/examples/train/gsm8k/run_gsm8k.sh b/examples/train/gsm8k/run_gsm8k.sh index 2693571dac..e5dc9e20f3 100755 --- a/examples/train/gsm8k/run_gsm8k.sh +++ b/examples/train/gsm8k/run_gsm8k.sh @@ -26,8 +26,8 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \ trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \ trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ - generator.inference_engine.num_engines=$NUM_GPUS \ - generator.inference_engine.tensor_parallel_size=1 \ + generator.inference_engine.num_engines=1 \ + generator.inference_engine.tensor_parallel_size=4 \ trainer.epochs=20 \ trainer.eval_batch_size=1024 \ trainer.eval_before_train=true \ diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 0b2ce73e50..68d266e052 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -7,6 +7,38 @@ from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +def _patch_topk_router_layer_number(): + """Monkey-patch TopKRouter.set_layer_number to propagate the global layer + number to the RouterReplay instance. + + DeepSeek V3 (and similar) architectures have dense FFN layers before the MoE + layers. vLLM reports routing indices for ALL transformer layers (including + dense), but Megatron only creates RouterReplay instances for MoE layers. + Storing the global layer_number on each RouterReplay instance lets us map + vLLM's per-layer data to the correct MoE router even when dense layers are + present. + + Must be called BEFORE model creation (i.e. before make_megatron_module). + """ + try: + from megatron.core.transformer.moe.router import TopKRouter + except ImportError: + return + + if getattr(TopKRouter, "_set_layer_number_patched", False): + return + + original_set_layer_number = TopKRouter.set_layer_number + + def patched_set_layer_number(self, layer_number: int): + original_set_layer_number(self, layer_number) + if self.router_replay is not None: + self.router_replay.layer_number = layer_number + + TopKRouter.set_layer_number = patched_set_layer_number + TopKRouter._set_layer_number_patched = True + + def _patch_alltoall_dispatcher_for_replay(): """Monkey-patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay. @@ -96,6 +128,12 @@ def _setup_per_microbatch_replay( Handles sequence parallelism: when TP > 1, the sequence is split across TP ranks, so each rank's MoE router only sees its local chunk of tokens. + + Handles dense-layer mismatch: DeepSeek V3-style models have dense FFN + layers before the MoE layers. vLLM reports routing indices for ALL + transformer layers, but Megatron only has RouterReplay instances for MoE + layers. We use each instance's global layer_number (set by the patched + TopKRouter.set_layer_number) to index into the correct slice of the data. """ import megatron.core.parallel_state as mpu from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction @@ -111,7 +149,30 @@ def _setup_per_microbatch_replay( chunk_size = seq_len // tp_size aligned = aligned[:, tp_rank * chunk_size : (tp_rank + 1) * chunk_size, :, :] - RouterReplay.set_replay_data(_split_replay_indices(aligned)) + per_layer_data = _split_replay_indices(aligned) + num_layers_in_data = len(per_layer_data) + instances = RouterReplay.global_router_replay_instances + num_instances = len(instances) + + if num_layers_in_data == num_instances: + RouterReplay.set_replay_data(per_layer_data) + else: + # Dense-layer mismatch: map each MoE router to its global layer index. + # Prefer the patched layer_number; fall back to offset-based mapping + # (assumes dense layers precede MoE layers). + for i, router_instance in enumerate(instances): + layer_number = getattr(router_instance, "layer_number", None) + if layer_number is not None: + layer_idx = layer_number - 1 # layer_number is 1-based + else: + layer_idx = i + (num_layers_in_data - num_instances) + if layer_idx < 0 or layer_idx >= num_layers_in_data: + raise ValueError( + f"Router replay layer index {layer_idx} out of range " + f"for data with {num_layers_in_data} layers " + f"({num_instances} router instances)" + ) + router_instance.set_target_indices(per_layer_data[layer_idx]) RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 540e9328e2..ec9717819e 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -87,6 +87,22 @@ def collection_func(logits, data): tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() + if tp_rank == 0 and mpu.get_data_parallel_rank() == 0: + import os + + if os.environ.get("SKYRL_DEBUG_LOGITS"): + print( + f"[DEBUG] logits shape={logits.shape}, " + f"mean={logits.float().mean().item():.4f}, " + f"std={logits.float().std().item():.4f}, " + f"min={logits.float().min().item():.4f}, " + f"max={logits.float().max().item():.4f}, " + f"sequences shape={sequences.shape}, " + f"attention_backend={getattr(get_model_config(self.actor_module[0]), 'attention_backend', 'unknown')}, " + f"multi_latent_attention={getattr(get_model_config(self.actor_module[0]), 'multi_latent_attention', 'unknown')}, " + f"q_lora_rank={getattr(get_model_config(self.actor_module[0]), 'q_lora_rank', 'unknown')}" + ) + if temperature != 1.0: logits.div_(temperature) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 3a78b8fb91..77c4b1327e 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -253,6 +253,39 @@ def extract_weights(self, dtype: torch.dtype): class MegatronWorker: + def debug_model_config(self): + """Return model config diagnostics for debugging logprob mismatch.""" + from skyrl.backends.skyrl_train.distributed.megatron.megatron_utils import get_model_config + + config = get_model_config(self.actor_module[0]) + diag = { + "attention_backend": str(getattr(config, "attention_backend", "unknown")), + "multi_latent_attention": getattr(config, "multi_latent_attention", "unknown"), + "q_lora_rank": getattr(config, "q_lora_rank", "unknown"), + "kv_lora_rank": getattr(config, "kv_lora_rank", "unknown"), + "qk_head_dim": getattr(config, "qk_head_dim", "unknown"), + "qk_pos_emb_head_dim": getattr(config, "qk_pos_emb_head_dim", "unknown"), + "v_head_dim": getattr(config, "v_head_dim", "unknown"), + "num_layers": getattr(config, "num_layers", "unknown"), + "hidden_size": getattr(config, "hidden_size", "unknown"), + "num_attention_heads": getattr(config, "num_attention_heads", "unknown"), + "rope_type": getattr(config, "rope_type", "unknown"), + "layernorm_epsilon": getattr(config, "layernorm_epsilon", "unknown"), + "sequence_parallel": getattr(config, "sequence_parallel", "unknown"), + } + weight_stats = {} + model = self.actor_module[0] + for name, param in model.named_parameters(): + if any(k in name for k in ["layers.0.", "output_layer", "word_embeddings"]): + weight_stats[name] = { + "shape": list(param.shape), + "mean": param.float().mean().item(), + "std": param.float().std().item(), + "norm": param.float().norm().item(), + } + diag["weight_stats"] = weight_stats + return diag + def _read_router_replay_state(self): """Read the current RouterReplay state from all instances.""" from megatron.core.transformer.moe.router_replay import RouterReplay @@ -301,13 +334,11 @@ def init_configs( override_config_kwargs.update(model_config_kwargs.get("model_config", {})) update_model_config(hf_config, override_config_kwargs=override_config_kwargs) - # if flash_attn is enabled, we use flash attention backend, otherwise fall back to fused attention backend transformer_config_kwargs = ( transformer_config_kwargs if isinstance(transformer_config_kwargs, dict) else OmegaConf.to_container(transformer_config_kwargs, resolve=True) ) - transformer_config_kwargs["attention_backend"] = "flash" if flash_attn else "fused" if not self.cfg.gradient_checkpointing: for key in ("recompute_granularity", "recompute_method", "recompute_num_layers"): @@ -315,6 +346,18 @@ def init_configs( bridge = AutoBridge.from_hf_pretrained(model_path, trust_remote_code=True) provider = bridge.to_megatron_provider() + + # Determine attention backend. MLA (Multi-Latent Attention) with TE fused + # attention can produce NaN/incorrect results; fall back to unfused for MLA. + if "attention_backend" not in transformer_config_kwargs: + has_mla = getattr(provider, "multi_latent_attention", False) + if flash_attn: + transformer_config_kwargs["attention_backend"] = "flash" + elif has_mla: + transformer_config_kwargs["attention_backend"] = "unfused" + else: + transformer_config_kwargs["attention_backend"] = "fused" + provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 @@ -322,7 +365,7 @@ def init_configs( provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size provider.expert_tensor_parallel_size = megatron_config.expert_tensor_parallel_size provider.sequence_parallel = megatron_config.tensor_model_parallel_size > 1 - provider.attention_backend = "flash" if flash_attn else "fused" + provider.attention_backend = transformer_config_kwargs["attention_backend"] provider.variable_seq_lengths = True provider.masked_softmax_fusion = True # Apply explicit MoE config fields to the provider. @@ -558,6 +601,11 @@ def init_model(self, model_path, num_training_steps: int = 1e9): flash_attn=self.cfg.flash_attn, ) + if self.enable_router_replay: + from skyrl.backends.skyrl_train.utils.replay_utils import _patch_topk_router_layer_number + + _patch_topk_router_layer_number() + # wrap with DDP for training self.actor_module = self.make_megatron_module( wrap_with_ddp=True, @@ -860,6 +908,11 @@ def init_model(self, model_path, num_training_steps: int = 1e9): flash_attn=self.cfg.flash_attn, ) + if self.enable_router_replay: + from skyrl.backends.skyrl_train.utils.replay_utils import _patch_topk_router_layer_number + + _patch_topk_router_layer_number() + self.actor_module = self.make_megatron_module( wrap_with_ddp=False, ddp_config=None, diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 3008b615c0..d06f16b51f 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -801,7 +801,7 @@ def run_p2p_access_check(): if device_count < 2: return False - # Check P2P access between all GPU pairs + # # Check P2P access between all GPU pairs for i in range(device_count): for j in range(device_count): if i != j: diff --git a/skyrl/utils/tok.py b/skyrl/utils/tok.py index 1006830500..b328ee8334 100644 --- a/skyrl/utils/tok.py +++ b/skyrl/utils/tok.py @@ -7,6 +7,7 @@ def get_tokenizer(model_name_or_path, **tokenizer_kwargs) -> AutoTokenizer: """Gets tokenizer for the given base model with the given parameters Sets the pad token ID to EOS token ID if `None`""" + tokenizer_kwargs.setdefault("trust_remote_code", True) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **tokenizer_kwargs) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 3a0bfe532c..1394e2b43c 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -33,6 +33,8 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env_vars["NVTE_FUSED_ATTN"] = "0" + env_vars["RAY_CGRAPH_get_timeout"] = "600" + env_vars["SKYRL_DEBUG_LOGITS"] = "1" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") @@ -40,6 +42,10 @@ def ray_init_fixture(): raise RuntimeError("SKYRL_PYTHONPATH_EXPORT is set but PYTHONPATH is not defined in environment") env_vars["PYTHONPATH"] = pythonpath + # RAY_CGRAPH_get_timeout must be set in os.environ so that vLLM subprocesses + # (EngineCore) inherit it — runtime_env alone doesn't propagate to subprocesses. + os.environ["RAY_CGRAPH_get_timeout"] = env_vars.pop("RAY_CGRAPH_get_timeout") + logger.info(f"Initializing Ray with environment variables: {env_vars}") ray.init(runtime_env={"env_vars": env_vars}) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 6e31b5815b..55e30ddb6a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -33,7 +33,8 @@ # TODO (erictang000): we would prefer to use this smaller MoE model for testing, but seeing incorrect logprobs when using EP > 1 # this might be a model specific mbridge issue - see if this persists when we transition to Megatron-Bridge # MOE_MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B" -MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +# MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +MOE_MODEL_NAME = "/home/ray/moonlight16b" def get_test_actor_config(model_name=MODEL_NAME) -> SkyRLTrainConfig: @@ -228,10 +229,10 @@ async def test_megatron_forward( cfg.trainer.use_sample_packing = use_sample_packing batch = get_test_training_batch(max(4, gpus_per_node)) - if ep > 1: - if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: - cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() - cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 + # if ep > 1: + # if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: + # cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 if lora: cfg.trainer.policy.model.lora = SkyRLLoraConfig(rank=16, alpha=16) @@ -246,6 +247,9 @@ async def test_megatron_forward( action_log_probs_refs = actor_group.async_run_ray_method("mesh", "forward", data=batch) all_rank_action_log_probs = ray.get(action_log_probs_refs) + print( + f"Megatron logprobs - mean: {all_rank_action_log_probs.mean().item():.6f}, std: {all_rank_action_log_probs.std().item():.6f}" + ) action_log_probs_megatron = concatenate_outputs_after_mesh_dispatch( actor_group.actor_infos, all_rank_action_log_probs )["output"] @@ -258,8 +262,8 @@ async def test_megatron_forward( @ray.remote(num_gpus=1) def run_hf_forward(batch, model_name): config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, dtype=torch.bfloat16) - if ep > 1: - config.num_hidden_layers = 2 + # if ep > 1: + # config.num_hidden_layers = 2 model = AutoModelForCausalLM.from_pretrained(model_name, config=config, dtype=torch.bfloat16) model.eval() model.to("cuda") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index f113240443..6ab38cd31c 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -24,8 +24,11 @@ from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices -MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +MOE_MODEL_NAME = "/home/ray/moonlight16b" +# MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" REPLAY_NUM_LAYERS = 2 +NUM_PROMPTS = 10 +N_SAMPLES_PER_PROMPT = 5 def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: @@ -35,6 +38,20 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: cfg.trainer.micro_train_batch_size_per_gpu = 2 cfg.trainer.use_sample_packing = False cfg.trainer.logger = "console" + if "moonlight" in model_name: + # flash attn not supported for moonlight16b + cfg.trainer.policy.megatron_config.moe_token_dispatcher_type = "alltoall" + cfg.trainer.policy.megatron_config.moe_router_load_balancing_type = "seq_aux_loss" + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_aux_loss_coeff"] = 0 + cfg.trainer.policy.megatron_config.moe_router_score_function = "sigmoid" + cfg.trainer.policy.megatron_config.moe_router_enable_expert_bias = True + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_bias_update_rate"] = 0 + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_dtype"] = "fp32" + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_topk"] = 6 + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_pre_softmax"] = True + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_group_topk"] = 1 + cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_num_groups"] = 1 + cfg.trainer.flash_attn = False validate_cfg(cfg) return cfg @@ -51,7 +68,7 @@ def test_megatron_router_replay(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 2 cfg.generator.sampling_params = SamplingParams( - max_generate_length=1024, + max_generate_length=128, logprobs=1, temperature=1.0, ) @@ -95,7 +112,7 @@ def test_megatron_router_replay(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=16, + max_generate_length=128, min_p=0.0, logprobs=1, ), @@ -211,7 +228,7 @@ def test_megatron_router_replay(ray_init_fixture): @pytest.mark.megatron -def test_megatron_router_replay_logprobs(ray_init_fixture): +def test_logprobs(ray_init_fixture): """ Check that logprob diff is lower when using router replay. Requires full 8xH100 setup to do full forward pass. """ @@ -221,7 +238,7 @@ def test_megatron_router_replay_logprobs(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 8 cfg.generator.sampling_params = SamplingParams( - max_generate_length=1024, + max_generate_length=128, logprobs=1, temperature=1.0, ) @@ -251,8 +268,8 @@ def test_megatron_router_replay_logprobs(ray_init_fixture): input_batch: GeneratorInput = get_test_generator_input( model=MOE_MODEL_NAME, - num_prompts=4, - n_samples_per_prompt=1, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, max_prompt_length=512, env_class="gsm8k", ) @@ -262,7 +279,7 @@ def test_megatron_router_replay_logprobs(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=16, + max_generate_length=128, min_p=0.0, logprobs=1, ), @@ -329,7 +346,11 @@ def test_megatron_router_replay_logprobs(ray_init_fixture): cfg.trainer.micro_forward_batch_size_per_gpu = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 - def run_megatron_forward(enable_replay: bool) -> torch.Tensor: + import os + + os.environ["SKYRL_DEBUG_LOGITS"] = "1" + + def run_megatron_forward(enable_replay: bool, debug: bool = False) -> torch.Tensor: cfg.trainer.policy.megatron_config.transformer_config_kwargs = { "moe_enable_routing_replay": enable_replay, } @@ -340,6 +361,19 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: num_gpus_per_node=8, cfg=cfg, ) + + if debug: + diag = ray.get(actor_group.async_run_ray_method("pass_through", "debug_model_config"))[0] + print(f"\n=== Model Config (replay={enable_replay}) ===") + for k, v in diag.items(): + if k != "weight_stats": + print(f" {k}: {v}") + print(f" weight_stats ({len(diag['weight_stats'])} params):") + for name, stats in sorted(diag["weight_stats"].items()): + print( + f" {name}: shape={stats['shape']}, mean={stats['mean']:.6f}, std={stats['std']:.6f}, norm={stats['norm']:.2f}" + ) + refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) results = ray.get(refs) outputs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] @@ -348,11 +382,14 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: ray.kill(actor) return outputs - r3_logprobs = run_megatron_forward(enable_replay=True) + r3_logprobs = run_megatron_forward(enable_replay=True, debug=True) no_r3_logprobs = run_megatron_forward(enable_replay=False) r3_diff = (logprobs_t - r3_logprobs).abs() no_r3_diff = (logprobs_t - no_r3_logprobs).abs() + print(f"vLLM logprobs - mean: {logprobs_t.mean().item():.6f}, std: {logprobs_t.std().item():.6f}") + print(f"Megatron (replay) - mean: {r3_logprobs.mean().item():.6f}, std: {r3_logprobs.std().item():.6f}") + print(f"Megatron (no rep) - mean: {no_r3_logprobs.mean().item():.6f}, std: {no_r3_logprobs.std().item():.6f}") print(f"With replay - logprob diff mean: {r3_diff.mean().item():.6f}, std: {r3_diff.std().item():.6f}") print(f"Without replay - logprob diff mean: {no_r3_diff.mean().item():.6f}, std: {no_r3_diff.std().item():.6f}") @@ -365,7 +402,7 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: @pytest.mark.megatron -def test_megatron_router_replay_forward_backward(ray_init_fixture): +def test_forward_backward(ray_init_fixture): """ Check that forward_backward produces similar losses with and without router replay (same weights, so routing decisions should nearly match). @@ -377,7 +414,7 @@ def test_megatron_router_replay_forward_backward(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 8 cfg.generator.sampling_params = SamplingParams( - max_generate_length=1024, + max_generate_length=128, logprobs=1, temperature=1.0, ) @@ -407,8 +444,8 @@ def test_megatron_router_replay_forward_backward(ray_init_fixture): input_batch: GeneratorInput = get_test_generator_input( model=MOE_MODEL_NAME, - num_prompts=10, - n_samples_per_prompt=5, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, max_prompt_length=512, env_class="gsm8k", ) @@ -418,7 +455,7 @@ def test_megatron_router_replay_forward_backward(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=16, + max_generate_length=128, min_p=0.0, logprobs=1, ), From 93eee6545070b432f1b0ff729420e86f41617a5b Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Fri, 6 Mar 2026 02:25:05 +0000 Subject: [PATCH 10/18] x --- .../workers/megatron/megatron_model_wrapper.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index ec9717819e..540e9328e2 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -87,22 +87,6 @@ def collection_func(logits, data): tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() - if tp_rank == 0 and mpu.get_data_parallel_rank() == 0: - import os - - if os.environ.get("SKYRL_DEBUG_LOGITS"): - print( - f"[DEBUG] logits shape={logits.shape}, " - f"mean={logits.float().mean().item():.4f}, " - f"std={logits.float().std().item():.4f}, " - f"min={logits.float().min().item():.4f}, " - f"max={logits.float().max().item():.4f}, " - f"sequences shape={sequences.shape}, " - f"attention_backend={getattr(get_model_config(self.actor_module[0]), 'attention_backend', 'unknown')}, " - f"multi_latent_attention={getattr(get_model_config(self.actor_module[0]), 'multi_latent_attention', 'unknown')}, " - f"q_lora_rank={getattr(get_model_config(self.actor_module[0]), 'q_lora_rank', 'unknown')}" - ) - if temperature != 1.0: logits.div_(temperature) From 9c716a16f251369dc7b2a728570283e00988f826 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Mon, 9 Mar 2026 19:59:35 +0000 Subject: [PATCH 11/18] fixed test for moonlight by enforcing fused attn --- .../workers/megatron/megatron_model_wrapper.py | 16 ---------------- .../workers/megatron/megatron_worker.py | 6 +++--- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index ec9717819e..540e9328e2 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -87,22 +87,6 @@ def collection_func(logits, data): tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() - if tp_rank == 0 and mpu.get_data_parallel_rank() == 0: - import os - - if os.environ.get("SKYRL_DEBUG_LOGITS"): - print( - f"[DEBUG] logits shape={logits.shape}, " - f"mean={logits.float().mean().item():.4f}, " - f"std={logits.float().std().item():.4f}, " - f"min={logits.float().min().item():.4f}, " - f"max={logits.float().max().item():.4f}, " - f"sequences shape={sequences.shape}, " - f"attention_backend={getattr(get_model_config(self.actor_module[0]), 'attention_backend', 'unknown')}, " - f"multi_latent_attention={getattr(get_model_config(self.actor_module[0]), 'multi_latent_attention', 'unknown')}, " - f"q_lora_rank={getattr(get_model_config(self.actor_module[0]), 'q_lora_rank', 'unknown')}" - ) - if temperature != 1.0: logits.div_(temperature) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 77c4b1327e..1643c4a86d 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -350,11 +350,11 @@ def init_configs( # Determine attention backend. MLA (Multi-Latent Attention) with TE fused # attention can produce NaN/incorrect results; fall back to unfused for MLA. if "attention_backend" not in transformer_config_kwargs: - has_mla = getattr(provider, "multi_latent_attention", False) + # has_mla = getattr(provider, "multi_latent_attention", False) if flash_attn: transformer_config_kwargs["attention_backend"] = "flash" - elif has_mla: - transformer_config_kwargs["attention_backend"] = "unfused" + # elif has_mla: + # transformer_config_kwargs["attention_backend"] = "unfused" else: transformer_config_kwargs["attention_backend"] = "fused" From 6de7d5c9e3ff6cbe524e18e59d804238de2ca7b6 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Mon, 9 Mar 2026 20:01:34 +0000 Subject: [PATCH 12/18] x --- .../skyrl_train/gpu/gpu_ci/conftest.py | 1 - .../gpu/gpu_ci/test_router_replay.py | 205 +++++++++++++++--- 2 files changed, 169 insertions(+), 37 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 1394e2b43c..9ea416dcee 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -34,7 +34,6 @@ def ray_init_fixture(): env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env_vars["NVTE_FUSED_ATTN"] = "0" env_vars["RAY_CGRAPH_get_timeout"] = "600" - env_vars["SKYRL_DEBUG_LOGITS"] = "1" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 6ab38cd31c..477af7d64a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -28,7 +28,8 @@ # MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" REPLAY_NUM_LAYERS = 2 NUM_PROMPTS = 10 -N_SAMPLES_PER_PROMPT = 5 +N_SAMPLES_PER_PROMPT = 4 +MAX_GENERATE_LENGTH = 1024 def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: @@ -36,26 +37,84 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: cfg.trainer.policy.model.path = model_name cfg.trainer.micro_forward_batch_size_per_gpu = 2 cfg.trainer.micro_train_batch_size_per_gpu = 2 - cfg.trainer.use_sample_packing = False + cfg.trainer.use_sample_packing = True + # flash attn + mla works without sample packing, logprobs are crazy/wrong + # but flash-attn correctly throws error with sample packing + # we should add an assert that if you set use_sample_packing=False flash attn can accidentally be used cfg.trainer.logger = "console" if "moonlight" in model_name: + if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = {} # flash attn not supported for moonlight16b - cfg.trainer.policy.megatron_config.moe_token_dispatcher_type = "alltoall" - cfg.trainer.policy.megatron_config.moe_router_load_balancing_type = "seq_aux_loss" - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_aux_loss_coeff"] = 0 - cfg.trainer.policy.megatron_config.moe_router_score_function = "sigmoid" - cfg.trainer.policy.megatron_config.moe_router_enable_expert_bias = True - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_bias_update_rate"] = 0 - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_dtype"] = "fp32" - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_topk"] = 6 - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_pre_softmax"] = True - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_group_topk"] = 1 - cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_num_groups"] = 1 + # cfg.trainer.policy.megatron_config.moe_token_dispatcher_type = "alltoall" + # cfg.trainer.policy.megatron_config.moe_router_load_balancing_type = "seq_aux_loss" + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_aux_loss_coeff"] = 0 + # cfg.trainer.policy.megatron_config.moe_router_score_function = "sigmoid" + # cfg.trainer.policy.megatron_config.moe_router_enable_expert_bias = True + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_bias_update_rate"] = 0 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_dtype"] = "fp32" + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_topk"] = 6 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_pre_softmax"] = True + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_group_topk"] = 1 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_num_groups"] = 1 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers_in_last_pipeline_stage"] = 13 cfg.trainer.flash_attn = False validate_cfg(cfg) return cfg +def build_training_input_from_text_samples( + tokenizer: AutoTokenizer, prompt_response_pairs: list[tuple[str, str]] +) -> TrainingInputBatch: + prompts = [] + responses = [] + rewards = [] + loss_masks = [] + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + for prompt_text, response_text in prompt_response_pairs: + prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False) + response_ids = tokenizer.encode(response_text, add_special_tokens=False) + if tokenizer.eos_token_id is not None and (not response_ids or response_ids[-1] != tokenizer.eos_token_id): + response_ids.append(tokenizer.eos_token_id) + + prompts.append(prompt_ids) + responses.append(response_ids) + rewards.append([0.0] * len(response_ids)) + loss_masks.append([1] * len(response_ids)) + + sequences, attention_mask, response_mask, rewards_t, loss_mask_t, _, _ = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=prompts, + responses=responses, + rewards=rewards, + loss_masks=loss_masks, + ) + ) + + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) + training_input.metadata = {"response_length": num_actions} + return training_input + + @pytest.mark.megatron def test_megatron_router_replay(ray_init_fixture): """ @@ -227,6 +286,96 @@ def test_megatron_router_replay(ray_init_fixture): ray.shutdown() +@pytest.mark.megatron +def test_moonlight_logprobs(ray_init_fixture): + """ + Check magnitude of moonlight-16b-a3b logprobs without router replay. + """ + actor_group = None + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, + max_prompt_length=512, + env_class="gsm8k", + ) + training_input = build_training_input_from_text_samples( + tokenizer=tokenizer, + prompt_response_pairs=[ + ( + tokenizer.apply_chat_template( + prompt + if any(message["role"] == "system" for message in prompt) + else [{"role": "system", "content": "You are a helpful assistant."}] + prompt, + add_generation_prompt=True, + tokenize=False, + ), + " " + + ( + " ".join( + next( + ( + message["content"] + for message in reversed(prompt) + if message["role"] == "user" + ), + "", + ).split()[:16] + ).strip() + or "I will solve this step by step." + ), + ) + for prompt in input_batch["prompts"] + ], + ) + + cfg.trainer.placement.policy_num_gpus_per_node = 8 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.context_parallel_size = 1 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + **(cfg.trainer.policy.megatron_config.transformer_config_kwargs or {}), + "moe_enable_routing_replay": False, + } + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=8, + cfg=cfg, + ) + + with Timer("moonlight_forward"): + refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) + results = ray.get(refs) + + action_log_probs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] + valid_log_probs = action_log_probs[training_input["response_mask"].bool()] + avg_logprob = valid_log_probs.mean().item() + print( + f"Moonlight Megatron logprobs - mean: {avg_logprob:.6f}, std: {valid_log_probs.std().item():.6f}, " + f"num_tokens: {valid_log_probs.numel()}" + ) + + assert valid_log_probs.numel() > 0, "Expected at least one valid response token" + assert torch.isfinite(valid_log_probs).all().item(), "Expected all response logprobs to be finite" + assert -20.0 < avg_logprob < -0.01, f"Unexpected average logprob magnitude: {avg_logprob:.6f}" + finally: + if actor_group is not None: + for actor in actor_group._actor_handlers: + ray.kill(actor) + ray.shutdown() + @pytest.mark.megatron def test_logprobs(ray_init_fixture): """ @@ -238,7 +387,7 @@ def test_logprobs(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 8 cfg.generator.sampling_params = SamplingParams( - max_generate_length=128, + max_generate_length=MAX_GENERATE_LENGTH, logprobs=1, temperature=1.0, ) @@ -279,7 +428,7 @@ def test_logprobs(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=128, + max_generate_length=MAX_GENERATE_LENGTH, min_p=0.0, logprobs=1, ), @@ -338,7 +487,7 @@ def test_logprobs(ray_init_fixture): training_input.metadata = {"response_length": num_actions} cfg.trainer.placement.policy_num_gpus_per_node = 8 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 4 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 cfg.trainer.policy.megatron_config.context_parallel_size = 1 cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 @@ -346,11 +495,7 @@ def test_logprobs(ray_init_fixture): cfg.trainer.micro_forward_batch_size_per_gpu = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 - import os - - os.environ["SKYRL_DEBUG_LOGITS"] = "1" - - def run_megatron_forward(enable_replay: bool, debug: bool = False) -> torch.Tensor: + def run_megatron_forward(enable_replay: bool) -> torch.Tensor: cfg.trainer.policy.megatron_config.transformer_config_kwargs = { "moe_enable_routing_replay": enable_replay, } @@ -362,18 +507,6 @@ def run_megatron_forward(enable_replay: bool, debug: bool = False) -> torch.Tens cfg=cfg, ) - if debug: - diag = ray.get(actor_group.async_run_ray_method("pass_through", "debug_model_config"))[0] - print(f"\n=== Model Config (replay={enable_replay}) ===") - for k, v in diag.items(): - if k != "weight_stats": - print(f" {k}: {v}") - print(f" weight_stats ({len(diag['weight_stats'])} params):") - for name, stats in sorted(diag["weight_stats"].items()): - print( - f" {name}: shape={stats['shape']}, mean={stats['mean']:.6f}, std={stats['std']:.6f}, norm={stats['norm']:.2f}" - ) - refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) results = ray.get(refs) outputs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] @@ -382,7 +515,7 @@ def run_megatron_forward(enable_replay: bool, debug: bool = False) -> torch.Tens ray.kill(actor) return outputs - r3_logprobs = run_megatron_forward(enable_replay=True, debug=True) + r3_logprobs = run_megatron_forward(enable_replay=True) no_r3_logprobs = run_megatron_forward(enable_replay=False) r3_diff = (logprobs_t - r3_logprobs).abs() @@ -414,7 +547,7 @@ def test_forward_backward(ray_init_fixture): cfg.generator.inference_engine.enable_return_routed_experts = True cfg.generator.inference_engine.tensor_parallel_size = 8 cfg.generator.sampling_params = SamplingParams( - max_generate_length=128, + max_generate_length=MAX_GENERATE_LENGTH, logprobs=1, temperature=1.0, ) @@ -455,7 +588,7 @@ def test_forward_backward(ray_init_fixture): temperature=1.0, top_p=1.0, top_k=-1, - max_generate_length=128, + max_generate_length=MAX_GENERATE_LENGTH, min_p=0.0, logprobs=1, ), From 591af9beee0450dde9202fda45ef9a234c876df4 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Mon, 9 Mar 2026 21:34:52 +0000 Subject: [PATCH 13/18] x --- .../skyrl_train/gpu/gpu_ci/conftest.py | 1 - .../gpu/gpu_ci/test_router_replay.py | 174 +----------------- 2 files changed, 1 insertion(+), 174 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 9ea416dcee..297672c737 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -33,7 +33,6 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env_vars["NVTE_FUSED_ATTN"] = "0" - env_vars["RAY_CGRAPH_get_timeout"] = "600" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 477af7d64a..1504d04f96 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -29,7 +29,7 @@ REPLAY_NUM_LAYERS = 2 NUM_PROMPTS = 10 N_SAMPLES_PER_PROMPT = 4 -MAX_GENERATE_LENGTH = 1024 +MAX_GENERATE_LENGTH = 128 def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: @@ -114,178 +114,6 @@ def build_training_input_from_text_samples( training_input.metadata = {"response_length": num_actions} return training_input - -@pytest.mark.megatron -def test_megatron_router_replay(ray_init_fixture): - """ - Test that SkyRLGymGenerator returns rollout_inference_indices - for MoE models with enable_return_routed_experts=True. - """ - try: - cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) - cfg.trainer.strategy = "megatron" - cfg.generator.inference_engine.enable_return_routed_experts = True - cfg.generator.inference_engine.tensor_parallel_size = 2 - cfg.generator.sampling_params = SamplingParams( - max_generate_length=128, - logprobs=1, - temperature=1.0, - ) - cfg.generator.batched = False - cfg.generator.max_turns = 1 - - num_prompts = 1 - - tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) - - with InferenceEngineState.create( - cfg=cfg, - model=MOE_MODEL_NAME, - use_local=True, - backend="vllm", - sleep_level=1, - gpu_memory_utilization=0.9, - max_num_seqs=1, - ) as engines: - client = engines.client - - asyncio.run(client.wake_up()) - - generator = SkyRLGymGenerator( - generator_cfg=cfg.generator, - skyrl_gym_cfg=cfg.environment.skyrl_gym, - inference_engine_client=client, - tokenizer=tokenizer, - ) - - input_batch: GeneratorInput = get_test_generator_input( - model=MOE_MODEL_NAME, - num_prompts=num_prompts, - n_samples_per_prompt=1, - max_prompt_length=512, - env_class="gsm8k", - ) - input_batch["sampling_params"] = get_sampling_params_for_backend( - "vllm", - SamplingParams( - temperature=1.0, - top_p=1.0, - top_k=-1, - max_generate_length=128, - min_p=0.0, - logprobs=1, - ), - ) - - with Timer("generate_with_router_replay"): - generator_output = asyncio.run(generator.generate(input_batch)) - - # --- Basic output checks --- - assert ( - "rollout_inference_indices" in generator_output - ), "rollout_inference_indices missing from GeneratorOutput" - indices = generator_output["rollout_inference_indices"] - assert ( - indices is not None - ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" - - responses = generator_output["response_ids"] - assert len(indices) == len( - responses - ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" - - rewards = generator_output["rewards"] - if rewards and not isinstance(rewards[0], list): - rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] - (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( - convert_prompts_responses_to_batch_tensors( - tokenizer=tokenizer, - prompts=generator_output["prompt_token_ids"], - responses=responses, - rewards=rewards, - loss_masks=generator_output["loss_masks"], - logprobs=generator_output.get("rollout_logprobs"), - rollout_inference_indices=indices, - ) - ) - - assert rii_tensor is not None - rii_tensor = rii_tensor[:, :, :REPLAY_NUM_LAYERS, :] - - num_actions = response_mask.shape[1] - batch_size = sequences.shape[0] - training_input = TrainingInputBatch( - { - "sequences": sequences, - "attention_mask": attention_mask, - "response_mask": response_mask, - "rewards": rewards_t, - "loss_mask": loss_mask_t, - "rollout_logprobs": ( - logprobs_t - if logprobs_t is not None - else torch.zeros((batch_size, num_actions), dtype=torch.float32) - ), - "rollout_inference_indices": rii_tensor, - "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), - "action_mask": response_mask.to(dtype=torch.int64), - } - ) - training_input.metadata = {"response_length": num_actions} - - cfg.trainer.policy.megatron_config.transformer_config_kwargs = { - "num_layers": REPLAY_NUM_LAYERS, - "moe_enable_routing_replay": True, - } - cfg.trainer.placement.policy_num_gpus_per_node = 2 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 - cfg.trainer.micro_forward_batch_size_per_gpu = 1 - cfg.trainer.micro_train_batch_size_per_gpu = 1 - - actor_group = init_worker_with_type( - "policy", - shared_pg=None, - colocate_all=False, - num_gpus_per_node=2, - cfg=cfg, - ) - - expected_per_layer = _split_replay_indices(rii_tensor.to(torch.long)) - - state = ray.get( - actor_group.async_run_ray_method( - "pass_through", - "debug_setup_router_replay_state", - data=training_input, - ) - )[0] - - assert state is not None, "Worker returned None state" - assert ( - "REPLAY_FORWARD" in state["action"] - ), f"RouterReplay action should be REPLAY_FORWARD, got: {state['action']}" - assert state["num_instances"] == len(expected_per_layer), ( - f"Expected {len(expected_per_layer)} replay instances (one per layer), " f"got {state['num_instances']}" - ) - for layer_idx, (got, expected) in enumerate(zip(state["target_indices"], expected_per_layer)): - assert torch.equal( - got.to(torch.long), expected.to(torch.long) - ), f"Layer {layer_idx}: Megatron target indices differ from vLLM indices" - print( - f"PASSED: vLLM routing indices ({rii_tensor.shape}) correctly " - f"loaded into {state['num_instances']} Megatron RouterReplay instances" - ) - - finally: - ray.shutdown() - - @pytest.mark.megatron def test_moonlight_logprobs(ray_init_fixture): """ From acb35ec2692afd2684b079787fb274257a7aab42 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 10 Mar 2026 22:33:37 +0000 Subject: [PATCH 14/18] clean up --- .../inference_engines/vllm/vllm_engine.py | 2 +- .../skyrl_train/utils/replay_utils.py | 44 +------ .../megatron/megatron_model_wrapper.py | 6 +- .../workers/megatron/megatron_worker.py | 119 ++++-------------- .../skyrl_train/workers/worker_utils.py | 1 + skyrl/train/dataset/replay_buffer.py | 5 + skyrl/train/utils/utils.py | 2 +- .../skyrl_train/gpu/gpu_ci/conftest.py | 6 +- .../gpu/gpu_ci/test_router_replay.py | 104 +-------------- 9 files changed, 42 insertions(+), 247 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 048fba7ec8..11541786de 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -168,7 +168,7 @@ def _postprocess_outputs(self, outputs): if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params - if len(rollout_inference_indices) == 0 and _routed_experts is None: + if len(rollout_inference_indices) == 0 and rollout_inference_indices[0] is None: rollout_inference_indices = None # hack: assume uniform sampling params return InferenceEngineOutput( diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 68d266e052..cf016f7e46 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -4,7 +4,6 @@ import torch from typing import List -from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch def _patch_topk_router_layer_number(): @@ -46,6 +45,8 @@ def _patch_alltoall_dispatcher_for_replay(): routing_map.sum() < num_tokens * topk, leading to a split size mismatch in the alltoall collective. We fix this by deriving num_out_tokens from the routing map instead of the static num_tokens * topk formula. + + Reference: https://github.com/verl-project/verl/pull/4986 """ try: from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher @@ -119,7 +120,7 @@ def _remove_left_padding_from_indices( return new_rii -def _setup_per_microbatch_replay( +def setup_per_microbatch_replay( rollout_inference_indices: torch.Tensor, attention_mask: torch.Tensor, ) -> None: @@ -176,45 +177,6 @@ def _setup_per_microbatch_replay( RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) -def setup_router_replay_forward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: - """ - Set up router replay for forward pass (ref/policy inference). - """ - if not enable_router_replay: - return False - - rollout_inference_indices = data.get("rollout_inference_indices") - if rollout_inference_indices is None: - return False - - from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction - - RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) - RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - - return True - - -def setup_router_replay_backward(data: TrainingInputBatch, enable_router_replay: bool) -> bool: - """ - Set up router replay for training forward/backward pass. - """ - if not enable_router_replay: - return False - - rollout_inference_indices = data.get("rollout_inference_indices") - if rollout_inference_indices is None: - return False - - from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction - - RouterReplay.set_replay_data(_split_replay_indices(rollout_inference_indices)) - # Use REPLAY_FORWARD - Megatron handles REPLAY_BACKWARD automatically - RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - - return True - - def clear_router_replay(): """Clear all router replay state.""" from megatron.core.transformer.moe.router_replay import RouterReplay diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 540e9328e2..a7f6ade581 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -24,7 +24,7 @@ remove_left_padding, recover_left_padding, ) -from skyrl.backends.skyrl_train.utils.replay_utils import _setup_per_microbatch_replay +from skyrl.backends.skyrl_train.utils.replay_utils import setup_per_microbatch_replay class MegatronModelWrapper: @@ -107,7 +107,7 @@ def forward_step(batch_iter, model): rollout_inference_indices = batch.pop("rollout_inference_indices", None) if rollout_inference_indices is not None: - _setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) @@ -363,7 +363,7 @@ def forward_step(batch_iter, model): rollout_inference_indices = batch.pop("rollout_inference_indices", None) if rollout_inference_indices is not None: - _setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 1643c4a86d..cd5343c5bb 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -33,7 +33,7 @@ from skyrl.train.utils.utils import update_model_config, str_to_torch_dtype from skyrl.backends.skyrl_train.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch -from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics, all_reduce_metrics +from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics, all_reduce_metrics, BatchIterator from skyrl.backends.skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -253,63 +253,6 @@ def extract_weights(self, dtype: torch.dtype): class MegatronWorker: - def debug_model_config(self): - """Return model config diagnostics for debugging logprob mismatch.""" - from skyrl.backends.skyrl_train.distributed.megatron.megatron_utils import get_model_config - - config = get_model_config(self.actor_module[0]) - diag = { - "attention_backend": str(getattr(config, "attention_backend", "unknown")), - "multi_latent_attention": getattr(config, "multi_latent_attention", "unknown"), - "q_lora_rank": getattr(config, "q_lora_rank", "unknown"), - "kv_lora_rank": getattr(config, "kv_lora_rank", "unknown"), - "qk_head_dim": getattr(config, "qk_head_dim", "unknown"), - "qk_pos_emb_head_dim": getattr(config, "qk_pos_emb_head_dim", "unknown"), - "v_head_dim": getattr(config, "v_head_dim", "unknown"), - "num_layers": getattr(config, "num_layers", "unknown"), - "hidden_size": getattr(config, "hidden_size", "unknown"), - "num_attention_heads": getattr(config, "num_attention_heads", "unknown"), - "rope_type": getattr(config, "rope_type", "unknown"), - "layernorm_epsilon": getattr(config, "layernorm_epsilon", "unknown"), - "sequence_parallel": getattr(config, "sequence_parallel", "unknown"), - } - weight_stats = {} - model = self.actor_module[0] - for name, param in model.named_parameters(): - if any(k in name for k in ["layers.0.", "output_layer", "word_embeddings"]): - weight_stats[name] = { - "shape": list(param.shape), - "mean": param.float().mean().item(), - "std": param.float().std().item(), - "norm": param.float().norm().item(), - } - diag["weight_stats"] = weight_stats - return diag - - def _read_router_replay_state(self): - """Read the current RouterReplay state from all instances.""" - from megatron.core.transformer.moe.router_replay import RouterReplay - - # See https://docs.nvidia.com/megatron-core/developer-guide/0.15.0/api-guide/router_replay.html docs for more info - instances = RouterReplay.global_router_replay_instances or [] - action = instances[0].router_replay_action if instances else None - - target_indices = [inst.target_topk_idx.detach().cpu() for inst in instances if inst.target_topk_idx is not None] - - return { - "action": str(action), - "target_indices": target_indices, - "num_instances": len(instances), - } - - def debug_setup_router_replay_state(self, data: TrainingInputBatch): - from skyrl.backends.skyrl_train.utils.replay_utils import setup_router_replay_forward, clear_router_replay - - setup_router_replay_forward(data, enable_router_replay=True) - state = self._read_router_replay_state() - clear_router_replay() - return state - def init_configs( self, model_path, @@ -347,17 +290,6 @@ def init_configs( bridge = AutoBridge.from_hf_pretrained(model_path, trust_remote_code=True) provider = bridge.to_megatron_provider() - # Determine attention backend. MLA (Multi-Latent Attention) with TE fused - # attention can produce NaN/incorrect results; fall back to unfused for MLA. - if "attention_backend" not in transformer_config_kwargs: - # has_mla = getattr(provider, "multi_latent_attention", False) - if flash_attn: - transformer_config_kwargs["attention_backend"] = "flash" - # elif has_mla: - # transformer_config_kwargs["attention_backend"] = "unfused" - else: - transformer_config_kwargs["attention_backend"] = "fused" - provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 @@ -365,7 +297,7 @@ def init_configs( provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size provider.expert_tensor_parallel_size = megatron_config.expert_tensor_parallel_size provider.sequence_parallel = megatron_config.tensor_model_parallel_size > 1 - provider.attention_backend = transformer_config_kwargs["attention_backend"] + provider.attention_backend = "flash" if flash_attn else "fused" provider.variable_seq_lengths = True provider.masked_softmax_fusion = True # Apply explicit MoE config fields to the provider. @@ -510,7 +442,6 @@ def forward(self, data: TrainingInputBatch): log_probs = log_probs.to("cpu") output = TrainingOutputBatch({"output": log_probs}) output.metadata = data.metadata - self._last_router_replay_state = self._read_router_replay_state() clear_router_replay() return output @@ -685,31 +616,29 @@ def forward_backward( # Move data to GPU data.to(torch.cuda.current_device()) - # Chunk manually so we can propagate rollout_inference_indices for - # per-micro-batch router replay (BatchIterator/Experience don't carry them). + # Build micro-batch dicts expected by forward_backward_mini_batch micro_buffer = [] - for micro in data.chunk(micro_batch_size): - sequences = micro["sequences"] - attention_mask = micro["attention_mask"] + for experience in BatchIterator(data, micro_batch_size, drop_last=False): + sequences = experience.sequences + attention_mask = experience.attention_mask position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) - micro_dict = { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": micro.metadata["response_length"], - "old_action_log_probs": micro.get("action_log_probs"), - "base_action_log_probs": micro.get("base_action_log_probs"), - "advantages": micro.get("advantages"), - "loss_mask": micro.get("loss_mask"), - "rollout_action_logprobs": micro.get("rollout_logprobs"), - "action_mask": micro.get("action_mask"), - } - rii = micro.get("rollout_inference_indices") - if rii is not None and self.enable_router_replay: - micro_dict["rollout_inference_indices"] = rii - micro_buffer.append(micro_dict) + micro_buffer.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": experience.num_actions, + "old_action_log_probs": experience.action_log_probs, + "base_action_log_probs": experience.base_action_log_probs, + "advantages": experience.advantages, + "loss_mask": experience.loss_mask, + "rollout_action_logprobs": experience.rollout_logprobs, + "action_mask": experience.action_mask, + "rollout_inference_indices": experience.rollout_inference_indices, + } + ) if not micro_buffer: return {} @@ -750,7 +679,6 @@ def forward_backward( if all_loss_fn_outputs: status["loss_fn_outputs"] = all_loss_fn_outputs - self._last_router_replay_state = self._read_router_replay_state() clear_router_replay() return status @@ -908,11 +836,6 @@ def init_model(self, model_path, num_training_steps: int = 1e9): flash_attn=self.cfg.flash_attn, ) - if self.enable_router_replay: - from skyrl.backends.skyrl_train.utils.replay_utils import _patch_topk_router_layer_number - - _patch_topk_router_layer_number() - self.actor_module = self.make_megatron_module( wrap_with_ddp=False, ddp_config=None, diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index eb76c5ee7d..312e3ee52c 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -83,6 +83,7 @@ def batch_to_experience(batch: TrainingInputBatch): action_mask=batch.get("response_mask"), num_actions=batch.metadata["response_length"], # int rollout_logprobs=batch.get("rollout_logprobs"), + rollout_inference_indices=batch.get("rollout_inference_indices"), # additional info # can be used to log metrics etc for micro-batches in the worker info={}, diff --git a/skyrl/train/dataset/replay_buffer.py b/skyrl/train/dataset/replay_buffer.py index 07846937a6..109f1957ec 100644 --- a/skyrl/train/dataset/replay_buffer.py +++ b/skyrl/train/dataset/replay_buffer.py @@ -66,6 +66,7 @@ class Experience: loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]] action_mask: Optional[Integer[torch.Tensor, "batch response_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]] + rollout_inference_indices: Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]] num_actions: int info: Optional[dict] kl: Optional[Float[torch.Tensor, "batch response_len"]] = None @@ -92,6 +93,8 @@ def to_device(self, device: torch.device) -> None: self.action_mask = to(self.action_mask, device) if self.rollout_logprobs is not None: self.rollout_logprobs = to(self.rollout_logprobs, device) + if self.rollout_inference_indices is not None: + self.rollout_inference_indices = to(self.rollout_inference_indices, device) def pin_memory(self): self.sequences = pin_memory(self.sequences) @@ -113,6 +116,8 @@ def pin_memory(self): self.action_mask = self.action_mask.pin_memory() if self.rollout_logprobs is not None: self.rollout_logprobs = self.rollout_logprobs.pin_memory() + if self.rollout_inference_indices is not None: + self.rollout_inference_indices = self.rollout_inference_indices.pin_memory() return self diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index d06f16b51f..3008b615c0 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -801,7 +801,7 @@ def run_p2p_access_check(): if device_count < 2: return False - # # Check P2P access between all GPU pairs + # Check P2P access between all GPU pairs for i in range(device_count): for j in range(device_count): if i != j: diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 297672c737..17bafaa98a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -32,7 +32,7 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - env_vars["NVTE_FUSED_ATTN"] = "0" + # env_vars["NVTE_FUSED_ATTN"] = "0" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") @@ -40,10 +40,6 @@ def ray_init_fixture(): raise RuntimeError("SKYRL_PYTHONPATH_EXPORT is set but PYTHONPATH is not defined in environment") env_vars["PYTHONPATH"] = pythonpath - # RAY_CGRAPH_get_timeout must be set in os.environ so that vLLM subprocesses - # (EngineCore) inherit it — runtime_env alone doesn't propagate to subprocesses. - os.environ["RAY_CGRAPH_get_timeout"] = env_vars.pop("RAY_CGRAPH_get_timeout") - logger.info(f"Initializing Ray with environment variables: {env_vars}") ray.init(runtime_env={"env_vars": env_vars}) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 1504d04f96..9581b8f990 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -22,7 +22,6 @@ from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch -from skyrl.backends.skyrl_train.utils.replay_utils import _split_replay_indices MOE_MODEL_NAME = "/home/ray/moonlight16b" # MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" @@ -85,14 +84,12 @@ def build_training_input_from_text_samples( rewards.append([0.0] * len(response_ids)) loss_masks.append([1] * len(response_ids)) - sequences, attention_mask, response_mask, rewards_t, loss_mask_t, _, _ = ( - convert_prompts_responses_to_batch_tensors( - tokenizer=tokenizer, - prompts=prompts, - responses=responses, - rewards=rewards, - loss_masks=loss_masks, - ) + sequences, attention_mask, response_mask, rewards_t, loss_mask_t, _, _ = convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=prompts, + responses=responses, + rewards=rewards, + loss_masks=loss_masks, ) num_actions = response_mask.shape[1] @@ -114,95 +111,6 @@ def build_training_input_from_text_samples( training_input.metadata = {"response_length": num_actions} return training_input -@pytest.mark.megatron -def test_moonlight_logprobs(ray_init_fixture): - """ - Check magnitude of moonlight-16b-a3b logprobs without router replay. - """ - actor_group = None - try: - cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) - cfg.trainer.strategy = "megatron" - - tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) - input_batch: GeneratorInput = get_test_generator_input( - model=MOE_MODEL_NAME, - num_prompts=NUM_PROMPTS, - n_samples_per_prompt=N_SAMPLES_PER_PROMPT, - max_prompt_length=512, - env_class="gsm8k", - ) - training_input = build_training_input_from_text_samples( - tokenizer=tokenizer, - prompt_response_pairs=[ - ( - tokenizer.apply_chat_template( - prompt - if any(message["role"] == "system" for message in prompt) - else [{"role": "system", "content": "You are a helpful assistant."}] + prompt, - add_generation_prompt=True, - tokenize=False, - ), - " " - + ( - " ".join( - next( - ( - message["content"] - for message in reversed(prompt) - if message["role"] == "user" - ), - "", - ).split()[:16] - ).strip() - or "I will solve this step by step." - ), - ) - for prompt in input_batch["prompts"] - ], - ) - - cfg.trainer.placement.policy_num_gpus_per_node = 8 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 - cfg.trainer.micro_forward_batch_size_per_gpu = 1 - cfg.trainer.micro_train_batch_size_per_gpu = 1 - cfg.trainer.policy.megatron_config.transformer_config_kwargs = { - **(cfg.trainer.policy.megatron_config.transformer_config_kwargs or {}), - "moe_enable_routing_replay": False, - } - - actor_group = init_worker_with_type( - "policy", - shared_pg=None, - colocate_all=False, - num_gpus_per_node=8, - cfg=cfg, - ) - - with Timer("moonlight_forward"): - refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) - results = ray.get(refs) - - action_log_probs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] - valid_log_probs = action_log_probs[training_input["response_mask"].bool()] - avg_logprob = valid_log_probs.mean().item() - print( - f"Moonlight Megatron logprobs - mean: {avg_logprob:.6f}, std: {valid_log_probs.std().item():.6f}, " - f"num_tokens: {valid_log_probs.numel()}" - ) - - assert valid_log_probs.numel() > 0, "Expected at least one valid response token" - assert torch.isfinite(valid_log_probs).all().item(), "Expected all response logprobs to be finite" - assert -20.0 < avg_logprob < -0.01, f"Unexpected average logprob magnitude: {avg_logprob:.6f}" - finally: - if actor_group is not None: - for actor in actor_group._actor_handlers: - ray.kill(actor) - ray.shutdown() @pytest.mark.megatron def test_logprobs(ray_init_fixture): From 5ad9426b3ac62d65f053f1bfad78b4eda107f6c1 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 11 Mar 2026 16:54:03 +0000 Subject: [PATCH 15/18] rename var and clean up --- .../skyrl_train/inference_engines/base.py | 10 +- .../inference_engine_client.py | 30 +++-- .../inference_engines/vllm/vllm_engine.py | 10 +- skyrl/backends/skyrl_train/training_batch.py | 2 +- .../skyrl_train/utils/replay_utils.py | 34 +++--- .../megatron/megatron_model_wrapper.py | 14 +-- .../workers/megatron/megatron_worker.py | 26 ++--- .../skyrl_train/workers/worker_utils.py | 2 +- skyrl/train/dataset/preprocess.py | 16 +-- skyrl/train/dataset/replay_buffer.py | 10 +- skyrl/train/generators/base.py | 2 +- skyrl/train/generators/skyrl_gym_generator.py | 107 +++++++++--------- skyrl/train/trainer.py | 10 +- .../gpu/gpu_ci/test_router_replay.py | 94 +++++++++++++-- .../gpu/gpu_ci/test_skyrl_gym_generator.py | 40 ++++++- 15 files changed, 259 insertions(+), 148 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index ddecc965fb..819a071603 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -29,7 +29,7 @@ class InferenceEngineOutput(TypedDict): response_ids: List[List[int]] stop_reasons: List[str] response_logprobs: Optional[List[List[float]]] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] + rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] class InferenceEngineInterface(ABC): @@ -64,7 +64,7 @@ async def sample( all_responses = [] all_stop_reasons = [] all_response_logprobs = [] - all_rollout_inference_indices = [] + all_rollout_expert_indices = [] for _ in range(num_samples): input_batch: InferenceEngineInput = { @@ -81,15 +81,15 @@ async def sample( all_stop_reasons.append(output["stop_reasons"][0]) if output.get("response_logprobs") is not None: all_response_logprobs.append(output["response_logprobs"][0]) - if output.get("rollout_inference_indices") is not None: - all_rollout_inference_indices.append(output["rollout_inference_indices"][0]) + if output.get("rollout_expert_indices") is not None: + all_rollout_expert_indices.append(output["rollout_expert_indices"][0]) return { "response_ids": all_response_ids, "responses": all_responses, "stop_reasons": all_stop_reasons, "response_logprobs": all_response_logprobs if all_response_logprobs else None, - "rollout_inference_indices": all_rollout_inference_indices if all_rollout_inference_indices else None, + "rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None, } @abstractmethod diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index 264e060cd1..d418089a72 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -153,10 +153,10 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu stop_reasons: list[str] = [""] * n response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)] response_ids: List[List[int]] = [[] for _ in range(n)] - rollout_inference_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)] + rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)] # a bit hacky for now add_resp_logprobs = False - add_rollout_inference_indices = False + add_rollout_expert_indices = False for indices, result in zip(indices_list, results): for local_idx, original_idx in enumerate(indices): @@ -166,16 +166,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu if result.get("response_logprobs", None): add_resp_logprobs = True response_logprobs[original_idx] = result["response_logprobs"][local_idx] - if result.get("rollout_inference_indices", None): - add_rollout_inference_indices = True - rollout_inference_indices[original_idx] = result["rollout_inference_indices"][local_idx] + if result.get("rollout_expert_indices", None): + add_rollout_expert_indices = True + rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx] return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs if add_resp_logprobs else None, - rollout_inference_indices=rollout_inference_indices if add_rollout_inference_indices else None, + rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None, ) def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int: @@ -271,7 +271,7 @@ async def _generate_single_with_retry( # 2. Initialize fields we want to accumulate or update in each loop iteration accum_response_ids: List[int] = [] accum_response_logprobs: List[float] = [] - accum_rollout_inference_indices: List[List[List[int]]] = [] + accum_rollout_expert_indices: List[List[List[int]]] = [] stop_reason: str = "abort" # We only use it if generation is completed in one turn to maintain original behavior with no retry. @@ -307,10 +307,10 @@ async def _generate_single_with_retry( new_response_logprobs_list: Optional[List[List[float]]] = partial_response.get("response_logprobs", None) if new_response_logprobs_list is not None and len(new_response_logprobs_list) > 0: new_response_logprobs = new_response_logprobs_list[0] - new_rollout_inference_indices: Optional[List[List[List[int]]]] = None - new_rollout_inference_indices_list = partial_response.get("rollout_inference_indices", None) - if new_rollout_inference_indices_list is not None and len(new_rollout_inference_indices_list) > 0: - new_rollout_inference_indices = new_rollout_inference_indices_list[0] + new_rollout_expert_indices: Optional[List[List[List[int]]]] = None + new_rollout_expert_indices_list = partial_response.get("rollout_expert_indices", None) + if new_rollout_expert_indices_list is not None and len(new_rollout_expert_indices_list) > 0: + new_rollout_expert_indices = new_rollout_expert_indices_list[0] # 3.4 Aborted without generating tokens, so partial_response is useless. if stop_reason == "abort" and len(new_response_ids) == 0: @@ -320,8 +320,8 @@ async def _generate_single_with_retry( accum_response_ids.extend(new_response_ids) if new_response_logprobs is not None: accum_response_logprobs.extend(new_response_logprobs) - if new_rollout_inference_indices is not None: - accum_rollout_inference_indices.extend(new_rollout_inference_indices) + if new_rollout_expert_indices is not None: + accum_rollout_expert_indices.extend(new_rollout_expert_indices) num_turns += 1 # 4. Build the final response and return. @@ -334,9 +334,7 @@ async def _generate_single_with_retry( stop_reasons=[stop_reason], response_ids=[accum_response_ids], response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None, - rollout_inference_indices=( - [accum_rollout_inference_indices] if len(accum_rollout_inference_indices) > 0 else None - ), + rollout_expert_indices=([accum_rollout_expert_indices] if len(accum_rollout_expert_indices) > 0 else None), ) async def _chat_completion_with_retry( diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 11541786de..aa62a74f7c 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -135,7 +135,7 @@ def _postprocess_outputs(self, outputs): stop_reasons: List[str] = [] response_ids: List[List[int]] = [] response_logprobs: Optional[List[List[float]]] = [] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] = [] + rollout_expert_indices: Optional[List[List[List[List[int]]]]] = [] for output in outputs: # TODO(tgriggs): Support n>1 sampling. @@ -163,20 +163,20 @@ def _postprocess_outputs(self, outputs): _routed_experts = resp.routed_experts.tolist() else: _routed_experts = resp.routed_experts - rollout_inference_indices.append(_routed_experts) + rollout_expert_indices.append(_routed_experts) if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params - if len(rollout_inference_indices) == 0 and rollout_inference_indices[0] is None: - rollout_inference_indices = None # hack: assume uniform sampling params + if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None: + rollout_expert_indices = None # hack: assume uniform sampling params return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs, - rollout_inference_indices=rollout_inference_indices, + rollout_expert_indices=rollout_expert_indices, ) def _get_engine(self): diff --git a/skyrl/backends/skyrl_train/training_batch.py b/skyrl/backends/skyrl_train/training_batch.py index 8b00aeaa92..5295c6c4bc 100644 --- a/skyrl/backends/skyrl_train/training_batch.py +++ b/skyrl/backends/skyrl_train/training_batch.py @@ -369,7 +369,7 @@ class TrainingInput(TypedDict, total=False): kl: Float[torch.Tensor, "batch_size seq_len"] rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] - rollout_inference_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]] + rollout_expert_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]] class TrainingInputBatch(TensorBatch[TrainingInput]): diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index cf016f7e46..72d14e67db 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -73,24 +73,24 @@ def patched_preprocess(self, routing_map): MoEAlltoAllTokenDispatcher._preprocess_patched = True -def _split_replay_indices(rollout_inference_indices: torch.Tensor) -> List[torch.Tensor]: - if rollout_inference_indices is None: +def _split_replay_indices(rollout_expert_indices: torch.Tensor) -> List[torch.Tensor]: + if rollout_expert_indices is None: return None - if rollout_inference_indices.dim() != 4: - raise ValueError(f"Expected 4D replay indices, got shape {rollout_inference_indices.shape}") - per_layer = rollout_inference_indices.permute(2, 0, 1, 3).contiguous() + if rollout_expert_indices.dim() != 4: + raise ValueError(f"Expected 4D replay indices, got shape {rollout_expert_indices.shape}") + per_layer = rollout_expert_indices.permute(2, 0, 1, 3).contiguous() # flatten [batch, seq, topk] to [batch * seq, topk] for each layer return [per_layer[i].reshape(-1, per_layer.shape[-1]) for i in range(per_layer.shape[0])] def _remove_left_padding_from_indices( - rollout_inference_indices: torch.Tensor, + rollout_expert_indices: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: """Apply the same left-padding removal as remove_left_padding to routing indices. Args: - rollout_inference_indices: [batch, padded_seq_len, layers, topk] + rollout_expert_indices: [batch, padded_seq_len, layers, topk] attention_mask: [batch, padded_seq_len] (int or bool) Returns: @@ -105,23 +105,23 @@ def _remove_left_padding_from_indices( pad_size = (sp_world_size - effective_seq_len % sp_world_size) % sp_world_size effective_seq_len += pad_size - batch_size = rollout_inference_indices.shape[0] + batch_size = rollout_expert_indices.shape[0] new_rii = torch.zeros( batch_size, effective_seq_len, - rollout_inference_indices.shape[2], - rollout_inference_indices.shape[3], - dtype=rollout_inference_indices.dtype, - device=rollout_inference_indices.device, + rollout_expert_indices.shape[2], + rollout_expert_indices.shape[3], + dtype=rollout_expert_indices.dtype, + device=rollout_expert_indices.device, ) for i in range(batch_size): mask = attention_mask[i].bool() - new_rii[i, : seq_lens[i]] = rollout_inference_indices[i, mask] + new_rii[i, : seq_lens[i]] = rollout_expert_indices[i, mask] return new_rii -def setup_per_microbatch_replay( - rollout_inference_indices: torch.Tensor, +def setup_per_microbatch_replay_forward( + rollout_expert_indices: torch.Tensor, attention_mask: torch.Tensor, ) -> None: """Set up RouterReplay for a single micro-batch, aligning indices @@ -141,8 +141,10 @@ def setup_per_microbatch_replay( _patch_alltoall_dispatcher_for_replay() - aligned = _remove_left_padding_from_indices(rollout_inference_indices, attention_mask) + aligned = _remove_left_padding_from_indices(rollout_expert_indices, attention_mask) + # handles megatron sequence parallelism across the tensor model parallel region + # since we automatically enable sequence parallelism when TP > 1 tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: tp_rank = mpu.get_tensor_model_parallel_rank() diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index a7f6ade581..c7d1c640b8 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -24,7 +24,7 @@ remove_left_padding, recover_left_padding, ) -from skyrl.backends.skyrl_train.utils.replay_utils import setup_per_microbatch_replay +from skyrl.backends.skyrl_train.utils.replay_utils import setup_per_microbatch_replay_forward class MegatronModelWrapper: @@ -105,9 +105,9 @@ def collection_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) - rollout_inference_indices = batch.pop("rollout_inference_indices", None) - if rollout_inference_indices is not None: - setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + rollout_expert_indices = batch.pop("rollout_expert_indices", None) + if rollout_expert_indices is not None: + setup_per_microbatch_replay_forward(rollout_expert_indices, batch["attention_mask"]) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) @@ -361,9 +361,9 @@ def loss_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) - rollout_inference_indices = batch.pop("rollout_inference_indices", None) - if rollout_inference_indices is not None: - setup_per_microbatch_replay(rollout_inference_indices, batch["attention_mask"]) + rollout_expert_indices = batch.pop("rollout_expert_indices", None) + if rollout_expert_indices is not None: + setup_per_microbatch_replay_forward(rollout_expert_indices, batch["attention_mask"]) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index cd5343c5bb..ec67b91897 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -33,7 +33,7 @@ from skyrl.train.utils.utils import update_model_config, str_to_torch_dtype from skyrl.backends.skyrl_train.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch -from skyrl.backends.skyrl_train.workers.worker_utils import reduce_metrics, all_reduce_metrics, BatchIterator +from skyrl.backends.skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics from skyrl.backends.skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -289,7 +289,6 @@ def init_configs( bridge = AutoBridge.from_hf_pretrained(model_path, trust_remote_code=True) provider = bridge.to_megatron_provider() - provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32 @@ -417,16 +416,17 @@ def forward(self, data: TrainingInputBatch): num_actions = micro.metadata["response_length"] position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) - micro_dict = { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": num_actions, - } - rii = micro.get("rollout_inference_indices") - if rii is not None and self.enable_router_replay: - micro_dict["rollout_inference_indices"] = rii - micro_dicts.append(micro_dict) + micro_dicts.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": num_actions, + "rollout_expert_indices": ( + micro.get("rollout_expert_indices") if self.enable_router_replay else None + ), + } + ) self.model.eval() seq_len = micro_dicts[0]["sequences"].shape[1] @@ -636,7 +636,7 @@ def forward_backward( "loss_mask": experience.loss_mask, "rollout_action_logprobs": experience.rollout_logprobs, "action_mask": experience.action_mask, - "rollout_inference_indices": experience.rollout_inference_indices, + "rollout_expert_indices": experience.rollout_expert_indices if self.enable_router_replay else None, } ) diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index 312e3ee52c..b8ad1b09fb 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -83,7 +83,7 @@ def batch_to_experience(batch: TrainingInputBatch): action_mask=batch.get("response_mask"), num_actions=batch.metadata["response_length"], # int rollout_logprobs=batch.get("rollout_logprobs"), - rollout_inference_indices=batch.get("rollout_inference_indices"), + rollout_expert_indices=batch.get("rollout_expert_indices"), # additional info # can be used to log metrics etc for micro-batches in the worker info={}, diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index be9c8fbcce..67907e977f 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -32,7 +32,7 @@ def convert_prompts_responses_to_batch_tensors( rewards: List[List[float]], loss_masks: List[List[int]], logprobs: Optional[List[List[float]]] = None, - rollout_inference_indices: Optional[List[List[List[List[int]]]]] = None, + rollout_expert_indices: Optional[List[List[List[List[int]]]]] = None, ) -> Tuple[ Float[torch.Tensor, "batch seq_len"], Float[torch.Tensor, "batch seq_len"], @@ -131,20 +131,20 @@ def convert_prompts_responses_to_batch_tensors( ] logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float) - rollout_inference_indices_tensor = None - if rollout_inference_indices: - first_non_empty = next((x for x in rollout_inference_indices if x), None) + rollout_expert_indices_tensor = None + if rollout_expert_indices: + first_non_empty = next((x for x in rollout_expert_indices if x), None) if first_non_empty: total_seq_len = max_input_len + max_output_len num_layers = len(first_non_empty[0]) topk = len(first_non_empty[0][0]) if num_layers > 0 else 0 - padded = torch.zeros(len(rollout_inference_indices), total_seq_len, num_layers, topk, dtype=torch.int32) - for i, sample_indices in enumerate(rollout_inference_indices): + padded = torch.zeros(len(rollout_expert_indices), total_seq_len, num_layers, topk, dtype=torch.int32) + for i, sample_indices in enumerate(rollout_expert_indices): if sample_indices: left_pad = max_input_len - prompt_token_lens[i] n = min(len(sample_indices), total_seq_len - left_pad) padded[i, left_pad : left_pad + n] = torch.tensor(sample_indices[:n], dtype=torch.int32) - rollout_inference_indices_tensor = padded + rollout_expert_indices_tensor = padded return ( sequences, @@ -153,5 +153,5 @@ def convert_prompts_responses_to_batch_tensors( ret_rewards, ret_loss_masks, logprobs_tensor, - rollout_inference_indices_tensor, + rollout_expert_indices_tensor, ) diff --git a/skyrl/train/dataset/replay_buffer.py b/skyrl/train/dataset/replay_buffer.py index 109f1957ec..072c65fdf7 100644 --- a/skyrl/train/dataset/replay_buffer.py +++ b/skyrl/train/dataset/replay_buffer.py @@ -66,7 +66,7 @@ class Experience: loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]] action_mask: Optional[Integer[torch.Tensor, "batch response_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]] - rollout_inference_indices: Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]] + rollout_expert_indices: Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]] num_actions: int info: Optional[dict] kl: Optional[Float[torch.Tensor, "batch response_len"]] = None @@ -93,8 +93,8 @@ def to_device(self, device: torch.device) -> None: self.action_mask = to(self.action_mask, device) if self.rollout_logprobs is not None: self.rollout_logprobs = to(self.rollout_logprobs, device) - if self.rollout_inference_indices is not None: - self.rollout_inference_indices = to(self.rollout_inference_indices, device) + if self.rollout_expert_indices is not None: + self.rollout_expert_indices = to(self.rollout_expert_indices, device) def pin_memory(self): self.sequences = pin_memory(self.sequences) @@ -116,8 +116,8 @@ def pin_memory(self): self.action_mask = self.action_mask.pin_memory() if self.rollout_logprobs is not None: self.rollout_logprobs = self.rollout_logprobs.pin_memory() - if self.rollout_inference_indices is not None: - self.rollout_inference_indices = self.rollout_inference_indices.pin_memory() + if self.rollout_expert_indices is not None: + self.rollout_expert_indices = self.rollout_expert_indices.pin_memory() return self diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index 530fbbbe84..49b3ecc1ac 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -39,7 +39,7 @@ class GeneratorOutput(TypedDict): rollout_metrics: Optional[Dict[str, Any]] rollout_logprobs: Optional[List[List[float]]] trajectory_ids: Optional[List[TrajectoryID]] - rollout_inference_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] + rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] # Applicable only for step-wise training is_last_step: Optional[List[bool]] diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index b48249130d..cbad99eb31 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -40,7 +40,7 @@ class TrajectoryOutput: prompt_ids: List[int] rollout_logprobs: Optional[List[float]] env_metrics: Dict[str, Any] - rollout_inference_indices: Optional[List[List[List[int]]]] = None + rollout_expert_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -58,7 +58,7 @@ class AgentLoopState: rollout_logprobs: Optional[List[float]] response_end_idx: Optional[int] done: bool - rollout_inference_indices: Optional[List[List[List[int]]]] = None + rollout_expert_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -68,26 +68,26 @@ class TurnOutput: output_logprobs: Optional[List[float]] new_obs: ConversationType obs_ids: List[int] - rollout_inference_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk] + rollout_expert_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk] reward: Optional[float] added_eos: bool = False - def get_turn_rollout_inference_indices(self) -> Optional[List[List[List[int]]]]: + def get_turn_rollout_expert_indices(self) -> Optional[List[List[List[int]]]]: """ Get rollout inference indices for this turn's tokens (output + observation). Returns indices for generated output tokens, with padding entries (all -1) for any manually-added EOS token and observation tokens. - Returns None if rollout_inference_indices is None. + Returns None if rollout_expert_indices is None. """ - if self.rollout_inference_indices is None: + if self.rollout_expert_indices is None: return None - if not self.rollout_inference_indices: - return self.rollout_inference_indices - layer_num = len(self.rollout_inference_indices[0]) - topk = len(self.rollout_inference_indices[0][0]) if layer_num > 0 else 0 + if not self.rollout_expert_indices: + return self.rollout_expert_indices + layer_num = len(self.rollout_expert_indices[0]) + topk = len(self.rollout_expert_indices[0][0]) if layer_num > 0 else 0 pad_entry = [[-1] * topk for _ in range(layer_num)] - indices = list(self.rollout_inference_indices) + indices = list(self.rollout_expert_indices) if self.added_eos: indices.append(pad_entry) indices.extend(pad_entry for _ in range(len(self.obs_ids))) @@ -324,14 +324,14 @@ async def agent_loop( output_ids = engine_output["response_ids"][0] stop_reason = engine_output["stop_reasons"][0] response_logprobs = engine_output.get("response_logprobs", None) - rollout_inference_indices = engine_output.get("rollout_inference_indices", None) + rollout_expert_indices = engine_output.get("rollout_expert_indices", None) if response_logprobs is not None: response_logprobs = response_logprobs[0] if self.custom_chat_template is not None: raise ValueError("Response Logprobs bookkeeping is not supported with custom chat template") - if rollout_inference_indices is not None: - rollout_inference_indices = rollout_inference_indices[0] + if rollout_expert_indices is not None: + rollout_expert_indices = rollout_expert_indices[0] # Append eos when sampling_params.stop is not None. Does not affect 3.a as chat templates add eos_token. # sampling_params is not None for eval, but None for training (which uses engine.sampling_params which are from cfg) stop_strs = current_sampling_params.get("stop", None) @@ -375,11 +375,11 @@ async def agent_loop( reward=step_reward, obs_ids=obs_ids, added_eos=added_eos, - rollout_inference_indices=rollout_inference_indices, + rollout_expert_indices=rollout_expert_indices, ) - if turn_output.rollout_inference_indices is not None and agent_loop_state.rollout_inference_indices is None: - agent_loop_state.rollout_inference_indices = [] + if turn_output.rollout_expert_indices is not None and agent_loop_state.rollout_expert_indices is None: + agent_loop_state.rollout_expert_indices = [] if is_step_wise: # current response + observation ids @@ -398,7 +398,7 @@ async def agent_loop( rollout_logprobs=turn_response_logprobs, stop_reason=stop_reason, env_metrics=env.get_metrics() if agent_loop_state.done else {}, - rollout_inference_indices=turn_output.get_turn_rollout_inference_indices(), + rollout_expert_indices=turn_output.get_turn_rollout_expert_indices(), ) agent_loop_output.step_outputs.append(per_step_output) @@ -427,7 +427,7 @@ async def agent_loop( prompt_ids = agent_loop_state.input_ids[:initial_prompt_length] rollout_logprobs = None - rollout_inference_indices_out = None + rollout_expert_indices_out = None response_ids = None # Prepare the final loss_mask, response_ids and rollout_logprobs . @@ -458,8 +458,8 @@ async def agent_loop( rollout_logprobs = agent_loop_state.rollout_logprobs[ : agent_loop_state.response_end_idx - initial_prompt_length + 1 ] - if agent_loop_state.rollout_inference_indices is not None: - rollout_inference_indices_out = agent_loop_state.rollout_inference_indices[ + if agent_loop_state.rollout_expert_indices is not None: + rollout_expert_indices_out = agent_loop_state.rollout_expert_indices[ : agent_loop_state.response_end_idx + 1 ] # fix index for per_step_rewards @@ -478,10 +478,10 @@ async def agent_loop( loss_mask.append(1) if rollout_logprobs is not None: rollout_logprobs.append(0.0) - if rollout_inference_indices_out is not None and rollout_inference_indices_out: - layer_num = len(rollout_inference_indices_out[0]) - topk = len(rollout_inference_indices_out[0][0]) if layer_num > 0 else 0 - rollout_inference_indices_out.append([[-1] * topk for _ in range(layer_num)]) + if rollout_expert_indices_out is not None and rollout_expert_indices_out: + layer_num = len(rollout_expert_indices_out[0]) + topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0 + rollout_expert_indices_out.append([[-1] * topk for _ in range(layer_num)]) appended_eos_token = True if self.generator_cfg.step_wise_trajectories: @@ -501,7 +501,7 @@ async def agent_loop( prompt_ids=prompt_ids, rollout_logprobs=rollout_logprobs, env_metrics=env_metrics, - rollout_inference_indices=rollout_inference_indices_out, + rollout_expert_indices=rollout_expert_indices_out, ) return agent_loop_output @@ -653,14 +653,14 @@ async def generate_batched( responses = engine_output["response_ids"] stop_reasons = engine_output["stop_reasons"] logprobs = engine_output.get("response_logprobs", None) - raw_rollout_inference_indices = engine_output.get("rollout_inference_indices", None) + raw_rollout_expert_indices = engine_output.get("rollout_expert_indices", None) truncated_responses = [] rewards = [] loss_masks = [] env_metrics = [] truncated_logprobs: Optional[List[List[float]]] = [] if logprobs is not None else None - truncated_indices: Optional[List] = [] if raw_rollout_inference_indices is not None else None + truncated_indices: Optional[List] = [] if raw_rollout_expert_indices is not None else None for i, (output, response, env, env_class) in enumerate(zip(outputs, responses, envs, env_classes)): # step on environment and compute reward @@ -675,8 +675,8 @@ async def generate_batched( if logprobs is not None: sample_logprobs = logprobs[i][: len(response)] truncated_logprobs.append(sample_logprobs) - if raw_rollout_inference_indices is not None: - truncated_indices.append(raw_rollout_inference_indices[i]) + if raw_rollout_expert_indices is not None: + truncated_indices.append(raw_rollout_expert_indices[i]) # Get environment-specific metrics env_metrics.append(env.get_metrics()) @@ -696,7 +696,7 @@ async def generate_batched( "stop_reasons": stop_reasons, "rollout_metrics": rollout_metrics, "rollout_logprobs": truncated_logprobs, - "rollout_inference_indices": truncated_indices, + "rollout_expert_indices": truncated_indices, } return generator_output @@ -799,15 +799,12 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False if self.generator_cfg.step_wise_trajectories: all_indices = sum( - [ - [step_output.rollout_inference_indices for step_output in output.step_outputs] - for output in all_outputs - ], + [[step_output.rollout_expert_indices for step_output in output.step_outputs] for output in all_outputs], [], ) else: - all_indices = [output.rollout_inference_indices for output in all_outputs] - rollout_inference_indices = all_indices if any(idx is not None for idx in all_indices) else None + all_indices = [output.rollout_expert_indices for output in all_outputs] + rollout_expert_indices = all_indices if any(idx is not None for idx in all_indices) else None rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) @@ -827,7 +824,7 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False "rollout_metrics": rollout_metrics, "rollout_logprobs": rollout_logprobs, "trajectory_ids": out_trajectory_ids, - "rollout_inference_indices": rollout_inference_indices, + "rollout_expert_indices": rollout_expert_indices, "is_last_step": is_last_step, } @@ -896,7 +893,7 @@ def _update_agent_state_by_retokenizing_chat_history( # `logprobs` are not computed because retokenizing breaks token-in-token-out agent_loop_state.rollout_logprobs = None # indices are not meaningful when retokenizing - agent_loop_state.rollout_inference_indices = None + agent_loop_state.rollout_expert_indices = None return agent_loop_state def _update_agent_loop_state_with_multiturn_chat_template( @@ -948,14 +945,15 @@ def _update_agent_loop_state_with_multiturn_chat_template( loss_mask_for_turn = turn_output.get_turn_loss_mask() rollout_logprobs_for_turn = turn_output.get_turn_rollout_logprobs() - rollout_inference_indices_for_turn = turn_output.get_turn_rollout_inference_indices() + rollout_expert_indices_for_turn = turn_output.get_turn_rollout_expert_indices() if self.generator_cfg.step_wise_trajectories: # cumulative input_ids is not tracked for step wise training agent_loop_state.response_end_idx = len(turn_output.output_ids) - 1 - # no running loss_mask or `rollout_logprobs` are tracked for step-wise training + # no running loss_mask, `rollout_logprobs`, or `rollout_expert_indices` are tracked for step-wise training agent_loop_state.loss_mask = None agent_loop_state.rollout_logprobs = None + agent_loop_state.rollout_expert_indices = None else: # Directly append turn output turn_ids = turn_output.output_ids + turn_output.obs_ids @@ -964,11 +962,10 @@ def _update_agent_loop_state_with_multiturn_chat_template( agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn - if ( - agent_loop_state.rollout_inference_indices is not None - and rollout_inference_indices_for_turn is not None - ): - agent_loop_state.rollout_inference_indices += rollout_inference_indices_for_turn + if agent_loop_state.rollout_expert_indices is not None and rollout_expert_indices_for_turn is not None: + # overwrite the existing rollout inference indices, since the inference engine should + # return the expert indices for the entire sequence including each turn's input + agent_loop_state.rollout_expert_indices = rollout_expert_indices_for_turn return agent_loop_state @@ -1033,13 +1030,13 @@ def _update_agent_loop_state_with_singleturn_chat_template( obs_ids_to_add ) - rollout_inference_indices_for_turn = None - if turn_output.rollout_inference_indices is not None and turn_output.rollout_inference_indices: - layer_num = len(turn_output.rollout_inference_indices[0]) - topk = len(turn_output.rollout_inference_indices[0][0]) if layer_num > 0 else 0 - pad_entry = [[-1] * topk for _ in range(layer_num)] - rollout_inference_indices_for_turn = list(turn_output.rollout_inference_indices[: len(new_resp_tokens)]) - rollout_inference_indices_for_turn.extend(pad_entry for _ in range(len(obs_ids_to_add))) + # rollout_expert_indices_for_turn = None + # if turn_output.rollout_expert_indices is not None and turn_output.rollout_expert_indices: + # layer_num = len(turn_output.rollout_expert_indices[0]) + # topk = len(turn_output.rollout_expert_indices[0][0]) if layer_num > 0 else 0 + # pad_entry = [[-1] * topk for _ in range(layer_num)] + # rollout_expert_indices_for_turn = list(turn_output.rollout_expert_indices[: len(new_resp_tokens)]) + # rollout_expert_indices_for_turn.extend(pad_entry for _ in range(len(obs_ids_to_add))) # Directly append turn output agent_loop_state.response_end_idx = len(agent_loop_state.input_ids) + len(new_resp_tokens) - 1 @@ -1047,7 +1044,7 @@ def _update_agent_loop_state_with_singleturn_chat_template( agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn - if agent_loop_state.rollout_inference_indices is not None and rollout_inference_indices_for_turn is not None: - agent_loop_state.rollout_inference_indices += rollout_inference_indices_for_turn + if self.generator_cfg.enable_return_routed_experts and turn_output.rollout_expert_indices is not None: + agent_loop_state.rollout_expert_indices = turn_output.rollout_expert_indices return agent_loop_state diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index f842868816..ef73ab11b5 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -604,8 +604,8 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks: List[List[int]] = generator_output["loss_masks"] logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None) - rollout_inference_indices: Optional[List[List[List[List[int]]]]] = generator_output.get( - "rollout_inference_indices", None + rollout_expert_indices: Optional[List[List[List[List[int]]]]] = generator_output.get( + "rollout_expert_indices", None ) ( @@ -615,7 +615,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards_tensor, loss_masks_tensor, rollout_logprobs_tensor, - rollout_inference_indices_tensor, + rollout_expert_indices_tensor, ) = convert_prompts_responses_to_batch_tensors( self.tokenizer, prompt_ids, @@ -623,7 +623,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards, loss_masks, logprobs, - rollout_inference_indices, + rollout_expert_indices, ) # sanity check for off_policy_correction @@ -644,7 +644,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis "rewards": rewards_tensor, "loss_mask": loss_masks_tensor, "rollout_logprobs": rollout_logprobs_tensor, - "rollout_inference_indices": rollout_inference_indices_tensor, + "rollout_expert_indices": rollout_expert_indices_tensor, "is_last_step": ( torch.tensor(generator_output["is_last_step"], dtype=torch.bool) if generator_output.get("is_last_step", None) is not None diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 9581b8f990..8e84ee61f2 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -23,7 +23,8 @@ from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch -MOE_MODEL_NAME = "/home/ray/moonlight16b" +MOE_MODEL_NAME = "arcee-ai/Trinity-Nano-Preview" +# MOE_MODEL_NAME = "/home/ray/moonlight16b" # MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" REPLAY_NUM_LAYERS = 2 NUM_PROMPTS = 10 @@ -112,6 +113,81 @@ def build_training_input_from_text_samples( return training_input +def test_generate_with_router_replay(ray_init_fixture): + """ + Check that generate with router replay produces the correct rollout inference indices. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 8 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=MAX_GENERATE_LENGTH, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 2 + cfg.generator.use_conversation_multi_turn = True + env_class = "gsm8k_multi_turn" + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + colocate_all=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.9, + ) as engines: + client = engines.client + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, + max_prompt_length=512, + env_class=env_class, + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=MAX_GENERATE_LENGTH, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + indices = generator_output["rollout_expert_indices"] + responses = generator_output["response_ids"] + assert ( + indices is not None + ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + asyncio.run(client.sleep()) + finally: + ray.shutdown() + + @pytest.mark.megatron def test_logprobs(ray_init_fixture): """ @@ -173,11 +249,11 @@ def test_logprobs(ray_init_fixture): with Timer("generate_with_router_replay"): generator_output = asyncio.run(generator.generate(input_batch)) - indices = generator_output["rollout_inference_indices"] + indices = generator_output["rollout_expert_indices"] responses = generator_output["response_ids"] assert ( indices is not None - ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" assert len(indices) == len( responses ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" @@ -194,7 +270,7 @@ def test_logprobs(ray_init_fixture): rewards=rewards, loss_masks=generator_output["loss_masks"], logprobs=generator_output.get("rollout_logprobs"), - rollout_inference_indices=indices, + rollout_expert_indices=indices, ) ) @@ -213,7 +289,7 @@ def test_logprobs(ray_init_fixture): if logprobs_t is not None else torch.zeros((batch_size, num_actions), dtype=torch.float32) ), - "rollout_inference_indices": rii_tensor, + "rollout_expert_indices": rii_tensor, "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), @@ -333,11 +409,11 @@ def test_forward_backward(ray_init_fixture): with Timer("generate_with_router_replay"): generator_output = asyncio.run(generator.generate(input_batch)) - indices = generator_output["rollout_inference_indices"] + indices = generator_output["rollout_expert_indices"] responses = generator_output["response_ids"] assert ( indices is not None - ), "rollout_inference_indices should not be None when enable_return_routed_experts=True" + ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" assert len(indices) == len( responses ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" @@ -354,7 +430,7 @@ def test_forward_backward(ray_init_fixture): rewards=rewards, loss_masks=generator_output["loss_masks"], logprobs=generator_output.get("rollout_logprobs"), - rollout_inference_indices=indices, + rollout_expert_indices=indices, ) ) @@ -373,7 +449,7 @@ def test_forward_backward(ray_init_fixture): if logprobs_t is not None else torch.zeros((batch_size, num_actions), dtype=torch.float32) ), - "rollout_inference_indices": rii_tensor, + "rollout_expert_indices": rii_tensor, "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py index 788bc8ef42..7a0b77bbfa 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py @@ -29,6 +29,7 @@ def get_test_config( is_step_wise, temperature, get_logprobs, + enable_return_routed_experts, ): cfg = SkyRLTrainConfig() cfg.trainer.policy.model.path = model @@ -49,7 +50,7 @@ def get_test_config( cfg.generator.inference_engine.http_endpoint_host = "127.0.0.1" cfg.generator.inference_engine.http_endpoint_port = 8000 cfg.generator.step_wise_trajectories = is_step_wise - + cfg.generator.inference_engine.enable_return_routed_experts = enable_return_routed_experts cfg.environment.skyrl_gym.search.log_requests = True cfg.environment.skyrl_gym.search.search_url = "http://127.0.0.1:8000/retrieve" cfg.environment.skyrl_gym.max_env_workers = max_env_workers @@ -108,6 +109,7 @@ async def run_generator_end_to_end( is_step_wise: bool = False, temperature=1.0, get_logprobs: bool = False, + enable_return_routed_experts: bool = False, ): """ End to end generator test - requires minimum 2 GPUs @@ -125,6 +127,7 @@ async def run_generator_end_to_end( is_step_wise, temperature, get_logprobs, + enable_return_routed_experts, ) # Use InferenceEngineState to support both legacy and new inference backends @@ -186,6 +189,11 @@ async def run_generator_end_to_end( "rollout_logprobs": ( generator_output["rollout_logprobs"][i] if generator_output["rollout_logprobs"] else None ), + "rollout_expert_indices": ( + generator_output["rollout_expert_indices"][i] + if generator_output["rollout_expert_indices"] + else None + ), } for i in range(len(generator_output["response_ids"])) ] @@ -450,3 +458,33 @@ async def test_generator_multi_turn_gsm8k_step_wise(ray_init_fixture): assert isinstance(generator_output["is_last_step"], list) and isinstance(generator_output["is_last_step"][0], bool) # Expect atleast one response with more than one turn assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) + + +async def test_generator_multi_turn_gsm8k_router_replay(ray_init_fixture): + """ + Test the generator with the multi-turn GSM8K environment for router replay + """ + generator_output: GeneratorOutput = await run_generator_end_to_end( + use_async_engine=True, + batched=False, + n_samples_per_prompt=5, + num_inference_engines=2, + tensor_parallel_size=2, + model="arcee-ai/Trinity-Nano-Preview", + max_prompt_length=2048, + max_input_length=4096, + max_generate_length=1000, + data_path=os.path.expanduser("~/data/gsm8k/validation.parquet"), + env_class="gsm8k_multi_turn", + num_prompts=2, + max_turns=2, + use_conversation_multi_turn=True, + max_env_workers=0, + is_step_wise=True, + temperature=0, + enable_return_routed_experts=True, + ) + + assert isinstance(generator_output["is_last_step"], list) and isinstance(generator_output["is_last_step"][0], bool) + # Expect atleast one response with more than one turn + assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) From 736735904a17fcd6e7ee86fa35fc4d9d93700795 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Wed, 11 Mar 2026 23:03:45 +0000 Subject: [PATCH 16/18] testing replay utils with pp --- .../skyrl_train/utils/replay_utils.py | 82 +++++++++++++++++-- .../megatron/megatron_model_wrapper.py | 8 +- 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 72d14e67db..8122ca1ca8 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -120,9 +120,75 @@ def _remove_left_padding_from_indices( return new_rii +def _get_current_pp_stage_layer_range(model_config) -> tuple[int, int]: + """Return the current PP rank's transformer-layer range. + + Prefer Megatron's own helpers so replay indexing stays aligned with the + actual model partition, including embedding/loss pipeline accounting. + """ + import megatron.core.parallel_state as mpu + from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + from megatron.core.transformer.transformer_block import get_num_layers_to_build + + + if get_num_layers_to_build is not None: + return get_transformer_layer_offset(model_config), get_num_layers_to_build(model_config, pp_rank=pp_rank) + + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + + total_layers = model_config.num_layers + first_stage_layers = getattr(model_config, "num_layers_in_first_pipeline_stage", None) + last_stage_layers = getattr(model_config, "num_layers_in_last_pipeline_stage", None) + + if pp_size <= 1: + return 0, total_layers + + if first_stage_layers is None and last_stage_layers is None: + assert total_layers % pp_size == 0, ( + "For even pipelineing, num_layers should be divisible by pipeline_model_parallel_size" + ) + pp_layers = total_layers // pp_size + return pp_rank * pp_layers, pp_layers + + next_n_pp_layers = total_layers + next_n_pp_stages = pp_size + + if first_stage_layers is not None: + next_n_pp_layers -= first_stage_layers + next_n_pp_stages -= 1 + + if last_stage_layers is not None: + next_n_pp_layers -= last_stage_layers + next_n_pp_stages -= 1 + + if next_n_pp_stages > 0: + assert next_n_pp_layers % next_n_pp_stages == 0, ( + "Uneven pipelineing, not divisible by remaining pipeline stages" + ) + next_n_pp_layers = next_n_pp_layers // next_n_pp_stages + else: + next_n_pp_layers = 0 + + if pp_rank == 0 and first_stage_layers is not None: + return 0, first_stage_layers + + if pp_rank == pp_size - 1 and last_stage_layers is not None: + if first_stage_layers is not None: + start = first_stage_layers + (next_n_pp_layers * (pp_size - 2)) + else: + start = next_n_pp_layers * (pp_size - 1) + return start, last_stage_layers + + if first_stage_layers is not None: + return first_stage_layers + (next_n_pp_layers * (pp_rank - 1)), next_n_pp_layers + return next_n_pp_layers * pp_rank, next_n_pp_layers + + def setup_per_microbatch_replay_forward( rollout_expert_indices: torch.Tensor, attention_mask: torch.Tensor, + model_config, ) -> None: """Set up RouterReplay for a single micro-batch, aligning indices with the left-padding-removed token layout that the MoE layer sees. @@ -151,14 +217,16 @@ def setup_per_microbatch_replay_forward( seq_len = aligned.shape[1] chunk_size = seq_len // tp_size aligned = aligned[:, tp_rank * chunk_size : (tp_rank + 1) * chunk_size, :, :] - + per_layer_data = _split_replay_indices(aligned) - num_layers_in_data = len(per_layer_data) + global_num_layers_in_data = len(per_layer_data) instances = RouterReplay.global_router_replay_instances num_instances = len(instances) + local_layer_offset, local_num_layers = _get_current_pp_stage_layer_range(model_config) - if num_layers_in_data == num_instances: - RouterReplay.set_replay_data(per_layer_data) + if local_num_layers == num_instances: + local_per_layer_data = per_layer_data[local_layer_offset : local_layer_offset + local_num_layers] + RouterReplay.set_replay_data(local_per_layer_data) else: # Dense-layer mismatch: map each MoE router to its global layer index. # Prefer the patched layer_number; fall back to offset-based mapping @@ -168,11 +236,11 @@ def setup_per_microbatch_replay_forward( if layer_number is not None: layer_idx = layer_number - 1 # layer_number is 1-based else: - layer_idx = i + (num_layers_in_data - num_instances) - if layer_idx < 0 or layer_idx >= num_layers_in_data: + layer_idx = local_layer_offset + i + if layer_idx < 0 or layer_idx >= global_num_layers_in_data: raise ValueError( f"Router replay layer index {layer_idx} out of range " - f"for data with {num_layers_in_data} layers " + f"for data with {global_num_layers_in_data} layers " f"({num_instances} router instances)" ) router_instance.set_target_indices(per_layer_data[layer_idx]) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index c7d1c640b8..50ecb2fa90 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -107,7 +107,9 @@ def forward_step(batch_iter, model): rollout_expert_indices = batch.pop("rollout_expert_indices", None) if rollout_expert_indices is not None: - setup_per_microbatch_replay_forward(rollout_expert_indices, batch["attention_mask"]) + setup_per_microbatch_replay_forward( + rollout_expert_indices, batch["attention_mask"], get_model_config(model) + ) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) @@ -363,7 +365,9 @@ def forward_step(batch_iter, model): rollout_expert_indices = batch.pop("rollout_expert_indices", None) if rollout_expert_indices is not None: - setup_per_microbatch_replay_forward(rollout_expert_indices, batch["attention_mask"]) + setup_per_microbatch_replay_forward( + rollout_expert_indices, batch["attention_mask"], get_model_config(model) + ) sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) From b88b820c3a9f3ac8ecc1255aeb10e62d9c420318 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Wed, 11 Mar 2026 23:23:31 +0000 Subject: [PATCH 17/18] move rank up --- skyrl/backends/skyrl_train/utils/replay_utils.py | 7 ++++--- .../backends/skyrl_train/gpu/gpu_ci/test_router_replay.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 8122ca1ca8..13abeaf85f 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -130,12 +130,12 @@ def _get_current_pp_stage_layer_range(model_config) -> tuple[int, int]: from megatron.core.transformer.transformer_layer import get_transformer_layer_offset from megatron.core.transformer.transformer_block import get_num_layers_to_build + pp_rank = mpu.get_pipeline_model_parallel_rank() if get_num_layers_to_build is not None: return get_transformer_layer_offset(model_config), get_num_layers_to_build(model_config, pp_rank=pp_rank) pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() total_layers = model_config.num_layers first_stage_layers = getattr(model_config, "num_layers_in_first_pipeline_stage", None) @@ -222,6 +222,7 @@ def setup_per_microbatch_replay_forward( global_num_layers_in_data = len(per_layer_data) instances = RouterReplay.global_router_replay_instances num_instances = len(instances) + local_layer_offset, local_num_layers = _get_current_pp_stage_layer_range(model_config) if local_num_layers == num_instances: @@ -231,12 +232,12 @@ def setup_per_microbatch_replay_forward( # Dense-layer mismatch: map each MoE router to its global layer index. # Prefer the patched layer_number; fall back to offset-based mapping # (assumes dense layers precede MoE layers). - for i, router_instance in enumerate(instances): + for local_router_idx, router_instance in enumerate(instances): layer_number = getattr(router_instance, "layer_number", None) if layer_number is not None: layer_idx = layer_number - 1 # layer_number is 1-based else: - layer_idx = local_layer_offset + i + layer_idx = local_layer_offset + local_router_idx if layer_idx < 0 or layer_idx >= global_num_layers_in_data: raise ValueError( f"Router replay layer index {layer_idx} out of range " diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index 8e84ee61f2..bd0ed6465d 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -302,7 +302,7 @@ def test_logprobs(ray_init_fixture): cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 + cfg.trainer.policy.megatron_config.expert_model_parallel_size = 2 cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 cfg.trainer.micro_forward_batch_size_per_gpu = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 From 4bbf22bcb1797187509583d15daf3f60954795df Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 16 Mar 2026 08:32:29 +0000 Subject: [PATCH 18/18] working CP and PP implementation --- .../skyrl_train/utils/replay_utils.py | 40 +++++++++++++--- .../gpu/gpu_ci/test_router_replay.py | 48 +++++++++++++------ 2 files changed, 66 insertions(+), 22 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 13abeaf85f..1065989f70 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -100,9 +100,11 @@ def _remove_left_padding_from_indices( seq_lens = attention_mask.sum(dim=1) effective_seq_len = seq_lens.max().item() - sp_world_size = mpu.get_tensor_model_parallel_world_size() - if sp_world_size > 1: - pad_size = (sp_world_size - effective_seq_len % sp_world_size) % sp_world_size + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + if align_size > 1: + pad_size = (align_size - effective_seq_len % align_size) % align_size effective_seq_len += pad_size batch_size = rollout_expert_indices.shape[0] @@ -193,14 +195,26 @@ def setup_per_microbatch_replay_forward( """Set up RouterReplay for a single micro-batch, aligning indices with the left-padding-removed token layout that the MoE layer sees. + Handles context parallelism: when CP > 1, the sequence is split into + 2*cp_size chunks with each CP rank receiving a front chunk and a back + chunk (for causal-mask load balancing). Replay indices are split using + the same pattern so they stay aligned with the tokens each rank sees. + Handles sequence parallelism: when TP > 1, the sequence is split across TP ranks, so each rank's MoE router only sees its local chunk of tokens. Handles dense-layer mismatch: DeepSeek V3-style models have dense FFN - layers before the MoE layers. vLLM reports routing indices for ALL + layers before the MoE layers. vLLM reports routing indices for ALL transformer layers, but Megatron only has RouterReplay instances for MoE - layers. We use each instance's global layer_number (set by the patched + layers. We use each instance's global layer_number (set by the patched TopKRouter.set_layer_number) to index into the correct slice of the data. + + Handles pipeline parallelism: when PP > 1, the sequence is split across + PP ranks, so each rank only sees its local RouterReplay instances. In cases + where the number of local RouterReplay instances does not match the local + layer count, indicating that the model has dense layers before MoE layers, + we use the global layer_number to index into the correct slice of the data. + """ import megatron.core.parallel_state as mpu from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction @@ -209,8 +223,20 @@ def setup_per_microbatch_replay_forward( aligned = _remove_left_padding_from_indices(rollout_expert_indices, attention_mask) - # handles megatron sequence parallelism across the tensor model parallel region - # since we automatically enable sequence parallelism when TP > 1 + # CP splitting: mirror the front+back chunking from preprocess_packed_seqs + cp_size = mpu.get_context_parallel_world_size() + if cp_size > 1: + cp_rank = mpu.get_context_parallel_rank() + seq_len = aligned.shape[1] + seqlen_per_cp = seq_len // cp_size + half = seqlen_per_cp // 2 # we do *2 for causal masking, so get half of the sequence length per CP rank + front = aligned[:, half * cp_rank : half * (cp_rank + 1), :, :] + back_start = seq_len - half * (cp_rank + 1) + back_end = seq_len - half * cp_rank + back = aligned[:, back_start:back_end, :, :] + aligned = torch.cat([front, back], dim=1) + + # TP splitting: sequence parallelism across the tensor model parallel region tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: tp_rank = mpu.get_tensor_model_parallel_rank() diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index bd0ed6465d..707455eca5 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -23,12 +23,12 @@ from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch -MOE_MODEL_NAME = "arcee-ai/Trinity-Nano-Preview" # MOE_MODEL_NAME = "/home/ray/moonlight16b" # MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +MOE_MODEL_NAME = "moonshotai/Moonlight-16B-A3B" REPLAY_NUM_LAYERS = 2 -NUM_PROMPTS = 10 -N_SAMPLES_PER_PROMPT = 4 +NUM_PROMPTS = 2 +N_SAMPLES_PER_PROMPT = 2 MAX_GENERATE_LENGTH = 128 @@ -189,7 +189,16 @@ def test_generate_with_router_replay(ray_init_fixture): @pytest.mark.megatron -def test_logprobs(ray_init_fixture): +@pytest.mark.parametrize( + "tp,pp,cp,ep,etp,extra_tf_kwargs", + [ + pytest.param(2, 1, 1, 2, 1, {}, id="baseline"), + pytest.param(2, 2, 1, 2, 1, {"num_layers_in_last_pipeline_stage": 13}, id="pp2"), + pytest.param(4, 1, 2, 8, 1, {}, id="cp2"), + pytest.param(2, 2, 2, 4, 1, {"num_layers_in_last_pipeline_stage": 13}, id="cp2_pp2"), + ], +) +def test_logprobs(ray_init_fixture, tp, pp, cp, ep, etp, extra_tf_kwargs): """ Check that logprob diff is lower when using router replay. Requires full 8xH100 setup to do full forward pass. """ @@ -299,17 +308,18 @@ def test_logprobs(ray_init_fixture): training_input.metadata = {"response_length": num_actions} cfg.trainer.placement.policy_num_gpus_per_node = 8 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = tp + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = pp + cfg.trainer.policy.megatron_config.context_parallel_size = cp + cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp cfg.trainer.micro_forward_batch_size_per_gpu = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 def run_megatron_forward(enable_replay: bool) -> torch.Tensor: cfg.trainer.policy.megatron_config.transformer_config_kwargs = { "moe_enable_routing_replay": enable_replay, + **extra_tf_kwargs, } actor_group = init_worker_with_type( "policy", @@ -347,7 +357,14 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: @pytest.mark.megatron -def test_forward_backward(ray_init_fixture): +@pytest.mark.parametrize( + "tp,pp,cp,ep,etp,extra_tf_kwargs", + [ + pytest.param(4, 1, 1, 8, 1, {}, id="baseline"), + pytest.param(2, 2, 1, 2, 1, {"num_layers_in_last_pipeline_stage": 13}, id="pp2"), + ], +) +def test_forward_backward(ray_init_fixture, tp, pp, cp, ep, etp, extra_tf_kwargs): """ Check that forward_backward produces similar losses with and without router replay (same weights, so routing decisions should nearly match). @@ -459,17 +476,18 @@ def test_forward_backward(ray_init_fixture): training_input.metadata = {"response_length": num_actions} cfg.trainer.placement.policy_num_gpus_per_node = 8 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 4 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = tp + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = pp + cfg.trainer.policy.megatron_config.context_parallel_size = cp + cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp cfg.trainer.micro_forward_batch_size_per_gpu = 1 cfg.trainer.micro_train_batch_size_per_gpu = 1 def run_megatron_forward_backward(enable_replay: bool) -> dict: cfg.trainer.policy.megatron_config.transformer_config_kwargs = { "moe_enable_routing_replay": enable_replay, + **extra_tf_kwargs, } actor_group = init_worker_with_type( "policy",