-
Notifications
You must be signed in to change notification settings - Fork 3.4k
fix issues: checkpoints keys mismatch and 'task' tokenisation in smolvla #1256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
d276ac7
8ce11f7
9ff4760
080be39
cf3e07c
f58ef7f
5947cb5
83af770
d0af78f
6c98188
5c3820c
380fd05
4e1b049
0d7f089
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,11 +14,13 @@ | |
| import abc | ||
| import logging | ||
| import os | ||
| import re | ||
| from pathlib import Path | ||
| from typing import Type, TypeVar | ||
| from typing import Dict, Tuple, Type, TypeVar | ||
|
|
||
| import packaging | ||
| import safetensors | ||
| import torch | ||
| from huggingface_hub import hf_hub_download | ||
| from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE | ||
| from huggingface_hub.errors import HfHubHTTPError | ||
|
|
@@ -42,6 +44,92 @@ | |
| - Docs: {{ docs_url | default("[More Information Needed]", true) }} | ||
| """ | ||
|
|
||
| # Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker | ||
| _VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") | ||
|
|
||
|
|
||
| def canonicalise(k: str) -> str: | ||
| """ | ||
| Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a | ||
| normalisation-buffer key. | ||
| """ | ||
| return _VARIANT_RE.sub(".buffer_", k) | ||
|
|
||
|
|
||
| def standardise_state_dict( | ||
| ckpt: Dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True | ||
danaaubakirova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> Tuple[Dict[str, torch.Tensor], list[str]]: | ||
| """ | ||
| • Re-keys `ckpt` so that every entry matches the *reference* key set. | ||
danaaubakirova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| • If several variant keys collapse to the same canonical name we keep the | ||
| first one and log the collision. | ||
| • Returns the new dict + a list of entries that could not be matched. | ||
| """ | ||
| out, collisions, unmatched = {}, {}, [] | ||
|
|
||
| for k, v in ckpt.items(): | ||
danaaubakirova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| canon = canonicalise(k) | ||
| if canon in ref_keys: | ||
| if canon in out: # duplicate after collapsing | ||
| collisions.setdefault(canon, []).append(k) | ||
| else: | ||
| out[canon] = v | ||
| else: | ||
| unmatched.append(k) | ||
|
|
||
| if verbose: | ||
| for canon, variants in collisions.items(): | ||
| print(f"[standardise_state_dict] '{canon}' ← {variants}") | ||
| if unmatched: | ||
| print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") | ||
|
|
||
| out.update({k: ckpt[k] for k in unmatched}) | ||
danaaubakirova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return out, unmatched | ||
|
|
||
|
|
||
| def rename_checkpoint_keys(ckpt, rename_str): | ||
danaaubakirova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Renames keys in a checkpoint dictionary based on the given rename string. | ||
|
|
||
| Args: | ||
| ckpt (dict): The checkpoint dictionary. | ||
danaaubakirova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". | ||
|
|
||
| Returns: | ||
| dict: The modified checkpoint with renamed keys. | ||
| """ | ||
|
|
||
| rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) | ||
|
|
||
| new_ckpt = {} | ||
| for k, v in ckpt.items(): | ||
| for old_key, new_key in rename_dict.items(): | ||
| if old_key in k: | ||
| k = k.replace(old_key, new_key) | ||
| new_ckpt[k] = v | ||
| return new_ckpt | ||
danaaubakirova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def load_model( | ||
| model: torch.nn.Module, | ||
| filename: str | os.PathLike, | ||
| *, | ||
| strict: bool = True, | ||
| device: str | int = "cpu", | ||
|
||
| checkpoint_keys_mapping: str = "", | ||
| ) -> tuple[list[str], list[str]]: | ||
| state_dict = safetensors.torch.load_file(filename, device=device) | ||
|
|
||
| # Optional user-supplied renames (e.g. "model._orig_mod.//model.") | ||
| if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: | ||
| state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) | ||
|
|
||
| state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) | ||
|
|
||
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | ||
|
|
||
| return missing, unexpected | ||
|
|
||
|
|
||
| class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): | ||
| """ | ||
|
|
@@ -148,6 +236,13 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric | |
| model.to(map_location) | ||
| else: | ||
| safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) | ||
| missing, unexpected = load_model( | ||
| model, | ||
| model_file, | ||
| strict=strict, | ||
| device=map_location, | ||
| checkpoint_keys_mapping="model._orig_mod.//model.", | ||
|
||
| ) | ||
| return model | ||
|
|
||
| # def generate_model_card(self, *args, **kwargs) -> ModelCard: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.