Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,18 @@ def _get_engine(self):
"""Get the underlying engine for RPC calls."""
return self.llm.engine if hasattr(self.llm, "engine") else self.llm

@staticmethod
def _get_unfinished_request_ids(output_processor) -> list:
"""Get unfinished request IDs suitable for abort/abort_request calls.

In vllm 0.16.0+, request_states is keyed by internal IDs (with a random suffix),
while abort() expects external IDs by default. We use external_req_ids when
available and fall back to request_states keys for older vllm versions.
"""
if hasattr(output_processor, "external_req_ids"):
return list(output_processor.external_req_ids.keys())
return list(output_processor.request_states.keys())

def reset_prefix_cache(self):
"""Reset the prefix cache. Subclasses override for async version."""
return self.llm.llm_engine.reset_prefix_cache()
Expand Down Expand Up @@ -245,7 +257,7 @@ async def sleep(self, *args: Any, **kwargs: Any):
"generation should be done before sleep() is called. Check for potential failures or "
"dangling requests in your Generator/Env. Aborting all unfinished requests."
)
unfinished_request_ids = list(output_processor.request_states.keys())
unfinished_request_ids = self._get_unfinished_request_ids(output_processor)
await asyncio.to_thread(engine.abort_request, unfinished_request_ids)

level = 1 if self._is_lora else kwargs.get("level", 2)
Expand Down Expand Up @@ -457,7 +469,7 @@ async def sleep(self, *args: Any, **kwargs: Any):
"generation should be done before sleep() is called. Check for potential failures or "
"dangling requests in your Generator/Env. Aborting all unfinished requests."
)
unfinished_request_ids = list(output_processor.request_states.keys())
unfinished_request_ids = self._get_unfinished_request_ids(output_processor)
await engine.abort(unfinished_request_ids)

# TODO(team): remove once vllm fixes this
Expand Down Expand Up @@ -608,8 +620,7 @@ async def abort_generation(self) -> None:
already-generated tokens with a stop_reason of "abort".
"""
engine = self._get_engine()
# Collect all request IDs currently tracked by the scheduler/output processor
unfinished_request_ids = list(engine.output_processor.request_states.keys())
unfinished_request_ids = self._get_unfinished_request_ids(engine.output_processor)
if unfinished_request_ids:
await engine.abort(unfinished_request_ids)
await engine.reset_prefix_cache() # avoid KV-cache pollution
Expand Down
Loading