diff --git a/README.md b/README.md index fac8a78..e21bb72 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ Beyond normal ViT (e.g., dinov2 or siglip), equimo proposes other SotA architect | SHViT | [SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design](https://arxiv.org/abs/2401.16456) | 2024 | ✅ | | VSSD | [VSSD: Vision Mamba with Non-Causal State Space Duality](https://arxiv.org/abs/2407.18559) | 2024 | ✅ | | ReduceFormer | [ReduceFormer: Attention with Tensor Reduction by Summation](https://arxiv.org/abs/2406.07488) | 2024 | ✅ | +| LowFormer | [LowFormer: Hardware Efficient Design for Convolutional Transformer Backbones](https://arxiv.org/abs/2409.03460) | 2024 | ✅ | \*: Only contains the [Linear Angular Attention](https://github.com/clementpoiret/Equimo/blob/f8fcc79e45ca65e9deb1d970c4286c0b8562f9c2/equimo/layers/attention.py#L1407) module. It is straight forward to build a ViT around it, but may require an additional `__call__` kwarg to control the `sparse_reg` bool. diff --git a/pyproject.toml b/pyproject.toml index 84ef7d4..946d955 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "Equimo" -version = "0.4.0-alpha.1" +version = "0.4.0-alpha.2" description = "Implementation of popular vision models in Jax" readme = "README.md" requires-python = ">=3.11" diff --git a/src/equimo/__init__.py b/src/equimo/__init__.py index 37587dc..2bccb3c 100644 --- a/src/equimo/__init__.py +++ b/src/equimo/__init__.py @@ -1 +1 @@ -__version__ = "0.4.0-alpha.1" +__version__ = "0.4.0-alpha.2" diff --git a/src/equimo/layers/attention.py b/src/equimo/layers/attention.py index 9dc2855..ad1c88d 100644 --- a/src/equimo/layers/attention.py +++ b/src/equimo/layers/attention.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Sequence, Tuple +from typing import Callable, List, Literal, Optional, Sequence, Tuple import equinox as eqx import jax @@ -7,7 +7,7 @@ from einops import rearrange, reduce from jaxtyping import Array, Float, PRNGKeyArray -from equimo.layers.convolution import SingleConvBlock, MBConv +from equimo.layers.convolution import ConvBlock, SingleConvBlock, MBConv from equimo.layers.dropout import DropPathAdd from equimo.layers.ffn import Mlp from equimo.layers.mamba import Mamba2Mixer @@ -1738,6 +1738,316 @@ def __call__( return x +class ConvAttention(eqx.Module): + """Lightweight ConvAttention from LowFormer.""" + + num_heads: int = eqx.field(static=True) + head_dim: int = eqx.field(static=True) + attention_type: Literal["softmax", "sigmoid"] = eqx.field(static=True) + + qkv: eqx.nn.Sequential + out_proj: eqx.nn.Conv2d | eqx.nn.Identity + upsample: eqx.nn.ConvTranspose2d + + def __init__( + self, + in_channels: int, + *, + key: PRNGKeyArray, + att_kernel: int = 7, + att_stride: int = 4, + fuse: bool = True, + attention_type: Literal["softmax", "sigmoid"] = "softmax", + norm_layer: eqx.Module | None = eqx.nn.GroupNorm, + norm_kwargs: dict = {}, + **kwargs, + ): + key_qkv1, key_qkv2, key_oproj, key_upsampling = jr.split(key, 4) + + self.attention_type = attention_type + self.num_heads = int(max(1, in_channels // 30)) + self.head_dim = in_channels // self.num_heads + total_dim = int(self.head_dim * self.num_heads * 3) + + self.qkv = eqx.nn.Sequential( + [ + SingleConvBlock( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=att_kernel, + stride=att_stride, + padding=att_kernel // 2, + groups=in_channels, + use_bias=False, + norm_layer=norm_layer, + act_layer=None, + key=key_qkv1, + ), + eqx.nn.Conv2d( + in_channels=in_channels, + out_channels=total_dim, + kernel_size=1, + stride=1, + padding=0, + use_bias=False, + key=key_qkv2, + ), + ] + ) + + if fuse: + self.out_proj = eqx.nn.Identity() + self.upsample = eqx.nn.ConvTranspose2d( + in_channels=self.head_dim * self.num_heads, + out_channels=in_channels, + kernel_size=3 if att_stride == 1 else (att_stride * 2), + stride=att_stride, + padding=1 if att_stride == 1 else (att_stride // 2), + key=key_upsampling, + ) + else: + self.out_proj = eqx.nn.Conv2d( + self.head_dim * self.num_heads, + in_channels, + kernel_size=1, + stride=1, + padding=0, + key=key_oproj, + ) + self.upsample = eqx.nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3 if att_stride == 1 else (att_stride * 2), + stride=att_stride, + padding=1 if att_stride == 1 else (att_stride // 2), + groups=in_channels, + key=key_upsampling, + ) + + def __call__( + self, + x: Float[Array, "seqlen height width"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + ) -> Float[Array, "seqlen height width"]: + q, k, v = rearrange( + self.qkv(x), + "(n h d) s1 s2 -> n h (s1 s2) d", + n=3, + h=self.num_heads, + d=self.head_dim, + ) + + _, s, _ = q.shape + h = w = int(s**0.5) + + attn_logits = q @ rearrange(k, "h s d -> h d s") + attn_logits /= self.head_dim**0.5 + + if self.attention_type == "softmax": + attn = jax.nn.softmax(attn_logits, axis=-1) + elif self.attention_type == "sigmoid": + attn = jax.nn.sigmoid(attn_logits) + + v = attn @ v + + out = self.out_proj(rearrange(v, "h (s1 s2) d -> (h d) s1 s2", s1=h, s2=w)) + out = self.upsample(out) + + return out + + +class ConvAttentionBlock(eqx.Module): + prenorm: eqx.Module + postnorm: eqx.Module + norm: eqx.Module + attn: eqx.Module + mlp: eqx.Module + drop_path1: DropPathAdd + drop_path2: DropPathAdd + + def __init__( + self, + in_channels: int, + *, + key: PRNGKeyArray, + mlp_ratio: float = 4.0, + att_stride: int = 1, + attention_type: Literal["softmax", "sigmoid"] = "softmax", + fuse: bool = True, # TODO: verify + drop_path: float | List[float] = 0.0, + act_layer: Callable = jax.nn.gelu, # TODO: try hardswish + norm_layer: eqx.Module = eqx.nn.GroupNorm, + norm_max_group: int = 32, + post_attention_norm: bool = False, + eps: float = 1e-5, + **kwargs, + ): + key_attn, key_mlp = jr.split(key, 2) + + if isinstance(drop_path, list): + if len(drop_path) != 2: + raise AssertionError( + f"`drop_path` needs to have 2 elements, got {len(drop_path)} ({drop_path})." + ) + dr1, dr2 = drop_path + dr1 = float(dr1) + dr2 = float(dr2) + else: + dr1 = dr2 = float(drop_path) + + num_groups = nearest_power_of_2_divisor(in_channels, norm_max_group) + self.prenorm = norm_layer(num_groups, in_channels, eps=eps) + self.postnorm = ( + norm_layer(num_groups, in_channels, eps=eps) + if post_attention_norm + else eqx.nn.Identity() + ) + self.norm = norm_layer(num_groups, in_channels, eps=eps) + + self.attn = ConvAttention( + in_channels=in_channels, + att_kernel=5 if att_stride > 1 else 3, + att_stride=att_stride, + fuse=fuse, + attention_type=attention_type, + key=key_attn, + ) + + self.mlp = ConvBlock( + dim=in_channels, + hidden_dim=int(in_channels * mlp_ratio), + kernel_size=1, + stride=1, + padding=0, + drop_path=dr1, + act_layer=act_layer, + key=key_mlp, + ) + + self.drop_path1 = DropPathAdd(dr1) + self.drop_path2 = DropPathAdd(dr2) + + def __call__( + self, + x: Float[Array, "seqlen dim"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + ) -> Float[Array, "seqlen dim"]: + key_attn, key_mlp, key_dr1, key_dr2 = jr.split(key, 4) + + x = self.drop_path1( + x, + self.postnorm( + self.attn( + self.prenorm(x), + inference=inference, + key=key_attn, + ) + ), + inference=inference, + key=key_dr1, + ) + x = self.drop_path2( + x, + self.mlp( + self.norm(x), + inference=inference, + key=key_mlp, + ), + inference=inference, + key=key_dr2, + ) + + return x + + +class LowFormerBlock(eqx.Module): + context_module: ConvAttentionBlock + local_module: MBConv + drop_path1: DropPathAdd + drop_path2: DropPathAdd + + def __init__( + self, + in_channels: int, + *, + key, + mlp_ratio: float = 4.0, + att_stride: int = 1, + attention_type: Literal["softmax", "sigmoid"] = "softmax", + fuse_conv: bool = True, + drop_path: float | List[float] = 0.0, + act_layer: Callable = jax.nn.hard_swish, + norm_layer: eqx.Module = eqx.nn.GroupNorm, + expand_ratio: float = 4.0, + mbconv_norm_layers: tuple = (None, None, eqx.nn.GroupNorm), + mbconv_act_layers: tuple = (jax.nn.hard_swish, jax.nn.hard_swish, None), + fuse_mbconv: bool = False, + **kwargs, + ): + key_context, key_local = jr.split(key, 2) + + if isinstance(drop_path, list): + if len(drop_path) != 2: + raise AssertionError( + f"`drop_path` needs to have 2 elements, got {len(drop_path)} ({drop_path})." + ) + dr1, dr2 = drop_path + dr1 = float(dr1) + dr2 = float(dr2) + else: + dr1 = dr2 = float(drop_path) + + self.context_module = ConvAttentionBlock( + in_channels=in_channels, + mlp_ratio=mlp_ratio, + att_stride=att_stride, + attention_type=attention_type, + fuse=fuse_conv, + drop_path=drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + key=key_context, + ) + self.local_module = MBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + norm_layers=mbconv_norm_layers, + act_layers=mbconv_act_layers, + use_bias=(True, True, False), + fuse=fuse_mbconv, + key=key_local, + ) + + self.drop_path1 = DropPathAdd(dr1) + self.drop_path2 = DropPathAdd(dr2) + + def __call__( + self, + x: Float[Array, "dim height width"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + ): + key_context, key_local, key_dr1, key_dr2 = jr.split(key, 4) + + x = self.drop_path1( + x, + self.context_module(x, inference=inference, key=key_context), + inference=inference, + key=key_dr1, + ) + x = self.drop_path2( + x, + self.local_module(x, inference=inference, key=key_local), + inference=inference, + key=key_dr2, + ) + + return x + + def get_attention(module: str | eqx.Module) -> eqx.Module: """Get an `eqx.Module` from its common name. diff --git a/src/equimo/layers/convolution.py b/src/equimo/layers/convolution.py index ffac2b7..37de634 100644 --- a/src/equimo/layers/convolution.py +++ b/src/equimo/layers/convolution.py @@ -36,7 +36,7 @@ class ConvBlock(eqx.Module): norm2: eqx.Module drop_path1: DropPathAdd act: Callable - ls1: LayerScale + ls1: LayerScale | None def __init__( self, @@ -71,7 +71,8 @@ def __init__( key_conv1, key_conv2 = jr.split(key, 2) hidden_dim = hidden_dim or dim - num_groups = nearest_power_of_2_divisor(dim, norm_max_group) + num_groups1 = nearest_power_of_2_divisor(hidden_dim, norm_max_group) + num_groups2 = nearest_power_of_2_divisor(dim, norm_max_group) self.conv1 = eqx.nn.Conv( num_spatial_dims=2, in_channels=dim, @@ -82,7 +83,7 @@ def __init__( use_bias=True, key=key_conv1, ) - self.norm1 = eqx.nn.GroupNorm(num_groups, hidden_dim) + self.norm1 = eqx.nn.GroupNorm(num_groups1, hidden_dim) self.act = act_layer self.conv2 = eqx.nn.Conv( num_spatial_dims=2, @@ -94,16 +95,12 @@ def __init__( use_bias=True, key=key_conv2, ) - self.norm2 = eqx.nn.GroupNorm(num_groups, dim) + self.norm2 = eqx.nn.GroupNorm(num_groups2, dim) dpr = drop_path[0] if isinstance(drop_path, list) else float(drop_path) self.drop_path1 = DropPathAdd(dpr) - self.ls1 = ( - LayerScale(dim, init_values=init_values) - if init_values - else eqx.nn.Identity() - ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else None def permute( self, x: Float[Array, "channels height width"] @@ -122,10 +119,10 @@ def __call__( key: PRNGKeyArray, inference: Optional[bool] = None, ) -> Float[Array, "channels height width"]: - _, h, w = x.shape x2 = self.act(self.norm1(self.conv1(x))) x2 = self.norm2(self.conv2(x2)) - x2 = self.depermute(jax.vmap(jax.vmap(self.ls1))(self.permute(x2))) + if self.ls1 is not None: + x2 = self.depermute(jax.vmap(jax.vmap(self.ls1))(self.permute(x2))) return self.drop_path1(x, x2, inference=inference, key=key) @@ -143,7 +140,7 @@ class SingleConvBlock(eqx.Module): act: Activation layer (Lambda or Identity) """ - conv: eqx.nn.Conv + conv: eqx.nn.Conv2d | eqx.nn.ConvTranspose2d norm: eqx.Module act: eqx.Module @@ -159,6 +156,7 @@ def __init__( norm_layer: eqx.Module | None = eqx.nn.GroupNorm, norm_max_group: int = 32, act_layer: Callable | None = None, + transposed: bool = False, norm_kwargs: dict = {}, **kwargs, ): @@ -175,7 +173,8 @@ def __init__( **kwargs: Additional arguments passed to Conv layer """ - self.conv = eqx.nn.Conv2d( + conv = eqx.nn.ConvTranspose2d if transposed else eqx.nn.Conv2d + self.conv = conv( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, diff --git a/src/equimo/models/lowformer.py b/src/equimo/models/lowformer.py new file mode 100644 index 0000000..e0e920a --- /dev/null +++ b/src/equimo/models/lowformer.py @@ -0,0 +1,387 @@ +from typing import Callable, Literal, Optional, Tuple + +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jr +from einops import reduce +from jaxtyping import Array, Float, PRNGKeyArray + +from equimo.layers.activation import get_act +from equimo.layers.attention import LowFormerBlock +from equimo.layers.convolution import DSConv, MBConv, SingleConvBlock +from equimo.layers.norm import get_norm + + +class BlockChunk(eqx.Module): + residuals: list[bool] = eqx.field(static=True) + blocks: list[DSConv | MBConv | LowFormerBlock] + + def __init__( + self, + in_channels: int, + out_channels: int, + depth: int, + *, + key: PRNGKeyArray, + block_type: Literal["conv", "attention"] = "conv", + mlp_ratio: float = 4.0, + att_stride: int = 1, + attention_type: Literal["softmax", "sigmoid"] = "softmax", + fuse_conv: bool = True, + stride: int = 1, + expand_ratio: float = 4.0, + attention_expand_ratio: float = 4.0, + norm_layer: eqx.Module = eqx.nn.GroupNorm, + act_layer: Callable = jax.nn.hard_swish, + fewer_norm: bool = False, + fuse_mbconv: bool = False, + drop_path: list[float] = [0.0], + **kwargs, + ): + key, *block_subkeys = jr.split(key, depth + 1) + + keys_to_spread = [ + k for k, v in kwargs.items() if isinstance(v, list) and len(v) == depth + ] + + blocks = [] + residuals = [] + + # TODO: simplify logic + match block_type: + case "conv": + block = DSConv if expand_ratio == 1.0 else MBConv + if fewer_norm: + use_bias: Tuple[bool, ...] | bool = ( + (True, False) if block == DSConv else (True, True, False) + ) + norm_layer = ( + (None, norm_layer) + if block == DSConv + else (None, None, norm_layer) + ) + else: + use_bias = False + + for i in range(depth): + config = kwargs | {k: kwargs[k][i] for k in keys_to_spread} + + if block == MBConv: + config["expand_ratio"] = expand_ratio + config["fuse"] = fuse_mbconv + + blocks.append( + block( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + stride=stride if i == 0 else 1, + use_bias=use_bias, + norm_layers=norm_layer, + act_layers=(act_layer, None) + if block == DSConv + else (act_layer, act_layer, None), + **config, + key=block_subkeys[i], + ) + ) + residuals.append( + (in_channels == out_channels and stride == 1) or i > 0 + ) + + case "attention": + blocks.append( + MBConv( + in_channels, + out_channels, + stride=2, # TODO: make downsampling optional + expand_ratio=attention_expand_ratio, + norm_layers=(None, None, norm_layer), + act_layers=(act_layer, act_layer, None), + use_bias=(True, True, False), + fuse=fuse_mbconv, + key=key, + ) + ) + for i in range(depth): + blocks.append( + LowFormerBlock( + in_channels=out_channels, + mlp_ratio=mlp_ratio, + att_stride=att_stride, + attention_type=attention_type, + fuse_conv=fuse_conv, + drop_path=drop_path[i], + act_layer=act_layer, + norm_layer=norm_layer, + expand_ratio=expand_ratio, + mbconv_norm_layers=(None, None, norm_layer), + mbconv_act_layers=(act_layer, act_layer, None), + fuse_mbconv=fuse_mbconv, + key=block_subkeys[i], + ) + ) + residuals.append(False) + + self.blocks = blocks + self.residuals = residuals + + def __call__( + self, + x: Float[Array, "..."], + *, + key: PRNGKeyArray, + inference: Optional[bool] = None, + **kwargs, + ) -> Float[Array, "..."]: + keys = jr.split(key, len(self.blocks)) + + # TODO: Dropout and Stochastic Path Add + for blk, residual, key_block in zip(self.blocks, self.residuals, keys): + res = blk(x, inference=inference, key=key_block, **kwargs) + x = x + res if residual else res + + return x + + +class LowFormer(eqx.Module): + input_stem: eqx.nn.Sequential + blocks: list[BlockChunk] + head: eqx.nn.Linear | eqx.nn.Identity + + def __init__( + self, + in_channels: int, + widths: list[int], + depths: list[int], + att_strides: list[int], + block_types: list[Literal["conv", "attention"]], + *, + key: PRNGKeyArray, + mlp_ratio: float = 4.0, + attention_type: Literal["softmax", "sigmoid"], + stem_expand_ratio: float = 2.0, + blocks_expand_ratio: float = 4.0, + blocks_attention_expand_ratio: float = 4.0, + norm_layer: eqx.Module | str = eqx.nn.GroupNorm, + act_layer: Callable | str = jax.nn.hard_swish, + fuse_mbconv: bool = False, + num_classes: int | None = 1000, + drop_path_rate: float = 0.0, + drop_path_uniform: bool = False, + **kwargs, + ): + if not len(widths) == len(depths) == len(att_strides) == len(block_types): + raise ValueError( + "`widths`, `depths`, `att_strides`, and `block_types` must have the same lengths." + ) + + key_stem, key_head, *key_blocks = jr.split(key, 3 + len(depths)) + + depth = sum(depths) + act_layer = get_act(act_layer) + norm_layer = get_norm(norm_layer) + + width_stem = widths.pop(0) + depth_stem = depths.pop(0) + block_type_stem = block_types.pop(0) + key_block_stem = key_blocks.pop(0) + + if drop_path_uniform: + dpr = [drop_path_rate] * depth + else: + dpr = list(jnp.linspace(0.0, drop_path_rate, depth)) + + self.input_stem = eqx.nn.Sequential( + [ + SingleConvBlock( + in_channels=in_channels, + out_channels=width_stem, + kernel_size=3, + stride=2, + padding="SAME", + use_bias=False, + norm_layer=norm_layer, + act_layer=act_layer, + key=key_stem, + ), + BlockChunk( + in_channels=width_stem, + out_channels=width_stem, + depth=depth_stem, + block_type=block_type_stem, + stride=1, + expand_ratio=stem_expand_ratio, + fuse_mbconv=fuse_mbconv, + norm_layer=norm_layer, + act_layer=act_layer, + key=key_block_stem, + ), + ] + ) + + self.blocks = [ + BlockChunk( + in_channels=widths[i - 1] if i > 0 else width_stem, + out_channels=widths[i], + depth=depth, + block_type=block_type, + mlp_ratio=mlp_ratio, + stride=2, + att_stride=att_stride, + attention_type=attention_type, + expand_ratio=blocks_expand_ratio, + attention_expand_ratio=blocks_attention_expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + fuse_mbconv=fuse_mbconv, + drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], + key=key_block, + ) + for i, (depth, att_stride, block_type, key_block) in enumerate( + zip(depths, att_strides, block_types, key_blocks) + ) + ] + + self.head = ( + eqx.nn.Linear( + in_features=widths[-1], out_features=num_classes, key=key_head + ) + if num_classes and num_classes > 0 + else eqx.nn.Identity() + ) + + def features( + self, + x: Float[Array, "channels height width"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + **kwargs, + ) -> Float[Array, "seqlen dim"]: + """Extract features from input image. + + Args: + x: Input image tensor + inference: Whether to enable dropout during inference + key: PRNG key for random operations + + Returns: + Processed feature tensor + """ + key_stem, *key_blocks = jr.split(key, len(self.blocks) + 1) + + x = self.input_stem(x, key=key_stem) + + for i, blk in enumerate(self.blocks): + x = blk(x, inference=inference, key=key_blocks[i]) + + return x + + def __call__( + self, + x: Float[Array, "channels height width"], + key: PRNGKeyArray = jr.PRNGKey(42), + inference: Optional[bool] = None, + **kwargs, + ) -> Float[Array, "num_classes"]: + """Process input image through the full network. + + Args: + x: Input image tensor + inference: Whether to enable dropout during inference + key: PRNG key for random operations + + Returns: + Classification logits + """ + x = self.features(x, inference=inference, key=key, **kwargs) + + x = reduce(x, "c h w -> c", "mean") + + x = self.head(x) + + return x + + +def lowformer_backbone_b0(**kwargs) -> LowFormer: + backbone = LowFormer( + widths=[8, 16, 32, 64, 128], + depths=[0, 1, 1, 3, 4], + block_types=[ + "conv", + "conv", + "conv", + "attention", + "attention", + ], + att_strides=[2, 2, 2, 2, 1], + stem_expand_ratio=2.0, + blocks_expand_ratio=4.0, + blocks_attention_expand_ratio=4.0, + fuse_mbconv=True, + **kwargs, + ) + return backbone + + +def lowformer_backbone_b1(**kwargs) -> LowFormer: + backbone = LowFormer( + widths=[16, 32, 64, 128, 256], + depths=[1, 2, 3, 3, 4], + block_types=[ + "conv", + "conv", + "conv", + "attention", + "attention", + ], + att_strides=[2, 2, 2, 2, 1], + stem_expand_ratio=2.0, + blocks_expand_ratio=4.0, + blocks_attention_expand_ratio=4.0, + fuse_mbconv=True, + **kwargs, + ) + return backbone + + +def lowformer_backbone_b2(**kwargs) -> LowFormer: + backbone = LowFormer( + widths=[24, 48, 96, 192, 384], + depths=[1, 3, 4, 4, 6], + block_types=[ + "conv", + "conv", + "conv", + "attention", + "attention", + ], + att_strides=[2, 2, 2, 2, 1], + stem_expand_ratio=4.0, + blocks_expand_ratio=4.0, + blocks_attention_expand_ratio=6.0, + fuse_mbconv=True, + **kwargs, + ) + return backbone + + +def lowformer_backbone_b3(**kwargs) -> LowFormer: + backbone = LowFormer( + widths=[32, 64, 128, 256, 512], + depths=[1, 4, 6, 6, 9], + block_types=[ + "conv", + "conv", + "conv", + "attention", + "attention", + ], + att_strides=[2, 2, 2, 2, 1], + stem_expand_ratio=4.0, + blocks_expand_ratio=6.0, + blocks_attention_expand_ratio=6.0, + fuse_mbconv=True, + **kwargs, + ) + return backbone diff --git a/uv.lock b/uv.lock index a487c8b..9fb0caa 100644 --- a/uv.lock +++ b/uv.lock @@ -166,7 +166,7 @@ wheels = [ [[package]] name = "equimo" -version = "0.3.4" +version = "0.4.0a2" source = { virtual = "." } dependencies = [ { name = "einops" }, @@ -175,7 +175,6 @@ dependencies = [ { name = "jaxlib" }, { name = "loguru" }, { name = "lz4" }, - { name = "pytest" }, { name = "requests" }, { name = "semver" }, ] @@ -210,7 +209,6 @@ requires-dist = [ { name = "lz4", specifier = ">=4.4.3" }, { name = "matplotlib", marker = "extra == 'extras'", specifier = ">=3.10.1" }, { name = "pillow", marker = "extra == 'extras'", specifier = ">=11.1.0" }, - { name = "pytest", specifier = ">=8.3.5" }, { name = "requests", specifier = ">=2.32.3" }, { name = "scikit-learn", marker = "extra == 'extras'", specifier = ">=1.6.1" }, { name = "semver", specifier = ">=3.0.4" },