@@ -60,19 +60,24 @@ def __init__(self, steps):
60
60
61
61
Parameters
62
62
----------
63
- steps : list | dict
64
- List of SingleExoDataStep objects or a feature dictionary with list
65
- of steps for each feature
63
+ steps : dict
64
+ Dictionary with feature keys each with entries describing whether
65
+ features should be combined at input, a mid network layer, or with
66
+ output. e.g.
67
+ {'topography': {'steps': [
68
+ {'combine_type': 'input', 'model': 0, 'data': ...,
69
+ 'resolution': ...},
70
+ {'combine_type': 'layer', 'model': 0, 'data': ...,
71
+ 'resolution': ...}]}}
72
+ Each array in in 'data' key has 3D or 4D shape:
73
+ (spatial_1, spatial_2, 1)
74
+ (spatial_1, spatial_2, n_temporal, 1)
66
75
"""
67
- if isinstance (steps , list ):
68
- for step in steps :
69
- self .append (step .feature , step )
70
- elif isinstance (steps , dict ):
76
+ if isinstance (steps , dict ):
71
77
for k , v in steps .items ():
72
78
self .__setitem__ (k , v )
73
79
else :
74
- msg = ('ExoData must be initialized with a dictionary of features '
75
- 'or list of SingleExoDataStep objects.' )
80
+ msg = 'ExoData must be initialized with a dictionary of features.'
76
81
logger .error (msg )
77
82
raise ValueError (msg )
78
83
@@ -117,18 +122,6 @@ def split_exo_dict(self, split_step):
117
122
spatial models and temporal models split_step should be
118
123
len(spatial_models). If this is for a TemporalThenSpatial model
119
124
split_step should be len(temporal_models).
120
- exogenous_data : dict
121
- Dictionary of exogenous feature data with entries describing
122
- whether features should be combined at input, a mid network layer,
123
- or with output. e.g.
124
- {'topography': {'steps': [
125
- {'combine_type': 'input', 'model': 0, 'data': ...,
126
- 'resolution': ...},
127
- {'combine_type': 'layer', 'model': 0, 'data': ...,
128
- 'resolution': ...}]}}
129
- Each array in in 'data' key has 3D or 4D shape:
130
- (spatial_1, spatial_2, 1)
131
- (spatial_1, spatial_2, n_temporal, 1)
132
125
133
126
Returns
134
127
-------
@@ -306,7 +299,7 @@ def __init__(self,
306
299
self .input_handler = input_handler
307
300
self .cache_data = cache_data
308
301
self .cache_dir = cache_dir
309
- self .data = []
302
+ self .data = { feature : { 'steps' : []}}
310
303
311
304
self .input_check ()
312
305
agg_enhance = self ._get_all_agg_and_enhancement ()
@@ -341,16 +334,16 @@ def __init__(self,
341
334
t_agg_factor = t_agg_factor )
342
335
step = SingleExoDataStep (feature , steps [i ]['combine_type' ],
343
336
steps [i ]['model' ], data )
344
- self .data .append (step )
337
+ self .data [ feature ][ 'steps' ] .append (step )
345
338
else :
346
339
msg = (f"Can only extract { list (self .AVAILABLE_HANDLERS )} ."
347
340
f" Received { feature } ." )
348
341
raise NotImplementedError (msg )
349
- shapes = [None if d is None else d [ 'data' ] .shape
350
- for d in self .data ]
342
+ shapes = [None if step is None else step .shape
343
+ for step in self .data [ feature ][ 'steps' ] ]
351
344
logger .info (
352
345
'Got exogenous_data of length {} with shapes: {}' .format (
353
- len (self .data ), shapes ))
346
+ len (self .data [ feature ][ 'steps' ] ), shapes ))
354
347
355
348
def input_check (self ):
356
349
"""Make sure agg factors are provided or exo_resolution and models are
0 commit comments