Skip to content
Merged
66 changes: 50 additions & 16 deletions recipe/fully_async_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(

async def generate_sequences_no_post(
self, batch: DataProto, partial_output_list: Optional[list[AgentLoopOutput]]
) -> list[AgentLoopOutput]:
) -> tuple[list[AgentLoopOutput], bool] | tuple[DataProto, bool]:
"""Generate sequences from agent loop.

Args:
Expand Down Expand Up @@ -126,15 +126,34 @@ async def generate_sequences_no_post(

if not partial_output_list:
partial_output_list = [None] * len(batch)

tasks = []
for i in range(len(batch)):
kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
kwargs["output"] = partial_output_list[i]
tasks.append(
asyncio.create_task(self._partial_run_agent_loop(sampling_params, trajectory_info[i], **kwargs))
)
return await asyncio.gather(*tasks)
try:
tasks = []
for i in range(len(batch)):
kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
kwargs["output"] = partial_output_list[i]
tasks.append(
asyncio.create_task(self._partial_run_agent_loop(sampling_params, trajectory_info[i], **kwargs))
)
outputs = await asyncio.gather(*tasks)
except Exception:
logger.exception("_partial_run_agent_loop failed")
raise

is_cancel = any(output.extra_fields.get("is_cancel", False) for output in outputs)
if not is_cancel:
output = self._postprocess(outputs)
output = self._addition_process(output)
return output, is_cancel
return outputs, is_cancel

def _addition_process(self, output: DataProto):
"""collect metirics"""
metrics = output.meta_info.pop("metrics") # List[Dict[str, str]]
processing_times_list = [item["generate_sequences"] for item in metrics]
tool_calls_times_list = [item["tool_calls"] for item in metrics]
output.non_tensor_batch["processing_times"] = processing_times_list
output.non_tensor_batch["tool_calls_times"] = tool_calls_times_list
return output

async def _partial_run_agent_loop(
self,
Expand All @@ -144,6 +163,10 @@ async def _partial_run_agent_loop(
agent_name: str,
**kwargs,
) -> AgentLoopOutput:
# Completed, return directly
if kwargs["output"] is not None and not kwargs["output"].extra_fields.get("is_cancel", False):
logger.info("In _partial_run_agent_loop, already completed, return derictly!")
return kwargs["output"]
try:
with rollout_trace_attr(
step=trajectory["step"],
Expand All @@ -164,10 +187,17 @@ async def _partial_run_agent_loop(
tokenizer=self.tokenizer,
processor=self.processor,
)
return await agent_loop.run(sampling_params, cancellation_event=self.cancellation_event, **kwargs)
except Exception as e:
logger.exception(f"Agent_loop run failed: {e}")
raise e
output: AgentLoopOutput = await agent_loop.run(
sampling_params, cancellation_event=self.cancellation_event, **kwargs
)
if not output.extra_fields.get("is_cancel", False):
kwargs.pop("output", None)
output = await self._agent_loop_postprocess(output, **kwargs)

return output
except Exception:
logger.exception("Agent_loop run failed")
raise

async def cancel_agent_loops(self):
"""Set the shared cancellation event to stop all agent loops."""
Expand Down Expand Up @@ -210,7 +240,11 @@ async def _async_init(self):
self._init_agent_loop_workers()

async def _initialize_llm_servers_async(self):
rollout_world_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
rollout_world_size = (
self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
* self.config.actor_rollout_ref.rollout.data_parallel_size
* self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size
)
world_size = (
self.worker_group.world_size
if self.worker_group
Expand Down Expand Up @@ -249,7 +283,7 @@ async def generate_single_sample_async(
self,
sample: DataProto,
partial_output_list: Optional[list[AgentLoopOutput]],
) -> list[AgentLoopOutput]:
) -> tuple[list[AgentLoopOutput], bool] | tuple[DataProto, bool]:
"""
Asynchronously process a single sample

Expand Down
188 changes: 8 additions & 180 deletions recipe/fully_async_policy/detach_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,149 +18,10 @@

import numpy as np
import torch
from tensordict import TensorDict

from verl import DataProto
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput
from verl.trainer.ppo.ray_trainer import compute_response_mask
from verl.utils.model import compute_position_id_with_mask


def postprocess_agent_loop_outputs(rs: "RolloutSample", tokenizer, config, processor) -> DataProto:
"""Static method to postprocess a list of AgentLoopOutput into DataProto

Args:
rs: RolloutSample
tokenizer: Tokenizer instance
config: Configuration object

Returns:
DataProto: Processed batch data
"""
inputs: list[AgentLoopOutput] = rs.agent_loop_output_list
full_batch = rs.full_batch
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompts: left pad
# responses: right pad
# input_ids: prompt + response
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]

# prompts
tokenizer.padding_side = "left"
outputs = tokenizer.pad(
[{"input_ids": input.prompt_ids} for input in inputs],
padding="max_length",
max_length=config.actor_rollout_ref.rollout.prompt_length,
return_tensors="pt",
return_attention_mask=True,
)
prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"]

# responses
tokenizer.padding_side = "right"
outputs = tokenizer.pad(
[{"input_ids": input.response_ids} for input in inputs],
padding="max_length",
max_length=config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=True,
)
response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"]

# response_mask
outputs = tokenizer.pad(
[{"input_ids": input.response_mask} for input in inputs],
padding="max_length",
max_length=config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=False,
)
response_mask = outputs["input_ids"]
assert response_ids.shape == response_mask.shape, (
f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}"
)
response_mask = response_mask * response_attention_mask

# Handle multi-modal inputs and position_ids calculation
# Only support Qwen2VLImageProcessor for multi-modal processing currently
# TODO: support other multi-modal inputs
multi_modal_inputs = None
if processor is not None and "Qwen2VLImageProcessor" in processor.image_processor.__class__.__name__:
# qwen-vl mrope
if "Qwen3VLProcessor" in processor.__class__.__name__:
pass
else:
pass

images = [one.get("image", None) for one in full_batch.non_tensor_batch.get("multi_modal_data")]
current_text = [tokenizer.decode(input.prompt_ids, skip_special_tokens=False) for input in inputs]
multi_modal_inputs = processor(
text=current_text,
images=images,
return_tensors="pt",
max_length=config.actor_rollout_ref.rollout.prompt_length,
padding="max_length",
padding_side="left",
)

prompt_ids = multi_modal_inputs.pop("input_ids")
prompt_attention_mask = multi_modal_inputs.pop("attention_mask")

# TODO: megatron will cauculate rope position_ids in the forward pass, so we don't need to calculate it here
# but for FSDP support, we need to calculate it here

# # We must use dict(multi_modal_inputs) to convert BatchFeature values to a new dict
# # because np.array() only keeps the keys for BatchFeature.
# multi_modal_inputs = dict(multi_modal_inputs)

# image_grid_thw = multi_modal_inputs.get("image_grid_thw")
# video_grid_thw = multi_modal_inputs.get("video_grid_thw")
# second_per_grid_ts = multi_modal_inputs.get("second_per_grid_ts")

# vision_position_ids = get_rope_index(
# processor,
# input_ids=input_ids.squeeze(0),
# image_grid_thw=image_grid_thw,
# video_grid_thw=video_grid_thw,
# second_per_grid_ts=second_per_grid_ts,
# attention_mask=attention_mask.squeeze(0),
# ).unsqueeze(0) # (1, 3, seq_len)

# valid_mask = attention_mask[0].bool()
# text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long)
# text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item())
# text_position_ids = text_position_ids.unsqueeze(0)
# position_ids = torch.cat((text_position_ids, vision_position_ids), dim=1) # (1, 4, seq_length)
else:
pass
input_ids = torch.cat([prompt_ids, response_ids], dim=1)
attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1)
position_ids = compute_position_id_with_mask(attention_mask) # (1, seq_len)

batch = TensorDict(
{
"prompts": prompt_ids, # [bsz, prompt_length]
"responses": response_ids, # [bsz, response_length]
"response_mask": response_mask, # [bsz, response_length]
"input_ids": input_ids, # [bsz, prompt_length + response_length]
"attention_mask": attention_mask, # [bsz, prompt_length + response_length]
"position_ids": position_ids, # [bsz, prompt_length + response_length]
},
batch_size=len(input_ids),
)

response_logprobs_list = []
for input in inputs:
pad_size = config.actor_rollout_ref.rollout.response_length - len(input.response_logprobs)
response_logprobs = torch.tensor(input.response_logprobs + [0.0] * pad_size).unsqueeze(0)
response_logprobs_list.append(response_logprobs)
rollout_log_probs = torch.cat(response_logprobs_list, dim=0)
batch["rollout_log_probs"] = rollout_log_probs # [bsz, response_length]

num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32)
metrics = [input.metrics.model_dump() for input in inputs]
return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}, meta_info={"metrics": metrics})


@dataclass
Expand Down Expand Up @@ -230,40 +91,6 @@ def prepare_single_generation_data(batch_dict, config) -> DataProto:
return full_batch


def merge_rollout_sample(config, tokenizer, rs: RolloutSample, processor):
"""
Supplement and refine the RolloutSample object,
"""
# Step 1: Create a DataProto from the AgentLoopOutput to generate the result
gen_batch_output = postprocess_agent_loop_outputs(rs, tokenizer, config, processor)

# Step 2: Add uid
rs.full_batch.non_tensor_batch["uid"] = np.array([f"uid_{rs.sample_id}"] * len(rs.full_batch), dtype=object)

# Step 2: Merge batches
# Merge the non_tensor_batch and meta_info of original_batch into final_batch
for key, value in rs.full_batch.non_tensor_batch.items():
gen_batch_output.non_tensor_batch[key] = value
gen_batch_output.meta_info.update(rs.full_batch.meta_info)

# Step 3, set full_batch
rs.full_batch = gen_batch_output
rs.processing_times = []
rs.tool_calls = []
for agent_loop in rs.agent_loop_output_list:
rs.processing_times.append(agent_loop.metrics.generate_sequences)
rs.tool_calls.append(agent_loop.metrics.tool_calls)
rs.param_version_start = [
agent_loop.extra_fields.get("param_version_start", 0) for agent_loop in rs.agent_loop_output_list
]
rs.param_version_end = [
agent_loop.extra_fields.get("param_version_end", 0) for agent_loop in rs.agent_loop_output_list
]
# Step 4, clear agent_loop_output_list
rs.agent_loop_output_list = []
return rs


def assemble_batch_from_rollout_samples(
rollout_samples: list[RolloutSample], tokenizer, config, balance_batch=None
) -> DataProto:
Expand Down Expand Up @@ -299,8 +126,6 @@ def assemble_batch_from_rollout_samples(

for rs in rollout_samples:
rollout_samples_batch.append(rs.full_batch)
processing_times.extend(rs.processing_times)
tool_calls.extend(rs.tool_calls)
final_batch = DataProto.concat(rollout_samples_batch)

# Calculate response_mask (if not present)
Expand All @@ -314,9 +139,9 @@ def assemble_batch_from_rollout_samples(
if "attention_mask" in final_batch.batch:
final_batch.meta_info["global_token_num"] = torch.sum(final_batch.batch["attention_mask"], dim=-1).tolist()

processing_times = final_batch.non_tensor_batch["processing_times"]
tool_calls = final_batch.non_tensor_batch["tool_calls_times"]
# Collect statistics
param_versions = [rs.param_version for rs in rollout_samples]
trajectorys_param_versions = [version for rs in rollout_samples for version in rs.param_version_end]

processing_time_stats = {
"processing_time/avg": np.mean(processing_times),
Expand All @@ -327,16 +152,16 @@ def assemble_batch_from_rollout_samples(
"processing_time/tp95": np.percentile(processing_times, 95),
}
tool_calls_stats = {}
if tool_calls:
if len(tool_calls) > 0:
tool_calls_stats = {
"timing_s/agent_loop/tool_calls/max": np.max(tool_calls),
"timing_s/agent_loop/tool_calls/min": np.min(tool_calls),
"timing_s/agent_loop/tool_calls/mean": np.mean(tool_calls),
}
processing_time_stats = {f"fully_async/{key}": value for key, value in processing_time_stats.items()}

param_version_start = [v for rs in rollout_samples for v in rs.param_version_start]
param_version_end = [v for rs in rollout_samples for v in rs.param_version_end]
param_version_start = final_batch.non_tensor_batch["param_version_start"]
param_version_end = final_batch.non_tensor_batch["param_version_end"]
param_version_diff = [abs(a - b) for a, b in zip(param_version_end, param_version_start, strict=False)]
num_diff0 = param_version_diff.count(0)
partial_stats = {
Expand All @@ -345,6 +170,9 @@ def assemble_batch_from_rollout_samples(
"fully_async/partial/max_partial_span": max(param_version_diff),
}
# add meta_info
param_versions = [rs.param_version for rs in rollout_samples]
trajectorys_param_versions = final_batch.non_tensor_batch["param_version_end"]

final_batch.meta_info.update(
{
"rollout_param_versions": param_versions,
Expand Down
3 changes: 2 additions & 1 deletion recipe/fully_async_policy/fully_async_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import socket
import threading
Expand Down Expand Up @@ -278,7 +279,7 @@ def _run_training_loop(self):
ray.cancel(future)
raise
finally:
self.components["message_queue_client"].clear_queue()
asyncio.run(self.components["message_queue_client"].clear_queue())
print("[ASYNC MAIN] Training completed or interrupted")


Expand Down
Loading
Loading