diff --git a/tests/test_protocol_on_cpu.py b/tests/test_protocol_on_cpu.py index 099fde53a3f..1874e834f0a 100644 --- a/tests/test_protocol_on_cpu.py +++ b/tests/test_protocol_on_cpu.py @@ -779,6 +779,217 @@ def test_from_tensordict(): assert data.meta_info["name"] == "abdce" +@pytest.mark.skipif( + parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" +) +def test_to_tensordict_with_nested_lists(): + """Test converting DataProto with nested lists to TensorDict (lists of lists).""" + obs = torch.tensor([1, 2, 3]) + # Simulate turn_scores or tool_rewards: array of lists with varying lengths + turn_scores = [[], [0.5, 0.8], [0.9]] + + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"turn_scores": turn_scores}) + + # This should not raise an error + tensordict_output = data.to_tensordict() + + # Verify the data is preserved + assert torch.all(torch.eq(tensordict_output["obs"], obs)).item() + # Verify nested structure is accessible (TensorDict wraps NonTensorStack as LinkedList) + retrieved_scores = tensordict_output["turn_scores"] + assert len(retrieved_scores) == len(turn_scores) + # Verify content matches + assert list(retrieved_scores[0]) == [] + assert list(retrieved_scores[1]) == [0.5, 0.8] + assert list(retrieved_scores[2]) == [0.9] + + +@pytest.mark.skipif( + parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" +) +def test_to_tensordict_with_nested_dicts(): + """Test converting DataProto with lists of dicts to TensorDict.""" + obs = torch.tensor([1, 2, 3]) + # Simulate reward_extra_info: array of dicts + reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}] + + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"reward_extra_info": reward_extra_info}) + + # This should not raise an error - this was the original bug + tensordict_output = data.to_tensordict() + + # Verify the data is preserved + assert torch.all(torch.eq(tensordict_output["obs"], obs)).item() + # Verify nested dicts are accessible + retrieved_info = tensordict_output["reward_extra_info"] + assert len(retrieved_info) == len(reward_extra_info) + # Verify content matches + for i, expected_dict in enumerate(reward_extra_info): + assert dict(retrieved_info[i]) == expected_dict + + +@pytest.mark.skipif( + parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" +) +def test_to_tensordict_with_complex_nested_structures(): + """Test converting DataProto with complex nested structures (lists of lists of dicts).""" + obs = torch.tensor([1, 2, 3]) + # Simulate raw_prompt: array of lists containing dicts + raw_prompt = [ + [{"content": "Question 1", "role": "user"}], + [{"content": "Question 2", "role": "user"}, {"content": "Answer 2", "role": "assistant"}], + [{"content": "Question 3", "role": "user"}], + ] + + data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"raw_prompt": raw_prompt}) + + # This should not raise an error + tensordict_output = data.to_tensordict() + + # Verify the data is preserved + assert torch.all(torch.eq(tensordict_output["obs"], obs)).item() + # Verify complex nested structure is accessible + retrieved_prompt = tensordict_output["raw_prompt"] + assert len(retrieved_prompt) == len(raw_prompt) + # Spot check: verify first prompt has correct structure + assert len(retrieved_prompt[0]) == 1 + assert dict(retrieved_prompt[0][0]) == {"content": "Question 1", "role": "user"} + + +@pytest.mark.skipif( + parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" +) +def test_to_tensordict_and_back_with_nested_data(): + """Test round-trip conversion: DataProto → TensorDict → DataProto with nested structures.""" + obs = torch.tensor([1, 2, 3, 4]) + labels = ["a", "b", "c", "d"] + + # Multiple types of nested structures + turn_scores = [[], [0.5], [0.8, 0.9], [0.7]] + reward_extra_info = [ + {"acc": 1.0, "loss": 0.1}, + {"acc": 0.5, "loss": 0.3}, + {"acc": 1.0, "loss": 0.05}, + {"acc": 0.0, "loss": 0.9}, + ] + raw_prompt = [ + [{"content": "Q1", "role": "user"}], + [{"content": "Q2", "role": "user"}], + [{"content": "Q3", "role": "user"}, {"content": "A3", "role": "assistant"}], + [{"content": "Q4", "role": "user"}], + ] + + # Create original DataProto + original_data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={ + "labels": labels, + "turn_scores": turn_scores, + "reward_extra_info": reward_extra_info, + "raw_prompt": raw_prompt, + }, + meta_info={"experiment": "test_nested"}, + ) + + # Convert to TensorDict + tensordict_output = original_data.to_tensordict() + + # Convert back to DataProto + reconstructed_data = DataProto.from_tensordict(tensordict_output) + + # Verify tensors are preserved + assert torch.all(torch.eq(reconstructed_data.batch["obs"], obs)).item() + + # Verify non-tensor data is preserved + assert reconstructed_data.non_tensor_batch["labels"].tolist() == labels + + # Verify nested structures are preserved + assert len(reconstructed_data.non_tensor_batch["turn_scores"]) == len(turn_scores) + for orig, recon in zip(turn_scores, reconstructed_data.non_tensor_batch["turn_scores"], strict=True): + assert list(orig) == list(recon) + + assert len(reconstructed_data.non_tensor_batch["reward_extra_info"]) == len(reward_extra_info) + for orig, recon in zip(reward_extra_info, reconstructed_data.non_tensor_batch["reward_extra_info"], strict=True): + assert orig == recon + + assert len(reconstructed_data.non_tensor_batch["raw_prompt"]) == len(raw_prompt) + for orig, recon in zip(raw_prompt, reconstructed_data.non_tensor_batch["raw_prompt"], strict=True): + assert orig == list(recon) + + # Verify meta_info is preserved + assert reconstructed_data.meta_info["experiment"] == "test_nested" + + +@pytest.mark.skipif( + parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" +) +def test_to_tensordict_agent_loop_scenario(): + """Test the exact scenario from agent loop: DataProto with tool rewards, acc, etc. + + This test reproduces the exact error from the agent loop where nested structures + (lists of lists, lists of dicts) failed to convert to TensorDict. + """ + # Simulate real agent loop data structure + prompts = torch.tensor([[1, 2, 3], [4, 5, 6]]) + responses = torch.tensor([[7, 8], [9, 10]]) + + # Non-tensor data with nested structures from agent loop + data_source = ["lighteval/MATH", "lighteval/MATH"] + uid = ["uuid-1", "uuid-2"] + num_turns = np.array([2, 4], dtype=np.int32) + acc = np.array([1.0, 0.0]) + turn_scores = [[], [0.5, 0.8]] # Lists of varying lengths + reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}] # List of dicts + raw_prompt = [ + [{"content": "Compute 4 @ 2", "role": "user"}], + [{"content": "Compute 8 @ 7", "role": "user"}], + ] + tool_rewards = [[0.0], []] # List of lists + + data = DataProto.from_dict( + tensors={"prompts": prompts, "responses": responses}, + non_tensors={ + "data_source": data_source, + "uid": uid, + "num_turns": num_turns, + "acc": acc, + "turn_scores": turn_scores, + "reward_extra_info": reward_extra_info, + "raw_prompt": raw_prompt, + "tool_rewards": tool_rewards, + }, + meta_info={"global_steps": 42}, + ) + + # THE KEY TEST: This should not raise ValueError about TensorDict conversion + tensordict_output = data.to_tensordict() + + # Verify tensors are accessible + assert torch.all(torch.eq(tensordict_output["prompts"], prompts)).item() + assert torch.all(torch.eq(tensordict_output["responses"], responses)).item() + + # Verify all nested structures are accessible (content check, not type check) + assert len(tensordict_output["turn_scores"]) == 2 + assert list(tensordict_output["turn_scores"][0]) == [] + assert list(tensordict_output["turn_scores"][1]) == [0.5, 0.8] + + assert len(tensordict_output["reward_extra_info"]) == 2 + assert dict(tensordict_output["reward_extra_info"][0]) == {"acc": 1.0} + + assert len(tensordict_output["raw_prompt"]) == 2 + assert dict(tensordict_output["raw_prompt"][0][0]) == {"content": "Compute 4 @ 2", "role": "user"} + + assert len(tensordict_output["tool_rewards"]) == 2 + assert list(tensordict_output["tool_rewards"][0]) == [0.0] + assert list(tensordict_output["tool_rewards"][1]) == [] + + # Verify round-trip conversion works perfectly + reconstructed = DataProto.from_tensordict(tensordict_output) + assert len(reconstructed) == 2 + assert reconstructed.meta_info["global_steps"] == 42 + assert torch.all(torch.eq(reconstructed.batch["prompts"], prompts)).item() + + def test_serialize_deserialize_single_tensor(): """Test serialization and deserialization of a single tensor""" # Create test tensor diff --git a/verl/protocol.py b/verl/protocol.py index e0b1affe8f1..53d291b1fe6 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -1118,6 +1118,8 @@ def to_tensordict(self) -> TensorDict: tensor_batch = self.batch.to_dict() non_tensor_batch = self.non_tensor_batch + from tensordict.tensorclass import NonTensorData, NonTensorStack + from verl.utils import tensordict_utils as tu common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys()) @@ -1125,7 +1127,8 @@ def to_tensordict(self) -> TensorDict: for key, val in non_tensor_batch.items(): assert isinstance(val, np.ndarray) - tensor_batch[key] = val.tolist() + # Convert to NonTensorStack instead of plain list to handle nested structures + tensor_batch[key] = NonTensorStack.from_list([NonTensorData(item) for item in val]) output = tu.get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info) return output diff --git a/verl/utils/tensordict_utils.py b/verl/utils/tensordict_utils.py index fb26f27c8fb..25f828dbce7 100644 --- a/verl/utils/tensordict_utils.py +++ b/verl/utils/tensordict_utils.py @@ -110,6 +110,17 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: if isinstance(val, torch.Tensor) and val.is_nested: assert val.is_contiguous(), "Nested tensors must be contiguous. Try setting layout=torch.jagged" + # Skip validation for NonTensorStack as it's already properly formatted + if isinstance(val, NonTensorStack): + if batch_size is None: + batch_size = len(val) + else: + assert len(val) == batch_size, ( + f"Batch size of NonTensorStack {key} is not consistent with other tensors. " + f"Expected {batch_size}, got {len(val)}" + ) + continue + if isinstance(val, list): for v in val: assert not isinstance(v, torch.Tensor), (