diff --git a/particle_tracking_manager/models/opendrift/config.json b/particle_tracking_manager/models/opendrift/config.json index 6b1edae..9d12f85 100644 --- a/particle_tracking_manager/models/opendrift/config.json +++ b/particle_tracking_manager/models/opendrift/config.json @@ -204,15 +204,5 @@ "default": "low", "ptm_level": 3, "description": "Log verbosity" - }, - "output_format": { - "type": "enum", - "enum": [ - "netcdf", - "parquet" - ], - "default": "netcdf", - "description": "Output file format. Options are \"netcdf\" or \"parquet\".", - "ptm_level": 2 } } diff --git a/particle_tracking_manager/models/opendrift/opendrift.py b/particle_tracking_manager/models/opendrift/opendrift.py index de8121d..5b46d53 100644 --- a/particle_tracking_manager/models/opendrift/opendrift.py +++ b/particle_tracking_manager/models/opendrift/opendrift.py @@ -144,8 +144,6 @@ class OpenDriftModel(ParticleTrackingManager): Oil mass is biodegraded (eaten by bacteria). log : str, optional Options are "low" and "high" verbosity for log, by default "low" - output_format : str, default "netcdf" - Name of input/output module type to use for writing Lagrangian model output. Default is "netcdf". Notes ----- @@ -166,7 +164,6 @@ class OpenDriftModel(ParticleTrackingManager): o: Union[OceanDrift, Leeway, LarvalFish, OpenOil] horizontal_diffusivity: Optional[float] config_model: dict - output_format: str def __init__( self, @@ -224,7 +221,6 @@ def __init__( ], biodegradation: bool = config_model["biodegradation"]["default"], log: str = config_model["log"]["default"], - output_format: str = config_model["output_format"]["default"], **kw, ) -> None: """Inputs for OpenDrift model.""" @@ -252,6 +248,9 @@ def __init__( # so do this before super initialization self.__dict__["drift_model"] = drift_model + # I left this code here but it isn't used for now + # it will be used if we can export to parquet/netcdf directly + # without needing to resave the file with extra config # # need output_format defined right away # self.__dict__["output_format"] = output_format @@ -263,7 +262,7 @@ def __init__( elif self.drift_model == "OceanDrift": o = OceanDrift( - loglevel=self.loglevel + loglevel=self.loglevel, ) # , output_format=self.output_format) elif self.drift_model == "LarvalFish": @@ -993,12 +992,6 @@ def run_drifters(self): self.o._config = config_input_to_opendrift # only OpenDrift config - output_file = ( - self.output_file - or f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}" - ) - output_file_initial = f"{output_file}_initial" + ".nc" - # initially output to netcdf even if parquet has been selected # since I do this weird 2 step saving process @@ -1014,39 +1007,33 @@ def run_drifters(self): time_step_output=self.time_step_output, steps=self.steps, export_variables=self.export_variables, - outfile=output_file_initial, + outfile=self.output_file_initial, ) self.o._config = full_config # reinstate config # open outfile file and add config to it # config can't be present earlier because it breaks opendrift - ds = xr.open_dataset(output_file_initial) + ds = xr.open_dataset(self.output_file_initial) for k, v in self.drift_model_config(): if isinstance(v, (bool, type(None), pd.Timestamp, pd.Timedelta)): v = str(v) ds.attrs[f"ptm_config_{k}"] = v if self.output_format == "netcdf": - output_file += ".nc" - elif self.output_format == "parquet": - output_file += ".parq" - else: - raise ValueError(f"output_format {self.output_format} not recognized.") - - if self.output_format == "netcdf": - ds.to_netcdf(output_file) + ds.to_netcdf(self.output_file) elif self.output_format == "parquet": - ds.to_dataframe().to_parquet(output_file) + ds.to_dataframe().to_parquet(self.output_file) else: raise ValueError(f"output_format {self.output_format} not recognized.") # update with new path name - self.o.outfile_name = output_file + self.o.outfile_name = self.output_file + self.output_file = self.output_file try: # remove initial file to save space - os.remove(output_file_initial) + os.remove(self.output_file_initial) except PermissionError: # windows issue pass diff --git a/particle_tracking_manager/the_manager.py b/particle_tracking_manager/the_manager.py index d4fee25..da35c6d 100644 --- a/particle_tracking_manager/the_manager.py +++ b/particle_tracking_manager/the_manager.py @@ -137,6 +137,8 @@ class ParticleTrackingManager: will be less accurate, especially in the tidal flat regions of the model. output_file : Optional[str], optional Name of output file to save, by default None. If None, default is set in the model. Without any suffix. + output_format : str, default "netcdf" + Name of input/output module type to use for writing Lagrangian model output. Default is "netcdf". Notes ----- @@ -160,6 +162,8 @@ class ParticleTrackingManager: config_ptm: dict config_model: Optional[dict] seed_seafloor: bool + output_file: str + output_format: str def __init__( self, @@ -187,6 +191,7 @@ def __init__( vertical_mixing: bool = config_ptm["vertical_mixing"]["default"], use_static_masks: bool = config_ptm["use_static_masks"]["default"], output_file: Optional[str] = config_ptm["output_file"]["default"], + output_format: str = config_ptm["output_format"]["default"], **kw, ) -> None: """Inputs necessary for any particle tracking.""" @@ -364,6 +369,39 @@ def __setattr__(self, name: str, value) -> None: self.__dict__["lon"] += 360 self.config_ptm["lon"]["value"] += 360 # this isn't really used + if name in ["output_file", "output_format"]: + # import pdb; pdb.set_trace() + + # remove netcdf suffix if it is there to just have name + # by this point, output_file should already be a filename like what is + # available here, from OpenDrift (if run from there) + if self.output_file is not None: + output_file = self.output_file.rstrip(".nc") + else: + output_file = ( + f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}" + ) + + # make new attribute for initial output file + self.output_file_initial = str( + pathlib.Path(f"{output_file}_initial").with_suffix(".nc") + ) + + if self.output_format is not None: + if self.output_format == "netcdf": + output_file = str(pathlib.Path(output_file).with_suffix(".nc")) + elif self.output_format == "parquet": + output_file = str( + pathlib.Path(output_file).with_suffix(".parq") + ) + else: + raise ValueError( + f"output_format {self.output_format} not recognized." + ) + + self.__dict__["output_file"] = output_file + self.config_ptm["output_file"]["value"] = output_file + if name == "surface_only" and value: self.logger.info( "Overriding values for do3D, z, and vertical_mixing because surface_only is True (to False, 0, False)." diff --git a/particle_tracking_manager/the_manager_config.json b/particle_tracking_manager/the_manager_config.json index 595371e..064460d 100644 --- a/particle_tracking_manager/the_manager_config.json +++ b/particle_tracking_manager/the_manager_config.json @@ -177,5 +177,15 @@ "default": "None", "description": "Name of file to write output to. If None, default name is used.", "ptm_level": 3 + }, + "output_format": { + "type": "enum", + "enum": [ + "netcdf", + "parquet" + ], + "default": "netcdf", + "description": "Output file format. Options are \"netcdf\" or \"parquet\".", + "ptm_level": 2 } } diff --git a/tests/test_opendrift.py b/tests/test_opendrift.py index 5a87597..3cb26a4 100644 --- a/tests/test_opendrift.py +++ b/tests/test_opendrift.py @@ -322,6 +322,16 @@ def test_output_format(): assert m.output_format == "parquet" +def test_output_file(): + """make sure output file is parquet if output_format is parquet""" + + m = OpenDriftModel(output_format="parquet") + assert m.output_file.endswith(".parq") + + m = OpenDriftModel(output_format="netcdf") + assert m.output_file.endswith(".nc") + + def test_horizontal_diffusivity_logic(): """Check logic for using default horizontal diff values for known models."""