Skip to content

Commit 945ca3a

Browse files
committed
combined fill extend method and interior smoothing
1 parent 12c468f commit 945ca3a

File tree

3 files changed

+105
-34
lines changed

3 files changed

+105
-34
lines changed

sup3r/bias/bias_calc.py

+48-10
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,15 @@ def _run_single(cls, bias_data, base_fps, bias_feature, base_dset,
553553
bias_feature, base_dset)
554554
return out
555555

556-
def fill_extend(self, out, smooth_extend):
556+
def fill_smooth_extend(self, out, fill_extend=True, smooth_extend=0,
557+
smooth_interior=0):
557558
"""Fill data extending beyond the base meta data extent by doing a
558-
nearest neighbor gap fill.
559+
nearest neighbor gap fill. Smooth interior and extended region with
560+
given smoothing values.
561+
Interior smoothing can reduce the affect of extreme values
562+
within aggregations over large number of pixels.
563+
The interior is assumed to be defined by the region without nan values.
564+
The extended region is assumed to be the region with nan values.
559565
560566
Parameters
561567
----------
@@ -564,11 +570,20 @@ def fill_extend(self, out, smooth_extend):
564570
data and the scalar + adder factors to correct the biased data
565571
like: bias_data * scalar + adder. Each value is of shape
566572
(lat, lon, time).
573+
fill_extend : bool
574+
Whether to fill data extending beyond the base meta data with
575+
nearest neighbor values.
567576
smooth_extend : float
568577
Option to smooth the scalar/adder data outside of the spatial
569578
domain set by the threshold input. This alleviates the weird seams
570579
far from the domain of interest. This value is the standard
571580
deviation for the gaussian_filter kernel
581+
smooth_interior : float
582+
Value to use to smooth the scalar/adder data inside of the spatial
583+
domain set by the threshold input. This can reduce the affect of
584+
extreme values within aggregations over large number of pixels.
585+
This value is the standard deviation for the gaussian_filter
586+
kernel.
572587
573588
Returns
574589
-------
@@ -581,12 +596,30 @@ def fill_extend(self, out, smooth_extend):
581596
for key, arr in out.items():
582597
nan_mask = np.isnan(arr[..., 0])
583598
for idt in range(self.NT):
584-
arr[..., idt] = nn_fill_array(arr[..., idt])
599+
600+
arr_smooth = arr[..., idt]
601+
602+
needs_fill = (fill_extend or smooth_extend > 0
603+
or smooth_interior > 0)
604+
605+
if needs_fill:
606+
arr_smooth = nn_fill_array(arr_smooth)
607+
608+
arr_smooth_int = arr_smooth_ext = arr_smooth
609+
585610
if smooth_extend > 0:
586-
arr_smooth = gaussian_filter(arr[..., idt],
587-
smooth_extend,
588-
mode='nearest')
589-
out[key][nan_mask, idt] = arr_smooth[nan_mask]
611+
arr_smooth_ext = gaussian_filter(arr_smooth_ext,
612+
smooth_extend,
613+
mode='nearest')
614+
615+
if smooth_interior > 0:
616+
arr_smooth_int = gaussian_filter(arr_smooth_int,
617+
smooth_interior,
618+
mode='nearest')
619+
620+
out[key][nan_mask, idt] = arr_smooth_ext[nan_mask]
621+
out[key][~nan_mask, idt] = arr_smooth_int[~nan_mask]
622+
590623
return out
591624

592625
def write_outputs(self, fp_out, out):
@@ -623,7 +656,8 @@ def write_outputs(self, fp_out, out):
623656
.format(fp_out))
624657

625658
def run(self, knn, threshold=0.6, fp_out=None, max_workers=None,
626-
daily_reduction='avg', fill_extend=True, smooth_extend=0):
659+
daily_reduction='avg', fill_extend=True, smooth_extend=0,
660+
smooth_interior=0):
627661
"""Run linear correction factor calculations for every site in the bias
628662
dataset
629663
@@ -654,6 +688,10 @@ def run(self, knn, threshold=0.6, fp_out=None, max_workers=None,
654688
domain set by the threshold input. This alleviates the weird seams
655689
far from the domain of interest. This value is the standard
656690
deviation for the gaussian_filter kernel
691+
smooth_interior : float
692+
Option to smooth the scalar/adder data within the valid spatial
693+
domain. This can reduce the affect of extreme values within
694+
aggregations over large number of pixels.
657695
658696
Returns
659697
-------
@@ -732,8 +770,8 @@ def run(self, knn, threshold=0.6, fp_out=None, max_workers=None,
732770

733771
logger.info('Finished calculating bias correction factors.')
734772

735-
if fill_extend:
736-
out = self.fill_extend(out, smooth_extend)
773+
out = self.fill_smooth_extend(out, fill_extend, smooth_extend,
774+
smooth_interior)
737775

738776
self.write_outputs(fp_out, out)
739777

sup3r/utilities/utilities.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@
44
@author: bbenton
55
"""
66

7-
import numpy as np
8-
import logging
97
import glob
10-
from scipy import ndimage as nd
11-
from scipy.interpolate import RegularGridInterpolator
12-
from scipy.interpolate import interp1d
13-
from scipy.ndimage import zoom
14-
from scipy.ndimage.filters import gaussian_filter
15-
from fnmatch import fnmatch
8+
import logging
169
import os
1710
import re
11+
from fnmatch import fnmatch
1812
from warnings import warn
19-
import psutil
13+
14+
import numpy as np
2015
import pandas as pd
21-
from packaging import version
16+
import psutil
2217
import xarray as xr
18+
from packaging import version
19+
from scipy import ndimage as nd
20+
from scipy.interpolate import RegularGridInterpolator, interp1d
21+
from scipy.ndimage import zoom
22+
from scipy.ndimage.filters import gaussian_filter
2323

2424
np.random.seed(42)
2525

@@ -140,7 +140,7 @@ def get_chunk_slices(arr_size, chunk_size, index_slice=slice(None)):
140140

141141

142142
def get_raster_shape(raster_index):
143-
"""method to get shape of raster_index"""
143+
"""Method to get shape of raster_index"""
144144

145145
if any(isinstance(r, slice) for r in raster_index):
146146
shape = (raster_index[0].stop - raster_index[0].start,
@@ -182,7 +182,7 @@ def get_wrf_date_range(files):
182182

183183

184184
def uniform_box_sampler(data, shape):
185-
'''Extracts a sample cut from data.
185+
"""Extracts a sample cut from data.
186186
187187
Parameters
188188
----------
@@ -197,7 +197,7 @@ def uniform_box_sampler(data, shape):
197197
-------
198198
slices : list
199199
List of slices corresponding to row and col extent of arr sample
200-
'''
200+
"""
201201

202202
shape_1 = data.shape[0] if data.shape[0] < shape[0] else shape[0]
203203
shape_2 = data.shape[1] if data.shape[1] < shape[1] else shape[1]
@@ -299,7 +299,7 @@ def weighted_time_sampler(data, shape, weights):
299299

300300

301301
def uniform_time_sampler(data, shape):
302-
'''Extracts a temporal slice from data.
302+
"""Extracts a temporal slice from data.
303303
304304
Parameters
305305
----------
@@ -314,7 +314,7 @@ def uniform_time_sampler(data, shape):
314314
-------
315315
slice : slice
316316
time slice with size shape
317-
'''
317+
"""
318318
shape = data.shape[2] if data.shape[2] < shape else shape
319319
start = np.random.randint(0, data.shape[2] - shape + 1)
320320
stop = start + shape
@@ -996,7 +996,7 @@ def potential_temperature_difference(T_top, P_top, T_bottom, P_bottom):
996996
"""Potential temp difference calculation
997997
998998
Parameters
999-
---------
999+
----------
10001000
T_top : ndarray
10011001
Temperature at higher height. Used in the approximation of potential
10021002
temperature derivative
@@ -1023,7 +1023,7 @@ def potential_temperature_average(T_top, P_top, T_bottom, P_bottom):
10231023
"""Potential temp average calculation
10241024
10251025
Parameters
1026-
---------
1026+
----------
10271027
T_top : ndarray
10281028
Temperature at higher height. Used in the approximation of potential
10291029
temperature derivative

tests/training/test_bias_correction.py

+40-7
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
# -*- coding: utf-8 -*-
22
"""pytests bias correction calculations"""
3-
import h5py
43
import os
5-
import pytest
64
import tempfile
5+
6+
import h5py
77
import numpy as np
8+
import pytest
89
import xarray as xr
910

10-
from sup3r import TEST_DATA_DIR, CONFIG_DIR
11-
from sup3r.models import Sup3rGan
12-
from sup3r.qa.qa import Sup3rQa
13-
from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy
11+
from sup3r import CONFIG_DIR, TEST_DATA_DIR
1412
from sup3r.bias.bias_calc import LinearCorrection, MonthlyLinearCorrection
1513
from sup3r.bias.bias_transforms import local_linear_bc, monthly_local_linear_bc
16-
14+
from sup3r.models import Sup3rGan
15+
from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy
16+
from sup3r.qa.qa import Sup3rQa
1717

1818
FP_NSRDB = os.path.join(TEST_DATA_DIR, 'test_nsrdb_co_2018.h5')
1919
FP_CC = os.path.join(TEST_DATA_DIR, 'rsds_test.nc')
@@ -25,6 +25,39 @@
2525
SHAPE = (len(fh.lat.values), len(fh.lon.values))
2626

2727

28+
def test_smooth_interior_bc():
29+
"""Test linear bias correction with interior smoothing"""
30+
31+
calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds',
32+
TARGET, SHAPE, bias_handler='DataHandlerNCforCC')
33+
34+
out = calc.run(knn=1, threshold=0.6, fill_extend=False, max_workers=1)
35+
og_scalar = out['rsds_scalar']
36+
og_adder = out['rsds_adder']
37+
nan_mask = np.isnan(og_scalar)
38+
assert np.isnan(og_adder[nan_mask]).all()
39+
40+
out = calc.run(knn=1, threshold=0.6, fill_extend=True, smooth_interior=0,
41+
max_workers=1)
42+
scalar = out['rsds_scalar']
43+
adder = out['rsds_adder']
44+
# Make sure smooth_interior=0 does not change interior pixels
45+
assert np.allclose(og_scalar[~nan_mask], scalar[~nan_mask])
46+
assert np.allclose(og_adder[~nan_mask], adder[~nan_mask])
47+
assert not np.isnan(adder[nan_mask]).any()
48+
assert not np.isnan(scalar[nan_mask]).any()
49+
50+
# make sure smoothing affects the interior pixels but not the exterior
51+
out = calc.run(knn=1, threshold=0.6, fill_extend=True, smooth_interior=1,
52+
max_workers=1)
53+
smooth_scalar = out['rsds_scalar']
54+
smooth_adder = out['rsds_adder']
55+
assert not np.allclose(smooth_scalar[~nan_mask], scalar[~nan_mask])
56+
assert not np.allclose(smooth_adder[~nan_mask], adder[~nan_mask])
57+
assert np.allclose(smooth_scalar[nan_mask], scalar[nan_mask])
58+
assert np.allclose(smooth_adder[nan_mask], adder[nan_mask])
59+
60+
2861
def test_linear_bc():
2962
"""Test linear bias correction"""
3063

0 commit comments

Comments
 (0)