Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions lerobot/common/policies/smolvla/modeling_smolvla.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@
"""

import math
import os
import re
from collections import deque

import safetensors
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
Expand All @@ -73,6 +76,98 @@
)
from lerobot.common.utils.utils import get_safe_dtype

# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")


def canonicalise(k: str) -> str:
"""
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
normalisation-buffer key.
"""
return _VARIANT_RE.sub(".buffer_", k)


def standardise_state_dict(
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
) -> tuple[dict[str, torch.Tensor], list[str]]:
"""
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
• If several variant keys collapse to the same canonical name we keep the
first one and log the collision.
• Returns the new dict + a list of entries that could not be matched.
"""
out, collisions, unmatched = {}, {}, []

for k, v in checkpoint.items():
canon = canonicalise(k)
if canon in ref_keys:
if canon in out: # duplicate after collapsing
collisions.setdefault(canon, []).append(k)
else:
out[canon] = v
else:
unmatched.append(k)

if verbose:
for canon, variants in collisions.items():
print(f"[standardise_state_dict] '{canon}' ← {variants}")
if unmatched:
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")

out.update({k: checkpoint[k] for k in unmatched})
return out, unmatched


def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
"""
Renames keys in a checkpoint dictionary based on the given rename string.

Args:
checkpoint (dict): The checkpoint dictionary.
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".

Returns:
dict: The modified checkpoint with renamed keys.
"""

rename_dict = dict(pair.split("//") for pair in rename_str.split(","))

new_checkpoint = {}
for k, v in checkpoint.items():
for old_key, new_key in rename_dict.items():
if old_key in k:
k = k.replace(old_key, new_key)
new_checkpoint[k] = v
return new_checkpoint


def load_smolvla(
model: torch.nn.Module,
filename: str | os.PathLike,
*,
device: str = "cpu",
checkpoint_keys_mapping: str = "",
) -> torch.nn.Module:
state_dict = safetensors.torch.load_file(filename, device=device)

# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)

state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))

missing, unexpected = model.load_state_dict(state_dict)

if missing or unexpected:
raise RuntimeError(
"SmolVLA %d missing / %d unexpected keys",
len(missing),
len(unexpected),
)

return model


def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
Expand Down Expand Up @@ -264,6 +359,23 @@ def reset(self):
ACTION: deque(maxlen=self.config.n_action_steps),
}

# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
@classmethod
def _load_as_safetensor(
cls,
model: "SmolVLAPolicy",
model_file: str,
map_location: str,
strict: bool,
):
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return load_smolvla(
model,
model_file,
device=map_location,
checkpoint_keys_mapping="model._orig_mod.//model.",
)

def get_optim_params(self) -> dict:
return self.parameters()

Expand Down Expand Up @@ -387,10 +499,14 @@ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]

if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]

tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]

tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding=self.config.pad_language_to,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ intelrealsense = [
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
]
pi0 = ["transformers>=4.48.0"]
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
Expand Down
Loading