diff --git a/examples/train/gsm8k/run_gsm8k.sh b/examples/train/gsm8k/run_gsm8k.sh index 2693571dac..e5dc9e20f3 100755 --- a/examples/train/gsm8k/run_gsm8k.sh +++ b/examples/train/gsm8k/run_gsm8k.sh @@ -26,8 +26,8 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \ trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \ trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ - generator.inference_engine.num_engines=$NUM_GPUS \ - generator.inference_engine.tensor_parallel_size=1 \ + generator.inference_engine.num_engines=1 \ + generator.inference_engine.tensor_parallel_size=4 \ trainer.epochs=20 \ trainer.eval_batch_size=1024 \ trainer.eval_before_train=true \ diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index 4b073da5a0..819a071603 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -29,6 +29,7 @@ class InferenceEngineOutput(TypedDict): response_ids: List[List[int]] stop_reasons: List[str] response_logprobs: Optional[List[List[float]]] + rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk] class InferenceEngineInterface(ABC): @@ -63,6 +64,7 @@ async def sample( all_responses = [] all_stop_reasons = [] all_response_logprobs = [] + all_rollout_expert_indices = [] for _ in range(num_samples): input_batch: InferenceEngineInput = { @@ -79,12 +81,15 @@ async def sample( all_stop_reasons.append(output["stop_reasons"][0]) if output.get("response_logprobs") is not None: all_response_logprobs.append(output["response_logprobs"][0]) + if output.get("rollout_expert_indices") is not None: + all_rollout_expert_indices.append(output["rollout_expert_indices"][0]) return { "response_ids": all_response_ids, "responses": all_responses, "stop_reasons": all_stop_reasons, "response_logprobs": all_response_logprobs if all_response_logprobs else None, + "rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None, } @abstractmethod diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index 02bad0d7be..d418089a72 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -153,8 +153,10 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu stop_reasons: list[str] = [""] * n response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)] response_ids: List[List[int]] = [[] for _ in range(n)] + rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)] # a bit hacky for now add_resp_logprobs = False + add_rollout_expert_indices = False for indices, result in zip(indices_list, results): for local_idx, original_idx in enumerate(indices): @@ -164,12 +166,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu if result.get("response_logprobs", None): add_resp_logprobs = True response_logprobs[original_idx] = result["response_logprobs"][local_idx] + if result.get("rollout_expert_indices", None): + add_rollout_expert_indices = True + rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx] return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs if add_resp_logprobs else None, + rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None, ) def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int: @@ -265,6 +271,7 @@ async def _generate_single_with_retry( # 2. Initialize fields we want to accumulate or update in each loop iteration accum_response_ids: List[int] = [] accum_response_logprobs: List[float] = [] + accum_rollout_expert_indices: List[List[List[int]]] = [] stop_reason: str = "abort" # We only use it if generation is completed in one turn to maintain original behavior with no retry. @@ -300,6 +307,10 @@ async def _generate_single_with_retry( new_response_logprobs_list: Optional[List[List[float]]] = partial_response.get("response_logprobs", None) if new_response_logprobs_list is not None and len(new_response_logprobs_list) > 0: new_response_logprobs = new_response_logprobs_list[0] + new_rollout_expert_indices: Optional[List[List[List[int]]]] = None + new_rollout_expert_indices_list = partial_response.get("rollout_expert_indices", None) + if new_rollout_expert_indices_list is not None and len(new_rollout_expert_indices_list) > 0: + new_rollout_expert_indices = new_rollout_expert_indices_list[0] # 3.4 Aborted without generating tokens, so partial_response is useless. if stop_reason == "abort" and len(new_response_ids) == 0: @@ -309,6 +320,8 @@ async def _generate_single_with_retry( accum_response_ids.extend(new_response_ids) if new_response_logprobs is not None: accum_response_logprobs.extend(new_response_logprobs) + if new_rollout_expert_indices is not None: + accum_rollout_expert_indices.extend(new_rollout_expert_indices) num_turns += 1 # 4. Build the final response and return. @@ -321,6 +334,7 @@ async def _generate_single_with_retry( stop_reasons=[stop_reason], response_ids=[accum_response_ids], response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None, + rollout_expert_indices=([accum_rollout_expert_indices] if len(accum_rollout_expert_indices) > 0 else None), ) async def _chat_completion_with_retry( diff --git a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 6e238f492f..0ec3c29ac3 100644 --- a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -107,6 +107,7 @@ def create_ray_wrapped_inference_engines( rope_scaling: Dict[str, Any] = {}, rope_theta: float | None = None, enable_ray_prometheus_stats: bool = False, + enable_return_routed_experts: bool = False, served_model_name: str | None = None, ) -> List[InferenceEngineInterface]: """ @@ -249,6 +250,7 @@ def create_ray_wrapped_inference_engines( max_num_seqs=max_num_seqs, max_logprobs=1, # only need chosen-token logprobs enable_ray_prometheus_stats=enable_ray_prometheus_stats, + enable_return_routed_experts=enable_return_routed_experts, **dp_kwargs, **engine_init_kwargs, **lora_kwargs, diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 1123088cb6..aa62a74f7c 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -135,6 +135,7 @@ def _postprocess_outputs(self, outputs): stop_reasons: List[str] = [] response_ids: List[List[int]] = [] response_logprobs: Optional[List[List[float]]] = [] + rollout_expert_indices: Optional[List[List[List[List[int]]]]] = [] for output in outputs: # TODO(tgriggs): Support n>1 sampling. @@ -156,14 +157,26 @@ def _postprocess_outputs(self, outputs): del token_logprobs response_logprobs.append(_logprobs) + _routed_experts = None + if resp.routed_experts is not None: + if hasattr(resp.routed_experts, "tolist"): + _routed_experts = resp.routed_experts.tolist() + else: + _routed_experts = resp.routed_experts + rollout_expert_indices.append(_routed_experts) + if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params + if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None: + rollout_expert_indices = None # hack: assume uniform sampling params + return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs, + rollout_expert_indices=rollout_expert_indices, ) def _get_engine(self): @@ -321,6 +334,14 @@ def _create_engine(self, *args, **kwargs): enable_log_requests = kwargs.pop("enable_log_requests", False) max_log_len = kwargs.pop("max_log_len", None) + # Log if enable_return_routed_experts is being passed + if "enable_return_routed_experts" in kwargs: + logger.info( + f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs" + ) + else: + logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") + if version.parse(vllm.__version__) >= version.parse("0.10.0"): engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) else: diff --git a/skyrl/backends/skyrl_train/training_batch.py b/skyrl/backends/skyrl_train/training_batch.py index 839508aa3c..5295c6c4bc 100644 --- a/skyrl/backends/skyrl_train/training_batch.py +++ b/skyrl/backends/skyrl_train/training_batch.py @@ -369,6 +369,7 @@ class TrainingInput(TypedDict, total=False): kl: Float[torch.Tensor, "batch_size seq_len"] rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] + rollout_expert_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]] class TrainingInputBatch(TensorBatch[TrainingInput]): diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py new file mode 100644 index 0000000000..1065989f70 --- /dev/null +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -0,0 +1,283 @@ +""" +Utility functions for MoE Router Replay. +""" + +import torch +from typing import List + + +def _patch_topk_router_layer_number(): + """Monkey-patch TopKRouter.set_layer_number to propagate the global layer + number to the RouterReplay instance. + + DeepSeek V3 (and similar) architectures have dense FFN layers before the MoE + layers. vLLM reports routing indices for ALL transformer layers (including + dense), but Megatron only creates RouterReplay instances for MoE layers. + Storing the global layer_number on each RouterReplay instance lets us map + vLLM's per-layer data to the correct MoE router even when dense layers are + present. + + Must be called BEFORE model creation (i.e. before make_megatron_module). + """ + try: + from megatron.core.transformer.moe.router import TopKRouter + except ImportError: + return + + if getattr(TopKRouter, "_set_layer_number_patched", False): + return + + original_set_layer_number = TopKRouter.set_layer_number + + def patched_set_layer_number(self, layer_number: int): + original_set_layer_number(self, layer_number) + if self.router_replay is not None: + self.router_replay.layer_number = layer_number + + TopKRouter.set_layer_number = patched_set_layer_number + TopKRouter._set_layer_number_patched = True + + +def _patch_alltoall_dispatcher_for_replay(): + """Monkey-patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay. + + When router replay is enabled, duplicate indices in top_indices can cause + routing_map.sum() < num_tokens * topk, leading to a split size mismatch + in the alltoall collective. We fix this by deriving num_out_tokens from + the routing map instead of the static num_tokens * topk formula. + + Reference: https://github.com/verl-project/verl/pull/4986 + """ + try: + from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher + except ImportError: + return + + if getattr(MoEAlltoAllTokenDispatcher, "_preprocess_patched", False): + return + + original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + + def patched_preprocess(self, routing_map): + result = original_preprocess(self, routing_map) + if ( + getattr(self.config, "moe_enable_routing_replay", False) + and not self.drop_and_pad + and self.config.moe_expert_capacity_factor is None + and not self.config.moe_router_padding_for_quantization + ): + self.num_out_tokens = int(routing_map.sum().item()) + return result + + MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess + MoEAlltoAllTokenDispatcher._preprocess_patched = True + + +def _split_replay_indices(rollout_expert_indices: torch.Tensor) -> List[torch.Tensor]: + if rollout_expert_indices is None: + return None + if rollout_expert_indices.dim() != 4: + raise ValueError(f"Expected 4D replay indices, got shape {rollout_expert_indices.shape}") + per_layer = rollout_expert_indices.permute(2, 0, 1, 3).contiguous() + # flatten [batch, seq, topk] to [batch * seq, topk] for each layer + return [per_layer[i].reshape(-1, per_layer.shape[-1]) for i in range(per_layer.shape[0])] + + +def _remove_left_padding_from_indices( + rollout_expert_indices: torch.Tensor, + attention_mask: torch.Tensor, +) -> torch.Tensor: + """Apply the same left-padding removal as remove_left_padding to routing indices. + + Args: + rollout_expert_indices: [batch, padded_seq_len, layers, topk] + attention_mask: [batch, padded_seq_len] (int or bool) + + Returns: + [batch, effective_seq_len, layers, topk] with real tokens packed left. + """ + import megatron.core.parallel_state as mpu + + seq_lens = attention_mask.sum(dim=1) + effective_seq_len = seq_lens.max().item() + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + if align_size > 1: + pad_size = (align_size - effective_seq_len % align_size) % align_size + effective_seq_len += pad_size + + batch_size = rollout_expert_indices.shape[0] + new_rii = torch.zeros( + batch_size, + effective_seq_len, + rollout_expert_indices.shape[2], + rollout_expert_indices.shape[3], + dtype=rollout_expert_indices.dtype, + device=rollout_expert_indices.device, + ) + for i in range(batch_size): + mask = attention_mask[i].bool() + new_rii[i, : seq_lens[i]] = rollout_expert_indices[i, mask] + return new_rii + + +def _get_current_pp_stage_layer_range(model_config) -> tuple[int, int]: + """Return the current PP rank's transformer-layer range. + + Prefer Megatron's own helpers so replay indexing stays aligned with the + actual model partition, including embedding/loss pipeline accounting. + """ + import megatron.core.parallel_state as mpu + from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + from megatron.core.transformer.transformer_block import get_num_layers_to_build + + pp_rank = mpu.get_pipeline_model_parallel_rank() + + if get_num_layers_to_build is not None: + return get_transformer_layer_offset(model_config), get_num_layers_to_build(model_config, pp_rank=pp_rank) + + pp_size = mpu.get_pipeline_model_parallel_world_size() + + total_layers = model_config.num_layers + first_stage_layers = getattr(model_config, "num_layers_in_first_pipeline_stage", None) + last_stage_layers = getattr(model_config, "num_layers_in_last_pipeline_stage", None) + + if pp_size <= 1: + return 0, total_layers + + if first_stage_layers is None and last_stage_layers is None: + assert total_layers % pp_size == 0, ( + "For even pipelineing, num_layers should be divisible by pipeline_model_parallel_size" + ) + pp_layers = total_layers // pp_size + return pp_rank * pp_layers, pp_layers + + next_n_pp_layers = total_layers + next_n_pp_stages = pp_size + + if first_stage_layers is not None: + next_n_pp_layers -= first_stage_layers + next_n_pp_stages -= 1 + + if last_stage_layers is not None: + next_n_pp_layers -= last_stage_layers + next_n_pp_stages -= 1 + + if next_n_pp_stages > 0: + assert next_n_pp_layers % next_n_pp_stages == 0, ( + "Uneven pipelineing, not divisible by remaining pipeline stages" + ) + next_n_pp_layers = next_n_pp_layers // next_n_pp_stages + else: + next_n_pp_layers = 0 + + if pp_rank == 0 and first_stage_layers is not None: + return 0, first_stage_layers + + if pp_rank == pp_size - 1 and last_stage_layers is not None: + if first_stage_layers is not None: + start = first_stage_layers + (next_n_pp_layers * (pp_size - 2)) + else: + start = next_n_pp_layers * (pp_size - 1) + return start, last_stage_layers + + if first_stage_layers is not None: + return first_stage_layers + (next_n_pp_layers * (pp_rank - 1)), next_n_pp_layers + return next_n_pp_layers * pp_rank, next_n_pp_layers + + +def setup_per_microbatch_replay_forward( + rollout_expert_indices: torch.Tensor, + attention_mask: torch.Tensor, + model_config, +) -> None: + """Set up RouterReplay for a single micro-batch, aligning indices + with the left-padding-removed token layout that the MoE layer sees. + + Handles context parallelism: when CP > 1, the sequence is split into + 2*cp_size chunks with each CP rank receiving a front chunk and a back + chunk (for causal-mask load balancing). Replay indices are split using + the same pattern so they stay aligned with the tokens each rank sees. + + Handles sequence parallelism: when TP > 1, the sequence is split across + TP ranks, so each rank's MoE router only sees its local chunk of tokens. + + Handles dense-layer mismatch: DeepSeek V3-style models have dense FFN + layers before the MoE layers. vLLM reports routing indices for ALL + transformer layers, but Megatron only has RouterReplay instances for MoE + layers. We use each instance's global layer_number (set by the patched + TopKRouter.set_layer_number) to index into the correct slice of the data. + + Handles pipeline parallelism: when PP > 1, the sequence is split across + PP ranks, so each rank only sees its local RouterReplay instances. In cases + where the number of local RouterReplay instances does not match the local + layer count, indicating that the model has dense layers before MoE layers, + we use the global layer_number to index into the correct slice of the data. + + """ + import megatron.core.parallel_state as mpu + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + + _patch_alltoall_dispatcher_for_replay() + + aligned = _remove_left_padding_from_indices(rollout_expert_indices, attention_mask) + + # CP splitting: mirror the front+back chunking from preprocess_packed_seqs + cp_size = mpu.get_context_parallel_world_size() + if cp_size > 1: + cp_rank = mpu.get_context_parallel_rank() + seq_len = aligned.shape[1] + seqlen_per_cp = seq_len // cp_size + half = seqlen_per_cp // 2 # we do *2 for causal masking, so get half of the sequence length per CP rank + front = aligned[:, half * cp_rank : half * (cp_rank + 1), :, :] + back_start = seq_len - half * (cp_rank + 1) + back_end = seq_len - half * cp_rank + back = aligned[:, back_start:back_end, :, :] + aligned = torch.cat([front, back], dim=1) + + # TP splitting: sequence parallelism across the tensor model parallel region + tp_size = mpu.get_tensor_model_parallel_world_size() + if tp_size > 1: + tp_rank = mpu.get_tensor_model_parallel_rank() + seq_len = aligned.shape[1] + chunk_size = seq_len // tp_size + aligned = aligned[:, tp_rank * chunk_size : (tp_rank + 1) * chunk_size, :, :] + + per_layer_data = _split_replay_indices(aligned) + global_num_layers_in_data = len(per_layer_data) + instances = RouterReplay.global_router_replay_instances + num_instances = len(instances) + + local_layer_offset, local_num_layers = _get_current_pp_stage_layer_range(model_config) + + if local_num_layers == num_instances: + local_per_layer_data = per_layer_data[local_layer_offset : local_layer_offset + local_num_layers] + RouterReplay.set_replay_data(local_per_layer_data) + else: + # Dense-layer mismatch: map each MoE router to its global layer index. + # Prefer the patched layer_number; fall back to offset-based mapping + # (assumes dense layers precede MoE layers). + for local_router_idx, router_instance in enumerate(instances): + layer_number = getattr(router_instance, "layer_number", None) + if layer_number is not None: + layer_idx = layer_number - 1 # layer_number is 1-based + else: + layer_idx = local_layer_offset + local_router_idx + if layer_idx < 0 or layer_idx >= global_num_layers_in_data: + raise ValueError( + f"Router replay layer index {layer_idx} out of range " + f"for data with {global_num_layers_in_data} layers " + f"({num_instances} router instances)" + ) + router_instance.set_target_indices(per_layer_data[layer_idx]) + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + +def clear_router_replay(): + """Clear all router replay state.""" + from megatron.core.transformer.moe.router_replay import RouterReplay + + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() + RouterReplay.clear_global_router_replay_instances() diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 5ab565f989..50ecb2fa90 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -24,6 +24,7 @@ remove_left_padding, recover_left_padding, ) +from skyrl.backends.skyrl_train.utils.replay_utils import setup_per_microbatch_replay_forward class MegatronModelWrapper: @@ -103,6 +104,13 @@ def collection_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) + + rollout_expert_indices = batch.pop("rollout_expert_indices", None) + if rollout_expert_indices is not None: + setup_per_microbatch_replay_forward( + rollout_expert_indices, batch["attention_mask"], get_model_config(model) + ) + sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] @@ -355,6 +363,12 @@ def loss_func(logits, data): def forward_step(batch_iter, model): batch = next(batch_iter) + rollout_expert_indices = batch.pop("rollout_expert_indices", None) + if rollout_expert_indices is not None: + setup_per_microbatch_replay_forward( + rollout_expert_indices, batch["attention_mask"], get_model_config(model) + ) + sequences = batch["sequences"] attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index e80ad1c85a..ec67b91897 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -277,13 +277,11 @@ def init_configs( override_config_kwargs.update(model_config_kwargs.get("model_config", {})) update_model_config(hf_config, override_config_kwargs=override_config_kwargs) - # if flash_attn is enabled, we use flash attention backend, otherwise fall back to fused attention backend transformer_config_kwargs = ( transformer_config_kwargs if isinstance(transformer_config_kwargs, dict) else OmegaConf.to_container(transformer_config_kwargs, resolve=True) ) - transformer_config_kwargs["attention_backend"] = "flash" if flash_attn else "fused" if not self.cfg.gradient_checkpointing: for key in ("recompute_granularity", "recompute_method", "recompute_num_layers"): @@ -322,6 +320,7 @@ def init_configs( self.strategy.hf_config = hf_config self.tokenizer = tokenizer + self.enable_router_replay = transformer_config_kwargs.get("moe_enable_routing_replay", False) def configure_lora(self, lora_config, lora_type: Optional[str] = "lora"): if lora_type == "lora": @@ -401,6 +400,8 @@ def forward(self, data: TrainingInputBatch): """ Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. """ + from skyrl.backends.skyrl_train.utils.replay_utils import clear_router_replay + # Run in micro batches grouped into a single mini-batch micro_bsz = self.cfg.micro_forward_batch_size_per_gpu micro_batches = data.chunk(micro_bsz) @@ -421,6 +422,9 @@ def forward(self, data: TrainingInputBatch): "attention_mask": attention_mask, "position_ids": position_ids, "num_actions": num_actions, + "rollout_expert_indices": ( + micro.get("rollout_expert_indices") if self.enable_router_replay else None + ), } ) @@ -438,6 +442,7 @@ def forward(self, data: TrainingInputBatch): log_probs = log_probs.to("cpu") output = TrainingOutputBatch({"output": log_probs}) output.metadata = data.metadata + clear_router_replay() return output def save_hf_model(self, export_dir: str, tokenizer): @@ -527,6 +532,11 @@ def init_model(self, model_path, num_training_steps: int = 1e9): flash_attn=self.cfg.flash_attn, ) + if self.enable_router_replay: + from skyrl.backends.skyrl_train.utils.replay_utils import _patch_topk_router_layer_number + + _patch_topk_router_layer_number() + # wrap with DDP for training self.actor_module = self.make_megatron_module( wrap_with_ddp=True, @@ -593,6 +603,8 @@ def forward_backward( Returns: Aggregated metrics dict across all micro batches """ + from skyrl.backends.skyrl_train.utils.replay_utils import clear_router_replay + self.model.train() for chunk in self.actor_module: # if use distributed optimizer, zero grad buffer will be handled by optimizer @@ -624,6 +636,7 @@ def forward_backward( "loss_mask": experience.loss_mask, "rollout_action_logprobs": experience.rollout_logprobs, "action_mask": experience.action_mask, + "rollout_expert_indices": experience.rollout_expert_indices if self.enable_router_replay else None, } ) @@ -666,6 +679,8 @@ def forward_backward( if all_loss_fn_outputs: status["loss_fn_outputs"] = all_loss_fn_outputs + clear_router_replay() + return status def optim_step(self) -> Optional[float]: diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index eb76c5ee7d..b8ad1b09fb 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -83,6 +83,7 @@ def batch_to_experience(batch: TrainingInputBatch): action_mask=batch.get("response_mask"), num_actions=batch.metadata["response_length"], # int rollout_logprobs=batch.get("rollout_logprobs"), + rollout_expert_indices=batch.get("rollout_expert_indices"), # additional info # can be used to log metrics etc for micro-batches in the worker info={}, diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index effaa7c6cc..4c4f749670 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -440,6 +440,7 @@ class InferenceEngineConfig(BaseConfig): """Sets ``VLLM_ENABLE_V1_MULTIPROCESSING=0`` for reproducibility.""" enable_prefix_caching: bool = True enable_chunked_prefill: bool = True + enable_return_routed_experts: bool = False max_num_batched_tokens: int = 8192 enforce_eager: bool = True """Disable CUDA graphs for stability. Set to ``False`` for higher performance, diff --git a/skyrl/train/config/megatron_config/policy.yaml b/skyrl/train/config/megatron_config/policy.yaml index d6ef8382f9..b641d1f706 100644 --- a/skyrl/train/config/megatron_config/policy.yaml +++ b/skyrl/train/config/megatron_config/policy.yaml @@ -57,6 +57,7 @@ transformer_config_kwargs: recompute_modules: ["core_attn"] recompute_method: uniform recompute_num_layers: 1 + moe_enable_routing_replay: ${generator.enable_return_routed_experts} # flag to manually empty torch's cuda cache between the forward/backward pass and the optimizer step # this will free reserved but unallocated memory, and can help avoid OoMs in the optimizer diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index 8f4184b18d..9772eb5e9b 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -286,6 +286,7 @@ generator: vllm_v1_disable_multiproc: true enable_prefix_caching: true enable_chunked_prefill: true + enable_return_routed_experts: false max_num_batched_tokens: 8192 # Disable CUDA graphs by default for stability. Set to false for higher performance, but this may affect convergence for long-running and/or long context training jobs. enforce_eager: true diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index a17cea1b91..67907e977f 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -1,7 +1,7 @@ from typing import List, Tuple, Optional import torch from transformers import AutoTokenizer -from jaxtyping import Float +from jaxtyping import Float, Integer def _verify_inputs( @@ -32,6 +32,7 @@ def convert_prompts_responses_to_batch_tensors( rewards: List[List[float]], loss_masks: List[List[int]], logprobs: Optional[List[List[float]]] = None, + rollout_expert_indices: Optional[List[List[List[List[int]]]]] = None, ) -> Tuple[ Float[torch.Tensor, "batch seq_len"], Float[torch.Tensor, "batch seq_len"], @@ -39,6 +40,7 @@ def convert_prompts_responses_to_batch_tensors( Float[torch.Tensor, "batch response_len"], Float[torch.Tensor, "batch response_len"], Optional[Float[torch.Tensor, "batch response_len"]], + Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]], ]: """ Convert prompts and responses to batch tensors for training. @@ -129,4 +131,27 @@ def convert_prompts_responses_to_batch_tensors( ] logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float) - return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor + rollout_expert_indices_tensor = None + if rollout_expert_indices: + first_non_empty = next((x for x in rollout_expert_indices if x), None) + if first_non_empty: + total_seq_len = max_input_len + max_output_len + num_layers = len(first_non_empty[0]) + topk = len(first_non_empty[0][0]) if num_layers > 0 else 0 + padded = torch.zeros(len(rollout_expert_indices), total_seq_len, num_layers, topk, dtype=torch.int32) + for i, sample_indices in enumerate(rollout_expert_indices): + if sample_indices: + left_pad = max_input_len - prompt_token_lens[i] + n = min(len(sample_indices), total_seq_len - left_pad) + padded[i, left_pad : left_pad + n] = torch.tensor(sample_indices[:n], dtype=torch.int32) + rollout_expert_indices_tensor = padded + + return ( + sequences, + attention_mask, + action_mask, + ret_rewards, + ret_loss_masks, + logprobs_tensor, + rollout_expert_indices_tensor, + ) diff --git a/skyrl/train/dataset/replay_buffer.py b/skyrl/train/dataset/replay_buffer.py index 07846937a6..072c65fdf7 100644 --- a/skyrl/train/dataset/replay_buffer.py +++ b/skyrl/train/dataset/replay_buffer.py @@ -66,6 +66,7 @@ class Experience: loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]] action_mask: Optional[Integer[torch.Tensor, "batch response_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]] + rollout_expert_indices: Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]] num_actions: int info: Optional[dict] kl: Optional[Float[torch.Tensor, "batch response_len"]] = None @@ -92,6 +93,8 @@ def to_device(self, device: torch.device) -> None: self.action_mask = to(self.action_mask, device) if self.rollout_logprobs is not None: self.rollout_logprobs = to(self.rollout_logprobs, device) + if self.rollout_expert_indices is not None: + self.rollout_expert_indices = to(self.rollout_expert_indices, device) def pin_memory(self): self.sequences = pin_memory(self.sequences) @@ -113,6 +116,8 @@ def pin_memory(self): self.action_mask = self.action_mask.pin_memory() if self.rollout_logprobs is not None: self.rollout_logprobs = self.rollout_logprobs.pin_memory() + if self.rollout_expert_indices is not None: + self.rollout_expert_indices = self.rollout_expert_indices.pin_memory() return self diff --git a/skyrl/train/entrypoints/main_base.py b/skyrl/train/entrypoints/main_base.py index 10e78e9de3..e3c2b980f1 100644 --- a/skyrl/train/entrypoints/main_base.py +++ b/skyrl/train/entrypoints/main_base.py @@ -68,6 +68,7 @@ def create_ray_wrapped_inference_engines_from_config( "backend": ie_cfg.backend, "engine_init_kwargs": ie_cfg.engine_init_kwargs, "enable_ray_prometheus_stats": ie_cfg.enable_ray_prometheus_stats, + "enable_return_routed_experts": ie_cfg.enable_return_routed_experts, } # Conditionally add LoRA parameters if LoRA is enabled diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index fabe5524e8..49b3ecc1ac 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -39,6 +39,7 @@ class GeneratorOutput(TypedDict): rollout_metrics: Optional[Dict[str, Any]] rollout_logprobs: Optional[List[List[float]]] trajectory_ids: Optional[List[TrajectoryID]] + rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] # Applicable only for step-wise training is_last_step: Optional[List[bool]] diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 35c7e189f3..cbad99eb31 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -40,6 +40,7 @@ class TrajectoryOutput: prompt_ids: List[int] rollout_logprobs: Optional[List[float]] env_metrics: Dict[str, Any] + rollout_expert_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -57,6 +58,7 @@ class AgentLoopState: rollout_logprobs: Optional[List[float]] response_end_idx: Optional[int] done: bool + rollout_expert_indices: Optional[List[List[List[int]]]] = None @dataclass @@ -66,9 +68,31 @@ class TurnOutput: output_logprobs: Optional[List[float]] new_obs: ConversationType obs_ids: List[int] + rollout_expert_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk] reward: Optional[float] added_eos: bool = False + def get_turn_rollout_expert_indices(self) -> Optional[List[List[List[int]]]]: + """ + Get rollout inference indices for this turn's tokens (output + observation). + + Returns indices for generated output tokens, with padding entries (all -1) + for any manually-added EOS token and observation tokens. + Returns None if rollout_expert_indices is None. + """ + if self.rollout_expert_indices is None: + return None + if not self.rollout_expert_indices: + return self.rollout_expert_indices + layer_num = len(self.rollout_expert_indices[0]) + topk = len(self.rollout_expert_indices[0][0]) if layer_num > 0 else 0 + pad_entry = [[-1] * topk for _ in range(layer_num)] + indices = list(self.rollout_expert_indices) + if self.added_eos: + indices.append(pad_entry) + indices.extend(pad_entry for _ in range(len(self.obs_ids))) + return indices + def get_turn_loss_mask(self) -> List[int]: """ Get loss mask for this turn's tokens. @@ -300,11 +324,14 @@ async def agent_loop( output_ids = engine_output["response_ids"][0] stop_reason = engine_output["stop_reasons"][0] response_logprobs = engine_output.get("response_logprobs", None) + rollout_expert_indices = engine_output.get("rollout_expert_indices", None) if response_logprobs is not None: response_logprobs = response_logprobs[0] if self.custom_chat_template is not None: raise ValueError("Response Logprobs bookkeeping is not supported with custom chat template") + if rollout_expert_indices is not None: + rollout_expert_indices = rollout_expert_indices[0] # Append eos when sampling_params.stop is not None. Does not affect 3.a as chat templates add eos_token. # sampling_params is not None for eval, but None for training (which uses engine.sampling_params which are from cfg) stop_strs = current_sampling_params.get("stop", None) @@ -348,8 +375,12 @@ async def agent_loop( reward=step_reward, obs_ids=obs_ids, added_eos=added_eos, + rollout_expert_indices=rollout_expert_indices, ) + if turn_output.rollout_expert_indices is not None and agent_loop_state.rollout_expert_indices is None: + agent_loop_state.rollout_expert_indices = [] + if is_step_wise: # current response + observation ids turn_response_ids = turn_output.output_ids + turn_output.obs_ids @@ -367,6 +398,7 @@ async def agent_loop( rollout_logprobs=turn_response_logprobs, stop_reason=stop_reason, env_metrics=env.get_metrics() if agent_loop_state.done else {}, + rollout_expert_indices=turn_output.get_turn_rollout_expert_indices(), ) agent_loop_output.step_outputs.append(per_step_output) @@ -395,6 +427,7 @@ async def agent_loop( prompt_ids = agent_loop_state.input_ids[:initial_prompt_length] rollout_logprobs = None + rollout_expert_indices_out = None response_ids = None # Prepare the final loss_mask, response_ids and rollout_logprobs . @@ -425,6 +458,10 @@ async def agent_loop( rollout_logprobs = agent_loop_state.rollout_logprobs[ : agent_loop_state.response_end_idx - initial_prompt_length + 1 ] + if agent_loop_state.rollout_expert_indices is not None: + rollout_expert_indices_out = agent_loop_state.rollout_expert_indices[ + : agent_loop_state.response_end_idx + 1 + ] # fix index for per_step_rewards per_step_rewards = [(reward, idx - initial_prompt_length) for reward, idx in per_step_rewards] assert len(loss_mask) == len( @@ -441,6 +478,10 @@ async def agent_loop( loss_mask.append(1) if rollout_logprobs is not None: rollout_logprobs.append(0.0) + if rollout_expert_indices_out is not None and rollout_expert_indices_out: + layer_num = len(rollout_expert_indices_out[0]) + topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0 + rollout_expert_indices_out.append([[-1] * topk for _ in range(layer_num)]) appended_eos_token = True if self.generator_cfg.step_wise_trajectories: @@ -460,6 +501,7 @@ async def agent_loop( prompt_ids=prompt_ids, rollout_logprobs=rollout_logprobs, env_metrics=env_metrics, + rollout_expert_indices=rollout_expert_indices_out, ) return agent_loop_output @@ -611,12 +653,14 @@ async def generate_batched( responses = engine_output["response_ids"] stop_reasons = engine_output["stop_reasons"] logprobs = engine_output.get("response_logprobs", None) + raw_rollout_expert_indices = engine_output.get("rollout_expert_indices", None) truncated_responses = [] rewards = [] loss_masks = [] env_metrics = [] truncated_logprobs: Optional[List[List[float]]] = [] if logprobs is not None else None + truncated_indices: Optional[List] = [] if raw_rollout_expert_indices is not None else None for i, (output, response, env, env_class) in enumerate(zip(outputs, responses, envs, env_classes)): # step on environment and compute reward @@ -631,6 +675,8 @@ async def generate_batched( if logprobs is not None: sample_logprobs = logprobs[i][: len(response)] truncated_logprobs.append(sample_logprobs) + if raw_rollout_expert_indices is not None: + truncated_indices.append(raw_rollout_expert_indices[i]) # Get environment-specific metrics env_metrics.append(env.get_metrics()) @@ -650,6 +696,7 @@ async def generate_batched( "stop_reasons": stop_reasons, "rollout_metrics": rollout_metrics, "rollout_logprobs": truncated_logprobs, + "rollout_expert_indices": truncated_indices, } return generator_output @@ -750,6 +797,15 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False else: rollout_logprobs = None + if self.generator_cfg.step_wise_trajectories: + all_indices = sum( + [[step_output.rollout_expert_indices for step_output in output.step_outputs] for output in all_outputs], + [], + ) + else: + all_indices = [output.rollout_expert_indices for output in all_outputs] + rollout_expert_indices = all_indices if any(idx is not None for idx in all_indices) else None + rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) if self.generator_cfg.zero_reward_on_non_stop: @@ -768,6 +824,7 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False "rollout_metrics": rollout_metrics, "rollout_logprobs": rollout_logprobs, "trajectory_ids": out_trajectory_ids, + "rollout_expert_indices": rollout_expert_indices, "is_last_step": is_last_step, } @@ -835,6 +892,8 @@ def _update_agent_state_by_retokenizing_chat_history( agent_loop_state.response_end_idx = None # `logprobs` are not computed because retokenizing breaks token-in-token-out agent_loop_state.rollout_logprobs = None + # indices are not meaningful when retokenizing + agent_loop_state.rollout_expert_indices = None return agent_loop_state def _update_agent_loop_state_with_multiturn_chat_template( @@ -886,12 +945,15 @@ def _update_agent_loop_state_with_multiturn_chat_template( loss_mask_for_turn = turn_output.get_turn_loss_mask() rollout_logprobs_for_turn = turn_output.get_turn_rollout_logprobs() + rollout_expert_indices_for_turn = turn_output.get_turn_rollout_expert_indices() + if self.generator_cfg.step_wise_trajectories: # cumulative input_ids is not tracked for step wise training agent_loop_state.response_end_idx = len(turn_output.output_ids) - 1 - # no running loss_mask or `rollout_logprobs` are tracked for step-wise training + # no running loss_mask, `rollout_logprobs`, or `rollout_expert_indices` are tracked for step-wise training agent_loop_state.loss_mask = None agent_loop_state.rollout_logprobs = None + agent_loop_state.rollout_expert_indices = None else: # Directly append turn output turn_ids = turn_output.output_ids + turn_output.obs_ids @@ -900,6 +962,10 @@ def _update_agent_loop_state_with_multiturn_chat_template( agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn + if agent_loop_state.rollout_expert_indices is not None and rollout_expert_indices_for_turn is not None: + # overwrite the existing rollout inference indices, since the inference engine should + # return the expert indices for the entire sequence including each turn's input + agent_loop_state.rollout_expert_indices = rollout_expert_indices_for_turn return agent_loop_state @@ -964,11 +1030,21 @@ def _update_agent_loop_state_with_singleturn_chat_template( obs_ids_to_add ) + # rollout_expert_indices_for_turn = None + # if turn_output.rollout_expert_indices is not None and turn_output.rollout_expert_indices: + # layer_num = len(turn_output.rollout_expert_indices[0]) + # topk = len(turn_output.rollout_expert_indices[0][0]) if layer_num > 0 else 0 + # pad_entry = [[-1] * topk for _ in range(layer_num)] + # rollout_expert_indices_for_turn = list(turn_output.rollout_expert_indices[: len(new_resp_tokens)]) + # rollout_expert_indices_for_turn.extend(pad_entry for _ in range(len(obs_ids_to_add))) + # Directly append turn output agent_loop_state.response_end_idx = len(agent_loop_state.input_ids) + len(new_resp_tokens) - 1 agent_loop_state.input_ids += turn_ids agent_loop_state.loss_mask += loss_mask_for_turn if agent_loop_state.rollout_logprobs is not None and rollout_logprobs_for_turn is not None: agent_loop_state.rollout_logprobs += rollout_logprobs_for_turn + if self.generator_cfg.enable_return_routed_experts and turn_output.rollout_expert_indices is not None: + agent_loop_state.rollout_expert_indices = turn_output.rollout_expert_indices return agent_loop_state diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index ada3afbc5c..ef73ab11b5 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -604,6 +604,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks: List[List[int]] = generator_output["loss_masks"] logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None) + rollout_expert_indices: Optional[List[List[List[List[int]]]]] = generator_output.get( + "rollout_expert_indices", None + ) ( sequences_tensor, @@ -612,6 +615,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards_tensor, loss_masks_tensor, rollout_logprobs_tensor, + rollout_expert_indices_tensor, ) = convert_prompts_responses_to_batch_tensors( self.tokenizer, prompt_ids, @@ -619,6 +623,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards, loss_masks, logprobs, + rollout_expert_indices, ) # sanity check for off_policy_correction @@ -639,6 +644,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis "rewards": rewards_tensor, "loss_mask": loss_masks_tensor, "rollout_logprobs": rollout_logprobs_tensor, + "rollout_expert_indices": rollout_expert_indices_tensor, "is_last_step": ( torch.tensor(generator_output["is_last_step"], dtype=torch.bool) if generator_output.get("is_last_step", None) is not None diff --git a/skyrl/utils/tok.py b/skyrl/utils/tok.py index 1006830500..b328ee8334 100644 --- a/skyrl/utils/tok.py +++ b/skyrl/utils/tok.py @@ -7,6 +7,7 @@ def get_tokenizer(model_name_or_path, **tokenizer_kwargs) -> AutoTokenizer: """Gets tokenizer for the given base model with the given parameters Sets the pad token ID to EOS token ID if `None`""" + tokenizer_kwargs.setdefault("trust_remote_code", True) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **tokenizer_kwargs) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py index 3a0bfe532c..17bafaa98a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/conftest.py @@ -32,7 +32,7 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - env_vars["NVTE_FUSED_ATTN"] = "0" + # env_vars["NVTE_FUSED_ATTN"] = "0" if SKYRL_PYTHONPATH_EXPORT: pythonpath = os.environ.get("PYTHONPATH") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 6e31b5815b..55e30ddb6a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -33,7 +33,8 @@ # TODO (erictang000): we would prefer to use this smaller MoE model for testing, but seeing incorrect logprobs when using EP > 1 # this might be a model specific mbridge issue - see if this persists when we transition to Megatron-Bridge # MOE_MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B" -MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +# MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +MOE_MODEL_NAME = "/home/ray/moonlight16b" def get_test_actor_config(model_name=MODEL_NAME) -> SkyRLTrainConfig: @@ -228,10 +229,10 @@ async def test_megatron_forward( cfg.trainer.use_sample_packing = use_sample_packing batch = get_test_training_batch(max(4, gpus_per_node)) - if ep > 1: - if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: - cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() - cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 + # if ep > 1: + # if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: + # cfg.trainer.policy.megatron_config.transformer_config_kwargs = dict() + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers"] = 2 if lora: cfg.trainer.policy.model.lora = SkyRLLoraConfig(rank=16, alpha=16) @@ -246,6 +247,9 @@ async def test_megatron_forward( action_log_probs_refs = actor_group.async_run_ray_method("mesh", "forward", data=batch) all_rank_action_log_probs = ray.get(action_log_probs_refs) + print( + f"Megatron logprobs - mean: {all_rank_action_log_probs.mean().item():.6f}, std: {all_rank_action_log_probs.std().item():.6f}" + ) action_log_probs_megatron = concatenate_outputs_after_mesh_dispatch( actor_group.actor_infos, all_rank_action_log_probs )["output"] @@ -258,8 +262,8 @@ async def test_megatron_forward( @ray.remote(num_gpus=1) def run_hf_forward(batch, model_name): config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, dtype=torch.bfloat16) - if ep > 1: - config.num_hidden_layers = 2 + # if ep > 1: + # config.num_hidden_layers = 2 model = AutoModelForCausalLM.from_pretrained(model_name, config=config, dtype=torch.bfloat16) model.eval() model.to("cuda") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py new file mode 100644 index 0000000000..707455eca5 --- /dev/null +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -0,0 +1,522 @@ +""" +Run with: +uv run --isolated --extra dev --extra megatron -- pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +""" + +import ray +import pytest +import asyncio +import torch +from transformers import AutoTokenizer +from tests.backends.skyrl_train.gpu.utils import ( + InferenceEngineState, + get_test_generator_input, + Timer, + init_worker_with_type, +) +from skyrl.backends.skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch +from skyrl.train.utils.utils import validate_cfg +from skyrl.train.config import SkyRLTrainConfig, SamplingParams +from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator +from skyrl.train.generators.base import GeneratorInput +from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend +from skyrl.train.dataset.preprocess import convert_prompts_responses_to_batch_tensors +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch + +# MOE_MODEL_NAME = "/home/ray/moonlight16b" +# MOE_MODEL_NAME = "Qwen/Qwen3-30B-A3B" +MOE_MODEL_NAME = "moonshotai/Moonlight-16B-A3B" +REPLAY_NUM_LAYERS = 2 +NUM_PROMPTS = 2 +N_SAMPLES_PER_PROMPT = 2 +MAX_GENERATE_LENGTH = 128 + + +def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: + cfg = SkyRLTrainConfig() + cfg.trainer.policy.model.path = model_name + cfg.trainer.micro_forward_batch_size_per_gpu = 2 + cfg.trainer.micro_train_batch_size_per_gpu = 2 + cfg.trainer.use_sample_packing = True + # flash attn + mla works without sample packing, logprobs are crazy/wrong + # but flash-attn correctly throws error with sample packing + # we should add an assert that if you set use_sample_packing=False flash attn can accidentally be used + cfg.trainer.logger = "console" + if "moonlight" in model_name: + if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = {} + # flash attn not supported for moonlight16b + # cfg.trainer.policy.megatron_config.moe_token_dispatcher_type = "alltoall" + # cfg.trainer.policy.megatron_config.moe_router_load_balancing_type = "seq_aux_loss" + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_aux_loss_coeff"] = 0 + # cfg.trainer.policy.megatron_config.moe_router_score_function = "sigmoid" + # cfg.trainer.policy.megatron_config.moe_router_enable_expert_bias = True + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_bias_update_rate"] = 0 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_dtype"] = "fp32" + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_topk"] = 6 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_pre_softmax"] = True + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_group_topk"] = 1 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["moe_router_num_groups"] = 1 + # cfg.trainer.policy.megatron_config.transformer_config_kwargs["num_layers_in_last_pipeline_stage"] = 13 + cfg.trainer.flash_attn = False + validate_cfg(cfg) + return cfg + + +def build_training_input_from_text_samples( + tokenizer: AutoTokenizer, prompt_response_pairs: list[tuple[str, str]] +) -> TrainingInputBatch: + prompts = [] + responses = [] + rewards = [] + loss_masks = [] + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + for prompt_text, response_text in prompt_response_pairs: + prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False) + response_ids = tokenizer.encode(response_text, add_special_tokens=False) + if tokenizer.eos_token_id is not None and (not response_ids or response_ids[-1] != tokenizer.eos_token_id): + response_ids.append(tokenizer.eos_token_id) + + prompts.append(prompt_ids) + responses.append(response_ids) + rewards.append([0.0] * len(response_ids)) + loss_masks.append([1] * len(response_ids)) + + sequences, attention_mask, response_mask, rewards_t, loss_mask_t, _, _ = convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=prompts, + responses=responses, + rewards=rewards, + loss_masks=loss_masks, + ) + + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) + training_input.metadata = {"response_length": num_actions} + return training_input + + +def test_generate_with_router_replay(ray_init_fixture): + """ + Check that generate with router replay produces the correct rollout inference indices. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 8 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=MAX_GENERATE_LENGTH, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 2 + cfg.generator.use_conversation_multi_turn = True + env_class = "gsm8k_multi_turn" + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + colocate_all=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.9, + ) as engines: + client = engines.client + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, + max_prompt_length=512, + env_class=env_class, + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=MAX_GENERATE_LENGTH, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + indices = generator_output["rollout_expert_indices"] + responses = generator_output["response_ids"] + assert ( + indices is not None + ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + asyncio.run(client.sleep()) + finally: + ray.shutdown() + + +@pytest.mark.megatron +@pytest.mark.parametrize( + "tp,pp,cp,ep,etp,extra_tf_kwargs", + [ + pytest.param(2, 1, 1, 2, 1, {}, id="baseline"), + pytest.param(2, 2, 1, 2, 1, {"num_layers_in_last_pipeline_stage": 13}, id="pp2"), + pytest.param(4, 1, 2, 8, 1, {}, id="cp2"), + pytest.param(2, 2, 2, 4, 1, {"num_layers_in_last_pipeline_stage": 13}, id="cp2_pp2"), + ], +) +def test_logprobs(ray_init_fixture, tp, pp, cp, ep, etp, extra_tf_kwargs): + """ + Check that logprob diff is lower when using router replay. Requires full 8xH100 setup to do full forward pass. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 8 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=MAX_GENERATE_LENGTH, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 1 + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + colocate_all=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.9, + ) as engines: + client, pg = engines.client, engines.pg + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, + max_prompt_length=512, + env_class="gsm8k", + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=MAX_GENERATE_LENGTH, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + indices = generator_output["rollout_expert_indices"] + responses = generator_output["response_ids"] + assert ( + indices is not None + ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + asyncio.run(client.sleep()) + + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_expert_indices=indices, + ) + ) + + assert rii_tensor is not None + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": ( + logprobs_t + if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32) + ), + "rollout_expert_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) + training_input.metadata = {"response_length": num_actions} + + cfg.trainer.placement.policy_num_gpus_per_node = 8 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = tp + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = pp + cfg.trainer.policy.megatron_config.context_parallel_size = cp + cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + def run_megatron_forward(enable_replay: bool) -> torch.Tensor: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + "moe_enable_routing_replay": enable_replay, + **extra_tf_kwargs, + } + actor_group = init_worker_with_type( + "policy", + shared_pg=pg, + colocate_all=True, + num_gpus_per_node=8, + cfg=cfg, + ) + + refs = actor_group.async_run_ray_method("mesh", "forward", data=training_input) + results = ray.get(refs) + outputs = concatenate_outputs_after_mesh_dispatch(actor_group.actor_infos, results)["output"] + + for actor in actor_group._actor_handlers: + ray.kill(actor) + return outputs + + r3_logprobs = run_megatron_forward(enable_replay=True) + no_r3_logprobs = run_megatron_forward(enable_replay=False) + + r3_diff = (logprobs_t - r3_logprobs).abs() + no_r3_diff = (logprobs_t - no_r3_logprobs).abs() + print(f"vLLM logprobs - mean: {logprobs_t.mean().item():.6f}, std: {logprobs_t.std().item():.6f}") + print(f"Megatron (replay) - mean: {r3_logprobs.mean().item():.6f}, std: {r3_logprobs.std().item():.6f}") + print(f"Megatron (no rep) - mean: {no_r3_logprobs.mean().item():.6f}, std: {no_r3_logprobs.std().item():.6f}") + print(f"With replay - logprob diff mean: {r3_diff.mean().item():.6f}, std: {r3_diff.std().item():.6f}") + print(f"Without replay - logprob diff mean: {no_r3_diff.mean().item():.6f}, std: {no_r3_diff.std().item():.6f}") + + assert r3_diff.mean().item() < no_r3_diff.mean().item(), ( + f"Router replay should reduce logprob diff vs rollout, " + f"but with_replay={r3_diff.mean().item():.6f} >= without_replay={no_r3_diff.mean().item():.6f}" + ) + finally: + ray.shutdown() + + +@pytest.mark.megatron +@pytest.mark.parametrize( + "tp,pp,cp,ep,etp,extra_tf_kwargs", + [ + pytest.param(4, 1, 1, 8, 1, {}, id="baseline"), + pytest.param(2, 2, 1, 2, 1, {"num_layers_in_last_pipeline_stage": 13}, id="pp2"), + ], +) +def test_forward_backward(ray_init_fixture, tp, pp, cp, ep, etp, extra_tf_kwargs): + """ + Check that forward_backward produces similar losses with and without + router replay (same weights, so routing decisions should nearly match). + Requires full 8xH100 setup. + """ + try: + cfg = get_test_actor_config(model_name=MOE_MODEL_NAME) + cfg.trainer.strategy = "megatron" + cfg.generator.inference_engine.enable_return_routed_experts = True + cfg.generator.inference_engine.tensor_parallel_size = 8 + cfg.generator.sampling_params = SamplingParams( + max_generate_length=MAX_GENERATE_LENGTH, + logprobs=1, + temperature=1.0, + ) + cfg.generator.batched = False + cfg.generator.max_turns = 1 + + tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) + + with InferenceEngineState.create( + cfg=cfg, + model=MOE_MODEL_NAME, + use_local=True, + colocate_all=True, + backend="vllm", + sleep_level=1, + gpu_memory_utilization=0.9, + ) as engines: + client, pg = engines.client, engines.pg + asyncio.run(client.wake_up()) + + generator = SkyRLGymGenerator( + generator_cfg=cfg.generator, + skyrl_gym_cfg=cfg.environment.skyrl_gym, + inference_engine_client=client, + tokenizer=tokenizer, + ) + + input_batch: GeneratorInput = get_test_generator_input( + model=MOE_MODEL_NAME, + num_prompts=NUM_PROMPTS, + n_samples_per_prompt=N_SAMPLES_PER_PROMPT, + max_prompt_length=512, + env_class="gsm8k", + ) + input_batch["sampling_params"] = get_sampling_params_for_backend( + "vllm", + SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=-1, + max_generate_length=MAX_GENERATE_LENGTH, + min_p=0.0, + logprobs=1, + ), + ) + + with Timer("generate_with_router_replay"): + generator_output = asyncio.run(generator.generate(input_batch)) + + indices = generator_output["rollout_expert_indices"] + responses = generator_output["response_ids"] + assert ( + indices is not None + ), "rollout_expert_indices should not be None when enable_return_routed_experts=True" + assert len(indices) == len( + responses + ), f"Batch size mismatch: {len(indices)} indices vs {len(responses)} responses" + asyncio.run(client.sleep()) + + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + rewards = [[r] * len(resp) for r, resp in zip(rewards, responses)] + (sequences, attention_mask, response_mask, rewards_t, loss_mask_t, logprobs_t, rii_tensor) = ( + convert_prompts_responses_to_batch_tensors( + tokenizer=tokenizer, + prompts=generator_output["prompt_token_ids"], + responses=responses, + rewards=rewards, + loss_masks=generator_output["loss_masks"], + logprobs=generator_output.get("rollout_logprobs"), + rollout_expert_indices=indices, + ) + ) + + assert rii_tensor is not None + num_actions = response_mask.shape[1] + batch_size = sequences.shape[0] + training_input = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "response_mask": response_mask, + "rewards": rewards_t, + "loss_mask": loss_mask_t, + "rollout_logprobs": ( + logprobs_t + if logprobs_t is not None + else torch.zeros((batch_size, num_actions), dtype=torch.float32) + ), + "rollout_expert_indices": rii_tensor, + "action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "base_action_log_probs": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "advantages": torch.zeros((batch_size, num_actions), dtype=torch.float32), + "action_mask": response_mask.to(dtype=torch.int64), + } + ) + training_input.metadata = {"response_length": num_actions} + + cfg.trainer.placement.policy_num_gpus_per_node = 8 + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = tp + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = pp + cfg.trainer.policy.megatron_config.context_parallel_size = cp + cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + + def run_megatron_forward_backward(enable_replay: bool) -> dict: + cfg.trainer.policy.megatron_config.transformer_config_kwargs = { + "moe_enable_routing_replay": enable_replay, + **extra_tf_kwargs, + } + actor_group = init_worker_with_type( + "policy", + shared_pg=pg, + colocate_all=True, + num_gpus_per_node=8, + cfg=cfg, + ) + results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) + for actor in actor_group._actor_handlers: + ray.kill(actor) + return results[0] + + metrics_replay = run_megatron_forward_backward(enable_replay=True) + metrics_no_replay = run_megatron_forward_backward(enable_replay=False) + + loss_replay = metrics_replay["policy_loss"] + loss_no_replay = metrics_no_replay["policy_loss"] + print(f"With replay - loss: {loss_replay:.6f}") + print(f"Without replay - loss: {loss_no_replay:.6f}") + print(f"With replay metrics: {metrics_replay}") + print(f"Without replay metrics: {metrics_no_replay}") + + diff = abs(loss_replay - loss_no_replay) + threshold = 0.5 + print(f"Loss diff: {diff:.6f} (threshold: {threshold})") + assert diff < threshold, ( + f"Losses with/without replay should be similar (same weights), " + f"but diff={diff:.6f} >= threshold={threshold}" + ) + finally: + ray.shutdown() diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py index 788bc8ef42..7a0b77bbfa 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py @@ -29,6 +29,7 @@ def get_test_config( is_step_wise, temperature, get_logprobs, + enable_return_routed_experts, ): cfg = SkyRLTrainConfig() cfg.trainer.policy.model.path = model @@ -49,7 +50,7 @@ def get_test_config( cfg.generator.inference_engine.http_endpoint_host = "127.0.0.1" cfg.generator.inference_engine.http_endpoint_port = 8000 cfg.generator.step_wise_trajectories = is_step_wise - + cfg.generator.inference_engine.enable_return_routed_experts = enable_return_routed_experts cfg.environment.skyrl_gym.search.log_requests = True cfg.environment.skyrl_gym.search.search_url = "http://127.0.0.1:8000/retrieve" cfg.environment.skyrl_gym.max_env_workers = max_env_workers @@ -108,6 +109,7 @@ async def run_generator_end_to_end( is_step_wise: bool = False, temperature=1.0, get_logprobs: bool = False, + enable_return_routed_experts: bool = False, ): """ End to end generator test - requires minimum 2 GPUs @@ -125,6 +127,7 @@ async def run_generator_end_to_end( is_step_wise, temperature, get_logprobs, + enable_return_routed_experts, ) # Use InferenceEngineState to support both legacy and new inference backends @@ -186,6 +189,11 @@ async def run_generator_end_to_end( "rollout_logprobs": ( generator_output["rollout_logprobs"][i] if generator_output["rollout_logprobs"] else None ), + "rollout_expert_indices": ( + generator_output["rollout_expert_indices"][i] + if generator_output["rollout_expert_indices"] + else None + ), } for i in range(len(generator_output["response_ids"])) ] @@ -450,3 +458,33 @@ async def test_generator_multi_turn_gsm8k_step_wise(ray_init_fixture): assert isinstance(generator_output["is_last_step"], list) and isinstance(generator_output["is_last_step"][0], bool) # Expect atleast one response with more than one turn assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) + + +async def test_generator_multi_turn_gsm8k_router_replay(ray_init_fixture): + """ + Test the generator with the multi-turn GSM8K environment for router replay + """ + generator_output: GeneratorOutput = await run_generator_end_to_end( + use_async_engine=True, + batched=False, + n_samples_per_prompt=5, + num_inference_engines=2, + tensor_parallel_size=2, + model="arcee-ai/Trinity-Nano-Preview", + max_prompt_length=2048, + max_input_length=4096, + max_generate_length=1000, + data_path=os.path.expanduser("~/data/gsm8k/validation.parquet"), + env_class="gsm8k_multi_turn", + num_prompts=2, + max_turns=2, + use_conversation_multi_turn=True, + max_env_workers=0, + is_step_wise=True, + temperature=0, + enable_return_routed_experts=True, + ) + + assert isinstance(generator_output["is_last_step"], list) and isinstance(generator_output["is_last_step"][0], bool) + # Expect atleast one response with more than one turn + assert sum(generator_output["is_last_step"]) != len(generator_output["is_last_step"]) diff --git a/tests/backends/skyrl_train/gpu/utils.py b/tests/backends/skyrl_train/gpu/utils.py index f6d726b902..a7f7545331 100644 --- a/tests/backends/skyrl_train/gpu/utils.py +++ b/tests/backends/skyrl_train/gpu/utils.py @@ -516,6 +516,7 @@ def create( sleep_level=sleep_level, enable_lora=enable_lora, engine_init_kwargs=ie_cfg.engine_init_kwargs, + enable_return_routed_experts=ie_cfg.enable_return_routed_experts, served_model_name=served_model_name, ) client = InferenceEngineClient(