Skip to content

Commit 93c56f9

Browse files
authored
Merge pull request #140 from NREL/gb/ts_model
added temporal then spatial model with test
2 parents f16f3fe + 2fff120 commit 93c56f9

File tree

4 files changed

+229
-72
lines changed

4 files changed

+229
-72
lines changed

sup3r/models/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from .wind import WindGan
55
from .solar_cc import SolarCC
66
from .data_centric import Sup3rGanDC
7-
from .multi_step import (MultiStepGan, SpatialThenTemporalGan,
7+
from .multi_step import (MultiStepGan,
8+
SpatialThenTemporalGan, TemporalThenSpatialGan,
89
MultiStepSurfaceMetGan, SolarMultiStepGan)
910
from .surface import SurfaceSpatialMetModel
1011
from .linear import LinearInterp

sup3r/models/multi_step.py

+190-67
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,8 @@ def model_params(self):
246246

247247

248248
class SpatialThenTemporalBase(MultiStepGan):
249-
"""A two-step model where the first step is a spatial-only enhancement on a
250-
4D tensor and the second step is (spatio)temporal enhancement on a 5D
251-
tensor.
252-
253-
NOTE: The low res input to the spatial enhancement should be a 4D tensor of
254-
the shape (temporal, spatial_1, spatial_2, features) where temporal
255-
(usually the observation index) is a series of sequential timesteps that
256-
will be transposed to a 5D tensor of shape
257-
(1, spatial_1, spatial_2, temporal, features) tensor and then fed to the
258-
2nd-step (spatio)temporal model.
249+
"""A base class for spatial-then-temporal or temporal-then-spatial multi
250+
step GANs
259251
"""
260252

261253
def __init__(self, spatial_models, temporal_models):
@@ -272,22 +264,6 @@ def __init__(self, spatial_models, temporal_models):
272264
self._spatial_models = spatial_models
273265
self._temporal_models = temporal_models
274266

275-
@property
276-
def models(self):
277-
"""Get an ordered tuple of the Sup3rGan models that are part of this
278-
MultiStepGan
279-
"""
280-
if isinstance(self.spatial_models, MultiStepGan):
281-
spatial_models = self.spatial_models.models
282-
else:
283-
spatial_models = [self.spatial_models]
284-
if isinstance(self.temporal_models, MultiStepGan):
285-
temporal_models = self.temporal_models.models
286-
else:
287-
temporal_models = [self.temporal_models]
288-
289-
return (*spatial_models, *temporal_models)
290-
291267
@property
292268
def spatial_models(self):
293269
"""Get the MultiStepGan object for the spatial-only model(s)
@@ -308,6 +284,72 @@ def temporal_models(self):
308284
"""
309285
return self._temporal_models
310286

287+
@classmethod
288+
def load(cls, spatial_model_dirs, temporal_model_dirs, verbose=True):
289+
"""Load the GANs with its sub-networks from a previously saved-to
290+
output directory.
291+
292+
Parameters
293+
----------
294+
spatial_model_dirs : str | list | tuple
295+
An ordered list/tuple of one or more directories containing trained
296+
+ saved Sup3rGan models created using the Sup3rGan.save() method.
297+
This must contain only spatial models that input/output 4D
298+
tensors.
299+
temporal_model_dirs : str | list | tuple
300+
An ordered list/tuple of one or more directories containing trained
301+
+ saved Sup3rGan models created using the Sup3rGan.save() method.
302+
This must contain only (spatio)temporal models that input/output 5D
303+
tensors.
304+
verbose : bool
305+
Flag to log information about the loaded model.
306+
307+
Returns
308+
-------
309+
out : MultiStepGan
310+
Returns a pretrained gan model that was previously saved to
311+
model_dirs
312+
"""
313+
if isinstance(spatial_model_dirs, str):
314+
spatial_model_dirs = [spatial_model_dirs]
315+
if isinstance(temporal_model_dirs, str):
316+
temporal_model_dirs = [temporal_model_dirs]
317+
318+
s_models = MultiStepGan.load(spatial_model_dirs, verbose=verbose)
319+
t_models = MultiStepGan.load(temporal_model_dirs, verbose=verbose)
320+
321+
return cls(s_models, t_models)
322+
323+
324+
class SpatialThenTemporalGan(SpatialThenTemporalBase):
325+
"""A two-step GAN where the first step is a spatial-only enhancement on a
326+
4D tensor and the second step is a (spatio)temporal enhancement on a 5D
327+
tensor.
328+
329+
NOTE: The low res input to the spatial enhancement should be a 4D tensor of
330+
the shape (temporal, spatial_1, spatial_2, features) where temporal
331+
(usually the observation index) is a series of sequential timesteps that
332+
will be transposed to a 5D tensor of shape
333+
(1, spatial_1, spatial_2, temporal, features) tensor and then fed to the
334+
2nd-step (spatio)temporal GAN.
335+
"""
336+
337+
@property
338+
def models(self):
339+
"""Get an ordered tuple of the Sup3rGan models that are part of this
340+
MultiStepGan
341+
"""
342+
if isinstance(self.spatial_models, MultiStepGan):
343+
spatial_models = self.spatial_models.models
344+
else:
345+
spatial_models = [self.spatial_models]
346+
if isinstance(self.temporal_models, MultiStepGan):
347+
temporal_models = self.temporal_models.models
348+
else:
349+
temporal_models = [self.temporal_models]
350+
351+
return (*spatial_models, *temporal_models)
352+
311353
@property
312354
def meta(self):
313355
"""Get a tuple of meta data dictionaries for all models
@@ -329,14 +371,14 @@ def meta(self):
329371
@property
330372
def training_features(self):
331373
"""Get the list of input feature names that the first spatial
332-
generative model in this SpatialThenTemporalBase model requires as
374+
generative model in this SpatialThenTemporalGan model requires as
333375
input."""
334376
return self.spatial_models.training_features
335377

336378
@property
337379
def output_features(self):
338380
"""Get the list of output feature names that the last spatiotemporal
339-
interpolation model in this SpatialThenTemporalBase model outputs."""
381+
interpolation model in this SpatialThenTemporalGan model outputs."""
340382
return self.temporal_models.output_features
341383

342384
def generate(self, low_res, norm_in=True, un_norm_out=True,
@@ -412,58 +454,139 @@ def generate(self, low_res, norm_in=True, un_norm_out=True,
412454

413455
return hi_res
414456

415-
@classmethod
416-
def load(cls, spatial_model_dirs, temporal_model_dirs, verbose=True):
417-
"""Load the GANs with its sub-networks from a previously saved-to
418-
output directory.
457+
458+
class TemporalThenSpatialGan(SpatialThenTemporalBase):
459+
"""A two-step GAN where the first step is a spatiotemporal enhancement on a
460+
5D tensor and the second step is a spatial enhancement on a 4D tensor.
461+
"""
462+
463+
@property
464+
def models(self):
465+
"""Get an ordered tuple of the Sup3rGan models that are part of this
466+
MultiStepGan
467+
"""
468+
if isinstance(self.spatial_models, MultiStepGan):
469+
spatial_models = self.spatial_models.models
470+
else:
471+
spatial_models = [self.spatial_models]
472+
if isinstance(self.temporal_models, MultiStepGan):
473+
temporal_models = self.temporal_models.models
474+
else:
475+
temporal_models = [self.temporal_models]
476+
477+
return (*temporal_models, *spatial_models)
478+
479+
@property
480+
def meta(self):
481+
"""Get a tuple of meta data dictionaries for all models
482+
483+
Returns
484+
-------
485+
tuple
486+
"""
487+
if isinstance(self.spatial_models, MultiStepGan):
488+
spatial_models = self.spatial_models.meta
489+
else:
490+
spatial_models = [self.spatial_models.meta]
491+
if isinstance(self.temporal_models, MultiStepGan):
492+
temporal_models = self.temporal_models.meta
493+
else:
494+
temporal_models = [self.temporal_models.meta]
495+
496+
return (*temporal_models, *spatial_models)
497+
498+
@property
499+
def training_features(self):
500+
"""Get the list of input feature names that the first temporal
501+
generative model in this TemporalThenSpatialGan model requires as
502+
input."""
503+
return self.temporal_models.training_features
504+
505+
@property
506+
def output_features(self):
507+
"""Get the list of output feature names that the last spatial
508+
interpolation model in this TemporalThenSpatialGan model outputs."""
509+
return self.spatial_models.output_features
510+
511+
def generate(self, low_res, norm_in=True, un_norm_out=True,
512+
exogenous_data=None):
513+
"""Use the generator model to generate high res data from low res
514+
input. This is the public generate function.
419515
420516
Parameters
421517
----------
422-
spatial_model_dirs : str | list | tuple
423-
An ordered list/tuple of one or more directories containing trained
424-
+ saved Sup3rGan models created using the Sup3rGan.save() method.
425-
This must contain only spatial models that input/output 4D
426-
tensors.
427-
temporal_model_dirs : str | list | tuple
428-
An ordered list/tuple of one or more directories containing trained
429-
+ saved Sup3rGan models created using the Sup3rGan.save() method.
430-
This must contain only (spatio)temporal models that input/output 5D
431-
tensors.
432-
verbose : bool
433-
Flag to log information about the loaded model.
518+
low_res : np.ndarray
519+
Low-resolution input data, a 5D array of shape:
520+
(1, spatial_1, spatial_2, n_temporal, n_features)
521+
norm_in : bool
522+
Flag to normalize low_res input data if the self.means,
523+
self.stdevs attributes are available. The generator should always
524+
received normalized data with mean=0 stdev=1.
525+
un_norm_out : bool
526+
Flag to un-normalize synthetically generated output data to physical
527+
units
528+
exogenous_data : list
529+
List of arrays of exogenous_data with length equal to the
530+
number of model steps. e.g. If we want to include topography as
531+
an exogenous feature in a temporal + spatial multistep model then
532+
we need to provide a list of length=2 with topography at the low
533+
spatial resolution and at the high resolution. If we include more
534+
than one exogenous feature the ordering must be consistent.
535+
Each array in the list has 3D or 4D shape:
536+
(spatial_1, spatial_2, n_features)
537+
(temporal, spatial_1, spatial_2, n_features)
434538
435539
Returns
436540
-------
437-
out : MultiStepGan
438-
Returns a pretrained gan model that was previously saved to
439-
model_dirs
541+
hi_res : ndarray
542+
Synthetically generated high-resolution data output from the 2nd
543+
step (spatio)temporal GAN with a 5D array shape:
544+
(1, spatial_1, spatial_2, n_temporal, n_features)
440545
"""
441-
if isinstance(spatial_model_dirs, str):
442-
spatial_model_dirs = [spatial_model_dirs]
443-
if isinstance(temporal_model_dirs, str):
444-
temporal_model_dirs = [temporal_model_dirs]
546+
logger.debug('Data input to the 1st step (spatio)temporal '
547+
'enhancement has shape {}'.format(low_res.shape))
548+
s_exogenous = None
549+
if exogenous_data is not None:
550+
s_exogenous = exogenous_data[len(self.temporal_models):]
445551

446-
s_models = MultiStepGan.load(spatial_model_dirs, verbose=verbose)
447-
t_models = MultiStepGan.load(temporal_model_dirs, verbose=verbose)
552+
assert low_res.shape[0] == 1, 'Low res input can only have 1 obs!'
448553

449-
return cls(s_models, t_models)
554+
try:
555+
hi_res = self.temporal_models.generate(
556+
low_res, norm_in=norm_in, un_norm_out=True,
557+
exogenous_data=exogenous_data)
558+
except Exception as e:
559+
msg = ('Could not run the 1st step (spatio)temporal GAN on input '
560+
'shape {}'.format(low_res.shape))
561+
logger.exception(msg)
562+
raise RuntimeError(msg) from e
450563

564+
logger.debug('Data output from the 1st step (spatio)temporal '
565+
'enhancement has shape {}'.format(hi_res.shape))
566+
hi_res = np.transpose(hi_res[0], axes=(2, 0, 1, 3))
567+
logger.debug('Data from the 1st step (spatio)temporal enhancement has '
568+
'been reshaped to {}'.format(hi_res.shape))
451569

452-
class SpatialThenTemporalGan(SpatialThenTemporalBase):
453-
"""A two-step GAN where the first step is a spatial-only enhancement on a
454-
4D tensor and the second step is a (spatio)temporal enhancement on a 5D
455-
tensor.
570+
try:
571+
hi_res = self.spatial_models.generate(
572+
hi_res, norm_in=True, un_norm_out=un_norm_out,
573+
exogenous_data=s_exogenous)
574+
except Exception as e:
575+
msg = ('Could not run the 2nd step spatial GAN on input '
576+
'shape {}'.format(low_res.shape))
577+
logger.exception(msg)
578+
raise RuntimeError(msg) from e
456579

457-
NOTE: The low res input to the spatial enhancement should be a 4D tensor of
458-
the shape (temporal, spatial_1, spatial_2, features) where temporal
459-
(usually the observation index) is a series of sequential timesteps that
460-
will be transposed to a 5D tensor of shape
461-
(1, spatial_1, spatial_2, temporal, features) tensor and then fed to the
462-
2nd-step (spatio)temporal GAN.
463-
"""
580+
hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3))
581+
hi_res = np.expand_dims(hi_res, axis=0)
582+
583+
logger.debug('Final multistep GAN output has shape: {}'
584+
.format(hi_res.shape))
585+
586+
return hi_res
464587

465588

466-
class MultiStepSurfaceMetGan(SpatialThenTemporalBase):
589+
class MultiStepSurfaceMetGan(SpatialThenTemporalGan):
467590
"""A two-step GAN where the first step is a spatial-only enhancement on a
468591
4D tensor of near-surface temperature and relative humidity data, and the
469592
second step is a (spatio)temporal enhancement on a 5D tensor.
@@ -612,7 +735,7 @@ def load(cls, surface_model_class='SurfaceSpatialMetModel',
612735
return cls(s_models, t_models)
613736

614737

615-
class SolarMultiStepGan(SpatialThenTemporalBase):
738+
class SolarMultiStepGan(SpatialThenTemporalGan):
616739
"""Special multi step model for solar clearsky ratio super resolution.
617740
618741
This model takes in two parallel models for wind-only and solar-only

sup3r/models/wind.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def init_weights(self, lr_shape, hr_shape, device=None):
5050
device = self.default_device
5151

5252
logger.info('Initializing model weights on device "{}"'.format(device))
53-
low_res = np.random.uniform(0, 1, lr_shape).astype(np.float32)
54-
hi_res = np.random.uniform(0, 1, hr_shape).astype(np.float32)
53+
low_res = np.ones(lr_shape).astype(np.float32)
54+
hi_res = np.ones(hr_shape).astype(np.float32)
5555

5656
hr_topo_shape = hr_shape[:-1] + (1,)
57-
hr_topo = np.random.uniform(0, 1, hr_topo_shape).astype(np.float32)
57+
hr_topo = np.ones(hr_topo_shape).astype(np.float32)
5858

5959
with tf.device(device):
6060
_ = self._tf_generate(low_res, hr_topo)

tests/test_multi_step.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import tempfile
77

88
from sup3r import CONFIG_DIR
9-
from sup3r.models import (Sup3rGan, MultiStepGan, SpatialThenTemporalGan,
9+
from sup3r.models import (Sup3rGan, MultiStepGan,
10+
SpatialThenTemporalGan, TemporalThenSpatialGan,
1011
SolarMultiStepGan, LinearInterp)
1112

1213
FEATURES = ['U_100m', 'V_100m']
@@ -129,6 +130,38 @@ def test_spatial_then_temporal_gan():
129130
assert out.shape == (1, 60, 60, 16, 2)
130131

131132

133+
def test_temporal_then_spatial_gan():
134+
"""Test the 2-step temporal-then-spatial GAN"""
135+
fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json')
136+
fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json')
137+
model1 = Sup3rGan(fp_gen, fp_disc)
138+
_ = model1.generate(np.ones((4, 10, 10, len(FEATURES))))
139+
140+
fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json')
141+
fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json')
142+
model2 = Sup3rGan(fp_gen, fp_disc)
143+
_ = model2.generate(np.ones((4, 10, 10, 6, len(FEATURES))))
144+
145+
model1.set_norm_stats([0.1, 0.2], [0.04, 0.02])
146+
model2.set_norm_stats([0.3, 0.9], [0.02, 0.07])
147+
model1.set_model_params(training_features=FEATURES,
148+
output_features=FEATURES)
149+
model2.set_model_params(training_features=FEATURES,
150+
output_features=FEATURES)
151+
152+
with tempfile.TemporaryDirectory() as td:
153+
fp1 = os.path.join(td, 'model1')
154+
fp2 = os.path.join(td, 'model2')
155+
model1.save(fp1)
156+
model2.save(fp2)
157+
158+
ms_model = TemporalThenSpatialGan.load(fp1, fp2)
159+
160+
x = np.ones((1, 10, 10, 4, len(FEATURES)))
161+
out = ms_model.generate(x)
162+
assert out.shape == (1, 60, 60, 16, 2)
163+
164+
132165
def test_spatial_gan_then_linear_interp():
133166
"""Test the 2-step spatial GAN then linear spatiotemporal interpolation"""
134167
fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json')

0 commit comments

Comments
 (0)