Skip to content

Commit 6ef5591

Browse files
committed
PR updates
1 parent 40155ae commit 6ef5591

File tree

3 files changed

+33
-32
lines changed

3 files changed

+33
-32
lines changed

sup3r/models/abstract.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
224224
Low-resolution input data, usually a 4D or 5D array of shape:
225225
(n_obs, spatial_1, spatial_2, n_features)
226226
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
227-
exogenous_data : ExoData | None
227+
exogenous_data : dict | ExoData | None
228228
Special dictionary (class:`ExoData`) of exogenous feature data with
229229
entries describing whether features should be combined at input, a
230230
mid network layer, or with output. This doesn't have to include
@@ -241,6 +241,10 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
241241
if exogenous_data is None:
242242
return low_res
243243

244+
if (not isinstance(exogenous_data, ExoData)
245+
and exogenous_data is not None):
246+
exogenous_data = ExoData(exogenous_data)
247+
244248
training_features = ([] if self.training_features is None
245249
else self.training_features)
246250
fnum_diff = len(training_features) - low_res.shape[-1]
@@ -267,7 +271,7 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
267271
High-resolution output data, usually a 4D or 5D array of shape:
268272
(n_obs, spatial_1, spatial_2, n_features)
269273
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
270-
exogenous_data : dict | None
274+
exogenous_data : dict | ExoData | None
271275
Special dictionary (class:`ExoData`) of exogenous feature data with
272276
entries describing whether features should be combined at input, a
273277
mid network layer, or with output. This doesn't have to include
@@ -284,6 +288,10 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
284288
if exogenous_data is None:
285289
return hi_res
286290

291+
if (not isinstance(exogenous_data, ExoData)
292+
and exogenous_data is not None):
293+
exogenous_data = ExoData(exogenous_data)
294+
287295
output_features = ([] if self.output_features is None
288296
else self.output_features)
289297
fnum_diff = len(output_features) - hi_res.shape[-1]
@@ -1260,8 +1268,8 @@ def generate(self,
12601268
(n_obs, spatial_1, spatial_2, n_features)
12611269
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
12621270
"""
1263-
if (isinstance(exogenous_data, dict)
1264-
and not isinstance(exogenous_data, ExoData)):
1271+
if (not isinstance(exogenous_data, ExoData)
1272+
and exogenous_data is not None):
12651273
exogenous_data = ExoData(exogenous_data)
12661274

12671275
low_res = self._combine_fwp_input(low_res, exogenous_data)

sup3r/pipeline/forward_pass.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def load_exo_data(self):
11281128
class:`ExoData` object composed of multiple
11291129
class:`SingleExoDataStep` objects.
11301130
"""
1131-
data = []
1131+
data = {}
11321132
exo_data = None
11331133
if self.exo_kwargs:
11341134
self.features = [f for f in self.features
@@ -1144,7 +1144,7 @@ def load_exo_data(self):
11441144
sig = signature(ExogenousDataHandler)
11451145
exo_kwargs = {k: v for k, v in exo_kwargs.items()
11461146
if k in sig.parameters}
1147-
data += ExogenousDataHandler(**exo_kwargs).data
1147+
data.update(ExogenousDataHandler(**exo_kwargs).data)
11481148
exo_data = ExoData(data)
11491149
return exo_data
11501150

sup3r/preprocessing/data_handling/exogenous_data_handling.py

+19-26
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,24 @@ def __init__(self, steps):
6060
6161
Parameters
6262
----------
63-
steps : list | dict
64-
List of SingleExoDataStep objects or a feature dictionary with list
65-
of steps for each feature
63+
steps : dict
64+
Dictionary with feature keys each with entries describing whether
65+
features should be combined at input, a mid network layer, or with
66+
output. e.g.
67+
{'topography': {'steps': [
68+
{'combine_type': 'input', 'model': 0, 'data': ...,
69+
'resolution': ...},
70+
{'combine_type': 'layer', 'model': 0, 'data': ...,
71+
'resolution': ...}]}}
72+
Each array in in 'data' key has 3D or 4D shape:
73+
(spatial_1, spatial_2, 1)
74+
(spatial_1, spatial_2, n_temporal, 1)
6675
"""
67-
if isinstance(steps, list):
68-
for step in steps:
69-
self.append(step.feature, step)
70-
elif isinstance(steps, dict):
76+
if isinstance(steps, dict):
7177
for k, v in steps.items():
7278
self.__setitem__(k, v)
7379
else:
74-
msg = ('ExoData must be initialized with a dictionary of features '
75-
'or list of SingleExoDataStep objects.')
80+
msg = 'ExoData must be initialized with a dictionary of features.'
7681
logger.error(msg)
7782
raise ValueError(msg)
7883

@@ -117,18 +122,6 @@ def split_exo_dict(self, split_step):
117122
spatial models and temporal models split_step should be
118123
len(spatial_models). If this is for a TemporalThenSpatial model
119124
split_step should be len(temporal_models).
120-
exogenous_data : dict
121-
Dictionary of exogenous feature data with entries describing
122-
whether features should be combined at input, a mid network layer,
123-
or with output. e.g.
124-
{'topography': {'steps': [
125-
{'combine_type': 'input', 'model': 0, 'data': ...,
126-
'resolution': ...},
127-
{'combine_type': 'layer', 'model': 0, 'data': ...,
128-
'resolution': ...}]}}
129-
Each array in in 'data' key has 3D or 4D shape:
130-
(spatial_1, spatial_2, 1)
131-
(spatial_1, spatial_2, n_temporal, 1)
132125
133126
Returns
134127
-------
@@ -306,7 +299,7 @@ def __init__(self,
306299
self.input_handler = input_handler
307300
self.cache_data = cache_data
308301
self.cache_dir = cache_dir
309-
self.data = []
302+
self.data = {feature: {'steps': []}}
310303

311304
self.input_check()
312305
agg_enhance = self._get_all_agg_and_enhancement()
@@ -341,16 +334,16 @@ def __init__(self,
341334
t_agg_factor=t_agg_factor)
342335
step = SingleExoDataStep(feature, steps[i]['combine_type'],
343336
steps[i]['model'], data)
344-
self.data.append(step)
337+
self.data[feature]['steps'].append(step)
345338
else:
346339
msg = (f"Can only extract {list(self.AVAILABLE_HANDLERS)}."
347340
f" Received {feature}.")
348341
raise NotImplementedError(msg)
349-
shapes = [None if d is None else d['data'].shape
350-
for d in self.data]
342+
shapes = [None if step is None else step.shape
343+
for step in self.data[feature]['steps']]
351344
logger.info(
352345
'Got exogenous_data of length {} with shapes: {}'.format(
353-
len(self.data), shapes))
346+
len(self.data[feature]['steps']), shapes))
354347

355348
def input_check(self):
356349
"""Make sure agg factors are provided or exo_resolution and models are

0 commit comments

Comments
 (0)