Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
63f7144
fix: sharing predicted chunk with user
fracapuano Apr 23, 2025
9020109
[pre-commit.ci] pre-commit autoupdate (#1011)
pre-commit-ci[bot] Apr 23, 2025
da8bec0
Revert "[pre-commit.ci] pre-commit autoupdate" (#1025)
aliberts Apr 24, 2025
309deca
fix(ci): Pin draccus (<0.10.0) and torch (<2.7) to fix pipeline (#1022)
AdilZouitine Apr 24, 2025
3ce6e22
fix(ci): Pin `torchcodec` (==0.2.1) to fix pipeline temporarly (#1030)
AdilZouitine Apr 24, 2025
2e5aab3
Update tutorial (#1021)
pkooij Apr 28, 2025
be8f9a4
Add description motor order SO-101 leader (#1051)
pkooij Apr 29, 2025
1a45b26
feat(encoding): switching to PyAV for ffmpeg related tasks (#983)
CarolinePascal Apr 29, 2025
d3f5991
feat(docs): Add new docs build process (#1046)
pkooij May 2, 2025
60b5a21
Docs: adapt text + fix video code (#1064)
pkooij May 2, 2025
9d59f12
Fix typos (#1070)
omahs May 5, 2025
8bcfe4e
docs: minor corrections and clean-up (#1089)
pkooij May 9, 2025
a220c15
Update 10_use_so100.md; use diff syntax (#944)
mishig25 May 9, 2025
a1793bc
Update 12_use_so101.md (#1081)
CharlesCNorton May 9, 2025
d3fc33e
bug fix for #1071 When --display_data=true, Failed running control_ro…
masato-ka May 9, 2025
5517320
Add editable -e for feetech install command (#1133)
pkooij May 20, 2025
e099821
Fix: emptying action queue between resets (#1117)
fracapuano May 22, 2025
d162f08
fix: typos and grammar (#1148)
DeVikingMark May 25, 2025
06271de
Update README.md (#1160)
mshukor May 27, 2025
8427a73
Update README.md (#1163)
mshukor May 27, 2025
6eb03fb
[Fix] Unpin torch beyond 2.6.0 & torchcodec beyond 0.2.1 (#1127)
AdilZouitine May 28, 2025
8ad376a
(hotfix): nightly CI by clipping pymunk version below 7.0.0 (#1182)
AdilZouitine Jun 2, 2025
4583b99
[pre-commit.ci] pre-commit autoupdate (#1048)
pre-commit-ci[bot] Jun 2, 2025
eba1747
Add SmolVLA (#1175)
mshukor Jun 3, 2025
8457cfd
Fix SmolVLA loss not sent to wandb (#1198)
ben-z Jun 5, 2025
42bad54
Hardware API redesign (#777)
aliberts Jun 5, 2025
6384eee
fix(smolvla): update record.py, fix populate_queues and remove unused…
imstevenpmwork Jun 6, 2025
79d922d
replaced OBS_ROBOT with OBS_STATE constant (#1211)
utterwqlnut Jun 6, 2025
60feeaf
Fix test_teleoperate (#1216)
aliberts Jun 6, 2025
335c38b
Fix LeKiwi example (#1217)
aliberts Jun 6, 2025
83c0c1f
Fix smolVLA dependencies (#1218)
aliberts Jun 6, 2025
f9db80c
fix(pyserial): adding pyserial dependency to global ones (#1219)
CarolinePascal Jun 6, 2025
9ec68eb
Update SmolVLA README.md (#1228)
mshukor Jun 8, 2025
4da5d93
Fix unable to set camera width/height to non-default (#1225)
ben-z Jun 10, 2025
b275656
Update tutorial link (#1250)
Tiryoh Jun 10, 2025
3b32731
update KochFollower.get_observation() so it returns same observation …
skalade Jun 10, 2025
79b928e
[pre-commit.ci] pre-commit autoupdate (#1185)
pre-commit-ci[bot] Jun 10, 2025
36908fc
Proposal for fix for enter_pressed on Windows (#1230)
koenvanwijk Jun 10, 2025
235d8b3
fix: update pi0 dependency version constraint (#1247)
YushunXiang Jun 10, 2025
3305d2e
Match motor names with ids lekiwi (#1261)
pkooij Jun 11, 2025
bf99e98
fix issues: checkpoints keys mismatch and 'task' tokenisation in smol…
danaaubakirova Jun 11, 2025
50e6761
fix(docs): update realsense documentation (#1268)
imstevenpmwork Jun 11, 2025
456359c
Use HF Papers (#1120)
qgallouedec Jun 12, 2025
be64bd2
Skip normalization parameters in load_smolvla (#1274)
aliberts Jun 13, 2025
8460c7c
fix(record): no teleop needed when running with policy (#1284)
imstevenpmwork Jun 13, 2025
0d2800d
Port HIL SERL (#644)
AdilZouitine Jun 13, 2025
402f14a
fix(docs): SmolVLA fine-tuning getting started (#1201)
danaaubakirova Jun 13, 2025
c662f8c
chore(teleop): print calibration path saved (#1286)
imstevenpmwork Jun 13, 2025
bd4cc25
chore(dependencies): add gamepad support with pygame and hidapi (#1287)
AdilZouitine Jun 13, 2025
0bfc27d
Robot integration tutorial (#1285)
aliberts Jun 13, 2025
f775326
fix(docs): update send_feedback docstrings
imstevenpmwork Jun 13, 2025
88f137d
Add sim tutorial, fix lekiwi motor config, add notebook links (#1275)
pkooij Jun 13, 2025
814e48f
Fixes on robot integration tutorial (#1290)
aliberts Jun 13, 2025
c54e9d4
Add keyboard teleop device to control the end effector robot (#1289)
michel-aractingi Jun 14, 2025
3d920f7
Improve type hints (#1293)
tidely Jun 14, 2025
67d016b
fix(record): no teleop arg in reset environment (#1294)
imstevenpmwork Jun 14, 2025
27e47fe
`learner.py` import so101_leader instead of so100 (#1295)
michel-aractingi Jun 14, 2025
3933448
Fixing `PI0` Policy (#1297)
fracapuano Jun 14, 2025
2c83f2e
`gym_manipulator.py` Remove None value action_intervention of BaseLea…
michel-aractingi Jun 14, 2025
c27735a
(chore): incorrect resume parameter in recording documentation (#1301)
DavidLMS Jun 14, 2025
ee63451
Update lekiwi.mdx (#1229)
koenvanwijk Jun 14, 2025
16ce5e7
bump `pi0` and `hil` transformers version (#1298)
fracapuano Jun 15, 2025
04d46e5
docs: fix imitation learning robots docs command (#1308)
imstevenpmwork Jun 15, 2025
e46cccb
fix(benchmarks): remove .numpy() from frame in benchmark script (#1354)
imstevenpmwork Jun 19, 2025
fee9422
add smolvla to the supported policies to run tests (:
Jun 21, 2025
1317a99
add: chunk-level access for the policy
Jun 21, 2025
5c3119d
Merge branch 'main' into user/fracapuano/2025-04-23-predicting-chunks
fracapuano Jun 21, 2025
ab06419
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2025
7fbf738
add: smolvla in availables
Jun 21, 2025
b307886
remove: smolvla from library supported policies
Jun 26, 2025
197b3f4
fix: change env for training, xarm is broken as of now
Jun 26, 2025
35b7a08
add: predict_action_chunk to all supported policies
Jun 26, 2025
6a8f2bf
fix: add robot type constants
Jun 26, 2025
ffc3d8f
add: predict action chunk in base policy class
Jun 26, 2025
a4074f3
restore original Makefile
Jun 26, 2025
21734e5
fix: minor
Jun 26, 2025
8df1809
fix: dict keys come from lerobot/constants
Jun 26, 2025
342e3f2
fix: improve act encapsulation, properly supporting temporal ensembling
Jun 26, 2025
d0187a3
fix: smolvla action chunking
Jun 26, 2025
aa01e8c
fix: very minor, but very annoying
Jun 26, 2025
b70573e
fix: minor
Jun 26, 2025
bdb1f5c
fix minor naming
fracapuano Jun 26, 2025
cdeaf19
fix: refactoring inference for single actions and chunks into differe…
Jun 26, 2025
cba1e62
fix: minor
Jun 26, 2025
d4277d1
fix: temporal ensembling
Jun 26, 2025
e653a58
fix: moving populate queues out of modular component for batch prepar…
Jun 26, 2025
bcfb66e
Merge branch 'main' into user/fracapuano/2025-04-23-predicting-chunks
imstevenpmwork Jun 26, 2025
a7ba5ab
fix: minor for CI
Jun 26, 2025
9a2fe6a
fix: smovla debug
Jun 26, 2025
9617c90
fix: reward classifier, maybe the last policy lacking?
Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lerobot/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
REWARD = "next.reward"

ROBOTS = "robots"
ROBOT_TYPE = "robot_type"
TELEOPERATORS = "teleoperators"

# files & directories
Expand Down
36 changes: 20 additions & 16 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d

from lerobot.common.constants import ACTION, OBS_IMAGES
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
Expand Down Expand Up @@ -114,46 +115,49 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed

batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]

# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
if self.config.temporal_ensemble_coeff is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
actions = self.predict_action_chunk(batch)
action = self.temporal_ensembler.update(actions)
return action

# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model(batch)[0][:, : self.config.n_action_steps]

# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]

# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()

batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]

actions = self.model(batch)[0]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions

def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]

batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)

l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()

loss_dict = {"l1_loss": l1_loss.item()}
Expand Down
36 changes: 19 additions & 17 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn

from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
Expand Down Expand Up @@ -99,6 +99,18 @@ def reset(self):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)

@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)

# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]

return actions

@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
Expand All @@ -124,33 +136,23 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)

if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)

# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))

self._queues["action"].extend(actions.transpose(0, 1))

action = self._queues["action"].popleft()
action = self._queues[ACTION].popleft()
return action

def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
Expand Down
5 changes: 5 additions & 0 deletions lerobot/common/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ def reset(self):
def get_optim_params(self) -> dict:
return self.parameters()

@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0")

@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
Expand Down
5 changes: 5 additions & 0 deletions lerobot/common/policies/pi0fast/modeling_pi0fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def _pi_aloha_encode_actions_inv(self, actions):
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions

@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")

@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
Expand Down
9 changes: 9 additions & 0 deletions lerobot/common/policies/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""
raise NotImplementedError

@abc.abstractmethod
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.

Child classes using action chunking should use this method within `select_action` to form the action chunk
cached for selection.
"""
raise NotImplementedError

@abc.abstractmethod
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
Expand Down
5 changes: 5 additions & 0 deletions lerobot/common/policies/sac/modeling_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def reset(self):
"""Reset the policy"""
pass

@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")

@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,13 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""
raise NotImplementedError("Reward classifiers do not select actions")

def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not produce action chunks.
"""
raise NotImplementedError("Reward classifiers do not predict action chunks")

def reset(self):
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
Expand Down
69 changes: 44 additions & 25 deletions lerobot/common/policies/smolvla/modeling_smolvla.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,45 @@ def _load_as_safetensor(
def get_optim_params(self) -> dict:
return self.parameters()

def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)

images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)

actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)

# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]

actions = self.unnormalize_outputs({ACTION: actions})[ACTION]

if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)

return actions

def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])

batch = self.normalize_inputs(batch)

return batch

def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self.eval()

batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])

actions = self._get_action_chunk(batch, noise)
return actions

@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
Expand All @@ -392,38 +431,18 @@ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -
queue is empty.
"""
self.eval()

if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])

batch = self.normalize_inputs(batch)

batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])

# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._queues[ACTION]) == 0:
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)

actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]

actions = self.unnormalize_outputs({"action": actions})["action"]
actions = self._get_action_chunk(batch, noise)

if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)

# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])

return self._queues[ACTION].popleft()

def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
Expand Down
Loading
Loading