Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions recipe/transfer_queue/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio

import numpy as np
import ray
from transfer_queue import BatchMeta
Expand Down Expand Up @@ -65,19 +67,33 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Batc
timing["agent_loop/tool_calls/max"] = t_tool_calls.max()
timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean()

# TODO (TQ): pass tq info throughout AgentLoop so we can retrieve tensor for these metrics
# TODO (TQ): initialize tq during init when enable TQ switch is stable
tq_client = self._create_transferqueue_client()
# batch sequence generation is bounded by the slowest sample
# slowest = np.argmax(t_generate_sequences + t_tool_calls)
# attention_mask = output.extra_info.pop("attention_mask_perf")[slowest]
# prompt_length = output.extra_info.pop("prompts_perf").shape[1]
# timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest]
# timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest]
# timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item()
# timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item()
slowest = np.argmax(t_generate_sequences + t_tool_calls)
attention_mask = asyncio.run(tq_client.async_get_data(output[slowest]))["attention_mask"]
prompt_length = output.samples[0].fields["prompts"].shape[0]
timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest]
timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest]
timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item()
timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item()

return timing

def create_transferqueue_client(self, controller_info, config):
ray.get(
[worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers]
def create_transferqueue_client_for_workers(self):
# TODO (TQ): initialize tq during worker init when enable TQ switch is stable
ray.get([worker.create_transferqueue_client.remote() for worker in self.agent_loop_workers])

def _create_transferqueue_client(self):
"""Create a client for data system (TransferQueue)."""
from verl.single_controller.ray.base import get_random_string
from verl.utils.transferqueue_utils import create_transferqueue_client

client_name = get_random_string(length=6)

tq_client = create_transferqueue_client(
client_id=f"AgentLoopManager_{client_name}",
config=self.config.transfer_queue,
)

return tq_client
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ defaults:
# config for TransferQueue
transfer_queue:
enable: True
num_global_batch: 1
storage_backend: AsyncSimpleStorageManager
num_data_storage_units: 8
3 changes: 3 additions & 0 deletions recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ defaults:
# config for TransferQueue
transfer_queue:
enable: True
num_global_batch: 1
storage_backend: AsyncSimpleStorageManager
num_data_storage_units: 8
179 changes: 89 additions & 90 deletions recipe/transfer_queue/ray_trainer.py

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions recipe/transfer_queue/run_qwen3-8b_transferqueue.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,5 @@ python3 -m recipe.transfer_queue.main_ppo \
trainer.total_epochs=15 \
trainer.total_training_steps=2 \
trainer.val_before_train=False \
+trainer.num_global_batch=1 \
+trainer.num_data_storage_units=8 \
2>&1 | tee "$log_file"
echo "Finished, log is saved in: $log_file"
2 changes: 0 additions & 2 deletions tests/special_e2e/run_transferqueue.sh
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ common_params=(
trainer.total_training_steps=2
trainer.total_epochs=15
trainer.val_before_train=True
+trainer.num_global_batch=1
+trainer.num_data_storage_units=8
)

if [ "${ACTOR_STRATEGY}" == "fsdp" ]; then
Expand Down
14 changes: 8 additions & 6 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,16 +633,18 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
meta_info={"metrics": metrics, "reward_extra_keys": reward_extra_keys},
)

def create_transferqueue_client(self, controller_info, role):
"""Create a client for data system(transfer queue)."""
def create_transferqueue_client(
self,
):
"""Create a client for data system (TransferQueue)."""
from verl.single_controller.ray.base import get_random_string
from verl.utils.transferqueue_utils import create_transferqueue_client

client_name = get_random_string(length=6)
create_transferqueue_client(
client_id=f"{role}_worker_{client_name}",
controller_info=controller_info,
config=self.config,

self.tq_client = create_transferqueue_client(
client_id=f"AgentLoopWorker_{client_name}",
config=self.config.transfer_queue,
)


Expand Down
5 changes: 2 additions & 3 deletions verl/single_controller/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,12 @@ def set_dispatch_collect(self, mesh_name: str, dispatch_dp_rank: dict[str, int],
self.__collect_dp_rank[mesh_name] = is_collect

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
def create_transferqueue_client(self, controller_info, config):
def create_transferqueue_client(self, config):
from verl.utils.transferqueue_utils import create_transferqueue_client

create_transferqueue_client(
client_id=f"worker_{self.rank}",
controller_info=controller_info,
config=config,
config=config.transfer_queue,
)

@classmethod
Expand Down
14 changes: 8 additions & 6 deletions verl/utils/transferqueue_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from transfer_queue import (
AsyncTransferQueueClient,
BatchMeta,
ZMQServerInfo,
)

except ImportError:
Expand All @@ -44,18 +43,21 @@ class BatchMeta:

def create_transferqueue_client(
client_id: str,
controller_info: "ZMQServerInfo",
config,
) -> None:
) -> "AsyncTransferQueueClient":
global _TRANSFER_QUEUE_CLIENT
_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_info)
_TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
if _TRANSFER_QUEUE_CLIENT is None:
_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, config.controller_info)
_TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=config.storage_backend, config=config)

return _TRANSFER_QUEUE_CLIENT


def get_transferqueue_client() -> "AsyncTransferQueueClient":
return _TRANSFER_QUEUE_CLIENT


# TODO (TQ): verl will make all actor async, so this can be cleanup later.
def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any:
# Use a temporary event loop in a new thread because event
# loop may already exist in server mode
Expand Down Expand Up @@ -127,7 +129,7 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") ->


def tqbridge(put_data: bool = True):
""" "Creates a decorator for bridging BatchMeta and DataProto.
"""Creates a decorator for bridging BatchMeta and DataProto.

This decorator automatically handles conversions between `BatchMeta` and
`DataProto` in function parameters, and decides whether to sync function
Expand Down
Loading