Skip to content

Commit 63b12f6

Browse files
fracapuanoAdilZouitine
authored andcommitted
Fixes @torch.no_grad() usage (#1455)
* fix: decorator calls with parentheses * fix no grad for normalize too Signed-off-by: Francesco Capuano <[email protected]> --------- Signed-off-by: Francesco Capuano <[email protected]>
1 parent 947b462 commit 63b12f6

File tree

9 files changed

+16
-15
lines changed

9 files changed

+16
-15
lines changed

src/lerobot/policies/act/modeling_act.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def reset(self):
107107
else:
108108
self._action_queue = deque([], maxlen=self.config.n_action_steps)
109109

110-
@torch.no_grad
110+
@torch.no_grad()
111111
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
112112
"""Select a single action given environment observations.
113113
@@ -132,7 +132,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
132132
self._action_queue.extend(actions.transpose(0, 1))
133133
return self._action_queue.popleft()
134134

135-
@torch.no_grad
135+
@torch.no_grad()
136136
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
137137
"""Predict a chunk of actions given environment observations."""
138138
self.eval()

src/lerobot/policies/diffusion/modeling_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def reset(self):
9999
if self.config.env_state_feature:
100100
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
101101

102-
@torch.no_grad
102+
@torch.no_grad()
103103
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
104104
"""Predict a chunk of actions given environment observations."""
105105
# stack n latest observations from the queue
@@ -111,7 +111,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
111111

112112
return actions
113113

114-
@torch.no_grad
114+
@torch.no_grad()
115115
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
116116
"""Select a single action given environment observations.
117117

src/lerobot/policies/normalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
150150

151151
# TODO(rcadene): should we remove torch.no_grad?
152-
@torch.no_grad
152+
@torch.no_grad()
153153
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
154154
# TODO: Remove this shallow copy
155155
batch = dict(batch) # shallow copy avoids mutating the input batch
@@ -224,7 +224,7 @@ def __init__(
224224
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
225225

226226
# TODO(rcadene): should we remove torch.no_grad?
227-
@torch.no_grad
227+
@torch.no_grad()
228228
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
229229
batch = dict(batch) # shallow copy avoids mutating the input batch
230230
for key, ft in self.features.items():

src/lerobot/policies/pi0/modeling_pi0.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,12 @@ def reset(self):
260260
def get_optim_params(self) -> dict:
261261
return self.parameters()
262262

263-
@torch.no_grad
263+
@torch.no_grad()
264264
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
265265
"""Predict a chunk of actions given environment observations."""
266266
raise NotImplementedError("Currently not implemented for PI0")
267267

268-
@torch.no_grad
268+
@torch.no_grad()
269269
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
270270
"""Select a single action given environment observations.
271271

src/lerobot/policies/pi0fast/modeling_pi0fast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,12 @@ def _pi_aloha_encode_actions_inv(self, actions):
192192
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
193193
return actions
194194

195-
@torch.no_grad
195+
@torch.no_grad()
196196
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
197197
"""Predict a chunk of actions given environment observations."""
198198
raise NotImplementedError("Currently not implemented for PI0FAST")
199199

200-
@torch.no_grad
200+
@torch.no_grad()
201201
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
202202
"""Select a single action given environment observations.
203203

src/lerobot/policies/sac/modeling_sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def reset(self):
7676
"""Reset the policy"""
7777
pass
7878

79-
@torch.no_grad
79+
@torch.no_grad()
8080
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
8181
"""Predict a chunk of actions given environment observations."""
8282
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")

src/lerobot/policies/smolvla/modeling_smolvla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
413413

414414
return batch
415415

416+
@torch.no_grad()
416417
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
417418
self.eval()
418419

@@ -422,7 +423,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None =
422423
actions = self._get_action_chunk(batch, noise)
423424
return actions
424425

425-
@torch.no_grad
426+
@torch.no_grad()
426427
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
427428
"""Select a single action given environment observations.
428429

src/lerobot/policies/tdmpc/modeling_tdmpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def reset(self):
110110
# CEM for the next step.
111111
self._prev_mean: torch.Tensor | None = None
112112

113-
@torch.no_grad
113+
@torch.no_grad()
114114
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
115115
"""Predict a chunk of actions given environment observations."""
116116
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}

src/lerobot/policies/vqbet/modeling_vqbet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,14 @@ def reset(self):
124124
ACTION: deque(maxlen=self.config.action_chunk_size),
125125
}
126126

127-
@torch.no_grad
127+
@torch.no_grad()
128128
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
129129
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
130130
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
131131
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
132132
return actions
133133

134-
@torch.no_grad
134+
@torch.no_grad()
135135
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
136136
"""Select a single action given environment observations.
137137

0 commit comments

Comments
 (0)