Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
args: ["--profile", "black"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.1
rev: v1.18.2
hooks:
- id: mypy
name: mypy-ndsl
Expand Down
10 changes: 9 additions & 1 deletion ndsl/comm/caching_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommReader.Allreduce")

def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommReader.Allreduce_inplace")

@classmethod
def load(cls, file: BinaryIO) -> CachingCommReader:
data = CachingCommData.load(file)
Expand Down Expand Up @@ -234,10 +237,15 @@ def Split(self, color, key) -> CachingCommWriter:
def dump(self, file: BinaryIO):
self._data.dump(file)

def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
def allreduce(
self, sendobj, op: ReductionOperator = ReductionOperator.NO_OP
) -> Any:
result = self._comm.allreduce(sendobj, op)
self._data.generic_obj_buffers.append(copy.deepcopy(result))
return result

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommWriter.Allreduce")

def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> Any:
raise NotImplementedError("CachingCommWriter.Allreduce_inplace")
5 changes: 4 additions & 1 deletion ndsl/comm/comm_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,12 @@ def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: ...
def Split(self, color, key) -> Comm: ...

@abc.abstractmethod
def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: ...
def allreduce(
self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP
) -> T: ...

@abc.abstractmethod
def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: ...

@abc.abstractmethod
def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: ...
8 changes: 4 additions & 4 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def all_reduce(
self,
input_quantity: Quantity,
op: ReductionOperator,
output_quantity: Quantity = None,
output_quantity: Quantity | None = None,
):
reduced_quantity_data = self.comm.allreduce(input_quantity.data, op)
if output_quantity is None:
Expand Down Expand Up @@ -243,8 +243,8 @@ def _get_scatter_recv_quantity(
return recv_quantity

def gather(
self, send_quantity: Quantity, recv_quantity: Quantity = None
) -> Optional[Quantity]:
self, send_quantity: Quantity, recv_quantity: Quantity | None = None
) -> Quantity | None:
"""Transfer subtile regions of a full-tile quantity
from each rank to the tile root rank.

Expand All @@ -255,7 +255,7 @@ def gather(
Returns:
recv_quantity: quantity if on root rank, otherwise None
"""
result: Optional[Quantity]
result: Quantity | None
if self.rank == constants.ROOT_RANK:
with array_buffer(
send_quantity.np.zeros,
Expand Down
8 changes: 7 additions & 1 deletion ndsl/comm/local_comm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
from typing import Any

from ndsl.comm.comm_abc import Comm
from ndsl.comm.comm_abc import Comm, ReductionOperator
from ndsl.logging import ndsl_log
from ndsl.utils import ensure_contiguous, safe_assign_array

Expand Down Expand Up @@ -200,3 +200,9 @@ def Allreduce(self, sendobj, recvobj, op) -> Any:
"Allreduce fundamentally cannot be written for LocalComm, "
"as it requires synchronicity"
)

def Allreduce_inplace(self, obj: Any, op: ReductionOperator) -> Any:
raise NotImplementedError(
"Allreduce_inplace fundamentally cannot be written for LocalComm, "
"as it requires synchronicity"
)
4 changes: 3 additions & 1 deletion ndsl/comm/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request:
def Split(self, color, key) -> Comm:
return self._comm.Split(color, key)

def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
def allreduce(
self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP
) -> T:
return self._comm.allreduce(sendobj, self._op_mapping[op])

def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> T:
Expand Down
7 changes: 5 additions & 2 deletions ndsl/comm/null_comm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Any, Mapping, Optional
from typing import Any, Mapping

from ndsl.comm.comm_abc import Comm, ReductionOperator, Request

Expand Down Expand Up @@ -91,9 +91,12 @@ def Split(self, color, key):
self._split_comms[color].append(new_comm)
return new_comm

def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any:
def allreduce(self, sendobj, op: ReductionOperator | None = None) -> Any:
return self._fill_value

def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any:
recvobj = sendobj
return recvobj

def Allreduce_inplace(self, obj: Any, op: ReductionOperator) -> Any:
raise NotImplementedError("NullComm.Allreduce_inplace")
4 changes: 2 additions & 2 deletions ndsl/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _build_sdfg(
raise ValueError("Couldn't load SDFG post build")
compiledSDFG, _ = dace_program.load_precompiled_sdfg(
sdfg_path, *args, **kwargs
) # type: ignore
)
config.loaded_precompiled_SDFG[dace_program] = compiledSDFG


Expand Down Expand Up @@ -553,7 +553,7 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None):


def orchestrate_function(
config: DaceConfig = None,
config: DaceConfig,
dace_compiletime_args: Sequence[str] | None = None,
) -> Callable[..., Any] | _LazyComputepathFunction:
"""
Expand Down
42 changes: 22 additions & 20 deletions ndsl/dsl/dace/sdfg_debug_passes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from typing import List, Optional, Tuple

import dace
import sympy as sp
Expand All @@ -14,11 +13,11 @@

def _filter_all_maps(
sdfg: dace.SDFG,
whitelist: List[str] = None,
blacklist: List[str] = None,
whitelist: list[str] = [],
blacklist: list[str] = [],
skip_dynamic_memlet=True,
) -> List[
Tuple[dace.SDFGState, dace.nodes.AccessNode, gr.MultiConnectorEdge[dace.Memlet]]
) -> list[
tuple[dace.SDFGState, dace.nodes.AccessNode, gr.MultiConnectorEdge[dace.Memlet]]
]:
"""
Grab all maps outputs and filter by variable name (either black or whitelist)
Expand All @@ -34,8 +33,8 @@ def _filter_all_maps(
[state, node, edges]
"""

checks: List[
Tuple[dace.SDFGState, dace.nodes.AccessNode, gr.MultiConnectorEdge[dace.Memlet]]
checks: list[
tuple[dace.SDFGState, dace.nodes.AccessNode, gr.MultiConnectorEdge[dace.Memlet]]
] = []
all_maps = [
(me, state)
Expand All @@ -54,13 +53,11 @@ def _filter_all_maps(
continue
node = sdutil.get_last_view_node(state, e.dst)
# Whitelist
if whitelist is not None:
if all([varname not in node.data for varname in whitelist]):
continue
if all([varname not in node.data for varname in whitelist]):
continue
# Blacklist
if blacklist is not None:
if any([varname in node.data for varname in blacklist]):
continue
if any([varname in node.data for varname in blacklist]):
continue
# Skip dynamic (region) outputs
if skip_dynamic_memlet and state.memlet_path(e)[0].data.dynamic:
dynamic_skipped += 1
Expand All @@ -81,7 +78,7 @@ def _check_node(
check_c_code: str,
comment_c_code: str,
assert_out: bool = False,
array_range: Optional[List[Tuple[int, int, int]]] = None,
array_range: list[tuple[int, int, int]] | None = None,
):
"""
Grab all maps outputs and filter by variable name (either black or whitelist)
Expand Down Expand Up @@ -268,20 +265,25 @@ def negative_qtracers_checker(sdfg: dace.SDFG):

def sdfg_nan_checker(
sdfg: dace.SDFG,
i_range: Optional[Tuple[int, int, int]] = None,
j_range: Optional[Tuple[int, int, int]] = None,
k_range: Optional[Tuple[int, int, int]] = None,
i_range: tuple[int, int, int] | None = None,
j_range: tuple[int, int, int] | None = None,
k_range: tuple[int, int, int] | None = None,
):
"""
Insert a check on array after each computational map to check for NaN
in the domain. Assert when check is True.
"""
all_maps_filtered = _filter_all_maps(sdfg, blacklist=["diss_estd"])

if i_range or j_range or k_range:
array_range = [i_range, j_range, k_range]
else:
if i_range is None and j_range is None and k_range is None:
array_range = None
else:
if i_range is not None and j_range is not None and k_range is not None:
array_range = [i_range, j_range, k_range]
else:
raise RuntimeError(
"It looks like you have to specify either all or not of the ranges."
)

for state, node, e in all_maps_filtered:
_check_node(
Expand Down
8 changes: 4 additions & 4 deletions ndsl/dsl/dace/wrapped_halo_exchange.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Any, List, Optional
from typing import Any

from ndsl.comm.communicator import Communicator
from ndsl.dsl.dace.orchestration import dace_inhibitor
Expand All @@ -19,9 +19,9 @@ def __init__(
self,
updater: HaloUpdater,
state,
qty_x_names: List[str],
qty_y_names: List[str] = None,
comm: Optional[Communicator] = None,
qty_x_names: list[str],
qty_y_names: list[str] | None = None,
comm: Communicator | None = None,
) -> None:
self._updater = updater
self._state = state
Expand Down
Loading
Loading