Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
dc93a03
replay utils update
devpatelio Mar 17, 2026
d2b2d53
clean
devpatelio Mar 17, 2026
8cb6155
linter
devpatelio Mar 17, 2026
d1712bd
rm the sampke packing flag
devpatelio Mar 17, 2026
3da803a
move cp out of pack_replay_indices for simpler implementation
devpatelio Mar 20, 2026
f7362c3
model wrapper
devpatelio Mar 20, 2026
5d74b89
done:
erictang000 Mar 20, 2026
5855ee8
lint
devpatelio Mar 20, 2026
153a2d0
Fix async script to use megatron
devpatelio Mar 20, 2026
c46287e
Fix fallback logic to account for dense layers for indexing replay in…
devpatelio Mar 20, 2026
33b3083
[train] Add `worker_process_setup_hook` to set mp start method to `sp…
SumanthRH Mar 17, 2026
7066bf4
[CI] Fix `test_inference_engines_generation` after vllm 0.16.0 upgrad…
SumanthRH Mar 18, 2026
7744e69
[train] Make TrainingInputBatch to PAD only to left, hence response t…
CharlieFRuan Mar 18, 2026
bfcd8db
[train] Revert "[train] Add `worker_process_setup_hook` to set mp sta…
SumanthRH Mar 19, 2026
daf5752
[Docs] Add docs on agent integration and step-wise training (#1347)
CharlieFRuan Mar 19, 2026
493387c
[Docs] Small update on docs (#1348)
CharlieFRuan Mar 19, 2026
089ee8b
[train] Add validation for step-wise GeneratorOutput (#1281)
CharlieFRuan Mar 19, 2026
0f18c70
[megatron] rebuild weight conversion tasks per sync to prevent stale …
erictang000 Mar 19, 2026
1a73422
[StepWise] Trivial fix to avg_response_length metric (#1351)
CharlieFRuan Mar 19, 2026
04d29f3
[CI] Make `MultiItemDataset` a global variable after switch to `spawn…
SumanthRH Mar 19, 2026
2619ed4
[train] Add support for LoRA in the new inference codepath (#1329)
SumanthRH Mar 20, 2026
9a0088f
[bug][algorithm] remove incorrect torch.no_grad() for kl in loss (use…
erictang000 Mar 20, 2026
4cf3a48
latest main
devpatelio Mar 20, 2026
4b844c0
Merge remote-tracking branch 'origin/main' into r3-pp-cp
devpatelio Mar 20, 2026
44aab7d
final done
devpatelio Mar 21, 2026
be50ba6
revert test file changes
devpatelio Mar 21, 2026
f7c3086
fix
devpatelio Mar 21, 2026
bf84467
done
devpatelio Mar 21, 2026
0a2fdb2
remove comment
devpatelio Mar 21, 2026
73741a3
rm file
devpatelio Mar 21, 2026
43468a8
Revert "remove comment"
devpatelio Mar 21, 2026
a4b9228
uv.lock bring back
devpatelio Mar 21, 2026
a841448
Delete 1
devpatelio Mar 21, 2026
5def4cb
Revert "done"
devpatelio Mar 21, 2026
266b9c1
done
devpatelio Mar 21, 2026
5628f22
addressed comments
devpatelio Mar 25, 2026
2bc0ee0
lint
devpatelio Mar 25, 2026
fff1c6a
Merge branch 'main' of https://github.com/NovaSky-AI/SkyRL into r3-pp-cp
erictang000 Mar 26, 2026
9463aec
x
erictang000 Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 76 additions & 34 deletions skyrl/backends/skyrl_train/utils/replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ def _remove_left_padding_from_indices(

seq_lens = attention_mask.sum(dim=1)
effective_seq_len = seq_lens.max().item()
sp_world_size = mpu.get_tensor_model_parallel_world_size()
if sp_world_size > 1:
pad_size = (sp_world_size - effective_seq_len % sp_world_size) % sp_world_size
tp_size = mpu.get_tensor_model_parallel_world_size()
cp_size = mpu.get_context_parallel_world_size()
align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
if align_size > 1:
pad_size = (align_size - effective_seq_len % align_size) % align_size
effective_seq_len += pad_size

batch_size = rollout_expert_indices.shape[0]
Expand Down Expand Up @@ -151,8 +153,6 @@ def _pack_replay_indices(
seqlens_padded = seq_lens + pad_sizes

total_packed_len = int(seqlens_padded.sum().item())
if cp_size > 1:
total_packed_len = total_packed_len // cp_size

packed = torch.zeros(
total_packed_len,
Expand All @@ -164,48 +164,88 @@ def _pack_replay_indices(

seq_lens_cpu = seq_lens.tolist()
seqlens_padded_cpu = seqlens_padded.tolist()
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
offset = 0
for i in range(batch_size):
n = seq_lens_cpu[i]
mask = attention_mask[i].bool()
d = rollout_expert_indices[i, mask]
if cp_size > 1:
chunk_size = seqlens_padded_cpu[i] // cp_size
start = cp_rank * chunk_size
end = min(start + chunk_size, n)
valid_len = max(0, end - start)
if valid_len > 0:
packed[offset : offset + valid_len] = d[start:end]
offset += chunk_size
else:
packed[offset : offset + n] = d
offset += seqlens_padded_cpu[i]
packed[offset : offset + n] = d
offset += seqlens_padded_cpu[i]

return packed.unsqueeze(0) # [1, total_packed_len, layers, topk]
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
out = torch.zeros(
total_packed_len // cp_size,
num_layers,
topk,
dtype=packed.dtype,
device=packed.device,
)
src_offset = 0
dst_offset = 0
for i in range(batch_size):
seqlen_padded_i = seqlens_padded_cpu[i]
seqlen_per_cp = seqlen_padded_i // cp_size
half = seqlen_per_cp // 2
out[dst_offset : dst_offset + half] = packed[
src_offset + half * cp_rank : src_offset + half * (cp_rank + 1)
]
back_start = src_offset + seqlen_padded_i - half * (cp_rank + 1)
back_end = src_offset + seqlen_padded_i - half * cp_rank
out[dst_offset + half : dst_offset + seqlen_per_cp] = packed[back_start:back_end]
src_offset += seqlen_padded_i
dst_offset += seqlen_per_cp
packed = out

return packed.unsqueeze(0) # [1, packed_len_per_cp, layers, topk]


def _get_current_pp_stage_layer_range(model_config) -> tuple[int, int]:
"""Return the current PP rank's transformer-layer range as (start_layer,
num_layers).

Prefer Megatron's own helpers so replay indexing stays aligned with the
actual model partition, including embedding/loss pipeline accounting.
"""
import megatron.core.parallel_state as mpu
from megatron.core.transformer.transformer_block import get_num_layers_to_build
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset

pp_rank = mpu.get_pipeline_model_parallel_rank()
offset = get_transformer_layer_offset(model_config, pp_rank=pp_rank)
num_layers = get_num_layers_to_build(model_config, pp_rank=pp_rank)
return offset, num_layers


def setup_per_microbatch_replay_forward(
rollout_expert_indices: torch.Tensor,
attention_mask: torch.Tensor,
model_config,
use_sample_packing: bool = False,
) -> None:
"""Set up RouterReplay for a single micro-batch, aligning indices
with the token layout that the MoE layer sees.
with the left-padding-removed token layout that the MoE layer sees.

Handles context parallelism: when CP > 1, the sequence is split into
2*cp_size chunks with each CP rank receiving a front chunk and a back
chunk (for causal-mask load balancing). Replay indices are split using
the same pattern so they stay aligned with the tokens each rank sees.

Handles sequence parallelism: when TP > 1, the sequence is split across
TP ranks, so each rank's MoE router only sees its local chunk of tokens.

Handles sample packing: when use_sample_packing is True, sequences are
concatenated into one packed sequence with per-sample alignment padding.
The replay indices must follow this same packed layout.

Handles dense-layer mismatch: DeepSeek V3-style models have dense FFN
layers before the MoE layers. vLLM reports routing indices for ALL
layers before the MoE layers. vLLM reports routing indices for ALL
transformer layers, but Megatron only has RouterReplay instances for MoE
layers. We use each instance's global layer_number (set by the patched
layers. We use each instance's global layer_number (set by the patched
TopKRouter.set_layer_number) to index into the correct slice of the data.

Handles pipeline parallelism: when PP > 1, the sequence is split across
PP ranks, so each rank only sees its local RouterReplay instances. In cases
where the number of local RouterReplay instances does not match the local
layer count, indicating that the model has dense layers before MoE layers,
we use the global layer_number to index into the correct slice of the data.

"""
import megatron.core.parallel_state as mpu
from megatron.core.transformer.moe.router_replay import (
Expand All @@ -220,34 +260,36 @@ def setup_per_microbatch_replay_forward(
else:
aligned = _remove_left_padding_from_indices(rollout_expert_indices, attention_mask)

# TP splitting: sequence parallelism across the tensor model parallel region
tp_size = mpu.get_tensor_model_parallel_world_size()
if tp_size > 1:
tp_rank = mpu.get_tensor_model_parallel_rank()
seq_len = aligned.shape[1]
chunk_size = seq_len // tp_size
aligned = aligned[:, tp_rank * chunk_size : (tp_rank + 1) * chunk_size, :, :]

per_layer_data = _split_replay_indices(aligned)
num_layers_in_data = len(per_layer_data)
global_num_layers_in_data = len(per_layer_data)
instances = RouterReplay.global_router_replay_instances
num_instances = len(instances)
local_layer_offset, local_num_layers = _get_current_pp_stage_layer_range(model_config)

if num_layers_in_data == num_instances:
RouterReplay.set_replay_data(per_layer_data)
if local_num_layers == num_instances:
Comment thread
erictang000 marked this conversation as resolved.
local_per_layer_data = per_layer_data[local_layer_offset : local_layer_offset + local_num_layers]
RouterReplay.set_replay_data(local_per_layer_data)
Comment thread
erictang000 marked this conversation as resolved.
else:
# Dense-layer mismatch: map each MoE router to its global layer index.
# Prefer the patched layer_number; fall back to offset-based mapping
# (assumes dense layers precede MoE layers).
for i, router_instance in enumerate(instances):
for local_router_idx, router_instance in enumerate(instances):
layer_number = getattr(router_instance, "layer_number", None)
if layer_number is not None:
layer_idx = layer_number - 1 # layer_number is 1-based
else:
layer_idx = i + (num_layers_in_data - num_instances)
if layer_idx < 0 or layer_idx >= num_layers_in_data:
layer_idx = local_layer_offset + local_router_idx + (local_num_layers - num_instances)
if layer_idx < 0 or layer_idx >= global_num_layers_in_data:
raise ValueError(
f"Router replay layer index {layer_idx} out of range "
f"for data with {num_layers_in_data} layers "
f"for data with {global_num_layers_in_data} layers "
f"({num_instances} router instances)"
)
router_instance.set_target_indices(per_layer_data[layer_idx])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def forward_step(batch_iter, model):
rollout_expert_indices = batch.pop("rollout_expert_indices", None)
if rollout_expert_indices is not None:
setup_per_microbatch_replay_forward(
rollout_expert_indices, batch["attention_mask"], use_sample_packing=self.use_sample_packing
rollout_expert_indices,
batch["attention_mask"],
model_config=get_model_config(model),
use_sample_packing=self.use_sample_packing,
)

sequences = batch["sequences"]
Expand Down Expand Up @@ -380,7 +383,10 @@ def forward_step(batch_iter, model):
rollout_expert_indices = batch.pop("rollout_expert_indices", None)
if rollout_expert_indices is not None:
setup_per_microbatch_replay_forward(
rollout_expert_indices, batch["attention_mask"], use_sample_packing=self.use_sample_packing
rollout_expert_indices,
batch["attention_mask"],
model_config=get_model_config(model),
use_sample_packing=self.use_sample_packing,
)

sequences = batch["sequences"]
Expand Down
6 changes: 0 additions & 6 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,6 @@ def validate_megatron_cfg(cfg: SkyRLTrainConfig):
assert (
cfg.generator.inference_engine.enable_return_routed_experts
), "rollout router replay (r3) is only supported when enable_return_routed_experts is True"
assert (
cfg.trainer.policy.megatron_config.pipeline_model_parallel_size == 1
), "pipeline parallel is not yet supported for router replay (r3) with megatron"
assert (
cfg.trainer.policy.megatron_config.context_parallel_size == 1
), "context parallel is not yet supported for router replay (r3) with megatron"

worker_configs = [(cfg.trainer.policy, "policy"), (cfg.trainer.ref, "ref")]
for config, worker_type in worker_configs:
Expand Down
58 changes: 38 additions & 20 deletions tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
)

MOE_MODEL_NAME = "moonshotai/Moonlight-16B-A3B-Instruct"
NUM_PROMPTS = 5
N_SAMPLES_PER_PROMPT = 2
NUM_PROMPTS = 10
N_SAMPLES_PER_PROMPT = 4
MAX_GENERATE_LENGTH = 128


Expand All @@ -45,10 +45,13 @@ def get_test_actor_config(model_name=MOE_MODEL_NAME) -> SkyRLTrainConfig:
# flash attn + mla works without sample packing, logprobs are crazy/wrong
# but flash-attn correctly throws error with sample packing
# we should add an assert that if you set use_sample_packing=False flash attn can accidentally be used
# and that we enable nvte fused attn for moonlight models with use_sample_packing=True
# need to enable nvte fused attn for router replay tests when using moonlight models with use_sample_packing=True
cfg.trainer.logger = "console"
if "moonlight" in model_name:
if cfg.trainer.policy.megatron_config.transformer_config_kwargs is None:
cfg.trainer.policy.megatron_config.transformer_config_kwargs = {}

cfg.trainer.flash_attn = False
validate_cfg(cfg)
return cfg
Expand Down Expand Up @@ -105,8 +108,14 @@ def build_training_input_from_text_samples(


@pytest.mark.megatron
@pytest.mark.skip(reason="Skipping router replay test for now due to size constraints")
Comment thread
devpatelio marked this conversation as resolved.
def test_logprobs(ray_init_fixture):
@pytest.mark.skip(reason="Skipping router replay tests for now due to size constraints")
@pytest.mark.parametrize(
"tp,pp,cp,ep,etp,extra_tf_kwargs",
[
pytest.param(2, 2, 2, 4, 1, {"num_layers_in_first_pipeline_stage": 13}, id="max_parallelism"),
],
)
def test_logprobs(ray_init_fixture, tp, pp, cp, ep, etp, extra_tf_kwargs):
"""
Check that logprob diff is lower when using router replay. Requires full 8xH100 setup to do full forward pass.
"""
Expand All @@ -120,8 +129,7 @@ def test_logprobs(ray_init_fixture):
logprobs=1,
temperature=1.0,
)
cfg.generator.batched = True
cfg.generator.async_engine = False
cfg.generator.batched = False
cfg.generator.max_turns = 1

tokenizer = AutoTokenizer.from_pretrained(MOE_MODEL_NAME, trust_remote_code=True)
Expand Down Expand Up @@ -217,11 +225,13 @@ def test_logprobs(ray_init_fixture):
training_input.metadata = {"response_length": num_actions}

cfg.trainer.placement.policy_num_gpus_per_node = 8
cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2
cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1
cfg.trainer.policy.megatron_config.context_parallel_size = 1
cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8
cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1
if extra_tf_kwargs is not None:
cfg.trainer.policy.megatron_config.transformer_config_kwargs.update(extra_tf_kwargs)
cfg.trainer.policy.megatron_config.tensor_model_parallel_size = tp
cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = pp
cfg.trainer.policy.megatron_config.context_parallel_size = cp
cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep
cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp
cfg.trainer.micro_forward_batch_size_per_gpu = 2
cfg.trainer.micro_train_batch_size_per_gpu = 2

Expand All @@ -245,11 +255,12 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor:

r3_logprobs = run_megatron_forward(enable_replay=True)
no_r3_logprobs = run_megatron_forward(enable_replay=False)

mask = response_mask.bool()
Comment thread
devpatelio marked this conversation as resolved.

vllm_valid = logprobs_t[mask]
r3_valid = r3_logprobs[mask]
no_r3_valid = no_r3_logprobs[mask]

r3_diff = (vllm_valid - r3_valid).abs()
no_r3_diff = (vllm_valid - no_r3_valid).abs()
print(f"vLLM logprobs - mean: {vllm_valid.mean().item():.6f}, std: {vllm_valid.std().item():.6f}")
Expand All @@ -267,8 +278,14 @@ def run_megatron_forward(enable_replay: bool) -> torch.Tensor:


@pytest.mark.megatron
@pytest.mark.skip(reason="Skipping router replay test for now due to size constraints")
Comment thread
devpatelio marked this conversation as resolved.
def test_forward_backward(ray_init_fixture):
@pytest.mark.skip(reason="Skipping router replay tests for now due to size constraints")
@pytest.mark.parametrize(
"tp,pp,cp,ep,etp,extra_tf_kwargs",
[
pytest.param(2, 2, 2, 4, 1, {"num_layers_in_first_pipeline_stage": 13}, id="max_parallelism"),
],
)
def test_forward_backward(ray_init_fixture, tp, pp, cp, ep, etp, extra_tf_kwargs):
"""
Check that forward_backward with router replay completes without error.
Uses dummy expert routing indices (no vLLM engine needed).
Expand Down Expand Up @@ -339,11 +356,13 @@ def test_forward_backward(ray_init_fixture):
training_input.metadata = {"response_length": num_actions}

cfg.trainer.placement.policy_num_gpus_per_node = 8
cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 4
cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1
cfg.trainer.policy.megatron_config.context_parallel_size = 1
cfg.trainer.policy.megatron_config.expert_model_parallel_size = 8
cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = 1
if extra_tf_kwargs is not None:
cfg.trainer.policy.megatron_config.transformer_config_kwargs.update(extra_tf_kwargs)
cfg.trainer.policy.megatron_config.tensor_model_parallel_size = tp
cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = pp
cfg.trainer.policy.megatron_config.context_parallel_size = cp
cfg.trainer.policy.megatron_config.expert_model_parallel_size = ep
cfg.trainer.policy.megatron_config.expert_tensor_parallel_size = etp
cfg.trainer.micro_forward_batch_size_per_gpu = 2
cfg.trainer.micro_train_batch_size_per_gpu = 2
cfg.trainer.policy.megatron_config.moe_enable_routing_replay = True
Expand All @@ -354,7 +373,6 @@ def test_forward_backward(ray_init_fixture):
cfg=cfg,
)

ray.get(actor_group.async_run_ray_method("pass_through", "setup_per_microbatch_replay_backward"))
ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input))
ray.get(actor_group.async_run_ray_method("pass_through", "optim_step"))
results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=training_input))
Expand Down
Loading