diff --git a/equimo/layers/norm.py b/equimo/layers/norm.py index 837c715..2023361 100644 --- a/equimo/layers/norm.py +++ b/equimo/layers/norm.py @@ -58,12 +58,10 @@ class LayerScale(eqx.Module): dampening the contribution of each layer. Attributes: - init_values: Initial scale value (static) gamma: Learnable scale parameters of size dim """ - init_values: float = eqx.field(static=True) - gamma: Optional[Float[Array, "dim"]] + gamma: Float[Array, "dim"] def __init__(self, dim: int, init_values: float): """Initialize LayerScale. @@ -72,8 +70,7 @@ def __init__(self, dim: int, init_values: float): dim: Dimension of the input features init_values: Initial value for all scaling factors """ - self.init_values = init_values - self.gamma = jnp.repeat(self.init_values, dim) + self.gamma = jnp.repeat(init_values, dim) def __call__(self, x: Float[Array, "dim"]): """Apply layer scaling to input tensor. @@ -87,6 +84,49 @@ def __call__(self, x: Float[Array, "dim"]): return x * self.gamma +class DyT(eqx.Module): + """Dynamic Tanh layear. + + This layer implements the DyT layer introduced in the Transformer + without Normalization paper[1]. + + Attributes: + init_values: Initial scale value (static) + gamma: Learnable scale parameters of size dim + + References: + [1]. Zhu, et al., Transformers without Normalization. 2025. + https://arxiv.org/abs/2503.10622 + """ + + alpha: Float[Array, "dim"] + weight: Float[Array, "dim"] + bias: Float[Array, "dim"] + + def __init__(self, dim: int, alpha_init_value: float = 0.5): + """Initialize DyT. + + Args: + dim: Dimension of the input features + alpha_init_value: Initial value for the scaling factor + """ + self.alpha = jnp.repeat(alpha_init_value, dim) + self.weight = jnp.ones(dim) + self.bias = jnp.zeros(dim) + + def __call__(self, x: Float[Array, "dim"]): + """Apply dynamic tanh to input tensor. + + Args: + x: Input tensor of shape (dim,) + + Returns: + Scaled tensor of same shape as input + """ + x = jnp.tanh(self.alpha * x) + return x * self.weight + self.bias + + def get_norm(module: str | eqx.Module) -> eqx.Module: """Get an `eqx.Module` from its common name. @@ -107,5 +147,7 @@ def get_norm(module: str | eqx.Module) -> eqx.Module: return RMSNormGated case "layerscale": return LayerScale + case "dynamictanh": + return DyT case _: raise ValueError(f"Got an unknown module string: {module}")