Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ndsl/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
CachingRequestReader,
CachingRequestWriter,
)
from .comm_abc import Comm, Request
from .comm_abc import Comm, ReductionOperator, Request


__all__ = [
Expand All @@ -15,5 +15,6 @@
"CachingRequestReader",
"CachingRequestWriter",
"Comm",
"ReductionOperator",
"Request",
]
50 changes: 47 additions & 3 deletions ndsl/monitor/zarr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import xarray as xr

import ndsl.constants as constants
from ndsl.comm import Comm, ReductionOperator, Request
from ndsl.comm.partitioner import Partitioner, subtile_slice
from ndsl.logging import ndsl_log
from ndsl.monitor.convert import to_numpy
Expand All @@ -19,14 +20,16 @@
T = TypeVar("T")


class DummyComm:
class DummyComm(Comm[T]):
"""Dummy comm object that works in single-core mode."""

def Get_rank(self) -> int:
return 0

def Get_size(self) -> int:
return 1

def bcast(self, value: T, root: int = 0) -> T:
def bcast(self, value: T | None, root: int = 0) -> T | None:
assert root == 0, (
"DummyComm should only be used on a single core, "
"so root should only ever be 0"
Expand All @@ -36,6 +39,47 @@ def bcast(self, value: T, root: int = 0) -> T:
def barrier(self) -> None:
return

def Barrier(self) -> None:
raise NotImplementedError("DummyComm.Barrier")

def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
raise NotImplementedError("DummyComm.Scatter")

def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
raise NotImplementedError("DummyComm.Gather")

def allgather(self, sendobj: T) -> list[T]:
raise NotImplementedError("DummyComm.allgather")

def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
raise NotImplementedError("DummyComm.Send")

def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def]
raise NotImplementedError("DummyComm.sendrcv")

def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def]
raise NotImplementedError("DummyComm.Isend")

def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
raise NotImplementedError("DummyComm.Recv")

def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def]
raise NotImplementedError("DummyComm.Irecv")

def Split(self, color, key) -> DummyComm: # type: ignore[no-untyped-def]
raise NotImplementedError("DummyComm.Split")

def allreduce(
self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP
) -> T:
raise NotImplementedError("DummyComm.allreduce")

def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
raise NotImplementedError("DummyComm.Allreduce")

def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T:
raise NotImplementedError("DummyComm.Allreduce_inplace")


class ZarrMonitor:
"""
Expand All @@ -47,7 +91,7 @@ def __init__(
store: str | zarr.storage.MutableMapping,
partitioner: Partitioner,
mode: str = "w",
mpi_comm: DummyComm | None = None,
mpi_comm: Comm | None = None,
) -> None:
"""Create a ZarrMonitor.

Expand Down