diff --git a/pyfv3/stencils/fv_dynamics.py b/pyfv3/stencils/fv_dynamics.py index a60998dc..bb6057f6 100644 --- a/pyfv3/stencils/fv_dynamics.py +++ b/pyfv3/stencils/fv_dynamics.py @@ -14,7 +14,7 @@ from ndsl.dsl.typing import Float, FloatField from ndsl.grid import DampingCoefficients, GridData from ndsl.logging import ndsl_log -from ndsl.performance import NullTimer, Timer +from ndsl.performance import Timer from ndsl.stencils.basic_operations import copy from ndsl.stencils.c2l_ord import CubedToLatLon from ndsl.typing import Checkpointer, Communicator @@ -55,9 +55,7 @@ def omega_from_w(delp: FloatField, delz: FloatField, w: FloatField, omega: Float omega = delp / delz * w -def fvdyn_temporaries( - quantity_factory: QuantityFactory, -) -> Mapping[str, Quantity]: +def fvdyn_temporaries(quantity_factory: QuantityFactory) -> Mapping[str, Quantity]: tmps = {} for name in ["te_2d", "te0_2d", "wsd"]: quantity = quantity_factory.zeros( @@ -77,10 +75,10 @@ def fvdyn_temporaries( @dace_inhibitor -def log_on_rank_0(msg: str): +def log_on_rank_0(message: str) -> None: """Print when rank is 0 - outside of DaCe critical path""" if not MPI or MPI.COMM_WORLD.Get_rank() == 0: - ndsl_log.info(msg) + ndsl_log.info(message) class DynamicalCore: @@ -100,7 +98,7 @@ def __init__( state: DycoreState, timestep: timedelta, checkpointer: Checkpointer | None = None, - ): + ) -> None: """ Args: comm: object for cubed sphere or tile inter-process communication @@ -127,7 +125,7 @@ def __init__( obj=self, config=stencil_factory.config.dace_config, method_to_orchestrate="compute_preamble", - dace_compiletime_args=["state", "is_root_rank"], + dace_compiletime_args=["state"], ) orchestrate( @@ -181,7 +179,7 @@ def __init__( # have not implemented, so they are hard-coded here. self.call_checkpointer = checkpointer is not None if checkpointer is None: - self.checkpointer: Checkpointer = NullCheckpointer() + self.checkpointer = NullCheckpointer() else: self.checkpointer = checkpointer nested = False @@ -336,7 +334,7 @@ def __init__( def _get_da_min(self) -> float: return self._da_min - def _checkpoint_fvdynamics(self, state: DycoreState, tag: str): + def _checkpoint_fvdynamics(self, state: DycoreState, tag: str) -> None: if self.call_checkpointer: self.checkpointer( f"FVDynamics-{tag}", @@ -355,10 +353,7 @@ def _checkpoint_fvdynamics(self, state: DycoreState, tag: str): qvapor=state.qvapor, ) - def _checkpoint_remapping_in( - self, - state: DycoreState, - ): + def _checkpoint_remapping_in(self, state: DycoreState) -> None: if self.call_checkpointer: self.checkpointer( "Remapping-In", @@ -386,10 +381,7 @@ def _checkpoint_remapping_in( dp1=self._dp_initial, ) - def _checkpoint_remapping_out( - self, - state: DycoreState, - ): + def _checkpoint_remapping_out(self, state: DycoreState) -> None: if self.call_checkpointer: self.checkpointer( "Remapping-Out", @@ -411,10 +403,7 @@ def _checkpoint_remapping_out( dp1=self._dp_initial, ) - def _checkpoint_tracer_advection_in( - self, - state: DycoreState, - ): + def _checkpoint_tracer_advection_in(self, state: DycoreState) -> None: if self.call_checkpointer: self.checkpointer( "Tracer2D1L-In", @@ -425,10 +414,7 @@ def _checkpoint_tracer_advection_in( cyd=state.cyd, ) - def _checkpoint_tracer_advection_out( - self, - state: DycoreState, - ): + def _checkpoint_tracer_advection_out(self, state: DycoreState) -> None: if self.call_checkpointer: self.checkpointer( "Tracer2D1L-Out", @@ -439,11 +425,7 @@ def _checkpoint_tracer_advection_out( cyd=state.cyd, ) - def step_dynamics( - self, - state: DycoreState, - timer: Timer | None = None, - ): + def step_dynamics(self, state: DycoreState, timer: Timer) -> None: """ Step the model state forward by one timestep. @@ -451,16 +433,14 @@ def step_dynamics( state: model prognostic state and inputs timer: keep time of model sections """ - if timer is None: - timer = NullTimer() - self._checkpoint_fvdynamics(state=state, tag="In") self._compute(state, timer) self._checkpoint_fvdynamics(state=state, tag="Out") - def compute_preamble(self, state: DycoreState, is_root_rank: bool): + def compute_preamble(self, state: DycoreState) -> None: if self.config.hydrostatic: raise NotImplementedError("Hydrostatic is not implemented") + if __debug__: log_on_rank_0("FV Setup") @@ -497,26 +477,24 @@ def compute_preamble(self, state: DycoreState, is_root_rank: bool): "Dynamical Core (fv_dynamics): Adiabatic with positive kord_tm" " is not implemented." ) - else: - if __debug__: - log_on_rank_0("Adjust pt") - self._pt_to_potential_density_pt( - state.pkz, - self._dp_initial, - state.q_con, - state.pt, - ) - def __call__(self, *args, **kwargs): - return self.step_dynamics(*args, **kwargs) + if __debug__: + log_on_rank_0("Adjust pt") - def _compute(self, state: DycoreState, timer: Timer): - last_step = False - self.compute_preamble( - state, - is_root_rank=self.comm_rank == 0, + self._pt_to_potential_density_pt( + state.pkz, + self._dp_initial, + state.q_con, + state.pt, ) + def __call__(self, *args, **kwargs) -> None: + self.step_dynamics(*args, **kwargs) + + def _compute(self, state: DycoreState, timer: Timer) -> None: + last_step = False + self.compute_preamble(state) + for k_split in dace_no_unroll(range(self._k_split)): n_map = k_split + 1 last_step = k_split == self._k_split - 1 @@ -525,17 +503,21 @@ def _compute(self, state: DycoreState, timer: Timer): state.delp, self._dp_initial, ) + if __debug__: log_on_rank_0("DynCore") + with timer.clock("DynCore"): self.acoustic_dynamics( state, timestep=self._timestep / self._k_split, n_map=n_map, ) + if self.config.z_tracer: if __debug__: log_on_rank_0("TracerAdvection") + with timer.clock("TracerAdvection"): self._checkpoint_tracer_advection_in(state) self.tracer_advection( @@ -567,6 +549,7 @@ def _compute(self, state: DycoreState, timer: Timer): # "surface" array if __debug__: log_on_rank_0("Remapping") + with timer.clock("Remapping"): self._checkpoint_remapping_in(state) diff --git a/tests/mpi/test_doubly_periodic.py b/tests/mpi/test_doubly_periodic.py index 79022afd..6afeed64 100644 --- a/tests/mpi/test_doubly_periodic.py +++ b/tests/mpi/test_doubly_periodic.py @@ -16,6 +16,7 @@ TilePartitioner, ) from ndsl.grid import DampingCoefficients, GridData, MetricTerms +from ndsl.performance import NullTimer from pyfv3 import DynamicalCore, DynamicalCoreConfig @@ -134,4 +135,4 @@ def test_dycore_runs_one_step() -> None: ) # run one step - dycore.step_dynamics(state) + dycore.step_dynamics(state, NullTimer())