From 7b1639f0a02eb99c9557a43d74ce280f7967786f Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 23 Jun 2025 12:35:42 +0200 Subject: [PATCH] Update pre-commit hooks --- .pre-commit-config.yaml | 10 ++--- ndsl/checkpointer/base.py | 3 +- ndsl/comm/comm_abc.py | 54 +++++++++---------------- ndsl/comm/partitioner.py | 5 ++- ndsl/dsl/dace/orchestration.py | 6 +-- ndsl/dsl/dace/wrapped_halo_exchange.py | 4 +- ndsl/dsl/stencil.py | 12 +++--- ndsl/dsl/typing.py | 2 +- ndsl/grid/geometry.py | 2 +- ndsl/grid/global_setup.py | 12 +++--- ndsl/grid/gnomonic.py | 25 ++++++------ ndsl/grid/stretch_transformation.py | 4 +- ndsl/initialization/allocator.py | 8 ++-- ndsl/logging.py | 6 +-- ndsl/monitor/protocol.py | 6 +-- ndsl/performance/collector.py | 9 ++--- ndsl/quantity/quantity.py | 8 ++-- ndsl/stencils/testing/savepoint.py | 9 ++--- ndsl/stencils/testing/test_translate.py | 6 +-- ndsl/testing/comparison.py | 12 ++---- ndsl/viz/fv3/_plot_cube.py | 2 +- ndsl/viz/fv3/_plot_diagnostics.py | 1 + ndsl/viz/fv3/grid_metadata.py | 3 +- setup.cfg | 8 +--- tests/mpi/test_mpi_all_reduce_sum.py | 2 +- tests/mpi/test_mpi_halo_update.py | 2 +- tests/mpi/test_mpi_mock.py | 2 +- tests/quantity/test_quantity.py | 2 +- tests/test_g2g_communication.py | 11 ++--- 29 files changed, 103 insertions(+), 133 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8df5e5ae..0ad9802d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,19 +3,19 @@ default_language_version: repos: - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 25.1.0 hooks: - id: black additional_dependencies: ["click==8.0.4"] - repo: https://github.com/pre-commit/mirrors-isort - rev: v5.4.2 + rev: v5.10.1 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 + rev: v1.16.1 hooks: - id: mypy name: mypy-ndsl @@ -27,14 +27,14 @@ repos: ndsl/ndsl/gt4py_utils.py | )$ - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v5.0.0 hooks: - id: check-toml - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/flake8 - rev: 3.9.2 + rev: 7.3.0 hooks: - id: flake8 name: flake8 diff --git a/ndsl/checkpointer/base.py b/ndsl/checkpointer/base.py index 8218bbfe..5b48b8b3 100644 --- a/ndsl/checkpointer/base.py +++ b/ndsl/checkpointer/base.py @@ -3,5 +3,4 @@ class Checkpointer(abc.ABC): @abc.abstractmethod - def __call__(self, savepoint_name, **kwargs): - ... + def __call__(self, savepoint_name, **kwargs): ... diff --git a/ndsl/comm/comm_abc.py b/ndsl/comm/comm_abc.py index a3cdc897..42c04b0b 100644 --- a/ndsl/comm/comm_abc.py +++ b/ndsl/comm/comm_abc.py @@ -29,74 +29,56 @@ class ReductionOperator(enum.Enum): class Request(abc.ABC): @abc.abstractmethod - def wait(self): - ... + def wait(self): ... class Comm(abc.ABC): @abc.abstractmethod - def Get_rank(self) -> int: - ... + def Get_rank(self) -> int: ... @abc.abstractmethod - def Get_size(self) -> int: - ... + def Get_size(self) -> int: ... @abc.abstractmethod - def bcast(self, value: Optional[T], root=0) -> T: - ... + def bcast(self, value: Optional[T], root=0) -> T: ... @abc.abstractmethod - def barrier(self): - ... + def barrier(self): ... @abc.abstractmethod - def Barrier(self): - ... + def Barrier(self): ... @abc.abstractmethod - def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): - ... + def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): ... @abc.abstractmethod - def Gather(self, sendbuf, recvbuf, root=0, **kwargs): - ... + def Gather(self, sendbuf, recvbuf, root=0, **kwargs): ... @abc.abstractmethod - def allgather(self, sendobj: T) -> List[T]: - ... + def allgather(self, sendobj: T) -> List[T]: ... @abc.abstractmethod - def Send(self, sendbuf, dest, tag: int = 0, **kwargs): - ... + def Send(self, sendbuf, dest, tag: int = 0, **kwargs): ... @abc.abstractmethod - def sendrecv(self, sendbuf, dest, **kwargs): - ... + def sendrecv(self, sendbuf, dest, **kwargs): ... @abc.abstractmethod - def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request: - ... + def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request: ... @abc.abstractmethod - def Recv(self, recvbuf, source, tag: int = 0, **kwargs): - ... + def Recv(self, recvbuf, source, tag: int = 0, **kwargs): ... @abc.abstractmethod - def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: - ... + def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: ... @abc.abstractmethod - def Split(self, color, key) -> Comm: - ... + def Split(self, color, key) -> Comm: ... @abc.abstractmethod - def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: - ... + def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: ... @abc.abstractmethod - def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: - ... + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: ... - def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: - ... + def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: ... diff --git a/ndsl/comm/partitioner.py b/ndsl/comm/partitioner.py index 6b8750a1..8b26e925 100644 --- a/ndsl/comm/partitioner.py +++ b/ndsl/comm/partitioner.py @@ -61,8 +61,9 @@ def __init__(self): self.layout = None @abc.abstractmethod - def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: - ... + def boundary( + self, boundary_type: int, rank: int + ) -> Optional[bd.SimpleBoundary]: ... @abc.abstractmethod def tile_index(self, rank: int): diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index b8e7da39..0b5283e6 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -424,9 +424,9 @@ def __get__(self, obj, objtype=None) -> SDFGEnabledCallable: """Return SDFGEnabledCallable wrapping original obj.method from cache. Update cache first if need be""" if (id(obj), id(self.func)) not in _LazyComputepathMethod.bound_callables: - _LazyComputepathMethod.bound_callables[ - (id(obj), id(self.func)) - ] = _LazyComputepathMethod.SDFGEnabledCallable(self, obj) + _LazyComputepathMethod.bound_callables[(id(obj), id(self.func))] = ( + _LazyComputepathMethod.SDFGEnabledCallable(self, obj) + ) return _LazyComputepathMethod.bound_callables[(id(obj), id(self.func))] diff --git a/ndsl/dsl/dace/wrapped_halo_exchange.py b/ndsl/dsl/dace/wrapped_halo_exchange.py index 78a68fa4..0e90ccaa 100644 --- a/ndsl/dsl/dace/wrapped_halo_exchange.py +++ b/ndsl/dsl/dace/wrapped_halo_exchange.py @@ -32,8 +32,8 @@ def __init__( @staticmethod def check_for_attribute(state: Any, attr: str): if dataclasses.is_dataclass(state): - return state.__getattribute__(attr) - elif isinstance(state, dict): + return state.__getattribute__(attr) # type: ignore + if isinstance(state, dict): return attr in state.keys() return False diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 7cefa5ab..1eea7cc0 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -359,14 +359,14 @@ def __init__( ): unblock_waiting_tiles(MPI.COMM_WORLD) - self._timing_collector.build_info[ - _stencil_object_name(self.stencil_object) - ] = build_info + self._timing_collector.build_info[_stencil_object_name(self.stencil_object)] = ( + build_info + ) field_info = self.stencil_object.field_info - self._field_origins: Dict[ - str, Tuple[int, ...] - ] = FrozenStencil._compute_field_origins(field_info, self.origin) + self._field_origins: Dict[str, Tuple[int, ...]] = ( + FrozenStencil._compute_field_origins(field_info, self.origin) + ) """mapping from field names to field origins""" self._stencil_run_kwargs: Dict[str, Any] = { diff --git a/ndsl/dsl/typing.py b/ndsl/dsl/typing.py index 5f60f401..33e01b10 100644 --- a/ndsl/dsl/typing.py +++ b/ndsl/dsl/typing.py @@ -38,7 +38,7 @@ def get_precision() -> int: def global_set_precision() -> Tuple[TypeAlias, TypeAlias]: """Set the global precision for all references of Float and Int in the codebase. Defaults to 64 bit.""" - global Float, Int + global Float, Int # noqa: F824 global ... is unused precision_in_bit = get_precision() if precision_in_bit == 64: return NDSL_64BIT_FLOAT_TYPE, NDSL_64BIT_INT_TYPE diff --git a/ndsl/grid/geometry.py b/ndsl/grid/geometry.py index 74441cde..c6862693 100644 --- a/ndsl/grid/geometry.py +++ b/ndsl/grid/geometry.py @@ -207,7 +207,7 @@ def calculate_supergrid_cos_sin( cos_sg[abs(1.0 - cos_sg) < 1e-15] = 1.0 - sin_sg_tmp = 1.0 - cos_sg ** 2 + sin_sg_tmp = 1.0 - cos_sg**2 sin_sg_tmp[sin_sg_tmp < 0] = 0.0 sin_sg = np.sqrt(sin_sg_tmp) sin_sg[sin_sg > 1.0] = 1.0 diff --git a/ndsl/grid/global_setup.py b/ndsl/grid/global_setup.py index 60bd3c3b..1f8f68dd 100644 --- a/ndsl/grid/global_setup.py +++ b/ndsl/grid/global_setup.py @@ -40,7 +40,7 @@ def gnomonic_grid(grid_type: int, lon, lat, np): # closer to the Fortran code def global_gnomonic_ed(lon, lat, np): im = lon.shape[0] - 1 - alpha = np.arcsin(3 ** -0.5) + alpha = np.arcsin(3**-0.5) dely = np.multiply(2.0, alpha) / float(im) pp = np.zeros((3, im + 1, im + 1)) @@ -68,16 +68,16 @@ def global_gnomonic_ed(lon, lat, np): i = 0 for j in range(1, im): pp[:, i, j] = _latlon2xyz(lon[i, j], lat[i, j], np) - pp[1, i, j] = -pp[1, i, j] * (3 ** -0.5) / pp[0, i, j] - pp[2, i, j] = -pp[2, i, j] * (3 ** -0.5) / pp[0, i, j] + pp[1, i, j] = -pp[1, i, j] * (3**-0.5) / pp[0, i, j] + pp[2, i, j] = -pp[2, i, j] * (3**-0.5) / pp[0, i, j] j = 0 for i in range(1, im): pp[:, i, j] = _latlon2xyz(lon[i, j], lat[i, j], np) - pp[1, i, j] = -pp[1, i, j] * (3 ** -0.5) / pp[0, i, j] - pp[2, i, j] = -pp[2, i, j] * (3 ** -0.5) / pp[0, i, j] + pp[1, i, j] = -pp[1, i, j] * (3**-0.5) / pp[0, i, j] + pp[2, i, j] = -pp[2, i, j] * (3**-0.5) / pp[0, i, j] - pp[0, :, :] = -(3 ** -0.5) + pp[0, :, :] = -(3**-0.5) for j in range(1, im + 1): # copy y-z face of the cube along j=0 pp[1, 1:, j] = pp[1, 1:, 0] diff --git a/ndsl/grid/gnomonic.py b/ndsl/grid/gnomonic.py index 0fc421ef..02d41e20 100644 --- a/ndsl/grid/gnomonic.py +++ b/ndsl/grid/gnomonic.py @@ -39,7 +39,7 @@ def local_gnomonic_ed( _check_shapes(lon, lat) # tile_im, wedge_dict, corner_dict, global_is, global_js im = lon.shape[0] - 1 - alpha = np.arcsin(3 ** -0.5) + alpha = np.arcsin(3**-0.5) tile_im = npx - 1 dely = np.multiply(2.0, alpha / float(tile_im)) halo = 3 @@ -96,10 +96,10 @@ def local_gnomonic_ed( lon_west_tile_edge[i, j], lat_west_tile_edge[i, j], np ) pp_west_tile_edge[1, i, j] = ( - -pp_west_tile_edge[1, i, j] * (3 ** -0.5) / pp_west_tile_edge[0, i, j] + -pp_west_tile_edge[1, i, j] * (3**-0.5) / pp_west_tile_edge[0, i, j] ) pp_west_tile_edge[2, i, j] = ( - -pp_west_tile_edge[2, i, j] * (3 ** -0.5) / pp_west_tile_edge[0, i, j] + -pp_west_tile_edge[2, i, j] * (3**-0.5) / pp_west_tile_edge[0, i, j] ) if west_edge: pp[:, 0, :] = pp_west_tile_edge[:, 0, :] @@ -110,10 +110,10 @@ def local_gnomonic_ed( lon_south_tile_edge[i, j], lat_south_tile_edge[i, j], np ) pp_south_tile_edge[1, i, j] = ( - -pp_south_tile_edge[1, i, j] * (3 ** -0.5) / pp_south_tile_edge[0, i, j] + -pp_south_tile_edge[1, i, j] * (3**-0.5) / pp_south_tile_edge[0, i, j] ) pp_south_tile_edge[2, i, j] = ( - -pp_south_tile_edge[2, i, j] * (3 ** -0.5) / pp_south_tile_edge[0, i, j] + -pp_south_tile_edge[2, i, j] * (3**-0.5) / pp_south_tile_edge[0, i, j] ) if south_edge: pp[:, :, 0] = pp_south_tile_edge[:, :, 0] @@ -138,7 +138,7 @@ def local_gnomonic_ed( if north_edge and east_edge: pp[:, im, im] = _latlon2xyz(lon_east, lat_north, np) - pp[0, :, :] = -(3 ** -0.5) + pp[0, :, :] = -(3**-0.5) for j in range(start_j, im + 1): # copy y-z face of the cube along j=0 pp[1, start_i:, j] = pp_south_tile_edge[1, start_i:, 0] # pp[1,:,0] @@ -167,14 +167,14 @@ def _corner_to_center_mean(corner_array): def normalize_vector(np, *vector_components): scale = np.divide( 1.0, - np.sum(np.asarray([item ** 2.0 for item in vector_components]), axis=0) ** 0.5, + np.sum(np.asarray([item**2.0 for item in vector_components]), axis=0) ** 0.5, ) return np.asarray([item * scale for item in vector_components]) def normalize_xyz(xyz): # double transpose to broadcast along last dimension instead of first - return (xyz.T / ((xyz ** 2).sum(axis=-1) ** 0.5).T).T + return (xyz.T / ((xyz**2).sum(axis=-1) ** 0.5).T).T def lon_lat_midpoint(lon1, lon2, lat1, lat2, np): @@ -606,7 +606,7 @@ def get_rectangle_area(p1, p2, p3, p4, radius, np): ) in ((p3, p2, p4), (p4, p3, p1), (p1, p4, p2)): total_angle += spherical_angle(q1, q2, q3, np) - return (total_angle - 2 * PI) * radius ** 2 + return (total_angle - 2 * PI) * radius**2 def get_triangle_area(p1, p2, p3, radius, np): @@ -618,7 +618,7 @@ def get_triangle_area(p1, p2, p3, radius, np): total_angle = spherical_angle(p1, p2, p3, np) for q1, q2, q3 in ((p2, p3, p1), (p3, p1, p2)): total_angle += spherical_angle(q1, q2, q3, np) - return (total_angle - PI) * radius ** 2 + return (total_angle - PI) * radius**2 def fortran_vector_spherical_angle(e1, e2, e3): @@ -678,8 +678,7 @@ def spherical_angle(p_center, p2, p3, np): p = np.cross(p_center, p2) q = np.cross(p_center, p3) angle = np.arccos( - np.sum(p * q, axis=-1) - / np.sqrt(np.sum(p ** 2, axis=-1) * np.sum(q ** 2, axis=-1)) + np.sum(p * q, axis=-1) / np.sqrt(np.sum(p**2, axis=-1) * np.sum(q**2, axis=-1)) ) if not np.isscalar(angle): angle[np.isnan(angle)] = 0.0 @@ -696,7 +695,7 @@ def spherical_cos(p_center, p2, p3, np): p = np.cross(p_center, p2) q = np.cross(p_center, p3) return np.sum(p * q, axis=-1) / np.sqrt( - np.sum(p ** 2, axis=-1) * np.sum(q ** 2, axis=-1) + np.sum(p**2, axis=-1) * np.sum(q**2, axis=-1) ) diff --git a/ndsl/grid/stretch_transformation.py b/ndsl/grid/stretch_transformation.py index 9a481ab9..8b0d554b 100644 --- a/ndsl/grid/stretch_transformation.py +++ b/ndsl/grid/stretch_transformation.py @@ -56,8 +56,8 @@ def direct_transform( lon_p, lat_p = np.deg2rad(lon_target), np.deg2rad(lat_target) sin_p, cos_p = np.sin(lat_p), np.cos(lat_p) - c2p1 = 1.0 + stretch_factor ** 2 - c2m1 = 1.0 - stretch_factor ** 2 + c2p1 = 1.0 + stretch_factor**2 + c2m1 = 1.0 - stretch_factor**2 # first limit longitude so it's between 0 and 2pi lon_data[lon_data < 0] += 2 * np.pi diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index 85ee17dd..081a9240 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -146,9 +146,11 @@ def _allocate( extent = self.sizer.get_extent(dims) shape = self.sizer.get_shape(dims) dimensions = [ - axis - if any(dim in axis_dims for axis_dims in SPATIAL_DIMS) - else str(shape[index]) + ( + axis + if any(dim in axis_dims for axis_dims in SPATIAL_DIMS) + else str(shape[index]) + ) for index, (dim, axis) in enumerate( zip(dims, ("I", "J", "K", *([None] * (len(dims) - 3)))) ) diff --git a/ndsl/logging.py b/ndsl/logging.py index 73b7979c..2ec6f979 100644 --- a/ndsl/logging.py +++ b/ndsl/logging.py @@ -62,9 +62,9 @@ def _ndsl_logger_on_rank_0() -> logging.Logger: return name_log -ndsl_log: Annotated[ - logging.Logger, "NDSL Python logger, logs on all rank" -] = _ndsl_logger() +ndsl_log: Annotated[logging.Logger, "NDSL Python logger, logs on all rank"] = ( + _ndsl_logger() +) ndsl_log_on_rank_0: Annotated[ logging.Logger, "NDSL Python logger, logs on rank 0 only" diff --git a/ndsl/monitor/protocol.py b/ndsl/monitor/protocol.py index e1448044..ec8bb5d5 100644 --- a/ndsl/monitor/protocol.py +++ b/ndsl/monitor/protocol.py @@ -12,8 +12,6 @@ def store(self, state: dict) -> None: """Append the model state dictionary to the stored data.""" ... - def store_constant(self, state: Dict[str, Quantity]) -> None: - ... + def store_constant(self, state: Dict[str, Quantity]) -> None: ... - def cleanup(self): - ... + def cleanup(self): ... diff --git a/ndsl/performance/collector.py b/ndsl/performance/collector.py index 8ec7a817..b64d51a2 100644 --- a/ndsl/performance/collector.py +++ b/ndsl/performance/collector.py @@ -25,21 +25,18 @@ class AbstractPerformanceCollector(Protocol): total_timer: Timer timestep_timer: Timer - def collect_performance(self): - ... + def collect_performance(self): ... def write_out_performance( self, backend: str, is_orchestrated: bool, dt_atmos: float, - ): - ... + ): ... def write_out_rank_0( self, backend: str, is_orchestrated: bool, dt_atmos: float, sim_status: str - ): - ... + ): ... @classmethod def start_cuda_profiler(cls): diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 2fc4f451..a4f7bd52 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -88,9 +88,11 @@ def __init__( dimensions: Tuple[Union[str, int], ...] = tuple( [ - axis - if any(dim in axis_dims for axis_dims in constants.SPATIAL_DIMS) - else str(data.shape[index]) + ( + axis + if any(dim in axis_dims for axis_dims in constants.SPATIAL_DIMS) + else str(data.shape[index]) + ) for index, (dim, axis) in enumerate( zip(dims, ("I", "J", "K", *([None] * (len(dims) - 3)))) ) diff --git a/ndsl/stencils/testing/savepoint.py b/ndsl/stencils/testing/savepoint.py index 2708011e..68a52fad 100644 --- a/ndsl/stencils/testing/savepoint.py +++ b/ndsl/stencils/testing/savepoint.py @@ -40,14 +40,11 @@ def load( class Translate(Protocol): - def collect_input_data(self, ds: xr.Dataset) -> dict: - ... + def collect_input_data(self, ds: xr.Dataset) -> dict: ... - def compute(self, data: dict): - ... + def compute(self, data: dict): ... - def extra_data_load(self, data_loader: DataLoader): - ... + def extra_data_load(self, data_loader: DataLoader): ... @dataclasses.dataclass diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index 18271a4a..02d00420 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -68,9 +68,9 @@ def process_override(threshold_overrides, testobj, test_name, backend): for key in testobj.out_vars.keys(): if key not in testobj.ignore_near_zero_errors: testobj.ignore_near_zero_errors[key] = {} - testobj.ignore_near_zero_errors[key][ - "near_zero" - ] = float(match["all_other_near_zero"]) + testobj.ignore_near_zero_errors[key]["near_zero"] = ( + float(match["all_other_near_zero"]) + ) else: raise TypeError( diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 7862904d..ab4d881d 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -30,17 +30,13 @@ def __init__( self.computed = np.atleast_1d(computed_values) self.check = False - def __str__(self) -> str: - ... + def __str__(self) -> str: ... - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... - def report(self, file_path: Optional[str] = None) -> List[str]: - ... + def report(self, file_path: Optional[str] = None) -> List[str]: ... - def one_line_report(self) -> str: - ... + def one_line_report(self) -> str: ... class LegacyMetric(BaseMetric): diff --git a/ndsl/viz/fv3/_plot_cube.py b/ndsl/viz/fv3/_plot_cube.py index 8942d494..210ef5a7 100644 --- a/ndsl/viz/fv3/_plot_cube.py +++ b/ndsl/viz/fv3/_plot_cube.py @@ -205,7 +205,7 @@ def plot_cube( fig, ax = plt.subplots(1, 1, subplot_kw={"projection": projection}) else: fig = ax.figure - handle = _plot_func_short(array, ax=ax) + handle = _plot_func_short(array, ax=ax) # type: ignore axes = np.array(ax) handles = [handle] facet_grid = None diff --git a/ndsl/viz/fv3/_plot_diagnostics.py b/ndsl/viz/fv3/_plot_diagnostics.py index 9c102759..e54f709e 100644 --- a/ndsl/viz/fv3/_plot_diagnostics.py +++ b/ndsl/viz/fv3/_plot_diagnostics.py @@ -8,6 +8,7 @@ """ + import os import matplotlib.pyplot as plt diff --git a/ndsl/viz/fv3/grid_metadata.py b/ndsl/viz/fv3/grid_metadata.py index 2171360e..f44df8be 100644 --- a/ndsl/viz/fv3/grid_metadata.py +++ b/ndsl/viz/fv3/grid_metadata.py @@ -5,8 +5,7 @@ class GridMetadata(abc.ABC): @property @abc.abstractmethod - def coord_vars(self) -> dict: - ... + def coord_vars(self) -> dict: ... @dataclasses.dataclass diff --git a/setup.cfg b/setup.cfg index ce80d954..5b289889 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,16 +1,12 @@ [flake8] exclude = docs -ignore = E203,E501,W293,W503,E302,E203,F841 +ignore = E203,E302,E704,F841,W293,E501,W503 max-line-length = 88 [aliases] [tool:isort] -line_length = 88 -force_grid_wrap = 0 -include_trailing_comma = true -multi_line_output = 3 -use_parentheses = true +profile = black lines_after_imports = 2 default_section = THIRDPARTY sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index bec096dd..c7c864cf 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -18,7 +18,7 @@ def layout(): if MPI is not None: size = MPI.COMM_WORLD.Get_size() ranks_per_tile = size // 6 - ranks_per_edge = int(ranks_per_tile ** 0.5) + ranks_per_edge = int(ranks_per_tile**0.5) return (ranks_per_edge, ranks_per_edge) else: return (1, 1) diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index b6c38e95..c9b35ecc 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -40,7 +40,7 @@ def layout(): if MPI is not None: size = MPI.COMM_WORLD.Get_size() ranks_per_tile = size // 6 - ranks_per_edge = int(ranks_per_tile ** 0.5) + ranks_per_edge = int(ranks_per_tile**0.5) return (ranks_per_edge, ranks_per_edge) else: return (1, 1) diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 6b441702..1e4113fd 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -304,7 +304,7 @@ def test_worker(comm, dummy_results, mpi_results, numpy): if isinstance(mpi, numpy.ndarray): numpy.testing.assert_array_equal(np.asarray(dummy), np.asarray(mpi)) elif isinstance(mpi, Exception): - assert type(dummy) == type(mpi) + assert type(dummy) is type(mpi) assert dummy.args == mpi.args else: assert dummy == mpi diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 2b4954d1..0250a96b 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -160,7 +160,7 @@ def test_compute_view_edit_all_domain(quantity, n_halo, n_dims, extent_1d): pytest.skip("cannot edit an empty domain") quantity.data[:] = 0.0 quantity.view[:] = 1 - assert quantity.np.sum(quantity.data) == extent_1d ** n_dims + assert quantity.np.sum(quantity.data) == extent_1d**n_dims if n_dims > 1: quantity.np.testing.assert_array_equal(quantity.data[:n_halo, :], 0.0) quantity.np.testing.assert_array_equal( diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index 40595669..dab27cb3 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -1,7 +1,8 @@ -""" Test of the GPU to GPU communication strategy. +"""Test of the GPU to GPU communication strategy. Those test use halo_update but are separated from the entire """ + import contextlib import functools @@ -92,7 +93,7 @@ def gpu_communicators(cube_partitioner): @contextlib.contextmanager def module_count_calls_to_zeros(module): - global N_ZEROS_CALLS + global N_ZEROS_CALLS # noqa: F824 global ... is unused N_ZEROS_CALLS[module.zeros] = 0 def count_calls(func): @@ -100,7 +101,7 @@ def count_calls(func): @functools.wraps(func) def wrapped(*args, **kwargs): - global N_ZEROS_CALLS + global N_ZEROS_CALLS # noqa: F824 global ... is unused N_ZEROS_CALLS[func] = N_ZEROS_CALLS[func] + 1 return func(*args, **kwargs) @@ -135,7 +136,7 @@ def test_halo_update_only_communicate_on_gpu(backend, gpu_communicators): halo_updater.wait() # We expect no np calls and several cp calls - global N_ZEROS_CALLS + global N_ZEROS_CALLS # noqa: F824 global ... is unused print(f"Results {N_ZEROS_CALLS}") assert N_ZEROS_CALLS[cp.zeros] > 0 assert N_ZEROS_CALLS[np.zeros] == 0 @@ -165,6 +166,6 @@ def test_halo_update_communicate_though_cpu(backend, cpu_communicators): halo_updater.wait() # We expect several np calls and several cp calls - global N_ZEROS_CALLS + global N_ZEROS_CALLS # noqa: F824 global ... is unused assert N_ZEROS_CALLS[np.zeros] > 0 assert N_ZEROS_CALLS[cp.zeros] == 0