Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/single threaded url area selector #265

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions cads_adaptors/tools/area_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from cads_adaptors.exceptions import CdsFormatConversionError, InvalidRequest
from cads_adaptors.tools import adaptor_tools, convertors

DEFAULT_DASK_SCHEDULER_MODE = os.getenv("CADS_ADAPTOR_DASK_SCHEDULER_MODE", "threads")


def area_to_checked_dictionary(area: list[float | int]) -> dict[str, float | int]:
north, east, south, west = area
Expand Down Expand Up @@ -183,7 +185,9 @@ def area_selector(
ds, lat_key, area["south"], area["north"], context, **extra_kwargs
)[0]

context.debug(f"lat_slice: {lat_slice}\nlon_slices: {lon_slices}")
context.debug(
f"Area selector: lat_slice: {lat_slice}\nlon_slices: {lon_slices}"
)

sub_selections = []
for lon_slice in lon_slices:
Expand All @@ -197,12 +201,12 @@ def area_selector(
**sel_kwargs,
)
)
context.debug(f"selections: {sub_selections}")
# context.debug(f"selections: {sub_selections}")

ds_area = xr.concat(
sub_selections, dim=lon_key, data_vars="minimal", coords="minimal"
)
context.debug(f"ds_area: {ds_area}")
context.debug(f"Area selector: ds_area: {ds_area}")

# Ensure that there are no length zero dimensions
for dim in [lat_key, lon_key]:
Expand Down Expand Up @@ -231,6 +235,7 @@ def area_selector_path(
target_dir: str | None = None,
area_selector_kwargs: dict[str, Any] = {},
open_datasets_kwargs: list[dict[str, Any]] | dict[str, Any] = {},
dask_scheduler_mode: str = "threads",
**kwargs: dict[str, Any],
) -> list[str]:
if isinstance(area, list):
Expand Down Expand Up @@ -281,8 +286,10 @@ def area_selector_path(
out_path = os.path.join(target_dir, f"{fname_tag}.nc")
for var in ds_area.variables:
ds_area[var].encoding.setdefault("_FillValue", None)
# Need to compute before writing to disk as dask loses too many jobs
ds_area.compute().to_netcdf(out_path)
# If threads, need to compute before writing to disk as dask loses too many jobs
if dask_scheduler_mode == "threads":
ds_area.compute()
ds_area.to_netcdf(out_path)
out_paths.append(out_path)
else:
context.add_user_visible_error(
Expand All @@ -292,7 +299,10 @@ def area_selector_path(
out_path = os.path.join(target_dir, f"{fname_tag}.nc")
for var in ds_area.variables:
ds_area[var].encoding.setdefault("_FillValue", None)
ds_area.compute().to_netcdf(out_path)
# If threads, need to compute before writing to disk as dask loses too many jobs
if dask_scheduler_mode == "threads":
ds_area.compute()
ds_area.to_netcdf(out_path)
out_paths.append(out_path)

return out_paths
Expand All @@ -304,17 +314,28 @@ def area_selector_paths(
context: Context = Context(),
**kwargs: Any,
) -> list[str]:
with dask.config.set(scheduler="threads"):
import time

dask_scheduler_mode: str = kwargs.pop(
"dask_scheduler_mode", DEFAULT_DASK_SCHEDULER_MODE
)
with dask.config.set(scheduler=dask_scheduler_mode):
time0 = time.time()
# We try to select the area for all paths, if any fail we return the original paths
out_paths = []
for path in paths:
try:
out_paths += area_selector_path(
path, area=area, context=context, **kwargs
path,
area=area,
context=context,
dask_scheduler_mode=dask_scheduler_mode,
**kwargs,
)
except (NotImplementedError, CdsFormatConversionError):
context.logger.debug(
context.debug(
f"could not convert {path} to xarray; returning the original data"
)
out_paths.append(path)
context.logger.info("Area selection complete", delta_time=time.time() - time0)
return out_paths
23 changes: 20 additions & 3 deletions tests/test_30_area_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,27 +252,43 @@ def test_area_selector_zero_length_dim():
)


@pytest.mark.parametrize(
"dask_mode",
[
"single-threaded",
"threads",
],
)
@pytest.mark.parametrize(
"url",
[
f"{TEST_DATA_BASE_URL}/CAMS-GLOB-AIR_Glb_0.5x0.5_anthro_voc25_v1.1_2012.nc",
f"{TEST_DATA_BASE_URL}/C3S-312bL1-L3C-MONTHLY-SRB-ATSR2_ORAC_ERS2_199506_fv3.0.nc",
],
)
def test_area_selector_real_files(url):
def test_area_selector_real_files(dask_mode, url):
with tempfile.TemporaryDirectory() as temp_dir:
test_file = os.path.join(temp_dir, TEMP_FILENAME)
remote_file = requests.get(url)
with open(test_file, "wb") as f:
f.write(remote_file.content)

result = area_selector_path(test_file, area=[90, -180, -90, 180])
result = area_selector_path(
test_file, area=[90, -180, -90, 180], dask_scheduler_mode=dask_mode
)
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], str)


# Test with lists of urls
@pytest.mark.parametrize(
"dask_mode",
[
"single-threaded",
"threads",
],
)
@pytest.mark.parametrize(
"urls",
[
Expand All @@ -292,7 +308,7 @@ def test_area_selector_real_files(url):
],
],
)
def test_area_selector_paths_real_files(urls):
def test_area_selector_paths_real_files(dask_mode, urls):
with tempfile.TemporaryDirectory() as temp_dir:
test_files = []
for i, file in enumerate(urls):
Expand All @@ -306,6 +322,7 @@ def test_area_selector_paths_real_files(urls):
result = area_selector_paths(
test_files,
area=[90, -180, -90, 180],
dask_scheduler_mode=dask_mode,
)
assert isinstance(result, list)
assert len(result) == len(urls)
Expand Down