Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions tests/test_protocol_v2_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,65 @@ def test_chunk_concat():


def test_pop():
obs = torch.randn(100, 10)
act = torch.randn(100, 3)
dataset = tu.get_tensordict({"obs": obs, "act": act}, non_tensor_dict={"2": 2, "1": 1})
obs = torch.randn(3, 10)
act = torch.randn(3, 3)
labels = ["a", ["b"], []]
dataset = tu.get_tensordict({"obs": obs, "act": act, "labels": labels}, non_tensor_dict={"2": 2, "1": 1})

dataset1 = copy.deepcopy(dataset)

# test pop keys
popped_dataset = tu.pop_keys(dataset, keys=["obs", "2"])

assert popped_dataset.batch_size[0] == 3

assert popped_dataset.keys() == {"obs", "2"}
assert torch.all(torch.eq(popped_dataset["obs"], obs)).item()
assert popped_dataset["2"] == 2

assert dataset.keys() == {"act", "1", "labels"}

# test pop non-exist key
with pytest.raises(KeyError):
tu.pop_keys(dataset, keys=["obs", "2"])

# test single pop
# NonTensorData
assert tu.pop(dataset1, key="2") == 2
# NonTensorStack
assert tu.pop(dataset1, key="labels") == ["a", ["b"], []]
# Tensor
assert torch.all(torch.eq(tu.pop(dataset1, key="obs"), obs)).item()


def test_get():
obs = torch.randn(3, 10)
act = torch.randn(3, 3)
labels = ["a", ["b"], []]
dataset = tu.get_tensordict({"obs": obs, "act": act, "labels": labels}, non_tensor_dict={"2": 2, "1": 1})

# test pop keys
popped_dataset = tu.get_keys(dataset, keys=["obs", "2"])

assert popped_dataset.batch_size[0] == 3

poped_dataset = tu.pop(dataset, keys=["obs", "2"])
assert torch.all(torch.eq(popped_dataset["obs"], dataset["obs"])).item()

assert poped_dataset.batch_size[0] == 100
assert popped_dataset["2"] == dataset["2"]

assert poped_dataset.keys() == {"obs", "2"}
# test pop non-exist key
with pytest.raises(KeyError):
tu.get_keys(dataset, keys=["obs", "3"])

assert dataset.keys() == {"act", "1"}
# test single pop
# NonTensorData
assert tu.get(dataset, key="2") == 2
# NonTensorStack
assert tu.get(dataset, key="labels") == ["a", ["b"], []]
# Tensor
assert torch.all(torch.eq(tu.get(dataset, key="obs"), obs)).item()
# Non-exist key
assert tu.get(dataset, key="3", default=3) == 3


def test_repeat():
Expand Down Expand Up @@ -531,7 +579,7 @@ def test_dataproto_no_batch():
selected = data.select("labels")

assert selected["labels"] == labels
pop_data = tu.pop(data, keys=["labels"])
pop_data = tu.pop_keys(data, keys=["labels"])
assert pop_data["labels"] == labels
assert "labels" not in data

Expand Down
56 changes: 53 additions & 3 deletions verl/utils/tensordict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Iterator
from typing import Any, Iterable

import torch
from tensordict import TensorDict
Expand Down Expand Up @@ -256,7 +256,8 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten
)
for key in tensor_dict2.keys():
if key not in tensor_dict1.keys():
tensor_dict1[key] = tensor_dict2[key]
# Note that there is a difference between tensor_dict2[key] and tensor_dict2.get(key)
tensor_dict1[key] = tensor_dict2.get(key)
else:
if isinstance(tensor_dict2[key], torch.Tensor):
assert tensor_dict1[key].equal(tensor_dict2[key]), (
Expand Down Expand Up @@ -325,10 +326,59 @@ def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict):
assert val == val2


def pop(tensordict: TensorDict, keys: Iterator[str]) -> TensorDict:
def get(tensordict: TensorDict, key: str, default=None) -> Any:
if key not in tensordict:
return default

output = tensordict.get(key)
if isinstance(output, torch.Tensor):
return output
elif isinstance(output, NonTensorStack):
return output.tolist()
else:
assert isinstance(output, NonTensorData)
return output.data


def get_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict:
tensor_output = {}
non_tensor_output = {}
for key in keys:
if key not in tensordict.keys():
raise KeyError(f"key {key} not in tensordict")
output = tensordict.get(key)
if isinstance(output, torch.Tensor):
tensor_output[key] = output
elif isinstance(output, NonTensorStack):
tensor_output[key] = output.tolist()
else:
assert isinstance(output, NonTensorData)
non_tensor_output[key] = output.data

return get_tensordict(tensor_output, non_tensor_output)


def pop(tensordict: TensorDict, key: str, default=None) -> Any:
_sentinel = object()
output = tensordict.pop(key, _sentinel)
if output is _sentinel:
return default

if isinstance(output, torch.Tensor):
return output
elif isinstance(output, NonTensorStack):
return output.tolist()
else:
assert isinstance(output, NonTensorData)
return output.data


def pop_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict:
tensor_output = {}
non_tensor_output = {}
for key in keys:
if key not in tensordict.keys():
raise KeyError(f"key {key} not in tensordict")
output = tensordict.get(key)
if isinstance(output, torch.Tensor):
tensor_output[key] = tensordict.pop(key)
Expand Down
Loading