diff --git a/tests/test_protocol_v2_on_cpu.py b/tests/test_protocol_v2_on_cpu.py index ba2534aa094..294783fab5e 100644 --- a/tests/test_protocol_v2_on_cpu.py +++ b/tests/test_protocol_v2_on_cpu.py @@ -22,6 +22,7 @@ import numpy as np import pytest import torch +from tensordict.tensorclass import NonTensorData, NonTensorStack from verl.utils import tensordict_utils as tu @@ -45,10 +46,10 @@ def test_union_tensor_dict(): # conflict in tensor values tu.union_tensor_dict(data1, data_with_copied_obs) - data1 = tu.assign_non_tensor_dict(data1, meta_info1) + data1 = tu.assign_non_tensor(data1, **meta_info1) tu.union_tensor_dict(data1, data2) # works ok - data2 = tu.assign_non_tensor_dict(data2, meta_info2) + data2 = tu.assign_non_tensor(data2, **meta_info2) with pytest.raises(AssertionError): # conflict in NonTensorData @@ -651,3 +652,203 @@ def test_concat_tensordict(): # make sure tensordict1 and tensordict2 is untouched tu.assert_tensordict_eq(tensordict1, tensordict1_copy) tu.assert_tensordict_eq(tensordict2, tensordict2_copy) + + +def test_assign_non_tensor_stack_with_nested_lists(): + """Test assign_non_tensor_stack with lists of lists.""" + td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={}) + + # Lists of varying lengths (like turn_scores or tool_rewards) + turn_scores = [[], [0.5, 0.8], [0.9]] + tu.assign_non_tensor_stack(td, "turn_scores", turn_scores) + + # Verify data is accessible + assert len(td["turn_scores"]) == 3 + assert list(td["turn_scores"][0]) == [] + assert list(td["turn_scores"][1]) == [0.5, 0.8] + assert list(td["turn_scores"][2]) == [0.9] + + +def test_assign_non_tensor_stack_with_nested_dicts(): + """Test assign_non_tensor_stack with lists of dicts.""" + td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={}) + + # Lists of dicts (like reward_extra_info) + reward_extra_info = [{"acc": 1.0, "loss": 0.1}, {"acc": 0.0, "loss": 0.9}, {"acc": 1.0, "loss": 0.05}] + tu.assign_non_tensor_stack(td, "reward_extra_info", reward_extra_info) + + # Verify data is accessible + assert len(td["reward_extra_info"]) == 3 + assert dict(td["reward_extra_info"][0]) == {"acc": 1.0, "loss": 0.1} + assert dict(td["reward_extra_info"][1]) == {"acc": 0.0, "loss": 0.9} + assert dict(td["reward_extra_info"][2]) == {"acc": 1.0, "loss": 0.05} + + +def test_assign_non_tensor_stack_with_complex_nested(): + """Test assign_non_tensor_stack with lists of lists of dicts.""" + td = tu.get_tensordict({"obs": torch.randn(2, 4)}, non_tensor_dict={}) + + # Lists of lists of dicts (like raw_prompt) + raw_prompt = [ + [{"content": "Question 1", "role": "user"}], + [{"content": "Question 2", "role": "user"}, {"content": "Answer 2", "role": "assistant"}], + ] + tu.assign_non_tensor_stack(td, "raw_prompt", raw_prompt) + + # Verify data is accessible + assert len(td["raw_prompt"]) == 2 + assert len(td["raw_prompt"][0]) == 1 + assert dict(td["raw_prompt"][0][0]) == {"content": "Question 1", "role": "user"} + assert len(td["raw_prompt"][1]) == 2 + assert dict(td["raw_prompt"][1][0]) == {"content": "Question 2", "role": "user"} + + +def test_assign_non_tensor_handles_wrappers(): + td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={}) + + meta = {"top_p": 0.8} + tu.assign_non_tensor(td, **meta) + assert td["top_p"] == 0.8 + + wrapped = NonTensorData(0.3) + stack = NonTensorStack.from_list([NonTensorData(1.0), NonTensorData(2.0), NonTensorData(3.0)]) + tu.assign_non_tensor(td, wrapped=wrapped, stack=stack) + + assert td["wrapped"] == 0.3 + assert td["stack"] == [1.0, 2.0, 3.0] + + +def test_assign_non_tensor_stack_batch_size_check(): + td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={}) + stack = NonTensorStack.from_list([NonTensorData(1.0), NonTensorData(2.0)]) + + with pytest.raises(RuntimeError): + tu.assign_non_tensor(td, stack=stack) + + +def test_assign_non_tensor_with_auto_detection(): + """Test assign_non_tensor automatically detects and handles nested structures.""" + td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={}) + + # Mix of simple and nested data + tu.assign_non_tensor( + td, + metadata="experiment_1", # Simple value + turn_scores=[[], [0.5, 0.8], [0.9]], # Nested list + reward_extra_info=[{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}], # List of dicts + simple_list=["a", "b", "c"], # Simple list (also uses NonTensorStack for consistency) + ) + + # Verify all data is accessible + assert td["metadata"] == "experiment_1" + assert len(td["turn_scores"]) == 3 + assert list(td["turn_scores"][1]) == [0.5, 0.8] + assert len(td["reward_extra_info"]) == 3 + assert dict(td["reward_extra_info"][0]) == {"acc": 1.0} + assert len(td["simple_list"]) == 3 + assert td["simple_list"][0] == "a" + + +def test_get_tensordict_with_nested_lists(): + """Test get_tensordict automatically handles nested lists.""" + obs = torch.randn(3, 4) + turn_scores = [[], [0.5, 0.8], [0.9]] + + # This should automatically convert turn_scores to NonTensorStack + td = tu.get_tensordict({"obs": obs, "turn_scores": turn_scores}) + + # Verify tensors and nested data are both accessible + assert torch.all(torch.eq(td["obs"], obs)) + assert len(td["turn_scores"]) == 3 + assert list(td["turn_scores"][0]) == [] + assert list(td["turn_scores"][1]) == [0.5, 0.8] + + +def test_get_tensordict_with_nested_dicts(): + """Test get_tensordict automatically handles lists of dicts.""" + obs = torch.randn(3, 4) + reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}] + + td = tu.get_tensordict({"obs": obs, "reward_extra_info": reward_extra_info}) + + assert torch.all(torch.eq(td["obs"], obs)) + assert len(td["reward_extra_info"]) == 3 + assert dict(td["reward_extra_info"][0]) == {"acc": 1.0} + + +def test_get_tensordict_with_complex_nested_structures(): + """Test get_tensordict with lists of lists of dicts.""" + obs = torch.randn(2, 4) + raw_prompt = [ + [{"content": "Q1", "role": "user"}], + [{"content": "Q2", "role": "user"}, {"content": "A2", "role": "assistant"}], + ] + + td = tu.get_tensordict({"obs": obs, "raw_prompt": raw_prompt}) + + assert torch.all(torch.eq(td["obs"], obs)) + assert len(td["raw_prompt"]) == 2 + assert dict(td["raw_prompt"][0][0]) == {"content": "Q1", "role": "user"} + + +def test_get_tensordict_agent_loop_scenario(): + """Test the complete agent loop scenario with all nested types. + + This simulates the exact use case from agent loops with: + - turn_scores: lists of lists + - reward_extra_info: lists of dicts + - raw_prompt: lists of lists of dicts + - tool_rewards: lists of lists + """ + prompts = torch.randn(2, 10) + responses = torch.randn(2, 5) + + # Nested structures from agent loop + data_source = ["lighteval/MATH", "lighteval/MATH"] + uid = ["uuid-1", "uuid-2"] + turn_scores = [[], [0.5, 0.8]] # Lists of varying lengths + reward_extra_info = [{"acc": 1.0, "loss": 0.1}, {"acc": 0.0, "loss": 0.9}] + raw_prompt = [ + [{"content": "Compute 4 @ 2", "role": "user"}], + [{"content": "Compute 8 @ 7", "role": "user"}], + ] + tool_rewards = [[0.0], []] # List of lists + + # This should handle all nested structures automatically + td = tu.get_tensordict( + tensor_dict={ + "prompts": prompts, + "responses": responses, + "data_source": data_source, + "uid": uid, + "turn_scores": turn_scores, + "reward_extra_info": reward_extra_info, + "raw_prompt": raw_prompt, + "tool_rewards": tool_rewards, + }, + non_tensor_dict={"global_steps": 42}, + ) + + # Verify all data types are accessible + assert torch.all(torch.eq(td["prompts"], prompts)) + assert torch.all(torch.eq(td["responses"], responses)) + assert td["data_source"] == data_source + assert td["uid"] == uid + + # Verify nested structures + assert len(td["turn_scores"]) == 2 + assert list(td["turn_scores"][0]) == [] + assert list(td["turn_scores"][1]) == [0.5, 0.8] + + assert len(td["reward_extra_info"]) == 2 + assert dict(td["reward_extra_info"][0]) == {"acc": 1.0, "loss": 0.1} + + assert len(td["raw_prompt"]) == 2 + assert dict(td["raw_prompt"][0][0]) == {"content": "Compute 4 @ 2", "role": "user"} + + assert len(td["tool_rewards"]) == 2 + assert list(td["tool_rewards"][0]) == [0.0] + assert list(td["tool_rewards"][1]) == [] + + # Verify metadata + assert td["global_steps"] == 42 diff --git a/verl/utils/tensordict_utils.py b/verl/utils/tensordict_utils.py index 25f828dbce7..fcbe4e29290 100644 --- a/verl/utils/tensordict_utils.py +++ b/verl/utils/tensordict_utils.py @@ -20,20 +20,71 @@ from tensordict.tensorclass import NonTensorData, NonTensorStack -def assign_non_tensor_dict(tensor_dict: TensorDict, non_tensor_dict: dict): - for key, val in non_tensor_dict.items(): - assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val) - return tensor_dict - - def assign_non_tensor_data(tensor_dict: TensorDict, key, val): + assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict" tensor_dict[key] = NonTensorData(val) -def assign_non_tensor(tensordict: TensorDict, **kwargs): +def assign_non_tensor_stack(tensor_dict: TensorDict, key, val: list): + """Assign a list with potentially nested structures (lists, dicts, etc.) to TensorDict. + + This function handles complex nested data structures like: + - Lists of lists: [[], [0.5, 0.8], [0.9]] + - Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}] + - Lists of lists of dicts: [[{"content": "...", "role": "user"}]] + + These structures are wrapped in NonTensorStack so TensorDict can handle them correctly. + + Args: + tensor_dict: The TensorDict to assign to + key: The key to assign the value under + val: A list containing potentially nested structures + + Example: + >>> td = TensorDict({}, batch_size=[]) + >>> turn_scores = [[], [0.5, 0.8], [0.9]] + >>> assign_non_tensor_stack(td, "turn_scores", turn_scores) + >>> # Now td["turn_scores"] contains the nested data + """ + # Convert list to NonTensorStack to handle nested structures + # This wraps each item in NonTensorData to preserve complex objects + # TODO(petersh6): can convert back to val directly if we are not accessing .data from the NonTensorStack + assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict" + tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val]) + + +def assign_non_tensor(tensor_dict: TensorDict, **kwargs): + """Assign non-tensor data to a TensorDict. + + Automatically detects if the value is a list with nested structures and uses + the appropriate assignment method (NonTensorData for simple values, + NonTensorStack for lists with nested structures). + + Args: + tensor_dict: The TensorDict to assign to + **kwargs: Key-value pairs where values can be: + - Simple values (stored as NonTensorData) + - Lists with nested structures (stored as NonTensorStack) + + Example: + >>> td = TensorDict({"obs": torch.randn(3, 4)}, batch_size=[3]) + >>> assign_non_tensor( + ... tensor_dict=td, + ... metadata="experiment_1", # Simple value + ... turn_scores=[[], [0.5, 0.8], [0.9]] # Nested list + ... ) + """ + assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict" for key, val in kwargs.items(): - assign_non_tensor_data(tensor_dict=tensordict, key=key, val=val) - return tensordict + if isinstance(val, (NonTensorData | NonTensorStack)): + tensor_dict[key] = val + elif isinstance(val, list): + # For lists, use NonTensorStack + assign_non_tensor_stack(tensor_dict=tensor_dict, key=key, val=val) + else: + # For non-list values, use NonTensorData + assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val) + return tensor_dict def unwrap_non_tensor_data(data): @@ -92,15 +143,31 @@ def concat_tensordict(data: list[TensorDict]) -> TensorDict: def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict: - """ + """Create a TensorDict from tensors and non-tensor data. + + Automatically handles nested structures in lists by converting them to NonTensorStack. + This enables support for: + - Lists of lists: [[], [0.5, 0.8], [0.9]] + - Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}] + - Lists of lists of dicts: [[{"content": "...", "role": "user"}]] Args: - data_dict: - meta_info: + tensor_dict: Dictionary of tensors and lists to include in the TensorDict + non_tensor_dict: Dictionary of metadata to store as NonTensorData Returns: - + TensorDict with proper handling of nested structures + + Example: + >>> td = get_tensordict( + ... tensor_dict={ + ... "obs": torch.randn(3, 4), + ... "turn_scores": [[], [0.5, 0.8], [0.9]] # Nested list + ... }, + ... non_tensor_dict={"experiment": "test"} + ... ) """ + tensor_dict = tensor_dict.copy() if non_tensor_dict is None: non_tensor_dict = {} @@ -127,6 +194,9 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: "Passing a list makes the data NonTensorStack, " "which doesn't support torch.Tensor. Please convert to numpy first" ) + # Convert to NonTensorStack to handle nested structures + tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val]) + assert isinstance(val, torch.Tensor | list) if batch_size is None: