diff --git a/pyshield/_config.py b/pyshield/_config.py index 2be7cf02..b8cfe547 100644 --- a/pyshield/_config.py +++ b/pyshield/_config.py @@ -67,9 +67,6 @@ class PhysicsConfig: nwat: int = DEFAULT_INT schemes: List = None ntracers: int = int(len(tracer_variables)) - ntiw: int = DEFAULT_INT - ntcw: int = DEFAULT_INT - ntke: int = DEFAULT_INT do_qa: bool = DEFAULT_BOOL do_inline_mp: bool = False """Whether microphysics is inlined in the dycore""" diff --git a/pyshield/physics_state.py b/pyshield/physics_state.py index 5a413639..c28835a7 100644 --- a/pyshield/physics_state.py +++ b/pyshield/physics_state.py @@ -487,6 +487,7 @@ def init_from_storages( _field.metadata["units"], origin=sizer.get_origin(dims), extent=sizer.get_extent(dims), + backend=quantity_factory.backend, ) else: quantity = quantity_factory.zeros( diff --git a/pyshield/radiation/rad_astro.py b/pyshield/radiation/rad_astro.py index 4842b674..7557df61 100644 --- a/pyshield/radiation/rad_astro.py +++ b/pyshield/radiation/rad_astro.py @@ -1,7 +1,6 @@ import datetime import re from pathlib import Path -from typing import Union import numpy as np @@ -65,7 +64,7 @@ def date_to_julian(iyear: int, imonth: int, iday: int) -> int: def read_NOAA_solar_file( solar_fname: Path, -) -> dict[str : Union[float, int, dict[str : Union[int, float, dict[str:float]]]]]: +) -> dict[str, float | int | dict[str, int | float | dict[str, float]]]: """ function to read in a file of solar constants structured as: first_year last_year first_cycle last_cycle mean_value info_string @@ -98,7 +97,7 @@ def read_NOAA_solar_file( solar_constant_data["icy1"] = int(table_dat[2]) solar_constant_data["icy2"] = int(table_dat[3]) solar_constant_data["smean"] = float(table_dat[4]) - elif re.fullmatch(r"^(\*)\1{1,}$", table_dat[0]): + elif re.fullmatch(r"^(\*)\1{1,}$", table_dat[0]) is not None: break # end at the asterisks else: year = int(table_dat[0]) @@ -350,7 +349,7 @@ def solar_update( deltim: float, lsol_chg: bool, iyr_sav: int, - isolflg: bool, + isolflg: int, solar_constant_data: dict = None, ) -> tuple[float, float, float, float, float, float, int, int]: """ diff --git a/pyshield/radiation/rad_gases.py b/pyshield/radiation/rad_gases.py index 7bb67a84..0b2e8c0d 100644 --- a/pyshield/radiation/rad_gases.py +++ b/pyshield/radiation/rad_gases.py @@ -1,7 +1,6 @@ import os import re from pathlib import Path -from typing import Union import numpy as np @@ -32,7 +31,7 @@ F113VMR_DEF = 8.2000e-11 # gfdl 1999 value -def read_global_annual_co2(co2gbl_file: Path) -> dict[str : Union[Float, Int]]: +def read_global_annual_co2(co2gbl_file: Path) -> dict[str, Float | Int]: """ Function to read a text file of CO2 global half-yearly means and growth rates into a model. @@ -62,7 +61,7 @@ def read_global_annual_co2(co2gbl_file: Path) -> dict[str : Union[Float, Int]]: def read_monthly_resolved_co2( co2dat_file: Path, -) -> tuple[int, dict[Union[str, Int] : Union[Float, list[Float]]]]: +) -> tuple[int, dict[str | Int, Float | list[Float]]]: """ Function to read a text file of 15-degree CO2 monthly means into a model. Assumes a format of: @@ -103,7 +102,7 @@ def read_monthly_resolved_co2( def read_monthly_cycle_co2( co2cyc_file: Path, -) -> dict[Union[str, Int] : Union[Int, Float, dict[str : Union[Float, list[Float]]]]]: +) -> dict[str | Int, Int | Float | dict[str, Float | list[Float]]]: """ Function to read a text file of 15-degree CO2 monthly deviations into a model. Assumes a format of: @@ -175,7 +174,6 @@ def read_co2_files(input_dir: Path, prefix="") -> tuple[dict, dict, dict]: if co2_glob_file.is_file(): co2_glb_data = read_global_annual_co2(co2_glob_file) if monthly_co2_files: - co2_mvr_data = {} for monthly_file in monthly_co2_files: year, data = read_monthly_resolved_co2(input_dir.joinpath(monthly_file)) co2_mvr_data[year] = data @@ -242,6 +240,7 @@ def gas_init( float, float, float, + float, np.ndarray, np.ndarray, dict, @@ -336,6 +335,7 @@ def gas_init( cfc11, cfc12, cfc22, + cfc113, ccl4, co2_glb, co2_arr, diff --git a/pyshield/radiation/state.py b/pyshield/radiation/state.py index c2c0c886..33fbfb24 100644 --- a/pyshield/radiation/state.py +++ b/pyshield/radiation/state.py @@ -312,6 +312,7 @@ def init_from_storages( _field.metadata["units"], origin=sizer.get_origin(dims), extent=sizer.get_extent(dims), + backend=quantity_factory.backend, ) else: quantity = quantity_factory.zeros( @@ -358,7 +359,7 @@ def to_rterrtmgp_xr(self) -> xr.Dataset: for name, field_info in self.__dataclass_fields__.items(): if name not in ["quantity_factory", "np_like"]: if field_info.metadata["intent"] != "out": - if issubclass(field_info.type, Quantity): + if issubclass(field_info.type, Quantity): # type: ignore[arg-type] dims = [] slice_list = [] ndims = len(field_info.metadata["dims"]) @@ -366,22 +367,30 @@ def to_rterrtmgp_xr(self) -> xr.Dataset: for dim_name in field_info.metadata["dims"]: # dims.append(f"{dim_name}_{name}") if dim_name == "z_interface": - slice_list.append(self._np.s_[:]) + slice_list.append( + self._np.s_[:] # type: ignore[attr-defined] + ) nz = self._nz + 1 dims.append("level") elif dim_name == "z": - slice_list.append(self._np.s_[:-1]) + slice_list.append( + self._np.s_[:-1] # type: ignore[attr-defined] + ) dims.append("layer") elif "INTERFACE" in dim_name: - slice_list.append(self._np.s_[3:-3]) + slice_list.append( + self._np.s_[3:-3] # type: ignore[attr-defined] + ) else: - slice_list.append(self._np.s_[3:-4]) + slice_list.append( + self._np.s_[3:-4] # type: ignore[attr-defined] + ) # We have to reshape to get the max 2D shape rterrtmgp expects: if ndims == 3: # x-y-z array: newshape = (-1, nz) dims.insert(0, "column") elif ndims == 2: # x-y array: - newshape = (-1,) # noqa + newshape = (-1,) # type: ignore[assignment] dims.insert(0, "column") elif ndims == 1: # z-array newshape == (nz,) diff --git a/pyshield/stencils/gfdl_cld_microphysics/gfdl_cld_microphysics_state.py b/pyshield/stencils/gfdl_cld_microphysics/gfdl_cld_microphysics_state.py index 4cabeeeb..2fa20c03 100644 --- a/pyshield/stencils/gfdl_cld_microphysics/gfdl_cld_microphysics_state.py +++ b/pyshield/stencils/gfdl_cld_microphysics/gfdl_cld_microphysics_state.py @@ -519,6 +519,7 @@ def init_from_storages( _field.metadata["units"], origin=sizer.get_origin(_field.metadata["dims"]), extent=sizer.get_extent(_field.metadata["dims"]), + backend=quantity_factory.backend, ) return cls(**inputs) diff --git a/pyshield/stencils/pbl/satmedmfvdiff_state.py b/pyshield/stencils/pbl/satmedmfvdiff_state.py index 1fe94069..188da496 100644 --- a/pyshield/stencils/pbl/satmedmfvdiff_state.py +++ b/pyshield/stencils/pbl/satmedmfvdiff_state.py @@ -358,6 +358,7 @@ def init_from_storages( _field.metadata["units"], origin=sizer.get_origin(dims), extent=sizer.get_extent(dims), + backend=quantity_factory.backend, ) else: quantity = quantity_factory.zeros( diff --git a/pyshield/stencils/shallow_convection/samf_shalconv_state.py b/pyshield/stencils/shallow_convection/samf_shalconv_state.py index 2b5eb507..f11294ee 100644 --- a/pyshield/stencils/shallow_convection/samf_shalconv_state.py +++ b/pyshield/stencils/shallow_convection/samf_shalconv_state.py @@ -222,6 +222,7 @@ def init_from_storages( _field.metadata["units"], origin=sizer.get_origin(dims), extent=sizer.get_extent(dims), + backend=quantity_factory.backend, ) else: quantity = quantity_factory.zeros(dims, _field.metadata["units"]) diff --git a/pyshield/stencils/surface/sfc_state.py b/pyshield/stencils/surface/sfc_state.py index 79fed153..dd9492e2 100644 --- a/pyshield/stencils/surface/sfc_state.py +++ b/pyshield/stencils/surface/sfc_state.py @@ -481,6 +481,7 @@ def init_from_storages( _field.metadata["units"], origin=sizer.get_origin(dims), extent=sizer.get_extent(dims), + backend=quantity_factory.backend, ) else: quantity = quantity_factory.zeros( diff --git a/tests/savepoint/translate/translate_fv_update_phys.py b/tests/savepoint/translate/translate_fv_update_phys.py index 84e4a20c..2e690477 100644 --- a/tests/savepoint/translate/translate_fv_update_phys.py +++ b/tests/savepoint/translate/translate_fv_update_phys.py @@ -173,6 +173,7 @@ def compute_parallel(self, inputs, communicator): units="test", origin=(0, 0, 0), extent=storage.shape, + backend=self.grid.quantity_factory.backend, ) state = DycoreState(**inputs) self._base.compute_func = ApplyPhysicsToDycore( diff --git a/tests/savepoint/translate/translate_sedimentation.py b/tests/savepoint/translate/translate_sedimentation.py index 17bd081d..e74343e5 100644 --- a/tests/savepoint/translate/translate_sedimentation.py +++ b/tests/savepoint/translate/translate_sedimentation.py @@ -717,12 +717,14 @@ def compute(self, inputs): inputs[var], dims=[X_DIM, Y_DIM, Z_DIM], units="unknown", + backend=self.quantity_factory.backend, ) elif len(inputs[var].shape) == 2: inputs[var] = Quantity( inputs[var], dims=[X_DIM, Y_DIM], units="unknown", + backend=self.quantity_factory.backend, ) else: raise TypeError(