Skip to content

Commit 60fab96

Browse files
authored
Merge pull request #142 from NREL/gb/trh_loss
Gb/trh loss
2 parents 9cfa20c + 2ff434e commit 60fab96

File tree

8 files changed

+182
-87
lines changed

8 files changed

+182
-87
lines changed

sup3r/models/linear.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Simple models for super resolution such as linear interp models."""
33
import numpy as np
44
import logging
5+
from inspect import signature
56
import os
67
import json
78
from sup3r.utilities.utilities import st_interp
@@ -45,7 +46,9 @@ def load(cls, model_dir, verbose=False):
4546
Parameters
4647
----------
4748
model_dir : str
48-
Directory to load LinearInterp model files from.
49+
Directory to load LinearInterp model files from. Must
50+
have a model_params.json file containing "meta" key with all of the
51+
class init args.
4952
verbose : bool
5053
Flag to log information about the loaded model.
5154
@@ -59,11 +62,10 @@ def load(cls, model_dir, verbose=False):
5962
with open(fp_params, 'r') as f:
6063
params = json.load(f)
6164

62-
meta = params.get('meta', {'class': 'Sup3rGan'})
63-
model = cls(features=meta['training_features'],
64-
s_enhance=meta['s_enhance'],
65-
t_enhance=meta['t_enhance'],
66-
t_centered=meta['t_centered'])
65+
meta = params['meta']
66+
args = signature(cls.__init__).parameters
67+
kwargs = {k: v for k, v in meta.items() if k in args}
68+
model = cls(**kwargs)
6769

6870
if verbose:
6971
logger.info('Loading LinearInterp with meta data: {}'
@@ -74,7 +76,8 @@ def load(cls, model_dir, verbose=False):
7476
@property
7577
def meta(self):
7678
"""Get meta data dictionary that defines the model params"""
77-
return {'s_enhance': self._s_enhance,
79+
return {'features': self._features,
80+
's_enhance': self._s_enhance,
7881
't_enhance': self._t_enhance,
7982
't_centered': self._t_centered,
8083
'training_features': self.training_features,

sup3r/models/surface.py

+82-66
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
# -*- coding: utf-8 -*-
22
"""Special models for surface meteorological data."""
3+
import os
4+
import json
35
import logging
6+
from inspect import signature
47
from fnmatch import fnmatch
58
import numpy as np
69
from PIL import Image
710
from sklearn import linear_model
811
from warnings import warn
912

10-
from sup3r.models.abstract import AbstractInterface
13+
from sup3r.models.linear import LinearInterp
1114
from sup3r.utilities.utilities import spatial_coarsening
1215

1316
logger = logging.getLogger(__name__)
1417

1518

16-
class SurfaceSpatialMetModel(AbstractInterface):
19+
class SurfaceSpatialMetModel(LinearInterp):
1720
"""Model to spatially downscale daily-average near-surface temperature,
1821
relative humidity, and pressure
1922
@@ -43,7 +46,8 @@ class SurfaceSpatialMetModel(AbstractInterface):
4346

4447
def __init__(self, features, s_enhance, noise_adders=None,
4548
temp_lapse=None, w_delta_temp=None, w_delta_topo=None,
46-
pres_div=None, pres_exp=None):
49+
pres_div=None, pres_exp=None, interp_method='LANCZOS',
50+
fix_bias=True):
4751
"""
4852
Parameters
4953
----------
@@ -85,6 +89,15 @@ def __init__(self, features, s_enhance, noise_adders=None,
8589
pres_div : None | float
8690
Exponential factor in the pressure scale height equation. Defaults
8791
to the cls.PRES_EXP attribute.
92+
interp_method : str
93+
Name of the interpolation method to use from PIL.Image.Resampling
94+
(NEAREST, BILINEAR, BICUBIC, LANCZOS)
95+
LANCZOS is default and has been tested to work best for
96+
SurfaceSpatialMetModel.
97+
fix_bias : bool
98+
Some local bias can be introduced by the bilinear interp + lapse
99+
rate, this flag will attempt to correct that bias by using the
100+
low-resolution deviation from the input data
88101
"""
89102

90103
self._features = features
@@ -95,6 +108,8 @@ def __init__(self, features, s_enhance, noise_adders=None,
95108
self._w_delta_topo = w_delta_topo or self.W_DELTA_TOPO
96109
self._pres_div = pres_div or self.PRES_DIV
97110
self._pres_exp = pres_exp or self.PRES_EXP
111+
self._fix_bias = fix_bias
112+
self._interp_method = getattr(Image.Resampling, interp_method)
98113

99114
if isinstance(self._noise_adders, (int, float)):
100115
self._noise_adders = [self._noise_adders] * len(self._features)
@@ -103,42 +118,6 @@ def __len__(self):
103118
"""Get number of model steps (match interface of MultiStepGan)"""
104119
return 1
105120

106-
@classmethod
107-
def load(cls, features, s_enhance, verbose=False, **kwargs):
108-
"""Load the GAN with its sub-networks from a previously saved-to output
109-
directory.
110-
111-
Parameters
112-
----------
113-
features : list
114-
List of feature names that this model will operate on for both
115-
input and output. This must match the feature axis ordering in the
116-
array input to generate(). Typically this is a list containing:
117-
temperature_*m, relativehumidity_*m, and pressure_*m. The list can
118-
contain multiple instances of each variable at different heights.
119-
relativehumidity_*m entries must have corresponding temperature_*m
120-
entires at the same hub height.
121-
s_enhance : int
122-
Integer factor by which the spatial axes are to be enhanced.
123-
verbose : bool
124-
Flag to log information about the loaded model.
125-
kwargs : None | dict
126-
Optional kwargs to initialize SurfaceSpatialMetModel
127-
128-
Returns
129-
-------
130-
out : SurfaceSpatialMetModel
131-
Returns an initialized SurfaceSpatialMetModel
132-
"""
133-
134-
model = cls(features, s_enhance, **kwargs)
135-
136-
if verbose:
137-
logger.info('Loading SurfaceSpatialMetModel with meta data: {}'
138-
.format(model.meta))
139-
140-
return model
141-
142121
@staticmethod
143122
def _get_s_enhance(topo_lr, topo_hr):
144123
"""Get the spatial enhancement factor given low-res and high-res
@@ -227,8 +206,39 @@ def _get_temp_rh_ind(self, idf_rh):
227206

228207
return idf_temp
229208

209+
def _fix_downscaled_bias(self, single_lr, single_hr,
210+
method=Image.Resampling.LANCZOS):
211+
"""Fix any bias introduced by the spatial downscaling with lapse rate.
212+
213+
Parameters
214+
----------
215+
single_lr : np.ndarray
216+
Single timestep raster data with shape
217+
(lat, lon) matching the low-resolution input data.
218+
single_hr : np.ndarray
219+
Single timestep downscaled raster data with shape
220+
(lat, lon) matching the high-resolution input data.
221+
method : Image.Resampling.LANCZOS
222+
An Image.Resampling method (NEAREST, BILINEAR, BICUBIC, LANCZOS).
223+
NEAREST enforces zero bias but makes slightly more spatial seams.
224+
225+
Returns
226+
-------
227+
single_hr : np.ndarray
228+
Single timestep downscaled raster data with shape
229+
(lat, lon) matching the high-resolution input data.
230+
"""
231+
232+
re_coarse = spatial_coarsening(np.expand_dims(single_hr, axis=-1),
233+
s_enhance=self._s_enhance,
234+
obs_axis=False)[..., 0]
235+
bias = re_coarse - single_lr
236+
bc = self.downscale_arr(bias, s_enhance=self._s_enhance, method=method)
237+
single_hr -= bc
238+
return single_hr
239+
230240
@staticmethod
231-
def downscale_arr(arr, s_enhance, method=Image.Resampling.BILINEAR):
241+
def downscale_arr(arr, s_enhance, method=Image.Resampling.LANCZOS):
232242
"""Downscale a 2D array of data Image.resize() method
233243
234244
Parameters
@@ -238,9 +248,9 @@ def downscale_arr(arr, s_enhance, method=Image.Resampling.BILINEAR):
238248
(lat, lon)
239249
s_enhance : int
240250
Integer factor by which the spatial axes are to be enhanced.
241-
method : Image.Resampling.BILINEAR
251+
method : Image.Resampling.LANCZOS
242252
An Image.Resampling method (NEAREST, BILINEAR, BICUBIC, LANCZOS).
243-
BILINEAR is default and has been tested to work best for
253+
LANCZOS is default and has been tested to work best for
244254
SurfaceSpatialMetModel.
245255
"""
246256
im = Image.fromarray(arr)
@@ -284,9 +294,15 @@ def downscale_temp(self, single_lr_temp, topo_lr, topo_hr):
284294
assert len(topo_hr.shape) == 2, 'Bad shape for topo_hr'
285295

286296
lower_data = single_lr_temp.copy() + topo_lr * self._temp_lapse
287-
hi_res_temp = self.downscale_arr(lower_data, self._s_enhance)
297+
hi_res_temp = self.downscale_arr(lower_data, self._s_enhance,
298+
method=self._interp_method)
288299
hi_res_temp -= topo_hr * self._temp_lapse
289300

301+
if self._fix_bias:
302+
hi_res_temp = self._fix_downscaled_bias(single_lr_temp,
303+
hi_res_temp,
304+
method=self._interp_method)
305+
290306
return hi_res_temp
291307

292308
def downscale_rh(self, single_lr_rh, single_lr_temp, single_hr_temp,
@@ -336,9 +352,12 @@ def downscale_rh(self, single_lr_rh, single_lr_temp, single_hr_temp,
336352
assert len(topo_lr.shape) == 2, 'Bad shape for topo_lr'
337353
assert len(topo_hr.shape) == 2, 'Bad shape for topo_hr'
338354

339-
interp_rh = self.downscale_arr(single_lr_rh, self._s_enhance)
340-
interp_temp = self.downscale_arr(single_lr_temp, self._s_enhance)
341-
interp_topo = self.downscale_arr(topo_lr, self._s_enhance)
355+
interp_rh = self.downscale_arr(single_lr_rh, self._s_enhance,
356+
method=self._interp_method)
357+
interp_temp = self.downscale_arr(single_lr_temp, self._s_enhance,
358+
method=self._interp_method)
359+
interp_topo = self.downscale_arr(topo_lr, self._s_enhance,
360+
method=self._interp_method)
342361

343362
delta_temp = single_hr_temp - interp_temp
344363
delta_topo = topo_hr - interp_topo
@@ -347,6 +366,10 @@ def downscale_rh(self, single_lr_rh, single_lr_temp, single_hr_temp,
347366
+ self._w_delta_temp * delta_temp
348367
+ self._w_delta_topo * delta_topo)
349368

369+
if self._fix_bias:
370+
hi_res_rh = self._fix_downscaled_bias(single_lr_rh, hi_res_rh,
371+
method=self._interp_method)
372+
350373
return hi_res_rh
351374

352375
def downscale_pres(self, single_lr_pres, topo_lr, topo_hr):
@@ -388,21 +411,28 @@ def downscale_pres(self, single_lr_pres, topo_lr, topo_hr):
388411
warn(msg)
389412

390413
const = 101325 * (1 - (1 - topo_lr / self._pres_div)**self._pres_exp)
391-
single_lr_pres = single_lr_pres.copy() + const
414+
lr_pres_adj = single_lr_pres.copy() + const
392415

393-
if np.min(single_lr_pres) < 0.0:
416+
if np.min(lr_pres_adj) < 0.0:
394417
msg = ('Spatial interpolation of surface pressure '
395418
'resulted in negative values. Incorrectly '
396419
'scaled/unscaled values or incorrect units are '
397-
'the most likely causes.')
420+
'the most likely causes. All pressure data should be '
421+
'in Pascals.')
398422
logger.error(msg)
399423
raise ValueError(msg)
400424

401-
hi_res_pres = self.downscale_arr(single_lr_pres, self._s_enhance)
425+
hi_res_pres = self.downscale_arr(lr_pres_adj, self._s_enhance,
426+
method=self._interp_method)
402427

403428
const = 101325 * (1 - (1 - topo_hr / self._pres_div)**self._pres_exp)
404429
hi_res_pres -= const
405430

431+
if self._fix_bias:
432+
hi_res_pres = self._fix_downscaled_bias(single_lr_pres,
433+
hi_res_pres,
434+
method=self._interp_method)
435+
406436
if np.min(hi_res_pres) < 0.0:
407437
msg = ('Spatial interpolation of surface pressure '
408438
'resulted in negative values. Incorrectly '
@@ -524,25 +554,11 @@ def meta(self):
524554
'pressure_exponent': self._pres_exp,
525555
'training_features': self.training_features,
526556
'output_features': self.output_features,
557+
'interp_method': str(self._interp_method),
558+
'fix_bias': self._fix_bias,
527559
'class': self.__class__.__name__,
528560
}
529561

530-
@property
531-
def training_features(self):
532-
"""Get the list of input feature names that the generative model was
533-
trained on.
534-
535-
Note that topography needs to be passed into generate() as an exogenous
536-
data input.
537-
"""
538-
return self._features
539-
540-
@property
541-
def output_features(self):
542-
"""Get the list of output feature names that the generative model
543-
outputs"""
544-
return self._features
545-
546562
def train(self, true_hr_temp, true_hr_rh, true_hr_topo):
547563
"""This method trains the relative humidity linear model. The
548564
temperature and surface lapse rate models are parameterizations taken

sup3r/preprocessing/data_handling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,7 @@ def preflight(self):
11061106
msg = (f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger '
11071107
'than the number of time steps in the raw data '
11081108
f'({len(self.raw_time_index)}).')
1109-
if len(self.raw_time_index) >= self.sample_shape[2]:
1109+
if len(self.raw_time_index) < self.sample_shape[2]:
11101110
logger.warning(msg)
11111111
warnings.warn(msg)
11121112

sup3r/utilities/loss_metrics.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Loss metrics for Sup3r"""
22

3-
from tensorflow.keras.losses import MeanSquaredError
3+
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
44
import tensorflow as tf
55

66

@@ -171,3 +171,39 @@ def __call__(self, x1, x2):
171171
x1_coarse = tf.reduce_mean(x1, axis=(1, 2))
172172
x2_coarse = tf.reduce_mean(x2, axis=(1, 2))
173173
return self.MSE_LOSS(x1_coarse, x2_coarse)
174+
175+
176+
class TemporalExtremesLoss(tf.keras.losses.Loss):
177+
"""Loss class that encourages accuracy of the min/max values in the
178+
timeseries"""
179+
180+
MAE_LOSS = MeanAbsoluteError()
181+
182+
def __call__(self, x1, x2):
183+
"""Custom content loss that encourages temporal min/max accuracy
184+
185+
Parameters
186+
----------
187+
x1 : tf.tensor
188+
synthetic generator output
189+
(n_observations, spatial_1, spatial_2, temporal, features)
190+
x2 : tf.tensor
191+
high resolution data
192+
(n_observations, spatial_1, spatial_2, temporal, features)
193+
194+
Returns
195+
-------
196+
tf.tensor
197+
0D tensor with loss value
198+
"""
199+
x1_min = tf.reduce_min(x1, axis=3)
200+
x2_min = tf.reduce_min(x2, axis=3)
201+
202+
x1_max = tf.reduce_max(x1, axis=3)
203+
x2_max = tf.reduce_max(x2, axis=3)
204+
205+
mae = self.MAE_LOSS(x1, x2)
206+
mae_min = self.MAE_LOSS(x1_min, x2_min)
207+
mae_max = self.MAE_LOSS(x1_max, x2_max)
208+
209+
return mae + mae_min + mae_max

sup3r/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# -*- coding: utf-8 -*-
22
"""SUP3R Version"""
33

4-
__version__ = '0.0.8'
4+
__version__ = '0.0.9'

0 commit comments

Comments
 (0)