Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions equimo/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@ class LayerScale(eqx.Module):
dampening the contribution of each layer.

Attributes:
init_values: Initial scale value (static)
gamma: Learnable scale parameters of size dim
"""

init_values: float = eqx.field(static=True)
gamma: Optional[Float[Array, "dim"]]
gamma: Float[Array, "dim"]

def __init__(self, dim: int, init_values: float):
"""Initialize LayerScale.
Expand All @@ -72,8 +70,7 @@ def __init__(self, dim: int, init_values: float):
dim: Dimension of the input features
init_values: Initial value for all scaling factors
"""
self.init_values = init_values
self.gamma = jnp.repeat(self.init_values, dim)
self.gamma = jnp.repeat(init_values, dim)

def __call__(self, x: Float[Array, "dim"]):
"""Apply layer scaling to input tensor.
Expand All @@ -87,6 +84,49 @@ def __call__(self, x: Float[Array, "dim"]):
return x * self.gamma


class DyT(eqx.Module):
"""Dynamic Tanh layear.

This layer implements the DyT layer introduced in the Transformer
without Normalization paper[1].

Attributes:
init_values: Initial scale value (static)
gamma: Learnable scale parameters of size dim

References:
[1]. Zhu, et al., Transformers without Normalization. 2025.
https://arxiv.org/abs/2503.10622
"""

alpha: Float[Array, "dim"]
weight: Float[Array, "dim"]
bias: Float[Array, "dim"]

def __init__(self, dim: int, alpha_init_value: float = 0.5):
"""Initialize DyT.

Args:
dim: Dimension of the input features
alpha_init_value: Initial value for the scaling factor
"""
self.alpha = jnp.repeat(alpha_init_value, dim)
self.weight = jnp.ones(dim)
self.bias = jnp.zeros(dim)

def __call__(self, x: Float[Array, "dim"]):
"""Apply dynamic tanh to input tensor.

Args:
x: Input tensor of shape (dim,)

Returns:
Scaled tensor of same shape as input
"""
x = jnp.tanh(self.alpha * x)
return x * self.weight + self.bias


def get_norm(module: str | eqx.Module) -> eqx.Module:
"""Get an `eqx.Module` from its common name.

Expand All @@ -107,5 +147,7 @@ def get_norm(module: str | eqx.Module) -> eqx.Module:
return RMSNormGated
case "layerscale":
return LayerScale
case "dynamictanh":
return DyT
case _:
raise ValueError(f"Got an unknown module string: {module}")