Skip to content

Commit

Permalink
[BugFix] union -> intersection in _StepMDP check (#2039)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 25, 2024
1 parent 1fcd3e3 commit 247ed6e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
15 changes: 13 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ def test_multi_purpose_env(self, serial):
env = SerialEnv(2, ContinuousActionVecMockEnv)
else:
env = ContinuousActionVecMockEnv()
rollout = env.rollout(10)
env.rollout(10)
assert env._step_mdp.validate(None)
c = SyncDataCollector(
env, env.rand_action, frames_per_batch=10, total_frames=20
Expand All @@ -1736,7 +1736,18 @@ def test_multi_purpose_env(self, serial):
pass
assert ("collector", "traj_ids") in data.keys(True)
assert env._step_mdp.validate(None)
rollout = env.rollout(10)
env.rollout(10)

# An exception will be raised when the collector sees extra keys
if serial:
env = SerialEnv(2, ContinuousActionVecMockEnv)
else:
env = ContinuousActionVecMockEnv()
c = SyncDataCollector(
env, env.rand_action, frames_per_batch=10, total_frames=20
)
for data in c: # noqa: B007
pass


@pytest.mark.parametrize("device", get_default_devices())
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def validate(self, tensordict):
)
actual = set(tensordict.keys(True, True))
expected = set(expected)
self.validated = expected.union(actual) == expected
self.validated = expected.intersection(actual) == expected
if not self.validated:
warnings.warn(
"The expected key set and actual key set differ. "
Expand Down

0 comments on commit 247ed6e

Please sign in to comment.