Skip to content

Commit b8ebeeb

Browse files
authored
Merge pull request #240 from NREL/gb/bc_kwargs
Gb/bc kwargs
2 parents 9ac7518 + 6e46d6c commit b8ebeeb

File tree

8 files changed

+190
-62
lines changed

8 files changed

+190
-62
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers=[
2727
"Programming Language :: Python :: 3.11",
2828
]
2929
dependencies = [
30-
"NREL-rex>=0.2.90",
30+
"NREL-rex>=0.2.91",
3131
"NREL-phygnn>=0.0.23",
3232
"NREL-gaps>=0.6.13",
3333
"NREL-farms>=1.0.4",

sup3r/bias/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,12 @@ def _read_base_sup3r_data(dh, base_dset, base_gid):
621621
gid_raster = np.arange(len(dh.meta))
622622
gid_raster = gid_raster.reshape(dh.shape[:2])
623623
idy, idx = np.where(np.isin(gid_raster, base_gid))
624-
base_data = dh.data[[base_dset]][idy, idx].squeeze(axis=-1)
624+
if dh.data.loaded:
625+
# faster direct access of numpy array if loaded
626+
base_data = dh.data[base_dset].data[idy, idx]
627+
else:
628+
base_data = dh.data[base_dset].data.vindex[idy, idx]
629+
625630
assert base_data.shape[0] == len(base_gid)
626631
assert base_data.shape[1] == len(dh.time_index)
627632
return base_data.mean(axis=0)

sup3r/bias/bias_transforms.py

+104-46
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
import dask.array as da
1616
import numpy as np
17-
import pandas as pd
1817
from rex.utilities.bc_utils import QuantileDeltaMapping
1918
from scipy.ndimage import gaussian_filter
2019

2120
from sup3r.preprocessing import Rasterizer
21+
from sup3r.preprocessing.utilities import make_time_index_from_kws
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -27,7 +27,7 @@ def _get_factors(target, shape, var_names, bias_fp, threshold=0.1):
2727
"""Get bias correction factors from sup3r's standard resource
2828
2929
This was stripped without any change from original
30-
`get_spatial_bc_factors` to allow re-use in other `*_bc_factors`
30+
`_get_spatial_bc_factors` to allow re-use in other `*_bc_factors`
3131
functions.
3232
3333
Parameters
@@ -76,7 +76,7 @@ def _get_factors(target, shape, var_names, bias_fp, threshold=0.1):
7676
return out
7777

7878

79-
def get_spatial_bc_factors(lat_lon, feature_name, bias_fp, threshold=0.1):
79+
def _get_spatial_bc_factors(lat_lon, feature_name, bias_fp, threshold=0.1):
8080
"""Get bc factors (scalar/adder) for the given feature for the given
8181
domain (specified by lat_lon).
8282
@@ -114,7 +114,7 @@ def get_spatial_bc_factors(lat_lon, feature_name, bias_fp, threshold=0.1):
114114
)
115115

116116

117-
def get_spatial_bc_quantiles(
117+
def _get_spatial_bc_quantiles(
118118
lat_lon: Union[np.ndarray, da.core.Array],
119119
base_dset: str,
120120
feature_name: str,
@@ -200,7 +200,7 @@ def get_spatial_bc_quantiles(
200200
>>> lat_lon = np.array([
201201
... [39.649033, -105.46875 ],
202202
... [39.649033, -104.765625]])
203-
>>> params = get_spatial_bc_quantiles(
203+
>>> params = _get_spatial_bc_quantiles(
204204
... lat_lon, "ghi", "rsds", "./dist_params.hdf")
205205
206206
"""
@@ -297,7 +297,7 @@ def local_linear_bc(
297297
out = data * scalar + adder
298298
"""
299299

300-
out = get_spatial_bc_factors(lat_lon, feature_name, bias_fp)
300+
out = _get_spatial_bc_factors(lat_lon, feature_name, bias_fp)
301301
scalar, adder = out['scalar'], out['adder']
302302
# 3D bias correction factors have seasonal/monthly correction in last axis
303303
if len(scalar.shape) == 3 and len(adder.shape) == 3:
@@ -402,8 +402,8 @@ def monthly_local_linear_bc(
402402
out : np.ndarray
403403
out = data * scalar + adder
404404
"""
405-
time_index = pd.date_range(**date_range_kwargs)
406-
out = get_spatial_bc_factors(lat_lon, feature_name, bias_fp)
405+
time_index = make_time_index_from_kws(date_range_kwargs)
406+
out = _get_spatial_bc_factors(lat_lon, feature_name, bias_fp)
407407
scalar, adder = out['scalar'], out['adder']
408408

409409
assert len(scalar.shape) == 3, 'Monthly bias correct needs 3D scalars'
@@ -471,6 +471,7 @@ def local_qdm_bc(
471471
no_trend=False,
472472
delta_denom_min=None,
473473
delta_denom_zero=None,
474+
delta_range=None,
474475
out_range=None,
475476
max_workers=1,
476477
):
@@ -536,6 +537,11 @@ def local_qdm_bc(
536537
division by a very small number making delta blow up and resulting
537538
in very large output bias corrected values. See equation 4 of
538539
Cannon et al., 2015 for the delta term.
540+
delta_range : tuple | None
541+
Option to set a (min, max) on the delta term in QDM. This can help
542+
prevent QDM from making non-realistic increases/decreases in
543+
otherwise physical values. See equation 4 of Cannon et al., 2015 for
544+
the delta term.
539545
out_range : None | tuple
540546
Option to set floor/ceiling values on the output data.
541547
max_workers: int | None
@@ -583,12 +589,15 @@ def local_qdm_bc(
583589
584590
"""
585591
# Confirm that the given time matches the expected data size
586-
time_index = pd.date_range(**date_range_kwargs)
587-
assert (
588-
data.shape[2] == time_index.size
589-
), 'Time should align with data 3rd dimension'
590-
591-
params = get_spatial_bc_quantiles(
592+
msg = f'data was expected to be a 3D array but got shape {data.shape}'
593+
assert data.ndim == 3, msg
594+
time_index = make_time_index_from_kws(date_range_kwargs)
595+
msg = (f'Time should align with data 3rd dimension but got data '
596+
f'{data.shape} and time_index length '
597+
f'{time_index.size}: {time_index}')
598+
assert data.shape[-1] == time_index.size, msg
599+
600+
params = _get_spatial_bc_quantiles(
592601
lat_lon=lat_lon,
593602
base_dset=base_dset,
594603
feature_name=feature_name,
@@ -635,6 +644,7 @@ def local_qdm_bc(
635644
log_base=cfg['log_base'],
636645
delta_denom_min=delta_denom_min,
637646
delta_denom_zero=delta_denom_zero,
647+
delta_range=delta_range,
638648
)
639649

640650
subset_idx = nearest_window_idx == window_idx
@@ -654,10 +664,17 @@ def local_qdm_bc(
654664
output = np.maximum(output, np.min(out_range))
655665
output = np.minimum(output, np.max(out_range))
656666

667+
if np.isnan(output).any():
668+
msg = ('Presrat bias correction resulted in NaN values! If this is a '
669+
'relative QDM, you may try setting ``delta_denom_min`` or '
670+
'``delta_denom_zero``')
671+
logger.error(msg)
672+
raise RuntimeError(msg)
673+
657674
return output
658675

659676

660-
def get_spatial_bc_presrat(
677+
def _get_spatial_bc_presrat(
661678
lat_lon: np.array,
662679
base_dset: str,
663680
feature_name: str,
@@ -766,7 +783,7 @@ def get_spatial_bc_presrat(
766783
>>> lat_lon = np.array([
767784
... [39.649033, -105.46875 ],
768785
... [39.649033, -104.765625]])
769-
>>> params = get_spatial_bc_quantiles(
786+
>>> params = _get_spatial_bc_quantiles(
770787
... lat_lon, "ghi", "rsds", "./dist_params.hdf")
771788
772789
"""
@@ -788,12 +805,12 @@ def get_spatial_bc_presrat(
788805
)
789806

790807

791-
def apply_presrat_bc(data, time_index, base_params, bias_params,
792-
bias_fut_params, bias_tau_fut, k_factor,
793-
time_window_center, dist='empirical', sampling='invlog',
794-
log_base=10, relative=True, no_trend=False,
795-
zero_rate_threshold=1.157e-7, out_range=None,
796-
max_workers=1):
808+
def _apply_presrat_bc(data, time_index, base_params, bias_params,
809+
bias_fut_params, bias_tau_fut, k_factor,
810+
time_window_center, dist='empirical', sampling='invlog',
811+
log_base=10, relative=True, no_trend=False,
812+
delta_denom_min=None, delta_range=None, out_range=None,
813+
max_workers=1):
797814
"""Run PresRat to bias correct data from input parameters and not from bias
798815
correction file on disk.
799816
@@ -868,13 +885,18 @@ def apply_presrat_bc(data, time_index, base_params, bias_params,
868885
:class:`rex.utilities.bc_utils.QuantileDeltaMapping`. Note that this
869886
assumes that params_mh is the data distribution representative for the
870887
target data.
871-
zero_rate_threshold : float, default=1.157e-7
872-
Threshold value used to determine the zero rate in the observed
873-
historical dataset and the minimum value in the denominator in relative
874-
QDM. For instance, 0.01 means that anything less than that will be
875-
considered negligible, hence equal to zero. Dai 2006 defined this as
876-
1mm/day. Pierce 2015 used 0.01mm/day. We recommend 0.01mm/day
877-
(1.157e-7 kg/m2/s).
888+
delta_denom_min : float | None
889+
Option to specify a minimum value for the denominator term in the
890+
calculation of a relative delta value. This prevents division by a
891+
very small number making delta blow up and resulting in very large
892+
output bias corrected values. See equation 4 of Cannon et al., 2015
893+
for the delta term. If this is not set, the ``zero_rate_threshold``
894+
calculated as part of the presrat bias calculation will be used
895+
delta_range : tuple | None
896+
Option to set a (min, max) on the delta term in QDM. This can help
897+
prevent QDM from making non-realistic increases/decreases in
898+
otherwise physical values. See equation 4 of Cannon et al., 2015 for
899+
the delta term.
878900
out_range : None | tuple
879901
Option to set floor/ceiling values on the output data.
880902
max_workers : int | None
@@ -904,7 +926,8 @@ def apply_presrat_bc(data, time_index, base_params, bias_params,
904926
relative=relative,
905927
sampling=sampling,
906928
log_base=log_base,
907-
delta_denom_min=zero_rate_threshold,
929+
delta_denom_min=delta_denom_min,
930+
delta_range=delta_range,
908931
)
909932

910933
# input 3D shape (spatial, spatial, temporal)
@@ -928,6 +951,13 @@ def apply_presrat_bc(data, time_index, base_params, bias_params,
928951
data_unbiased = np.maximum(data_unbiased, np.min(out_range))
929952
data_unbiased = np.minimum(data_unbiased, np.max(out_range))
930953

954+
if np.isnan(data_unbiased).any():
955+
msg = ('Presrat bias correction resulted in NaN values! If this is a '
956+
'relative QDM, you may try setting ``delta_denom_min`` or '
957+
'``delta_denom_zero``')
958+
logger.error(msg)
959+
raise RuntimeError(msg)
960+
931961
return data_unbiased
932962

933963

@@ -941,6 +971,9 @@ def local_presrat_bc(data: np.ndarray,
941971
threshold=0.1,
942972
relative=True,
943973
no_trend=False,
974+
delta_denom_min=None,
975+
delta_range=None,
976+
k_range=None,
944977
out_range=None,
945978
max_workers=1,
946979
):
@@ -996,18 +1029,34 @@ def local_presrat_bc(data: np.ndarray,
9961029
:class:`rex.utilities.bc_utils.QuantileDeltaMapping`. Note that this
9971030
assumes that params_mh is the data distribution representative for the
9981031
target data.
1032+
delta_denom_min : float | None
1033+
Option to specify a minimum value for the denominator term in the
1034+
calculation of a relative delta value. This prevents division by a
1035+
very small number making delta blow up and resulting in very large
1036+
output bias corrected values. See equation 4 of Cannon et al., 2015
1037+
for the delta term. If this is not set, the ``zero_rate_threshold``
1038+
calculated as part of the presrat bias calculation will be used
1039+
delta_range : tuple | None
1040+
Option to set a (min, max) on the delta term in QDM. This can help
1041+
prevent QDM from making non-realistic increases/decreases in
1042+
otherwise physical values. See equation 4 of Cannon et al., 2015 for
1043+
the delta term.
1044+
k_range : tuple | None
1045+
Option to set a (min, max) value for the k-factor multiplier
9991046
out_range : None | tuple
10001047
Option to set floor/ceiling values on the output data.
10011048
max_workers : int | None
10021049
Max number of workers to use for QDM process pool
10031050
"""
1004-
time_index = pd.date_range(**date_range_kwargs)
1005-
assert data.ndim == 3, 'data was expected to be a 3D array'
1006-
assert (
1007-
data.shape[-1] == time_index.size
1008-
), 'The last dimension of data should be time'
1009-
1010-
params = get_spatial_bc_presrat(
1051+
time_index = make_time_index_from_kws(date_range_kwargs)
1052+
msg = f'data was expected to be a 3D array but got shape {data.shape}'
1053+
assert data.ndim == 3, msg
1054+
msg = (f'Time should align with data 3rd dimension but got data '
1055+
f'{data.shape} and time_index length '
1056+
f'{time_index.size}: {time_index}')
1057+
assert data.shape[-1] == time_index.size, msg
1058+
1059+
params = _get_spatial_bc_presrat(
10111060
lat_lon, base_dset, feature_name, bias_fp, threshold
10121061
)
10131062
cfg = params['cfg']
@@ -1022,21 +1071,30 @@ def local_presrat_bc(data: np.ndarray,
10221071
sampling = cfg['sampling']
10231072
log_base = cfg['log_base']
10241073
zero_rate_threshold = cfg['zero_rate_threshold']
1074+
delta_denom_min = delta_denom_min or zero_rate_threshold
1075+
1076+
if k_range is not None:
1077+
k_factor = np.maximum(k_factor, np.min(k_range))
1078+
k_factor = np.minimum(k_factor, np.max(k_range))
1079+
1080+
logger.debug(f'Presrat K Factor has shape {k_factor.shape} and ranges '
1081+
f'from {k_factor.min()} to {k_factor.max()}')
10251082

10261083
if lr_padded_slice is not None:
10271084
spatial_slice = (lr_padded_slice[0], lr_padded_slice[1])
10281085
base_params = base_params[spatial_slice]
10291086
bias_params = bias_params[spatial_slice]
10301087
bias_fut_params = bias_fut_params[spatial_slice]
10311088

1032-
data_unbiased = apply_presrat_bc(data, time_index, base_params,
1033-
bias_params, bias_fut_params,
1034-
bias_tau_fut, k_factor,
1035-
time_window_center, dist=dist,
1036-
sampling=sampling, log_base=log_base,
1037-
relative=relative, no_trend=no_trend,
1038-
zero_rate_threshold=zero_rate_threshold,
1039-
out_range=out_range,
1040-
max_workers=max_workers)
1089+
data_unbiased = _apply_presrat_bc(data, time_index, base_params,
1090+
bias_params, bias_fut_params,
1091+
bias_tau_fut, k_factor,
1092+
time_window_center, dist=dist,
1093+
sampling=sampling, log_base=log_base,
1094+
relative=relative, no_trend=no_trend,
1095+
delta_denom_min=delta_denom_min,
1096+
delta_range=delta_range,
1097+
out_range=out_range,
1098+
max_workers=max_workers)
10411099

10421100
return data_unbiased

sup3r/bias/utilities.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from rex import Resource
1010

1111
import sup3r.bias.bias_transforms
12-
from sup3r.bias.bias_transforms import get_spatial_bc_factors, local_qdm_bc
12+
from sup3r.bias.bias_transforms import _get_spatial_bc_factors, local_qdm_bc
1313
from sup3r.preprocessing.utilities import (
1414
_parse_time_slice,
1515
get_date_range_kwargs,
@@ -56,7 +56,7 @@ def lin_bc(handler, bc_files, threshold=0.1):
5656
and dset_adder.lower() in dsets
5757
)
5858
if feature not in completed and check:
59-
out = get_spatial_bc_factors(
59+
out = _get_spatial_bc_factors(
6060
lat_lon=handler.lat_lon,
6161
feature_name=feature,
6262
bias_fp=fp,
@@ -268,11 +268,19 @@ def bias_correct_features(
268268

269269
time_slice = _parse_time_slice(time_slice)
270270
for feat in features:
271-
input_handler[feat][..., time_slice] = bias_correct_feature(
272-
source_feature=feat,
273-
input_handler=input_handler,
274-
time_slice=time_slice,
275-
bc_method=bc_method,
276-
bc_kwargs=bc_kwargs,
277-
)
271+
try:
272+
input_handler[feat][..., time_slice] = bias_correct_feature(
273+
source_feature=feat,
274+
input_handler=input_handler,
275+
time_slice=time_slice,
276+
bc_method=bc_method,
277+
bc_kwargs=bc_kwargs,
278+
)
279+
except Exception as e:
280+
msg = (f'Could not run bias correction method {bc_method} on '
281+
f'feature {feat} time slice {time_slice} with input '
282+
f'handler of class {type(input_handler)} with shape '
283+
f'{input_handler.shape}. Received error: {e}')
284+
logger.exception(msg)
285+
raise RuntimeError(msg) from e
278286
return input_handler

sup3r/pipeline/strategy.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,12 @@ def init_input_handler(self):
261261

262262
InputHandler = get_input_handler_class(self.input_handler_name)
263263
input_handler_kwargs = copy.deepcopy(self.input_handler_kwargs)
264-
features = [] if self.head_node else self.features
265-
input_handler_kwargs['features'] = features
264+
265+
input_handler_kwargs['features'] = self.features
266+
if self.head_node:
267+
input_handler_kwargs['features'] = []
268+
input_handler_kwargs['chunks'] = 'auto'
269+
266270
input_handler_kwargs['time_slice'] = slice(None)
267271

268272
return InputHandler(**input_handler_kwargs)

0 commit comments

Comments
 (0)