From c6ced752401fc7779fce5774c2c7632cdcf04cbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 01/10] feat(vit): separate `pos_emb` class --- src/equimo/layers/posemb.py | 143 ++++++++++++++++++++++++++++++++++++ src/equimo/models/vit.py | 134 ++++++--------------------------- 2 files changed, 166 insertions(+), 111 deletions(-) diff --git a/src/equimo/layers/posemb.py b/src/equimo/layers/posemb.py index 9646829..bdfc335 100644 --- a/src/equimo/layers/posemb.py +++ b/src/equimo/layers/posemb.py @@ -1,3 +1,4 @@ +import math from typing import Any, Literal, Optional, Tuple import equinox as eqx @@ -8,6 +9,148 @@ from jaxtyping import Array, Float, PRNGKeyArray +class LearnedPosEmbed(eqx.Module): + weight: jax.Array + + dim: int = eqx.field(static=True) + embed_size: int = eqx.field(static=True) + num_prefix_tokens: int = eqx.field(static=True) + num_embedded_prefix_tokens: int = eqx.field(static=True) + no_embed_class: bool = eqx.field(static=True) + pos_embed_reg_tokens: bool = eqx.field(static=True) + + antialias: bool = eqx.field(static=True, default=True) + + def resample( + self, + *, + new_size: tuple[int, int], + dim: int, + num_embedded_prefix_tokens: int, + old_size: tuple[int, int] | None, + interpolation: str = "bicubic", + ) -> jax.Array: + """Resample positional embeddings for different input sizes. + + Args: + new_size: Target size (height, width) + dim: Dimensionality of the sequence + num_embedded_prefix_tokens: To include cls and reg tokens + old_size: Original size (height, width), computed if None + interpolation: Interpolation method + + Returns: + Resampled positional embeddings + """ + pe = self.weight + prev_dtype = pe.dtype + H, W = new_size + dim = self.dim if dim is None else dim + num_embedded_prefix_tokens = ( + self.num_embedded_prefix_tokens + if num_embedded_prefix_tokens is None + else num_embedded_prefix_tokens + ) + + tgt_len = H * W + num_embedded_prefix_tokens + if ( + (tgt_len == pe.shape[0]) + and (old_size is not None) + and (H == W == old_size[0]) + ): + return pe + + if old_size is None: + L = pe.shape[0] - num_embedded_prefix_tokens + hw = int(math.sqrt(L)) + old_size = (hw, hw) + + prefix = pe[:num_embedded_prefix_tokens] if num_embedded_prefix_tokens else None + grid = pe[num_embedded_prefix_tokens:].astype(jnp.float32) + grid = rearrange(grid, "(h w) d -> h w d", h=old_size[0], w=old_size[1]) + grid = jax.image.resize( + grid, (H, W, dim), method=interpolation, antialias=self.antialias + ) + grid = rearrange(grid, "h w d -> (h w) d").astype(prev_dtype) + if prefix is not None: + grid = jnp.concatenate([prefix, grid], axis=0) + return grid + + def __call__( + self, + x: jax.Array, + *, + cls_token: Optional[jax.Array], + reg_tokens: Optional[jax.Array], + dynamic_img_size: bool, + interpolation: str = "bicubic", + ) -> jax.Array: + """Compose tokens and add positional embeddings. + + Inputs: + - x: + - If dynamic_img_size: shape (C, H, W) from PatchEmbedding(flatten=False) + - Else: shape ((H*W), C) from PatchEmbedding(flatten=True) + - cls_token: shape (1, dim) or None + - reg_tokens: shape (R, dim) or None + - dynamic_img_size: whether x is spatial or already flattened + + Returns: + - Token sequence with positional information and optional prefix tokens. + """ + if dynamic_img_size: + C, H, W = x.shape + assert C == self.dim, f"Channel dim mismatch: {C} vs {self.dim}" + pos_embed = self.resample( + new_size=(H, W), + old_size=(self.embed_size, self.embed_size), + interpolation=interpolation, + ) + x = rearrange(x, "c h w -> (h w) c") + else: + pos_embed = self.weight + + to_cat = [] + if cls_token is not None: + # Expect (1, dim) + assert cls_token.shape[-1] == self.dim and cls_token.shape[0] == 1 + to_cat.append(cls_token) + if reg_tokens is not None: + # Expect (R, dim) + assert reg_tokens.ndim == 2 and reg_tokens.shape[-1] == self.dim + to_cat.append(reg_tokens) + + # Branching exactly mirrors your current _pos_embed logic + if self.no_embed_class: + # Add pos to patches only; then prepend any prefix tokens (cls/reg) + x = x + pos_embed + if to_cat: + x = jnp.concatenate(to_cat + [x], axis=0) + + elif self.pos_embed_reg_tokens: + # Prefix tokens are included in the positional grid length; concat first, then add + if to_cat: + x = jnp.concatenate(to_cat + [x], axis=0) + x = x + pos_embed + + else: + # Only class token is embedded with patches; reg tokens (if any) are inserted after + # the class token and before the patch tokens. + # Note: this branch assumes that if reg_tokens are used, a cls_token exists too. + if cls_token is None and reg_tokens is not None: + raise ValueError( + "Configuration invalid: reg_tokens without cls_token when pos_embed_reg_tokens=False " + "and no_embed_class=False." + ) + x = jnp.concatenate(to_cat[:1] + [x], axis=0) # cat cls_token if present + x = x + pos_embed + if reg_tokens is not None: + # Insert reg_tokens between cls and patch tokens + x = jnp.concatenate([x[:1], reg_tokens, x[1:]], axis=0) + + return x + + class PosEmbMLPSwinv1D(eqx.Module): """1D Positional Embedding using MLP for Swin Transformer. diff --git a/src/equimo/models/vit.py b/src/equimo/models/vit.py index a0fcfad..107ffa1 100644 --- a/src/equimo/models/vit.py +++ b/src/equimo/models/vit.py @@ -1,5 +1,4 @@ -import math -from typing import Callable, List, Literal, Optional, Tuple +from typing import Callable, List, Literal, Optional import equinox as eqx import jax @@ -18,7 +17,7 @@ from equimo.layers.ffn import Mlp, get_ffn from equimo.layers.norm import get_norm from equimo.layers.patch import PatchEmbedding -from equimo.layers.posemb import PosCNN +from equimo.layers.posemb import LearnedPosEmbed, PosCNN from equimo.utils import pool_sd, to_list @@ -142,10 +141,10 @@ class VisionTransformer(eqx.Module): """ patch_embed: PatchEmbedding - pos_embed: jnp.ndarray - cls_token: jnp.ndarray | None - reg_tokens: jnp.ndarray | None - mask_token: jnp.ndarray | None + pos_embed: LearnedPosEmbed + cls_token: jax.Array | None + reg_tokens: jax.Array | None + mask_token: jax.Array | None blocks: List[eqx.Module] pos_drop: eqx.nn.Dropout norm: eqx.Module @@ -180,6 +179,7 @@ def __init__( class_token: bool = True, no_embed_class: bool = False, reg_tokens: int = 4, + rope_pos_embed: bool = False, pos_embed_reg_tokens: bool = False, pos_drop_rate: float = 0.0, drop_path_rate: float = 0.0, @@ -252,7 +252,16 @@ def __init__( self.num_embedded_prefix_tokens += 1 self.embed_len = self.num_patches + 1 - self.pos_embed = jr.normal(key_posemb, (self.embed_len, dim)) + self.pos_embed = LearnedPosEmbed( + weight=jr.normal(key_posemb, (self.embed_len, dim)), + dim=dim, + embed_size=self.embed_size, + num_prefix_tokens=self.num_prefix_tokens, + num_embedded_prefix_tokens=self.num_embedded_prefix_tokens, + no_embed_class=self.no_embed_class, + pos_embed_reg_tokens=self.pos_embed_reg_tokens, + antialias=interpolate_antialias, + ) self.pos_drop = eqx.nn.Dropout(pos_drop_rate) if drop_path_uniform: @@ -296,107 +305,6 @@ def __init__( else eqx.nn.Identity() ) - def resample_pos_embed( - self, - pos_embed: Float[Array, "embed_len dim"], - new_size: Tuple[int, int], - old_size: Optional[Tuple[int, int]] = None, - interpolation: str = "bicubic", - antialias: bool = True, - ): - """Resample positional embeddings for different input sizes. - - Args: - pos_embed: Original positional embeddings - new_size: Target size (height, width) - old_size: Original size (height, width), computed if None - interpolation: Interpolation method - antialias: Whether to use antialiasing - - Returns: - Resampled positional embeddings - """ - previous_dtype = pos_embed.dtype - - num_new_tokens = new_size[0] * new_size[1] + self.num_embedded_prefix_tokens - - if num_new_tokens == self.embed_len and new_size[0] == new_size[1]: - return pos_embed - - if old_size is None: - hw = int(math.sqrt(self.num_patches)) - old_size = hw, hw - - prefix_embed = ( - pos_embed[: self.num_embedded_prefix_tokens] - if self.num_embedded_prefix_tokens - else None - ) - pos_embed = pos_embed[self.num_embedded_prefix_tokens :].astype("float32") - - pos_embed = rearrange( - pos_embed, "(h w) d -> h w d", h=old_size[0], w=old_size[1] - ) - pos_embed = jax.image.resize( - pos_embed, - (new_size[0], new_size[1], self.dim), - method=interpolation, - antialias=antialias, - ) - pos_embed = rearrange(pos_embed, "h w d -> (h w) d").astype(previous_dtype) - - if prefix_embed is not None: - pos_embed = jnp.concatenate([prefix_embed, pos_embed], axis=0) - - return pos_embed - - def _pos_embed(self, x: Float[Array, "..."], h: int, w: int): - """Add positional embeddings to input features. - - Args: - x: Input features - h: Height of feature map - w: Width of feature map - - Returns: - Features with positional embeddings and tokens added - """ - if self.pos_embed is None: - return rearrange(x, "c h w -> (h w) c") - - if self.dynamic_img_size: - C, H, W = x.shape - pos_embed = self.resample_pos_embed( - self.pos_embed, new_size=(H, W), antialias=self.antialias - ) - x = rearrange(x, "c h w -> (h w) c") - else: - pos_embed = self.pos_embed - - to_cat = [] - if self.cls_token is not None: - to_cat.append(self.cls_token) - if self.reg_tokens is not None: - to_cat.append(self.reg_tokens) - - if self.no_embed_class: - x = x + pos_embed - if to_cat: - x = jnp.concatenate(to_cat + [x], axis=0) - elif self.pos_embed_reg_tokens: - if to_cat: - x = jnp.concatenate(to_cat + [x], axis=0) - x = x + pos_embed - else: - x = jnp.concatenate(to_cat[:1] + [x], axis=0) # cat cls_token - x = x + pos_embed - if self.reg_tokens is not None: - x = jnp.concatenate( - [x[:1], to_cat[1], x[1:]], axis=0 - ) # insert reg_tokens in between - - return x - def features( self, x: Float[Array, "channels height width"], @@ -431,8 +339,12 @@ def features( value = self.mask_token x = jnp.where(mask, x, value.astype(x.dtype)) - # TODO: Decompose in multiple fns - x = self._pos_embed(x, h=self.embed_size, w=self.embed_size) + x = self.pos_embed( + x, + cls_token=self.cls_token, + reg_tokens=self.reg_tokens, + dynamic_img_size=self.dynamic_img_size, + ) for blk, key_block in zip(self.blocks, block_subkeys): x = blk(x, inference=inference, key=key_block, **kwargs) From feccad9754e491c27e79c0b5a479c30bfc4c8b98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 02/10] fix(posemb): resample kwargs --- src/equimo/layers/posemb.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/equimo/layers/posemb.py b/src/equimo/layers/posemb.py index bdfc335..2bc5796 100644 --- a/src/equimo/layers/posemb.py +++ b/src/equimo/layers/posemb.py @@ -24,10 +24,10 @@ class LearnedPosEmbed(eqx.Module): def resample( self, *, - new_size: tuple[int, int], - dim: int, - num_embedded_prefix_tokens: int, - old_size: tuple[int, int] | None, + new_size: Tuple[int, int], + dim: int | None = None, + num_embedded_prefix_tokens: int | None = None, + old_size: Optional[Tuple[int, int]] = None, interpolation: str = "bicubic", ) -> jax.Array: """Resample positional embeddings for different input sizes. @@ -452,15 +452,15 @@ def __call__( class DinoRoPE(eqx.Module): """Axial RoPE that produces per-position sin/cos for later rotation of features. - - Enforces embed_dim % (4 * num_heads) == 0. + - Enforces dim % (4 * num_heads) == 0. - Periods can be specified via `base` or `min_period` + `max_period` (mutually exclusive). - Coordinates are normalized to [-1, 1] according to `normalize_coords`. - Optional training-time augmentations: shift, jitter (log-uniform per-axis), rescale (log-uniform shared). - - Returns (sin, cos) with shape [H*W, D_head], where D_head = embed_dim // num_heads. + - Returns (sin, cos) with shape [H*W, D_head], where D_head = dim // num_heads. Parameters ---------- - embed_dim: int + dim: int Total embedding dimension (across heads). num_heads: int Number of attention heads. @@ -497,7 +497,7 @@ class DinoRoPE(eqx.Module): def __init__( self, - embed_dim: int, + dim: int, *, num_heads: int, base: Optional[float] = 100.0, @@ -509,8 +509,8 @@ def __init__( rescale_coords: Optional[float] = None, dtype: jnp.dtype = jnp.float32, ): - if embed_dim % (4 * num_heads) != 0: - raise ValueError("embed_dim must be divisible by 4 * num_heads.") + if dim % (4 * num_heads) != 0: + raise ValueError("dim must be divisible by 4 * num_heads.") both_periods = (min_period is not None) and (max_period is not None) if (base is None and not both_periods) or (base is not None and both_periods): raise ValueError( @@ -525,7 +525,7 @@ def __init__( self.rescale_coords = rescale_coords self.dtype = dtype - self.D_head = embed_dim // num_heads + self.D_head = dim // num_heads D_quarter = self.D_head // 4 if base is not None: From 2a059ef66d6186634e5aa730730e2c0ed54c0d9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 03/10] feat(vit): modular rope posemb --- src/equimo/layers/posemb.py | 31 ++++++----- src/equimo/models/vit.py | 100 ++++++++++++++++++++++++++++-------- 2 files changed, 98 insertions(+), 33 deletions(-) diff --git a/src/equimo/layers/posemb.py b/src/equimo/layers/posemb.py index 2bc5796..c839188 100644 --- a/src/equimo/layers/posemb.py +++ b/src/equimo/layers/posemb.py @@ -483,6 +483,9 @@ class DinoRoPE(eqx.Module): ----- - The `periods` buffer is persistent (part of the tree) and not trainable; we stop gradients on it inside `__call__`. + - I had to separate `dtype` and `periods_dtype`. For some obscure reasons, I faced cases + with the reference PyTorch impl. where `periods` were computed in bfloat16 (wanted behavior), + but subsequent computations (coords, angles, cos, sin) were at a float32 precision. """ D_head: int = eqx.field(static=True) @@ -507,6 +510,7 @@ def __init__( shift_coords: Optional[float] = None, jitter_coords: Optional[float] = None, rescale_coords: Optional[float] = None, + periods_dtype: jnp.dtype = jnp.bfloat16, dtype: jnp.dtype = jnp.float32, ): if dim % (4 * num_heads) != 0: @@ -530,36 +534,38 @@ def __init__( if base is not None: denom = self.D_head // 2 - k = jnp.arange(D_quarter, dtype=dtype) + k = jnp.arange(D_quarter, dtype=periods_dtype) periods = base ** (2.0 * k / float(denom)) else: # Geometric progression from min_period to max_period (inclusive endpoints behavior per torch linspace) assert min_period is not None and max_period is not None base_ratio = max_period / min_period - exponents = jnp.linspace(0.0, 1.0, D_quarter, dtype=dtype) + exponents = jnp.linspace(0.0, 1.0, D_quarter, dtype=periods_dtype) periods = base_ratio**exponents # in [1, base_ratio] periods = periods / base_ratio # in [1/base_ratio, 1] periods = periods * max_period # in [min_period, max_period] - periods = periods.astype(dtype) + periods = periods.astype(periods_dtype) # Persistent buffer (will be copied with the tree; we stop gradients in __call__) - self.periods = periods + self.periods = periods.astype(dtype) def _make_coords(self, H: int, W: int) -> jnp.ndarray: """Create normalized coords in [-1, 1], shape [H*W, 2], dtype=self.dtype.""" dtype = self.dtype + # WARNING: I removed `dtype=dtype` in those jnp.arange fns because it was + # creating a discrepancy w/ dinov3 pytorch impl. if self.normalize_coords == "max": denom = float(max(H, W)) - coords_h = jnp.arange(0.5, H, step=1.0, dtype=dtype) / denom # [H] - coords_w = jnp.arange(0.5, W, step=1.0, dtype=dtype) / denom # [W] + coords_h = jnp.arange(0.5, H, step=1.0) / denom # [H] + coords_w = jnp.arange(0.5, W, step=1.0) / denom # [W] elif self.normalize_coords == "min": denom = float(min(H, W)) - coords_h = jnp.arange(0.5, H, step=1.0, dtype=dtype) / denom - coords_w = jnp.arange(0.5, W, step=1.0, dtype=dtype) / denom + coords_h = jnp.arange(0.5, H, step=1.0) / denom + coords_w = jnp.arange(0.5, W, step=1.0) / denom else: # "separate" - coords_h = jnp.arange(0.5, H, step=1.0, dtype=dtype) / float(H) - coords_w = jnp.arange(0.5, W, step=1.0, dtype=dtype) / float(W) + coords_h = jnp.arange(0.5, H, step=1.0) / float(H) + coords_w = jnp.arange(0.5, W, step=1.0) / float(W) hh, ww = jnp.meshgrid(coords_h, coords_w, indexing="ij") # [H, W] coords = jnp.stack([hh, ww], axis=-1).reshape(H * W, 2) # [HW, 2] @@ -567,14 +573,14 @@ def _make_coords(self, H: int, W: int) -> jnp.ndarray: return coords.astype(dtype) - def __call__( + def get_sincos( self, *, H: int, W: int, key: jax.Array, inference: Optional[bool] = None, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jax.Array, jax.Array]: """Compute (sin, cos) with shapes [H*W, D_head]. If `inference is False`, training-time augmentations may be applied @@ -632,6 +638,7 @@ def __call__( cos = jnp.cos(angles).astype(dtype) # [HW, D_head] sin = jnp.sin(angles).astype(dtype) # [HW, D_head] + return sin, cos diff --git a/src/equimo/models/vit.py b/src/equimo/models/vit.py index 107ffa1..3ba78db 100644 --- a/src/equimo/models/vit.py +++ b/src/equimo/models/vit.py @@ -17,7 +17,7 @@ from equimo.layers.ffn import Mlp, get_ffn from equimo.layers.norm import get_norm from equimo.layers.patch import PatchEmbedding -from equimo.layers.posemb import LearnedPosEmbed, PosCNN +from equimo.layers.posemb import DinoRoPE, LearnedPosEmbed, PosCNN from equimo.utils import pool_sd, to_list @@ -141,7 +141,7 @@ class VisionTransformer(eqx.Module): """ patch_embed: PatchEmbedding - pos_embed: LearnedPosEmbed + pos_embed: LearnedPosEmbed | DinoRoPE cls_token: jax.Array | None reg_tokens: jax.Array | None mask_token: jax.Array | None @@ -159,6 +159,7 @@ class VisionTransformer(eqx.Module): num_embedded_prefix_tokens: int = eqx.field(static=True) no_embed_class: bool = eqx.field(static=True) pos_embed_reg_tokens: bool = eqx.field(static=True) + use_rope_pos_embed: bool = eqx.field(static=True) embed_len: int = eqx.field(static=True) dynamic_img_size: bool = eqx.field(static=True) antialias: bool = eqx.field(static=True) @@ -179,7 +180,15 @@ def __init__( class_token: bool = True, no_embed_class: bool = False, reg_tokens: int = 4, - rope_pos_embed: bool = False, + use_rope_pos_embed: bool = False, + rope_pos_embed_base: float = 100.0, + rope_pos_embed_min_period: Optional[float] = None, + rope_pos_embed_max_period: Optional[float] = None, + rope_pos_embed_normalize_coords: Literal["min", "max", "separate"] = "separate", + rope_pos_embed_shift_coords: Optional[float] = None, + rope_pos_embed_jitter_coords: Optional[float] = None, + rope_pos_embed_rescale_coords: Optional[float] = None, + rope_pos_embed_dtype: jnp.dtype = jnp.float32, pos_embed_reg_tokens: bool = False, pos_drop_rate: float = 0.0, drop_path_rate: float = 0.0, @@ -218,6 +227,7 @@ def __init__( self.pos_embed_reg_tokens = pos_embed_reg_tokens self.global_pool = global_pool self.embed_size = img_size // patch_size + self.use_rope_pos_embed = use_rope_pos_embed block = get_attention_block(block) attn_layer = get_attention(attn_layer) @@ -252,16 +262,34 @@ def __init__( self.num_embedded_prefix_tokens += 1 self.embed_len = self.num_patches + 1 - self.pos_embed = LearnedPosEmbed( - weight=jr.normal(key_posemb, (self.embed_len, dim)), - dim=dim, - embed_size=self.embed_size, - num_prefix_tokens=self.num_prefix_tokens, - num_embedded_prefix_tokens=self.num_embedded_prefix_tokens, - no_embed_class=self.no_embed_class, - pos_embed_reg_tokens=self.pos_embed_reg_tokens, - antialias=interpolate_antialias, - ) + if use_rope_pos_embed: + if not isinstance(num_heads, int): + raise ValueError( + "RoPE pos embedding currently requires a static number of heads." + ) + self.pos_embed = DinoRoPE( + dim=dim, + num_heads=num_heads, + base=rope_pos_embed_base, + min_period=rope_pos_embed_min_period, + max_period=rope_pos_embed_max_period, + normalize_coords=rope_pos_embed_normalize_coords, + shift_coords=rope_pos_embed_shift_coords, + jitter_coords=rope_pos_embed_jitter_coords, + rescale_coords=rope_pos_embed_rescale_coords, + dtype=rope_pos_embed_dtype, + ) + else: + self.pos_embed = LearnedPosEmbed( + weight=jr.normal(key_posemb, (self.embed_len, dim)), + dim=dim, + embed_size=self.embed_size, + num_prefix_tokens=self.num_prefix_tokens, + num_embedded_prefix_tokens=self.num_embedded_prefix_tokens, + no_embed_class=self.no_embed_class, + pos_embed_reg_tokens=self.pos_embed_reg_tokens, + antialias=interpolate_antialias, + ) self.pos_drop = eqx.nn.Dropout(pos_drop_rate) if drop_path_uniform: @@ -324,7 +352,7 @@ def features( Returns: Processed feature tensor """ - key_posdrop, *block_subkeys = jr.split(key, len(self.blocks) + 1) + key_pos, *block_subkeys = jr.split(key, len(self.blocks) + 1) x = self.patch_embed(x) if mask is not None: @@ -339,15 +367,45 @@ def features( value = self.mask_token x = jnp.where(mask, x, value.astype(x.dtype)) - x = self.pos_embed( - x, - cls_token=self.cls_token, - reg_tokens=self.reg_tokens, - dynamic_img_size=self.dynamic_img_size, - ) + if self.use_rope_pos_embed: + # In models like Dinov3, RoPE is not applied here, but in self attention blocks + # It means that we have to dumbly cat prefix token and flattened x manually + _, H, W = x.shape + if inference: + rope_sincos = self.pos_embed.get_sincos( + H=H, W=W, inference=inference, key=key_pos + ) + x = jnp.concatenate( + [ + self.cls_token, + self.reg_tokens, + rearrange(x, "c h w -> (h w) c"), + ], + axis=0, + ) + else: + # TODO: pos drop + rope_sincos = None + x = self.pos_embed( + x, + cls_token=self.cls_token, + reg_tokens=self.reg_tokens, + dynamic_img_size=self.dynamic_img_size, + ) for blk, key_block in zip(self.blocks, block_subkeys): - x = blk(x, inference=inference, key=key_block, **kwargs) + if self.use_rope_pos_embed and not inference: + key_pos, key_rope = jr.split(key_pos, 2) + rope_sincos = ( + self.pos_embed.get_sincos( + H=H, W=W, inference=inference, key=key_rope + ) + if self.use_rope_pos_embed + else None + ) + x = blk( + x, rope_sincos=rope_sincos, inference=inference, key=key_block, **kwargs + ) return x From 6f21292cc79c5acec2052cbe7bd6eb2a876a45cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 04/10] feat(attention): optional rope on `q` and `k` --- src/equimo/layers/attention.py | 82 +++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/src/equimo/layers/attention.py b/src/equimo/layers/attention.py index 46e7791..6b66d54 100644 --- a/src/equimo/layers/attention.py +++ b/src/equimo/layers/attention.py @@ -16,6 +16,70 @@ from equimo.utils import nearest_power_of_2_divisor +def rope_rotate_half(x: jax.Array) -> jax.Array: + """Rotate last-dim pairs by 90 degrees: [x0..x_{D/2-1}, x_{D/2}..x_{D-1}] + -> [-x_{D/2}..-x_{D-1}, x0..x_{D/2-1}] + """ + x1, x2 = jnp.split(x, 2, axis=-1) + return jnp.concatenate([-x2, x1], axis=-1) + + +def rope_apply(x: jax.Array, sin: jax.Array, cos: jax.Array) -> jax.Array: + """Apply RoPE to `x` using per-position sin/cos. + + Shapes: + - x: [..., D] + - sin: [..., D] + - cos: [..., D] + Broadcasting across leading axes is supported. + """ + return (x * cos) + (rope_rotate_half(x) * sin) + + +def rope_apply_qk_last_hw( + q: jax.Array, k: jax.Array, sin: jax.Array, cos: jax.Array +) -> tuple[jax.Array, jax.Array]: + """Apply RoPE to the last HW tokens of q and k (keeping any prefix unchanged). + + Shapes: + - q, k: [H, N, D] (heads, tokens, head_dim) + - sin, cos: [HW, D] + - N = prefix + HW + + Returns: + - q_rot, k_rot: same shape as inputs. + """ + if sin.shape[-1] != q.shape[-1] or cos.shape[-1] != q.shape[-1]: + raise ValueError( + f"sin/cos last dim must equal head_dim; got {sin.shape[-1]} vs {q.shape[-1]}" + ) + N = q.shape[-2] + HW = sin.shape[-2] + prefix = N - HW + if prefix < 0: + raise ValueError(f"Sequence length N={N} smaller than HW={HW}.") + + q_dtype, k_dtype = q.dtype, k.dtype + rope_dtype = sin.dtype + + q = q.astype(rope_dtype) + k = k.astype(rope_dtype) + + # Broadcast sin/cos across heads + sin_b = sin[None, :, :] # [1, HW, D] + cos_b = cos[None, :, :] # [1, HW, D] + + q_prefix, q_tail = jnp.split(q, [prefix], axis=-2) # axis -2 is tokens + k_prefix, k_tail = jnp.split(k, [prefix], axis=-2) + + q_tail = rope_apply(q_tail, sin_b, cos_b) + k_tail = rope_apply(k_tail, sin_b, cos_b) + + q_out = jnp.concatenate([q_prefix, q_tail], axis=-2).astype(q_dtype) + k_out = jnp.concatenate([k_prefix, k_tail], axis=-2).astype(k_dtype) + return q_out, k_out + + class Attention(eqx.Module): """Multi-head self attention module. @@ -76,6 +140,7 @@ def __call__( key: PRNGKeyArray, inference: Optional[bool] = None, mask: Optional[Float[Array, ""]] = None, + rope_sincos: Optional[Tuple[jax.Array, jax.Array]] = None, ) -> Float[Array, "seqlen dim"]: key1, key2 = jr.split(key, 2) @@ -91,6 +156,15 @@ def __call__( q = jax.vmap(jax.vmap(self.q_norm))(q) k = jax.vmap(jax.vmap(self.k_norm))(k) + if rope_sincos is not None: + sin, cos = rope_sincos # [HW, D], [HW, D] + if sin.shape[-1] != self.head_dim or cos.shape[-1] != self.head_dim: + raise ValueError( + f"RoPE sin/cos last dim ({sin.shape[-1]}) must equal head_dim ({self.head_dim})." + ) + # leave any prefix tokens untouched; rotate last HW tokens + q, k = rope_apply_qk_last_hw(q, k, sin, cos) + attn = jnp.einsum("hqd,hkd->hqk", q, k) / jnp.sqrt(self.head_dim) if mask is not None: @@ -317,6 +391,7 @@ def __call__( self, x: Float[Array, "seqlen dim"], key: PRNGKeyArray, + rope_sincos: Optional[Tuple[jax.Array, jax.Array]] = None, inference: Optional[bool] = None, mask: Optional[Float[Array, ""]] = None, ) -> Float[Array, "seqlen dim"]: @@ -325,6 +400,11 @@ def __call__( # I chose to define extra args here rather than passing mask directly # because not all attention mechanisms support masks as args extra_kwargs = {"mask": mask} if mask is not None else {} + attn_kwargs = ( + extra_kwargs | {"rope_sincos": rope_sincos} + if rope_sincos is not None + else extra_kwargs + ) x = self.drop_path1( x, @@ -334,7 +414,7 @@ def __call__( jax.vmap(self.prenorm)(x), inference=inference, key=key_attn, - **extra_kwargs, + **attn_kwargs, ) ) ), From e6a26742f53f5e406e7e7636b50abd3331bc7d77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 05/10] fix(conversion): minor torch hub fixes --- src/equimo/conversion/utils.py | 46 ++++++++++++++++++++++------------ src/equimo/io.py | 2 +- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/equimo/conversion/utils.py b/src/equimo/conversion/utils.py index 406996f..26930e8 100644 --- a/src/equimo/conversion/utils.py +++ b/src/equimo/conversion/utils.py @@ -36,10 +36,11 @@ def convert_params_from_torch( replace_cfg: Dict[str, str], expand_cfg: Dict[str, list], squeeze_cfg: Dict[str, int | None], - whitelist: list[str], + torch_whitelist: list[str], + jax_whitelist: list[str], strict: bool = True, source: Literal["torchhub", "timm", "custom"] = "torchhub", - torch_hub_cfg: Optional[list[str]] = None, + torch_hub_cfg: Optional[dict] = None, torch_model=None, timm_cfg: Optional[list] = None, return_torch: bool = False, @@ -49,14 +50,15 @@ def convert_params_from_torch( Args: jax_model (eqx.Module): A preexisting Jax model corresponding to the checkpoint to download. - torch_hub_cfg (Tuple[str]): Arguments passed to `torch.hub.load()`. + torch_hub_cfg (dict): Arguments passed to `torch.hub.load()`. replace_cfg (Dict[str, str]): Rename parameters from key to value. expand_cfg (Dict[str, list]): Config to reshape params, see `expand_torch_tensor` sqeeze_cfg (Dict[str, int|None]): Config to squeeze tensors, opposite of expand. - whitelist (Set[str]): Parameters to exclude from format conversion. + torch_whitelist (Set[str]): Parameters to exclude from format conversion. + jax_whitelist (Set[str]): Parameters to exclude from format conversion. strict (bool): Whether to crash on missing parameters one of the models. source (str): Torch Hub or timm. - torch_hub_cfg (Optional[list]): args to pass to `torch.hub.load`. + torch_hub_cfg (dict): args to pass to `torch.hub.load`. torch_model [torch.nn.Module]: Custom torch model timm_cfg (Optional[list]): args to pass to `timm.create_model`. return_torch (bool): Return both jax and torch models. @@ -79,7 +81,7 @@ def convert_params_from_torch( raise ValueError( "The `torchhub` source is selected but `torch_hub_cfg` is None." ) - torch_model = torch.hub.load(*torch_hub_cfg) + torch_model = torch.hub.load(**torch_hub_cfg) case "timm": if timm_cfg is None: raise ValueError( @@ -107,12 +109,20 @@ def convert_params_from_torch( if param_path not in torch_params: _msg = f"{param_path} ({shape}) not found in PyTorch model." - if strict: + if strict and param_path not in jax_whitelist: logger.error(_msg) raise AttributeError(_msg) - logger.warning(f"{_msg} Appending `None` to flat param list.") - torch_params_flat.append(None) + if param_path in jax_whitelist: + p = param + logger.warning( + f"{_msg} Appending original parameters to flat param list because of `jax_whitelist`." + ) + else: + p = None + logger.warning(f"{_msg} Appending `None` to flat param list.") + + torch_params_flat.append(p) continue logger.info(f"Converting {param_path}...") @@ -137,7 +147,7 @@ def convert_params_from_torch( logger.warning( f"PyTorch parameters `{path}` ({param.shape}) were not converted." ) - if strict and path not in whitelist: + if strict and path not in torch_whitelist: _msg = f"The PyTorch model contains parameters ({path}) that do not have a Jax counterpart." logger.error(_msg) raise AttributeError(_msg) @@ -152,10 +162,11 @@ def convert_torch_to_equinox( replace_cfg: dict = {}, expand_cfg: dict = {}, squeeze_cfg: dict = {}, - whitelist: list[str] = [], + torch_whitelist: list[str] = [], + jax_whitelist: list[str] = [], strict: bool = True, source: Literal["torchhub", "timm"] = "torchhub", - torch_hub_cfg: Optional[list[str]] = None, + torch_hub_cfg: Optional[dict] = None, torch_model=None, timm_cfg: Optional[list] = None, return_torch: bool = False, @@ -168,10 +179,11 @@ def convert_torch_to_equinox( replace_cfg: Dict of parameter name replacements expand_cfg: Dict of dimensions to expand squeeze_cfg: Dict of dimensions to squeeze - whitelist: List of parameters to keep from JAX model + torch_whitelist: List of parameters allowed to be in PT model but not in Jax + jax_whitelist: List of parameters allowed to be in Jax model but not in PT strict: Wether to raise an issue if not all weights are converted source (str): Torch Hub or timm. - torch_hub_cfg: [repo, model_name] for torch.hub.load + torch_hub_cfg (dict): torch.hub.load config torch_model [torch.nn.Module]: Custom torch model timm_cfg (Optional[list]): args to pass to `timm.create_model`. return_torch (bool): Return both jax and torch models. @@ -186,7 +198,8 @@ def convert_torch_to_equinox( replace_cfg, expand_cfg, squeeze_cfg, - whitelist, + torch_whitelist, + jax_whitelist, strict, source, torch_hub_cfg, @@ -204,7 +217,8 @@ def convert_torch_to_equinox( replace_cfg, expand_cfg, squeeze_cfg, - whitelist, + torch_whitelist, + jax_whitelist, strict, source, torch_hub_cfg, diff --git a/src/equimo/io.py b/src/equimo/io.py index 161129f..6d9b097 100644 --- a/src/equimo/io.py +++ b/src/equimo/io.py @@ -25,7 +25,7 @@ def save_model( path: Path, model: eqx.Module, model_config: dict, - torch_hub_cfg: list[str] = [], + torch_hub_cfg: list[str] | dict = {}, timm_cfg: list = [], compression: bool = True, ): From b4ec21a6602075e83a6e2eebe4bc9bcd6ebf0372 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 06/10] chore: bump version --- pyproject.toml | 4 ++-- src/equimo/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a840a74..64f137f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [project] name = "Equimo" -version = "0.4.0-alpha.14" +version = "0.4.1" description = "Implementation of popular vision models in Jax" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" dependencies = [ "einops>=0.8.0", "equinox>=0.11.5", diff --git a/src/equimo/__init__.py b/src/equimo/__init__.py index d744fa9..3d26edf 100644 --- a/src/equimo/__init__.py +++ b/src/equimo/__init__.py @@ -1 +1 @@ -__version__ = "0.4.0-alpha.14" +__version__ = "0.4.1" From 823c3183ff6bfbfccbec2a95753a4fb956d27f54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 07/10] feat(vit): optional separate cls norm --- src/equimo/models/vit.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/equimo/models/vit.py b/src/equimo/models/vit.py index 3ba78db..7f836ac 100644 --- a/src/equimo/models/vit.py +++ b/src/equimo/models/vit.py @@ -148,6 +148,7 @@ class VisionTransformer(eqx.Module): blocks: List[eqx.Module] pos_drop: eqx.nn.Dropout norm: eqx.Module + local_cls_norm: eqx.Module | None head: eqx.Module dim: int = eqx.field(static=True) @@ -205,6 +206,7 @@ def __init__( ffn_layer: str | eqx.Module = Mlp, ffn_bias: bool = True, norm_layer: str | eqx.Module = eqx.nn.LayerNorm, + untie_global_and_local_cls_norm: bool = False, init_values: float | None = None, global_pool: Literal["", "token", "avg", "avgmax", "max"] = "avg", num_classes: int = 1000, @@ -327,6 +329,13 @@ def __init__( ] self.norm = norm_layer(dim, eps=eps) + + # WARNING: This has no effect in the code. + # This norm layer is created to hold some training-only norm layer of Dinov3 + self.local_cls_norm = ( + norm_layer(dim, eps=eps) if untie_global_and_local_cls_norm else None + ) + self.head = ( eqx.nn.Linear(dim, num_classes, key=key_head) if num_classes > 0 From e810a6ac6c3a075f37c1234fa91444efcd017230 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 08/10] fix(swiglu): non-fused swiglu ffn --- src/equimo/layers/ffn.py | 122 ++++++++++++++++++++++++++++++--------- 1 file changed, 95 insertions(+), 27 deletions(-) diff --git a/src/equimo/layers/ffn.py b/src/equimo/layers/ffn.py index 3e6bbe6..fef039f 100644 --- a/src/equimo/layers/ffn.py +++ b/src/equimo/layers/ffn.py @@ -244,14 +244,8 @@ class SwiGlu(eqx.Module): a gating mechanism where the input is transformed by two parallel paths and combined multiplicatively. - The computation flow is: - 1. Joint projection to higher dimension (w12) - 2. Split into two paths - 3. Apply SiLU to first path and multiply with second path - 4. Project back to original dimension (w3) - Attributes: - w12: Joint projection layer for both paths + w1, w2: projection layers for both paths w3: Final projection layer drop1: Dropout after gating drop2: Dropout after final projection @@ -260,7 +254,8 @@ class SwiGlu(eqx.Module): [1]: https://arxiv.org/pdf/2002.05202 """ - w12: eqx.nn.Linear + w1: eqx.nn.Linear + w2: eqx.nn.Linear w3: eqx.nn.Linear drop1: eqx.nn.Dropout drop2: eqx.nn.Dropout @@ -273,6 +268,7 @@ def __init__( out_features: int | None = None, hidden_features: int | None = None, dropout_rate: float = 0.0, + align_to: int = 8, bias: bool = True, **kwargs, ): @@ -284,19 +280,25 @@ def __init__( out_features: Number of output features (default: same as in_features) hidden_features: Size of hidden dimension (default: same as in_features) dropout_rate: Dropout probability (default: 0.0) + align_to: constrains hidden features to be a multiple of a given int (default: 8) bias: Whether to include bias in linear layers (default: True) **kwargs: Additional arguments """ - key_fc1, key_fc2 = jr.split(key, 2) + key_fc1, key_fc2, key_fc3 = jr.split(key, 3) - hidden_features = hidden_features or in_features out_features = out_features or in_features + hidden_features = hidden_features or in_features + d = int(hidden_features * 2 / 3) + hidden_features = d + (-d % align_to) - self.w12 = eqx.nn.Linear( - in_features, 2 * hidden_features, use_bias=bias, key=key_fc1 + self.w1 = eqx.nn.Linear( + in_features, hidden_features, use_bias=bias, key=key_fc1 + ) + self.w2 = eqx.nn.Linear( + in_features, hidden_features, use_bias=bias, key=key_fc2 ) self.w3 = eqx.nn.Linear( - hidden_features, out_features, use_bias=bias, key=key_fc2 + hidden_features, out_features, use_bias=bias, key=key_fc3 ) self.drop1 = eqx.nn.Dropout(dropout_rate) @@ -310,8 +312,9 @@ def __call__( ) -> Float[Array, "seqlen dim"]: key_dr1, key_dr2 = jr.split(key, 2) - x12 = jax.vmap(self.w12)(x) - x1, x2 = jnp.split(x12, 2, axis=-1) + x1 = jax.vmap(self.w1)(x) + x2 = jax.vmap(self.w2)(x) + x = self.drop1( jax.nn.silu(x1) * x2, inference=inference, @@ -327,7 +330,38 @@ def __call__( return x -class SwiGluFused(SwiGlu): +class SwiGluFused(eqx.Module): + """SwiGLU activation module with dropout. + + This matches the implementation of Dinov2 giant at + https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L54 + + Implements the SwiGLU (Swish-Gated Linear Unit) activation function with dropout, + as described in "GLU Variants Improve Transformer" paper [1]. The architecture uses + a gating mechanism where the input is transformed by two parallel paths and + combined multiplicatively. + + The computation flow is: + 1. Joint projection to higher dimension (w12) + 2. Split into two paths + 3. Apply SiLU to first path and multiply with second path + 4. Project back to original dimension (w3) + + Attributes: + w12: Joint projection layer for both paths + w3: Final projection layer + drop1: Dropout after gating + drop2: Dropout after final projection + + References: + [1]: https://arxiv.org/pdf/2002.05202 + """ + + w12: eqx.nn.Linear + w3: eqx.nn.Linear + drop1: eqx.nn.Dropout + drop2: eqx.nn.Dropout + def __init__( self, in_features: int, @@ -335,27 +369,61 @@ def __init__( key: PRNGKeyArray, out_features: int | None = None, hidden_features: int | None = None, - dropout_rate: float = 0, + dropout_rate: float = 0.0, bias: bool = True, **kwargs, ): - """This matches the implementation of Dinov2 giant at - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L54 + """Initialize the SwiGLU module. + + Args: + in_features: Number of input features + key: PRNG key for initialization + out_features: Number of output features (default: same as in_features) + hidden_features: Size of hidden dimension (default: same as in_features) + dropout_rate: Dropout probability (default: 0.0) + bias: Whether to include bias in linear layers (default: True) + **kwargs: Additional arguments """ + key_fc1, key_fc2 = jr.split(key, 2) + out_features = out_features or in_features hidden_features = hidden_features or in_features hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 - super().__init__( - in_features, - key=key, - out_features=out_features, - hidden_features=hidden_features, - dropout_rate=dropout_rate, - bias=bias, - **kwargs, + self.w12 = eqx.nn.Linear( + in_features, 2 * hidden_features, use_bias=bias, key=key_fc1 + ) + self.w3 = eqx.nn.Linear( + hidden_features, out_features, use_bias=bias, key=key_fc2 + ) + + self.drop1 = eqx.nn.Dropout(dropout_rate) + self.drop2 = eqx.nn.Dropout(dropout_rate) + + def __call__( + self, + x: Float[Array, "seqlen dim"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + ) -> Float[Array, "seqlen dim"]: + key_dr1, key_dr2 = jr.split(key, 2) + + x12 = jax.vmap(self.w12)(x) + x1, x2 = jnp.split(x12, 2, axis=-1) + x = self.drop1( + jax.nn.silu(x1) * x2, + inference=inference, + key=key_dr1, ) + x = self.drop2( + jax.vmap(self.w3)(x), + inference=inference, + key=key_dr2, + ) + + return x + def get_ffn(module: str | eqx.Module) -> eqx.Module: """Get an `eqx.Module` from its common name. From fd3c408f8e1ae381149ea6a886e8b7a34930fd77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 09/10] feat(vit): pass kwargs to swiglu --- src/equimo/layers/attention.py | 2 ++ src/equimo/models/vit.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/equimo/layers/attention.py b/src/equimo/layers/attention.py index 6b66d54..c3a5b08 100644 --- a/src/equimo/layers/attention.py +++ b/src/equimo/layers/attention.py @@ -329,6 +329,7 @@ def __init__( ffn_layer: eqx.Module = Mlp, ffn_bias: bool = True, ffn_norm: bool = False, + ffn_kwargs: dict = {}, norm_layer: eqx.Module = eqx.nn.LayerNorm, post_attention_norm: bool = False, init_values: float | None = None, @@ -382,6 +383,7 @@ def __init__( bias=ffn_bias, eps=eps, key=key_mlp, + **ffn_kwargs, ) self.drop_path1 = DropPathAdd(dr1) diff --git a/src/equimo/models/vit.py b/src/equimo/models/vit.py index 7f836ac..9f26907 100644 --- a/src/equimo/models/vit.py +++ b/src/equimo/models/vit.py @@ -205,6 +205,7 @@ def __init__( attn_layer: str | eqx.Module = Attention, ffn_layer: str | eqx.Module = Mlp, ffn_bias: bool = True, + ffn_kwargs: dict = {}, norm_layer: str | eqx.Module = eqx.nn.LayerNorm, untie_global_and_local_cls_norm: bool = False, init_values: float | None = None, @@ -320,6 +321,7 @@ def __init__( attn_layer=attn_layer[i], ffn_layer=ffn_layer, ffn_bias=ffn_bias, + ffn_kwargs=ffn_kwargs, norm_layer=norm_layer, init_values=init_values, eps=eps, From 52d21570f201a1db99d9b5a893ee5bfbecd154b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Wed, 20 Aug 2025 12:07:34 +0200 Subject: [PATCH 10/10] feat(dinov3): conversion script --- models/dinov3.py | 249 +++++++++++++++++++++++++++++++++++++++++++++++ uv.lock | 30 ++++-- 2 files changed, 272 insertions(+), 7 deletions(-) create mode 100644 models/dinov3.py diff --git a/models/dinov3.py b/models/dinov3.py new file mode 100644 index 0000000..4f3df59 --- /dev/null +++ b/models/dinov3.py @@ -0,0 +1,249 @@ +from pathlib import Path + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np + +import equimo.models as em +from equimo.conversion.utils import convert_torch_to_equinox +from equimo.io import save_model + +DIR = Path("~/.cache/torch/hub/dinov3").expanduser() + + +def compare(j, t) -> float: + j = np.array(j) + t = t.squeeze().detach().numpy() + return float(np.mean(np.abs(j - t))) + + +weights = { + # LVD + "dinov3_vits16_pretrain_lvd1689m": str( + ( + Path( + "~/.cache/torch/hub/dinov3/weights/dinov3_vits16_pretrain_lvd1689m-08c60483.pth" + ) + ).expanduser() + ), + "dinov3_vits16plus_pretrain_lvd1689m": str( + ( + Path( + "~/.cache/torch/hub/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth" + ) + ).expanduser() + ), + "dinov3_vitb16_pretrain_lvd1689m": str( + ( + Path( + "~/.cache/torch/hub/dinov3/weights/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth" + ) + ).expanduser() + ), + "dinov3_vitl16_pretrain_lvd1689m": str( + ( + Path( + "~/.cache/torch/hub/dinov3/weights/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth" + ) + ).expanduser() + ), + "dinov3_vith16plus_pretrain_lvd1689m": str( + ( + Path( + "~/.cache/torch/hub/dinov3/weights/dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth" + ) + ).expanduser() + ), + "dinov3_vit7b16_pretrain_lvd1689m": str( + ( + Path( + "~/.cache/torch/hub/dinov3/weights/dinov3_vit7b16_pretrain_lvd1689m-a955f4ea.pth" + ) + ).expanduser() + ), + # SAT + "dinov3_vitl16_pretrain_sat493m": str( + ( + Path( + "~/.cache/torch/hub/dinov3/weights/dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth" + ) + ).expanduser() + ), + "dinov3_vit7b16_pretrain_sat493m": str( + ( + Path( + "~/.cache/torch/hub/dinov3/weights/dinov3_vit7b16_pretrain_sat493m-a6675841.pth" + ) + ).expanduser() + ), +} + +configs = { + "dinov3_vits16_pretrain_lvd1689m": { + "dim": 384, + "num_heads": 6, + "depths": [12], + "reg_tokens": 4, + "mlp_ratio": 4.0, + }, + "dinov3_vits16plus_pretrain_lvd1689m": { + "dim": 384, + "num_heads": 6, + "depths": [12], + "reg_tokens": 4, + "mlp_ratio": 6.0, + "ffn_layer": "swiglu", + }, + "dinov3_vitb16_pretrain_lvd1689m": { + "dim": 768, + "num_heads": 12, + "depths": [12], + "reg_tokens": 4, + "mlp_ratio": 4.0, + }, + "dinov3_vitl16_pretrain_lvd1689m": { + "dim": 1024, + "num_heads": 16, + "depths": [24], + "reg_tokens": 4, + "mlp_ratio": 4.0, + }, + "dinov3_vith16plus_pretrain_lvd1689m": { + "dim": 1280, + "num_heads": 20, + "depths": [32], + "reg_tokens": 4, + "mlp_ratio": 6.0, + "ffn_layer": "swiglu", + }, + "dinov3_vit7b16_pretrain_lvd1689m": { + "dim": 4096, + "num_heads": 32, + "depths": [40], + "reg_tokens": 4, + "mlp_ratio": 3.0, + "untie_global_and_local_cls_norm": True, + "ffn_kwargs": {"align_to": 64}, + }, + "dinov3_vitl16_pretrain_sat493m": { + "dim": 1024, + "num_heads": 16, + "depths": [24], + "reg_tokens": 4, + "mlp_ratio": 4.0, + "untie_global_and_local_cls_norm": True, + }, + "dinov3_vit7b16_pretrain_sat493m": { + "dim": 4096, + "num_heads": 32, + "depths": [40], + "reg_tokens": 4, + "mlp_ratio": 3.0, + "untie_global_and_local_cls_norm": True, + "ffn_kwargs": {"align_to": 64}, + }, +} + +citr = iter(configs.items()) +name, config = next(citr) + + +def main(): + try: + import torch + except: + raise ImportError("`torch` not available") + + key = jax.random.PRNGKey(42) + dinov3_config = { + "img_size": 224, + "in_channels": 3, + "patch_size": 16, + "num_classes": 0, + "use_mask_token": True, + "use_rope_pos_embed": True, + "reg_tokens": 4, + "init_values": 1e-5, + "eps": 1e-5, + "dynamic_img_size": True, + "act_layer": "exactgelu", + } + + for name, config in configs.items(): + print(f"Converting {name}...") + + cfg = dinov3_config | config + + dinov3 = em.VisionTransformer( + **cfg, + key=key, + ) + + torch_name = "_".join(name.split("_")[:-2]) + torch_hub_cfg = { + "repo_or_dir": str(DIR / "dinov3"), + "model": torch_name, + "source": "local", + "weights": weights[name], + } + # model = torch.hub.load(**torch_hub_cfg) + + replace_cfg = { + "reg_tokens": "storage_tokens", + "blocks.0.blocks": "blocks", + ".prenorm.": ".norm1.", + ".norm.": ".norm2.", + } + expand_cfg = {"patch_embed.proj.bias": ["after", 2]} + squeeze_cfg = { + "pos_embed": 0, + "cls_token": 0, + "storage_tokens": 0, + } + torch_whitelist = [] + jax_whitelist = ["pos_embed.periods"] + + dinov3, torch_model = convert_torch_to_equinox( + dinov3, + replace_cfg, + expand_cfg, + squeeze_cfg, + torch_whitelist, + jax_whitelist, + strict=True, + torch_hub_cfg=torch_hub_cfg, + return_torch=True, + ) + dinov3 = eqx.nn.inference_mode(dinov3, True) + + arr = np.random.randn(3, cfg["img_size"], cfg["img_size"]) + jax_arr = jnp.array(arr) + torch_arr = torch.tensor(arr).unsqueeze(0).float() + + assert ( + err := compare( + dinov3.features(jax_arr, inference=True, key=key), + torch_model.forward_features(torch_arr)["x_prenorm"], + ) + < 5e-4 + ), f"Conversion error: {err}" + + save_path = Path(f"~/.cache/equimo/dinov3/{name}").expanduser() + save_model( + save_path, + dinov3, + cfg, + torch_hub_cfg, + compression=True, + ) + + # Ensure the serialization is okay + # loaded_model = load_model(cls="vit", path=save_path.with_suffix(".tar.lz4")) + # a = dinov3.features(jax_arr, inference=True, key=key) + # b = loaded_model.features(jax_arr, inference=True, key=key) + # jnp.mean((a - b) ** 2) + + +if __name__ == "__main__": + main() diff --git a/uv.lock b/uv.lock index 9fb0caa..4609c8f 100644 --- a/uv.lock +++ b/uv.lock @@ -2,9 +2,9 @@ version = 1 revision = 1 requires-python = ">=3.11" resolution-markers = [ - "python_full_version < '3.12'", - "python_full_version == '3.12.*'", "python_full_version >= '3.13'", + "python_full_version == '3.12.*'", + "python_full_version < '3.12'", ] [[package]] @@ -166,7 +166,7 @@ wheels = [ [[package]] name = "equimo" -version = "0.4.0a2" +version = "0.4.1" source = { virtual = "." } dependencies = [ { name = "einops" }, @@ -381,22 +381,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/97/34/165b87ea55184770a0c1fcdb7e017199974ad2e271451fd045cfe35f3add/h5py-3.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:4f97ecde7ac6513b21cd95efdfc38dc6d19f96f6ca6f2a30550e94e551458e0a", size = 2940890 }, ] +[[package]] +name = "hf-xet" +version = "1.1.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/0a/a0f56735940fde6dd627602fec9ab3bad23f66a272397560abd65aba416e/hf_xet-1.1.7.tar.gz", hash = "sha256:20cec8db4561338824a3b5f8c19774055b04a8df7fff0cb1ff2cb1a0c1607b80", size = 477719 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/7c/8d7803995caf14e7d19a392a486a040f923e2cfeff824e9b800b92072f76/hf_xet-1.1.7-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:60dae4b44d520819e54e216a2505685248ec0adbdb2dd4848b17aa85a0375cde", size = 2761743 }, + { url = "https://files.pythonhosted.org/packages/51/a3/fa5897099454aa287022a34a30e68dbff0e617760f774f8bd1db17f06bd4/hf_xet-1.1.7-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:b109f4c11e01c057fc82004c9e51e6cdfe2cb230637644ade40c599739067b2e", size = 2624331 }, + { url = "https://files.pythonhosted.org/packages/86/50/2446a132267e60b8a48b2e5835d6e24fd988000d0f5b9b15ebd6d64ef769/hf_xet-1.1.7-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6efaaf1a5a9fc3a501d3e71e88a6bfebc69ee3a716d0e713a931c8b8d920038f", size = 3183844 }, + { url = "https://files.pythonhosted.org/packages/20/8f/ccc670616bb9beee867c6bb7139f7eab2b1370fe426503c25f5cbb27b148/hf_xet-1.1.7-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:751571540f9c1fbad9afcf222a5fb96daf2384bf821317b8bfb0c59d86078513", size = 3074209 }, + { url = "https://files.pythonhosted.org/packages/21/0a/4c30e1eb77205565b854f5e4a82cf1f056214e4dc87f2918ebf83d47ae14/hf_xet-1.1.7-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:18b61bbae92d56ae731b92087c44efcac216071182c603fc535f8e29ec4b09b8", size = 3239602 }, + { url = "https://files.pythonhosted.org/packages/f5/1e/fc7e9baf14152662ef0b35fa52a6e889f770a7ed14ac239de3c829ecb47e/hf_xet-1.1.7-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:713f2bff61b252f8523739969f247aa354ad8e6d869b8281e174e2ea1bb8d604", size = 3348184 }, + { url = "https://files.pythonhosted.org/packages/a3/73/e354eae84ceff117ec3560141224724794828927fcc013c5b449bf0b8745/hf_xet-1.1.7-cp37-abi3-win_amd64.whl", hash = "sha256:2e356da7d284479ae0f1dea3cf5a2f74fdf925d6dca84ac4341930d892c7cb34", size = 2820008 }, +] + [[package]] name = "huggingface-hub" -version = "0.29.1" +version = "0.34.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, { name = "packaging" }, { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/22/37/797d6476f13e5ef6af5fc48a5d641d32b39c37e166ccf40c3714c5854a85/huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250", size = 389776 } +sdist = { url = "https://files.pythonhosted.org/packages/45/c9/bdbe19339f76d12985bc03572f330a01a93c04dffecaaea3061bdd7fb892/huggingface_hub-0.34.4.tar.gz", hash = "sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c", size = 459768 } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/05/75b90de9093de0aadafc868bb2fa7c57651fd8f45384adf39bd77f63980d/huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5", size = 468049 }, + { url = "https://files.pythonhosted.org/packages/39/7b/bb06b061991107cd8783f300adff3e7b7f284e330fd82f507f2a1417b11d/huggingface_hub-0.34.4-py3-none-any.whl", hash = "sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a", size = 561452 }, ] [[package]] @@ -782,8 +798,8 @@ name = "ml-dtypes" version = "0.5.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.12'", "python_full_version == '3.12.*'", + "python_full_version < '3.12'", ] dependencies = [ { name = "numpy", marker = "python_full_version < '3.13'" },