diff --git a/skyrl/backends/skyrl_train/utils/replay_utils.py b/skyrl/backends/skyrl_train/utils/replay_utils.py index 47e405769f..7f6f7f6a21 100644 --- a/skyrl/backends/skyrl_train/utils/replay_utils.py +++ b/skyrl/backends/skyrl_train/utils/replay_utils.py @@ -103,9 +103,11 @@ def _remove_left_padding_from_indices( seq_lens = attention_mask.sum(dim=1) effective_seq_len = seq_lens.max().item() - sp_world_size = mpu.get_tensor_model_parallel_world_size() - if sp_world_size > 1: - pad_size = (sp_world_size - effective_seq_len % sp_world_size) % sp_world_size + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + if align_size > 1: + pad_size = (align_size - effective_seq_len % align_size) % align_size effective_seq_len += pad_size batch_size = rollout_expert_indices.shape[0] @@ -151,8 +153,6 @@ def _pack_replay_indices( seqlens_padded = seq_lens + pad_sizes total_packed_len = int(seqlens_padded.sum().item()) - if cp_size > 1: - total_packed_len = total_packed_len // cp_size packed = torch.zeros( total_packed_len, @@ -164,48 +164,88 @@ def _pack_replay_indices( seq_lens_cpu = seq_lens.tolist() seqlens_padded_cpu = seqlens_padded.tolist() - if cp_size > 1: - cp_rank = mpu.get_context_parallel_rank() offset = 0 for i in range(batch_size): n = seq_lens_cpu[i] mask = attention_mask[i].bool() d = rollout_expert_indices[i, mask] - if cp_size > 1: - chunk_size = seqlens_padded_cpu[i] // cp_size - start = cp_rank * chunk_size - end = min(start + chunk_size, n) - valid_len = max(0, end - start) - if valid_len > 0: - packed[offset : offset + valid_len] = d[start:end] - offset += chunk_size - else: - packed[offset : offset + n] = d - offset += seqlens_padded_cpu[i] + packed[offset : offset + n] = d + offset += seqlens_padded_cpu[i] - return packed.unsqueeze(0) # [1, total_packed_len, layers, topk] + if cp_size > 1: + cp_rank = mpu.get_context_parallel_rank() + out = torch.zeros( + total_packed_len // cp_size, + num_layers, + topk, + dtype=packed.dtype, + device=packed.device, + ) + src_offset = 0 + dst_offset = 0 + for i in range(batch_size): + seqlen_padded_i = seqlens_padded_cpu[i] + seqlen_per_cp = seqlen_padded_i // cp_size + half = seqlen_per_cp // 2 + out[dst_offset : dst_offset + half] = packed[ + src_offset + half * cp_rank : src_offset + half * (cp_rank + 1) + ] + back_start = src_offset + seqlen_padded_i - half * (cp_rank + 1) + back_end = src_offset + seqlen_padded_i - half * cp_rank + out[dst_offset + half : dst_offset + seqlen_per_cp] = packed[back_start:back_end] + src_offset += seqlen_padded_i + dst_offset += seqlen_per_cp + packed = out + + return packed.unsqueeze(0) # [1, packed_len_per_cp, layers, topk] + + +def _get_current_pp_stage_layer_range(model_config) -> tuple[int, int]: + """Return the current PP rank's transformer-layer range as (start_layer, + num_layers). + + Prefer Megatron's own helpers so replay indexing stays aligned with the + actual model partition, including embedding/loss pipeline accounting. + """ + import megatron.core.parallel_state as mpu + from megatron.core.transformer.transformer_block import get_num_layers_to_build + from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + + pp_rank = mpu.get_pipeline_model_parallel_rank() + offset = get_transformer_layer_offset(model_config, pp_rank=pp_rank) + num_layers = get_num_layers_to_build(model_config, pp_rank=pp_rank) + return offset, num_layers def setup_per_microbatch_replay_forward( rollout_expert_indices: torch.Tensor, attention_mask: torch.Tensor, + model_config, use_sample_packing: bool = False, ) -> None: """Set up RouterReplay for a single micro-batch, aligning indices - with the token layout that the MoE layer sees. + with the left-padding-removed token layout that the MoE layer sees. + + Handles context parallelism: when CP > 1, the sequence is split into + 2*cp_size chunks with each CP rank receiving a front chunk and a back + chunk (for causal-mask load balancing). Replay indices are split using + the same pattern so they stay aligned with the tokens each rank sees. Handles sequence parallelism: when TP > 1, the sequence is split across TP ranks, so each rank's MoE router only sees its local chunk of tokens. - Handles sample packing: when use_sample_packing is True, sequences are - concatenated into one packed sequence with per-sample alignment padding. - The replay indices must follow this same packed layout. - Handles dense-layer mismatch: DeepSeek V3-style models have dense FFN - layers before the MoE layers. vLLM reports routing indices for ALL + layers before the MoE layers. vLLM reports routing indices for ALL transformer layers, but Megatron only has RouterReplay instances for MoE - layers. We use each instance's global layer_number (set by the patched + layers. We use each instance's global layer_number (set by the patched TopKRouter.set_layer_number) to index into the correct slice of the data. + + Handles pipeline parallelism: when PP > 1, the sequence is split across + PP ranks, so each rank only sees its local RouterReplay instances. In cases + where the number of local RouterReplay instances does not match the local + layer count, indicating that the model has dense layers before MoE layers, + we use the global layer_number to index into the correct slice of the data. + """ import megatron.core.parallel_state as mpu from megatron.core.transformer.moe.router_replay import ( @@ -220,34 +260,36 @@ def setup_per_microbatch_replay_forward( else: aligned = _remove_left_padding_from_indices(rollout_expert_indices, attention_mask) + # TP splitting: sequence parallelism across the tensor model parallel region tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: tp_rank = mpu.get_tensor_model_parallel_rank() seq_len = aligned.shape[1] chunk_size = seq_len // tp_size aligned = aligned[:, tp_rank * chunk_size : (tp_rank + 1) * chunk_size, :, :] - per_layer_data = _split_replay_indices(aligned) - num_layers_in_data = len(per_layer_data) + global_num_layers_in_data = len(per_layer_data) instances = RouterReplay.global_router_replay_instances num_instances = len(instances) + local_layer_offset, local_num_layers = _get_current_pp_stage_layer_range(model_config) - if num_layers_in_data == num_instances: - RouterReplay.set_replay_data(per_layer_data) + if local_num_layers == num_instances: + local_per_layer_data = per_layer_data[local_layer_offset : local_layer_offset + local_num_layers] + RouterReplay.set_replay_data(local_per_layer_data) else: # Dense-layer mismatch: map each MoE router to its global layer index. # Prefer the patched layer_number; fall back to offset-based mapping # (assumes dense layers precede MoE layers). - for i, router_instance in enumerate(instances): + for local_router_idx, router_instance in enumerate(instances): layer_number = getattr(router_instance, "layer_number", None) if layer_number is not None: layer_idx = layer_number - 1 # layer_number is 1-based else: - layer_idx = i + (num_layers_in_data - num_instances) - if layer_idx < 0 or layer_idx >= num_layers_in_data: + layer_idx = local_layer_offset + local_router_idx + (local_num_layers - num_instances) + if layer_idx < 0 or layer_idx >= global_num_layers_in_data: raise ValueError( f"Router replay layer index {layer_idx} out of range " - f"for data with {num_layers_in_data} layers " + f"for data with {global_num_layers_in_data} layers " f"({num_instances} router instances)" ) router_instance.set_target_indices(per_layer_data[layer_idx]) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 49f6bde57f..597d59e446 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -114,7 +114,10 @@ def forward_step(batch_iter, model): rollout_expert_indices = batch.pop("rollout_expert_indices", None) if rollout_expert_indices is not None: setup_per_microbatch_replay_forward( - rollout_expert_indices, batch["attention_mask"], use_sample_packing=self.use_sample_packing + rollout_expert_indices, + batch["attention_mask"], + model_config=get_model_config(model), + use_sample_packing=self.use_sample_packing, ) sequences = batch["sequences"] @@ -380,7 +383,10 @@ def forward_step(batch_iter, model): rollout_expert_indices = batch.pop("rollout_expert_indices", None) if rollout_expert_indices is not None: setup_per_microbatch_replay_forward( - rollout_expert_indices, batch["attention_mask"], use_sample_packing=self.use_sample_packing + rollout_expert_indices, + batch["attention_mask"], + model_config=get_model_config(model), + use_sample_packing=self.use_sample_packing, ) sequences = batch["sequences"] diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 787434c7a2..1fa7b1d47b 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -206,12 +206,6 @@ def validate_megatron_cfg(cfg: SkyRLTrainConfig): assert ( cfg.generator.inference_engine.enable_return_routed_experts ), "rollout router replay (r3) is only supported when enable_return_routed_experts is True" - assert ( - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size == 1 - ), "pipeline parallel is not yet supported for router replay (r3) with megatron" - assert ( - cfg.trainer.policy.megatron_config.context_parallel_size == 1 - ), "context parallel is not yet supported for router replay (r3) with megatron" worker_configs = [(cfg.trainer.policy, "policy"), (cfg.trainer.ref, "ref")] for config, worker_type in worker_configs: diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py index bd114d7b76..21a163230a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py @@ -30,8 +30,8 @@ ) MOE_MODEL_NAME = "moonshotai/Moonlight-16B-A3B-Instruct" -NUM_PROMPTS = 5 -N_SAMPLES_PER_PROMPT = 2 +NUM_PROMPTS = 10 +N_SAMPLES_PER_PROMPT = 4 MAX_GENERATE_LENGTH = 128 @@ -45,10 +45,13 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig: # flash attn + mla works without sample packing, logprobs are crazy/wrong # but flash-attn correctly throws error with sample packing # we should add an assert that if you set use_sample_packing=False flash attn can accidentally be used + # and that we enable nvte fused attn for moonlight models with use_sample_packing=True + # need to enable nvte fused attn for router replay tests when using moonlight models with use_sample_packing=True cfg.trainer.logger = "console" if "moonlight" in model_name: if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None: cfg.trainer.policy.megatron_config.transformer_config_kwargs = {} + cfg.trainer.flash_attn = False validate_cfg(cfg) return cfg @@ -105,8 +108,14 @@ def build_training_input_from_text_samples( @pytest.mark.megatron -@pytest.mark.skip(reason="Skipping router replay test for now due to size constraints") -def test_logprobs(ray_init_fixture): +@pytest.mark.skip(reason="Skipping router replay tests for now due to size constraints") +@pytest.mark.parametrize( + "tp,pp,cp,ep,etp,extra_tf_kwargs", + [ + pytest.param(2, 2, 2, 4, 1, {"num_layers_in_first_pipeline_stage": 13}, id="max_parallelism"), + ], +) +def test_logprobs(ray_init_fixture, tp, pp, cp, ep, etp, extra_tf_kwargs): """ Check that logprob diff is lower when using router replay. Requires full 8xH100 setup to do full forward pass. """ @@ -120,8 +129,7 @@ def test_logprobs(ray_init_fixture): logprobs=1, temperature=1.0, ) - cfg.generator.batched = True - cfg.generator.async_engine = False + cfg.generator.batched = False cfg.generator.max_turns = 1 tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True) @@ -217,11 +225,13 @@ def test_logprobs(ray_init_fixture): training_input.metadata = {"response_length": num_actions} cfg.trainer.placement.policy_num_gpus_per_node = 8 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + if extra_tf_kwargs is not None: + cfg.trainer.policy.megatron_config.transformer_config_kwargs.update(extra_tf_kwargs) + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = tp + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = pp + cfg.trainer.policy.megatron_config.context_parallel_size = cp + cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp cfg.trainer.micro_forward_batch_size_per_gpu = 2 cfg.trainer.micro_train_batch_size_per_gpu = 2 @@ -245,11 +255,12 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: r3_logprobs = run_megatron_forward(enable_replay=True) no_r3_logprobs = run_megatron_forward(enable_replay=False) - mask = response_mask.bool() + vllm_valid = logprobs_t[mask] r3_valid = r3_logprobs[mask] no_r3_valid = no_r3_logprobs[mask] + r3_diff = (vllm_valid - r3_valid).abs() no_r3_diff = (vllm_valid - no_r3_valid).abs() print(f"vLLM logprobs - mean: {vllm_valid.mean().item():.6f}, std: {vllm_valid.std().item():.6f}") @@ -267,8 +278,14 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor: @pytest.mark.megatron -@pytest.mark.skip(reason="Skipping router replay test for now due to size constraints") -def test_forward_backward(ray_init_fixture): +@pytest.mark.skip(reason="Skipping router replay tests for now due to size constraints") +@pytest.mark.parametrize( + "tp,pp,cp,ep,etp,extra_tf_kwargs", + [ + pytest.param(2, 2, 2, 4, 1, {"num_layers_in_first_pipeline_stage": 13}, id="max_parallelism"), + ], +) +def test_forward_backward(ray_init_fixture, tp, pp, cp, ep, etp, extra_tf_kwargs): """ Check that forward_backward with router replay completes without error. Uses dummy expert routing indices (no vLLM engine needed). @@ -339,11 +356,13 @@ def test_forward_backward(ray_init_fixture): training_input.metadata = {"response_length": num_actions} cfg.trainer.placement.policy_num_gpus_per_node = 8 - cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 4 - cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - cfg.trainer.policy.megatron_config.context_parallel_size = 1 - cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8 - cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1 + if extra_tf_kwargs is not None: + cfg.trainer.policy.megatron_config.transformer_config_kwargs.update(extra_tf_kwargs) + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = tp + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = pp + cfg.trainer.policy.megatron_config.context_parallel_size = cp + cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep + cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp cfg.trainer.micro_forward_batch_size_per_gpu = 2 cfg.trainer.micro_train_batch_size_per_gpu = 2 cfg.trainer.policy.megatron_config.moe_enable_routing_replay = True @@ -354,7 +373,6 @@ def test_forward_backward(ray_init_fixture): cfg=cfg, ) - ray.get(actor_group.async_run_ray_method("pass_through", "setup_per_microbatch_replay_backward")) ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input))