Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix TensorDictPrimer init #491

Merged
merged 1 commit into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions torchrl/envs/transforms/r3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,24 @@ class R3MTransform(Compose):

R3M provides pre-trained ResNet weights aimed at facilitating visual
embedding for robotic tasks. The models are trained using Ego4d.

See the paper:
R3M: A Universal Visual Representation for Robot Manipulation (Suraj Nair,
Aravind Rajeswaran, Vikash Kumar, Chelsea Finn, Abhinav Gupta)
https://arxiv.org/abs/2203.12601

The R3MTransform is created in a lazy manner: the object will be initialized
only when an attribute (a spec or the forward method) will be queried.
The reason for this is that the `_init()` method requires some attributes of
the parent environment (if any) to be accessed: by making the class lazy we
can ensure that the following code snippet works as expected:

Examples:
>>> transform = R3MTransform("resenet50", keys_in=["next_pixels"])
>>> env.append_transform(transform)
>>> # the forward method will first call _init which will look at env.observation_spec
>>> env.reset()

Args:
model_name (str): one of resnet50, resnet34 or resnet18
keys_in (list of str, optional): list of input keys. If left empty, the
Expand Down
25 changes: 17 additions & 8 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def parent(self) -> EnvBase:
if parent is None:
return parent
if not isinstance(parent, EnvBase):
print(parent, parent.parent)
# if it's not an env, it should be a Compose transform
if not isinstance(parent, Compose):
raise ValueError(
Expand Down Expand Up @@ -454,7 +455,9 @@ def append_transform(self, transform: Transform) -> None:
)
transform = transform.to(self.device)
if not isinstance(self.transform, Compose):
self.transform = Compose(self.transform)
prev_transform = self.transform
self.transform = Compose()
self.transform.append(prev_transform)
self.transform.set_parent(self)
self.transform.append(transform)

Expand Down Expand Up @@ -1853,7 +1856,6 @@ def __init__(self, random=False, default_value=0.0, **kwargs):
self.primers = kwargs
self.random = random
self.default_value = default_value
self._batch_size = []
self.device = kwargs.get("device", torch.device("cpu"))
# sanity check
for spec in self.primers.values():
Expand Down Expand Up @@ -1897,12 +1899,19 @@ def transform_observation_spec(
return observation_spec

def set_parent(self, parent: Union[Transform, EnvBase]) -> None:
parent_env = parent
while not isinstance(parent_env, EnvBase):
parent_env = parent_env.parent
self._batch_size = parent_env.batch_size
self.device = parent_env.device
return super().set_parent(parent)
super().set_parent(parent)

@property
def _batch_size(self):
return self.parent.batch_size

@property
def device(self):
return self._device

@device.setter
def device(self, value):
self._device = value

def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
for key, spec in self.primers.items():
Expand Down