@@ -93,19 +93,25 @@ def decode(self, z: torch.Tensor) -> Any:
93
93
"""
94
94
raise NotImplementedError
95
95
96
- def compute_loss (self , pred : torch .Tensor , target : torch .Tensor ) -> LossOutput :
96
+ def compute_loss (
97
+ self , pred : torch .Tensor , target : torch .Tensor
98
+ ) -> LossOutput | None :
97
99
"""
98
100
Generic loss computation the modality.
99
101
100
102
Args:
101
103
pred (`torch.Tensor`): prediction of the model
102
104
target (`torch.Tensor`): target tensor
103
105
Results:
104
- `LossOutput`: LossOuput with training loss and additional metrics.
106
+ `LossOutput | None`: LossOuput with training loss and additional metrics.
107
+ If `None` is returned, this loss will be ignored and will not
108
+ participate in the total loss.
105
109
"""
106
110
raise NotImplementedError
107
111
108
- def compute_dcy_loss (self , pred : torch .Tensor , target : torch .Tensor ) -> LossOutput :
112
+ def compute_dcy_loss (
113
+ self , pred : torch .Tensor , target : torch .Tensor
114
+ ) -> LossOutput | None :
109
115
"""
110
116
Computes the loss for a demi-cycle. Override if the demi-cycle loss is
111
117
different that the generic loss.
@@ -114,11 +120,16 @@ def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutp
114
120
pred (`torch.Tensor`): prediction of the model
115
121
target (`torch.Tensor`): target tensor
116
122
Results:
117
- `LossOutput`: LossOuput with training loss and additional metrics.
123
+ `LossOutput | None`: LossOuput with training loss and additional metrics.
124
+ If `None` is returned, this loss will be ignored and will not
125
+ participate in the total loss; it can be used to deactivate
126
+ demi-cycle loss for this domain.
118
127
"""
119
128
return self .compute_loss (pred , target )
120
129
121
- def compute_cy_loss (self , pred : torch .Tensor , target : torch .Tensor ) -> LossOutput :
130
+ def compute_cy_loss (
131
+ self , pred : torch .Tensor , target : torch .Tensor
132
+ ) -> LossOutput | None :
122
133
"""
123
134
Computes the loss for a cycle. Override if the cycle loss is
124
135
different that the generic loss.
@@ -127,11 +138,16 @@ def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu
127
138
pred (`torch.Tensor`): prediction of the model
128
139
target (`torch.Tensor`): target tensor
129
140
Results:
130
- `LossOutput`: LossOuput with training loss and additional metrics.
141
+ `LossOutput | None`: LossOuput with training loss and additional metrics.
142
+ If `None` is returned, this loss will be ignored and will not
143
+ participate in the total loss; it can be used to deactivate
144
+ cycle loss for this domain.
131
145
"""
132
146
return self .compute_loss (pred , target )
133
147
134
- def compute_tr_loss (self , pred : torch .Tensor , target : torch .Tensor ) -> LossOutput :
148
+ def compute_tr_loss (
149
+ self , pred : torch .Tensor , target : torch .Tensor
150
+ ) -> LossOutput | None :
135
151
"""
136
152
Computes the loss for a translation. Override if the translation loss is
137
153
different that the generic loss.
@@ -140,21 +156,27 @@ def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu
140
156
pred (`torch.Tensor`): prediction of the model
141
157
target (`torch.Tensor`): target tensor
142
158
Results:
143
- `LossOutput`: LossOuput with training loss and additional metrics.
159
+ `LossOutput | None`: LossOuput with training loss and additional metrics.
160
+ If `None` is returned, this loss will be ignored and will not
161
+ participate in the total loss; it can be used to deactivate
162
+ translation loss for this domain.
144
163
"""
145
164
return self .compute_loss (pred , target )
146
165
147
- def compute_broadcast_loss (
166
+ def compute_fused_loss (
148
167
self , pred : torch .Tensor , target : torch .Tensor
149
- ) -> LossOutput :
168
+ ) -> LossOutput | None :
150
169
"""
151
- Computes the loss for a broadcast (fusion). Override if the broadcast loss is
170
+ Computes the loss for fused (fusion). Override if the fused loss is
152
171
different that the generic loss.
153
172
154
173
Args:
155
174
pred (`torch.Tensor`): prediction of the model
156
175
target (`torch.Tensor`): target tensor
157
176
Results:
158
- `LossOutput`: LossOuput with training loss and additional metrics.
177
+ `LossOutput | None`: LossOuput with training loss and additional metrics.
178
+ If `None` is returned, this loss will be ignored and will not
179
+ participate in the total loss; it can be used to deactivate
180
+ fused loss for this domain.
159
181
"""
160
182
return self .compute_loss (pred , target )
0 commit comments