Skip to content

Commit 2c9b60d

Browse files
authored
Better loss system for broadcast_loss (#126)
* if compute_loss returns None, it's skipped + Broadcast loss uses tr, dcy, cy and fused losses * Lower lightning deps * add docstring to explain what returning None can offer
1 parent d3a3978 commit 2c9b60d

File tree

4 files changed

+59
-16
lines changed

4 files changed

+59
-16
lines changed

poetry.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pandas = "^2.2.2"
1313
matplotlib = "^3.9.1"
1414
migrate-ckpt = {git = "https://github.com/bdvllrs/migrate-ckpt.git", rev = "v0.2.0"}
1515
click = "^8.1.7"
16-
lightning = "^2.3.3"
16+
lightning = ">=2.1.0"
1717
torch = "^2.0.1"
1818

1919
[tool.poetry.group.dev.dependencies]

shimmer/modules/domain.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,25 @@ def decode(self, z: torch.Tensor) -> Any:
9393
"""
9494
raise NotImplementedError
9595

96-
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
96+
def compute_loss(
97+
self, pred: torch.Tensor, target: torch.Tensor
98+
) -> LossOutput | None:
9799
"""
98100
Generic loss computation the modality.
99101
100102
Args:
101103
pred (`torch.Tensor`): prediction of the model
102104
target (`torch.Tensor`): target tensor
103105
Results:
104-
`LossOutput`: LossOuput with training loss and additional metrics.
106+
`LossOutput | None`: LossOuput with training loss and additional metrics.
107+
If `None` is returned, this loss will be ignored and will not
108+
participate in the total loss.
105109
"""
106110
raise NotImplementedError
107111

108-
def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
112+
def compute_dcy_loss(
113+
self, pred: torch.Tensor, target: torch.Tensor
114+
) -> LossOutput | None:
109115
"""
110116
Computes the loss for a demi-cycle. Override if the demi-cycle loss is
111117
different that the generic loss.
@@ -114,11 +120,16 @@ def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutp
114120
pred (`torch.Tensor`): prediction of the model
115121
target (`torch.Tensor`): target tensor
116122
Results:
117-
`LossOutput`: LossOuput with training loss and additional metrics.
123+
`LossOutput | None`: LossOuput with training loss and additional metrics.
124+
If `None` is returned, this loss will be ignored and will not
125+
participate in the total loss; it can be used to deactivate
126+
demi-cycle loss for this domain.
118127
"""
119128
return self.compute_loss(pred, target)
120129

121-
def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
130+
def compute_cy_loss(
131+
self, pred: torch.Tensor, target: torch.Tensor
132+
) -> LossOutput | None:
122133
"""
123134
Computes the loss for a cycle. Override if the cycle loss is
124135
different that the generic loss.
@@ -127,11 +138,16 @@ def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu
127138
pred (`torch.Tensor`): prediction of the model
128139
target (`torch.Tensor`): target tensor
129140
Results:
130-
`LossOutput`: LossOuput with training loss and additional metrics.
141+
`LossOutput | None`: LossOuput with training loss and additional metrics.
142+
If `None` is returned, this loss will be ignored and will not
143+
participate in the total loss; it can be used to deactivate
144+
cycle loss for this domain.
131145
"""
132146
return self.compute_loss(pred, target)
133147

134-
def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
148+
def compute_tr_loss(
149+
self, pred: torch.Tensor, target: torch.Tensor
150+
) -> LossOutput | None:
135151
"""
136152
Computes the loss for a translation. Override if the translation loss is
137153
different that the generic loss.
@@ -140,21 +156,27 @@ def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu
140156
pred (`torch.Tensor`): prediction of the model
141157
target (`torch.Tensor`): target tensor
142158
Results:
143-
`LossOutput`: LossOuput with training loss and additional metrics.
159+
`LossOutput | None`: LossOuput with training loss and additional metrics.
160+
If `None` is returned, this loss will be ignored and will not
161+
participate in the total loss; it can be used to deactivate
162+
translation loss for this domain.
144163
"""
145164
return self.compute_loss(pred, target)
146165

147-
def compute_broadcast_loss(
166+
def compute_fused_loss(
148167
self, pred: torch.Tensor, target: torch.Tensor
149-
) -> LossOutput:
168+
) -> LossOutput | None:
150169
"""
151-
Computes the loss for a broadcast (fusion). Override if the broadcast loss is
170+
Computes the loss for fused (fusion). Override if the fused loss is
152171
different that the generic loss.
153172
154173
Args:
155174
pred (`torch.Tensor`): prediction of the model
156175
target (`torch.Tensor`): target tensor
157176
Results:
158-
`LossOutput`: LossOuput with training loss and additional metrics.
177+
`LossOutput | None`: LossOuput with training loss and additional metrics.
178+
If `None` is returned, this loss will be ignored and will not
179+
participate in the total loss; it can be used to deactivate
180+
fused loss for this domain.
159181
"""
160182
return self.compute_loss(pred, target)

shimmer/modules/losses.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def demi_cycle_loss(
7878
gw_mod.encode_and_fuse(latents, selection_mod), domains={domain_name}
7979
)[domain_name]
8080
loss_output = domain_mod.compute_dcy_loss(x_recons, latents[domain_name])
81+
if loss_output is None:
82+
continue
8183
losses[f"demi_cycle_{domain_name}"] = loss_output.loss
8284
metrics.update(
8385
{f"demi_cycle_{domain_name}_{k}": v for k, v in loss_output.metrics.items()}
@@ -138,6 +140,9 @@ def cycle_loss(
138140
x_recons[domain_name_source],
139141
latents_source[domain_name_source],
140142
)
143+
if loss_output is None:
144+
continue
145+
141146
metrics.update(
142147
{f"cycle_{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
143148
)
@@ -200,6 +205,9 @@ def translation_loss(
200205
prediction,
201206
latents[domain_name_target],
202207
)
208+
if loss_output is None:
209+
continue
210+
203211
losses[f"translation_{loss_name}"] = loss_output.loss
204212
metrics.update(
205213
{
@@ -565,7 +573,18 @@ def broadcast_loss(
565573
if domain not in group_domains: # if we don't have ground truth
566574
continue
567575
ground_truth = latents[domain]
568-
loss_output = domain_mods[domain].compute_loss(pred, ground_truth)
576+
577+
if num_active_domains == 1 and domain in selected_latents:
578+
loss_fn = domain_mods[domain].compute_dcy_loss
579+
elif domain not in selected_latents:
580+
loss_fn = domain_mods[domain].compute_tr_loss
581+
else:
582+
loss_fn = domain_mods[domain].compute_fused_loss
583+
584+
loss_output = loss_fn(pred, ground_truth)
585+
if loss_output is None:
586+
continue
587+
569588
loss_label = f"from_{selected_group_label}_to_{domain}"
570589
losses[loss_label + "_loss"] = loss_output.loss
571590
metrics.update(
@@ -601,9 +620,11 @@ def broadcast_loss(
601620

602621
for domain in selected_latents:
603622
re_ground_truth = latents[domain]
604-
re_loss_output = domain_mods[domain].compute_loss(
623+
re_loss_output = domain_mods[domain].compute_cy_loss(
605624
re_decoded_latents[domain], re_ground_truth
606625
)
626+
if re_loss_output is None:
627+
continue
607628
loss_label = (
608629
f"from_{selected_group_label}_"
609630
f"through_{inverse_selected_group_label}_to_{domain}_"

0 commit comments

Comments
 (0)