Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 203 additions & 2 deletions tests/test_protocol_v2_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import pytest
import torch
from tensordict.tensorclass import NonTensorData, NonTensorStack

from verl.utils import tensordict_utils as tu

Expand All @@ -45,10 +46,10 @@ def test_union_tensor_dict():
# conflict in tensor values
tu.union_tensor_dict(data1, data_with_copied_obs)

data1 = tu.assign_non_tensor_dict(data1, meta_info1)
data1 = tu.assign_non_tensor(data1, **meta_info1)
tu.union_tensor_dict(data1, data2) # works ok

data2 = tu.assign_non_tensor_dict(data2, meta_info2)
data2 = tu.assign_non_tensor(data2, **meta_info2)

with pytest.raises(AssertionError):
# conflict in NonTensorData
Expand Down Expand Up @@ -651,3 +652,203 @@ def test_concat_tensordict():
# make sure tensordict1 and tensordict2 is untouched
tu.assert_tensordict_eq(tensordict1, tensordict1_copy)
tu.assert_tensordict_eq(tensordict2, tensordict2_copy)


def test_assign_non_tensor_stack_with_nested_lists():
"""Test assign_non_tensor_stack with lists of lists."""
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})

# Lists of varying lengths (like turn_scores or tool_rewards)
turn_scores = [[], [0.5, 0.8], [0.9]]
tu.assign_non_tensor_stack(td, "turn_scores", turn_scores)

# Verify data is accessible
assert len(td["turn_scores"]) == 3
assert list(td["turn_scores"][0]) == []
assert list(td["turn_scores"][1]) == [0.5, 0.8]
assert list(td["turn_scores"][2]) == [0.9]


def test_assign_non_tensor_stack_with_nested_dicts():
"""Test assign_non_tensor_stack with lists of dicts."""
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})

# Lists of dicts (like reward_extra_info)
reward_extra_info = [{"acc": 1.0, "loss": 0.1}, {"acc": 0.0, "loss": 0.9}, {"acc": 1.0, "loss": 0.05}]
tu.assign_non_tensor_stack(td, "reward_extra_info", reward_extra_info)

# Verify data is accessible
assert len(td["reward_extra_info"]) == 3
assert dict(td["reward_extra_info"][0]) == {"acc": 1.0, "loss": 0.1}
assert dict(td["reward_extra_info"][1]) == {"acc": 0.0, "loss": 0.9}
assert dict(td["reward_extra_info"][2]) == {"acc": 1.0, "loss": 0.05}


def test_assign_non_tensor_stack_with_complex_nested():
"""Test assign_non_tensor_stack with lists of lists of dicts."""
td = tu.get_tensordict({"obs": torch.randn(2, 4)}, non_tensor_dict={})

# Lists of lists of dicts (like raw_prompt)
raw_prompt = [
[{"content": "Question 1", "role": "user"}],
[{"content": "Question 2", "role": "user"}, {"content": "Answer 2", "role": "assistant"}],
]
tu.assign_non_tensor_stack(td, "raw_prompt", raw_prompt)

# Verify data is accessible
assert len(td["raw_prompt"]) == 2
assert len(td["raw_prompt"][0]) == 1
assert dict(td["raw_prompt"][0][0]) == {"content": "Question 1", "role": "user"}
assert len(td["raw_prompt"][1]) == 2
assert dict(td["raw_prompt"][1][0]) == {"content": "Question 2", "role": "user"}


def test_assign_non_tensor_handles_wrappers():
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})

meta = {"top_p": 0.8}
tu.assign_non_tensor(td, **meta)
assert td["top_p"] == 0.8

wrapped = NonTensorData(0.3)
stack = NonTensorStack.from_list([NonTensorData(1.0), NonTensorData(2.0), NonTensorData(3.0)])
tu.assign_non_tensor(td, wrapped=wrapped, stack=stack)

assert td["wrapped"] == 0.3
assert td["stack"] == [1.0, 2.0, 3.0]


def test_assign_non_tensor_stack_batch_size_check():
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})
stack = NonTensorStack.from_list([NonTensorData(1.0), NonTensorData(2.0)])

with pytest.raises(RuntimeError):
tu.assign_non_tensor(td, stack=stack)


def test_assign_non_tensor_with_auto_detection():
"""Test assign_non_tensor automatically detects and handles nested structures."""
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})

# Mix of simple and nested data
tu.assign_non_tensor(
td,
metadata="experiment_1", # Simple value
turn_scores=[[], [0.5, 0.8], [0.9]], # Nested list
reward_extra_info=[{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}], # List of dicts
simple_list=["a", "b", "c"], # Simple list (also uses NonTensorStack for consistency)
)

# Verify all data is accessible
assert td["metadata"] == "experiment_1"
assert len(td["turn_scores"]) == 3
assert list(td["turn_scores"][1]) == [0.5, 0.8]
assert len(td["reward_extra_info"]) == 3
assert dict(td["reward_extra_info"][0]) == {"acc": 1.0}
assert len(td["simple_list"]) == 3
assert td["simple_list"][0] == "a"


def test_get_tensordict_with_nested_lists():
"""Test get_tensordict automatically handles nested lists."""
obs = torch.randn(3, 4)
turn_scores = [[], [0.5, 0.8], [0.9]]

# This should automatically convert turn_scores to NonTensorStack
td = tu.get_tensordict({"obs": obs, "turn_scores": turn_scores})

# Verify tensors and nested data are both accessible
assert torch.all(torch.eq(td["obs"], obs))
assert len(td["turn_scores"]) == 3
assert list(td["turn_scores"][0]) == []
assert list(td["turn_scores"][1]) == [0.5, 0.8]


def test_get_tensordict_with_nested_dicts():
"""Test get_tensordict automatically handles lists of dicts."""
obs = torch.randn(3, 4)
reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}]

td = tu.get_tensordict({"obs": obs, "reward_extra_info": reward_extra_info})

assert torch.all(torch.eq(td["obs"], obs))
assert len(td["reward_extra_info"]) == 3
assert dict(td["reward_extra_info"][0]) == {"acc": 1.0}


def test_get_tensordict_with_complex_nested_structures():
"""Test get_tensordict with lists of lists of dicts."""
obs = torch.randn(2, 4)
raw_prompt = [
[{"content": "Q1", "role": "user"}],
[{"content": "Q2", "role": "user"}, {"content": "A2", "role": "assistant"}],
]

td = tu.get_tensordict({"obs": obs, "raw_prompt": raw_prompt})

assert torch.all(torch.eq(td["obs"], obs))
assert len(td["raw_prompt"]) == 2
assert dict(td["raw_prompt"][0][0]) == {"content": "Q1", "role": "user"}


def test_get_tensordict_agent_loop_scenario():
"""Test the complete agent loop scenario with all nested types.

This simulates the exact use case from agent loops with:
- turn_scores: lists of lists
- reward_extra_info: lists of dicts
- raw_prompt: lists of lists of dicts
- tool_rewards: lists of lists
"""
prompts = torch.randn(2, 10)
responses = torch.randn(2, 5)

# Nested structures from agent loop
data_source = ["lighteval/MATH", "lighteval/MATH"]
uid = ["uuid-1", "uuid-2"]
turn_scores = [[], [0.5, 0.8]] # Lists of varying lengths
reward_extra_info = [{"acc": 1.0, "loss": 0.1}, {"acc": 0.0, "loss": 0.9}]
raw_prompt = [
[{"content": "Compute 4 @ 2", "role": "user"}],
[{"content": "Compute 8 @ 7", "role": "user"}],
]
tool_rewards = [[0.0], []] # List of lists

# This should handle all nested structures automatically
td = tu.get_tensordict(
tensor_dict={
"prompts": prompts,
"responses": responses,
"data_source": data_source,
"uid": uid,
"turn_scores": turn_scores,
"reward_extra_info": reward_extra_info,
"raw_prompt": raw_prompt,
"tool_rewards": tool_rewards,
},
non_tensor_dict={"global_steps": 42},
)

# Verify all data types are accessible
assert torch.all(torch.eq(td["prompts"], prompts))
assert torch.all(torch.eq(td["responses"], responses))
assert td["data_source"] == data_source
assert td["uid"] == uid

# Verify nested structures
assert len(td["turn_scores"]) == 2
assert list(td["turn_scores"][0]) == []
assert list(td["turn_scores"][1]) == [0.5, 0.8]

assert len(td["reward_extra_info"]) == 2
assert dict(td["reward_extra_info"][0]) == {"acc": 1.0, "loss": 0.1}

assert len(td["raw_prompt"]) == 2
assert dict(td["raw_prompt"][0][0]) == {"content": "Compute 4 @ 2", "role": "user"}

assert len(td["tool_rewards"]) == 2
assert list(td["tool_rewards"][0]) == [0.0]
assert list(td["tool_rewards"][1]) == []

# Verify metadata
assert td["global_steps"] == 42
96 changes: 83 additions & 13 deletions verl/utils/tensordict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,71 @@
from tensordict.tensorclass import NonTensorData, NonTensorStack


def assign_non_tensor_dict(tensor_dict: TensorDict, non_tensor_dict: dict):
for key, val in non_tensor_dict.items():
assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val)
return tensor_dict


def assign_non_tensor_data(tensor_dict: TensorDict, key, val):
assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict"
tensor_dict[key] = NonTensorData(val)


def assign_non_tensor(tensordict: TensorDict, **kwargs):
def assign_non_tensor_stack(tensor_dict: TensorDict, key, val: list):
"""Assign a list with potentially nested structures (lists, dicts, etc.) to TensorDict.

This function handles complex nested data structures like:
- Lists of lists: [[], [0.5, 0.8], [0.9]]
- Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}]
- Lists of lists of dicts: [[{"content": "...", "role": "user"}]]

These structures are wrapped in NonTensorStack so TensorDict can handle them correctly.

Args:
tensor_dict: The TensorDict to assign to
key: The key to assign the value under
val: A list containing potentially nested structures

Example:
>>> td = TensorDict({}, batch_size=[])
>>> turn_scores = [[], [0.5, 0.8], [0.9]]
>>> assign_non_tensor_stack(td, "turn_scores", turn_scores)
>>> # Now td["turn_scores"] contains the nested data
"""
# Convert list to NonTensorStack to handle nested structures
# This wraps each item in NonTensorData to preserve complex objects
# TODO(petersh6): can convert back to val directly if we are not accessing .data from the NonTensorStack
assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict"
tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val])


def assign_non_tensor(tensor_dict: TensorDict, **kwargs):
"""Assign non-tensor data to a TensorDict.

Automatically detects if the value is a list with nested structures and uses
the appropriate assignment method (NonTensorData for simple values,
NonTensorStack for lists with nested structures).

Args:
tensor_dict: The TensorDict to assign to
**kwargs: Key-value pairs where values can be:
- Simple values (stored as NonTensorData)
- Lists with nested structures (stored as NonTensorStack)

Example:
>>> td = TensorDict({"obs": torch.randn(3, 4)}, batch_size=[3])
>>> assign_non_tensor(
... tensor_dict=td,
... metadata="experiment_1", # Simple value
... turn_scores=[[], [0.5, 0.8], [0.9]] # Nested list
... )
"""
assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict"
for key, val in kwargs.items():
assign_non_tensor_data(tensor_dict=tensordict, key=key, val=val)
return tensordict
if isinstance(val, (NonTensorData | NonTensorStack)):
tensor_dict[key] = val
elif isinstance(val, list):
# For lists, use NonTensorStack
assign_non_tensor_stack(tensor_dict=tensor_dict, key=key, val=val)
else:
# For non-list values, use NonTensorData
assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val)
return tensor_dict


def unwrap_non_tensor_data(data):
Expand Down Expand Up @@ -92,15 +143,31 @@ def concat_tensordict(data: list[TensorDict]) -> TensorDict:


def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:
"""
"""Create a TensorDict from tensors and non-tensor data.

Automatically handles nested structures in lists by converting them to NonTensorStack.
This enables support for:
- Lists of lists: [[], [0.5, 0.8], [0.9]]
- Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}]
- Lists of lists of dicts: [[{"content": "...", "role": "user"}]]

Args:
data_dict:
meta_info:
tensor_dict: Dictionary of tensors and lists to include in the TensorDict
non_tensor_dict: Dictionary of metadata to store as NonTensorData

Returns:

TensorDict with proper handling of nested structures

Example:
>>> td = get_tensordict(
... tensor_dict={
... "obs": torch.randn(3, 4),
... "turn_scores": [[], [0.5, 0.8], [0.9]] # Nested list
... },
... non_tensor_dict={"experiment": "test"}
... )
"""
tensor_dict = tensor_dict.copy()
if non_tensor_dict is None:
non_tensor_dict = {}

Expand All @@ -127,6 +194,9 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict:
"Passing a list makes the data NonTensorStack, "
"which doesn't support torch.Tensor. Please convert to numpy first"
)
# Convert to NonTensorStack to handle nested structures
tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val])

assert isinstance(val, torch.Tensor | list)

if batch_size is None:
Expand Down
Loading