Skip to content

Commit

Permalink
[Doc] Tutorial revamp (#926)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 20, 2023
1 parent 099ced3 commit e2d5dbe
Show file tree
Hide file tree
Showing 15 changed files with 961 additions and 1,971 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
- name: Get output time
run: echo "The time was ${{ steps.build.outputs.time }}"
- name: Deploy
if: ${{ github.ref == 'refs/heads/main' }}
if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }}
uses: JamesIves/github-pages-deploy-action@releases/v4
with:
token: ${{ secrets.GITHUB_TOKEN }}
Expand Down
Binary file added docs/source/_static/img/cartpole.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/img/cartpole_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/img/dqn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/img/dqn_td0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/img/dqn_tdlambda.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/img/pendulum.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/img/transforms.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 6 additions & 8 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,9 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
)
elif is_tensorclass(data):
out = (
data.expand(self.max_size, *data.shape)
.clone()
.zero_()
.memmap_(prefix=self.scratch_dir)
data.clone()
.expand(self.max_size, *data.shape)
.memmap_like(prefix=self.scratch_dir)
.to(self.device)
)
for key, tensor in sorted(
Expand All @@ -384,10 +383,9 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
# out = TensorDict({}, [self.max_size, *data.shape])
print("The storage is being created: ")
out = (
data.expand(self.max_size, *data.shape)
.to_tensordict()
.zero_()
.memmap_(prefix=self.scratch_dir)
data.clone()
.expand(self.max_size, *data.shape)
.memmap_like(prefix=self.scratch_dir)
.to(self.device)
)
for key, tensor in sorted(
Expand Down
37 changes: 27 additions & 10 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,23 +584,24 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
f"got {tensordict.batch_size} and {self.batch_size}"
)

def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
"""Performs a random step in the environment given the action_spec attribute.
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
"""Performs a random action given the action_spec attribute.
Args:
tensordict (TensorDictBase, optional): tensordict where the resulting info should be written.
tensordict (TensorDictBase, optional): tensordict where the resulting action should be written.
Returns:
a tensordict object with the new observation after a random step in the environment. The action will
be stored with the "action" key.
a tensordict object with the "action" entry updated with a random
sample from the action-spec.
"""
shape = torch.Size([])
if tensordict is None:
tensordict = TensorDict(
{}, device=self.device, batch_size=self.batch_size, _run_checks=False
)
elif not self.batch_locked and not self.batch_size:

if not self.batch_locked and not self.batch_size:
shape = tensordict.shape
elif not self.batch_locked and tensordict.shape != self.batch_size:
raise RuntimeError(
Expand All @@ -611,6 +612,20 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa
)
action = self.action_spec.rand(shape)
tensordict.set("action", action)
return tensordict

def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
"""Performs a random step in the environment given the action_spec attribute.
Args:
tensordict (TensorDictBase, optional): tensordict where the resulting info should be written.
Returns:
a tensordict object with the new observation after a random step in the environment. The action will
be stored with the "action" key.
"""
tensordict = self.rand_action(tensordict)
return self.step(tensordict)

@property
Expand Down Expand Up @@ -680,7 +695,7 @@ def rollout(
if policy is None:

def policy(td):
self.rand_step(td)
self.rand_action(td)
return td

tensordicts = []
Expand Down Expand Up @@ -796,16 +811,18 @@ def to(self, device: DEVICE_TYPING) -> EnvBase:
def fake_tensordict(self) -> TensorDictBase:
"""Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout."""
input_spec = self.input_spec
fake_input = input_spec.zero()
observation_spec = self.observation_spec
fake_obs = observation_spec.zero()
fake_input = input_spec.zero()
# the input and output key may match, but the output prevails
# Hence we generate the input, and override using the output
fake_in_out = fake_input.clone().update(fake_obs)
reward_spec = self.reward_spec
fake_reward = reward_spec.zero()
fake_td = TensorDict(
{
**fake_obs,
**fake_in_out,
"next": fake_obs.clone(),
**fake_input,
"reward": fake_reward,
"done": torch.zeros(
(*self.batch_size, 1), dtype=torch.bool, device=self.device
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
if len(info) == 2:
# gym 0.26
truncation, info = info
done = done | truncation
elif len(info) == 1:
info = info[0]
elif len(info) == 0:
Expand Down
Loading

0 comments on commit e2d5dbe

Please sign in to comment.