File tree 1 file changed +15
-0
lines changed
1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -333,6 +333,21 @@ def combine_loss(
333
333
metrics : dict [str , torch .Tensor ],
334
334
coefs : Mapping [str , float ] | LossCoefs | BroadcastLossCoefs ,
335
335
) -> torch .Tensor :
336
+ """
337
+ Combines the metrics according to the ones selected in coefs
338
+
339
+ Args:
340
+ metrics (`dict[str, torch.Tensor]`): all metrics to combine
341
+ coefs (`Mapping[str, float] | LossCoefs | BroadcastLossCoefs`): coefs for
342
+ selected metrics. Note, every metric does not need to be included here.
343
+ If not specified, the metric will not count in the final loss.
344
+ Also not that some metrics are redundant (e.g. `translations` contains
345
+ all of the `translation_{domain_1}_to_{domain_2}`). You can look at the
346
+ docs of the different losses for available values.
347
+
348
+ Returns:
349
+ `torch.Tensor`: the combined loss.
350
+ """
336
351
loss = torch .stack (
337
352
[
338
353
metrics [name ] * coef
You can’t perform that action at this time.
0 commit comments