Skip to content

Commit 74398e3

Browse files
committed
docs: add docstrings to combine_loss
1 parent de2e0c9 commit 74398e3

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

shimmer/modules/losses.py

+15
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,21 @@ def combine_loss(
333333
metrics: dict[str, torch.Tensor],
334334
coefs: Mapping[str, float] | LossCoefs | BroadcastLossCoefs,
335335
) -> torch.Tensor:
336+
"""
337+
Combines the metrics according to the ones selected in coefs
338+
339+
Args:
340+
metrics (`dict[str, torch.Tensor]`): all metrics to combine
341+
coefs (`Mapping[str, float] | LossCoefs | BroadcastLossCoefs`): coefs for
342+
selected metrics. Note, every metric does not need to be included here.
343+
If not specified, the metric will not count in the final loss.
344+
Also not that some metrics are redundant (e.g. `translations` contains
345+
all of the `translation_{domain_1}_to_{domain_2}`). You can look at the
346+
docs of the different losses for available values.
347+
348+
Returns:
349+
`torch.Tensor`: the combined loss.
350+
"""
336351
loss = torch.stack(
337352
[
338353
metrics[name] * coef

0 commit comments

Comments
 (0)