@@ -16,27 +16,31 @@ class SolarCC(Sup3rGan):
16
16
Note
17
17
----
18
18
*Modifications to standard Sup3rGan*
19
- - Content loss is only on the n_days of the center 8 daylight hours of
20
- the daily true+synthetic high res samples
21
- - Discriminator only sees n_days of the center 8 daylight hours of the
22
- daily true high res sample.
23
- - Discriminator sees random n_days of 8-hour samples of the daily
24
- synthetic high res sample.
19
+ - Pointwise content loss (MAE/MSE) is only on the center 2 daylight
20
+ hours (POINT_LOSS_HOURS) of the daily true + synthetic days and the
21
+ temporal mean of the 24hours of synthetic for n_days
22
+ (usually just 1 day)
23
+ - Discriminator only sees n_days of the center 8 daylight hours
24
+ (DAYLIGHT_HOURS and STARTING_HOUR) of the daily true high res sample.
25
+ - Discriminator sees random n_days of 8-hour samples (DAYLIGHT_HOURS)
26
+ of the daily synthetic high res sample.
25
27
- Includes padding on high resolution output of :meth:`generate` so
26
28
that forward pass always outputs a multiple of 24 hours.
27
29
"""
28
30
29
- # starting hour is the hour that daylight starts at, daylight hours is the
30
- # number of daylight hours to sample, so for example if 8 and 8, the
31
- # daylight slice will be slice(8, 16). The stride length is the step size
32
- # for sampling the temporal axis of the generated data to send to the
33
- # discriminator for the adversarial loss component of the generator. For
34
- # example, if the generator produces 24 timesteps and stride is 4 and the
35
- # daylight hours is 8, slices of (0, 8) (4, 12), (8, 16), (12, 20), and
36
- # (16, 24) will be sent to the disc.
37
31
STARTING_HOUR = 8
32
+ """Starting hour is the hour that daylight starts at, typically
33
+ zero-indexed and rolled to local time"""
34
+
38
35
DAYLIGHT_HOURS = 8
39
- STRIDE_LEN = 4
36
+ """Daylight hours is the number of daylight hours to sample, so for example
37
+ if STARTING_HOUR is 8 and DAYLIGHT_HOURS is 8, the daylight slice will be
38
+ slice(8, 16). """
39
+
40
+ POINT_LOSS_HOURS = 2
41
+ """Number of hours from the center of the day to calculate pointwise loss
42
+ from, e.g., MAE/MSE based on data from the true 4km hourly high res
43
+ field."""
40
44
41
45
def __init__ (self , * args , t_enhance = None , ** kwargs ):
42
46
"""Add optional t_enhance adjustment.
@@ -142,32 +146,55 @@ def calc_loss(
142
146
143
147
t_len = hi_res_true .shape [3 ]
144
148
n_days = int (t_len // 24 )
145
- day_slices = [
149
+
150
+ # slices for 24-hour full days
151
+ day_24h_slices = [slice (x , x + 24 ) for x in range (0 , 24 * n_days , 24 )]
152
+
153
+ # slices for middle-daylight-hours
154
+ sub_day_slices = [
146
155
slice (
147
156
self .STARTING_HOUR + x ,
148
157
self .STARTING_HOUR + x + self .DAYLIGHT_HOURS ,
149
158
)
150
159
for x in range (0 , 24 * n_days , 24 )
151
160
]
152
161
162
+ # slices for middle-pointwise-loss-hours
163
+ point_loss_slices = [
164
+ slice (
165
+ (24 - self .POINT_LOSS_HOURS ) // 2 + x ,
166
+ (24 - self .POINT_LOSS_HOURS ) // 2 + x + self .POINT_LOSS_HOURS ,
167
+ )
168
+ for x in range (0 , 24 * n_days , 24 )
169
+ ]
170
+
153
171
# sample only daylight hours for disc training and gen content loss
154
172
disc_out_true = []
155
173
disc_out_gen = []
156
174
loss_gen_content = 0.0
157
- for tslice in day_slices :
158
- disc_t = self ._tf_discriminate (hi_res_true [:, :, :, tslice , :])
159
- gen_c = self .calc_loss_gen_content (
160
- hi_res_true [:, :, :, tslice , :], hi_res_gen [:, :, :, tslice , :]
161
- )
175
+ ziter = zip (sub_day_slices , point_loss_slices , day_24h_slices )
176
+ for tslice_sub , tslice_ploss , tslice_24h in ziter :
177
+ hr_true_sub = hi_res_true [:, :, :, tslice_sub , :]
178
+ hr_gen_24h = hi_res_gen [:, :, :, tslice_24h , :]
179
+ hr_true_ploss = hi_res_true [:, :, :, tslice_ploss , :]
180
+ hr_gen_ploss = hi_res_gen [:, :, :, tslice_ploss , :]
181
+
182
+ hr_true_mean = tf .math .reduce_mean (hr_true_sub , axis = 3 )
183
+ hr_gen_mean = tf .math .reduce_mean (hr_gen_24h , axis = 3 )
184
+
185
+ gen_c_sub = self .calc_loss_gen_content (hr_true_ploss , hr_gen_ploss )
186
+ gen_c_24h = self .calc_loss_gen_content (hr_true_mean , hr_gen_mean )
187
+ loss_gen_content += gen_c_24h + gen_c_sub
188
+
189
+ disc_t = self ._tf_discriminate (hr_true_sub )
162
190
disc_out_true .append (disc_t )
163
- loss_gen_content += gen_c
164
191
165
192
# Randomly sample daylight windows from generated data. Better than
166
193
# strided samples covering full day because the random samples will
167
194
# provide an evenly balanced training set for the disc
168
- logits = [[1.0 ] * (t_len - self .DAYLIGHT_HOURS )]
169
- time_samples = tf .random .categorical (logits , len ( day_slices ) )
170
- for i in range (len ( day_slices ) ):
195
+ logits = [[1.0 ] * (t_len - self .DAYLIGHT_HOURS + 1 )]
196
+ time_samples = tf .random .categorical (logits , n_days )
197
+ for i in range (n_days ):
171
198
t0 = time_samples [0 , i ]
172
199
t1 = t0 + self .DAYLIGHT_HOURS
173
200
disc_g = self ._tf_discriminate (hi_res_gen [:, :, :, t0 :t1 , :])
@@ -177,7 +204,7 @@ def calc_loss(
177
204
disc_out_gen = tf .concat ([disc_out_gen ], axis = 0 )
178
205
loss_disc = self .calc_loss_disc (disc_out_true , disc_out_gen )
179
206
180
- loss_gen_content /= len (day_slices )
207
+ loss_gen_content /= len (sub_day_slices )
181
208
loss_gen_advers = self .calc_loss_gen_advers (disc_out_gen )
182
209
loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers
183
210
0 commit comments