Skip to content

Commit 651f396

Browse files
authored
Merge pull request #250 from NREL/gb/solar_experiment
Gb/solar experiment
2 parents 94eebae + dc0f42c commit 651f396

File tree

8 files changed

+82
-42
lines changed

8 files changed

+82
-42
lines changed

sup3r/models/solar_cc.py

+53-26
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,31 @@ class SolarCC(Sup3rGan):
1616
Note
1717
----
1818
*Modifications to standard Sup3rGan*
19-
- Content loss is only on the n_days of the center 8 daylight hours of
20-
the daily true+synthetic high res samples
21-
- Discriminator only sees n_days of the center 8 daylight hours of the
22-
daily true high res sample.
23-
- Discriminator sees random n_days of 8-hour samples of the daily
24-
synthetic high res sample.
19+
- Pointwise content loss (MAE/MSE) is only on the center 2 daylight
20+
hours (POINT_LOSS_HOURS) of the daily true + synthetic days and the
21+
temporal mean of the 24hours of synthetic for n_days
22+
(usually just 1 day)
23+
- Discriminator only sees n_days of the center 8 daylight hours
24+
(DAYLIGHT_HOURS and STARTING_HOUR) of the daily true high res sample.
25+
- Discriminator sees random n_days of 8-hour samples (DAYLIGHT_HOURS)
26+
of the daily synthetic high res sample.
2527
- Includes padding on high resolution output of :meth:`generate` so
2628
that forward pass always outputs a multiple of 24 hours.
2729
"""
2830

29-
# starting hour is the hour that daylight starts at, daylight hours is the
30-
# number of daylight hours to sample, so for example if 8 and 8, the
31-
# daylight slice will be slice(8, 16). The stride length is the step size
32-
# for sampling the temporal axis of the generated data to send to the
33-
# discriminator for the adversarial loss component of the generator. For
34-
# example, if the generator produces 24 timesteps and stride is 4 and the
35-
# daylight hours is 8, slices of (0, 8) (4, 12), (8, 16), (12, 20), and
36-
# (16, 24) will be sent to the disc.
3731
STARTING_HOUR = 8
32+
"""Starting hour is the hour that daylight starts at, typically
33+
zero-indexed and rolled to local time"""
34+
3835
DAYLIGHT_HOURS = 8
39-
STRIDE_LEN = 4
36+
"""Daylight hours is the number of daylight hours to sample, so for example
37+
if STARTING_HOUR is 8 and DAYLIGHT_HOURS is 8, the daylight slice will be
38+
slice(8, 16). """
39+
40+
POINT_LOSS_HOURS = 2
41+
"""Number of hours from the center of the day to calculate pointwise loss
42+
from, e.g., MAE/MSE based on data from the true 4km hourly high res
43+
field."""
4044

4145
def __init__(self, *args, t_enhance=None, **kwargs):
4246
"""Add optional t_enhance adjustment.
@@ -142,32 +146,55 @@ def calc_loss(
142146

143147
t_len = hi_res_true.shape[3]
144148
n_days = int(t_len // 24)
145-
day_slices = [
149+
150+
# slices for 24-hour full days
151+
day_24h_slices = [slice(x, x + 24) for x in range(0, 24 * n_days, 24)]
152+
153+
# slices for middle-daylight-hours
154+
sub_day_slices = [
146155
slice(
147156
self.STARTING_HOUR + x,
148157
self.STARTING_HOUR + x + self.DAYLIGHT_HOURS,
149158
)
150159
for x in range(0, 24 * n_days, 24)
151160
]
152161

162+
# slices for middle-pointwise-loss-hours
163+
point_loss_slices = [
164+
slice(
165+
(24 - self.POINT_LOSS_HOURS) // 2 + x,
166+
(24 - self.POINT_LOSS_HOURS) // 2 + x + self.POINT_LOSS_HOURS,
167+
)
168+
for x in range(0, 24 * n_days, 24)
169+
]
170+
153171
# sample only daylight hours for disc training and gen content loss
154172
disc_out_true = []
155173
disc_out_gen = []
156174
loss_gen_content = 0.0
157-
for tslice in day_slices:
158-
disc_t = self._tf_discriminate(hi_res_true[:, :, :, tslice, :])
159-
gen_c = self.calc_loss_gen_content(
160-
hi_res_true[:, :, :, tslice, :], hi_res_gen[:, :, :, tslice, :]
161-
)
175+
ziter = zip(sub_day_slices, point_loss_slices, day_24h_slices)
176+
for tslice_sub, tslice_ploss, tslice_24h in ziter:
177+
hr_true_sub = hi_res_true[:, :, :, tslice_sub, :]
178+
hr_gen_24h = hi_res_gen[:, :, :, tslice_24h, :]
179+
hr_true_ploss = hi_res_true[:, :, :, tslice_ploss, :]
180+
hr_gen_ploss = hi_res_gen[:, :, :, tslice_ploss, :]
181+
182+
hr_true_mean = tf.math.reduce_mean(hr_true_sub, axis=3)
183+
hr_gen_mean = tf.math.reduce_mean(hr_gen_24h, axis=3)
184+
185+
gen_c_sub = self.calc_loss_gen_content(hr_true_ploss, hr_gen_ploss)
186+
gen_c_24h = self.calc_loss_gen_content(hr_true_mean, hr_gen_mean)
187+
loss_gen_content += gen_c_24h + gen_c_sub
188+
189+
disc_t = self._tf_discriminate(hr_true_sub)
162190
disc_out_true.append(disc_t)
163-
loss_gen_content += gen_c
164191

165192
# Randomly sample daylight windows from generated data. Better than
166193
# strided samples covering full day because the random samples will
167194
# provide an evenly balanced training set for the disc
168-
logits = [[1.0] * (t_len - self.DAYLIGHT_HOURS)]
169-
time_samples = tf.random.categorical(logits, len(day_slices))
170-
for i in range(len(day_slices)):
195+
logits = [[1.0] * (t_len - self.DAYLIGHT_HOURS + 1)]
196+
time_samples = tf.random.categorical(logits, n_days)
197+
for i in range(n_days):
171198
t0 = time_samples[0, i]
172199
t1 = t0 + self.DAYLIGHT_HOURS
173200
disc_g = self._tf_discriminate(hi_res_gen[:, :, :, t0:t1, :])
@@ -177,7 +204,7 @@ def calc_loss(
177204
disc_out_gen = tf.concat([disc_out_gen], axis=0)
178205
loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen)
179206

180-
loss_gen_content /= len(day_slices)
207+
loss_gen_content /= len(sub_day_slices)
181208
loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen)
182209
loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers
183210

sup3r/preprocessing/accessor.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,9 @@ def compute(self, **kwargs):
230230
logger.debug(f'Loading dataset into memory: {self._ds}')
231231
logger.debug(f'Pre-loading: {_mem_check()}')
232232

233-
for f in self._ds.data_vars:
234-
self._ds[f] = self._ds[f].compute(**kwargs)
233+
for f in list(self._ds.data_vars) + list(self._ds.coords):
234+
if hasattr(self._ds[f], 'compute'):
235+
self._ds[f] = self._ds[f].compute(**kwargs)
235236
logger.debug(
236237
f'Loaded {f} into memory with shape '
237238
f'{self._ds[f].shape}. {_mem_check()}'

sup3r/preprocessing/derivers/methods.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,6 @@ class TasMax(Tas):
384384
class Sza(DerivedFeature):
385385
"""Solar zenith angle derived feature."""
386386

387-
inputs = ()
388-
389387
@classmethod
390388
def compute(cls, data):
391389
"""Compute method for sza."""
@@ -402,6 +400,8 @@ def compute(cls, data):
402400
'cloud_mask': CloudMask,
403401
'clearsky_ratio': ClearSkyRatio,
404402
'sza': Sza,
403+
'latitude_feature': 'latitude',
404+
'longitude_feature': 'longitude',
405405
}
406406

407407
RegistryH5WindCC = {

sup3r/preprocessing/rasterizers/extended.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ def get_lat_lon(self):
193193
return self._get_flat_data_lat_lon()
194194

195195
def _get_flat_data_lat_lon(self):
196-
"""Get lat lon for flattened source data."""
196+
"""Get lat lon for flattened source data. Output is shape (y, x, 2)
197+
where 2 is (lat, lon)"""
197198
if hasattr(self.full_lat_lon, 'vindex'):
198199
return self.full_lat_lon.vindex[self.raster_index]
199-
return self.full_lat_lon[self.raster_index.flatten]
200+
return self.full_lat_lon[self.raster_index]

sup3r/solar/solar.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
sup3r_fps,
3535
nsrdb_fp,
3636
t_slice=slice(None),
37-
tz=-6,
37+
tz=-7,
3838
agg_factor=1,
3939
nn_threshold=0.5,
4040
cloud_threshold=0.99,
@@ -64,8 +64,8 @@ def __init__(
6464
tz : int
6565
The timezone offset for the data in sup3r_fps. It is assumed that
6666
the GAN is trained on data in local time and therefore the output
67-
in sup3r_fps should be treated as local time. For example, -6 is
68-
CST which is default for CONUS training data.
67+
in sup3r_fps should be treated as local time. For example, -7 is
68+
MST which is default for CONUS training data.
6969
agg_factor : int
7070
Spatial aggregation factor for nsrdb-to-GAN-meta e.g. the number of
7171
NSRDB spatial pixels to average for a single sup3r GAN output site.
@@ -585,7 +585,7 @@ def run_temporal_chunks(
585585
fp_pattern,
586586
nsrdb_fp,
587587
fp_out_suffix='irradiance',
588-
tz=-6,
588+
tz=-7,
589589
agg_factor=1,
590590
nn_threshold=0.5,
591591
cloud_threshold=0.99,
@@ -610,8 +610,8 @@ def run_temporal_chunks(
610610
tz : int
611611
The timezone offset for the data in sup3r_fps. It is assumed that
612612
the GAN is trained on data in local time and therefore the output
613-
in sup3r_fps should be treated as local time. For example, -6 is
614-
CST which is default for CONUS training data.
613+
in sup3r_fps should be treated as local time. For example, -7 is
614+
MST which is default for CONUS training data.
615615
agg_factor : int
616616
Spatial aggregation factor for nsrdb-to-GAN-meta e.g. the number of
617617
NSRDB spatial pixels to average for a single sup3r GAN output site.
@@ -663,7 +663,7 @@ def _run_temporal_chunk(
663663
fp_pattern,
664664
nsrdb_fp,
665665
fp_out_suffix='irradiance',
666-
tz=-6,
666+
tz=-7,
667667
agg_factor=1,
668668
nn_threshold=0.5,
669669
cloud_threshold=0.99,

tests/data_handlers/test_dh_nc_cc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,11 @@ def test_data_handling_nc_cc():
122122

123123
handler = DataHandlerNCforCC(
124124
pytest.FPS_GCM,
125-
features=['u_100m', 'v_100m'],
125+
features=['u_100m', 'v_100m', 'latitude_feature', 'longitude_feature'],
126126
target=target,
127127
shape=(20, 20),
128128
)
129-
assert handler.data.shape == (20, 20, 20, 2)
129+
assert handler.data.shape == (20, 20, 20, 4)
130130

131131
# upper case features warning
132132
features = [f'U_{int(plevel)}pa', f'V_{int(plevel)}pa']

tests/rasterizers/test_rasterizer_general.py

+12
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,15 @@ def test_topography_h5():
8080
topo = res.get_meta_arr('elevation')[ri.flatten(),]
8181
topo = topo.reshape((ri.shape[0], ri.shape[1]))
8282
assert np.allclose(topo, rasterizer['topography'][..., 0])
83+
84+
85+
def test_preloaded_h5():
86+
"""Test preload of h5 file"""
87+
rasterizer = Rasterizer(
88+
file_paths=pytest.FP_WTK,
89+
target=(39.01, -105.15),
90+
shape=(20, 20),
91+
chunks=None,
92+
)
93+
for f in list(rasterizer.data.data_vars) + list(Dimension.coords_2d()):
94+
assert isinstance(rasterizer[f].data, np.ndarray)

tests/training/test_train_solar.py

-1
Original file line numberDiff line numberDiff line change
@@ -246,4 +246,3 @@ def test_solar_custom_loss():
246246
)
247247

248248
assert loss1 > loss2
249-
assert loss2 == 0

0 commit comments

Comments
 (0)