Skip to content

Commit

Permalink
using tanh in hyperconnection was not clear cut
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 22, 2024
1 parent 8880051 commit cdf51f7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.43.0',
version = '1.43.1',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
9 changes: 7 additions & 2 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,15 +591,20 @@ def test_cross_attn_rotary(
context_mask = context_mask
)

def test_hyper_connections():
@pytest.mark.parametrize('tanh', (True, False))
def test_hyper_connections(tanh):

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8,
num_residual_streams = 8 # 8 dynamic hyper connection residual streams
num_residual_streams = 8, # 8 dynamic hyper connection residual streams
residual_fn_kwargs = dict(
tanh = tanh
)
)
)

Expand Down
10 changes: 7 additions & 3 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,7 @@ def __init__(
*,
layer_index,
num_residual_streams,
tanh = True,
**kwargs
):
"""
Expand All @@ -878,6 +879,8 @@ def __init__(
"""
super().__init__()

self.act = nn.Tanh() if tanh else nn.Identity()

self.norm = nn.LayerNorm(dim, bias = False)

self.num_residual_streams = num_residual_streams
Expand All @@ -901,11 +904,11 @@ def prepare(self, residuals):

normed = self.norm(residuals)

wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
alpha = dynamic_alpha + self.static_alpha

dc_weight = (normed @ self.dynamic_beta_fn).tanh()
dc_weight = self.act(normed @ self.dynamic_beta_fn)
dynamic_beta = dc_weight * self.dynamic_beta_scale
beta = dynamic_beta + self.static_beta

Expand Down Expand Up @@ -1653,6 +1656,7 @@ def __init__(
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
rel_pos_kwargs: dict = dict(),
residual_fn_kwargs: dict = dict(),
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -1957,7 +1961,7 @@ def __init__(
else:
residual_fn = Residual

residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant, **residual_fn_kwargs)

# handle unet skip connection

Expand Down

0 comments on commit cdf51f7

Please sign in to comment.