diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml index b5e8cfd4d..4b3f55290 100644 --- a/.github/workflows/claude-code-review.yml +++ b/.github/workflows/claude-code-review.yml @@ -12,11 +12,7 @@ on: jobs: claude-review: - # Optional: Filter by PR author - # if: | - # github.event.pull_request.user.login == 'external-contributor' || - # github.event.pull_request.user.login == 'new-developer' || - # github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' + if: false # Disabled due to credential issue with CLAUDE_CODE_OAUTH_TOKEN runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index d300267f1..1cead94bd 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -12,11 +12,7 @@ on: jobs: claude: - if: | - (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || - (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || - (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || - (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) + if: false # Disabled due to credential issue with CLAUDE_CODE_OAUTH_TOKEN runs-on: ubuntu-latest permissions: contents: read diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 5d32ba068..8540013a2 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -12,12 +12,14 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - + with: + fetch-depth: 0 + - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.11' - + - name: Cache pre-commit uses: actions/cache@v3 with: @@ -25,11 +27,16 @@ jobs: key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} restore-keys: | pre-commit- - + - name: Install dependencies run: | python -m pip install --upgrade pip pip install pre-commit - + - name: Run pre-commit - run: pre-commit run --all-files + run: | + if [ "${{ github.event_name }}" = "pull_request" ]; then + pre-commit run --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} --show-diff-on-fail + else + pre-commit run --all-files --show-diff-on-fail + fi diff --git a/.github/workflows/test-tinker.yml b/.github/workflows/test-tinker.yml index 94d244cff..96d69993f 100644 --- a/.github/workflows/test-tinker.yml +++ b/.github/workflows/test-tinker.yml @@ -32,7 +32,7 @@ jobs: uses: astral-sh/setup-uv@v4 - name: Install dependencies - run: uv sync --extra tinker + run: uv sync --extra tinker --extra dev - name: Run Tinker engine tests run: uv run pytest tests/engine/test_tinker_engine.py -v diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 421b0384b..c2cb4428f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,4 +6,4 @@ repos: args: ["--fix", "--show-fixes", "--output-format=full"] exclude: ^.*\.(ipynb)$|^verl/.*$ - id: ruff-format - exclude: ^verl/.*$ \ No newline at end of file + exclude: ^.*\.(ipynb)$|^verl/.*$ \ No newline at end of file diff --git a/examples/archive/mcp/README.md b/examples/archive/mcp/README.md index d87c3226c..ece63b705 100644 --- a/examples/archive/mcp/README.md +++ b/examples/archive/mcp/README.md @@ -7,7 +7,7 @@ This example demonstrates how to use external MCP servers as tool providers with ```bash # Install MCP CLI (if needed for other MCP servers) uv pip install mcp - +``` ## Files @@ -71,6 +71,33 @@ This will: - **`MCPEnvironment`** - Environment that manages MCP server connections and tool execution - **`MCPConnectionManager`** - Handles MCP server lifecycle and tool discovery +### Multiple MCP Servers + +`MCPEnvironment` also supports routing tool calls across multiple named MCP servers: + +```python +env = MCPEnvironment( + task={"question": "Find and summarize the latest updates."}, + mcp_servers={ + "search": { + "command": "npx", + "args": ["-y", "tavily-mcp@0.2.4"], + "env": {"TAVILY_API_KEY": "..."}, + }, + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + }, + }, + tool_name_to_server_name={ + "tavily_search": "search", + "read_file": "filesystem", + }, +) +``` + +Use `tool_name_to_server_name` when multiple servers expose the same public tool name, including underscore aliases for tools whose original MCP names contain hyphens. + ### Integration with RLLM The example follows standard RLLM patterns: diff --git a/examples/archive/sft/run_sft_model.py b/examples/archive/sft/run_sft_model.py index 961782d21..ca010ca16 100644 --- a/examples/archive/sft/run_sft_model.py +++ b/examples/archive/sft/run_sft_model.py @@ -30,7 +30,9 @@ agent_args = { "tools": ["python"], "parser_name": "qwen", - "system_prompt": ('You are an expert mathematician and programmer. Your goal is to solve challenging math problems, like those from the AIME competition, by breaking them down into logical steps and using Python code for calculations. Strive for clarity and efficiency.\n\nFollow this process for every problem:\n1. **Analyze the Problem**: Read the question carefully. Identify the key information, constraints, and what is being asked.\n2. **Think Step-by-Step**: In the `` block, outline your plan. Decompose the problem into the smallest, most logical steps. **You must not write code or perform calculations in this block.** Your goal is to create a plan that will be executed by the Python tool.\n3. **Write Python Code**: In the `` block, write an efficient Python script to execute your plan. The tool expects a JSON object with `name` and `arguments` keys. The `arguments` should be a dictionary with a single `code` key. Ensure the code is self-contained, runs quickly, and prints the final result.\n4. **State the Final Answer**: After receiving the ``, verify it. Then, state the final answer clearly and concisely in the format \\boxed{answer}.\n\nHere is an example:\nQuestion: What is the largest prime factor of 25! ?\nThe problem asks for the largest prime factor of 25 factorial. The largest prime factor of n! is the largest prime number less than or equal to n. In this case, n=25. I will write a Python script to find the largest prime number less than or equal to 25.\n\n{"name": "python", "arguments": {"code": "import math\\ndef is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(math.sqrt(n)) + 1):\\n if n % i == 0:\\n return False\\n return True\\n\\ndef largest_prime_up_to(n):\\n for i in range(n, 1, -1):\\n if is_prime(i):\\n return i\\n return None\\n\\nprint(largest_prime_up_to(25))"}}\n\n\n23\n\nThe largest prime factor of 25! is the largest prime number less than or equal to 25. The answer is \\boxed{23}.'), + "system_prompt": ( + 'You are an expert mathematician and programmer. Your goal is to solve challenging math problems, like those from the AIME competition, by breaking them down into logical steps and using Python code for calculations. Strive for clarity and efficiency.\n\nFollow this process for every problem:\n1. **Analyze the Problem**: Read the question carefully. Identify the key information, constraints, and what is being asked.\n2. **Think Step-by-Step**: In the `` block, outline your plan. Decompose the problem into the smallest, most logical steps. **You must not write code or perform calculations in this block.** Your goal is to create a plan that will be executed by the Python tool.\n3. **Write Python Code**: In the `` block, write an efficient Python script to execute your plan. The tool expects a JSON object with `name` and `arguments` keys. The `arguments` should be a dictionary with a single `code` key. Ensure the code is self-contained, runs quickly, and prints the final result.\n4. **State the Final Answer**: After receiving the ``, verify it. Then, state the final answer clearly and concisely in the format \\boxed{answer}.\n\nHere is an example:\nQuestion: What is the largest prime factor of 25! ?\nThe problem asks for the largest prime factor of 25 factorial. The largest prime factor of n! is the largest prime number less than or equal to n. In this case, n=25. I will write a Python script to find the largest prime number less than or equal to 25.\n\n{"name": "python", "arguments": {"code": "import math\\ndef is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(math.sqrt(n)) + 1):\\n if n % i == 0:\\n return False\\n return True\\n\\ndef largest_prime_up_to(n):\\n for i in range(n, 1, -1):\\n if is_prime(i):\\n return i\\n return None\\n\\nprint(largest_prime_up_to(25))"}}\n\n\n23\n\nThe largest prime factor of 25! is the largest prime number less than or equal to 25. The answer is \\boxed{23}.' + ), } env_args = { "tools": ["python"], diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker.py b/examples/countdown/unified_trainer/train_countdown_unified_tinker.py new file mode 100644 index 000000000..dadc98b20 --- /dev/null +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker.py @@ -0,0 +1,28 @@ +import hydra + +from rllm.data.dataset import DatasetRegistry +from rllm.experimental.unified_trainer import AgentTrainer +from rllm.rewards.countdown_reward import countdown_reward_fn +from rllm.workflows.simple_workflow import SimpleWorkflow + + +@hydra.main(config_path="pkg://rllm.experimental.config", config_name="unified", version_base=None) +def main(config): + train_dataset = DatasetRegistry.load_dataset("countdown", "train") + test_dataset = DatasetRegistry.load_dataset("countdown", "test") + + trainer = AgentTrainer( + workflow_class=SimpleWorkflow, + workflow_args={ + "reward_function": countdown_reward_fn, + }, + config=config, + train_dataset=train_dataset, + val_dataset=test_dataset, + backend="tinker", + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh b/examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh new file mode 100644 index 000000000..1f3e8586b --- /dev/null +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh @@ -0,0 +1,36 @@ +set -x + +python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \ + rllm/backend=tinker \ + model.name=Qwen/Qwen3-8B \ + model.lora_rank=32 \ + training.group_size=8 \ + training.learning_rate=2e-5 \ + training.max_length=4096 \ + sampling.train.temperature=1.0 \ + sampling.train.top_p=1.0 \ + sampling.val.temperature=1.0 \ + sampling.val.top_p=1.0 \ + validation.group_size=1 \ + rllm.workflow.n_parallel_tasks=256 \ + rllm.workflow.retry_limit=1 \ + rllm.workflow.raise_on_error=false \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.train_batch_size=1 \ + data.val_batch_size=1024 \ + rllm.algorithm.adv_estimator=grpo \ + rllm.algorithm.norm_adv_by_std_in_grpo=true \ + rllm.async_training.enable=true \ + rllm.async_training.mini_batch_size=32 \ + rllm.async_training.fwd_bwd_group_size=8 \ + rllm.async_training.staleness_threshold=0.5 \ + rllm.async_training.trigger_parameter_sync_step=1 \ + rllm.async_training.partial_rollout=true \ + rllm.trainer.total_epochs=1 \ + rllm.trainer.logger='[wandb]' \ + rllm.trainer.project_name='rllm-countdown' \ + rllm.trainer.experiment_name='countdown-tinker-async-staleness-0.5' \ + rllm.trainer.val_before_train=true \ + rllm.trainer.test_freq=10 \ + rllm.trainer.save_freq=-1 diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh b/examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh new file mode 100644 index 000000000..a9d2748fa --- /dev/null +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh @@ -0,0 +1,31 @@ +set -x + +python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \ + rllm/backend=tinker \ + model.name=Qwen/Qwen3-8B \ + model.lora_rank=32 \ + training.group_size=8 \ + training.learning_rate=2e-5 \ + training.max_length=4096 \ + sampling.train.temperature=1.0 \ + sampling.train.top_p=1.0 \ + sampling.val.temperature=1.0 \ + sampling.val.top_p=1.0 \ + validation.group_size=1 \ + rllm.workflow.n_parallel_tasks=256 \ + rllm.workflow.retry_limit=1 \ + rllm.workflow.raise_on_error=false \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.train_batch_size=32 \ + data.val_batch_size=1024 \ + rllm.algorithm.adv_estimator=grpo \ + rllm.algorithm.norm_adv_by_std_in_grpo=true \ + rllm.async_training.enable=false \ + rllm.trainer.total_epochs=1 \ + rllm.trainer.logger='[wandb]' \ + rllm.trainer.project_name='rllm-countdown' \ + rllm.trainer.experiment_name='countdown-tinker-sync' \ + rllm.trainer.val_before_train=true \ + rllm.trainer.test_freq=10 \ + rllm.trainer.save_freq=-1 diff --git a/examples/deepcoder/prepare_deepcoder_data.py b/examples/deepcoder/prepare_deepcoder_data.py index fedb8f18e..533a1a3ef 100644 --- a/examples/deepcoder/prepare_deepcoder_data.py +++ b/examples/deepcoder/prepare_deepcoder_data.py @@ -7,8 +7,19 @@ def prepare_deepcoder_data(train_size: int = None, test_size: int = None): - train_dataset = concatenate_datasets([load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="primeintellect", split="train"), load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="taco", split="train"), load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="train")]) - test_dataset = concatenate_datasets([load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="codeforces", split="test"), load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="test")]) + train_dataset = concatenate_datasets( + [ + load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="primeintellect", split="train"), + load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="taco", split="train"), + load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="train"), + ] + ) + test_dataset = concatenate_datasets( + [ + load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="codeforces", split="test"), + load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="test"), + ] + ) def preprocess_fn(example, idx): starter_code = example.get("starter_code", "") @@ -39,7 +50,15 @@ def preprocess_fn(example, idx): else: test["metadata"] = {"func_name": None} - return {"question": question, "ground_truth": json.dumps(tests), "data_source": "livecodebench", "uid": f"deepcoder_{idx}", "index": idx, "starter_code": starter_code, "metadata": json.dumps(metadata)} + return { + "question": question, + "ground_truth": json.dumps(tests), + "data_source": "livecodebench", + "uid": f"deepcoder_{idx}", + "index": idx, + "starter_code": starter_code, + "metadata": json.dumps(metadata), + } if train_size: train_dataset = train_dataset.select(range(min(train_size, len(train_dataset)))) diff --git a/examples/fully_async/deepresearch/refine_agent.py b/examples/fully_async/deepresearch/refine_agent.py index 1ffab4ece..a44bd43d4 100644 --- a/examples/fully_async/deepresearch/refine_agent.py +++ b/examples/fully_async/deepresearch/refine_agent.py @@ -70,7 +70,9 @@ async def end_request(self, success: bool, latency: float): # Log periodically if (self.total_completed + self.total_failed) % 50 == 0: avg_latency = self.total_latency / max(1, self.total_completed) - logger.info(f"[STATS] In-flight: {self.in_flight}, Completed: {self.total_completed}, Failed: {self.total_failed}, Latency(avg/min/max): {avg_latency:.2f}s/{self.min_latency:.2f}s/{self.max_latency:.2f}s") + logger.info( + f"[STATS] In-flight: {self.in_flight}, Completed: {self.total_completed}, Failed: {self.total_failed}, Latency(avg/min/max): {avg_latency:.2f}s/{self.min_latency:.2f}s/{self.max_latency:.2f}s" + ) async def get_stats(self) -> dict: async with self._lock: diff --git a/examples/fully_async/deepresearch/search_agent.py b/examples/fully_async/deepresearch/search_agent.py index 4ffed4a7f..139d9ff74 100644 --- a/examples/fully_async/deepresearch/search_agent.py +++ b/examples/fully_async/deepresearch/search_agent.py @@ -242,7 +242,28 @@ async def run(self, question): final_answer = extract_boxed_answer(content) # Aggregate metrics across all tool calls - aggregated_metrics = {"num_turns": num_turns, "total_parse_tool_args_error": sum(m.get("parse_tool_args_error", 0) for m in metrics), "total_tool_return_error": sum(m.get("tool_return_error", 0) for m in metrics), "total_tool_calls": sum(m.get("tool_calls", 0) for m in metrics), "total_tool_wait_time": sum(m.get("tool_wait_time", 0) for m in metrics), "total_refine_time": sum(m.get("refine_time", 0) for m in metrics), "avg_refine_time": sum(m.get("refine_time", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), "total_query_length": sum(m.get("query_length", 0) for m in metrics), "avg_query_length": sum(m.get("query_length", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), "total_generation_time": total_generation_time, "total_completion_tokens": total_completion_tokens, "total_tool_tokens": sum(m.get("tool_tokens", 0) for m in metrics), "avg_completion_tokens_per_turn": total_completion_tokens / max(num_turns, 1), "avg_tool_tokens_per_call": sum(m.get("tool_tokens", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), "duplicate_search_detected": duplicate_search_detected, "excessive_parallel_calls": excessive_parallel_calls, "tool_error_detected": tool_error_detected, "refine_error_detected": refine_error_detected, "overlong": overlong, "merged_step": len(trajectory.merge())} + aggregated_metrics = { + "num_turns": num_turns, + "total_parse_tool_args_error": sum(m.get("parse_tool_args_error", 0) for m in metrics), + "total_tool_return_error": sum(m.get("tool_return_error", 0) for m in metrics), + "total_tool_calls": sum(m.get("tool_calls", 0) for m in metrics), + "total_tool_wait_time": sum(m.get("tool_wait_time", 0) for m in metrics), + "total_refine_time": sum(m.get("refine_time", 0) for m in metrics), + "avg_refine_time": sum(m.get("refine_time", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), + "total_query_length": sum(m.get("query_length", 0) for m in metrics), + "avg_query_length": sum(m.get("query_length", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), + "total_generation_time": total_generation_time, + "total_completion_tokens": total_completion_tokens, + "total_tool_tokens": sum(m.get("tool_tokens", 0) for m in metrics), + "avg_completion_tokens_per_turn": total_completion_tokens / max(num_turns, 1), + "avg_tool_tokens_per_call": sum(m.get("tool_tokens", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), + "duplicate_search_detected": duplicate_search_detected, + "excessive_parallel_calls": excessive_parallel_calls, + "tool_error_detected": tool_error_detected, + "refine_error_detected": refine_error_detected, + "overlong": overlong, + "merged_step": len(trajectory.merge()), + } if OVERLONG_FILTER and overlong: for seq in trajectory.sequences: diff --git a/examples/math_tinker/rl_loop_tinker_original.py b/examples/math_tinker/rl_loop_tinker_original.py index e0b495f0f..5df78c7ca 100644 --- a/examples/math_tinker/rl_loop_tinker_original.py +++ b/examples/math_tinker/rl_loop_tinker_original.py @@ -191,7 +191,9 @@ def main(config: Config): target_tokens = tokens[1:] all_logprobs = [0.0] * ob_len + logprob all_advantages = [0.0] * ob_len + [advantage] * (len(input_tokens) - ob_len) - assert len(input_tokens) == len(target_tokens) == len(all_logprobs) == len(all_advantages), f"len(input_tokens): {len(input_tokens)}, len(target_tokens): {len(target_tokens)}, len(all_logprobs): {len(all_logprobs)}, len(all_advantages): {len(all_advantages)}" + assert len(input_tokens) == len(target_tokens) == len(all_logprobs) == len(all_advantages), ( + f"len(input_tokens): {len(input_tokens)}, len(target_tokens): {len(target_tokens)}, len(all_logprobs): {len(all_logprobs)}, len(all_advantages): {len(all_advantages)}" + ) datum = types.Datum( model_input=types.ModelInput.from_ints(tokens=input_tokens), loss_fn_inputs={ diff --git a/examples/search/retrieval/server.py b/examples/search/retrieval/server.py index ce3071fd5..cde188342 100755 --- a/examples/search/retrieval/server.py +++ b/examples/search/retrieval/server.py @@ -149,7 +149,20 @@ def encode(self, query_list: list[str], is_query: bool = True) -> np.ndarray: class Config: """Configuration class for retrieval server.""" - def __init__(self, retrieval_method: str = "e5", retrieval_topk: int = 10, index_path: str = "./index/e5_Flat.index", corpus_path: str = "./data/corpus.jsonl", faiss_gpu: bool = True, gpu_id: int = 0, retrieval_model_path: str = "intfloat/e5-base-v2", retrieval_pooling_method: str = "mean", retrieval_query_max_length: int = 256, retrieval_use_fp16: bool = True, retrieval_batch_size: int = 512): + def __init__( + self, + retrieval_method: str = "e5", + retrieval_topk: int = 10, + index_path: str = "./index/e5_Flat.index", + corpus_path: str = "./data/corpus.jsonl", + faiss_gpu: bool = True, + gpu_id: int = 0, + retrieval_model_path: str = "intfloat/e5-base-v2", + retrieval_pooling_method: str = "mean", + retrieval_query_max_length: int = 256, + retrieval_use_fp16: bool = True, + retrieval_batch_size: int = 512, + ): self.retrieval_method = retrieval_method self.retrieval_topk = retrieval_topk self.index_path = index_path @@ -274,7 +287,14 @@ def __init__(self, config: Config): logger.info(f"Loaded corpus with {len(self.corpus)} documents") # Initialize encoder - self.encoder = Encoder(model_name=self.retrieval_method, model_path=config.retrieval_model_path, pooling_method=config.retrieval_pooling_method, max_length=config.retrieval_query_max_length, use_fp16=config.retrieval_use_fp16, gpu_id=config.gpu_id) + self.encoder = Encoder( + model_name=self.retrieval_method, + model_path=config.retrieval_model_path, + pooling_method=config.retrieval_pooling_method, + max_length=config.retrieval_query_max_length, + use_fp16=config.retrieval_use_fp16, + gpu_id=config.gpu_id, + ) self.topk = config.retrieval_topk self.batch_size = config.retrieval_batch_size @@ -371,7 +391,18 @@ class HealthResponse(BaseModel): @app.get("/health") def health_check(): """Health check endpoint.""" - return {"status": "healthy", "corpus_size": len(retriever.corpus), "index_type": "dense" if not config.retrieval_method == "bm25" else "bm25", "index_loaded": retriever.index is not None if hasattr(retriever, "index") else True, "retrieval_method": config.retrieval_method, "faiss_gpu": config.faiss_gpu, "batch_size": config.retrieval_batch_size, "topk": config.retrieval_topk, "model_path": config.retrieval_model_path, "use_fp16": config.retrieval_use_fp16} + return { + "status": "healthy", + "corpus_size": len(retriever.corpus), + "index_type": "dense" if not config.retrieval_method == "bm25" else "bm25", + "index_loaded": (retriever.index is not None if hasattr(retriever, "index") else True), + "retrieval_method": config.retrieval_method, + "faiss_gpu": config.faiss_gpu, + "batch_size": config.retrieval_batch_size, + "topk": config.retrieval_topk, + "model_path": config.retrieval_model_path, + "use_fp16": config.retrieval_use_fp16, + } @app.post("/retrieve") diff --git a/pyproject.toml b/pyproject.toml index 590d5405c..4ea3c3201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,6 +199,38 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "rllm/rewards/code_utils/**" = ["ALL"] +# Prompt / string-heavy files where long lines are intentional +"rllm/agents/system_prompts.py" = ["E501"] +"rllm/system_prompts.py" = ["E501"] +"rllm/agents/appworld_react_agents.py" = ["E501"] +"rllm/agents/code_agent.py" = ["E501"] +"rllm/agents/frozenlake_agent.py" = ["E501"] +"rllm/agents/miniwob_agent.py" = ["E501"] +"rllm/agents/webarena_agent.py" = ["E501"] +"agenthub/frozenlake_agent/agent/agent.py" = ["E501"] +"examples/archive/deepresearch/deepresearch_agent.py" = ["E501"] +"examples/fully_async/deepresearch/refine_agent.py" = ["E501"] +"examples/fully_async/deepresearch/search_agent.py" = ["E501"] +"examples/archive/sft/run_sft_model.py" = ["E501"] +"examples/archive/vimgolf/prepare_vimgolf_data.py" = ["E501"] +"examples/countdown/prepare_countdown_data.py" = ["E501"] +"examples/eval_protocol/prepare_frozen_lake_data.py" = ["E501"] +"examples/solver_judge/prepare_countdown_data.py" = ["E501"] +"cookbooks/solver_judge_flow/solver_judge_flow.py" = ["E501"] +"examples/sdk/solver_judge/solver_judge_flow_decorator.py" = ["E501"] +"examples/sdk/solver_judge/solver_judge_flow_session.py" = ["E501"] +"examples/solver_judge/solver_judge_flow.py" = ["E501"] +"rllm/experimental/cli/main.py" = ["E501"] +"rllm/experimental/rllm_telemetry/examples/experiment_comparison.py" = ["E501"] +"rllm/experimental/test_examples/opsd/math_opsd_workflow.py" = ["E501"] +"rllm/experimental/agents/tools/file_editor_tool.py" = ["E501"] +"rllm/rewards/math_reward.py" = ["E501"] +"projects/finqa/fin_qa_tools.py" = ["E501"] +"projects/finqa/scripts/data_generation/cleanup_tables.py" = ["E501"] +"tests/agents/test_appworld_agent.py" = ["E501"] +"tests/integration/test_agentcore_runtime.py" = ["E501"] +"tests/integration/test_remote_engine.py" = ["E501"] +"tests/rewards/test_code_reward.py" = ["E501"] [tool.setuptools.packages.find] where = ["."] diff --git a/rllm-model-gateway/src/rllm_model_gateway/client.py b/rllm-model-gateway/src/rllm_model_gateway/client.py index 48292fe8d..b245c3c64 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/client.py +++ b/rllm-model-gateway/src/rllm_model_gateway/client.py @@ -60,9 +60,7 @@ def get_session_info(self, session_id: str) -> dict[str, Any]: resp.raise_for_status() return resp.json() - def list_sessions( - self, since: float | None = None, limit: int | None = None - ) -> list[dict[str, Any]]: + def list_sessions(self, since: float | None = None, limit: int | None = None) -> list[dict[str, Any]]: params: dict[str, Any] = {} if since is not None: params["since"] = since @@ -90,9 +88,7 @@ def get_session_traces( params["since"] = since if limit is not None: params["limit"] = limit - resp = self._http.get( - f"{self.gateway_url}/sessions/{session_id}/traces", params=params - ) + resp = self._http.get(f"{self.gateway_url}/sessions/{session_id}/traces", params=params) resp.raise_for_status() data = resp.json() return [TraceRecord(**t) for t in data] @@ -188,9 +184,7 @@ async def get_session_info(self, session_id: str) -> dict[str, Any]: resp.raise_for_status() return resp.json() - async def list_sessions( - self, since: float | None = None, limit: int | None = None - ) -> list[dict[str, Any]]: + async def list_sessions(self, since: float | None = None, limit: int | None = None) -> list[dict[str, Any]]: params: dict[str, Any] = {} if since is not None: params["since"] = since @@ -218,9 +212,7 @@ async def get_session_traces( params["since"] = since if limit is not None: params["limit"] = limit - resp = await self._http.get( - f"{self.gateway_url}/sessions/{session_id}/traces", params=params - ) + resp = await self._http.get(f"{self.gateway_url}/sessions/{session_id}/traces", params=params) resp.raise_for_status() data = resp.json() return [TraceRecord(**t) for t in data] diff --git a/rllm-model-gateway/src/rllm_model_gateway/middleware.py b/rllm-model-gateway/src/rllm_model_gateway/middleware.py index 13cd2453e..6e428f52e 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/middleware.py +++ b/rllm-model-gateway/src/rllm_model_gateway/middleware.py @@ -67,17 +67,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # Inject sampling parameters into POST request bodies (chat completions, etc.) method = scope.get("method", "").upper() - needs_injection = ( - self.add_logprobs or self.add_return_token_ids or self.sessions is not None - ) + needs_injection = self.add_logprobs or self.add_return_token_ids or self.sessions is not None if method == "POST" and needs_injection: await self._inject_params(scope, receive, send, session_id) else: await self.app(scope, receive, send) - async def _inject_params( - self, scope: Scope, receive: Receive, send: Send, session_id: str | None = None - ) -> None: + async def _inject_params(self, scope: Scope, receive: Receive, send: Send, session_id: str | None = None) -> None: """Read body, inject sampling params, then forward with mutated body.""" body_parts: list[bytes] = [] more = True @@ -94,9 +90,7 @@ async def _inject_params( # Record whether the client originally requested logprobs # so the proxy can strip them from the response if not. state = scope["state"] - state["originally_requested_logprobs"] = ( - "logprobs" in payload and payload["logprobs"] - ) + state["originally_requested_logprobs"] = "logprobs" in payload and payload["logprobs"] self._mutate(payload, session_id) raw = json.dumps(payload).encode("utf-8") except (json.JSONDecodeError, UnicodeDecodeError): diff --git a/rllm-model-gateway/src/rllm_model_gateway/server.py b/rllm-model-gateway/src/rllm_model_gateway/server.py index 74b3cb18b..0990ec3cd 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/server.py +++ b/rllm-model-gateway/src/rllm_model_gateway/server.py @@ -359,10 +359,7 @@ def _load_config(args: argparse.Namespace) -> GatewayConfig: # Workers from CLI --worker flags (WorkerConfig validator auto-splits URLs) worker_urls = getattr(args, "worker", None) or [] if worker_urls: - data["workers"] = [ - {"url": raw_url, "worker_id": str(i)} - for i, raw_url in enumerate(worker_urls) - ] + data["workers"] = [{"url": raw_url, "worker_id": str(i)} for i, raw_url in enumerate(worker_urls)] return GatewayConfig(**data) @@ -373,9 +370,7 @@ def _load_config(args: argparse.Namespace) -> GatewayConfig: def main() -> None: - parser = argparse.ArgumentParser( - description="rllm-model-gateway: lightweight LLM call proxy for RL training" - ) + parser = argparse.ArgumentParser(description="rllm-model-gateway: lightweight LLM call proxy for RL training") parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=None) parser.add_argument("--config", type=str, default=None, help="Path to YAML config") @@ -398,9 +393,7 @@ def main() -> None: import uvicorn - uvicorn.run( - app, host=config.host, port=config.port, log_level=config.log_level.lower() - ) + uvicorn.run(app, host=config.host, port=config.port, log_level=config.log_level.lower()) if __name__ == "__main__": diff --git a/rllm-model-gateway/src/rllm_model_gateway/session_manager.py b/rllm-model-gateway/src/rllm_model_gateway/session_manager.py index 72ee79a0f..7c9df28f0 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/session_manager.py +++ b/rllm-model-gateway/src/rllm_model_gateway/session_manager.py @@ -23,9 +23,7 @@ def __init__(self, store: TraceStore) -> None: self._created_at: dict[str, float] = {} self._sampling_params: dict[str, dict[str, Any]] = {} - def ensure_session( - self, session_id: str, metadata: dict[str, Any] | None = None - ) -> str: + def ensure_session(self, session_id: str, metadata: dict[str, Any] | None = None) -> str: """Ensure a session exists (create if needed). Returns session_id.""" if session_id not in self._created_at: self._created_at[session_id] = time.time() diff --git a/rllm/agents/agent.py b/rllm/agents/agent.py index 5ef7094d9..f72e334d3 100644 --- a/rllm/agents/agent.py +++ b/rllm/agents/agent.py @@ -2,6 +2,7 @@ import uuid from abc import ABC, abstractmethod +from copy import deepcopy from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field @@ -23,6 +24,7 @@ class Step(_StepBase): prompt_ids: list[int] | list[Any] = Field(default_factory=list) response_ids: list[int] = Field(default_factory=list) logprobs: list[float] = Field(default_factory=list) + routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) chat_completions: list[dict[str, Any]] = Field(default_factory=list) @@ -38,6 +40,9 @@ class Step(_StepBase): # Per-token or scalar advantages advantage: list[float] | float | None = None + # weight version at time of generation (for async training staleness tracking) + weight_version: int | None = None + @property def info(self) -> dict: """Alias for metadata. Auto-initializes to {} if None so mutation works.""" @@ -50,6 +55,7 @@ def info(self, value: dict) -> None: self.metadata = value def model_post_init(self, __context: Any) -> None: + self.chat_completions = deepcopy(self.chat_completions) if self.model_output is None: return # backfill fields like prompt_ids, response_ids, logprobs, etc. @@ -59,17 +65,35 @@ def model_post_init(self, __context: Any) -> None: self.response_ids = self.model_output.completion_ids if len(self.logprobs) == 0 and self.model_output.logprobs is not None: self.logprobs = self.model_output.logprobs + if self.routing_matrices is None and getattr(self.model_output, "routing_matrices", None) is not None: + self.routing_matrices = self.model_output.routing_matrices + if self.weight_version is None and hasattr(self.model_output, "weight_version"): + self.weight_version = self.model_output.weight_version # check that the lengths would match up if len(self.logprobs) > 0: assert len(self.response_ids) == len(self.logprobs), f"length mismatch between response_ids and logprobs, got {len(self.response_ids)}, {len(self.logprobs)}" def to_dict(self) -> dict: + from rllm.tools.tool_base import ToolCall, ToolOutput + + # Helper function to recursively convert ToolCall and ToolOutput objects to dicts + def _serialize_value(value): + if isinstance(value, ToolCall | ToolOutput): + return value.to_dict() + elif isinstance(value, list): + return [_serialize_value(item) for item in value] + elif isinstance(value, dict): + return {k: _serialize_value(v) for k, v in value.items()} + else: + return value + return { "prompt_ids": self.prompt_ids, "response_ids": self.response_ids, "logprobs": self.logprobs, - "chat_completions": self.chat_completions, + "routing_matrices": self.routing_matrices, + "chat_completions": _serialize_value(self.chat_completions), "observation": self.observation, "thought": self.thought, "action": self.action.action if isinstance(self.action, Action) else self.action, @@ -80,6 +104,7 @@ def to_dict(self) -> dict: "done": self.done, "mc_return": self.mc_return, "advantage": self.advantage, + "weight_version": self.weight_version, } @classmethod @@ -90,6 +115,7 @@ def from_dict(cls, data: dict) -> Step: prompt_ids=data["prompt_ids"], response_ids=data["response_ids"], logprobs=data["logprobs"], + routing_matrices=data.get("routing_matrices"), chat_completions=data["chat_completions"], observation=data["observation"], thought=data["thought"], @@ -100,7 +126,8 @@ def from_dict(cls, data: dict) -> Step: reward=data["reward"], done=data["done"], mc_return=data["mc_return"], - advantage=data["advantage"], + advantage=data.get("advantage", 0.0), + weight_version=data.get("weight_version"), ) @classmethod @@ -109,11 +136,13 @@ def from_model_output(cls, model_output: ModelOutput, messages: list[dict] | Non prompt_ids=model_output.prompt_ids or [], response_ids=model_output.completion_ids or [], logprobs=model_output.logprobs or [], + routing_matrices=getattr(model_output, "routing_matrices", None), chat_completions=(messages or []) + [{"role": "assistant", "content": model_output.content, "reasoning": model_output.reasoning}], thought=model_output.reasoning or "", action=action, model_response=model_output.content or "", model_output=model_output, + weight_version=model_output.weight_version, ) @@ -259,6 +288,7 @@ class TrajectoryGroup(BaseModel): trajectories: list[Trajectory] group_id: str = "" metadata: list[dict] = Field(default_factory=list) + weight_version: int = 0 @property def group_role(self) -> str: diff --git a/rllm/agents/miniwob_agent.py b/rllm/agents/miniwob_agent.py index f77b7d31b..d944f6246 100644 --- a/rllm/agents/miniwob_agent.py +++ b/rllm/agents/miniwob_agent.py @@ -38,7 +38,17 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image) -> str: class MiniWobAgent(BaseAgent): - def __init__(self, chat_mode: bool = False, use_html: bool = True, use_axtree: bool = True, use_screenshot: bool = False, use_accumulate_thinking: bool = True, cot_prompt: bool = False, use_full_conversation: bool = True, use_reward_shaping: bool = False): + def __init__( + self, + chat_mode: bool = False, + use_html: bool = True, + use_axtree: bool = True, + use_screenshot: bool = False, + use_accumulate_thinking: bool = True, + cot_prompt: bool = False, + use_full_conversation: bool = True, + use_reward_shaping: bool = False, + ): self.chat_mode: bool = chat_mode self.use_html: bool = use_html self.use_axtree: bool = use_axtree @@ -217,7 +227,12 @@ def get_user_msgs(self, user_obs) -> list[dict[str, str]]: user_msgs.append({"type": "text", "text": self._get_action_space_description()}) # Add next action prompt - user_msgs.append({"type": "text", "text": "# Next action\nThe task has not been completed yet. You will now think step by step and produce your next best action. Reflect on your past actions, any resulting error message, and the current state of the page before deciding on your next action. The content must be in the same format as shown before in the Action Space. You can plan ahead but only 1 immediate action is needed."}) + user_msgs.append( + { + "type": "text", + "text": "# Next action\nThe task has not been completed yet. You will now think step by step and produce your next best action. Reflect on your past actions, any resulting error message, and the current state of the page before deciding on your next action. The content must be in the same format as shown before in the Action Space. You can plan ahead but only 1 immediate action is needed.", + } + ) return user_msgs diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index 07c33da39..1ea59c860 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -271,7 +271,12 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te except asyncio.TimeoutError: termination_reason = "ENV_TIMEOUT" if step_idx == 0: - colorful_print(f"Warning: Trajectory {idx} completed due to: {termination_reason} before able to perform 1 complete action. This might cause unexpected behavior. Consider increasing trajectory timeout limit.\n", "red") + colorful_print( + f"Warning: Trajectory {idx} completed due to: {termination_reason}" + " before able to perform 1 complete action. This might cause" + " unexpected behavior. Consider increasing trajectory timeout limit.\n", + "red", + ) reward = 0 cur_step = agent.get_current_state() @@ -307,9 +312,21 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te assistant_msg_tokens, assistant_msg_masks = [], [] env_msg_tokens, env_msg_masks = [], [] if assistant_message: - assistant_msg_tokens, assistant_msg_masks = convert_messages_to_tokens_and_masks([assistant_message], tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=False, contains_generation_msg=False) + assistant_msg_tokens, assistant_msg_masks = convert_messages_to_tokens_and_masks( + [assistant_message], + tokenizer=self.tokenizer, + parser=self.chat_parser, + contains_first_msg=False, + contains_generation_msg=False, + ) if env_messages: - env_msg_tokens, env_msg_masks = convert_messages_to_tokens_and_masks(env_messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=False, contains_generation_msg=True) + env_msg_tokens, env_msg_masks = convert_messages_to_tokens_and_masks( + env_messages, + tokenizer=self.tokenizer, + parser=self.chat_parser, + contains_first_msg=False, + contains_generation_msg=True, + ) # Update repsonse token length response_token_len += len(assistant_msg_tokens) + len(env_msg_tokens) @@ -473,7 +490,12 @@ def assemble_steps(self, steps: list[dict]): break if diff_pos is not None: - logger.warning(f"When assemble steps, detect the trajectory not accumulative at position {diff_pos}. Expected: {accumulated_sequence[diff_pos : diff_pos + 5]}, Got: {prefix[diff_pos : diff_pos + 5]}. Setting response_masks to all 0s. This is likely due to retokenization.") + logger.warning( + f"When assemble steps, detect the trajectory not accumulative at position {diff_pos}. " + f"Expected: {accumulated_sequence[diff_pos : diff_pos + 5]}, " + f"Got: {prefix[diff_pos : diff_pos + 5]}. " + "Setting response_masks to all 0s. This is likely due to retokenization." + ) else: logger.warning(f"When assemble steps, detect length mismatch. Expected length: {len(accumulated_sequence)}, Got length: {len(prefix)}. Setting response_masks to all 0s.") diff --git a/rllm/engine/agent_sdk_engine.py b/rllm/engine/agent_sdk_engine.py index 353d6e1ac..8019fb6d9 100644 --- a/rllm/engine/agent_sdk_engine.py +++ b/rllm/engine/agent_sdk_engine.py @@ -40,7 +40,18 @@ class AgentSdkEngine: - def __init__(self, agent_run_func: Callable, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, proxy_config: dict | None = None, tracer: Optional["TracerProtocol"] = None, **kwargs): + def __init__( + self, + agent_run_func: Callable, + rollout_engine: RolloutEngine, + config=None, + n_parallel_tasks: int = 128, + retry_limit: int = 3, + raise_on_error: bool = True, + proxy_config: dict | None = None, + tracer: Optional["TracerProtocol"] = None, + **kwargs, + ): """Initialize SdkEngine for executing agent_run_func on multiple tasks. Args: @@ -112,7 +123,13 @@ def _setup_verl_proxy(self, proxy_config: dict, tracer: Optional["TracerProtocol requires_sync_storage = SESSION_BACKEND == "opentelemetry" if requires_sync_storage and proxy_mode == "external": - logger.warning("OpenTelemetry-based sessions require synchronous storage mode for proper synchronization. When using external proxy mode, ensure the proxy is started with --sync-tracer flag. Alternatively, use proxy_mode='subprocess' to automatically enable sync storage. Without sync storage, there may be synchronization issues between tracer persistence and session reads.") + logger.warning( + "OpenTelemetry-based sessions require synchronous storage mode for proper synchronization. " + "When using external proxy mode, ensure the proxy is started with --sync-tracer flag. " + "Alternatively, use proxy_mode='subprocess' to automatically enable sync storage. " + "Without sync storage, there may be synchronization issues between tracer persistence " + "and session reads." + ) add_logprobs = proxy_config.get("add_logprobs", False) @@ -519,7 +536,11 @@ def transform_results_for_verl(self, episodes: list[Episode], task_ids: np.ndarr # if not self.config.rllm.stepwise_advantage.enable: # if len(trajectory.steps) > 1: # if not trajectory.is_cumulative(): - # logger.warning(f"Warning: Multi-step trajectory {trajectory_id} is not cumulative, but stepwise mode is not enabled. There could be a token mismatch during trajectory generation.") + # logger.warning( + # f"Warning: Multi-step trajectory {trajectory_id} is not cumulative, " + # "but stepwise mode is not enabled. There could be a token mismatch " + # "during trajectory generation." + # ) # chat_completions = trajectory.steps[-1].chat_completions # prompt, response, mask = self.rollout_engine.chat_parser.tokenize_and_mask_cumulative(chat_completions) @@ -669,7 +690,15 @@ def transform_results_for_verl(self, episodes: list[Episode], task_ids: np.ndarr if cf.enable: for i in range(len(episode_ids)): termination_reason = termination_reasons[i] - if (cf.mask_max_prompt_length_exceeded and termination_reason == TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) or (cf.mask_max_response_length_exceeded and termination_reason == TerminationReason.MAX_RESPONSE_LENGTH_EXCEEDED) or (cf.mask_env_done and termination_reason == TerminationReason.ENV_DONE) or (cf.mask_max_turns_exceeded and termination_reason == TerminationReason.MAX_TURNS_EXCEEDED) or (cf.mask_timeout and termination_reason == TerminationReason.TIMEOUT) or (cf.mask_unknown and termination_reason == TerminationReason.UNKNOWN) or (cf.mask_error and termination_reason == TerminationReason.ERROR): + if ( + (cf.mask_max_prompt_length_exceeded and termination_reason == TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) + or (cf.mask_max_response_length_exceeded and termination_reason == TerminationReason.MAX_RESPONSE_LENGTH_EXCEEDED) + or (cf.mask_env_done and termination_reason == TerminationReason.ENV_DONE) + or (cf.mask_max_turns_exceeded and termination_reason == TerminationReason.MAX_TURNS_EXCEEDED) + or (cf.mask_timeout and termination_reason == TerminationReason.TIMEOUT) + or (cf.mask_unknown and termination_reason == TerminationReason.UNKNOWN) + or (cf.mask_error and termination_reason == TerminationReason.ERROR) + ): is_valid[i] = False # set flag to filter out the episode later (after advantages are computed) # Build tensors dict, conditionally include rollout_log_probs if available diff --git a/rllm/engine/rollout/openai_engine.py b/rllm/engine/rollout/openai_engine.py index 60c130505..1f59b8c3e 100644 --- a/rllm/engine/rollout/openai_engine.py +++ b/rllm/engine/rollout/openai_engine.py @@ -15,7 +15,22 @@ class OpenAIEngine(RolloutEngine): - def __init__(self, model: str = "", tokenizer=None, chat_parser=None, max_prompt_length: int = 4096, max_response_length: int = 4096, max_model_length: int | None = None, api_retries: int = 3, base_url: str = "https://api.openai.com/v1", api_key: str = os.getenv("OPENAI_API_KEY"), sampling_params: dict | None = None, tools: list[Tool | dict] = None, accumulate_reasoning: bool = False, **kwargs): + def __init__( + self, + model: str = "", + tokenizer=None, + chat_parser=None, + max_prompt_length: int = 4096, + max_response_length: int = 4096, + max_model_length: int | None = None, + api_retries: int = 3, + base_url: str = "https://api.openai.com/v1", + api_key: str = os.getenv("OPENAI_API_KEY"), + sampling_params: dict | None = None, + tools: list[Tool | dict] = None, + accumulate_reasoning: bool = False, + **kwargs, + ): self.model = model self.max_prompt_length = max_prompt_length self.max_response_length = max_response_length diff --git a/rllm/environments/appworld/appworld_env.py b/rllm/environments/appworld/appworld_env.py index 746672efc..cc1967843 100644 --- a/rllm/environments/appworld/appworld_env.py +++ b/rllm/environments/appworld/appworld_env.py @@ -92,7 +92,12 @@ def reset(self): user_info = {} if self.world and hasattr(self.world, "task") and hasattr(self.world.task, "supervisor"): main_user = self.world.task.supervisor - user_info = {"first_name": main_user.first_name if hasattr(main_user, "first_name") else "User", "last_name": main_user.last_name if hasattr(main_user, "last_name") else "Test", "email": main_user.email if hasattr(main_user, "email") else "user@example.com", "phone_number": main_user.phone_number if hasattr(main_user, "phone_number") else "+1234567890"} + user_info = { + "first_name": main_user.first_name if hasattr(main_user, "first_name") else "User", + "last_name": main_user.last_name if hasattr(main_user, "last_name") else "Test", + "email": main_user.email if hasattr(main_user, "email") else "user@example.com", + "phone_number": main_user.phone_number if hasattr(main_user, "phone_number") else "+1234567890", + } else: # Default user info if not available user_info = {"first_name": "User", "last_name": "Test", "email": "user@example.com", "phone_number": "+1234567890"} @@ -106,7 +111,12 @@ def reset(self): "instruction": instruction, "user_info": user_info, "available_apps": ["spotify", "gmail", "calendar", "contacts", "messages", "notes", "todo", "files", "banking"], - "helper_apis": {"show_app_descriptions": "apis.api_docs.show_app_descriptions()", "show_api_descriptions": "apis.api_docs.show_api_descriptions(app_name='app')", "show_api_doc": "apis.api_docs.show_api_doc(app_name='app', api_name='api')", "complete_task": "apis.supervisor.complete_task(answer='your_answer')"}, + "helper_apis": { + "show_app_descriptions": "apis.api_docs.show_app_descriptions()", + "show_api_descriptions": "apis.api_docs.show_api_descriptions(app_name='app')", + "show_api_doc": "apis.api_docs.show_api_doc(app_name='app', api_name='api')", + "complete_task": "apis.supervisor.complete_task(answer='your_answer')", + }, "app_descriptions": app_descriptions, } diff --git a/rllm/environments/frozenlake/frozenlake.py b/rllm/environments/frozenlake/frozenlake.py index 9156ac830..c32210cae 100644 --- a/rllm/environments/frozenlake/frozenlake.py +++ b/rllm/environments/frozenlake/frozenlake.py @@ -209,7 +209,12 @@ def _get_player_position(self): def reset(self, task=None): task = task or {} - self.__init__(size=task.get("size", self.map_kwargs["size"]), p=task.get("p", self.map_kwargs["p"]), seed=task.get("seed", self.env_kwargs["seed"]), is_slippery=task.get("is_slippery", self.env_kwargs["is_slippery"])) + self.__init__( + size=task.get("size", self.map_kwargs["size"]), + p=task.get("p", self.map_kwargs["p"]), + seed=task.get("seed", self.env_kwargs["seed"]), + is_slippery=task.get("is_slippery", self.env_kwargs["is_slippery"]), + ) GymFrozenLakeEnv.reset(self, seed=self.seed) return self.render(mode="tiny_rgb_array"), {} diff --git a/rllm/environments/tools/mcp_env.py b/rllm/environments/tools/mcp_env.py index 16c5c4692..cd5b6c494 100644 --- a/rllm/environments/tools/mcp_env.py +++ b/rllm/environments/tools/mcp_env.py @@ -4,6 +4,7 @@ import threading import warnings from contextlib import AsyncExitStack +from dataclasses import dataclass from typing import Any try: @@ -17,6 +18,135 @@ from rllm.tools.mcp_tool import MCPTool +@dataclass(frozen=True) +class MCPServerSpec: + name: str + command: str + args: tuple[str, ...] = () + env_items: tuple[tuple[str, str], ...] | None = None + + @property + def args_list(self) -> list[str]: + return list(self.args) + + @property + def env_dict(self) -> dict[str, str] | None: + if self.env_items is None: + return None + return dict(self.env_items) + + +def _normalize_server_spec(name: str, config: dict[str, Any]) -> MCPServerSpec: + if not isinstance(config, dict): + raise ValueError(f"Config for MCP server '{name}' must be a dictionary") + + command = config.get("command", config.get("mcp_server_command")) + if not command: + raise ValueError(f"Config for MCP server '{name}' must include 'command'") + + raw_args = config.get("args", config.get("mcp_server_args")) or [] + if not isinstance(raw_args, list | tuple): + raise ValueError(f"Config for MCP server '{name}' must include list-like 'args'") + args = tuple(str(arg) for arg in raw_args) + + raw_env = config.get("env", config.get("mcp_server_env")) + if raw_env is not None and not isinstance(raw_env, dict): + raise ValueError(f"Config for MCP server '{name}' must include dict-like 'env'") + env_items = tuple(sorted((str(key), str(value)) for key, value in raw_env.items())) if raw_env is not None else None + + return MCPServerSpec(name=name, command=str(command), args=args, env_items=env_items) + + +def _normalize_mcp_servers( + mcp_server_command: str | None, + mcp_server_args: list[str] | None, + mcp_server_env: dict[str, str] | None, + mcp_servers: dict[str, dict[str, Any]] | None, +) -> dict[str, MCPServerSpec]: + has_legacy_config = mcp_server_command is not None or mcp_server_args is not None or mcp_server_env is not None + + if mcp_servers is not None and has_legacy_config: + raise ValueError("Cannot specify both legacy single-server MCP args and 'mcp_servers'") + + if mcp_servers is not None: + if not isinstance(mcp_servers, dict): + raise ValueError("'mcp_servers' must be a dictionary mapping server names to configs") + return {server_name: _normalize_server_spec(server_name, server_config) for server_name, server_config in mcp_servers.items()} + + if mcp_server_command is None: + return {} + + return { + "default": MCPServerSpec( + name="default", + command=str(mcp_server_command), + args=tuple(str(arg) for arg in (mcp_server_args or [])), + env_items=tuple(sorted((str(key), str(value)) for key, value in mcp_server_env.items())) if mcp_server_env is not None else None, + ) + } + + +def _tool_call_id(tool_call: Any, fallback_idx: int) -> str: + if isinstance(tool_call, dict): + tool_call_id = tool_call.get("id") + if isinstance(tool_call_id, str) and tool_call_id: + return tool_call_id + return f"tool_call_{fallback_idx}" + + +def _tool_call_name(tool_call: Any) -> str | None: + if not isinstance(tool_call, dict): + return None + function = tool_call.get("function") + if not isinstance(function, dict): + return None + tool_name = function.get("name") + if isinstance(tool_name, str) and tool_name: + return tool_name + return None + + +def _parse_tool_arguments(tool_call: Any) -> tuple[dict[str, Any] | None, str | None]: + if not isinstance(tool_call, dict): + return None, "Tool call must be a dictionary" + + function = tool_call.get("function") + if not isinstance(function, dict): + return None, "Tool call missing function payload" + + raw_arguments = function.get("arguments", {}) + if isinstance(raw_arguments, dict): + return raw_arguments, None + if isinstance(raw_arguments, str): + try: + parsed = json.loads(raw_arguments) + except json.JSONDecodeError as exc: + return None, f"Invalid tool arguments JSON: {exc}" + if not isinstance(parsed, dict): + return None, "Tool arguments JSON must decode to an object" + return parsed, None + return None, "Tool arguments must be a dict or JSON string" + + +def _assign_missing_tool_call_ids(tool_calls: list[Any]) -> list[Any]: + normalized_tool_calls: list[Any] = [] + for idx, tool_call in enumerate(tool_calls): + if not isinstance(tool_call, dict): + normalized_tool_calls.append(tool_call) + continue + + tool_call_id = tool_call.get("id") + if isinstance(tool_call_id, str) and tool_call_id: + normalized_tool_calls.append(tool_call) + continue + + normalized_tool_call = dict(tool_call) + normalized_tool_call["id"] = _tool_call_id(tool_call, idx) + normalized_tool_calls.append(normalized_tool_call) + + return normalized_tool_calls + + class MCPConnectionManager: """Manages MCP connections in a dedicated thread to avoid asyncio context issues.""" @@ -154,16 +284,24 @@ async def _execute_tools(self, tool_calls: list[dict[str, Any]]) -> dict[str, st """Execute tool calls.""" tool_outputs: dict[str, str] = {} - for tool_call in tool_calls: - tool_name = tool_call["function"]["name"] - tool_args = json.loads(tool_call["function"]["arguments"]) + for idx, tool_call in enumerate(tool_calls): + tool_call_id = _tool_call_id(tool_call, idx) + tool_name = _tool_call_name(tool_call) + if tool_name is None: + tool_outputs[tool_call_id] = "Error: Tool call missing function.name" + continue + + tool_args, parse_error = _parse_tool_arguments(tool_call) + if parse_error is not None or tool_args is None: + tool_outputs[tool_call_id] = f"Error: {parse_error}" + continue if tool_name in self.tool_map: tool_instance = self.tool_map[tool_name] result = await tool_instance.async_forward(**tool_args) - tool_outputs[tool_call["id"]] = result.to_string() + tool_outputs[tool_call_id] = result.to_string() else: - tool_outputs[tool_call["id"]] = f"Error: Tool {tool_name} not found" + tool_outputs[tool_call_id] = f"Error: Tool {tool_name} not found" return tool_outputs @@ -179,11 +317,23 @@ class MCPEnvironment(BaseEnv): Uses a dedicated connection manager to avoid asyncio context issues. """ - # Class-level connection manager to share across instances - _connection_manager: MCPConnectionManager | None = None + # Class-level connection managers shared across instances + _connection_manager: MCPConnectionManager | None = None # backward-compatible alias for single-server usage + _connection_managers: dict[str, MCPConnectionManager] = {} + _server_specs: dict[str, MCPServerSpec] = {} _manager_lock = threading.Lock() - def __init__(self, task: dict[str, Any] | None = None, mcp_server_command: str | None = None, mcp_server_args: list[str] | None = None, mcp_server_env: dict[str, str] | None = None, reward_fn: RewardFunction | None = None, max_steps: int = 10): + def __init__( + self, + task: dict[str, Any] | None = None, + mcp_server_command: str | None = None, + mcp_server_args: list[str] | None = None, + mcp_server_env: dict[str, str] | None = None, + mcp_servers: dict[str, dict[str, Any]] | None = None, + tool_name_to_server_name: dict[str, str] | None = None, + reward_fn: RewardFunction | None = None, + max_steps: int = 10, + ): """ Initialize the MCPEnvironment. @@ -192,6 +342,9 @@ def __init__(self, task: dict[str, Any] | None = None, mcp_server_command: str | mcp_server_command: Command to run the MCP server. mcp_server_args: Arguments for the MCP server. mcp_server_env: Environment variables for the MCP server. + mcp_servers: Named MCP server configurations for multi-server routing. + tool_name_to_server_name: Optional explicit mapping from public tool names + to MCP server names. reward_fn: Reward function to use for evaluation. max_steps: Maximum number of steps allowed in the environment. """ @@ -206,12 +359,191 @@ def __init__(self, task: dict[str, Any] | None = None, mcp_server_command: str | self.mcp_server_command = mcp_server_command self.mcp_server_args = mcp_server_args or [] self.mcp_server_env = mcp_server_env + self.mcp_servers = _normalize_mcp_servers(mcp_server_command, mcp_server_args, mcp_server_env, mcp_servers) + self.tool_name_to_server_name = dict(tool_name_to_server_name or {}) + self._resolved_tool_name_to_server_name: dict[str, str] = {} - # Initialize shared connection manager - with MCPEnvironment._manager_lock: - if MCPEnvironment._connection_manager is None and mcp_server_command is not None: - MCPEnvironment._connection_manager = MCPConnectionManager(mcp_server_command, mcp_server_args, mcp_server_env) - MCPEnvironment._connection_manager.start() + newly_created_server_names: list[str] = [] + try: + newly_created_server_names = self._ensure_connection_managers() + self._resolved_tool_name_to_server_name = self._build_tool_routing() + except Exception: + if newly_created_server_names: + self._rollback_connection_managers(newly_created_server_names) + raise + + @classmethod + def _sync_connection_manager_alias_locked(cls) -> None: + cls._connection_manager = next(iter(cls._connection_managers.values())) if len(cls._connection_managers) == 1 else None + + @classmethod + def _rollback_connection_managers(cls, server_names: list[str]) -> None: + managers_to_stop: list[MCPConnectionManager] = [] + with cls._manager_lock: + for server_name in server_names: + manager = cls._connection_managers.pop(server_name, None) + cls._server_specs.pop(server_name, None) + if manager is not None: + managers_to_stop.append(manager) + cls._sync_connection_manager_alias_locked() + + for manager in managers_to_stop: + try: + manager.stop() + except Exception: + pass + + def _ensure_connection_managers(self) -> list[str]: + newly_created_server_names: list[str] = [] + managers_to_stop: list[MCPConnectionManager] = [] + + try: + with MCPEnvironment._manager_lock: + for server_name, server_spec in self.mcp_servers.items(): + existing_spec = MCPEnvironment._server_specs.get(server_name) + if existing_spec is not None: + if existing_spec != server_spec: + raise ValueError(f"MCP server '{server_name}' is already initialized with a different configuration") + continue + + manager = MCPConnectionManager( + mcp_server_command=server_spec.command, + mcp_server_args=server_spec.args_list, + mcp_server_env=server_spec.env_dict, + ) + try: + manager.start() + except Exception: + managers_to_stop.append(manager) + raise + + MCPEnvironment._connection_managers[server_name] = manager + MCPEnvironment._server_specs[server_name] = server_spec + newly_created_server_names.append(server_name) + + MCPEnvironment._sync_connection_manager_alias_locked() + except Exception: + with MCPEnvironment._manager_lock: + for server_name in newly_created_server_names: + manager = MCPEnvironment._connection_managers.pop(server_name, None) + MCPEnvironment._server_specs.pop(server_name, None) + if manager is not None: + managers_to_stop.append(manager) + MCPEnvironment._sync_connection_manager_alias_locked() + + for manager in managers_to_stop: + try: + manager.stop() + except Exception: + pass + raise + + return newly_created_server_names + + def _build_tool_routing(self) -> dict[str, str]: + if not self.mcp_servers: + return {} + + discovered_tool_servers: dict[str, set[str]] = {} + for server_name in self.mcp_servers: + manager = MCPEnvironment._connection_managers.get(server_name) + if manager is None: + continue + for public_tool_name in getattr(manager, "tool_map", {}): + discovered_tool_servers.setdefault(public_tool_name, set()).add(server_name) + + resolved: dict[str, str] = {} + + for public_tool_name, candidate_servers in discovered_tool_servers.items(): + explicit_server_name = self.tool_name_to_server_name.get(public_tool_name) + if explicit_server_name is not None: + if explicit_server_name not in candidate_servers: + raise ValueError(f"Tool '{public_tool_name}' is not provided by mapped MCP server '{explicit_server_name}'") + resolved[public_tool_name] = explicit_server_name + elif len(candidate_servers) == 1: + resolved[public_tool_name] = next(iter(candidate_servers)) + else: + raise ValueError(f"Tool '{public_tool_name}' is provided by multiple MCP servers {sorted(candidate_servers)}. Supply 'tool_name_to_server_name' to disambiguate.") + + for public_tool_name, mapped_server_name in self.tool_name_to_server_name.items(): + if mapped_server_name not in self.mcp_servers: + raise ValueError(f"Tool mapping for '{public_tool_name}' references unknown MCP server '{mapped_server_name}'") + if public_tool_name not in discovered_tool_servers: + raise ValueError(f"Tool mapping for '{public_tool_name}' does not match any discovered tool on the configured MCP servers") + + return resolved + + @staticmethod + def _is_finish_tool_call(tool_call: Any) -> bool: + return _tool_call_name(tool_call) == "finish" + + def _extract_final_response(self, action: list[dict[str, Any]] | str) -> str: + if isinstance(action, str): + return action + + finish_action = None + for tool_call in action: + if self._is_finish_tool_call(tool_call): + finish_action = tool_call + break + + if finish_action is None: + return str(action) + + arguments, parse_error = _parse_tool_arguments(finish_action) + if parse_error is not None or arguments is None: + return str(action) + + response = arguments.get("response", "") + return response if isinstance(response, str) else str(response) + + def _execute_tool_calls_by_server(self, tool_calls: list[dict[str, Any]]) -> dict[str, str]: + tool_calls = _assign_missing_tool_call_ids(tool_calls) + tool_outputs: dict[str, str] = {} + grouped_calls: dict[str, list[dict[str, Any]]] = {} + + for idx, tool_call in enumerate(tool_calls): + tool_call_id = _tool_call_id(tool_call, idx) + tool_name = _tool_call_name(tool_call) + if tool_name is None: + tool_outputs[tool_call_id] = "Error: Tool call missing function.name" + continue + + server_name = self._resolved_tool_name_to_server_name.get(tool_name) + if server_name is None and len(self.mcp_servers) == 1: + # Preserve legacy single-server behavior where every tool call is + # forwarded to the sole configured MCP server. + server_name = next(iter(self.mcp_servers)) + if server_name is None: + tool_outputs[tool_call_id] = f"Error: Tool {tool_name} not found" + continue + + grouped_calls.setdefault(server_name, []).append(tool_call) + + for server_name, grouped_tool_calls in grouped_calls.items(): + manager = MCPEnvironment._connection_managers.get(server_name) + if manager is None: + for idx, tool_call in enumerate(grouped_tool_calls): + tool_outputs[_tool_call_id(tool_call, idx)] = f"Error: MCP server {server_name} is not available" + continue + + try: + tool_outputs.update(manager.execute_tool_calls(grouped_tool_calls)) + except Exception as exc: + for idx, tool_call in enumerate(grouped_tool_calls): + tool_outputs[_tool_call_id(tool_call, idx)] = f"Error: MCP server {server_name} failed: {exc}" + + ordered_tool_outputs: dict[str, str] = {} + for idx, tool_call in enumerate(tool_calls): + tool_call_id = _tool_call_id(tool_call, idx) + if tool_call_id in tool_outputs: + ordered_tool_outputs[tool_call_id] = tool_outputs[tool_call_id] + + for tool_call_id, tool_output in tool_outputs.items(): + if tool_call_id not in ordered_tool_outputs: + ordered_tool_outputs[tool_call_id] = tool_output + + return ordered_tool_outputs def reset(self): """Reset the environment and return initial observations.""" @@ -239,32 +571,13 @@ def step(self, action: Any): # Check if action contains a "finish" tool call if isinstance(action, list) and action: for tool_call in action: - if tool_call.get("function", {}).get("name") == "finish": + if self._is_finish_tool_call(tool_call): done = True break if done: # Agent is done - evaluate the response - if isinstance(action, str): - llm_response = action - elif isinstance(action, list): - # Find the finish tool call - finish_action = None - for tool_call in action: - if tool_call.get("function", {}).get("name") == "finish": - finish_action = tool_call - break - if finish_action: - arguments = finish_action.get("function", {}).get("arguments", {}) - if isinstance(arguments, str): - arguments = json.loads(arguments) - - if isinstance(arguments, dict): - llm_response = arguments.get("response", "") - else: - llm_response = str(arguments) - else: - llm_response = str(action) + llm_response = self._extract_final_response(action) if self.reward_fn and self.task is not None: reward_output = self.reward_fn(task_info=self.task, action=llm_response) @@ -275,11 +588,8 @@ def step(self, action: Any): # Execute tool calls using the connection manager tool_calls = action try: - if MCPEnvironment._connection_manager is not None: - tool_outputs = MCPEnvironment._connection_manager.execute_tool_calls(tool_calls) - next_obs = {"tool_outputs": tool_outputs} - else: - next_obs = {"tool_outputs": {}} + tool_outputs = self._execute_tool_calls_by_server(tool_calls) if isinstance(tool_calls, list) else {} + next_obs = {"tool_outputs": tool_outputs} except Exception as e: print(f"Tool execution error: {e}") next_obs = {"tool_outputs": {}} @@ -293,17 +603,37 @@ def close(self): @staticmethod def cleanup_global_resources(): - """Clean up global connection manager.""" + """Clean up global connection managers.""" + managers_to_stop: list[MCPConnectionManager] = [] with MCPEnvironment._manager_lock: - if MCPEnvironment._connection_manager: - MCPEnvironment._connection_manager.stop() - MCPEnvironment._connection_manager = None + managers_to_stop = list(MCPEnvironment._connection_managers.values()) + MCPEnvironment._connection_managers = {} + MCPEnvironment._server_specs = {} + MCPEnvironment._sync_connection_manager_alias_locked() + + for manager in managers_to_stop: + try: + manager.stop() + except Exception: + pass @staticmethod def from_dict(env_args: dict[str, Any]) -> "MCPEnvironment": + env_args = dict(env_args) mcp_server_command = env_args.pop("mcp_server_command", None) mcp_server_args = env_args.pop("mcp_server_args", None) mcp_server_env = env_args.pop("mcp_server_env", None) + mcp_servers = env_args.pop("mcp_servers", None) + tool_name_to_server_name = env_args.pop("tool_name_to_server_name", None) reward_fn = env_args.pop("reward_fn", None) max_steps = env_args.pop("max_steps", 10) - return MCPEnvironment(task=env_args, mcp_server_command=mcp_server_command, mcp_server_args=mcp_server_args, mcp_server_env=mcp_server_env, max_steps=max_steps, reward_fn=reward_fn) + return MCPEnvironment( + task=env_args, + mcp_server_command=mcp_server_command, + mcp_server_args=mcp_server_args, + mcp_server_env=mcp_server_env, + mcp_servers=mcp_servers, + tool_name_to_server_name=tool_name_to_server_name, + max_steps=max_steps, + reward_fn=reward_fn, + ) diff --git a/rllm/experimental/buffer.py b/rllm/experimental/buffer.py new file mode 100644 index 000000000..288c72fcc --- /dev/null +++ b/rllm/experimental/buffer.py @@ -0,0 +1,273 @@ +"""TrajectoryGroupBuffer for async training. + +Accumulates episodes, processes into ready-to-train trajectory groups, +with optional NVMe offloading for memory management. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import pickle +import tempfile +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tqdm import tqdm + +from rllm.agents.agent import Episode, TrajectoryGroup +from rllm.experimental.common import ( + AlgorithmConfig, + CompactFilteringConfig, + RejectionSamplingConfig, + TransformConfig, + collect_reward_and_advantage_from_trajectory_groups, +) +from rllm.experimental.common.transform import transform_episodes_to_trajectory_groups +from rllm.experimental.metrics import MetricsAggregator +from rllm.experimental.sync_coordinator import SyncCoordinator +from rllm.workflows.workflow import TerminationReason + +logger = logging.getLogger(__name__) + + +@dataclass +class TaskBatch: + """All trajectory groups produced from one task's episodes, plus stripped episodes for UI logging.""" + + groups: list[TrajectoryGroup] + episodes: list[Episode] = field(default_factory=list) + + +class TrajectoryGroupBuffer: + """Accumulates episodes, processes into trajectory groups, yields to training. + + When all rollouts for a task arrive: + 1. Record episode-level metrics to aggregator (before any filtering) + 2. Transform episodes -> trajectory groups + 3. Compact filtering + drop groups with < min_trajs_per_group + 4. Compute advantages + 5. If rejection sampling enabled: drop groups with all-zero advantage + 6. Queue the task batch for training + + Filtered groups are reported directly to the coordinator (which tracks + throttle slots and filter counts). Only non-empty task batches are queued. + All metrics flow through the shared MetricsAggregator. + + Optionally offloads pending episodes and/or queued task batches to + disk to reduce memory pressure (disabled by default). + """ + + def __init__( + self, + group_size: int, + coordinator: SyncCoordinator, + aggregator: MetricsAggregator, + algorithm_config: AlgorithmConfig, + transform_config: TransformConfig, + cf_config: CompactFilteringConfig, + rs_config: RejectionSamplingConfig, + episode_offload_dir: str | None = None, + trajectory_group_offload_dir: str | None = None, + pbar: tqdm | None = None, + ): + self._group_size = group_size + self._coordinator = coordinator + self._aggregator = aggregator + self._algorithm_config = algorithm_config + self._transform_config = transform_config + self._cf_config = cf_config + self._rs_config = rs_config + self._pbar = pbar + + # Episode offloading: pending episodes serialized to disk + self._episode_offload_dir = episode_offload_dir + if episode_offload_dir: + os.makedirs(episode_offload_dir, exist_ok=True) + self._pending: dict[str, list[Episode | str]] = {} # str = offloaded file path + + # Trajectory group offloading: queued task batches serialized to disk + self._tg_offload_dir = trajectory_group_offload_dir + if trajectory_group_offload_dir: + os.makedirs(trajectory_group_offload_dir, exist_ok=True) + self._queue: asyncio.Queue[TaskBatch | str | None] = asyncio.Queue() + + async def _offload_episode(self, task_id: str, episode: Episode) -> str: + """Serialize episode to disk, return file path.""" + idx = len(self._pending.get(task_id, [])) + path = os.path.join(self._episode_offload_dir, f"{task_id}_{idx}.pkl") + await asyncio.to_thread(self._pickle_dump, path, episode) + return path + + async def _load_pending_episodes(self, task_id: str) -> list[Episode]: + """Load all pending episodes for a task, deserializing offloaded ones.""" + episodes = [] + for item in self._pending.pop(task_id, []): + if isinstance(item, str): + ep = await asyncio.to_thread(self._pickle_load, item) + episodes.append(ep) + else: + episodes.append(item) + return episodes + + async def _offload_task_batch(self, batch: TaskBatch) -> str: + """Serialize task batch to disk, return file path.""" + fd, path = tempfile.mkstemp(dir=self._tg_offload_dir, suffix=".pkl") + os.close(fd) + await asyncio.to_thread(self._pickle_dump, path, batch) + return path + + async def _load_task_batch(self, item: TaskBatch | str) -> TaskBatch: + """Load task batch, deserializing if offloaded.""" + if isinstance(item, str): + return await asyncio.to_thread(self._pickle_load, item) + return item + + @staticmethod + def _pickle_dump(path: str, obj) -> None: + with open(path, "wb") as f: + pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) + + @staticmethod + def _pickle_load(path: str): + with open(path, "rb") as f: + obj = pickle.load(f) + os.remove(path) + return obj + + async def add_episode(self, task_id: str, episode: Episode) -> bool: + """Add episode. When group completes, process and queue task batch.""" + # Offload episode to disk if enabled + if self._episode_offload_dir: + path = await self._offload_episode(task_id, episode) + self._pending.setdefault(task_id, []).append(path) + else: + self._pending.setdefault(task_id, []).append(episode) + + if len(self._pending[task_id]) < self._group_size: + return False + + # Group complete — tick progress bar + if self._pbar is not None: + self._pbar.update(1) + + # Load all episodes + if self._episode_offload_dir: + episodes = await self._load_pending_episodes(task_id) + else: + episodes = self._pending.pop(task_id, []) + + weight_version = self._min_weight_version(episodes) + + # 1. Record episode-level metrics (includes filtered tasks) + self._record_episode_metrics(episodes) + + # 2. Transform episodes -> trajectory groups + traj_groups, transform_metrics = transform_episodes_to_trajectory_groups( + episodes, + self._transform_config, + self._cf_config, + ) + self._aggregator.record_dict(transform_metrics) + + # 3. Drop groups with too few trajectories + before_min_traj = len(traj_groups) + traj_groups = [g for g in traj_groups if len(g.trajectories) >= self._rs_config.min_trajs_per_group] + self._aggregator.record("groups/dropped_min_trajs", before_min_traj - len(traj_groups)) + + if not traj_groups: + self._coordinator.on_group_filtered() + return True + + # 4. Compute advantages + adv_metrics = collect_reward_and_advantage_from_trajectory_groups( + traj_groups, + self._algorithm_config, + ) + self._aggregator.record_dict(adv_metrics) + + # 5. Rejection sampling: drop groups with all-zero advantage + filtered_zero_adv = 0 + if self._rs_config.filter_uniform_groups: + before_adv = len(traj_groups) + traj_groups = [g for g in traj_groups if any(abs(step.advantage) > 1e-8 for traj in g.trajectories for step in traj.steps if step.advantage is not None)] + filtered_zero_adv = before_adv - len(traj_groups) + self._aggregator.record("groups/dropped_zero_adv", filtered_zero_adv) + + if not traj_groups: + self._coordinator.on_group_filtered() + return True + + # 6. Set weight version and queue + for g in traj_groups: + g.weight_version = weight_version + + batch = TaskBatch(groups=traj_groups, episodes=episodes) + if self._tg_offload_dir: + await self._queue.put(await self._offload_task_batch(batch)) + else: + await self._queue.put(batch) + + return True + + async def get(self) -> TaskBatch | None: + """Get next task batch. Returns None when generation is done and buffer is drained.""" + item = await self._queue.get() + if item is None: + return None + return await self._load_task_batch(item) + + def mark_generation_complete(self) -> None: + """Signal that generation is finished. Flushes incomplete groups and enqueues a sentinel.""" + for task_id in list(self._pending.keys()): + items = self._pending.pop(task_id, []) + for item in items: + if isinstance(item, str): + try: + os.remove(item) + except OSError: + pass + self._coordinator.on_group_filtered() + self._queue.put_nowait(None) + + def stats(self) -> dict: + return { + "async/buffer_qsize": self._queue.qsize(), + "async/buffer_pending": len(self._pending), + } + + def _record_episode_metrics(self, episodes: list[Episode]) -> None: + """Record episode-level metrics to aggregator (all episodes, including filtered).""" + for ep in episodes: + reason = ep.termination_reason or TerminationReason.UNKNOWN + for r in TerminationReason: + self._aggregator.record( + f"episode/termination_reason/{r.value}", + 1.0 if reason == r else 0.0, + ) + for k, v in ep.metrics.items(): + try: + self._aggregator.record(f"episode/{k}", float(v)) + except (TypeError, ValueError): + continue + + # Episode-level totals across all trajectories + total_turns = sum(len(traj.steps) for traj in ep.trajectories) + total_prompt_tokens = sum(len(s.prompt_ids) for traj in ep.trajectories for s in traj.steps) + total_response_tokens = sum(len(s.response_ids) for traj in ep.trajectories for s in traj.steps) + self._aggregator.record("episode/num_turns", total_turns) + self._aggregator.record("episode/prompt_tokens", total_prompt_tokens) + self._aggregator.record("episode/response_tokens", total_response_tokens) + self._aggregator.record("episode/correct", 1.0 if ep.is_correct else 0.0) + + @staticmethod + def _min_weight_version(episodes: list[Episode]) -> int: + min_v = float("inf") + for ep in episodes: + for traj in ep.trajectories: + for step in traj.steps: + if step.weight_version is not None: + min_v = min(min_v, step.weight_version) + return int(min_v) if min_v != float("inf") else 0 diff --git a/rllm/experimental/cli/eval.py b/rllm/experimental/cli/eval.py index 884738c31..40ff766ff 100644 --- a/rllm/experimental/cli/eval.py +++ b/rllm/experimental/cli/eval.py @@ -30,7 +30,19 @@ def _suggest_benchmarks(name: str, catalog_names: list[str], max_suggestions: in return get_close_matches(name, catalog_names, n=max_suggestions, cutoff=0.5) -def _run_eval(benchmark: str, agent_name: str, evaluator_name: str | None, base_url: str, model: str, split: str, concurrency: int, max_examples: int | None, output_path: str | None, agent_metadata: dict | None = None, enable_ui: bool = False): +def _run_eval( + benchmark: str, + agent_name: str, + evaluator_name: str | None, + base_url: str, + model: str, + split: str, + concurrency: int, + max_examples: int | None, + output_path: str | None, + agent_metadata: dict | None = None, + enable_ui: bool = False, +): """Core eval logic, extracted for clean proxy lifecycle management.""" from rllm.data import DatasetRegistry from rllm.experimental.eval.agent_loader import load_agent @@ -238,11 +250,37 @@ def on_episode_complete(episode): @click.option("--concurrency", default=64, type=int, help="Number of parallel requests.") @click.option("--max-examples", default=None, type=int, help="Limit number of examples (for dev/testing).") @click.option("--output", "output_path", default=None, help="Output file path for results JSON.") -@click.option("--search-backend", "search_backend", default=None, type=click.Choice(["serper", "brave"], case_sensitive=False), help="Search backend for the search agent (auto-detected from API keys if omitted).") -@click.option("--sandbox-backend", "sandbox_backend", default=None, type=click.Choice(["docker", "local", "modal"], case_sensitive=False), help="Sandbox backend for sandboxed agents (auto-detected from agent if omitted).") +@click.option( + "--search-backend", + "search_backend", + default=None, + type=click.Choice(["serper", "brave"], case_sensitive=False), + help="Search backend for the search agent (auto-detected from API keys if omitted).", +) +@click.option( + "--sandbox-backend", + "sandbox_backend", + default=None, + type=click.Choice(["docker", "local", "modal"], case_sensitive=False), + help="Sandbox backend for sandboxed agents (auto-detected from agent if omitted).", +) @click.option("--sandbox-concurrency", "sandbox_concurrency", default=None, type=int, help="Override max concurrent sandboxes (default: agent's max_concurrent).") @click.option("--ui/--no-ui", "enable_ui", default=None, help="Enable/disable live UI logging. Default: auto-enabled when logged in (see 'rllm login').") -def eval_cmd(benchmark: str, agent_name: str | None, evaluator_name: str | None, base_url: str | None, model: str | None, split: str | None, concurrency: int, max_examples: int | None, output_path: str | None, search_backend: str | None, sandbox_backend: str | None, sandbox_concurrency: int | None, enable_ui: bool | None): +def eval_cmd( + benchmark: str, + agent_name: str | None, + evaluator_name: str | None, + base_url: str | None, + model: str | None, + split: str | None, + concurrency: int, + max_examples: int | None, + output_path: str | None, + search_backend: str | None, + sandbox_backend: str | None, + sandbox_concurrency: int | None, + enable_ui: bool | None, +): """Evaluate a model on a benchmark dataset.""" # Auto-detect UI logging: enable if user is logged in (has ui_api_key or RLLM_API_KEY) _ui_explicit = enable_ui is not None diff --git a/rllm/experimental/cli/init.py b/rllm/experimental/cli/init.py index e6e09de79..22afae7f0 100644 --- a/rllm/experimental/cli/init.py +++ b/rllm/experimental/cli/init.py @@ -173,7 +173,9 @@ def init_cmd(project_name: str | None, template: str | None, evaluator: bool, ou console.print() console.print( Panel( - f"[bold green]Project created:[/bold green] {project_dir}\n\n[bold]Template:[/bold] {tpl_info['label']}\n[bold]Agent:[/bold] {module_name}.agent:{agent_instance}\n" + (f"[bold]Evaluator:[/bold] {module_name}.evaluator:{evaluator_class}\n" if evaluator else ""), + f"[bold green]Project created:[/bold green] {project_dir}\n\n" + f"[bold]Template:[/bold] {tpl_info['label']}\n" + f"[bold]Agent:[/bold] {module_name}.agent:{agent_instance}\n" + (f"[bold]Evaluator:[/bold] {module_name}.evaluator:{evaluator_class}\n" if evaluator else ""), title="[bold cyan]rllm init[/bold cyan]", border_style="cyan", ) diff --git a/rllm/experimental/common/__init__.py b/rllm/experimental/common/__init__.py index ed169b372..b75f43c90 100644 --- a/rllm/experimental/common/__init__.py +++ b/rllm/experimental/common/__init__.py @@ -7,8 +7,10 @@ from rllm.experimental.common.advantage import collect_reward_and_advantage_from_trajectory_groups from rllm.experimental.common.config import ( AlgorithmConfig, + AsyncTrainingConfig, CompactFilteringConfig, RejectionSamplingConfig, + RolloutCorrectionConfig, TransformConfig, rLLMAdvantageEstimator, ) @@ -24,8 +26,10 @@ __all__ = [ # Config + "AsyncTrainingConfig", "CompactFilteringConfig", "RejectionSamplingConfig", + "RolloutCorrectionConfig", "TransformConfig", "AlgorithmConfig", # Transform pipeline diff --git a/rllm/experimental/common/advantage.py b/rllm/experimental/common/advantage.py index be808fe77..0dd213077 100644 --- a/rllm/experimental/common/advantage.py +++ b/rllm/experimental/common/advantage.py @@ -49,7 +49,10 @@ def get_rllm_adv_estimator(name: str | rLLMAdvantageEstimator) -> Callable: @register_rllm_adv_estimator(rLLMAdvantageEstimator.GRPO) def calculate_grpo_advantages(rewards: list[np.ndarray], norm_adv_by_std_in_grpo=True, episilon=1e-6, **kwargs) -> tuple[list[np.ndarray], list[np.ndarray]]: - advantages_by_group, returns_by_group = zip(*[calculate_grpo_advantages_per_group(group_rewards, norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, episilon=episilon) for group_rewards in rewards], strict=True) + advantages_by_group, returns_by_group = zip( + *[calculate_grpo_advantages_per_group(group_rewards, norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, episilon=episilon) for group_rewards in rewards], + strict=True, + ) return advantages_by_group, returns_by_group diff --git a/rllm/experimental/common/config.py b/rllm/experimental/common/config.py index 5399da600..dee2f2d13 100644 --- a/rllm/experimental/common/config.py +++ b/rllm/experimental/common/config.py @@ -8,6 +8,38 @@ from rllm.workflows.workflow import TerminationReason +@dataclass +class AsyncTrainingConfig: + """Controls the async training behavior spectrum. + + When `enable` is False, the trainer uses the current synchronous pipeline. + When `enable` is True, the trainer runs concurrent generation + training + with group-level streaming and dispatch-time throttle. + + Behavior spectrum: + - staleness_threshold=0, trigger_parameter_sync_step=1: On-policy + - staleness_threshold=0, trigger_parameter_sync_step=K: Stream off-policy + - staleness_threshold>0, partial_rollout=False: Async with staleness + - staleness_threshold>0, partial_rollout=True: Async with partial rollout + """ + + enable: bool = False + mini_batch_size: int = 1 # episode groups per optimizer step + fwd_bwd_group_size: int | None = None # task batches per forward-backward pass (default: mini_batch_size) + staleness_threshold: float = 0.0 # 0.0 = on-policy. Controls dispatch throttle quota. + trigger_parameter_sync_step: int = 1 # optimizer steps between weight sync + version bump + partial_rollout: bool = True # enable turn-level gating during weight sync + episode_offload_dir: str | None = None # NVMe offload dir for pending episodes (None = disabled) + trajectory_group_offload_dir: str | None = None # NVMe offload dir for queued task batches (None = disabled) + + def __post_init__(self): + if self.fwd_bwd_group_size is None: + self.fwd_bwd_group_size = self.mini_batch_size + if self.enable: + assert self.fwd_bwd_group_size >= 1 + assert self.mini_batch_size % self.fwd_bwd_group_size == 0, f"mini_batch_size ({self.mini_batch_size}) must be divisible by fwd_bwd_group_size ({self.fwd_bwd_group_size})" + + @dataclass class CompactFilteringConfig: """Configuration for compact filtering of episodes based on termination reasons. @@ -93,6 +125,30 @@ class RejectionSamplingConfig: # For "episode" mode (verl compatibility): minimum number of tasks with partial solves before proceeding min_partial_solve_tasks: int = 1 + # Filter out episode groups where all rollouts have the same is_correct (no gradient signal). + # Applied at the accumulator level in async training, before groups enter the buffer. + filter_uniform_groups: bool = False + + +@dataclass +class RolloutCorrectionConfig: + """Configuration for rollout correction (TIS, proximal forward passes). + + Backend-agnostic — each backend interprets these according to its infrastructure. + + Attributes: + tis_mode: None = disabled (string loss names, current behavior). + "token" or "sequence" = enable custom callable loss with TIS at that level. + bypass_mode: When True, use rollout (inference) logprobs as π_old — no + proximal forward pass. When False, compute π_old via policy.forward() + (3-policy / decoupled PPO). + tis_cap: Upper clamp on the TIS importance weight. + """ + + tis_mode: str | None = None + bypass_mode: bool = True + tis_cap: float = 5.0 + class rLLMAdvantageEstimator(str, Enum): """ @@ -143,6 +199,14 @@ class AlgorithmConfig: lr_schedule: Literal["linear", "cosine", "constant"] = "constant" warmup_steps_ratio: float = 0.0 + # Custom loss / rollout correction fields (used by Fireworks backend with cookbook losses) + kl_beta: float = 0.0 + eps_clip: float = 0.2 + eps_clip_high: float | None = None + loss_agg_mode: Literal["token_mean", "seq_mean_token_sum", "seq_mean_token_mean", None] = None + rollout_correction: RolloutCorrectionConfig = field(default_factory=RolloutCorrectionConfig) + router_replay: bool = False + @classmethod def from_config(cls, config: DictConfig) -> "AlgorithmConfig": """Create an AlgorithmConfig from a dictionary configuration. @@ -152,15 +216,27 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig": Returns: AlgorithmConfig: The AlgorithmConfig built from the configuration. """ + rc_section = config.rllm.algorithm.get("rollout_correction", {}) + rollout_correction = RolloutCorrectionConfig( + tis_mode=rc_section.get("tis_mode", None), + bypass_mode=rc_section.get("bypass_mode", True), + tis_cap=rc_section.get("tis_cap", 5.0), + ) return cls( estimator=rLLMAdvantageEstimator(config.algorithm.adv_estimator), stepwise_advantage_mode=config.rllm.stepwise_advantage.mode, - norm_adv_by_std_in_grpo=config.rllm.stepwise_advantage.get("norm_adv_by_std_in_grpo", True), + norm_adv_by_std_in_grpo=config.rllm.algorithm.get("norm_adv_by_std_in_grpo", True), use_rllm=config.rllm.stepwise_advantage.get("use_rllm", False), use_precomputed_advantage=config.rllm.algorithm.get("use_precomputed_advantage", False), loss_fn=config.rllm.algorithm.get("loss_fn", None), lr_schedule=config.rllm.algorithm.get("lr_schedule", "constant"), warmup_steps_ratio=config.rllm.algorithm.get("warmup_steps_ratio", 0.0), + kl_beta=config.rllm.algorithm.get("kl_beta", 0.0), + eps_clip=config.rllm.algorithm.get("eps_clip", 0.2), + eps_clip_high=config.rllm.algorithm.get("eps_clip_high", None), + loss_agg_mode=config.rllm.algorithm.get("loss_agg_mode", None), + rollout_correction=rollout_correction, + router_replay=config.rllm.algorithm.get("router_replay", False), ) def __post_init__(self): diff --git a/rllm/experimental/common/rejection_sampling.py b/rllm/experimental/common/rejection_sampling.py index 0b8d9e8cb..ec033b8ea 100644 --- a/rllm/experimental/common/rejection_sampling.py +++ b/rllm/experimental/common/rejection_sampling.py @@ -157,7 +157,12 @@ def filter_episodes( return filtered_episodes -def apply_rejection_sampling_and_filtering(episodes: list[Episode], groups: list[TrajectoryGroup], config: RejectionSamplingConfig, state: RejectionSamplingState) -> tuple[list[TrajectoryGroup], list[Episode], dict]: +def apply_rejection_sampling_and_filtering( + episodes: list[Episode], + groups: list[TrajectoryGroup], + config: RejectionSamplingConfig, + state: RejectionSamplingState, +) -> tuple[list[TrajectoryGroup], list[Episode], dict]: """ Apply rejection sampling to trajectory groups and episodes. diff --git a/rllm/experimental/common/transform.py b/rllm/experimental/common/transform.py index 723b491af..f13e70f29 100644 --- a/rllm/experimental/common/transform.py +++ b/rllm/experimental/common/transform.py @@ -87,7 +87,9 @@ def _validate_and_propagate_rewards( for group in groups: if config.broadcast: num_missing_rewards = sum(traj.reward is None for traj in group.trajectories) - assert num_missing_rewards == 0 or num_missing_rewards == len(group.trajectories), "Trajectories in a group must either ALL NOT have a trajectory-level reward or ALL have a trajectory-level reward" + assert num_missing_rewards == 0 or num_missing_rewards == len(group.trajectories), ( + "Trajectories in a group must either ALL NOT have a trajectory-level reward or ALL have a trajectory-level reward" + ) if num_missing_rewards > 0: for traj in group.trajectories: assert len(traj.steps) > 0, "Trajectory within a group must have at least one step" @@ -150,7 +152,7 @@ def _build_trajectory_groups(episodes: list[Episode], compact_filtering_config: """ -def _get_transform_metrics(episodes: list[Episode], groups: list[TrajectoryGroup], prefix: str = "grouping") -> dict: +def _get_transform_metrics(episodes: list[Episode], groups: list[TrajectoryGroup], prefix: str = "groups") -> dict: """ Get metrics for the transformation pipeline. """ @@ -180,12 +182,11 @@ def _default_traj_grouping_hook(episodes: list[Episode], transform_config: Trans """ trajectory_groups = _build_trajectory_groups(episodes, compact_filtering_config) # part 1 reward_warnings = _validate_and_propagate_rewards(trajectory_groups, transform_config) # part 2 - - for warning in reward_warnings[:LOG_N_WARNINGS]: - logger.warning(warning) - - if len(reward_warnings) > LOG_N_WARNINGS: - logger.warning(f"Skipping {len(reward_warnings) - LOG_N_WARNINGS} more similar warnings with reward validation") + if reward_warnings: + for warning in reward_warnings[:LOG_N_WARNINGS]: + logger.debug(warning) + if len(reward_warnings) > LOG_N_WARNINGS: + logger.debug(f"Skipping {len(reward_warnings) - LOG_N_WARNINGS} more similar reward validation warnings") return trajectory_groups @@ -194,7 +195,7 @@ def transform_episodes_to_trajectory_groups( episodes: list[Episode], transform_config: TransformConfig, compact_filtering_config: CompactFilteringConfig | None = None, - metrics_prefix: str = "grouping", + metrics_prefix: str = "groups", traj_grouping_hook: Callable[[list[Episode], TransformConfig, CompactFilteringConfig | None], list[TrajectoryGroup]] = _default_traj_grouping_hook, ) -> tuple[list[TrajectoryGroup], dict]: """ @@ -232,12 +233,11 @@ def transform_episodes_to_trajectory_groups( # Step 1: Name imputation rename_warnings = _impute_trajectory_names(episodes, transform_config) - - for warning in rename_warnings[:LOG_N_WARNINGS]: - logger.warning(warning) - - if len(rename_warnings) > LOG_N_WARNINGS: - logger.warning(f"Skipping {len(rename_warnings) - LOG_N_WARNINGS} more similar warnings with trajectory names") + if rename_warnings: + for warning in rename_warnings[:LOG_N_WARNINGS]: + logger.debug(warning) + if len(rename_warnings) > LOG_N_WARNINGS: + logger.debug(f"Skipping {len(rename_warnings) - LOG_N_WARNINGS} more similar trajectory name warnings") # Step 2: Invoke the trajectory grouping hook groups = traj_grouping_hook(episodes, transform_config, compact_filtering_config) diff --git a/rllm/experimental/common/visualization.py b/rllm/experimental/common/visualization.py index cee8e9d01..d7d45d242 100644 --- a/rllm/experimental/common/visualization.py +++ b/rllm/experimental/common/visualization.py @@ -24,6 +24,40 @@ class VisualizationConfig: failure_style: dict[str, Any] = field(default_factory=lambda: {"fg": "red", "bold": True}) +def print_metrics_table(metrics: dict, step: int, title: str | None = None) -> None: + """Print metrics as a formatted Rich table with fallback to plain text.""" + try: + from rich.console import Console + from rich.table import Table + + table = Table(title=title or f"Step {step}", show_header=True, header_style="bold magenta") + table.add_column("Metric", style="cyan", no_wrap=False) + table.add_column("Value", justify="right", style="green") + + for key, value in sorted(metrics.items()): + if isinstance(value, float): + value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" + elif isinstance(value, int): + value_str = str(value) + else: + value_str = str(value) + table.add_row(key, value_str) + + Console().print(table) + except ImportError: + print(f"\n{title or f'Step {step}'}") + print("=" * 60) + for key, value in sorted(metrics.items()): + if isinstance(value, float): + value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" + elif isinstance(value, int): + value_str = str(value) + else: + value_str = str(value) + print(f"{key:40s} {value_str:>15s}") + print("=" * 60) + + def colorful_print(string: str, *args, **kwargs) -> None: end = kwargs.pop("end", "\n") print(click.style(string, *args, **kwargs), end=end, flush=True) diff --git a/rllm/experimental/config/rllm/base.yaml b/rllm/experimental/config/rllm/base.yaml index 152310f9b..9aaee0296 100644 --- a/rllm/experimental/config/rllm/base.yaml +++ b/rllm/experimental/config/rllm/base.yaml @@ -56,8 +56,16 @@ algorithm: # When true, always use pre-computed step.advantage from the workflow (e.g. distillation) # and skip advantage computation (GRPO/REINFORCE). Missing advantages default to 0. use_precomputed_advantage: false - # for tinker backend only (avaiable options: importance_sampling, ppo, cispo, dro, cross_entropy) - loss_fn: null + loss_fn: null # [null, importance_sampling, ppo, cispo, dro, cross_entropy] + loss_agg_mode: null # [null, token-mean, seq-mean-token-sum, seq-mean-token-mean] + kl_beta: 0.0 # KL penalty coefficient + eps_clip: 0.2 # PPO clip epsilon + eps_clip_high: null # Asymmetric upper clip bound (null = symmetric) + router_replay: false # Router Replay (R3): replay MoE expert routing from inference during training + rollout_correction: + bypass_mode: true # true = use rollout logprobs as pi_old (2-policy), false = proximal forward (3-policy) + tis_mode: null # null = disabled, "token" or "sequence" = TIS importance sampling level + tis_cap: 5.0 # Upper clamp on TIS importance weight # Stepwise advantage # TODO(listar2000): deprecate the `per_step` mode and refactor this config. @@ -93,6 +101,7 @@ rejection_sample: multiplier: 1 min_partial_solve_tasks: 1 min_trajs_per_group: 2 + filter_uniform_groups: false # SDK Configuration # DEPRECATED: This section is only kept for backward compatibility with @@ -153,6 +162,17 @@ gateway: host: null # Auto-detects routable IP; set explicitly to override db_path: null # Defaults to temp file +# Async Training Configuration +async_training: + enable: false + mini_batch_size: 1 + fwd_bwd_group_size: null # task batches per forward-backward pass (default: mini_batch_size) + staleness_threshold: 0.0 + trigger_parameter_sync_step: 1 + partial_rollout: true + episode_offload_dir: null # NVMe offload dir for pending episodes (null = disabled) + trajectory_group_offload_dir: null # NVMe offload dir for queued task batches (null = disabled) + # Episode Logging Configuration episode_logging: log_episodes: false diff --git a/rllm/experimental/engine/agent_flow_engine.py b/rllm/experimental/engine/agent_flow_engine.py index 60c3ccdd6..1b582dc31 100644 --- a/rllm/experimental/engine/agent_flow_engine.py +++ b/rllm/experimental/engine/agent_flow_engine.py @@ -74,6 +74,7 @@ def __init__( self.raise_on_error = raise_on_error self.episode_logger = episode_logger self.executor = ThreadPoolExecutor(max_workers=n_parallel_tasks) + self._semaphore = asyncio.Semaphore(n_parallel_tasks) # Raise the file descriptor limit to avoid "Too many open files" when # running many parallel agent flows with individual HTTP clients. @@ -117,7 +118,7 @@ async def execute_tasks( for idx, (task, task_id) in enumerate(zip(tasks, task_ids, strict=True)): rollout_idx = task_id_counter[task_id] task_id_counter[task_id] += 1 - futures.append(self._process_task_with_retry(task, task_id, rollout_idx, idx, is_validation=is_validation)) + futures.append(self.process_task_with_retry(task, task_id, rollout_idx, idx, is_validation=is_validation)) with tqdm(total=len(tasks), desc="Generating trajectories") as pbar: for future in asyncio.as_completed(futures): @@ -141,7 +142,7 @@ async def execute_tasks( return ordered_results - async def _process_task_with_retry( + async def process_task_with_retry( self, task: dict, task_id: str, @@ -150,64 +151,61 @@ async def _process_task_with_retry( is_validation: bool = False, ) -> tuple[str, int, int, Episode]: """Process a single task with retry logic.""" - for retry_attempt in range(1, self.retry_limit + 1): - uid = f"{task_id}:{rollout_idx}" - try: - episode = await self._run_single(task, uid, is_validation=is_validation) - episode.id = uid - episode.task = task - - # Display rewards - reward_strs = [] - for traj in episode.trajectories: - reward = "N/A" - if traj.reward is not None: - reward = f"{traj.reward:.1f}" - elif len(traj.steps) > 0: - reward = f"{traj.steps[-1].reward:.1f}" - reward_strs.append(f"{traj.name}: {reward}") - colorful_print( - f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}", - fg="green" if episode.is_correct else "yellow", - ) - - return task_id, rollout_idx, result_idx, episode - - except Exception as e: - logger.error("[%s] Attempt %d/%d failed: %s", uid, retry_attempt, self.retry_limit, e) - if retry_attempt < self.retry_limit: - continue - - if self.raise_on_error: - raise - - # Return an error episode - return ( - task_id, - rollout_idx, - result_idx, - Episode( - id=uid, - task=task, - is_correct=False, - termination_reason=TerminationReason.ERROR, - metadata={"error": {"message": str(e)}}, - ), - ) - - # Should not reach here, but satisfy type checker - raise RuntimeError(f"[{uid}] Exhausted all retries") + async with self._semaphore: + for retry_attempt in range(1, self.retry_limit + 1): + uid = f"{task_id}:{rollout_idx}" + try: + episode = await self._run_single(task, uid, is_validation=is_validation) + episode.id = uid + episode.task = task + + # Display rewards + reward_strs = [] + for traj in episode.trajectories: + reward = "N/A" + if traj.reward is not None: + reward = f"{traj.reward:.1f}" + elif len(traj.steps) > 0: + reward = f"{traj.steps[-1].reward:.1f}" + reward_strs.append(f"{traj.name}: {reward}") + colorful_print( + f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}", + fg="green" if episode.is_correct else "yellow", + ) + + return task_id, rollout_idx, result_idx, episode + + except Exception as e: + logger.error("[%s] Attempt %d/%d failed: %s", uid, retry_attempt, self.retry_limit, e) + if retry_attempt < self.retry_limit: + continue + + if self.raise_on_error: + raise + + # Return an error episode + return ( + task_id, + rollout_idx, + result_idx, + Episode( + id=uid, + task=task, + is_correct=False, + termination_reason=TerminationReason.ERROR, + metadata={"error": {"message": str(e)}}, + ), + ) + + # Should not reach here, but satisfy type checker + raise RuntimeError(f"[{uid}] Exhausted all retries") async def _run_single(self, task: dict, uid: str, is_validation: bool = False) -> Episode: """Run a single AgentFlow task: execute, evaluate, enrich.""" loop = asyncio.get_event_loop() - # 1. Create gateway session (run in executor to avoid blocking event loop) - await loop.run_in_executor( - self.executor, - self.gateway.create_session, - uid, - ) + # 1. Create gateway session + await self.gateway.acreate_session(uid, is_validation=is_validation) session_url = self.gateway.get_session_url(uid) # 2. Build config @@ -239,12 +237,8 @@ async def _run_single(self, task: dict, uid: str, is_validation: bool = False) - traj.reward = eval_output.reward episode.is_correct = eval_output.is_correct - # 5. Retrieve traces from gateway (run in executor to avoid blocking event loop) - traces = await loop.run_in_executor( - self.executor, - self.gateway.get_traces, - uid, - ) + # 5. Retrieve traces from gateway + traces = await self.gateway.aget_traces(uid) # 6. Enrich episode with token data enriched = self._enrich_episode(episode, traces, uid, task) diff --git a/rllm/experimental/engine/gateway_manager.py b/rllm/experimental/engine/gateway_manager.py index 44e75b474..abcc24191 100644 --- a/rllm/experimental/engine/gateway_manager.py +++ b/rllm/experimental/engine/gateway_manager.py @@ -18,7 +18,7 @@ import time from typing import TYPE_CHECKING, Any -from rllm_model_gateway.client import GatewayClient +from rllm_model_gateway.client import AsyncGatewayClient, GatewayClient from rllm_model_gateway.models import TraceRecord if TYPE_CHECKING: @@ -89,6 +89,7 @@ def __init__(self, config: DictConfig, mode: str = "thread") -> None: self._server: Any = None # uvicorn.Server when using thread mode self._local_handler: Any = None # in-process handler for tinker self._client: GatewayClient | None = None + self._async_client: AsyncGatewayClient | None = None # Per-mode sampling params (extracted from rollout engine in start()) self._train_sampling_params: dict[str, Any] = {} @@ -100,10 +101,18 @@ def gateway_url(self) -> str: @property def client(self) -> GatewayClient: + """Sync client for lifecycle operations (start, stop, health polling).""" if self._client is None: self._client = GatewayClient(self.gateway_url) return self._client + @property + def async_client(self) -> AsyncGatewayClient: + """Async client for runtime operations (sessions, traces).""" + if self._async_client is None: + self._async_client = AsyncGatewayClient(self.gateway_url) + return self._async_client + # -- Lifecycle ----------------------------------------------------------- def start(self, rollout_engine: RolloutEngine) -> None: @@ -171,6 +180,16 @@ def get_traces(self, session_id: str) -> list[TraceRecord]: self.client.flush() return self.client.get_session_traces(session_id) + # -- Async session / trace API ------------------------------------------- + + async def acreate_session(self, session_id: str, is_validation: bool = False) -> str: + sp = self._val_sampling_params if is_validation else self._train_sampling_params + return await self.async_client.create_session(session_id=session_id, sampling_params=sp or None) + + async def aget_traces(self, session_id: str) -> list[TraceRecord]: + await self.async_client.flush() + return await self.async_client.get_session_traces(session_id) + # -- Worker setup -------------------------------------------------------- def _ensure_workers(self, rollout_engine: RolloutEngine) -> list[str]: diff --git a/rllm/experimental/engine/remote_agent_flow_engine.py b/rllm/experimental/engine/remote_agent_flow_engine.py index c59adc814..0bf67d6f6 100644 --- a/rllm/experimental/engine/remote_agent_flow_engine.py +++ b/rllm/experimental/engine/remote_agent_flow_engine.py @@ -5,6 +5,7 @@ converting gateway traces to training Steps. """ +import asyncio import logging import uuid from collections import defaultdict @@ -31,12 +32,15 @@ def __init__( runtime: RemoteAgentRuntime, gateway: GatewayManager, session_timeout: float = 900.0, + n_parallel_tasks: int = 128, episode_logger: EpisodeLogger | None = None, ) -> None: self.runtime = runtime self.gateway = gateway self.session_timeout = session_timeout + self.n_parallel_tasks = n_parallel_tasks self.episode_logger = episode_logger + self._semaphore = asyncio.Semaphore(n_parallel_tasks) # Training step tracking (set by set_training_step) self.current_step = 0 @@ -55,59 +59,24 @@ async def execute_tasks( is_validation: bool = False, **kwargs, ) -> list[Episode]: - """Submit tasks to remote runtime, gather results, build Episodes from gateway traces. - - 1. Prepare submissions (create gateway sessions) - 2. Submit all and gather results concurrently via runtime - 3. Retrieve traces from gateway + build Episodes - """ + """Submit tasks to remote runtime, gather results, build Episodes from gateway traces.""" if task_ids is None: task_ids = [str(uuid.uuid4()) for _ in tasks] - # Phase 1: Prepare submissions task_id_counter: dict[str, int] = defaultdict(int) - submissions: list[TaskSubmission] = [] - # Map session_id -> (idx, uid, task) for result correlation - session_metadata: dict[str, tuple[int, str, dict]] = {} + results: list[Episode | None] = [None] * len(tasks) + futures = [] for idx, (task, task_id) in enumerate(zip(tasks, task_ids, strict=True)): rollout_idx = task_id_counter[task_id] task_id_counter[task_id] += 1 - uid = f"{task_id}:{rollout_idx}" - session_id = str(uuid.uuid4()) - - self.gateway.create_session(session_id, is_validation=is_validation) - session_url = self.gateway.get_session_url(session_id) - - submissions.append( - TaskSubmission( - task=task, - session_id=session_id, - task_id=task_id, - inference_url=session_url, - ) - ) - session_metadata[session_id] = (idx, uid, task) - - # Phase 2: Submit all and gather results concurrently - logger.info("Submitting %d tasks to remote runtime (timeout=%.0fs)", len(submissions), self.session_timeout) - remote_results = await self.runtime.execute_tasks(submissions, timeout=self.session_timeout) + futures.append(self.process_task_with_retry(task, task_id, rollout_idx, idx, is_validation=is_validation)) - # Phase 3: Retrieve traces from gateway + build Episodes (match by session_id) - episode_map: dict[int, Episode] = {} + for future in asyncio.as_completed(futures): + task_id, rollout_idx, idx, episode = await future + results[idx] = episode - for result in remote_results: - idx, uid, task = session_metadata[result.session_id] - if not result.finished: - logger.warning("[%s] Remote task failed (assigning reward=0): %s", uid, result.error) - result.reward = 0.0 - traces = self.gateway.get_traces(result.session_id) - episode = _build_episode(traces, result, uid, task) - if not result.finished: - episode.metadata["error"] = {"message": result.error or "Unknown error"} - episode_map[idx] = episode - - episodes = [episode_map[i] for i in range(len(tasks))] + episodes: list[Episode] = results # type: ignore[assignment] # Log episodes if logger is provided if self.episode_logger is not None: @@ -123,6 +92,43 @@ async def execute_tasks( return episodes + async def process_task_with_retry( + self, + task: dict, + task_id: str, + rollout_idx: int, + result_idx: int, + **kwargs, + ) -> tuple[str, int, int, Episode]: + """Process a single task with concurrency control.""" + async with self._semaphore: + uid = f"{task_id}:{rollout_idx}" + session_id = str(uuid.uuid4()) + is_validation = kwargs.get("is_validation", False) + + await self.gateway.acreate_session(session_id, is_validation=is_validation) + session_url = self.gateway.get_session_url(session_id) + + submission = TaskSubmission( + task=task, + session_id=session_id, + task_id=task_id, + inference_url=session_url, + ) + results = await self.runtime.execute_tasks([submission], timeout=self.session_timeout) + result = results[0] + + if not result.finished: + logger.warning("[%s] Remote task failed (assigning reward=0): %s", uid, result.error) + result.reward = 0.0 + + traces = await self.gateway.aget_traces(session_id) + episode = _build_episode(traces, result, uid, task) + if not result.finished: + episode.metadata["error"] = {"message": result.error or "Unknown error"} + + return task_id, rollout_idx, result_idx, episode + def shutdown(self) -> None: """No local resources to clean up (runtime shutdown is separate).""" pass diff --git a/rllm/experimental/engine/remote_runtime/agentcore_runtime.py b/rllm/experimental/engine/remote_runtime/agentcore_runtime.py index 744114018..136f2a463 100644 --- a/rllm/experimental/engine/remote_runtime/agentcore_runtime.py +++ b/rllm/experimental/engine/remote_runtime/agentcore_runtime.py @@ -112,9 +112,7 @@ async def _run_one(self, sub: TaskSubmission, timeout: float) -> RemoteTaskResul raw_result=result, ) - async def execute_tasks( - self, submissions: list[TaskSubmission], timeout: float | None = None - ) -> list[RemoteTaskResult]: + async def execute_tasks(self, submissions: list[TaskSubmission], timeout: float | None = None) -> list[RemoteTaskResult]: """Submit all tasks concurrently via asyncio.gather. Each task invokes then polls in sequence; all tasks run in parallel. diff --git a/rllm/experimental/engine/remote_runtime/protocol.py b/rllm/experimental/engine/remote_runtime/protocol.py index b90399cd2..e00e3b8d3 100644 --- a/rllm/experimental/engine/remote_runtime/protocol.py +++ b/rllm/experimental/engine/remote_runtime/protocol.py @@ -48,9 +48,7 @@ def initialize(self) -> None: """Client setup from config.""" ... - async def execute_tasks( - self, submissions: list[TaskSubmission], timeout: float | None = None - ) -> list[RemoteTaskResult]: + async def execute_tasks(self, submissions: list[TaskSubmission], timeout: float | None = None) -> list[RemoteTaskResult]: """Submit tasks concurrently and gather results. Returns one result per submission.""" ... diff --git a/rllm/experimental/engine/unified_workflow_engine.py b/rllm/experimental/engine/unified_workflow_engine.py index da03c4d4f..037cbd9a3 100644 --- a/rllm/experimental/engine/unified_workflow_engine.py +++ b/rllm/experimental/engine/unified_workflow_engine.py @@ -12,12 +12,12 @@ from rllm.agents.agent import Episode from rllm.experimental.rollout import RolloutEngine -from rllm.utils import colorful_print from rllm.workflows.store import Store from rllm.workflows.workflow import TerminationReason, Workflow # Avoid hard dependency on verl at import time; only for typing if TYPE_CHECKING: + from omegaconf import DictConfig from verl import DataProto from rllm.utils.episode_logger import EpisodeLogger @@ -31,7 +31,7 @@ def __init__( workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, - config=None, + config: DictConfig | None = None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, @@ -104,6 +104,7 @@ async def initialize_pool(self): assert self.executor is not None, "executor is not initialized" if self.workflow_queue is not None: return + logger.info(f"[WorkflowEngine] Initializing pool with {self.n_parallel_tasks} workflows") self.workflow_queue = asyncio.Queue(maxsize=self.n_parallel_tasks) for i in range(self.n_parallel_tasks): workflow = self.workflow_cls( @@ -114,6 +115,7 @@ async def initialize_pool(self): ) assert workflow.is_multithread_safe(), "Workflows must contain only thread-save environments" self.workflow_queue.put_nowait(workflow) + logger.info(f"[WorkflowEngine] Pool initialized. Queue size: {self.workflow_queue.qsize()}") async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: int, result_idx: int, **kwargs) -> tuple[str, int, int, Episode]: """Process a single task rollout with retry logic based on termination reasons. @@ -132,10 +134,12 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i Exception: If task fails permanently after retry_limit attempts and raise_on_error is True. """ assert self.workflow_queue is not None, "workflow_queue is not initialized" + logger.debug(f"[WorkflowEngine] Waiting for workflow from queue. Available: {self.workflow_queue.qsize()}") workflow = await self.workflow_queue.get() try: for retry_attempt in range(1, self.retry_limit + 1): uid = f"{task_id}:{rollout_idx}" + logger.debug(f"[WorkflowEngine] [{uid}] Starting attempt {retry_attempt}/{self.retry_limit}") workflow.reset(task=task, uid=uid) episode = await workflow.run_with_termination_handling(task=task, uid=uid, **kwargs) @@ -152,24 +156,21 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i elif len(traj.steps) > 0: reward = f"{traj.steps[-1].reward:.1f}" reward_strs.append(f"{traj.name}: {reward}") - colorful_print( - f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}", - fg="green" if episode.is_correct else "yellow", - ) + logger.debug(f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}") if episode.termination_reason != TerminationReason.ERROR: return task_id, rollout_idx, result_idx, episode error_tb = episode.info.get("error", {}).get("traceback") if error_tb: - print(error_tb) + logger.error(f"[WorkflowEngine] [{uid}] Error on attempt {retry_attempt}/{self.retry_limit}:\n{error_tb}") if retry_attempt < self.retry_limit: - print(f"[{uid}] Rollout failed on attempt {retry_attempt}/{self.retry_limit}, retrying...") + logger.warning(f"[WorkflowEngine] [{uid}] Rollout failed on attempt {retry_attempt}/{self.retry_limit}, retrying...") continue if not self.raise_on_error: - print(f"[{uid}] Rollout failed permanently after {self.retry_limit} attempts.") + logger.error(f"[WorkflowEngine] [{uid}] Rollout failed permanently after {self.retry_limit} attempts.") else: raise Exception(f"[{uid}] Rollout failed permanently after {self.retry_limit} attempts.") @@ -177,6 +178,7 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i finally: await self.workflow_queue.put(workflow) + logger.debug(f"[WorkflowEngine] Returned workflow to queue. Available: {self.workflow_queue.qsize()}") async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = None, is_validation: bool = False, **kwargs) -> list[Episode]: """Run asynchronous workflow execution with retry logic for multiple tasks. diff --git a/rllm/experimental/eval/task_spec.py b/rllm/experimental/eval/task_spec.py index 96f4acba1..b13e32bc8 100644 --- a/rllm/experimental/eval/task_spec.py +++ b/rllm/experimental/eval/task_spec.py @@ -61,7 +61,12 @@ def __post_init__(self): BENCHMARK_INSTRUCTIONS: dict[str, str] = { "math_reward_fn": ("Solve the math problem step by step, showing your reasoning clearly. Put your final answer in \\boxed{} notation.\nExample: The answer is \\boxed{42}."), - "countdown_reward_fn": ("You are given a target number and a set of numbers. Use each number exactly once with basic arithmetic (+, -, *, /) to reach the target. Show your reasoning, then provide your equation inside ... tags.\nExample: (25 + 3) * 2"), + "countdown_reward_fn": ( + "You are given a target number and a set of numbers. Use each number exactly once" + " with basic arithmetic (+, -, *, /) to reach the target. Show your reasoning," + " then provide your equation inside ... tags.\n" + "Example: (25 + 3) * 2" + ), "mcq_reward_fn": ("Choose the correct answer from the given options. Think through the problem carefully, then respond with ONLY the letter of the correct answer (A, B, C, D, etc.)."), "code_reward_fn": ("Write a Python function that solves the problem. Your code will be tested against hidden test cases. Put your complete solution in a ```python code block."), "f1_reward_fn": ("Answer the question directly and concisely. Provide only the answer, no additional explanation."), diff --git a/rllm/experimental/fully_async/fully_async_trainer.py b/rllm/experimental/fully_async/fully_async_trainer.py index 55832f801..0f9825093 100644 --- a/rllm/experimental/fully_async/fully_async_trainer.py +++ b/rllm/experimental/fully_async/fully_async_trainer.py @@ -21,7 +21,7 @@ from omegaconf import OmegaConf from tqdm import tqdm from verl import DataProto -from verl.experimental.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer +from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup from verl.trainer.ppo import core_algos from verl.trainer.ppo.core_algos import agg_loss @@ -40,7 +40,7 @@ @ray.remote(num_cpus=10) -class FullyAsyncTrainer(FullyAsyncRayPPOTrainer): +class FullyAsyncTrainer(SeparateRayPPOTrainer): """ A fully asynchronous PPO trainer that obtains samples from a MessageQueue for training. Based on an improved implementation of OneStepOffRayTrainer @@ -274,7 +274,15 @@ def _save_checkpoint(self): if self.use_critic: critic_local_path = os.path.join(local_global_step_folder, str(Role.Critic)) - critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.current_param_version}", str(Role.Critic)) + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join( + self.config.trainer.default_hdfs_dir, + f"global_step_{self.current_param_version}", + str(Role.Critic), + ) + ) self.critic_wg.save_checkpoint( critic_local_path, critic_remote_path, diff --git a/rllm/experimental/fully_async/inference_manager.py b/rllm/experimental/fully_async/inference_manager.py index db2bd4fb6..fbc83c300 100644 --- a/rllm/experimental/fully_async/inference_manager.py +++ b/rllm/experimental/fully_async/inference_manager.py @@ -15,15 +15,15 @@ import subprocess import ray -from verl.experimental.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer +from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup from verl.trainer.ppo.ray_trainer import ResourcePoolManager from verl.trainer.ppo.utils import Role, WorkerType -from verl.workers.rollout.utils import get_free_port +from verl.utils.net_utils import get_free_port @ray.remote(num_cpus=10, max_concurrency=100) -class InferenceManager(FullyAsyncRayPPOTrainer): +class InferenceManager(SeparateRayPPOTrainer): """ Manages SGLang inference servers for async training. Responsible for: @@ -120,10 +120,10 @@ def _init_models(self): async def _init_async_rollout_manager(self): # create async rollout manager and request scheduler assert self.config.actor_rollout_ref.rollout.mode == "async" - from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager + from verl.experimental.agent_loop import AgentLoopManager self.async_rollout_mode = True - self.async_rollout_manager = await FullyAsyncAgentLoopManager.create( + self.async_rollout_manager = await AgentLoopManager.create( config=self.config, worker_group=self.rollout_wg, ) diff --git a/rllm/experimental/fully_async/rollout_executor.py b/rllm/experimental/fully_async/rollout_executor.py index 6e6f36bc4..0a0792438 100644 --- a/rllm/experimental/fully_async/rollout_executor.py +++ b/rllm/experimental/fully_async/rollout_executor.py @@ -71,7 +71,13 @@ def __init__(self, router_url, rollout_fn, n, config, tokenizer, processor, max_ self.total_train_steps = int(self.total_rollout_steps / (required_samples * trigger_parameter_sync_step)) self.max_queue_size = self.max_required_samples - print(f"[RolloutExecutor] required_samples={required_samples} max_required_samples={self.max_required_samples} max_queue_size={self.max_queue_size} total_train_steps={self.total_train_steps} total_rollout_steps={self.total_rollout_steps}") + print( + f"[RolloutExecutor] required_samples={required_samples}" + f" max_required_samples={self.max_required_samples}" + f" max_queue_size={self.max_queue_size}" + f" total_train_steps={self.total_train_steps}" + f" total_rollout_steps={self.total_rollout_steps}" + ) # Lock for dataloader access (async safety) self.dataloader_lock = asyncio.Lock() @@ -355,7 +361,12 @@ async def fit(self): datum = batch[0] # batch_size=1, extract single item datum_count += 1 if datum_count % 128 == 1: - print(f"[RolloutExecutor] Processing datum {datum_count}, global_steps={self.global_steps}/{self.total_rollout_steps}, active={self.active_sample}, enqueued={self.enqueued_sample}", flush=True) + print( + f"[RolloutExecutor] Processing datum {datum_count}," + f" global_steps={self.global_steps}/{self.total_rollout_steps}," + f" active={self.active_sample}, enqueued={self.enqueued_sample}", + flush=True, + ) if self.active_sample + self.enqueued_sample >= self.max_staleness_samples: self.continue_event.clear() @@ -403,7 +414,15 @@ async def update_staleness_tracking(self): mq_total_consumed = mq_stats.get("total_consumed", "N/A") mq_total_produced = mq_stats.get("total_produced", "N/A") - print(f"[RolloutExecutor] update_staleness_tracking CALLED, current enqueued_sample={self.enqueued_sample}, active_sample={self.active_sample}, mq_queue_size={mq_queue_size}, mq_total_consumed={mq_total_consumed}, mq_total_produced={mq_total_produced}", flush=True) + print( + f"[RolloutExecutor] update_staleness_tracking CALLED," + f" current enqueued_sample={self.enqueued_sample}," + f" active_sample={self.active_sample}," + f" mq_queue_size={mq_queue_size}," + f" mq_total_consumed={mq_total_consumed}," + f" mq_total_produced={mq_total_produced}", + flush=True, + ) self.enqueued_sample = mq_queue_size print(f"[RolloutExecutor] update_staleness_tracking DONE, new enqueued_sample={self.enqueued_sample}", flush=True) diff --git a/rllm/experimental/fully_async/runner.py b/rllm/experimental/fully_async/runner.py index d5d5186ef..bfd542439 100644 --- a/rllm/experimental/fully_async/runner.py +++ b/rllm/experimental/fully_async/runner.py @@ -22,7 +22,7 @@ import ray from omegaconf import OmegaConf -from verl.experimental.fully_async_policy.fully_async_main import create_resource_pool_manager, create_role_worker_mapping +from verl.experimental.separation.utils import create_resource_pool_manager, create_role_worker_mapping from verl.trainer.ppo.utils import Role from verl.utils.fs import copy_to_local diff --git a/rllm/experimental/fully_async/utils.py b/rllm/experimental/fully_async/utils.py index b818089ae..7d567db1b 100644 --- a/rllm/experimental/fully_async/utils.py +++ b/rllm/experimental/fully_async/utils.py @@ -456,9 +456,23 @@ def apply_rejection_sampling( unfiltered_mean = stats.get("rejection_sample/unfiltered_reward_mean", 0) filtered_mean = stats.get("rejection_sample/filtered_reward_mean", 0) if enable: - print(f"[RejectionSampling] Applied rejection sampling: solve_none={stats['rejection_sample/solve_none']}, solve_all={stats['rejection_sample/solve_all']}, solve_partial={stats['rejection_sample/solve_partial']}, kept {len(filtered_groups)}/{len(trajectory_group_ls)} groups, unfiltered_reward={unfiltered_mean:.4f}, filtered_reward={filtered_mean:.4f}") + print( + f"[RejectionSampling] Applied rejection sampling:" + f" solve_none={stats['rejection_sample/solve_none']}," + f" solve_all={stats['rejection_sample/solve_all']}," + f" solve_partial={stats['rejection_sample/solve_partial']}," + f" kept {len(filtered_groups)}/{len(trajectory_group_ls)} groups," + f" unfiltered_reward={unfiltered_mean:.4f}," + f" filtered_reward={filtered_mean:.4f}" + ) else: - print(f"[RejectionSampling] Stats (filtering disabled): solve_none={stats['rejection_sample/solve_none']}, solve_all={stats['rejection_sample/solve_all']}, solve_partial={stats['rejection_sample/solve_partial']}, reward_mean={unfiltered_mean:.4f}") + print( + f"[RejectionSampling] Stats (filtering disabled):" + f" solve_none={stats['rejection_sample/solve_none']}," + f" solve_all={stats['rejection_sample/solve_all']}," + f" solve_partial={stats['rejection_sample/solve_partial']}," + f" reward_mean={unfiltered_mean:.4f}" + ) return filtered_groups, stats diff --git a/rllm/experimental/metrics.py b/rllm/experimental/metrics.py new file mode 100644 index 000000000..fe2ce1ef4 --- /dev/null +++ b/rllm/experimental/metrics.py @@ -0,0 +1,119 @@ +"""MetricsAggregator for async training. + +Accumulates metric observations from multiple sources (buffer, training loop, +coordinator) and reduces them with per-key aggregation rules at flush time. +""" + +from __future__ import annotations + +from collections import defaultdict + +import numpy as np + +# Keys that should be summed rather than averaged. +_SUM_KEYS: set[str] = { + "groups/num_trajs_before_filter", + "groups/num_trajs_after_filter", + "groups/num_groups", + "groups/dropped_min_trajs", + "groups/dropped_zero_adv", +} + +# Prefixes where "last value" is the correct reduction. +_LAST_PREFIXES: tuple[str, ...] = ( + "time/", + "train/", + "progress/", + "async/", +) + +# Prefixes where "mean" is the correct reduction. +_MEAN_PREFIXES: tuple[str, ...] = ("episode/",) + + +def _infer_rule(key: str) -> str: + """Infer aggregation rule from metric key name. + + Resolution order: + 1. Explicit sum keys + 2. Prefix-based rules (last or mean) + 3. Keyword-based rules (/max, /min, /mean, /avg, /std, /fraction) + 4. Default: mean + """ + if key in _SUM_KEYS: + return "sum" + + for prefix in _LAST_PREFIXES: + if key.startswith(prefix): + return "last" + + for prefix in _MEAN_PREFIXES: + if key.startswith(prefix): + return "mean" + + # Keyword inference from the key name + if "/max" in key: + return "max" + if "/min" in key: + return "min" + if "/mean" in key or "/avg" in key: + return "mean" + if "/std" in key or "/fraction" in key: + return "mean" + + return "mean" + + +def _reduce(rule: str, values: list[float]) -> float: + if rule == "mean": + return sum(values) / len(values) + if rule == "sum": + return sum(values) + if rule == "max": + return max(values) + if rule == "min": + return min(values) + if rule == "last": + return values[-1] + return sum(values) / len(values) + + +class MetricsAggregator: + """Accumulates metric observations and flushes as an aggregated plain dict. + + Usage:: + + agg = MetricsAggregator() + + # record from various sources + agg.record("episode/queue_wait", 0.3) + agg.record("episode/queue_wait", 0.5) + agg.record_dict(transform_metrics) + + # at log time + plain_dict = agg.flush() # reduces, clears, returns dict + """ + + def __init__(self) -> None: + self._values: dict[str, list[float]] = defaultdict(list) + + def record(self, key: str, value: float) -> None: + """Record a single metric observation.""" + self._values[key].append(float(value)) + + def record_dict(self, metrics: dict) -> None: + """Record all numeric values from a dict, coercing types.""" + for k, v in metrics.items(): + if isinstance(v, int | float): + self._values[k].append(float(v)) + elif isinstance(v, np.number): + self._values[k].append(float(v)) + + def flush(self) -> dict[str, float]: + """Reduce all accumulated values and return a plain dict. Clears state.""" + result = {} + for key, values in self._values.items(): + if values: + result[key] = _reduce(_infer_rule(key), values) + self._values.clear() + return result diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index 0c8491dbe..ebdb79a0b 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -205,6 +205,10 @@ async def on_epoch_end(self, trainer_state: TrainerState) -> None: """Hook method called at the end of an epoch.""" pass + async def on_policy_updated(self, trainer_state: TrainerState) -> None: + """Hook called immediately after update_policy() for weight sync.""" + pass + async def on_validation_start(self, trainer_state: TrainerState) -> bool: """Hook method called at the start of validation. diff --git a/rllm/experimental/rllm_telemetry/src/rllm_telemetry/exporter.py b/rllm/experimental/rllm_telemetry/src/rllm_telemetry/exporter.py index 05c6d804b..ef1349aa0 100755 --- a/rllm/experimental/rllm_telemetry/src/rllm_telemetry/exporter.py +++ b/rllm/experimental/rllm_telemetry/src/rllm_telemetry/exporter.py @@ -376,7 +376,9 @@ async def _validate_dataset_and_table(self, loop: asyncio.AbstractEventLoop, dat try: await loop.run_in_executor(None, self._client.get_dataset, dataset_ref) except Exception as exc: - raise BigQueryValidationError(f"BigQuery dataset '{self._config.bq_dataset}' not found in project '{self._config.bq_project}'. Set bq_auto_create=True to create it automatically: {exc}") from exc + raise BigQueryValidationError( + f"BigQuery dataset '{self._config.bq_dataset}' not found in project '{self._config.bq_project}'. Set bq_auto_create=True to create it automatically: {exc}" + ) from exc try: await loop.run_in_executor(None, self._client.get_table, self._table_ref) diff --git a/rllm/experimental/rollout/__init__.py b/rllm/experimental/rollout/__init__.py index 6e5c4d681..50ab03477 100644 --- a/rllm/experimental/rollout/__init__.py +++ b/rllm/experimental/rollout/__init__.py @@ -9,11 +9,10 @@ __all__ = [ "ModelOutput", - # Rollout engines "RolloutEngine", "TinkerEngine", "VerlEngine", - # Token input/output types + # Token types "TokenInput", "TokenOutput", "TinkerTokenInput", @@ -30,7 +29,10 @@ def __getattr__(name): return _TinkerEngine if name == "VerlEngine": - from .verl_engine import VerlEngine as _VerlEngine + try: + from .verl_engine import VerlEngine as _VerlEngine - return _VerlEngine + return _VerlEngine + except Exception: + raise AttributeError(name) from None raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/rllm/experimental/rollout/rollout_engine.py b/rllm/experimental/rollout/rollout_engine.py index 1660d65a2..3837d59ec 100644 --- a/rllm/experimental/rollout/rollout_engine.py +++ b/rllm/experimental/rollout/rollout_engine.py @@ -1,13 +1,15 @@ from __future__ import annotations +import logging from dataclasses import dataclass from typing import TYPE_CHECKING -from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput -from rllm.tools.tool_base import ToolCall - if TYPE_CHECKING: + from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput from rllm.parser import ChatTemplateParser + from rllm.tools.tool_base import ToolCall + +logger = logging.getLogger(__name__) @dataclass @@ -21,9 +23,12 @@ class ModelOutput: multi_modal_inputs: dict[str, list] | None = None logprobs: list[float] | None = None # completion logprobs prompt_logprobs: list[float] | None = None # prompt logprobs aligned to prompt_ids + routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) prompt_length: int = 0 completion_length: int = 0 finish_reason: str | None = None + weight_version: int | None = None # policy version at time of generation + metrics: dict | None = None # per-turn server metrics (e.g. ttft, queue durations) def to_dict(self): return { @@ -39,10 +44,14 @@ def to_dict(self): "prompt_length": self.prompt_length, "completion_length": self.completion_length, "finish_reason": self.finish_reason, + "weight_version": self.weight_version, + "metrics": self.metrics, } @classmethod def from_dict(cls, data: dict): + from rllm.tools.tool_base import ToolCall + return cls( text=data.get("text"), content=data.get("content"), @@ -56,6 +65,8 @@ def from_dict(cls, data: dict): prompt_length=data.get("prompt_length", 0), completion_length=data.get("completion_length", 0), finish_reason=data.get("finish_reason"), + weight_version=data.get("weight_version"), + metrics=data.get("metrics"), ) @@ -65,10 +76,16 @@ class RolloutEngine: is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks def __init__(self, *args, **kwargs): - pass + self.weight_version: int = 0 + + # --- Model response --- + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + raise NotImplementedError(f"_get_model_response is not implemented for {self.__class__.__name__}") async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - raise NotImplementedError("get_model_response is not implemented") + result = await self._get_model_response(messages, **kwargs) + result.weight_version = self.weight_version + return result def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: """ @@ -80,13 +97,13 @@ async def get_token_output_from_token_input(self, token_input: TokenInput, **kwa """Obtain the token output from the given token input.""" raise NotImplementedError("get_token_output_from_token_input is not implemented") + @property + def supports_token_in_token_out(self) -> bool: + """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" + return False + async def wake_up(self): pass async def sleep(self): pass - - @property - def supports_token_in_token_out(self) -> bool: - """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" - return False diff --git a/rllm/experimental/rollout/tinker_engine.py b/rllm/experimental/rollout/tinker_engine.py index 77cdc24a2..0665777d2 100644 --- a/rllm/experimental/rollout/tinker_engine.py +++ b/rllm/experimental/rollout/tinker_engine.py @@ -186,6 +186,7 @@ def __init__( reasoning_effort: The effort level for reasoning (used when bypass_render_with_parser=True) renderer_name: The name of the renderer to use (used when bypass_render_with_parser=True) """ + super().__init__() self.base_url = base_url self.model_name = model_name self.max_prompt_length = max_prompt_length @@ -365,7 +366,7 @@ def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutp ) @override - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: """ Generate model response for a given set of messages. diff --git a/rllm/experimental/rollout/verl_engine.py b/rllm/experimental/rollout/verl_engine.py index 4540a8d1e..a9b5d99e3 100644 --- a/rllm/experimental/rollout/verl_engine.py +++ b/rllm/experimental/rollout/verl_engine.py @@ -24,6 +24,7 @@ def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokeni # reconstruct the servers list from the server_addresses and server_handles (Verl 0.7.0+) servers = zip(rollout_manager.server_addresses, rollout_manager.server_handles, strict=True) self.server_manager = AsyncLLMServerManager(config, servers=servers, load_balancer_handle=rollout_manager.global_load_balancer) + self.tokenizer = tokenizer self.processor = processor self.chat_parser = ChatTemplateParser.get_parser(tokenizer, processor=processor, disable_thinking=config.get("rllm", {}).get("disable_thinking", False)) @@ -74,12 +75,13 @@ async def get_token_output_from_token_input(self, token_input: TokenInput, **kwa return token_output @override - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: # these go to the parser tools = kwargs.pop("tools", []) accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) + reasoning_effort = kwargs.pop("reasoning_effort", "medium") - prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning) + prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning, reasoning_effort=reasoning_effort) request_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) # list[int] if any(msg.get("images", None) is not None and msg["role"] == "user" for msg in messages) and self.processor is not None: diff --git a/rllm/experimental/sync_coordinator.py b/rllm/experimental/sync_coordinator.py new file mode 100644 index 000000000..1405f3c0a --- /dev/null +++ b/rllm/experimental/sync_coordinator.py @@ -0,0 +1,129 @@ +"""SyncCoordinator: manages rollout quotas and parameter sync timing for fully-async training.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass + + +@dataclass +class SyncCoordinatorConfig: + mini_batch_size: int # episode groups per optimizer step + group_size: int # episodes per group (rollout.n) + staleness_threshold: float + trigger_parameter_sync_step: int + + @property + def max_rollout_quota(self) -> int: + """Max dispatches per sync window (Verl/AReaL formulation).""" + return int((1 + self.staleness_threshold) * self.trigger_parameter_sync_step * self.mini_batch_size) + + +class SyncCoordinator: + """Coordinates rollout scheduling and parameter sync between generation and training loops. + + Uses a per-sync-window dispatch counter (matching Verl/AReaL). The counter + resets only on weight sync, not on consume. This guarantees zero staleness + when staleness_threshold=0. + """ + + def __init__(self, config: SyncCoordinatorConfig): + self.config = config + + self._weight_version: int = 0 + self._quota_used: int = 0 # groups counting toward current sync window quota (includes carryover) + self._in_flight: int = 0 # groups dispatched but not yet consumed/filtered + self._steps_since_sync: int = 0 + self._total_syncs: int = 0 + + # Throttle — blocks generation when dispatched_since_sync >= max_rollout_quota + self._throttle_event: asyncio.Event = asyncio.Event() + self._throttle_event.set() + + # Generation pause — blocks generation during validation or weight sync + self._generation_paused: asyncio.Event = asyncio.Event() + self._generation_paused.set() + + # Tracks in-flight async rollout tasks for drain/wait logic + self._in_flight_tasks: set[asyncio.Task] = set() + + @property + def weight_version(self) -> int: + return self._weight_version + + # --- Throttle --- + + def on_group_dispatched(self) -> None: + """Generation loop dispatched one prompt (n rollouts).""" + self._quota_used += 1 + self._in_flight += 1 + if self._quota_used >= self.config.max_rollout_quota: + self._throttle_event.clear() + + def on_group_consumed(self) -> None: + """Training loop consumed one group from the buffer.""" + self._in_flight = max(0, self._in_flight - 1) + + def on_group_filtered(self) -> None: + """Accumulator filtered out a group. Decrements in-flight count.""" + self._in_flight = max(0, self._in_flight - 1) + + async def wait_for_throttle(self) -> None: + """Generation loop blocks here when dispatch window is full.""" + await self._throttle_event.wait() + + def has_quota(self) -> bool: + """Whether the generation loop can dispatch another group.""" + return self._quota_used < self.config.max_rollout_quota + + # --- Weight sync --- + + def on_training_step_complete(self) -> None: + self._steps_since_sync += 1 + + def should_sync(self) -> bool: + return self._steps_since_sync >= self.config.trigger_parameter_sync_step + + def on_sync_complete(self) -> None: + self._weight_version += 1 + self._steps_since_sync = 0 + self._total_syncs += 1 + # Reset dispatch window. In-flight items span the sync boundary — + # they were dispatched with old weights and count toward the new window. + self._quota_used = self._in_flight + if self._quota_used < self.config.max_rollout_quota: + self._throttle_event.set() + + # --- Generation pause (for validation / weight sync if partial_rollout is False) --- + + def pause_generation(self) -> None: + self._generation_paused.clear() + + def resume_generation(self) -> None: + self._generation_paused.set() + + async def wait_for_generation_allowed(self) -> None: + await self._generation_paused.wait() + + # --- In-flight task tracking --- + + def track_task(self, task: asyncio.Task) -> None: + """Register an in-flight rollout task.""" + self._in_flight_tasks.add(task) + task.add_done_callback(self._in_flight_tasks.discard) + + async def wait_for_drain(self) -> None: + """Wait for all in-flight rollout tasks to complete.""" + while self._in_flight_tasks: + await asyncio.sleep(0.1) + + def stats(self) -> dict: + return { + "async/weight_version": self._weight_version, + "async/dispatched_since_sync": self._quota_used - self._in_flight, + "async/quota_used": self._quota_used, + "async/in_flight_groups": self._in_flight, + "async/steps_since_sync": self._steps_since_sync, + "async/max_rollout_quota": self.config.max_rollout_quota, + "async/total_syncs": self._total_syncs, + } diff --git a/rllm/experimental/test_examples/opsd/math_opsd_workflow.py b/rllm/experimental/test_examples/opsd/math_opsd_workflow.py index 8f9488591..25852b76e 100644 --- a/rllm/experimental/test_examples/opsd/math_opsd_workflow.py +++ b/rllm/experimental/test_examples/opsd/math_opsd_workflow.py @@ -26,7 +26,13 @@ async def run(self, task: dict, uid: str, **kwargs) -> Episode: self.reset(task, uid) student_prompt = f"Problem: {task['question']}" - teacher_prompt = student_prompt + "\n\n" + f"Here is a reference solution:\n\n{task['ground_truth']}" + "\n\n" + "After understanding the reference solution, please try to solve this problem using your own approach below:" + teacher_prompt = ( + student_prompt + + "\n\n" + + f"Here is a reference solution:\n\n{task['ground_truth']}" + + "\n\n" + + "After understanding the reference solution, please try to solve this problem using your own approach below:" + ) student_messages = [{"role": "user", "content": student_prompt}] teacher_messages = [{"role": "user", "content": teacher_prompt}] diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index effda0473..45819f7c7 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -1,5 +1,7 @@ import asyncio +import logging import time +import uuid from abc import ABC, abstractmethod from collections import Counter, defaultdict from collections.abc import Callable, Iterable @@ -9,14 +11,17 @@ import numpy as np from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm from rllm.agents.agent import Episode, TrajectoryGroup from rllm.data import Dataset +from rllm.experimental.buffer import TrajectoryGroupBuffer from rllm.experimental.common.advantage import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, ) from rllm.experimental.common.config import ( + AsyncTrainingConfig, CompactFilteringConfig, RejectionSamplingConfig, TransformConfig, @@ -27,15 +32,22 @@ RejectionSamplingState, apply_rejection_sampling_and_filtering, ) -from rllm.experimental.common.transform import _default_traj_grouping_hook, transform_episodes_to_trajectory_groups -from rllm.experimental.common.visualization import visualize_trajectory_last_steps +from rllm.experimental.common.transform import ( + _default_traj_grouping_hook, + transform_episodes_to_trajectory_groups, +) +from rllm.experimental.common.visualization import print_metrics_table, visualize_trajectory_last_steps from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine +from rllm.experimental.metrics import MetricsAggregator from rllm.experimental.protocol import BackendProtocol from rllm.experimental.rollout import RolloutEngine +from rllm.experimental.sync_coordinator import SyncCoordinator, SyncCoordinatorConfig from rllm.utils import EpisodeLogger, Tracking, extract_source_metadata from rllm.workflows.store import Store from rllm.workflows.workflow import TerminationReason, Workflow +logger = logging.getLogger(__name__) + @dataclass class TrainerState: @@ -46,6 +58,7 @@ class TrainerState: epoch: int = 0 total_steps: int = 0 is_training: bool = True + weight_version: int = 0 # For timing and metrics timing_dict: dict = field(default_factory=dict) metrics: dict = field(default_factory=dict) @@ -129,11 +142,24 @@ def __init__( # Extract the TrajectoryGroup-specific estimator from kwargs self.traj_group_adv_estimator_map = traj_group_adv_estimator_map or {} + # TODO(kylemontgomery1): disaggregate UnitifiedTrainer.__init__ from engine/infra setup + self.backend = backend_cls(config=config, **(backend_args or {})) self._validate_and_setup_configs() self._setup_logging() + # Async training config + async_cfg = self.rllm_config.get("async_training", {}) + self.async_config = AsyncTrainingConfig( + enable=async_cfg.get("enable", False), + mini_batch_size=async_cfg.get("mini_batch_size", 1), + fwd_bwd_group_size=async_cfg.get("fwd_bwd_group_size", 1), + staleness_threshold=async_cfg.get("staleness_threshold", 0.0), + trigger_parameter_sync_step=async_cfg.get("trigger_parameter_sync_step", 1), + partial_rollout=async_cfg.get("partial_rollout", True), + ) + rollout_engine: RolloutEngine = self.backend.init_rollout_engine( cf_config=self.cf_config, transform_config=self.transform_config, @@ -201,6 +227,7 @@ def __init__( runtime=self._remote_runtime, gateway=self._gateway, session_timeout=remote_runtime_config.session_timeout, + n_parallel_tasks=self.rllm_config.workflow.n_parallel_tasks, episode_logger=self.episode_logger, ) else: @@ -247,6 +274,7 @@ def _validate_and_setup_configs(self): mode=rs_mode, min_partial_solve_tasks=self.rllm_config.rejection_sample.min_partial_solve_tasks, min_trajs_per_group=self.rllm_config.rejection_sample.min_trajs_per_group, + filter_uniform_groups=self.rllm_config.rejection_sample.get("filter_uniform_groups", False), ) # algorithm config (used for rLLM-native advantage computation) @@ -254,7 +282,7 @@ def _validate_and_setup_configs(self): estimator=self.rllm_config.algorithm.adv_estimator, estimator_map=self.traj_group_adv_estimator_map, # TODO(listar2000): see if we can make this configurable in config as well stepwise_advantage_mode=self.rllm_config.stepwise_advantage.mode, - norm_adv_by_std_in_grpo=self.rllm_config.stepwise_advantage.get("norm_adv_by_std_in_grpo", True), + norm_adv_by_std_in_grpo=self.rllm_config.algorithm.get("norm_adv_by_std_in_grpo", True), use_rllm=self.rllm_config.algorithm.get("use_rllm", False), use_precomputed_advantage=self.rllm_config.algorithm.get("use_precomputed_advantage", False), loss_fn=self.rllm_config.algorithm.get("loss_fn", None), @@ -290,6 +318,8 @@ def _setup_logging(self): # Main training loop methods # ========================================================================= + # TODO(kylemontgomery1): better seperation of on policy vs fully async training code + def fit(self): """Main training loop (sync entry point).""" asyncio.run(self.fit_async()) @@ -310,8 +340,7 @@ async def fit_async(self) -> None: await self.backend.on_train_start(trainer_state) if self.rllm_config.trainer.get("val_before_train", True): - val_metrics = await self._validate_async(trainer_state) - pprint(f"Initial validation metrics: {val_metrics}") + await self._validate_async(trainer_state) if self.rllm_config.trainer.get("val_only", False): return @@ -324,7 +353,16 @@ async def fit_async(self) -> None: await self.backend.on_train_end(trainer_state) async def _fit_async(self, trainer_state: TrainerState) -> None: - """Internal async main training loop.""" + """Dispatch to sync or concurrent training based on config.""" + # TODO(listar2000): after some benchmarking, maybe we just keep the fully-async and treat on-policy as a special case. + if self.async_config.enable: + await self._fit_fully_async(trainer_state) + else: + await self._fit_on_policy(trainer_state) + + async def _fit_on_policy(self, trainer_state: TrainerState) -> None: + """Synchronous training loop (the most vanilla, standalone case that does not support minibatching or off-policy training).""" + # TODO(kylemontgomery1): dataloader should be backend-agnostic train_dataloader: Iterable = self.backend.get_dataloader(self.train_dataset, trainer_state) break_via_total_batches = False # used to break the training loop via the `total_batches` parameter use_total_batches = self.rllm_config.trainer.get("total_batches") is not None and self.rllm_config.trainer.total_batches > 0 @@ -351,6 +389,7 @@ async def _fit_async(self, trainer_state: TrainerState) -> None: await self._train_batch_async(batch, trainer_state) await self.backend.on_batch_end(trainer_state) + print_metrics_table(trainer_state.metrics, trainer_state.global_step) self.logger.log( data=trainer_state.metrics, step=trainer_state.global_step, @@ -373,13 +412,13 @@ async def _fit_async(self, trainer_state: TrainerState) -> None: # final validation after training if self.rllm_config.trainer.test_freq > 0: - val_metrics = await self._validate_async(trainer_state) - pprint(f"Final validation metrics: {val_metrics}") + await self._validate_async(trainer_state) async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> None: """Train a batch (async implementation).""" self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=trainer_state.epoch) + # TODO(kylemontgomery1): episode generation should be backend-agnostic # stage 1: generate episodes (async) and collect metrics (sync) trainer_state.episodes = await self.backend.generate_episodes(batch, agent_workflow_engine=self.agent_workflow_engine, is_validation=False) if not trainer_state.has_episodes: @@ -413,6 +452,7 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N await self.backend.process_backend_batch(trainer_state) assert trainer_state.has_backend_batch, "Backend batch is not transformed or processed successfully" + # TODO(kylemontgomery1): compute advantages should be backend-agnostic # stage 6: compute advantages (async) await self.backend.compute_advantages(trainer_state, self.algorithm_config) @@ -435,6 +475,262 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N for r in TerminationReason: trainer_state.metrics[f"batch/termination_reason/{r.value}"] = termination_counts[r.value] / total_counts + # ========================================================================= + # Fully-asynchronous training pipeline + # ========================================================================= + + async def _fit_fully_async(self, trainer_state: TrainerState) -> None: + """Fully-async generation + training with group-level streaming.""" + assert self.config.data.train_batch_size == 1, f"Async training requires train_batch_size=1, got {self.config.data.train_batch_size}" + assert not getattr(self.agent_workflow_engine, "raise_on_error", False), "Async training requires raise_on_error=False so that process_task_with_retry always returns an episode" + coord_config = SyncCoordinatorConfig( + mini_batch_size=self.async_config.mini_batch_size, + group_size=self.rllm_config.rollout.n, + staleness_threshold=self.async_config.staleness_threshold, + trigger_parameter_sync_step=self.async_config.trigger_parameter_sync_step, + ) + coordinator = SyncCoordinator(coord_config) + aggregator = MetricsAggregator() + buffer = TrajectoryGroupBuffer( + group_size=self.rllm_config.rollout.n, + coordinator=coordinator, + aggregator=aggregator, + algorithm_config=self.algorithm_config, + transform_config=self.transform_config, + cf_config=self.cf_config, + rs_config=self.rs_config, + episode_offload_dir=self.async_config.episode_offload_dir, + trajectory_group_offload_dir=self.async_config.trajectory_group_offload_dir, + ) + + # Compute total_steps for LR scheduling + train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) + use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + if use_total_batches: + trainer_state.total_steps = self.rllm_config.trainer.total_batches + else: + trainer_state.total_steps = len(train_dataloader) * self.rllm_config.trainer.total_epochs + + total_tasks = len(train_dataloader) * self.rllm_config.trainer.total_epochs + pbar = tqdm(total=total_tasks, desc="Tasks", unit="task") + buffer._pbar = pbar + + try: + gen_task = asyncio.create_task(self._generation_loop(trainer_state, buffer, coordinator)) + await self._training_loop(trainer_state, buffer, coordinator, aggregator) + if not gen_task.done(): + gen_task.cancel() + try: + await gen_task + except asyncio.CancelledError: + pass + finally: + pbar.close() + + async def _generation_loop( + self, + trainer_state: TrainerState, + buffer: TrajectoryGroupBuffer, + coordinator: SyncCoordinator, + ) -> None: + """Generate episodes and stream to TrajectoryGroupBuffer.""" + group_size = self.rllm_config.rollout.n + + try: + for epoch in range(self.rllm_config.trainer.total_epochs): + await self.backend.on_epoch_start(trainer_state) + train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) + self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) + + for batch in train_dataloader: + task = batch[0] + + await coordinator.wait_for_generation_allowed() + if not coordinator.has_quota(): + await coordinator.wait_for_throttle() + coordinator.on_group_dispatched() + + task_id = str(uuid.uuid4()) + for rollout_idx in range(group_size): + + async def _run_rollout(t=task, tid=task_id, ridx=rollout_idx): + _, _, _, episode = await self.agent_workflow_engine.process_task_with_retry(task=t, task_id=tid, rollout_idx=ridx, result_idx=0) + await buffer.add_episode(tid, episode) + + t = asyncio.create_task(_run_rollout()) + coordinator.track_task(t) + + await self.backend.on_epoch_end(trainer_state) + + await coordinator.wait_for_drain() + finally: + buffer.mark_generation_complete() + + async def _training_loop( + self, + trainer_state: TrainerState, + buffer: TrajectoryGroupBuffer, + coordinator: SyncCoordinator, + aggregator: MetricsAggregator, + ) -> None: + """Consume task batches from buffer, run forward-backward + optimizer step.""" + mini_batch_size = self.async_config.mini_batch_size + fwd_bwd_group_size = self.async_config.fwd_bwd_group_size + num_fwd_bwd_passes = mini_batch_size // fwd_bwd_group_size + use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + rollout_engine = getattr(self.agent_workflow_engine, "rollout_engine", None) + + while True: + trainer_state.reset_batch() + step_start = time.perf_counter() + weight_versions = [] + all_trajectory_groups: list[TrajectoryGroup] = [] + all_episodes: list[Episode] = [] + groups_consumed = 0 + buffer_wait_time = 0.0 + done = False + + buffered = buffer._queue.qsize() + logger.info( + f"[TrainingLoop] Step {trainer_state.global_step}: waiting for {mini_batch_size} task batches ({num_fwd_bwd_passes} fwd-bwd passes x {fwd_bwd_group_size} groups), {buffered} buffered" + ) + + # 1. Pull mini_batch_size task batches total, split into + # num_fwd_bwd_passes forward-backward passes of fwd_bwd_group_size each. + for pass_idx in range(num_fwd_bwd_passes): + chunk_groups: list[TrajectoryGroup] = [] + + for _ in range(fwd_bwd_group_size): + t_wait = time.perf_counter() + task_batch = await buffer.get() + buffer_wait_time += time.perf_counter() - t_wait + if task_batch is None: + done = True + break + + coordinator.on_group_consumed() + groups_consumed += 1 + + for group in task_batch.groups: + weight_versions.append(group.weight_version) + chunk_groups.extend(task_batch.groups) + all_trajectory_groups.extend(task_batch.groups) + all_episodes.extend(task_batch.episodes) + + if not chunk_groups or done: + break + + # Forward-backward on this chunk + trainer_state.trajectory_groups = chunk_groups + + if trainer_state.has_trajectory_groups: + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: fwd-bwd pass {pass_idx + 1}/{num_fwd_bwd_passes} ({len(chunk_groups)} groups)") + await self.backend.on_batch_start(trainer_state) + trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) + await self.backend.process_backend_batch(trainer_state) + + # Drain per-chunk backend metrics into aggregator + aggregator.record_dict(trainer_state.metrics) + trainer_state.metrics = {} + + # Only run optimizer step on a full batch + if groups_consumed < mini_batch_size: + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: incomplete batch ({groups_consumed}/{mini_batch_size}), stopping") + break + + # 2. Optimizer step + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: optimizer step") + await self.backend.update_policy(trainer_state) + + # 3. Capture pre-sync metrics (before weight sync resets coordinator state) + staleness_values = [coordinator.weight_version - v for v in weight_versions] + aggregator.record("async/staleness_mean", float(np.mean(staleness_values))) + aggregator.record("async/staleness_min", float(np.min(staleness_values))) + aggregator.record("async/staleness_max", float(np.max(staleness_values))) + aggregator.record("async/groups_consumed", groups_consumed) + aggregator.record("time/buffer_wait", buffer_wait_time) + pre_sync_coordinator_stats = coordinator.stats() + pre_sync_buffer_stats = buffer.stats() + + # 4. Weight sync + coordinator.on_training_step_complete() + sync_time = 0.0 + if coordinator.should_sync(): + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: triggering weight sync") + t0 = time.perf_counter() + await self._perform_weight_sync(trainer_state, coordinator, rollout_engine) + sync_time = time.perf_counter() - t0 + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: weight sync complete ({sync_time:.2f}s)") + if sync_time > 0: + aggregator.record("time/weight_sync", sync_time) + aggregator.record("time/step", time.perf_counter() - step_start) + + # Set all trajectory groups and stripped episodes for visualization/logging + trainer_state.trajectory_groups = all_trajectory_groups + trainer_state.episodes = all_episodes + + if self.tokenizer is not None and trainer_state.has_trajectory_groups: + visualize_trajectory_last_steps( + trainer_state.trajectory_groups, + tokenizer=self.tokenizer, + max_steps_to_visualize=2, + show_workflow_metadata=True, + ) + + # 5. Flush aggregator and merge pre-sync snapshots into trainer_state.metrics + trainer_state.metrics.update(aggregator.flush()) + trainer_state.metrics.update(pre_sync_buffer_stats) + trainer_state.metrics.update(pre_sync_coordinator_stats) + + # 6. Compute derived metrics + step_time = trainer_state.metrics.get("time/step", 1.0) + trainer_state.metrics["async/trainer_idle_ratio"] = buffer_wait_time / max(step_time, 1e-9) + + # 7. on_batch_end writes backend metrics (progress, optim, timing) + await self.backend.on_batch_end(trainer_state) + + # 7. Print and log + print_metrics_table(trainer_state.metrics, trainer_state.global_step) + self.logger.log( + data=trainer_state.metrics, + step=trainer_state.global_step, + episodes=trainer_state.episodes, + trajectory_groups=trainer_state.trajectory_groups, + ) + + # Periodic validation + if self.rllm_config.trainer.test_freq > 0 and trainer_state.global_step % self.rllm_config.trainer.test_freq == 0: + await self._validate_async_with_pause(trainer_state, coordinator) + + trainer_state.global_step += 1 + + if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: + break + + async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: SyncCoordinator, rollout_engine: RolloutEngine | None) -> None: + """Synchronize weights between training and rollout engines.""" + if not self.async_config.partial_rollout: + coordinator.pause_generation() + await coordinator.wait_for_drain() + + trainer_state.weight_version = coordinator.weight_version + 1 + await self.backend.on_policy_updated(trainer_state) + if rollout_engine is not None: + rollout_engine.weight_version = trainer_state.weight_version + coordinator.on_sync_complete() + + if not self.async_config.partial_rollout: + coordinator.resume_generation() + + async def _validate_async_with_pause(self, trainer_state: TrainerState, coordinator: SyncCoordinator) -> dict: + """Validation with dispatch-level pause. Waits for workflows to drain, then runs validation.""" + coordinator.pause_generation() + await coordinator.wait_for_drain() + try: + return await self._validate_async(trainer_state) + finally: + coordinator.resume_generation() + async def _validate_async(self, trainer_state: TrainerState) -> dict: """Validate the model (async implementation).""" n_val_samples = self.rllm_config.rollout.n_val @@ -453,7 +749,7 @@ async def _validate_async(self, trainer_state: TrainerState) -> dict: for batch in val_dataloader: # Generate episodes and transform to trajectory groups val_episodes = await self.backend.generate_episodes(batch, agent_workflow_engine=self.agent_workflow_engine, is_validation=True) - val_trajectory_groups, transform_metrics = transform_episodes_to_trajectory_groups(val_episodes, self.transform_config, self.cf_config, traj_grouping_hook=self.traj_grouping_hook) + val_trajectory_groups, _ = transform_episodes_to_trajectory_groups(val_episodes, self.transform_config, self.cf_config, traj_grouping_hook=self.traj_grouping_hook) reward_metrics = collect_reward_and_advantage_from_trajectory_groups(val_trajectory_groups, self.algorithm_config, collect_advantage=False) is_correct_lst.extend([episode.is_correct for episode in val_episodes]) @@ -466,7 +762,7 @@ async def _validate_async(self, trainer_state: TrainerState) -> dict: for key, value in episode.metrics.items(): workflow_metrics_by_source[data_source][key].append(float(value)) - for key, value in (transform_metrics | reward_metrics).items(): + for key, value in reward_metrics.items(): val_metrics[f"val/{key}"].append(value) test_end = time.perf_counter() @@ -496,6 +792,7 @@ async def _validate_async(self, trainer_state: TrainerState) -> dict: # post-process the val metrics to reduce any "list values" into scalars reduce_metrics_lists(val_metrics) + print_metrics_table(val_metrics, trainer_state.global_step, title="Validation") self.logger.log(data=val_metrics, step=trainer_state.global_step) await self.backend.on_validation_end(trainer_state) return val_metrics diff --git a/rllm/experimental/verl/patch.py b/rllm/experimental/verl/patch.py index 81183d7b8..cc74c9f71 100644 --- a/rllm/experimental/verl/patch.py +++ b/rllm/experimental/verl/patch.py @@ -10,48 +10,10 @@ logger = logging.getLogger(__name__) -_VERL_ACTOR_PATCHED = False _VERL_DYNAMIC_BATCH_PATCHED = False _VLLM_SDK_PATCHED = False -# --------------------------------------------------------------------------- -# Verl actor: per-call policy loss mode override -# --------------------------------------------------------------------------- - - -def patch_verl_actor_for_loss_override() -> None: - """Patch ``DataParallelPPOActor.update_policy`` to support per-call loss mode. - - When ``data.meta_info`` contains ``"policy_loss_mode_override"``, the - actor temporarily uses that loss mode instead of the one baked into - ``self.config.policy_loss.loss_mode``. The original config value is - restored after the call (even on exception). - """ - global _VERL_ACTOR_PATCHED - if _VERL_ACTOR_PATCHED: - return - - from verl.workers.actor.dp_actor import DataParallelPPOActor - - _original_update_policy = DataParallelPPOActor.update_policy - - def _patched_update_policy(self, data): - override = data.meta_info.get("policy_loss_mode_override") - if override is not None: - original = self.config.policy_loss.get("loss_mode", "vanilla") - self.config.policy_loss["loss_mode"] = override - try: - return _original_update_policy(self, data) - finally: - self.config.policy_loss["loss_mode"] = original - return _original_update_policy(self, data) - - DataParallelPPOActor.update_policy = _patched_update_policy - _VERL_ACTOR_PATCHED = True - logger.info("Patched DataParallelPPOActor.update_policy for per-call loss mode override") - - # --------------------------------------------------------------------------- # Verl dynamic batch: sync micro-batch counts across DP ranks # --------------------------------------------------------------------------- diff --git a/rllm/experimental/verl/verl_advantage.py b/rllm/experimental/verl/verl_advantage.py index e6a092dc9..5c8ca2cc3 100644 --- a/rllm/experimental/verl/verl_advantage.py +++ b/rllm/experimental/verl/verl_advantage.py @@ -31,11 +31,8 @@ def compute_advantage_verl(batch: DataProto, config: DictConfig) -> tuple[DataPr is_last_step = batch.non_tensor_batch["is_last_step"] last_step_indices = np.where(is_last_step)[0] not_last_step_indices = np.where(~is_last_step)[0] - non_last_step_batch = batch.select_idxs(not_last_step_indices) - batch = batch.select_idxs(last_step_indices) - batch = compute_advantage( - batch, + adv_kwargs = dict( adv_estimator=config.algorithm.adv_estimator, gamma=config.algorithm.gamma, lam=config.algorithm.lam, @@ -44,6 +41,17 @@ def compute_advantage_verl(batch: DataProto, config: DictConfig) -> tuple[DataPr config=config.algorithm, ) + if len(not_last_step_indices) == 0: + # All steps are last steps (e.g. single-step trajectories) — compute directly, no broadcast needed + batch = compute_advantage(batch, **adv_kwargs) + return batch, metrics + + # Multi-step: split by last step, compute advantages on last steps, broadcast to earlier steps + non_last_step_batch = batch.select_idxs(not_last_step_indices) + batch = batch.select_idxs(last_step_indices) + + batch = compute_advantage(batch, **adv_kwargs) + _stepwise_advantage_broadcast(batch, non_last_step_batch, config) batch = DataProto.concat([batch, non_last_step_batch]) @@ -73,7 +81,9 @@ def _stepwise_advantage_broadcast(last_step_batch: DataProto, non_last_step_batc traj_ep_to_scalar_adv[(traj_id, eps_id)] = scalar - scalar_rows = torch.stack([torch.full_like(tgt_mask[i], fill_value=traj_ep_to_scalar_adv[(traj_id, eps_id)], dtype=torch.float32) for i, (traj_id, eps_id) in enumerate(zip(tgt_traj_ids, tgt_eps_ids, strict=False))]) + scalar_rows = torch.stack( + [torch.full_like(tgt_mask[i], fill_value=traj_ep_to_scalar_adv[(traj_id, eps_id)], dtype=torch.float32) for i, (traj_id, eps_id) in enumerate(zip(tgt_traj_ids, tgt_eps_ids, strict=False))] + ) final_advantage = scalar_rows * tgt_mask non_last_step_batch.batch["advantages"] = final_advantage diff --git a/rllm/experimental/verl/verl_backend.py b/rllm/experimental/verl/verl_backend.py index 8050974cd..06bbd49e7 100644 --- a/rllm/experimental/verl/verl_backend.py +++ b/rllm/experimental/verl/verl_backend.py @@ -9,6 +9,7 @@ import math import uuid +from collections import defaultdict from collections.abc import Iterable from functools import reduce from typing import TYPE_CHECKING, Any @@ -26,7 +27,9 @@ ) from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager from verl.trainer.ppo.utils import Role, WorkerType +from verl.utils import tensordict_utils as tu from verl.utils.metric import reduce_metrics +from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding from rllm.agents.agent import Episode from rllm.data import Dataset @@ -51,6 +54,36 @@ _VERL_KNOWN_LOSSES: set[str] | None = None +class CustomPPOLoss: + """Wraps Verl's ``ppo_loss`` to support per-call loss mode override. + + When the data TensorDict contains ``policy_loss_mode_override``, + the loss mode is temporarily overridden for that call. Instances + are serialised via cloudpickle and sent to remote workers through + Verl's ``set_loss_fn`` RPC. + """ + + def __init__(self, config): + # Convert OmegaConf DictConfig → ActorConfig dataclass + from verl.utils.config import omega_conf_to_dataclass + + self.config = omega_conf_to_dataclass(config) + + def __call__(self, model_output, data, dp_group=None): + from verl.utils import tensordict_utils as _tu + from verl.workers.utils.losses import ppo_loss + + override = _tu.get(data, "policy_loss_mode_override", default=None) + if override is not None: + original = self.config.policy_loss.get("loss_mode", "vanilla") + self.config.policy_loss["loss_mode"] = override + try: + return ppo_loss(self.config, model_output, data, dp_group) + finally: + self.config.policy_loss["loss_mode"] = original + return ppo_loss(self.config, model_output, data, dp_group) + + def _get_verl_known_losses() -> set[str]: """Lazily load the set of registered Verl policy loss function names.""" global _VERL_KNOWN_LOSSES @@ -135,9 +168,8 @@ def init_rollout_engine(self, **kwargs) -> RolloutEngine: VerlEngine: The initialized rollout engine. """ # Apply Verl patches - from rllm.experimental.verl.patch import patch_verl_actor_for_loss_override, patch_verl_dynamic_batch_sync + from rllm.experimental.verl.patch import patch_verl_dynamic_batch_sync - patch_verl_actor_for_loss_override() patch_verl_dynamic_batch_sync() # If SDK is enabled, instrument vLLM replicas before creating workers @@ -150,7 +182,10 @@ def init_rollout_engine(self, **kwargs) -> RolloutEngine: assert self.async_rollout_manager is not None, "async_rollout_manager is not available. Issues with RayPPOTrainer's `init_workers()` function." - # Step 2: initialize the rollout engine + # Step 2: replace loss function on remote workers to support per-role loss override + self.actor_rollout_wg.set_loss_fn(CustomPPOLoss(self.config.actor_rollout_ref.actor)) + + # Step 3: initialize the rollout engine self.rollout_engine = VerlEngine( config=self.config, rollout_manager=self.async_rollout_manager, @@ -158,7 +193,7 @@ def init_rollout_engine(self, **kwargs) -> RolloutEngine: processor=self.processor, ) - # Step 3: store the algorithm config + # Step 4: store the algorithm config self.algorithm_config = kwargs.get("algorithm_config") return self.rollout_engine @@ -173,6 +208,14 @@ def validate_config(self) -> None: """Validate verl-specific configuration settings.""" assert self.config.actor_rollout_ref.rollout.mode == "async", "Only async rollout mode is supported for VerlBackend" assert self.use_rm is False, "Reward models are not supported. Rewards should be assigned using a reward function in the workflow or environment." + # Enforce new EngineWorker path (TensorDict + no-padding) + legacy_mode = self.config.trainer.get("use_legacy_worker_impl", "auto") + if legacy_mode != "disable": # force to disable legacy worker impl + logger.warning( + "VerlBackend forces use_legacy_worker_impl='disable' (new EngineWorker path), got '{legacy_mode}'." + "If you insist on using the legacy worker implementation, consider using the older agent workflow trainer." + ) + self.config.trainer.use_legacy_worker_impl = "disable" if self.config.rllm.stepwise_advantage.mode != "broadcast": # automatically set the stepwise_advantage_mode to "broadcast", the warning is already shown in AlgorithmConfig.from_config self.config.rllm.stepwise_advantage.mode = "broadcast" @@ -290,41 +333,13 @@ def _pad_dataproto_to_world_size(self, batch: DataProto) -> DataProto: batch.non_tensor_batch["is_valid"][pad_start:pad_end] = False return batch - def _pad_dataproto_for_megatron_training(self, batch: DataProto) -> DataProto: - """Pad batch for megatron actor update using uniform random sampling. - - Megatron's make_minibatch_iterator requires per-GPU batch to be divisible - by ppo_mini_batch_size. Padded samples are randomly sampled real data that - participate in training as redundant samples from the same distribution. - """ - - world_size = self._get_dp_world_size() - if world_size is None: - return batch - - ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size - rollout_n = self.config.actor_rollout_ref.rollout.n - divisor = math.lcm(world_size, ppo_mini_batch_size * rollout_n) - - batch = self._remove_padding(batch) - original_batch_size = batch.batch["prompts"].shape[0] - - pad_size = (-original_batch_size) % divisor - if pad_size > 0: - pad_indices = np.random.choice(original_batch_size, size=pad_size, replace=True) - pad_batch = batch.select_idxs(pad_indices) - batch = DataProto.concat([batch, pad_batch]) - # Deliberately skip setting is_pad_step/is_last_step/is_valid on padded rows. - # This method is only called in update_policy (the last pipeline stage before - # actor update), so _remove_padding is never called on this output. The padded - # rows are real duplicates that participate in training with real advantages. - return batch - async def process_backend_batch(self, trainer_state: TrainerState, **kwargs) -> None: """Compute step-level values: old_log_probs, ref_log_probs, critic values. - Reuses logic from AgentWorkflowPPOTrainer._compute_step_level_values. - Note: This is async for protocol compatibility but operations are sync (blocking) + Uses the new EngineWorker path: converts DataProto to TensorDict in + no-padding format, calls workers, converts results back to padded + DataProto. The no-padding TensorDict (batch_td) is created once and + reused across all inference worker calls. """ metrics = trainer_state.metrics timing_dict = trainer_state.timing_dict @@ -339,29 +354,42 @@ async def process_backend_batch(self, trainer_state: TrainerState, **kwargs) -> batch = self._pad_dataproto_to_world_size(batch=batch) self._balance_batch(batch, metrics=metrics) + # Set meta_info needed by workers batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - # get images_seqlens - if "multi_modal_inputs" in batch.non_tensor_batch.keys(): + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + if "multi_modal_inputs" in batch.non_tensor_batch: images_seqlens_all = [] for multi_modal_input in batch.non_tensor_batch["multi_modal_inputs"]: - if "image_grid_thw" not in multi_modal_input.keys(): + if "image_grid_thw" not in multi_modal_input: continue images_seqlens_all.extend(multi_modal_input["images_seqlens"].tolist()) batch.meta_info["images_seqlens"] = images_seqlens_all + # Convert to TensorDict + no-padding ONCE — reused for all inference calls. + # to_tensordict() does NOT mutate the original DataProto. + # left_right_2_no_padding mutates batch_td in-place. + batch_td = batch.to_tensordict() + batch_td = left_right_2_no_padding(batch_td) + + # --- Compute old_log_probs --- with simple_timer("old_log_probs", timing_dict): - # Compute old_log_probs from actor - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] + tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False) + output = self.actor_rollout_wg.compute_log_prob(batch_td) + log_probs = no_padding_2_padding(tu.get(output, "log_probs"), batch_td) + entropy = no_padding_2_padding(tu.get(output, "entropy"), batch_td) + + # Entropy metric (for logging only) response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + entropy_agg = agg_loss(loss_mat=entropy, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) metrics["actor/entropy"] = entropy_agg.detach().item() - old_log_prob.batch.pop("entropys") + + # Merge old_log_probs back into the padded DataProto + old_log_prob = DataProto.from_tensordict(tu.get_tensordict({"old_log_probs": log_probs.float()})) batch = batch.union(old_log_prob) # Compute rollout log prob diff if available - if "rollout_log_probs" in batch.batch.keys(): + if "rollout_log_probs" in batch.batch: rollout_old_log_probs = batch.batch["rollout_log_probs"] actor_old_log_probs = batch.batch["old_log_probs"] attention_mask = batch.batch["attention_mask"] @@ -381,19 +409,27 @@ async def process_backend_batch(self, trainer_state: TrainerState, **kwargs) -> } metrics.update(rollout_probs_diff_metrics) - # Compute reference log_probs if using reference policy + # --- Compute reference log_probs (reuse batch_td) --- if self.use_reference_policy: with simple_timer("ref", timing_dict): + tu.assign_non_tensor(batch_td, calculate_entropy=False, compute_loss=False) if not self.ref_in_actor: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + ref_output = self.ref_policy_wg.compute_ref_log_prob(batch_td) else: - ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + tu.assign_non_tensor(batch_td, no_lora_adapter=True) + ref_output = self.actor_rollout_wg.compute_log_prob(batch_td) + ref_lp = no_padding_2_padding(tu.get(ref_output, "log_probs"), batch_td) + ref_log_prob = DataProto.from_tensordict(tu.get_tensordict({"ref_log_prob": ref_lp.float()})) batch = batch.union(ref_log_prob) - # Compute critic values if using critic + # --- Compute critic values --- if self.use_critic: with simple_timer("values", timing_dict): - values = self.critic_wg.compute_values(batch) + tu.assign_non_tensor(batch_td, compute_loss=False) + values_output = self.critic_wg.infer_batch(batch_td) + values_output = values_output.get() # blocking await on future + values_tensor = no_padding_2_padding(tu.get(values_output, "values"), batch_td) + values = DataProto.from_tensordict(tu.get_tensordict({"values": values_tensor.float()})) batch = batch.union(values) # Mask truncated samples if configured @@ -426,55 +462,77 @@ async def compute_advantages(self, trainer_state: TrainerState, algorithm_config async def update_policy(self, trainer_state: TrainerState, **kwargs) -> None: """Update actor and critic policies. - Note: This is async for protocol compatibility but operations are sync (blocking) + Uses the new EngineWorker path: converts DataProto to TensorDict in + no-padding format with training metadata, then calls workers. The new + workers handle micro-batching internally, so no manual re-padding is + needed before the update. """ global_steps = trainer_state.global_step - batch = trainer_state.backend_batch - - # Re-pad batch before gradient updates. For megatron, use the larger - # divisor (world_size * ppo_mini_batch_size) with uniform random sampling. - actor_strategy = self.config.actor_rollout_ref.actor.get("strategy", None) - if actor_strategy == "megatron": - batch = self._pad_dataproto_for_megatron_training(batch) - else: - batch = self._pad_dataproto_to_world_size(batch) - trainer_state.backend_batch = batch + batch: DataProto = trainer_state.backend_batch # type: ignore[assignment] # Update critic - # NOTE: The megatron-padded batch (with duplicated samples) is also used for - # the critic update. This is acceptable because: (1) GRPO disables the critic, - # and (2) if a megatron critic is used, it needs the same divisibility and the - # duplicated samples are real data from the same distribution. If the critic has - # a different ppo_mini_batch_size, the divisor may need to account for it. if self.use_critic: with simple_timer("update_critic", trainer_state.timing_dict): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - trainer_state.metrics.update(critic_output_metrics) + critic_td = batch.to_tensordict() + critic_td = left_right_2_no_padding(critic_td) + ppo_mbs_critic = self.config.critic.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n + tu.assign_non_tensor( + critic_td, + global_batch_size=ppo_mbs_critic, + mini_batch_size=ppo_mbs_critic, + epochs=self.config.critic.ppo_epochs, + seed=self.config.critic.data_loader_seed, + dataloader_kwargs={"shuffle": self.config.critic.shuffle}, + ) + critic_output = self.critic_wg.train_mini_batch(critic_td) + critic_output = critic_output.get() + critic_output_metrics = tu.get(critic_output, "metrics") + trainer_state.metrics.update(reduce_metrics(critic_output_metrics)) # Update actor (after critic warmup) if self.config.trainer.get("critic_warmup", 0) <= global_steps: with simple_timer("update_actor", trainer_state.timing_dict): self._update_actor_with_loss_routing(batch, trainer_state) - def _update_actor_with_loss_routing(self, batch, trainer_state: TrainerState) -> None: + def _update_actor_with_loss_routing(self, batch: DataProto, trainer_state: TrainerState) -> None: """Update actor with per-loss-group splitting when ``loss_fn_map`` is set. Roles that share the same policy loss function are grouped together into a single ``update_actor`` call, minimising the number of - optimiser steps. + optimiser steps. Each (sub-)batch is converted to TensorDict + + no-padding format with training metadata before being sent to the + worker. """ - from collections import defaultdict - - import numpy as np - loss_fn_map = self.algorithm_config.loss_fn_map if self.algorithm_config is not None else {} group_roles = batch.non_tensor_batch.get("group_roles") if hasattr(batch, "non_tensor_batch") and batch.non_tensor_batch is not None else None + # Common training metadata + rollout_n = self.config.actor_rollout_ref.rollout.n + actor_cfg = self.config.actor_rollout_ref.actor + ppo_mbs = actor_cfg.ppo_mini_batch_size * rollout_n + + def _send_actor_update(sub_batch: DataProto, loss_override: str | None = None) -> None: + """Convert DataProto to TensorDict, inject metadata, send to worker.""" + batch_td = sub_batch.to_tensordict() + batch_td = left_right_2_no_padding(batch_td) + metadata: dict[str, Any] = dict( + calculate_entropy=(actor_cfg.entropy_coeff != 0.0), + global_batch_size=ppo_mbs, + mini_batch_size=ppo_mbs, + epochs=actor_cfg.ppo_epochs, + seed=actor_cfg.data_loader_seed, + dataloader_kwargs={"shuffle": actor_cfg.shuffle}, + ) + if loss_override is not None: + metadata["policy_loss_mode_override"] = loss_override + tu.assign_non_tensor(batch_td, **metadata) + actor_output = self.actor_rollout_wg.update_actor(batch_td) + actor_metrics = tu.get(actor_output, "metrics") + trainer_state.metrics.update(reduce_metrics(actor_metrics)) + # Fast path: no per-role loss overrides or no role annotations. if not loss_fn_map or group_roles is None: - actor_output = self.actor_rollout_wg.update_actor(batch) - trainer_state.metrics.update(reduce_metrics(actor_output.meta_info["metrics"])) + _send_actor_update(batch) return # Resolve each role to a Verl loss name with validation + fallback. @@ -494,10 +552,7 @@ def _update_actor_with_loss_routing(self, batch, trainer_state: TrainerState) -> if len(loss_to_roles) <= 1: # All roles share the same loss — single update. - loss_name = next(iter(loss_to_roles)) - batch.meta_info["policy_loss_mode_override"] = loss_name - actor_output = self.actor_rollout_wg.update_actor(batch) - trainer_state.metrics.update(reduce_metrics(actor_output.meta_info["metrics"])) + _send_actor_update(batch, next(iter(loss_to_roles))) return # Multiple distinct losses: split batch by loss group, update each. @@ -506,9 +561,7 @@ def _update_actor_with_loss_routing(self, batch, trainer_state: TrainerState) -> mask = np.array([r in role_set for r in group_roles]) indices = np.where(mask)[0] sub_batch = batch[indices] - sub_batch.meta_info["policy_loss_mode_override"] = loss_name - actor_output = self.actor_rollout_wg.update_actor(sub_batch) - trainer_state.metrics.update(reduce_metrics(actor_output.meta_info["metrics"])) + _send_actor_update(sub_batch, loss_name) def shutdown(self) -> None: """Placeholder, just use the BackendProtocol's default shutdown method.""" @@ -569,10 +622,10 @@ async def on_validation_start(self, trainer_state: TrainerState) -> bool: return False else: trainer_state.is_training = False - self.rollout_engine.validate = True # type: ignore[attr-defined] + self.rollout_engine.is_validation = True return True async def on_validation_end(self, trainer_state: TrainerState) -> None: """Called at the end of validation.""" trainer_state.is_training = True - self.rollout_engine.validate = False # type: ignore[attr-defined] + self.rollout_engine.is_validation = False diff --git a/rllm/patches/vllm_instrumentation.py b/rllm/patches/vllm_instrumentation.py index 3e97f6291..e407aad55 100644 --- a/rllm/patches/vllm_instrumentation.py +++ b/rllm/patches/vllm_instrumentation.py @@ -189,7 +189,10 @@ async def _generate_interceptor(): response_logprobs_lists = [] for output in res.outputs: if output.logprobs: # list[LogprobsOnePosition]: list of dict[int, Logprob] - curr_log_probs = [output.logprobs[i][token_id].logprob if i < len(output.logprobs) and token_id in output.logprobs[i] else float("-inf") for i, token_id in enumerate(output.token_ids)] + curr_log_probs = [ + output.logprobs[i][token_id].logprob if i < len(output.logprobs) and token_id in output.logprobs[i] else float("-inf") + for i, token_id in enumerate(output.token_ids) + ] response_logprobs_lists.append(curr_log_probs) else: response_logprobs_lists.append([]) diff --git a/rllm/rewards/code_reward.py b/rllm/rewards/code_reward.py index 4d9845a62..28e4807f5 100644 --- a/rllm/rewards/code_reward.py +++ b/rllm/rewards/code_reward.py @@ -19,7 +19,6 @@ # from rllm.rewards.code_utils.code_contests import run_test as code_contests_run_test from rllm.rewards.code_utils.livecodebench import run_test as lcb_run_test -from rllm.rewards.code_utils.taco import run_test as taco_run_test from rllm.rewards.reward_types import RewardConfig, RewardOutput, RewardType from rllm.tools.code_tools.code_tool import CodeTool from rllm.tools.code_tools.together_tool import TogetherCodeTool @@ -175,6 +174,8 @@ def postprocess_lcb_sample(sample): # https://huggingface.co/datasets/PrimeIntellect/verifiable-coding-problems def primeintellect_check_correctness(tests, code, use_tci=False): + from rllm.rewards.code_utils.taco import run_test as taco_run_test + if isinstance(tests, str): try: tests = ast.literal_eval(tests) @@ -247,7 +248,17 @@ def lcb_check_correctness_v2(sample, generation, timeout=6, debug=False): # Create detailed test results in_outs = json.loads(sample["input_output"]) detailed_results["total_tests"] = len(result[0]) - detailed_results["test_results"] = [{"input": inp, "expected": out, "passed": res == True, "error": metadata_list[0].get("error", None), "error_message": metadata_list[0].get("error_message", None), "output": metadata_list[0].get("output", None)} for inp, out, res in zip(in_outs["inputs"], in_outs["outputs"], result[0], strict=False)] + detailed_results["test_results"] = [ + { + "input": inp, + "expected": out, + "passed": res == True, + "error": metadata_list[0].get("error", None), + "error_message": metadata_list[0].get("error_message", None), + "output": metadata_list[0].get("output", None), + } + for inp, out, res in zip(in_outs["inputs"], in_outs["outputs"], result[0], strict=False) + ] detailed_results["passed_tests"] = sum(1 for r in result[0] if r == True) detailed_results["all_passed"] = all(r == True for r in result[0]) diff --git a/rllm/rewards/code_utils/pyext2.py b/rllm/rewards/code_utils/pyext2.py index 5e33119f7..f950f9a58 100644 --- a/rllm/rewards/code_utils/pyext2.py +++ b/rllm/rewards/code_utils/pyext2.py @@ -2,7 +2,23 @@ __version__ = "0.7" -__all__ = ["overload", "RuntimeModule", "switch", "tail_recurse", "copyfunc", "set_docstring", "annotate", "safe_unpack", "modify_function", "assign", "fannotate", "compare_and_swap", "is_main", "call_if_main", "run_main"] +__all__ = [ + "overload", + "RuntimeModule", + "switch", + "tail_recurse", + "copyfunc", + "set_docstring", + "annotate", + "safe_unpack", + "modify_function", + "assign", + "fannotate", + "compare_and_swap", + "is_main", + "call_if_main", + "run_main", +] import inspect import sys diff --git a/rllm/rewards/code_utils/swebench.py b/rllm/rewards/code_utils/swebench.py index c0716e162..160d732a3 100644 --- a/rllm/rewards/code_utils/swebench.py +++ b/rllm/rewards/code_utils/swebench.py @@ -54,7 +54,21 @@ run_threadpool, ) -from rllm.globals import CACHE_LEVEL, CLEAN, FORCE_REBUILD, INSTANCE_IMAGE_TAG, MAX_WORKERS, MODEL_NAME_OR_PATH, NAMESPACE, OPEN_FILE_LIMIT, REPORT_DIR, REWRITE_REPORTS, SPLIT, SWEBENCH_DATASET_NAME, TIMEOUT +from rllm.globals import ( + CACHE_LEVEL, + CLEAN, + FORCE_REBUILD, + INSTANCE_IMAGE_TAG, + MAX_WORKERS, + MODEL_NAME_OR_PATH, + NAMESPACE, + OPEN_FILE_LIMIT, + REPORT_DIR, + REWRITE_REPORTS, + SPLIT, + SWEBENCH_DATASET_NAME, + TIMEOUT, +) GIT_APPLY_CMDS = [ "git apply --verbose", @@ -627,7 +641,23 @@ def swebench_check_correctness( run_id = uuid.uuid4().hex instance_ids = [instance_id] - eval_report_path = run_evaluation(SWEBENCH_DATASET_NAME, instance_ids, actions, MAX_WORKERS, FORCE_REBUILD, CACHE_LEVEL, CLEAN, OPEN_FILE_LIMIT, run_id, TIMEOUT, NAMESPACE, REWRITE_REPORTS, SPLIT, INSTANCE_IMAGE_TAG, REPORT_DIR) + eval_report_path = run_evaluation( + SWEBENCH_DATASET_NAME, + instance_ids, + actions, + MAX_WORKERS, + FORCE_REBUILD, + CACHE_LEVEL, + CLEAN, + OPEN_FILE_LIMIT, + run_id, + TIMEOUT, + NAMESPACE, + REWRITE_REPORTS, + SPLIT, + INSTANCE_IMAGE_TAG, + REPORT_DIR, + ) # read from eval report and get the correct/incorrect stats for reward calculation with open(eval_report_path) as f: diff --git a/rllm/rewards/code_utils/taco.py b/rllm/rewards/code_utils/taco.py index 62085d396..a1cdb0ffb 100644 --- a/rllm/rewards/code_utils/taco.py +++ b/rllm/rewards/code_utils/taco.py @@ -66,7 +66,10 @@ def timeout_handler(signum, frame): raise TimeoutException -signal.signal(signal.SIGALRM, timeout_handler) +try: + signal.signal(signal.SIGALRM, timeout_handler) +except ValueError: + pass # signal only works in main thread; skip in Ray workers TIMEOUT = 90 # seconds EXECUTION_RESULTS = {1: "passed", 0: "false", -1: "timeout", -2: "runtime_error", -3: "returncode:{code}", -4: "compile_error"} @@ -379,7 +382,15 @@ def execute_std_code(method, synthesized_code, inputs_list, outputs_list, timeou temp_file_name = temp_input.name stdout, stderr = "", "" try: - result = subprocess.run(["bash", "-c", "ulimit -v 10485760; python3 " + temp_program_path], stdin=temp_input, stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setsid, timeout=timeout, text=True) + result = subprocess.run( + ["bash", "-c", "ulimit -v 10485760; python3 " + temp_program_path], + stdin=temp_input, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=os.setsid, + timeout=timeout, + text=True, + ) stdout, stderr = result.stdout, result.stderr return_code = result.returncode diff --git a/rllm/rewards/countdown_reward.py b/rllm/rewards/countdown_reward.py index ccdf6157a..093d35f95 100644 --- a/rllm/rewards/countdown_reward.py +++ b/rllm/rewards/countdown_reward.py @@ -1,9 +1,12 @@ +import logging import random import re from rllm import Action from rllm.rewards.reward_types import RewardOutput +logger = logging.getLogger(__name__) + def extract_solution(solution_str): """Extract the equation from the solution string.""" @@ -72,20 +75,17 @@ def compute_score(solution_str, ground_truth, method="strict", format_score=0.1, do_print = random.randint(1, 64) == 1 if do_print: - print("--------------------------------") - print(f"Target: {target} | Numbers: {numbers}") - print(f"Extracted equation: {equation}") - print(f"Solution string: {solution_str}") + logger.debug(f"Target: {target} | Numbers: {numbers} | Equation: {equation} | Solution: {solution_str}") if equation is None: if do_print: - print("No equation found") + logger.debug("No equation found") return 0 # Validate equation uses correct numbers if not validate_equation(equation, numbers): if do_print: - print("Invalid equation") + logger.debug("Invalid equation") return format_score # Evaluate equation @@ -93,20 +93,20 @@ def compute_score(solution_str, ground_truth, method="strict", format_score=0.1, result = evaluate_equation(equation) if result is None: if do_print: - print("Could not evaluate equation") + logger.debug("Could not evaluate equation") return format_score if abs(result - target) < 1e-5: # Account for floating point precision if do_print: - print(f"Correct equation: {equation} = {result}") + logger.debug(f"Correct equation: {equation} = {result}") return score else: if do_print: - print(f"Wrong result: equation = {result}, target = {target}") + logger.debug(f"Wrong result: equation = {result}, target = {target}") return format_score except Exception: if do_print: - print("Error evaluating equation") + logger.debug("Error evaluating equation") return format_score diff --git a/rllm/rewards/math_reward.py b/rllm/rewards/math_reward.py index 4a4c17bbb..ee4ad0b75 100644 --- a/rllm/rewards/math_reward.py +++ b/rllm/rewards/math_reward.py @@ -130,7 +130,9 @@ def rllm_reward_fn_math(data_source: str, llm_solution: str, ground_truth: str | reward = RewardMathFn(RewardConfig()) task_info = { "data_source": "", - "problem": ("Let $P(x)=x^{4}+2 x^{3}-13 x^{2}-14 x+24$ be a polynomial with roots $r_{1}, r_{2}, r_{3}, r_{4}$. Let $Q$ be the quartic polynomial with roots $r_{1}^{2}, r_{2}^{2}, r_{3}^{2}, r_{4}^{2}$, such that the coefficient of the $x^{4}$ term of $Q$ is 1. Simplify the quotient $Q\\left(x^{2}\\right) / P(x)$, leaving your answer in terms of $x$. (You may assume that $x$ is not equal to any of $\\left.r_{1}, r_{2}, r_{3}, r_{4}\\right)$."), + "problem": ( + "Let $P(x)=x^{4}+2 x^{3}-13 x^{2}-14 x+24$ be a polynomial with roots $r_{1}, r_{2}, r_{3}, r_{4}$. Let $Q$ be the quartic polynomial with roots $r_{1}^{2}, r_{2}^{2}, r_{3}^{2}, r_{4}^{2}$, such that the coefficient of the $x^{4}$ term of $Q$ is 1. Simplify the quotient $Q\\left(x^{2}\\right) / P(x)$, leaving your answer in terms of $x$. (You may assume that $x$ is not equal to any of $\\left.r_{1}, r_{2}, r_{3}, r_{4}\\right)$." + ), "problem_type": RewardType.MATH, "ground_truth": ["10", "$x^{4}-2 x^{3}-13 x^{2}+14 x+24$"], "has_toolcall": True, diff --git a/rllm/sdk/proxy/proxy_manager.py b/rllm/sdk/proxy/proxy_manager.py index c51154227..13ae2e431 100644 --- a/rllm/sdk/proxy/proxy_manager.py +++ b/rllm/sdk/proxy/proxy_manager.py @@ -366,7 +366,12 @@ def _instrument_vllm_servers(self) -> None: print(f"[PROXY_MANAGER] Detailed instrumentation status: {status}") if support == "none": - logger.warning("vLLM < 0.10.2 detected, but VERL servers are already running. Token IDs will NOT be available! To enable token IDs, call instrument_vllm() BEFORE creating AgentLoopManager. See docs/howto/instrument_verl_vllm_for_token_ids.md for details.") + logger.warning( + "vLLM < 0.10.2 detected, but VERL servers are already running. " + "Token IDs will NOT be available! To enable token IDs, call " + "instrument_vllm() BEFORE creating AgentLoopManager. " + "See docs/howto/instrument_verl_vllm_for_token_ids.md for details." + ) elif support == "native": logger.info("vLLM >= 0.10.2 detected, token IDs available via native support") elif support == "instrumented": diff --git a/rllm/sdk/session/opentelemetry.py b/rllm/sdk/session/opentelemetry.py index 4657b896d..50f690809 100644 --- a/rllm/sdk/session/opentelemetry.py +++ b/rllm/sdk/session/opentelemetry.py @@ -105,7 +105,9 @@ def get_active_otel_session_uids() -> list[str]: def get_current_otel_session() -> OpenTelemetrySession: """Not implemented - use get_current_otel_metadata() or get_current_otel_session_name() instead.""" - raise NotImplementedError("get_current_otel_session() is not supported with baggage-based sessions. Use get_current_otel_metadata(), get_current_otel_session_name(), or get_active_otel_session_uids() instead.") + raise NotImplementedError( + "get_current_otel_session() is not supported with baggage-based sessions. Use get_current_otel_metadata(), get_current_otel_session_name(), or get_active_otel_session_uids() instead." + ) def otel_session(**kwargs: Any) -> OpenTelemetrySession: diff --git a/rllm/system_prompts.py b/rllm/system_prompts.py index f5d108675..aa4fe9eb3 100644 --- a/rllm/system_prompts.py +++ b/rllm/system_prompts.py @@ -269,7 +269,9 @@ - IF THE PROBLEM DOESNT HAVE MULTIPLE CHOICE, OUTPUT 'NO MULTIPLE CHOICE'.""" -LCB_SYSTEM_MESSAGE_GENERIC = "You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests." +LCB_SYSTEM_MESSAGE_GENERIC = ( + "You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests." +) LCB_FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters." diff --git a/rllm/tools/__init__.py b/rllm/tools/__init__.py index 193687d71..b647f3cf7 100644 --- a/rllm/tools/__init__.py +++ b/rllm/tools/__init__.py @@ -22,4 +22,4 @@ tool_registry = ToolRegistry() tool_registry.register_all(DEFAULT_TOOLS) -__all__ = ["PythonInterpreter", "LocalRetrievalTool", "GoogleSearchTool", "FirecrawlTool", "TavilyExtractTool", "TavilySearchTool", "ToolRegistry", "tool_registry"] +__all__ = ["PythonInterpreter", "GoogleSearchTool", "FirecrawlTool", "TavilyExtractTool", "TavilySearchTool", "ToolRegistry", "tool_registry"] diff --git a/rllm/tools/code_tools/__init__.py b/rllm/tools/code_tools/__init__.py index b45a44e50..a7fb5f3f1 100644 --- a/rllm/tools/code_tools/__init__.py +++ b/rllm/tools/code_tools/__init__.py @@ -6,7 +6,6 @@ __all__ = [ "PythonInterpreter", # New unified interpreter "E2BPythonInterpreter", # Legacy interpreters for backward compatibility - "LocalPythonInterpreter", "LCBPythonInterpreter", "TogetherCodeTool", ] diff --git a/rllm/tools/code_tools/python_interpreter.py b/rllm/tools/code_tools/python_interpreter.py index 5dda1919a..fe8c4311a 100644 --- a/rllm/tools/code_tools/python_interpreter.py +++ b/rllm/tools/code_tools/python_interpreter.py @@ -18,7 +18,14 @@ class PythonInterpreter(CodeTool): and LiveCodeBench environment. """ - def __init__(self, backend: BackendType = "local", n_sandboxes: int = 1, api_key: str | None = None, name: str = "python", description: str = "Execute Python code in a sandboxed environment. Returns results and standard output/error."): + def __init__( + self, + backend: BackendType = "local", + n_sandboxes: int = 1, + api_key: str | None = None, + name: str = "python", + description: str = "Execute Python code in a sandboxed environment. Returns results and standard output/error.", + ): """ Initialize the unified Python interpreter with the specified backend. diff --git a/rllm/tools/code_tools/together_tool.py b/rllm/tools/code_tools/together_tool.py index a8a4451d1..9f23dbabd 100644 --- a/rllm/tools/code_tools/together_tool.py +++ b/rllm/tools/code_tools/together_tool.py @@ -79,7 +79,13 @@ def forward(self, code: str, timeout: int = 12, session_id: str | None = None, * output += str(output_item.data) + "\n" # Return formatted output - return CodeToolOutput(name=self.name or "together_python", output=output.strip() if output else None, stdout=stdout.strip() if stdout else None, stderr=stderr.strip() if stderr else None, error=error) + return CodeToolOutput( + name=self.name or "together_python", + output=output.strip() if output else None, + stdout=stdout.strip() if stdout else None, + stderr=stderr.strip() if stderr else None, + error=error, + ) except Exception as e: return CodeToolOutput(name=self.name or "together_python", error=f"{type(e).__name__} - {str(e)}", stderr=str(e)) diff --git a/rllm/tools/web_tools/firecrawl_tool.py b/rllm/tools/web_tools/firecrawl_tool.py index 31b27993f..0c08f2d6c 100644 --- a/rllm/tools/web_tools/firecrawl_tool.py +++ b/rllm/tools/web_tools/firecrawl_tool.py @@ -66,7 +66,23 @@ def _start_firecrawl_job(self, url): @property def json(self): """Return the tool's information in a standardized format for tool registration.""" - return {"type": "function", "function": {"name": self.name, "description": self.description, "parameters": {"type": "object", "properties": {"url": {"type": "string", "description": "Web URL to scrape content from."}}, "required": ["url"]}}} + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "Web URL to scrape content from.", + } + }, + "required": ["url"], + }, + }, + } def forward(self, url: str) -> ToolOutput: """ diff --git a/rllm/tools/web_tools/gsearch_tool.py b/rllm/tools/web_tools/gsearch_tool.py index 8d0674538..61bfbccff 100644 --- a/rllm/tools/web_tools/gsearch_tool.py +++ b/rllm/tools/web_tools/gsearch_tool.py @@ -41,7 +41,23 @@ def _init_client(self): @property def json(self): - return {"type": "function", "function": {"name": self.name, "description": self.description, "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Query to be submitted to Google search engine."}}, "required": ["query"]}}} + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Query to be submitted to Google search engine.", + } + }, + "required": ["query"], + }, + }, + } def _search_with_google(self, query: str): """ diff --git a/rllm/tools/web_tools/tavily_tool.py b/rllm/tools/web_tools/tavily_tool.py index 2ec202bf7..6d716ab83 100644 --- a/rllm/tools/web_tools/tavily_tool.py +++ b/rllm/tools/web_tools/tavily_tool.py @@ -17,7 +17,24 @@ def __init__(self): @property def json(self): - return {"type": "function", "function": {"name": self.name, "description": self.description, "parameters": {"type": "object", "properties": {"urls": {"type": "array", "items": {"type": "string"}, "description": "Array of URLs to extract content from"}}, "required": ["urls"]}}} + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": { + "urls": { + "type": "array", + "items": {"type": "string"}, + "description": "Array of URLs to extract content from", + } + }, + "required": ["urls"], + }, + }, + } def _init_client(self): self.client: httpx.Client | None = httpx.Client() diff --git a/rllm/trainer/deprecated/__init__.py b/rllm/trainer/deprecated/__init__.py index 2e78fcb4a..51ee74b06 100644 --- a/rllm/trainer/deprecated/__init__.py +++ b/rllm/trainer/deprecated/__init__.py @@ -8,6 +8,19 @@ from rllm.trainer.deprecated.tinker_sft_trainer import TinkerSFTTrainer from rllm.trainer.deprecated.tinker_workflow_trainer import TinkerWorkflowTrainer -warnings.warn(("`rllm.trainer.deprecated` contains deprecated Tinker trainer backends and may be removed in a future release.\nIf you are using the TinkerWorkflowTrainer, we recommend you migrate to the experimental unified trainer with the Tinker backend.\nThe change to config will be minimal, and will become the standard way to train Tinker workflows in the future.\nSee https://rllm-project.readthedocs.io/en/latest/experimental/unified-trainer.html for more details."), FutureWarning, stacklevel=2) +warnings.warn( + ( + "`rllm.trainer.deprecated` contains deprecated Tinker trainer backends " + "and may be removed in a future release.\n" + "If you are using the TinkerWorkflowTrainer, we recommend you migrate to " + "the experimental unified trainer with the Tinker backend.\n" + "The change to config will be minimal, and will become the standard way " + "to train Tinker workflows in the future.\n" + "See https://rllm-project.readthedocs.io/en/latest/experimental/" + "unified-trainer.html for more details." + ), + FutureWarning, + stacklevel=2, +) __all__ = ["TinkerAgentTrainer", "TinkerSFTTrainer", "TinkerWorkflowTrainer"] diff --git a/rllm/trainer/deprecated/tinker_agent_trainer.py b/rllm/trainer/deprecated/tinker_agent_trainer.py index 68a0b61cb..427b262bb 100644 --- a/rllm/trainer/deprecated/tinker_agent_trainer.py +++ b/rllm/trainer/deprecated/tinker_agent_trainer.py @@ -445,7 +445,12 @@ async def produce_episodes(): elif produce_completed.is_set() and episode_queue.empty(): break else: - raise TimeoutError(f"Episode generation stuck: no episodes received for {max_timeouts * 0.1:.1f} seconds. Producer completed: {produce_completed.is_set()}, Queue size: {episode_queue.qsize()}") from None + raise TimeoutError( + f"Episode generation stuck: no episodes received for " + f"{max_timeouts * 0.1:.1f} seconds. Producer completed: " + f"{produce_completed.is_set()}, Queue size: " + f"{episode_queue.qsize()}" + ) from None if produce_completed.is_set(): break diff --git a/rllm/trainer/deprecated/tinker_data_processor.py b/rllm/trainer/deprecated/tinker_data_processor.py index 44cda2ff2..b6f33dd7d 100644 --- a/rllm/trainer/deprecated/tinker_data_processor.py +++ b/rllm/trainer/deprecated/tinker_data_processor.py @@ -225,7 +225,9 @@ def build_datum_from_step(step: Step, advantage: float) -> tinker.Datum: all_mask = [0.0] * ob_len + [1.0] * (len(input_tokens) - ob_len) # Ensure all lists have the same length - assert len(input_tokens) == len(target_tokens) == len(all_logprobs) == len(all_advantages) == len(all_mask), f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, logprobs={len(all_logprobs)}, advantages={len(all_advantages)}, mask={len(all_mask)}" + assert len(input_tokens) == len(target_tokens) == len(all_logprobs) == len(all_advantages) == len(all_mask), ( + f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, logprobs={len(all_logprobs)}, advantages={len(all_advantages)}, mask={len(all_mask)}" + ) # Create Datum datum = tinker.types.Datum( @@ -268,7 +270,9 @@ def build_datum(trajectory: dict, advantage: float) -> tinker.Datum: all_mask = [0.0] * ob_len + [1.0] * (len(input_tokens) - ob_len) # Ensure all lists have the same length - assert len(input_tokens) == len(target_tokens) == len(all_logprobs) == len(all_advantages) == len(all_mask), f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, logprobs={len(all_logprobs)}, advantages={len(all_advantages)}, mask={len(all_mask)}" + assert len(input_tokens) == len(target_tokens) == len(all_logprobs) == len(all_advantages) == len(all_mask), ( + f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, logprobs={len(all_logprobs)}, advantages={len(all_advantages)}, mask={len(all_mask)}" + ) # Create Datum datum = tinker.types.Datum( diff --git a/rllm/trainer/deprecated/tinker_workflow_trainer.py b/rllm/trainer/deprecated/tinker_workflow_trainer.py index e7229c95d..cecffffd7 100644 --- a/rllm/trainer/deprecated/tinker_workflow_trainer.py +++ b/rllm/trainer/deprecated/tinker_workflow_trainer.py @@ -14,7 +14,7 @@ import tinker import torch -from transformers import AutoProcessor, AutoTokenizer +from transformers import AutoTokenizer from rllm.agents.agent import Episode from rllm.engine.agent_workflow_engine import AgentWorkflowEngine @@ -91,6 +91,8 @@ def __init__( model_name_lower = self.config.model.name.lower() if "vl" in model_name_lower or "vision" in model_name_lower: try: + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained(self.config.model.name, trust_remote_code=True) if hasattr(processor, "image_processor") and processor.image_processor is not None: image_processor = processor.image_processor diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index 14d1727c4..00ffd46ef 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -19,7 +19,7 @@ import tinker import torch from omegaconf import DictConfig -from transformers import AutoProcessor, AutoTokenizer +from transformers import AutoTokenizer from rllm.agents.agent import Episode from rllm.data import Dataset @@ -27,7 +27,6 @@ from rllm.experimental.protocol import BackendProtocol from rllm.experimental.rollout import RolloutEngine, TinkerEngine from rllm.trainer.tinker.tinker_metrics_utils import ( - print_metrics_table, update_training_metrics, ) from rllm.trainer.tinker.tinker_policy_trainer import TinkerPolicyTrainer @@ -98,6 +97,9 @@ def __init__( # Store algorithm config for use in process_backend_batch self._algorithm_config: AlgorithmConfig | None = None + # Track whether on_policy_updated was called this step (for backward compat) + self._policy_updated_this_step: bool = False + # Specific optimizer parameters for Tinker self.learning_rate = self.full_config.training.get("learning_rate", 1e-6) self.beta1 = self.full_config.training.get("beta1", 0.9) @@ -132,6 +134,8 @@ def init_rollout_engine(self, **kwargs) -> RolloutEngine: model_name_lower = self.full_config.model.name.lower() if "vl" in model_name_lower or "vision" in model_name_lower: try: + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained(self.full_config.model.name, trust_remote_code=True) if hasattr(processor, "image_processor") and processor.image_processor is not None: image_processor = processor.image_processor @@ -139,7 +143,18 @@ def init_rollout_engine(self, **kwargs) -> RolloutEngine: except Exception as e: logger.warning(f"Failed to load image_processor for VLM model: {e}") - self.rollout_engine = TinkerEngine(base_url=self.full_config.tinker_base_url, model_name=self.full_config.model.name, service_client=self.service_client, tokenizer=self.tokenizer, max_prompt_length=self.full_config.data.max_prompt_length, max_response_length=self.full_config.data.max_response_length, max_model_length=self.full_config.training.max_length, sampling_params=self.full_config.sampling, **self.full_config.rollout_engine, image_processor=image_processor) + self.rollout_engine = TinkerEngine( + base_url=self.full_config.tinker_base_url, + model_name=self.full_config.model.name, + service_client=self.service_client, + tokenizer=self.tokenizer, + max_prompt_length=self.full_config.data.max_prompt_length, + max_response_length=self.full_config.data.max_response_length, + max_model_length=self.full_config.training.max_length, + sampling_params=self.full_config.sampling, + **self.full_config.rollout_engine, + image_processor=image_processor, + ) return self.rollout_engine def validate_config(self) -> None: @@ -147,7 +162,10 @@ def validate_config(self) -> None: # Check for recommended sampling parameters sampling_params = self.full_config.sampling if sampling_params.get("temperature", 1.0) != 1.0 or sampling_params.get("top_p", 1.0) != 1.0: - logger.warning("Temperature and top_p are set away from 1.0, this is not recommended by Tinker and can cause mysterious issues with logprobs. See https://github.com/thinking-machines-lab/tinker-cookbook/pull/86 for discussion.") + logger.warning( + "Temperature and top_p are set away from 1.0, this is not recommended by Tinker and can cause mysterious issues with logprobs." + "See https://github.com/thinking-machines-lab/tinker-cookbook/pull/86 for discussion." + ) # Validate num_minibatches (currently only support 1) if self.full_config.training.get("num_minibatches", 1) != 1: @@ -173,6 +191,8 @@ def get_dataloader(self, dataset: Dataset | None, trainer_state: TrainerState) - shuffle = True else: batch_size = self.full_config.data.get("val_batch_size", self.full_config.data.train_batch_size) + if batch_size == -1: + batch_size = len(dataset) shuffle = False return torch.utils.data.DataLoader( @@ -388,6 +408,9 @@ async def on_train_start(self, trainer_state: TrainerState) -> None: resume = bool(self.full_config.training.resume_from_tinker_id) start_batch, self.sampling_client = await self.policy_trainer.initialize_async(resume_from_checkpoint=resume) + # Propagate sampling_client to rollout engine so it can make inference calls + self.rollout_engine.set_sampling_client(self.sampling_client) + # Update trainer state with the start batch from checkpoint trainer_state.global_step = start_batch @@ -400,29 +423,38 @@ async def on_train_end(self, trainer_state: TrainerState) -> None: logger.info(f"Saving final checkpoint at step {trainer_state.global_step}") await self.policy_trainer.save_checkpoint_and_get_sampling_client(trainer_state.global_step, kind="both", do_save=True) + async def on_policy_updated(self, trainer_state: TrainerState) -> None: + """Save checkpoint and update sampling_client after policy update.""" + assert self.policy_trainer is not None, "policy_trainer is not initialized" + self._policy_updated_this_step = True + + global_step = trainer_state.global_step + save_freq = self.full_config.rllm.trainer.save_freq + do_save = save_freq > 0 and global_step % save_freq == 0 + self.sampling_client = await self.policy_trainer.save_checkpoint_and_get_sampling_client(global_step, kind="both", do_save=do_save) + + # Propagate updated sampling_client to rollout engine for async weight sync + self.rollout_engine.set_sampling_client(self.sampling_client) + async def on_batch_end(self, trainer_state: TrainerState) -> None: """Called at the end of each batch. - Saves checkpoint, updates sampling client, and prints metrics. + In sync mode, on_policy_updated() is not called separately, so we + do the checkpoint/sampling_client update here for backward compat. """ assert self.policy_trainer is not None, "policy_trainer is not initialized" - global_step = trainer_state.global_step - # Save sampler checkpoint after each batch - with simple_timer("save_checkpoint", trainer_state.timing_dict): - logger.info(f"Saving state checkpoint and sampler at step {global_step}") - save_freq = self.full_config.rllm.trainer.save_freq - do_save = save_freq > 0 and global_step % save_freq == 0 - self.sampling_client = await self.policy_trainer.save_checkpoint_and_get_sampling_client(global_step, kind="both", do_save=do_save) + # If on_policy_updated() wasn't called (sync mode), do checkpoint here + if not self._policy_updated_this_step: + with simple_timer("save_checkpoint", trainer_state.timing_dict): + logger.info(f"Saving state checkpoint and sampler at step {trainer_state.global_step}") + await self.on_policy_updated(trainer_state) + self._policy_updated_this_step = False # Update metrics learning_rate = trainer_state.extra_info.get("scheduled_learning_rate", self.learning_rate) update_training_metrics(trainer_state, learning_rate, trainer_state.total_steps) - # Print metrics table - if trainer_state.metrics: - print_metrics_table(trainer_state.metrics, global_step) - async def on_epoch_start(self, trainer_state: TrainerState) -> None: """Called at the start of an epoch.""" logger.info(f"Starting epoch {trainer_state.epoch}") diff --git a/rllm/trainer/tinker/tinker_metrics_utils.py b/rllm/trainer/tinker/tinker_metrics_utils.py index 3976081f3..618a6ce30 100644 --- a/rllm/trainer/tinker/tinker_metrics_utils.py +++ b/rllm/trainer/tinker/tinker_metrics_utils.py @@ -5,61 +5,12 @@ import tinker import torch +from rllm.experimental.common.visualization import print_metrics_table # noqa: F401 (re-export) from rllm.experimental.unified_trainer import TrainerState logger = logging.getLogger(__name__) -def print_metrics_table(metrics: dict, step: int): - """ - Print metrics as a formatted table (similar to tinker_cookbook). - - Args: - metrics: Dictionary of metrics - step: Current step number - """ - try: - from rich.console import Console - from rich.table import Table - - console = Console() - - # Create table - table = Table(title=f"Step {step}", show_header=True, header_style="bold magenta") - table.add_column("Metric", style="cyan", no_wrap=False) - table.add_column("Value", justify="right", style="green") - - # Sort metrics by key for consistent ordering - sorted_metrics = sorted(metrics.items()) - - for key, value in sorted_metrics: - # Format value based on type - if isinstance(value, float): - value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" - elif isinstance(value, int): - value_str = str(value) - else: - value_str = str(value) - - table.add_row(key, value_str) - - console.print(table) - - except ImportError: - # Fallback to simple text table if rich is not available - print(f"\nStep {step}") - print("=" * 60) - for key, value in sorted(metrics.items()): - if isinstance(value, float): - value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" - elif isinstance(value, int): - value_str = str(value) - else: - value_str = str(value) - print(f"{key:40s} {value_str:>15s}") - print("=" * 60) - - def compute_kl_and_entropy_metrics(training_datums: list[tinker.Datum], training_logprobs: list[torch.Tensor]) -> dict: """ Compute KL divergence and entropy metrics from training. @@ -102,10 +53,10 @@ def compute_kl_and_entropy_metrics(training_datums: list[tinker.Datum], training perplexity = torch.exp(torch.tensor(entropy_sample)).item() return { - "optim/kl_sample_train_v1": kl_sample_train_v1, - "optim/kl_sample_train_v2": kl_sample_train_v2, - "optim/entropy": entropy_sample, - "optim/perplexity": perplexity, + "train/kl_sample_train_v1": kl_sample_train_v1, + "train/kl_sample_train_v2": kl_sample_train_v2, + "train/entropy": entropy_sample, + "train/perplexity": perplexity, } @@ -125,7 +76,7 @@ def update_training_metrics(trainer_state: TrainerState, learning_rate: float, t { "progress/batch": trainer_state.global_step, "progress/epoch": trainer_state.epoch, - "optim/lr": learning_rate, + "progress/lr": learning_rate, } ) diff --git a/rllm/trainer/tinker/tinker_policy_trainer.py b/rllm/trainer/tinker/tinker_policy_trainer.py index 57d7bd6ca..30f1caa6f 100644 --- a/rllm/trainer/tinker/tinker_policy_trainer.py +++ b/rllm/trainer/tinker/tinker_policy_trainer.py @@ -255,12 +255,18 @@ async def forward_backward_from_trajectory_groups( # Wait for completion and extract logprobs fwd_bwd_results = await asyncio.gather(*fwd_bwd_futures) - # Extract training logprobs from loss_fn_outputs + # Extract training logprobs and server-side metrics from results training_logprobs = [] for fwd_bwd_result in fwd_bwd_results: for output in fwd_bwd_result.loss_fn_outputs: logprobs = output["logprobs"].to_torch() training_logprobs.append(logprobs) + # Capture server-side metrics (e.g. loss) under train/ prefix + if fwd_bwd_result.metrics: + for k, v in fwd_bwd_result.metrics.items(): + if k.startswith("clock_cycle"): + continue + adv_metrics[f"train/{k.replace(':', '/')}"] = v return training_datums, training_logprobs, adv_metrics @@ -335,6 +341,11 @@ async def fused_forward_backward_and_optim_step( for output in fwd_bwd_result.loss_fn_outputs: logprobs = output["logprobs"].to_torch() training_logprobs.append(logprobs) + if fwd_bwd_result.metrics: + for k, v in fwd_bwd_result.metrics.items(): + if k.startswith("clock_cycle"): + continue + adv_metrics[f"train/{k.replace(':', '/')}"] = v return training_datums, training_logprobs, adv_metrics, scheduled_learning_rate diff --git a/rllm/trainer/tinker/transform.py b/rllm/trainer/tinker/transform.py index f758b7ce1..13497ab63 100644 --- a/rllm/trainer/tinker/transform.py +++ b/rllm/trainer/tinker/transform.py @@ -36,7 +36,7 @@ def _flatten_token_input(token_input: TinkerTokenInput) -> TinkerTokenInput: return flattened -def trajectory_to_datums(traj: Trajectory) -> list[tinker.Datum]: +def trajectory_to_datums(traj: Trajectory, router_replay: bool = False) -> list[tinker.Datum]: """ Return one or more Datum objects corresponding to the trajectory. If the sequence grows by appending, i.e., each successive observation contains @@ -61,6 +61,7 @@ class SequenceAccumulator: sampled_logprobs: list[float] = [] advantages: list[float] = [] mask: list[float] = [] + routing_matrices: list[str] = [] @classmethod def clear(cls): @@ -68,6 +69,7 @@ def clear(cls): cls.sampled_logprobs = [] cls.advantages = [] cls.mask = [] + cls.routing_matrices = [] def make_datum_from_state(): all_tokens_T = _flat_token_input_to_model_input(SequenceAccumulator.full_sequence) @@ -77,6 +79,9 @@ def make_datum_from_state(): advantages_T = SequenceAccumulator.advantages[1:] mask_T = SequenceAccumulator.mask[1:] assert input_tokens_T.length == len(target_tokens_T) == len(sampled_logprobs_T) == len(advantages_T) == len(mask_T) + if router_replay and SequenceAccumulator.routing_matrices: + rm_shifted = SequenceAccumulator.routing_matrices[1:] # match rightshift + input_tokens_T = input_tokens_T.model_copy(update={"routing_matrices": rm_shifted}) return tinker.Datum( model_input=input_tokens_T, loss_fn_inputs={ @@ -118,6 +123,9 @@ def make_datum_from_state(): SequenceAccumulator.sampled_logprobs.extend([0.0] * delta_token_input_length + output_logprobs) SequenceAccumulator.advantages.extend([0] * delta_token_input_length + advantages) SequenceAccumulator.mask.extend([0.0] * delta_token_input_length + [1.0] * len(output_token_ids)) + if router_replay: + step_rm = step.routing_matrices or [] + SequenceAccumulator.routing_matrices.extend([""] * delta_token_input_length + (list(step_rm) if step_rm else [""] * len(output_token_ids))) if SequenceAccumulator.full_sequence: data.append(make_datum_from_state()) @@ -137,9 +145,12 @@ def transform_trajectory_groups_to_datums( If the `estimator_map` is used in the algorithm config, we return a dictionary of datums, keyed by the trajectory group role. Otherwise, we return a list of datums. """ - # step 1: compute the advantages for each group using the common functionality - # this fills the `advantage` attribute of all the steps in the trajectory groups - adv_metrics = collect_reward_and_advantage_from_trajectory_groups(trajectory_groups, algorithm_config) + # step 1: compute advantages (skip if already pre-computed by buffer) + has_advantages = any(step.advantage is not None for group in trajectory_groups for traj in group.trajectories for step in traj.steps) + if has_advantages: + adv_metrics = {} + else: + adv_metrics = collect_reward_and_advantage_from_trajectory_groups(trajectory_groups, algorithm_config) if algorithm_config.estimator_map: datums_dict = defaultdict(list) @@ -147,11 +158,27 @@ def transform_trajectory_groups_to_datums( datums = [] # step 2: iterate over all steps and build the Tinker Datum objects + seqs_per_traj = [] + seq_lengths = [] for group in trajectory_groups: for trajectory in group.trajectories: + traj_datums = trajectory_to_datums(trajectory, router_replay=algorithm_config.router_replay) + seqs_per_traj.append(len(traj_datums)) + for d in traj_datums: + seq_lengths.append(d.model_input.length) if algorithm_config.estimator_map: - datums_dict[group.group_role].extend(trajectory_to_datums(trajectory)) + datums_dict[group.group_role].extend(traj_datums) else: - datums.extend(trajectory_to_datums(trajectory)) + datums.extend(traj_datums) + + if seqs_per_traj: + import numpy as _np + + adv_metrics["batch/seqs_per_traj/mean"] = _np.mean(seqs_per_traj) + adv_metrics["batch/seqs_per_traj/min"] = _np.min(seqs_per_traj) + adv_metrics["batch/seqs_per_traj/max"] = _np.max(seqs_per_traj) + adv_metrics["batch/seq_length/mean"] = _np.mean(seq_lengths) + adv_metrics["batch/seq_length/min"] = _np.min(seq_lengths) + adv_metrics["batch/seq_length/max"] = _np.max(seq_lengths) return (datums if not algorithm_config.estimator_map else datums_dict), adv_metrics diff --git a/rllm/trainer/verl/agent_sdk_trainer.py b/rllm/trainer/verl/agent_sdk_trainer.py index 0d59b7c19..7880f6f76 100644 --- a/rllm/trainer/verl/agent_sdk_trainer.py +++ b/rllm/trainer/verl/agent_sdk_trainer.py @@ -175,6 +175,7 @@ def fit_agent(self): self.global_steps = 0 self._load_checkpoint() + self.checkpoint_manager.update_weights(self.global_steps) start_time = time.time() if self.config.trainer.get("val_before_train", True): @@ -213,6 +214,7 @@ def fit_agent(self): with marked_timer("step", timing_raw): # generate trajectories final_gen_batch_output = self.generate_trajectories(batch=new_batch, timing_raw=timing_raw) + self.checkpoint_manager.sleep_replicas() # need to repeat to make shape match repeat_counts = final_gen_batch_output.meta_info["repeat_counts"] @@ -473,6 +475,9 @@ def fit_agent(self): with marked_timer("save_checkpoint", timing_raw, color="green"): self._save_checkpoint() + # update weights from trainer to rollout + with marked_timer("update_weights", timing_raw, color="red"): + self.checkpoint_manager.update_weights(self.global_steps) # Visualize some sample trajectories if batch is not None and len(batch) > 0: # Randomly select a few samples to visualize diff --git a/rllm/trainer/verl/agent_workflow_trainer.py b/rllm/trainer/verl/agent_workflow_trainer.py index 279461d09..3ce94db4c 100644 --- a/rllm/trainer/verl/agent_workflow_trainer.py +++ b/rllm/trainer/verl/agent_workflow_trainer.py @@ -30,8 +30,8 @@ compute_advantage, ) from verl.trainer.ppo.utils import Role, WorkerType -from verl.utils.debug import marked_timer from verl.utils import tensordict_utils as tu +from verl.utils.debug import marked_timer from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding from rllm.engine.agent_workflow_engine import AgentWorkflowEngine @@ -336,8 +336,7 @@ def fit_agent(self): with marked_timer("old_log_prob", timing_raw, color="blue"): batch_td = batch.to_tensordict() batch_td = left_right_2_no_padding(batch_td) - tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False, - temperature=self.config.actor_rollout_ref.rollout.temperature) + tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False, temperature=self.config.actor_rollout_ref.rollout.temperature) old_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch_td) # New worker returns TensorDict in no-padding format if isinstance(old_log_prob_output, TensorDict): @@ -346,9 +345,7 @@ def fit_agent(self): # Convert from no-padding back to padding format entropy = no_padding_2_padding(entropy, batch_td) log_probs = no_padding_2_padding(log_probs, batch_td) - old_log_prob = DataProto.from_tensordict( - tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float()}) - ) + old_log_prob = DataProto.from_tensordict(tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float()})) else: old_log_prob = old_log_prob_output entropys = old_log_prob.batch["entropys"] @@ -376,9 +373,7 @@ def fit_agent(self): if isinstance(ref_log_prob_output, TensorDict): ref_lp = tu.get(ref_log_prob_output, "log_probs") ref_lp = no_padding_2_padding(ref_lp, batch_td) - ref_log_prob = DataProto.from_tensordict( - tu.get_tensordict({"ref_log_prob": ref_lp.float()}) - ) + ref_log_prob = DataProto.from_tensordict(tu.get_tensordict({"ref_log_prob": ref_lp.float()})) else: ref_log_prob = ref_log_prob_output batch = batch.union(ref_log_prob) @@ -463,10 +458,7 @@ def fit_agent(self): # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # verl 0.7.1 new worker path: update_actor needs training metadata - ppo_mini_batch_size = ( - self.config.actor_rollout_ref.actor.ppo_mini_batch_size - * self.config.actor_rollout_ref.rollout.n - ) + ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature batch.meta_info["mini_batch_size"] = ppo_mini_batch_size batch.meta_info["epochs"] = self.config.actor_rollout_ref.actor.ppo_epochs @@ -680,16 +672,13 @@ def _validate_agent(self): # Follow verl's no-padding pattern for compute_log_prob cb_td = combined_batch.to_tensordict() cb_td = left_right_2_no_padding(cb_td) - tu.assign_non_tensor(cb_td, calculate_entropy=True, compute_loss=False, - temperature=self.config.actor_rollout_ref.rollout.temperature) + tu.assign_non_tensor(cb_td, calculate_entropy=True, compute_loss=False, temperature=self.config.actor_rollout_ref.rollout.temperature) old_log_prob_output = self.actor_rollout_wg.compute_log_prob(cb_td) if isinstance(old_log_prob_output, TensorDict): log_probs = tu.get(old_log_prob_output, "log_probs") log_probs = no_padding_2_padding(log_probs, cb_td) - old_log_prob = DataProto.from_tensordict( - tu.get_tensordict({"old_log_probs": log_probs.float()}) - ) + old_log_prob = DataProto.from_tensordict(tu.get_tensordict({"old_log_probs": log_probs.float()})) else: old_log_prob = old_log_prob_output combined_batch = combined_batch.union(old_log_prob) diff --git a/rllm/trainer/verl/ray_runtime_env.py b/rllm/trainer/verl/ray_runtime_env.py index f89058e13..d22e803bd 100644 --- a/rllm/trainer/verl/ray_runtime_env.py +++ b/rllm/trainer/verl/ray_runtime_env.py @@ -7,7 +7,9 @@ "VLLM_LOGGING_LEVEL": "WARN", "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", "CUDA_DEVICE_MAX_CONNECTIONS": "1", - "VLLM_USE_V1": "1", + # TODO: disable compile cache due to cache corruption issue + # https://github.com/vllm-project/vllm/issues/31199 + "VLLM_DISABLE_COMPILE_CACHE": "1", # To prevent hanging or crash during synchronization of weights between actor and rollout # in disaggregated mode. See: # https://docs.vllm.ai/en/latest/usage/troubleshooting.html?h=nccl_cumem_enable#known-issues diff --git a/rllm/trainer/verl/sft_dataset.py b/rllm/trainer/verl/sft_dataset.py index 17bc68888..e70202d70 100644 --- a/rllm/trainer/verl/sft_dataset.py +++ b/rllm/trainer/verl/sft_dataset.py @@ -9,8 +9,8 @@ class RLLMSFTDataset(MultiTurnSFTDataset): - def __init__(self, parquet_files: str | list[str], tokenizer, config=None): - super().__init__(parquet_files, tokenizer, config) + def __init__(self, parquet_files: str | list[str], tokenizer, config=None, processor=None, max_samples=-1): + super().__init__(parquet_files, tokenizer, config, processor=processor, max_samples=max_samples) self.tokenize_and_mask_method = config.rllm.tokenize_and_mask_method logger.info(f"Using {self.tokenize_and_mask_method} tokenization and masking method") @@ -22,6 +22,8 @@ def _tokenize_and_mask(self, messages): return self._tokenize_and_mask_cumulative(messages) elif self.tokenize_and_mask_method == "stepwise": return self._tokenize_and_mask_stepwise(messages) + elif self.tokenize_and_mask_method == "hf_template": + return self._tokenize_and_mask_hf_template(messages) else: raise ValueError(f"Unknown tokenize_and_mask_method {self.tokenize_and_mask_method}") @@ -40,6 +42,43 @@ def _tokenize_and_mask_cumulative(self, messages): return tokens, loss_mask + def _tokenize_and_mask_hf_template(self, messages): + """Use HF tokenizer.apply_chat_template for native tool call rendering. + + Renders incrementally: messages[0:i] vs messages[0:i+1] to isolate each + message's tokens, then applies loss mask only on assistant tokens. + """ + full_text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + + # Build prefix lengths to find boundaries + prefix_lengths = [0] # char offset where each message starts + for i in range(len(messages)): + prefix_text = self.tokenizer.apply_chat_template( + messages[: i + 1], + tokenize=False, + add_generation_prompt=False, + ) + prefix_lengths.append(len(prefix_text)) + + # Tokenize each segment and assign loss mask + tokens = [] + loss_mask = [] + for i in range(len(messages)): + segment = full_text[prefix_lengths[i] : prefix_lengths[i + 1]] + seg_ids = self.tokenizer.encode(segment, add_special_tokens=False) + + if messages[i]["role"] == "assistant": + loss_mask.extend([1] * len(seg_ids)) + else: + loss_mask.extend([0] * len(seg_ids)) + tokens.extend(seg_ids) + + return tokens, loss_mask + def _tokenize_and_mask_stepwise(self, messages): tokens = [] loss_mask = [] diff --git a/rllm/trainer/verl/train_agent_ppo_pipeline.py b/rllm/trainer/verl/train_agent_ppo_pipeline.py index c20976936..600938da8 100644 --- a/rllm/trainer/verl/train_agent_ppo_pipeline.py +++ b/rllm/trainer/verl/train_agent_ppo_pipeline.py @@ -83,7 +83,17 @@ def main_task(config, compute_score=None): agent_class = AGENT_CLASS_MAPPING[config.rllm.agent.name] setup_environment(config) - trainer = PipelineAgentPPOTrainer(config=config, tokenizer=tokenizer, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=RayWorkerGroup, reward_fn=reward_fn, val_reward_fn=val_reward_fn, env_class=env_class, agent_class=agent_class) + trainer = PipelineAgentPPOTrainer( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=RayWorkerGroup, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + env_class=env_class, + agent_class=agent_class, + ) trainer.init_workers() trainer.fit_agent() diff --git a/rllm/trajectory_visualizer.py b/rllm/trajectory_visualizer.py index 1cfac0bc9..a79e6c8b6 100644 --- a/rllm/trajectory_visualizer.py +++ b/rllm/trajectory_visualizer.py @@ -119,7 +119,12 @@ def detect_response_structure(step): # Updated logic for new API pattern uses_thinking_format = has_thinking and step.thought.strip() - return {"has_thinking": has_thinking, "has_direct_response": has_direct_response, "uses_thinking_format": uses_thinking_format, "response_field": "thought" if uses_thinking_format else "model_response"} + return { + "has_thinking": has_thinking, + "has_direct_response": has_direct_response, + "uses_thinking_format": uses_thinking_format, + "response_field": "thought" if uses_thinking_format else "model_response", + } def get_task_type(metadata): """Detect task type from metadata""" @@ -476,13 +481,25 @@ def update_step_view(traj_idx: int, step_idx: int, filter_option: str): with gr.Row(): with gr.Column(scale=1): - filter_dropdown = gr.Dropdown(choices=["All Trajectories", "Zero Reward (Failed)", "Nonzero Reward (Partial/Full Success)", "Perfect Score (Reward = 1)"], value="All Trajectories", label="🎯 Filter by Reward", interactive=True) + filter_dropdown = gr.Dropdown( + choices=[ + "All Trajectories", + "Zero Reward (Failed)", + "Nonzero Reward (Partial/Full Success)", + "Perfect Score (Reward = 1)", + ], + value="All Trajectories", + label="🎯 Filter by Reward", + interactive=True, + ) with gr.Column(scale=1): zero_count = len([t for t in all_trajs if float(t.reward) == 0.0]) nonzero_count = len([t for t in all_trajs if float(t.reward) > 0.0]) perfect_count = len([t for t in all_trajs if float(t.reward) == 1.0]) - _ = gr.Markdown(f"**Dataset Stats:**\n- Total: {len(all_trajs)} trajectories\n- Failed (0): {zero_count}\n- Partial/Full Success (>0): {nonzero_count}\n- Perfect Score (=1): {perfect_count}") + _ = gr.Markdown( + f"**Dataset Stats:**\n- Total: {len(all_trajs)} trajectories\n- Failed (0): {zero_count}\n- Partial/Full Success (>0): {nonzero_count}\n- Perfect Score (=1): {perfect_count}" + ) with gr.Row(): with gr.Column(scale=1): @@ -524,15 +541,42 @@ def update_step_view(traj_idx: int, step_idx: int, filter_option: str): with gr.Accordion("📋 Tool Results", open=True): outputs_output = gr.Markdown(elem_classes=["outputs-box"]) - all_outputs = [current_pos_display, metadata_output, performance_output, question_output, thinking_output, response_output, step_perf_output, actions_output, outputs_output, final_answer_output] + all_outputs = [ + current_pos_display, + metadata_output, + performance_output, + question_output, + thinking_output, + response_output, + step_perf_output, + actions_output, + outputs_output, + final_answer_output, + ] def reset_to_first_trajectory(): return 0, 0 - prev_traj_button.click(fn=lambda t, s, f: advance_step_or_trajectory(t, s, "prev", "trajectory", filter_trajectories_by_reward(f)), inputs=[current_traj_idx_state, current_step_idx_state, filter_dropdown], outputs=[current_traj_idx_state, current_step_idx_state]) - next_traj_button.click(fn=lambda t, s, f: advance_step_or_trajectory(t, s, "next", "trajectory", filter_trajectories_by_reward(f)), inputs=[current_traj_idx_state, current_step_idx_state, filter_dropdown], outputs=[current_traj_idx_state, current_step_idx_state]) - prev_step_button.click(fn=lambda t, s, f: advance_step_or_trajectory(t, s, "prev", "step", filter_trajectories_by_reward(f)), inputs=[current_traj_idx_state, current_step_idx_state, filter_dropdown], outputs=[current_traj_idx_state, current_step_idx_state]) - next_step_button.click(fn=lambda t, s, f: advance_step_or_trajectory(t, s, "next", "step", filter_trajectories_by_reward(f)), inputs=[current_traj_idx_state, current_step_idx_state, filter_dropdown], outputs=[current_traj_idx_state, current_step_idx_state]) + prev_traj_button.click( + fn=lambda t, s, f: advance_step_or_trajectory(t, s, "prev", "trajectory", filter_trajectories_by_reward(f)), + inputs=[current_traj_idx_state, current_step_idx_state, filter_dropdown], + outputs=[current_traj_idx_state, current_step_idx_state], + ) + next_traj_button.click( + fn=lambda t, s, f: advance_step_or_trajectory(t, s, "next", "trajectory", filter_trajectories_by_reward(f)), + inputs=[current_traj_idx_state, current_step_idx_state, filter_dropdown], + outputs=[current_traj_idx_state, current_step_idx_state], + ) + prev_step_button.click( + fn=lambda t, s, f: advance_step_or_trajectory(t, s, "prev", "step", filter_trajectories_by_reward(f)), + inputs=[current_traj_idx_state, current_step_idx_state, filter_dropdown], + outputs=[current_traj_idx_state, current_step_idx_state], + ) + next_step_button.click( + fn=lambda t, s, f: advance_step_or_trajectory(t, s, "next", "step", filter_trajectories_by_reward(f)), + inputs=[current_traj_idx_state, current_step_idx_state, filter_dropdown], + outputs=[current_traj_idx_state, current_step_idx_state], + ) filter_dropdown.change(fn=reset_to_first_trajectory, outputs=[current_traj_idx_state, current_step_idx_state]) diff --git a/rllm/utils/episode_logger.py b/rllm/utils/episode_logger.py index bdd665bac..5031aeddb 100644 --- a/rllm/utils/episode_logger.py +++ b/rllm/utils/episode_logger.py @@ -85,7 +85,18 @@ def log_episode(self, episode: Episode, step: int, mode: str = "train", epoch: i mode: Mode identifier ('train' or 'val'), defaults to 'train' epoch: Current epoch number, defaults to 0 """ - episode_data = {"training_step": step, "epoch": epoch, "episode_id": episode.id, "task": episode.task, "task_hash": self.compute_task_hash(episode.task), "is_correct": episode.is_correct, "termination_reason": episode.termination_reason.value if episode.termination_reason else None, "metrics": episode.metrics, "timing": episode.info.get("timing", {}), "trajectories": []} + episode_data = { + "training_step": step, + "epoch": epoch, + "episode_id": episode.id, + "task": episode.task, + "task_hash": self.compute_task_hash(episode.task), + "is_correct": episode.is_correct, + "termination_reason": (episode.termination_reason.value if episode.termination_reason else None), + "metrics": episode.metrics, + "timing": episode.info.get("timing", {}), + "trajectories": [], + } for traj in episode.trajectories: traj_data = { diff --git a/rllm/utils/tracking.py b/rllm/utils/tracking.py index e9bc4169a..917c74281 100644 --- a/rllm/utils/tracking.py +++ b/rllm/utils/tracking.py @@ -683,7 +683,9 @@ def log(self, data, step): iteration=step, ) else: - logger.warning(f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}". This invocation of ClearML logger\'s function is incorrect so this attribute was dropped. ') + logger.warning( + f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}". This invocation of ClearML logger\'s function is incorrect so this attribute was dropped.' + ) def finish(self): self._task.close() diff --git a/rllm/utils/visualization.py b/rllm/utils/visualization.py index 45d5c385b..7210bce14 100644 --- a/rllm/utils/visualization.py +++ b/rllm/utils/visualization.py @@ -118,7 +118,14 @@ def visualize_trajectories( _visualize_metadata(batch, idx, config) # 2. Print Legend (simplified) - legend = " ".join([_format_token("masked", config.masked_token_style), _format_token("unmasked", config.unmasked_token_style), _format_token("reward > 0", config.reward_pos_style), _format_token("reward <= 0", config.reward_neg_style)]) + legend = " ".join( + [ + _format_token("masked", config.masked_token_style), + _format_token("unmasked", config.unmasked_token_style), + _format_token("reward > 0", config.reward_pos_style), + _format_token("reward <= 0", config.reward_neg_style), + ] + ) print(f"[{legend}]") # 3. Render Prompt diff --git a/rllm/workflows/cumulative_workflow.py b/rllm/workflows/cumulative_workflow.py index 45fa5e30b..932e6036d 100644 --- a/rllm/workflows/cumulative_workflow.py +++ b/rllm/workflows/cumulative_workflow.py @@ -49,7 +49,14 @@ async def run(self, task: dict, uid: str, **kwargs) -> Episode | None: if max_tokens <= 0: raise TerminationEvent(TerminationReason.MAX_RESPONSE_LENGTH_EXCEEDED) - output: ModelOutput = await self.timed_llm_call(self.agent.chat_completions, application_id=uid, accumulate_reasoning=True, enforce_max_prompt_length=False, max_tokens=max_tokens, **kwargs) + output: ModelOutput = await self.timed_llm_call( + self.agent.chat_completions, + application_id=uid, + accumulate_reasoning=True, + enforce_max_prompt_length=False, + max_tokens=max_tokens, + **kwargs, + ) response = output.text action = self.agent.update_from_model(response) diff --git a/rllm/workflows/distillation_workflow.py b/rllm/workflows/distillation_workflow.py index 4423f04bf..31e5531c9 100644 --- a/rllm/workflows/distillation_workflow.py +++ b/rllm/workflows/distillation_workflow.py @@ -21,7 +21,16 @@ class DistillationWorkflow(Workflow): **kwargs: Additional arguments passed to Workflow. """ - def __init__(self, rollout_engine: RolloutEngine, reward_function: RewardFunction | None = None, teacher_engine: RolloutEngine | None = None, shared_tokenizer: bool = False, clip_min: float | None = None, clip_max: float | None = None, **kwargs): + def __init__( + self, + rollout_engine: RolloutEngine, + reward_function: RewardFunction | None = None, + teacher_engine: RolloutEngine | None = None, + shared_tokenizer: bool = False, + clip_min: float | None = None, + clip_max: float | None = None, + **kwargs, + ): super().__init__(rollout_engine, **kwargs) self.reward_function = reward_function self.teacher_engine = teacher_engine diff --git a/rllm/workflows/workflow.py b/rllm/workflows/workflow.py index d427c5648..dd4f092b9 100644 --- a/rllm/workflows/workflow.py +++ b/rllm/workflows/workflow.py @@ -197,7 +197,7 @@ def assign_episode_correctness(self, episode: Episode) -> None: """ total_reward = 0 for trajectory in episode.trajectories: - total_reward += trajectory.reward + total_reward += trajectory.reward or 0 episode.is_correct = total_reward > 0 def collect_metrics(self, episode: Episode) -> None: @@ -210,7 +210,7 @@ def collect_metrics(self, episode: Episode) -> None: metrics = defaultdict(list) for traj in episode.trajectories: name = traj.name - metrics[name].append(traj.reward) + metrics[name].append(traj.reward or 0.0) episode.metrics = {f"{k}_acc": float(np.mean(v)) for k, v in metrics.items()} def postprocess_episode(self, episode: Episode, termination_reason: TerminationReason = None, error: dict = None) -> Episode: diff --git a/scripts/data/gaia_dataset.py b/scripts/data/gaia_dataset.py index d59f6265f..fa6373511 100644 --- a/scripts/data/gaia_dataset.py +++ b/scripts/data/gaia_dataset.py @@ -43,7 +43,19 @@ def process_fn(example: dict[str, Any], idx: int, dataset_name=None) -> dict[str tests = json.dumps(tests) if isinstance(question, dict): question = json.dumps(question) - data = {"data_source": dataset_name, "prompt": [{"role": "user", "content": question}], "ability": "code", "reward_model": {"style": "rule", "ground_truth": tests}, "extra_info": {"split": split, "index": idx, "task": {"question": question}}, "task": {"question": question}, "uid": idx} + data = { + "data_source": dataset_name, + "prompt": [{"role": "user", "content": question}], + "ability": "code", + "reward_model": {"style": "rule", "ground_truth": tests}, + "extra_info": { + "split": split, + "index": idx, + "task": {"question": question}, + }, + "task": {"question": question}, + "uid": idx, + } return data return process_fn diff --git a/tests/agents/test_appworld_agent.py b/tests/agents/test_appworld_agent.py index 0c1faa44b..e54fbb25b 100644 --- a/tests/agents/test_appworld_agent.py +++ b/tests/agents/test_appworld_agent.py @@ -53,7 +53,11 @@ def test_update_from_env_initial_task(self): """Test updating environment with initial task observation""" agent = AppWorldReactAgent() - observation = {"instruction": "How many playlists do I have in Spotify?", "user_info": {"first_name": "Test", "last_name": "User", "email": "test@example.com", "phone_number": "+1234567890"}, "app_descriptions": "spotify: Music streaming app\nsupervisor: User management"} + observation = { + "instruction": "How many playlists do I have in Spotify?", + "user_info": {"first_name": "Test", "last_name": "User", "email": "test@example.com", "phone_number": "+1234567890"}, + "app_descriptions": "spotify: Music streaming app\nsupervisor: User management", + } agent.update_from_env(observation, 0.0, False, {}) @@ -247,7 +251,11 @@ def test_basic_interaction_flow(self): agent = AppWorldReactAgent() # Step 1: Receive initial task - task_observation = {"instruction": "How many playlists do I have in Spotify?", "user_info": {"first_name": "Test", "last_name": "User", "email": "test@example.com", "phone_number": "+1234567890"}, "app_descriptions": "spotify: Music streaming app"} + task_observation = { + "instruction": "How many playlists do I have in Spotify?", + "user_info": {"first_name": "Test", "last_name": "User", "email": "test@example.com", "phone_number": "+1234567890"}, + "app_descriptions": "spotify: Music streaming app", + } agent.update_from_env(task_observation, 0.0, False, {}) assert agent.initialized is True diff --git a/tests/cli/test_eval_command.py b/tests/cli/test_eval_command.py index 83fd968d7..27d9fa178 100644 --- a/tests/cli/test_eval_command.py +++ b/tests/cli/test_eval_command.py @@ -86,7 +86,11 @@ def test_eval_with_proxy_mode(runner, tmp_rllm_home, mock_dataset): mock_pm.get_proxy_url.return_value = "http://127.0.0.1:4000/v1" mock_pm.build_proxy_config.return_value = {"model_list": []} - with patch("rllm.experimental.eval.config.load_config", return_value=config), patch("rllm.experimental.eval.proxy.EvalProxyManager", return_value=mock_pm), patch("rllm.experimental.cli.eval._run_eval"): + with ( + patch("rllm.experimental.eval.config.load_config", return_value=config), + patch("rllm.experimental.eval.proxy.EvalProxyManager", return_value=mock_pm), + patch("rllm.experimental.cli.eval._run_eval"), + ): result = runner.invoke( cli, [ @@ -106,7 +110,11 @@ def test_eval_base_url_skips_proxy(runner, tmp_rllm_home, mock_dataset): """Eval with --base-url should not create a proxy.""" mock_agent = _MockAgentFlow() - with patch("rllm.experimental.eval.proxy.EvalProxyManager") as mock_pm_cls, patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=_MockEvaluator()): + with ( + patch("rllm.experimental.eval.proxy.EvalProxyManager") as mock_pm_cls, + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=_MockEvaluator()), + ): result = runner.invoke( cli, [ @@ -129,7 +137,10 @@ def test_eval_with_mock_agent(runner, tmp_rllm_home, mock_dataset): """Eval with a mock agent should produce results.""" mock_agent = _MockAgentFlow() - with patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=_MockEvaluator()): + with ( + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=_MockEvaluator()), + ): result = runner.invoke( cli, [ @@ -154,7 +165,10 @@ def test_eval_with_max_examples(runner, tmp_rllm_home, mock_dataset): """Eval with --max-examples should limit evaluation.""" mock_agent = _MockAgentFlow() - with patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=_MockEvaluator()): + with ( + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=_MockEvaluator()), + ): result = runner.invoke( cli, [ @@ -180,7 +194,10 @@ def test_eval_saves_results(runner, tmp_rllm_home, mock_dataset): """Eval should save results to a JSON file.""" mock_agent = _MockAgentFlow() - with patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=_MockEvaluator()): + with ( + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=_MockEvaluator()), + ): result = runner.invoke( cli, [ @@ -203,7 +220,10 @@ def test_eval_with_explicit_evaluator(runner, tmp_rllm_home, mock_dataset): """Eval with --evaluator should use specified evaluator.""" mock_agent = _MockAgentFlow() - with patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.load_evaluator", return_value=_MockEvaluator()) as mock_load_eval: + with ( + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.load_evaluator", return_value=_MockEvaluator()) as mock_load_eval, + ): result = runner.invoke( cli, [ diff --git a/tests/cli/test_train_command.py b/tests/cli/test_train_command.py index cca4d9f21..9cf64144e 100644 --- a/tests/cli/test_train_command.py +++ b/tests/cli/test_train_command.py @@ -258,7 +258,12 @@ def test_train_agent_resolution_from_catalog(self, runner, tmp_rllm_home, mock_t mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent) as mock_load_agent, patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer): + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent) as mock_load_agent, + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer), + ): result = runner.invoke(cli, ["train", "test_math", "--model", "test-model"]) assert result.exit_code == 0 @@ -272,7 +277,12 @@ def test_train_explicit_agent_and_evaluator(self, runner, tmp_rllm_home, mock_tr mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent) as mock_load_agent, patch("rllm.experimental.eval.evaluator_loader.load_evaluator", return_value=mock_evaluator) as mock_load_eval, patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer): + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent) as mock_load_agent, + patch("rllm.experimental.eval.evaluator_loader.load_evaluator", return_value=mock_evaluator) as mock_load_eval, + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer), + ): result = runner.invoke( cli, [ @@ -298,7 +308,12 @@ def test_train_header_display(self, runner, tmp_rllm_home, mock_train_dataset): mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer): + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer), + ): result = runner.invoke( cli, [ @@ -325,7 +340,12 @@ def test_train_with_max_examples(self, runner, tmp_rllm_home, mock_train_dataset mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer): + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer), + ): result = runner.invoke( cli, [ @@ -349,7 +369,12 @@ def test_train_passes_correct_config_to_trainer(self, runner, tmp_rllm_home, moc mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls: + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls, + ): result = runner.invoke( cli, [ @@ -394,7 +419,12 @@ def test_train_separate_val_dataset(self, runner, tmp_rllm_home): mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer): + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer), + ): result = runner.invoke( cli, [ @@ -417,7 +447,11 @@ def test_train_no_evaluator_found(self, runner, tmp_rllm_home, mock_train_datase catalog = {"datasets": {"test_math": {"default_agent": "math"}}} mock_agent = _MockAgentFlow() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=None): + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=None), + ): result = runner.invoke(cli, ["train", "test_math", "--model", "test-model"]) assert result.exit_code != 0 @@ -430,7 +464,12 @@ def test_train_default_experiment_name(self, runner, tmp_rllm_home, mock_train_d mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls: + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls, + ): result = runner.invoke(cli, ["train", "test_math", "--model", "test-model"]) assert result.exit_code == 0 @@ -445,7 +484,12 @@ def test_train_default_no_ui_logger(self, runner, tmp_rllm_home, mock_train_data mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls: + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls, + ): result = runner.invoke(cli, ["train", "test_math", "--model", "test-model"]) assert result.exit_code == 0 @@ -461,7 +505,12 @@ def test_train_ui_flag_appends_ui_logger(self, runner, tmp_rllm_home, mock_train mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls: + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls, + ): result = runner.invoke(cli, ["train", "test_math", "--model", "test-model", "--ui"]) assert result.exit_code == 0 @@ -476,7 +525,12 @@ def test_train_ui_url_implies_ui(self, runner, tmp_rllm_home, mock_train_dataset mock_evaluator = _MockEvaluator() mock_trainer = MagicMock() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls: + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + patch("rllm.experimental.unified_trainer.AgentTrainer", return_value=mock_trainer) as mock_at_cls, + ): result = runner.invoke(cli, ["train", "test_math", "--model", "test-model", "--ui-url", "http://localhost:3000"]) assert result.exit_code == 0 @@ -493,7 +547,11 @@ def test_train_ui_without_api_key_errors(self, runner, tmp_rllm_home, mock_train mock_agent = _MockAgentFlow() mock_evaluator = _MockEvaluator() - with patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator): + with ( + patch("rllm.experimental.cli.train.load_dataset_catalog", return_value=catalog), + patch("rllm.experimental.eval.agent_loader.load_agent", return_value=mock_agent), + patch("rllm.experimental.eval.evaluator_loader.resolve_evaluator_from_catalog", return_value=mock_evaluator), + ): result = runner.invoke(cli, ["train", "test_math", "--model", "test-model", "--ui"]) assert result.exit_code != 0 diff --git a/tests/envs/test_mcp_env.py b/tests/envs/test_mcp_env.py index 0ce9953cf..f1189a3e6 100644 --- a/tests/envs/test_mcp_env.py +++ b/tests/envs/test_mcp_env.py @@ -16,6 +16,25 @@ def __call__(self, task_info, action, **kwargs): return RewardOutput(reward=reward, metadata=metadata) +@pytest.fixture(autouse=True) +def reset_mcp_environment_state(): + MCPEnvironment._connection_manager = None + MCPEnvironment._connection_managers = {} + MCPEnvironment._server_specs = {} + yield + MCPEnvironment._connection_manager = None + MCPEnvironment._connection_managers = {} + MCPEnvironment._server_specs = {} + + +def make_start_side_effect(command_to_tools): + def _start(manager): + manager.running = True + manager.tool_map = command_to_tools.get(manager.mcp_server_command, {}) + + return _start + + class TestMCPConnectionManager: """Test suite for MCPConnectionManager class.""" @@ -277,6 +296,7 @@ def test_step_with_regular_tool_calls(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Tool output"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} action = [{"id": "call_1", "function": {"name": "search", "arguments": {"query": "test"}}}] @@ -302,6 +322,7 @@ def test_step_max_steps_termination(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Tool output"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} # Take steps until max_steps for i in range(2): @@ -328,6 +349,7 @@ def test_step_with_dict_action(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Tool output"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} action = {"id": "call_1", "function": {"name": "search", "arguments": {"query": "test"}}} @@ -370,11 +392,14 @@ def test_cleanup_global_resources_with_manager(self): """Test cleanup_global_resources with existing manager.""" mock_manager = Mock() MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} MCPEnvironment.cleanup_global_resources() mock_manager.stop.assert_called_once() assert MCPEnvironment._connection_manager is None + assert MCPEnvironment._connection_managers == {} + assert MCPEnvironment._server_specs == {} @patch.object(MCPConnectionManager, "start") @patch.object(MCPConnectionManager, "__init__", return_value=None) @@ -398,7 +423,14 @@ def test_is_multithread_safe(self): def test_from_dict(self): """Test creating environment from dictionary.""" - env_args = {"question": "Test question", "mcp_server_command": "test_command", "mcp_server_args": ["--arg1"], "mcp_server_env": {"VAR": "value"}, "max_steps": 15, "reward_fn": MockRewardFunction()} + env_args = { + "question": "Test question", + "mcp_server_command": "test_command", + "mcp_server_args": ["--arg1"], + "mcp_server_env": {"VAR": "value"}, + "max_steps": 15, + "reward_fn": MockRewardFunction(), + } with patch.object(MCPConnectionManager, "start"), patch.object(MCPConnectionManager, "__init__", return_value=None): # Clear any existing manager @@ -447,6 +479,7 @@ def test_full_interaction_flow(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Paris is the capital of France"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} action1 = [{"id": "call_1", "function": {"name": "search", "arguments": {"query": "capital of France"}}}] @@ -499,12 +532,15 @@ def test_step_with_tool_execution_error(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.side_effect = Exception("Tool execution failed") MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} action = [{"id": "call_1", "function": {"name": "search", "arguments": {"query": "test"}}}] - # Should not raise error but return empty tool outputs obs, reward, done, info = env.step(action) - assert obs == {"tool_outputs": {}} + assert obs == {"tool_outputs": {"call_1": "Error: MCP server default failed: Tool execution failed"}} + assert reward == 0 + assert done is False + assert info["response"] == action @patch.object(MCPConnectionManager, "start") @patch.object(MCPConnectionManager, "__init__", return_value=None) @@ -520,6 +556,7 @@ def test_edge_cases(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} # Empty action list obs, reward, done, info = env.step([]) @@ -569,6 +606,8 @@ def test_connection_manager_thread_safety(self): """Test that connection manager handles thread safety correctly.""" # Both environments should be able to access the class-level manager assert hasattr(MCPEnvironment, "_connection_manager") + assert hasattr(MCPEnvironment, "_connection_managers") + assert hasattr(MCPEnvironment, "_server_specs") assert hasattr(MCPEnvironment, "_manager_lock") @patch.object(MCPConnectionManager, "__init__", return_value=None) @@ -583,6 +622,9 @@ def test_connection_manager_singleton_behavior(self, mock_start, mock_init): # Both environments should use the same manager assert MCPEnvironment._connection_manager is not None + assert len(MCPEnvironment._connection_managers) == 1 + assert mock_init.call_count == 1 + assert mock_start.call_count == 1 @patch.object(MCPConnectionManager, "start") @patch.object(MCPConnectionManager, "__init__", return_value=None) @@ -598,12 +640,360 @@ def test_malformed_tool_call_handling(self, mock_init, mock_start): mock_manager = Mock() mock_manager.execute_tool_calls.return_value = {"call_1": "Tool output"} MCPEnvironment._connection_manager = mock_manager + MCPEnvironment._connection_managers = {"default": mock_manager} # Malformed action (missing required fields) action = [{"id": "call_1"}] # Missing function field obs, reward, done, info = env.step(action) - # Should still process the action - assert obs == {"tool_outputs": {"call_1": "Tool output"}} + assert obs == {"tool_outputs": {"call_1": "Error: Tool call missing function.name"}} + assert done is False + mock_manager.execute_tool_calls.assert_not_called() + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_init_with_multiple_servers(self, mock_start): + """Test initializing MCPEnvironment with multiple named servers.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + + assert set(env.mcp_servers) == {"search_server", "wiki_server"} + assert set(MCPEnvironment._connection_managers) == {"search_server", "wiki_server"} + assert MCPEnvironment._connection_manager is None + assert env._resolved_tool_name_to_server_name == { + "search": "search_server", + "lookup": "wiki_server", + } + assert mock_start.call_count == 2 + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_same_server_name_with_different_config_raises(self, mock_start): + """Test that reusing a server name with a different config fails fast.""" + mock_start.side_effect = make_start_side_effect({"search-command": {"search": Mock()}}) + + MCPEnvironment(mcp_servers={"shared": {"command": "search-command"}}) + + with pytest.raises(ValueError, match="already initialized with a different configuration"): + MCPEnvironment(mcp_servers={"shared": {"command": "different-command"}}) + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_explicit_tool_name_to_server_name_resolves_ambiguity(self, mock_start): + """Test that explicit tool routing resolves duplicate tool names.""" + mock_start.side_effect = make_start_side_effect( + { + "command-a": {"shared_tool": Mock()}, + "command-b": {"shared_tool": Mock()}, + } + ) + + env = MCPEnvironment( + mcp_servers={ + "server_a": {"command": "command-a"}, + "server_b": {"command": "command-b"}, + }, + tool_name_to_server_name={"shared_tool": "server_b"}, + ) + + assert env._resolved_tool_name_to_server_name == {"shared_tool": "server_b"} + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_duplicate_public_tool_name_without_mapping_raises(self, mock_start): + """Test that duplicate tool names across servers require explicit routing.""" + mock_start.side_effect = make_start_side_effect( + { + "command-a": {"shared_tool": Mock()}, + "command-b": {"shared_tool": Mock()}, + } + ) + + with pytest.raises(ValueError, match="Tool 'shared_tool' is provided by multiple MCP servers"): + MCPEnvironment( + mcp_servers={ + "server_a": {"command": "command-a"}, + "server_b": {"command": "command-b"}, + } + ) + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_step_routes_tool_calls_to_correct_server(self, mock_start): + """Test that tool calls are routed to the correct MCP server.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + env.reset() + + search_manager = MCPEnvironment._connection_managers["search_server"] + wiki_manager = MCPEnvironment._connection_managers["wiki_server"] + search_manager.execute_tool_calls = Mock(return_value={"call_1": "Search output"}) + wiki_manager.execute_tool_calls = Mock(return_value={"call_2": "Lookup output"}) + + action = [ + {"id": "call_1", "function": {"name": "search", "arguments": {"query": "France"}}}, + {"id": "call_2", "function": {"name": "lookup", "arguments": {"topic": "Paris"}}}, + ] + + obs, reward, done, info = env.step(action) + + assert obs == {"tool_outputs": {"call_1": "Search output", "call_2": "Lookup output"}} + assert reward == 0 + assert done is False + assert info["response"] == action + search_manager.execute_tool_calls.assert_called_once_with([action[0]]) + wiki_manager.execute_tool_calls.assert_called_once_with([action[1]]) + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_partial_server_failure_does_not_erase_other_outputs(self, mock_start): + """Test that one server failure does not discard successful tool outputs.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + env.reset() + + search_manager = MCPEnvironment._connection_managers["search_server"] + wiki_manager = MCPEnvironment._connection_managers["wiki_server"] + search_manager.execute_tool_calls = Mock(return_value={"call_1": "Search output"}) + wiki_manager.execute_tool_calls = Mock(side_effect=Exception("wiki unavailable")) + + action = [ + {"id": "call_1", "function": {"name": "search", "arguments": {"query": "France"}}}, + {"id": "call_2", "function": {"name": "lookup", "arguments": {"topic": "Paris"}}}, + ] + + obs, reward, done, info = env.step(action) + + assert obs == { + "tool_outputs": { + "call_1": "Search output", + "call_2": "Error: MCP server wiki_server failed: wiki unavailable", + } + } + assert reward == 0 assert done is False + assert info["response"] == action + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_step_assigns_missing_tool_call_ids_across_servers(self, mock_start): + """Test that synthetic tool call ids stay unique across routed server groups.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + env.reset() + + search_manager = MCPEnvironment._connection_managers["search_server"] + wiki_manager = MCPEnvironment._connection_managers["wiki_server"] + search_manager.execute_tool_calls = Mock(side_effect=lambda tool_calls: {tool_calls[0]["id"]: "Search output"}) + wiki_manager.execute_tool_calls = Mock(side_effect=lambda tool_calls: {tool_calls[0]["id"]: "Lookup output"}) + + action = [ + {"function": {"name": "search", "arguments": {"query": "France"}}}, + {"function": {"name": "lookup", "arguments": {"topic": "Paris"}}}, + ] + + obs, reward, done, info = env.step(action) + + assert obs == { + "tool_outputs": { + "tool_call_0": "Search output", + "tool_call_1": "Lookup output", + } + } + assert reward == 0 + assert done is False + assert info["response"] == action + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_step_preserves_interleaved_tool_output_order_across_servers(self, mock_start): + """Test that output ordering follows the original tool-call order.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env = MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + env.reset() + + search_manager = MCPEnvironment._connection_managers["search_server"] + wiki_manager = MCPEnvironment._connection_managers["wiki_server"] + search_manager.execute_tool_calls = Mock( + return_value={ + "call_1": "Search output 1", + "call_3": "Search output 2", + } + ) + wiki_manager.execute_tool_calls = Mock(return_value={"call_2": "Lookup output"}) + + action = [ + {"id": "call_1", "function": {"name": "search", "arguments": {"query": "France"}}}, + {"id": "call_2", "function": {"name": "lookup", "arguments": {"topic": "Paris"}}}, + {"id": "call_3", "function": {"name": "search", "arguments": {"query": "Europe"}}}, + ] + + obs, reward, done, info = env.step(action) + + assert list(obs["tool_outputs"]) == ["call_1", "call_2", "call_3"] + assert reward == 0 + assert done is False + assert info["response"] == action + + @patch.object(MCPConnectionManager, "stop", autospec=True) + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_start_failure_rolls_back_previously_started_managers(self, mock_start, mock_stop): + """Test that manager startup failures do not leave partial global state behind.""" + + def _start(manager): + if manager.mcp_server_command == "search-command": + manager.running = True + manager.tool_map = {"search": Mock()} + return + raise RuntimeError("startup failed") + + mock_start.side_effect = _start + + with pytest.raises(RuntimeError, match="startup failed"): + MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + } + ) + + assert MCPEnvironment._connection_manager is None + assert MCPEnvironment._connection_managers == {} + assert MCPEnvironment._server_specs == {} + assert mock_stop.call_count == 2 + + @patch.object(MCPConnectionManager, "stop", autospec=True) + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_invalid_tool_mapping_rolls_back_new_managers(self, mock_start, mock_stop): + """Test that routing validation failures clean up newly started managers.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + + with pytest.raises(ValueError, match="does not match any discovered tool"): + MCPEnvironment( + mcp_servers={ + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + }, + tool_name_to_server_name={"missing_tool": "search_server"}, + ) + + assert MCPEnvironment._connection_manager is None + assert MCPEnvironment._connection_managers == {} + assert MCPEnvironment._server_specs == {} + assert mock_stop.call_count == 2 + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_from_dict_with_mcp_servers(self, mock_start): + """Test creating an environment from dictionary with multi-server config.""" + mock_start.side_effect = make_start_side_effect( + { + "search-command": {"search": Mock()}, + "wiki-command": {"lookup": Mock()}, + } + ) + env_args = { + "question": "Test question", + "mcp_servers": { + "search_server": {"command": "search-command"}, + "wiki_server": {"command": "wiki-command"}, + }, + "tool_name_to_server_name": {"lookup": "wiki_server"}, + "max_steps": 15, + "reward_fn": MockRewardFunction(), + } + + env = MCPEnvironment.from_dict(env_args) + + assert isinstance(env, MCPEnvironment) + assert env.task == {"question": "Test question"} + assert env.max_steps == 15 + assert set(env.mcp_servers) == {"search_server", "wiki_server"} + assert env.tool_name_to_server_name == {"lookup": "wiki_server"} + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_from_dict_does_not_mutate_input(self, mock_start): + """Test that from_dict does not mutate the provided env_args dictionary.""" + mock_start.side_effect = make_start_side_effect({"search-command": {"search": Mock()}}) + env_args = { + "question": "Test question", + "mcp_servers": {"search_server": {"command": "search-command"}}, + "tool_name_to_server_name": {"search": "search_server"}, + "max_steps": 7, + } + expected_env_args = { + "question": "Test question", + "mcp_servers": {"search_server": {"command": "search-command"}}, + "tool_name_to_server_name": {"search": "search_server"}, + "max_steps": 7, + } + + MCPEnvironment.from_dict(env_args) + + assert env_args == expected_env_args + + @patch.object(MCPConnectionManager, "start", autospec=True) + def test_hyphenated_tool_aliases_are_checked_for_duplicates(self, mock_start): + """Test that hyphenated tool aliases participate in duplicate detection.""" + mock_start.side_effect = make_start_side_effect( + { + "command-a": {"search-tool": Mock(), "search_tool": Mock()}, + "command-b": {"search-tool": Mock(), "search_tool": Mock()}, + } + ) + + with pytest.raises(ValueError, match="Tool 'search-tool' is provided by multiple MCP servers"): + MCPEnvironment( + mcp_servers={ + "server_a": {"command": "command-a"}, + "server_b": {"command": "command-b"}, + } + ) diff --git a/tests/envs/test_tool_env.py b/tests/envs/test_tool_env.py index 7a834cecd..8cb25c5d9 100644 --- a/tests/envs/test_tool_env.py +++ b/tests/envs/test_tool_env.py @@ -235,7 +235,10 @@ def test_execute_tool_calls_multiple_tools(self): mock_tool_output = ToolOutput(name="mock_tool", output="Mock result") env.tools.forward = Mock(return_value=mock_tool_output) - tool_calls = [{"id": "call_1", "function": {"name": "mock_tool", "arguments": json.dumps({"query": "test1"})}}, {"id": "call_2", "function": {"name": "mock_tool", "arguments": json.dumps({"query": "test2"})}}] + tool_calls = [ + {"id": "call_1", "function": {"name": "mock_tool", "arguments": json.dumps({"query": "test1"})}}, + {"id": "call_2", "function": {"name": "mock_tool", "arguments": json.dumps({"query": "test2"})}}, + ] result = env._execute_tool_calls(tool_calls) @@ -259,7 +262,10 @@ def mock_forward(*args, **kwargs): env.tools.forward = mock_forward - tool_calls = [{"id": "call_1", "function": {"name": "mock_tool", "arguments": json.dumps({"query": "test1"})}}, {"id": "call_2", "function": {"name": "mock_tool", "arguments": json.dumps({"query": "test2"})}}] + tool_calls = [ + {"id": "call_1", "function": {"name": "mock_tool", "arguments": json.dumps({"query": "test1"})}}, + {"id": "call_2", "function": {"name": "mock_tool", "arguments": json.dumps({"query": "test2"})}}, + ] result = env._execute_tool_calls(tool_calls) @@ -394,7 +400,10 @@ def slow_forward(*args, **kwargs): env.tools.forward = slow_forward - tool_calls = [{"id": "call_1", "function": {"name": "slow_tool", "arguments": json.dumps({"query": "test1"})}}, {"id": "call_2", "function": {"name": "slow_tool", "arguments": json.dumps({"query": "test2"})}}] + tool_calls = [ + {"id": "call_1", "function": {"name": "slow_tool", "arguments": json.dumps({"query": "test1"})}}, + {"id": "call_2", "function": {"name": "slow_tool", "arguments": json.dumps({"query": "test2"})}}, + ] start_time = time.time() result = env._execute_tool_calls(tool_calls) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c4879a5a6..7b10c65a3 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -18,4 +18,4 @@ requires_tinker = pytest.mark.skipif( not TINKER_API_KEY, reason="TINKER_API_KEY env var required", -) \ No newline at end of file +) diff --git a/tests/integration/test_agentcore_runtime.py b/tests/integration/test_agentcore_runtime.py index 2bcab0ae7..fa34a122f 100644 --- a/tests/integration/test_agentcore_runtime.py +++ b/tests/integration/test_agentcore_runtime.py @@ -47,7 +47,9 @@ async def test_single_task(self): """Submit a GSM8K problem, verify success=True, reward is not None.""" runtime = _make_runtime() sub = _make_submission( - prompt=("Toula went to the bakery and bought various types of pastries. She bought 3 dozen donuts which cost $68 per dozen, 2 dozen mini cupcakes which cost $80 per dozen, and 6 dozen mini cheesecakes for $55 per dozen. How much was the total cost?"), + prompt=( + "Toula went to the bakery and bought various types of pastries. She bought 3 dozen donuts which cost $68 per dozen, 2 dozen mini cupcakes which cost $80 per dozen, and 6 dozen mini cheesecakes for $55 per dozen. How much was the total cost?" + ), answer="694", ) diff --git a/tests/integration/test_minimax_integration.py b/tests/integration/test_minimax_integration.py index 9f3529219..db34e4116 100644 --- a/tests/integration/test_minimax_integration.py +++ b/tests/integration/test_minimax_integration.py @@ -9,7 +9,7 @@ import pytest -from rllm.experimental.eval.config import get_provider_info, load_config, save_config, RllmConfig +from rllm.experimental.eval.config import RllmConfig, load_config, save_config from rllm.experimental.eval.proxy import EvalProxyManager MINIMAX_API_KEY = os.environ.get("MINIMAX_API_KEY") diff --git a/tests/rewards/test_code_reward.py b/tests/rewards/test_code_reward.py index 47ccbf91f..00386b1c2 100644 --- a/tests/rewards/test_code_reward.py +++ b/tests/rewards/test_code_reward.py @@ -358,7 +358,9 @@ def test_reward_leetcode(self): class Solution:\n def minOperations(self, nums: List[int], k: int) -> int:\n is_added = [False] * k\n count = 0\n n = len(nums)\n for i in range(n - 1, -1, -1):\n if nums[i] > k or is_added[nums[i] - 1]:\n continue\n is_added[nums[i] - 1] = True\n count += 1\n if count == k:\n return n - i\n ``` """ - tests = {"functional": "def check(candidate):\n assert candidate(nums = [3,1,5,4,2], k = 2) == 4\n assert candidate(nums = [3,1,5,4,2], k = 5) == 5\n assert candidate(nums = [3,2,5,3,1], k = 3) == 4\n\n\ncheck(Solution().minOperations)"} + tests = { + "functional": "def check(candidate):\n assert candidate(nums = [3,1,5,4,2], k = 2) == 4\n assert candidate(nums = [3,1,5,4,2], k = 5) == 5\n assert candidate(nums = [3,2,5,3,1], k = 3) == 4\n\n\ncheck(Solution().minOperations)" + } reward = RewardCodeFn(RewardConfig()) task_info = {"problem": "", "problem_type": RewardType.CODE, "data_source": "leetcode", "ground_truth": tests} output = reward(task_info, model_response) @@ -373,7 +375,9 @@ def test_reward_leetcode_format_error(self): Here is my bad response, it is not in markdown oops class Solution:\n def minOperations(self, nums: List[int], k: int) -> int:\n is_added = [False] * k\n count = 0\n n = len(nums)\n for i in range(n - 1, -1, -1):\n if nums[i] > k or is_added[nums[i] - 1]:\n continue\n is_added[nums[i] - 1] = True\n count += 1\n if count == k:\n return n - i\n """ - tests = {"functional": "def check(candidate):\n assert candidate(nums = [3,1,5,4,2], k = 2) == 4\n assert candidate(nums = [3,1,5,4,2], k = 5) == 5\n assert candidate(nums = [3,2,5,3,1], k = 3) == 4\n\n\ncheck(Solution().minOperations)"} + tests = { + "functional": "def check(candidate):\n assert candidate(nums = [3,1,5,4,2], k = 2) == 4\n assert candidate(nums = [3,1,5,4,2], k = 5) == 5\n assert candidate(nums = [3,2,5,3,1], k = 3) == 4\n\n\ncheck(Solution().minOperations)" + } reward = RewardCodeFn(RewardConfig()) task_info = {"problem": "", "problem_type": RewardType.CODE, "data_source": "leetcode", "ground_truth": tests} output = reward(task_info, model_response) diff --git a/tests/test_verl_imports.py b/tests/test_verl_imports.py new file mode 100644 index 000000000..914bb16e8 --- /dev/null +++ b/tests/test_verl_imports.py @@ -0,0 +1,50 @@ +"""Test that verl import paths are compatible with verl 0.7.1+. + +Regression test for https://github.com/rllm-org/rllm/issues/470 +After verl 0.7.1 restructured module paths, several imports in the +fully_async module broke with ModuleNotFoundError. +""" + +import importlib + +import pytest + + +@pytest.mark.parametrize( + "module_path,names", + [ + ( + "rllm.experimental.fully_async.runner", + ["AsyncAgentTrainer", "FullyAsyncTaskRunner"], + ), + ( + "rllm.experimental.fully_async.fully_async_trainer", + ["FullyAsyncTrainer"], + ), + ( + "rllm.experimental.fully_async.inference_manager", + ["InferenceManager"], + ), + ], +) +def test_fully_async_imports(module_path: str, names: list[str]) -> None: + """Verify fully_async modules can be imported without ModuleNotFoundError.""" + mod = importlib.import_module(module_path) + for name in names: + assert hasattr(mod, name), f"{module_path} is missing attribute {name}" + + +@pytest.mark.parametrize( + "module_path,name", + [ + ("verl.experimental.separation.ray_trainer", "SeparateRayPPOTrainer"), + ("verl.experimental.separation.utils", "create_resource_pool_manager"), + ("verl.experimental.separation.utils", "create_role_worker_mapping"), + ("verl.experimental.agent_loop", "AgentLoopManager"), + ("verl.utils.net_utils", "get_free_port"), + ], +) +def test_verl_module_paths_exist(module_path: str, name: str) -> None: + """Verify the verl module paths used by rllm exist in the installed verl package.""" + mod = importlib.import_module(module_path) + assert hasattr(mod, name), f"{module_path} is missing attribute {name}" diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index f28a44d6c..a989a2a4f 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -118,10 +118,22 @@ async def main(): # Test queries for search tools search_test_cases = { - "google_search": [{"name": "Basic Search", "query": "What is Python programming?", "expected_fields": ["title", "snippet", "link"]}, {"name": "Technical Search", "query": "Python async await syntax example", "expected_fields": ["title", "snippet", "link"]}], - # "tavily_search": [{"name": "News Search", "query": "Latest developments in AI", "expected_fields": ["title", "snippet", "url"]}, {"name": "Technical Search", "query": "Python type hints tutorial", "expected_fields": ["title", "snippet", "url"]}], - # "tavily_extract": [{"name": "Python.org", "url": "https://www.python.org/about/", "expected_fields": ["title", "text"]}, {"name": "Python Docs", "url": "https://docs.python.org/3/tutorial/", "expected_fields": ["title", "text"]}], - # "firecrawl": [{"name": "Python.org", "url": "https://www.python.org", "expected_fields": ["title", "text", "links"]}, {"name": "Python Docs", "url": "https://docs.python.org/3/", "expected_fields": ["title", "text", "links"]}], + "google_search": [ + {"name": "Basic Search", "query": "What is Python programming?", "expected_fields": ["title", "snippet", "link"]}, + {"name": "Technical Search", "query": "Python async await syntax example", "expected_fields": ["title", "snippet", "link"]}, + ], + # "tavily_search": [ + # {"name": "News Search", "query": "Latest developments in AI", "expected_fields": ["title", "snippet", "url"]}, + # {"name": "Technical Search", "query": "Python type hints tutorial", "expected_fields": ["title", "snippet", "url"]}, + # ], + # "tavily_extract": [ + # {"name": "Python.org", "url": "https://www.python.org/about/", "expected_fields": ["title", "text"]}, + # {"name": "Python Docs", "url": "https://docs.python.org/3/tutorial/", "expected_fields": ["title", "text"]}, + # ], + # "firecrawl": [ + # {"name": "Python.org", "url": "https://www.python.org", "expected_fields": ["title", "text", "links"]}, + # {"name": "Python Docs", "url": "https://docs.python.org/3/", "expected_fields": ["title", "text", "links"]}, + # ], } diff --git a/tests/unified_trainer/test_algorithm_config.py b/tests/unified_trainer/test_algorithm_config.py new file mode 100644 index 000000000..b55cf09f8 --- /dev/null +++ b/tests/unified_trainer/test_algorithm_config.py @@ -0,0 +1,58 @@ +""" +Tests for AlgorithmConfig to verify norm_adv_by_std_in_grpo is read from +rllm.algorithm (not rllm.stepwise_advantage). + +See: https://github.com/rllm-org/rllm/issues/447 +""" + +import importlib.util +import os + +from omegaconf import OmegaConf + +# Import config module directly to avoid heavy transitive deps (codetiming, verl, etc.) +_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../../rllm/experimental/common/config.py") +_spec = importlib.util.spec_from_file_location("rllm_common_config", _CONFIG_PATH) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) +AlgorithmConfig = _mod.AlgorithmConfig + + +def _make_config(norm_adv_by_std_in_grpo: bool = True): + """Build a minimal DictConfig mirroring the real rllm config structure.""" + return OmegaConf.create( + { + "algorithm": { + "adv_estimator": "grpo", + }, + "rllm": { + "algorithm": { + "adv_estimator": "grpo", + "norm_adv_by_std_in_grpo": norm_adv_by_std_in_grpo, + "use_precomputed_advantage": False, + "loss_fn": None, + "lr_schedule": "constant", + "warmup_steps_ratio": 0.0, + }, + "stepwise_advantage": { + "mode": "broadcast", + # Intentionally omit norm_adv_by_std_in_grpo here to confirm + # the code reads from rllm.algorithm, not stepwise_advantage. + }, + }, + } + ) + + +def test_norm_adv_by_std_in_grpo_true_from_algorithm(): + """norm_adv_by_std_in_grpo=True is read from rllm.algorithm, not stepwise_advantage.""" + config = _make_config(norm_adv_by_std_in_grpo=True) + algo_config = AlgorithmConfig.from_config(config) + assert algo_config.norm_adv_by_std_in_grpo is True + + +def test_norm_adv_by_std_in_grpo_false_from_algorithm(): + """norm_adv_by_std_in_grpo=False is read from rllm.algorithm, not stepwise_advantage.""" + config = _make_config(norm_adv_by_std_in_grpo=False) + algo_config = AlgorithmConfig.from_config(config) + assert algo_config.norm_adv_by_std_in_grpo is False