Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 168 additions & 70 deletions trl/experimental/async_grpo/async_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,94 @@
Messages: TypeAlias = list[dict[str, str]]


@dataclass(slots=True)
class ToolCallRecord:
tool_call_id: str | None
name: str
arguments: dict[str, Any]
result: str
failed: bool
duration: float # wall-clock seconds for this tool call


@dataclass(slots=True)
class TaggedMessage:
message: dict[str, Any] # raw chat message (has "role", "content", etc.)
is_completion: bool # True = generated by the model or resulting tool execution


@dataclass(slots=True)
class TurnRecord:
messages: list[TaggedMessage] # full turn in order: [context..., assistant, tool_results...]
generation_ids: list[int] # token ids for the assistant response only
generation_logprobs: list[float]
generation_duration: float # wall-clock seconds for the vLLM call

tool_calls: list[ToolCallRecord] # per-call records with timing (empty if no tool calls)
tool_response_ids: list[int] # token ids for the tool-result suffix (empty if no tool calls)
tool_execution_duration_s: float # total wall-clock seconds for all tool calls in this turn


@dataclass(slots=True)
class RolloutCompletion:
"""Result of a single (possibly multi-turn) generation for one prompt."""

turns: list[TurnRecord]
truncated: bool
total_duration_s: float

def get_completion_messages(self) -> Messages:
"""Return the completion messages (assistant + tool results) across all turns."""
return [tm.message for turn in self.turns for tm in turn.messages if tm.is_completion]

def get_completion_ids(self) -> list[int]:
"""Flattened token IDs for the completion (generation + tool response per turn)."""
ids: list[int] = []
for turn in self.turns:
ids.extend(turn.generation_ids)
ids.extend(turn.tool_response_ids)
return ids

def get_completion_logprobs(self) -> list[float]:
"""Flattened logprobs (real for generation tokens, 0.0 for tool response tokens)."""
logprobs: list[float] = []
for turn in self.turns:
logprobs.extend(turn.generation_logprobs)
logprobs.extend([0.0] * len(turn.tool_response_ids))
return logprobs

def get_tool_mask(self) -> list[int]:
"""1 for generation tokens (trainable), 0 for tool response tokens."""
mask: list[int] = []
for turn in self.turns:
mask.extend([1] * len(turn.generation_ids))
mask.extend([0] * len(turn.tool_response_ids))
return mask

def get_trajectory(self) -> list[TaggedMessage]:
"""Full trajectory in order — ready for export to dataset/file."""
return [tm for turn in self.turns for tm in turn.messages]
Comment thread
cursor[bot] marked this conversation as resolved.



@dataclass(slots=True)
class RolloutGroup:
"""Single GRPO group for one prompt with multiple completions."""

prompt: Messages
prompt_ids: list[int]
reward_kwargs: dict[str, list[Any]]
completions: list[Messages]
completions_ids: list[list[int]]
completions_logprobs: list[list[float]]
tool_mask: list[list[int]]
tool_call_counts: list[int]
tool_failure_counts: list[int]
completions: list[RolloutCompletion]
model_version: int
queued_at: float = 0.0

def append_completion(self, result: RolloutCompletion):
self.completions.append(result)

def _build_completion(turns: list[TurnRecord], truncated: bool, total_duration: float) -> RolloutCompletion:
"""Derive a RolloutCompletion from a list of TurnRecords."""
return RolloutCompletion(turns=turns, truncated=truncated, total_duration_s=total_duration)


@dataclass(slots=True)
class RolloutSample:
Expand Down Expand Up @@ -332,11 +404,6 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None:
prompt_ids=prompt_ids,
reward_kwargs=reward_kwargs,
completions=[],
completions_ids=[],
completions_logprobs=[],
tool_mask=[],
tool_call_counts=[],
tool_failure_counts=[],
model_version=self.model_version,
)
pending_completed[group_id] = 0
Expand Down Expand Up @@ -375,23 +442,10 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None:
if task.exception() is not None:
raise task.exception()

(
completion,
completion_ids,
completion_logprobs,
tool_mask,
tool_call_count,
tool_failure_count,
) = task.result()
result: RolloutCompletion = task.result()
group = pending_groups[group_id]
group.completions.append(completion)
group.completions_ids.append(completion_ids)
group.completions_logprobs.append(completion_logprobs)
group.tool_mask.append(tool_mask)
group.tool_call_counts.append(tool_call_count)
group.tool_failure_counts.append(tool_failure_count)
# TODO: move this in generation task, shouldn't matter but is correct
self._total_completion_tokens += sum(tool_mask)
group.append_completion(result)
self._total_completion_tokens += sum(len(t.generation_ids) for t in result.turns)
pending_completed[group_id] += 1

if pending_completed[group_id] == self.num_generations:
Expand Down Expand Up @@ -508,12 +562,10 @@ def _repeat_iterator(self) -> Iterator[tuple[int, dict[str, Any]]]:

async def _generate_one(
self, prompt: Messages, tool_dict: dict[str, Callable]
) -> tuple[list[dict[str, str]], list[int], list[float], list[int], int, int]:
completion, completion_ids, completion_logprobs, tool_mask = [], [], [], []
tool_call_count = 0
tool_failure_count = 0
iteration_num = 0
max_iterations = self.max_tool_calling_iterations
) -> RolloutCompletion:
turns: list[TurnRecord] = []
max_num_turns = self.max_tool_calling_iterations
t_start = time.monotonic()
prompt_ids = self.tokenizer.apply_chat_template(
prompt,
return_dict=False,
Expand All @@ -522,26 +574,57 @@ async def _generate_one(
chat_template=self.chat_template,
**self.chat_template_kwargs,
)

iteration_num = 0
# context_messages for the current turn: initial prompt for turn 0, tool results for turn N>0
context_messages: Messages = list(prompt)
while True:
t_gen = time.monotonic()
turn_ids, turn_logprobs = await self._generate_one_turn(prompt_ids)
generation_duration = time.monotonic() - t_gen

assistant_message = parse_response(self.tokenizer, turn_ids)
completion.append(assistant_message)
completion_ids.extend(turn_ids)
completion_logprobs.extend(turn_logprobs)
tool_mask.extend([1] * len(turn_ids))
tool_calls = assistant_message.get("tool_calls")
if tool_calls is None or (max_iterations is not None and iteration_num >= max_iterations):
return completion, completion_ids, completion_logprobs, tool_mask, tool_call_count, tool_failure_count

tool_messages, n_calls, n_failures = self._execute_tool_calls(tool_calls, tool_dict)
tool_call_count += n_calls
tool_failure_count += n_failures
completion.extend(tool_messages)
tool_calls_raw = assistant_message.get("tool_calls")

if tool_calls_raw is None or (max_num_turns is not None and iteration_num >= max_num_turns):
messages = [TaggedMessage(m, is_completion=False) for m in context_messages]
messages.append(TaggedMessage(assistant_message, is_completion=True))
turns.append(
TurnRecord(
messages=messages,
generation_ids=turn_ids,
generation_logprobs=turn_logprobs,
generation_duration=generation_duration,
tool_calls=[],
tool_response_ids=[],
tool_execution_duration_s=0.0,
)
)
return _build_completion(turns, truncated=False, total_duration=time.monotonic() - t_start)
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

t_tools = time.monotonic()
tool_call_records, tool_messages = self._execute_tool_calls(tool_calls_raw, tool_dict)
tool_execution_duration = time.monotonic() - t_tools

tool_suffix_ids = self._build_messages_suffix_ids(tool_messages)
completion_ids.extend(tool_suffix_ids)
completion_logprobs.extend([0.0] * len(tool_suffix_ids))
tool_mask.extend([0] * len(tool_suffix_ids))
messages = [TaggedMessage(m, is_completion=False) for m in context_messages]
messages.append(TaggedMessage(assistant_message, is_completion=True))
messages.extend(TaggedMessage(m, is_completion=True) for m in tool_messages)
turns.append(
TurnRecord(
messages=messages,
generation_ids=turn_ids,
generation_logprobs=turn_logprobs,
generation_duration=generation_duration,
tool_calls=tool_call_records,
tool_response_ids=tool_suffix_ids,
tool_execution_duration_s=tool_execution_duration,
)
)
context_messages = list(tool_messages)

prompt_ids = prompt_ids + turn_ids + tool_suffix_ids

iteration_num += 1

def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int]:
Expand Down Expand Up @@ -571,22 +654,34 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int

def _execute_tool_calls(
self, tool_calls: list[dict[str, Any]], tool_dict: dict[str, Callable]
) -> tuple[list[dict[str, str]], int, int]:
tool_messages = []
n_calls = 0
n_failures = 0
) -> tuple[list[ToolCallRecord], list[dict[str, str]]]:
records: list[ToolCallRecord] = []
tool_messages: list[dict[str, str]] = []
for tool_call in tool_calls:
n_calls += 1
function = tool_call["function"]
name = function["name"]
arguments = function.get("arguments", {})
tool_call_id = tool_call.get("id")
t_tool = time.monotonic()
failed = False
try:
arguments = function.get("arguments", {})
result = tool_dict[name](**arguments)
except Exception as error:
n_failures += 1
failed = True
result = {"error": str(error)}
duration = time.monotonic() - t_tool
records.append(
ToolCallRecord(
tool_call_id=tool_call_id,
name=name,
arguments=arguments,
result=str(result),
failed=failed,
duration=duration,
)
)
tool_messages.append({"role": "tool", "name": name, "content": str(result)})
return tool_messages, n_calls, n_failures
return records, tool_messages

async def _generate_one_turn(self, prompt_ids: list[int]) -> tuple[list[int], list[float]]:
payload = {
Expand All @@ -612,11 +707,12 @@ async def _generate_one_turn(self, prompt_ids: list[int]) -> tuple[list[int], li
return completion_ids, completion_logprobs

async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]:
completion_messages = [c.get_completion_messages() for c in group.completions]
kwargs = dict(
completions=group.completions,
completions=completion_messages,
prompt=group.prompt,
prompts=[group.prompt] * len(group.completions),
completion_ids=group.completions_ids,
completion_ids=[c.get_completion_ids() for c in group.completions],
**group.reward_kwargs,
)
all_rewards = await asyncio.gather(
Expand All @@ -641,14 +737,19 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]:
# tools/call_frequency: mean calls per completion (matches TRL's total_calls / num_completions)
# tools/failure_frequency: per-completion failure rate; averaged across samples in compute_loss
# (TRL uses total_failures / total_calls, ours weights equally per completion — close enough)
total_calls = sum(group.tool_call_counts)
total_calls = sum(len(t.tool_calls) for c in group.completions for t in c.turns)
tool_metrics = (
[
{
"tools/call_frequency": float(n_calls),
"tools/failure_frequency": (n_failures / n_calls) if n_calls > 0 else 0.0,
"tools/call_frequency": float(sum(len(t.tool_calls) for t in c.turns)),
"tools/failure_frequency": (
sum(1 for t in c.turns for tc in t.tool_calls if tc.failed)
/ sum(len(t.tool_calls) for t in c.turns)
)
if sum(len(t.tool_calls) for t in c.turns) > 0
else 0.0,
}
for n_calls, n_failures in zip(group.tool_call_counts, group.tool_failure_counts, strict=True)
for c in group.completions
]
if total_calls > 0
else [{}] * len(group.completions)
Expand All @@ -659,10 +760,10 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]:
return [
RolloutSample(
prompt=group.prompt,
completion=completion,
input_ids=group.prompt_ids + completion_ids,
completion_mask=[0] * len(group.prompt_ids) + tool_mask,
old_log_probs=[0.0] * len(group.prompt_ids) + logprobs,
completion=c.get_completion_messages(),
input_ids=group.prompt_ids + c.get_completion_ids(),
completion_mask=[0] * len(group.prompt_ids) + c.get_tool_mask(),
old_log_probs=[0.0] * len(group.prompt_ids) + c.get_completion_logprobs(),
advantage=advantage,
model_version=group.model_version,
metrics={
Expand All @@ -675,12 +776,9 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]:
**tm,
},
)
for i, (completion, completion_ids, logprobs, tool_mask, advantage, reward, tm) in enumerate(
for i, (c, advantage, reward, tm) in enumerate(
zip(
group.completions,
group.completions_ids,
group.completions_logprobs,
group.tool_mask,
advantages,
rewards,
tool_metrics,
Expand Down
Loading