3
3
4
4
import torch
5
5
6
- from shimmer .modules .contrastive_loss import ContrastiveLossBase
6
+ from shimmer .modules .contrastive_loss import (ContrastiveLossType ,
7
+ VarContrastiveLossType )
7
8
from shimmer .modules .dict_buffer import DictBuffer
8
9
from shimmer .modules .domain import DomainModule
9
10
from shimmer .modules .gw_module import (GWModule , GWModuleBase ,
@@ -156,7 +157,7 @@ def _translation_loss(
156
157
def _contrastive_loss (
157
158
gw_mod : GWModuleBase ,
158
159
latent_domains : LatentsT ,
159
- contrastive_fn : ContrastiveLossBase ,
160
+ contrastive_fn : ContrastiveLossType ,
160
161
) -> dict [str , torch .Tensor ]:
161
162
losses : dict [str , torch .Tensor ] = {}
162
163
metrics : dict [str , torch .Tensor ] = {}
@@ -197,7 +198,7 @@ def _contrastive_loss(
197
198
def _contrastive_loss_with_uncertainty (
198
199
gw_mod : VariationalGWModule ,
199
200
latent_domains : LatentsT ,
200
- contrastive_fn : ContrastiveLossBase ,
201
+ contrastive_fn : VarContrastiveLossType ,
201
202
) -> dict [str , torch .Tensor ]:
202
203
losses : dict [str , torch .Tensor ] = {}
203
204
metrics : dict [str , torch .Tensor ] = {}
@@ -246,7 +247,7 @@ def __init__(
246
247
gw_mod : GWModule ,
247
248
domain_mods : dict [str , DomainModule ],
248
249
coef_buffers : DictBuffer ,
249
- contrastive_fn : ContrastiveLossBase ,
250
+ contrastive_fn : ContrastiveLossType ,
250
251
):
251
252
super ().__init__ ()
252
253
self .gw_mod = gw_mod
@@ -303,14 +304,19 @@ def __init__(
303
304
gw_mod : VariationalGWModule ,
304
305
domain_mods : dict [str , DomainModule ],
305
306
coef_buffers : DictBuffer ,
306
- contrastive_fn : ContrastiveLossBase ,
307
+ contrastive_fn : ContrastiveLossType | None = None ,
308
+ var_contrastive_fn : VarContrastiveLossType | None = None ,
307
309
):
308
310
super ().__init__ ()
309
311
310
312
self .gw_mod = gw_mod
311
313
self .domain_mods = domain_mods
312
314
self .loss_coefs = coef_buffers
315
+ assert (contrastive_fn is not None ) != (
316
+ var_contrastive_fn is not None
317
+ ), "Should either have contrastive_fn or var_contrastive_fn"
313
318
self .contrastive_fn = contrastive_fn
319
+ self .var_contrastive_fn = var_contrastive_fn
314
320
315
321
def demi_cycle_loss (
316
322
self , latent_domains : LatentsT
@@ -328,10 +334,12 @@ def translation_loss(
328
334
def contrastive_loss (
329
335
self , latent_domains : LatentsT
330
336
) -> dict [str , torch .Tensor ]:
331
- if self .var_contrastive_loss :
337
+ if self .var_contrastive_fn is not None :
332
338
return _contrastive_loss_with_uncertainty (
333
- self .gw_mod , latent_domains , self .contrastive_fn
339
+ self .gw_mod , latent_domains , self .var_contrastive_fn
334
340
)
341
+
342
+ assert self .contrastive_fn is not None
335
343
return _contrastive_loss (
336
344
self .gw_mod , latent_domains , self .contrastive_fn
337
345
)
0 commit comments