From e1be411dbe011467a3b36fd2ee50342b925f2d06 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 23 Dec 2024 17:01:19 -0800 Subject: [PATCH] merge hyper connection streams before final norm, to avoid edge case with adaptive layernorm --- setup.py | 2 +- x_transformers/x_transformers.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 07e66c5a..af5676ba 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.43.2', + version = '1.43.4', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 24c057ff..482a3800 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -2270,16 +2270,16 @@ def forward( if self.need_condition: final_norm = maybe(partial)(final_norm, **norm_kwargs) - if self.resi_dual: - x = x + final_norm(outer_residual) - else: - x = final_norm(x) - # take care of multistreams if needed, use sum for now if is_multistream: x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams) + if self.resi_dual: + x = x + final_norm(outer_residual) + else: + x = final_norm(x) + if not return_hiddens: return x