Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ default_language_version:

repos:
- repo: https://github.com/psf/black
rev: 20.8b1
rev: 25.1.0
hooks:
- id: black
additional_dependencies: ["click==8.0.4"]

- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.4.2
rev: v5.10.1
hooks:
- id: isort
args: ["--profile", "black"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
rev: v1.16.1
hooks:
- id: mypy
name: mypy-ndsl
Expand All @@ -27,14 +27,14 @@ repos:
ndsl/ndsl/gt4py_utils.py |
)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v5.0.0
hooks:
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
rev: 7.3.0
hooks:
- id: flake8
name: flake8
Expand Down
3 changes: 1 addition & 2 deletions ndsl/checkpointer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@

class Checkpointer(abc.ABC):
@abc.abstractmethod
def __call__(self, savepoint_name, **kwargs):
...
def __call__(self, savepoint_name, **kwargs): ...
54 changes: 18 additions & 36 deletions ndsl/comm/comm_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,74 +29,56 @@ class ReductionOperator(enum.Enum):

class Request(abc.ABC):
@abc.abstractmethod
def wait(self):
...
def wait(self): ...


class Comm(abc.ABC):
@abc.abstractmethod
def Get_rank(self) -> int:
...
def Get_rank(self) -> int: ...

@abc.abstractmethod
def Get_size(self) -> int:
...
def Get_size(self) -> int: ...

@abc.abstractmethod
def bcast(self, value: Optional[T], root=0) -> T:
...
def bcast(self, value: Optional[T], root=0) -> T: ...

@abc.abstractmethod
def barrier(self):
...
def barrier(self): ...

@abc.abstractmethod
def Barrier(self):
...
def Barrier(self): ...

@abc.abstractmethod
def Scatter(self, sendbuf, recvbuf, root=0, **kwargs):
...
def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): ...

@abc.abstractmethod
def Gather(self, sendbuf, recvbuf, root=0, **kwargs):
...
def Gather(self, sendbuf, recvbuf, root=0, **kwargs): ...

@abc.abstractmethod
def allgather(self, sendobj: T) -> List[T]:
...
def allgather(self, sendobj: T) -> List[T]: ...

@abc.abstractmethod
def Send(self, sendbuf, dest, tag: int = 0, **kwargs):
...
def Send(self, sendbuf, dest, tag: int = 0, **kwargs): ...

@abc.abstractmethod
def sendrecv(self, sendbuf, dest, **kwargs):
...
def sendrecv(self, sendbuf, dest, **kwargs): ...

@abc.abstractmethod
def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request:
...
def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request: ...

@abc.abstractmethod
def Recv(self, recvbuf, source, tag: int = 0, **kwargs):
...
def Recv(self, recvbuf, source, tag: int = 0, **kwargs): ...

@abc.abstractmethod
def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request:
...
def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: ...

@abc.abstractmethod
def Split(self, color, key) -> Comm:
...
def Split(self, color, key) -> Comm: ...

@abc.abstractmethod
def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
...
def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: ...

@abc.abstractmethod
def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
...
def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: ...

def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T:
...
def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: ...
5 changes: 3 additions & 2 deletions ndsl/comm/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def __init__(self):
self.layout = None

@abc.abstractmethod
def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]:
...
def boundary(
self, boundary_type: int, rank: int
) -> Optional[bd.SimpleBoundary]: ...

@abc.abstractmethod
def tile_index(self, rank: int):
Expand Down
6 changes: 3 additions & 3 deletions ndsl/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ def __get__(self, obj, objtype=None) -> SDFGEnabledCallable:
"""Return SDFGEnabledCallable wrapping original obj.method from cache.
Update cache first if need be"""
if (id(obj), id(self.func)) not in _LazyComputepathMethod.bound_callables:
_LazyComputepathMethod.bound_callables[
(id(obj), id(self.func))
] = _LazyComputepathMethod.SDFGEnabledCallable(self, obj)
_LazyComputepathMethod.bound_callables[(id(obj), id(self.func))] = (
_LazyComputepathMethod.SDFGEnabledCallable(self, obj)
)
Comment thread
romanc marked this conversation as resolved.

return _LazyComputepathMethod.bound_callables[(id(obj), id(self.func))]

Expand Down
4 changes: 2 additions & 2 deletions ndsl/dsl/dace/wrapped_halo_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(
@staticmethod
def check_for_attribute(state: Any, attr: str):
if dataclasses.is_dataclass(state):
return state.__getattribute__(attr)
elif isinstance(state, dict):
return state.__getattribute__(attr) # type: ignore
if isinstance(state, dict):
return attr in state.keys()
return False

Expand Down
12 changes: 6 additions & 6 deletions ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,14 +359,14 @@ def __init__(
):
unblock_waiting_tiles(MPI.COMM_WORLD)

self._timing_collector.build_info[
_stencil_object_name(self.stencil_object)
] = build_info
self._timing_collector.build_info[_stencil_object_name(self.stencil_object)] = (
build_info
)
field_info = self.stencil_object.field_info

self._field_origins: Dict[
str, Tuple[int, ...]
] = FrozenStencil._compute_field_origins(field_info, self.origin)
self._field_origins: Dict[str, Tuple[int, ...]] = (
FrozenStencil._compute_field_origins(field_info, self.origin)
)
"""mapping from field names to field origins"""

self._stencil_run_kwargs: Dict[str, Any] = {
Expand Down
2 changes: 1 addition & 1 deletion ndsl/dsl/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_precision() -> int:
def global_set_precision() -> Tuple[TypeAlias, TypeAlias]:
"""Set the global precision for all references of
Float and Int in the codebase. Defaults to 64 bit."""
global Float, Int
global Float, Int # noqa: F824 global ... is unused
precision_in_bit = get_precision()
if precision_in_bit == 64:
return NDSL_64BIT_FLOAT_TYPE, NDSL_64BIT_INT_TYPE
Expand Down
2 changes: 1 addition & 1 deletion ndsl/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def calculate_supergrid_cos_sin(

cos_sg[abs(1.0 - cos_sg) < 1e-15] = 1.0

sin_sg_tmp = 1.0 - cos_sg ** 2
sin_sg_tmp = 1.0 - cos_sg**2
sin_sg_tmp[sin_sg_tmp < 0] = 0.0
sin_sg = np.sqrt(sin_sg_tmp)
sin_sg[sin_sg > 1.0] = 1.0
Expand Down
12 changes: 6 additions & 6 deletions ndsl/grid/global_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def gnomonic_grid(grid_type: int, lon, lat, np):
# closer to the Fortran code
def global_gnomonic_ed(lon, lat, np):
im = lon.shape[0] - 1
alpha = np.arcsin(3 ** -0.5)
alpha = np.arcsin(3**-0.5)
dely = np.multiply(2.0, alpha) / float(im)
pp = np.zeros((3, im + 1, im + 1))

Expand Down Expand Up @@ -68,16 +68,16 @@ def global_gnomonic_ed(lon, lat, np):
i = 0
for j in range(1, im):
pp[:, i, j] = _latlon2xyz(lon[i, j], lat[i, j], np)
pp[1, i, j] = -pp[1, i, j] * (3 ** -0.5) / pp[0, i, j]
pp[2, i, j] = -pp[2, i, j] * (3 ** -0.5) / pp[0, i, j]
pp[1, i, j] = -pp[1, i, j] * (3**-0.5) / pp[0, i, j]
pp[2, i, j] = -pp[2, i, j] * (3**-0.5) / pp[0, i, j]

j = 0
for i in range(1, im):
pp[:, i, j] = _latlon2xyz(lon[i, j], lat[i, j], np)
pp[1, i, j] = -pp[1, i, j] * (3 ** -0.5) / pp[0, i, j]
pp[2, i, j] = -pp[2, i, j] * (3 ** -0.5) / pp[0, i, j]
pp[1, i, j] = -pp[1, i, j] * (3**-0.5) / pp[0, i, j]
pp[2, i, j] = -pp[2, i, j] * (3**-0.5) / pp[0, i, j]

pp[0, :, :] = -(3 ** -0.5)
pp[0, :, :] = -(3**-0.5)
for j in range(1, im + 1):
# copy y-z face of the cube along j=0
pp[1, 1:, j] = pp[1, 1:, 0]
Expand Down
25 changes: 12 additions & 13 deletions ndsl/grid/gnomonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def local_gnomonic_ed(
_check_shapes(lon, lat)
# tile_im, wedge_dict, corner_dict, global_is, global_js
im = lon.shape[0] - 1
alpha = np.arcsin(3 ** -0.5)
alpha = np.arcsin(3**-0.5)
tile_im = npx - 1
dely = np.multiply(2.0, alpha / float(tile_im))
halo = 3
Expand Down Expand Up @@ -96,10 +96,10 @@ def local_gnomonic_ed(
lon_west_tile_edge[i, j], lat_west_tile_edge[i, j], np
)
pp_west_tile_edge[1, i, j] = (
-pp_west_tile_edge[1, i, j] * (3 ** -0.5) / pp_west_tile_edge[0, i, j]
-pp_west_tile_edge[1, i, j] * (3**-0.5) / pp_west_tile_edge[0, i, j]
)
pp_west_tile_edge[2, i, j] = (
-pp_west_tile_edge[2, i, j] * (3 ** -0.5) / pp_west_tile_edge[0, i, j]
-pp_west_tile_edge[2, i, j] * (3**-0.5) / pp_west_tile_edge[0, i, j]
)
if west_edge:
pp[:, 0, :] = pp_west_tile_edge[:, 0, :]
Expand All @@ -110,10 +110,10 @@ def local_gnomonic_ed(
lon_south_tile_edge[i, j], lat_south_tile_edge[i, j], np
)
pp_south_tile_edge[1, i, j] = (
-pp_south_tile_edge[1, i, j] * (3 ** -0.5) / pp_south_tile_edge[0, i, j]
-pp_south_tile_edge[1, i, j] * (3**-0.5) / pp_south_tile_edge[0, i, j]
)
pp_south_tile_edge[2, i, j] = (
-pp_south_tile_edge[2, i, j] * (3 ** -0.5) / pp_south_tile_edge[0, i, j]
-pp_south_tile_edge[2, i, j] * (3**-0.5) / pp_south_tile_edge[0, i, j]
)
if south_edge:
pp[:, :, 0] = pp_south_tile_edge[:, :, 0]
Expand All @@ -138,7 +138,7 @@ def local_gnomonic_ed(
if north_edge and east_edge:
pp[:, im, im] = _latlon2xyz(lon_east, lat_north, np)

pp[0, :, :] = -(3 ** -0.5)
pp[0, :, :] = -(3**-0.5)
for j in range(start_j, im + 1):
# copy y-z face of the cube along j=0
pp[1, start_i:, j] = pp_south_tile_edge[1, start_i:, 0] # pp[1,:,0]
Expand Down Expand Up @@ -167,14 +167,14 @@ def _corner_to_center_mean(corner_array):
def normalize_vector(np, *vector_components):
scale = np.divide(
1.0,
np.sum(np.asarray([item ** 2.0 for item in vector_components]), axis=0) ** 0.5,
np.sum(np.asarray([item**2.0 for item in vector_components]), axis=0) ** 0.5,
)
return np.asarray([item * scale for item in vector_components])


def normalize_xyz(xyz):
# double transpose to broadcast along last dimension instead of first
return (xyz.T / ((xyz ** 2).sum(axis=-1) ** 0.5).T).T
return (xyz.T / ((xyz**2).sum(axis=-1) ** 0.5).T).T


def lon_lat_midpoint(lon1, lon2, lat1, lat2, np):
Expand Down Expand Up @@ -606,7 +606,7 @@ def get_rectangle_area(p1, p2, p3, p4, radius, np):
) in ((p3, p2, p4), (p4, p3, p1), (p1, p4, p2)):
total_angle += spherical_angle(q1, q2, q3, np)

return (total_angle - 2 * PI) * radius ** 2
return (total_angle - 2 * PI) * radius**2


def get_triangle_area(p1, p2, p3, radius, np):
Expand All @@ -618,7 +618,7 @@ def get_triangle_area(p1, p2, p3, radius, np):
total_angle = spherical_angle(p1, p2, p3, np)
for q1, q2, q3 in ((p2, p3, p1), (p3, p1, p2)):
total_angle += spherical_angle(q1, q2, q3, np)
return (total_angle - PI) * radius ** 2
return (total_angle - PI) * radius**2


def fortran_vector_spherical_angle(e1, e2, e3):
Expand Down Expand Up @@ -678,8 +678,7 @@ def spherical_angle(p_center, p2, p3, np):
p = np.cross(p_center, p2)
q = np.cross(p_center, p3)
angle = np.arccos(
np.sum(p * q, axis=-1)
/ np.sqrt(np.sum(p ** 2, axis=-1) * np.sum(q ** 2, axis=-1))
np.sum(p * q, axis=-1) / np.sqrt(np.sum(p**2, axis=-1) * np.sum(q**2, axis=-1))
)
if not np.isscalar(angle):
angle[np.isnan(angle)] = 0.0
Expand All @@ -696,7 +695,7 @@ def spherical_cos(p_center, p2, p3, np):
p = np.cross(p_center, p2)
q = np.cross(p_center, p3)
return np.sum(p * q, axis=-1) / np.sqrt(
np.sum(p ** 2, axis=-1) * np.sum(q ** 2, axis=-1)
np.sum(p**2, axis=-1) * np.sum(q**2, axis=-1)
)


Expand Down
4 changes: 2 additions & 2 deletions ndsl/grid/stretch_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def direct_transform(

lon_p, lat_p = np.deg2rad(lon_target), np.deg2rad(lat_target)
sin_p, cos_p = np.sin(lat_p), np.cos(lat_p)
c2p1 = 1.0 + stretch_factor ** 2
c2m1 = 1.0 - stretch_factor ** 2
c2p1 = 1.0 + stretch_factor**2
c2m1 = 1.0 - stretch_factor**2

# first limit longitude so it's between 0 and 2pi
lon_data[lon_data < 0] += 2 * np.pi
Expand Down
8 changes: 5 additions & 3 deletions ndsl/initialization/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ def _allocate(
extent = self.sizer.get_extent(dims)
shape = self.sizer.get_shape(dims)
dimensions = [
axis
if any(dim in axis_dims for axis_dims in SPATIAL_DIMS)
else str(shape[index])
(
axis
if any(dim in axis_dims for axis_dims in SPATIAL_DIMS)
else str(shape[index])
)
for index, (dim, axis) in enumerate(
zip(dims, ("I", "J", "K", *([None] * (len(dims) - 3))))
)
Expand Down
6 changes: 3 additions & 3 deletions ndsl/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def _ndsl_logger_on_rank_0() -> logging.Logger:
return name_log


ndsl_log: Annotated[
logging.Logger, "NDSL Python logger, logs on all rank"
] = _ndsl_logger()
ndsl_log: Annotated[logging.Logger, "NDSL Python logger, logs on all rank"] = (
_ndsl_logger()
)

ndsl_log_on_rank_0: Annotated[
logging.Logger, "NDSL Python logger, logs on rank 0 only"
Expand Down
6 changes: 2 additions & 4 deletions ndsl/monitor/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ def store(self, state: dict) -> None:
"""Append the model state dictionary to the stored data."""
...

def store_constant(self, state: Dict[str, Quantity]) -> None:
...
def store_constant(self, state: Dict[str, Quantity]) -> None: ...

def cleanup(self):
...
def cleanup(self): ...
Loading