10
10
from dataclasses import dataclass
11
11
from functools import cached_property
12
12
from typing import Dict , Optional , Tuple , Union
13
+ from warnings import warn
13
14
14
15
import dask .array as da
15
16
import numpy as np
@@ -228,6 +229,18 @@ def __post_init__(self):
228
229
)
229
230
self .n_chunks = self .fwp_slicer .n_chunks
230
231
232
+ msg = (
233
+ 'The same exogenous data is used by all nodes, so it will be '
234
+ 'cached on the head_node. This can take a long time and might be '
235
+ 'worth doing as an independent preprocessing step instead.'
236
+ )
237
+ if self .head_node and not all (
238
+ os .path .exists (fp ) for fp in self .get_exo_cache_files (model )
239
+ ):
240
+ logger .warning (msg )
241
+ warn (msg )
242
+ _ = self .timer (self .load_exo_data , log = True )(model )
243
+
231
244
if not self .head_node :
232
245
hr_shape = self .hr_lat_lon .shape [:- 1 ]
233
246
self .gids = np .arange (np .prod (hr_shape )).reshape (hr_shape )
@@ -532,19 +545,9 @@ def init_chunk(self, chunk_index=0):
532
545
index = chunk_index ,
533
546
)
534
547
535
- def load_exo_data (self , model ):
536
- """Extract exogenous data for each exo feature and store data in
537
- dictionary with key for each exo feature
538
-
539
- Returns
540
- -------
541
- exo_data : ExoData
542
- :class:`ExoData` object composed of multiple
543
- :class:`SingleExoDataStep` objects. This is the exo data for the
544
- full spatiotemporal extent.
545
- """
546
- data = {}
547
- exo_data = None
548
+ def get_exo_kwargs (self , model ):
549
+ """Get list of exo kwargs for all exo features."""
550
+ exo_kwargs_list = []
548
551
if self .exo_handler_kwargs :
549
552
for feature in self .exo_features :
550
553
exo_kwargs = copy .deepcopy (self .exo_handler_kwargs [feature ])
@@ -558,8 +561,32 @@ def load_exo_data(self, model):
558
561
_ = input_handler_kwargs .pop ('time_slice' , None )
559
562
exo_kwargs ['input_handler_kwargs' ] = input_handler_kwargs
560
563
exo_kwargs = get_class_kwargs (ExoDataHandler , exo_kwargs )
561
- data .update (ExoDataHandler (** exo_kwargs ).data )
562
- exo_data = ExoData (data )
564
+ exo_kwargs_list .append (exo_kwargs )
565
+ return exo_kwargs_list
566
+
567
+ def get_exo_cache_files (self , model ):
568
+ """Get list of exo cache files so we can check if they exist or not."""
569
+ cache_files = []
570
+ for exo_kwargs in self .get_exo_kwargs (model ):
571
+ cache_files .extend (ExoDataHandler (** exo_kwargs ).cache_files )
572
+ return cache_files
573
+
574
+ def load_exo_data (self , model ):
575
+ """Extract exogenous data for each exo feature and store data in
576
+ dictionary with key for each exo feature
577
+
578
+ Returns
579
+ -------
580
+ exo_data : ExoData
581
+ :class:`ExoData` object composed of multiple
582
+ :class:`SingleExoDataStep` objects. This is the exo data for the
583
+ full spatiotemporal extent.
584
+ """
585
+ data = {}
586
+ exo_data = None
587
+ for exo_kwargs in self .get_exo_kwargs (model ):
588
+ data .update (ExoDataHandler (** exo_kwargs ).data )
589
+ exo_data = ExoData (data )
563
590
return exo_data
564
591
565
592
@cached_property
0 commit comments