diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 3bebbd8655..ef5e3902b9 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -170,6 +170,18 @@ def _get_engine(self): """Get the underlying engine for RPC calls.""" return self.llm.engine if hasattr(self.llm, "engine") else self.llm + @staticmethod + def _get_unfinished_request_ids(output_processor) -> list: + """Get unfinished request IDs suitable for abort/abort_request calls. + + In vllm 0.16.0+, request_states is keyed by internal IDs (with a random suffix), + while abort() expects external IDs by default. We use external_req_ids when + available and fall back to request_states keys for older vllm versions. + """ + if hasattr(output_processor, "external_req_ids"): + return list(output_processor.external_req_ids.keys()) + return list(output_processor.request_states.keys()) + def reset_prefix_cache(self): """Reset the prefix cache. Subclasses override for async version.""" return self.llm.llm_engine.reset_prefix_cache() @@ -245,7 +257,7 @@ async def sleep(self, *args: Any, **kwargs: Any): "generation should be done before sleep() is called. Check for potential failures or " "dangling requests in your Generator/Env. Aborting all unfinished requests." ) - unfinished_request_ids = list(output_processor.request_states.keys()) + unfinished_request_ids = self._get_unfinished_request_ids(output_processor) await asyncio.to_thread(engine.abort_request, unfinished_request_ids) level = 1 if self._is_lora else kwargs.get("level", 2) @@ -457,7 +469,7 @@ async def sleep(self, *args: Any, **kwargs: Any): "generation should be done before sleep() is called. Check for potential failures or " "dangling requests in your Generator/Env. Aborting all unfinished requests." ) - unfinished_request_ids = list(output_processor.request_states.keys()) + unfinished_request_ids = self._get_unfinished_request_ids(output_processor) await engine.abort(unfinished_request_ids) # TODO(team): remove once vllm fixes this @@ -608,8 +620,7 @@ async def abort_generation(self) -> None: already-generated tokens with a stop_reason of "abort". """ engine = self._get_engine() - # Collect all request IDs currently tracked by the scheduler/output processor - unfinished_request_ids = list(engine.output_processor.request_states.keys()) + unfinished_request_ids = self._get_unfinished_request_ids(engine.output_processor) if unfinished_request_ids: await engine.abort(unfinished_request_ids) await engine.reset_prefix_cache() # avoid KV-cache pollution