diff --git a/equimo/layers/attention.py b/equimo/layers/attention.py index 03c721b..786da34 100644 --- a/equimo/layers/attention.py +++ b/equimo/layers/attention.py @@ -1552,3 +1552,53 @@ def __call__( x = self.proj_drop(x, inference=inference, key=key2) return x + + +def get_attention(module: str | eqx.Module) -> eqx.Module: + """Get an `eqx.Module` from its common name. + + This is necessary because configs have to be stringified and stored as + json files to allow (de)serialization. + """ + if not isinstance(module, str): + return module + + match module: + case "attention": + return Attention + case "windowedattention": + return WindowedAttention + case "shsa": + return SHSA + case "linearattention": + return LinearAttention + case "mmsa": + return MMSA + case "sqa": + return SQA + case "linearangularattention": + return LinearAngularAttention + case _: + raise ValueError(f"Got an unknown module string: {module}") + + +def get_attention_block(module: str | eqx.Module) -> eqx.Module: + """Get an `eqx.Module` from its common name. + + This is necessary because configs have to be stringified and stored as + json files to allow (de)serialization. + """ + if not isinstance(module, str): + return module + + match module: + case "attentionblock": + return AttentionBlock + case "hatblock": + return HATBlock + case "mllablock": + return MllaBlock + case "partialformerblock": + return PartialFormerBlock + case _: + raise ValueError(f"Got an unknown module string: {module}") diff --git a/equimo/layers/ffn.py b/equimo/layers/ffn.py index 852b757..a7f6f97 100644 --- a/equimo/layers/ffn.py +++ b/equimo/layers/ffn.py @@ -285,10 +285,10 @@ def __init__( out_features = out_features or in_features self.w12 = eqx.nn.Linear( - in_features, hidden_features, use_bias=bias, key=key_fc1 + in_features, 2 * hidden_features, use_bias=bias, key=key_fc1 ) self.w3 = eqx.nn.Linear( - hidden_features // 2, out_features, use_bias=bias, key=key_fc2 + hidden_features, out_features, use_bias=bias, key=key_fc2 ) self.drop1 = eqx.nn.Dropout(dropout_rate) @@ -317,3 +317,57 @@ def __call__( ) return x + + +class SwiGluFused(SwiGlu): + def __init__( + self, + in_features: int, + *, + key: PRNGKeyArray, + out_features: int | None = None, + hidden_features: int | None = None, + dropout_rate: float = 0, + bias: bool = True, + **kwargs, + ): + """This matches the implementation of Dinov2 giant at + https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L54 + """ + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + super().__init__( + in_features, + key=key, + out_features=out_features, + hidden_features=hidden_features, + dropout_rate=dropout_rate, + bias=bias, + **kwargs, + ) + + +def get_ffn(module: str | eqx.Module) -> eqx.Module: + """Get an `eqx.Module` from its common name. + + This is necessary because configs have to be stringified and stored as + json files to allow (de)serialization. + """ + if not isinstance(module, str): + return module + + match module: + case "mlp": + return Mlp + case "swiglu": + return SwiGlu + case "swiglufused": + return SwiGluFused + case "dinohead": + return DINOHead + case "weightnormlinear": + return WeightNormLinear + case _: + raise ValueError(f"Got an unknown module string: {module}") diff --git a/equimo/layers/norm.py b/equimo/layers/norm.py index 85c257b..837c715 100644 --- a/equimo/layers/norm.py +++ b/equimo/layers/norm.py @@ -85,3 +85,27 @@ def __call__(self, x: Float[Array, "dim"]): Scaled tensor of same shape as input """ return x * self.gamma + + +def get_norm(module: str | eqx.Module) -> eqx.Module: + """Get an `eqx.Module` from its common name. + + This is necessary because configs have to be stringified and stored as + json files to allow (de)serialization. + """ + if not isinstance(module, str): + return module + + match module: + case "layernorm": + return eqx.nn.LayerNorm + case "rmsnorm": + return eqx.nn.RMSNorm + case "groupnorm": + return eqx.nn.GroupNorm + case "rmsnormgated": + return RMSNormGated + case "layerscale": + return LayerScale + case _: + raise ValueError(f"Got an unknown module string: {module}") diff --git a/equimo/models/vit.py b/equimo/models/vit.py index 744e0a3..8080a23 100644 --- a/equimo/models/vit.py +++ b/equimo/models/vit.py @@ -8,8 +8,14 @@ from einops import rearrange from jaxtyping import Array, Float, Int, PRNGKeyArray -from equimo.layers.attention import Attention, AttentionBlock -from equimo.layers.ffn import Mlp +from equimo.layers.attention import ( + Attention, + AttentionBlock, + get_attention, + get_attention_block, +) +from equimo.layers.ffn import Mlp, get_ffn +from equimo.layers.norm import get_norm from equimo.layers.patch import PatchEmbedding from equimo.layers.posemb import PosCNN from equimo.utils import pool_sd, to_list @@ -177,7 +183,7 @@ def __init__( pos_drop_rate: float = 0.0, drop_path_rate: float = 0.0, drop_path_uniform: bool = False, - block: eqx.Module = AttentionBlock, + block: str | eqx.Module = AttentionBlock, mlp_ratio: float = 4.0, qkv_bias: bool = True, proj_bias: bool = True, @@ -185,10 +191,10 @@ def __init__( attn_drop: float = 0.0, proj_drop: float = 0.0, act_layer: Callable = jax.nn.gelu, - attn_layer: eqx.Module = Attention, - ffn_layer: eqx.Module = Mlp, + attn_layer: str | eqx.Module = Attention, + ffn_layer: str | eqx.Module = Mlp, ffn_bias: bool = True, - norm_layer: eqx.Module = eqx.nn.LayerNorm, + norm_layer: str | eqx.Module = eqx.nn.LayerNorm, init_values: float | None = None, global_pool: Literal["", "token", "avg", "avgmax", "max"] = "avg", num_classes: int = 1000, @@ -212,6 +218,11 @@ def __init__( self.global_pool = global_pool self.embed_size = img_size // patch_size + block = get_attention_block(block) + attn_layer = get_attention(attn_layer) + ffn_layer = get_ffn(ffn_layer) + norm_layer = get_norm(norm_layer) + self.patch_embed = PatchEmbedding( in_channels, dim, diff --git a/equimo/utils.py b/equimo/utils.py index f5aebc0..e20bc36 100644 --- a/equimo/utils.py +++ b/equimo/utils.py @@ -1,4 +1,5 @@ from functools import partial + import jax import jax.numpy as jnp from jaxtyping import Array, Float