diff --git a/test/test_transforms.py b/test/test_transforms.py index cc3ca40b059..7fae20ad7e1 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -116,6 +116,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, MultiStepTransform, NoopResetEnv, @@ -2177,6 +2178,158 @@ def test_transform_no_env(self, device, batch): pytest.skip("TrajCounter cannot be called without env") +# TODO: Add tests that hash NonTensorStacks of strings +class TestHash(TransformBase): + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_transform_no_env(self, datatype): + if datatype == "tensor": + obs = torch.tensor(10) + elif datatype == "str": + obs = "abcdefg" + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + td = TensorDict( + { + "observation": obs, + } + ) + t = Hash(in_keys=["observation"], out_keys=["hash"]) + td_hashed = t(td) + + assert td_hashed["observation"] is td["observation"] + assert td_hashed["hash"] == hash(td["observation"]) + + def test_single_trans_env_check(self): + t = Hash(in_keys=["observation"], out_keys=["hash"]) + env = TransformedEnv(CountingEnv(), t) + check_env_specs(env) + + def test_serial_trans_env_check(self): + def make_env(): + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + ) + return TransformedEnv(CountingEnv(), t) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + def make_env(): + t = Hash(in_keys=["observation"], out_keys=["hash"]) + return TransformedEnv(CountingEnv(), t) + + env = maybe_fork_ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + def test_trans_serial_env_check(self): + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + ) + + env = TransformedEnv(SerialEnv(2, CountingEnv), t) + check_env_specs(env) + + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + ) + + env = TransformedEnv(maybe_fork_ParallelEnv(2, CountingEnv), t) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_transform_compose(self, datatype): + if datatype == "tensor": + obs = torch.tensor(10) + elif datatype == "str": + obs = "abcdefg" + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + td = TensorDict( + { + "observation": obs, + } + ) + t = Hash(in_keys=["observation"], out_keys=["hash"]) + t = Compose(t) + td_hashed = t(td) + + assert td_hashed["observation"] is td["observation"] + assert td_hashed["hash"] == hash(td["observation"]) + + def test_transform_model(self): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hash"), ("hash",)], + ) + model = nn.Sequential(t, nn.Identity()) + td = TensorDict( + {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] + ) + td_out = model(td) + assert ("next", "hash") in td_out.keys(True) + assert ("hash",) in td_out.keys(True) + assert td_out["next", "hash"] == hash(td["next", "observation"]) + assert td_out["hash"] == hash(td["observation"]) + + @pytest.mark.skipif(not _has_gym, reason="Gym not found") + def test_transform_env(self): + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + ) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) + assert env.observation_spec["hash"] + assert "observation" in env.observation_spec + assert "observation" in env.base_env.observation_spec + check_env_specs(env) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hash"), ("hash",)], + ) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(t) + td = TensorDict( + { + "observation": torch.randn(3, 4), + "next": TensorDict( + {"observation": torch.randn(3, 4)}, + [], + ), + }, + [], + ).expand(10) + rb.extend(td) + td = rb.sample(2) + assert "observation_out" in td.keys() + assert "observation" not in td.keys() + assert ("next", "observation") not in td.keys(True) + + def test_transform_inverse(self): + raise pytest.skip("No inverse for Hash") + + class TestStack(TransformBase): def test_single_trans_env_check(self): t = Stack( diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index b863ad0801c..85d8b993335 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -67,6 +67,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, KLRewardTransform, MultiStepTransform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 77f6ecc03bf..8e7ecbf2c65 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -31,6 +31,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, NoopResetEnv, ObservationNorm, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f3329d085df..e3dca6ca069 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4400,6 +4400,52 @@ def __repr__(self) -> str: ) +class Hash(Transform): + """Adds a hash value to a tensordict. + + Args: + in_keys (sequence of NestedKey): the key of the data to create the hash from. + out_key (sequence of NestedKey): the key of the resulting hash. + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey], + ): + super().__init__(in_keys=in_keys, out_keys=out_keys) + + # TODO: If this transform is run on a tensordict like + # `TensorDict({"obs": # tensor.rand(2)}, batch_size=[2])`, then + # `_apply_transform` will create only one hash value for the tensor of size + # 2. Then, when `forward` tries to add the hash to the tensordict, an error + # is raised since the hash doesn't have a leading dimension of size 2. + # TODO: Add support for NonTensorStack inputs. + def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: + if isinstance(observation, NonTensorData): + obs = observation.get("data") + else: + obs = observation + return hash(obs) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + with _set_missing_tolerance(self, True): + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if not isinstance(observation_spec, Composite): + raise TypeError(f"{self}: Only specs of type Composite can be transformed") + for out_key in self.out_keys: + observation_spec.set( + out_key, + Unbounded(shape=(), dtype=torch.int64), + ) + return observation_spec + + class Stack(Transform): """Stacks tensors and tensordicts.