Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ndsl/buffer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import contextlib
from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple

Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(self, key: BufferKey, array: np.ndarray):
@classmethod
def pop_from_cache(
cls, allocator: Allocator, shape: Iterable[int], dtype: type
) -> "Buffer":
) -> Buffer:
"""Retrieve or insert then retrieve of buffer from cache.

Args:
Expand All @@ -61,7 +63,7 @@ def pop_from_cache(
return cls(key, array)

@staticmethod
def push_to_cache(buffer: "Buffer"):
def push_to_cache(buffer: Buffer):
"""Push the buffer back into the cache.

Args:
Expand Down
4 changes: 2 additions & 2 deletions ndsl/checkpointer/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def store(self, savepoint_name: str, variable_name: str, python_data):
self._arrays[variable_name].append(python_data)

@property
def dataset(self) -> "xr.Dataset":
def dataset(self) -> xr.Dataset:
data_vars = {}
for variable_name, savepoint_list in self._savepoints.items():
savepoint_dim = f"sp_{variable_name}"
Expand All @@ -58,7 +58,7 @@ def __call__(self, savepoint_name, **kwargs):
self._snapshots.store(savepoint_name, name, array_data)

@property
def dataset(self) -> "xr.Dataset":
def dataset(self) -> xr.Dataset:
return self._snapshots.dataset

def cleanup(self):
Expand Down
8 changes: 4 additions & 4 deletions ndsl/checkpointer/thresholds.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import collections
import contextlib
import dataclasses
Expand All @@ -23,7 +25,7 @@ class Threshold:
relative: float
absolute: float

def merge(self, other: "Threshold") -> "Threshold":
def merge(self, other: Threshold) -> Threshold:
"""
Provide a threshold which is always satisfied
if both input thresholds are satisfied.
Expand Down Expand Up @@ -126,9 +128,7 @@ def trial(self):
self._n_trials += 1

@property
def thresholds(
self,
) -> SavepointThresholds:
def thresholds(self) -> SavepointThresholds:
if self._n_trials < 2:
raise InsufficientTrialsError(
"at least 2 trials required to generate thresholds"
Expand Down
10 changes: 6 additions & 4 deletions ndsl/comm/caching_comm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import dataclasses
import pickle
Expand Down Expand Up @@ -82,7 +84,7 @@ def dump(self, file: BinaryIO):
pickle.dump(self, file)

@classmethod
def load(self, file: BinaryIO) -> "CachingCommData":
def load(self, file: BinaryIO) -> CachingCommData:
return pickle.load(file)


Expand Down Expand Up @@ -143,7 +145,7 @@ def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request:
def sendrecv(self, sendbuf, dest, **kwargs):
raise NotImplementedError()

def Split(self, color, key) -> "CachingCommReader":
def Split(self, color, key) -> CachingCommReader:
new_data = self._data.get_split()
return CachingCommReader(data=new_data)

Expand All @@ -154,7 +156,7 @@ def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommReader.Allreduce")

@classmethod
def load(cls, file: BinaryIO) -> "CachingCommReader":
def load(cls, file: BinaryIO) -> CachingCommReader:
data = CachingCommData.load(file)
return cls(data)

Expand Down Expand Up @@ -223,7 +225,7 @@ def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request:
def sendrecv(self, sendbuf, dest, **kwargs):
raise NotImplementedError()

def Split(self, color, key) -> "CachingCommWriter":
def Split(self, color, key) -> CachingCommWriter:
new_comm = self._comm.Split(color=color, key=key)
new_wrapper = CachingCommWriter(new_comm)
self._data.split_data.append(new_wrapper._data)
Expand Down
4 changes: 3 additions & 1 deletion ndsl/comm/comm_abc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import abc
import enum
from typing import List, Optional, TypeVar
Expand Down Expand Up @@ -85,7 +87,7 @@ def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request:
...

@abc.abstractmethod
def Split(self, color, key) -> "Comm":
def Split(self, color, key) -> Comm:
...

@abc.abstractmethod
Expand Down
8 changes: 5 additions & 3 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import abc
from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast

Expand Down Expand Up @@ -55,7 +57,7 @@ def __init__(
self.timer: Timer = timer if timer is not None else NullTimer()

@abc.abstractproperty
def tile(self) -> "TileCommunicator":
def tile(self) -> TileCommunicator:
pass

@classmethod
Expand Down Expand Up @@ -661,7 +663,7 @@ def from_layout(
layout: Tuple[int, int],
force_cpu: bool = False,
timer: Optional[Timer] = None,
) -> "TileCommunicator":
) -> TileCommunicator:
partitioner = TilePartitioner(layout=layout)
return cls(comm=comm, partitioner=partitioner, force_cpu=force_cpu, timer=timer)

Expand Down Expand Up @@ -794,7 +796,7 @@ def from_layout(
layout: Tuple[int, int],
force_cpu: bool = False,
timer: Optional[Timer] = None,
) -> "CubedSphereCommunicator":
) -> CubedSphereCommunicator:
partitioner = CubedSpherePartitioner(tile=TilePartitioner(layout=layout))
return cls(comm=comm, partitioner=partitioner, force_cpu=force_cpu, timer=timer)

Expand Down
2 changes: 1 addition & 1 deletion ndsl/comm/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def Recv(self, recvbuf, source, tag: int = 0, **kwargs):
def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request:
return self._comm.Irecv(recvbuf, source, tag=tag, **kwargs)

def Split(self, color, key) -> "Comm":
def Split(self, color, key) -> Comm:
return self._comm.Split(color, key)

def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
Expand Down
2 changes: 1 addition & 1 deletion ndsl/dsl/dace/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_sdfg_path(
return sdfg_dir_path


def set_distributed_caches(config: "DaceConfig"):
def set_distributed_caches(config: DaceConfig):
"""In Run mode, check required file then point current rank cache to source cache"""

# Execute specific initialization per orchestration state
Expand Down
4 changes: 3 additions & 1 deletion ndsl/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import enum
import os
from typing import Any, Dict, Optional, Tuple
Expand Down Expand Up @@ -56,7 +58,7 @@ def _smallest_rank_middle(x: int, y: int, layout: Tuple[int, int]):


def _determine_compiling_ranks(
config: "DaceConfig",
config: DaceConfig,
partitioner: Partitioner,
) -> bool:
"""
Expand Down
4 changes: 3 additions & 1 deletion ndsl/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -370,7 +372,7 @@ class _LazyComputepathMethod:
bound_callables: Dict[Tuple[int, int], "SDFGEnabledCallable"] = dict()

class SDFGEnabledCallable(SDFGConvertible):
def __init__(self, lazy_method: "_LazyComputepathMethod", obj_to_bind):
def __init__(self, lazy_method: _LazyComputepathMethod, obj_to_bind):
methodwrapper = dace.method(lazy_method.func)
self.obj_to_bind = obj_to_bind
self.lazy_method = lazy_method
Expand Down
8 changes: 5 additions & 3 deletions ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import dataclasses
import inspect
Expand Down Expand Up @@ -605,7 +607,7 @@ def domain(self, domain):
@classmethod
def from_sizer_and_communicator(
cls, sizer: GridSizer, comm: Communicator
) -> "GridIndexing":
) -> GridIndexing:
# TODO: if this class is refactored to split off the *_edge booleans,
# this init routine can be refactored to require only a GridSizer
domain = cast(
Expand Down Expand Up @@ -831,7 +833,7 @@ def get_shape(
shape[i] += n
return tuple(shape)

def restrict_vertical(self, k_start=0, nk=None) -> "GridIndexing":
def restrict_vertical(self, k_start=0, nk=None) -> GridIndexing:
"""
Returns a copy of itself with modified vertical origin and domain.

Expand Down Expand Up @@ -969,7 +971,7 @@ def from_dims_halo(
skip_passes=skip_passes,
)

def restrict_vertical(self, k_start=0, nk=None) -> "StencilFactory":
def restrict_vertical(self, k_start=0, nk=None) -> StencilFactory:
return StencilFactory(
config=self.config,
grid_indexing=self.grid_indexing.restrict_vertical(k_start=k_start, nk=nk),
Expand Down
6 changes: 4 additions & 2 deletions ndsl/grid/generation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dataclasses
import functools
import warnings
Expand Down Expand Up @@ -465,7 +467,7 @@ def from_external(
communicator,
grid_type,
eta_file: str = "None",
) -> "MetricTerms":
) -> MetricTerms:
"""
Generates a metric terms object, using input from data contained in an
externally generated tile file
Expand Down Expand Up @@ -499,7 +501,7 @@ def from_tile_sizing(
dy_const: float = 1000.0,
deglat: float = 15.0,
eta_file: str = "None",
) -> "MetricTerms":
) -> MetricTerms:
sizer = SubtileGridSizer.from_tile_params(
nx_tile=npx - 1,
ny_tile=npy - 1,
Expand Down
16 changes: 8 additions & 8 deletions ndsl/grid/helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dataclasses
import pathlib

Expand Down Expand Up @@ -86,7 +88,7 @@ class HorizontalGridData:
edge_n: Quantity

@classmethod
def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "HorizontalGridData":
def new_from_metric_terms(cls, metric_terms: MetricTerms) -> HorizontalGridData:
return cls(
lon=metric_terms.lon,
lat=metric_terms.lat,
Expand Down Expand Up @@ -145,7 +147,7 @@ def __post_init__(self):
self._p_interface = None

@classmethod
def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "VerticalGridData":
def new_from_metric_terms(cls, metric_terms: MetricTerms) -> VerticalGridData:
return cls(
ak=metric_terms.ak,
bk=metric_terms.bk,
Expand Down Expand Up @@ -258,9 +260,7 @@ class ContravariantGridData:
rsin2: Quantity

@classmethod
def new_from_metric_terms(
cls, metric_terms: MetricTerms
) -> "ContravariantGridData":
def new_from_metric_terms(cls, metric_terms: MetricTerms) -> ContravariantGridData:
return cls(
cosa=metric_terms.cosa,
cosa_u=metric_terms.cosa_u,
Expand Down Expand Up @@ -303,7 +303,7 @@ class AngleGridData:
cos_sg9: Quantity

@classmethod
def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "AngleGridData":
def new_from_metric_terms(cls, metric_terms: MetricTerms) -> AngleGridData:
return cls(
sin_sg1=metric_terms.sin_sg1,
sin_sg2=metric_terms.sin_sg2,
Expand Down Expand Up @@ -752,7 +752,7 @@ class DriverGridData:
grid_type: int

@classmethod
def new_from_metric_terms(cls, metric_terms: MetricTerms) -> "DriverGridData":
def new_from_metric_terms(cls, metric_terms: MetricTerms) -> DriverGridData:
return cls.new_from_grid_variables(
vlon=metric_terms.vlon,
vlat=metric_terms.vlon,
Expand All @@ -777,7 +777,7 @@ def new_from_grid_variables(
es1: Quantity,
ew2: Quantity,
grid_type: int = 0,
) -> "DriverGridData":
) -> DriverGridData:
try:
vlon1, vlon2, vlon3 = split_quantity_along_last_dim(vlon)
vlat1, vlat2, vlat3 = split_quantity_along_last_dim(vlat)
Expand Down
4 changes: 3 additions & 1 deletion ndsl/halo/data_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import abc
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -238,7 +240,7 @@ def get(
np_module: NumpyModule,
exchange_descriptors_x: Sequence[HaloExchangeSpec],
exchange_descriptors_y: Optional[Sequence[HaloExchangeSpec]] = None,
) -> "HaloDataTransformer":
) -> HaloDataTransformer:
"""Construct a module from a numpy-like module.

Args:
Expand Down
12 changes: 7 additions & 5 deletions ndsl/halo/updater.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple

Expand Down Expand Up @@ -43,7 +45,7 @@ class HaloUpdater:

def __init__(
self,
comm: "Communicator",
comm: Communicator,
tag: int,
transformers: Dict[int, HaloDataTransformer],
timer: Timer,
Expand Down Expand Up @@ -90,13 +92,13 @@ def __del__(self):
@classmethod
def from_scalar_specifications(
cls,
comm: "Communicator",
comm: Communicator,
numpy_like_module: NumpyModule,
specifications: Iterable[QuantityHaloSpec],
boundaries: Iterable[Boundary],
tag: int,
optional_timer: Optional[Timer] = None,
) -> "HaloUpdater":
) -> HaloUpdater:
"""
Create/retrieve as many packed buffer as needed and
queue the slices to exchange.
Expand Down Expand Up @@ -142,14 +144,14 @@ def from_scalar_specifications(
@classmethod
def from_vector_specifications(
cls,
comm: "Communicator",
comm: Communicator,
numpy_like_module: NumpyModule,
specifications_x: Iterable[QuantityHaloSpec],
specifications_y: Iterable[QuantityHaloSpec],
boundaries: Iterable[Boundary],
tag: int,
optional_timer: Optional[Timer] = None,
) -> "HaloUpdater":
) -> HaloUpdater:
"""
Create/retrieve as many packed buffer as needed and queue
the slices to exchange.
Expand Down
Loading