diff --git a/examples/standalone/runfile/acoustics.py b/examples/standalone/runfile/acoustics.py index 67ef6f02..b6c4ae28 100755 --- a/examples/standalone/runfile/acoustics.py +++ b/examples/standalone/runfile/acoustics.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # type: ignore from types import SimpleNamespace -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import click import f90nml @@ -14,11 +14,13 @@ CubedSphereCommunicator, CubedSpherePartitioner, DaceConfig, - NullComm, + LocalComm, + MPIComm, StencilConfig, StencilFactory, TilePartitioner, ) +from ndsl.comm import Comm from ndsl.performance import Timer from ndsl.stencils.testing import Grid from pyfv3 import DynamicalCoreConfig @@ -26,12 +28,6 @@ from pyfv3.testing import TranslateDynCore -try: - from mpi4py import MPI -except ImportError: - MPI = None - - def dycore_config_from_namelist(data_directory: str) -> DynamicalCoreConfig: """ Reads the namelist at the given directory and sets @@ -89,17 +85,15 @@ def get_state_from_input( def set_up_communicator( disable_halo_exchange: bool, layout: Tuple[int, int], -) -> Tuple[Optional[MPI.Comm], Optional[CubedSphereCommunicator]]: - partitioner = CubedSpherePartitioner(TilePartitioner(layout)) - if MPI is not None: - comm = MPI.COMM_WORLD - else: - comm = None - if not disable_halo_exchange: - assert comm is not None - cube_comm = CubedSphereCommunicator(comm, partitioner) - else: - cube_comm = CubedSphereCommunicator(NullComm(0, 0), partitioner) +) -> Tuple[Comm, CubedSphereCommunicator]: + comm = ( + LocalComm(rank=0, total_ranks=1, buffer={}) + if disable_halo_exchange + else MPIComm() + ) + cube_comm = CubedSphereCommunicator( + comm, CubedSpherePartitioner(TilePartitioner(layout)) + ) return comm, cube_comm diff --git a/examples/standalone/runfile/compile.py b/examples/standalone/runfile/compile.py index 15f7a880..245fd64f 100755 --- a/examples/standalone/runfile/compile.py +++ b/examples/standalone/runfile/compile.py @@ -7,16 +7,12 @@ import f90nml import gt4py.cartesian.config +from mpi4py import MPI -from ndsl import NullComm +from ndsl import LocalComm from pyfv3 import DynamicalCoreConfig -try: - from mpi4py import MPI -except ImportError: - MPI = None - local = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, local) from runfile.dynamics import get_experiment_info, setup_dycore # noqa: E402 @@ -68,10 +64,8 @@ def parse_args() -> Namespace: for iteration in range(iterations): top_tile_rank = global_rank + size * iteration if top_tile_rank < sub_tiles: - mpi_comm = NullComm( - rank=top_tile_rank, - total_ranks=6 * sub_tiles, - fill_value=0.0, + mpi_comm = LocalComm( + rank=top_tile_rank, total_ranks=6 * sub_tiles, buffer_dict={} ) gt4py.cartesian.config.cache_settings["dir_name"] = os.environ.get( "GT_CACHE_ROOT", f".gt_cache_{mpi_comm.Get_rank():06}" diff --git a/examples/standalone/runfile/dynamics.py b/examples/standalone/runfile/dynamics.py index ba99c94b..cd17968f 100755 --- a/examples/standalone/runfile/dynamics.py +++ b/examples/standalone/runfile/dynamics.py @@ -22,7 +22,8 @@ CubedSphereCommunicator, CubedSpherePartitioner, DaceConfig, - NullComm, + LocalComm, + MPIComm, StencilConfig, StencilFactory, TilePartitioner, @@ -286,10 +287,15 @@ def setup_dycore( namelist = f90nml.read(args.data_dir + "/input.nml") dycore_config = DynamicalCoreConfig.from_f90nml(namelist) experiment_name, is_baroclinic_test_case = get_experiment_info(args.data_dir) - if args.disable_halo_exchange: - mpi_comm = NullComm(MPI.COMM_WORLD.Get_rank(), MPI.COMM_WORLD.Get_size()) - else: - mpi_comm = MPI.COMM_WORLD + mpi_comm = ( + LocalComm( + rank=MPI.COMM_WORLD.Get_rank(), + total_ranks=MPI.COMM_WORLD.Get_size(), + buffer_dict={}, + ) + if args.disable_halo_exchange + else MPIComm() + ) dycore, state, stencil_factory = setup_dycore( dycore_config, mpi_comm, diff --git a/pyfv3/wrappers/geos_wrapper.py b/pyfv3/wrappers/geos_wrapper.py index 271c29cb..07353761 100644 --- a/pyfv3/wrappers/geos_wrapper.py +++ b/pyfv3/wrappers/geos_wrapper.py @@ -17,7 +17,7 @@ DaceConfig, DaCeOrchestration, GridIndexing, - NullComm, + LocalComm, PerformanceCollector, QuantityFactory, StencilConfig, @@ -114,7 +114,9 @@ def __init__( # Look for an override to run on a single node gtfv3_single_rank_override = int(os.getenv("GTFV3_SINGLE_RANK_OVERRIDE", -1)) if gtfv3_single_rank_override >= 0: - comm = NullComm(gtfv3_single_rank_override, 6, 42) + comm = LocalComm( + rank=gtfv3_single_rank_override, total_ranks=6, buffer_dict={} + ) # Make a custom performance collector for the GEOS wrapper self.perf_collector = PerformanceCollector("GEOS wrapper", comm)