Skip to content

Commit 08169d6

Browse files
committed
Update changelog
1 parent e799073 commit 08169d6

File tree

4 files changed

+26
-22
lines changed

4 files changed

+26
-22
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,9 @@ refers to `DeterministicGlobalWorkspace`.
3030

3131
# 0.4.0
3232
* Use ABC for abstract methods.
33+
* Replace `DomainDescription` with `GWInterface`.
34+
* Add `contrastive_fn` attribute in `DeterministicGWLosses` to compute the contrastive loss.
35+
It can then be customized.
36+
* Rename every abstract class with ClassNameBase. Rename every "Deterministic" classes
37+
to remove "Deterministic".
38+

shimmer/__init__.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
SchedulerArgs,
77
VariationalGlobalWorkspace,
88
pretrained_global_workspace)
9-
from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder,
10-
GWEncoder, GWInterface, GWInterfaceBase,
11-
GWModule, VariationalGWEncoder,
9+
from shimmer.modules.gw_module import (GWDecoder, GWEncoder, GWInterface,
10+
GWInterfaceBase, GWModule, GWModuleBase,
11+
VariationalGWEncoder,
1212
VariationalGWInterface,
1313
VariationalGWModule)
14-
from shimmer.modules.losses import (DeterministicGWLosses, GWLosses,
15-
VariationalGWLosses)
14+
from shimmer.modules.losses import GWLosses, GWLossesBase, VariationalGWLosses
1615
from shimmer.version import __version__
1716

1817
__all__ = [
@@ -22,16 +21,16 @@
2221
"ShimmerInfoConfig",
2322
"DomainModule",
2423
"GWInterfaceBase",
25-
"DeterministicGWModule",
24+
"GWModule",
2625
"GWDecoder",
2726
"GWEncoder",
2827
"GWInterface",
29-
"GWModule",
28+
"GWModuleBase",
3029
"VariationalGWEncoder",
3130
"VariationalGWInterface",
3231
"VariationalGWModule",
33-
"DeterministicGWLosses",
3432
"GWLosses",
33+
"GWLossesBase",
3534
"VariationalGWLosses",
3635
"GlobalWorkspace",
3736
"GlobalWorkspaceBase",

shimmer/modules/__init__.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,26 @@
44
SchedulerArgs,
55
VariationalGlobalWorkspace,
66
pretrained_global_workspace)
7-
from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder,
8-
GWEncoder, GWInterface, GWInterfaceBase,
9-
GWModule, VariationalGWEncoder,
7+
from shimmer.modules.gw_module import (GWDecoder, GWEncoder, GWInterface,
8+
GWInterfaceBase, GWModule, GWModuleBase,
9+
VariationalGWEncoder,
1010
VariationalGWInterface,
1111
VariationalGWModule)
12-
from shimmer.modules.losses import (DeterministicGWLosses, GWLosses,
13-
VariationalGWLosses)
12+
from shimmer.modules.losses import GWLosses, GWLossesBase, VariationalGWLosses
1413

1514
__all__ = [
1615
"DomainModule",
1716
"GWInterfaceBase",
18-
"DeterministicGWModule",
17+
"GWModule",
1918
"GWDecoder",
2019
"GWEncoder",
2120
"GWInterface",
22-
"GWModule",
21+
"GWModuleBase",
2322
"VariationalGWEncoder",
2423
"VariationalGWInterface",
2524
"VariationalGWModule",
26-
"DeterministicGWLosses",
2725
"GWLosses",
26+
"GWLossesBase",
2827
"VariationalGWLosses",
2928
"GlobalWorkspace",
3029
"GlobalWorkspaceBase",

shimmer/modules/losses.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from shimmer.modules.dict_buffer import DictBuffer
99
from shimmer.modules.domain import DomainModule
10-
from shimmer.modules.gw_module import (DeterministicGWModule, GWModule,
10+
from shimmer.modules.gw_module import (GWModule, GWModuleBase,
1111
VariationalGWModule)
1212
from shimmer.modules.vae import kl_divergence_loss
1313

@@ -134,7 +134,7 @@ def step(
134134

135135

136136
def _demi_cycle_loss(
137-
gw_mod: GWModule,
137+
gw_mod: GWModuleBase,
138138
domain_mods: dict[str, DomainModule],
139139
latent_domains: LatentsT,
140140
) -> dict[str, torch.Tensor]:
@@ -164,7 +164,7 @@ def _demi_cycle_loss(
164164

165165

166166
def _cycle_loss(
167-
gw_mod: GWModule,
167+
gw_mod: GWModuleBase,
168168
domain_mods: dict[str, DomainModule],
169169
latent_domains: LatentsT,
170170
) -> dict[str, torch.Tensor]:
@@ -205,7 +205,7 @@ def _cycle_loss(
205205

206206

207207
def _translation_loss(
208-
gw_mod: GWModule,
208+
gw_mod: GWModuleBase,
209209
domain_mods: dict[str, DomainModule],
210210
latent_domains: LatentsT,
211211
) -> dict[str, torch.Tensor]:
@@ -252,7 +252,7 @@ def _translation_loss(
252252

253253

254254
def _contrastive_loss(
255-
gw_mod: GWModule,
255+
gw_mod: GWModuleBase,
256256
latent_domains: LatentsT,
257257
contrastive_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
258258
) -> dict[str, torch.Tensor]:
@@ -330,7 +330,7 @@ def _contrastive_loss_with_uncertainty(
330330
class GWLosses(GWLossesBase):
331331
def __init__(
332332
self,
333-
gw_mod: DeterministicGWModule,
333+
gw_mod: GWModule,
334334
domain_mods: dict[str, DomainModule],
335335
coef_buffers: DictBuffer,
336336
):

0 commit comments

Comments
 (0)