Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ repos:
args: ["--fix", "--show-fixes", "--output-format=full"]
exclude: ^.*\.(ipynb)$|^verl/.*$
- id: ruff-format
exclude: ^verl/.*$
exclude: ^.*\.(ipynb)$|^verl/.*$
4 changes: 3 additions & 1 deletion examples/archive/sft/run_sft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
agent_args = {
"tools": ["python"],
"parser_name": "qwen",
"system_prompt": ('You are an expert mathematician and programmer. Your goal is to solve challenging math problems, like those from the AIME competition, by breaking them down into logical steps and using Python code for calculations. Strive for clarity and efficiency.\n\nFollow this process for every problem:\n1. **Analyze the Problem**: Read the question carefully. Identify the key information, constraints, and what is being asked.\n2. **Think Step-by-Step**: In the `<think>` block, outline your plan. Decompose the problem into the smallest, most logical steps. **You must not write code or perform calculations in this block.** Your goal is to create a plan that will be executed by the Python tool.\n3. **Write Python Code**: In the `<tool_call>` block, write an efficient Python script to execute your plan. The tool expects a JSON object with `name` and `arguments` keys. The `arguments` should be a dictionary with a single `code` key. Ensure the code is self-contained, runs quickly, and prints the final result.\n4. **State the Final Answer**: After receiving the `<tool_result>`, verify it. Then, state the final answer clearly and concisely in the format \\boxed{answer}.\n\nHere is an example:\nQuestion: What is the largest prime factor of 25! ?\n<think>The problem asks for the largest prime factor of 25 factorial. The largest prime factor of n! is the largest prime number less than or equal to n. In this case, n=25. I will write a Python script to find the largest prime number less than or equal to 25.</think>\n<tool_call>\n{"name": "python", "arguments": {"code": "import math\\ndef is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(math.sqrt(n)) + 1):\\n if n % i == 0:\\n return False\\n return True\\n\\ndef largest_prime_up_to(n):\\n for i in range(n, 1, -1):\\n if is_prime(i):\\n return i\\n return None\\n\\nprint(largest_prime_up_to(25))"}}\n</tool_call>\n<tool_result>\n23\n</tool_result>\nThe largest prime factor of 25! is the largest prime number less than or equal to 25. The answer is \\boxed{23}.'),
"system_prompt": (
'You are an expert mathematician and programmer. Your goal is to solve challenging math problems, like those from the AIME competition, by breaking them down into logical steps and using Python code for calculations. Strive for clarity and efficiency.\n\nFollow this process for every problem:\n1. **Analyze the Problem**: Read the question carefully. Identify the key information, constraints, and what is being asked.\n2. **Think Step-by-Step**: In the `<think>` block, outline your plan. Decompose the problem into the smallest, most logical steps. **You must not write code or perform calculations in this block.** Your goal is to create a plan that will be executed by the Python tool.\n3. **Write Python Code**: In the `<tool_call>` block, write an efficient Python script to execute your plan. The tool expects a JSON object with `name` and `arguments` keys. The `arguments` should be a dictionary with a single `code` key. Ensure the code is self-contained, runs quickly, and prints the final result.\n4. **State the Final Answer**: After receiving the `<tool_result>`, verify it. Then, state the final answer clearly and concisely in the format \\boxed{answer}.\n\nHere is an example:\nQuestion: What is the largest prime factor of 25! ?\n<think>The problem asks for the largest prime factor of 25 factorial. The largest prime factor of n! is the largest prime number less than or equal to n. In this case, n=25. I will write a Python script to find the largest prime number less than or equal to 25.</think>\n<tool_call>\n{"name": "python", "arguments": {"code": "import math\\ndef is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(math.sqrt(n)) + 1):\\n if n % i == 0:\\n return False\\n return True\\n\\ndef largest_prime_up_to(n):\\n for i in range(n, 1, -1):\\n if is_prime(i):\\n return i\\n return None\\n\\nprint(largest_prime_up_to(25))"}}\n</tool_call>\n<tool_result>\n23\n</tool_result>\nThe largest prime factor of 25! is the largest prime number less than or equal to 25. The answer is \\boxed{23}.'
),
}
env_args = {
"tools": ["python"],
Expand Down
4 changes: 3 additions & 1 deletion examples/fully_async/deepresearch/refine_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ async def end_request(self, success: bool, latency: float):
# Log periodically
if (self.total_completed + self.total_failed) % 50 == 0:
avg_latency = self.total_latency / max(1, self.total_completed)
logger.info(f"[STATS] In-flight: {self.in_flight}, Completed: {self.total_completed}, Failed: {self.total_failed}, Latency(avg/min/max): {avg_latency:.2f}s/{self.min_latency:.2f}s/{self.max_latency:.2f}s")
logger.info(
f"[STATS] In-flight: {self.in_flight}, Completed: {self.total_completed}, Failed: {self.total_failed}, Latency(avg/min/max): {avg_latency:.2f}s/{self.min_latency:.2f}s/{self.max_latency:.2f}s"
)

async def get_stats(self) -> dict:
async with self._lock:
Expand Down
23 changes: 22 additions & 1 deletion examples/fully_async/deepresearch/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,28 @@ async def run(self, question):
final_answer = extract_boxed_answer(content)

# Aggregate metrics across all tool calls
aggregated_metrics = {"num_turns": num_turns, "total_parse_tool_args_error": sum(m.get("parse_tool_args_error", 0) for m in metrics), "total_tool_return_error": sum(m.get("tool_return_error", 0) for m in metrics), "total_tool_calls": sum(m.get("tool_calls", 0) for m in metrics), "total_tool_wait_time": sum(m.get("tool_wait_time", 0) for m in metrics), "total_refine_time": sum(m.get("refine_time", 0) for m in metrics), "avg_refine_time": sum(m.get("refine_time", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), "total_query_length": sum(m.get("query_length", 0) for m in metrics), "avg_query_length": sum(m.get("query_length", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), "total_generation_time": total_generation_time, "total_completion_tokens": total_completion_tokens, "total_tool_tokens": sum(m.get("tool_tokens", 0) for m in metrics), "avg_completion_tokens_per_turn": total_completion_tokens / max(num_turns, 1), "avg_tool_tokens_per_call": sum(m.get("tool_tokens", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), "duplicate_search_detected": duplicate_search_detected, "excessive_parallel_calls": excessive_parallel_calls, "tool_error_detected": tool_error_detected, "refine_error_detected": refine_error_detected, "overlong": overlong, "merged_step": len(trajectory.merge())}
aggregated_metrics = {
"num_turns": num_turns,
"total_parse_tool_args_error": sum(m.get("parse_tool_args_error", 0) for m in metrics),
"total_tool_return_error": sum(m.get("tool_return_error", 0) for m in metrics),
"total_tool_calls": sum(m.get("tool_calls", 0) for m in metrics),
"total_tool_wait_time": sum(m.get("tool_wait_time", 0) for m in metrics),
"total_refine_time": sum(m.get("refine_time", 0) for m in metrics),
"avg_refine_time": sum(m.get("refine_time", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1),
"total_query_length": sum(m.get("query_length", 0) for m in metrics),
"avg_query_length": sum(m.get("query_length", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1),
"total_generation_time": total_generation_time,
"total_completion_tokens": total_completion_tokens,
"total_tool_tokens": sum(m.get("tool_tokens", 0) for m in metrics),
"avg_completion_tokens_per_turn": total_completion_tokens / max(num_turns, 1),
"avg_tool_tokens_per_call": sum(m.get("tool_tokens", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1),
"duplicate_search_detected": duplicate_search_detected,
"excessive_parallel_calls": excessive_parallel_calls,
"tool_error_detected": tool_error_detected,
"refine_error_detected": refine_error_detected,
"overlong": overlong,
"merged_step": len(trajectory.merge()),
}

if OVERLONG_FILTER and overlong:
for seq in trajectory.sequences:
Expand Down
16 changes: 4 additions & 12 deletions rllm-model-gateway/src/rllm_model_gateway/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def get_session_info(self, session_id: str) -> dict[str, Any]:
resp.raise_for_status()
return resp.json()

def list_sessions(
self, since: float | None = None, limit: int | None = None
) -> list[dict[str, Any]]:
def list_sessions(self, since: float | None = None, limit: int | None = None) -> list[dict[str, Any]]:
params: dict[str, Any] = {}
if since is not None:
params["since"] = since
Expand Down Expand Up @@ -90,9 +88,7 @@ def get_session_traces(
params["since"] = since
if limit is not None:
params["limit"] = limit
resp = self._http.get(
f"{self.gateway_url}/sessions/{session_id}/traces", params=params
)
resp = self._http.get(f"{self.gateway_url}/sessions/{session_id}/traces", params=params)
resp.raise_for_status()
data = resp.json()
return [TraceRecord(**t) for t in data]
Expand Down Expand Up @@ -188,9 +184,7 @@ async def get_session_info(self, session_id: str) -> dict[str, Any]:
resp.raise_for_status()
return resp.json()

async def list_sessions(
self, since: float | None = None, limit: int | None = None
) -> list[dict[str, Any]]:
async def list_sessions(self, since: float | None = None, limit: int | None = None) -> list[dict[str, Any]]:
params: dict[str, Any] = {}
if since is not None:
params["since"] = since
Expand Down Expand Up @@ -218,9 +212,7 @@ async def get_session_traces(
params["since"] = since
if limit is not None:
params["limit"] = limit
resp = await self._http.get(
f"{self.gateway_url}/sessions/{session_id}/traces", params=params
)
resp = await self._http.get(f"{self.gateway_url}/sessions/{session_id}/traces", params=params)
resp.raise_for_status()
data = resp.json()
return [TraceRecord(**t) for t in data]
Expand Down
12 changes: 3 additions & 9 deletions rllm-model-gateway/src/rllm_model_gateway/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

# Inject sampling parameters into POST request bodies (chat completions, etc.)
method = scope.get("method", "").upper()
needs_injection = (
self.add_logprobs or self.add_return_token_ids or self.sessions is not None
)
needs_injection = self.add_logprobs or self.add_return_token_ids or self.sessions is not None
if method == "POST" and needs_injection:
await self._inject_params(scope, receive, send, session_id)
else:
await self.app(scope, receive, send)

async def _inject_params(
self, scope: Scope, receive: Receive, send: Send, session_id: str | None = None
) -> None:
async def _inject_params(self, scope: Scope, receive: Receive, send: Send, session_id: str | None = None) -> None:
"""Read body, inject sampling params, then forward with mutated body."""
body_parts: list[bytes] = []
more = True
Expand All @@ -94,9 +90,7 @@ async def _inject_params(
# Record whether the client originally requested logprobs
# so the proxy can strip them from the response if not.
state = scope["state"]
state["originally_requested_logprobs"] = (
"logprobs" in payload and payload["logprobs"]
)
state["originally_requested_logprobs"] = "logprobs" in payload and payload["logprobs"]
self._mutate(payload, session_id)
raw = json.dumps(payload).encode("utf-8")
except (json.JSONDecodeError, UnicodeDecodeError):
Expand Down
13 changes: 3 additions & 10 deletions rllm-model-gateway/src/rllm_model_gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,7 @@ def _load_config(args: argparse.Namespace) -> GatewayConfig:
# Workers from CLI --worker flags (WorkerConfig validator auto-splits URLs)
worker_urls = getattr(args, "worker", None) or []
if worker_urls:
data["workers"] = [
{"url": raw_url, "worker_id": str(i)}
for i, raw_url in enumerate(worker_urls)
]
data["workers"] = [{"url": raw_url, "worker_id": str(i)} for i, raw_url in enumerate(worker_urls)]

return GatewayConfig(**data)

Expand All @@ -373,9 +370,7 @@ def _load_config(args: argparse.Namespace) -> GatewayConfig:


def main() -> None:
parser = argparse.ArgumentParser(
description="rllm-model-gateway: lightweight LLM call proxy for RL training"
)
parser = argparse.ArgumentParser(description="rllm-model-gateway: lightweight LLM call proxy for RL training")
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--config", type=str, default=None, help="Path to YAML config")
Expand All @@ -398,9 +393,7 @@ def main() -> None:

import uvicorn

uvicorn.run(
app, host=config.host, port=config.port, log_level=config.log_level.lower()
)
uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level.lower())


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions rllm-model-gateway/src/rllm_model_gateway/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def __init__(self, store: TraceStore) -> None:
self._created_at: dict[str, float] = {}
self._sampling_params: dict[str, dict[str, Any]] = {}

def ensure_session(
self, session_id: str, metadata: dict[str, Any] | None = None
) -> str:
def ensure_session(self, session_id: str, metadata: dict[str, Any] | None = None) -> str:
"""Ensure a session exists (create if needed). Returns session_id."""
if session_id not in self._created_at:
self._created_at[session_id] = time.time()
Expand Down
19 changes: 17 additions & 2 deletions rllm/agents/miniwob_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,17 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image) -> str:


class MiniWobAgent(BaseAgent):
def __init__(self, chat_mode: bool = False, use_html: bool = True, use_axtree: bool = True, use_screenshot: bool = False, use_accumulate_thinking: bool = True, cot_prompt: bool = False, use_full_conversation: bool = True, use_reward_shaping: bool = False):
def __init__(
self,
chat_mode: bool = False,
use_html: bool = True,
use_axtree: bool = True,
use_screenshot: bool = False,
use_accumulate_thinking: bool = True,
cot_prompt: bool = False,
use_full_conversation: bool = True,
use_reward_shaping: bool = False,
):
self.chat_mode: bool = chat_mode
self.use_html: bool = use_html
self.use_axtree: bool = use_axtree
Expand Down Expand Up @@ -217,7 +227,12 @@ def get_user_msgs(self, user_obs) -> list[dict[str, str]]:
user_msgs.append({"type": "text", "text": self._get_action_space_description()})

# Add next action prompt
user_msgs.append({"type": "text", "text": "# Next action\nThe task has not been completed yet. You will now think step by step and produce your next best action. Reflect on your past actions, any resulting error message, and the current state of the page before deciding on your next action. The content must be in the same format as shown before in the Action Space. You can plan ahead but only 1 immediate action is needed."})
user_msgs.append(
{
"type": "text",
"text": "# Next action\nThe task has not been completed yet. You will now think step by step and produce your next best action. Reflect on your past actions, any resulting error message, and the current state of the page before deciding on your next action. The content must be in the same format as shown before in the Action Space. You can plan ahead but only 1 immediate action is needed.",
}
)

return user_msgs

Expand Down
21 changes: 5 additions & 16 deletions rllm/environments/tools/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,7 @@ def _ensure_connection_managers(self) -> list[str]:
existing_spec = MCPEnvironment._server_specs.get(server_name)
if existing_spec is not None:
if existing_spec != server_spec:
raise ValueError(
f"MCP server '{server_name}' is already initialized with a different configuration"
)
raise ValueError(f"MCP server '{server_name}' is already initialized with a different configuration")
continue

manager = MCPConnectionManager(
Expand Down Expand Up @@ -460,27 +458,18 @@ def _build_tool_routing(self) -> dict[str, str]:
explicit_server_name = self.tool_name_to_server_name.get(public_tool_name)
if explicit_server_name is not None:
if explicit_server_name not in candidate_servers:
raise ValueError(
f"Tool '{public_tool_name}' is not provided by mapped MCP server '{explicit_server_name}'"
)
raise ValueError(f"Tool '{public_tool_name}' is not provided by mapped MCP server '{explicit_server_name}'")
resolved[public_tool_name] = explicit_server_name
elif len(candidate_servers) == 1:
resolved[public_tool_name] = next(iter(candidate_servers))
else:
raise ValueError(
f"Tool '{public_tool_name}' is provided by multiple MCP servers {sorted(candidate_servers)}. "
"Supply 'tool_name_to_server_name' to disambiguate."
)
raise ValueError(f"Tool '{public_tool_name}' is provided by multiple MCP servers {sorted(candidate_servers)}. Supply 'tool_name_to_server_name' to disambiguate.")

for public_tool_name, mapped_server_name in self.tool_name_to_server_name.items():
if mapped_server_name not in self.mcp_servers:
raise ValueError(
f"Tool mapping for '{public_tool_name}' references unknown MCP server '{mapped_server_name}'"
)
raise ValueError(f"Tool mapping for '{public_tool_name}' references unknown MCP server '{mapped_server_name}'")
if public_tool_name not in discovered_tool_servers:
raise ValueError(
f"Tool mapping for '{public_tool_name}' does not match any discovered tool on the configured MCP servers"
)
raise ValueError(f"Tool mapping for '{public_tool_name}' does not match any discovered tool on the configured MCP servers")

return resolved

Expand Down
4 changes: 1 addition & 3 deletions rllm/experimental/engine/remote_runtime/agentcore_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ async def _run_one(self, sub: TaskSubmission, timeout: float) -> RemoteTaskResul
raw_result=result,
)

async def execute_tasks(
self, submissions: list[TaskSubmission], timeout: float | None = None
) -> list[RemoteTaskResult]:
async def execute_tasks(self, submissions: list[TaskSubmission], timeout: float | None = None) -> list[RemoteTaskResult]:
"""Submit all tasks concurrently via asyncio.gather.

Each task invokes then polls in sequence; all tasks run in parallel.
Expand Down
4 changes: 1 addition & 3 deletions rllm/experimental/engine/remote_runtime/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def initialize(self) -> None:
"""Client setup from config."""
...

async def execute_tasks(
self, submissions: list[TaskSubmission], timeout: float | None = None
) -> list[RemoteTaskResult]:
async def execute_tasks(self, submissions: list[TaskSubmission], timeout: float | None = None) -> list[RemoteTaskResult]:
"""Submit tasks concurrently and gather results. Returns one result per submission."""
...

Expand Down
Loading
Loading