diff --git a/deel/torchlip/modules/loss.py b/deel/torchlip/modules/loss.py index 1492fa2..6520328 100644 --- a/deel/torchlip/modules/loss.py +++ b/deel/torchlip/modules/loss.py @@ -223,7 +223,7 @@ def _update_mean(self, y_pred): self.alpha_mean * self.current_mean + (1 - self.alpha_mean) * current_global_mean ) - self.current_mean = self.clamp_current_mean(current_global_mean) + self.current_mean = self.clamp_current_mean(current_global_mean).detach() total_mean = current_global_mean total_mean = torch.clamp(total_mean, self.min_margin_v, 20000) return total_mean