@@ -246,16 +246,8 @@ def model_params(self):
246
246
247
247
248
248
class SpatialThenTemporalBase (MultiStepGan ):
249
- """A two-step model where the first step is a spatial-only enhancement on a
250
- 4D tensor and the second step is (spatio)temporal enhancement on a 5D
251
- tensor.
252
-
253
- NOTE: The low res input to the spatial enhancement should be a 4D tensor of
254
- the shape (temporal, spatial_1, spatial_2, features) where temporal
255
- (usually the observation index) is a series of sequential timesteps that
256
- will be transposed to a 5D tensor of shape
257
- (1, spatial_1, spatial_2, temporal, features) tensor and then fed to the
258
- 2nd-step (spatio)temporal model.
249
+ """A base class for spatial-then-temporal or temporal-then-spatial multi
250
+ step GANs
259
251
"""
260
252
261
253
def __init__ (self , spatial_models , temporal_models ):
@@ -272,22 +264,6 @@ def __init__(self, spatial_models, temporal_models):
272
264
self ._spatial_models = spatial_models
273
265
self ._temporal_models = temporal_models
274
266
275
- @property
276
- def models (self ):
277
- """Get an ordered tuple of the Sup3rGan models that are part of this
278
- MultiStepGan
279
- """
280
- if isinstance (self .spatial_models , MultiStepGan ):
281
- spatial_models = self .spatial_models .models
282
- else :
283
- spatial_models = [self .spatial_models ]
284
- if isinstance (self .temporal_models , MultiStepGan ):
285
- temporal_models = self .temporal_models .models
286
- else :
287
- temporal_models = [self .temporal_models ]
288
-
289
- return (* spatial_models , * temporal_models )
290
-
291
267
@property
292
268
def spatial_models (self ):
293
269
"""Get the MultiStepGan object for the spatial-only model(s)
@@ -308,6 +284,72 @@ def temporal_models(self):
308
284
"""
309
285
return self ._temporal_models
310
286
287
+ @classmethod
288
+ def load (cls , spatial_model_dirs , temporal_model_dirs , verbose = True ):
289
+ """Load the GANs with its sub-networks from a previously saved-to
290
+ output directory.
291
+
292
+ Parameters
293
+ ----------
294
+ spatial_model_dirs : str | list | tuple
295
+ An ordered list/tuple of one or more directories containing trained
296
+ + saved Sup3rGan models created using the Sup3rGan.save() method.
297
+ This must contain only spatial models that input/output 4D
298
+ tensors.
299
+ temporal_model_dirs : str | list | tuple
300
+ An ordered list/tuple of one or more directories containing trained
301
+ + saved Sup3rGan models created using the Sup3rGan.save() method.
302
+ This must contain only (spatio)temporal models that input/output 5D
303
+ tensors.
304
+ verbose : bool
305
+ Flag to log information about the loaded model.
306
+
307
+ Returns
308
+ -------
309
+ out : MultiStepGan
310
+ Returns a pretrained gan model that was previously saved to
311
+ model_dirs
312
+ """
313
+ if isinstance (spatial_model_dirs , str ):
314
+ spatial_model_dirs = [spatial_model_dirs ]
315
+ if isinstance (temporal_model_dirs , str ):
316
+ temporal_model_dirs = [temporal_model_dirs ]
317
+
318
+ s_models = MultiStepGan .load (spatial_model_dirs , verbose = verbose )
319
+ t_models = MultiStepGan .load (temporal_model_dirs , verbose = verbose )
320
+
321
+ return cls (s_models , t_models )
322
+
323
+
324
+ class SpatialThenTemporalGan (SpatialThenTemporalBase ):
325
+ """A two-step GAN where the first step is a spatial-only enhancement on a
326
+ 4D tensor and the second step is a (spatio)temporal enhancement on a 5D
327
+ tensor.
328
+
329
+ NOTE: The low res input to the spatial enhancement should be a 4D tensor of
330
+ the shape (temporal, spatial_1, spatial_2, features) where temporal
331
+ (usually the observation index) is a series of sequential timesteps that
332
+ will be transposed to a 5D tensor of shape
333
+ (1, spatial_1, spatial_2, temporal, features) tensor and then fed to the
334
+ 2nd-step (spatio)temporal GAN.
335
+ """
336
+
337
+ @property
338
+ def models (self ):
339
+ """Get an ordered tuple of the Sup3rGan models that are part of this
340
+ MultiStepGan
341
+ """
342
+ if isinstance (self .spatial_models , MultiStepGan ):
343
+ spatial_models = self .spatial_models .models
344
+ else :
345
+ spatial_models = [self .spatial_models ]
346
+ if isinstance (self .temporal_models , MultiStepGan ):
347
+ temporal_models = self .temporal_models .models
348
+ else :
349
+ temporal_models = [self .temporal_models ]
350
+
351
+ return (* spatial_models , * temporal_models )
352
+
311
353
@property
312
354
def meta (self ):
313
355
"""Get a tuple of meta data dictionaries for all models
@@ -329,14 +371,14 @@ def meta(self):
329
371
@property
330
372
def training_features (self ):
331
373
"""Get the list of input feature names that the first spatial
332
- generative model in this SpatialThenTemporalBase model requires as
374
+ generative model in this SpatialThenTemporalGan model requires as
333
375
input."""
334
376
return self .spatial_models .training_features
335
377
336
378
@property
337
379
def output_features (self ):
338
380
"""Get the list of output feature names that the last spatiotemporal
339
- interpolation model in this SpatialThenTemporalBase model outputs."""
381
+ interpolation model in this SpatialThenTemporalGan model outputs."""
340
382
return self .temporal_models .output_features
341
383
342
384
def generate (self , low_res , norm_in = True , un_norm_out = True ,
@@ -412,58 +454,139 @@ def generate(self, low_res, norm_in=True, un_norm_out=True,
412
454
413
455
return hi_res
414
456
415
- @classmethod
416
- def load (cls , spatial_model_dirs , temporal_model_dirs , verbose = True ):
417
- """Load the GANs with its sub-networks from a previously saved-to
418
- output directory.
457
+
458
+ class TemporalThenSpatialGan (SpatialThenTemporalBase ):
459
+ """A two-step GAN where the first step is a spatiotemporal enhancement on a
460
+ 5D tensor and the second step is a spatial enhancement on a 4D tensor.
461
+ """
462
+
463
+ @property
464
+ def models (self ):
465
+ """Get an ordered tuple of the Sup3rGan models that are part of this
466
+ MultiStepGan
467
+ """
468
+ if isinstance (self .spatial_models , MultiStepGan ):
469
+ spatial_models = self .spatial_models .models
470
+ else :
471
+ spatial_models = [self .spatial_models ]
472
+ if isinstance (self .temporal_models , MultiStepGan ):
473
+ temporal_models = self .temporal_models .models
474
+ else :
475
+ temporal_models = [self .temporal_models ]
476
+
477
+ return (* temporal_models , * spatial_models )
478
+
479
+ @property
480
+ def meta (self ):
481
+ """Get a tuple of meta data dictionaries for all models
482
+
483
+ Returns
484
+ -------
485
+ tuple
486
+ """
487
+ if isinstance (self .spatial_models , MultiStepGan ):
488
+ spatial_models = self .spatial_models .meta
489
+ else :
490
+ spatial_models = [self .spatial_models .meta ]
491
+ if isinstance (self .temporal_models , MultiStepGan ):
492
+ temporal_models = self .temporal_models .meta
493
+ else :
494
+ temporal_models = [self .temporal_models .meta ]
495
+
496
+ return (* temporal_models , * spatial_models )
497
+
498
+ @property
499
+ def training_features (self ):
500
+ """Get the list of input feature names that the first temporal
501
+ generative model in this TemporalThenSpatialGan model requires as
502
+ input."""
503
+ return self .temporal_models .training_features
504
+
505
+ @property
506
+ def output_features (self ):
507
+ """Get the list of output feature names that the last spatial
508
+ interpolation model in this TemporalThenSpatialGan model outputs."""
509
+ return self .spatial_models .output_features
510
+
511
+ def generate (self , low_res , norm_in = True , un_norm_out = True ,
512
+ exogenous_data = None ):
513
+ """Use the generator model to generate high res data from low res
514
+ input. This is the public generate function.
419
515
420
516
Parameters
421
517
----------
422
- spatial_model_dirs : str | list | tuple
423
- An ordered list/tuple of one or more directories containing trained
424
- + saved Sup3rGan models created using the Sup3rGan.save() method.
425
- This must contain only spatial models that input/output 4D
426
- tensors.
427
- temporal_model_dirs : str | list | tuple
428
- An ordered list/tuple of one or more directories containing trained
429
- + saved Sup3rGan models created using the Sup3rGan.save() method.
430
- This must contain only (spatio)temporal models that input/output 5D
431
- tensors.
432
- verbose : bool
433
- Flag to log information about the loaded model.
518
+ low_res : np.ndarray
519
+ Low-resolution input data, a 5D array of shape:
520
+ (1, spatial_1, spatial_2, n_temporal, n_features)
521
+ norm_in : bool
522
+ Flag to normalize low_res input data if the self.means,
523
+ self.stdevs attributes are available. The generator should always
524
+ received normalized data with mean=0 stdev=1.
525
+ un_norm_out : bool
526
+ Flag to un-normalize synthetically generated output data to physical
527
+ units
528
+ exogenous_data : list
529
+ List of arrays of exogenous_data with length equal to the
530
+ number of model steps. e.g. If we want to include topography as
531
+ an exogenous feature in a temporal + spatial multistep model then
532
+ we need to provide a list of length=2 with topography at the low
533
+ spatial resolution and at the high resolution. If we include more
534
+ than one exogenous feature the ordering must be consistent.
535
+ Each array in the list has 3D or 4D shape:
536
+ (spatial_1, spatial_2, n_features)
537
+ (temporal, spatial_1, spatial_2, n_features)
434
538
435
539
Returns
436
540
-------
437
- out : MultiStepGan
438
- Returns a pretrained gan model that was previously saved to
439
- model_dirs
541
+ hi_res : ndarray
542
+ Synthetically generated high-resolution data output from the 2nd
543
+ step (spatio)temporal GAN with a 5D array shape:
544
+ (1, spatial_1, spatial_2, n_temporal, n_features)
440
545
"""
441
- if isinstance (spatial_model_dirs , str ):
442
- spatial_model_dirs = [spatial_model_dirs ]
443
- if isinstance (temporal_model_dirs , str ):
444
- temporal_model_dirs = [temporal_model_dirs ]
546
+ logger .debug ('Data input to the 1st step (spatio)temporal '
547
+ 'enhancement has shape {}' .format (low_res .shape ))
548
+ s_exogenous = None
549
+ if exogenous_data is not None :
550
+ s_exogenous = exogenous_data [len (self .temporal_models ):]
445
551
446
- s_models = MultiStepGan .load (spatial_model_dirs , verbose = verbose )
447
- t_models = MultiStepGan .load (temporal_model_dirs , verbose = verbose )
552
+ assert low_res .shape [0 ] == 1 , 'Low res input can only have 1 obs!'
448
553
449
- return cls (s_models , t_models )
554
+ try :
555
+ hi_res = self .temporal_models .generate (
556
+ low_res , norm_in = norm_in , un_norm_out = True ,
557
+ exogenous_data = exogenous_data )
558
+ except Exception as e :
559
+ msg = ('Could not run the 1st step (spatio)temporal GAN on input '
560
+ 'shape {}' .format (low_res .shape ))
561
+ logger .exception (msg )
562
+ raise RuntimeError (msg ) from e
450
563
564
+ logger .debug ('Data output from the 1st step (spatio)temporal '
565
+ 'enhancement has shape {}' .format (hi_res .shape ))
566
+ hi_res = np .transpose (hi_res [0 ], axes = (2 , 0 , 1 , 3 ))
567
+ logger .debug ('Data from the 1st step (spatio)temporal enhancement has '
568
+ 'been reshaped to {}' .format (hi_res .shape ))
451
569
452
- class SpatialThenTemporalGan (SpatialThenTemporalBase ):
453
- """A two-step GAN where the first step is a spatial-only enhancement on a
454
- 4D tensor and the second step is a (spatio)temporal enhancement on a 5D
455
- tensor.
570
+ try :
571
+ hi_res = self .spatial_models .generate (
572
+ hi_res , norm_in = True , un_norm_out = un_norm_out ,
573
+ exogenous_data = s_exogenous )
574
+ except Exception as e :
575
+ msg = ('Could not run the 2nd step spatial GAN on input '
576
+ 'shape {}' .format (low_res .shape ))
577
+ logger .exception (msg )
578
+ raise RuntimeError (msg ) from e
456
579
457
- NOTE: The low res input to the spatial enhancement should be a 4D tensor of
458
- the shape (temporal, spatial_1, spatial_2, features) where temporal
459
- (usually the observation index) is a series of sequential timesteps that
460
- will be transposed to a 5D tensor of shape
461
- (1, spatial_1, spatial_2, temporal, features) tensor and then fed to the
462
- 2nd-step (spatio)temporal GAN.
463
- """
580
+ hi_res = np . transpose ( hi_res , axes = ( 1 , 2 , 0 , 3 ))
581
+ hi_res = np . expand_dims ( hi_res , axis = 0 )
582
+
583
+ logger . debug ( 'Final multistep GAN output has shape: {}'
584
+ . format ( hi_res . shape ))
585
+
586
+ return hi_res
464
587
465
588
466
- class MultiStepSurfaceMetGan (SpatialThenTemporalBase ):
589
+ class MultiStepSurfaceMetGan (SpatialThenTemporalGan ):
467
590
"""A two-step GAN where the first step is a spatial-only enhancement on a
468
591
4D tensor of near-surface temperature and relative humidity data, and the
469
592
second step is a (spatio)temporal enhancement on a 5D tensor.
@@ -612,7 +735,7 @@ def load(cls, surface_model_class='SurfaceSpatialMetModel',
612
735
return cls (s_models , t_models )
613
736
614
737
615
- class SolarMultiStepGan (SpatialThenTemporalBase ):
738
+ class SolarMultiStepGan (SpatialThenTemporalGan ):
616
739
"""Special multi step model for solar clearsky ratio super resolution.
617
740
618
741
This model takes in two parallel models for wind-only and solar-only
0 commit comments