Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/train/gsm8k/run_gsm8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.inference_engine.num_engines=$NUM_GPUS \
generator.inference_engine.tensor_parallel_size=1 \
generator.inference_engine.num_engines=1 \
generator.inference_engine.tensor_parallel_size=4 \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
Expand Down
5 changes: 5 additions & 0 deletions skyrl/backends/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class InferenceEngineOutput(TypedDict):
response_ids: List[List[int]]
stop_reasons: List[str]
response_logprobs: Optional[List[List[float]]]
rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [seq_len, layer_num, topk]


class InferenceEngineInterface(ABC):
Expand Down Expand Up @@ -63,6 +64,7 @@ async def sample(
all_responses = []
all_stop_reasons = []
all_response_logprobs = []
all_rollout_expert_indices = []

for _ in range(num_samples):
input_batch: InferenceEngineInput = {
Expand All @@ -79,12 +81,15 @@ async def sample(
all_stop_reasons.append(output["stop_reasons"][0])
if output.get("response_logprobs") is not None:
all_response_logprobs.append(output["response_logprobs"][0])
if output.get("rollout_expert_indices") is not None:
all_rollout_expert_indices.append(output["rollout_expert_indices"][0])

return {
"response_ids": all_response_ids,
"responses": all_responses,
"stop_reasons": all_stop_reasons,
"response_logprobs": all_response_logprobs if all_response_logprobs else None,
"rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None,
}

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,10 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
stop_reasons: list[str] = [""] * n
response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
response_ids: List[List[int]] = [[] for _ in range(n)]
rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)]
# a bit hacky for now
add_resp_logprobs = False
add_rollout_expert_indices = False

for indices, result in zip(indices_list, results):
for local_idx, original_idx in enumerate(indices):
Expand All @@ -164,12 +166,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
if result.get("response_logprobs", None):
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]
if result.get("rollout_expert_indices", None):
add_rollout_expert_indices = True
rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx]

return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None,
)

def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int:
Expand Down Expand Up @@ -265,6 +271,7 @@ async def _generate_single_with_retry(
# 2. Initialize fields we want to accumulate or update in each loop iteration
accum_response_ids: List[int] = []
accum_response_logprobs: List[float] = []
accum_rollout_expert_indices: List[List[List[int]]] = []
stop_reason: str = "abort"

# We only use it if generation is completed in one turn to maintain original behavior with no retry.
Expand Down Expand Up @@ -300,6 +307,10 @@ async def _generate_single_with_retry(
new_response_logprobs_list: Optional[List[List[float]]] = partial_response.get("response_logprobs", None)
if new_response_logprobs_list is not None and len(new_response_logprobs_list) > 0:
new_response_logprobs = new_response_logprobs_list[0]
new_rollout_expert_indices: Optional[List[List[List[int]]]] = None
new_rollout_expert_indices_list = partial_response.get("rollout_expert_indices", None)
if new_rollout_expert_indices_list is not None and len(new_rollout_expert_indices_list) > 0:
new_rollout_expert_indices = new_rollout_expert_indices_list[0]

# 3.4 Aborted without generating tokens, so partial_response is useless.
if stop_reason == "abort" and len(new_response_ids) == 0:
Expand All @@ -309,6 +320,8 @@ async def _generate_single_with_retry(
accum_response_ids.extend(new_response_ids)
if new_response_logprobs is not None:
accum_response_logprobs.extend(new_response_logprobs)
if new_rollout_expert_indices is not None:
accum_rollout_expert_indices.extend(new_rollout_expert_indices)
num_turns += 1

# 4. Build the final response and return.
Expand All @@ -321,6 +334,7 @@ async def _generate_single_with_retry(
stop_reasons=[stop_reason],
response_ids=[accum_response_ids],
response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None,
rollout_expert_indices=([accum_rollout_expert_indices] if len(accum_rollout_expert_indices) > 0 else None),
)

async def _chat_completion_with_retry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def create_ray_wrapped_inference_engines(
rope_scaling: Dict[str, Any] = {},
rope_theta: float | None = None,
enable_ray_prometheus_stats: bool = False,
enable_return_routed_experts: bool = False,
served_model_name: str | None = None,
) -> List[InferenceEngineInterface]:
"""
Expand Down Expand Up @@ -249,6 +250,7 @@ def create_ray_wrapped_inference_engines(
max_num_seqs=max_num_seqs,
max_logprobs=1, # only need chosen-token logprobs
enable_ray_prometheus_stats=enable_ray_prometheus_stats,
enable_return_routed_experts=enable_return_routed_experts,
**dp_kwargs,
**engine_init_kwargs,
**lora_kwargs,
Expand Down
21 changes: 21 additions & 0 deletions skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _postprocess_outputs(self, outputs):
stop_reasons: List[str] = []
response_ids: List[List[int]] = []
response_logprobs: Optional[List[List[float]]] = []
rollout_expert_indices: Optional[List[List[List[List[int]]]]] = []

for output in outputs:
# TODO(tgriggs): Support n>1 sampling.
Expand All @@ -156,14 +157,26 @@ def _postprocess_outputs(self, outputs):
del token_logprobs
response_logprobs.append(_logprobs)

_routed_experts = None
if resp.routed_experts is not None:
if hasattr(resp.routed_experts, "tolist"):
_routed_experts = resp.routed_experts.tolist()
else:
_routed_experts = resp.routed_experts
rollout_expert_indices.append(_routed_experts)

if len(response_logprobs) and response_logprobs[0] is None:
response_logprobs = None # hack: assume uniform sampling params

if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None:
rollout_expert_indices = None # hack: assume uniform sampling params
Comment on lines +171 to +172
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Wrong comparison operator (== 0 instead of > 0) prevents rollout_expert_indices from being set to None

In _postprocess_outputs, line 171 uses len(rollout_expert_indices) == 0 but should use len(rollout_expert_indices) > 0 (or truthiness, like the logprobs check on line 168). With == 0: (1) if the list is empty, rollout_expert_indices[0] raises an IndexError; (2) if the list is non-empty (the normal case), the condition is always False, so a list of all None values (e.g. [None, None, ...]) is never collapsed to None. This means when enable_return_routed_experts is disabled (the default), downstream code in inference_engine_client.py:169-171 sees a truthy list of Nones, sets add_rollout_expert_indices = True, and propagates [None, None, ...] instead of None through the pipeline.

Comparison with correct pattern on line 168

Line 168 (correct): if len(response_logprobs) and response_logprobs[0] is None:
Line 171 (broken): if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None:

Suggested change
if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None:
rollout_expert_indices = None # hack: assume uniform sampling params
if len(rollout_expert_indices) and rollout_expert_indices[0] is None:
rollout_expert_indices = None # hack: assume uniform sampling params
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs,
rollout_expert_indices=rollout_expert_indices,
)

def _get_engine(self):
Expand Down Expand Up @@ -321,6 +334,14 @@ def _create_engine(self, *args, **kwargs):
enable_log_requests = kwargs.pop("enable_log_requests", False)
max_log_len = kwargs.pop("max_log_len", None)

# Log if enable_return_routed_experts is being passed
if "enable_return_routed_experts" in kwargs:
logger.info(
f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs"
)
else:
logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs")
Comment on lines +337 to +343
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These DEBUG log messages are helpful during development but should ideally be removed or made configurable (e.g., tied to a verbose flag) before merging to production. Excessive logging can clutter output and potentially impact performance.

Suggested change
# Log if enable_return_routed_experts is being passed
if "enable_return_routed_experts" in kwargs:
logger.info(
f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs"
)
else:
logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs")
# Log if enable_return_routed_experts is being passed
# Consider making this logging configurable or removing it for production builds.
# if "enable_return_routed_experts" in kwargs:
# logger.info(
# f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs"
# )
# else:
# logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs")


if version.parse(vllm.__version__) >= version.parse("0.10.0"):
engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs)
else:
Expand Down
1 change: 1 addition & 0 deletions skyrl/backends/skyrl_train/training_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ class TrainingInput(TypedDict, total=False):
kl: Float[torch.Tensor, "batch_size seq_len"]
rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]]
rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]]
rollout_expert_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]]


class TrainingInputBatch(TensorBatch[TrainingInput]):
Expand Down
Loading