diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 1133ae9f..aeba3695 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -51,6 +51,8 @@ def __init__( logit_scale: torch.Tensor, reduction: Literal["mean", "sum", "none"] = "mean", ) -> None: + super().__init__() + self.register_buffer("logit_scale", logit_scale) self.reduction: Literal["mean", "sum", "none"] = reduction @@ -88,6 +90,8 @@ def __init__( logit_scale: torch.Tensor, reduction: Literal["mean", "sum", "none"] = "mean", ) -> None: + super().__init__() + self.register_buffer("logit_scale", logit_scale) self.reduction: Literal["mean", "sum", "none"] = reduction