Skip to content

Commit

Permalink
Replace implicit subloss addition with user's "total" loss key
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed May 17, 2024
1 parent 388c9b6 commit bcffb91
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,16 @@ def _base_step(self, batch: Dict, batch_idx: int, mode: str) -> Union[Dict[str,
target = apply_fns(target, self.postprocessing["logging"]["target"])
pred = apply_fns(pred, self.postprocessing["logging"]["pred"])

# If the loss is a dict, sum the sublosses under "combined" key. Any weightings should be applied in the criterion.
if isinstance(loss, dict):
if "combined" in loss:
raise ValueError("The loss dictionary cannot contain a key 'combined'.")
loss["combined"] = sum(loss.values())
# If the loss is a dict, the sublosses must be combined under "total" key.
if isinstance(loss, dict) and "total" not in loss:
raise ValueError("The loss dictionary must have 'total' loss, combining all the sublosses.")

# Logging
self._log_stats(loss, metrics, mode, batch_idx)

# Return the loss as required by Lightning as well as other data that can be used in hooks or callbacks.
return {
"loss": loss["combined"] if isinstance(loss, dict) else loss,
"loss": loss["total"] if isinstance(loss, dict) else loss,
"metrics": metrics,
"input": input,
"target": target,
Expand Down

0 comments on commit bcffb91

Please sign in to comment.