diff --git a/pyaerocom/colocation_auto.py b/pyaerocom/colocation_auto.py index f4b41b17f..9f3c0d3fe 100644 --- a/pyaerocom/colocation_auto.py +++ b/pyaerocom/colocation_auto.py @@ -593,7 +593,7 @@ def _instantiate_gridded_reader(self, what): reader_class = self._get_gridded_reader_class(what=what) if what == "model" and reader_class in self.MODELS_WITH_KWARGS: reader = reader_class( - data_id=data_id, data_dir=data_dir, **self.colocation_setup.model_read_kwargs + data_id=data_id, data_dir=data_dir, **self.colocation_setup.model_kwargs ) else: reader = reader_class(data_id=data_id, data_dir=data_dir) diff --git a/pyaerocom/colocation_setup.py b/pyaerocom/colocation_setup.py index 6d8c379c4..3203679e0 100644 --- a/pyaerocom/colocation_setup.py +++ b/pyaerocom/colocation_setup.py @@ -196,9 +196,7 @@ class ColocationSetup(BaseModel): active, only single year analysis are supported (i.e. provide int to :attr:`start` to specify the year and leave :attr:`stop` empty). model_kwargs: dict - Key word arguments to be given to the model reader class's read_var function - model_read_kwargs: dict - Key word arguments to be given to the model reader class's init function + Key word arguments to be given to the model reader class's read_var and init function gridded_reader_id : dict BETA: dictionary specifying which gridded reader is supposed to be used for model (and gridded obs) reading. Note: this is a workaround @@ -390,15 +388,11 @@ def validate_basedirs(cls, v): model_to_stp: bool = False model_ts_type_read: str | dict | None = None - model_read_aux: dict[ - str, dict[Literal["vars_required", "fun"], list[str] | Callable] - ] | None = {} + model_read_aux: ( + dict[str, dict[Literal["vars_required", "fun"], list[str] | Callable]] | None + ) = {} model_use_climatology: bool = False - model_kwargs: dict = {} - # model_read_kwargs are arguments that are sent to the model reader - model_read_kwargs: dict = {} - gridded_reader_id: dict[str, str] = {"model": "ReadGridded", "obs": "ReadGridded"} flex_ts_type: bool = True @@ -423,6 +417,19 @@ def validate_basedirs(cls, v): keep_data: bool = True add_meta: dict | None = {} + model_kwargs: dict = {} + + @field_validator("model_kwargs") + @classmethod + def validate_kwargs(cls, v): + forbidden = [ + "vert_which", + ] # Forbidden key names which are not found in colocation_setup.model_field, or has another name there + for key in v: + if key in list(cls.model_fields.keys()) + forbidden: + raise ValueError(f"Key {key} not allowed in model_kwargs") + return v + # Override __init__ to allow for positional arguments def __init__( self, diff --git a/tests/test_colocation_setup.py b/tests/test_colocation_setup.py index 2100cff45..3619e1782 100644 --- a/tests/test_colocation_setup.py +++ b/tests/test_colocation_setup.py @@ -1,6 +1,7 @@ from pathlib import Path import pytest +from pydantic import ValidationError from pyaerocom import const from pyaerocom.colocation_setup import ColocationSetup @@ -65,3 +66,15 @@ def test_ColocationSetup(stp: ColocationSetup, should_be: dict): assert Path(val) == Path(stp_dict["basedir_coldata"]) else: assert val == stp_dict[key], key + + +def test_ColocationSetup_model_kwargs_validationerror() -> None: + stp_dict = default_setup + + with pytest.raises(ValidationError): + stp_dict["model_kwargs"] = "not a dict" + stp = ColocationSetup(**stp_dict) + + with pytest.raises(ValidationError): + stp_dict["model_kwargs"] = {"emep_vars": {}, "ts_type": "daily"} + stp = ColocationSetup(**stp_dict)