Skip to content

Commit

Permalink
Add NWB support III - provenance writing (#17)
Browse files Browse the repository at this point in the history
* added tests for provenance writting

* MyPy compliance
  • Loading branch information
h-mayorquin authored Sep 17, 2022
1 parent 0a153dc commit 7f3dd13
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 19 deletions.
79 changes: 61 additions & 18 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ def write_labels_to_nwb(
https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile
Defaults to None and default values are used to generate the nwb file.
pose_estimation_metadata (dict): This argument has a dual purpose:
1) It can be used to pass time information about the video which is
necessary for synchronizing frames in pose estimation tracking to other
modalities. Either the video timestamps can be passed to
This can be used to pass the timestamps with the key `video_timestamps`
or the sampling rate with key`video_sample_rate`.
e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps)
or pose_estimation_metadata["video_sample_rate] = 15 # In Hz
2) The other use of this dictionary is to ovewrite sleap-io default
arguments for the PoseEstimation container.
see https://github.com/rly/ndx-pose for a full list or arguments.
"""

nwb_file_kwargs = nwb_file_kwargs or dict()
Expand All @@ -129,7 +144,6 @@ def write_labels_to_nwb(
)

nwbfile = NWBFile(**nwb_file_kwargs)

nwbfile = append_labels_data_to_nwb(labels, nwbfile, pose_estimation_metadata)

with NWBHDF5IO(str(nwbfile_path), "w") as io:
Expand All @@ -145,12 +159,33 @@ def append_labels_data_to_nwb(
labels (Labels): A general labels object
nwbfile (NWBFile): And in-memory nwbfile where the data is to be appended.
pose_estimation_metadata (dict): This argument has a dual purpose:
1) It can be used to pass time information about the video which is
necessary for synchronizing frames in pose estimation tracking to other
modalities. Either the video timestamps can be passed to
This can be used to pass the timestamps with the key `video_timestamps`
or the sampling rate with key`video_sample_rate`.
e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps)
or pose_estimation_metadata["video_sample_rate] = 15 # In Hz
2) The other use of this dictionary is to ovewrite sleap-io default
arguments for the PoseEstimation container.
see https://github.com/rly/ndx-pose for a full list or arguments.
Returns:
NWBFile: An in-memory nwbfile with the data from the labels object appended.
"""

pose_estimation_metadata = pose_estimation_metadata or dict()

# Extract default metadata
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version

labels_data_df = _extract_predicted_instances_data(labels)

# For every video create a processing module
Expand All @@ -162,20 +197,27 @@ def append_labels_data_to_nwb(
processing_module_name, nwbfile
)

# Propagate video metadata
default_metadata["original_videos"] = [f"{video.filename}"] # type: ignore
default_metadata["labeled_videos"] = [f"{video.filename}"] # type: ignore

# Overwrite default with the user provided metadata
default_metadata.update(pose_estimation_metadata)

# For every track in that video create a PoseEstimation container
name_of_tracks_in_video = (
labels_data_df[video.filename]
.columns.get_level_values("track_name")
.unique()
)

# For every track in that video create a PoseEstimation container
for track_index, track_name in enumerate(name_of_tracks_in_video):
pose_estimation_container = build_pose_estimation_container_for_track(
labels_data_df,
labels,
track_name,
video,
pose_estimation_metadata,
default_metadata,
)
nwb_processing_module.add(pose_estimation_container)

Expand Down Expand Up @@ -260,19 +302,14 @@ def build_pose_estimation_container_for_track(
)

# Arrange and mix metadata

pose_estimation_container_kwargs = dict(
name=f"track={track_name}",
description=f"Estimated positions of {skeleton.name} in video {video_path.name}",
pose_estimation_series=pose_estimation_series_list,
nodes=skeleton.node_names,
edges=np.array(skeleton.edge_inds).astype("uint64"),
source_software="SLEAP",
original_videos=[f"{video.filename}"],
labeled_videos=[f"{video.filename}"],
# dimensions=np.array([[video.backend.height, video.backend.width]]),
# scorer=str(labels.provenance),
# source_software_version=f"{sleap.__version__}
)

pose_estimation_container_kwargs.update(**pose_estimation_metadata)
Expand Down Expand Up @@ -302,28 +339,32 @@ def build_track_pose_estimation_list(

pose_estimation_series_list: List[PoseEstimationSeries] = []
for node_name in name_of_nodes_in_track:
# Add predicted instances only

data_for_node = track_data_df[
node_name,
]
# Drop data with missing values
data_for_node = track_data_df[node_name].dropna(axis="index", how="any")

data_for_node_cleaned = data_for_node.dropna(axis="index", how="any")
node_trajectory = data_for_node_cleaned[["x", "y"]].to_numpy()
confidence = data_for_node_cleaned["score"].to_numpy()
node_trajectory = data_for_node[["x", "y"]].to_numpy()
confidence = data_for_node["score"].to_numpy()

reference_frame = (
"The coordinates are in (x, y) relative to the top-left of the image. "
"Coordinates refer to the midpoint of the pixel. "
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas "
"the top-left corner of that same pixel is at (-0.5, -0.5)."
)

pose_estimation_kwargs = dict(
name=f"{node_name}",
description=f"Sequential trajectory of {node_name}.",
data=node_trajectory,
unit="pixels",
reference_frame="No reference.",
reference_frame=reference_frame,
confidence=confidence,
confidence_definition="Point-wise confidence scores.",
)

# Add timestamps or rate if timestamps are uniform
frames = data_for_node_cleaned.index.values
# Add timestamps or only rate if the timestamps are uniform
frames = data_for_node.index.values
timestamps_for_data = timestamps[frames]
sample_periods = np.diff(timestamps_for_data)
if sample_periods.size == 0:
Expand All @@ -334,6 +375,8 @@ def build_track_pose_estimation_list(
rate = 1 / sample_periods[0] if uniform_samples else None

if rate:
# Video sample rates are ints but nwb expect floats
rate = float(int(rate))
pose_estimation_kwargs.update(rate=rate)
else:
pose_estimation_kwargs.update(timestamps=timestamps_for_data)
Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/fixtures/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ def slp_predictions():
@pytest.fixture
def slp_predictions_with_provenance():
"""The slp file generated with the collab tutorial and sleap version 1.27"""
return "tests/data/slp/tutorial_predictions_version_1.2.7_with_provenance.slp"
return "tests/data/slp/predictions_1.2.7_provenance_and_tracking.slp"
33 changes: 33 additions & 0 deletions tests/io/test_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,39 @@ def test_typical_case_append_with_metadata_propagation(nwbfile, slp_typical):
assert extracted_dimensions == expected_dimensions


def test_provenance_writing(nwbfile, slp_predictions_with_provenance):
labels = load_slp(slp_predictions_with_provenance)
nwbfile = append_labels_data_to_nwb(labels, nwbfile)

# Extract processing module
video_index = 0
video = labels.videos[video_index]
video_path = Path(video.filename)
processing_module_name = f"SLEAP_VIDEO_{video_index:03}_{video_path.stem}"
processing_module = nwbfile.processing[processing_module_name]

# Test that the provenance information is propagated
for pose_estimation_container in processing_module.data_interfaces.values():
assert pose_estimation_container.scorer == str(labels.provenance)


def test_default_metadata_overwriting(nwbfile, slp_predictions_with_provenance):
labels = load_slp(slp_predictions_with_provenance)
pose_estimation_metadata = {"scorer": "overwritten_value"}
nwbfile = append_labels_data_to_nwb(labels, nwbfile, pose_estimation_metadata)

# Extract processing module
video_index = 0
video = labels.videos[video_index]
video_path = Path(video.filename)
processing_module_name = f"SLEAP_VIDEO_{video_index:03}_{video_path.stem}"
processing_module = nwbfile.processing[processing_module_name]

# Test that the value of scorer was overwritten
for pose_estimation_container in processing_module.data_interfaces.values():
assert pose_estimation_container.scorer == "overwritten_value"


def test_complex_case_append(nwbfile, slp_predictions):
labels = load_slp(slp_predictions)
nwbfile = append_labels_data_to_nwb(labels, nwbfile)
Expand Down

0 comments on commit 7f3dd13

Please sign in to comment.