diff --git a/verl/utils/tensordict_utils.py b/verl/utils/tensordict_utils.py index f0e6b980a04..c757450f37e 100644 --- a/verl/utils/tensordict_utils.py +++ b/verl/utils/tensordict_utils.py @@ -21,6 +21,23 @@ def assign_non_tensor_data(tensor_dict: TensorDict, key, val): + """Assign a single non-tensor value to a TensorDict. + + Wraps the value in NonTensorData so it can be stored alongside tensors + in the TensorDict. Use this for scalar metadata or simple non-tensor values. + + Args: + tensor_dict: The TensorDict to assign to. + key: The key under which to store the value. + val: Any non-tensor value to store (e.g., string, int, dict). + + Raises: + AssertionError: If tensor_dict is not a TensorDict. + + Example: + >>> td = TensorDict({"obs": torch.randn(3, 4)}, batch_size=[3]) + >>> assign_non_tensor_data(td, "experiment_name", "run_001") + """ assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict" tensor_dict[key] = NonTensorData(val) @@ -88,17 +105,81 @@ def assign_non_tensor(tensor_dict: TensorDict, **kwargs): def unwrap_non_tensor_data(data): + """Unwrap a NonTensorData object to get the underlying value. + + If the input is a NonTensorData wrapper, extracts and returns the + underlying data. Otherwise, returns the input unchanged. + + Args: + data: Either a NonTensorData object or any other value. + + Returns: + The unwrapped data if input was NonTensorData, otherwise the + original input unchanged. + + Example: + >>> wrapped = NonTensorData("hello") + >>> unwrap_non_tensor_data(wrapped) + 'hello' + >>> unwrap_non_tensor_data(42) # Non-wrapped value + 42 + """ if isinstance(data, NonTensorData): return data.data return data def get_non_tensor_data(data: TensorDict, key: str, default): + """Retrieve and unwrap non-tensor data from a TensorDict. + + Fetches the value for the given key from the TensorDict and automatically + unwraps it if it's stored as NonTensorData. + + Args: + data: The TensorDict to retrieve from. + key: The key to look up. + default: Value to return if the key is not found. + + Returns: + The unwrapped value if the key exists and was wrapped in NonTensorData, + the raw value if it wasn't wrapped, or the default if key not found. + + Example: + >>> td = TensorDict({}, batch_size=[]) + >>> assign_non_tensor_data(td, "config", {"lr": 0.01}) + >>> get_non_tensor_data(td, "config", None) + {'lr': 0.01} + >>> get_non_tensor_data(td, "missing", "default_value") + 'default_value' + """ output = data.get(key, default) return unwrap_non_tensor_data(output) def concat_nested_tensors(tensors: list[torch.Tensor]) -> torch.Tensor: + """Concatenate multiple 2D nested tensors along the batch dimension. + + Takes a list of nested tensors with jagged layout and concatenates them + into a single nested tensor. Each input tensor must be 2D and contiguous. + + Args: + tensors: List of 2D nested tensors to concatenate. All tensors must + be nested, contiguous, and have exactly 2 dimensions. + + Returns: + A new nested tensor with jagged layout containing all rows from + the input tensors concatenated along dimension 0. + + Raises: + AssertionError: If any tensor is not nested, not contiguous, or + doesn't have exactly 2 dimensions. + + Example: + >>> t1 = torch.nested.as_nested_tensor([torch.randn(3), torch.randn(5)], layout=torch.jagged) + >>> t2 = torch.nested.as_nested_tensor([torch.randn(2), torch.randn(4)], layout=torch.jagged) + >>> result = concat_nested_tensors([t1, t2]) + >>> # result contains 4 rows: lengths [3, 5, 2, 4] + """ for tensor in tensors: assert tensor.is_nested and tensor.is_contiguous() unbind_tensors = [] @@ -112,6 +193,25 @@ def concat_nested_tensors(tensors: list[torch.Tensor]) -> torch.Tensor: def concat_tensordict_with_none_bsz(data: list[TensorDict]): + """Handle concatenation of TensorDicts with empty batch size. + + For TensorDicts that contain only metadata (NonTensorData) with no batch + dimension, returns the first TensorDict as the concatenation result. + + Args: + data: List of TensorDicts, each with empty batch_size (batch_size=[]). + + Returns: + The first TensorDict from the list, as metadata concatenation + simply preserves the first instance. + + Raises: + AssertionError: If any TensorDict has a non-empty batch_size. + + Note: + This is used internally by concat_tensordict when handling + TensorDicts that contain only non-tensor metadata. + """ for d in data: assert len(d.batch_size) == 0 # directly return the first meta info @@ -119,7 +219,28 @@ def concat_tensordict_with_none_bsz(data: list[TensorDict]): def concat_tensordict(data: list[TensorDict]) -> TensorDict: - """Concatenates tensordicts into a single tensordict on dim zero. Support nested tensor""" + """Concatenate multiple TensorDicts along dimension zero. + + Combines a list of TensorDicts into a single TensorDict by concatenating + all tensors along the batch dimension (dim=0). Handles nested tensors + specially by unbinding and rebinding them. + + Args: + data: List of TensorDicts to concatenate. All TensorDicts must have + the same keys and the same set of nested tensor keys. + + Returns: + A new TensorDict containing concatenated tensors from all inputs. + + Raises: + AssertionError: If data is empty or if TensorDicts have inconsistent + nested tensor keys. + + Note: + - For TensorDicts with empty batch_size, returns the first one + - Nested tensors are handled specially via concat_nested_tensors + - Regular tensors use TensorDict.cat for efficient concatenation + """ assert len(data) > 0, "Must have at least one tensordict" # Find nested tensor keys from the first tensordict @@ -153,10 +274,27 @@ def concat_tensordict(data: list[TensorDict]) -> TensorDict: def chunk_tensordict(td: TensorDict, chunks: int) -> list[TensorDict]: - """Splits a tensordict into the specified number of chunks with special handling of 3d nested tensors. + """Split a TensorDict into equal-sized chunks with special nested tensor handling. + + Divides a TensorDict into the specified number of chunks along the batch + dimension. Handles 3D+ nested tensors specially since torch.chunk() doesn't + support jagged tensors with 3 or more dimensions. - This is a workaround for torch.chunk() not support 3d jagged tensor, e.g. MRoPE position_id. - https://github.com/pytorch/pytorch/issues/153238 + Args: + td: The TensorDict to split. + chunks: Number of chunks to create. Must evenly divide len(td). + + Returns: + List of TensorDicts, each containing a portion of the original data. + + Raises: + AssertionError: If td is not a TensorDict or if its length is not + evenly divisible by chunks. + + Note: + This is a workaround for PyTorch issue #153238 where torch.chunk() + doesn't support 3D jagged tensors (e.g., MRoPE position_ids). + See: https://github.com/pytorch/pytorch/issues/153238 """ assert isinstance(td, TensorDict) and len(td) % chunks == 0, ( f"expecting td with length divisible by chunks, but got {len(td)} and {chunks}" @@ -254,7 +392,29 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: def index_select_tensor_dict(batch: TensorDict, indices: torch.Tensor | list[int]) -> TensorDict: - """Index a tensor dict with a tensor of indices.""" + """Select rows from a TensorDict using indices. + + Creates a new TensorDict containing only the rows specified by indices. + Handles regular tensors, nested tensors, NonTensorStack, and NonTensorData + appropriately. + + Args: + batch: The TensorDict to index into. Can be None. + indices: 1D tensor or list of integers specifying which rows to select. + + Returns: + A new TensorDict containing only the selected rows, or None if + batch was None. + + Raises: + AssertionError: If indices is not 1-dimensional. + + Note: + - Regular tensors are indexed directly + - Nested tensors are unbound, indexed, and rebound + - NonTensorStack is indexed by batch dimension + - NonTensorData (scalar metadata) is preserved unchanged + """ if isinstance(indices, list): indices = torch.tensor(indices) @@ -286,7 +446,30 @@ def index_select_tensor_dict(batch: TensorDict, indices: torch.Tensor | list[int def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: - """Union two tensordicts.""" + """Merge two TensorDicts, adding keys from the second to the first. + + Performs an in-place union of two TensorDicts. Keys from tensor_dict2 + that don't exist in tensor_dict1 are added. Keys that exist in both + must have identical values. + + Args: + tensor_dict1: The base TensorDict to merge into (modified in-place). + tensor_dict2: The TensorDict whose keys will be added to tensor_dict1. + + Returns: + The modified tensor_dict1 containing the union of both TensorDicts. + + Raises: + AssertionError: If batch sizes don't match, or if a key exists in + both TensorDicts with different values. + + Example: + >>> td1 = TensorDict({"a": torch.tensor([1, 2])}, batch_size=[2]) + >>> td2 = TensorDict({"b": torch.tensor([3, 4])}, batch_size=[2]) + >>> result = union_tensor_dict(td1, td2) + >>> list(result.keys()) + ['a', 'b'] + """ assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" ) @@ -309,6 +492,32 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten def make_iterator(tensordict: TensorDict, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): + """Create an iterator that yields mini-batches from a TensorDict. + + Wraps a TensorDict in a DataLoader-style iterator that yields mini-batches + for the specified number of epochs. Useful for training loops. + + Args: + tensordict: The TensorDict to iterate over. + mini_batch_size: Size of each mini-batch. Must evenly divide the + TensorDict's batch size. + epochs: Number of times to iterate through the entire dataset. + seed: Optional random seed for reproducible shuffling. + dataloader_kwargs: Optional dict of additional kwargs to pass to + the underlying DataLoader (e.g., shuffle=True, num_workers=4). + + Returns: + An iterator that yields TensorDict mini-batches. + + Raises: + AssertionError: If batch size is not divisible by mini_batch_size. + + Example: + >>> td = TensorDict({"obs": torch.randn(100, 4)}, batch_size=[100]) + >>> for batch in make_iterator(td, mini_batch_size=10, epochs=2): + ... # batch is a TensorDict with batch_size=[10] + ... pass + """ from torch.utils.data import DataLoader assert tensordict.batch_size[0] % mini_batch_size == 0, f"{tensordict.batch_size[0]} % {mini_batch_size} != 0" @@ -339,6 +548,25 @@ def get_data(): def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict): + """Assert that two TensorDicts are equal. + + Performs a deep equality check between two TensorDicts, verifying that + they have the same keys with identical values. Handles nested tensors + by comparing their unbound components. + + Args: + tensordict1: First TensorDict to compare. + tensordict2: Second TensorDict to compare. + + Raises: + AssertionError: If the TensorDicts differ in keys, value types, or + value contents. The error message indicates what differs. + + Note: + - Regular tensors are compared element-wise + - Nested tensors are unbound and compared component by component + - Non-tensor values are compared with standard equality + """ tensordict1_key_set = set(tensordict1.keys()) tensordict2_key_set = set(tensordict2.keys()) assert tensordict1_key_set == tensordict2_key_set, ( @@ -367,6 +595,28 @@ def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict): def get(tensordict: TensorDict, key: str, default=None) -> Any: + """Get a value from a TensorDict with automatic unwrapping. + + Retrieves a value from the TensorDict and automatically converts it + to a Python-native format: + - Tensors are returned as-is + - NonTensorStack is converted to a Python list + - NonTensorData is unwrapped to its underlying value + + Args: + tensordict: The TensorDict to retrieve from. + key: The key to look up. + default: Value to return if the key doesn't exist. Defaults to None. + + Returns: + The value for the key in its native format, or default if not found. + + Example: + >>> td = get_tensordict({"obs": torch.randn(3, 4), "labels": ["a", "b", "c"]}) + >>> get(td, "obs") # Returns torch.Tensor + >>> get(td, "labels") # Returns ["a", "b", "c"] as a list + >>> get(td, "missing", "default") # Returns "default" + """ if key not in tensordict: return default @@ -381,6 +631,27 @@ def get(tensordict: TensorDict, key: str, default=None) -> Any: def get_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: + """Extract a subset of keys from a TensorDict into a new TensorDict. + + Creates a new TensorDict containing only the specified keys. Values + are properly categorized as tensor or non-tensor data. + + Args: + tensordict: The source TensorDict. + keys: Iterable of key names to extract. + + Returns: + A new TensorDict containing only the specified keys with their values. + + Raises: + KeyError: If any key in keys doesn't exist in the tensordict. + + Example: + >>> td = get_tensordict({"a": torch.randn(3), "b": torch.randn(3), "c": torch.randn(3)}) + >>> subset = get_keys(td, ["a", "c"]) + >>> list(subset.keys()) + ['a', 'c'] + """ tensor_output = {} non_tensor_output = {} for key in keys: @@ -399,6 +670,26 @@ def get_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: def pop(tensordict: TensorDict, key: str, default=None) -> Any: + """Remove and return a value from a TensorDict with automatic unwrapping. + + Removes the specified key from the TensorDict and returns its value, + automatically converting to Python-native format (same as get()). + + Args: + tensordict: The TensorDict to pop from. + key: The key to remove and return. + default: Value to return if the key doesn't exist. Defaults to None. + + Returns: + The value for the key in its native format, or default if not found. + The key is removed from the TensorDict. + + Example: + >>> td = get_tensordict({"obs": torch.randn(3, 4), "labels": ["a", "b", "c"]}) + >>> labels = pop(td, "labels") # Returns ["a", "b", "c"], removes from td + >>> "labels" in td.keys() + False + """ _sentinel = object() output = tensordict.pop(key, _sentinel) if output is _sentinel: @@ -414,6 +705,29 @@ def pop(tensordict: TensorDict, key: str, default=None) -> Any: def pop_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: + """Remove multiple keys from a TensorDict and return them as a new TensorDict. + + Removes the specified keys from the source TensorDict and creates a new + TensorDict containing those keys and their values. + + Args: + tensordict: The source TensorDict to pop from (modified in-place). + keys: Iterable of key names to remove and return. + + Returns: + A new TensorDict containing the popped keys and their values. + + Raises: + KeyError: If any key in keys doesn't exist in the tensordict. + + Example: + >>> td = get_tensordict({"a": torch.randn(3), "b": torch.randn(3), "c": torch.randn(3)}) + >>> popped = pop_keys(td, ["a", "c"]) + >>> list(td.keys()) # Only 'b' remains + ['b'] + >>> list(popped.keys()) + ['a', 'c'] + """ tensor_output = {} non_tensor_output = {} for key in keys: @@ -432,14 +746,31 @@ def pop_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: def pad_to_divisor(data: TensorDict, size_divisor: int): - """Pad a TensorDict to size divisible by size_divisor + """Pad a TensorDict's batch dimension to be divisible by a given divisor. + + If the TensorDict's length is not evenly divisible by size_divisor, + pads the batch dimension by repeating elements from the beginning. + Useful for ensuring even distribution across workers in distributed training. Args: - size_divisor (int): size divisor + data: The TensorDict to pad. + size_divisor: The divisor that the padded length must be divisible by. Returns: - data: (TensorDict): the padded TensorDict - pad_size (int) + tuple: A tuple containing: + - data (TensorDict): The padded TensorDict (or original if no padding needed) + - pad_size (int): Number of elements added as padding (0 if none) + + Raises: + AssertionError: If data is not a TensorDict. + + Example: + >>> td = TensorDict({"obs": torch.randn(10, 4)}, batch_size=[10]) + >>> padded, pad_size = pad_to_divisor(td, 4) + >>> len(padded) # 12 (next multiple of 4 after 10) + 12 + >>> pad_size + 2 """ assert isinstance(data, TensorDict), "data must be a TensorDict" if len(data) % size_divisor != 0: @@ -460,7 +791,25 @@ def pad_to_divisor(data: TensorDict, size_divisor: int): def unpad(data: TensorDict, pad_size): - """Unpad the data proto with pad_size. i.e. `data[:-pad_size]`""" + """Remove padding from a TensorDict. + + Reverses the effect of pad_to_divisor by removing the specified number + of elements from the end of the TensorDict. + + Args: + data: The padded TensorDict. + pad_size: Number of padding elements to remove. If 0, returns + data unchanged. + + Returns: + The TensorDict with padding removed, equivalent to data[:-pad_size]. + + Example: + >>> td = TensorDict({"obs": torch.randn(12, 4)}, batch_size=[12]) + >>> unpadded = unpad(td, pad_size=2) + >>> len(unpadded) + 10 + """ if pad_size != 0: data = data[:-pad_size] return data