Skip to content

Commit 298e551

Browse files
committed
exogenous data caching done on head node after checking for cache files. each node uses this data so it doesn't make sense to have each node try to cache this data.
1 parent 9d6363c commit 298e551

File tree

14 files changed

+110
-49
lines changed

14 files changed

+110
-49
lines changed

.github/workflows/codecov.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ name: Codecov
22

33
on:
44
push:
5-
branches: [main, master]
5+
branches: [main]
6+
workflow_dispatch:
67

78
jobs:
89
run:

.github/workflows/gh_pages.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ name: Documentation
22

33
on:
44
push:
5-
branches: [main, master]
5+
branches: [main]
6+
workflow_dispatch:
67

78
jobs:
89
make-pages:

.github/workflows/release_drafter.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Release Drafter
22

33
on:
44
push:
5-
branches: [main, master]
5+
branches: [main]
66

77
jobs:
88
update_release_draft:

docs/source/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
"sphinx.ext.napoleon",
6363
"sphinx_autosummary_accessors",
6464
"sphinx_copybutton",
65+
"pygments_lexer"
6566
]
6667

6768
intersphinx_mapping = {

sup3r/bias/bias_transforms.py

+12
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def monthly_local_linear_bc(
350350
temporal_avg=True,
351351
out_range=None,
352352
smoothing=0,
353+
range_kwargs=None
353354
):
354355
"""Bias correct data using a simple monthly *scalar +adder method on a
355356
site-by-site basis.
@@ -396,6 +397,9 @@ def monthly_local_linear_bc(
396397
effect of extreme values within aggregations over large number of
397398
pixels. This value is the standard deviation for the gaussian_filter
398399
kernel.
400+
range_kwargs : dict | None
401+
Dictionary of ranges for scalar and adder values. e.g. {'scalar': (0,
402+
3), 'adder': (-2, 2)}
399403
400404
Returns
401405
-------
@@ -450,6 +454,14 @@ def monthly_local_linear_bc(
450454
adder[..., idt], smoothing, mode='nearest'
451455
)
452456

457+
if range_kwargs is not None:
458+
scalar_range = range_kwargs.get('scalar', (-np.inf, np.inf))
459+
adder_range = range_kwargs.get('adder', (-np.inf, np.inf))
460+
scalar = np.minimum(scalar, np.max(scalar_range))
461+
scalar = np.maximum(scalar, np.min(scalar_range))
462+
adder = np.minimum(adder, np.max(adder_range))
463+
adder = np.maximum(adder, np.min(adder_range))
464+
453465
out = data * scalar + adder
454466
if out_range is not None:
455467
out = np.maximum(out, np.min(out_range))

sup3r/pipeline/slicer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ class ForwardPassSlicer:
2525
time_steps : int
2626
Number of time steps for full temporal domain of low res data. This
2727
is used to construct a dummy_time_index from np.arange(time_steps)
28-
time_slice : slice
29-
Slice to use to extract range from time_index
28+
time_slice : slice | list
29+
Slice to use to extract range from time_index. Can be a ``slice(start,
30+
stop, step)`` or list ``[start, stop, step]``
3031
chunk_shape : tuple
3132
Max shape (spatial_1, spatial_2, temporal) of an unpadded coarse
3233
chunk to use for a forward pass. The number of nodes that the

sup3r/pipeline/strategy.py

+42-15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from dataclasses import dataclass
1111
from functools import cached_property
1212
from typing import Dict, Optional, Tuple, Union
13+
from warnings import warn
1314

1415
import dask.array as da
1516
import numpy as np
@@ -228,6 +229,18 @@ def __post_init__(self):
228229
)
229230
self.n_chunks = self.fwp_slicer.n_chunks
230231

232+
msg = (
233+
'The same exogenous data is used by all nodes, so it will be '
234+
'cached on the head_node. This can take a long time and might be '
235+
'worth doing as an independent preprocessing step instead.'
236+
)
237+
if self.head_node and not all(
238+
os.path.exists(fp) for fp in self.get_exo_cache_files(model)
239+
):
240+
logger.warning(msg)
241+
warn(msg)
242+
_ = self.timer(self.load_exo_data, log=True)(model)
243+
231244
if not self.head_node:
232245
hr_shape = self.hr_lat_lon.shape[:-1]
233246
self.gids = np.arange(np.prod(hr_shape)).reshape(hr_shape)
@@ -532,19 +545,9 @@ def init_chunk(self, chunk_index=0):
532545
index=chunk_index,
533546
)
534547

535-
def load_exo_data(self, model):
536-
"""Extract exogenous data for each exo feature and store data in
537-
dictionary with key for each exo feature
538-
539-
Returns
540-
-------
541-
exo_data : ExoData
542-
:class:`ExoData` object composed of multiple
543-
:class:`SingleExoDataStep` objects. This is the exo data for the
544-
full spatiotemporal extent.
545-
"""
546-
data = {}
547-
exo_data = None
548+
def get_exo_kwargs(self, model):
549+
"""Get list of exo kwargs for all exo features."""
550+
exo_kwargs_list = []
548551
if self.exo_handler_kwargs:
549552
for feature in self.exo_features:
550553
exo_kwargs = copy.deepcopy(self.exo_handler_kwargs[feature])
@@ -558,8 +561,32 @@ def load_exo_data(self, model):
558561
_ = input_handler_kwargs.pop('time_slice', None)
559562
exo_kwargs['input_handler_kwargs'] = input_handler_kwargs
560563
exo_kwargs = get_class_kwargs(ExoDataHandler, exo_kwargs)
561-
data.update(ExoDataHandler(**exo_kwargs).data)
562-
exo_data = ExoData(data)
564+
exo_kwargs_list.append(exo_kwargs)
565+
return exo_kwargs_list
566+
567+
def get_exo_cache_files(self, model):
568+
"""Get list of exo cache files so we can check if they exist or not."""
569+
cache_files = []
570+
for exo_kwargs in self.get_exo_kwargs(model):
571+
cache_files.extend(ExoDataHandler(**exo_kwargs).cache_files)
572+
return cache_files
573+
574+
def load_exo_data(self, model):
575+
"""Extract exogenous data for each exo feature and store data in
576+
dictionary with key for each exo feature
577+
578+
Returns
579+
-------
580+
exo_data : ExoData
581+
:class:`ExoData` object composed of multiple
582+
:class:`SingleExoDataStep` objects. This is the exo data for the
583+
full spatiotemporal extent.
584+
"""
585+
data = {}
586+
exo_data = None
587+
for exo_kwargs in self.get_exo_kwargs(model):
588+
data.update(ExoDataHandler(**exo_kwargs).data)
589+
exo_data = ExoData(data)
563590
return exo_data
564591

565592
@cached_property

sup3r/preprocessing/batch_queues/abstract.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import time
1313
from abc import ABC, abstractmethod
1414
from collections import namedtuple
15+
from concurrent.futures import ThreadPoolExecutor
1516
from typing import TYPE_CHECKING, List, Optional, Union
1617

17-
import dask
1818
import numpy as np
1919
import tensorflow as tf
2020

@@ -244,15 +244,13 @@ def enqueue_batches(self) -> None:
244244
if needed == 1 or self.max_workers == 1:
245245
self.enqueue_batch()
246246
elif needed > 0:
247-
tasks = [
248-
dask.delayed(self.enqueue_batch)() for _ in range(needed)
249-
]
247+
with ThreadPoolExecutor(self.max_workers) as exe:
248+
_ = [exe.submit(self.enqueue_batch) for _ in range(needed)]
250249
logger.debug(
251250
'Added %s enqueue futures to %s queue.',
252251
needed,
253252
self._thread_name,
254253
)
255-
dask.compute(*tasks)
256254
if time.time() > log_time + 10:
257255
logger.debug(self.log_queue_info())
258256
log_time = time.time()

sup3r/preprocessing/cachers/base.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,15 @@ def __init__(
5959
of dictionaries for each feature (or a single dictionary to use
6060
for all features). e.g.
6161
.. code-block:: JSON
62-
{'cache_pattern': ...,
63-
'chunks': {
64-
'u_10m': {'time': 20, 'south_north': 100, 'west_east': 100}}
65-
}
62+
{'cache_pattern': ...,
63+
'chunks': {
64+
'u_10m': {
65+
'time': 20,
66+
'south_north': 100,
67+
'west_east': 100
68+
}
69+
}
70+
}
6671
6772
Note
6873
----
@@ -414,8 +419,10 @@ def write_netcdf(
414419
features : str | list
415420
Names of feature(s) to write to file.
416421
chunks : dict | None
417-
Chunk sizes for coordinate dimensions. e.g. ``{'windspeed':
418-
{'south_north': 100, 'west_east': 100, 'time': 10}}``
422+
Chunk sizes for coordinate dimensions. e.g. ``{'south_north': 100,
423+
'west_east': 100, 'time': 10}`` Can also include dataset specific
424+
values. e.g. ``{'windspeed': {'south_north': 100, 'west_east': 100,
425+
'time': 10}}``
419426
max_workers : int | None
420427
Number of workers to use for parallel writing of chunks
421428
mode : str

sup3r/preprocessing/data_handlers/exo.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def get_chunk(self, lr_slices):
250250
if k == 'data':
251251
# last dimension is feature channel, so we use only the
252252
# spatial slices if data is 2d and all slices otherwise
253-
chunk_step[k] = v[tuple(exo_slices)[:len(v.shape) - 1]]
253+
chunk_step[k] = v[
254+
tuple(exo_slices)[: len(v.shape) - 1]
255+
]
254256
else:
255257
chunk_step[k] = v
256258
exo_chunk[feature]['steps'].append(chunk_step)
@@ -380,9 +382,8 @@ def get_exo_steps(cls, feature, models):
380382
steps.append({'model': i, 'combine_type': 'output'})
381383
return steps
382384

383-
def get_single_step_data(self, s_enhance, t_enhance):
384-
"""Get exo data for a single model step, with specific enhancement
385-
factors."""
385+
def get_exo_rasterizer(self, s_enhance, t_enhance):
386+
"""Get exo rasterizer instance for given enhancement factors"""
386387
return ExoRasterizer(
387388
file_paths=self.file_paths,
388389
source_file=self.source_file,
@@ -394,7 +395,20 @@ def get_single_step_data(self, s_enhance, t_enhance):
394395
cache_dir=self.cache_dir,
395396
chunks=self.chunks,
396397
distance_upper_bound=self.distance_upper_bound,
397-
).data
398+
)
399+
400+
def get_single_step_data(self, s_enhance, t_enhance):
401+
"""Get exo data for a single model step, with specific enhancement
402+
factors."""
403+
return self.get_exo_rasterizer(s_enhance, t_enhance).data
404+
405+
@property
406+
def cache_files(self):
407+
"""Get exo data cache file for all enhancement factors"""
408+
return [
409+
self.get_exo_rasterizer(s_en, t_en).cache_file
410+
for s_en, t_en in zip(self.s_enhancements, self.t_enhancements)
411+
]
398412

399413
def get_all_step_data(self):
400414
"""Get exo data for each model step."""

sup3r/preprocessing/loaders/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Callable
99

1010
import numpy as np
11-
import xarray as xr
1211

1312
from sup3r.preprocessing.base import Container
1413
from sup3r.preprocessing.names import FEATURE_NAMES
@@ -17,6 +16,7 @@
1716
log_args,
1817
ordered_dims,
1918
)
19+
from sup3r.utilities.utilities import xr_open_mfdataset
2020

2121
from .utilities import (
2222
lower_names,
@@ -35,7 +35,7 @@ class BaseLoader(Container, ABC):
3535
by :class:`~sup3r.preprocessing.rasterizers.Rasterizer` objects to derive /
3636
extract specific features / regions / time_periods."""
3737

38-
BASE_LOADER: Callable = xr.open_mfdataset
38+
BASE_LOADER: Callable = xr_open_mfdataset
3939

4040
@log_args
4141
def __init__(

sup3r/preprocessing/rasterizers/exo.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -143,22 +143,19 @@ def source_handler(self):
143143
)
144144
return self._source_handler
145145

146-
def get_cache_file(self, feature):
146+
@property
147+
def cache_file(self):
147148
"""Get cache file name
148149
149-
Parameters
150-
----------
151-
feature : str
152-
Name of feature to get cache file for
153-
154150
Returns
155151
-------
156152
cache_fp : str
157153
Name of cache file. This is a netcdf file which will be saved with
158154
:class:`~sup3r.preprocessing.cachers.Cacher` and loaded with
159155
:class:`~sup3r.preprocessing.loaders.Loader`
160156
"""
161-
fn = f'exo_{feature}_{"_".join(map(str, self.input_handler.target))}_'
157+
fn = f'exo_{self.feature}_'
158+
fn += f'{"_".join(map(str, self.input_handler.target))}_'
162159
fn += f'{"x".join(map(str, self.input_handler.grid_shape))}_'
163160

164161
if len(self.source_data.shape) == 3:
@@ -278,8 +275,8 @@ def data(self):
278275
"""Get a raster of source values corresponding to the
279276
high-resolution grid (the file_paths input grid * s_enhance *
280277
t_enhance). The shape is (lats, lons, temporal, 1)"""
281-
cache_fp = self.get_cache_file(feature=self.feature)
282278

279+
cache_fp = self.cache_file
283280
if os.path.exists(cache_fp):
284281
data = Loader(cache_fp)
285282
else:

sup3r/preprocessing/samplers/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
uniform_box_sampler,
1616
uniform_time_sampler,
1717
)
18-
from sup3r.preprocessing.utilities import log_args, lowered
18+
from sup3r.preprocessing.utilities import compute_if_dask, log_args, lowered
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -195,9 +195,9 @@ def _reshape_samples(self, samples):
195195
new_shape[-1],
196196
]
197197
# (lats, lons, batch_size, times, feats)
198-
out = samples.reshape(new_shape)
198+
out = np.reshape(samples, new_shape)
199199
# (batch_size, lats, lons, times, feats)
200-
return np.asarray(out.transpose((2, 0, 1, 3, 4)))
200+
return compute_if_dask(np.transpose(out, axes=(2, 0, 1, 3, 4)))
201201

202202
def _stack_samples(self, samples):
203203
"""Used to build batch arrays in the case of independent time samples

sup3r/utilities/utilities.py

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def xr_open_mfdataset(files, **kwargs):
5858
"""Wrapper for xr.open_mfdataset with default opening options."""
5959
default_kwargs = {'engine': 'netcdf4'}
6060
default_kwargs.update(kwargs)
61+
if isinstance(files, str):
62+
files = [files]
6163
try:
6264
return xr.open_mfdataset(files, **default_kwargs)
6365
except Exception as e:

0 commit comments

Comments
 (0)