21
21
from tensorflow .keras import optimizers
22
22
23
23
import sup3r .utilities .loss_metrics
24
+ from sup3r .preprocessing .data_handling .exogenous_data_handling import ExoData
24
25
from sup3r .utilities import VERSION_RECORD
25
26
26
27
logger = logging .getLogger (__name__ )
@@ -223,14 +224,11 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
223
224
Low-resolution input data, usually a 4D or 5D array of shape:
224
225
(n_obs, spatial_1, spatial_2, n_features)
225
226
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
226
- exogenous_data : dict | None
227
- Dictionary of exogenous feature data with entries describing
228
- whether features should be combined at input, a mid network layer,
229
- or with output. This doesn't have to include the 'model' key since
230
- this data is for a single step model. e.g.
231
- {'topography': {'steps': [
232
- {'combine_type': 'input', 'data': ..., 'resolution': ...},
233
- {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}}
227
+ exogenous_data : dict | ExoData | None
228
+ Special dictionary (class:`ExoData`) of exogenous feature data with
229
+ entries describing whether features should be combined at input, a
230
+ mid network layer, or with output. This doesn't have to include
231
+ the 'model' key since this data is for a single step model.
234
232
235
233
Returns
236
234
-------
@@ -243,6 +241,10 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
243
241
if exogenous_data is None :
244
242
return low_res
245
243
244
+ if (not isinstance (exogenous_data , ExoData )
245
+ and exogenous_data is not None ):
246
+ exogenous_data = ExoData (exogenous_data )
247
+
246
248
training_features = ([] if self .training_features is None
247
249
else self .training_features )
248
250
fnum_diff = len (training_features ) - low_res .shape [- 1 ]
@@ -253,14 +255,10 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
253
255
assert all (feature in exogenous_data for feature in exo_feats ), msg
254
256
if exogenous_data is not None and fnum_diff > 0 :
255
257
for feature in exo_feats :
256
- entry = exogenous_data [feature ]
257
- combine_types = [step ['combine_type' ]
258
- for step in entry ['steps' ]]
259
- if 'input' in combine_types :
260
- idx = combine_types .index ('input' )
261
- low_res = np .concatenate ((low_res ,
262
- entry ['steps' ][idx ]['data' ]),
263
- axis = - 1 )
258
+ exo_input = exogenous_data .get_combine_type_data (
259
+ feature , 'input' )
260
+ if exo_input is not None :
261
+ low_res = np .concatenate ((low_res , exo_input ), axis = - 1 )
264
262
return low_res
265
263
266
264
def _combine_fwp_output (self , hi_res , exogenous_data = None ):
@@ -273,14 +271,11 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
273
271
High-resolution output data, usually a 4D or 5D array of shape:
274
272
(n_obs, spatial_1, spatial_2, n_features)
275
273
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
276
- exogenous_data : dict | None
277
- Dictionary of exogenous feature data with entries describing
278
- whether features should be combined at input, a mid network layer,
279
- or with output. This doesn't have to include the 'model' key since
280
- this data is for a single step model. e.g.
281
- {'topography': {'steps': [
282
- {'combine_type': 'input', 'data': ..., 'resolution': ...},
283
- {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}}
274
+ exogenous_data : dict | ExoData | None
275
+ Special dictionary (class:`ExoData`) of exogenous feature data with
276
+ entries describing whether features should be combined at input, a
277
+ mid network layer, or with output. This doesn't have to include
278
+ the 'model' key since this data is for a single step model.
284
279
285
280
Returns
286
281
-------
@@ -293,6 +288,10 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
293
288
if exogenous_data is None :
294
289
return hi_res
295
290
291
+ if (not isinstance (exogenous_data , ExoData )
292
+ and exogenous_data is not None ):
293
+ exogenous_data = ExoData (exogenous_data )
294
+
296
295
output_features = ([] if self .output_features is None
297
296
else self .output_features )
298
297
fnum_diff = len (output_features ) - hi_res .shape [- 1 ]
@@ -303,14 +302,10 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
303
302
assert all (feature in exogenous_data for feature in exo_feats ), msg
304
303
if exogenous_data is not None and fnum_diff > 0 :
305
304
for feature in exo_feats :
306
- entry = exogenous_data [feature ]
307
- combine_types = [step ['combine_type' ]
308
- for step in entry ['steps' ]]
309
- if 'output' in combine_types :
310
- idx = combine_types .index ('output' )
311
- hi_res = np .concatenate ((hi_res ,
312
- entry ['steps' ][idx ]['data' ]),
313
- axis = - 1 )
305
+ exo_output = exogenous_data .get_combine_type_data (
306
+ feature , 'output' )
307
+ if exo_output is not None :
308
+ hi_res = np .concatenate ((hi_res , exo_output ), axis = - 1 )
314
309
return hi_res
315
310
316
311
def _combine_loss_input (self , high_res_true , high_res_gen ):
@@ -1237,39 +1232,6 @@ def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True):
1237
1232
1238
1233
return hi_res_exo
1239
1234
1240
- def _get_layer_exo_input (self , layer_name , exogenous_data ):
1241
- """Get the high-resolution exo data for the given layer name from the
1242
- full exogenous_data dictionary.
1243
-
1244
- Parameters
1245
- ----------
1246
- layer_name : str
1247
- Name of Sup3rAdder or Sup3rConcat layer. This should match a
1248
- feature key in exogenous_data
1249
- exogenous_data : dict | None
1250
- Dictionary of exogenous feature data with entries describing
1251
- whether features should be combined at input, a mid network layer,
1252
- or with output. This doesn't have to include the 'model' key since
1253
- this data is for a single step model. e.g.
1254
- {'topography': {'steps': [
1255
- {'combine_type': 'input', 'data': ..., 'resolution': ...},
1256
- {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}}
1257
-
1258
- """
1259
- msg = (f'layer.name = { layer_name } does not match any '
1260
- 'features in exogenous_data '
1261
- f'({ list (exogenous_data )} )' )
1262
- assert layer_name in exogenous_data , msg
1263
- steps = exogenous_data [layer_name ]['steps' ]
1264
- combine_types = [step ['combine_type' ] for step in steps ]
1265
- msg = ('Received exogenous_data without any combine_type '
1266
- '= "layer" steps, for a model with an Adder/Concat '
1267
- 'layer.' )
1268
- assert 'layer' in combine_types , msg
1269
- idx = combine_types .index ('layer' )
1270
- hi_res_exo = steps [idx ]['data' ]
1271
- return hi_res_exo
1272
-
1273
1235
def generate (self ,
1274
1236
low_res ,
1275
1237
norm_in = True ,
@@ -1292,14 +1254,11 @@ def generate(self,
1292
1254
un_norm_out : bool
1293
1255
Flag to un-normalize synthetically generated output data to physical
1294
1256
units
1295
- exogenous_data : dict | None
1296
- Dictionary of exogenous feature data with entries describing
1297
- whether features should be combined at input, a mid network layer,
1298
- or with output. This doesn't have to include the 'model' key since
1299
- this data is for a single step model. e.g.
1300
- {'topography': {'steps': [
1301
- {'combine_type': 'input', 'data': ..., 'resolution': ...},
1302
- {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}}
1257
+ exogenous_data : dict | ExoData | None
1258
+ Special dictionary (class:`ExoData`) of exogenous feature data with
1259
+ entries describing whether features should be combined at input, a
1260
+ mid network layer, or with output. This doesn't have to include
1261
+ the 'model' key since this data is for a single step model.
1303
1262
1304
1263
Returns
1305
1264
-------
@@ -1309,6 +1268,10 @@ def generate(self,
1309
1268
(n_obs, spatial_1, spatial_2, n_features)
1310
1269
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
1311
1270
"""
1271
+ if (not isinstance (exogenous_data , ExoData )
1272
+ and exogenous_data is not None ):
1273
+ exogenous_data = ExoData (exogenous_data )
1274
+
1312
1275
low_res = self ._combine_fwp_input (low_res , exogenous_data )
1313
1276
if norm_in and self ._means is not None :
1314
1277
low_res = self .norm_input (low_res )
@@ -1317,8 +1280,12 @@ def generate(self,
1317
1280
for i , layer in enumerate (self .generator .layers [1 :]):
1318
1281
try :
1319
1282
if isinstance (layer , (Sup3rAdder , Sup3rConcat )):
1320
- hi_res_exo = self ._get_layer_exo_input (layer .name ,
1321
- exogenous_data )
1283
+ msg = (f'layer.name = { layer .name } does not match any '
1284
+ 'features in exogenous_data '
1285
+ f'({ list (exogenous_data )} )' )
1286
+ assert layer .name in exogenous_data , msg
1287
+ hi_res_exo = exogenous_data .get_combine_type_data (
1288
+ layer .name , 'layer' )
1322
1289
hi_res_exo = self ._reshape_norm_exo (hi_res ,
1323
1290
hi_res_exo ,
1324
1291
layer .name ,
0 commit comments