Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
5b54789
fix(trainer): supplement dfed770 by adding missing update_weights in …
MarkJoson Apr 2, 2026
829da71
Fix norm_adv_by_std_in_grpo read from algorithm not stepwise_advantage
JiwaniZakir Apr 2, 2026
662f2f6
Add multi-server support to MCPEnvironment
taivu1998 Apr 2, 2026
3bbe160
fix: update verl import paths for verl 0.7.1+ compatibility
Lidang-Jiang Apr 3, 2026
4328573
test: add import path verification tests for verl 0.7.1
Lidang-Jiang Apr 3, 2026
61a5145
Merge pull request #471 from JiwaniZakir/fix/447-norm-adv-by-std-in-g…
kylemontgomery1 Apr 4, 2026
4520ad7
Merge pull request #480 from Lidang-Jiang/fix/verl-import-path
kylemontgomery1 Apr 4, 2026
4f49efc
Merge pull request #476 from taivu1998/tdv/issue-321-multi-mcp
kylemontgomery1 Apr 4, 2026
af297ca
additional fixes of sdk trainer
MarkJoson Apr 4, 2026
ec8bd7a
fix: migrate VerlBackend to new EngineWorker path (verl 0.7.1) (#483)
listar2000 Apr 4, 2026
d6101dc
feat: add hf_template tokenize_and_mask method + verl SFTTrainer compat
yifannnwu Apr 4, 2026
bc54009
fix: handle signal.signal ValueError in non-main threads (#484)
yifannnwu Apr 4, 2026
19618b2
Merge pull request #469 from MarkJoson/fix-sdk-rollout-engine-crash
kylemontgomery1 Apr 4, 2026
e5b81c1
Merge pull request #485 from yifannnwu/feat/sft-hf-template
kylemontgomery1 Apr 4, 2026
7fb0450
fix: resolve CI failures — E501 lint, tinker test deps, disable Claud…
listar2000 Apr 5, 2026
ac1852d
style: auto-format 21 files to fix ruff-format pre-commit failures (#…
listar2000 Apr 5, 2026
6d97845
Integrate fully async training to UnifiedTrainer (#481)
kylemontgomery1 Apr 5, 2026
86c5c7f
fix(verl): disable vllm compile cache to work around corruption bug (…
luyuzhe111 Apr 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/claude-code-review.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ on:

jobs:
claude-review:
# Optional: Filter by PR author
# if: |
# github.event.pull_request.user.login == 'external-contributor' ||
# github.event.pull_request.user.login == 'new-developer' ||
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
if: false # Disabled due to credential issue with CLAUDE_CODE_OAUTH_TOKEN

runs-on: ubuntu-latest
permissions:
Expand Down
6 changes: 1 addition & 5 deletions .github/workflows/claude.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ on:

jobs:
claude:
if: |
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
(github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
if: false # Disabled due to credential issue with CLAUDE_CODE_OAUTH_TOKEN
runs-on: ubuntu-latest
permissions:
contents: read
Expand Down
17 changes: 12 additions & 5 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,31 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Cache pre-commit
uses: actions/cache@v3
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
pre-commit-

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit

- name: Run pre-commit
run: pre-commit run --all-files
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
pre-commit run --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} --show-diff-on-fail
else
pre-commit run --all-files --show-diff-on-fail
fi
2 changes: 1 addition & 1 deletion .github/workflows/test-tinker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
uses: astral-sh/setup-uv@v4

- name: Install dependencies
run: uv sync --extra tinker
run: uv sync --extra tinker --extra dev

- name: Run Tinker engine tests
run: uv run pytest tests/engine/test_tinker_engine.py -v
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ repos:
args: ["--fix", "--show-fixes", "--output-format=full"]
exclude: ^.*\.(ipynb)$|^verl/.*$
- id: ruff-format
exclude: ^verl/.*$
exclude: ^.*\.(ipynb)$|^verl/.*$
29 changes: 28 additions & 1 deletion examples/archive/mcp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This example demonstrates how to use external MCP servers as tool providers with
```bash
# Install MCP CLI (if needed for other MCP servers)
uv pip install mcp

```

## Files

Expand Down Expand Up @@ -71,6 +71,33 @@ This will:
- **`MCPEnvironment`** - Environment that manages MCP server connections and tool execution
- **`MCPConnectionManager`** - Handles MCP server lifecycle and tool discovery

### Multiple MCP Servers

`MCPEnvironment` also supports routing tool calls across multiple named MCP servers:

```python
env = MCPEnvironment(
task={"question": "Find and summarize the latest updates."},
mcp_servers={
"search": {
"command": "npx",
"args": ["-y", "tavily-mcp@0.2.4"],
"env": {"TAVILY_API_KEY": "..."},
},
"filesystem": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
},
},
tool_name_to_server_name={
"tavily_search": "search",
"read_file": "filesystem",
},
)
```

Use `tool_name_to_server_name` when multiple servers expose the same public tool name, including underscore aliases for tools whose original MCP names contain hyphens.

### Integration with RLLM

The example follows standard RLLM patterns:
Expand Down
4 changes: 3 additions & 1 deletion examples/archive/sft/run_sft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
agent_args = {
"tools": ["python"],
"parser_name": "qwen",
"system_prompt": ('You are an expert mathematician and programmer. Your goal is to solve challenging math problems, like those from the AIME competition, by breaking them down into logical steps and using Python code for calculations. Strive for clarity and efficiency.\n\nFollow this process for every problem:\n1. **Analyze the Problem**: Read the question carefully. Identify the key information, constraints, and what is being asked.\n2. **Think Step-by-Step**: In the `<think>` block, outline your plan. Decompose the problem into the smallest, most logical steps. **You must not write code or perform calculations in this block.** Your goal is to create a plan that will be executed by the Python tool.\n3. **Write Python Code**: In the `<tool_call>` block, write an efficient Python script to execute your plan. The tool expects a JSON object with `name` and `arguments` keys. The `arguments` should be a dictionary with a single `code` key. Ensure the code is self-contained, runs quickly, and prints the final result.\n4. **State the Final Answer**: After receiving the `<tool_result>`, verify it. Then, state the final answer clearly and concisely in the format \\boxed{answer}.\n\nHere is an example:\nQuestion: What is the largest prime factor of 25! ?\n<think>The problem asks for the largest prime factor of 25 factorial. The largest prime factor of n! is the largest prime number less than or equal to n. In this case, n=25. I will write a Python script to find the largest prime number less than or equal to 25.</think>\n<tool_call>\n{"name": "python", "arguments": {"code": "import math\\ndef is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(math.sqrt(n)) + 1):\\n if n % i == 0:\\n return False\\n return True\\n\\ndef largest_prime_up_to(n):\\n for i in range(n, 1, -1):\\n if is_prime(i):\\n return i\\n return None\\n\\nprint(largest_prime_up_to(25))"}}\n</tool_call>\n<tool_result>\n23\n</tool_result>\nThe largest prime factor of 25! is the largest prime number less than or equal to 25. The answer is \\boxed{23}.'),
"system_prompt": (
'You are an expert mathematician and programmer. Your goal is to solve challenging math problems, like those from the AIME competition, by breaking them down into logical steps and using Python code for calculations. Strive for clarity and efficiency.\n\nFollow this process for every problem:\n1. **Analyze the Problem**: Read the question carefully. Identify the key information, constraints, and what is being asked.\n2. **Think Step-by-Step**: In the `<think>` block, outline your plan. Decompose the problem into the smallest, most logical steps. **You must not write code or perform calculations in this block.** Your goal is to create a plan that will be executed by the Python tool.\n3. **Write Python Code**: In the `<tool_call>` block, write an efficient Python script to execute your plan. The tool expects a JSON object with `name` and `arguments` keys. The `arguments` should be a dictionary with a single `code` key. Ensure the code is self-contained, runs quickly, and prints the final result.\n4. **State the Final Answer**: After receiving the `<tool_result>`, verify it. Then, state the final answer clearly and concisely in the format \\boxed{answer}.\n\nHere is an example:\nQuestion: What is the largest prime factor of 25! ?\n<think>The problem asks for the largest prime factor of 25 factorial. The largest prime factor of n! is the largest prime number less than or equal to n. In this case, n=25. I will write a Python script to find the largest prime number less than or equal to 25.</think>\n<tool_call>\n{"name": "python", "arguments": {"code": "import math\\ndef is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(math.sqrt(n)) + 1):\\n if n % i == 0:\\n return False\\n return True\\n\\ndef largest_prime_up_to(n):\\n for i in range(n, 1, -1):\\n if is_prime(i):\\n return i\\n return None\\n\\nprint(largest_prime_up_to(25))"}}\n</tool_call>\n<tool_result>\n23\n</tool_result>\nThe largest prime factor of 25! is the largest prime number less than or equal to 25. The answer is \\boxed{23}.'
),
}
env_args = {
"tools": ["python"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import hydra

from rllm.data.dataset import DatasetRegistry
from rllm.experimental.unified_trainer import AgentTrainer
from rllm.rewards.countdown_reward import countdown_reward_fn
from rllm.workflows.simple_workflow import SimpleWorkflow


@hydra.main(config_path="pkg://rllm.experimental.config", config_name="unified", version_base=None)
def main(config):
train_dataset = DatasetRegistry.load_dataset("countdown", "train")
test_dataset = DatasetRegistry.load_dataset("countdown", "test")

trainer = AgentTrainer(
workflow_class=SimpleWorkflow,
workflow_args={
"reward_function": countdown_reward_fn,
},
config=config,
train_dataset=train_dataset,
val_dataset=test_dataset,
backend="tinker",
)
trainer.train()


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
set -x

python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \
rllm/backend=tinker \
model.name=Qwen/Qwen3-8B \
model.lora_rank=32 \
training.group_size=8 \
training.learning_rate=2e-5 \
training.max_length=4096 \
sampling.train.temperature=1.0 \
sampling.train.top_p=1.0 \
sampling.val.temperature=1.0 \
sampling.val.top_p=1.0 \
validation.group_size=1 \
rllm.workflow.n_parallel_tasks=256 \
rllm.workflow.retry_limit=1 \
rllm.workflow.raise_on_error=false \
data.max_prompt_length=2048 \
data.max_response_length=2048 \
data.train_batch_size=1 \
data.val_batch_size=1024 \
rllm.algorithm.adv_estimator=grpo \
rllm.algorithm.norm_adv_by_std_in_grpo=true \
rllm.async_training.enable=true \
rllm.async_training.mini_batch_size=32 \
rllm.async_training.fwd_bwd_group_size=8 \
rllm.async_training.staleness_threshold=0.5 \
rllm.async_training.trigger_parameter_sync_step=1 \
rllm.async_training.partial_rollout=true \
rllm.trainer.total_epochs=1 \
rllm.trainer.logger='[wandb]' \
rllm.trainer.project_name='rllm-countdown' \
rllm.trainer.experiment_name='countdown-tinker-async-staleness-0.5' \
rllm.trainer.val_before_train=true \
rllm.trainer.test_freq=10 \
rllm.trainer.save_freq=-1
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
set -x

python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \
rllm/backend=tinker \
model.name=Qwen/Qwen3-8B \
model.lora_rank=32 \
training.group_size=8 \
training.learning_rate=2e-5 \
training.max_length=4096 \
sampling.train.temperature=1.0 \
sampling.train.top_p=1.0 \
sampling.val.temperature=1.0 \
sampling.val.top_p=1.0 \
validation.group_size=1 \
rllm.workflow.n_parallel_tasks=256 \
rllm.workflow.retry_limit=1 \
rllm.workflow.raise_on_error=false \
data.max_prompt_length=2048 \
data.max_response_length=2048 \
data.train_batch_size=32 \
data.val_batch_size=1024 \
rllm.algorithm.adv_estimator=grpo \
rllm.algorithm.norm_adv_by_std_in_grpo=true \
rllm.async_training.enable=false \
rllm.trainer.total_epochs=1 \
rllm.trainer.logger='[wandb]' \
rllm.trainer.project_name='rllm-countdown' \
rllm.trainer.experiment_name='countdown-tinker-sync' \
rllm.trainer.val_before_train=true \
rllm.trainer.test_freq=10 \
rllm.trainer.save_freq=-1
25 changes: 22 additions & 3 deletions examples/deepcoder/prepare_deepcoder_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,19 @@


def prepare_deepcoder_data(train_size: int = None, test_size: int = None):
train_dataset = concatenate_datasets([load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="primeintellect", split="train"), load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="taco", split="train"), load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="train")])
test_dataset = concatenate_datasets([load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="codeforces", split="test"), load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="test")])
train_dataset = concatenate_datasets(
[
load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="primeintellect", split="train"),
load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="taco", split="train"),
load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="train"),
]
)
test_dataset = concatenate_datasets(
[
load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="codeforces", split="test"),
load_dataset("agentica-org/DeepCoder-Preview-Dataset", name="lcbv5", split="test"),
]
)

def preprocess_fn(example, idx):
starter_code = example.get("starter_code", "")
Expand Down Expand Up @@ -39,7 +50,15 @@ def preprocess_fn(example, idx):
else:
test["metadata"] = {"func_name": None}

return {"question": question, "ground_truth": json.dumps(tests), "data_source": "livecodebench", "uid": f"deepcoder_{idx}", "index": idx, "starter_code": starter_code, "metadata": json.dumps(metadata)}
return {
"question": question,
"ground_truth": json.dumps(tests),
"data_source": "livecodebench",
"uid": f"deepcoder_{idx}",
"index": idx,
"starter_code": starter_code,
"metadata": json.dumps(metadata),
}

if train_size:
train_dataset = train_dataset.select(range(min(train_size, len(train_dataset))))
Expand Down
4 changes: 3 additions & 1 deletion examples/fully_async/deepresearch/refine_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ async def end_request(self, success: bool, latency: float):
# Log periodically
if (self.total_completed + self.total_failed) % 50 == 0:
avg_latency = self.total_latency / max(1, self.total_completed)
logger.info(f"[STATS] In-flight: {self.in_flight}, Completed: {self.total_completed}, Failed: {self.total_failed}, Latency(avg/min/max): {avg_latency:.2f}s/{self.min_latency:.2f}s/{self.max_latency:.2f}s")
logger.info(
f"[STATS] In-flight: {self.in_flight}, Completed: {self.total_completed}, Failed: {self.total_failed}, Latency(avg/min/max): {avg_latency:.2f}s/{self.min_latency:.2f}s/{self.max_latency:.2f}s"
)

async def get_stats(self) -> dict:
async with self._lock:
Expand Down
23 changes: 22 additions & 1 deletion examples/fully_async/deepresearch/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,28 @@ async def run(self, question):
final_answer = extract_boxed_answer(content)

# Aggregate metrics across all tool calls
aggregated_metrics = {"num_turns": num_turns, "total_parse_tool_args_error": sum(m.get("parse_tool_args_error", 0) for m in metrics), "total_tool_return_error": sum(m.get("tool_return_error", 0) for m in metrics), "total_tool_calls": sum(m.get("tool_calls", 0) for m in metrics), "total_tool_wait_time": sum(m.get("tool_wait_time", 0) for m in metrics), "total_refine_time": sum(m.get("refine_time", 0) for m in metrics), "avg_refine_time": sum(m.get("refine_time", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), "total_query_length": sum(m.get("query_length", 0) for m in metrics), "avg_query_length": sum(m.get("query_length", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), "total_generation_time": total_generation_time, "total_completion_tokens": total_completion_tokens, "total_tool_tokens": sum(m.get("tool_tokens", 0) for m in metrics), "avg_completion_tokens_per_turn": total_completion_tokens / max(num_turns, 1), "avg_tool_tokens_per_call": sum(m.get("tool_tokens", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1), "duplicate_search_detected": duplicate_search_detected, "excessive_parallel_calls": excessive_parallel_calls, "tool_error_detected": tool_error_detected, "refine_error_detected": refine_error_detected, "overlong": overlong, "merged_step": len(trajectory.merge())}
aggregated_metrics = {
"num_turns": num_turns,
"total_parse_tool_args_error": sum(m.get("parse_tool_args_error", 0) for m in metrics),
"total_tool_return_error": sum(m.get("tool_return_error", 0) for m in metrics),
"total_tool_calls": sum(m.get("tool_calls", 0) for m in metrics),
"total_tool_wait_time": sum(m.get("tool_wait_time", 0) for m in metrics),
"total_refine_time": sum(m.get("refine_time", 0) for m in metrics),
"avg_refine_time": sum(m.get("refine_time", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1),
"total_query_length": sum(m.get("query_length", 0) for m in metrics),
"avg_query_length": sum(m.get("query_length", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1),
"total_generation_time": total_generation_time,
"total_completion_tokens": total_completion_tokens,
"total_tool_tokens": sum(m.get("tool_tokens", 0) for m in metrics),
"avg_completion_tokens_per_turn": total_completion_tokens / max(num_turns, 1),
"avg_tool_tokens_per_call": sum(m.get("tool_tokens", 0) for m in metrics) / max(sum(m.get("tool_calls", 0) for m in metrics), 1),
"duplicate_search_detected": duplicate_search_detected,
"excessive_parallel_calls": excessive_parallel_calls,
"tool_error_detected": tool_error_detected,
"refine_error_detected": refine_error_detected,
"overlong": overlong,
"merged_step": len(trajectory.merge()),
}

if OVERLONG_FILTER and overlong:
for seq in trajectory.sequences:
Expand Down
4 changes: 3 additions & 1 deletion examples/math_tinker/rl_loop_tinker_original.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def main(config: Config):
target_tokens = tokens[1:]
all_logprobs = [0.0] * ob_len + logprob
all_advantages = [0.0] * ob_len + [advantage] * (len(input_tokens) - ob_len)
assert len(input_tokens) == len(target_tokens) == len(all_logprobs) == len(all_advantages), f"len(input_tokens): {len(input_tokens)}, len(target_tokens): {len(target_tokens)}, len(all_logprobs): {len(all_logprobs)}, len(all_advantages): {len(all_advantages)}"
assert len(input_tokens) == len(target_tokens) == len(all_logprobs) == len(all_advantages), (
f"len(input_tokens): {len(input_tokens)}, len(target_tokens): {len(target_tokens)}, len(all_logprobs): {len(all_logprobs)}, len(all_advantages): {len(all_advantages)}"
)
datum = types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs={
Expand Down
Loading
Loading