Skip to content

Commit 6259db4

Browse files
Remove fusion_activation_fn from GWModuleBase
1 parent 39669f7 commit 6259db4

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

shimmer/modules/gw_module.py

-6
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def __init__(
242242
self,
243243
domain_mods: Mapping[str, DomainModule],
244244
workspace_dim: int,
245-
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
246245
*args,
247246
**kwargs,
248247
) -> None:
@@ -252,8 +251,6 @@ def __init__(
252251
Args:
253252
domain_modules (`Mapping[str, DomainModule]`): the domain modules.
254253
workspace_dim (`int`): dimension of the GW.
255-
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
256-
function used to fuse the domains.
257254
"""
258255
super().__init__()
259256

@@ -263,9 +260,6 @@ def __init__(
263260
self.workspace_dim = workspace_dim
264261
"""Dimension of the GW"""
265262

266-
self.fusion_activation_fn = fusion_activation_fn
267-
"""Activation function used to fuse the domains"""
268-
269263
@abstractmethod
270264
def fuse(
271265
self, x: LatentsDomainGroupT, selection_scores: Mapping[str, torch.Tensor]

0 commit comments

Comments
 (0)