Skip to content

Commit efd26d6

Browse files
mshukorpre-commit-ci[bot]fracapuanoimstevenpmworkdanaaubakirova
authored andcommitted
Add SmolVLA (#1175)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: fracapuano <[email protected]> Co-authored-by: Steven Palma <[email protected]> Co-authored-by: Dana Aubakirova <[email protected]> Co-authored-by: Remi <[email protected]>
1 parent cf0669f commit efd26d6

File tree

9 files changed

+1523
-13
lines changed

9 files changed

+1523
-13
lines changed

lerobot/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,7 @@
168168
)
169169

170170
# lists all available policies from `lerobot/common/policies`
171-
available_policies = [
172-
"act",
173-
"diffusion",
174-
"tdmpc",
175-
"vqbet",
176-
]
171+
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
177172

178173
# lists all available robots from `lerobot/common/robot_devices/robots`
179174
available_robots = [

lerobot/common/policies/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
from .act.configuration_act import ACTConfig as ACTConfig
1616
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
1717
from .pi0.configuration_pi0 import PI0Config as PI0Config
18+
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
1819
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
1920
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

lerobot/common/policies/factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
2828
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
2929
from lerobot.common.policies.pretrained import PreTrainedPolicy
30+
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
3031
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
3132
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
3233
from lerobot.configs.policies import PreTrainedConfig
@@ -59,6 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
5960
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
6061

6162
return PI0FASTPolicy
63+
elif name == "smolvla":
64+
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
65+
66+
return SmolVLAPolicy
6267
else:
6368
raise NotImplementedError(f"Policy with name {name} is not implemented.")
6469

@@ -76,6 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
7681
return PI0Config(**kwargs)
7782
elif policy_type == "pi0fast":
7883
return PI0FASTConfig(**kwargs)
84+
elif policy_type == "smolvla":
85+
return SmolVLAConfig(**kwargs)
7986
else:
8087
raise ValueError(f"Policy type '{policy_type}' is not available.")
8188

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass, field
16+
17+
from lerobot.common.optim.optimizers import AdamWConfig
18+
from lerobot.common.optim.schedulers import (
19+
CosineDecayWithWarmupSchedulerConfig,
20+
)
21+
from lerobot.configs.policies import PreTrainedConfig
22+
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
23+
24+
25+
@PreTrainedConfig.register_subclass("smolvla")
26+
@dataclass
27+
class SmolVLAConfig(PreTrainedConfig):
28+
# Input / output structure.
29+
n_obs_steps: int = 1
30+
chunk_size: int = 50
31+
n_action_steps: int = 50
32+
33+
normalization_mapping: dict[str, NormalizationMode] = field(
34+
default_factory=lambda: {
35+
"VISUAL": NormalizationMode.IDENTITY,
36+
"STATE": NormalizationMode.MEAN_STD,
37+
"ACTION": NormalizationMode.MEAN_STD,
38+
}
39+
)
40+
41+
# Shorter state and action vectors will be padded
42+
max_state_dim: int = 32
43+
max_action_dim: int = 32
44+
45+
# Image preprocessing
46+
resize_imgs_with_padding: tuple[int, int] = (512, 512)
47+
48+
# Add empty images. Used by smolvla_aloha_sim which adds the empty
49+
# left and right wrist cameras in addition to the top camera.
50+
empty_cameras: int = 0
51+
52+
# Converts the joint and gripper values from the standard Aloha space to
53+
# the space used by the pi internal runtime which was used to train the base model.
54+
adapt_to_pi_aloha: bool = False
55+
56+
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
57+
# Gripper dimensions will remain in absolute values.
58+
use_delta_joint_actions_aloha: bool = False
59+
60+
# Tokenizer
61+
tokenizer_max_length: int = 48
62+
63+
# Decoding
64+
num_steps: int = 10
65+
66+
# Attention utils
67+
use_cache: bool = True
68+
69+
# Finetuning settings
70+
freeze_vision_encoder: bool = True
71+
train_expert_only: bool = True
72+
train_state_proj: bool = True
73+
74+
# Training presets
75+
optimizer_lr: float = 1e-4
76+
optimizer_betas: tuple[float, float] = (0.9, 0.95)
77+
optimizer_eps: float = 1e-8
78+
optimizer_weight_decay: float = 1e-10
79+
optimizer_grad_clip_norm: float = 10
80+
81+
scheduler_warmup_steps: int = 1_000
82+
scheduler_decay_steps: int = 30_000
83+
scheduler_decay_lr: float = 2.5e-6
84+
85+
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
86+
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
87+
88+
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
89+
90+
attention_mode: str = "cross_attn"
91+
92+
prefix_length: int = -1
93+
94+
pad_language_to: str = "longest" # "max_length"
95+
96+
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
97+
num_vlm_layers: int = 16 # Number of layers used in the VLM (first num_vlm_layers layers)
98+
self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
99+
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
100+
101+
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
102+
max_period: float = 4.0
103+
104+
def __post_init__(self):
105+
super().__post_init__()
106+
107+
"""Input validation (not exhaustive)."""
108+
if self.n_action_steps > self.chunk_size:
109+
raise ValueError(
110+
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
111+
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
112+
)
113+
if self.use_delta_joint_actions_aloha:
114+
raise NotImplementedError(
115+
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
116+
)
117+
118+
def validate_features(self) -> None:
119+
for i in range(self.empty_cameras):
120+
key = f"observation.images.empty_camera_{i}"
121+
empty_camera = PolicyFeature(
122+
type=FeatureType.VISUAL,
123+
shape=(3, 480, 640),
124+
)
125+
self.input_features[key] = empty_camera
126+
127+
def get_optimizer_preset(self) -> AdamWConfig:
128+
return AdamWConfig(
129+
lr=self.optimizer_lr,
130+
betas=self.optimizer_betas,
131+
eps=self.optimizer_eps,
132+
weight_decay=self.optimizer_weight_decay,
133+
grad_clip_norm=self.optimizer_grad_clip_norm,
134+
)
135+
136+
def get_scheduler_preset(self):
137+
return CosineDecayWithWarmupSchedulerConfig(
138+
peak_lr=self.optimizer_lr,
139+
decay_lr=self.scheduler_decay_lr,
140+
num_warmup_steps=self.scheduler_warmup_steps,
141+
num_decay_steps=self.scheduler_decay_steps,
142+
)
143+
144+
@property
145+
def observation_delta_indices(self) -> list:
146+
return [0]
147+
148+
@property
149+
def action_delta_indices(self) -> list:
150+
return list(range(self.chunk_size))
151+
152+
@property
153+
def reward_delta_indices(self) -> None:
154+
return None

0 commit comments

Comments
 (0)