Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion lerobot/common/policies/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import abc
import logging
import os
import re
from pathlib import Path
from typing import Type, TypeVar
from typing import Dict, Tuple, Type, TypeVar

import packaging
import safetensors
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
Expand All @@ -42,6 +44,92 @@
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
"""

# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")


def canonicalise(k: str) -> str:
"""
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
normalisation-buffer key.
"""
return _VARIANT_RE.sub(".buffer_", k)


def standardise_state_dict(
ckpt: Dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
) -> Tuple[Dict[str, torch.Tensor], list[str]]:
"""
• Re-keys `ckpt` so that every entry matches the *reference* key set.
• If several variant keys collapse to the same canonical name we keep the
first one and log the collision.
• Returns the new dict + a list of entries that could not be matched.
"""
out, collisions, unmatched = {}, {}, []

for k, v in ckpt.items():
canon = canonicalise(k)
if canon in ref_keys:
if canon in out: # duplicate after collapsing
collisions.setdefault(canon, []).append(k)
else:
out[canon] = v
else:
unmatched.append(k)

if verbose:
for canon, variants in collisions.items():
print(f"[standardise_state_dict] '{canon}' ← {variants}")
if unmatched:
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")

out.update({k: ckpt[k] for k in unmatched})
return out, unmatched


def rename_checkpoint_keys(ckpt, rename_str):
"""
Renames keys in a checkpoint dictionary based on the given rename string.

Args:
ckpt (dict): The checkpoint dictionary.
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".

Returns:
dict: The modified checkpoint with renamed keys.
"""

rename_dict = dict(pair.split("//") for pair in rename_str.split(","))

new_ckpt = {}
for k, v in ckpt.items():
for old_key, new_key in rename_dict.items():
if old_key in k:
k = k.replace(old_key, new_key)
new_ckpt[k] = v
return new_ckpt


def load_model(
model: torch.nn.Module,
filename: str | os.PathLike,
*,
strict: bool = True,
device: str | int = "cpu",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean to have an int device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

checkpoint_keys_mapping: str = "",
) -> tuple[list[str], list[str]]:
state_dict = safetensors.torch.load_file(filename, device=device)

# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)

state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))

missing, unexpected = model.load_state_dict(state_dict, strict=False)

return missing, unexpected


class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
Expand Down Expand Up @@ -148,6 +236,13 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric
model.to(map_location)
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
missing, unexpected = load_model(
model,
model_file,
strict=strict,
device=map_location,
checkpoint_keys_mapping="model._orig_mod.//model.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where's this string mapping coming from?

)
return model

# def generate_model_card(self, *args, **kwargs) -> ModelCard:
Expand Down
4 changes: 4 additions & 0 deletions lerobot/common/policies/smolvla/modeling_smolvla.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,14 @@ def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]

if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]

tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]

tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding=self.config.pad_language_to,
Expand Down
Loading