diff --git a/README.md b/README.md index fcfc7a0..eeee49d 100644 --- a/README.md +++ b/README.md @@ -128,9 +128,13 @@ model = load_model( #### List of pretrained models -Currently, only [DinoV2](https://github.com/facebookresearch/dinov2/) and SigLIP2 have been ported. +The following models have pretrained weights available in Equimo: -Model identifiers allows downloading from equimo's [repository on huggingface](https://huggingface.co/poiretclement/equimo/tree/main/models/default) +- [DinoV2](https://arxiv.org/abs/2304.07193), +- [SigLIP2](https://arxiv.org/abs/2502.14786), +- [TIPS](https://arxiv.org/abs/2410.16512). + +Model identifiers allow downloading from equimo's [repository on huggingface](https://huggingface.co/poiretclement/equimo/tree/main/models/default) Identifiers are filenames without the extensions, such as: @@ -138,6 +142,7 @@ Identifiers are filenames without the extensions, such as: - `dinov2_vits14_reg` - `siglip2_vitl16_512` - `siglip2_vitso400m16_384` +- `tips_vitg14_lr` ## Contributing diff --git a/equimo/converters/__init__.py b/equimo/conversion/__init__.py similarity index 100% rename from equimo/converters/__init__.py rename to equimo/conversion/__init__.py diff --git a/equimo/converters/utils.py b/equimo/conversion/utils.py similarity index 92% rename from equimo/converters/utils.py rename to equimo/conversion/utils.py index 91460da..406996f 100644 --- a/equimo/converters/utils.py +++ b/equimo/conversion/utils.py @@ -31,15 +31,16 @@ def expand_torch_tensor(tensor, pos: str, n: int): ) -def convert_params_from_torch_hub( +def convert_params_from_torch( jax_model: eqx.Module, replace_cfg: Dict[str, str], expand_cfg: Dict[str, list], squeeze_cfg: Dict[str, int | None], whitelist: list[str], strict: bool = True, - source: Literal["torchhub", "timm"] = "torchhub", + source: Literal["torchhub", "timm", "custom"] = "torchhub", torch_hub_cfg: Optional[list[str]] = None, + torch_model=None, timm_cfg: Optional[list] = None, return_torch: bool = False, ): @@ -56,6 +57,7 @@ def convert_params_from_torch_hub( strict (bool): Whether to crash on missing parameters one of the models. source (str): Torch Hub or timm. torch_hub_cfg (Optional[list]): args to pass to `torch.hub.load`. + torch_model [torch.nn.Module]: Custom torch model timm_cfg (Optional[list]): args to pass to `timm.create_model`. return_torch (bool): Return both jax and torch models. """ @@ -67,6 +69,11 @@ def convert_params_from_torch_hub( # Load the pytorch model match source: + case "custom": + if torch_model is None: + raise ValueError( + "The `custom` source is selected but `torch_model` is None." + ) case "torchhub": if torch_hub_cfg is None: raise ValueError( @@ -149,6 +156,7 @@ def convert_torch_to_equinox( strict: bool = True, source: Literal["torchhub", "timm"] = "torchhub", torch_hub_cfg: Optional[list[str]] = None, + torch_model=None, timm_cfg: Optional[list] = None, return_torch: bool = False, ) -> eqx.Module | Tuple[eqx.Module, Any]: @@ -164,6 +172,7 @@ def convert_torch_to_equinox( strict: Wether to raise an issue if not all weights are converted source (str): Torch Hub or timm. torch_hub_cfg: [repo, model_name] for torch.hub.load + torch_model [torch.nn.Module]: Custom torch model timm_cfg (Optional[list]): args to pass to `timm.create_model`. return_torch (bool): Return both jax and torch models. @@ -172,7 +181,7 @@ def convert_torch_to_equinox( """ dynamic, static = eqx.partition(jax_model, eqx.is_array) if return_torch: - converted_params, torch_model = convert_params_from_torch_hub( + converted_params, torch_model = convert_params_from_torch( dynamic, replace_cfg, expand_cfg, @@ -181,6 +190,7 @@ def convert_torch_to_equinox( strict, source, torch_hub_cfg, + torch_model, timm_cfg, return_torch, ) @@ -189,7 +199,7 @@ def convert_torch_to_equinox( eqx.combine(converted_params, static), value=True ), torch_model.eval() else: - converted_params = convert_params_from_torch_hub( + converted_params = convert_params_from_torch( dynamic, replace_cfg, expand_cfg, @@ -198,6 +208,7 @@ def convert_torch_to_equinox( strict, source, torch_hub_cfg, + torch_model, timm_cfg, return_torch, ) diff --git a/equimo/converters/dinov2.py b/models/dinov2.py similarity index 93% rename from equimo/converters/dinov2.py rename to models/dinov2.py index 67275c9..c7853d2 100644 --- a/equimo/converters/dinov2.py +++ b/models/dinov2.py @@ -5,8 +5,8 @@ import numpy as np import equimo.models as em -from equimo.converters.utils import convert_torch_to_equinox -from equimo.io import load_model, save_model +from equimo.conversion.utils import convert_torch_to_equinox +from equimo.io import save_model def compare(j, t) -> float: @@ -151,8 +151,3 @@ def main(): torch_hub_cfg, compression=True, ) - - # _ = load_model( - # cls="vit", - # path=Path(f"~/.cache/equimo/dinov2/{name}.tar.lz4").expanduser(), - # ) diff --git a/equimo/converters/siglip2.py b/models/siglip2.py similarity index 97% rename from equimo/converters/siglip2.py rename to models/siglip2.py index c00fe16..fe06f15 100644 --- a/equimo/converters/siglip2.py +++ b/models/siglip2.py @@ -5,7 +5,7 @@ import numpy as np import equimo.models as em -from equimo.converters.utils import convert_torch_to_equinox +from equimo.conversion.utils import convert_torch_to_equinox from equimo.io import load_model, save_model diff --git a/models/tips.py b/models/tips.py new file mode 100644 index 0000000..60baaf8 --- /dev/null +++ b/models/tips.py @@ -0,0 +1,176 @@ +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import equimo.models as em +from equimo.conversion.utils import convert_torch_to_equinox +from equimo.io import save_model +from tips.pytorch import image_encoder + +CKPT_PATHS = { + "tips_vits14_hr": "/mnt/hdd/torch/tips/tips_oss_s14_highres_distilled_vision.npz", + "tips_vitb14_hr": "/mnt/hdd/torch/tips/tips_oss_b14_highres_distilled_vision.npz", + "tips_vitl14_hr": "/mnt/hdd/torch/tips/tips_oss_l14_highres_distilled_vision.npz", + "tips_vitso400m14_hr": "/mnt/hdd/torch/tips/tips_oss_so400m14_highres_largetext_distilled_vision.npz", + "tips_vitg14_lr": "/mnt/hdd/torch/tips/tips_oss_g14_lowres_vision.npz", + "tips_vitg14_hr": "/mnt/hdd/torch/tips/tips_oss_g14_highres_vision.npz", +} + +CLS = { + "tips_vits14_hr": image_encoder.vit_small, + "tips_vitb14_hr": image_encoder.vit_base, + "tips_vitl14_hr": image_encoder.vit_large, + "tips_vitso400m14_hr": image_encoder.vit_so400m, + "tips_vitg14_lr": image_encoder.vit_giant2, + "tips_vitg14_hr": image_encoder.vit_giant2, +} + + +def compare(j, t) -> float: + j = np.array(j) + t = t.squeeze().detach().numpy() + return float(np.mean(np.abs(j - t))) + + +configs = { + "tips_vits14_hr": { + "img_size": 448, + "dim": 384, + "num_heads": [6], + "depths": [12], + }, + "tips_vitb14_hr": { + "img_size": 448, + "dim": 768, + "num_heads": [12], + "depths": [12], + }, + "tips_vitl14_hr": { + "img_size": 448, + "dim": 1024, + "num_heads": [16], + "depths": [24], + }, + "tips_vitso400m14_hr": { + "img_size": 448, + "dim": 1152, + "num_heads": [16], + "depths": [27], + "mlp_ratio": 4304 / 1152, + }, + "tips_vitg14_lr": { + "img_size": 224, + "dim": 1536, + "num_heads": [24], + "depths": [40], + "ffn_layer": "swiglufused", + }, + "tips_vitg14_hr": { + "img_size": 448, + "dim": 1536, + "num_heads": [24], + "depths": [40], + "ffn_layer": "swiglufused", + }, +} + +citr = iter(configs.items()) +name, config = next(citr) + + +def main(): + try: + import torch + except: + raise ImportError("`torch` not available") + + key = jax.random.PRNGKey(42) + base_config = { + # "img_size": 448, + "in_channels": 3, + # "dim": 384, + "patch_size": 14, + # "num_heads": [6], + # "depths": [12], + "num_classes": 0, + "use_mask_token": True, + "reg_tokens": 1, + "init_values": 1e-5, + "eps": 1e-6, + "dynamic_img_size": False, + "act_layer": "exactgelu", + } + + for name, config in configs.items(): + print(f"Converting {name}...") + + cfg = base_config | config + + tips = em.VisionTransformer( + **cfg, + key=key, + ) + + weights_image = dict(np.load(CKPT_PATHS[name], allow_pickle=False)) + for k in weights_image: + weights_image[k] = torch.tensor(weights_image[k]) + + with torch.no_grad(): + # Load the vision encoder. + model_image = CLS[name]( + img_size=224 if "lr" in name else 448, + patch_size=14, + ffn_layer="swiglu" if "vitg" in name else "mlp", + block_chunks=0, + init_values=1.0, + interpolate_antialias=True, + interpolate_offset=0.0, + ) + model_image.load_state_dict(weights_image) + + replace_cfg = { + "reg_tokens": "register_tokens", + "blocks.0.blocks": "blocks", + ".prenorm.": ".norm1.", + ".norm.": ".norm2.", + } + expand_cfg = {"patch_embed.proj.bias": ["after", 2]} + squeeze_cfg = { + "pos_embed": 0, + "cls_token": 0, + "register_tokens": 0, + } + whitelist = [] + + tips, torch_model = convert_torch_to_equinox( + tips, + replace_cfg, + expand_cfg, + squeeze_cfg, + whitelist, + strict=True, + source="custom", + torch_model=model_image, + return_torch=True, + ) + + arr = np.random.randn(3, cfg["img_size"], cfg["img_size"]) + jax_arr = jnp.array(arr) + torch_arr = torch.tensor(arr).unsqueeze(0).float() + + assert ( + compare( + tips.features(jax_arr, key), + torch_model.forward_features(torch_arr)["x_prenorm"], + ) + < 1e-5 + ) + + save_model( + Path(f"~/.cache/equimo/tips/{name}").expanduser(), + tips, + cfg, + compression=True, + ) diff --git a/pyproject.toml b/pyproject.toml index a8f6b51..4fafc5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,3 +43,4 @@ dev = [ [tool.setuptools] license-files = [] +packages = ["equimo"]