Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ async def clear_kv_cache(self):
async def wait_for_requests_to_drain(self):
await self.engine.wait_for_requests_to_drain()

async def abort_all_requests(self) -> dict[str, Any]:
async def abort_all_requests(self, reset_prefix_cache: bool = True) -> dict[str, Any]:
"""Abort all ongoing generation requests.

Returns:
Expand Down Expand Up @@ -552,14 +552,19 @@ async def abort_all_requests(self) -> dict[str, Any]:
self.engine.output_processor.abort_requests(request_ids)
await self.engine.engine_core.abort_requests_async(request_ids)

# Try to reset prefix cache to ensure clean state
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should always reset_prefix_cache when abort request?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should reset prefix cache after abort as the rollout weights are usually updated after aborting.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not reset, it may degenerate to a paradigm similar to pipelineRL

if reset_prefix_cache:
await self.clear_kv_cache()
logger.info("Prefix cache reset after abort")

logger.info(f"Aborted {len(request_ids)} requests: {request_ids}")
return {"aborted_count": len(request_ids), "request_ids": request_ids}

except Exception as e:
logger.error(f"Error aborting requests: {e}")
return {"aborted_count": 0, "request_ids": [], "error": str(e)}

async def abort_request(self, request_id: str) -> dict[str, Any]:
async def abort_request(self, request_id: str, reset_prefix_cache: bool = True) -> dict[str, Any]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will we want to abort single request?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be used quite often, but we can leave it as an API for fine-grained control

"""Abort a specific generation request.

Args:
Expand Down Expand Up @@ -587,6 +592,11 @@ async def abort_request(self, request_id: str) -> dict[str, Any]:
self.engine.output_processor.abort_requests([request_id])
await self.engine.engine_core.abort_requests_async([request_id])

# Try to reset prefix cache to ensure clean state
if reset_prefix_cache:
await self.clear_kv_cache()
logger.info(f"Prefix cache reset after abort request {request_id}")

logger.info(f"Aborted request: {request_id}")
return {"aborted": True, "request_id": request_id}

Expand Down
Loading