Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions models/dinov3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
from pathlib import Path

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np

import equimo.models as em
from equimo.conversion.utils import convert_torch_to_equinox
from equimo.io import save_model

DIR = Path("~/.cache/torch/hub/dinov3").expanduser()


def compare(j, t) -> float:
j = np.array(j)
t = t.squeeze().detach().numpy()
return float(np.mean(np.abs(j - t)))


weights = {
# LVD
"dinov3_vits16_pretrain_lvd1689m": str(
(
Path(
"~/.cache/torch/hub/dinov3/weights/dinov3_vits16_pretrain_lvd1689m-08c60483.pth"
)
).expanduser()
),
"dinov3_vits16plus_pretrain_lvd1689m": str(
(
Path(
"~/.cache/torch/hub/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
)
).expanduser()
),
"dinov3_vitb16_pretrain_lvd1689m": str(
(
Path(
"~/.cache/torch/hub/dinov3/weights/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth"
)
).expanduser()
),
"dinov3_vitl16_pretrain_lvd1689m": str(
(
Path(
"~/.cache/torch/hub/dinov3/weights/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth"
)
).expanduser()
),
"dinov3_vith16plus_pretrain_lvd1689m": str(
(
Path(
"~/.cache/torch/hub/dinov3/weights/dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth"
)
).expanduser()
),
"dinov3_vit7b16_pretrain_lvd1689m": str(
(
Path(
"~/.cache/torch/hub/dinov3/weights/dinov3_vit7b16_pretrain_lvd1689m-a955f4ea.pth"
)
).expanduser()
),
# SAT
"dinov3_vitl16_pretrain_sat493m": str(
(
Path(
"~/.cache/torch/hub/dinov3/weights/dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth"
)
).expanduser()
),
"dinov3_vit7b16_pretrain_sat493m": str(
(
Path(
"~/.cache/torch/hub/dinov3/weights/dinov3_vit7b16_pretrain_sat493m-a6675841.pth"
)
).expanduser()
),
}

configs = {
"dinov3_vits16_pretrain_lvd1689m": {
"dim": 384,
"num_heads": 6,
"depths": [12],
"reg_tokens": 4,
"mlp_ratio": 4.0,
},
"dinov3_vits16plus_pretrain_lvd1689m": {
"dim": 384,
"num_heads": 6,
"depths": [12],
"reg_tokens": 4,
"mlp_ratio": 6.0,
"ffn_layer": "swiglu",
},
"dinov3_vitb16_pretrain_lvd1689m": {
"dim": 768,
"num_heads": 12,
"depths": [12],
"reg_tokens": 4,
"mlp_ratio": 4.0,
},
"dinov3_vitl16_pretrain_lvd1689m": {
"dim": 1024,
"num_heads": 16,
"depths": [24],
"reg_tokens": 4,
"mlp_ratio": 4.0,
},
"dinov3_vith16plus_pretrain_lvd1689m": {
"dim": 1280,
"num_heads": 20,
"depths": [32],
"reg_tokens": 4,
"mlp_ratio": 6.0,
"ffn_layer": "swiglu",
},
"dinov3_vit7b16_pretrain_lvd1689m": {
"dim": 4096,
"num_heads": 32,
"depths": [40],
"reg_tokens": 4,
"mlp_ratio": 3.0,
"untie_global_and_local_cls_norm": True,
"ffn_kwargs": {"align_to": 64},
},
"dinov3_vitl16_pretrain_sat493m": {
"dim": 1024,
"num_heads": 16,
"depths": [24],
"reg_tokens": 4,
"mlp_ratio": 4.0,
"untie_global_and_local_cls_norm": True,
},
"dinov3_vit7b16_pretrain_sat493m": {
"dim": 4096,
"num_heads": 32,
"depths": [40],
"reg_tokens": 4,
"mlp_ratio": 3.0,
"untie_global_and_local_cls_norm": True,
"ffn_kwargs": {"align_to": 64},
},
}

citr = iter(configs.items())
name, config = next(citr)


def main():
try:
import torch
except:
raise ImportError("`torch` not available")

key = jax.random.PRNGKey(42)
dinov3_config = {
"img_size": 224,
"in_channels": 3,
"patch_size": 16,
"num_classes": 0,
"use_mask_token": True,
"use_rope_pos_embed": True,
"reg_tokens": 4,
"init_values": 1e-5,
"eps": 1e-5,
"dynamic_img_size": True,
"act_layer": "exactgelu",
}

for name, config in configs.items():
print(f"Converting {name}...")

cfg = dinov3_config | config

dinov3 = em.VisionTransformer(
**cfg,
key=key,
)

torch_name = "_".join(name.split("_")[:-2])
torch_hub_cfg = {
"repo_or_dir": str(DIR / "dinov3"),
"model": torch_name,
"source": "local",
"weights": weights[name],
}
# model = torch.hub.load(**torch_hub_cfg)

replace_cfg = {
"reg_tokens": "storage_tokens",
"blocks.0.blocks": "blocks",
".prenorm.": ".norm1.",
".norm.": ".norm2.",
}
expand_cfg = {"patch_embed.proj.bias": ["after", 2]}
squeeze_cfg = {
"pos_embed": 0,
"cls_token": 0,
"storage_tokens": 0,
}
torch_whitelist = []
jax_whitelist = ["pos_embed.periods"]

dinov3, torch_model = convert_torch_to_equinox(
dinov3,
replace_cfg,
expand_cfg,
squeeze_cfg,
torch_whitelist,
jax_whitelist,
strict=True,
torch_hub_cfg=torch_hub_cfg,
return_torch=True,
)
dinov3 = eqx.nn.inference_mode(dinov3, True)

arr = np.random.randn(3, cfg["img_size"], cfg["img_size"])
jax_arr = jnp.array(arr)
torch_arr = torch.tensor(arr).unsqueeze(0).float()

assert (
err := compare(
dinov3.features(jax_arr, inference=True, key=key),
torch_model.forward_features(torch_arr)["x_prenorm"],
)
< 5e-4
), f"Conversion error: {err}"

save_path = Path(f"~/.cache/equimo/dinov3/{name}").expanduser()
save_model(
save_path,
dinov3,
cfg,
torch_hub_cfg,
compression=True,
)

# Ensure the serialization is okay
# loaded_model = load_model(cls="vit", path=save_path.with_suffix(".tar.lz4"))
# a = dinov3.features(jax_arr, inference=True, key=key)
# b = loaded_model.features(jax_arr, inference=True, key=key)
# jnp.mean((a - b) ** 2)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[project]
name = "Equimo"
version = "0.4.0-alpha.14"
version = "0.4.1"
description = "Implementation of popular vision models in Jax"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
dependencies = [
"einops>=0.8.0",
"equinox>=0.11.5",
Expand Down
2 changes: 1 addition & 1 deletion src/equimo/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.0-alpha.14"
__version__ = "0.4.1"
46 changes: 30 additions & 16 deletions src/equimo/conversion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ def convert_params_from_torch(
replace_cfg: Dict[str, str],
expand_cfg: Dict[str, list],
squeeze_cfg: Dict[str, int | None],
whitelist: list[str],
torch_whitelist: list[str],
jax_whitelist: list[str],
strict: bool = True,
source: Literal["torchhub", "timm", "custom"] = "torchhub",
torch_hub_cfg: Optional[list[str]] = None,
torch_hub_cfg: Optional[dict] = None,
torch_model=None,
timm_cfg: Optional[list] = None,
return_torch: bool = False,
Expand All @@ -49,14 +50,15 @@ def convert_params_from_torch(

Args:
jax_model (eqx.Module): A preexisting Jax model corresponding to the checkpoint to download.
torch_hub_cfg (Tuple[str]): Arguments passed to `torch.hub.load()`.
torch_hub_cfg (dict): Arguments passed to `torch.hub.load()`.
replace_cfg (Dict[str, str]): Rename parameters from key to value.
expand_cfg (Dict[str, list]): Config to reshape params, see `expand_torch_tensor`
sqeeze_cfg (Dict[str, int|None]): Config to squeeze tensors, opposite of expand.
whitelist (Set[str]): Parameters to exclude from format conversion.
torch_whitelist (Set[str]): Parameters to exclude from format conversion.
jax_whitelist (Set[str]): Parameters to exclude from format conversion.
strict (bool): Whether to crash on missing parameters one of the models.
source (str): Torch Hub or timm.
torch_hub_cfg (Optional[list]): args to pass to `torch.hub.load`.
torch_hub_cfg (dict): args to pass to `torch.hub.load`.
torch_model [torch.nn.Module]: Custom torch model
timm_cfg (Optional[list]): args to pass to `timm.create_model`.
return_torch (bool): Return both jax and torch models.
Expand All @@ -79,7 +81,7 @@ def convert_params_from_torch(
raise ValueError(
"The `torchhub` source is selected but `torch_hub_cfg` is None."
)
torch_model = torch.hub.load(*torch_hub_cfg)
torch_model = torch.hub.load(**torch_hub_cfg)
case "timm":
if timm_cfg is None:
raise ValueError(
Expand Down Expand Up @@ -107,12 +109,20 @@ def convert_params_from_torch(

if param_path not in torch_params:
_msg = f"{param_path} ({shape}) not found in PyTorch model."
if strict:
if strict and param_path not in jax_whitelist:
logger.error(_msg)
raise AttributeError(_msg)

logger.warning(f"{_msg} Appending `None` to flat param list.")
torch_params_flat.append(None)
if param_path in jax_whitelist:
p = param
logger.warning(
f"{_msg} Appending original parameters to flat param list because of `jax_whitelist`."
)
else:
p = None
logger.warning(f"{_msg} Appending `None` to flat param list.")

torch_params_flat.append(p)
continue

logger.info(f"Converting {param_path}...")
Expand All @@ -137,7 +147,7 @@ def convert_params_from_torch(
logger.warning(
f"PyTorch parameters `{path}` ({param.shape}) were not converted."
)
if strict and path not in whitelist:
if strict and path not in torch_whitelist:
_msg = f"The PyTorch model contains parameters ({path}) that do not have a Jax counterpart."
logger.error(_msg)
raise AttributeError(_msg)
Expand All @@ -152,10 +162,11 @@ def convert_torch_to_equinox(
replace_cfg: dict = {},
expand_cfg: dict = {},
squeeze_cfg: dict = {},
whitelist: list[str] = [],
torch_whitelist: list[str] = [],
jax_whitelist: list[str] = [],
strict: bool = True,
source: Literal["torchhub", "timm"] = "torchhub",
torch_hub_cfg: Optional[list[str]] = None,
torch_hub_cfg: Optional[dict] = None,
torch_model=None,
timm_cfg: Optional[list] = None,
return_torch: bool = False,
Expand All @@ -168,10 +179,11 @@ def convert_torch_to_equinox(
replace_cfg: Dict of parameter name replacements
expand_cfg: Dict of dimensions to expand
squeeze_cfg: Dict of dimensions to squeeze
whitelist: List of parameters to keep from JAX model
torch_whitelist: List of parameters allowed to be in PT model but not in Jax
jax_whitelist: List of parameters allowed to be in Jax model but not in PT
strict: Wether to raise an issue if not all weights are converted
source (str): Torch Hub or timm.
torch_hub_cfg: [repo, model_name] for torch.hub.load
torch_hub_cfg (dict): torch.hub.load config
torch_model [torch.nn.Module]: Custom torch model
timm_cfg (Optional[list]): args to pass to `timm.create_model`.
return_torch (bool): Return both jax and torch models.
Expand All @@ -186,7 +198,8 @@ def convert_torch_to_equinox(
replace_cfg,
expand_cfg,
squeeze_cfg,
whitelist,
torch_whitelist,
jax_whitelist,
strict,
source,
torch_hub_cfg,
Expand All @@ -204,7 +217,8 @@ def convert_torch_to_equinox(
replace_cfg,
expand_cfg,
squeeze_cfg,
whitelist,
torch_whitelist,
jax_whitelist,
strict,
source,
torch_hub_cfg,
Expand Down
2 changes: 1 addition & 1 deletion src/equimo/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def save_model(
path: Path,
model: eqx.Module,
model_config: dict,
torch_hub_cfg: list[str] = [],
torch_hub_cfg: list[str] | dict = {},
timm_cfg: list = [],
compression: bool = True,
):
Expand Down
Loading
Loading