-
Notifications
You must be signed in to change notification settings - Fork 6
[data] feat: Provide general decorator for DataProto <-> BatchMeta #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
6eab275
c87a7a6
7aba92f
9dc427b
5d9cf9d
cbd3907
616a8fa
ebc1bed
1198419
254ef97
34b08d2
5aabe8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,291 @@ | ||||||||||||||||||||||||||
| # Copyright 2025 The TransferQueue Team | ||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| DataProto conversion decorator for TransferQueue integration. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| This decorator wraps functions that take DataProto as input and return DataProto as output, | ||||||||||||||||||||||||||
| enabling them to work with BatchMeta and TransferQueue system. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Pattern: | ||||||||||||||||||||||||||
| 1. Input: BatchMeta + TransferQueueClient | ||||||||||||||||||||||||||
| 2. Decorator: BatchMeta -> DataProto -> function(DataProto) -> DataProto -> update BatchMeta | ||||||||||||||||||||||||||
| 3. Output: Updated BatchMeta | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||
| import functools | ||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||
| from typing import Any, Callable, Optional | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
| from tensordict import TensorDict, NonTensorData, NonTensorStack | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from verl import DataProto | ||||||||||||||||||||||||||
| from verl.experimental.transfer_queue import AsyncTransferQueueClient, BatchMeta | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def dataproto_batchmeta_conversion(transfer_queue_client: Optional[AsyncTransferQueueClient] = None): | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Decorator for converting DataProto functions to work with BatchMeta. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| This decorator enables DataProto-based functions to work with TransferQueue's | ||||||||||||||||||||||||||
| BatchMeta system by: | ||||||||||||||||||||||||||
| 1. Converting BatchMeta input to DataProto via client | ||||||||||||||||||||||||||
| 2. Calling the wrapped function with DataProto | ||||||||||||||||||||||||||
| 3. Converting function's DataProto output back to update BatchMeta | ||||||||||||||||||||||||||
| 4. Returning the updated BatchMeta | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||
| transfer_queue_client: AsyncTransferQueueClient for data operations | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Usage: | ||||||||||||||||||||||||||
| @dataproto_batchmeta_conversion(client) | ||||||||||||||||||||||||||
| def apply_kl_penalty(data: DataProto, kl_ctrl) -> DataProto: | ||||||||||||||||||||||||||
| # Function works with DataProto as usual | ||||||||||||||||||||||||||
| response_mask = data.batch["response_mask"] | ||||||||||||||||||||||||||
| # ... compute kl_penalty ... | ||||||||||||||||||||||||||
| data.batch["kl_penalty"] = kl_penalty_result | ||||||||||||||||||||||||||
| return data | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Usage with BatchMeta: | ||||||||||||||||||||||||||
| batch_meta = apply_kl_penalty(batch_meta, kl_ctrl, transfer_queue_client=client) | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def decorator(func: Callable) -> Callable: | ||||||||||||||||||||||||||
| @functools.wraps(func) | ||||||||||||||||||||||||||
| async def async_wrapper(*args, **kwargs): | ||||||||||||||||||||||||||
| # Extract batch_meta and client from arguments | ||||||||||||||||||||||||||
| batch_meta, client, other_args, other_kwargs = _extract_args(args, kwargs, transfer_queue_client) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Convert BatchMeta to DataProto | ||||||||||||||||||||||||||
| data = await _batchmeta_to_dataproto_async(batch_meta, client) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Call function with DataProto | ||||||||||||||||||||||||||
| result_data = await func(data, *other_args, **other_kwargs) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Update BatchMeta with result | ||||||||||||||||||||||||||
| await _update_batchmeta_with_result_async(result_data, batch_meta, client) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return batch_meta | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @functools.wraps(func) | ||||||||||||||||||||||||||
| def sync_wrapper(*args, **kwargs): | ||||||||||||||||||||||||||
| # Extract batch_meta and client from arguments | ||||||||||||||||||||||||||
| batch_meta, client, other_args, other_kwargs = _extract_args(args, kwargs, transfer_queue_client) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Convert BatchMeta to DataProto | ||||||||||||||||||||||||||
| data = _batchmeta_to_dataproto_sync(batch_meta, client) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Call function with DataProto | ||||||||||||||||||||||||||
| result_data = func(data, *other_args, **other_kwargs) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Update BatchMeta with result | ||||||||||||||||||||||||||
| _update_batchmeta_with_result_sync(result_data, batch_meta, client) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return batch_meta | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return decorator | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _extract_args(args: tuple, kwargs: dict, default_client: Optional[AsyncTransferQueueClient]): | ||||||||||||||||||||||||||
| """Extract batch_meta, client, and other arguments from function call.""" | ||||||||||||||||||||||||||
| # Find batch_meta (first argument) | ||||||||||||||||||||||||||
| batch_meta = args[0] if args else None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Find client in kwargs or use default | ||||||||||||||||||||||||||
| client = kwargs.pop('transfer_queue_client', default_client) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Remaining arguments | ||||||||||||||||||||||||||
| other_args = args[1:] if len(args) > 1 else () | ||||||||||||||||||||||||||
| other_kwargs = kwargs | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return batch_meta, client, other_args, other_kwargs | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _batchmeta_to_dataproto_sync(batch_meta: BatchMeta, client: Optional[AsyncTransferQueueClient]) -> DataProto: | ||||||||||||||||||||||||||
| """Convert BatchMeta to DataProto (synchronous).""" | ||||||||||||||||||||||||||
| if client is not None: | ||||||||||||||||||||||||||
| # Check if we're already in an event loop | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| loop = asyncio.get_running_loop() | ||||||||||||||||||||||||||
| # We're in a running loop, this shouldn't happen for sync wrapper | ||||||||||||||||||||||||||
| raise RuntimeError("Sync wrapper called from within async context") | ||||||||||||||||||||||||||
| except RuntimeError: | ||||||||||||||||||||||||||
| # No running loop, we can use asyncio.run | ||||||||||||||||||||||||||
| data_dict = asyncio.run(client.async_get_data(batch_meta)) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
| # We're in a running loop, this shouldn't happen for sync wrapper | |
| raise RuntimeError("Sync wrapper called from within async context") | |
| except RuntimeError: | |
| # No running loop, we can use asyncio.run | |
| data_dict = asyncio.run(client.async_get_data(batch_meta)) | |
| except RuntimeError: | |
| # No running loop, we can use asyncio.run | |
| data_dict = asyncio.run(client.async_get_data(batch_meta)) | |
| else: | |
| # We're in a running loop, this shouldn't happen for sync wrapper | |
| raise RuntimeError("Sync wrapper called from within async context") |
Copilot
AI
Sep 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mock data generation logic is duplicated between sync and async versions (lines 141-150 and 170-179). This should be extracted into a separate helper function to avoid code duplication.
Copilot
AI
Sep 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching broad Exception makes debugging difficult. Consider catching specific TensorDict-related exceptions or at minimum log the specific exception type and tensor_dict contents for better debugging.
Copilot
AI
Sep 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _v2 function appears to be unused and provides the same functionality as the main decorator. Consider removing this duplicate implementation to reduce code complexity.
| def dataproto_batchmeta_conversion_v2(func: Optional[Callable] = None, *, transfer_queue_client: Optional[AsyncTransferQueueClient] = None): | |
| """ | |
| Alternative decorator syntax that supports both @decorator and @decorator() usage. | |
| """ | |
| def decorator(f: Callable) -> Callable: | |
| return dataproto_batchmeta_conversion(transfer_queue_client)(f) | |
| if func is not None: | |
| return decorator(func) | |
| return decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is being awaited but may not be async. The wrapper assumes
funcis async but should check if it's a coroutine function first, or handle both sync and async functions appropriately.