Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a802773
previous code
erictang000 Mar 4, 2026
21b44de
lint
erictang000 Mar 4, 2026
23fdc45
make opus take a pass at test + plumbing fully thru generator
erictang000 Mar 4, 2026
8daac59
updated test utils and file to support rollout replay indices
devpatelio Mar 4, 2026
647426f
add helper functions for router visibility and megatron testing, succ…
devpatelio Mar 4, 2026
d4b753f
linter
devpatelio Mar 4, 2026
f1b9c53
worked w opus to get forward pass logprob diff lower with replay + ru…
erictang000 Mar 4, 2026
8a8fa70
add test for forward backward and fix behavior
erictang000 Mar 4, 2026
410995a
working for qwen but not moonlight... debugging moonlight
erictang000 Mar 6, 2026
93eee65
x
erictang000 Mar 6, 2026
9c716a1
fixed test for moonlight by enforcing fused attn
devpatelio Mar 9, 2026
097d2ad
Merge branch 'r3' of https://github.com/NovaSky-AI/SkyRL into HEAD
devpatelio Mar 9, 2026
6de7d5c
x
devpatelio Mar 9, 2026
591af9b
x
devpatelio Mar 9, 2026
acb35ec
clean up
erictang000 Mar 10, 2026
5ad9426
rename var and clean up
erictang000 Mar 11, 2026
909f5ad
Merge branch 'main' of https://github.com/erictang000/SkyRL into HEAD
erictang000 Mar 12, 2026
d2cd56a
Merge branch 'main' of https://github.com/erictang000/SkyRL into HEAD
erictang000 Mar 12, 2026
4a60d4c
cleaning up
erictang000 Mar 12, 2026
205da19
lint
erictang000 Mar 12, 2026
f78dc75
cleaning up
erictang000 Mar 12, 2026
43297f0
x
erictang000 Mar 12, 2026
0468c37
x
erictang000 Mar 12, 2026
0dfed8d
Merge branch 'main' of https://github.com/erictang000/SkyRL into r3
erictang000 Mar 12, 2026
4878ed7
x
erictang000 Mar 13, 2026
a5babb4
fix bug not propagating router indices to fwd pass
erictang000 Mar 13, 2026
7c11d73
x
erictang000 Mar 13, 2026
ac1fb79
add supported settings to cfg validation
erictang000 Mar 13, 2026
38b15a1
add docs'
erictang000 Mar 13, 2026
465ec77
docs
erictang000 Mar 13, 2026
e6af1a0
remove legacy
erictang000 Mar 13, 2026
951bc24
x
erictang000 Mar 13, 2026
2f6c778
x
erictang000 Mar 13, 2026
e3c965c
ur right devin
erictang000 Mar 13, 2026
7681a33
Merge branch 'main' of https://github.com/erictang000/SkyRL into r3
erictang000 Mar 13, 2026
bd69614
add dapo moonlight with r3
erictang000 Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions docs/content/docs/algorithms/off_policy_correction.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,26 @@ SkyRL provides built-in utilities for correcting off-policy drift from trainer/i
We recommend adding the following configs in order to your training runs to help address off-policy drift:

```yaml
# we recommend trying basic TIS correction first
# For dense models, we recommend trying basic TIS correction first
trainer.algorithm.off_policy_correction.tis_ratio_type="token"
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=2.0

# for long context + MoE models, try geometric sequence masking - tune geo_mask_high/geo_mask_low as needed
# for MoE models, enabling router replay (R3) to fix the source of train/infer mismatch is recommended
trainer.policy.megatron_config.moe_enable_routing_replay=True
generator.inference_engine.enable_return_routed_experts=True
generator.inference_engine.distributed_executor_backend="mp" # this is temporarily needed for vLLM, since routed experts cause issues with the ray backend.

# The following masking strategies can additionally help mitigate off policy drift, especially from sources other than train/infer mismatch
# geometric sequence masking - tune geo_mask_high/geo_mask_low as needed
trainer.algorithm.off_policy_correction.sequence_mask_metric="geometric"
trainer.algorithm.off_policy_correction.geo_mask_high=1.01
trainer.algorithm.off_policy_correction.geo_mask_low=0.99

# alternatively, for long context + MoE you can try token masking (icepop) and tune token_mask_is_threshold_low/high
# token masking (icepop): tune token_mask_is_threshold_low/high
trainer.algorithm.off_policy_correction.token_mask_is_threshold_low=0.5
trainer.algorithm.off_policy_correction.token_mask_is_threshold_high=2.0

# for longer context + MoE, you can also try outlier based sequence masking, which stacks on top of geometric sequence masking
# outlier based sequence masking: stacks on top of geometric sequence masking
trainer.algorithm.off_policy_correction.outlier_token_is_threshold_low=1e-4
trainer.algorithm.off_policy_correction.outlier_token_is_threshold_high=100
```
Expand Down Expand Up @@ -125,10 +131,31 @@ policies. To mitigate this, the max staleness of trajectories can be tuned to pr
Mini batching results in off-policy updates, which can be clamped within an acceptable range in the common dual clip formulation of the PPO loss. Tuning the number of mini batches per training batch
can impact convergence of RL runs, and impact whether corrections like routing replay and masking are needed.

# Routing Replay

SkyRL supports rollout routing replay (R3), first introduced by [Ma et al.](https://arxiv.org/pdf/2510.11370) to help eliminate trainer/inference mismatch for MoE at the source. Rollout routing replay works by recording expert
routing decisions for MoE layers at inference time, and replaying the same per-layer expert routing decisions at training time, which helps reduce mismatched logprobs.

```yaml
generator:
inference_engine:
enable_return_routed_experts: True # pass through argument to vLLM
distributed_executor_backend: "mp" # temporarily needed to work around hanging issues with other backends
...
trainer:
policy:
megatron_config:
moe_enable_routing_replay: True # enables Megatron native RoutingReplay feature
```

To enable rollout router replay, set `generator.inference_engine.enable_return_routed_experts=True`, `trainer.policy.megatron_config.moe_enable_routing_replay=True`, and use the `mp` distributed_executor_backend for vLLM. Note that
R3 does induce additional training bias when mini-batching, since routing decisions are fixed for all mini-batches in a training batch. However, it has been shown to be important for stabilizing large-scale MoE training, particularly
in models adopting a DeepSeek-V3 like architecture (notably the GLM family) due to the use of sigmoid-based affinity scoring instead of softmax for top-k routing.

# Algorithmic Off Policy Correction

In the previous section, we described some reasons why off-policy drift can occur, and some ways to mitigate it (e.g., batch invariant kernels, routing replay). However,
these solutions come with tradeoffs (slower inference for batch invariant kernels, additional bias for routing replay), and are not sufficient to address all sources of drift, like fully async RL.
In the previous sections, we described some reasons why off-policy drift can occur, and some ways to mitigate it (e.g., batch invariant kernels, routing replay). However,
these solutions come with tradeoffs (slower inference for batch invariant kernels), and are not sufficient to address all sources of drift, like fully async RL.

Recent works ([Liu et. al 2025](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda), [Yao et. al 2025](https://fengyao.notion.site/off-policy-rl))
have proposed additional techniques for off-policy correction. In this section, we describe these techniques and how to enable them in SkyRL.
Expand Down
6 changes: 6 additions & 0 deletions docs/content/docs/configuration/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ megatron_config:
expert_model_parallel_size: 1
expert_tensor_parallel_size: null

moe_enable_routing_replay: False

ddp_config: # pass-through config to Megatron's `DistributedDataParallelConfig` object
# https://github.com/NVIDIA/Megatron-LM/blob/core_r0.13.0/megatron/core/distributed/distributed_data_parallel_config.py#L8
...
Expand Down Expand Up @@ -203,6 +205,8 @@ Some rules for configuring these parameters:
- `world_size % (pp_size * ep_size * etp_size) == 0`
- This means that `ep_size * etp_size` can scale independently of `tp_size * cp_size`, and can go across data parallel ranks.

- `moe_enable_routing_replay`: Whether to enable Megatron router replay. Used together with `generator.inference_engine.enable_return_routed_experts` to enable R3.

<Callout type="warn">
`optimizer_config_kwargs.use_precision_aware_optimizer=true` can cause checkpointing to fail. See: https://github.com/nvidia/megatron-lm/issues/1820. We recommend leaving this setting to `false`.
</Callout>
Expand Down Expand Up @@ -631,6 +635,7 @@ generator:
max_num_seqs: 1024
vllm_v1_disable_multiproc: true
remote_urls: []
enable_return_routed_experts: false
distributed_executor_backend: "ray" # "mp", "ray"
engine_init_kwargs: {}
override_existing_update_group: "auto" # "auto", "enable", "disable"
Expand Down Expand Up @@ -709,6 +714,7 @@ For more details on how different placement options work, please refer to the [p
- `generator.inference_engine.max_num_batched_tokens`: Continous batching parameter for vLLM. Maximum number of tokens to pack into a batch.
- `generator.inference_engine.enforce_eager`: Whether to disable CUDA graphs. Default is `true` for stability. Set to `false` for higher performance, but this may affect convergence for long-running or long-context training jobs.
- `generator.inference_engine.enable_ray_prometheus_stats`: Whether to enable Ray Prometheus stats logger for vLLM inference engine metrics (vLLM v1 only). When enabled, uses `vllm.v1.metrics.ray_wrappers.RayPrometheusStatLogger`.
- `generator.inference_engine.enable_return_routed_experts`: Whether to return per-layer expert routing indices to use for rollout router replay (R3) if training an MoE model. Used together with `trainer.policy.megatron_config.moe_enable_routing_replay` to enable R3.
- `generator.inference_engine.distributed_executor_backend`: The distributed executor backend to use for the vLLM engine. Options are either `mp` or `ray`.
- `generator.inference_engine.engine_init_kwargs`: Inference engine arguments passed directly to the vLLM engine. If duplicate kwargs are passed or kwargs clash with existing inference engine arguments (e.g., `tensor_parallel_size`), an error is raised.

Expand Down
129 changes: 129 additions & 0 deletions examples/train/router_replay/run_dapo_moonlight_16b_a3b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
set -x

# Colocated DAPO training+generation for Moonlight-16B-A3B on DAPO with Megatron with router replay.
# Should run on 2 node of 8xH100s

# bash examples/train/algorithms/dapo/prepare_dapo_data.sh
# bash examples/train/router_replay/run_dapo_moonlight_16b_a3b.sh

MODEL_NAME="moonshotai/Moonlight-16B-A3B-Instruct"
DATA_DIR="$HOME/data/dapo"
TRAIN_FILE="$DATA_DIR/dapo-math-17k-cleaned.parquet"
TEST_FILE="$DATA_DIR/aime-2024-cleaned.parquet"
NUM_NODES=2
NUM_GPUS_PER_NODE=8
NUM_INFERENCE_ENGINES=2
INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE=8
LOGGER="wandb" # change to "console" to print to stdout

# flash attention off
FLASH_ATTN=false

CLIP_RATIO_LOW=0.2
CLIP_RATIO_HIGH=0.28
# use token mean loss reduction
LOSS_REDUCTION="token_mean"
# applies overlong filtering (but not soft overlong punishment)
APPLY_OVERLONG_FILTERING=true
# apply soft overlong punishment with custom trainer impl in main_dapo.py
OVERLONG_BUFFER_LEN=$((1024 * 4))
OVERLONG_BUFFER_PENALTY_FACTOR=1.0

# other DAPO parameters
USE_KL_LOSS=false
TEMPERATURE=1.0
TOP_P=1.0
EVAL_TOP_P=0.7
CLIP_RATIO_C=10.0
MAX_PROMPT_LENGTH=$((1024 * 2))
MAX_RESPONSE_LENGTH=$((1024 * 8))

# repro run parameters
TRAIN_BATCH_SIZE=128
MINI_BATCH_SIZE=32
N_SAMPLES_PER_PROMPT=16
EVAL_N_SAMPLES_PER_PROMPT=32
ENFORCE_EAGER=true # cuda graphs can cause some instability
LR=1e-6

# megatron config
MEGATRON_TP=4
MEGATRON_PP=1
MEGATRON_CP=1
MEGATRON_EP=8
MEGATRON_ETP=1


# Router replay (r3)
ROUTER_REPLAY=true
DISTRIBUTED_EXECUTION_BACKEND="mp"

SKYRL_RAY_PG_TIMEOUT_IN_S=300 uv run --isolated --extra megatron -m examples.train.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
data.val_data="['$TEST_FILE']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.algorithm.policy_loss_type="dual_clip" \
trainer.algorithm.overlong_buffer_len=$OVERLONG_BUFFER_LEN \
trainer.algorithm.overlong_buffer_penalty_factor=$OVERLONG_BUFFER_PENALTY_FACTOR \
trainer.algorithm.loss_reduction=$LOSS_REDUCTION \
generator.inference_engine.enforce_eager=$ENFORCE_EAGER \
generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \
generator.sampling_params.temperature=$TEMPERATURE \
generator.sampling_params.top_p=$TOP_P \
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
generator.eval_sampling_params.temperature=$TEMPERATURE \
generator.eval_sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.policy.model.path="$MODEL_NAME" \
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS_PER_NODE \
generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine.tensor_parallel_size=$INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.policy.megatron_config.moe_enable_routing_replay=$ROUTER_REPLAY \
generator.inference_engine.enable_return_routed_experts=$ROUTER_REPLAY \
generator.inference_engine.distributed_executor_backend=$DISTRIBUTED_EXECUTION_BACKEND \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=$TRAIN_BATCH_SIZE \
trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=2 \
trainer.ckpt_interval=200 \
trainer.max_prompt_length=$MAX_PROMPT_LENGTH \
generator.sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \
trainer.policy.optimizer_config.lr=$LR \
trainer.policy.optimizer_config.num_warmup_steps=40 \
trainer.policy.optimizer_config.weight_decay=0.1 \
trainer.policy.optimizer_config.max_grad_norm=1.0 \
trainer.flash_attn=$FLASH_ATTN \
generator.inference_engine.backend=vllm \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.weight_sync_backend=nccl \
generator.inference_engine.async_engine=false \
generator.batched=true \
environment.env_class=aime \
generator.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \
generator.eval_n_samples_per_prompt=$EVAL_N_SAMPLES_PER_PROMPT \
generator.inference_engine.gpu_memory_utilization=0.7 \
trainer.logger="$LOGGER" \
trainer.project_name="router_replay" \
trainer.run_name="dapo_moonlight_16b_a3b_megatron_r3" \
trainer.export_path="$HOME/exports/dapo_moonlight_16b_a3b_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_r3" \
trainer.hf_save_interval=300 \
trainer.resume_mode=latest \
trainer.max_ckpts_to_keep=3 \
trainer.ckpt_path="$HOME/ckpts/dapo_moonlight_16b_a3b_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_r3" \
$@
84 changes: 84 additions & 0 deletions examples/train/router_replay/run_moonlight16b_router_replay.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
set -x

# Colocated GRPO training+generation for Moonlight-16B-A3B-Instruct on GSM8K with Megatron with router replay (r3)
# Runs on 1 nodes of 8xH100s

# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/train/router_replay/run_moonlight16b_router_replay.sh

DATA_DIR="$HOME/data/gsm8k"
LOGGER="wandb" # change to "console" to print to stdout
MODEL_NAME="moonshotai/Moonlight-16B-A3B-Instruct"

INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron

NUM_NODES=1
NUM_GPUS=8

MEGATRON_TP=4
MEGATRON_PP=1
MEGATRON_CP=1
MEGATRON_EP=8
MEGATRON_ETP=1

NUM_INFERENCE_ENGINES=1
INFERENCE_ENGINE_TP=8

# flash attn is not supported for moonlight16b since it is a DeepSeekV3 like model, and uses Multi-Head Latent Attention (MLA)
# https://github.com/NVIDIA/TransformerEngine/blob/483d9594fb070f62966f6a12ed6c90942310b48e/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L483
FLASH_ATTN=false

# router replay (r3)
ROUTER_REPLAY=true
DISTRIBUTED_EXECUTION_BACKEND="mp"

SKYRL_RAY_PG_TIMEOUT_IN_S=300 uv run --isolated --extra megatron --with blobfile -m skyrl.train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path=$MODEL_NAME \
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine.tensor_parallel_size=$INFERENCE_ENGINE_TP \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.policy.megatron_config.moe_enable_routing_replay=$ROUTER_REPLAY \
generator.inference_engine.distributed_executor_backend=$DISTRIBUTED_EXECUTION_BACKEND \
generator.inference_engine.enable_return_routed_experts=$ROUTER_REPLAY \
trainer.use_sample_packing=true \
trainer.flash_attn=$FLASH_ATTN \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=256 \
trainer.policy_mini_batch_size=32 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.ckpt_interval=100 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=false \
generator.inference_engine.backend=$INFERENCE_BACKEND \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.weight_sync_backend=nccl \
generator.inference_engine.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.inference_engine.gpu_memory_utilization=0.6 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k_router_replay" \
trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_moonlight16b-a3b_with_router_replay" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \
$@
5 changes: 5 additions & 0 deletions skyrl/backends/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class InferenceEngineOutput(TypedDict):
response_ids: List[List[int]]
stop_reasons: List[str]
response_logprobs: Optional[List[List[float]]]
rollout_expert_indices: Optional[List[List[List[int]]]] # [seq_len, layer_num, topk]


class InferenceEngineInterface(ABC):
Expand Down Expand Up @@ -65,6 +66,7 @@ async def sample(
all_responses = []
all_stop_reasons = []
all_response_logprobs = []
all_rollout_expert_indices = []

for _ in range(num_samples):
input_batch: InferenceEngineInput = {
Expand All @@ -81,12 +83,15 @@ async def sample(
all_stop_reasons.append(output["stop_reasons"][0])
if output.get("response_logprobs") is not None:
all_response_logprobs.append(output["response_logprobs"][0])
if output.get("rollout_expert_indices") is not None:
all_rollout_expert_indices.append(output["rollout_expert_indices"][0])

return {
"response_ids": all_response_ids,
"responses": all_responses,
"stop_reasons": all_stop_reasons,
"response_logprobs": all_response_logprobs if all_response_logprobs else None,
"rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None,
}

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
stop_reasons: list[str] = [""] * n
response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
response_ids: List[List[int]] = [[] for _ in range(n)]
rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)]
# a bit hacky for now
add_resp_logprobs = False
add_rollout_expert_indices = False

for indices, result in zip(indices_list, results):
for local_idx, original_idx in enumerate(indices):
Expand All @@ -144,12 +146,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
if result.get("response_logprobs", None):
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]
if result.get("rollout_expert_indices", None):
add_rollout_expert_indices = True
rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx]

return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None,
)

def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int:
Expand Down
Loading
Loading