Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions benchmarks/benchmarks/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from subprocess import CalledProcessError, check_output, run
from os import environ
from pathlib import Path
import re
from textwrap import dedent

from iris import load_cube
Expand All @@ -34,6 +35,14 @@
error = "Env variable DATA_GEN_PYTHON not a runnable python executable path."
raise ValueError(error)

default_data_dir = (Path(__file__).parent.parent / ".data").resolve()
BENCHMARK_DATA = Path(environ.get("BENCHMARK_DATA", default_data_dir))
if BENCHMARK_DATA == default_data_dir:
BENCHMARK_DATA.mkdir(exist_ok=True)
elif not BENCHMARK_DATA.is_dir():
message = f"Not a directory: {BENCHMARK_DATA} ."
raise ValueError(message)


def run_function_elsewhere(func_to_run, *args, **kwargs):
"""
Expand Down Expand Up @@ -109,21 +118,32 @@ def external(*args, **kwargs):
cube = original(*args, **kwargs)
save(cube, save_path)

save_dir = (Path(__file__).parent.parent / ".data").resolve()
save_dir.mkdir(exist_ok=True)
# TODO: caching? Currently written assuming overwrite every time.
save_path = save_dir / "_grid_cube.nc"

_ = run_function_elsewhere(
external,
file_name_sections = [
"_grid_cube",
n_lons,
n_lats,
lon_outer_bounds,
lat_outer_bounds,
circular,
alt_coord_system=alt_coord_system,
save_path=str(save_path),
)
alt_coord_system,
]
file_name = "_".join(str(section) for section in file_name_sections)
# Remove 'unsafe' characters.
file_name = re.sub(r"\W+", "", file_name)
save_path = (BENCHMARK_DATA / file_name).with_suffix(".nc")

if not save_path.is_file():
_ = run_function_elsewhere(
external,
n_lons,
n_lats,
lon_outer_bounds,
lat_outer_bounds,
circular,
alt_coord_system=alt_coord_system,
save_path=str(save_path),
)

return_cube = load_cube(str(save_path))
return return_cube

Expand All @@ -149,17 +169,15 @@ def external(*args, **kwargs):
cube = original(*args, **kwargs)
save(cube, save_path)

save_dir = (Path(__file__).parent.parent / ".data").resolve()
save_dir.mkdir(exist_ok=True)
# TODO: caching? Currently written assuming overwrite every time.
save_path = save_dir / f"_mesh_cube_{n_lons}_{n_lats}.nc"
save_path = BENCHMARK_DATA / f"_mesh_cube_{n_lons}_{n_lats}.nc"

_ = run_function_elsewhere(
external,
n_lons,
n_lats,
save_path=str(save_path),
)
if not save_path.is_file():
_ = run_function_elsewhere(
external,
n_lons,
n_lats,
save_path=str(save_path),
)

with PARSE_UGRID_ON_LOAD.context():
return_cube = load_cube(str(save_path))
Expand Down