Skip to content

Commit

Permalink
add the proposed hyper-connections (multiple residual streams) propos…
Browse files Browse the repository at this point in the history
…ed by bytedance ai labs
  • Loading branch information
lucidrains committed Dec 21, 2024
1 parent 0fd37f5 commit 8b367e6
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 8 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2363,4 +2363,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
journal = {ArXiv},
year = {2024},
volume = {abs/2409.19606},
url = {https://api.semanticscholar.org/CorpusID:272987528}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.28',
version = '1.43.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
16 changes: 16 additions & 0 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,19 @@ def test_cross_attn_rotary(
context_pos = context_pos,
context_mask = context_mask
)

def test_hyper_connections():
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8,
num_residual_streams = 8 # 8 dynamic hyper connection residual streams
)
)

x = torch.randint(0, 20000, (2, 1024))

model(x)
112 changes: 105 additions & 7 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,12 +824,15 @@ def forward(self, x):
# residual and residual gates

class Residual(Module):
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
super().__init__()
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
self.scale_residual_constant = scale_residual_constant

def forward(self, x, residual):
def prepare(self, residual):
return residual, residual, dict()

def forward(self, x, residual, **kwargs):
if exists(self.residual_scale):
residual = residual * self.residual_scale

Expand All @@ -844,7 +847,10 @@ def __init__(self, dim, scale_residual = False, **kwargs):
self.gru = nn.GRUCell(dim, dim)
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None

def forward(self, x, residual):
def prepare(self, residual):
return residual, residual, dict()

def forward(self, x, residual, **kwargs):
if exists(self.residual_scale):
residual = residual * self.residual_scale

Expand All @@ -855,6 +861,66 @@ def forward(self, x, residual):

return gated_output.reshape_as(x)

# hyper connections

class HyperConnection(Module):
def __init__(
self,
dim,
*,
layer_index,
num_residual_streams,
**kwargs
):
"""
https://arxiv.org/abs/2409.19606
Appendix J - Algorithm 2, Dynamic only
"""
super().__init__()

self.norm = nn.LayerNorm(dim, bias = False)

self.num_residual_streams = num_residual_streams
self.layer_index = layer_index

self.static_beta = nn.Parameter(torch.ones(num_residual_streams))

init_alpha0 = torch.zeros((num_residual_streams, 1))
init_alpha0[layer_index % num_residual_streams, 0] = 1.

self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))

self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)

def prepare(self, residuals):

residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)

normed = self.norm(residuals)

wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
alpha = dynamic_alpha + self.static_alpha

dc_weight = (normed @ self.dynamic_beta_fn).tanh()
dynamic_beta = dc_weight * self.dynamic_beta_scale
beta = dynamic_beta + self.static_beta

# width connection

mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)

branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]

return branch_input, residuals, dict(beta = beta)

def forward(self, x, residuals, *, beta):
residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
return rearrange(residuals, 'b n s d -> (b s) n d')

# token shifting

def shift(t, amount, mask = None):
Expand Down Expand Up @@ -1582,6 +1648,7 @@ def __init__(
use_layerscale = False,
layerscale_init_value = 0.,
unet_skips = False,
num_residual_streams = 1,
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
Expand All @@ -1607,6 +1674,17 @@ def __init__(
self.causal = causal
self.layers = ModuleList([])

# greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606

assert num_residual_streams > 0

self.num_residual_streams = num_residual_streams
self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None

assert not (num_residual_streams > 1 and gate_residual)

# positions related

self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))

rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
Expand Down Expand Up @@ -1872,9 +1950,14 @@ def __init__(
if exists(post_branch_fn):
layer = post_branch_fn(layer)

residual_fn = GRUGating if gate_residual else Residual
if num_residual_streams > 1:
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
elif gate_residual:
residual_fn = GRUGating
else:
residual_fn = Residual

residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)

# handle unet skip connection

Expand Down Expand Up @@ -2024,6 +2107,16 @@ def forward(

iter_attn_cache = iter(attn_cache)

# setup multistreams if needed

streams = self.num_residual_streams
is_multistream = streams > 1

if is_multistream:
x = repeat(x, 'b n d -> b n s d', s = streams)
x = x + self.stream_emb
x = rearrange(x, 'b n s d -> (b s) n d')

# outer residual - for resiDual paper

outer_residual = x * self.resi_dual_scale
Expand Down Expand Up @@ -2090,7 +2183,7 @@ def forward(
if self.training and self.cross_attn_tokens_dropout > 0.:
context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)

inner_residual = x
x, inner_residual, residual_kwargs = residual_fn.prepare(x)

if return_hiddens:
layer_hiddens.append(x)
Expand Down Expand Up @@ -2148,7 +2241,7 @@ def forward(
if exists(post_branch_norm):
out = post_branch_norm(out)

x = residual_fn(out, inner_residual)
x = residual_fn(out, inner_residual, **residual_kwargs)

if layer_type in ('a', 'c') and return_hiddens:
inter.layer_type = layer_type
Expand Down Expand Up @@ -2178,6 +2271,11 @@ def forward(
else:
x = final_norm(x)

# take care of multistreams if needed, use sum for now

if is_multistream:
x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)

if not return_hiddens:
return x

Expand Down

1 comment on commit 8b367e6

@lucidrains
Copy link
Owner Author

@lucidrains lucidrains commented on 8b367e6 Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

seeing it at 4 streams ('expansion rate' in paper) even on my toy setup

Please sign in to comment.