Skip to content

Commit 91fab51

Browse files
authored
Merge pull request #169 from NREL/bnb/exo_refactor_plus
Bnb/exo refactor plus
2 parents 4330268 + 53d8bc7 commit 91fab51

File tree

8 files changed

+273
-325
lines changed

8 files changed

+273
-325
lines changed

sup3r/models/abstract.py

+42-75
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tensorflow.keras import optimizers
2222

2323
import sup3r.utilities.loss_metrics
24+
from sup3r.preprocessing.data_handling.exogenous_data_handling import ExoData
2425
from sup3r.utilities import VERSION_RECORD
2526

2627
logger = logging.getLogger(__name__)
@@ -223,14 +224,11 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
223224
Low-resolution input data, usually a 4D or 5D array of shape:
224225
(n_obs, spatial_1, spatial_2, n_features)
225226
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
226-
exogenous_data : dict | None
227-
Dictionary of exogenous feature data with entries describing
228-
whether features should be combined at input, a mid network layer,
229-
or with output. This doesn't have to include the 'model' key since
230-
this data is for a single step model. e.g.
231-
{'topography': {'steps': [
232-
{'combine_type': 'input', 'data': ..., 'resolution': ...},
233-
{'combine_type': 'layer', 'data': ..., 'resolution': ...}]}}
227+
exogenous_data : dict | ExoData | None
228+
Special dictionary (class:`ExoData`) of exogenous feature data with
229+
entries describing whether features should be combined at input, a
230+
mid network layer, or with output. This doesn't have to include
231+
the 'model' key since this data is for a single step model.
234232
235233
Returns
236234
-------
@@ -243,6 +241,10 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
243241
if exogenous_data is None:
244242
return low_res
245243

244+
if (not isinstance(exogenous_data, ExoData)
245+
and exogenous_data is not None):
246+
exogenous_data = ExoData(exogenous_data)
247+
246248
training_features = ([] if self.training_features is None
247249
else self.training_features)
248250
fnum_diff = len(training_features) - low_res.shape[-1]
@@ -253,14 +255,10 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
253255
assert all(feature in exogenous_data for feature in exo_feats), msg
254256
if exogenous_data is not None and fnum_diff > 0:
255257
for feature in exo_feats:
256-
entry = exogenous_data[feature]
257-
combine_types = [step['combine_type']
258-
for step in entry['steps']]
259-
if 'input' in combine_types:
260-
idx = combine_types.index('input')
261-
low_res = np.concatenate((low_res,
262-
entry['steps'][idx]['data']),
263-
axis=-1)
258+
exo_input = exogenous_data.get_combine_type_data(
259+
feature, 'input')
260+
if exo_input is not None:
261+
low_res = np.concatenate((low_res, exo_input), axis=-1)
264262
return low_res
265263

266264
def _combine_fwp_output(self, hi_res, exogenous_data=None):
@@ -273,14 +271,11 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
273271
High-resolution output data, usually a 4D or 5D array of shape:
274272
(n_obs, spatial_1, spatial_2, n_features)
275273
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
276-
exogenous_data : dict | None
277-
Dictionary of exogenous feature data with entries describing
278-
whether features should be combined at input, a mid network layer,
279-
or with output. This doesn't have to include the 'model' key since
280-
this data is for a single step model. e.g.
281-
{'topography': {'steps': [
282-
{'combine_type': 'input', 'data': ..., 'resolution': ...},
283-
{'combine_type': 'layer', 'data': ..., 'resolution': ...}]}}
274+
exogenous_data : dict | ExoData | None
275+
Special dictionary (class:`ExoData`) of exogenous feature data with
276+
entries describing whether features should be combined at input, a
277+
mid network layer, or with output. This doesn't have to include
278+
the 'model' key since this data is for a single step model.
284279
285280
Returns
286281
-------
@@ -293,6 +288,10 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
293288
if exogenous_data is None:
294289
return hi_res
295290

291+
if (not isinstance(exogenous_data, ExoData)
292+
and exogenous_data is not None):
293+
exogenous_data = ExoData(exogenous_data)
294+
296295
output_features = ([] if self.output_features is None
297296
else self.output_features)
298297
fnum_diff = len(output_features) - hi_res.shape[-1]
@@ -303,14 +302,10 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
303302
assert all(feature in exogenous_data for feature in exo_feats), msg
304303
if exogenous_data is not None and fnum_diff > 0:
305304
for feature in exo_feats:
306-
entry = exogenous_data[feature]
307-
combine_types = [step['combine_type']
308-
for step in entry['steps']]
309-
if 'output' in combine_types:
310-
idx = combine_types.index('output')
311-
hi_res = np.concatenate((hi_res,
312-
entry['steps'][idx]['data']),
313-
axis=-1)
305+
exo_output = exogenous_data.get_combine_type_data(
306+
feature, 'output')
307+
if exo_output is not None:
308+
hi_res = np.concatenate((hi_res, exo_output), axis=-1)
314309
return hi_res
315310

316311
def _combine_loss_input(self, high_res_true, high_res_gen):
@@ -1237,39 +1232,6 @@ def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True):
12371232

12381233
return hi_res_exo
12391234

1240-
def _get_layer_exo_input(self, layer_name, exogenous_data):
1241-
"""Get the high-resolution exo data for the given layer name from the
1242-
full exogenous_data dictionary.
1243-
1244-
Parameters
1245-
----------
1246-
layer_name : str
1247-
Name of Sup3rAdder or Sup3rConcat layer. This should match a
1248-
feature key in exogenous_data
1249-
exogenous_data : dict | None
1250-
Dictionary of exogenous feature data with entries describing
1251-
whether features should be combined at input, a mid network layer,
1252-
or with output. This doesn't have to include the 'model' key since
1253-
this data is for a single step model. e.g.
1254-
{'topography': {'steps': [
1255-
{'combine_type': 'input', 'data': ..., 'resolution': ...},
1256-
{'combine_type': 'layer', 'data': ..., 'resolution': ...}]}}
1257-
1258-
"""
1259-
msg = (f'layer.name = {layer_name} does not match any '
1260-
'features in exogenous_data '
1261-
f'({list(exogenous_data)})')
1262-
assert layer_name in exogenous_data, msg
1263-
steps = exogenous_data[layer_name]['steps']
1264-
combine_types = [step['combine_type'] for step in steps]
1265-
msg = ('Received exogenous_data without any combine_type '
1266-
'= "layer" steps, for a model with an Adder/Concat '
1267-
'layer.')
1268-
assert 'layer' in combine_types, msg
1269-
idx = combine_types.index('layer')
1270-
hi_res_exo = steps[idx]['data']
1271-
return hi_res_exo
1272-
12731235
def generate(self,
12741236
low_res,
12751237
norm_in=True,
@@ -1292,14 +1254,11 @@ def generate(self,
12921254
un_norm_out : bool
12931255
Flag to un-normalize synthetically generated output data to physical
12941256
units
1295-
exogenous_data : dict | None
1296-
Dictionary of exogenous feature data with entries describing
1297-
whether features should be combined at input, a mid network layer,
1298-
or with output. This doesn't have to include the 'model' key since
1299-
this data is for a single step model. e.g.
1300-
{'topography': {'steps': [
1301-
{'combine_type': 'input', 'data': ..., 'resolution': ...},
1302-
{'combine_type': 'layer', 'data': ..., 'resolution': ...}]}}
1257+
exogenous_data : dict | ExoData | None
1258+
Special dictionary (class:`ExoData`) of exogenous feature data with
1259+
entries describing whether features should be combined at input, a
1260+
mid network layer, or with output. This doesn't have to include
1261+
the 'model' key since this data is for a single step model.
13031262
13041263
Returns
13051264
-------
@@ -1309,6 +1268,10 @@ def generate(self,
13091268
(n_obs, spatial_1, spatial_2, n_features)
13101269
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
13111270
"""
1271+
if (not isinstance(exogenous_data, ExoData)
1272+
and exogenous_data is not None):
1273+
exogenous_data = ExoData(exogenous_data)
1274+
13121275
low_res = self._combine_fwp_input(low_res, exogenous_data)
13131276
if norm_in and self._means is not None:
13141277
low_res = self.norm_input(low_res)
@@ -1317,8 +1280,12 @@ def generate(self,
13171280
for i, layer in enumerate(self.generator.layers[1:]):
13181281
try:
13191282
if isinstance(layer, (Sup3rAdder, Sup3rConcat)):
1320-
hi_res_exo = self._get_layer_exo_input(layer.name,
1321-
exogenous_data)
1283+
msg = (f'layer.name = {layer.name} does not match any '
1284+
'features in exogenous_data '
1285+
f'({list(exogenous_data)})')
1286+
assert layer.name in exogenous_data, msg
1287+
hi_res_exo = exogenous_data.get_combine_type_data(
1288+
layer.name, 'layer')
13221289
hi_res_exo = self._reshape_norm_exo(hi_res,
13231290
hi_res_exo,
13241291
layer.name,

0 commit comments

Comments
 (0)