Skip to content

Commit da19d0c

Browse files
fracapuanopre-commit-ci[bot]alibertsAdilZouitineimstevenpmwork
authored andcommitted
Add direct access to action chunks (#1020)
* fix: sharing predicted chunk with user * [pre-commit.ci] pre-commit autoupdate (#1011) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Revert "[pre-commit.ci] pre-commit autoupdate" (#1025) * fix(ci): Pin draccus (<0.10.0) and torch (<2.7) to fix pipeline (#1022) Co-authored-by: imstevenpmwork <[email protected]> Co-authored-by: Simon Alibert <[email protected]> * fix(ci): Pin `torchcodec` (==0.2.1) to fix pipeline temporarly (#1030) * Update tutorial (#1021) Co-authored-by: Simon Alibert <[email protected]> * Add description motor order SO-101 leader (#1051) * feat(encoding): switching to PyAV for ffmpeg related tasks (#983) * feat(docs): Add new docs build process (#1046) Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Steven Palma <[email protected]> * Docs: adapt text + fix video code (#1064) * Fix typos (#1070) * docs: minor corrections and clean-up (#1089) * Update 10_use_so100.md; use diff syntax (#944) Co-authored-by: Pepijn <[email protected]> * Update 12_use_so101.md (#1081) Co-authored-by: Pepijn <[email protected]> * bug fix for #1071 When --display_data=true, Failed running control_robot. (#1073) * Add editable -e for feetech install command (#1133) * Fix: emptying action queue between resets (#1117) * fix: typos and grammar (#1148) * Update README.md (#1160) * Update README.md (#1163) * [Fix] Unpin torch beyond 2.6.0 & torchcodec beyond 0.2.1 (#1127) * (hotfix): nightly CI by clipping pymunk version below 7.0.0 (#1182) * [pre-commit.ci] pre-commit autoupdate (#1048) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <[email protected]> * Add SmolVLA (#1175) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: fracapuano <[email protected]> Co-authored-by: Steven Palma <[email protected]> Co-authored-by: Dana Aubakirova <[email protected]> Co-authored-by: Remi <[email protected]> * Fix SmolVLA loss not sent to wandb (#1198) * Hardware API redesign (#777) Co-authored-by: Pepijn <[email protected]> Co-authored-by: Steven Palma <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma <[email protected]> Co-authored-by: Adil Zouitine <[email protected]> Co-authored-by: Pepijn <[email protected]> * fix(smolvla): update record.py, fix populate_queues and remove unused dependencies (#1208) * replaced OBS_ROBOT with OBS_STATE constant (#1211) * Fix test_teleoperate (#1216) * Fix LeKiwi example (#1217) * Fix smolVLA dependencies (#1218) * fix(pyserial): adding pyserial dependency to global ones (#1219) * Update SmolVLA README.md (#1228) * Fix unable to set camera width/height to non-default (#1225) * Update tutorial link (#1250) * update KochFollower.get_observation() so it returns same observation structure as SO101 (#1248) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] pre-commit autoupdate (#1185) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <[email protected]> * Proposal for fix for enter_pressed on Windows (#1230) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <[email protected]> * fix: update pi0 dependency version constraint (#1247) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Match motor names with ids lekiwi (#1261) * fix issues: checkpoints keys mismatch and 'task' tokenisation in smolvla (#1256) Co-authored-by: danaaubakirova <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <[email protected]> Co-authored-by: Simon Alibert <[email protected]> * fix(docs): update realsense documentation (#1268) * Use HF Papers (#1120) * Skip normalization parameters in load_smolvla (#1274) * fix(record): no teleop needed when running with policy (#1284) * Port HIL SERL (#644) Co-authored-by: Michel Aractingi <[email protected]> Co-authored-by: Eugene Mironov <[email protected]> Co-authored-by: s1lent4gnt <[email protected]> Co-authored-by: Ke Wang <[email protected]> Co-authored-by: Yoel Chornton <[email protected]> Co-authored-by: imstevenpmwork <[email protected]> Co-authored-by: Simon Alibert <[email protected]> * fix(docs): SmolVLA fine-tuning getting started (#1201) Co-authored-by: Pepijn <[email protected]> Co-authored-by: danaaubakirova <[email protected]> Co-authored-by: Simon Alibert <[email protected]> Co-authored-by: Francesco Capuano <[email protected]> Co-authored-by: Steven Palma <[email protected]> * chore(teleop): print calibration path saved (#1286) * chore(dependencies): add gamepad support with pygame and hidapi (#1287) * Robot integration tutorial (#1285) * fix(docs): update send_feedback docstrings * Add sim tutorial, fix lekiwi motor config, add notebook links (#1275) Co-authored-by: AdilZouitine <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michel Aractingi <[email protected]> Co-authored-by: s1lent4gnt <[email protected]> Co-authored-by: Michel Aractingi <[email protected]> Co-authored-by: Eugene Mironov <[email protected]> Co-authored-by: imstevenpmwork <[email protected]> Co-authored-by: Simon Alibert <[email protected]> Co-authored-by: Steven Palma <[email protected]> * Fixes on robot integration tutorial (#1290) * Add keyboard teleop device to control the end effector robot (#1289) * Improve type hints (#1293) * fix(record): no teleop arg in reset environment (#1294) * `learner.py` import so101_leader instead of so100 (#1295) Co-authored-by: Adil Zouitine <[email protected]> * Fixing `PI0` Policy (#1297) * `gym_manipulator.py` Remove None value action_intervention of BaseLeaderTeleoperator (#1299) * (chore): incorrect resume parameter in recording documentation (#1301) * Update lekiwi.mdx (#1229) * bump `pi0` and `hil` transformers version (#1298) * docs: fix imitation learning robots docs command (#1308) * fix(benchmarks): remove .numpy() from frame in benchmark script (#1354) * add smolvla to the supported policies to run tests (: * add: chunk-level access for the policy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add: smolvla in availables * remove: smolvla from library supported policies * fix: change env for training, xarm is broken as of now * add: predict_action_chunk to all supported policies * fix: add robot type constants * add: predict action chunk in base policy class * restore original Makefile * fix: minor * fix: dict keys come from lerobot/constants * fix: improve act encapsulation, properly supporting temporal ensembling * fix: smolvla action chunking * fix: very minor, but very annoying * fix: minor * fix minor naming Co-authored-by: Steven Palma <[email protected]> Signed-off-by: Francesco Capuano <[email protected]> * fix: refactoring inference for single actions and chunks into different components * fix: minor * fix: temporal ensembling * fix: moving populate queues out of modular component for batch preparation * fix: minor for CI * fix: smovla debug * fix: reward classifier, maybe the last policy lacking? --------- Signed-off-by: Francesco Capuano <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <[email protected]> Co-authored-by: Adil Zouitine <[email protected]> Co-authored-by: imstevenpmwork <[email protected]> Co-authored-by: Pepijn <[email protected]> Co-authored-by: Caroline Pascal <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: omahs <[email protected]> Co-authored-by: CharlesCNorton <[email protected]> Co-authored-by: masato-ka <[email protected]> Co-authored-by: Ragnar <[email protected]> Co-authored-by: mshukor <[email protected]> Co-authored-by: Simon Alibert <[email protected]> Co-authored-by: Steven Palma <[email protected]> Co-authored-by: Dana Aubakirova <[email protected]> Co-authored-by: Remi <[email protected]> Co-authored-by: Ben Zhang <[email protected]> Co-authored-by: Pepijn <[email protected]> Co-authored-by: Dhruva <[email protected]> Co-authored-by: Daisuke Sato <[email protected]> Co-authored-by: Sarunas Kalade <[email protected]> Co-authored-by: koenvanwijk <[email protected]> Co-authored-by: Yushun Xiang <[email protected]> Co-authored-by: danaaubakirova <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Michel Aractingi <[email protected]> Co-authored-by: Eugene Mironov <[email protected]> Co-authored-by: s1lent4gnt <[email protected]> Co-authored-by: Ke Wang <[email protected]> Co-authored-by: Yoel Chornton <[email protected]> Co-authored-by: Michel Aractingi <[email protected]> Co-authored-by: tidely <[email protected]> Co-authored-by: David <[email protected]>
1 parent 1b944c2 commit da19d0c

File tree

11 files changed

+176
-109
lines changed

11 files changed

+176
-109
lines changed

lerobot/common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
REWARD = "next.reward"
2626

2727
ROBOTS = "robots"
28+
ROBOT_TYPE = "robot_type"
2829
TELEOPERATORS = "teleoperators"
2930

3031
# files & directories

lerobot/common/policies/act/modeling_act.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from torchvision.models._utils import IntermediateLayerGetter
3434
from torchvision.ops.misc import FrozenBatchNorm2d
3535

36+
from lerobot.common.constants import ACTION, OBS_IMAGES
3637
from lerobot.common.policies.act.configuration_act import ACTConfig
3738
from lerobot.common.policies.normalize import Normalize, Unnormalize
3839
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -114,46 +115,49 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
114115
environment. It works by managing the actions in a queue and only calling `select_actions` when the
115116
queue is empty.
116117
"""
117-
self.eval()
118+
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
118119

119-
batch = self.normalize_inputs(batch)
120-
if self.config.image_features:
121-
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
122-
batch["observation.images"] = [batch[key] for key in self.config.image_features]
123-
124-
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
125-
# we are ensembling over.
126120
if self.config.temporal_ensemble_coeff is not None:
127-
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
128-
actions = self.unnormalize_outputs({"action": actions})["action"]
121+
actions = self.predict_action_chunk(batch)
129122
action = self.temporal_ensembler.update(actions)
130123
return action
131124

132125
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
133126
# querying the policy.
134127
if len(self._action_queue) == 0:
135-
actions = self.model(batch)[0][:, : self.config.n_action_steps]
136-
137-
# TODO(rcadene): make _forward return output dictionary?
138-
actions = self.unnormalize_outputs({"action": actions})["action"]
128+
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
139129

140130
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
141131
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
142132
self._action_queue.extend(actions.transpose(0, 1))
143133
return self._action_queue.popleft()
144134

135+
@torch.no_grad
136+
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
137+
"""Predict a chunk of actions given environment observations."""
138+
self.eval()
139+
140+
batch = self.normalize_inputs(batch)
141+
if self.config.image_features:
142+
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
143+
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
144+
145+
actions = self.model(batch)[0]
146+
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
147+
return actions
148+
145149
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
146150
"""Run the batch through the model and compute the loss for training or validation."""
147151
batch = self.normalize_inputs(batch)
148152
if self.config.image_features:
149153
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
150-
batch["observation.images"] = [batch[key] for key in self.config.image_features]
154+
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
151155

152156
batch = self.normalize_targets(batch)
153157
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
154158

155159
l1_loss = (
156-
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
160+
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
157161
).mean()
158162

159163
loss_dict = {"l1_loss": l1_loss.item()}

lerobot/common/policies/diffusion/modeling_diffusion.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
3434
from torch import Tensor, nn
3535

36-
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
36+
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
3737
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
3838
from lerobot.common.policies.normalize import Normalize, Unnormalize
3939
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -99,6 +99,18 @@ def reset(self):
9999
if self.config.env_state_feature:
100100
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
101101

102+
@torch.no_grad
103+
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
104+
"""Predict a chunk of actions given environment observations."""
105+
# stack n latest observations from the queue
106+
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
107+
actions = self.diffusion.generate_actions(batch)
108+
109+
# TODO(rcadene): make above methods return output dictionary?
110+
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
111+
112+
return actions
113+
102114
@torch.no_grad
103115
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
104116
"""Select a single action given environment observations.
@@ -124,33 +136,23 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
124136
batch = self.normalize_inputs(batch)
125137
if self.config.image_features:
126138
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
127-
batch["observation.images"] = torch.stack(
128-
[batch[key] for key in self.config.image_features], dim=-4
129-
)
139+
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
130140
# Note: It's important that this happens after stacking the images into a single key.
131141
self._queues = populate_queues(self._queues, batch)
132142

133-
if len(self._queues["action"]) == 0:
134-
# stack n latest observations from the queue
135-
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
136-
actions = self.diffusion.generate_actions(batch)
137-
138-
# TODO(rcadene): make above methods return output dictionary?
139-
actions = self.unnormalize_outputs({"action": actions})["action"]
143+
if len(self._queues[ACTION]) == 0:
144+
actions = self.predict_action_chunk(batch)
145+
self._queues[ACTION].extend(actions.transpose(0, 1))
140146

141-
self._queues["action"].extend(actions.transpose(0, 1))
142-
143-
action = self._queues["action"].popleft()
147+
action = self._queues[ACTION].popleft()
144148
return action
145149

146150
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
147151
"""Run the batch through the model and compute the loss for training or validation."""
148152
batch = self.normalize_inputs(batch)
149153
if self.config.image_features:
150154
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
151-
batch["observation.images"] = torch.stack(
152-
[batch[key] for key in self.config.image_features], dim=-4
153-
)
155+
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
154156
batch = self.normalize_targets(batch)
155157
loss = self.diffusion.compute_loss(batch)
156158
# no output_dict so returning None

lerobot/common/policies/pi0/modeling_pi0.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,11 @@ def reset(self):
260260
def get_optim_params(self) -> dict:
261261
return self.parameters()
262262

263+
@torch.no_grad
264+
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
265+
"""Predict a chunk of actions given environment observations."""
266+
raise NotImplementedError("Currently not implemented for PI0")
267+
263268
@torch.no_grad
264269
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
265270
"""Select a single action given environment observations.

lerobot/common/policies/pi0fast/modeling_pi0fast.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ def _pi_aloha_encode_actions_inv(self, actions):
192192
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
193193
return actions
194194

195+
@torch.no_grad
196+
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
197+
"""Predict a chunk of actions given environment observations."""
198+
raise NotImplementedError("Currently not implemented for PI0FAST")
199+
195200
@torch.no_grad
196201
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
197202
"""Select a single action given environment observations.

lerobot/common/policies/pretrained.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,15 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
171171
"""
172172
raise NotImplementedError
173173

174+
@abc.abstractmethod
175+
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
176+
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
177+
178+
Child classes using action chunking should use this method within `select_action` to form the action chunk
179+
cached for selection.
180+
"""
181+
raise NotImplementedError
182+
174183
@abc.abstractmethod
175184
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
176185
"""Return one action to run in the environment (potentially in batch mode).

lerobot/common/policies/sac/modeling_sac.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ def reset(self):
7676
"""Reset the policy"""
7777
pass
7878

79+
@torch.no_grad
80+
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
81+
"""Predict a chunk of actions given environment observations."""
82+
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
83+
7984
@torch.no_grad()
8085
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
8186
"""Select action for inference/evaluation"""

lerobot/common/policies/sac/reward_model/modeling_classifier.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,13 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
308308
"""
309309
raise NotImplementedError("Reward classifiers do not select actions")
310310

311+
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
312+
"""
313+
This method is required by PreTrainedPolicy but not used for reward classifiers.
314+
The reward classifier is not an actor and does not produce action chunks.
315+
"""
316+
raise NotImplementedError("Reward classifiers do not predict action chunks")
317+
311318
def reset(self):
312319
"""
313320
This method is required by PreTrainedPolicy but not used for reward classifiers.

lerobot/common/policies/smolvla/modeling_smolvla.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,45 @@ def _load_as_safetensor(
383383
def get_optim_params(self) -> dict:
384384
return self.parameters()
385385

386+
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
387+
for k in batch:
388+
if k in self._queues:
389+
batch[k] = torch.stack(list(self._queues[k]), dim=1)
390+
391+
images, img_masks = self.prepare_images(batch)
392+
state = self.prepare_state(batch)
393+
lang_tokens, lang_masks = self.prepare_language(batch)
394+
395+
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
396+
397+
# Unpad actions
398+
original_action_dim = self.config.action_feature.shape[0]
399+
actions = actions[:, :, :original_action_dim]
400+
401+
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
402+
403+
if self.config.adapt_to_pi_aloha:
404+
actions = self._pi_aloha_encode_actions(actions)
405+
406+
return actions
407+
408+
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
409+
if self.config.adapt_to_pi_aloha:
410+
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
411+
412+
batch = self.normalize_inputs(batch)
413+
414+
return batch
415+
416+
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
417+
self.eval()
418+
419+
batch = self._prepare_batch(batch)
420+
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
421+
422+
actions = self._get_action_chunk(batch, noise)
423+
return actions
424+
386425
@torch.no_grad
387426
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
388427
"""Select a single action given environment observations.
@@ -392,38 +431,18 @@ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -
392431
queue is empty.
393432
"""
394433
self.eval()
395-
396-
if self.config.adapt_to_pi_aloha:
397-
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
398-
399-
batch = self.normalize_inputs(batch)
400-
434+
batch = self._prepare_batch(batch)
401435
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
436+
402437
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
403438
# querying the policy.
404439
if len(self._queues[ACTION]) == 0:
405-
for k in batch:
406-
if k in self._queues:
407-
batch[k] = torch.stack(list(self._queues[k]), dim=1)
408-
images, img_masks = self.prepare_images(batch)
409-
state = self.prepare_state(batch)
410-
lang_tokens, lang_masks = self.prepare_language(batch)
411-
412-
actions = self.model.sample_actions(
413-
images, img_masks, lang_tokens, lang_masks, state, noise=noise
414-
)
415-
# Unpad actions
416-
original_action_dim = self.config.action_feature.shape[0]
417-
actions = actions[:, :, :original_action_dim]
418-
419-
actions = self.unnormalize_outputs({"action": actions})["action"]
440+
actions = self._get_action_chunk(batch, noise)
420441

421-
if self.config.adapt_to_pi_aloha:
422-
actions = self._pi_aloha_encode_actions(actions)
423-
424-
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
442+
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
425443
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
426444
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
445+
427446
return self._queues[ACTION].popleft()
428447

429448
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:

0 commit comments

Comments
 (0)