Skip to content

Commit

Permalink
Merge pull request #32 from axiom-data-science/change_config
Browse files Browse the repository at this point in the history
moved output_format to manager config
  • Loading branch information
kthyng authored Nov 13, 2024
2 parents e4121c6 + 1655af6 commit 37f7544
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 34 deletions.
10 changes: 0 additions & 10 deletions particle_tracking_manager/models/opendrift/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,5 @@
"default": "low",
"ptm_level": 3,
"description": "Log verbosity"
},
"output_format": {
"type": "enum",
"enum": [
"netcdf",
"parquet"
],
"default": "netcdf",
"description": "Output file format. Options are \"netcdf\" or \"parquet\".",
"ptm_level": 2
}
}
35 changes: 11 additions & 24 deletions particle_tracking_manager/models/opendrift/opendrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ class OpenDriftModel(ParticleTrackingManager):
Oil mass is biodegraded (eaten by bacteria).
log : str, optional
Options are "low" and "high" verbosity for log, by default "low"
output_format : str, default "netcdf"
Name of input/output module type to use for writing Lagrangian model output. Default is "netcdf".
Notes
-----
Expand All @@ -166,7 +164,6 @@ class OpenDriftModel(ParticleTrackingManager):
o: Union[OceanDrift, Leeway, LarvalFish, OpenOil]
horizontal_diffusivity: Optional[float]
config_model: dict
output_format: str

def __init__(
self,
Expand Down Expand Up @@ -224,7 +221,6 @@ def __init__(
],
biodegradation: bool = config_model["biodegradation"]["default"],
log: str = config_model["log"]["default"],
output_format: str = config_model["output_format"]["default"],
**kw,
) -> None:
"""Inputs for OpenDrift model."""
Expand Down Expand Up @@ -252,6 +248,9 @@ def __init__(
# so do this before super initialization
self.__dict__["drift_model"] = drift_model

# I left this code here but it isn't used for now
# it will be used if we can export to parquet/netcdf directly
# without needing to resave the file with extra config
# # need output_format defined right away
# self.__dict__["output_format"] = output_format

Expand All @@ -263,7 +262,7 @@ def __init__(

elif self.drift_model == "OceanDrift":
o = OceanDrift(
loglevel=self.loglevel
loglevel=self.loglevel,
) # , output_format=self.output_format)

elif self.drift_model == "LarvalFish":
Expand Down Expand Up @@ -993,12 +992,6 @@ def run_drifters(self):

self.o._config = config_input_to_opendrift # only OpenDrift config

output_file = (
self.output_file
or f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}"
)
output_file_initial = f"{output_file}_initial" + ".nc"

# initially output to netcdf even if parquet has been selected
# since I do this weird 2 step saving process

Expand All @@ -1014,39 +1007,33 @@ def run_drifters(self):
time_step_output=self.time_step_output,
steps=self.steps,
export_variables=self.export_variables,
outfile=output_file_initial,
outfile=self.output_file_initial,
)

self.o._config = full_config # reinstate config

# open outfile file and add config to it
# config can't be present earlier because it breaks opendrift
ds = xr.open_dataset(output_file_initial)
ds = xr.open_dataset(self.output_file_initial)
for k, v in self.drift_model_config():
if isinstance(v, (bool, type(None), pd.Timestamp, pd.Timedelta)):
v = str(v)
ds.attrs[f"ptm_config_{k}"] = v

if self.output_format == "netcdf":
output_file += ".nc"
elif self.output_format == "parquet":
output_file += ".parq"
else:
raise ValueError(f"output_format {self.output_format} not recognized.")

if self.output_format == "netcdf":
ds.to_netcdf(output_file)
ds.to_netcdf(self.output_file)
elif self.output_format == "parquet":
ds.to_dataframe().to_parquet(output_file)
ds.to_dataframe().to_parquet(self.output_file)
else:
raise ValueError(f"output_format {self.output_format} not recognized.")

# update with new path name
self.o.outfile_name = output_file
self.o.outfile_name = self.output_file
self.output_file = self.output_file

try:
# remove initial file to save space
os.remove(output_file_initial)
os.remove(self.output_file_initial)
except PermissionError:
# windows issue
pass
Expand Down
38 changes: 38 additions & 0 deletions particle_tracking_manager/the_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class ParticleTrackingManager:
will be less accurate, especially in the tidal flat regions of the model.
output_file : Optional[str], optional
Name of output file to save, by default None. If None, default is set in the model. Without any suffix.
output_format : str, default "netcdf"
Name of input/output module type to use for writing Lagrangian model output. Default is "netcdf".
Notes
-----
Expand All @@ -160,6 +162,8 @@ class ParticleTrackingManager:
config_ptm: dict
config_model: Optional[dict]
seed_seafloor: bool
output_file: str
output_format: str

def __init__(
self,
Expand Down Expand Up @@ -187,6 +191,7 @@ def __init__(
vertical_mixing: bool = config_ptm["vertical_mixing"]["default"],
use_static_masks: bool = config_ptm["use_static_masks"]["default"],
output_file: Optional[str] = config_ptm["output_file"]["default"],
output_format: str = config_ptm["output_format"]["default"],
**kw,
) -> None:
"""Inputs necessary for any particle tracking."""
Expand Down Expand Up @@ -364,6 +369,39 @@ def __setattr__(self, name: str, value) -> None:
self.__dict__["lon"] += 360
self.config_ptm["lon"]["value"] += 360 # this isn't really used

if name in ["output_file", "output_format"]:
# import pdb; pdb.set_trace()

# remove netcdf suffix if it is there to just have name
# by this point, output_file should already be a filename like what is
# available here, from OpenDrift (if run from there)
if self.output_file is not None:
output_file = self.output_file.rstrip(".nc")
else:
output_file = (
f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}"
)

# make new attribute for initial output file
self.output_file_initial = str(
pathlib.Path(f"{output_file}_initial").with_suffix(".nc")
)

if self.output_format is not None:
if self.output_format == "netcdf":
output_file = str(pathlib.Path(output_file).with_suffix(".nc"))
elif self.output_format == "parquet":
output_file = str(
pathlib.Path(output_file).with_suffix(".parq")
)
else:
raise ValueError(
f"output_format {self.output_format} not recognized."
)

self.__dict__["output_file"] = output_file
self.config_ptm["output_file"]["value"] = output_file

if name == "surface_only" and value:
self.logger.info(
"Overriding values for do3D, z, and vertical_mixing because surface_only is True (to False, 0, False)."
Expand Down
10 changes: 10 additions & 0 deletions particle_tracking_manager/the_manager_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,15 @@
"default": "None",
"description": "Name of file to write output to. If None, default name is used.",
"ptm_level": 3
},
"output_format": {
"type": "enum",
"enum": [
"netcdf",
"parquet"
],
"default": "netcdf",
"description": "Output file format. Options are \"netcdf\" or \"parquet\".",
"ptm_level": 2
}
}
10 changes: 10 additions & 0 deletions tests/test_opendrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,16 @@ def test_output_format():
assert m.output_format == "parquet"


def test_output_file():
"""make sure output file is parquet if output_format is parquet"""

m = OpenDriftModel(output_format="parquet")
assert m.output_file.endswith(".parq")

m = OpenDriftModel(output_format="netcdf")
assert m.output_file.endswith(".nc")


def test_horizontal_diffusivity_logic():
"""Check logic for using default horizontal diff values for known models."""

Expand Down

0 comments on commit 37f7544

Please sign in to comment.