Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lerobot/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def reset(self):
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

Expand All @@ -132,7 +132,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def reset(self):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
Expand All @@ -111,7 +111,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:

return actions

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ def reset(self):
def get_optim_params(self) -> dict:
return self.parameters()

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0")

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/pi0fast/modeling_pi0fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,12 @@ def _pi_aloha_encode_actions_inv(self, actions):
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

Expand Down
2 changes: 1 addition & 1 deletion src/lerobot/policies/sac/modeling_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def reset(self):
"""Reset the policy"""
pass

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
Expand Down
3 changes: 2 additions & 1 deletion src/lerobot/policies/smolvla/modeling_smolvla.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:

return batch

@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self.eval()

Expand All @@ -422,7 +423,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None =
actions = self._get_action_chunk(batch, noise)
return actions

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

Expand Down
2 changes: 1 addition & 1 deletion src/lerobot/policies/tdmpc/modeling_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def reset(self):
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/vqbet/modeling_vqbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ def reset(self):
ACTION: deque(maxlen=self.config.action_chunk_size),
}

@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions

@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

Expand Down
Loading