Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow multiple (nested) action, reward, done keys in env,vec_env and collectors #1462

Merged
merged 37 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4f597bf
temp
matteobettini Aug 4, 2023
49bc8e5
action
matteobettini Aug 4, 2023
36e1afb
amend
matteobettini Aug 4, 2023
eee5045
Merge branch 'main' into allow-all-specs-compsite
matteobettini Aug 15, 2023
92f62a9
reward spec
matteobettini Aug 15, 2023
5a77edd
reward spec
matteobettini Aug 15, 2023
1c334b1
done spec
matteobettini Aug 15, 2023
7dc7548
done spec
matteobettini Aug 15, 2023
ba13680
fix
matteobettini Aug 15, 2023
2f548ea
rollout and step_mdp
matteobettini Aug 15, 2023
4054f61
fix
matteobettini Aug 15, 2023
a772289
amend
matteobettini Aug 15, 2023
5baa353
added todos for _reset
matteobettini Aug 15, 2023
b6c1047
docs
matteobettini Aug 16, 2023
5f294d6
fix transforms
matteobettini Aug 16, 2023
e20298e
vec_env
matteobettini Aug 16, 2023
873dbbf
collector
matteobettini Aug 16, 2023
4332984
treat done
matteobettini Aug 17, 2023
162e40f
amend
matteobettini Aug 18, 2023
d9c0dbb
amend
matteobettini Aug 18, 2023
451e9a9
collectors and vec_env
matteobettini Aug 18, 2023
e8e410e
TEMP
matteobettini Aug 18, 2023
d3cbd5d
Revert "TEMP"
matteobettini Aug 18, 2023
ea1fe3f
amend
matteobettini Aug 21, 2023
8d5abef
Merge branch 'main' into allow-all-specs-compsite
vmoens Aug 30, 2023
334aa8d
fix review
matteobettini Aug 30, 2023
4830358
Update torchrl/envs/vec_env.py
matteobettini Aug 30, 2023
78be054
Update torchrl/envs/vec_env.py
matteobettini Aug 30, 2023
836d085
Update torchrl/envs/vec_env.py
matteobettini Aug 30, 2023
95dd02c
Update torchrl/envs/common.py
matteobettini Aug 30, 2023
d755d89
Update torchrl/envs/common.py
matteobettini Aug 30, 2023
9187b28
Update torchrl/envs/common.py
matteobettini Aug 30, 2023
19875bd
Update torchrl/envs/common.py
matteobettini Aug 30, 2023
01fc27a
Update torchrl/envs/common.py
matteobettini Aug 30, 2023
790ff36
Update torchrl/envs/common.py
matteobettini Aug 30, 2023
793b738
Update torchrl/envs/common.py
matteobettini Aug 30, 2023
6f1debe
preappend full_ before specs
matteobettini Aug 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/test_envs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def test_step_mdp_speed(
benchmark(
step_mdp,
td,
action_key=action_key,
reward_key=reward_key,
done_key=done_key,
action_keys=action_key,
reward_keys=reward_key,
done_keys=done_key,
keep_other=keep_other,
exclude_reward=exclude_reward,
exclude_done=exclude_done,
Expand Down
234 changes: 234 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,3 +1483,237 @@ def _step(

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)


class MultiKeyCountingEnvPolicy:
def __init__(
self,
full_action_spec: TensorSpec,
count: bool = True,
deterministic: bool = False,
):
if not deterministic and not count:
raise ValueError("Not counting policy is always deterministic")

self.full_action_spec = full_action_spec
self.count = count
self.deterministic = deterministic

def __call__(self, td: TensorDictBase) -> TensorDictBase:
action_td = self.full_action_spec.zero()
if self.count:
if self.deterministic:
action_td["nested_1", "action"] += 1
action_td["nested_2", "azione"] += 1
action_td["action"][..., 1] = 1
else:
# We choose an action at random
choice = torch.randint(0, 3, ()).item()
if choice == 0:
action_td["nested_1", "action"] += 1
elif choice == 1:
action_td["nested_2", "azione"] += 1
else:
action_td["action"][..., 1] = 1
return td.update(action_td)


class MultiKeyCountingEnv(EnvBase):
def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
super().__init__(**kwargs)

self.max_steps = max_steps
self.start_val = start_val
self.nested_dim_1 = 3
self.nested_dim_2 = 2

count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int)
count_nested_1 = torch.zeros(
(*self.batch_size, self.nested_dim_1, 1),
device=self.device,
dtype=torch.int,
)
count_nested_2 = torch.zeros(
(*self.batch_size, self.nested_dim_2, 1),
device=self.device,
dtype=torch.int,
)

count[:] = self.start_val
count_nested_1[:] = self.start_val
count_nested_2[:] = self.start_val

self.register_buffer("count", count)
self.register_buffer("count_nested_1", count_nested_1)
self.register_buffer("count_nested_2", count_nested_2)

self.make_specs()

self.action_spec = self.unbatched_action_spec.expand(
*self.batch_size, *self.unbatched_action_spec.shape
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
)
self.reward_spec = self.unbatched_reward_spec.expand(
*self.batch_size, *self.unbatched_reward_spec.shape
)
self.done_spec = self.unbatched_done_spec.expand(
*self.batch_size, *self.unbatched_done_spec.shape
)

def make_specs(self):
self.unbatched_observation_spec = CompositeSpec(
nested_1=CompositeSpec(
observation=BoundedTensorSpec(
minimum=0, maximum=200, shape=(self.nested_dim_1, 3)
),
shape=(self.nested_dim_1,),
),
nested_2=CompositeSpec(
observation=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 2)),
shape=(self.nested_dim_2,),
),
observation=UnboundedContinuousTensorSpec(
shape=(
10,
10,
3,
)
),
)

self.unbatched_action_spec = CompositeSpec(
nested_1=CompositeSpec(
action=DiscreteTensorSpec(n=2, shape=(self.nested_dim_1,)),
shape=(self.nested_dim_1,),
),
nested_2=CompositeSpec(
azione=BoundedTensorSpec(
minimum=0, maximum=100, shape=(self.nested_dim_2, 1)
),
shape=(self.nested_dim_2,),
),
action=OneHotDiscreteTensorSpec(n=2),
)

self.unbatched_reward_spec = CompositeSpec(
nested_1=CompositeSpec(
gift=UnboundedContinuousTensorSpec(shape=(self.nested_dim_1, 1)),
shape=(self.nested_dim_1,),
),
nested_2=CompositeSpec(
reward=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 1)),
shape=(self.nested_dim_2,),
),
reward=UnboundedContinuousTensorSpec(shape=(1,)),
)

self.unbatched_done_spec = CompositeSpec(
nested_1=CompositeSpec(
done=DiscreteTensorSpec(
n=2,
shape=(self.nested_dim_1, 1),
dtype=torch.bool,
),
shape=(self.nested_dim_1,),
),
nested_2=CompositeSpec(
done=DiscreteTensorSpec(
n=2,
shape=(self.nested_dim_2, 1),
dtype=torch.bool,
),
shape=(self.nested_dim_2,),
),
done=DiscreteTensorSpec(
n=2,
shape=(1,),
dtype=torch.bool,
),
)

def _reset(
self,
tensordict: TensorDictBase = None,
**kwargs,
) -> TensorDictBase:
reset_all = False
if tensordict is not None:
_reset = tensordict.get("_reset", None)
if _reset is not None:
self.count[_reset.squeeze(-1)] = self.start_val

_reset_nested_1 = tensordict.get(("nested_1", "_reset"), None)
if _reset_nested_1 is not None:
self.count_nested_1[_reset_nested_1.squeeze(-1)] = self.start_val

_reset_nested_2 = tensordict.get(("nested_2", "_reset"), None)
if _reset_nested_2 is not None:
self.count_nested_2[_reset_nested_2.squeeze(-1)] = self.start_val

if _reset is None and _reset_nested_1 is None and _reset_nested_2 is None:
reset_all = True

if tensordict is None or reset_all:
self.count[:] = self.start_val
self.count_nested_1[:] = self.start_val
self.count_nested_2[:] = self.start_val

reset_td = self.observation_spec.zero()
reset_td["observation"] += expand_right(
self.count, reset_td["observation"].shape
)
reset_td["nested_1", "observation"] += expand_right(
self.count_nested_1, reset_td["nested_1", "observation"].shape
)
reset_td["nested_2", "observation"] += expand_right(
self.count_nested_2, reset_td["nested_2", "observation"].shape
)

reset_td.update(self.output_spec["_done_spec"].zero())

assert reset_td.batch_size == self.batch_size

return reset_td

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:

# Each action has a corresponding reward, done, and observation
reward = self.output_spec["_reward_spec"].zero()
done = self.output_spec["_done_spec"].zero()
td = self.observation_spec.zero()

one_hot_action = tensordict["action"].argmax(-1).unsqueeze(-1)
reward["reward"] += one_hot_action.to(torch.float)
self.count += one_hot_action.to(torch.int)
td["observation"] += expand_right(self.count, td["observation"].shape)
done["done"] = self.count > self.max_steps

discrete_action = tensordict["nested_1"]["action"].unsqueeze(-1)
reward["nested_1"]["gift"] += discrete_action.to(torch.float)
self.count_nested_1 += discrete_action.to(torch.int)
td["nested_1", "observation"] += expand_right(
self.count_nested_1, td["nested_1", "observation"].shape
)
done["nested_1", "done"] = self.count_nested_1 > self.max_steps

continuous_action = tensordict["nested_2"]["azione"]
reward["nested_2"]["reward"] += continuous_action.to(torch.float)
self.count_nested_2 += continuous_action.to(torch.bool)
td["nested_2", "observation"] += expand_right(
self.count_nested_2, td["nested_2", "observation"].shape
)
done["nested_2", "done"] = self.count_nested_2 > self.max_steps

td.update(done)
td.update(reward)

assert td.batch_size == self.batch_size
return td.select().set("next", td)

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)
80 changes: 80 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
HeteroCountingEnv,
HeteroCountingEnvPolicy,
MockSerialEnv,
MultiKeyCountingEnv,
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from tensordict.nn import TensorDictModule
from tensordict.tensordict import assert_allclose_td, TensorDict

from test_env import TestMultiKeyEnvs
from torch import nn
from torchrl._utils import prod, seed_generator
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
Expand All @@ -47,6 +51,7 @@
)
from torchrl.envs.libs.gym import _has_gym, GymEnv
from torchrl.envs.transforms import TransformedEnv, VecNorm
from torchrl.envs.utils import _replace_last
from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule

# torch.set_default_dtype(torch.double)
Expand Down Expand Up @@ -1552,6 +1557,81 @@ def test_multi_collector_het_env_consistency(
assert_allclose_td(c2, d2)


class TestMultiKeyEnvsCollector:
@pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)])
@pytest.mark.parametrize("frames_per_batch", [4, 8, 16])
@pytest.mark.parametrize("max_steps", [2, 3])
def test_collector(self, batch_size, frames_per_batch, max_steps, seed=1):
env = MultiKeyCountingEnv(batch_size=batch_size, max_steps=max_steps)
torch.manual_seed(seed)
policy = MultiKeyCountingEnvPolicy(env.input_spec["_action_spec"])
ccollector = SyncDataCollector(
create_env_fn=env,
policy=policy,
frames_per_batch=frames_per_batch,
total_frames=100,
device="cpu",
)

for _td in ccollector:
break
ccollector.shutdown()
for done_key in env.done_keys:
assert _replace_last(done_key, "_reset") not in _td.keys(True, True)
TestMultiKeyEnvs.check_rollout_consistency(_td, max_steps=max_steps)

def test_multi_collector_consistency(
self, seed=1, frames_per_batch=20, batch_dim=10
):
env = MultiKeyCountingEnv(batch_size=(batch_dim,))
env_fn = lambda: env
torch.manual_seed(seed)
policy = MultiKeyCountingEnvPolicy(
env.input_spec["_action_spec"], deterministic=True
)

ccollector = MultiaSyncDataCollector(
create_env_fn=[env_fn],
policy=policy,
frames_per_batch=frames_per_batch,
total_frames=100,
device="cpu",
)
for i, d in enumerate(ccollector):
if i == 0:
c1 = d
elif i == 1:
c2 = d
else:
break
assert d.names[-1] == "time"
with pytest.raises(AssertionError):
assert_allclose_td(c1, c2)
ccollector.shutdown()

ccollector = MultiSyncDataCollector(
create_env_fn=[env_fn],
policy=policy,
frames_per_batch=frames_per_batch,
total_frames=100,
device="cpu",
)
for i, d in enumerate(ccollector):
if i == 0:
d1 = d
elif i == 1:
d2 = d
else:
break
assert d.names[-1] == "time"
with pytest.raises(AssertionError):
assert_allclose_td(d1, d2)
ccollector.shutdown()

assert_allclose_td(c1, d1)
assert_allclose_td(c2, d2)


@pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda")
class TestUpdateParams:
class DummyEnv(EnvBase):
Expand Down
Loading
Loading