Skip to content

Commit 426dce0

Browse files
committed
Remove ContrastiveLossBase for just a function
1 parent 65469df commit 426dce0

File tree

5 files changed

+56
-43
lines changed

5 files changed

+56
-43
lines changed

shimmer/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from shimmer.config import (ShimmerInfoConfig, load_config,
22
load_structured_config)
33
from shimmer.modules.contrastive_loss import (
4-
ContrastiveLoss, ContrastiveLossBase, ContrastiveLossWithUncertainty,
5-
contrastive_loss, contrastive_loss_with_uncertainty)
4+
ContrastiveLoss, ContrastiveLossType, ContrastiveLossWithUncertainty,
5+
VarContrastiveLossType, contrastive_loss,
6+
contrastive_loss_with_uncertainty)
67
from shimmer.modules.domain import DomainModule, LossOutput
78
from shimmer.modules.global_workspace import (GlobalWorkspace,
89
GlobalWorkspaceBase,
@@ -39,7 +40,8 @@
3940
"VariationalGWInterface",
4041
"VariationalGWModule",
4142
"ContrastiveLoss",
42-
"ContrastiveLossBase",
43+
"ContrastiveLossType",
44+
"VarContrastiveLossType",
4345
"ContrastiveLossWithUncertainty",
4446
"contrastive_loss",
4547
"contrastive_loss_with_uncertainty",

shimmer/modules/__init__.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from shimmer.modules.contrastive_loss import (
2-
ContrastiveLoss, ContrastiveLossBase, ContrastiveLossWithUncertainty,
3-
contrastive_loss, contrastive_loss_with_uncertainty)
2+
ContrastiveLoss, ContrastiveLossType, ContrastiveLossWithUncertainty,
3+
VarContrastiveLossType, contrastive_loss,
4+
contrastive_loss_with_uncertainty)
45
from shimmer.modules.domain import DomainModule, LossOutput
56
from shimmer.modules.global_workspace import (GlobalWorkspace,
67
GlobalWorkspaceBase,
@@ -30,11 +31,12 @@
3031
"VariationalGWInterface",
3132
"VariationalGWModule",
3233
"ContrastiveLoss",
33-
"ContrastiveLossBase",
34+
"ContrastiveLossType",
35+
"VarContrastiveLossType",
3436
"ContrastiveLossWithUncertainty",
3537
"contrastive_loss",
3638
"contrastive_loss_with_uncertainty",
37-
"LatentT",
39+
"LatentsT",
3840
"LatentsDomainGroupT",
3941
"GWLosses",
4042
"GWLossesBase",

shimmer/modules/contrastive_loss.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from abc import ABC, abstractmethod
2-
from collections.abc import Mapping
3-
from typing import Literal, TypedDict
1+
from collections.abc import Callable
2+
from typing import Literal
43

54
import torch
65
from torch.nn.functional import cross_entropy, normalize
76

87
from shimmer.modules.domain import LossOutput
98

9+
ContrastiveLossType = Callable[[torch.Tensor, torch.Tensor], LossOutput]
10+
VarContrastiveLossType = Callable[
11+
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], LossOutput
12+
]
13+
1014

1115
def info_nce(
1216
x: torch.Tensor,
@@ -58,13 +62,7 @@ def contrastive_loss_with_uncertainty(
5862
return 0.5 * (ce + ce_t)
5963

6064

61-
class ContrastiveLossBase(torch.nn.Module, ABC):
62-
@abstractmethod
63-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:
64-
...
65-
66-
67-
class ContrastiveLoss(ContrastiveLossBase):
65+
class ContrastiveLoss(torch.nn.Module):
6866
logit_scale: torch.Tensor
6967

7068
def __init__(
@@ -77,17 +75,14 @@ def __init__(
7775
self.register_buffer("logit_scale", logit_scale)
7876
self.reduction: Literal["mean", "sum", "none"] = reduction
7977

80-
def __call__(self, *args, **kwargs) -> LossOutput:
81-
return super().__call__(*args, **kwargs)
82-
8378
def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:
8479
return LossOutput(
8580
contrastive_loss(x, y, self.logit_scale, self.reduction),
8681
{"logit_scale": self.logit_scale.exp()},
8782
)
8883

8984

90-
class ContrastiveLossWithUncertainty(ContrastiveLossBase):
85+
class ContrastiveLossWithUncertainty(torch.nn.Module):
9186
logit_scale: torch.Tensor
9287

9388
def __init__(

shimmer/modules/global_workspace.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.optim.lr_scheduler import OneCycleLR
1010

1111
from shimmer.modules.contrastive_loss import (ContrastiveLoss,
12-
ContrastiveLossBase,
12+
ContrastiveLossType,
1313
ContrastiveLossWithUncertainty)
1414
from shimmer.modules.dict_buffer import DictBuffer
1515
from shimmer.modules.domain import DomainModule
@@ -316,18 +316,24 @@ def __init__(
316316
domain_mods = freeze_domain_modules(domain_mods)
317317
coef_buffers = DictBuffer(loss_coefs)
318318

319-
contrastive_fn_class = (
320-
ContrastiveLossWithUncertainty
321-
if var_contrastive_loss
322-
else ContrastiveLoss
323-
)
324-
325-
loss_mod = VariationalGWLosses(
326-
gw_mod,
327-
domain_mods,
328-
coef_buffers,
329-
contrastive_fn_class(torch.tensor([1]).log(), "mean"),
330-
)
319+
if var_contrastive_loss:
320+
loss_mod = VariationalGWLosses(
321+
gw_mod,
322+
domain_mods,
323+
coef_buffers,
324+
var_contrastive_fn=ContrastiveLossWithUncertainty(
325+
torch.tensor([1]).log(), "mean"
326+
),
327+
)
328+
else:
329+
loss_mod = VariationalGWLosses(
330+
gw_mod,
331+
domain_mods,
332+
coef_buffers,
333+
contrastive_fn=ContrastiveLoss(
334+
torch.tensor([1]).log(), "mean"
335+
),
336+
)
331337

332338
super().__init__(
333339
gw_mod,
@@ -346,14 +352,14 @@ def pretrained_global_workspace(
346352
gw_interfaces: Mapping[str, GWInterfaceBase],
347353
workspace_dim: int,
348354
loss_coefs: Mapping[str, torch.Tensor],
349-
var_contrastive_loss: bool = False,
355+
contrastive_fn: ContrastiveLossType,
350356
**kwargs,
351357
) -> GlobalWorkspaceBase:
352358
gw_mod = VariationalGWModule(gw_interfaces, workspace_dim)
353359
domain_mods = freeze_domain_modules(domain_mods)
354360
coef_buffers = DictBuffer(loss_coefs)
355361
loss_mod = VariationalGWLosses(
356-
gw_mod, domain_mods, coef_buffers, var_contrastive_loss
362+
gw_mod, domain_mods, coef_buffers, contrastive_fn
357363
)
358364

359365
return cast(

shimmer/modules/losses.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import torch
55

6-
from shimmer.modules.contrastive_loss import ContrastiveLossBase
6+
from shimmer.modules.contrastive_loss import (ContrastiveLossType,
7+
VarContrastiveLossType)
78
from shimmer.modules.dict_buffer import DictBuffer
89
from shimmer.modules.domain import DomainModule
910
from shimmer.modules.gw_module import (GWModule, GWModuleBase,
@@ -156,7 +157,7 @@ def _translation_loss(
156157
def _contrastive_loss(
157158
gw_mod: GWModuleBase,
158159
latent_domains: LatentsT,
159-
contrastive_fn: ContrastiveLossBase,
160+
contrastive_fn: ContrastiveLossType,
160161
) -> dict[str, torch.Tensor]:
161162
losses: dict[str, torch.Tensor] = {}
162163
metrics: dict[str, torch.Tensor] = {}
@@ -197,7 +198,7 @@ def _contrastive_loss(
197198
def _contrastive_loss_with_uncertainty(
198199
gw_mod: VariationalGWModule,
199200
latent_domains: LatentsT,
200-
contrastive_fn: ContrastiveLossBase,
201+
contrastive_fn: VarContrastiveLossType,
201202
) -> dict[str, torch.Tensor]:
202203
losses: dict[str, torch.Tensor] = {}
203204
metrics: dict[str, torch.Tensor] = {}
@@ -246,7 +247,7 @@ def __init__(
246247
gw_mod: GWModule,
247248
domain_mods: dict[str, DomainModule],
248249
coef_buffers: DictBuffer,
249-
contrastive_fn: ContrastiveLossBase,
250+
contrastive_fn: ContrastiveLossType,
250251
):
251252
super().__init__()
252253
self.gw_mod = gw_mod
@@ -303,14 +304,19 @@ def __init__(
303304
gw_mod: VariationalGWModule,
304305
domain_mods: dict[str, DomainModule],
305306
coef_buffers: DictBuffer,
306-
contrastive_fn: ContrastiveLossBase,
307+
contrastive_fn: ContrastiveLossType | None = None,
308+
var_contrastive_fn: VarContrastiveLossType | None = None,
307309
):
308310
super().__init__()
309311

310312
self.gw_mod = gw_mod
311313
self.domain_mods = domain_mods
312314
self.loss_coefs = coef_buffers
315+
assert (contrastive_fn is not None) != (
316+
var_contrastive_fn is not None
317+
), "Should either have contrastive_fn or var_contrastive_fn"
313318
self.contrastive_fn = contrastive_fn
319+
self.var_contrastive_fn = var_contrastive_fn
314320

315321
def demi_cycle_loss(
316322
self, latent_domains: LatentsT
@@ -328,10 +334,12 @@ def translation_loss(
328334
def contrastive_loss(
329335
self, latent_domains: LatentsT
330336
) -> dict[str, torch.Tensor]:
331-
if self.var_contrastive_loss:
337+
if self.var_contrastive_fn is not None:
332338
return _contrastive_loss_with_uncertainty(
333-
self.gw_mod, latent_domains, self.contrastive_fn
339+
self.gw_mod, latent_domains, self.var_contrastive_fn
334340
)
341+
342+
assert self.contrastive_fn is not None
335343
return _contrastive_loss(
336344
self.gw_mod, latent_domains, self.contrastive_fn
337345
)

0 commit comments

Comments
 (0)