Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 34 additions & 51 deletions pyfv3/stencils/fv_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ndsl.dsl.typing import Float, FloatField
from ndsl.grid import DampingCoefficients, GridData
from ndsl.logging import ndsl_log
from ndsl.performance import NullTimer, Timer
from ndsl.performance import Timer
from ndsl.stencils.basic_operations import copy
from ndsl.stencils.c2l_ord import CubedToLatLon
from ndsl.typing import Checkpointer, Communicator
Expand Down Expand Up @@ -55,9 +55,7 @@ def omega_from_w(delp: FloatField, delz: FloatField, w: FloatField, omega: Float
omega = delp / delz * w


def fvdyn_temporaries(
quantity_factory: QuantityFactory,
) -> Mapping[str, Quantity]:
def fvdyn_temporaries(quantity_factory: QuantityFactory) -> Mapping[str, Quantity]:
tmps = {}
for name in ["te_2d", "te0_2d", "wsd"]:
quantity = quantity_factory.zeros(
Expand All @@ -77,10 +75,10 @@ def fvdyn_temporaries(


@dace_inhibitor
def log_on_rank_0(msg: str):
def log_on_rank_0(message: str) -> None:
"""Print when rank is 0 - outside of DaCe critical path"""
if not MPI or MPI.COMM_WORLD.Get_rank() == 0:
ndsl_log.info(msg)
ndsl_log.info(message)


class DynamicalCore:
Expand All @@ -100,7 +98,7 @@ def __init__(
state: DycoreState,
timestep: timedelta,
checkpointer: Checkpointer | None = None,
):
) -> None:
"""
Args:
comm: object for cubed sphere or tile inter-process communication
Expand All @@ -127,7 +125,7 @@ def __init__(
obj=self,
config=stencil_factory.config.dace_config,
method_to_orchestrate="compute_preamble",
dace_compiletime_args=["state", "is_root_rank"],
dace_compiletime_args=["state"],
)

orchestrate(
Expand Down Expand Up @@ -181,7 +179,7 @@ def __init__(
# have not implemented, so they are hard-coded here.
self.call_checkpointer = checkpointer is not None
if checkpointer is None:
self.checkpointer: Checkpointer = NullCheckpointer()
self.checkpointer = NullCheckpointer()
else:
self.checkpointer = checkpointer
nested = False
Expand Down Expand Up @@ -336,7 +334,7 @@ def __init__(
def _get_da_min(self) -> float:
return self._da_min

def _checkpoint_fvdynamics(self, state: DycoreState, tag: str):
def _checkpoint_fvdynamics(self, state: DycoreState, tag: str) -> None:
if self.call_checkpointer:
self.checkpointer(
f"FVDynamics-{tag}",
Expand All @@ -355,10 +353,7 @@ def _checkpoint_fvdynamics(self, state: DycoreState, tag: str):
qvapor=state.qvapor,
)

def _checkpoint_remapping_in(
self,
state: DycoreState,
):
def _checkpoint_remapping_in(self, state: DycoreState) -> None:
if self.call_checkpointer:
self.checkpointer(
"Remapping-In",
Expand Down Expand Up @@ -386,10 +381,7 @@ def _checkpoint_remapping_in(
dp1=self._dp_initial,
)

def _checkpoint_remapping_out(
self,
state: DycoreState,
):
def _checkpoint_remapping_out(self, state: DycoreState) -> None:
if self.call_checkpointer:
self.checkpointer(
"Remapping-Out",
Expand All @@ -411,10 +403,7 @@ def _checkpoint_remapping_out(
dp1=self._dp_initial,
)

def _checkpoint_tracer_advection_in(
self,
state: DycoreState,
):
def _checkpoint_tracer_advection_in(self, state: DycoreState) -> None:
if self.call_checkpointer:
self.checkpointer(
"Tracer2D1L-In",
Expand All @@ -425,10 +414,7 @@ def _checkpoint_tracer_advection_in(
cyd=state.cyd,
)

def _checkpoint_tracer_advection_out(
self,
state: DycoreState,
):
def _checkpoint_tracer_advection_out(self, state: DycoreState) -> None:
if self.call_checkpointer:
self.checkpointer(
"Tracer2D1L-Out",
Expand All @@ -439,28 +425,22 @@ def _checkpoint_tracer_advection_out(
cyd=state.cyd,
)

def step_dynamics(
self,
state: DycoreState,
timer: Timer | None = None,
):
def step_dynamics(self, state: DycoreState, timer: Timer) -> None:
"""
Step the model state forward by one timestep.

Args:
state: model prognostic state and inputs
timer: keep time of model sections
"""
if timer is None:
timer = NullTimer()

self._checkpoint_fvdynamics(state=state, tag="In")
self._compute(state, timer)
self._checkpoint_fvdynamics(state=state, tag="Out")

def compute_preamble(self, state: DycoreState, is_root_rank: bool):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated: is_root_rank is unused and can thus be removed

def compute_preamble(self, state: DycoreState) -> None:
if self.config.hydrostatic:
raise NotImplementedError("Hydrostatic is not implemented")

if __debug__:
log_on_rank_0("FV Setup")

Expand Down Expand Up @@ -497,26 +477,24 @@ def compute_preamble(self, state: DycoreState, is_root_rank: bool):
"Dynamical Core (fv_dynamics): Adiabatic with positive kord_tm"
" is not implemented."
)
else:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated: no else needed because the if branch just raises

if __debug__:
log_on_rank_0("Adjust pt")
self._pt_to_potential_density_pt(
state.pkz,
self._dp_initial,
state.q_con,
state.pt,
)

def __call__(self, *args, **kwargs):
return self.step_dynamics(*args, **kwargs)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated: no need to explicitly return from the __call__(..) function. step_dynamics(..) also doesn't return any value.

if __debug__:
log_on_rank_0("Adjust pt")

def _compute(self, state: DycoreState, timer: Timer):
last_step = False
self.compute_preamble(
state,
is_root_rank=self.comm_rank == 0,
self._pt_to_potential_density_pt(
state.pkz,
self._dp_initial,
state.q_con,
state.pt,
)

def __call__(self, *args, **kwargs) -> None:
self.step_dynamics(*args, **kwargs)

def _compute(self, state: DycoreState, timer: Timer) -> None:
last_step = False
self.compute_preamble(state)

for k_split in dace_no_unroll(range(self._k_split)):
n_map = k_split + 1
last_step = k_split == self._k_split - 1
Expand All @@ -525,17 +503,21 @@ def _compute(self, state: DycoreState, timer: Timer):
state.delp,
self._dp_initial,
)

if __debug__:
log_on_rank_0("DynCore")

with timer.clock("DynCore"):
self.acoustic_dynamics(
state,
timestep=self._timestep / self._k_split,
n_map=n_map,
)

if self.config.z_tracer:
if __debug__:
log_on_rank_0("TracerAdvection")

with timer.clock("TracerAdvection"):
self._checkpoint_tracer_advection_in(state)
self.tracer_advection(
Expand Down Expand Up @@ -567,6 +549,7 @@ def _compute(self, state: DycoreState, timer: Timer):
# "surface" array
if __debug__:
log_on_rank_0("Remapping")

with timer.clock("Remapping"):
self._checkpoint_remapping_in(state)

Expand Down
3 changes: 2 additions & 1 deletion tests/mpi/test_doubly_periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TilePartitioner,
)
from ndsl.grid import DampingCoefficients, GridData, MetricTerms
from ndsl.performance import NullTimer
from pyfv3 import DynamicalCore, DynamicalCoreConfig


Expand Down Expand Up @@ -134,4 +135,4 @@ def test_dycore_runs_one_step() -> None:
)

# run one step
dycore.step_dynamics(state)
dycore.step_dynamics(state, NullTimer())