Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f46aa54
init new feature on unified fully async design
listar2000 Mar 6, 2026
fd69d8f
add coordinator control and refactor queue
listar2000 Mar 8, 2026
fb85d2a
cherrypick Kyle's async design refinements from kyle/deepresearch
listar2000 Mar 9, 2026
f9f01e5
Refactor chat parser and migrate experimental rollout to engine (#435)
listar2000 Mar 10, 2026
9af8505
merge nightly
listar2000 Mar 10, 2026
9a9cb76
dump changes to rollout_engine into main file
listar2000 Mar 10, 2026
18ca0f4
refactor base rollout engine class to standardize gating behaviors
listar2000 Mar 11, 2026
764d0e1
make tinker backend fully compatible
listar2000 Mar 11, 2026
1da0085
merge Kyle's fork
kylemontgomery1 Mar 31, 2026
8a2db48
Merge remote-tracking branch 'origin/main' into unified-fully-async
kylemontgomery1 Mar 31, 2026
f77f94a
bump vllm, deepcopy msgs in Step's post_init
kylemontgomery1 Mar 31, 2026
46b3356
[wip] make fully-async unified trainer compatible with agent flow eng…
kylemontgomery1 Mar 31, 2026
497d35a
fix staleness thottling
kylemontgomery1 Mar 31, 2026
8170c7a
enfore concurrency across engines
kylemontgomery1 Mar 31, 2026
3e2eb8d
Merge remote-tracking branch 'origin/main' into unified-fully-async
kylemontgomery1 Apr 1, 2026
2f8e2f1
fix fully async, refactor metrics
kylemontgomery1 Apr 3, 2026
ec49de5
Merge origin/main into unified-fully-async
kylemontgomery1 Apr 4, 2026
0f01be7
revert engine/rollout to main, restore experimental/rollout engines
kylemontgomery1 Apr 4, 2026
c86083b
revert TinkerChatTemplateParser and parser changes for separate PR
kylemontgomery1 Apr 4, 2026
a5b8b4f
revert bypass_render_with_parser and tinker parser-related changes
kylemontgomery1 Apr 4, 2026
4b67829
remove engine/gateway-level gate mechanism
kylemontgomery1 Apr 4, 2026
bc7c37f
refactor: move task tracking to coordinator, revert validation rename…
kylemontgomery1 Apr 4, 2026
7550fda
restore load_balancer assertion in verl_engine, revert tool_base to main
kylemontgomery1 Apr 4, 2026
4f05c8e
fix: add future annotations to rollout_engine for TYPE_CHECKING imports
kylemontgomery1 Apr 4, 2026
44d95a5
Merge remote-tracking branch 'origin/main' into unified-fully-async
kylemontgomery1 Apr 5, 2026
7993243
style: fix ruff lint and format issues on unified-fully-async branch
kylemontgomery1 Apr 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import hydra

from rllm.data.dataset import DatasetRegistry
from rllm.experimental.unified_trainer import AgentTrainer
from rllm.rewards.countdown_reward import countdown_reward_fn
from rllm.workflows.simple_workflow import SimpleWorkflow


@hydra.main(config_path="pkg://rllm.experimental.config", config_name="unified", version_base=None)
def main(config):
train_dataset = DatasetRegistry.load_dataset("countdown", "train")
test_dataset = DatasetRegistry.load_dataset("countdown", "test")

trainer = AgentTrainer(
workflow_class=SimpleWorkflow,
workflow_args={
"reward_function": countdown_reward_fn,
},
config=config,
train_dataset=train_dataset,
val_dataset=test_dataset,
backend="tinker",
)
trainer.train()


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
set -x

python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \
rllm/backend=tinker \
model.name=Qwen/Qwen3-8B \
model.lora_rank=32 \
training.group_size=8 \
training.learning_rate=2e-5 \
training.max_length=4096 \
sampling.train.temperature=1.0 \
sampling.train.top_p=1.0 \
sampling.val.temperature=1.0 \
sampling.val.top_p=1.0 \
validation.group_size=1 \
rllm.workflow.n_parallel_tasks=256 \
rllm.workflow.retry_limit=1 \
rllm.workflow.raise_on_error=false \
data.max_prompt_length=2048 \
data.max_response_length=2048 \
data.train_batch_size=1 \
data.val_batch_size=1024 \
rllm.algorithm.adv_estimator=grpo \
rllm.algorithm.norm_adv_by_std_in_grpo=true \
rllm.async_training.enable=true \
rllm.async_training.mini_batch_size=32 \
rllm.async_training.fwd_bwd_group_size=8 \
rllm.async_training.staleness_threshold=0.5 \
rllm.async_training.trigger_parameter_sync_step=1 \
rllm.async_training.partial_rollout=true \
rllm.trainer.total_epochs=1 \
rllm.trainer.logger='[wandb]' \
rllm.trainer.project_name='rllm-countdown' \
rllm.trainer.experiment_name='countdown-tinker-async-staleness-0.5' \
rllm.trainer.val_before_train=true \
rllm.trainer.test_freq=10 \
rllm.trainer.save_freq=-1
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
set -x

python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \
rllm/backend=tinker \
model.name=Qwen/Qwen3-8B \
model.lora_rank=32 \
training.group_size=8 \
training.learning_rate=2e-5 \
training.max_length=4096 \
sampling.train.temperature=1.0 \
sampling.train.top_p=1.0 \
sampling.val.temperature=1.0 \
sampling.val.top_p=1.0 \
validation.group_size=1 \
rllm.workflow.n_parallel_tasks=256 \
rllm.workflow.retry_limit=1 \
rllm.workflow.raise_on_error=false \
data.max_prompt_length=2048 \
data.max_response_length=2048 \
data.train_batch_size=32 \
data.val_batch_size=1024 \
rllm.algorithm.adv_estimator=grpo \
rllm.algorithm.norm_adv_by_std_in_grpo=true \
rllm.async_training.enable=false \
rllm.trainer.total_epochs=1 \
rllm.trainer.logger='[wandb]' \
rllm.trainer.project_name='rllm-countdown' \
rllm.trainer.experiment_name='countdown-tinker-sync' \
rllm.trainer.val_before_train=true \
rllm.trainer.test_freq=10 \
rllm.trainer.save_freq=-1
34 changes: 32 additions & 2 deletions rllm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import uuid
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel, ConfigDict, Field
Expand All @@ -23,6 +24,7 @@ class Step(_StepBase):
prompt_ids: list[int] | list[Any] = Field(default_factory=list)
response_ids: list[int] = Field(default_factory=list)
logprobs: list[float] = Field(default_factory=list)
routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient)

chat_completions: list[dict[str, Any]] = Field(default_factory=list)

Expand All @@ -38,6 +40,9 @@ class Step(_StepBase):
# Per-token or scalar advantages
advantage: list[float] | float | None = None

# weight version at time of generation (for async training staleness tracking)
weight_version: int | None = None

@property
def info(self) -> dict:
"""Alias for metadata. Auto-initializes to {} if None so mutation works."""
Expand All @@ -50,6 +55,7 @@ def info(self, value: dict) -> None:
self.metadata = value

def model_post_init(self, __context: Any) -> None:
self.chat_completions = deepcopy(self.chat_completions)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This deepcopy was incorrectly removed during the refactor from dataclasses to pydantic. Many old workflows operate with:

for turn in max_turns:
    output: ModelOutput = await self.rollout_engine.get_model_response(messages)
    messages.append("role": "assistant", "content": output.content, ...}
    trajectory.steps.append(Step(chat_completions=messages, model_output=output))

If chat completions is not deepcopied, then appending a message on a future turn would mutate a previous turn's step.chat_completions.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in the future we should really have a rLLM built-in messages format & class (similar to Tinker's Message), and ensure (1) it's as easy to work with as a plain dictionary, while (2) every step only holds a "view" of it (so no need to keep lots of copies, while earlier steps are not affected).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think we can spend some time this week rethinking messages/parsers.

if self.model_output is None:
return
# backfill fields like prompt_ids, response_ids, logprobs, etc.
Expand All @@ -59,17 +65,35 @@ def model_post_init(self, __context: Any) -> None:
self.response_ids = self.model_output.completion_ids
if len(self.logprobs) == 0 and self.model_output.logprobs is not None:
self.logprobs = self.model_output.logprobs
if self.routing_matrices is None and getattr(self.model_output, "routing_matrices", None) is not None:
self.routing_matrices = self.model_output.routing_matrices
if self.weight_version is None and hasattr(self.model_output, "weight_version"):
self.weight_version = self.model_output.weight_version

# check that the lengths would match up
if len(self.logprobs) > 0:
assert len(self.response_ids) == len(self.logprobs), f"length mismatch between response_ids and logprobs, got {len(self.response_ids)}, {len(self.logprobs)}"

def to_dict(self) -> dict:
from rllm.tools.tool_base import ToolCall, ToolOutput

# Helper function to recursively convert ToolCall and ToolOutput objects to dicts
def _serialize_value(value):
if isinstance(value, ToolCall | ToolOutput):
return value.to_dict()
elif isinstance(value, list):
return [_serialize_value(item) for item in value]
elif isinstance(value, dict):
return {k: _serialize_value(v) for k, v in value.items()}
else:
return value

return {
"prompt_ids": self.prompt_ids,
"response_ids": self.response_ids,
"logprobs": self.logprobs,
"chat_completions": self.chat_completions,
"routing_matrices": self.routing_matrices,
"chat_completions": _serialize_value(self.chat_completions),
"observation": self.observation,
"thought": self.thought,
"action": self.action.action if isinstance(self.action, Action) else self.action,
Expand All @@ -80,6 +104,7 @@ def to_dict(self) -> dict:
"done": self.done,
"mc_return": self.mc_return,
"advantage": self.advantage,
"weight_version": self.weight_version,
}

@classmethod
Expand All @@ -90,6 +115,7 @@ def from_dict(cls, data: dict) -> Step:
prompt_ids=data["prompt_ids"],
response_ids=data["response_ids"],
logprobs=data["logprobs"],
routing_matrices=data.get("routing_matrices"),
chat_completions=data["chat_completions"],
observation=data["observation"],
thought=data["thought"],
Expand All @@ -100,7 +126,8 @@ def from_dict(cls, data: dict) -> Step:
reward=data["reward"],
done=data["done"],
mc_return=data["mc_return"],
advantage=data["advantage"],
advantage=data.get("advantage", 0.0),
weight_version=data.get("weight_version"),
)

@classmethod
Expand All @@ -109,11 +136,13 @@ def from_model_output(cls, model_output: ModelOutput, messages: list[dict] | Non
prompt_ids=model_output.prompt_ids or [],
response_ids=model_output.completion_ids or [],
logprobs=model_output.logprobs or [],
routing_matrices=getattr(model_output, "routing_matrices", None),
chat_completions=(messages or []) + [{"role": "assistant", "content": model_output.content, "reasoning": model_output.reasoning}],
thought=model_output.reasoning or "",
action=action,
model_response=model_output.content or "",
model_output=model_output,
weight_version=model_output.weight_version,
)


Expand Down Expand Up @@ -259,6 +288,7 @@ class TrajectoryGroup(BaseModel):
trajectories: list[Trajectory]
group_id: str = ""
metadata: list[dict] = Field(default_factory=list)
weight_version: int = 0

@property
def group_role(self) -> str:
Expand Down
Loading
Loading