Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/experimental/rllm-and-backend-config.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ This file contains:
| `rollout_engine.reasoning_effort` | `str` | `medium` | Reasoning effort mode |
| `rollout_engine.accumulate_reasoning` | `bool` | `false` | Whether to accumulate reasoning across steps |
| `rollout_engine.disable_thinking` | `bool` | `false` | Whether to disable thinking tokens |
| `rollout_engine.bypass_render_with_parser` | `bool` | `false` | Whether to bypass render parsing |
| `rollout_engine.renderer_name` | `str | null` | `null` | Optional renderer name |
| `data.max_prompt_length` | `int` | `2048` | Max prompt length |
| `data.max_response_length` | `int` | `2048` | Max response length |
Expand Down
1 change: 0 additions & 1 deletion examples/countdown/train_countdown_distill_tinker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,3 @@ python -m examples.countdown.train_countdown_tinker \
trainer.test_freq=10 \
trainer.save_freq=1000 \
trainer.default_local_dir='./outputs/countdown-distill-tinker-8b' \
rollout_engine.bypass_render_with_parser=True
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ python -m examples.math_distill.opsd.train_deepmath_distill_tinker \
training.default_local_dir='./outputs/opsd-deepmath-8b-rllm' \
rllm.algorithm.use_precomputed_advantage=true \
rllm.algorithm.loss_fn=importance_sampling \
rollout_engine.bypass_render_with_parser=True \
rllm.workflow.n_parallel_tasks=512
1 change: 0 additions & 1 deletion examples/math_distill/train_deepmath_distill_tinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def main(config: DictConfig):
tokenizer=teacher_tokenizer,
service_client=teacher_service_client,
sampling_client=teacher_sampling_client,
bypass_render_with_parser=True,
)

trainer = AgentTrainer(
Expand Down
1 change: 0 additions & 1 deletion examples/math_distill/train_deepmath_distill_tinker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,5 @@ python -m examples.math_distill.train_deepmath_distill_tinker \
training.default_local_dir='./outputs/deepmath-distill-8b-32b-unified' \
rllm.algorithm.use_precomputed_advantage=true \
rllm.algorithm.loss_fn=importance_sampling \
rollout_engine.bypass_render_with_parser=False \
rollout_engine.renderer_name=qwen3 \
rllm.workflow.n_parallel_tasks=512
4 changes: 2 additions & 2 deletions rllm/engine/agent_sdk_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,11 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":
self.rollout_engine.wake_up()

if batch.meta_info.get("validate", False):
self.rollout_engine.validate = True
self.rollout_engine.is_validation = True
tasks = batch.non_tensor_batch["extra_info"].tolist()
task_ids = batch.non_tensor_batch["task_ids"].tolist()
episodes = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes
self.rollout_engine.validate = False
self.rollout_engine.is_validation = False

if isinstance(self.rollout_engine, VerlEngine):
await self.rollout_engine.sleep()
Expand Down
4 changes: 2 additions & 2 deletions rllm/engine/agent_workflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":

is_validation = batch.meta_info.get("validate", False)
if is_validation:
self.rollout_engine.validate = True
self.rollout_engine.is_validation = True
self.current_mode = "val"
else:
self.current_mode = "train"
tasks = batch.non_tensor_batch["extra_info"].tolist()
task_ids = batch.non_tensor_batch["task_ids"].tolist()
results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes
self.rollout_engine.validate = False
self.rollout_engine.is_validation = False

await self.rollout_engine.sleep()

Expand Down
23 changes: 21 additions & 2 deletions rllm/engine/rollout/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
# Avoid importing concrete engines at module import time to prevent circular imports
from typing import TYPE_CHECKING

from .rollout_engine import ModelOutput, RolloutEngine
from .types import TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput, VerlTokenInput, VerlTokenOutput

if TYPE_CHECKING:
from .tinker_engine import TinkerEngine
from .verl_engine import VerlEngine

__all__ = [
"ModelOutput",
"RolloutEngine",
"OpenAIEngine",
"TinkerEngine",
"VerlEngine",
# Token types
"TokenInput",
"TokenOutput",
"TinkerTokenInput",
"TinkerTokenOutput",
"VerlTokenInput",
"VerlTokenOutput",
"Tokenizer",
]


Expand All @@ -14,11 +29,15 @@ def __getattr__(name):
from .openai_engine import OpenAIEngine as _OpenAIEngine

return _OpenAIEngine
if name == "TinkerEngine":
from .tinker_engine import TinkerEngine as _TinkerEngine

return _TinkerEngine
if name == "VerlEngine":
try:
from .verl_engine import VerlEngine as _VerlEngine

return _VerlEngine
except Exception:
raise AttributeError(name) from None
raise AttributeError(name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import Any

from rllm.agents.agent import Step
from rllm.experimental.rollout.rollout_engine import ModelOutput, RolloutEngine
from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput
from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine
from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput
from rllm.parser import ChatTemplateParser


Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self, rollout_engine: RolloutEngine):
raise ValueError(f"The rollout engine {cls_name} does not support token-in-token-out")
# we also require the rollout engine has a chat parser and a tokenizer
if rollout_engine.chat_parser is None or rollout_engine.tokenizer is None:
raise ValueError("The rollout engine must have a chat parser and a tokenizer. For Tinker engine, make sure you have set bypass_render_with_parser=True.")
raise ValueError("The rollout engine must have a chat parser and a tokenizer.")
self.tokenizer = rollout_engine.tokenizer
self.chat_parser = rollout_engine.chat_parser

Expand Down
23 changes: 22 additions & 1 deletion rllm/engine/rollout/rollout_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass

from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput
from rllm.parser import ChatTemplateParser
from rllm.tools.tool_base import ToolCall


Expand All @@ -9,7 +11,7 @@ class ModelOutput:
content: str | None = None
reasoning: str | None = None
tool_calls: list[ToolCall] | None = None
prompt_ids: list[int] | None = None
prompt_ids: TokenInput | None = None
completion_ids: list[int] | None = None
multi_modal_inputs: dict[str, list] | None = None
logprobs: list[float] | None = None # completion logprobs
Expand Down Expand Up @@ -53,12 +55,31 @@ def from_dict(cls, data: dict):


class RolloutEngine:
chat_parser: ChatTemplateParser | None = None
tokenizer: Tokenizer | None = None
is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks

def __init__(self, *args, **kwargs):
pass

async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput:
raise NotImplementedError("get_model_response is not implemented")

def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput:
"""
Assemble model output from a token output.
"""
raise NotImplementedError("assemble_model_output is not implemented")

async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> TokenOutput:
"""Obtain the token output from the given token input."""
raise NotImplementedError("get_token_output_from_token_input is not implemented")

@property
def supports_token_in_token_out(self) -> bool:
"""Whether the engine supports token-in-token-out (TITO) generation. Defaults to false."""
return False

async def wake_up(self):
pass

Expand Down
Loading