diff --git a/verl/protocol.py b/verl/protocol.py index 0461be9c022..f90488a81a8 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -35,7 +35,7 @@ from torch.utils.data import DataLoader from verl.utils.device import get_device_id, get_torch_device -from verl.utils.py_functional import union_two_dict +from verl.utils.py_functional import union_two_dict, append_to_dict_recursion from verl.utils.torch_functional import allgather_dict_tensors __all__ = ["DataProto", "union_tensor_dict"] @@ -699,8 +699,19 @@ def concat(data: List["DataProto"]) -> "DataProto": DataProto: concatenated DataProto """ batch_lst = [] - for batch in data: + meta_info = data[0].meta_info + if "metrics" in batch.meta_info: + meta_info["metrics"] = {} + for i, batch in enumerate(data): batch_lst.append(batch.batch) + if batch.meta_info is not None: + if "metrics" in batch.meta_info: + for key ,val in batch.meta_info["metrics"]: + if isinstance(val, List): + append_to_dict_recursion(meta_info["metrics"][key], batch.meta_info["metrics"][key]) + else: + meta_info["metrics"][key][i] = batch.meta_info["metrics"][key] + new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) @@ -708,7 +719,7 @@ def concat(data: List["DataProto"]) -> "DataProto": non_tensor_batch[key] = np.concatenate(val, axis=0) cls = type(data[0]) if len(data) > 0 else DataProto - return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) + return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=meta_info) def reorder(self, indices): """ diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py index 7f476387c16..999b643e9c2 100644 --- a/verl/utils/py_functional.py +++ b/verl/utils/py_functional.py @@ -22,7 +22,7 @@ import signal from functools import wraps from types import SimpleNamespace -from typing import Any, Callable, Dict, Iterator, Optional, Tuple +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, List # --- Top-level helper for multiprocessing timeout --- @@ -174,6 +174,38 @@ def append_to_dict(data: Dict, new_data: Dict): data[key] = [] data[key].append(val) +def _flatten_and_append(data: List[Any], value: Any): + """Flatten and append values from value to lists in data. + + Args: + data (List): The target list. + value: The source values to append. + + Returns: + None: The function modifies data in-place. + """ + if isinstance(value, (list, tuple)): + for item in value: + _flatten_and_append(data, item) + elif isinstance(value, dict): + for item in value.values(): + _flatten_and_append(data, item) + else: + data.append(value) + +def append_to_dict_recursion(data: Dict, new_data: Dict): + """Recursively append values from new_data to a list in the data. + Args: + data (Dict): The target dictionary containing lists as values. + new_data (Dict): The source dictionary with values to append. + + Returns: + None: The function modifies data in-place. + """ + for key, val in new_data.items(): + if key not in data: + data[key] = [] + _flatten_and_append(data[key], val) class NestedNamespace(SimpleNamespace): """A nested version of SimpleNamespace that recursively converts dictionaries to namespaces.