Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CreatesComm,
CreatesCommSelector,
MPICommConfig,
NullComm,
NullCommConfig,
ReaderCommConfig,
WriterCommConfig,
Expand All @@ -25,6 +26,7 @@
__version__ = "0.2.0"

__all__ = [
"NullComm",
"CreatesComm",
"CreatesCommSelector",
"MPICommConfig",
Expand Down
121 changes: 117 additions & 4 deletions pace/comm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,126 @@
import abc
import copy
import dataclasses
import os
from typing import Any, ClassVar, List

from ndsl import MPIComm, NullComm
from ndsl.comm import CachingCommReader, CachingCommWriter, Comm
from typing import Any, ClassVar, List, Mapping, TypeVar, cast

from ndsl import MPIComm
from ndsl.comm import (
CachingCommReader,
CachingCommWriter,
Comm,
ReductionOperator,
Request,
)
from pace.registry import Registry


T = TypeVar("T")


class NullAsyncResult(Request):
def __init__(self, recvbuf: Any = None) -> None:
self._recvbuf = recvbuf

def wait(self) -> None:
if self._recvbuf is not None:
self._recvbuf[:] = 0.0


class NullComm(Comm[T]):
"""
A class with a subset of the mpi4py Comm API, but which
'receives' a fill value (default zero) instead of using MPI.
"""

default_fill_value: T = cast(T, 0)

def __init__(self, rank: int, total_ranks: int, fill_value: T = default_fill_value):
"""
Args:
rank: rank to mock
total_ranks: number of total MPI ranks to mock
fill_value: fill halos with this value when performing
halo updates.
"""
self.rank = rank
self.total_ranks = total_ranks
self._fill_value = fill_value
self._split_comms: Mapping[Any, list[NullComm]] = {}

def __repr__(self) -> str:
return f"NullComm(rank={self.rank}, total_ranks={self.total_ranks})"

def Get_rank(self) -> int:
return self.rank

def Get_size(self) -> int:
return self.total_ranks

def bcast(self, value: T | None, root: int = 0) -> T | None:
return value

def barrier(self) -> None:
return

def Barrier(self) -> None:
return

def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
if recvbuf is not None:
recvbuf[:] = self._fill_value

def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
if recvbuf is not None:
recvbuf[:] = self._fill_value

def allgather(self, sendobj: T) -> list[T]:
return [copy.deepcopy(sendobj) for _ in range(self.total_ranks)]

def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
pass

def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def]
return NullAsyncResult()

def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
recvbuf[:] = self._fill_value

def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def]
return NullAsyncResult(recvbuf)

def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def]
return sendbuf

def Split(self, color, key) -> Comm: # type: ignore[no-untyped-def]
# key argument is ignored, assumes we're calling the ranks from least to
# greatest when mocking Split
self._split_comms[color] = self._split_comms.get(color, []) # type: ignore[index]
rank = len(self._split_comms[color])
total_ranks = rank + 1
new_comm = NullComm(
rank=rank, total_ranks=total_ranks, fill_value=self._fill_value
)
for comm in self._split_comms[color]:
# won't know how many ranks there are until everything is split
comm.total_ranks = total_ranks
self._split_comms[color].append(new_comm)
return new_comm

def allreduce(
self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP
) -> T:
return self._fill_value

def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
# TODO: what about reduction operator `op`?
recvobj = sendobj
return recvobj

def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T:
raise NotImplementedError("NullComm.Allreduce_inplace")


class CreatesComm(abc.ABC):
"""
Retrieves and does cleanup for a mpi4py-style Comm object.
Expand Down
4 changes: 2 additions & 2 deletions tests/main/driver/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import pytest

from ndsl import NullComm, StencilConfig
from ndsl import StencilConfig
from ndsl.performance.report import (
TimeReport,
gather_hit_counts,
gather_timing_data,
get_sypd,
)
from pace import CreatesCommSelector, DriverConfig, NullCommConfig
from pace import CreatesCommSelector, DriverConfig, NullComm, NullCommConfig


def get_driver_config(
Expand Down
3 changes: 1 addition & 2 deletions tests/main/driver/test_restart_fortran.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
CubedSphereCommunicator,
CubedSpherePartitioner,
LocalComm,
NullComm,
QuantityFactory,
SubtileGridSizer,
TilePartitioner,
)
from pace import FortranRestartInit, GeneratedGridConfig
from pace import FortranRestartInit, GeneratedGridConfig, NullComm
from pyshield import PHYSICS_PACKAGES
from tests.paths import REPO_ROOT

Expand Down
18 changes: 1 addition & 17 deletions tests/main/driver/test_restart_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
from ndsl import (
CubedSphereCommunicator,
CubedSpherePartitioner,
NullComm,
Quantity,
QuantityFactory,
SubtileGridSizer,
TilePartitioner,
)
from pace import (
AnalyticInit,
CreatesComm,
DriverConfig,
GeneratedGridConfig,
NullComm,
RestartConfig,
)
from pyshield import PHYSICS_PACKAGES
Expand All @@ -30,21 +29,6 @@
DIR = os.path.dirname(os.path.abspath(__file__))


class NullCommConfig(CreatesComm):
def __init__(self, layout):
self.layout = layout

def get_comm(self):
return NullComm(
rank=0,
total_ranks=6 * self.layout[0] * self.layout[1],
fill_value=0.0,
)

def cleanup(self, comm):
pass


def test_default_save_restart():
restart_config = RestartConfig()
assert restart_config.save_restart is False
Expand Down
3 changes: 2 additions & 1 deletion tests/main/fv3core/test_cartesian_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import numpy as np
import pytest

from ndsl import NullComm, TileCommunicator, TilePartitioner
from ndsl import TileCommunicator, TilePartitioner
from ndsl.constants import PI
from ndsl.grid import MetricTerms
from pace import NullComm


@pytest.mark.parametrize("npx", [8])
Expand Down
2 changes: 1 addition & 1 deletion tests/main/fv3core/test_dycore_baroclinic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
CubedSpherePartitioner,
DaceConfig,
GridIndexing,
NullComm,
QuantityFactory,
StencilConfig,
StencilFactory,
Expand All @@ -24,6 +23,7 @@
)
from ndsl.grid import DampingCoefficients, GridData, MetricTerms
from ndsl.performance.timer import NullTimer
from pace import NullComm
from pyfv3 import DycoreState, DynamicalCore, DynamicalCoreConfig


Expand Down
2 changes: 1 addition & 1 deletion tests/main/fv3core/test_dycore_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
CubedSpherePartitioner,
DaceConfig,
GridIndexing,
NullComm,
Quantity,
QuantityFactory,
StencilConfig,
Expand All @@ -22,6 +21,7 @@
from ndsl.grid import DampingCoefficients, GridData, MetricTerms
from ndsl.performance.timer import NullTimer, Timer
from ndsl.stencils.testing import assert_same_temporaries, copy_temporaries
from pace import NullComm
from pyfv3 import DycoreState, DynamicalCore, DynamicalCoreConfig
from pyfv3.initialization.analytic_init import AnalyticCase

Expand Down
2 changes: 1 addition & 1 deletion tests/main/fv3core/test_init_from_geos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pytest # noqa

from ndsl import NullComm
from pace import NullComm
from pyfv3 import DynamicalCore
from pyfv3.wrappers import GeosDycoreWrapper

Expand Down
2 changes: 1 addition & 1 deletion tests/main/physics/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
DaceConfig,
DaCeOrchestration,
GridIndexing,
NullComm,
QuantityFactory,
StencilConfig,
StencilFactory,
Expand All @@ -20,6 +19,7 @@
)
from ndsl.grid import GridData, MetricTerms
from ndsl.stencils.testing import assert_same_temporaries, copy_temporaries
from pace import NullComm
from pyshield import PHYSICS_PACKAGES, Physics, PhysicsConfig, PhysicsState


Expand Down
10 changes: 10 additions & 0 deletions tests/main/test_comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ndsl import CubedSphereCommunicator, CubedSpherePartitioner, TilePartitioner
from pace import NullComm


def test_can_create_cube_communicator():
null_comm = NullComm(rank=2, total_ranks=24)
partitioner = CubedSpherePartitioner(TilePartitioner(layout=(2, 2)))
communicator = CubedSphereCommunicator(null_comm, partitioner)

assert communicator.tile.partitioner
2 changes: 1 addition & 1 deletion tests/main/test_grid_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from ndsl import (
CubedSphereCommunicator,
CubedSpherePartitioner,
NullComm,
Quantity,
QuantityFactory,
SubtileGridSizer,
TilePartitioner,
)
from ndsl.grid import MetricTerms
from pace import NullComm


def get_cube_comm(layout, rank: int):
Expand Down
3 changes: 1 addition & 2 deletions tests/mpi/restart/test_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from ndsl import (
CubedSphereCommunicator,
CubedSpherePartitioner,
NullComm,
Quantity,
TilePartitioner,
)
from pace import DriverConfig, DriverState
from pace import DriverConfig, DriverState, NullComm
from pyshield import PHYSICS_PACKAGES


Expand Down