From 24a9a3b26e15a6eee9650dfe09fca81d25be4d5d Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Thu, 27 Nov 2025 19:14:53 +0800 Subject: [PATCH 1/5] update --- tests/test_protocol_v2_on_cpu.py | 61 ++++++++++++++++++++++++++++---- verl/utils/tensordict_utils.py | 56 +++++++++++++++++++++++++++-- 2 files changed, 107 insertions(+), 10 deletions(-) diff --git a/tests/test_protocol_v2_on_cpu.py b/tests/test_protocol_v2_on_cpu.py index 294783fab5e..c8e7941a195 100644 --- a/tests/test_protocol_v2_on_cpu.py +++ b/tests/test_protocol_v2_on_cpu.py @@ -328,17 +328,64 @@ def test_chunk_concat(): def test_pop(): - obs = torch.randn(100, 10) - act = torch.randn(100, 3) - dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1}) + obs = torch.randn(3, 10) + act = torch.randn(3, 3) + labels = ["a", ["b"], []] + dataset = tu.get_tensordict({"obs": obs, "act": act, "labels": labels}, non_tensor_dict={"2": 2, "1": 1}) + + dataset1 = copy.deepcopy(dataset) + + # test pop keys + popped_dataset = tu.pop_keys(dataset, keys=["obs", "2"]) + + assert popped_dataset.batch_size[0] == 3 + + assert popped_dataset.keys() == {"obs", "2"} + + assert dataset.keys() == {"act", "1", "labels"} + + # test pop non-exist key + with pytest.raises(KeyError): + tu.pop_keys(dataset, keys=["obs", "2"]) + + # test single pop + # NonTensorData + assert tu.pop(dataset1, key="2") == 2 + # NonTensorStack + assert tu.pop(dataset1, key="labels") == ['a', ['b'], []] + # Tensor + assert torch.all(torch.eq(tu.pop(dataset1, key='obs'), obs)).item() + + +def test_get(): + obs = torch.randn(3, 10) + act = torch.randn(3, 3) + labels = ["a", ["b"], []] + dataset = tu.get_tensordict({"obs": obs, "act": act, "labels": labels}, non_tensor_dict={"2": 2, "1": 1}) + + # test pop keys + popped_dataset = tu.get_keys(dataset, keys=["obs", "2"]) + + assert popped_dataset.batch_size[0] == 3 + + assert torch.all(torch.eq(popped_dataset['obs'], dataset['obs'])).item() - poped_dataset = tu.pop(dataset, keys=["obs", "2"]) + assert popped_dataset['2'] == dataset['2'] - assert poped_dataset.batch_size[0] == 100 + # test pop non-exist key + with pytest.raises(KeyError): + tu.get_keys(dataset, keys=["obs", "3"]) - assert poped_dataset.keys() == {"obs", "2"} + # test single pop + # NonTensorData + assert tu.get(dataset, key="2") == 2 + # NonTensorStack + assert tu.get(dataset, key="labels") == ['a', ['b'], []] + # Tensor + assert torch.all(torch.eq(tu.get(dataset, key='obs'), obs)).item() + # Non-exist key + assert tu.get(dataset, key="3", default=3) == 3 - assert dataset.keys() == {"act", "1"} def test_repeat(): diff --git a/verl/utils/tensordict_utils.py b/verl/utils/tensordict_utils.py index fcbe4e29290..8e8ccc68fdc 100644 --- a/verl/utils/tensordict_utils.py +++ b/verl/utils/tensordict_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Iterator +from typing import Iterator, Iterable, Any import torch from tensordict import TensorDict @@ -256,7 +256,8 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten ) for key in tensor_dict2.keys(): if key not in tensor_dict1.keys(): - tensor_dict1[key] = tensor_dict2[key] + # Note that there is a difference between tensor_dict2[key] and tensor_dict2.get(key) + tensor_dict1[key] = tensor_dict2.get(key) else: if isinstance(tensor_dict2[key], torch.Tensor): assert tensor_dict1[key].equal(tensor_dict2[key]), ( @@ -325,10 +326,59 @@ def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict): assert val == val2 -def pop(tensordict: TensorDict, keys: Iterator[str]) -> TensorDict: + +def get(tensordict: TensorDict, key: str, default=None) -> Any: + if key not in tensordict: + return default + + output = tensordict.get(key) + if isinstance(output, torch.Tensor): + return output + elif isinstance(output, NonTensorStack): + return output.tolist() + else: + assert isinstance(output, NonTensorData) + return output.data + + +def get_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: + tensor_output = {} + non_tensor_output = {} + for key in keys: + if key not in tensordict.keys(): + raise KeyError(f"key {key} not in tensordict") + output = tensordict.get(key) + if isinstance(output, torch.Tensor): + tensor_output[key] = output + elif isinstance(output, NonTensorStack): + tensor_output[key] = output.tolist() + else: + assert isinstance(output, NonTensorData) + non_tensor_output[key] = output.data + + return get_tensordict(tensor_output, non_tensor_output) + + +def pop(tensordict: TensorDict, key: str, default=None) -> Any: + if key not in tensordict.keys(): + return default + + output = tensordict.get(key) + if isinstance(output, torch.Tensor): + return tensordict.pop(key) + elif isinstance(output, NonTensorStack): + return tensordict.pop(key).tolist() + else: + assert isinstance(output, NonTensorData) + return tensordict.pop(key).data + + +def pop_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: tensor_output = {} non_tensor_output = {} for key in keys: + if key not in tensordict.keys(): + raise KeyError(f"key {key} not in tensordict") output = tensordict.get(key) if isinstance(output, torch.Tensor): tensor_output[key] = tensordict.pop(key) From 07b14687aaa376bee46e21b0be808b361d46b39c Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Thu, 27 Nov 2025 19:15:35 +0800 Subject: [PATCH 2/5] fix precommit --- tests/test_protocol_v2_on_cpu.py | 13 ++++++------- verl/utils/tensordict_utils.py | 3 +-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_protocol_v2_on_cpu.py b/tests/test_protocol_v2_on_cpu.py index c8e7941a195..f99e56fdde8 100644 --- a/tests/test_protocol_v2_on_cpu.py +++ b/tests/test_protocol_v2_on_cpu.py @@ -352,9 +352,9 @@ def test_pop(): # NonTensorData assert tu.pop(dataset1, key="2") == 2 # NonTensorStack - assert tu.pop(dataset1, key="labels") == ['a', ['b'], []] + assert tu.pop(dataset1, key="labels") == ["a", ["b"], []] # Tensor - assert torch.all(torch.eq(tu.pop(dataset1, key='obs'), obs)).item() + assert torch.all(torch.eq(tu.pop(dataset1, key="obs"), obs)).item() def test_get(): @@ -368,9 +368,9 @@ def test_get(): assert popped_dataset.batch_size[0] == 3 - assert torch.all(torch.eq(popped_dataset['obs'], dataset['obs'])).item() + assert torch.all(torch.eq(popped_dataset["obs"], dataset["obs"])).item() - assert popped_dataset['2'] == dataset['2'] + assert popped_dataset["2"] == dataset["2"] # test pop non-exist key with pytest.raises(KeyError): @@ -380,14 +380,13 @@ def test_get(): # NonTensorData assert tu.get(dataset, key="2") == 2 # NonTensorStack - assert tu.get(dataset, key="labels") == ['a', ['b'], []] + assert tu.get(dataset, key="labels") == ["a", ["b"], []] # Tensor - assert torch.all(torch.eq(tu.get(dataset, key='obs'), obs)).item() + assert torch.all(torch.eq(tu.get(dataset, key="obs"), obs)).item() # Non-exist key assert tu.get(dataset, key="3", default=3) == 3 - def test_repeat(): # Create a DataProto object with some batch and non-tensor data obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) diff --git a/verl/utils/tensordict_utils.py b/verl/utils/tensordict_utils.py index 8e8ccc68fdc..62d035e4169 100644 --- a/verl/utils/tensordict_utils.py +++ b/verl/utils/tensordict_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Iterator, Iterable, Any +from typing import Any, Iterable import torch from tensordict import TensorDict @@ -326,7 +326,6 @@ def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict): assert val == val2 - def get(tensordict: TensorDict, key: str, default=None) -> Any: if key not in tensordict: return default From f9c1504a859d9e00a3614b54517bf6e326f8cf41 Mon Sep 17 00:00:00 2001 From: Guangming Sheng Date: Thu, 27 Nov 2025 19:24:31 +0800 Subject: [PATCH 3/5] Update test_protocol_v2_on_cpu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/test_protocol_v2_on_cpu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_protocol_v2_on_cpu.py b/tests/test_protocol_v2_on_cpu.py index f99e56fdde8..cb75de4314e 100644 --- a/tests/test_protocol_v2_on_cpu.py +++ b/tests/test_protocol_v2_on_cpu.py @@ -341,6 +341,8 @@ def test_pop(): assert popped_dataset.batch_size[0] == 3 assert popped_dataset.keys() == {"obs", "2"} + assert torch.all(torch.eq(popped_dataset["obs"], obs)).item() + assert popped_dataset["2"] == 2 assert dataset.keys() == {"act", "1", "labels"} From 09667a9c777514fff4997d85b821f07f7e9f1ad1 Mon Sep 17 00:00:00 2001 From: Guangming Sheng Date: Thu, 27 Nov 2025 19:24:45 +0800 Subject: [PATCH 4/5] Update tensordict_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- verl/utils/tensordict_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/verl/utils/tensordict_utils.py b/verl/utils/tensordict_utils.py index 62d035e4169..bf7ce4fa068 100644 --- a/verl/utils/tensordict_utils.py +++ b/verl/utils/tensordict_utils.py @@ -359,17 +359,18 @@ def get_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: def pop(tensordict: TensorDict, key: str, default=None) -> Any: - if key not in tensordict.keys(): + _sentinel = object() + output = tensordict.pop(key, _sentinel) + if output is _sentinel: return default - output = tensordict.get(key) if isinstance(output, torch.Tensor): - return tensordict.pop(key) + return output elif isinstance(output, NonTensorStack): - return tensordict.pop(key).tolist() + return output.tolist() else: assert isinstance(output, NonTensorData) - return tensordict.pop(key).data + return output.data def pop_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: From f19f27424c9bd0723c8384f93f196a2dec992321 Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Thu, 27 Nov 2025 19:26:48 +0800 Subject: [PATCH 5/5] fix --- tests/test_protocol_v2_on_cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_protocol_v2_on_cpu.py b/tests/test_protocol_v2_on_cpu.py index cb75de4314e..ea5d6755fc7 100644 --- a/tests/test_protocol_v2_on_cpu.py +++ b/tests/test_protocol_v2_on_cpu.py @@ -579,7 +579,7 @@ def test_dataproto_no_batch(): selected = data.select("labels") assert selected["labels"] == labels - pop_data = tu.pop(data, keys=["labels"]) + pop_data = tu.pop_keys(data, keys=["labels"]) assert pop_data["labels"] == labels assert "labels" not in data