diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 0dc079c22e2..cb2f7141b8a 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -522,7 +522,7 @@ async def clear_kv_cache(self): async def wait_for_requests_to_drain(self): await self.engine.wait_for_requests_to_drain() - async def abort_all_requests(self) -> dict[str, Any]: + async def abort_all_requests(self, reset_prefix_cache: bool = True) -> dict[str, Any]: """Abort all ongoing generation requests. Returns: @@ -552,6 +552,11 @@ async def abort_all_requests(self) -> dict[str, Any]: self.engine.output_processor.abort_requests(request_ids) await self.engine.engine_core.abort_requests_async(request_ids) + # Try to reset prefix cache to ensure clean state + if reset_prefix_cache: + await self.clear_kv_cache() + logger.info("Prefix cache reset after abort") + logger.info(f"Aborted {len(request_ids)} requests: {request_ids}") return {"aborted_count": len(request_ids), "request_ids": request_ids} @@ -559,7 +564,7 @@ async def abort_all_requests(self) -> dict[str, Any]: logger.error(f"Error aborting requests: {e}") return {"aborted_count": 0, "request_ids": [], "error": str(e)} - async def abort_request(self, request_id: str) -> dict[str, Any]: + async def abort_request(self, request_id: str, reset_prefix_cache: bool = True) -> dict[str, Any]: """Abort a specific generation request. Args: @@ -587,6 +592,11 @@ async def abort_request(self, request_id: str) -> dict[str, Any]: self.engine.output_processor.abort_requests([request_id]) await self.engine.engine_core.abort_requests_async([request_id]) + # Try to reset prefix cache to ensure clean state + if reset_prefix_cache: + await self.clear_kv_cache() + logger.info(f"Prefix cache reset after abort request {request_id}") + logger.info(f"Aborted request: {request_id}") return {"aborted": True, "request_id": request_id}