Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/translate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ jobs:
-v -s --data_path=${{ env.TEST_DATA_PATH }} \
--backend=numpy \
--threshold_overrides_file=./tests/savepoint/translate/overrides/standard.yaml \
--no_legacy_namelist \
./tests/savepoint

- name: Orchestrated dace:cpu Translate Test
Expand All @@ -125,4 +126,5 @@ jobs:
-vvv -x -s --data_path=${{ env.TEST_DATA_PATH }} \
--backend=dace:cpu \
--threshold_overrides_file=./tests/savepoint/translate/overrides/standard.yaml \
--no_legacy_namelist \
./tests/savepoint
123 changes: 52 additions & 71 deletions pyshield/_config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from __future__ import annotations

import dataclasses
from enum import Enum, unique
from typing import List, Optional, Tuple

import f90nml
from dacite import Config, from_dict

from ndsl import MetaEnumStr
from ndsl.namelist import Namelist
from ndsl.utils import f90nml_as_dict


DEFAULT_INT = 0
DEFAULT_BOOL = False
DEFAULT_SCHEMES = ["GFS_microphysics"]
DEFAULT_PHYS_NML_GROUPS = (
"main_nml",
"coupler_nml",
"gfdl_cloud_microphysics_nml",
"integ_phys_nml",
"gfs_physics_nml",
)


@unique
Expand Down Expand Up @@ -144,6 +154,7 @@ class PhysicsConfig:
clin: float = 4.8
""""c" in lin 1983, 4.8 -- > 6. (to enhance ql -- > qs)"""
namelist_override: Optional[str] = None
target_nml_groups: Optional[Tuple[str, ...]] = DEFAULT_PHYS_NML_GROUPS
daily_mean: bool = DEFAULT_BOOL # flag to replace cosz with daily mean value

def __post_init__(self):
Expand All @@ -160,80 +171,50 @@ def __post_init__(self):
f90_nml = f90nml.read(self.namelist_override)
except FileNotFoundError:
print(f"{self.namelist_override} does not exist")
physics_config = self.from_f90nml(f90_nml)
# TODO: Find a better way to do below. Passing self.* as an argument
# to a class function of the same class is always a bit fishy.
physics_config = self.from_f90nml(f90_nml, self.target_nml_groups)
for var in physics_config.__dict__.keys():
setattr(self, var, physics_config.__dict__[var])

@classmethod
def from_f90nml(self, f90_namelist: f90nml.Namelist) -> "PhysicsConfig":
namelist = Namelist.from_f90nml(f90_namelist)
return self.from_namelist(namelist)
def from_f90nml(
cls,
nml: f90nml.Namelist,
target_groups: Tuple[str, ...] | None = DEFAULT_PHYS_NML_GROUPS,
) -> PhysicsConfig:
"""Uses the nml to create a PhysicsConfig.

Args:
nml: f90nml.Namelist
target_groups: Tuple[str,...] | None
This list will be used to specify which groups in the nml to
use when initializing the PhysicsConfig. If None, all
groups will be used. (Default: DEFAULT_PHYS_NML_GROUPS)
"""
groups = list(target_groups) if target_groups is not None else None
nml_dict = f90nml_as_dict(nml, flatten=True, target_groups=groups)
nml_dict["target_nml_groups"] = target_groups
return cls.from_dict(nml_dict)

@classmethod
def from_namelist(cls, namelist: Namelist) -> "PhysicsConfig":
return cls(
dt_atmos=namelist.dt_atmos,
hydrostatic=namelist.hydrostatic,
npx=namelist.npx,
npy=namelist.npy,
npz=namelist.npz,
nwat=namelist.nwat,
do_qa=namelist.do_qa,
c_cracw=namelist.c_cracw,
c_paut=namelist.c_paut,
c_pgacs=namelist.c_pgacs,
c_psaci=namelist.c_psaci,
ccn_l=namelist.ccn_l,
ccn_o=namelist.ccn_o,
const_vg=namelist.const_vg,
const_vi=namelist.const_vi,
const_vr=namelist.const_vr,
const_vs=namelist.const_vs,
vs_fac=namelist.vs_fac,
vg_fac=namelist.vg_fac,
vi_fac=namelist.vi_fac,
vr_fac=namelist.vr_fac,
de_ice=namelist.de_ice,
layout=namelist.layout,
tau_imlt=namelist.tau_imlt,
tau_i2s=namelist.tau_i2s,
tau_g2v=namelist.tau_g2v,
tau_v2g=namelist.tau_v2g,
ql_mlt=namelist.ql_mlt,
qs_mlt=namelist.qs_mlt,
t_sub=namelist.t_sub,
qi_gen=namelist.qi_gen,
qi_lim=namelist.qi_lim,
qi0_max=namelist.qi0_max,
rad_snow=namelist.rad_snow,
rad_rain=namelist.rad_rain,
dw_ocean=namelist.dw_ocean,
dw_land=namelist.dw_land,
tau_l2v=namelist.tau_l2v,
c2l_ord=namelist.c2l_ord,
do_sedi_heat=namelist.do_sedi_heat,
do_sedi_w=namelist.do_sedi_w,
fast_sat_adj=namelist.fast_sat_adj,
qc_crt=namelist.qc_crt,
fix_negative=namelist.fix_negative,
irain_f=namelist.irain_f,
mp_time=namelist.mp_time,
prog_ccn=namelist.prog_ccn,
qi0_crt=namelist.qi0_crt,
qs0_crt=namelist.qs0_crt,
rh_inc=namelist.rh_inc,
rh_inr=namelist.rh_inr,
rthresh=namelist.rthresh,
sedi_transport=namelist.sedi_transport,
use_ppm=namelist.use_ppm,
vg_max=namelist.vg_max,
vi_max=namelist.vi_max,
vr_max=namelist.vr_max,
vs_max=namelist.vs_max,
z_slope_ice=namelist.z_slope_ice,
z_slope_liq=namelist.z_slope_liq,
tice=namelist.tice,
alin=namelist.alin,
clin=namelist.clin,
daily_mean=namelist.daily_mean,
def from_dict(
cls,
data: dict,
) -> PhysicsConfig:
"""Create a PhysicsConfig from the given data.

Args:
data: "flattened" dictionary where the keys match the class member variables
"""
# NOTE: We're setting strict to False so that extra keys in the data are
# ignored. Eventually, we'd like to turn this to True once we move away from
# expecting dicts that are basically flattened f90nml files.
dacite_config = Config(
strict=False,
type_hooks={
Tuple[int, int]: lambda x: tuple(x),
Tuple[str, ...]: lambda x: tuple(x) if x is not None else None,
},
)
return from_dict(data_class=PhysicsConfig, data=data, config=dacite_config)
8 changes: 4 additions & 4 deletions tests/savepoint/translate/translate_atmos_phy_statein.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ def __init__(self, grid, namelist, stencil_factory):
"pt": {
"serialname": "IPD_tgrs",
"out_roll_zero": True,
"kend": namelist.npz - 1,
"kend": self.config.npz - 1,
"order": "F",
},
"qgrs": {
"serialname": "IPD_qgrs",
"kend": namelist.npz,
"kend": self.config.npz,
"order": "F",
"manual": True,
},
"delp": {
"serialname": "IPD_prsl",
"out_roll_zero": True,
"kend": namelist.npz - 1,
"kend": self.config.npz - 1,
"order": "F",
},
}
Expand Down Expand Up @@ -82,7 +82,7 @@ def post_process_qgrs(self, inputs):
self.update_info(info, inputs)
ds = self.grid.compute_dict()
ds.update(info)
k_length = info["kend"] if "kend" in info else self.namelist.npz
k_length = info["kend"] if "kend" in info else self.config.npz
index_order = info["order"] if "order" in info else "C"
ij_slice = self.grid.slice_dict(ds)
qgrs = qgrs[ij_slice[0], ij_slice[1], 0:k_length, :]
Expand Down
4 changes: 2 additions & 2 deletions tests/savepoint/translate/translate_dcyc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def __init__(self, grid, namelist, stencil_factory):
"adjvisdfd": {"shield": True},
}
self.grid_indexing = stencil_factory.grid_indexing
self._daily_mean = namelist.daily_mean
self._daily_mean = self.config.daily_mean
self.compute_func = stencil_factory.from_origin_domain(
interpolate_radiation,
externals={"daily_mean": namelist.daily_mean},
externals={"daily_mean": self.config.daily_mean},
origin=self.grid_indexing.origin_full(),
domain=self.grid_indexing.domain_full(),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/savepoint/translate/translate_fillgfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, grid, namelist, stencil_factory):
"q": {"serialname": "IPD_gq0"},
}
self.out_vars = {
"q": {"serialname": "IPD_qvapor", "kend": namelist.npz - 1},
"q": {"serialname": "IPD_qvapor", "kend": self.config.npz - 1},
}
self.grid_indexing = stencil_factory.grid_indexing
self.compute_func = stencil_factory.from_origin_domain(
Expand Down
3 changes: 2 additions & 1 deletion tests/savepoint/translate/translate_fpvs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from f90nml import Namelist

from ndsl import Namelist, StencilFactory
from ndsl import StencilFactory
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.gt4py import FORWARD, PARALLEL, computation, interval
from ndsl.dsl.typing import FloatField, FloatFieldIJ
Expand Down
18 changes: 6 additions & 12 deletions tests/savepoint/translate/translate_fv_update_phys.py
Comment thread
jjuyeonkim marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import dataclasses

import numpy as np
from f90nml import Namelist

import ndsl.dsl.gt4py_utils as utils
from ndsl import Namelist, Quantity, StencilFactory
from ndsl import Quantity, StencilFactory
from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from ndsl.dsl.typing import FloatField, FloatFieldIJ
from ndsl.utils import safe_assign_array
Expand All @@ -14,12 +15,6 @@
)


try:
import cupy as cp
except ImportError:
cp = None


try:
import cupy as cp
except ImportError:
Expand Down Expand Up @@ -78,7 +73,7 @@ def __init__(
"iend": self.grid_indexing.iec,
"jstart": self.grid_indexing.jsc,
"jend": self.grid_indexing.jec,
"kend": namelist.npz,
"kend": self.config.npz,
"kaxis": 1,
},
"delp": {},
Expand All @@ -89,7 +84,7 @@ def __init__(
"iend": self.grid_indexing.iec + 1,
"jstart": self.grid_indexing.jsc - 1,
"jend": self.grid_indexing.jec + 1,
"kend": namelist.npz + 1,
"kend": self.config.npz + 1,
"kaxis": 1,
},
"pk": grid.compute_buffer_k_dict(),
Expand All @@ -107,7 +102,6 @@ def __init__(
"ua": {},
"va": {},
}
self.namelist = namelist

def transform_dwind_serialized_data(self, data):
return transform_dwind_serialized_data(
Expand Down Expand Up @@ -185,7 +179,7 @@ def compute_parallel(self, inputs, communicator):
self.stencil_factory,
self.grid.quantity_factory,
self.grid.grid_data,
self.namelist,
self.config,
communicator,
self.grid.driver_grid_data,
state,
Expand Down Expand Up @@ -213,7 +207,7 @@ def compute_parallel(self, inputs, communicator):
tendencies["u_dt"],
tendencies["v_dt"],
tendencies["t_dt"],
dt=float(self.namelist.dt_atmos),
dt=float(self.config.dt_atmos),
)
out = {}
ds = self.grid.default_domain_dict()
Expand Down
Loading