Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from torch.utils.data import DataLoader

from verl.utils.device import get_device_id, get_torch_device
from verl.utils.py_functional import union_two_dict
from verl.utils.py_functional import union_two_dict, append_to_dict_recursion
from verl.utils.torch_functional import allgather_dict_tensors

__all__ = ["DataProto", "union_tensor_dict"]
Expand Down Expand Up @@ -699,16 +699,27 @@ def concat(data: List["DataProto"]) -> "DataProto":
DataProto: concatenated DataProto
"""
batch_lst = []
for batch in data:
meta_info = data[0].meta_info
if "metrics" in batch.meta_info:
meta_info["metrics"] = {}
for i, batch in enumerate(data):
batch_lst.append(batch.batch)
if batch.meta_info is not None:
if "metrics" in batch.meta_info:
for key ,val in batch.meta_info["metrics"]:
if isinstance(val, List):
append_to_dict_recursion(meta_info["metrics"][key], batch.meta_info["metrics"][key])
else:
meta_info["metrics"][key][i] = batch.meta_info["metrics"][key]

new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None

non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])
for key, val in non_tensor_batch.items():
non_tensor_batch[key] = np.concatenate(val, axis=0)

cls = type(data[0]) if len(data) > 0 else DataProto
return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=meta_info)

def reorder(self, indices):
"""
Expand Down
34 changes: 33 additions & 1 deletion verl/utils/py_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import signal
from functools import wraps
from types import SimpleNamespace
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, List


# --- Top-level helper for multiprocessing timeout ---
Expand Down Expand Up @@ -174,6 +174,38 @@ def append_to_dict(data: Dict, new_data: Dict):
data[key] = []
data[key].append(val)

def _flatten_and_append(data: List[Any], value: Any):
"""Flatten and append values from value to lists in data.

Args:
data (List): The target list.
value: The source values to append.

Returns:
None: The function modifies data in-place.
"""
if isinstance(value, (list, tuple)):
for item in value:
_flatten_and_append(data, item)
elif isinstance(value, dict):
for item in value.values():
_flatten_and_append(data, item)
else:
data.append(value)

def append_to_dict_recursion(data: Dict, new_data: Dict):
"""Recursively append values from new_data to a list in the data.
Args:
data (Dict): The target dictionary containing lists as values.
new_data (Dict): The source dictionary with values to append.

Returns:
None: The function modifies data in-place.
"""
for key, val in new_data.items():
if key not in data:
data[key] = []
_flatten_and_append(data[key], val)

class NestedNamespace(SimpleNamespace):
"""A nested version of SimpleNamespace that recursively converts dictionaries to namespaces.
Expand Down