Skip to content

Commit

Permalink
[Feature] Add Hash transform
Browse files Browse the repository at this point in the history
ghstack-source-id: 80f920674e13db2fcbed6e82a990d35cb14c6d11
Pull Request resolved: #2648
  • Loading branch information
kurtamohler committed Dec 13, 2024
1 parent e3c3047 commit a3f8d18
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 0 deletions.
153 changes: 153 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
FrameSkipTransform,
GrayScale,
gSDENoise,
Hash,
InitTracker,
MultiStepTransform,
NoopResetEnv,
Expand Down Expand Up @@ -2177,6 +2178,158 @@ def test_transform_no_env(self, device, batch):
pytest.skip("TrajCounter cannot be called without env")


# TODO: Add tests that hash NonTensorStacks of strings
class TestHash(TransformBase):
@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_transform_no_env(self, datatype):
if datatype == "tensor":
obs = torch.tensor(10)
elif datatype == "str":
obs = "abcdefg"
else:
raise RuntimeError(f"please add a test case for datatype {datatype}")

td = TensorDict(
{
"observation": obs,
}
)
t = Hash(in_keys=["observation"], out_keys=["hash"])
td_hashed = t(td)

assert td_hashed["observation"] is td["observation"]
assert td_hashed["hash"] == hash(td["observation"])

def test_single_trans_env_check(self):
t = Hash(in_keys=["observation"], out_keys=["hash"])
env = TransformedEnv(CountingEnv(), t)
check_env_specs(env)

def test_serial_trans_env_check(self):
def make_env():
t = Hash(
in_keys=["observation"],
out_keys=["hash"],
)
return TransformedEnv(CountingEnv(), t)

env = SerialEnv(2, make_env)
check_env_specs(env)

def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv):
def make_env():
t = Hash(in_keys=["observation"], out_keys=["hash"])
return TransformedEnv(CountingEnv(), t)

env = maybe_fork_ParallelEnv(2, make_env)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

def test_trans_serial_env_check(self):
t = Hash(
in_keys=["observation"],
out_keys=["hash"],
)

env = TransformedEnv(SerialEnv(2, CountingEnv), t)
check_env_specs(env)

def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
t = Hash(
in_keys=["observation"],
out_keys=["hash"],
)

env = TransformedEnv(maybe_fork_ParallelEnv(2, CountingEnv), t)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_transform_compose(self, datatype):
if datatype == "tensor":
obs = torch.tensor(10)
elif datatype == "str":
obs = "abcdefg"
else:
raise RuntimeError(f"please add a test case for datatype {datatype}")

td = TensorDict(
{
"observation": obs,
}
)
t = Hash(in_keys=["observation"], out_keys=["hash"])
t = Compose(t)
td_hashed = t(td)

assert td_hashed["observation"] is td["observation"]
assert td_hashed["hash"] == hash(td["observation"])

def test_transform_model(self):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hash"), ("hash",)],
)
model = nn.Sequential(t, nn.Identity())
td = TensorDict(
{("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, []
)
td_out = model(td)
assert ("next", "hash") in td_out.keys(True)
assert ("hash",) in td_out.keys(True)
assert td_out["next", "hash"] == hash(td["next", "observation"])
assert td_out["hash"] == hash(td["observation"])

@pytest.mark.skipif(not _has_gym, reason="Gym not found")
def test_transform_env(self):
t = Hash(
in_keys=["observation"],
out_keys=["hash"],
)
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
assert env.observation_spec["hash"]
assert "observation" in env.observation_spec
assert "observation" in env.base_env.observation_spec
check_env_specs(env)

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hash"), ("hash",)],
)
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(t)
td = TensorDict(
{
"observation": torch.randn(3, 4),
"next": TensorDict(
{"observation": torch.randn(3, 4)},
[],
),
},
[],
).expand(10)
rb.extend(td)
td = rb.sample(2)
assert "observation_out" in td.keys()
assert "observation" not in td.keys()
assert ("next", "observation") not in td.keys(True)

def test_transform_inverse(self):
raise pytest.skip("No inverse for Hash")


class TestStack(TransformBase):
def test_single_trans_env_check(self):
t = Stack(
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
FrameSkipTransform,
GrayScale,
gSDENoise,
Hash,
InitTracker,
KLRewardTransform,
MultiStepTransform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FrameSkipTransform,
GrayScale,
gSDENoise,
Hash,
InitTracker,
NoopResetEnv,
ObservationNorm,
Expand Down
46 changes: 46 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4400,6 +4400,52 @@ def __repr__(self) -> str:
)


class Hash(Transform):
"""Adds a hash value to a tensordict.
Args:
in_keys (sequence of NestedKey): the key of the data to create the hash from.
out_key (sequence of NestedKey): the key of the resulting hash.
"""

def __init__(
self,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey],
):
super().__init__(in_keys=in_keys, out_keys=out_keys)

# TODO: If this transform is run on a tensordict like
# `TensorDict({"obs": # tensor.rand(2)}, batch_size=[2])`, then
# `_apply_transform` will create only one hash value for the tensor of size
# 2. Then, when `forward` tries to add the hash to the tensordict, an error
# is raised since the hash doesn't have a leading dimension of size 2.
# TODO: Add support for NonTensorStack inputs.
def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
if isinstance(observation, NonTensorData):
obs = observation.get("data")
else:
obs = observation
return hash(obs)

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
with _set_missing_tolerance(self, True):
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if not isinstance(observation_spec, Composite):
raise TypeError(f"{self}: Only specs of type Composite can be transformed")
for out_key in self.out_keys:
observation_spec.set(
out_key,
Unbounded(shape=(), dtype=torch.int64),
)
return observation_spec


class Stack(Transform):
"""Stacks tensors and tensordicts.
Expand Down

0 comments on commit a3f8d18

Please sign in to comment.