diff --git a/cads_adaptors/tools/area_selector.py b/cads_adaptors/tools/area_selector.py index 75291e38..58bfff2d 100644 --- a/cads_adaptors/tools/area_selector.py +++ b/cads_adaptors/tools/area_selector.py @@ -11,6 +11,8 @@ from cads_adaptors.exceptions import CdsFormatConversionError, InvalidRequest from cads_adaptors.tools import adaptor_tools, convertors +DEFAULT_DASK_SCHEDULER_MODE = os.getenv("CADS_ADAPTOR_DASK_SCHEDULER_MODE", "threads") + def area_to_checked_dictionary(area: list[float | int]) -> dict[str, float | int]: north, east, south, west = area @@ -183,7 +185,9 @@ def area_selector( ds, lat_key, area["south"], area["north"], context, **extra_kwargs )[0] - context.debug(f"lat_slice: {lat_slice}\nlon_slices: {lon_slices}") + context.debug( + f"Area selector: lat_slice: {lat_slice}\nlon_slices: {lon_slices}" + ) sub_selections = [] for lon_slice in lon_slices: @@ -197,12 +201,12 @@ def area_selector( **sel_kwargs, ) ) - context.debug(f"selections: {sub_selections}") + # context.debug(f"selections: {sub_selections}") ds_area = xr.concat( sub_selections, dim=lon_key, data_vars="minimal", coords="minimal" ) - context.debug(f"ds_area: {ds_area}") + context.debug(f"Area selector: ds_area: {ds_area}") # Ensure that there are no length zero dimensions for dim in [lat_key, lon_key]: @@ -231,6 +235,7 @@ def area_selector_path( target_dir: str | None = None, area_selector_kwargs: dict[str, Any] = {}, open_datasets_kwargs: list[dict[str, Any]] | dict[str, Any] = {}, + dask_scheduler_mode: str = "threads", **kwargs: dict[str, Any], ) -> list[str]: if isinstance(area, list): @@ -281,8 +286,10 @@ def area_selector_path( out_path = os.path.join(target_dir, f"{fname_tag}.nc") for var in ds_area.variables: ds_area[var].encoding.setdefault("_FillValue", None) - # Need to compute before writing to disk as dask loses too many jobs - ds_area.compute().to_netcdf(out_path) + # If threads, need to compute before writing to disk as dask loses too many jobs + if dask_scheduler_mode == "threads": + ds_area.compute() + ds_area.to_netcdf(out_path) out_paths.append(out_path) else: context.add_user_visible_error( @@ -292,7 +299,10 @@ def area_selector_path( out_path = os.path.join(target_dir, f"{fname_tag}.nc") for var in ds_area.variables: ds_area[var].encoding.setdefault("_FillValue", None) - ds_area.compute().to_netcdf(out_path) + # If threads, need to compute before writing to disk as dask loses too many jobs + if dask_scheduler_mode == "threads": + ds_area.compute() + ds_area.to_netcdf(out_path) out_paths.append(out_path) return out_paths @@ -304,17 +314,28 @@ def area_selector_paths( context: Context = Context(), **kwargs: Any, ) -> list[str]: - with dask.config.set(scheduler="threads"): + import time + + dask_scheduler_mode: str = kwargs.pop( + "dask_scheduler_mode", DEFAULT_DASK_SCHEDULER_MODE + ) + with dask.config.set(scheduler=dask_scheduler_mode): + time0 = time.time() # We try to select the area for all paths, if any fail we return the original paths out_paths = [] for path in paths: try: out_paths += area_selector_path( - path, area=area, context=context, **kwargs + path, + area=area, + context=context, + dask_scheduler_mode=dask_scheduler_mode, + **kwargs, ) except (NotImplementedError, CdsFormatConversionError): - context.logger.debug( + context.debug( f"could not convert {path} to xarray; returning the original data" ) out_paths.append(path) + context.logger.info("Area selection complete", delta_time=time.time() - time0) return out_paths diff --git a/tests/test_30_area_selector.py b/tests/test_30_area_selector.py index 86ebaf6d..fdb59431 100644 --- a/tests/test_30_area_selector.py +++ b/tests/test_30_area_selector.py @@ -252,6 +252,13 @@ def test_area_selector_zero_length_dim(): ) +@pytest.mark.parametrize( + "dask_mode", + [ + "single-threaded", + "threads", + ], +) @pytest.mark.parametrize( "url", [ @@ -259,20 +266,29 @@ def test_area_selector_zero_length_dim(): f"{TEST_DATA_BASE_URL}/C3S-312bL1-L3C-MONTHLY-SRB-ATSR2_ORAC_ERS2_199506_fv3.0.nc", ], ) -def test_area_selector_real_files(url): +def test_area_selector_real_files(dask_mode, url): with tempfile.TemporaryDirectory() as temp_dir: test_file = os.path.join(temp_dir, TEMP_FILENAME) remote_file = requests.get(url) with open(test_file, "wb") as f: f.write(remote_file.content) - result = area_selector_path(test_file, area=[90, -180, -90, 180]) + result = area_selector_path( + test_file, area=[90, -180, -90, 180], dask_scheduler_mode=dask_mode + ) assert isinstance(result, list) assert len(result) == 1 assert isinstance(result[0], str) # Test with lists of urls +@pytest.mark.parametrize( + "dask_mode", + [ + "single-threaded", + "threads", + ], +) @pytest.mark.parametrize( "urls", [ @@ -292,7 +308,7 @@ def test_area_selector_real_files(url): ], ], ) -def test_area_selector_paths_real_files(urls): +def test_area_selector_paths_real_files(dask_mode, urls): with tempfile.TemporaryDirectory() as temp_dir: test_files = [] for i, file in enumerate(urls): @@ -306,6 +322,7 @@ def test_area_selector_paths_real_files(urls): result = area_selector_paths( test_files, area=[90, -180, -90, 180], + dask_scheduler_mode=dask_mode, ) assert isinstance(result, list) assert len(result) == len(urls)