Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gb/trh loss #142

Merged
merged 12 commits into from
Jan 20, 2023
Merged
17 changes: 10 additions & 7 deletions sup3r/models/linear.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
"""Simple models for super resolution such as linear interp models."""
import numpy as np
import logging
from inspect import signature
import os
import json
from sup3r.utilities.utilities import st_interp
@@ -45,7 +46,9 @@ def load(cls, model_dir, verbose=False):
Parameters
----------
model_dir : str
Directory to load LinearInterp model files from.
Directory to load LinearInterp model files from. Must
have a model_params.json file containing "meta" key with all of the
class init args.
verbose : bool
Flag to log information about the loaded model.

@@ -59,11 +62,10 @@ def load(cls, model_dir, verbose=False):
with open(fp_params, 'r') as f:
params = json.load(f)

meta = params.get('meta', {'class': 'Sup3rGan'})
model = cls(features=meta['training_features'],
s_enhance=meta['s_enhance'],
t_enhance=meta['t_enhance'],
t_centered=meta['t_centered'])
meta = params['meta']
args = signature(cls.__init__).parameters
kwargs = {k: v for k, v in meta.items() if k in args}
model = cls(**kwargs)

if verbose:
logger.info('Loading LinearInterp with meta data: {}'
@@ -74,7 +76,8 @@ def load(cls, model_dir, verbose=False):
@property
def meta(self):
"""Get meta data dictionary that defines the model params"""
return {'s_enhance': self._s_enhance,
return {'features': self._features,
's_enhance': self._s_enhance,
't_enhance': self._t_enhance,
't_centered': self._t_centered,
'training_features': self.training_features,
148 changes: 82 additions & 66 deletions sup3r/models/surface.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# -*- coding: utf-8 -*-
"""Special models for surface meteorological data."""
import os
import json
import logging
from inspect import signature
from fnmatch import fnmatch
import numpy as np
from PIL import Image
from sklearn import linear_model
from warnings import warn

from sup3r.models.abstract import AbstractInterface
from sup3r.models.linear import LinearInterp
from sup3r.utilities.utilities import spatial_coarsening

logger = logging.getLogger(__name__)


class SurfaceSpatialMetModel(AbstractInterface):
class SurfaceSpatialMetModel(LinearInterp):
"""Model to spatially downscale daily-average near-surface temperature,
relative humidity, and pressure

@@ -43,7 +46,8 @@ class SurfaceSpatialMetModel(AbstractInterface):

def __init__(self, features, s_enhance, noise_adders=None,
temp_lapse=None, w_delta_temp=None, w_delta_topo=None,
pres_div=None, pres_exp=None):
pres_div=None, pres_exp=None, interp_method='LANCZOS',
fix_bias=True):
"""
Parameters
----------
@@ -85,6 +89,15 @@ def __init__(self, features, s_enhance, noise_adders=None,
pres_div : None | float
Exponential factor in the pressure scale height equation. Defaults
to the cls.PRES_EXP attribute.
interp_method : str
Name of the interpolation method to use from PIL.Image.Resampling
(NEAREST, BILINEAR, BICUBIC, LANCZOS)
LANCZOS is default and has been tested to work best for
SurfaceSpatialMetModel.
fix_bias : bool
Some local bias can be introduced by the bilinear interp + lapse
rate, this flag will attempt to correct that bias by using the
low-resolution deviation from the input data
"""

self._features = features
@@ -95,6 +108,8 @@ def __init__(self, features, s_enhance, noise_adders=None,
self._w_delta_topo = w_delta_topo or self.W_DELTA_TOPO
self._pres_div = pres_div or self.PRES_DIV
self._pres_exp = pres_exp or self.PRES_EXP
self._fix_bias = fix_bias
self._interp_method = getattr(Image.Resampling, interp_method)

if isinstance(self._noise_adders, (int, float)):
self._noise_adders = [self._noise_adders] * len(self._features)
@@ -103,42 +118,6 @@ def __len__(self):
"""Get number of model steps (match interface of MultiStepGan)"""
return 1

@classmethod
def load(cls, features, s_enhance, verbose=False, **kwargs):
"""Load the GAN with its sub-networks from a previously saved-to output
directory.

Parameters
----------
features : list
List of feature names that this model will operate on for both
input and output. This must match the feature axis ordering in the
array input to generate(). Typically this is a list containing:
temperature_*m, relativehumidity_*m, and pressure_*m. The list can
contain multiple instances of each variable at different heights.
relativehumidity_*m entries must have corresponding temperature_*m
entires at the same hub height.
s_enhance : int
Integer factor by which the spatial axes are to be enhanced.
verbose : bool
Flag to log information about the loaded model.
kwargs : None | dict
Optional kwargs to initialize SurfaceSpatialMetModel

Returns
-------
out : SurfaceSpatialMetModel
Returns an initialized SurfaceSpatialMetModel
"""

model = cls(features, s_enhance, **kwargs)

if verbose:
logger.info('Loading SurfaceSpatialMetModel with meta data: {}'
.format(model.meta))

return model

@staticmethod
def _get_s_enhance(topo_lr, topo_hr):
"""Get the spatial enhancement factor given low-res and high-res
@@ -227,8 +206,39 @@ def _get_temp_rh_ind(self, idf_rh):

return idf_temp

def _fix_downscaled_bias(self, single_lr, single_hr,
method=Image.Resampling.LANCZOS):
"""Fix any bias introduced by the spatial downscaling with lapse rate.

Parameters
----------
single_lr : np.ndarray
Single timestep raster data with shape
(lat, lon) matching the low-resolution input data.
single_hr : np.ndarray
Single timestep downscaled raster data with shape
(lat, lon) matching the high-resolution input data.
method : Image.Resampling.LANCZOS
An Image.Resampling method (NEAREST, BILINEAR, BICUBIC, LANCZOS).
NEAREST enforces zero bias but makes slightly more spatial seams.

Returns
-------
single_hr : np.ndarray
Single timestep downscaled raster data with shape
(lat, lon) matching the high-resolution input data.
"""

re_coarse = spatial_coarsening(np.expand_dims(single_hr, axis=-1),
s_enhance=self._s_enhance,
obs_axis=False)[..., 0]
bias = re_coarse - single_lr
bc = self.downscale_arr(bias, s_enhance=self._s_enhance, method=method)
single_hr -= bc
return single_hr

@staticmethod
def downscale_arr(arr, s_enhance, method=Image.Resampling.BILINEAR):
def downscale_arr(arr, s_enhance, method=Image.Resampling.LANCZOS):
"""Downscale a 2D array of data Image.resize() method

Parameters
@@ -238,9 +248,9 @@ def downscale_arr(arr, s_enhance, method=Image.Resampling.BILINEAR):
(lat, lon)
s_enhance : int
Integer factor by which the spatial axes are to be enhanced.
method : Image.Resampling.BILINEAR
method : Image.Resampling.LANCZOS
An Image.Resampling method (NEAREST, BILINEAR, BICUBIC, LANCZOS).
BILINEAR is default and has been tested to work best for
LANCZOS is default and has been tested to work best for
SurfaceSpatialMetModel.
"""
im = Image.fromarray(arr)
@@ -284,9 +294,15 @@ def downscale_temp(self, single_lr_temp, topo_lr, topo_hr):
assert len(topo_hr.shape) == 2, 'Bad shape for topo_hr'

lower_data = single_lr_temp.copy() + topo_lr * self._temp_lapse
hi_res_temp = self.downscale_arr(lower_data, self._s_enhance)
hi_res_temp = self.downscale_arr(lower_data, self._s_enhance,
method=self._interp_method)
hi_res_temp -= topo_hr * self._temp_lapse

if self._fix_bias:
hi_res_temp = self._fix_downscaled_bias(single_lr_temp,
hi_res_temp,
method=self._interp_method)

return hi_res_temp

def downscale_rh(self, single_lr_rh, single_lr_temp, single_hr_temp,
@@ -336,9 +352,12 @@ def downscale_rh(self, single_lr_rh, single_lr_temp, single_hr_temp,
assert len(topo_lr.shape) == 2, 'Bad shape for topo_lr'
assert len(topo_hr.shape) == 2, 'Bad shape for topo_hr'

interp_rh = self.downscale_arr(single_lr_rh, self._s_enhance)
interp_temp = self.downscale_arr(single_lr_temp, self._s_enhance)
interp_topo = self.downscale_arr(topo_lr, self._s_enhance)
interp_rh = self.downscale_arr(single_lr_rh, self._s_enhance,
method=self._interp_method)
interp_temp = self.downscale_arr(single_lr_temp, self._s_enhance,
method=self._interp_method)
interp_topo = self.downscale_arr(topo_lr, self._s_enhance,
method=self._interp_method)

delta_temp = single_hr_temp - interp_temp
delta_topo = topo_hr - interp_topo
@@ -347,6 +366,10 @@ def downscale_rh(self, single_lr_rh, single_lr_temp, single_hr_temp,
+ self._w_delta_temp * delta_temp
+ self._w_delta_topo * delta_topo)

if self._fix_bias:
hi_res_rh = self._fix_downscaled_bias(single_lr_rh, hi_res_rh,
method=self._interp_method)

return hi_res_rh

def downscale_pres(self, single_lr_pres, topo_lr, topo_hr):
@@ -388,21 +411,28 @@ def downscale_pres(self, single_lr_pres, topo_lr, topo_hr):
warn(msg)

const = 101325 * (1 - (1 - topo_lr / self._pres_div)**self._pres_exp)
single_lr_pres = single_lr_pres.copy() + const
lr_pres_adj = single_lr_pres.copy() + const

if np.min(single_lr_pres) < 0.0:
if np.min(lr_pres_adj) < 0.0:
msg = ('Spatial interpolation of surface pressure '
'resulted in negative values. Incorrectly '
'scaled/unscaled values or incorrect units are '
'the most likely causes.')
'the most likely causes. All pressure data should be '
'in Pascals.')
logger.error(msg)
raise ValueError(msg)

hi_res_pres = self.downscale_arr(single_lr_pres, self._s_enhance)
hi_res_pres = self.downscale_arr(lr_pres_adj, self._s_enhance,
method=self._interp_method)

const = 101325 * (1 - (1 - topo_hr / self._pres_div)**self._pres_exp)
hi_res_pres -= const

if self._fix_bias:
hi_res_pres = self._fix_downscaled_bias(single_lr_pres,
hi_res_pres,
method=self._interp_method)

if np.min(hi_res_pres) < 0.0:
msg = ('Spatial interpolation of surface pressure '
'resulted in negative values. Incorrectly '
@@ -524,25 +554,11 @@ def meta(self):
'pressure_exponent': self._pres_exp,
'training_features': self.training_features,
'output_features': self.output_features,
'interp_method': str(self._interp_method),
'fix_bias': self._fix_bias,
'class': self.__class__.__name__,
}

@property
def training_features(self):
"""Get the list of input feature names that the generative model was
trained on.

Note that topography needs to be passed into generate() as an exogenous
data input.
"""
return self._features

@property
def output_features(self):
"""Get the list of output feature names that the generative model
outputs"""
return self._features

def train(self, true_hr_temp, true_hr_rh, true_hr_topo):
"""This method trains the relative humidity linear model. The
temperature and surface lapse rate models are parameterizations taken
2 changes: 1 addition & 1 deletion sup3r/preprocessing/data_handling.py
Original file line number Diff line number Diff line change
@@ -1106,7 +1106,7 @@ def preflight(self):
msg = (f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger '
'than the number of time steps in the raw data '
f'({len(self.raw_time_index)}).')
if len(self.raw_time_index) >= self.sample_shape[2]:
if len(self.raw_time_index) < self.sample_shape[2]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, fixed this in my current dev branch also.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was driving me insane haha

logger.warning(msg)
warnings.warn(msg)

38 changes: 37 additions & 1 deletion sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Loss metrics for Sup3r"""

from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
import tensorflow as tf


@@ -171,3 +171,39 @@ def __call__(self, x1, x2):
x1_coarse = tf.reduce_mean(x1, axis=(1, 2))
x2_coarse = tf.reduce_mean(x2, axis=(1, 2))
return self.MSE_LOSS(x1_coarse, x2_coarse)


class TemporalExtremesLoss(tf.keras.losses.Loss):
"""Loss class that encourages accuracy of the min/max values in the
timeseries"""

MAE_LOSS = MeanAbsoluteError()

def __call__(self, x1, x2):
"""Custom content loss that encourages temporal min/max accuracy

Parameters
----------
x1 : tf.tensor
synthetic generator output
(n_observations, spatial_1, spatial_2, temporal, features)
x2 : tf.tensor
high resolution data
(n_observations, spatial_1, spatial_2, temporal, features)

Returns
-------
tf.tensor
0D tensor with loss value
"""
x1_min = tf.reduce_min(x1, axis=3)
x2_min = tf.reduce_min(x2, axis=3)

x1_max = tf.reduce_max(x1, axis=3)
x2_max = tf.reduce_max(x2, axis=3)

mae = self.MAE_LOSS(x1, x2)
mae_min = self.MAE_LOSS(x1_min, x2_min)
mae_max = self.MAE_LOSS(x1_max, x2_max)

return mae + mae_min + mae_max
2 changes: 1 addition & 1 deletion sup3r/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""SUP3R Version"""

__version__ = '0.0.8'
__version__ = '0.0.9'
Loading