Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions tests/test_protocol_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,217 @@ def test_from_tensordict():
assert data.meta_info["name"] == "abdce"


@pytest.mark.skipif(
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
)
def test_to_tensordict_with_nested_lists():
"""Test converting DataProto with nested lists to TensorDict (lists of lists)."""
obs = torch.tensor([1, 2, 3])
# Simulate turn_scores or tool_rewards: array of lists with varying lengths
turn_scores = [[], [0.5, 0.8], [0.9]]

data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"turn_scores": turn_scores})

# This should not raise an error
tensordict_output = data.to_tensordict()

# Verify the data is preserved
assert torch.all(torch.eq(tensordict_output["obs"], obs)).item()
# Verify nested structure is accessible (TensorDict wraps NonTensorStack as LinkedList)
retrieved_scores = tensordict_output["turn_scores"]
assert len(retrieved_scores) == len(turn_scores)
# Verify content matches
assert list(retrieved_scores[0]) == []
assert list(retrieved_scores[1]) == [0.5, 0.8]
assert list(retrieved_scores[2]) == [0.9]


@pytest.mark.skipif(
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
)
def test_to_tensordict_with_nested_dicts():
"""Test converting DataProto with lists of dicts to TensorDict."""
obs = torch.tensor([1, 2, 3])
# Simulate reward_extra_info: array of dicts
reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}]

data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"reward_extra_info": reward_extra_info})

# This should not raise an error - this was the original bug
tensordict_output = data.to_tensordict()

# Verify the data is preserved
assert torch.all(torch.eq(tensordict_output["obs"], obs)).item()
# Verify nested dicts are accessible
retrieved_info = tensordict_output["reward_extra_info"]
assert len(retrieved_info) == len(reward_extra_info)
# Verify content matches
for i, expected_dict in enumerate(reward_extra_info):
assert dict(retrieved_info[i]) == expected_dict
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion dict(retrieved_info[i]) == expected_dict is incorrect. retrieved_info[i] is a tensordict.NonTensorData object that wraps a dictionary. Calling dict() on this object will raise a ValueError because its iterator yields only the keys of the wrapped dictionary, not the required key-value pairs.

The correct way to perform this comparison is to either access the underlying data via the .data attribute or, more idiomatically, rely on the overloaded __eq__ method of the NonTensorData object, which compares its content directly.

Suggested change
assert dict(retrieved_info[i]) == expected_dict
assert retrieved_info[i] == expected_dict



@pytest.mark.skipif(
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
)
def test_to_tensordict_with_complex_nested_structures():
"""Test converting DataProto with complex nested structures (lists of lists of dicts)."""
obs = torch.tensor([1, 2, 3])
# Simulate raw_prompt: array of lists containing dicts
raw_prompt = [
[{"content": "Question 1", "role": "user"}],
[{"content": "Question 2", "role": "user"}, {"content": "Answer 2", "role": "assistant"}],
[{"content": "Question 3", "role": "user"}],
]

data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"raw_prompt": raw_prompt})

# This should not raise an error
tensordict_output = data.to_tensordict()

# Verify the data is preserved
assert torch.all(torch.eq(tensordict_output["obs"], obs)).item()
# Verify complex nested structure is accessible
retrieved_prompt = tensordict_output["raw_prompt"]
assert len(retrieved_prompt) == len(raw_prompt)
# Spot check: verify first prompt has correct structure
assert len(retrieved_prompt[0]) == 1
assert dict(retrieved_prompt[0][0]) == {"content": "Question 1", "role": "user"}


@pytest.mark.skipif(
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
)
def test_to_tensordict_and_back_with_nested_data():
"""Test round-trip conversion: DataProto → TensorDict → DataProto with nested structures."""
obs = torch.tensor([1, 2, 3, 4])
labels = ["a", "b", "c", "d"]

# Multiple types of nested structures
turn_scores = [[], [0.5], [0.8, 0.9], [0.7]]
reward_extra_info = [
{"acc": 1.0, "loss": 0.1},
{"acc": 0.5, "loss": 0.3},
{"acc": 1.0, "loss": 0.05},
{"acc": 0.0, "loss": 0.9},
]
raw_prompt = [
[{"content": "Q1", "role": "user"}],
[{"content": "Q2", "role": "user"}],
[{"content": "Q3", "role": "user"}, {"content": "A3", "role": "assistant"}],
[{"content": "Q4", "role": "user"}],
]

# Create original DataProto
original_data = DataProto.from_dict(
tensors={"obs": obs},
non_tensors={
"labels": labels,
"turn_scores": turn_scores,
"reward_extra_info": reward_extra_info,
"raw_prompt": raw_prompt,
},
meta_info={"experiment": "test_nested"},
)

# Convert to TensorDict
tensordict_output = original_data.to_tensordict()

# Convert back to DataProto
reconstructed_data = DataProto.from_tensordict(tensordict_output)

# Verify tensors are preserved
assert torch.all(torch.eq(reconstructed_data.batch["obs"], obs)).item()

# Verify non-tensor data is preserved
assert reconstructed_data.non_tensor_batch["labels"].tolist() == labels

# Verify nested structures are preserved
assert len(reconstructed_data.non_tensor_batch["turn_scores"]) == len(turn_scores)
for orig, recon in zip(turn_scores, reconstructed_data.non_tensor_batch["turn_scores"], strict=True):
assert list(orig) == list(recon)

assert len(reconstructed_data.non_tensor_batch["reward_extra_info"]) == len(reward_extra_info)
for orig, recon in zip(reward_extra_info, reconstructed_data.non_tensor_batch["reward_extra_info"], strict=True):
assert orig == recon

assert len(reconstructed_data.non_tensor_batch["raw_prompt"]) == len(raw_prompt)
for orig, recon in zip(raw_prompt, reconstructed_data.non_tensor_batch["raw_prompt"], strict=True):
assert orig == list(recon)

# Verify meta_info is preserved
assert reconstructed_data.meta_info["experiment"] == "test_nested"


@pytest.mark.skipif(
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
)
def test_to_tensordict_agent_loop_scenario():
"""Test the exact scenario from agent loop: DataProto with tool rewards, acc, etc.

This test reproduces the exact error from the agent loop where nested structures
(lists of lists, lists of dicts) failed to convert to TensorDict.
"""
# Simulate real agent loop data structure
prompts = torch.tensor([[1, 2, 3], [4, 5, 6]])
responses = torch.tensor([[7, 8], [9, 10]])

# Non-tensor data with nested structures from agent loop
data_source = ["lighteval/MATH", "lighteval/MATH"]
uid = ["uuid-1", "uuid-2"]
num_turns = np.array([2, 4], dtype=np.int32)
acc = np.array([1.0, 0.0])
turn_scores = [[], [0.5, 0.8]] # Lists of varying lengths
reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}] # List of dicts
raw_prompt = [
[{"content": "Compute 4 @ 2", "role": "user"}],
[{"content": "Compute 8 @ 7", "role": "user"}],
]
tool_rewards = [[0.0], []] # List of lists

data = DataProto.from_dict(
tensors={"prompts": prompts, "responses": responses},
non_tensors={
"data_source": data_source,
"uid": uid,
"num_turns": num_turns,
"acc": acc,
"turn_scores": turn_scores,
"reward_extra_info": reward_extra_info,
"raw_prompt": raw_prompt,
"tool_rewards": tool_rewards,
},
meta_info={"global_steps": 42},
)

# THE KEY TEST: This should not raise ValueError about TensorDict conversion
tensordict_output = data.to_tensordict()

# Verify tensors are accessible
assert torch.all(torch.eq(tensordict_output["prompts"], prompts)).item()
assert torch.all(torch.eq(tensordict_output["responses"], responses)).item()

# Verify all nested structures are accessible (content check, not type check)
assert len(tensordict_output["turn_scores"]) == 2
assert list(tensordict_output["turn_scores"][0]) == []
assert list(tensordict_output["turn_scores"][1]) == [0.5, 0.8]

assert len(tensordict_output["reward_extra_info"]) == 2
assert dict(tensordict_output["reward_extra_info"][0]) == {"acc": 1.0}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion is incorrect for the same reason as noted in another comment. The expression tensordict_output["reward_extra_info"][0] returns a tensordict.NonTensorData object wrapping a dictionary. Attempting to cast it to a dict using dict() will fail.

To correctly verify the content, you should compare the NonTensorData object directly with the expected dictionary. The object's __eq__ method is overloaded to handle this comparison correctly.

Suggested change
assert dict(tensordict_output["reward_extra_info"][0]) == {"acc": 1.0}
assert tensordict_output["reward_extra_info"][0] == {"acc": 1.0}


assert len(tensordict_output["raw_prompt"]) == 2
assert dict(tensordict_output["raw_prompt"][0][0]) == {"content": "Compute 4 @ 2", "role": "user"}

assert len(tensordict_output["tool_rewards"]) == 2
assert list(tensordict_output["tool_rewards"][0]) == [0.0]
assert list(tensordict_output["tool_rewards"][1]) == []

# Verify round-trip conversion works perfectly
reconstructed = DataProto.from_tensordict(tensordict_output)
assert len(reconstructed) == 2
assert reconstructed.meta_info["global_steps"] == 42
assert torch.all(torch.eq(reconstructed.batch["prompts"], prompts)).item()


def test_serialize_deserialize_single_tensor():
"""Test serialization and deserialization of a single tensor"""
# Create test tensor
Expand Down
5 changes: 4 additions & 1 deletion verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,14 +1118,17 @@ def to_tensordict(self) -> TensorDict:
tensor_batch = self.batch.to_dict()
non_tensor_batch = self.non_tensor_batch

from tensordict.tensorclass import NonTensorData, NonTensorStack

from verl.utils import tensordict_utils as tu

common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys())
assert len(common_keys) == 0, f"tensor_batch and non_tensor_batch have common keys {common_keys}"

for key, val in non_tensor_batch.items():
assert isinstance(val, np.ndarray)
tensor_batch[key] = val.tolist()
# Convert to NonTensorStack instead of plain list to handle nested structures
tensor_batch[key] = NonTensorStack.from_list([NonTensorData(item) for item in val])
output = tu.get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info)
return output

Expand Down
11 changes: 11 additions & 0 deletions verl/utils/tensordict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict:
if isinstance(val, torch.Tensor) and val.is_nested:
assert val.is_contiguous(), "Nested tensors must be contiguous. Try setting layout=torch.jagged"

# Skip validation for NonTensorStack as it's already properly formatted
if isinstance(val, NonTensorStack):
if batch_size is None:
batch_size = len(val)
else:
assert len(val) == batch_size, (
f"Batch size of NonTensorStack {key} is not consistent with other tensors. "
f"Expected {batch_size}, got {len(val)}"
)
continue

if isinstance(val, list):
for v in val:
assert not isinstance(v, torch.Tensor), (
Expand Down
Loading