diff --git a/setup.py b/setup.py index 07e66c5a..af5676ba 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.43.2', + version = '1.43.4', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 24c057ff..482a3800 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -2270,16 +2270,16 @@ def forward( if self.need_condition: final_norm = maybe(partial)(final_norm, **norm_kwargs) - if self.resi_dual: - x = x + final_norm(outer_residual) - else: - x = final_norm(x) - # take care of multistreams if needed, use sum for now if is_multistream: x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams) + if self.resi_dual: + x = x + final_norm(outer_residual) + else: + x = final_norm(x) + if not return_hiddens: return x