Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
158 commits
Select commit Hold shift + click to select a range
3d509aa
Add draccus, create MainConfig
Dec 5, 2024
82f197b
WIP refactor train.py and ACT
Dec 18, 2024
bed1ec3
Add policies training presets
Dec 23, 2024
0ab28eb
Update diffusion policy
Dec 23, 2024
a82e004
Add pusht and xarm env configs
Dec 23, 2024
d2ca27a
Update tdmpc
Dec 23, 2024
250e380
Update vqbet
Dec 23, 2024
d8ad763
Fix poetry relax
Dec 23, 2024
928a417
Add feature types to envs
Dec 27, 2024
b5f3287
Add EvalPipelineConfig, parse features from envs
Dec 27, 2024
72e84f2
Add custom parser
Jan 6, 2025
f6443d9
Update pretrained loading mechanisms
Jan 6, 2025
06b604b
Add dependency fixes & lock update
Jan 6, 2025
4a4ef9b
Fix pretrained_path
Jan 6, 2025
68463a3
Refactor envs, remove RealEnv
Jan 7, 2025
2bdf1d2
Fix typo
Jan 7, 2025
9c6edc2
Enable end-to-end tests
Jan 7, 2025
a29a1f1
Fix Makefile
Jan 7, 2025
d83a94c
Log eval config
Jan 8, 2025
26eef6e
Fix end-to-end tests
Jan 8, 2025
e2508f7
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_11…
Jan 8, 2025
b799e02
Remove amp & add resume test
Jan 8, 2025
6c5667a
Speed-up tests
Jan 8, 2025
af96b04
Fix poetry relax
Jan 8, 2025
4261c5a
Remove config yaml for robot devices (#594)
Cadene Jan 9, 2025
6f62154
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_11…
Jan 9, 2025
02b996a
Fix logger
Jan 9, 2025
aa93aa1
Fix bug when wandb.enable=True
Cadene Jan 9, 2025
a69b425
Remove hydra-core
Jan 9, 2025
5871fe8
Remove NoneSchedulerConfig
Jan 9, 2025
3c5e8a5
Add push_pretrained
Jan 9, 2025
1eb8527
Remove eval.episode_length
Jan 9, 2025
abaf654
Fix wandb_video
Jan 9, 2025
6bd9e12
Fix typo
Jan 10, 2025
0d0f290
Fix wandb log + RL
Cadene Jan 11, 2025
fced457
Fix decoding with None not found for NormalizationMode. Replaced by I…
Cadene Jan 12, 2025
aa228c0
Add features back into policy configs (#643)
aliberts Jan 16, 2025
83174ea
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_11…
Jan 16, 2025
e8505a7
Add transformers
Cadene Jan 16, 2025
4eac86e
WIP Add Pi0
Cadene Jan 16, 2025
3623feb
Fix env_to_policy_features call
Jan 16, 2025
d33b7d8
WIP Add Pi0
Cadene Jan 16, 2025
e11e762
Fix wandb init
Jan 17, 2025
ae44cd0
remove omegaconf
Jan 17, 2025
b2feb9f
input of block works (xs[0][0,:256]) xs[0][0,-48:]
Cadene Jan 17, 2025
340ed02
Add branch arg
Jan 20, 2025
bbd1d84
fill_kv_cache working
Cadene Jan 20, 2025
14008ab
Move deprecated
Jan 21, 2025
95a1670
Move training config
Jan 21, 2025
b953595
Remove pathable_args
Jan 21, 2025
b319ce7
local paths
molbap Jan 21, 2025
73b690e
prefix works, debuging suffix
Cadene Jan 21, 2025
618e2f6
Merge branch 'user/rcadene/2025_01_14_pi0' into various_fixes
molbap Jan 21, 2025
2f0b8cc
fix sample_step (incl weight path)
molbap Jan 21, 2025
b579bd0
move backbone modeling out of pi0
molbap Jan 21, 2025
753461f
draft clean conversion script
molbap Jan 21, 2025
9ff3927
clean up conversion and init
molbap Jan 22, 2025
2f81239
fix PreTrainedModel interface
molbap Jan 22, 2025
ba1de7e
fix
molbap Jan 22, 2025
26226b8
move bits around
molbap Jan 22, 2025
635f390
add precision
molbap Jan 22, 2025
dff0322
fix, best results so far
molbap Jan 23, 2025
c6fde78
Can evaluate in simulation
Cadene Jan 23, 2025
e971fcc
Add noise as input
Cadene Jan 23, 2025
c14c471
merge
Cadene Jan 23, 2025
0fe1a6a
small fix
Cadene Jan 23, 2025
1a7e55a
....706
Cadene Jan 23, 2025
a1ebc44
....7118
Cadene Jan 23, 2025
7dc28d4
fix attention_mask broadcast
Cadene Jan 23, 2025
3bfdc5e
Implement custom HubMixin
Jan 23, 2025
5c8d5bd
Fixes
Jan 23, 2025
af36a88
Implement PreTrainedPolicy base class
Jan 23, 2025
a12f474
Add HubMixin to TrainPipelineConfig
Jan 23, 2025
8e3777a
Udpate example 2 & 3
Jan 23, 2025
71212b8
Update push_pretrained
Jan 23, 2025
309ae9c
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_11…
Jan 23, 2025
136bf50
Fix config_class
Jan 23, 2025
17358d6
Fix from_pretrained kwargs
Jan 23, 2025
0862e20
Remove policy_protocol
Jan 23, 2025
25f9020
Camelize PretrainedConfig
Jan 23, 2025
4e7c4dd
Additional fix while retraining policies (#629)
Cadene Jan 24, 2025
48d1817
Actually reactivate tdmpc online test
Jan 24, 2025
7824a59
Enable compute loss, backward
Cadene Jan 24, 2025
0d7fa33
local
molbap Jan 24, 2025
1ee59ed
Update example 4
Jan 24, 2025
3b1e64b
Remove advanced example 1
Jan 24, 2025
a5ab25a
Remove example 5
Jan 24, 2025
5d112c6
Move example 6 to advanced
Jan 24, 2025
220f818
Use HubMixin.save_pretrained
Jan 24, 2025
d2d536a
Enable config_path to be a repo_id
Jan 25, 2025
6c4bc32
Dry has_method
Jan 25, 2025
119b269
Update example 4
Jan 25, 2025
2d7d533
Update README
Jan 25, 2025
457e4bc
Cleanup pyproject.toml
Jan 25, 2025
22455a6
Add config resize, pad image, normalization -1 1 in model
Cadene Jan 25, 2025
645c440
fix precision issue
Cadene Jan 25, 2025
c923461
Add self.config.train_expert_only
Cadene Jan 25, 2025
ca3e362
Update eval docstring
Jan 25, 2025
d3ef145
Update README
Jan 25, 2025
2bf3d75
Clean example 4
Jan 25, 2025
105e9b8
Update README
Jan 25, 2025
b0771b7
Make 'last' checkpoint symlink relative
Jan 26, 2025
225c4f6
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_11…
Jan 26, 2025
8ac9429
Simplify example 4
Jan 26, 2025
080d8b0
Update docstrings
Jan 26, 2025
c53167c
Change default device selection, add warnings & errors
Jan 26, 2025
fc28057
add torch.compile to train
molbap Jan 27, 2025
deb51ce
Hard code dataset_stats in policy, Set fix noise=False, train_export_…
Cadene Jan 27, 2025
0eda592
Merge, Disable torch.compile
Cadene Jan 27, 2025
967ae99
Use validate instead of __post_init__
Jan 27, 2025
56a6f58
Fix
Jan 27, 2025
8cca5bd
Remove validate
Jan 27, 2025
d06505f
Fix tests
Jan 27, 2025
85c6f3a
Skip push_dataset_to_hub tests
Jan 28, 2025
ad458b6
Add exceptions
Jan 28, 2025
3551dc5
Update factories docstrings
Jan 28, 2025
d86fc23
Update validations
Jan 28, 2025
940e9d8
Remove deprecated config files
Jan 28, 2025
214083f
Update robot examples with draccus commands (#654)
Cadene Jan 28, 2025
57512d1
Add eval mode when requires_grad=False
Cadene Jan 28, 2025
aa65bb7
Add pusht hack
Jan 28, 2025
584691c
Fix
Jan 28, 2025
bac217c
Fix logging
Jan 28, 2025
742848b
Fix logging
Jan 28, 2025
cb18417
Fix policy factory
Jan 28, 2025
9f85df2
Simplify config validation logic
Jan 28, 2025
693810f
Update example 4
Jan 28, 2025
03da0a8
Revert "Add pusht hack"
Jan 29, 2025
ee67f8b
remove unused module
Cadene Jan 29, 2025
20501b7
WIP add EMA
Cadene Jan 29, 2025
f2b9845
Port base Pi0
Cadene Jan 29, 2025
68056cc
Train dana dataset with pi0 base
Cadene Jan 29, 2025
b845391
Merge remote-tracking branch 'origin/user/aliberts/2024_11_30_remove_…
Cadene Jan 29, 2025
4a25bec
Fix --control.policy.path
Cadene Jan 29, 2025
af38ed8
Add get_safe_dtype
Cadene Jan 29, 2025
7976f4f
Fix eval on real robots
Cadene Jan 30, 2025
6246399
Works
Cadene Jan 31, 2025
96050a9
Latest model (DOESNT WORK YET)
Cadene Feb 2, 2025
9eae715
add flex attention kernel
molbap Feb 2, 2025
d5114ac
fix conversion
molbap Feb 2, 2025
b9a8195
Merge branch 'main' of github.com:huggingface/lerobot into pi0
Cadene Feb 3, 2025
cb97aed
nit
Cadene Feb 3, 2025
1d3ea1d
Add comments + benchmark
Cadene Feb 3, 2025
45c65fb
benchmark no local dir, use_false=False
Cadene Feb 3, 2025
d317024
style
Cadene Feb 3, 2025
2ca3185
comment
Cadene Feb 3, 2025
f7ee800
fix
Cadene Feb 3, 2025
d4459df
fix
Cadene Feb 3, 2025
a766d20
fix
Cadene Feb 3, 2025
a218e96
Merge remote-tracking branch 'origin/main' into pi0
Cadene Feb 3, 2025
ff311a9
fix
Cadene Feb 3, 2025
5ea1e4a
fix unit tests
Cadene Feb 3, 2025
945e33d
fix
Cadene Feb 3, 2025
03f9ef7
fix
Cadene Feb 3, 2025
f3c27d4
fix
Cadene Feb 3, 2025
1a63e27
fix
Cadene Feb 3, 2025
0f8875b
Refactor
Cadene Feb 4, 2025
4ed481e
remove test-tdmpc-ete-train-with-online
Cadene Feb 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ test-end-to-end:
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online

test-act-ete-train:
python lerobot/scripts/train.py \
Expand Down Expand Up @@ -128,28 +127,29 @@ test-tdmpc-ete-eval:
--eval.batch_size=1 \
--device=$(DEVICE)

test-tdmpc-ete-train-with-online:
python lerobot/scripts/train.py \
--policy.type=tdmpc \
--env.type=pusht \
--env.obs_type=environment_state_agent_pos \
--env.episode_length=5 \
--dataset.repo_id=lerobot/pusht_keypoints \
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=2 \
--online.steps=20 \
--online.rollout_n_episodes=2 \
--online.rollout_batch_size=2 \
--online.steps_between_rollouts=10 \
--online.buffer_capacity=1000 \
--online.env_seed=10000 \
--save_checkpoint=false \
--save_freq=10 \
--log_freq=1 \
--eval.use_async_envs=true \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--device=$(DEVICE) \
--output_dir=tests/outputs/tdmpc_online/
# TODO(rcadene): fix online buffer to storing "task"
# test-tdmpc-ete-train-with-online:
# python lerobot/scripts/train.py \
# --policy.type=tdmpc \
# --env.type=pusht \
# --env.obs_type=environment_state_agent_pos \
# --env.episode_length=5 \
# --dataset.repo_id=lerobot/pusht_keypoints \
# --dataset.image_transforms.enable=true \
# --dataset.episodes="[0]" \
# --batch_size=2 \
# --offline.steps=2 \
# --online.steps=20 \
# --online.rollout_n_episodes=2 \
# --online.rollout_batch_size=2 \
# --online.steps_between_rollouts=10 \
# --online.buffer_capacity=1000 \
# --online.env_seed=10000 \
# --save_checkpoint=false \
# --save_freq=10 \
# --log_freq=1 \
# --eval.use_async_envs=true \
# --eval.n_episodes=1 \
# --eval.batch_size=1 \
# --device=$(DEVICE) \
# --output_dir=tests/outputs/tdmpc_online/
4 changes: 4 additions & 0 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,10 @@ def __getitem__(self, idx) -> dict:
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])

# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]

return item

def __repr__(self):
Expand Down
18 changes: 16 additions & 2 deletions lerobot/common/optim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
@dataclass
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
lr: float
betas: tuple[float, float]
eps: float
weight_decay: float
grad_clip_norm: float

Expand Down Expand Up @@ -54,3 +52,19 @@ def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.AdamW(params, **kwargs)


@OptimizerConfig.register_subclass("sgd")
@dataclass
class SGDConfig(OptimizerConfig):
lr: float = 1e-3
momentum: float = 0.0
dampening: float = 0.0
nesterov: bool = False
weight_decay: float = 0.0
grad_clip_norm: float = 10.0

def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.SGD(params, **kwargs)
35 changes: 35 additions & 0 deletions lerobot/common/optim/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,38 @@ def lr_lambda(current_step):
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))

return LambdaLR(optimizer, lr_lambda, -1)


@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by Physical Intelligence to train Pi0"""

num_warmup_steps: int
num_decay_steps: int
peak_lr: float
decay_lr: float

def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
del num_training_steps

def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / self.num_warmup_steps
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1

def cosine_decay_schedule(current_step):
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed

if current_step < self.num_warmup_steps:
return linear_warmup_schedule(current_step)

return cosine_decay_schedule(current_step)

return LambdaLR(optimizer, lr_lambda, -1)
1 change: 1 addition & 0 deletions lerobot/common/policies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
9 changes: 9 additions & 0 deletions lerobot/common/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
Expand All @@ -50,6 +51,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy

return VQBeTPolicy
elif name == "pi0":
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy

return PI0Policy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")

Expand All @@ -63,6 +68,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

Expand Down Expand Up @@ -141,4 +148,6 @@ def make_policy(
policy.to(device)
assert isinstance(policy, nn.Module)

# policy = torch.compile(policy, mode="reduce-overhead")

return policy
6 changes: 6 additions & 0 deletions lerobot/common/policies/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def __init__(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
continue

norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
Expand Down Expand Up @@ -210,6 +213,9 @@ def __init__(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
continue

norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
Expand Down
134 changes: 134 additions & 0 deletions lerobot/common/policies/pi0/configuration_pi0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from dataclasses import dataclass, field

from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature


@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50

normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)

# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32

# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)

# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0

# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False

# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False

# Tokenizer
tokenizer_max_length: int = 48

# Projector
proj_width: int = 1024

# Decoding
num_steps: int = 10

# Attention utils
use_cache: bool = True
attention_implementation: str = "eager" # or fa2, flex

# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True

# Training presets
optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10

scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6

# TODO: Add EMA

def __post_init__(self):
super().__post_init__()

"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)

if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
)

def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")

for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera

def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)

def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)

@property
def observation_delta_indices(self) -> None:
return None

@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))

@property
def reward_delta_indices(self) -> None:
return None
Loading