Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions equimo/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,3 +1552,53 @@ def __call__(
x = self.proj_drop(x, inference=inference, key=key2)

return x


def get_attention(module: str | eqx.Module) -> eqx.Module:
"""Get an `eqx.Module` from its common name.

This is necessary because configs have to be stringified and stored as
json files to allow (de)serialization.
"""
if not isinstance(module, str):
return module

match module:
case "attention":
return Attention
case "windowedattention":
return WindowedAttention
case "shsa":
return SHSA
case "linearattention":
return LinearAttention
case "mmsa":
return MMSA
case "sqa":
return SQA
case "linearangularattention":
return LinearAngularAttention
case _:
raise ValueError(f"Got an unknown module string: {module}")


def get_attention_block(module: str | eqx.Module) -> eqx.Module:
"""Get an `eqx.Module` from its common name.

This is necessary because configs have to be stringified and stored as
json files to allow (de)serialization.
"""
if not isinstance(module, str):
return module

match module:
case "attentionblock":
return AttentionBlock
case "hatblock":
return HATBlock
case "mllablock":
return MllaBlock
case "partialformerblock":
return PartialFormerBlock
case _:
raise ValueError(f"Got an unknown module string: {module}")
58 changes: 56 additions & 2 deletions equimo/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ def __init__(
out_features = out_features or in_features

self.w12 = eqx.nn.Linear(
in_features, hidden_features, use_bias=bias, key=key_fc1
in_features, 2 * hidden_features, use_bias=bias, key=key_fc1
)
self.w3 = eqx.nn.Linear(
hidden_features // 2, out_features, use_bias=bias, key=key_fc2
hidden_features, out_features, use_bias=bias, key=key_fc2
)

self.drop1 = eqx.nn.Dropout(dropout_rate)
Expand Down Expand Up @@ -317,3 +317,57 @@ def __call__(
)

return x


class SwiGluFused(SwiGlu):
def __init__(
self,
in_features: int,
*,
key: PRNGKeyArray,
out_features: int | None = None,
hidden_features: int | None = None,
dropout_rate: float = 0,
bias: bool = True,
**kwargs,
):
"""This matches the implementation of Dinov2 giant at
https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L54
"""
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8

super().__init__(
in_features,
key=key,
out_features=out_features,
hidden_features=hidden_features,
dropout_rate=dropout_rate,
bias=bias,
**kwargs,
)


def get_ffn(module: str | eqx.Module) -> eqx.Module:
"""Get an `eqx.Module` from its common name.

This is necessary because configs have to be stringified and stored as
json files to allow (de)serialization.
"""
if not isinstance(module, str):
return module

match module:
case "mlp":
return Mlp
case "swiglu":
return SwiGlu
case "swiglufused":
return SwiGluFused
case "dinohead":
return DINOHead
case "weightnormlinear":
return WeightNormLinear
case _:
raise ValueError(f"Got an unknown module string: {module}")
24 changes: 24 additions & 0 deletions equimo/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,27 @@ def __call__(self, x: Float[Array, "dim"]):
Scaled tensor of same shape as input
"""
return x * self.gamma


def get_norm(module: str | eqx.Module) -> eqx.Module:
"""Get an `eqx.Module` from its common name.

This is necessary because configs have to be stringified and stored as
json files to allow (de)serialization.
"""
if not isinstance(module, str):
return module

match module:
case "layernorm":
return eqx.nn.LayerNorm
case "rmsnorm":
return eqx.nn.RMSNorm
case "groupnorm":
return eqx.nn.GroupNorm
case "rmsnormgated":
return RMSNormGated
case "layerscale":
return LayerScale
case _:
raise ValueError(f"Got an unknown module string: {module}")
23 changes: 17 additions & 6 deletions equimo/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
from einops import rearrange
from jaxtyping import Array, Float, Int, PRNGKeyArray

from equimo.layers.attention import Attention, AttentionBlock
from equimo.layers.ffn import Mlp
from equimo.layers.attention import (
Attention,
AttentionBlock,
get_attention,
get_attention_block,
)
from equimo.layers.ffn import Mlp, get_ffn
from equimo.layers.norm import get_norm
from equimo.layers.patch import PatchEmbedding
from equimo.layers.posemb import PosCNN
from equimo.utils import pool_sd, to_list
Expand Down Expand Up @@ -177,18 +183,18 @@ def __init__(
pos_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
drop_path_uniform: bool = False,
block: eqx.Module = AttentionBlock,
block: str | eqx.Module = AttentionBlock,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
act_layer: Callable = jax.nn.gelu,
attn_layer: eqx.Module = Attention,
ffn_layer: eqx.Module = Mlp,
attn_layer: str | eqx.Module = Attention,
ffn_layer: str | eqx.Module = Mlp,
ffn_bias: bool = True,
norm_layer: eqx.Module = eqx.nn.LayerNorm,
norm_layer: str | eqx.Module = eqx.nn.LayerNorm,
init_values: float | None = None,
global_pool: Literal["", "token", "avg", "avgmax", "max"] = "avg",
num_classes: int = 1000,
Expand All @@ -212,6 +218,11 @@ def __init__(
self.global_pool = global_pool
self.embed_size = img_size // patch_size

block = get_attention_block(block)
attn_layer = get_attention(attn_layer)
ffn_layer = get_ffn(ffn_layer)
norm_layer = get_norm(norm_layer)

self.patch_embed = PatchEmbedding(
in_channels,
dim,
Expand Down
1 change: 1 addition & 0 deletions equimo/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float
Expand Down