Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,21 @@ model = load_model(

#### List of pretrained models

Currently, only [DinoV2](https://github.com/facebookresearch/dinov2/) and SigLIP2 have been ported.
The following models have pretrained weights available in Equimo:

Model identifiers allows downloading from equimo's [repository on huggingface](https://huggingface.co/poiretclement/equimo/tree/main/models/default)
- [DinoV2](https://arxiv.org/abs/2304.07193),
- [SigLIP2](https://arxiv.org/abs/2502.14786),
- [TIPS](https://arxiv.org/abs/2410.16512).

Model identifiers allow downloading from equimo's [repository on huggingface](https://huggingface.co/poiretclement/equimo/tree/main/models/default)

Identifiers are filenames without the extensions, such as:

- `dinov2_vitb14`
- `dinov2_vits14_reg`
- `siglip2_vitl16_512`
- `siglip2_vitso400m16_384`
- `tips_vitg14_lr`

## Contributing

Expand Down
File renamed without changes.
19 changes: 15 additions & 4 deletions equimo/converters/utils.py → equimo/conversion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ def expand_torch_tensor(tensor, pos: str, n: int):
)


def convert_params_from_torch_hub(
def convert_params_from_torch(
jax_model: eqx.Module,
replace_cfg: Dict[str, str],
expand_cfg: Dict[str, list],
squeeze_cfg: Dict[str, int | None],
whitelist: list[str],
strict: bool = True,
source: Literal["torchhub", "timm"] = "torchhub",
source: Literal["torchhub", "timm", "custom"] = "torchhub",
torch_hub_cfg: Optional[list[str]] = None,
torch_model=None,
timm_cfg: Optional[list] = None,
return_torch: bool = False,
):
Expand All @@ -56,6 +57,7 @@ def convert_params_from_torch_hub(
strict (bool): Whether to crash on missing parameters one of the models.
source (str): Torch Hub or timm.
torch_hub_cfg (Optional[list]): args to pass to `torch.hub.load`.
torch_model [torch.nn.Module]: Custom torch model
timm_cfg (Optional[list]): args to pass to `timm.create_model`.
return_torch (bool): Return both jax and torch models.
"""
Expand All @@ -67,6 +69,11 @@ def convert_params_from_torch_hub(

# Load the pytorch model
match source:
case "custom":
if torch_model is None:
raise ValueError(
"The `custom` source is selected but `torch_model` is None."
)
case "torchhub":
if torch_hub_cfg is None:
raise ValueError(
Expand Down Expand Up @@ -149,6 +156,7 @@ def convert_torch_to_equinox(
strict: bool = True,
source: Literal["torchhub", "timm"] = "torchhub",
torch_hub_cfg: Optional[list[str]] = None,
torch_model=None,
timm_cfg: Optional[list] = None,
return_torch: bool = False,
) -> eqx.Module | Tuple[eqx.Module, Any]:
Expand All @@ -164,6 +172,7 @@ def convert_torch_to_equinox(
strict: Wether to raise an issue if not all weights are converted
source (str): Torch Hub or timm.
torch_hub_cfg: [repo, model_name] for torch.hub.load
torch_model [torch.nn.Module]: Custom torch model
timm_cfg (Optional[list]): args to pass to `timm.create_model`.
return_torch (bool): Return both jax and torch models.

Expand All @@ -172,7 +181,7 @@ def convert_torch_to_equinox(
"""
dynamic, static = eqx.partition(jax_model, eqx.is_array)
if return_torch:
converted_params, torch_model = convert_params_from_torch_hub(
converted_params, torch_model = convert_params_from_torch(
dynamic,
replace_cfg,
expand_cfg,
Expand All @@ -181,6 +190,7 @@ def convert_torch_to_equinox(
strict,
source,
torch_hub_cfg,
torch_model,
timm_cfg,
return_torch,
)
Expand All @@ -189,7 +199,7 @@ def convert_torch_to_equinox(
eqx.combine(converted_params, static), value=True
), torch_model.eval()
else:
converted_params = convert_params_from_torch_hub(
converted_params = convert_params_from_torch(
dynamic,
replace_cfg,
expand_cfg,
Expand All @@ -198,6 +208,7 @@ def convert_torch_to_equinox(
strict,
source,
torch_hub_cfg,
torch_model,
timm_cfg,
return_torch,
)
Expand Down
9 changes: 2 additions & 7 deletions equimo/converters/dinov2.py → models/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np

import equimo.models as em
from equimo.converters.utils import convert_torch_to_equinox
from equimo.io import load_model, save_model
from equimo.conversion.utils import convert_torch_to_equinox
from equimo.io import save_model


def compare(j, t) -> float:
Expand Down Expand Up @@ -151,8 +151,3 @@ def main():
torch_hub_cfg,
compression=True,
)

# _ = load_model(
# cls="vit",
# path=Path(f"~/.cache/equimo/dinov2/{name}.tar.lz4").expanduser(),
# )
2 changes: 1 addition & 1 deletion equimo/converters/siglip2.py → models/siglip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

import equimo.models as em
from equimo.converters.utils import convert_torch_to_equinox
from equimo.conversion.utils import convert_torch_to_equinox
from equimo.io import load_model, save_model


Expand Down
176 changes: 176 additions & 0 deletions models/tips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np

import equimo.models as em
from equimo.conversion.utils import convert_torch_to_equinox
from equimo.io import save_model
from tips.pytorch import image_encoder

CKPT_PATHS = {
"tips_vits14_hr": "/mnt/hdd/torch/tips/tips_oss_s14_highres_distilled_vision.npz",
"tips_vitb14_hr": "/mnt/hdd/torch/tips/tips_oss_b14_highres_distilled_vision.npz",
"tips_vitl14_hr": "/mnt/hdd/torch/tips/tips_oss_l14_highres_distilled_vision.npz",
"tips_vitso400m14_hr": "/mnt/hdd/torch/tips/tips_oss_so400m14_highres_largetext_distilled_vision.npz",
"tips_vitg14_lr": "/mnt/hdd/torch/tips/tips_oss_g14_lowres_vision.npz",
"tips_vitg14_hr": "/mnt/hdd/torch/tips/tips_oss_g14_highres_vision.npz",
}

CLS = {
"tips_vits14_hr": image_encoder.vit_small,
"tips_vitb14_hr": image_encoder.vit_base,
"tips_vitl14_hr": image_encoder.vit_large,
"tips_vitso400m14_hr": image_encoder.vit_so400m,
"tips_vitg14_lr": image_encoder.vit_giant2,
"tips_vitg14_hr": image_encoder.vit_giant2,
}


def compare(j, t) -> float:
j = np.array(j)
t = t.squeeze().detach().numpy()
return float(np.mean(np.abs(j - t)))


configs = {
"tips_vits14_hr": {
"img_size": 448,
"dim": 384,
"num_heads": [6],
"depths": [12],
},
"tips_vitb14_hr": {
"img_size": 448,
"dim": 768,
"num_heads": [12],
"depths": [12],
},
"tips_vitl14_hr": {
"img_size": 448,
"dim": 1024,
"num_heads": [16],
"depths": [24],
},
"tips_vitso400m14_hr": {
"img_size": 448,
"dim": 1152,
"num_heads": [16],
"depths": [27],
"mlp_ratio": 4304 / 1152,
},
"tips_vitg14_lr": {
"img_size": 224,
"dim": 1536,
"num_heads": [24],
"depths": [40],
"ffn_layer": "swiglufused",
},
"tips_vitg14_hr": {
"img_size": 448,
"dim": 1536,
"num_heads": [24],
"depths": [40],
"ffn_layer": "swiglufused",
},
}

citr = iter(configs.items())
name, config = next(citr)


def main():
try:
import torch
except:
raise ImportError("`torch` not available")

key = jax.random.PRNGKey(42)
base_config = {
# "img_size": 448,
"in_channels": 3,
# "dim": 384,
"patch_size": 14,
# "num_heads": [6],
# "depths": [12],
"num_classes": 0,
"use_mask_token": True,
"reg_tokens": 1,
"init_values": 1e-5,
"eps": 1e-6,
"dynamic_img_size": False,
"act_layer": "exactgelu",
}

for name, config in configs.items():
print(f"Converting {name}...")

cfg = base_config | config

tips = em.VisionTransformer(
**cfg,
key=key,
)

weights_image = dict(np.load(CKPT_PATHS[name], allow_pickle=False))
for k in weights_image:
weights_image[k] = torch.tensor(weights_image[k])

with torch.no_grad():
# Load the vision encoder.
model_image = CLS[name](
img_size=224 if "lr" in name else 448,
patch_size=14,
ffn_layer="swiglu" if "vitg" in name else "mlp",
block_chunks=0,
init_values=1.0,
interpolate_antialias=True,
interpolate_offset=0.0,
)
model_image.load_state_dict(weights_image)

replace_cfg = {
"reg_tokens": "register_tokens",
"blocks.0.blocks": "blocks",
".prenorm.": ".norm1.",
".norm.": ".norm2.",
}
expand_cfg = {"patch_embed.proj.bias": ["after", 2]}
squeeze_cfg = {
"pos_embed": 0,
"cls_token": 0,
"register_tokens": 0,
}
whitelist = []

tips, torch_model = convert_torch_to_equinox(
tips,
replace_cfg,
expand_cfg,
squeeze_cfg,
whitelist,
strict=True,
source="custom",
torch_model=model_image,
return_torch=True,
)

arr = np.random.randn(3, cfg["img_size"], cfg["img_size"])
jax_arr = jnp.array(arr)
torch_arr = torch.tensor(arr).unsqueeze(0).float()

assert (
compare(
tips.features(jax_arr, key),
torch_model.forward_features(torch_arr)["x_prenorm"],
)
< 1e-5
)

save_model(
Path(f"~/.cache/equimo/tips/{name}").expanduser(),
tips,
cfg,
compression=True,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ dev = [

[tool.setuptools]
license-files = []
packages = ["equimo"]
Loading