diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index f4e26f13fde..4aea687fe66 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -412,11 +412,21 @@ async def run_server(self, args: argparse.Namespace): method="monkey_patch_model", kwargs={"vocab_size": len(self.model_config.tokenizer)} ) - app = build_app(args) - if _VLLM_VERSION > version.parse("0.11.0"): - await init_app_state(engine_client, app.state, args) + build_app_sig = inspect.signature(build_app) + supported_tasks: tuple[Any, ...] = () + if "supported_tasks" in build_app_sig.parameters: + supported_tasks = await engine_client.get_supported_tasks() + app = build_app(args, supported_tasks) else: + app = build_app(args) + + init_app_sig = inspect.signature(init_app_state) + if "vllm_config" in init_app_sig.parameters: await init_app_state(engine_client, vllm_config, app.state, args) + elif "supported_tasks" in init_app_sig.parameters: + await init_app_state(engine_client, app.state, args, supported_tasks) + else: + await init_app_state(engine_client, app.state, args) if self.replica_rank == 0 and self.node_rank == 0: logger.info(f"Initializing a V1 LLM engine with config: {vllm_config}")