Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
dc93a03
replay utils update
devpatelio Mar 17, 2026
d2b2d53
clean
devpatelio Mar 17, 2026
8cb6155
linter
devpatelio Mar 17, 2026
d1712bd
rm the sampke packing flag
devpatelio Mar 17, 2026
3da803a
move cp out of pack_replay_indices for simpler implementation
devpatelio Mar 20, 2026
f7362c3
model wrapper
devpatelio Mar 20, 2026
5d74b89
done:
erictang000 Mar 20, 2026
5855ee8
lint
devpatelio Mar 20, 2026
153a2d0
Fix async script to use megatron
devpatelio Mar 20, 2026
c46287e
Fix fallback logic to account for dense layers for indexing replay in…
devpatelio Mar 20, 2026
33b3083
[train] Add `worker_process_setup_hook` to set mp start method to `sp…
SumanthRH Mar 17, 2026
7066bf4
[CI] Fix `test_inference_engines_generation` after vllm 0.16.0 upgrad…
SumanthRH Mar 18, 2026
7744e69
[train] Make TrainingInputBatch to PAD only to left, hence response t…
CharlieFRuan Mar 18, 2026
bfcd8db
[train] Revert "[train] Add `worker_process_setup_hook` to set mp sta…
SumanthRH Mar 19, 2026
daf5752
[Docs] Add docs on agent integration and step-wise training (#1347)
CharlieFRuan Mar 19, 2026
493387c
[Docs] Small update on docs (#1348)
CharlieFRuan Mar 19, 2026
089ee8b
[train] Add validation for step-wise GeneratorOutput (#1281)
CharlieFRuan Mar 19, 2026
0f18c70
[megatron] rebuild weight conversion tasks per sync to prevent stale …
erictang000 Mar 19, 2026
1a73422
[StepWise] Trivial fix to avg_response_length metric (#1351)
CharlieFRuan Mar 19, 2026
04d29f3
[CI] Make `MultiItemDataset` a global variable after switch to `spawn…
SumanthRH Mar 19, 2026
2619ed4
[train] Add support for LoRA in the new inference codepath (#1329)
SumanthRH Mar 20, 2026
9a0088f
[bug][algorithm] remove incorrect torch.no_grad() for kl in loss (use…
erictang000 Mar 20, 2026
4cf3a48
latest main
devpatelio Mar 20, 2026
4b844c0
Merge remote-tracking branch 'origin/main' into r3-pp-cp
devpatelio Mar 20, 2026
44aab7d
final done
devpatelio Mar 21, 2026
be50ba6
revert test file changes
devpatelio Mar 21, 2026
f7c3086
fix
devpatelio Mar 21, 2026
bf84467
done
devpatelio Mar 21, 2026
0a2fdb2
remove comment
devpatelio Mar 21, 2026
73741a3
rm file
devpatelio Mar 21, 2026
43468a8
Revert "remove comment"
devpatelio Mar 21, 2026
a4b9228
uv.lock bring back
devpatelio Mar 21, 2026
a841448
Delete 1
devpatelio Mar 21, 2026
5def4cb
Revert "done"
devpatelio Mar 21, 2026
266b9c1
done
devpatelio Mar 21, 2026
5628f22
addressed comments
devpatelio Mar 25, 2026
2bc0ee0
lint
devpatelio Mar 25, 2026
fff1c6a
Merge branch 'main' of https://github.com/NovaSky-AI/SkyRL into r3-pp-cp
erictang000 Mar 26, 2026
9463aec
x
erictang000 Mar 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions examples/train/router_replay/router_replay_fully_async.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
set -x

# Fully async GRPO training+generation for Qwen2.5-1.5B-Instruct on GSM8K.
# This bash script is copied from examples/async/async_run_gsm8k.sh, except for:
# - running examples.train.fully_async.main_fully_async
# - setting the generator.batched=false.
# - colocate_all=false
# - the various generator configs at the end (http, chat template, etc.)

# uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/train/router_replay/router_replay_fully_async.sh

# NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned

# You can override the default values with e.g.: `NUM_GPUS=1 bash examples/train/fully_async/fully_async_run_gsm8k.sh`.

: "${DATA_DIR:="$HOME/data/gsm8k"}"
: "${NUM_INFERENCE_GPUS:=4}"
: "${NUM_POLICY_GPUS:=4}"
: "${LOGGER:=wandb}" # change to "console" to print to stdout / or use wandb

: "${INFERENCE_BACKEND:=vllm}"

# Fully async specific configuration knobs:
: "${MINI_BATCH_SIZE:=256}"
: "${MAX_STALENESS_STEPS:=4}"
: "${NUM_PARALLEL_GENERATION_WORKERS:=$(( MINI_BATCH_SIZE * (MAX_STALENESS_STEPS + 1) ))}"

TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

# moonlight16b
MODEL_NAME="moonshotai/Moonlight-16B-A3B-Instruct"

NUM_NODES=1
NUM_GPUS=8

MEGATRON_TP=1
MEGATRON_PP=2
MEGATRON_CP=2
MEGATRON_EP=4
MEGATRON_ETP=1

NUM_INFERENCE_ENGINES=1
INFERENCE_ENGINE_TP=4

# router replay (r3)
ROUTER_REPLAY=true
DISTRIBUTED_EXECUTION_BACKEND="mp"

RUN_NAME=gsm8k-fully-async-moonlight16b-a3b-useTIS_${TIS_TYPE}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen_r3

uv run --isolated --extra fsdp -m examples.train.fully_async.main_fully_async \
Comment thread
devpatelio marked this conversation as resolved.
Outdated
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.fully_async.max_staleness_steps=${MAX_STALENESS_STEPS} \
trainer.fully_async.num_parallel_generation_workers=${NUM_PARALLEL_GENERATION_WORKERS} \
trainer.algorithm.advantage_estimator="grpo" \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path=$MODEL_NAME \
trainer.placement.colocate_all=false \
trainer.strategy=fsdp2 \

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line sets trainer.strategy=fsdp2, but it's later overridden by trainer.strategy=megatron on line 76. Since this script is for a Megatron-based run, this line is redundant and could cause confusion. It should be removed.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Duplicate trainer.strategy setting in example shell script — first value (fsdp2) is a leftover

Line 64 sets trainer.strategy=fsdp2 and line 76 sets trainer.strategy=megatron. The second override wins, but the first is clearly a leftover from copying the async example script. Since the rest of the script configures Megatron-specific parameters (TP/PP/CP/EP), the fsdp2 value on line 64 is incorrect and misleading. A user who copies only the first section of this script would get the wrong strategy.

Suggested change
trainer.strategy=fsdp2 \
trainer.strategy=megatron \
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

trainer.placement.policy_num_gpus_per_node=$NUM_POLICY_GPUS \
trainer.placement.critic_num_gpus_per_node=$NUM_POLICY_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_POLICY_GPUS \
generator.inference_engine.num_engines=$NUM_INFERENCE_GPUS \
generator.inference_engine.distributed_executor_backend=$DISTRIBUTED_EXECUTION_BACKEND \
generator.inference_engine.enable_return_routed_experts=$ROUTER_REPLAY \
generator.inference_engine.tensor_parallel_size=1 \
Comment thread
erictang000 marked this conversation as resolved.
Outdated
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=false \
trainer.eval_interval=4 \
trainer.strategy=megatron \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.policy.megatron_config.moe_enable_routing_replay=$ROUTER_REPLAY \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=${MINI_BATCH_SIZE} \
trainer.policy_mini_batch_size=${MINI_BATCH_SIZE} \
trainer.micro_forward_batch_size_per_gpu=8 \
trainer.micro_train_batch_size_per_gpu=8 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.inference_engine.backend=$INFERENCE_BACKEND \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.weight_sync_backend=nccl \
generator.inference_engine.async_engine=true \
generator.batched=false \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.inference_engine.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k-async" \
trainer.run_name=${RUN_NAME} \
trainer.resume_mode=latest \
trainer.ckpt_path="$HOME/ckpts/${RUN_NAME}" \
generator.inference_engine.enforce_eager=true \
$@
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron
NUM_NODES=1
NUM_GPUS=8

MEGATRON_TP=4
MEGATRON_PP=1
MEGATRON_CP=1
MEGATRON_EP=8
MEGATRON_TP=1
MEGATRON_PP=2
MEGATRON_CP=2
MEGATRON_EP=4
MEGATRON_ETP=1

NUM_INFERENCE_ENGINES=1
Expand All @@ -41,6 +41,7 @@ SKYRL_RAY_PG_TIMEOUT_IN_S=300 uv run --isolated --extra megatron --with blobfile
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.policy.megatron_config.transformer_config_kwargs.num_layers_in_last_pipeline_stage=13 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine.tensor_parallel_size=$INFERENCE_ENGINE_TP \
Expand Down
156 changes: 124 additions & 32 deletions skyrl/backends/skyrl_train/utils/replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ def _remove_left_padding_from_indices(

seq_lens = attention_mask.sum(dim=1)
effective_seq_len = seq_lens.max().item()
sp_world_size = mpu.get_tensor_model_parallel_world_size()
if sp_world_size > 1:
pad_size = (sp_world_size - effective_seq_len % sp_world_size) % sp_world_size
tp_size = mpu.get_tensor_model_parallel_world_size()
cp_size = mpu.get_context_parallel_world_size()
align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
if align_size > 1:
pad_size = (align_size - effective_seq_len % align_size) % align_size
effective_seq_len += pad_size

batch_size = rollout_expert_indices.shape[0]
Expand Down Expand Up @@ -151,8 +153,6 @@ def _pack_replay_indices(
seqlens_padded = seq_lens + pad_sizes

total_packed_len = int(seqlens_padded.sum().item())
if cp_size > 1:
total_packed_len = total_packed_len // cp_size

packed = torch.zeros(
total_packed_len,
Expand All @@ -164,48 +164,136 @@ def _pack_replay_indices(

seq_lens_cpu = seq_lens.tolist()
seqlens_padded_cpu = seqlens_padded.tolist()
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
offset = 0
for i in range(batch_size):
n = seq_lens_cpu[i]
mask = attention_mask[i].bool()
d = rollout_expert_indices[i, mask]
if cp_size > 1:
chunk_size = seqlens_padded_cpu[i] // cp_size
start = cp_rank * chunk_size
end = min(start + chunk_size, n)
valid_len = max(0, end - start)
if valid_len > 0:
packed[offset : offset + valid_len] = d[start:end]
offset += chunk_size
packed[offset : offset + n] = d
offset += seqlens_padded_cpu[i]

if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
out = torch.zeros(
total_packed_len // cp_size,
num_layers,
topk,
dtype=packed.dtype,
device=packed.device,
)
src_offset = 0
dst_offset = 0
for i in range(batch_size):
seqlen_padded_i = seqlens_padded_cpu[i]
seqlen_per_cp = seqlen_padded_i // cp_size
half = seqlen_per_cp // 2
out[dst_offset : dst_offset + half] = packed[
src_offset + half * cp_rank : src_offset + half * (cp_rank + 1)
]
back_start = src_offset + seqlen_padded_i - half * (cp_rank + 1)
back_end = src_offset + seqlen_padded_i - half * cp_rank
out[dst_offset + half : dst_offset + seqlen_per_cp] = packed[back_start:back_end]
src_offset += seqlen_padded_i
dst_offset += seqlen_per_cp
packed = out

return packed.unsqueeze(0) # [1, packed_len_per_cp, layers, topk]


def _get_current_pp_stage_layer_range(model_config) -> tuple[int, int]:
"""Return the current PP rank's transformer-layer range.
Comment thread
erictang000 marked this conversation as resolved.
Outdated

Prefer Megatron's own helpers so replay indexing stays aligned with the
actual model partition, including embedding/loss pipeline accounting.
"""
import megatron.core.parallel_state as mpu
from megatron.core.transformer.transformer_block import get_num_layers_to_build
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset

pp_rank = mpu.get_pipeline_model_parallel_rank()

if get_num_layers_to_build is not None:
Comment thread
devpatelio marked this conversation as resolved.
Outdated
return get_transformer_layer_offset(model_config), get_num_layers_to_build(model_config, pp_rank=pp_rank)

pp_size = mpu.get_pipeline_model_parallel_world_size()

total_layers = model_config.num_layers
first_stage_layers = getattr(model_config, "num_layers_in_first_pipeline_stage", None)
last_stage_layers = getattr(model_config, "num_layers_in_last_pipeline_stage", None)

if pp_size <= 1:
return 0, total_layers

if first_stage_layers is None and last_stage_layers is None:
assert (
total_layers % pp_size == 0
), "For even pipelineing, num_layers should be divisible by pipeline_model_parallel_size"
pp_layers = total_layers // pp_size
return pp_rank * pp_layers, pp_layers

next_n_pp_layers = total_layers
next_n_pp_stages = pp_size

if first_stage_layers is not None:
next_n_pp_layers -= first_stage_layers
next_n_pp_stages -= 1

if last_stage_layers is not None:
next_n_pp_layers -= last_stage_layers
next_n_pp_stages -= 1

if next_n_pp_stages > 0:
assert (
next_n_pp_layers % next_n_pp_stages == 0
), "Uneven pipelineing, not divisible by remaining pipeline stages"
next_n_pp_layers = next_n_pp_layers // next_n_pp_stages
else:
next_n_pp_layers = 0

if pp_rank == 0 and first_stage_layers is not None:
return 0, first_stage_layers

if pp_rank == pp_size - 1 and last_stage_layers is not None:
if first_stage_layers is not None:
start = first_stage_layers + (next_n_pp_layers * (pp_size - 2))
else:
packed[offset : offset + n] = d
offset += seqlens_padded_cpu[i]
start = next_n_pp_layers * (pp_size - 1)
return start, last_stage_layers

return packed.unsqueeze(0) # [1, total_packed_len, layers, topk]
if first_stage_layers is not None:
return first_stage_layers + (next_n_pp_layers * (pp_rank - 1)), next_n_pp_layers
return next_n_pp_layers * pp_rank, next_n_pp_layers


def setup_per_microbatch_replay_forward(
rollout_expert_indices: torch.Tensor,
attention_mask: torch.Tensor,
model_config,
use_sample_packing: bool = False,
) -> None:
"""Set up RouterReplay for a single micro-batch, aligning indices
with the token layout that the MoE layer sees.
with the left-padding-removed token layout that the MoE layer sees.

Handles context parallelism: when CP > 1, the sequence is split into
2*cp_size chunks with each CP rank receiving a front chunk and a back
chunk (for causal-mask load balancing). Replay indices are split using
the same pattern so they stay aligned with the tokens each rank sees.

Handles sequence parallelism: when TP > 1, the sequence is split across
TP ranks, so each rank's MoE router only sees its local chunk of tokens.

Handles sample packing: when use_sample_packing is True, sequences are
concatenated into one packed sequence with per-sample alignment padding.
The replay indices must follow this same packed layout.

Handles dense-layer mismatch: DeepSeek V3-style models have dense FFN
layers before the MoE layers. vLLM reports routing indices for ALL
layers before the MoE layers. vLLM reports routing indices for ALL
transformer layers, but Megatron only has RouterReplay instances for MoE
layers. We use each instance's global layer_number (set by the patched
layers. We use each instance's global layer_number (set by the patched
TopKRouter.set_layer_number) to index into the correct slice of the data.

Handles pipeline parallelism: when PP > 1, the sequence is split across
PP ranks, so each rank only sees its local RouterReplay instances. In cases
where the number of local RouterReplay instances does not match the local
layer count, indicating that the model has dense layers before MoE layers,
we use the global layer_number to index into the correct slice of the data.

"""
import megatron.core.parallel_state as mpu
from megatron.core.transformer.moe.router_replay import (
Expand All @@ -220,6 +308,7 @@ def setup_per_microbatch_replay_forward(
else:
aligned = _remove_left_padding_from_indices(rollout_expert_indices, attention_mask)

# TP splitting: sequence parallelism across the tensor model parallel region
tp_size = mpu.get_tensor_model_parallel_world_size()
if tp_size > 1:
tp_rank = mpu.get_tensor_model_parallel_rank()
Expand All @@ -228,26 +317,29 @@ def setup_per_microbatch_replay_forward(
aligned = aligned[:, tp_rank * chunk_size : (tp_rank + 1) * chunk_size, :, :]

per_layer_data = _split_replay_indices(aligned)
num_layers_in_data = len(per_layer_data)
global_num_layers_in_data = len(per_layer_data)
instances = RouterReplay.global_router_replay_instances
num_instances = len(instances)

if num_layers_in_data == num_instances:
RouterReplay.set_replay_data(per_layer_data)
local_layer_offset, local_num_layers = _get_current_pp_stage_layer_range(model_config)

if local_num_layers == num_instances:
Comment thread
erictang000 marked this conversation as resolved.
local_per_layer_data = per_layer_data[local_layer_offset : local_layer_offset + local_num_layers]
RouterReplay.set_replay_data(local_per_layer_data)
Comment thread
erictang000 marked this conversation as resolved.
else:
# Dense-layer mismatch: map each MoE router to its global layer index.
# Prefer the patched layer_number; fall back to offset-based mapping
# (assumes dense layers precede MoE layers).
for i, router_instance in enumerate(instances):
for local_router_idx, router_instance in enumerate(instances):
layer_number = getattr(router_instance, "layer_number", None)
if layer_number is not None:
layer_idx = layer_number - 1 # layer_number is 1-based
else:
layer_idx = i + (num_layers_in_data - num_instances)
if layer_idx < 0 or layer_idx >= num_layers_in_data:
layer_idx = local_layer_offset + local_router_idx
Comment thread
devpatelio marked this conversation as resolved.
Outdated
if layer_idx < 0 or layer_idx >= global_num_layers_in_data:
raise ValueError(
f"Router replay layer index {layer_idx} out of range "
f"for data with {num_layers_in_data} layers "
f"for data with {global_num_layers_in_data} layers "
f"({num_instances} router instances)"
)
router_instance.set_target_indices(per_layer_data[layer_idx])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def forward_step(batch_iter, model):
rollout_expert_indices = batch.pop("rollout_expert_indices", None)
if rollout_expert_indices is not None:
setup_per_microbatch_replay_forward(
rollout_expert_indices, batch["attention_mask"], use_sample_packing=self.use_sample_packing
rollout_expert_indices,
batch["attention_mask"],
model_config=get_model_config(model),
use_sample_packing=self.use_sample_packing,
)

sequences = batch["sequences"]
Expand Down Expand Up @@ -372,7 +375,10 @@ def forward_step(batch_iter, model):
rollout_expert_indices = batch.pop("rollout_expert_indices", None)
if rollout_expert_indices is not None:
setup_per_microbatch_replay_forward(
rollout_expert_indices, batch["attention_mask"], use_sample_packing=self.use_sample_packing
rollout_expert_indices,
batch["attention_mask"],
model_config=get_model_config(model),
use_sample_packing=self.use_sample_packing,
)

sequences = batch["sequences"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,9 @@ def init_configs(

self.strategy.hf_config = hf_config
self.tokenizer = tokenizer
self.enable_router_replay = megatron_config.moe_enable_routing_replay
self.enable_router_replay = transformer_config_kwargs.get(
Comment thread
devpatelio marked this conversation as resolved.
Outdated
"moe_enable_routing_replay", megatron_config.moe_enable_routing_replay
)

@devin-ai-integration devin-ai-integration Bot Mar 20, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Validation for moe_enable_routing_replay doesn't check transformer_config_kwargs, creating a gap with the worker's logic

The PR changes megatron_worker.py:377-379 to determine self.enable_router_replay from transformer_config_kwargs.get("moe_enable_routing_replay", megatron_config.moe_enable_routing_replay), but the validation in skyrl/train/utils/utils.py:205 only checks cfg.trainer.policy.megatron_config.moe_enable_routing_replay (the top-level config field). If a user sets moe_enable_routing_replay=True only in transformer_config_kwargs (as the tests do at tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py:237-239), the worker enables replay but the validation skips the assertion that enable_return_routed_experts must also be True. This could lead to router replay being silently ineffective (since rollout_expert_indices would be None without enable_return_routed_experts).

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


def configure_lora(self, lora_config, lora_type: Optional[str] = "lora"):
if lora_type == "lora":
Expand Down
Loading
Loading