Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update prov attributes combine #1116

Merged
merged 4 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 94 additions & 8 deletions echopype/echodata/combine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import re
from collections import ChainMap
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from warnings import warn
Expand Down Expand Up @@ -643,6 +644,50 @@ def _capture_prov_attrs(
return prov_ds


def _get_prov_attrs(
ds: xr.Dataset, is_combined: bool = True
) -> Optional[Dict[str, List[Dict[str, str]]]]:
"""
Get the provenance attributes from the dataset.
This function is meant to be used on an already combined dataset.

Parameters
----------
ds : xr.Dataset
The provenance dataset to get attributes from
is_combined: bool
The flag to indicate if it's combined

Returns
-------
Dict[str, List[Dict[str, str]]]
The provenance attributes
"""

if is_combined:
attrs_dict = {}
for k, v in ds.data_vars.items():
# Go through each data variable and extract the attribute values
# based on the echodata group as stored in the variable attribute
if ED_GROUP in v.attrs:
ed_group = v.attrs[ED_GROUP]
if ed_group not in attrs_dict:
attrs_dict[ed_group] = []
# Store the values as a list of dictionary for each group
attrs_dict[ed_group].append([{k: i} for i in v.values])

# Merge the attributes for each group so it matches the
# attributes dict for later merging
return {
ed_group: [
dict(ChainMap(*v))
for _, v in pd.DataFrame.from_dict(attrs).to_dict(orient="list").items()
]
for ed_group, attrs in attrs_dict.items()
}
return None


def _combine(
sonar_model: str,
eds: List[EchoData] = [],
Expand Down Expand Up @@ -680,7 +725,27 @@ def _combine(
attrs_dict = {}

# Check if input data are combined datasets
all_combined = all(ed["Provenance"].attrs.get("is_combined", False) for ed in eds)
# Create combined mapping for later use
combined_mapping = []
for idx, ed in enumerate(eds):
is_combined = ed["Provenance"].attrs.get("is_combined", False)
combined_mapping.append(
{
"is_combined": is_combined,
"attrs_dict": _get_prov_attrs(ed["Provenance"], is_combined),
"echodata_filename": [str(s) for s in ed["Provenance"][ED_FILENAME].values]
if is_combined
else [echodata_filenames[idx]],
}
)
# Get single boolean value to see if there's any combined files
any_combined = any(d["is_combined"] for d in combined_mapping)

if any_combined:
# Fetches the true echodata filenames if there are any combined files
echodata_filenames = list(
itertools.chain.from_iterable([d[ED_FILENAME] for d in combined_mapping])
)

# Create Echodata tree dict
tree_dict = {}
Expand All @@ -697,8 +762,28 @@ def _combine(
]

if ds_list:
# Get all of the keys and attributes
ds_attrs = [ds.attrs for ds in ds_list]
if not any_combined:
# Get all of the keys and attributes
# for regular non combined echodata object
ds_attrs = [ds.attrs for ds in ds_list]
else:
# If there are any combined files,
# iterate through from mapping above
ds_attrs = []
for idx, ds in enumerate(ds_list):
# Retrieve the echodata attrs dict
# parsed from provenance group above
ed_attrs_dict = combined_mapping[idx]["attrs_dict"]
if ed_attrs_dict is not None:
# Set attributes to the appropriate group
# from echodata attrs provenance,
# set default empty dict for missing group
attrs = ed_attrs_dict.get(ed_group, {})
else:
# This is for non combined echodata object
attrs = [ds.attrs]
ds_attrs += attrs

# Attribute holding
attrs_dict[ed_group] = ds_attrs

Expand Down Expand Up @@ -753,15 +838,16 @@ def _combine(
# Data holding
tree_dict[ed_group] = combined_ds

if not all_combined:
# Capture provenance for all the attributes
prov_ds = _capture_prov_attrs(attrs_dict, echodata_filenames, sonar_model)
# Capture provenance for all the attributes
prov_ds = _capture_prov_attrs(attrs_dict, echodata_filenames, sonar_model)
if not any_combined:
# Update the provenance dataset with the captured data
prov_ds = tree_dict["Provenance"].assign(prov_ds)
else:
prov_ds = tree_dict["Provenance"]
prov_ds = tree_dict["Provenance"].drop_dims(ED_FILENAME).assign(prov_ds)

# Update filenames to iter integers
prov_ds[FILENAMES] = prov_ds[FILENAMES].copy(data=np.arange(*prov_ds[FILENAMES].shape))
prov_ds[FILENAMES] = prov_ds[FILENAMES].copy(data=np.arange(*prov_ds[FILENAMES].shape)) # noqa
tree_dict["Provenance"] = prov_ds

return tree_dict
Expand Down
41 changes: 41 additions & 0 deletions echopype/tests/echodata/test_echodata_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,25 @@ def attr_time_to_dt(time_str):
assert test_ds.identical(combined_group.drop_dims(grp_drop_dims))


def _check_prov_ds(prov_ds, eds):
"""Checks the provenance dataset against original echodata object"""
for i in range(prov_ds.dims["echodata_filename"]):
ed_ds = eds[i]
one_ds = prov_ds.isel(echodata_filename=i, filenames=i)
for key, value in one_ds.data_vars.items():
if key == "source_filenames":
ed_group = "Provenance"
assert np.array_equal(
ed_ds[ed_group][key].isel(filenames=0).values, value.values
)
else:
ed_group = value.attrs.get("echodata_group")
group_attrs = ed_ds[ed_group].attrs
expected_val = group_attrs[key]
if not isinstance(expected_val, str):
expected_val = str(expected_val)
assert str(value.values) == expected_val

@pytest.mark.parametrize("test_param", [
"single",
"multi",
Expand Down Expand Up @@ -263,6 +282,16 @@ def test_combine_echodata_combined_append(ek60_multi_test_data, test_param, sona
# First combined file
combined_ed = echopype.combine_echodata(eds[:2])
combined_ed.to_zarr(first_zarr, overwrite=True)

# Checks for Provenance group
prov_ds = combined_ed["Provenance"]
for _, n_val in prov_ds.dims.items():
# Both dims of filenames and echodata filename
# should be 2 at this point
assert n_val == 2

_check_prov_ds(prov_ds, eds)


# Second combined file
combined_ed_other = echopype.combine_echodata(eds[2:])
Expand All @@ -271,8 +300,11 @@ def test_combine_echodata_combined_append(ek60_multi_test_data, test_param, sona
combined_ed = echopype.open_converted(first_zarr)
combined_ed_other = echopype.open_converted(second_zarr)

# Set expected values for Provenance
expected_n_vals = 4
if test_param == "single":
data_inputs = [combined_ed, eds[2]]
expected_n_vals = 3
elif test_param == "multi":
data_inputs = [combined_ed, eds[2], eds[3]]
else:
Expand Down Expand Up @@ -313,6 +345,15 @@ def test_combine_echodata_combined_append(ek60_multi_test_data, test_param, sona

filt_combined = combined_ed2[group_path].sel(ping_time=combined_ed[group_path].ping_time)
assert filt_combined.identical(combined_ed[group_path])

# Checks for Provenance group
prov_ds = combined_ed2["Provenance"]
for _, n_val in prov_ds.dims.items():
# Both dims of filenames and echodata filename
# should be expected_n_vals at this point
assert n_val == expected_n_vals

_check_prov_ds(prov_ds, eds)


def test_combine_echodata_channel_selection():
Expand Down