diff --git a/ndsl/debug/__init__.py b/ndsl/debug/__init__.py new file mode 100644 index 00000000..4f255c54 --- /dev/null +++ b/ndsl/debug/__init__.py @@ -0,0 +1,4 @@ +from .config import ndsl_debugger + + +__all__ = ["ndsl_debugger"] diff --git a/ndsl/debug/config.py b/ndsl/debug/config.py new file mode 100644 index 00000000..94ee7119 --- /dev/null +++ b/ndsl/debug/config.py @@ -0,0 +1,51 @@ +""" +This module provides configuration for the global debugger `ndsl_debugger` + +When loading, the configuration will be searched in the global environment variable +`NDSL_DEBUG_CONFIG` + +Configuration is a yaml file of the shape +```yaml +stencils_or_class: + - copy_corners_x_nord + - copy_corners_y_nord + - DGridShallowWaterLagrangianDynamics.__call__ +track_parameter_by_name: + - fy +``` + +Global variable: + ndsl_debugger: Debugger accessible throughout the middleware, default to `None` + if there is no configuration +""" + +import os + +import yaml + +from ndsl.comm.mpi import MPIComm +from ndsl.debug.debugger import Debugger +from ndsl.logging import ndsl_log + + +ndsl_debugger = None + + +def _set_debugger(): + config = os.getenv("NDSL_DEBUG_CONFIG", "") + if not os.path.exists(config): + if config != "": + ndsl_log.warning( + f"NDSL_DEBUG_CONFIG set but path {config} does not exists." + ) + else: + return + with open(config) as file: + config_dict = yaml.load(file.read(), Loader=yaml.SafeLoader) + global ndsl_debugger + ndsl_debugger = Debugger(rank=MPIComm().Get_rank(), **config_dict) + ndsl_log.info("[NDSL Debugger] On") + ndsl_log.debug(f"[NDSL Debugger] Config:\n{config_dict}") + + +_set_debugger() diff --git a/ndsl/debug/debugger.py b/ndsl/debug/debugger.py new file mode 100644 index 00000000..7e1f60fe --- /dev/null +++ b/ndsl/debug/debugger.py @@ -0,0 +1,109 @@ +import dataclasses +import numbers +import os +import pathlib + +import pandas as pd +import xarray as xr + +from ndsl.logging import ndsl_log +from ndsl.quantity import Quantity + + +@dataclasses.dataclass +class Debugger: + """Debugger relying on `ndsl.debug.config` for setup capable + of doing automatic data save on external configuration.""" + + # Configuration + stencils_or_class: list[str] = dataclasses.field(default_factory=list) + track_parameter_by_name: list[str] = dataclasses.field(default_factory=list) + save_compute_domain_only: bool = False + dir_name: str = "./" + + # Runtime data + rank: int = -1 + calls_count: dict[str, int] = dataclasses.field(default_factory=dict) + track_parameter_count: dict[str, int] = dataclasses.field(default_factory=dict) + + def _to_xarray(self, data, name) -> xr.DataArray: + if isinstance(data, Quantity): + if self.save_compute_domain_only: + mem = data.field + shp = data.field.shape + else: + mem = data.data + shp = data.shape + elif hasattr(data, "shape"): + mem = data + shp = data.shape + elif ( + pd.api.types.is_numeric_dtype(data) + or pd.api.types.is_string_dtype(data) + or isinstance(data, numbers.Number) + ): + return xr.DataArray(data) + else: + ndsl_log.error(f"[Debugger] Cannot save data of type {type(data)}") + return xr.DataArray([0]) + return xr.DataArray(mem, dims=[f"dim_{i}_{s}" for i, s in enumerate(shp)]) + + def track_data(self, data_as_dict, source_as_name, is_in) -> None: + for name, data in data_as_dict.items(): + if name not in self.track_parameter_by_name: + continue + + if name not in self.track_parameter_count: + self.track_parameter_count[name] = 0 + count = self.track_parameter_count[name] + + path = pathlib.Path(f"{self.dir_name}/debug/tracks/{name}/R{self.rank}/") + os.makedirs(path, exist_ok=True) + path = pathlib.Path( + f"{path}/{count}_{name}_{source_as_name}-{'In' if is_in else 'Out'}.nc4" + ) + try: + self._to_xarray(data, name).to_netcdf(path) + except ValueError as e: + from ndsl import ndsl_log + + ndsl_log.error(f"[Debugger] Failure to save {data}: {e}") + + self.track_parameter_count[name] += 1 + + def save_as_dataset(self, data_as_dict, savename, is_in) -> None: + """Save dictionnary of data to NetCDF + + Note: Unknown types in the dictionnary won't be saved. + """ + if savename not in self.stencils_or_class: + return + + data_arrays = {} + for name, data in data_as_dict.items(): + if dataclasses.is_dataclass(data): + for field in dataclasses.fields(data): + data_arrays[f"{name}.{field.name}"] = self._to_xarray( + getattr(data, field.name), field.name + ) + else: + data_arrays[name] = self._to_xarray(data, name) + + call_count = ( + self.calls_count[savename] if savename in self.calls_count.keys() else 0 + ) + path = pathlib.Path(f"{self.dir_name}/debug/savepoints/R{self.rank}/") + os.makedirs(path, exist_ok=True) + path = pathlib.Path( + f"{path}/{savename}-Call{call_count}-{'In' if is_in else 'Out'}.nc4" + ) + try: + xr.Dataset(data_arrays).to_netcdf(path) + except ValueError as e: + ndsl_log.error(f"[DebugInfo] Failure to save {savename}: {e}") + + def increment_call_count(self, savename: str): + """Increment the call count for this savename""" + if savename not in self.calls_count.keys(): + self.calls_count[savename] = 0 + self.calls_count[savename] += 1 diff --git a/ndsl/debug/tooling.py b/ndsl/debug/tooling.py new file mode 100644 index 00000000..4ec89bbc --- /dev/null +++ b/ndsl/debug/tooling.py @@ -0,0 +1,44 @@ +import inspect +from functools import wraps +from typing import Any, Callable + +from ndsl.debug.config import ndsl_debugger + + +def instrument(func) -> Callable: + @wraps(func) + def wrapper(self, *args: Any, **kwargs: Any): + if ndsl_debugger is None: + return func(self, *args, **kwargs) + savename = func.__qualname__ + params = inspect.signature(func).parameters + data_as_dict = {} + + # Positional + positional_count = 0 + for name, param in params.items(): + if param.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + if positional_count == 0: # self + positional_count += 1 + continue + if positional_count < len(args) + 1: + data_as_dict[name] = args[positional_count - 1] + positional_count += 1 + # Keyword arguments + for name, value in kwargs.items(): + if name in params: + data_as_dict[name] = value + if ndsl_debugger is not None: + ndsl_debugger.save_as_dataset(data_as_dict, func.__qualname__, is_in=True) + ndsl_debugger.track_data(data_as_dict, func.__qualname__, is_in=True) + r = func(self, *args, **kwargs) + if ndsl_debugger is not None: + ndsl_debugger.save_as_dataset(data_as_dict, func.__qualname__, is_in=False) + ndsl_debugger.track_data(data_as_dict, func.__qualname__, is_in=False) + ndsl_debugger.increment_call_count(savename) + return r + + return wrapper diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 9260bd29..76242816 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -21,12 +21,14 @@ import numpy as np from gt4py.cartesian import gtscript from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline +from gt4py.cartesian.stencil_object import StencilObject from ndsl.comm.comm_abc import Comm from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles from ndsl.comm.mpi import MPI from ndsl.constants import X_DIM, X_DIMS, Y_DIM, Y_DIMS, Z_DIM, Z_DIMS +from ndsl.debug import ndsl_debugger from ndsl.dsl.dace.orchestration import SDFGConvertible from ndsl.dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from ndsl.dsl.typing import Float, Index3D, cast_to_index3d @@ -295,10 +297,11 @@ def __init__( externals = {} self.externals = externals self._func_name = func.__name__ + self._func_qualname = func.__qualname__ stencil_kwargs = self.stencil_config.stencil_kwargs( skip_passes=skip_passes, func=func ) - self.stencil_object = None + self.stencil_object: StencilObject | None = None self._argument_names = tuple(inspect.getfullargspec(func).args) @@ -350,7 +353,7 @@ def __init__( dtypes={float: Float}, **stencil_kwargs, build_info=(build_info := {}), - ) + ) # type: ignore if ( compilation_config.use_minimal_caching @@ -384,13 +387,17 @@ def nothing_function(*args, **kwargs): setattr(self, "__call__", nothing_function) def __call__(self, *args, **kwargs) -> None: + # Verbose stencil execution if self.stencil_config.verbose: ndsl_log.debug(f"Running {self._func_name}") + + # Marshal arguments args_list = list(args) _convert_quantities_to_storage(args_list, kwargs) args = tuple(args_list) - args_as_kwargs = dict(zip(self._argument_names, args)) + + # Ranks comparison tool if self.comm is not None: differences = compare_ranks(self.comm, {**args_as_kwargs, **kwargs}) if len(differences) > 0: @@ -398,6 +405,14 @@ def __call__(self, *args, **kwargs) -> None: f"rank {self.comm.Get_rank()} has differences {differences} " f"before calling {self._func_name}" ) + + # Debugger actions if turned on + if ndsl_debugger: + all_args = args_as_kwargs | kwargs + ndsl_debugger.save_as_dataset(all_args, self._func_qualname, is_in=True) + ndsl_debugger.track_data(all_args, self._func_qualname, is_in=True) + + # Execute stencil if self.stencil_config.compilation_config.validate_args: if __debug__ and "origin" in kwargs: raise TypeError("origin cannot be passed to FrozenStencil call") @@ -410,7 +425,7 @@ def __call__(self, *args, **kwargs) -> None: domain=self.domain, validate_args=True, exec_info=self._timing_collector.exec_info, - ) + ) # type: ignore else: self.stencil_object.run( **args_as_kwargs, @@ -418,6 +433,15 @@ def __call__(self, *args, **kwargs) -> None: **self._stencil_run_kwargs, exec_info=self._timing_collector.exec_info, ) + + # Debugger actions if turned on + if ndsl_debugger: + all_args = args_as_kwargs | kwargs + ndsl_debugger.save_as_dataset(all_args, self._func_qualname, is_in=False) + ndsl_debugger.track_data(all_args, self._func_qualname, is_in=False) + ndsl_debugger.increment_call_count(self._func_qualname) + + # Ranks comparison tool if self.comm is not None: differences = compare_ranks(self.comm, {**args_as_kwargs, **kwargs}) if len(differences) > 0: diff --git a/ndsl/viz/__init__.py b/ndsl/viz/__init__.py new file mode 100644 index 00000000..38f61a6f --- /dev/null +++ b/ndsl/viz/__init__.py @@ -0,0 +1,4 @@ +from .cube_sphere import plot_cube_sphere + + +__all__ = ["plot_cube_sphere"] diff --git a/ndsl/viz/cube_sphere.py b/ndsl/viz/cube_sphere.py new file mode 100644 index 00000000..09018d1f --- /dev/null +++ b/ndsl/viz/cube_sphere.py @@ -0,0 +1,36 @@ +import numpy as np +from cartopy import crs as ccrs +from matplotlib import pyplot as plt + +from ndsl import Quantity, ndsl_log +from ndsl.comm.communicator import Communicator +from ndsl.grid import GridData +from ndsl.viz.fv3 import pcolormesh_cube + + +def plot_cube_sphere( + quantity: Quantity, + k_level: int, + comm: Communicator, + grid_data: GridData, + save_to_path: str, +): + if len(quantity.shape) < 2 or len(quantity.shape) > 3: + ndsl_log.error( + f"[Plot Cube] Can't plot quantity with shape == {quantity.shape}" + ) + return + + data = comm.gather(quantity) + lat = comm.gather(grid_data.lat) + lon = comm.gather(grid_data.lon) + + if comm.rank == 0: + fig, ax = plt.subplots(1, 1, subplot_kw={"projection": ccrs.Robinson()}) + pcolormesh_cube( + lat.view[:] * 180.0 / np.pi, + lon.view[:] * 180.0 / np.pi, + data.view[:] if len(data.shape) == 3 else data.view[:, :, :, k_level], + ax=ax, + ) + fig.savefig(save_to_path) diff --git a/ndsl/viz/fv3/README.md b/ndsl/viz/fv3/README.md new file mode 100644 index 00000000..f6623c02 --- /dev/null +++ b/ndsl/viz/fv3/README.md @@ -0,0 +1,14 @@ +# Acknowledgment + +This code was lifted from and developped by AI2 under the MIT license (see below). + +## MIT License + +The MIT License (MIT) +Copyright (c) 2019, The Allen Institute for Artificial Intelligence + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/ndsl/viz/fv3/__init__.py b/ndsl/viz/fv3/__init__.py new file mode 100644 index 00000000..8f1ff012 --- /dev/null +++ b/ndsl/viz/fv3/__init__.py @@ -0,0 +1,41 @@ +from ._constants import ( + COORD_X_CENTER, + COORD_X_OUTER, + COORD_Y_CENTER, + COORD_Y_OUTER, + VAR_LAT_CENTER, + VAR_LAT_OUTER, + VAR_LON_CENTER, + VAR_LON_OUTER, +) +from ._plot_cube import pcolormesh_cube, plot_cube +from ._plot_diagnostics import plot_diurnal_cycle, plot_time_series +from ._plot_helpers import infer_cmap_params +from ._styles import use_colorblind_friendly_style, wong_palette +from ._timestep_histograms import ( + plot_daily_and_hourly_hist, + plot_daily_hist, + plot_hourly_hist, +) + + +__all__ = [ + "plot_daily_and_hourly_hist", + "plot_daily_hist", + "plot_hourly_hist", + "plot_cube", + "pcolormesh_cube", + "plot_diurnal_cycle", + "plot_time_series", + "infer_cmap_params", + "use_colorblind_friendly_style", + "wong_palette", + "COORD_X_CENTER", + "COORD_Y_CENTER", + "COORD_X_OUTER", + "COORD_Y_OUTER", + "VAR_LON_CENTER", + "VAR_LAT_CENTER", + "VAR_LON_OUTER", + "VAR_LAT_OUTER", +] diff --git a/ndsl/viz/fv3/_constants.py b/ndsl/viz/fv3/_constants.py new file mode 100644 index 00000000..4eb699a3 --- /dev/null +++ b/ndsl/viz/fv3/_constants.py @@ -0,0 +1,9 @@ +COORD_X_CENTER = "x" +COORD_X_OUTER = "x_interface" +COORD_Y_CENTER = "y" +COORD_Y_OUTER = "y_interface" +VAR_LON_CENTER = "lon" +VAR_LAT_CENTER = "lat" +VAR_LON_OUTER = "lonb" +VAR_LAT_OUTER = "latb" +INIT_TIME_DIM = "initialization_time" diff --git a/ndsl/viz/fv3/_masking.py b/ndsl/viz/fv3/_masking.py new file mode 100644 index 00000000..80280e15 --- /dev/null +++ b/ndsl/viz/fv3/_masking.py @@ -0,0 +1,117 @@ +import numpy as np + + +def _mask_antimeridian_quads(lonb: np.ndarray, central_longitude: float): + """Computes mask of cubed-sphere tile grid quadrilaterals bisected by a + projection system's antimeridian, in order to avoid cartopy plotting + artifacts + + Args: + lonb (np.ndarray): + Array of grid edge longitudes, of dimensions (npy + 1, npx + 1, + tile) + central_longitude (float): + Central longitude from which the antimeridian is computed + + Returns: + mask (np.ndarray): + Boolean array of grid centers, False = excluded, of dimensions + (npy, npx, tile) + + + Example: + masked_array = np.where( + mask_antimeridian_quads(lonb, central_longitude), + array, + np.nan + ) + """ + antimeridian = (central_longitude + 180.0) % 360.0 + mask = np.full([lonb.shape[0] - 1, lonb.shape[1] - 1, lonb.shape[2]], True) + for tile in range(6): + tile_lonb = lonb[:, :, tile] + tile_mask = mask[:, :, tile] + for ix in range(tile_lonb.shape[0] - 1): + for iy in range(tile_lonb.shape[1] - 1): + vertex_indices = ([ix, ix + 1, ix, ix + 1], [iy, iy, iy + 1, iy + 1]) + vertices = tile_lonb[vertex_indices] + if ( + sum(_periodic_equal_or_less_than(vertices, antimeridian)) != 4 + and sum(_periodic_greater_than(vertices, antimeridian)) != 4 + and sum((_periodic_difference(vertices, antimeridian) < 90.0)) == 4 + ): + tile_mask[ix, iy] = False + mask[:, :, tile] = tile_mask + + return mask + + +def _periodic_equal_or_less_than(x1, x2, period=360.0): + """Compute whether x1 is less than or equal to x2, where + the difference between the two is the shortest distance on a periodic domain + + Args: + x1 (float), x2 (float): + Values to be compared + Period (float, optional): + Period of domain. Default 360 (degrees). + + Returns: + Less_than_or_equal (Bool): + Whether x1 is less than or equal to x2 + """ + return np.where( + np.abs(x1 - x2) <= period / 2.0, + np.where(x1 - x2 <= 0, True, False), + np.where( + x1 - x2 >= 0, + np.where(x1 - (x2 + period) <= 0, True, False), + np.where((x1 + period) - x2 <= 0, True, False), + ), + ) + + +def _periodic_greater_than(x1, x2, period=360.0): + """Compute whether x1 is greater than x2, where + the difference between the two is the shortest distance on a periodic domain + + Args: + x1 (float), x2 (float): + Values to be compared + Period (float, optional): + Period of domain. Default 360 (degrees). + + Returns: + Greater_than (Bool): + Whether x1 is greater than x2 + """ + return np.where( + np.abs(x1 - x2) <= period / 2.0, + np.where(x1 - x2 > 0, True, False), + np.where( + x1 - x2 >= 0, + np.where(x1 - (x2 + period) > 0, True, False), + np.where((x1 + period) - x2 > 0, True, False), + ), + ) + + +def _periodic_difference(x1, x2, period=360.0): + """Compute difference between x1 and x2, where + the difference is the shortest distance on a periodic domain + + Args: + x1 (float), x2 (float): + Values to be compared + Period (float, optional): + Period of domain. Default 360 (degrees). + + Returns: + Difference (float): + Difference between x1 and x2 + """ + return np.where( + np.abs(x1 - x2) <= period / 2.0, + x1 - x2, + np.where(x1 - x2 >= 0, x1 - (x2 + period), (x1 + period) - x2), + ) diff --git a/ndsl/viz/fv3/_plot_cube.py b/ndsl/viz/fv3/_plot_cube.py new file mode 100644 index 00000000..8942d494 --- /dev/null +++ b/ndsl/viz/fv3/_plot_cube.py @@ -0,0 +1,621 @@ +from __future__ import annotations + +import os +import warnings +from functools import partial + +import cartopy +import numpy as np +import xarray as xr +from cartopy import crs as ccrs +from matplotlib import pyplot as plt + +from ._constants import ( + COORD_X_CENTER, + COORD_X_OUTER, + COORD_Y_CENTER, + COORD_Y_OUTER, + VAR_LAT_CENTER, + VAR_LAT_OUTER, + VAR_LON_CENTER, + VAR_LON_OUTER, +) +from ._masking import _mask_antimeridian_quads +from ._plot_helpers import ( + _align_grid_var_dims, + _align_plot_var_dims, + _get_var_label, + infer_cmap_params, +) +from .grid_metadata import GridMetadata, GridMetadataFV3, GridMetadataScream + + +if os.getenv("CARTOPY_EXTERNAL_DOWNLOADER") != "natural_earth": + # workaround to host our own global-scale coastline shapefile instead + # of unreliable cartopy source + cartopy.config["downloaders"][("shapefiles", "natural_earth")].url_template = ( + "https://raw.githubusercontent.com/ai2cm/" + "vcm-ml-example-data/main/fv3net/fv3viz/coastline_shapefiles/" + "{resolution}_{category}/ne_{resolution}_{name}.zip" + ) + +WRAPPER_GRID_METADATA = GridMetadataFV3( + COORD_X_CENTER, + COORD_Y_CENTER, + COORD_X_OUTER, + COORD_Y_OUTER, + "tile", + VAR_LON_CENTER, + VAR_LON_OUTER, + VAR_LAT_CENTER, + VAR_LAT_OUTER, +) + + +def plot_cube( + ds: xr.Dataset, + var_name: str, + grid_metadata: GridMetadata = WRAPPER_GRID_METADATA, + plotting_function: str = "pcolormesh", + ax: plt.axes = None, + row: str = None, + col: str = None, + col_wrap: int = None, + projection: ccrs.Projection = None, + colorbar: bool = True, + cmap_percentiles_lim: bool = True, + cbar_label: str = None, + coastlines: bool = True, + coastlines_kwargs: dict = None, + **kwargs, +): + """Plots an xr.DataArray containing tiled cubed sphere gridded data + onto a global map projection, with optional faceting of additional dims + + Args: + ds: + Dataset containing variable to plotted, along with the grid + variables defining cell center latitudes and longitudes and the + cell bounds latitudes and longitudes, which must share common + dimension names + var_name: + name of the data variable in `ds` to be plotted + grid_metadata: + a vcm.cubedsphere.GridMetadata data structure that + defines the names of plot and grid variable dimensions and the names + of the grid variables themselves; defaults to those used by the + fv3gfs Python wrapper (i.e., 'x', 'y', 'x_interface', 'y_interface' and + 'lat', 'lon', 'latb', 'lonb') + plotting_function: + Name of matplotlib 2-d plotting function. Available + options are "pcolormesh", "contour", and "contourf". Defaults to + "pcolormesh". + ax: + Axes onto which the map should be plotted; must be created with + a cartopy projection argument. If not supplied, axes are generated + with a projection. If ax is suppled, faceting is disabled. + row: + Name of diemnsion to be faceted along subplot rows. Must not be a + tile, lat, or lon dimension. Defaults to no row facets. + col: + Name of diemnsion to be faceted along subplot columns. Must not be + a tile, lat, or lon dimension. Defaults to no column facets. + col_wrap: + If only one of `col`, `row` is specified, number of columns to plot + before wrapping onto next row. Defaults to None, i.e. no limit. + projection: + Cartopy projection object to be used in creating axes. Ignored + if cartopy geo-axes are supplied. Defaults to Robinson projection. + colorbar: + Flag for whether to plot a colorbar. Defaults to True. + cmap_percentiles_lim: + If False, use the absolute min/max to set color limits. + If True, use 2/98 percentile values. + cbar_label: + If provided, use this as the color bar label. + coastlines: + Whether to plot coastlines on map. Default True. + coastlines_kwargs: + Dict of arguments to be passed to cartopy axes's + `coastline` function if `coastlines` flag is set to True. + **kwargs: Additional keyword arguments to be passed to the plotting function. + + Returns: + figure (plt.Figure): + matplotlib figure object onto which axes grid is created + axes (np.ndarray): + Array of `plt.axes` objects assocated with map subplots if faceting; + otherwise array containing single axes object. + handles (list): + List or nested list of matplotlib object handles associated with + map subplots if faceting; otherwise list of single object handle. + cbar (plt.colorbar): + object handle associated with figure, if `colorbar` + arg is True, else None. + facet_grid (xarray.plot.facetgrid): + xarray plotting facetgrid for multi-axes case. In single-axes case, + retunrs None. + + Example: + # plot diag winds at two times + fig, axes, hs, cbar, facet_grid = plot_cube( + diag_ds.isel(time = slice(2, 4)), + 'VGRD850', + plotting_function = "contourf", + col = "time", + coastlines = True, + colorbar = True, + vmin = -20, + vmax = 20 + ) + """ + + mappable_ds = _mappable_var(ds, var_name, grid_metadata) + array = mappable_ds[var_name].values + + kwargs["vmin"], kwargs["vmax"], kwargs["cmap"] = infer_cmap_params( + array, + vmin=kwargs.get("vmin"), + vmax=kwargs.get("vmax"), + cmap=kwargs.get("cmap"), + robust=cmap_percentiles_lim, + ) + if isinstance(grid_metadata, GridMetadataFV3): + _plot_func_short = partial( + _plot_cube_axes, + lat=mappable_ds.lat.values, + lon=mappable_ds.lon.values, + latb=mappable_ds.latb.values, + lonb=mappable_ds.lonb.values, + plotting_function=plotting_function, + **kwargs, + ) + elif isinstance(grid_metadata, GridMetadataScream): + _plot_func_short = partial( + _plot_scream_axes, + lat=mappable_ds.lat.values, + lon=mappable_ds.lon.values, + plotting_function=plotting_function, + **kwargs, + ) + else: + assert ValueError( + f"grid_metadata needs to be either GridMetadataFV3 or GridMetadataScream, \ + but got {type(grid_metadata)}" + ) + + projection = ccrs.Robinson() if not projection else projection + + if ax is None and (row or col): + # facets + facet_grid = xr.plot.FacetGrid( + data=mappable_ds, + row=row, + col=col, + col_wrap=col_wrap, + subplot_kws={"projection": projection}, + ) + facet_grid = facet_grid.map(_plot_func_short, var_name) + fig = facet_grid.fig + axes = facet_grid.axes + handles = facet_grid._mappables + else: + # single axes + if ax is None: + fig, ax = plt.subplots(1, 1, subplot_kw={"projection": projection}) + else: + fig = ax.figure + handle = _plot_func_short(array, ax=ax) + axes = np.array(ax) + handles = [handle] + facet_grid = None + + if coastlines: + coastlines_kwargs = dict() if not coastlines_kwargs else coastlines_kwargs + [ax.coastlines(**coastlines_kwargs) for ax in axes.flatten()] + + if colorbar: + if row or col: + fig.subplots_adjust( + bottom=0.1, top=0.9, left=0.1, right=0.8, wspace=0.02, hspace=0.02 + ) + cb_ax = fig.add_axes([0.83, 0.1, 0.02, 0.8]) + else: + fig.subplots_adjust(wspace=0.25) + cb_ax = ax.inset_axes([1.05, 0, 0.02, 1]) + cbar = plt.colorbar(handles[0], cax=cb_ax, extend="both") + cbar.set_label(cbar_label or _get_var_label(ds[var_name].attrs, var_name)) + else: + cbar = None + + return fig, axes, handles, cbar, facet_grid + + +def _mappable_var( + ds: xr.Dataset, + var_name: str, + grid_metadata: GridMetadata = WRAPPER_GRID_METADATA, +): + """Converts a dataset into a format for plotting across cubed-sphere tiles by + checking and ordering its grid variable and plotting variable dimensions + + Args: + ds: + Dataset containing the variable to be plotted, along with grid variables. + var_name: + Name of variable to be plotted. + grid_metadata: + vcm.cubedsphere.GridMetadata object describing dim + names and grid variable names + Returns: + ds (xr.Dataset): Dataset containing variable to be plotted as well as grid + variables, all of whose dimensions are ordered for plotting. + """ + mappable_ds = xr.Dataset() + for var, dims in grid_metadata.coord_vars.items(): + mappable_ds[var] = _align_grid_var_dims(ds[var], required_dims=dims) + if isinstance(grid_metadata, GridMetadataFV3): + var_da = _align_plot_var_dims(ds[var_name], grid_metadata.y, grid_metadata.x) + return mappable_ds.merge(var_da) + elif isinstance(grid_metadata, GridMetadataScream): + return mappable_ds.merge(ds[var_name]) + + +def pcolormesh_cube( + lat: np.ndarray, lon: np.ndarray, array: np.ndarray, ax: plt.axes = None, **kwargs +): + """Plots tiled cubed sphere. This function applies nan to gridcells which cross + the antimeridian, and then iteratively plots rectangles of array which avoid nan + gridcells. This is done to avoid artifacts when plotting gridlines with the + `edgecolor` argument. In comparison to :py:func:`plot_cube`, this function takes + np.ndarrays of the lat and lon cell corners and the variable to be plotted + at cell centers, and makes only one plot on an optionally specified axes object. + + Args: + lat: + Array of latitudes with dimensions (tile, ny + 1, nx + 1). + Should be given at cell corners. + lon: + Array of longitudes with dimensions (tile, ny + 1, nx + 1). + Should be given at cell corners. + array: + Array of variables values at cell centers, of dimensions (tile, ny, nx) + ax: + Matplotlib geoaxes object onto which plotting function will be + called. Default None uses current axes. + **kwargs: + Keyword arguments to be passed to plotting function. + + Returns: + p_handle (obj): + matplotlib object handle associated with a segment of the map subplot + """ + all_handles = _pcolormesh_cube_all_handles(lat, lon, array, ax=ax, **kwargs) + return all_handles[-1] + + +def _pcolormesh_cube_all_handles( + lat: np.ndarray, lon: np.ndarray, array: np.ndarray, ax: plt.axes = None, **kwargs +): + if lat.shape != lon.shape: + raise ValueError("lat and lon should have the same shape") + if ax is None: + ax = plt.gca() + central_longitude = ax.projection.proj4_params["lon_0"] + array = np.where( + _mask_antimeridian_quads(lon.T, central_longitude), array.T, np.nan + ).T + # oddly a PlateCarree transform seems to be needed here even for non-PlateCarree + # projections?? very puzzling, but it seems to be the case. + kwargs["transform"] = kwargs.get("transform", ccrs.PlateCarree()) + kwargs["vmin"] = kwargs.get("vmin", np.nanmin(array)) + kwargs["vmax"] = kwargs.get("vmax", np.nanmax(array)) + + def plot(x, y, array): + return ax.pcolormesh(x, y, array, **kwargs) + + handles = _apply_to_non_non_nan_segments( + plot, lat, center_longitudes(lon, central_longitude), array + ) + return handles + + +class UpdateablePColormesh: + def __init__(self, lat, lon, array: np.ndarray, ax: plt.axes = None, **kwargs): + self.handles = _pcolormesh_cube_all_handles(lat, lon, array, ax=ax, **kwargs) + plt.colorbar(self.handles[-1], ax=ax) + self.lat = lat + self.lon = lon + self.ax = ax + + def update(self, array): + central_longitude = self.ax.projection.proj4_params["lon_0"] + array = np.where( + _mask_antimeridian_quads(self.lon.T, central_longitude), array.T, np.nan + ).T + + iter_handles = iter(self.handles) + + def update_handle(x, y, array): + handle = next(iter_handles) + handle.set_array(array.ravel()) + + _apply_to_non_non_nan_segments(update_handle, self.lat, self.lon, array) + + +def _apply_to_non_non_nan_segments(func, lat, lon, array): + """ + Applies func to disjoint rectangular segments of array covering all non-nan values. + + Args: + func: + Function to be applied to non-nan segments of array. + lat: + Array of latitudes with dimensions (tile, ny + 1, nx + 1). + Should be given at cell corners. + lon: + Array of longitudes with dimensions (tile, ny + 1, nx + 1). + Should be given at cell corners. + array: + Array of variables values at cell centers, of dimensions (tile, ny, nx) + + Returns: + list of return values of func + """ + all_handles = [] + for tile in range(array.shape[0]): + x = lon[tile, :, :] + y = lat[tile, :, :] + for x_plot, y_plot, array_plot in _segment_plot_inputs(x, y, array[tile, :, :]): + all_handles.append(func(x_plot, y_plot, array_plot)) + return all_handles + + +def _segment_plot_inputs(x, y, masked_array): + """Takes in two arrays at corners of grid cells and an array at grid cell centers + which may contain NaNs. Yields 3-tuples of rectangular segments of + these arrays which cover all non-nan points without duplicates, and don't contain + NaNs. + """ + is_nan = np.isnan(masked_array) + if np.sum(is_nan) == 0: # contiguous section, just plot it + if np.product(masked_array.shape) > 0: + yield (x, y, masked_array) + else: + x_nans = np.sum(is_nan, axis=1) / is_nan.shape[1] + y_nans = np.sum(is_nan, axis=0) / is_nan.shape[0] + if x_nans.max() >= y_nans.max(): # most nan-y line is in first dimension + i_split = x_nans.argmax() + if x_nans[i_split] == 1.0: # split cleanly along line + yield from _segment_plot_inputs( + x[: i_split + 1, :], + y[: i_split + 1, :], + masked_array[:i_split, :], + ) + yield from _segment_plot_inputs( + x[i_split + 1 :, :], + y[i_split + 1 :, :], + masked_array[i_split + 1 :, :], + ) + else: + # split to create segments of complete nans + # which subsequent recursive calls will split on and remove + i_start = 0 + i_end = 1 + while i_end < is_nan.shape[1]: + while ( + i_end < is_nan.shape[1] + and is_nan[i_split, i_start] == is_nan[i_split, i_end] + ): + i_end += 1 + # we have a largest-possible contiguous segment of nans/not nans + yield from _segment_plot_inputs( + x[:, i_start : i_end + 1], + y[:, i_start : i_end + 1], + masked_array[:, i_start:i_end], + ) + i_start = i_end # start the next segment + else: + # put most nan-y line in first dimension + # so the first part of this if block catches it + yield from _segment_plot_inputs( + x.T, + y.T, + masked_array.T, + ) + + +def center_longitudes(lon_array, central_longitude): + return np.where( + lon_array < (central_longitude + 180.0) % 360.0, + lon_array, + lon_array - 360.0, + ) + + +def _validate_cube_shape(lat_shape, lon_shape, latb_shape, lonb_shape, array_shape): + if (lon_shape[-1] != 6) or (lat_shape[-1] != 6) or (array_shape[-1] != 6): + raise ValueError( + """Last axis of each array must have six elements for + cubed-sphere tiles.""" + ) + + if ( + (lon_shape[0] != lat_shape[0]) + or (lat_shape[0] != array_shape[0]) + or (lon_shape[1] != lat_shape[1]) + or (lat_shape[1] != array_shape[1]) + ): + raise ValueError( + """Horizontal axis lengths of lat and lon must be equal to + those of array.""" + ) + + if (len(lonb_shape) != 3) or (len(latb_shape) != 3) or (len(array_shape) != 3): + raise ValueError("Lonb, latb, and data_var each must be 3-dimensional.") + + if (lonb_shape[-1] != 6) or (latb_shape[-1] != 6) or (array_shape[-1] != 6): + raise ValueError( + "Tile axis of each array must have six elements for cubed-sphere tiles." + ) + + if ( + (lonb_shape[0] != latb_shape[0]) + or (latb_shape[0] != (array_shape[0] + 1)) + or (lonb_shape[1] != latb_shape[1]) + or (latb_shape[1] != (array_shape[1] + 1)) + ): + raise ValueError( + """Horizontal axis lengths of latb and lonb + must be one greater than those of array.""" + ) + + if (len(lon_shape) != 3) or (len(lat_shape) != 3) or (len(array_shape) != 3): + raise ValueError("Lon, lat, and data_var each must be 3-dimensional.") + + +def _plot_cube_axes( + array: np.ndarray, + lat: np.ndarray, + lon: np.ndarray, + latb: np.ndarray, + lonb: np.ndarray, + plotting_function: str, + ax: plt.axes = None, + **kwargs, +): + """Plots tiled cubed sphere for a given subplot axis, + using np.ndarrays for all data + + Args: + array: + Array of variables values at cell centers, of dimensions (npy, npx, + tile) + lat: + Array of latitudes of cell centers, of dimensions (npy, npx, tile) + lon: + Array of longitudes of cell centers, of dimensions (npy, npx, tile) + latb: + Array of latitudes of cell edges, of dimensions (npy + 1, npx + 1, + tile) + lonb: + Array of longitudes of cell edges, of dimensions (npy + 1, npx + 1, + tile) + plotting_function: + Name of matplotlib 2-d plotting function. Available options + are "pcolormesh", "contour", and "contourf". + ax: + Matplotlib geoaxes object onto which plotting function will be + called. Default None uses current axes. + **kwargs: + Keyword arguments to be passed to plotting function. + + Returns: + p_handle (obj): + matplotlib object handle associated with map subplot + """ + _validate_cube_shape(lon.shape, lat.shape, lonb.shape, latb.shape, array.shape) + + if ax is None: + ax = plt.gca() + + if plotting_function in ["pcolormesh", "contour", "contourf"]: + _plotting_function = getattr(ax, plotting_function) + else: + raise ValueError( + """Plotting functions only include pcolormesh, contour, + and contourf.""" + ) + + if "vmin" not in kwargs: + kwargs["vmin"] = np.nanmin(array) + + if "vmax" not in kwargs: + kwargs["vmax"] = np.nanmax(array) + + if np.isnan(kwargs["vmin"]): + kwargs["vmin"] = -0.1 + if np.isnan(kwargs["vmax"]): + kwargs["vmax"] = 0.1 + + if plotting_function != "pcolormesh": + if "levels" not in kwargs: + kwargs["n_levels"] = 11 if "n_levels" not in kwargs else kwargs["n_levels"] + kwargs["levels"] = np.linspace( + kwargs["vmin"], kwargs["vmax"], kwargs["n_levels"] + ) + + central_longitude = ax.projection.proj4_params["lon_0"] + + masked_array = np.where( + _mask_antimeridian_quads(lonb, central_longitude), array, np.nan + ) + + for tile in range(6): + if plotting_function == "pcolormesh": + x = lonb[:, :, tile] + y = latb[:, :, tile] + else: + # contouring + x = center_longitudes(lon[:, :, tile], central_longitude) + y = lat[:, :, tile] + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + p_handle = _plotting_function( + x, y, masked_array[:, :, tile], transform=ccrs.PlateCarree(), **kwargs + ) + + ax.set_global() + + return p_handle + + +def _plot_scream_axes( + array: np.ndarray, + lat: np.ndarray, + lon: np.ndarray, + plotting_function: str, + ax: plt.axes = None, + **kwargs, +): + if ax is None: + ax = plt.gca() + if plotting_function in ["pcolormesh", "contour", "contourf"]: + mapping = { + "pcolormesh": "tripcolor", + "contour": "tricontour", + "contourf": "tricontourf", + } + _plotting_function = getattr(ax, mapping[plotting_function]) + else: + raise ValueError( + """Plotting functions only include pcolormesh, contour, + and contourf.""" + ) + if "vmin" not in kwargs: + kwargs["vmin"] = np.nanmin(array) + + if "vmax" not in kwargs: + kwargs["vmax"] = np.nanmax(array) + + if np.isnan(kwargs["vmin"]): + kwargs["vmin"] = -0.1 + if np.isnan(kwargs["vmax"]): + kwargs["vmax"] = 0.1 + + if plotting_function != "pcolormesh": + if "levels" not in kwargs: + kwargs["n_levels"] = 11 if "n_levels" not in kwargs else kwargs["n_levels"] + kwargs["levels"] = np.linspace( + kwargs["vmin"], kwargs["vmax"], kwargs["n_levels"] + ) + lon = np.where(lon > 180, lon - 360, lon) + p_handle = _plotting_function( + lon.flatten(), + lat.flatten(), + array.flatten(), + transform=ccrs.PlateCarree(), + **kwargs, + ) + ax.set_global() + return p_handle diff --git a/ndsl/viz/fv3/_plot_diagnostics.py b/ndsl/viz/fv3/_plot_diagnostics.py new file mode 100644 index 00000000..9c102759 --- /dev/null +++ b/ndsl/viz/fv3/_plot_diagnostics.py @@ -0,0 +1,125 @@ +""" +Some helper functions for creating diagnostic plots. + +These are specifically for usage in fv3net. + +Uses the general purpose plotting functions in +fv3viz such as plot_cube. + + +""" +import os + +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from scipy.stats import binned_statistic + +from ._constants import COORD_X_CENTER, COORD_Y_CENTER, INIT_TIME_DIM + + +STACK_DIMS = ["tile", INIT_TIME_DIM, COORD_X_CENTER, COORD_Y_CENTER] + + +def _mask_nan_lines(x, y): + nan_mask = np.isfinite(y) + return np.array(x)[nan_mask], np.array(y)[nan_mask] + + +def plot_diurnal_cycle( + merged_ds, var, stack_dims=STACK_DIMS, num_time_bins=24, title=None, ylabel=None +): + """ + + Args: + merged_ds (xr.dataset): + can either provide a merged dataset with a "dataset" dim + that will be used to plot separate lines for each variable, or a + single dataset with no "dataset" dim + var (str): + name of variable to plot + num_time_bins (int): + number of bins per day + title(str): + optional plot title + + Returns: + matplotlib figure + """ + plt.clf() + fig = plt.figure() + if "dataset" not in merged_ds.dims: + merged_ds = xr.concat([merged_ds], "dataset") + for label in merged_ds["dataset"].values: + # TODO this function mixes computation, plotting, and implicitly + # I/O via deferred dask calculations. + # and should be extensively refactored. + ds = merged_ds.sel(dataset=label) + if len([dim for dim in ds.dims if dim in stack_dims]) > 1: + ds = ds.stack(sample=stack_dims).dropna("sample") + local_time = ds["local_time"].values.flatten() + data_var = ds[var].values.flatten() + bin_means, bin_edges, _ = binned_statistic( + local_time, data_var, bins=num_time_bins + ) + bin_centers = [ + 0.5 * (bin_edges[i] + bin_edges[i + 1]) for i in range(num_time_bins) + ] + bin_centers, bin_means = _mask_nan_lines(bin_centers, bin_means) + plt.plot(bin_centers, bin_means, label=label) + plt.xlabel("local_time [hr]") + plt.ylabel(ylabel or var) + plt.legend(loc="lower left") + if title: + plt.title(title) + return fig + + +# function below here are from the previous design and probably outdated +# leaving for now as it might be adapted to work with new design + + +def plot_time_series( + ds, + vars_to_plot, + output_dir, + plot_filename="time_series.png", + time_var=INIT_TIME_DIM, + xlabel=None, + ylabel=None, + title=None, +): + """Plot one or more variables as a time series. + + Args: + ds (xr.dataset): + dataset containing time series variables to plot + vars_to_plot(list[str]): + data variables to plot + output_dir (str): + output directory to save figure into + plot_filename (str): + filename to save figure to + time_var (str): + name of time dimension + xlabel (str): + x axis label + ylabel (str): + y axis label + title (str): + plot title + Returns: + matplotlib figure + """ + plt.clf() + for var in vars_to_plot: + time = ds[time_var].values + plt.plot(time, ds[var].values, label=var) + if xlabel: + plt.xlabel(xlabel) + if ylabel: + plt.ylabel(ylabel) + plt.legend() + if title: + plt.title(title) + plt.savefig(os.path.join(output_dir, plot_filename)) diff --git a/ndsl/viz/fv3/_plot_helpers.py b/ndsl/viz/fv3/_plot_helpers.py new file mode 100644 index 00000000..75da6983 --- /dev/null +++ b/ndsl/viz/fv3/_plot_helpers.py @@ -0,0 +1,172 @@ +import textwrap +from typing import Optional, Tuple + +import numpy as np + + +def _align_grid_var_dims(da, required_dims): + missing_dims = set(required_dims).difference(da.dims) + if len(missing_dims) > 0: + raise ValueError( + f"Grid variable {da.name} missing dims {missing_dims}. " + "Incompatible grid metadata may have been passed." + ) + redundant_dims = set(da.dims).difference(required_dims) + if len(redundant_dims) == 0: + da_out = da.transpose(*required_dims) + else: + redundant_dims_index = {dim: 0 for dim in redundant_dims} + da_out = ( + da.isel(redundant_dims_index) + .drop_vars(redundant_dims, errors="ignore") + .transpose(*required_dims) + ) + return da_out + + +def _align_plot_var_dims(da, coord_y_center, coord_x_center): + first_dims = [coord_y_center, coord_x_center, "tile"] + missing_dims = set(first_dims).difference(set(da.dims)) + if len(missing_dims) > 0: + raise ValueError( + f"Data array to be plotted {da.name} missing dims {missing_dims}. " + "Incompatible grid metadata may have been passed." + ) + rest = set(da.dims).difference(set(first_dims)) + xpose_dims = first_dims + list(rest) + return da.transpose(*xpose_dims) + + +def _min_max_from_percentiles(x, min_percentile=2, max_percentile=98): + """Use +/- small percentile to determine bounds for colorbar. Avoids the case + where an outlier in the data causes the color scale to be washed out. + + Args: + x: array of data values + min_percentile: lower percentile to use instead of absolute min + max_percentile: upper percentile to use instead of absolute max + + Returns: + Tuple of values at min_percentile, max_percentile + """ + x = np.array(x).flatten() + x = x[~np.isnan(x)] + if len(x) == 0: + # all values of x are equal to np.nan + xmin, xmax = np.nan, np.nan + else: + xmin, xmax = np.percentile(x, [min_percentile, max_percentile]) + return xmin, xmax + + +def _infer_color_limits( + xmin: float, xmax: float, vmin: float = None, vmax: float = None, cmap: str = None +): + """ "auto-magical" handling of color limits and colormap if not supplied by + user + + Args: + xmin (float): + Smallest value in data to be plotted + xmax (float): + Largest value in data to be plotted + vmin (float, optional): + Colormap minimum value. Default None. + vmax (float, optional): + Colormap minimum value. Default None. + cmap (str, optional): + Name of colormap. Default None. + + Returns: + vmin (float) + Inferred colormap minimum value if not supplied, or user value if + supplied. + vmax (float) + Inferred colormap maximum value if not supplied, or user value if + supplied. + cmap (str) + Inferred colormap if not supplied, or user value if supplied. + + Example: + # choose limits and cmap for data spanning 0 + >>>> _infer_color_limits(-10, 20) + (-20, 20, 'RdBu_r') + """ + if vmin is None and vmax is None: + if xmin < 0 and xmax > 0: + cmap = "RdBu_r" if not cmap else cmap + vabs_max = np.max([np.abs(xmin), np.abs(xmax)]) + vmin, vmax = (-vabs_max, vabs_max) + else: + vmin, vmax = xmin, xmax + cmap = "viridis" if not cmap else cmap + elif vmin is None: + if xmin < 0 and vmax > 0: + vmin = -vmax + cmap = "RdBu_r" if not cmap else cmap + else: + vmin = xmin + cmap = "viridis" if not cmap else cmap + elif vmax is None: + if xmax > 0 and vmin < 0: + vmax = -vmin + cmap = "RdBu_r" if not cmap else cmap + else: + vmax = xmax + cmap = "viridis" if not cmap else cmap + elif not cmap: + cmap = "RdBu_r" if vmin == -vmax else "viridis" + + return vmin, vmax, cmap + + +def _get_var_label(attrs: dict, var_name: str, max_line_length: int = 30): + """Get the label for the variable on the colorbar + + Args: + attrs (dict): + Variable aattribute dict + var_name (str): + Short name of variable + max_line_length (int, optional): + Max number of characters on each line of returned label. + Defaults to 30. + + Returns: + var_label (str) + long_name [units], var_name [units] or var_name depending on attrs + """ + if "long_name" in attrs: + var_label = attrs["long_name"] + else: + var_label = var_name + if "units" in attrs: + var_label += f" [{attrs['units']}]" + return "\n".join(textwrap.wrap(var_label, max_line_length)) + + +def infer_cmap_params( + data: np.ndarray, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + cmap: Optional[str] = None, + robust: bool = False, +) -> Tuple[float, float, str]: + """Determine useful colorbar limits and cmap for given data. + + Args: + data: The data to be plotted. + vmin: Optional minimum for colorbar. + vmax: Optional maximum for colorbar. + cmap: Optional colormap to use. + robust: If true, use 2nd and 98th percentiles for colorbar limits. + + Returns: + Tuple of (vmin, vmax, cmap). + """ + if robust: + xmin, xmax = _min_max_from_percentiles(data) + else: + xmin, xmax = np.nanmin(data), np.nanmax(data) + vmin, vmax, cmap = _infer_color_limits(xmin, xmax, vmin, vmax, cmap) + return vmin, vmax, cmap diff --git a/ndsl/viz/fv3/_styles.py b/ndsl/viz/fv3/_styles.py new file mode 100644 index 00000000..2eb221ea --- /dev/null +++ b/ndsl/viz/fv3/_styles.py @@ -0,0 +1,18 @@ +import matplotlib.pyplot as plt +from cycler import cycler + + +# adapted from https://davidmathlogic.com/colorblind +wong_palette = [ + "#56B4E9", + "#E69F00", + "#009E73", + "#0072B2", + "#D55E00", + "#CC79A7", + "#F0E442", # put yellow last, remove black +] + + +def use_colorblind_friendly_style(): + plt.rcParams["axes.prop_cycle"] = cycler("color", wong_palette) diff --git a/ndsl/viz/fv3/_timestep_histograms.py b/ndsl/viz/fv3/_timestep_histograms.py new file mode 100644 index 00000000..38985478 --- /dev/null +++ b/ndsl/viz/fv3/_timestep_histograms.py @@ -0,0 +1,36 @@ +import datetime +from typing import Sequence, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.axes import Axes + + +def plot_daily_and_hourly_hist( + time_list: Sequence[Union[datetime.datetime, np.datetime64]], +) -> plt.figure: + """Given a sequence of datetimes (anything that can be handled by pandas) create + and return 2-subplot figure with histograms of daily and hourly counts.""" + fig, axes = plt.subplots(1, 2, figsize=(8, 3)) + plot_daily_hist(axes[0], time_list) + plot_hourly_hist(axes[1], time_list) + fig.suptitle(f"total count: {len(time_list)}") + plt.tight_layout() + return fig + + +def plot_daily_hist(ax: Axes, time_list: Sequence[datetime.datetime]): + """Given list of datetimes, plot histogram of count per calendar day on ax""" + ser = pd.Series(time_list) + groupby_list = [ser.dt.year, ser.dt.month, ser.dt.day] + ser.groupby(groupby_list).count().plot(ax=ax, kind="bar", title="Daily count") + ax.set_ylabel("Count") + + +def plot_hourly_hist(ax: Axes, time_list: Sequence[datetime.datetime]): + """Given list of datetimes, plot histogram of count per UTC hour on ax""" + ser = pd.Series(time_list) + ser.groupby(ser.dt.hour).count().plot(ax=ax, kind="bar", title="Hourly count") + ax.set_ylabel("Count") + ax.set_xlabel("UTC hour") diff --git a/ndsl/viz/fv3/grid_metadata.py b/ndsl/viz/fv3/grid_metadata.py new file mode 100644 index 00000000..2171360e --- /dev/null +++ b/ndsl/viz/fv3/grid_metadata.py @@ -0,0 +1,47 @@ +import abc +import dataclasses + + +class GridMetadata(abc.ABC): + @property + @abc.abstractmethod + def coord_vars(self) -> dict: + ... + + +@dataclasses.dataclass +class GridMetadataFV3(GridMetadata): + x: str = "x" + y: str = "y" + x_interface: str = "x_interface" + y_interface: str = "y_interface" + tile: str = "tile" + lon: str = "lon" + lonb: str = "lonb" + lat: str = "lat" + latb: str = "latb" + + @property + def coord_vars(self): + coord_vars = { + self.lonb: [self.y_interface, self.x_interface, self.tile], + self.latb: [self.y_interface, self.x_interface, self.tile], + self.lon: [self.y, self.x, self.tile], + self.lat: [self.y, self.x, self.tile], + } + return coord_vars + + +@dataclasses.dataclass +class GridMetadataScream(GridMetadata): + ncol: str = "ncol" + lon: str = "lon" + lat: str = "lat" + + @property + def coord_vars(self): + coord_vars = { + self.lon: [self.ncol], + self.lat: [self.ncol], + } + return coord_vars diff --git a/setup.py b/setup.py index e9f1c2e6..238065c6 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ def local_pkg(name: str, relative_path: str) -> str: "dask", # for xarray "numpy==1.26.4", "matplotlib", # for plotting in boilerplate + "cartopy", # for plotting in ndsl.viz ]