Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnb/train only features hot fix #171

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
@@ -327,7 +327,7 @@ def _combine_loss_input(self, high_res_true, high_res_gen):
"""
if high_res_true.shape[-1] > high_res_gen.shape[-1]:
for feature in self.exogenous_features:
f_idx = self.training_features.index(feature)
f_idx = self.hr_features.index(feature)
exo_data = high_res_true[..., f_idx: f_idx + 1]
high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1)
return high_res_gen
@@ -360,6 +360,23 @@ def training_features(self):
trained on."""
return self.meta.get('training_features', None)

@property
def train_only_features(self):
"""Get the list of feature names used only for training (expected as
input but not included in output)."""
return self.meta.get('train_only_features', None)

@property
def hr_features(self):
"""Get the list of features stored in batch.high_res. This is the same
as training_features but without train_only_features. This is used to
select the correct high res exogenous data."""
hr_features = self.training_features
if self.train_only_features is not None:
hr_features = [f for f in self.training_features
if f not in self.train_only_features]
return hr_features

@property
def output_features(self):
"""Get the list of output feature names that the generative model
@@ -445,7 +462,8 @@ def set_model_params(self, **kwargs):
kwargs = self._check_exo_features(**kwargs)

keys = ('input_resolution', 'training_features', 'output_features',
'smoothed_features', 's_enhance', 't_enhance', 'smoothing')
'train_only_features', 'smoothed_features', 's_enhance',
't_enhance', 'smoothing')
keys = [k for k in keys if k in kwargs]

for var in keys:
@@ -660,7 +678,6 @@ def norm_input(self, low_res):
warn(msg)
else:
stdevs = self._stdevs

low_res = (low_res.copy() - self._means) / stdevs

return low_res
@@ -819,7 +836,7 @@ def get_high_res_exo_input(self, high_res):
"""
exo_data = {}
for feature in self.exogenous_features:
f_idx = self.training_features.index(feature)
f_idx = self.hr_features.index(feature)
exo_fdata = high_res[..., f_idx: f_idx + 1]
exo_data[feature] = exo_fdata
return exo_data
1 change: 1 addition & 0 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
@@ -855,6 +855,7 @@ def train(self,
t_enhance=batch_handler.t_enhance,
smoothing=batch_handler.smoothing,
training_features=batch_handler.training_features,
train_only_features=batch_handler.train_only_features,
output_features=batch_handler.output_features,
smoothed_features=batch_handler.smoothed_features)

1 change: 1 addition & 0 deletions sup3r/models/conditional_moments.py
Original file line number Diff line number Diff line change
@@ -395,6 +395,7 @@ def train(self, batch_handler,
s_enhance=batch_handler.s_enhance,
t_enhance=batch_handler.t_enhance,
smoothing=batch_handler.smoothing,
train_only_features=batch_handler.train_only_features,
training_features=batch_handler.training_features,
output_features=batch_handler.output_features,
smoothed_features=batch_handler.smoothed_features)
22 changes: 18 additions & 4 deletions sup3r/preprocessing/batch_handling.py
Original file line number Diff line number Diff line change
@@ -13,11 +13,19 @@
from scipy.ndimage.filters import gaussian_filter

from sup3r.preprocessing.data_handling.h5_data_handling import (
DataHandlerDCforH5, )
DataHandlerDCforH5,
)
from sup3r.utilities.utilities import (
estimate_max_workers, nn_fill_array, nsrdb_reduce_daily_data, smooth_data,
spatial_coarsening, temporal_coarsening, uniform_box_sampler,
uniform_time_sampler, weighted_box_sampler, weighted_time_sampler,
estimate_max_workers,
nn_fill_array,
nsrdb_reduce_daily_data,
smooth_data,
spatial_coarsening,
temporal_coarsening,
uniform_box_sampler,
uniform_time_sampler,
weighted_box_sampler,
weighted_time_sampler,
)

np.random.seed(42)
@@ -577,6 +585,12 @@ def training_features(self):
data handlers"""
return self.data_handlers[0].features

@property
def train_only_features(self):
"""Get the ordered list of feature names used only for training which
will not be stored in batch.high_res"""
return self.data_handlers[0].train_only_features

@property
def output_features(self):
"""Get the ordered list of feature names held in this object's
2 changes: 1 addition & 1 deletion sup3r/preprocessing/data_handling/base.py
Original file line number Diff line number Diff line change
@@ -1656,7 +1656,7 @@ def _check_grid_extent(cls, target, grid_shape, lat_lon):
min_lon = np.min(lat_lon[..., 1])
max_lat = np.max(lat_lon[..., 0])
max_lon = np.max(lat_lon[..., 1])
logger.debug('Calculating raster index from WRF file '
logger.debug('Calculating raster index from NETCDF file '
f'for shape {grid_shape} and target {target}')
logger.debug(f'lat/lon (min, max): {min_lat}/{min_lon}, '
f'{max_lat}/{max_lon}')
5 changes: 5 additions & 0 deletions sup3r/preprocessing/data_handling/dual_data_handling.py
Original file line number Diff line number Diff line change
@@ -181,6 +181,11 @@ def output_features(self):
GAN"""
return self.hr_dh.output_features

@property
def train_only_features(self):
"""Features to use for training only and not output"""
return self.lr_dh.train_only_features

def _shape_check(self):
"""Check if hr_handler.shape is divisible by s_enhance. If not take
the largest shape that can be."""
35 changes: 28 additions & 7 deletions sup3r/preprocessing/data_handling/nc_data_handling.py
Original file line number Diff line number Diff line change
@@ -19,13 +19,34 @@

from sup3r.preprocessing.data_handling.base import DataHandler, DataHandlerDC
from sup3r.preprocessing.feature_handling import (
BVFreqMon, BVFreqSquaredNC, ClearSkyRatioCC, Feature, InverseMonNC,
LatLonNC, PotentialTempNC, PressureNC, Rews, Shear, Tas, TasMax, TasMin,
TempNC, TempNCforCC, UWind, UWindPowerLaw, VWind, VWindPowerLaw,
WinddirectionNC, WindspeedNC)
BVFreqMon,
BVFreqSquaredNC,
ClearSkyRatioCC,
Feature,
InverseMonNC,
LatLonNC,
PotentialTempNC,
PressureNC,
Rews,
Shear,
Tas,
TasMax,
TasMin,
TempNC,
TempNCforCC,
UWind,
UWindPowerLaw,
VWind,
VWindPowerLaw,
WinddirectionNC,
WindspeedNC,
)
from sup3r.utilities.interpolation import Interpolator
from sup3r.utilities.utilities import (estimate_max_workers, get_time_dim_name,
np_to_pd_times)
from sup3r.utilities.utilities import (
estimate_max_workers,
get_time_dim_name,
np_to_pd_times,
)

np.random.seed(42)

@@ -385,7 +406,7 @@ def _check_grid_extent(cls, target, grid_shape, lat_lon):
min_lon = np.min(lat_lon[..., 1])
max_lat = np.max(lat_lon[..., 0])
max_lon = np.max(lat_lon[..., 1])
logger.debug('Calculating raster index from WRF file '
logger.debug('Calculating raster index from NETCDF file '
f'for shape {grid_shape} and target {target}')
logger.debug(f'lat/lon (min, max): {min_lat}/{min_lon}, '
f'{max_lat}/{max_lon}')
96 changes: 95 additions & 1 deletion tests/training/test_train_gan_exo.py
Original file line number Diff line number Diff line change
@@ -28,13 +28,107 @@
TARGET_S = (39.01, -105.13)

INPUT_FILE_W = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
FEATURES_W = ['U_100m', 'V_100m', 'temperature_100m', 'topography']
FEATURES_W = ['temperature_100m', 'U_100m', 'V_100m', 'topography']
TARGET_W = (39.01, -105.15)

FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
TARGET_COORD = (39.01, -105.15)


@pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat'])
def test_wind_hi_res_topo_with_train_only(custom_layer, log=False):
"""Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat
layer that adds/concatenates hi-res topography in the middle of the
network. This also includes a train only feature"""

handler = DataHandlerH5WindCC(INPUT_FILE_W,
FEATURES_W,
target=TARGET_W, shape=SHAPE,
temporal_slice=slice(None, None, 2),
time_roll=-7,
val_split=0.1,
sample_shape=(20, 20),
worker_kwargs=dict(max_workers=1),
train_only_features=['temperature_100m'])
batcher = SpatialBatchHandlerCC([handler], batch_size=2, n_batches=2,
s_enhance=2)

if log:
init_logger('sup3r', log_level='DEBUG')

gen_model = [{"class": "FlexiblePadding",
"paddings": [[0, 0], [3, 3], [3, 3], [0, 0]],
"mode": "REFLECT"},
{"class": "Conv2DTranspose", "filters": 64, "kernel_size": 3,
"strides": 1, "activation": "relu"},
{"class": "Cropping2D", "cropping": 4},

{"class": "FlexiblePadding",
"paddings": [[0, 0], [3, 3], [3, 3], [0, 0]],
"mode": "REFLECT"},
{"class": "Conv2DTranspose", "filters": 64,
"kernel_size": 3, "strides": 1, "activation": "relu"},
{"class": "Cropping2D", "cropping": 4},

{"class": "FlexiblePadding",
"paddings": [[0, 0], [3, 3], [3, 3], [0, 0]],
"mode": "REFLECT"},
{"class": "Conv2DTranspose", "filters": 64,
"kernel_size": 3, "strides": 1, "activation": "relu"},
{"class": "Cropping2D", "cropping": 4},
{"class": "SpatialExpansion", "spatial_mult": 2},
{"class": "Activation", "activation": "relu"},

{"class": custom_layer, "name": "topography"},

{"class": "FlexiblePadding",
"paddings": [[0, 0], [3, 3], [3, 3], [0, 0]],
"mode": "REFLECT"},
{"class": "Conv2DTranspose", "filters": 2,
"kernel_size": 3, "strides": 1, "activation": "relu"},
{"class": "Cropping2D", "cropping": 4}]

fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json')

Sup3rGan.seed()
model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4)

with tempfile.TemporaryDirectory() as td:
model.train(batcher,
input_resolution={'spatial': '16km',
'temporal': '3600min'},
n_epoch=1,
weight_gen_advers=0.0,
train_gen=True, train_disc=False,
checkpoint_int=None,
out_dir=os.path.join(td, 'test_{epoch}'))

assert model.train_only_features == ['temperature_100m']
assert model.hr_features == ['U_100m', 'V_100m', 'topography']
assert 'test_0' in os.listdir(td)
assert model.meta['output_features'] == ['U_100m', 'V_100m']
assert model.meta['class'] == 'Sup3rGan'
assert 'topography' in batcher.output_features
assert 'topography' not in model.output_features

x = np.random.uniform(0, 1, (4, 30, 30, 4))
hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1))

with pytest.raises(RuntimeError):
y = model.generate(x, exogenous_data=None)

exo_tmp = {
'topography': {
'steps': [
{'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}}
y = model.generate(x, exogenous_data=exo_tmp)

assert y.shape[0] == x.shape[0]
assert y.shape[1] == x.shape[1] * 2
assert y.shape[2] == x.shape[2] * 2
assert y.shape[3] == x.shape[3] - 2


@pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat'])
def test_wind_hi_res_topo(custom_layer, log=False):
"""Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat