Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NWB support III - provenance writing #17

Merged
merged 2 commits into from
Sep 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 61 additions & 18 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ def write_labels_to_nwb(
https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile

Defaults to None and default values are used to generate the nwb file.

pose_estimation_metadata (dict): This argument has a dual purpose:

1) It can be used to pass time information about the video which is
necessary for synchronizing frames in pose estimation tracking to other
modalities. Either the video timestamps can be passed to
This can be used to pass the timestamps with the key `video_timestamps`
or the sampling rate with key`video_sample_rate`.

e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps)
or pose_estimation_metadata["video_sample_rate] = 15 # In Hz

2) The other use of this dictionary is to ovewrite sleap-io default
arguments for the PoseEstimation container.
see https://github.com/rly/ndx-pose for a full list or arguments.
"""

nwb_file_kwargs = nwb_file_kwargs or dict()
Expand All @@ -129,7 +144,6 @@ def write_labels_to_nwb(
)

nwbfile = NWBFile(**nwb_file_kwargs)

nwbfile = append_labels_data_to_nwb(labels, nwbfile, pose_estimation_metadata)

with NWBHDF5IO(str(nwbfile_path), "w") as io:
Expand All @@ -145,12 +159,33 @@ def append_labels_data_to_nwb(
labels (Labels): A general labels object
nwbfile (NWBFile): And in-memory nwbfile where the data is to be appended.

pose_estimation_metadata (dict): This argument has a dual purpose:

1) It can be used to pass time information about the video which is
necessary for synchronizing frames in pose estimation tracking to other
modalities. Either the video timestamps can be passed to
This can be used to pass the timestamps with the key `video_timestamps`
or the sampling rate with key`video_sample_rate`.

e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps)
or pose_estimation_metadata["video_sample_rate] = 15 # In Hz

2) The other use of this dictionary is to ovewrite sleap-io default
arguments for the PoseEstimation container.
see https://github.com/rly/ndx-pose for a full list or arguments.

Returns:
NWBFile: An in-memory nwbfile with the data from the labels object appended.
"""

pose_estimation_metadata = pose_estimation_metadata or dict()

# Extract default metadata
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version

labels_data_df = _extract_predicted_instances_data(labels)

# For every video create a processing module
Expand All @@ -162,20 +197,27 @@ def append_labels_data_to_nwb(
processing_module_name, nwbfile
)

# Propagate video metadata
default_metadata["original_videos"] = [f"{video.filename}"] # type: ignore
default_metadata["labeled_videos"] = [f"{video.filename}"] # type: ignore

Copy link
Contributor Author

@h-mayorquin h-mayorquin Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was getting error sleap_io/io/nwb.py:201: error: Incompatible types in assignment (expression has type "List[str]", target has type "str") in liens 200 and 201. Is there are more elegant way of avoiding this? I guess I have to define the dict above in a specific way but I expect an heterogeneous input for each of the keys.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, yeah, the typing for complicated dict value types is a bit hairy. Technically I suppose this is intended to discourage the use of dicts as generic grab bags of different data structures that aren't documented explicitly, but it sacrifices a lot of flexibility.

Ignoring is fine for now. In my view, typing is a "best effort" practice and should be ignored if it makes the code harder to read or iterate on.

# Overwrite default with the user provided metadata
default_metadata.update(pose_estimation_metadata)

# For every track in that video create a PoseEstimation container
name_of_tracks_in_video = (
labels_data_df[video.filename]
.columns.get_level_values("track_name")
.unique()
)

# For every track in that video create a PoseEstimation container
for track_index, track_name in enumerate(name_of_tracks_in_video):
pose_estimation_container = build_pose_estimation_container_for_track(
labels_data_df,
labels,
track_name,
video,
pose_estimation_metadata,
default_metadata,
)
nwb_processing_module.add(pose_estimation_container)

Expand Down Expand Up @@ -260,19 +302,14 @@ def build_pose_estimation_container_for_track(
)

# Arrange and mix metadata

pose_estimation_container_kwargs = dict(
name=f"track={track_name}",
description=f"Estimated positions of {skeleton.name} in video {video_path.name}",
pose_estimation_series=pose_estimation_series_list,
nodes=skeleton.node_names,
edges=np.array(skeleton.edge_inds).astype("uint64"),
source_software="SLEAP",
original_videos=[f"{video.filename}"],
labeled_videos=[f"{video.filename}"],
# dimensions=np.array([[video.backend.height, video.backend.width]]),
# scorer=str(labels.provenance),
# source_software_version=f"{sleap.__version__}
)

pose_estimation_container_kwargs.update(**pose_estimation_metadata)
Expand Down Expand Up @@ -302,28 +339,32 @@ def build_track_pose_estimation_list(

pose_estimation_series_list: List[PoseEstimationSeries] = []
for node_name in name_of_nodes_in_track:
# Add predicted instances only

data_for_node = track_data_df[
node_name,
]
# Drop data with missing values
data_for_node = track_data_df[node_name].dropna(axis="index", how="any")

data_for_node_cleaned = data_for_node.dropna(axis="index", how="any")
node_trajectory = data_for_node_cleaned[["x", "y"]].to_numpy()
confidence = data_for_node_cleaned["score"].to_numpy()
node_trajectory = data_for_node[["x", "y"]].to_numpy()
confidence = data_for_node["score"].to_numpy()

reference_frame = (
"The coordinates are in (x, y) relative to the top-left of the image. "
"Coordinates refer to the midpoint of the pixel. "
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas "
"the top-left corner of that same pixel is at (-0.5, -0.5)."
)

pose_estimation_kwargs = dict(
name=f"{node_name}",
description=f"Sequential trajectory of {node_name}.",
data=node_trajectory,
unit="pixels",
reference_frame="No reference.",
reference_frame=reference_frame,
confidence=confidence,
confidence_definition="Point-wise confidence scores.",
)

# Add timestamps or rate if timestamps are uniform
frames = data_for_node_cleaned.index.values
# Add timestamps or only rate if the timestamps are uniform
frames = data_for_node.index.values
timestamps_for_data = timestamps[frames]
sample_periods = np.diff(timestamps_for_data)
if sample_periods.size == 0:
Expand All @@ -334,6 +375,8 @@ def build_track_pose_estimation_list(
rate = 1 / sample_periods[0] if uniform_samples else None

if rate:
# Video sample rates are ints but nwb expect floats
rate = float(int(rate))
pose_estimation_kwargs.update(rate=rate)
else:
pose_estimation_kwargs.update(timestamps=timestamps_for_data)
Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/fixtures/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ def slp_predictions():
@pytest.fixture
def slp_predictions_with_provenance():
"""The slp file generated with the collab tutorial and sleap version 1.27"""
return "tests/data/slp/tutorial_predictions_version_1.2.7_with_provenance.slp"
return "tests/data/slp/predictions_1.2.7_provenance_and_tracking.slp"
33 changes: 33 additions & 0 deletions tests/io/test_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,39 @@ def test_typical_case_append_with_metadata_propagation(nwbfile, slp_typical):
assert extracted_dimensions == expected_dimensions


def test_provenance_writing(nwbfile, slp_predictions_with_provenance):
labels = load_slp(slp_predictions_with_provenance)
nwbfile = append_labels_data_to_nwb(labels, nwbfile)

# Extract processing module
video_index = 0
video = labels.videos[video_index]
video_path = Path(video.filename)
processing_module_name = f"SLEAP_VIDEO_{video_index:03}_{video_path.stem}"
processing_module = nwbfile.processing[processing_module_name]

# Test that the provenance information is propagated
for pose_estimation_container in processing_module.data_interfaces.values():
assert pose_estimation_container.scorer == str(labels.provenance)


def test_default_metadata_overwriting(nwbfile, slp_predictions_with_provenance):
labels = load_slp(slp_predictions_with_provenance)
pose_estimation_metadata = {"scorer": "overwritten_value"}
nwbfile = append_labels_data_to_nwb(labels, nwbfile, pose_estimation_metadata)

# Extract processing module
video_index = 0
video = labels.videos[video_index]
video_path = Path(video.filename)
processing_module_name = f"SLEAP_VIDEO_{video_index:03}_{video_path.stem}"
processing_module = nwbfile.processing[processing_module_name]

# Test that the value of scorer was overwritten
for pose_estimation_container in processing_module.data_interfaces.values():
assert pose_estimation_container.scorer == "overwritten_value"


def test_complex_case_append(nwbfile, slp_predictions):
labels = load_slp(slp_predictions)
nwbfile = append_labels_data_to_nwb(labels, nwbfile)
Expand Down