Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lerobot/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def reset(self):
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

Expand All @@ -132,7 +132,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def reset(self):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
Expand All @@ -111,7 +111,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:

return actions

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
setattr(self, "buffer_" + key.replace(".", "_"), buffer)

# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
Expand Down Expand Up @@ -224,7 +224,7 @@ def __init__(
setattr(self, "buffer_" + key.replace(".", "_"), buffer)

# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ def reset(self):
def get_optim_params(self) -> dict:
return self.parameters()

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0")

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/pi0fast/modeling_pi0fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,12 @@ def _pi_aloha_encode_actions_inv(self, actions):
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

Expand Down
2 changes: 1 addition & 1 deletion src/lerobot/policies/sac/modeling_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def reset(self):
"""Reset the policy"""
pass

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
Expand Down
3 changes: 2 additions & 1 deletion src/lerobot/policies/smolvla/modeling_smolvla.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:

return batch

@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self.eval()

Expand All @@ -422,7 +423,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None =
actions = self._get_action_chunk(batch, noise)
return actions

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

Expand Down
2 changes: 1 addition & 1 deletion src/lerobot/policies/tdmpc/modeling_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def reset(self):
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/vqbet/modeling_vqbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ def reset(self):
ACTION: deque(maxlen=self.config.action_chunk_size),
}

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

Expand Down
Loading