Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions examples/standalone/runfile/acoustics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# type: ignore
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Tuple

import click
import f90nml
Expand All @@ -14,24 +14,20 @@
CubedSphereCommunicator,
CubedSpherePartitioner,
DaceConfig,
NullComm,
LocalComm,
MPIComm,
StencilConfig,
StencilFactory,
TilePartitioner,
)
from ndsl.comm import Comm
from ndsl.performance import Timer
from ndsl.stencils.testing import Grid
from pyfv3 import DynamicalCoreConfig
from pyfv3.stencils import AcousticDynamics
from pyfv3.testing import TranslateDynCore


try:
from mpi4py import MPI
except ImportError:
MPI = None


def dycore_config_from_namelist(data_directory: str) -> DynamicalCoreConfig:
"""
Reads the namelist at the given directory and sets
Expand Down Expand Up @@ -89,17 +85,15 @@ def get_state_from_input(
def set_up_communicator(
disable_halo_exchange: bool,
layout: Tuple[int, int],
) -> Tuple[Optional[MPI.Comm], Optional[CubedSphereCommunicator]]:
partitioner = CubedSpherePartitioner(TilePartitioner(layout))
if MPI is not None:
comm = MPI.COMM_WORLD
else:
comm = None
if not disable_halo_exchange:
assert comm is not None
cube_comm = CubedSphereCommunicator(comm, partitioner)
else:
cube_comm = CubedSphereCommunicator(NullComm(0, 0), partitioner)
) -> Tuple[Comm, CubedSphereCommunicator]:
comm = (
LocalComm(rank=0, total_ranks=1, buffer={})
if disable_halo_exchange
else MPIComm()
)
cube_comm = CubedSphereCommunicator(
comm, CubedSpherePartitioner(TilePartitioner(layout))
)
return comm, cube_comm


Expand Down
14 changes: 4 additions & 10 deletions examples/standalone/runfile/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@

import f90nml
import gt4py.cartesian.config
from mpi4py import MPI

from ndsl import NullComm
from ndsl import LocalComm
from pyfv3 import DynamicalCoreConfig


try:
from mpi4py import MPI
except ImportError:
MPI = None

local = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, local)
from runfile.dynamics import get_experiment_info, setup_dycore # noqa: E402
Expand Down Expand Up @@ -68,10 +64,8 @@ def parse_args() -> Namespace:
for iteration in range(iterations):
top_tile_rank = global_rank + size * iteration
if top_tile_rank < sub_tiles:
mpi_comm = NullComm(
rank=top_tile_rank,
total_ranks=6 * sub_tiles,
fill_value=0.0,
mpi_comm = LocalComm(
rank=top_tile_rank, total_ranks=6 * sub_tiles, buffer_dict={}
)
gt4py.cartesian.config.cache_settings["dir_name"] = os.environ.get(
"GT_CACHE_ROOT", f".gt_cache_{mpi_comm.Get_rank():06}"
Expand Down
16 changes: 11 additions & 5 deletions examples/standalone/runfile/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
CubedSphereCommunicator,
CubedSpherePartitioner,
DaceConfig,
NullComm,
LocalComm,
MPIComm,
StencilConfig,
StencilFactory,
TilePartitioner,
Expand Down Expand Up @@ -286,10 +287,15 @@ def setup_dycore(
namelist = f90nml.read(args.data_dir + "/input.nml")
dycore_config = DynamicalCoreConfig.from_f90nml(namelist)
experiment_name, is_baroclinic_test_case = get_experiment_info(args.data_dir)
if args.disable_halo_exchange:
mpi_comm = NullComm(MPI.COMM_WORLD.Get_rank(), MPI.COMM_WORLD.Get_size())
else:
mpi_comm = MPI.COMM_WORLD
mpi_comm = (
LocalComm(
rank=MPI.COMM_WORLD.Get_rank(),
total_ranks=MPI.COMM_WORLD.Get_size(),
buffer_dict={},
)
if args.disable_halo_exchange
else MPIComm()
)
dycore, state, stencil_factory = setup_dycore(
dycore_config,
mpi_comm,
Expand Down
6 changes: 4 additions & 2 deletions pyfv3/wrappers/geos_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
DaceConfig,
DaCeOrchestration,
GridIndexing,
NullComm,
LocalComm,
PerformanceCollector,
QuantityFactory,
StencilConfig,
Expand Down Expand Up @@ -114,7 +114,9 @@ def __init__(
# Look for an override to run on a single node
gtfv3_single_rank_override = int(os.getenv("GTFV3_SINGLE_RANK_OVERRIDE", -1))
if gtfv3_single_rank_override >= 0:
comm = NullComm(gtfv3_single_rank_override, 6, 42)
comm = LocalComm(
rank=gtfv3_single_rank_override, total_ranks=6, buffer_dict={}
)

# Make a custom performance collector for the GEOS wrapper
self.perf_collector = PerformanceCollector("GEOS wrapper", comm)
Expand Down