From ee26a04a679e23fc218b1a3a7e04bdf5dda67f1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Tue, 11 Feb 2025 11:09:44 +0100 Subject: [PATCH 1/3] refactor(dropout)!: inference This removes the `Dropout` class and renames `enable_dropout` to `inference`. `inference` is not mandatory anymore. --- equimo/layers/attention.py | 138 ++++++++++++++++----------------- equimo/layers/convolution.py | 4 +- equimo/layers/dropout.py | 97 +---------------------- equimo/layers/ffn.py | 36 ++++----- equimo/layers/generic.py | 8 +- equimo/layers/mamba.py | 2 +- equimo/layers/sharing.py | 8 +- equimo/models/fastervit.py | 30 +++---- equimo/models/mlla.py | 14 ++-- equimo/models/partialformer.py | 39 +++++----- equimo/models/shvit.py | 40 +++++----- equimo/models/vit.py | 35 ++++----- equimo/models/vssd.py | 14 ++-- 13 files changed, 186 insertions(+), 279 deletions(-) diff --git a/equimo/layers/attention.py b/equimo/layers/attention.py index 156a406..835d440 100644 --- a/equimo/layers/attention.py +++ b/equimo/layers/attention.py @@ -7,7 +7,7 @@ from einops import rearrange, reduce from jaxtyping import Array, Float, PRNGKeyArray -from equimo.layers.dropout import Dropout, DropPathAdd +from equimo.layers.dropout import DropPathAdd from equimo.layers.ffn import Mlp from equimo.layers.mamba import Mamba2Mixer from equimo.layers.norm import LayerScale @@ -35,8 +35,8 @@ class Attention(eqx.Module): proj: eqx.nn.Linear q_norm: eqx.Module k_norm: eqx.Module - attn_drop: Dropout - proj_drop: Dropout + attn_drop: eqx.nn.Dropout + proj_drop: eqx.nn.Dropout def __init__( self, @@ -65,14 +65,14 @@ def __init__( self.q_norm = norm_layer(dim) if qk_norm else eqx.nn.Identity() self.k_norm = norm_layer(dim) if qk_norm else eqx.nn.Identity() - self.attn_drop = Dropout(attn_drop) - self.proj_drop = Dropout(proj_drop) + self.attn_drop = eqx.nn.Dropout(attn_drop) + self.proj_drop = eqx.nn.Dropout(proj_drop) def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: key1, key2 = jr.split(key, 2) @@ -90,12 +90,12 @@ def __call__( attn = jnp.einsum("hqd,hkd->hqk", q, k) / jnp.sqrt(self.head_dim) attn = jax.nn.softmax(attn, axis=-1) - attn = self.attn_drop(attn, inference=not enable_dropout, key=key1) + attn = self.attn_drop(attn, inference=inference, key=key1) x = jnp.einsum("hqk,hvd->hqd", attn, v) x = rearrange(x, "h s d -> s (h d)") x = jax.vmap(self.proj)(x) - x = self.proj_drop(x, inference=not enable_dropout, key=key2) + x = self.proj_drop(x, inference=inference, key=key2) return x @@ -122,8 +122,8 @@ class WindowedAttention(eqx.Module): proj: eqx.nn.Linear q_norm: eqx.Module k_norm: eqx.Module - attn_drop: Dropout - proj_drop: Dropout + attn_drop: eqx.nn.Dropout + proj_drop: eqx.nn.Dropout pos_emb_funct: PosEmbMLPSwinv2D def __init__( @@ -164,14 +164,14 @@ def __init__( self.q_norm = norm_layer(dim) if qk_norm else eqx.nn.Identity() self.k_norm = norm_layer(dim) if qk_norm else eqx.nn.Identity() - self.attn_drop = Dropout(attn_drop) - self.proj_drop = Dropout(proj_drop) + self.attn_drop = eqx.nn.Dropout(attn_drop) + self.proj_drop = eqx.nn.Dropout(proj_drop) def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: key1, key2 = jr.split(key, 2) @@ -190,12 +190,12 @@ def __call__( attn = jnp.einsum("hqd,hkd->hqk", q, k) / jnp.sqrt(self.head_dim) attn = self.pos_emb_funct(attn, self.resolution**2) attn = jax.nn.softmax(attn, axis=-1) - attn = self.attn_drop(attn, inference=not enable_dropout, key=key1) + attn = self.attn_drop(attn, inference=inference, key=key1) x = jnp.einsum("hqk,hvd->hqd", attn, v) x = rearrange(x, "h s d -> s (h d)") x = jax.vmap(self.proj)(x) - x = self.proj_drop(x, inference=not enable_dropout, key=key2) + x = self.proj_drop(x, inference=inference, key=key2) return x @@ -209,7 +209,7 @@ class AttentionBlock(eqx.Module): - MLP feed-forward network - Residual connections - Optional layer scaling - - Dropout paths + - eqx.nn.Dropout paths Attributes: prenorm: First layer normalization (before attention) @@ -303,8 +303,8 @@ def __init__( def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: key_attn, key_mlp, key_dr1, key_dr2 = jr.split(key, 4) @@ -314,12 +314,12 @@ def __call__( jax.vmap(self.postnorm)( self.attn( jax.vmap(self.prenorm)(x), - enable_dropout, + inference=inference, key=key_attn, ) ) ), - inference=not enable_dropout, + inference=inference, key=key_dr1, ) x = self.drop_path2( @@ -327,11 +327,11 @@ def __call__( self.ls2( self.mlp( jax.vmap(self.norm)(x), - enable_dropout, + inference=inference, key=key_mlp, ) ), - inference=not enable_dropout, + inference=inference, key=key_dr2, ) @@ -543,8 +543,8 @@ def __call__( self, x: Float[Array, "seqlen dim"], carrier_tokens: Float[Array, "..."], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: key_attn, key_hattn, key_dr1, key_hdr1, key_dr2, key_hdr2, key_mlp, key_hmlp = ( jr.split(key, 8) @@ -576,11 +576,11 @@ def __call__( self.hat_ls1( self.hat_attn( jax.vmap(self.hat_norm1)(ct), - enable_dropout, - key_hattn, + inference=inference, + key=key_hattn, ) ), - inference=not enable_dropout, + inference=inference, key=key_hdr1, ) ct = self.hat_drop_path( @@ -588,11 +588,11 @@ def __call__( self.hat_ls2( self.hat_mlp( jax.vmap(self.hat_norm2)(ct), - enable_dropout, - key_hmlp, + inference=inference, + key=key_hmlp, ) ), - inference=not enable_dropout, + inference=inference, key=key_hdr2, ) @@ -612,11 +612,11 @@ def __call__( self.ls1( self.attn( jax.vmap(self.norm1)(x), - enable_dropout, - key_attn, + inference=inference, + key=key_attn, ) ), - inference=not enable_dropout, + inference=inference, key=key_dr1, ) x = self.drop_path2( @@ -624,11 +624,11 @@ def __call__( self.ls2( self.mlp( jax.vmap(self.norm2)(x), - enable_dropout, - key_mlp, + inference=inference, + key=key_mlp, ) ), - inference=not enable_dropout, + inference=inference, key=key_dr2, ) @@ -728,7 +728,7 @@ def flatten(self, x: Float[Array, "channels height width"]): def __call__( self, x: Float[Array, "channels height width"], - enable_dropout: Optional[bool] = None, + inference: Optional[bool] = None, key: Optional[PRNGKeyArray] = None, ) -> Float[Array, "channels height width"]: C, H, W = x.shape @@ -797,7 +797,7 @@ def __init__( def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: Optional[bool] = None, + inference: Optional[bool] = None, key: Optional[PRNGKeyArray] = None, ) -> Float[Array, "seqlen dim"]: n, c = x.shape @@ -982,7 +982,7 @@ def __init__( def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: Optional[bool] = None, + inference: Optional[bool] = None, key: Optional[PRNGKeyArray] = None, ) -> Float[Array, "seqlen dim"]: key_attn, key_dr1, key_dr2, key_mlp = jr.split(key, 4) @@ -1001,12 +1001,12 @@ def __call__( x1 = rearrange(jax.vmap(self.in_proj)(x1), "(h w) c -> c h w", h=h, w=w) x1 = self.act(rearrange(self.dwc(x1), "c h w -> (h w) c")) - x1 = self.attn(x1, enable_dropout, key_attn) + x1 = self.attn(x1, inference=inference, key=key_attn) if self.use_dwc: x1 = jax.vmap(self.out_proj)(x * act_res) - x = self.drop_path1(x, x1, inference=not enable_dropout, key=key_dr1) + x = self.drop_path1(x, x1, inference=inference, key=key_dr1) x += rearrange( self.cpe2(rearrange(x, "(h w) c -> c h w", h=h, w=w)), @@ -1015,8 +1015,8 @@ def __call__( return self.drop_path2( x, - self.mlp(jax.vmap(self.norm2)(x), enable_dropout, key_mlp), - inference=not enable_dropout, + self.mlp(jax.vmap(self.norm2)(x), inference=inference, key=key_mlp), + inference=inference, key=key_dr2, ) @@ -1028,7 +1028,7 @@ class MMSA(eqx.Module): - Multi-head structure - Attention score projection for multi-scale interaction - Normalized query/key processing - - Dropout regularization + - eqx.nn.Dropout regularization Attributes: dim: Total dimension of input/output @@ -1046,8 +1046,8 @@ class MMSA(eqx.Module): attn_proj2: eqx.nn.Linear q_norm: eqx.Module k_norm: eqx.Module - attn_drop: Dropout - proj_drop: Dropout + attn_drop: eqx.nn.Dropout + proj_drop: eqx.nn.Dropout def __init__( self, @@ -1087,14 +1087,14 @@ def __init__( self.q_norm = norm_layer(dim) if qk_norm else eqx.nn.Identity() self.k_norm = norm_layer(dim) if qk_norm else eqx.nn.Identity() - self.attn_drop = Dropout(attn_drop) - self.proj_drop = Dropout(proj_drop) + self.attn_drop = eqx.nn.Dropout(attn_drop) + self.proj_drop = eqx.nn.Dropout(proj_drop) def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: key1, key2 = jr.split(key, 2) @@ -1119,12 +1119,12 @@ def __call__( jax.vmap(jax.vmap(self.attn_proj2))(attn), "q k h -> h q k", ) - attn = self.attn_drop(attn, inference=not enable_dropout, key=key1) + attn = self.attn_drop(attn, inference=inference, key=key1) x = jnp.einsum("hqk,hvd->hqd", attn, v) x = rearrange(x, "h s d -> s (h d)") x = jax.vmap(self.proj)(x) - x = self.proj_drop(x, inference=not enable_dropout, key=key2) + x = self.proj_drop(x, inference=inference, key=key2) return x @@ -1154,8 +1154,8 @@ class SQA(eqx.Module): proj_norm: eqx.Module q_norm: eqx.Module k_norm: eqx.Module - attn_drop: Dropout - proj_drop: Dropout + attn_drop: eqx.nn.Dropout + proj_drop: eqx.nn.Dropout def __init__( self, @@ -1191,15 +1191,15 @@ def __init__( self.q_norm = norm_layer(dim) if qk_norm else eqx.nn.Identity() self.k_norm = norm_layer(dim) if qk_norm else eqx.nn.Identity() - self.attn_drop = Dropout(attn_drop) - self.proj_drop = Dropout(proj_drop) + self.attn_drop = eqx.nn.Dropout(attn_drop) + self.proj_drop = eqx.nn.Dropout(proj_drop) def __call__( self, x: Float[Array, "seqlen dim"], q: Float[Array, "1 dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen_x dim"]: key1, key2 = jr.split(key, 2) @@ -1223,14 +1223,14 @@ def __call__( attn = jnp.einsum("hqd,hkd->hqk", q, k) / jnp.sqrt(self.head_dim) attn = jax.nn.softmax(attn, axis=-1) - attn = self.attn_drop(attn, inference=not enable_dropout, key=key1) + attn = self.attn_drop(attn, inference=inference, key=key1) x1 = jnp.einsum("hqk,hvd->hqd", attn, v) x1 = rearrange(x1, "h s d -> s (h d)") x1 = jax.vmap(self.proj_norm)(jax.vmap(self.proj1)(x1)) x1 = jax.vmap(self.proj2)(jax.nn.relu(x1)) - x = self.proj_drop(x + x1, inference=not enable_dropout, key=key2) + x = self.proj_drop(x + x1, inference=inference, key=key2) return x @@ -1361,7 +1361,7 @@ def __call__( self, x: Float[Array, "seqlen dim"], qa: Float[Array, "1 dim"], - enable_dropout: Optional[bool] = None, + inference: Optional[bool] = None, key: Optional[PRNGKeyArray] = None, ) -> Tuple[Float[Array, "seqlen dim"], Float[Array, "1 dim"]]: key_mmsa, key_sqa, key_dr1, key_dr2, key_mlp = jr.split(key, 5) @@ -1388,24 +1388,24 @@ def __call__( f = rearrange(f, "n s c -> (n s) c") b = rearrange(b, "n s c -> (n s) c") - qf = self.mmsa(jnp.concat([qa, f], axis=0), enable_dropout, key=key_mmsa) + qf = self.mmsa(jnp.concat([qa, f], axis=0), inference=inference, key=key_mmsa) qa, f = jnp.split(qf, [1]) - b = self.sqa(b, qa, enable_dropout, key=key_sqa) + b = self.sqa(b, qa, inference=inference, key=key_sqa) x1 = self.ls1(jnp.concat([f, b], axis=0)) - x = self.drop_path1(x, x1, inference=not enable_dropout, key=key_dr1) + x = self.drop_path1(x, x1, inference=inference, key=key_dr1) qa, x1 = jnp.split( self.mlp( jax.vmap(self.norm2)(jnp.concat([qa, x])), - enable_dropout, + inference=inference, key=key_mlp, )[1], ) x = self.drop_path2( x, self.ls2(x1), - inference=not enable_dropout, + inference=inference, key=key_dr2, ) @@ -1442,8 +1442,8 @@ class LinearAngularAttention(eqx.Module): qkv: eqx.nn.Linear proj: eqx.nn.Linear - attn_drop: Dropout - proj_drop: Dropout + attn_drop: eqx.nn.Dropout + proj_drop: eqx.nn.Dropout dconv: eqx.nn.Conv def __init__( @@ -1486,14 +1486,14 @@ def __init__( key=key_dconv, ) - self.attn_drop = Dropout(attn_drop) - self.proj_drop = Dropout(proj_drop) + self.attn_drop = eqx.nn.Dropout(attn_drop) + self.proj_drop = eqx.nn.Dropout(proj_drop) def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: key1, key2 = jr.split(key, 2) @@ -1510,7 +1510,7 @@ def __call__( if self.sparse_reg: attn = jnp.einsum("hqd,hkd->hqk", q, k) / jnp.sqrt(self.head_dim) attn = jax.nn.softmax(attn, axis=-1) - attn = self.attn_drop(attn, inference=not enable_dropout, key=key1) + attn = self.attn_drop(attn, inference=inference, key=key1) sparse = jnp.where(attn > self.sparsity_threshold, attn, 0) q = q / jnp.linalg.norm(q, axis=-1, keepdims=True) @@ -1529,6 +1529,6 @@ def __call__( x = rearrange(x, "h s d -> s (h d)") x = jax.vmap(self.proj)(x) - x = self.proj_drop(x, inference=not enable_dropout, key=key2) + x = self.proj_drop(x, inference=inference, key=key2) return x diff --git a/equimo/layers/convolution.py b/equimo/layers/convolution.py index c7f6ad1..7c4b6e2 100644 --- a/equimo/layers/convolution.py +++ b/equimo/layers/convolution.py @@ -118,15 +118,15 @@ def depermute( def __call__( self, x: Float[Array, "channels height width"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "channels height width"]: _, h, w = x.shape x2 = self.act(self.norm1(self.conv1(x))) x2 = self.norm2(self.conv2(x2)) x2 = self.depermute(jax.vmap(jax.vmap(self.ls1))(self.permute(x2))) - return self.drop_path1(x, x2, inference=not enable_dropout, key=key) + return self.drop_path1(x, x2, inference=inference, key=key) class SingleConvBlock(eqx.Module): diff --git a/equimo/layers/dropout.py b/equimo/layers/dropout.py index 7e61865..a342305 100644 --- a/equimo/layers/dropout.py +++ b/equimo/layers/dropout.py @@ -1,99 +1,12 @@ -import warnings from typing import Optional import equinox as eqx import jax import jax.lax as lax -import jax.numpy as jnp import jax.random as jrandom from jaxtyping import Array, PRNGKeyArray -class Dropout(eqx.Module, strict=True): - """Applies dropout. - - Note that this layer behaves differently during training and inference. During - training then dropout is randomly applied; during inference this layer does nothing. - """ - - # Let's make them static fields, just to avoid possible filtering issues - p: float = eqx.field(static=True) - inference: bool = eqx.field(static=True) - - def __init__( - self, - p: float = 0.5, - inference: bool = False, - *, - deterministic: Optional[bool] = None, - ): - """**Arguments:** - - - `p`: The fraction of entries to set to zero. (On average.) - - `inference`: Whether to actually apply dropout at all. If `True` then dropout - is *not* applied. If `False` then dropout is applied. This may be toggled - with [`equinox.nn.inference_mode`][] or overridden during - [`equinox.nn.Dropout.__call__`][]. - - `deterministic`: Deprecated alternative to `inference`. - """ - - if deterministic is not None: - inference = deterministic - warnings.warn( - "Dropout(deterministic=...) is deprecated " - "in favour of Dropout(inference=...)" - ) - self.p = p - self.inference = inference - - # Backward compatibility - @property - def deterministic(self): - return self.inference - - @jax.named_scope("eqx.nn.Dropout") - def __call__( - self, - x: Array, - *, - key: Optional[PRNGKeyArray] = None, - inference: Optional[bool] = None, - deterministic: Optional[bool] = None, - ) -> Array: - """**Arguments:** - - - `x`: An any-dimensional JAX array to dropout. - - `key`: A `jax.random.PRNGKey` used to provide randomness for calculating - which elements to dropout. (Keyword only argument.) - - `inference`: As per [`equinox.nn.Dropout.__init__`][]. If `True` or - `False` then it will take priority over `self.inference`. If `None` - then the value from `self.inference` will be used. - - `deterministic`: Deprecated alternative to `inference`. - """ - - if deterministic is not None: - inference = deterministic - warnings.warn( - "Dropout()(deterministic=...) is deprecated " - "in favour of Dropout()(inference=...)" - ) - - if inference is None: - inference = self.inference - if isinstance(self.p, (int, float)) and self.p == 0: - inference = True - if inference: - return x - elif key is None: - raise RuntimeError( - "Dropout requires a key when running in non-deterministic mode." - ) - else: - q = 1 - lax.stop_gradient(self.p) - mask = jrandom.bernoulli(key, q, x.shape) - return jnp.where(mask, x / q, 0) - - class DropPath(eqx.Module, strict=True): """Applies drop path (stochastic depth). @@ -101,9 +14,8 @@ class DropPath(eqx.Module, strict=True): training then dropout is randomly applied; during inference this layer does nothing. """ - # Let's make them static fields, just to avoid possible filtering issues - p: float = eqx.field(static=True) - inference: bool = eqx.field(static=True) + p: float + inference: bool def __init__( self, @@ -162,9 +74,8 @@ class DropPathAdd(eqx.Module, strict=True): training then dropout is randomly applied; during inference this layer does nothing. """ - # Let's make them static fields, just to avoid possible filtering issues - p: float = eqx.field(static=True) - inference: bool = eqx.field(static=True) + p: float + inference: bool def __init__( self, diff --git a/equimo/layers/ffn.py b/equimo/layers/ffn.py index aa0fd98..e520bfb 100644 --- a/equimo/layers/ffn.py +++ b/equimo/layers/ffn.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Optional import equinox as eqx import jax @@ -6,8 +6,6 @@ import jax.random as jr from jaxtyping import Array, Float, PRNGKeyArray -from equimo.layers.dropout import Dropout - class WeightNormLinear(eqx.Module): """Linear layer with weight normalization. @@ -112,14 +110,14 @@ def __init__( def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: """Process input through the DINOv2 projection head. Args: x: Input feature tensor - enable_dropout: Whether to enable dropout (unused in original implementation) + inference: Whether to enable dropout (unused in original implementation) key: PRNG key for random operations Returns: @@ -158,8 +156,8 @@ class Mlp(eqx.Module): fc1: eqx.nn.Linear fc2: eqx.nn.Linear norm: eqx.Module - drop1: Dropout - drop2: Dropout + drop1: eqx.nn.Dropout + drop2: eqx.nn.Dropout def __init__( self, @@ -202,25 +200,25 @@ def __init__( hidden_features, out_features, use_bias=bias, key=key_fc2 ) - self.drop1 = Dropout(dropout_rate) - self.drop2 = Dropout(dropout_rate) + self.drop1 = eqx.nn.Dropout(dropout_rate) + self.drop2 = eqx.nn.Dropout(dropout_rate) def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: key_dr1, key_dr2 = jr.split(key, 2) x = self.drop1( jax.vmap(self.norm)(self.act_layer(jax.vmap(self.fc1)(x))), - inference=not enable_dropout, + inference=inference, key=key_dr1, ) x = self.drop2( jax.vmap(self.fc2)(x), - inference=not enable_dropout, + inference=inference, key=key_dr2, ) @@ -253,8 +251,8 @@ class SwiGlu(eqx.Module): w12: eqx.nn.Linear w3: eqx.nn.Linear - drop1: Dropout - drop2: Dropout + drop1: eqx.nn.Dropout + drop2: eqx.nn.Dropout def __init__( self, @@ -290,14 +288,14 @@ def __init__( hidden_features // 2, out_features, use_bias=bias, key=key_fc2 ) - self.drop1 = Dropout(dropout_rate) - self.drop2 = Dropout(dropout_rate) + self.drop1 = eqx.nn.Dropout(dropout_rate) + self.drop2 = eqx.nn.Dropout(dropout_rate) def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: key_dr1, key_dr2 = jr.split(key, 2) @@ -305,13 +303,13 @@ def __call__( x1, x2 = jnp.split(x12, 2, axis=-1) x = self.drop1( jax.nn.silu(x1) * x2, - inference=not enable_dropout, + inference=inference, key=key_dr1, ) x = self.drop2( jax.vmap(self.w3)(x), - inference=not enable_dropout, + inference=inference, key=key_dr2, ) diff --git a/equimo/layers/generic.py b/equimo/layers/generic.py index 5f8e4cd..6837da3 100644 --- a/equimo/layers/generic.py +++ b/equimo/layers/generic.py @@ -1,3 +1,5 @@ +from typing import Optional + import equinox as eqx from jaxtyping import Array, Float, PRNGKeyArray @@ -38,9 +40,9 @@ def __init__( def __call__( self, x: Float[Array, "..."], - enable_dropout: bool, key: PRNGKeyArray, pass_args: bool = False, + inference: Optional[bool] = None, ) -> Float[Array, "..."]: """Forward pass of the residual block. @@ -56,13 +58,13 @@ def __call__( with the residual connection through drop path """ if pass_args: - x2 = self.module(x, enable_dropout=enable_dropout, key=key) + x2 = self.module(x, inference=inference, key=key) else: x2 = self.module(x) return self.drop_path( x, x2, - inference=not enable_dropout, + inference=inference, key=key, ) diff --git a/equimo/layers/mamba.py b/equimo/layers/mamba.py index a95a0dc..aa6f628 100644 --- a/equimo/layers/mamba.py +++ b/equimo/layers/mamba.py @@ -124,8 +124,8 @@ def __init__( def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: A = -jnp.exp(self.A_log) zxbcdt = jax.vmap(self.in_proj)(x) diff --git a/equimo/layers/sharing.py b/equimo/layers/sharing.py index a83291d..dbf3dac 100644 --- a/equimo/layers/sharing.py +++ b/equimo/layers/sharing.py @@ -111,15 +111,15 @@ def __call__( self, x: Array, *args, - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, **kwargs, ): if self.repeat == 1: return self.f( x, *args, - enable_dropout=enable_dropout, + inference=inference, key=key, **kwargs, ) @@ -135,7 +135,7 @@ def __call__( lora_x = x lora_output = self.dropouts[i]( jax.vmap(self.loras[i])(lora_x), - inference=not enable_dropout, + inference=inference, key=keys[i], ) if reshape: @@ -150,7 +150,7 @@ def __call__( self.f( x, *args, - enable_dropout=enable_dropout, + inference=inference, key=key, **kwargs, ) diff --git a/equimo/models/fastervit.py b/equimo/models/fastervit.py index 0b3f524..df75a64 100644 --- a/equimo/models/fastervit.py +++ b/equimo/models/fastervit.py @@ -28,8 +28,8 @@ def __call__( x: Array, ct: Array, # *args, - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, **kwargs, ): """Apply layer sharing with carrier token support. @@ -37,7 +37,7 @@ def __call__( Args: x: Input tensor ct: carrier token tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: @@ -48,7 +48,7 @@ def __call__( x, ct, # *args, - enable_dropout=enable_dropout, + inference=inference, key=key, **kwargs, ) @@ -64,7 +64,7 @@ def __call__( lora_x = x lora_output = self.dropouts[i]( jax.vmap(self.loras[i])(lora_x), - inference=not enable_dropout, + inference=inference, key=keys[i], ) if reshape: @@ -78,7 +78,7 @@ def __call__( x, ct = self.f( x, ct, - enable_dropout=enable_dropout, + inference=inference, key=key, **kwargs, ) @@ -291,8 +291,8 @@ def window_reverse( def __call__( self, x: Float[Array, "channels height width"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "..."]: keys = jr.split(key, len(self.blocks)) ct = self.global_tokenizer(x) if self.do_gt else None @@ -305,14 +305,14 @@ def __call__( blk, in_axes=(0, None, None, None), out_axes=(0, None), - )(x, ct, enable_dropout, key_block) + )(x, ct, inference=inference, key=key_block) x = self.window_reverse(x, self.window_size, h, w) else: for blk, key_block in zip(self.blocks, keys): - x = blk(x, enable_dropout=enable_dropout, key=key_block) + x = blk(x, inference=inference, key=key_block) if self.downsampler_contains_dropout: - x = self.downsample(x, enable_dropout, key) + x = self.downsample(x, inference=inference, key=key) else: x = self.downsample(x) @@ -440,14 +440,14 @@ def __init__( def features( self, x: Float[Array, "channels height width"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: """Extract features from input image. Args: x: Input image tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: @@ -457,7 +457,7 @@ def features( x = self.patch_embed(x) for blk, key_block in zip(self.blocks, block_subkeys): - x = blk(x, enable_dropout=enable_dropout, key=key_block) + x = blk(x, inference=inference, key=key_block) x = rearrange(x, "c h w -> (h w) c") @@ -466,20 +466,20 @@ def features( def __call__( self, x: Float[Array, "channels height width"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "num_classes"]: """Process input image through the full network. Args: x: Input image tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: Classification logits for each class """ - x = self.features(x, enable_dropout, key) + x = self.features(x, inference=inference, key=key) x = jax.vmap(self.norm)(x) x = pool_sd( x, diff --git a/equimo/models/mlla.py b/equimo/models/mlla.py index 4b956fa..1f9b1a9 100644 --- a/equimo/models/mlla.py +++ b/equimo/models/mlla.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import equinox as eqx import jax @@ -136,36 +136,36 @@ def __init__( def features( self, x: Float[Array, "..."], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "..."]: key_pd, *keys = jr.split(key, 1 + len(self.blocks)) x = self.patch_embed(x) - x = self.pos_drop(x, inference=not enable_dropout, key=key_pd) + x = self.pos_drop(x, inference=inference, key=key_pd) for i, blk in enumerate(self.blocks): - x = blk(x, enable_dropout=enable_dropout, key=keys[i]) + x = blk(x, inference=inference, key=keys[i]) return x def __call__( self, x: Float[Array, "..."], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "..."]: """Process input through the MLLA model. Args: x: Input tensor (typically an image) - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: Output tensor (class logits if num_classes > 0, otherwise feature representations) """ - x = self.features(x, enable_dropout=enable_dropout, key=key) + x = self.features(x, inference=inference, key=key) x = jax.vmap(self.norm)(x) x = reduce(x, "s d -> d", "mean") x = self.head(x) diff --git a/equimo/models/partialformer.py b/equimo/models/partialformer.py index 8d9f862..1bc5f7d 100644 --- a/equimo/models/partialformer.py +++ b/equimo/models/partialformer.py @@ -9,7 +9,6 @@ from equimo.layers.attention import PartialFormerBlock from equimo.layers.convolution import Stem -from equimo.layers.dropout import Dropout from equimo.layers.ffn import Mlp from equimo.layers.patch import PatchMerging from equimo.layers.posemb import PosCNN @@ -30,15 +29,15 @@ def __call__( x: Array, qa: Array, *args, - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, **kwargs, ): if self.repeat == 1: return self.f( x, *args, - enable_dropout=enable_dropout, + inference=inference, key=key, **kwargs, ) @@ -54,7 +53,7 @@ def __call__( lora_x = x lora_output = self.dropouts[i]( jax.vmap(self.loras[i])(lora_x), - inference=not enable_dropout, + inference=inference, key=keys[i], ) if reshape: @@ -68,7 +67,7 @@ def __call__( x, qa = self.f( x, qa=qa, - enable_dropout=enable_dropout, + inference=inference, key=key, **kwargs, ) @@ -182,8 +181,8 @@ def __call__( x: Float[Array, "seqlen dim"], qa: Float[Array, "1 dim"], *, - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, **kwargs, ) -> Tuple[Float[Array, "..."], Float[Array, "..."]]: """Process input features and query attention token. @@ -191,7 +190,7 @@ def __call__( Args: x: Input feature tensor qa: Query attention token - enable_dropout: Whether to enable dropout + inference: Whether to enable dropout key: PRNG key for random operations Returns: @@ -202,17 +201,15 @@ def __call__( x = self.posemb(x) for blk, key_block in zip(self.blocks, keys): - x, qa = blk( - x, qa=qa, enable_dropout=enable_dropout, key=key_block, **kwargs - ) + x, qa = blk(x, qa=qa, inference=inference, key=key_block, **kwargs) if self.downsampler_contains_dropout: - x = self.downsample(x, enable_dropout, key) + x = self.downsample(x, inference=inference, key=key) else: x = self.downsample(x) qa = self.qa_drop( jax.vmap(self.qa_proj)(qa), - inference=not enable_dropout, + inference=inference, key=key_qadrop, ) @@ -249,7 +246,7 @@ class PartialFormer(eqx.Module): qa_token: jnp.ndarray patch_embed: Stem - pos_drop: Dropout + pos_drop: eqx.nn.Dropout blocks: List[eqx.Module] norm: eqx.Module head: eqx.Module @@ -300,7 +297,7 @@ def __init__( key=key_stem, ) - self.pos_drop = Dropout(pos_drop_rate) + self.pos_drop = eqx.nn.Dropout(pos_drop_rate) if drop_path_uniform: dpr = [drop_path_rate] * depth @@ -358,15 +355,15 @@ def __init__( def features( self, x: Float[Array, "channels height width"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, return_qa: bool = False, ) -> Float[Array, "seqlen dim"]: """Extract features from input image using partial attention. Args: x: Input image tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: @@ -374,14 +371,14 @@ def features( """ key_posdrop, *block_subkeys = jr.split(key, len(self.blocks) + 1) x = self.patch_embed(x) - x = self.pos_drop(x, inference=not enable_dropout, key=key_posdrop) + x = self.pos_drop(x, inference=inference, key=key_posdrop) qa = self.qa_token for blk, key_block in zip(self.blocks, block_subkeys): x, qa = blk( x, qa=qa, - enable_dropout=enable_dropout, + inference=inference, key=key_block, ) @@ -392,20 +389,20 @@ def features( def __call__( self, x: Float[Array, "channels height width"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "num_classes"]: """Process input image through the full network. Args: x: Input image tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: Classification logits for each class """ - x = self.features(x, enable_dropout, key) + x = self.features(x, inference=inference, key=key) x = jax.vmap(self.norm)(x) x = reduce(x, "n c -> c", "mean") diff --git a/equimo/models/shvit.py b/equimo/models/shvit.py index 25a6ca8..d0772e9 100644 --- a/equimo/models/shvit.py +++ b/equimo/models/shvit.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import equinox as eqx import jax @@ -101,14 +101,14 @@ def __init__( def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: """Apply downsampling to input features. Args: x: Input feature tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: @@ -116,15 +116,15 @@ def __call__( """ key_conv1, key_conv2, key_conv3, key_conv4 = jr.split(key, 4) x = self.conv2( - self.conv1(x, enable_dropout, key_conv1), - enable_dropout, - key_conv2, + self.conv1(x, inference=inference, key=key_conv1), + inference=inference, + key=key_conv2, ) x = self.patch_merging(x) x = self.conv4( - self.conv3(x, enable_dropout, key_conv3), - enable_dropout, - key_conv4, + self.conv3(x, inference=inference, key=key_conv3), + inference=inference, + key=key_conv4, ) return x @@ -207,18 +207,18 @@ def __init__( def __call__( self, x: Float[Array, "seqlen dim"], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: key_conv, key_mixer, key_ffn = jr.split(key, 3) return self.ffn( self.mixer( - self.conv(x, enable_dropout, key_conv), - enable_dropout, - key_mixer, + self.conv(x, inference=inference, key=key_conv), + inference=inference, + key=key_mixer, ), - enable_dropout, - key_ffn, + inference=inference, + key=key_ffn, ) @@ -346,34 +346,34 @@ def __init__( def features( self, x: Float[Array, "..."], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "..."]: keys = jr.split(key, len(self.blocks)) x = self.patch_embed(x) for i, blk in enumerate(self.blocks): - x = blk(x, enable_dropout=enable_dropout, key=keys[i]) + x = blk(x, inference=inference, key=keys[i]) return x def __call__( self, x: Float[Array, "..."], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "..."]: """Process input image through the full SHViT network. Args: x: Input image tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: Classification logits for each class """ - x = self.features(x, enable_dropout=enable_dropout, key=key) + x = self.features(x, inference=inference, key=key) x = jax.vmap(self.norm)(x) x = reduce(x, "c h w -> c", "mean") x = self.head(x) diff --git a/equimo/models/vit.py b/equimo/models/vit.py index 04a1d7e..2c4d4ef 100644 --- a/equimo/models/vit.py +++ b/equimo/models/vit.py @@ -9,7 +9,6 @@ from jaxtyping import Array, Float, Int, PRNGKeyArray from equimo.layers.attention import Attention, AttentionBlock -from equimo.layers.dropout import Dropout from equimo.layers.ffn import Mlp from equimo.layers.patch import PatchEmbedding from equimo.layers.posemb import PosCNN @@ -96,8 +95,8 @@ def __call__( self, x: Float[Array, "..."], *, - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, **kwargs, ) -> Float[Array, "..."]: keys = jr.split(key, len(self.blocks)) @@ -105,10 +104,10 @@ def __call__( x = self.posemb(x) for blk, key_block in zip(self.blocks, keys): - x = blk(x, enable_dropout=enable_dropout, key=key_block, **kwargs) + x = blk(x, inference=inference, key=key_block, **kwargs) if self.downsampler_contains_dropout: - x = self.downsample(x, enable_dropout, key) + x = self.downsample(x, inference=inference, key=key) else: x = self.downsample(x) @@ -150,7 +149,7 @@ class VisionTransformer(eqx.Module): reg_tokens: jnp.ndarray | None mask_token: jnp.ndarray | None blocks: List[eqx.Module] - pos_drop: Dropout + pos_drop: eqx.nn.Dropout norm: eqx.Module head: eqx.Module @@ -248,7 +247,7 @@ def __init__( self.embed_len = self.num_patches + 1 self.pos_embed = jr.normal(key_posemb, (self.embed_len, dim)) - self.pos_drop = Dropout(pos_drop_rate) + self.pos_drop = eqx.nn.Dropout(pos_drop_rate) if drop_path_uniform: dpr = [drop_path_rate] * depth @@ -395,15 +394,15 @@ def _pos_embed(self, x: Float[Array, "..."], h: int, w: int): def features( self, x: Float[Array, "channels height width"], - enable_dropout: bool, key: PRNGKeyArray, mask: Optional[Int[Array, "embed_h embed_w"]] = None, + inference: Optional[bool] = None, ) -> Float[Array, "seqlen dim"]: """Extract features from input image. Args: x: Input image tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations mask: optional binary mask of the size of the input after patch embedding @@ -414,9 +413,9 @@ def features( x = self.patch_embed(x) if mask is not None: - assert ( - self.mask_token is not None - ), "To use masked forward, init the model with `use_mask_token=True`." + assert self.mask_token is not None, ( + "To use masked forward, init the model with `use_mask_token=True`." + ) if self.dynamic_img_size: mask = rearrange(mask, "h w -> 1 h w") value = rearrange(self.mask_token, "1 c -> c 1 1") @@ -429,21 +428,21 @@ def features( x = self._pos_embed(x, h=self.embed_size, w=self.embed_size) for blk, key_block in zip(self.blocks, block_subkeys): - x = blk(x, enable_dropout=enable_dropout, key=key_block) + x = blk(x, inference=inference, key=key_block) return x def forward_features( self, x: Float[Array, "channels height width"], - enable_dropout: bool, + inference: bool, key: PRNGKeyArray, ) -> dict: """Process features and return intermediate representations. Args: x: Input image tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: @@ -453,7 +452,7 @@ def forward_features( - x_norm_patchtokens: Normalized patch tokens - x_prenorm: Pre-normalized features """ - x = self.features(x, enable_dropout=enable_dropout, key=key) + x = self.features(x, inference=inference, key=key) x_norm = jax.vmap(self.norm)(x) return { @@ -466,20 +465,20 @@ def forward_features( def __call__( self, x: Float[Array, "channels height width"], - enable_dropout: bool = False, key: PRNGKeyArray = jr.PRNGKey(42), + inference: Optional[bool] = None, ) -> Float[Array, "num_classes"]: """Process input image through the full network. Args: x: Input image tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: Classification logits """ - x = self.features(x, enable_dropout, key) + x = self.features(x, inference=inference, key=key) x = jax.vmap(self.norm)(x) x = pool_sd( x, diff --git a/equimo/models/vssd.py b/equimo/models/vssd.py index 956e676..973bd70 100644 --- a/equimo/models/vssd.py +++ b/equimo/models/vssd.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import equinox as eqx import jax @@ -154,29 +154,29 @@ def __init__( def features( self, x: Float[Array, "..."], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "..."]: key_pd, *keys = jr.split(key, 1 + len(self.blocks)) x = self.patch_embed(x) - x = self.pos_drop(x, inference=not enable_dropout, key=key_pd) + x = self.pos_drop(x, inference=inference, key=key_pd) for i, blk in enumerate(self.blocks): - x = blk(x, enable_dropout=enable_dropout, key=keys[i]) + x = blk(x, inference=inference, key=keys[i]) return x def __call__( self, x: Float[Array, "..."], - enable_dropout: bool, key: PRNGKeyArray, + inference: Optional[bool] = None, ) -> Float[Array, "..."]: """Process input image through the VSSD network. Args: x: Input image tensor - enable_dropout: Whether to enable dropout during inference + inference: Whether to enable dropout during inference key: PRNG key for random operations Returns: @@ -188,7 +188,7 @@ def __call__( 3. Global average pooling 4. Classification head """ - x = self.features(x, enable_dropout=enable_dropout, key=key) + x = self.features(x, inference=inference, key=key) x = jax.vmap(self.norm)(x) x = reduce(x, "s d -> d", "mean") x = self.head(x) From d78cb73e59956d35dd8908fcb5161740a29cf4f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Tue, 11 Feb 2025 16:10:01 +0100 Subject: [PATCH 2/3] feat(test): add simple vit test --- devenv.lock | 30 ++++------------- devenv.nix | 2 ++ equimo/layers/mamba.py | 2 +- equimo/layers/sharing.py | 2 +- equimo/models/partialformer.py | 2 +- pyproject.toml | 5 +++ pytest.ini | 2 ++ tests/test_models.py | 33 ++++++++++++++++++ uv.lock | 61 +++++++++++++++++++++++++++++++++- 9 files changed, 112 insertions(+), 27 deletions(-) create mode 100644 pytest.ini create mode 100644 tests/test_models.py diff --git a/devenv.lock b/devenv.lock index c7e65d0..8650f38 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,10 +3,10 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1733788855, + "lastModified": 1739283003, "owner": "cachix", "repo": "devenv", - "rev": "d59fee8696cd48f69cf79f65992269df9891ba86", + "rev": "2921e0f7708a69a6f746db58491dd0f0e35cbc8e", "type": "github" }, "original": { @@ -53,10 +53,10 @@ }, "nixpkgs": { "locked": { - "lastModified": 1733749988, + "lastModified": 1739138025, "owner": "NixOS", "repo": "nixpkgs", - "rev": "bc27f0fde01ce4e1bfec1ab122d72b7380278e68", + "rev": "b2243f41e860ac85c0b446eadc6930359b294e79", "type": "github" }, "original": { @@ -66,35 +66,19 @@ "type": "github" } }, - "nixpkgs-stable": { - "locked": { - "lastModified": 1733730953, - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "7109b680d161993918b0a126f38bc39763e5a709", - "type": "github" - }, - "original": { - "owner": "NixOS", - "ref": "nixos-24.05", - "repo": "nixpkgs", - "type": "github" - } - }, "pre-commit-hooks": { "inputs": { "flake-compat": "flake-compat", "gitignore": "gitignore", "nixpkgs": [ "nixpkgs" - ], - "nixpkgs-stable": "nixpkgs-stable" + ] }, "locked": { - "lastModified": 1733665616, + "lastModified": 1737465171, "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "d8c02f0ffef0ef39f6063731fc539d8c71eb463a", + "rev": "9364dc02281ce2d37a1f55b6e51f7c0f65a75f17", "type": "github" }, "original": { diff --git a/devenv.nix b/devenv.nix index 7af93cc..c8823cd 100644 --- a/devenv.nix +++ b/devenv.nix @@ -22,4 +22,6 @@ in enterShell = '' . .devenv/state/venv/bin/activate ''; + + enterTest = "uv run pytest tests/"; } diff --git a/equimo/layers/mamba.py b/equimo/layers/mamba.py index aa6f628..40d959c 100644 --- a/equimo/layers/mamba.py +++ b/equimo/layers/mamba.py @@ -1,5 +1,5 @@ import math -from typing import List, Tuple +from typing import List, Optional, Tuple import equinox as eqx import jax diff --git a/equimo/layers/sharing.py b/equimo/layers/sharing.py index dbf3dac..c38b7f9 100644 --- a/equimo/layers/sharing.py +++ b/equimo/layers/sharing.py @@ -1,5 +1,5 @@ import math -from typing import List +from typing import List, Optional import equinox as eqx import jax diff --git a/equimo/models/partialformer.py b/equimo/models/partialformer.py index 1bc5f7d..fa47985 100644 --- a/equimo/models/partialformer.py +++ b/equimo/models/partialformer.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple +from typing import Callable, List, Optional, Tuple import equinox as eqx import jax diff --git a/pyproject.toml b/pyproject.toml index c0f4667..be554cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,3 +10,8 @@ dependencies = [ "jax>=0.4.25", "jaxlib>=0.4.25", ] + +[dependency-groups] +dev = [ + "pytest>=8.3.4", +] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..7829689 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = "." diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..d4a444f --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,33 @@ +import equimo.models as em +import jax.numpy as jnp +import jax.random as jr + + +def test_vit_inference(): + key = jr.PRNGKey(42) + img_size = 224 + patch_size = 14 + + x1 = jr.normal(key, (3, 224, 224)) + x2 = jr.normal(key, (3, 98, 98)) + mask = jr.bernoulli(key, shape=(16, 16)) * 1 + + base_model = em.VisionTransformer( + img_size=img_size, + in_channels=3, + dim=384, + patch_size=patch_size, + num_heads=[6], + depths=[12], + num_classes=0, + use_mask_token=True, + dynamic_img_size=True, + key=key, + ) + + # Testing multiple img sizes, inference mode, and masking + f1 = base_model.features(x1, mask=mask, inference=True, key=key) + f2 = base_model.features(x2, inference=False, key=key) + + assert jnp.all(f1) + assert jnp.all(f2) diff --git a/uv.lock b/uv.lock index 62d3fe6..c274f73 100644 --- a/uv.lock +++ b/uv.lock @@ -6,6 +6,15 @@ resolution-markers = [ "python_full_version >= '3.13'", ] +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + [[package]] name = "einops" version = "0.8.0" @@ -17,7 +26,7 @@ wheels = [ [[package]] name = "equimo" -version = "0.1.3a3" +version = "0.1.3a9" source = { virtual = "." } dependencies = [ { name = "einops" }, @@ -26,6 +35,11 @@ dependencies = [ { name = "jaxlib" }, ] +[package.dev-dependencies] +dev = [ + { name = "pytest" }, +] + [package.metadata] requires-dist = [ { name = "einops", specifier = ">=0.8.0" }, @@ -34,6 +48,9 @@ requires-dist = [ { name = "jaxlib", specifier = ">=0.4.25" }, ] +[package.metadata.requires-dev] +dev = [{ name = "pytest", specifier = ">=8.3.4" }] + [[package]] name = "equinox" version = "0.11.9" @@ -48,6 +65,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/6a/aece19b15fe69057e87c5c8accf3209d4c51d57d7997bbdbba49ad543f80/equinox-0.11.9-py3-none-any.whl", hash = "sha256:d9257a5f9d923b18e309ba046bffaf92f49c8b03c577d1df8bc446f9a64055f4", size = 179312 }, ] +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, +] + [[package]] name = "jax" version = "0.4.35" @@ -178,6 +204,39 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 }, ] +[[package]] +name = "packaging" +version = "24.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 }, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + +[[package]] +name = "pytest" +version = "8.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, +] + [[package]] name = "scipy" version = "1.14.1" From 6b75a51a2dd780c0b38a33717151b84403679480 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20POIRET?= Date: Tue, 11 Feb 2025 17:46:49 +0100 Subject: [PATCH 3/3] feat(action): init test action --- .github/workflows/python-tests.yaml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/workflows/python-tests.yaml diff --git a/.github/workflows/python-tests.yaml b/.github/workflows/python-tests.yaml new file mode 100644 index 0000000..b66dab8 --- /dev/null +++ b/.github/workflows/python-tests.yaml @@ -0,0 +1,29 @@ +name: Python tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest] + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v4 + + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: DeterminateSystems/flake-checker-action@main + + - name: Install devenv.sh + run: nix profile install nixpkgs#devenv + + - name: Run tests + run: devenv test