-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[misc] fix: support nested datastructure in dataproto to convert to tensordict #4296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -779,6 +779,217 @@ def test_from_tensordict(): | |||||
| assert data.meta_info["name"] == "abdce" | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.skipif( | ||||||
| parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" | ||||||
| ) | ||||||
| def test_to_tensordict_with_nested_lists(): | ||||||
| """Test converting DataProto with nested lists to TensorDict (lists of lists).""" | ||||||
| obs = torch.tensor([1, 2, 3]) | ||||||
| # Simulate turn_scores or tool_rewards: array of lists with varying lengths | ||||||
| turn_scores = [[], [0.5, 0.8], [0.9]] | ||||||
|
|
||||||
| data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"turn_scores": turn_scores}) | ||||||
|
|
||||||
| # This should not raise an error | ||||||
| tensordict_output = data.to_tensordict() | ||||||
|
|
||||||
| # Verify the data is preserved | ||||||
| assert torch.all(torch.eq(tensordict_output["obs"], obs)).item() | ||||||
| # Verify nested structure is accessible (TensorDict wraps NonTensorStack as LinkedList) | ||||||
| retrieved_scores = tensordict_output["turn_scores"] | ||||||
| assert len(retrieved_scores) == len(turn_scores) | ||||||
| # Verify content matches | ||||||
| assert list(retrieved_scores[0]) == [] | ||||||
| assert list(retrieved_scores[1]) == [0.5, 0.8] | ||||||
| assert list(retrieved_scores[2]) == [0.9] | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.skipif( | ||||||
| parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" | ||||||
| ) | ||||||
| def test_to_tensordict_with_nested_dicts(): | ||||||
| """Test converting DataProto with lists of dicts to TensorDict.""" | ||||||
| obs = torch.tensor([1, 2, 3]) | ||||||
| # Simulate reward_extra_info: array of dicts | ||||||
| reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}] | ||||||
|
|
||||||
| data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"reward_extra_info": reward_extra_info}) | ||||||
|
|
||||||
| # This should not raise an error - this was the original bug | ||||||
| tensordict_output = data.to_tensordict() | ||||||
|
|
||||||
| # Verify the data is preserved | ||||||
| assert torch.all(torch.eq(tensordict_output["obs"], obs)).item() | ||||||
| # Verify nested dicts are accessible | ||||||
| retrieved_info = tensordict_output["reward_extra_info"] | ||||||
| assert len(retrieved_info) == len(reward_extra_info) | ||||||
| # Verify content matches | ||||||
| for i, expected_dict in enumerate(reward_extra_info): | ||||||
| assert dict(retrieved_info[i]) == expected_dict | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.skipif( | ||||||
| parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" | ||||||
| ) | ||||||
| def test_to_tensordict_with_complex_nested_structures(): | ||||||
| """Test converting DataProto with complex nested structures (lists of lists of dicts).""" | ||||||
| obs = torch.tensor([1, 2, 3]) | ||||||
| # Simulate raw_prompt: array of lists containing dicts | ||||||
| raw_prompt = [ | ||||||
| [{"content": "Question 1", "role": "user"}], | ||||||
| [{"content": "Question 2", "role": "user"}, {"content": "Answer 2", "role": "assistant"}], | ||||||
| [{"content": "Question 3", "role": "user"}], | ||||||
| ] | ||||||
|
|
||||||
| data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"raw_prompt": raw_prompt}) | ||||||
|
|
||||||
| # This should not raise an error | ||||||
| tensordict_output = data.to_tensordict() | ||||||
|
|
||||||
| # Verify the data is preserved | ||||||
| assert torch.all(torch.eq(tensordict_output["obs"], obs)).item() | ||||||
| # Verify complex nested structure is accessible | ||||||
| retrieved_prompt = tensordict_output["raw_prompt"] | ||||||
| assert len(retrieved_prompt) == len(raw_prompt) | ||||||
| # Spot check: verify first prompt has correct structure | ||||||
| assert len(retrieved_prompt[0]) == 1 | ||||||
| assert dict(retrieved_prompt[0][0]) == {"content": "Question 1", "role": "user"} | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.skipif( | ||||||
| parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" | ||||||
| ) | ||||||
| def test_to_tensordict_and_back_with_nested_data(): | ||||||
| """Test round-trip conversion: DataProto → TensorDict → DataProto with nested structures.""" | ||||||
| obs = torch.tensor([1, 2, 3, 4]) | ||||||
| labels = ["a", "b", "c", "d"] | ||||||
|
|
||||||
| # Multiple types of nested structures | ||||||
| turn_scores = [[], [0.5], [0.8, 0.9], [0.7]] | ||||||
| reward_extra_info = [ | ||||||
| {"acc": 1.0, "loss": 0.1}, | ||||||
| {"acc": 0.5, "loss": 0.3}, | ||||||
| {"acc": 1.0, "loss": 0.05}, | ||||||
| {"acc": 0.0, "loss": 0.9}, | ||||||
| ] | ||||||
| raw_prompt = [ | ||||||
| [{"content": "Q1", "role": "user"}], | ||||||
| [{"content": "Q2", "role": "user"}], | ||||||
| [{"content": "Q3", "role": "user"}, {"content": "A3", "role": "assistant"}], | ||||||
| [{"content": "Q4", "role": "user"}], | ||||||
| ] | ||||||
|
|
||||||
| # Create original DataProto | ||||||
| original_data = DataProto.from_dict( | ||||||
| tensors={"obs": obs}, | ||||||
| non_tensors={ | ||||||
| "labels": labels, | ||||||
| "turn_scores": turn_scores, | ||||||
| "reward_extra_info": reward_extra_info, | ||||||
| "raw_prompt": raw_prompt, | ||||||
| }, | ||||||
| meta_info={"experiment": "test_nested"}, | ||||||
| ) | ||||||
|
|
||||||
| # Convert to TensorDict | ||||||
| tensordict_output = original_data.to_tensordict() | ||||||
|
|
||||||
| # Convert back to DataProto | ||||||
| reconstructed_data = DataProto.from_tensordict(tensordict_output) | ||||||
|
|
||||||
| # Verify tensors are preserved | ||||||
| assert torch.all(torch.eq(reconstructed_data.batch["obs"], obs)).item() | ||||||
|
|
||||||
| # Verify non-tensor data is preserved | ||||||
| assert reconstructed_data.non_tensor_batch["labels"].tolist() == labels | ||||||
|
|
||||||
| # Verify nested structures are preserved | ||||||
| assert len(reconstructed_data.non_tensor_batch["turn_scores"]) == len(turn_scores) | ||||||
| for orig, recon in zip(turn_scores, reconstructed_data.non_tensor_batch["turn_scores"], strict=True): | ||||||
| assert list(orig) == list(recon) | ||||||
|
|
||||||
| assert len(reconstructed_data.non_tensor_batch["reward_extra_info"]) == len(reward_extra_info) | ||||||
| for orig, recon in zip(reward_extra_info, reconstructed_data.non_tensor_batch["reward_extra_info"], strict=True): | ||||||
| assert orig == recon | ||||||
|
|
||||||
| assert len(reconstructed_data.non_tensor_batch["raw_prompt"]) == len(raw_prompt) | ||||||
| for orig, recon in zip(raw_prompt, reconstructed_data.non_tensor_batch["raw_prompt"], strict=True): | ||||||
| assert orig == list(recon) | ||||||
|
|
||||||
| # Verify meta_info is preserved | ||||||
| assert reconstructed_data.meta_info["experiment"] == "test_nested" | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.skipif( | ||||||
| parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" | ||||||
| ) | ||||||
| def test_to_tensordict_agent_loop_scenario(): | ||||||
| """Test the exact scenario from agent loop: DataProto with tool rewards, acc, etc. | ||||||
|
|
||||||
| This test reproduces the exact error from the agent loop where nested structures | ||||||
| (lists of lists, lists of dicts) failed to convert to TensorDict. | ||||||
| """ | ||||||
| # Simulate real agent loop data structure | ||||||
| prompts = torch.tensor([[1, 2, 3], [4, 5, 6]]) | ||||||
| responses = torch.tensor([[7, 8], [9, 10]]) | ||||||
|
|
||||||
| # Non-tensor data with nested structures from agent loop | ||||||
| data_source = ["lighteval/MATH", "lighteval/MATH"] | ||||||
| uid = ["uuid-1", "uuid-2"] | ||||||
| num_turns = np.array([2, 4], dtype=np.int32) | ||||||
| acc = np.array([1.0, 0.0]) | ||||||
| turn_scores = [[], [0.5, 0.8]] # Lists of varying lengths | ||||||
| reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}] # List of dicts | ||||||
| raw_prompt = [ | ||||||
| [{"content": "Compute 4 @ 2", "role": "user"}], | ||||||
| [{"content": "Compute 8 @ 7", "role": "user"}], | ||||||
| ] | ||||||
| tool_rewards = [[0.0], []] # List of lists | ||||||
|
|
||||||
| data = DataProto.from_dict( | ||||||
| tensors={"prompts": prompts, "responses": responses}, | ||||||
| non_tensors={ | ||||||
| "data_source": data_source, | ||||||
| "uid": uid, | ||||||
| "num_turns": num_turns, | ||||||
| "acc": acc, | ||||||
| "turn_scores": turn_scores, | ||||||
| "reward_extra_info": reward_extra_info, | ||||||
| "raw_prompt": raw_prompt, | ||||||
| "tool_rewards": tool_rewards, | ||||||
| }, | ||||||
| meta_info={"global_steps": 42}, | ||||||
| ) | ||||||
|
|
||||||
| # THE KEY TEST: This should not raise ValueError about TensorDict conversion | ||||||
| tensordict_output = data.to_tensordict() | ||||||
|
|
||||||
| # Verify tensors are accessible | ||||||
| assert torch.all(torch.eq(tensordict_output["prompts"], prompts)).item() | ||||||
| assert torch.all(torch.eq(tensordict_output["responses"], responses)).item() | ||||||
|
|
||||||
| # Verify all nested structures are accessible (content check, not type check) | ||||||
| assert len(tensordict_output["turn_scores"]) == 2 | ||||||
| assert list(tensordict_output["turn_scores"][0]) == [] | ||||||
| assert list(tensordict_output["turn_scores"][1]) == [0.5, 0.8] | ||||||
|
|
||||||
| assert len(tensordict_output["reward_extra_info"]) == 2 | ||||||
| assert dict(tensordict_output["reward_extra_info"][0]) == {"acc": 1.0} | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assertion is incorrect for the same reason as noted in another comment. The expression To correctly verify the content, you should compare the
Suggested change
|
||||||
|
|
||||||
| assert len(tensordict_output["raw_prompt"]) == 2 | ||||||
| assert dict(tensordict_output["raw_prompt"][0][0]) == {"content": "Compute 4 @ 2", "role": "user"} | ||||||
|
|
||||||
| assert len(tensordict_output["tool_rewards"]) == 2 | ||||||
| assert list(tensordict_output["tool_rewards"][0]) == [0.0] | ||||||
| assert list(tensordict_output["tool_rewards"][1]) == [] | ||||||
|
|
||||||
| # Verify round-trip conversion works perfectly | ||||||
| reconstructed = DataProto.from_tensordict(tensordict_output) | ||||||
| assert len(reconstructed) == 2 | ||||||
| assert reconstructed.meta_info["global_steps"] == 42 | ||||||
| assert torch.all(torch.eq(reconstructed.batch["prompts"], prompts)).item() | ||||||
|
|
||||||
|
|
||||||
| def test_serialize_deserialize_single_tensor(): | ||||||
| """Test serialization and deserialization of a single tensor""" | ||||||
| # Create test tensor | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion
dict(retrieved_info[i]) == expected_dictis incorrect.retrieved_info[i]is atensordict.NonTensorDataobject that wraps a dictionary. Callingdict()on this object will raise aValueErrorbecause its iterator yields only the keys of the wrapped dictionary, not the required key-value pairs.The correct way to perform this comparison is to either access the underlying data via the
.dataattribute or, more idiomatically, rely on the overloaded__eq__method of theNonTensorDataobject, which compares its content directly.