Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a802773
previous code
erictang000 Mar 4, 2026
21b44de
lint
erictang000 Mar 4, 2026
23fdc45
make opus take a pass at test + plumbing fully thru generator
erictang000 Mar 4, 2026
8daac59
updated test utils and file to support rollout replay indices
devpatelio Mar 4, 2026
647426f
add helper functions for router visibility and megatron testing, succ…
devpatelio Mar 4, 2026
d4b753f
linter
devpatelio Mar 4, 2026
f1b9c53
worked w opus to get forward pass logprob diff lower with replay + ru…
erictang000 Mar 4, 2026
8a8fa70
add test for forward backward and fix behavior
erictang000 Mar 4, 2026
410995a
working for qwen but not moonlight... debugging moonlight
erictang000 Mar 6, 2026
93eee65
x
erictang000 Mar 6, 2026
9c716a1
fixed test for moonlight by enforcing fused attn
devpatelio Mar 9, 2026
097d2ad
Merge branch 'r3' of https://github.com/NovaSky-AI/SkyRL into HEAD
devpatelio Mar 9, 2026
6de7d5c
x
devpatelio Mar 9, 2026
591af9b
x
devpatelio Mar 9, 2026
acb35ec
clean up
erictang000 Mar 10, 2026
5ad9426
rename var and clean up
erictang000 Mar 11, 2026
909f5ad
Merge branch 'main' of https://github.com/erictang000/SkyRL into HEAD
erictang000 Mar 12, 2026
d2cd56a
Merge branch 'main' of https://github.com/erictang000/SkyRL into HEAD
erictang000 Mar 12, 2026
4a60d4c
cleaning up
erictang000 Mar 12, 2026
205da19
lint
erictang000 Mar 12, 2026
f78dc75
cleaning up
erictang000 Mar 12, 2026
43297f0
x
erictang000 Mar 12, 2026
0468c37
x
erictang000 Mar 12, 2026
0dfed8d
Merge branch 'main' of https://github.com/erictang000/SkyRL into r3
erictang000 Mar 12, 2026
4878ed7
x
erictang000 Mar 13, 2026
a5babb4
fix bug not propagating router indices to fwd pass
erictang000 Mar 13, 2026
7c11d73
x
erictang000 Mar 13, 2026
ac1fb79
add supported settings to cfg validation
erictang000 Mar 13, 2026
38b15a1
add docs'
erictang000 Mar 13, 2026
465ec77
docs
erictang000 Mar 13, 2026
e6af1a0
remove legacy
erictang000 Mar 13, 2026
951bc24
x
erictang000 Mar 13, 2026
2f6c778
x
erictang000 Mar 13, 2026
e3c965c
ur right devin
erictang000 Mar 13, 2026
7681a33
Merge branch 'main' of https://github.com/erictang000/SkyRL into r3
erictang000 Mar 13, 2026
bd69614
add dapo moonlight with r3
erictang000 Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions examples/train/router_replay/run_moonlight16b_router_replay.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
set -x

# Colocated GRPO training+generation for Moonlight-16B-A3B-Instruct on GSM8K with Megatron with router replay (r3)
# Runs on 1 nodes of 8xH100s

# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/train/router_replay/run_moonlight16b_router_replay.sh

DATA_DIR="$HOME/data/gsm8k"
LOGGER="wandb" # change to "console" to print to stdout
MODEL_NAME="moonshotai/Moonlight-16B-A3B-Instruct"

INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron

NUM_NODES=1
NUM_GPUS=8

MEGATRON_TP=4
MEGATRON_PP=1
MEGATRON_CP=1
MEGATRON_EP=8
MEGATRON_ETP=1

NUM_INFERENCE_ENGINES=1
INFERENCE_ENGINE_TP=8

# flash attn is not supported for moonlight16b since it is a DeepSeekV3 like model, and uses Multi-Head Latent Attention (MLA)
# https://github.com/NVIDIA/TransformerEngine/blob/483d9594fb070f62966f6a12ed6c90942310b48e/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L483
FLASH_ATTN=false

# router replay (r3)
ROUTER_REPLAY=true
DISTRIBUTED_EXECUTION_BACKEND="mp"

SKYRL_RAY_PG_TIMEOUT_IN_S=300 uv run --isolated --extra megatron --with blobfile -m skyrl.train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path=$MODEL_NAME \
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine.tensor_parallel_size=$INFERENCE_ENGINE_TP \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.policy.megatron_config.moe_enable_routing_replay=$ROUTER_REPLAY \
generator.inference_engine.distributed_executor_backend=$DISTRIBUTED_EXECUTION_BACKEND \
generator.inference_engine.enable_return_routed_experts=$ROUTER_REPLAY \
trainer.use_sample_packing=true \
trainer.flash_attn=$FLASH_ATTN \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=256 \
trainer.policy_mini_batch_size=32 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.ckpt_interval=100 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=false \
generator.inference_engine.backend=$INFERENCE_BACKEND \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.weight_sync_backend=nccl \
generator.inference_engine.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.inference_engine.gpu_memory_utilization=0.6 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k_router_replay" \
trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_moonlight16b-a3b_with_router_replay" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \
$@
5 changes: 5 additions & 0 deletions skyrl/backends/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class InferenceEngineOutput(TypedDict):
response_ids: List[List[int]]
stop_reasons: List[str]
response_logprobs: Optional[List[List[float]]]
rollout_expert_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk]


class InferenceEngineInterface(ABC):
Expand Down Expand Up @@ -65,6 +66,7 @@ async def sample(
all_responses = []
all_stop_reasons = []
all_response_logprobs = []
all_rollout_expert_indices = []

for _ in range(num_samples):
input_batch: InferenceEngineInput = {
Expand All @@ -81,12 +83,15 @@ async def sample(
all_stop_reasons.append(output["stop_reasons"][0])
if output.get("response_logprobs") is not None:
all_response_logprobs.append(output["response_logprobs"][0])
if output.get("rollout_expert_indices") is not None:
all_rollout_expert_indices.append(output["rollout_expert_indices"][0])

return {
"response_ids": all_response_ids,
"responses": all_responses,
"stop_reasons": all_stop_reasons,
"response_logprobs": all_response_logprobs if all_response_logprobs else None,
"rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None,
}

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
stop_reasons: list[str] = [""] * n
response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
response_ids: List[List[int]] = [[] for _ in range(n)]
rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)]
# a bit hacky for now
add_resp_logprobs = False
add_rollout_expert_indices = False

for indices, result in zip(indices_list, results):
for local_idx, original_idx in enumerate(indices):
Expand All @@ -166,12 +168,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
if result.get("response_logprobs", None):
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]
if result.get("rollout_expert_indices", None):
add_rollout_expert_indices = True
rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx]

return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None,
)

def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int:
Expand Down Expand Up @@ -267,6 +273,7 @@ async def _generate_single_with_retry(
# 2. Initialize fields we want to accumulate or update in each loop iteration
accum_response_ids: List[int] = []
accum_response_logprobs: List[float] = []
rollout_expert_indices: List[List[List[int]]] = None
stop_reason: str = "abort"

# We only use it if generation is completed in one turn to maintain original behavior with no retry.
Expand Down Expand Up @@ -302,6 +309,10 @@ async def _generate_single_with_retry(
new_response_logprobs_list: Optional[List[List[float]]] = partial_response.get("response_logprobs", None)
if new_response_logprobs_list is not None and len(new_response_logprobs_list) > 0:
new_response_logprobs = new_response_logprobs_list[0]
new_rollout_expert_indices: Optional[List[List[List[int]]]] = None
new_rollout_expert_indices_list = partial_response.get("rollout_expert_indices", None)
if new_rollout_expert_indices_list is not None and len(new_rollout_expert_indices_list) > 0:
new_rollout_expert_indices = new_rollout_expert_indices_list[0]

# 3.4 Aborted without generating tokens, so partial_response is useless.
if stop_reason == "abort" and len(new_response_ids) == 0:
Expand All @@ -311,6 +322,8 @@ async def _generate_single_with_retry(
accum_response_ids.extend(new_response_ids)
if new_response_logprobs is not None:
accum_response_logprobs.extend(new_response_logprobs)
if new_rollout_expert_indices is not None:
rollout_expert_indices = new_rollout_expert_indices
num_turns += 1

# 4. Build the final response and return.
Expand All @@ -323,6 +336,7 @@ async def _generate_single_with_retry(
stop_reasons=[stop_reason],
response_ids=[accum_response_ids],
response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None,
rollout_expert_indices=([rollout_expert_indices] if rollout_expert_indices is not None else None),
)

async def _chat_completion_with_retry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def create_ray_wrapped_inference_engines(
rope_scaling: Dict[str, Any] = {},
rope_theta: float | None = None,
enable_ray_prometheus_stats: bool = False,
enable_return_routed_experts: bool = False,
served_model_name: str | None = None,
distributed_executor_backend: str = "ray",
) -> List[InferenceEngineInterface]:
Expand Down Expand Up @@ -281,6 +282,7 @@ def create_ray_wrapped_inference_engines(
max_num_seqs=max_num_seqs,
max_logprobs=1, # only need chosen-token logprobs
enable_ray_prometheus_stats=enable_ray_prometheus_stats,
enable_return_routed_experts=enable_return_routed_experts,
**dp_kwargs,
**engine_init_kwargs,
**lora_kwargs,
Expand Down
21 changes: 21 additions & 0 deletions skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _postprocess_outputs(self, outputs):
stop_reasons: List[str] = []
response_ids: List[List[int]] = []
response_logprobs: Optional[List[List[float]]] = []
rollout_expert_indices: Optional[List[List[List[List[int]]]]] = []

for output in outputs:
# TODO(tgriggs): Support n>1 sampling.
Expand All @@ -170,14 +171,26 @@ def _postprocess_outputs(self, outputs):
del token_logprobs
response_logprobs.append(_logprobs)

_routed_experts = None
if resp.routed_experts is not None:
if hasattr(resp.routed_experts, "tolist"):
_routed_experts = resp.routed_experts.tolist()
else:
_routed_experts = resp.routed_experts
rollout_expert_indices.append(_routed_experts)

if len(response_logprobs) and response_logprobs[0] is None:
response_logprobs = None # hack: assume uniform sampling params

if len(rollout_expert_indices) > 0 and rollout_expert_indices[0] is None:
rollout_expert_indices = None # hack: assume uniform sampling params

return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs,
rollout_expert_indices=rollout_expert_indices,
)

def _get_engine(self):
Expand Down Expand Up @@ -335,6 +348,14 @@ def _create_engine(self, *args, **kwargs):
enable_log_requests = kwargs.pop("enable_log_requests", False)
max_log_len = kwargs.pop("max_log_len", None)

# Log if enable_return_routed_experts is being passed
if "enable_return_routed_experts" in kwargs:
logger.info(
f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs"
)
else:
logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs")

if version.parse(vllm.__version__) >= version.parse("0.10.0"):
engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs)
else:
Expand Down
1 change: 1 addition & 0 deletions skyrl/backends/skyrl_train/training_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ class TrainingInput(TypedDict, total=False):
kl: Float[torch.Tensor, "batch_size seq_len"]
rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]]
rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]]
rollout_expert_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]]


class TrainingInputBatch(TensorBatch[TrainingInput]):
Expand Down
Loading
Loading