diff --git a/ndsl/buffer.py b/ndsl/buffer.py index 05cd6434..8f0c90fd 100644 --- a/ndsl/buffer.py +++ b/ndsl/buffer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple @@ -40,7 +42,7 @@ def __init__(self, key: BufferKey, array: np.ndarray): @classmethod def pop_from_cache( cls, allocator: Allocator, shape: Iterable[int], dtype: type - ) -> "Buffer": + ) -> Buffer: """Retrieve or insert then retrieve of buffer from cache. Args: @@ -61,7 +63,7 @@ def pop_from_cache( return cls(key, array) @staticmethod - def push_to_cache(buffer: "Buffer"): + def push_to_cache(buffer: Buffer): """Push the buffer back into the cache. Args: diff --git a/ndsl/checkpointer/snapshots.py b/ndsl/checkpointer/snapshots.py index 1447c5fb..37f6dde3 100644 --- a/ndsl/checkpointer/snapshots.py +++ b/ndsl/checkpointer/snapshots.py @@ -31,7 +31,7 @@ def store(self, savepoint_name: str, variable_name: str, python_data): self._arrays[variable_name].append(python_data) @property - def dataset(self) -> "xr.Dataset": + def dataset(self) -> xr.Dataset: data_vars = {} for variable_name, savepoint_list in self._savepoints.items(): savepoint_dim = f"sp_{variable_name}" @@ -58,7 +58,7 @@ def __call__(self, savepoint_name, **kwargs): self._snapshots.store(savepoint_name, name, array_data) @property - def dataset(self) -> "xr.Dataset": + def dataset(self) -> xr.Dataset: return self._snapshots.dataset def cleanup(self): diff --git a/ndsl/checkpointer/thresholds.py b/ndsl/checkpointer/thresholds.py index fbf0e956..556bc0a9 100644 --- a/ndsl/checkpointer/thresholds.py +++ b/ndsl/checkpointer/thresholds.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections import contextlib import dataclasses @@ -23,7 +25,7 @@ class Threshold: relative: float absolute: float - def merge(self, other: "Threshold") -> "Threshold": + def merge(self, other: Threshold) -> Threshold: """ Provide a threshold which is always satisfied if both input thresholds are satisfied. @@ -126,9 +128,7 @@ def trial(self): self._n_trials += 1 @property - def thresholds( - self, - ) -> SavepointThresholds: + def thresholds(self) -> SavepointThresholds: if self._n_trials < 2: raise InsufficientTrialsError( "at least 2 trials required to generate thresholds" diff --git a/ndsl/comm/caching_comm.py b/ndsl/comm/caching_comm.py index 42f92ea2..75c4d2b8 100644 --- a/ndsl/comm/caching_comm.py +++ b/ndsl/comm/caching_comm.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import dataclasses import pickle @@ -82,7 +84,7 @@ def dump(self, file: BinaryIO): pickle.dump(self, file) @classmethod - def load(self, file: BinaryIO) -> "CachingCommData": + def load(self, file: BinaryIO) -> CachingCommData: return pickle.load(file) @@ -143,7 +145,7 @@ def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: def sendrecv(self, sendbuf, dest, **kwargs): raise NotImplementedError() - def Split(self, color, key) -> "CachingCommReader": + def Split(self, color, key) -> CachingCommReader: new_data = self._data.get_split() return CachingCommReader(data=new_data) @@ -154,7 +156,7 @@ def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: raise NotImplementedError("CachingCommReader.Allreduce") @classmethod - def load(cls, file: BinaryIO) -> "CachingCommReader": + def load(cls, file: BinaryIO) -> CachingCommReader: data = CachingCommData.load(file) return cls(data) @@ -223,7 +225,7 @@ def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: def sendrecv(self, sendbuf, dest, **kwargs): raise NotImplementedError() - def Split(self, color, key) -> "CachingCommWriter": + def Split(self, color, key) -> CachingCommWriter: new_comm = self._comm.Split(color=color, key=key) new_wrapper = CachingCommWriter(new_comm) self._data.split_data.append(new_wrapper._data) diff --git a/ndsl/comm/comm_abc.py b/ndsl/comm/comm_abc.py index 45596f1e..a3cdc897 100644 --- a/ndsl/comm/comm_abc.py +++ b/ndsl/comm/comm_abc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import enum from typing import List, Optional, TypeVar @@ -85,7 +87,7 @@ def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: ... @abc.abstractmethod - def Split(self, color, key) -> "Comm": + def Split(self, color, key) -> Comm: ... @abc.abstractmethod diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index c523affb..b4ea9a0c 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast @@ -55,7 +57,7 @@ def __init__( self.timer: Timer = timer if timer is not None else NullTimer() @abc.abstractproperty - def tile(self) -> "TileCommunicator": + def tile(self) -> TileCommunicator: pass @classmethod @@ -661,7 +663,7 @@ def from_layout( layout: Tuple[int, int], force_cpu: bool = False, timer: Optional[Timer] = None, - ) -> "TileCommunicator": + ) -> TileCommunicator: partitioner = TilePartitioner(layout=layout) return cls(comm=comm, partitioner=partitioner, force_cpu=force_cpu, timer=timer) @@ -794,7 +796,7 @@ def from_layout( layout: Tuple[int, int], force_cpu: bool = False, timer: Optional[Timer] = None, - ) -> "CubedSphereCommunicator": + ) -> CubedSphereCommunicator: partitioner = CubedSpherePartitioner(tile=TilePartitioner(layout=layout)) return cls(comm=comm, partitioner=partitioner, force_cpu=force_cpu, timer=timer) diff --git a/ndsl/comm/mpi.py b/ndsl/comm/mpi.py index 0b4a5540..076ef1d0 100644 --- a/ndsl/comm/mpi.py +++ b/ndsl/comm/mpi.py @@ -71,7 +71,7 @@ def Recv(self, recvbuf, source, tag: int = 0, **kwargs): def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: return self._comm.Irecv(recvbuf, source, tag=tag, **kwargs) - def Split(self, color, key) -> "Comm": + def Split(self, color, key) -> Comm: return self._comm.Split(color, key) def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: diff --git a/ndsl/dsl/dace/build.py b/ndsl/dsl/dace/build.py index 90d36a46..25c30d00 100644 --- a/ndsl/dsl/dace/build.py +++ b/ndsl/dsl/dace/build.py @@ -106,7 +106,7 @@ def get_sdfg_path( return sdfg_dir_path -def set_distributed_caches(config: "DaceConfig"): +def set_distributed_caches(config: DaceConfig): """In Run mode, check required file then point current rank cache to source cache""" # Execute specific initialization per orchestration state diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 27f17375..8f3d2688 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum import os from typing import Any, Dict, Optional, Tuple @@ -56,7 +58,7 @@ def _smallest_rank_middle(x: int, y: int, layout: Tuple[int, int]): def _determine_compiling_ranks( - config: "DaceConfig", + config: DaceConfig, partitioner: Partitioner, ) -> bool: """ diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 5c30367f..b8e7da39 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -370,7 +372,7 @@ class _LazyComputepathMethod: bound_callables: Dict[Tuple[int, int], "SDFGEnabledCallable"] = dict() class SDFGEnabledCallable(SDFGConvertible): - def __init__(self, lazy_method: "_LazyComputepathMethod", obj_to_bind): + def __init__(self, lazy_method: _LazyComputepathMethod, obj_to_bind): methodwrapper = dace.method(lazy_method.func) self.obj_to_bind = obj_to_bind self.lazy_method = lazy_method diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index d829bbd0..7cefa5ab 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import dataclasses import inspect @@ -605,7 +607,7 @@ def domain(self, domain): @classmethod def from_sizer_and_communicator( cls, sizer: GridSizer, comm: Communicator - ) -> "GridIndexing": + ) -> GridIndexing: # TODO: if this class is refactored to split off the *_edge booleans, # this init routine can be refactored to require only a GridSizer domain = cast( @@ -831,7 +833,7 @@ def get_shape( shape[i] += n return tuple(shape) - def restrict_vertical(self, k_start=0, nk=None) -> "GridIndexing": + def restrict_vertical(self, k_start=0, nk=None) -> GridIndexing: """ Returns a copy of itself with modified vertical origin and domain. @@ -969,7 +971,7 @@ def from_dims_halo( skip_passes=skip_passes, ) - def restrict_vertical(self, k_start=0, nk=None) -> "StencilFactory": + def restrict_vertical(self, k_start=0, nk=None) -> StencilFactory: return StencilFactory( config=self.config, grid_indexing=self.grid_indexing.restrict_vertical(k_start=k_start, nk=nk), diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 7b28c2ff..4cbc04ac 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import functools import warnings @@ -465,7 +467,7 @@ def from_external( communicator, grid_type, eta_file: str = "None", - ) -> "MetricTerms": + ) -> MetricTerms: """ Generates a metric terms object, using input from data contained in an externally generated tile file @@ -499,7 +501,7 @@ def from_tile_sizing( dy_const: float = 1000.0, deglat: float = 15.0, eta_file: str = "None", - ) -> "MetricTerms": + ) -> MetricTerms: sizer = SubtileGridSizer.from_tile_params( nx_tile=npx - 1, ny_tile=npy - 1, diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index 2fbc34a3..742d3343 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import pathlib @@ -86,7 +88,7 @@ class HorizontalGridData: edge_n: Quantity @classmethod - def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "HorizontalGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> HorizontalGridData: return cls( lon=metric_terms.lon, lat=metric_terms.lat, @@ -145,7 +147,7 @@ def __post_init__(self): self._p_interface = None @classmethod - def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "VerticalGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> VerticalGridData: return cls( ak=metric_terms.ak, bk=metric_terms.bk, @@ -258,9 +260,7 @@ class ContravariantGridData: rsin2: Quantity @classmethod - def new_from_metric_terms( - cls, metric_terms: MetricTerms - ) -> "ContravariantGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> ContravariantGridData: return cls( cosa=metric_terms.cosa, cosa_u=metric_terms.cosa_u, @@ -303,7 +303,7 @@ class AngleGridData: cos_sg9: Quantity @classmethod - def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "AngleGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> AngleGridData: return cls( sin_sg1=metric_terms.sin_sg1, sin_sg2=metric_terms.sin_sg2, @@ -752,7 +752,7 @@ class DriverGridData: grid_type: int @classmethod - def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "DriverGridData": + def new_from_metric_terms(cls, metric_terms: MetricTerms) -> DriverGridData: return cls.new_from_grid_variables( vlon=metric_terms.vlon, vlat=metric_terms.vlon, @@ -777,7 +777,7 @@ def new_from_grid_variables( es1: Quantity, ew2: Quantity, grid_type: int = 0, - ) -> "DriverGridData": + ) -> DriverGridData: try: vlon1, vlon2, vlon3 = split_quantity_along_last_dim(vlon) vlat1, vlat2, vlat3 = split_quantity_along_last_dim(vlat) diff --git a/ndsl/halo/data_transformer.py b/ndsl/halo/data_transformer.py index f3133974..56f3063b 100644 --- a/ndsl/halo/data_transformer.py +++ b/ndsl/halo/data_transformer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc from dataclasses import dataclass from enum import Enum @@ -238,7 +240,7 @@ def get( np_module: NumpyModule, exchange_descriptors_x: Sequence[HaloExchangeSpec], exchange_descriptors_y: Optional[Sequence[HaloExchangeSpec]] = None, - ) -> "HaloDataTransformer": + ) -> HaloDataTransformer: """Construct a module from a numpy-like module. Args: diff --git a/ndsl/halo/updater.py b/ndsl/halo/updater.py index 76f7608f..36d57e47 100644 --- a/ndsl/halo/updater.py +++ b/ndsl/halo/updater.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import defaultdict from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple @@ -43,7 +45,7 @@ class HaloUpdater: def __init__( self, - comm: "Communicator", + comm: Communicator, tag: int, transformers: Dict[int, HaloDataTransformer], timer: Timer, @@ -90,13 +92,13 @@ def __del__(self): @classmethod def from_scalar_specifications( cls, - comm: "Communicator", + comm: Communicator, numpy_like_module: NumpyModule, specifications: Iterable[QuantityHaloSpec], boundaries: Iterable[Boundary], tag: int, optional_timer: Optional[Timer] = None, - ) -> "HaloUpdater": + ) -> HaloUpdater: """ Create/retrieve as many packed buffer as needed and queue the slices to exchange. @@ -142,14 +144,14 @@ def from_scalar_specifications( @classmethod def from_vector_specifications( cls, - comm: "Communicator", + comm: Communicator, numpy_like_module: NumpyModule, specifications_x: Iterable[QuantityHaloSpec], specifications_y: Iterable[QuantityHaloSpec], boundaries: Iterable[Boundary], tag: int, optional_timer: Optional[Timer] = None, - ) -> "HaloUpdater": + ) -> HaloUpdater: """ Create/retrieve as many packed buffer as needed and queue the slices to exchange. diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 6d83cd89..2fc4f451 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from typing import Any, Iterable, Optional, Sequence, Tuple, Union, cast @@ -133,7 +135,7 @@ def from_data_array( origin: Sequence[int] = None, extent: Sequence[int] = None, gt4py_backend: Union[str, None] = None, - ) -> "Quantity": + ) -> Quantity: """ Initialize a Quantity from an xarray.DataArray. @@ -311,7 +313,7 @@ def transpose( self, target_dims: Sequence[Union[str, Iterable[str]]], allow_mismatch_float_precision: bool = False, - ) -> "Quantity": + ) -> Quantity: """Change the dimension order of this Quantity. Args: diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index 6d4b4b38..3241bd39 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -449,7 +449,7 @@ def get_halo_update_spec( ) @property - def grid_indexing(self) -> "GridIndexing": + def grid_indexing(self) -> GridIndexing: return GridIndexing( domain=tuple(int(item) for item in self.domain_shape_compute()), n_halo=self.halo, @@ -460,7 +460,7 @@ def grid_indexing(self) -> "GridIndexing": ) @property - def damping_coefficients(self) -> "DampingCoefficients": + def damping_coefficients(self) -> DampingCoefficients: if self._damping_coefficients is not None: return self._damping_coefficients self._damping_coefficients = DampingCoefficients( @@ -473,11 +473,11 @@ def damping_coefficients(self) -> "DampingCoefficients": ) return self._damping_coefficients - def set_damping_coefficients(self, damping_coefficients: "DampingCoefficients"): + def set_damping_coefficients(self, damping_coefficients: DampingCoefficients): self._damping_coefficients = damping_coefficients @property - def grid_data(self) -> "GridData": + def grid_data(self) -> GridData: if self._grid_data is not None: return self._grid_data @@ -835,7 +835,7 @@ def driver_grid_data(self) -> DriverGridData: ) return self._driver_grid_data - def set_grid_data(self, grid_data: "GridData"): + def set_grid_data(self, grid_data: GridData): self._grid_data = grid_data def make_grid_data(self, npx, npy, npz, communicator, backend): diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 0acee958..9f6693c2 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -129,7 +129,7 @@ def make_storage_data( names_4d: Optional[List[str]] = None, read_only: bool = False, full_shape: bool = False, - ) -> "Field": + ) -> Field: """Copy input data into a gt4py.storage with given shape. `array` is copied. Takes care of the device upload if necessary. diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index a67b6599..b3847ee4 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -324,7 +324,7 @@ def test_array_chunks(layout, tile_array_shape, array_dims, target): assert result == target -def _assert_no_nulls(dataset: "xr.Dataset"): +def _assert_no_nulls(dataset: xr.Dataset): number_of_null = dataset["var"].isnull().sum().item() total_size = dataset["var"].size