Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 185 additions & 73 deletions trl/experimental/async_grpo/async_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from accelerate.logging import get_logger
from datasets import Dataset
from transformers import AutoTokenizer
from transformers.utils import get_json_schema

from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response
from trl.import_utils import is_vllm_available
Expand All @@ -43,22 +44,94 @@
Messages: TypeAlias = list[dict[str, str]]


@dataclass(slots=True)
class ToolCallRecord:
tool_call_id: str | None
name: str
arguments: dict[str, Any]
result: str
failed: bool
duration: float # wall-clock seconds for this tool call


@dataclass(slots=True)
class TaggedMessage:
message: dict[str, Any] # raw chat message (has "role", "content", etc.)
is_completion: bool # True = generated by the model or resulting tool execution


@dataclass(slots=True)
class TurnRecord:
messages: list[TaggedMessage] # full turn in order: [context..., assistant, tool_results...]
generation_ids: list[int] # token ids for the assistant response only
generation_logprobs: list[float]
generation_duration: float # wall-clock seconds for the vLLM call

tool_calls: list[ToolCallRecord] # per-call records with timing (empty if no tool calls)
tool_response_ids: list[int] # token ids for the tool-result suffix (empty if no tool calls)
tool_execution_duration_s: float # total wall-clock seconds for all tool calls in this turn


@dataclass(slots=True)
class RolloutCompletion:
"""Result of a single (possibly multi-turn) generation for one prompt."""

turns: list[TurnRecord]
truncated: bool
total_duration_s: float

def get_completion_messages(self) -> Messages:
"""Return the completion messages (assistant + tool results) across all turns."""
return [tm.message for turn in self.turns for tm in turn.messages if tm.is_completion]

def get_completion_ids(self) -> list[int]:
"""Flattened token IDs for the completion (generation + tool response per turn)."""
ids: list[int] = []
for turn in self.turns:
ids.extend(turn.generation_ids)
ids.extend(turn.tool_response_ids)
return ids

def get_completion_logprobs(self) -> list[float]:
"""Flattened logprobs (real for generation tokens, 0.0 for tool response tokens)."""
logprobs: list[float] = []
for turn in self.turns:
logprobs.extend(turn.generation_logprobs)
logprobs.extend([0.0] * len(turn.tool_response_ids))
return logprobs

def get_tool_mask(self) -> list[int]:
"""1 for generation tokens (trainable), 0 for tool response tokens."""
mask: list[int] = []
for turn in self.turns:
mask.extend([1] * len(turn.generation_ids))
mask.extend([0] * len(turn.tool_response_ids))
return mask

def get_trajectory(self) -> list[TaggedMessage]:
"""Full trajectory in order — ready for export to dataset/file."""
return [tm for turn in self.turns for tm in turn.messages]


@dataclass(slots=True)
class RolloutGroup:
"""Single GRPO group for one prompt with multiple completions."""

prompt: Messages
prompt_ids: list[int]
reward_kwargs: dict[str, list[Any]]
completions: list[Messages]
completions_ids: list[list[int]]
completions_logprobs: list[list[float]]
tool_mask: list[list[int]]
tool_call_counts: list[int]
tool_failure_counts: list[int]
completions: list[RolloutCompletion]
model_version: int
queued_at: float = 0.0

def append_completion(self, result: RolloutCompletion):
self.completions.append(result)


def _build_completion(turns: list[TurnRecord], truncated: bool, total_duration: float) -> RolloutCompletion:
"""Derive a RolloutCompletion from a list of TurnRecords."""
return RolloutCompletion(turns=turns, truncated=truncated, total_duration_s=total_duration)


@dataclass(slots=True)
class RolloutSample:
Expand Down Expand Up @@ -127,6 +200,7 @@ def __init__(
self.num_generations = num_generations
self.max_inflight_tasks = max_inflight_tasks
self.environments = None
self._is_done_methods = [None] * self.max_inflight_tasks
environment_methods = [[] for _ in range(self.max_inflight_tasks)]
if environment_factory is not None:
self.environments = [environment_factory() for _ in range(self.max_inflight_tasks)]
Expand All @@ -135,6 +209,8 @@ def __init__(
for name, member in inspect.getmembers(environment, predicate=inspect.ismethod):
if name == "reset":
has_reset = True
elif name == "is_done":
self._is_done_methods[i] = member
elif not name.startswith("_"):
environment_methods[i].append(member)
if not has_reset:
Expand All @@ -149,7 +225,10 @@ def __init__(
if inspect.iscoroutinefunction(tool):
raise ValueError("Asynchronous tools are not supported in AsyncRolloutWorker yet.")
self._sync_tool_dicts[i][tool.__name__] = tool
self.tools = base_tools + (environment_methods[0] if self.environments is not None else [])
self.tools = base_tools + (
# Pre-convert any bound methods to JSON schema dicts so they pass transformers' `isfunction` check.
[get_json_schema(t) for t in environment_methods[0]] if self.environments is not None else []
)

self.vllm_server_url = vllm_server_url.rstrip("/")
self.model_update_group = None
Expand Down Expand Up @@ -332,11 +411,6 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None:
prompt_ids=prompt_ids,
reward_kwargs=reward_kwargs,
completions=[],
completions_ids=[],
completions_logprobs=[],
tool_mask=[],
tool_call_counts=[],
tool_failure_counts=[],
model_version=self.model_version,
)
pending_completed[group_id] = 0
Expand All @@ -349,7 +423,11 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None:

logger.info(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}")
task = asyncio.create_task(
self._generate_one(pending_groups[group_id].prompt, tool_dict=self._sync_tool_dicts[slot])
self._generate_one(
pending_groups[group_id].prompt,
tool_dict=self._sync_tool_dicts[slot],
is_done=self._is_done_methods[slot],
)
)
inflight_tasks[task] = (group_id, slot)

Expand All @@ -375,23 +453,10 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None:
if task.exception() is not None:
raise task.exception()

(
completion,
completion_ids,
completion_logprobs,
tool_mask,
tool_call_count,
tool_failure_count,
) = task.result()
result: RolloutCompletion = task.result()
group = pending_groups[group_id]
group.completions.append(completion)
group.completions_ids.append(completion_ids)
group.completions_logprobs.append(completion_logprobs)
group.tool_mask.append(tool_mask)
group.tool_call_counts.append(tool_call_count)
group.tool_failure_counts.append(tool_failure_count)
# TODO: move this in generation task, shouldn't matter but is correct
self._total_completion_tokens += sum(tool_mask)
group.append_completion(result)
self._total_completion_tokens += sum(len(t.generation_ids) for t in result.turns)
pending_completed[group_id] += 1

if pending_completed[group_id] == self.num_generations:
Expand Down Expand Up @@ -507,13 +572,11 @@ def _repeat_iterator(self) -> Iterator[tuple[int, dict[str, Any]]]:
group_id += 1

async def _generate_one(
self, prompt: Messages, tool_dict: dict[str, Callable]
) -> tuple[list[dict[str, str]], list[int], list[float], list[int], int, int]:
completion, completion_ids, completion_logprobs, tool_mask = [], [], [], []
tool_call_count = 0
tool_failure_count = 0
iteration_num = 0
max_iterations = self.max_tool_calling_iterations
self, prompt: Messages, tool_dict: dict[str, Callable], is_done: Callable[[], bool] | None = None
) -> RolloutCompletion:
turns: list[TurnRecord] = []
max_num_turns = self.max_tool_calling_iterations
t_start = time.monotonic()
prompt_ids = self.tokenizer.apply_chat_template(
prompt,
return_dict=False,
Expand All @@ -522,26 +585,60 @@ async def _generate_one(
chat_template=self.chat_template,
**self.chat_template_kwargs,
)

iteration_num = 0
# context_messages for the current turn: initial prompt for turn 0, tool results for turn N>0
context_messages: Messages = list(prompt)
while True:
t_gen = time.monotonic()
turn_ids, turn_logprobs = await self._generate_one_turn(prompt_ids)
generation_duration = time.monotonic() - t_gen

assistant_message = parse_response(self.tokenizer, turn_ids)
completion.append(assistant_message)
completion_ids.extend(turn_ids)
completion_logprobs.extend(turn_logprobs)
tool_mask.extend([1] * len(turn_ids))
tool_calls = assistant_message.get("tool_calls")
if tool_calls is None or (max_iterations is not None and iteration_num >= max_iterations):
return completion, completion_ids, completion_logprobs, tool_mask, tool_call_count, tool_failure_count

tool_messages, n_calls, n_failures = self._execute_tool_calls(tool_calls, tool_dict)
tool_call_count += n_calls
tool_failure_count += n_failures
completion.extend(tool_messages)
tool_calls_raw = assistant_message.get("tool_calls")

if tool_calls_raw is None or (max_num_turns is not None and iteration_num >= max_num_turns):
messages = [TaggedMessage(m, is_completion=False) for m in context_messages]
messages.append(TaggedMessage(assistant_message, is_completion=True))
turns.append(
TurnRecord(
messages=messages,
generation_ids=turn_ids,
generation_logprobs=turn_logprobs,
generation_duration=generation_duration,
tool_calls=[],
tool_response_ids=[],
tool_execution_duration_s=0.0,
)
)
return _build_completion(turns, truncated=False, total_duration=time.monotonic() - t_start)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

truncated always False even when max turns exceeded

Low Severity

The truncated field on RolloutCompletion is always set to False across all return paths. On line 600, two semantically different conditions are combined with or: (1) tool_calls_raw is None (natural completion) and (2) iteration_num >= max_num_turns (forced stop). When the model produces tool calls but max_num_turns forces an early exit, truncated is incorrectly set to False — the generation was cut short, which is truncation. The truncated value needs to depend on which sub-condition triggered the exit (e.g., truncated=(tool_calls_raw is not None)).

Fix in Cursor Fix in Web


t_tools = time.monotonic()
tool_call_records, tool_messages = self._execute_tool_calls(tool_calls_raw, tool_dict)
tool_execution_duration = time.monotonic() - t_tools

tool_suffix_ids = self._build_messages_suffix_ids(tool_messages)
completion_ids.extend(tool_suffix_ids)
completion_logprobs.extend([0.0] * len(tool_suffix_ids))
tool_mask.extend([0] * len(tool_suffix_ids))
messages = [TaggedMessage(m, is_completion=False) for m in context_messages]
messages.append(TaggedMessage(assistant_message, is_completion=True))
messages.extend(TaggedMessage(m, is_completion=True) for m in tool_messages)
turns.append(
TurnRecord(
messages=messages,
generation_ids=turn_ids,
generation_logprobs=turn_logprobs,
generation_duration=generation_duration,
tool_calls=tool_call_records,
tool_response_ids=tool_suffix_ids,
tool_execution_duration_s=tool_execution_duration,
)
)
context_messages = list(tool_messages)

prompt_ids = prompt_ids + turn_ids + tool_suffix_ids

if is_done is not None and is_done():
return _build_completion(turns, truncated=False, total_duration=time.monotonic() - t_start)

iteration_num += 1

def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int]:
Expand Down Expand Up @@ -571,22 +668,34 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int

def _execute_tool_calls(
self, tool_calls: list[dict[str, Any]], tool_dict: dict[str, Callable]
) -> tuple[list[dict[str, str]], int, int]:
tool_messages = []
n_calls = 0
n_failures = 0
) -> tuple[list[ToolCallRecord], list[dict[str, str]]]:
records: list[ToolCallRecord] = []
tool_messages: list[dict[str, str]] = []
for tool_call in tool_calls:
n_calls += 1
function = tool_call["function"]
name = function["name"]
arguments = function.get("arguments", {})
tool_call_id = tool_call.get("id")
t_tool = time.monotonic()
failed = False
try:
arguments = function.get("arguments", {})
result = tool_dict[name](**arguments)
except Exception as error:
n_failures += 1
failed = True
result = {"error": str(error)}
duration = time.monotonic() - t_tool
records.append(
ToolCallRecord(
tool_call_id=tool_call_id,
name=name,
arguments=arguments,
result=str(result),
failed=failed,
duration=duration,
)
)
tool_messages.append({"role": "tool", "name": name, "content": str(result)})
return tool_messages, n_calls, n_failures
return records, tool_messages

async def _generate_one_turn(self, prompt_ids: list[int]) -> tuple[list[int], list[float]]:
payload = {
Expand All @@ -612,11 +721,12 @@ async def _generate_one_turn(self, prompt_ids: list[int]) -> tuple[list[int], li
return completion_ids, completion_logprobs

async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]:
completion_messages = [c.get_completion_messages() for c in group.completions]
kwargs = dict(
completions=group.completions,
completions=completion_messages,
prompt=group.prompt,
prompts=[group.prompt] * len(group.completions),
completion_ids=group.completions_ids,
completion_ids=[c.get_completion_ids() for c in group.completions],
**group.reward_kwargs,
)
all_rewards = await asyncio.gather(
Expand All @@ -641,14 +751,19 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]:
# tools/call_frequency: mean calls per completion (matches TRL's total_calls / num_completions)
# tools/failure_frequency: per-completion failure rate; averaged across samples in compute_loss
# (TRL uses total_failures / total_calls, ours weights equally per completion — close enough)
total_calls = sum(group.tool_call_counts)
total_calls = sum(len(t.tool_calls) for c in group.completions for t in c.turns)
tool_metrics = (
[
{
"tools/call_frequency": float(n_calls),
"tools/failure_frequency": (n_failures / n_calls) if n_calls > 0 else 0.0,
"tools/call_frequency": float(sum(len(t.tool_calls) for t in c.turns)),
"tools/failure_frequency": (
sum(1 for t in c.turns for tc in t.tool_calls if tc.failed)
/ sum(len(t.tool_calls) for t in c.turns)
)
if sum(len(t.tool_calls) for t in c.turns) > 0
else 0.0,
}
for n_calls, n_failures in zip(group.tool_call_counts, group.tool_failure_counts, strict=True)
for c in group.completions
]
if total_calls > 0
else [{}] * len(group.completions)
Expand All @@ -659,10 +774,10 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]:
return [
RolloutSample(
prompt=group.prompt,
completion=completion,
input_ids=group.prompt_ids + completion_ids,
completion_mask=[0] * len(group.prompt_ids) + tool_mask,
old_log_probs=[0.0] * len(group.prompt_ids) + logprobs,
completion=c.get_completion_messages(),
input_ids=group.prompt_ids + c.get_completion_ids(),
completion_mask=[0] * len(group.prompt_ids) + c.get_tool_mask(),
old_log_probs=[0.0] * len(group.prompt_ids) + c.get_completion_logprobs(),
advantage=advantage,
model_version=group.model_version,
metrics={
Expand All @@ -675,12 +790,9 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]:
**tm,
},
)
for i, (completion, completion_ids, logprobs, tool_mask, advantage, reward, tm) in enumerate(
for i, (c, advantage, reward, tm) in enumerate(
zip(
group.completions,
group.completions_ids,
group.completions_logprobs,
group.tool_mask,
advantages,
rewards,
tool_metrics,
Expand Down
Loading