diff --git a/sleap_io/io/nwb.py b/sleap_io/io/nwb.py index cddfd848..30768535 100644 --- a/sleap_io/io/nwb.py +++ b/sleap_io/io/nwb.py @@ -109,6 +109,21 @@ def write_labels_to_nwb( https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile Defaults to None and default values are used to generate the nwb file. + + pose_estimation_metadata (dict): This argument has a dual purpose: + + 1) It can be used to pass time information about the video which is + necessary for synchronizing frames in pose estimation tracking to other + modalities. Either the video timestamps can be passed to + This can be used to pass the timestamps with the key `video_timestamps` + or the sampling rate with key`video_sample_rate`. + + e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps) + or pose_estimation_metadata["video_sample_rate] = 15 # In Hz + + 2) The other use of this dictionary is to ovewrite sleap-io default + arguments for the PoseEstimation container. + see https://github.com/rly/ndx-pose for a full list or arguments. """ nwb_file_kwargs = nwb_file_kwargs or dict() @@ -129,7 +144,6 @@ def write_labels_to_nwb( ) nwbfile = NWBFile(**nwb_file_kwargs) - nwbfile = append_labels_data_to_nwb(labels, nwbfile, pose_estimation_metadata) with NWBHDF5IO(str(nwbfile_path), "w") as io: @@ -145,12 +159,33 @@ def append_labels_data_to_nwb( labels (Labels): A general labels object nwbfile (NWBFile): And in-memory nwbfile where the data is to be appended. + pose_estimation_metadata (dict): This argument has a dual purpose: + + 1) It can be used to pass time information about the video which is + necessary for synchronizing frames in pose estimation tracking to other + modalities. Either the video timestamps can be passed to + This can be used to pass the timestamps with the key `video_timestamps` + or the sampling rate with key`video_sample_rate`. + + e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps) + or pose_estimation_metadata["video_sample_rate] = 15 # In Hz + + 2) The other use of this dictionary is to ovewrite sleap-io default + arguments for the PoseEstimation container. + see https://github.com/rly/ndx-pose for a full list or arguments. + Returns: NWBFile: An in-memory nwbfile with the data from the labels object appended. """ pose_estimation_metadata = pose_estimation_metadata or dict() + # Extract default metadata + provenance = labels.provenance + default_metadata = dict(scorer=str(provenance)) + sleap_version = provenance.get("sleap_version", None) + default_metadata["source_software_version"] = sleap_version + labels_data_df = _extract_predicted_instances_data(labels) # For every video create a processing module @@ -162,20 +197,27 @@ def append_labels_data_to_nwb( processing_module_name, nwbfile ) + # Propagate video metadata + default_metadata["original_videos"] = [f"{video.filename}"] # type: ignore + default_metadata["labeled_videos"] = [f"{video.filename}"] # type: ignore + + # Overwrite default with the user provided metadata + default_metadata.update(pose_estimation_metadata) + + # For every track in that video create a PoseEstimation container name_of_tracks_in_video = ( labels_data_df[video.filename] .columns.get_level_values("track_name") .unique() ) - # For every track in that video create a PoseEstimation container for track_index, track_name in enumerate(name_of_tracks_in_video): pose_estimation_container = build_pose_estimation_container_for_track( labels_data_df, labels, track_name, video, - pose_estimation_metadata, + default_metadata, ) nwb_processing_module.add(pose_estimation_container) @@ -260,7 +302,6 @@ def build_pose_estimation_container_for_track( ) # Arrange and mix metadata - pose_estimation_container_kwargs = dict( name=f"track={track_name}", description=f"Estimated positions of {skeleton.name} in video {video_path.name}", @@ -268,11 +309,7 @@ def build_pose_estimation_container_for_track( nodes=skeleton.node_names, edges=np.array(skeleton.edge_inds).astype("uint64"), source_software="SLEAP", - original_videos=[f"{video.filename}"], - labeled_videos=[f"{video.filename}"], # dimensions=np.array([[video.backend.height, video.backend.width]]), - # scorer=str(labels.provenance), - # source_software_version=f"{sleap.__version__} ) pose_estimation_container_kwargs.update(**pose_estimation_metadata) @@ -302,28 +339,32 @@ def build_track_pose_estimation_list( pose_estimation_series_list: List[PoseEstimationSeries] = [] for node_name in name_of_nodes_in_track: - # Add predicted instances only - data_for_node = track_data_df[ - node_name, - ] + # Drop data with missing values + data_for_node = track_data_df[node_name].dropna(axis="index", how="any") - data_for_node_cleaned = data_for_node.dropna(axis="index", how="any") - node_trajectory = data_for_node_cleaned[["x", "y"]].to_numpy() - confidence = data_for_node_cleaned["score"].to_numpy() + node_trajectory = data_for_node[["x", "y"]].to_numpy() + confidence = data_for_node["score"].to_numpy() + + reference_frame = ( + "The coordinates are in (x, y) relative to the top-left of the image. " + "Coordinates refer to the midpoint of the pixel. " + "That is, t the midpoint of the top-left pixel is at (0, 0), whereas " + "the top-left corner of that same pixel is at (-0.5, -0.5)." + ) pose_estimation_kwargs = dict( name=f"{node_name}", description=f"Sequential trajectory of {node_name}.", data=node_trajectory, unit="pixels", - reference_frame="No reference.", + reference_frame=reference_frame, confidence=confidence, confidence_definition="Point-wise confidence scores.", ) - # Add timestamps or rate if timestamps are uniform - frames = data_for_node_cleaned.index.values + # Add timestamps or only rate if the timestamps are uniform + frames = data_for_node.index.values timestamps_for_data = timestamps[frames] sample_periods = np.diff(timestamps_for_data) if sample_periods.size == 0: @@ -334,6 +375,8 @@ def build_track_pose_estimation_list( rate = 1 / sample_periods[0] if uniform_samples else None if rate: + # Video sample rates are ints but nwb expect floats + rate = float(int(rate)) pose_estimation_kwargs.update(rate=rate) else: pose_estimation_kwargs.update(timestamps=timestamps_for_data) diff --git a/tests/data/slp/tutorial_predictions_version_1.2.7_with_provenance.slp b/tests/data/slp/predictions_1.2.7_provenance_and_tracking.slp similarity index 51% rename from tests/data/slp/tutorial_predictions_version_1.2.7_with_provenance.slp rename to tests/data/slp/predictions_1.2.7_provenance_and_tracking.slp index 611015e7..00c31bd2 100644 Binary files a/tests/data/slp/tutorial_predictions_version_1.2.7_with_provenance.slp and b/tests/data/slp/predictions_1.2.7_provenance_and_tracking.slp differ diff --git a/tests/fixtures/slp.py b/tests/fixtures/slp.py index d3088f9c..3774a06f 100644 --- a/tests/fixtures/slp.py +++ b/tests/fixtures/slp.py @@ -29,4 +29,4 @@ def slp_predictions(): @pytest.fixture def slp_predictions_with_provenance(): """The slp file generated with the collab tutorial and sleap version 1.27""" - return "tests/data/slp/tutorial_predictions_version_1.2.7_with_provenance.slp" + return "tests/data/slp/predictions_1.2.7_provenance_and_tracking.slp" diff --git a/tests/io/test_nwb.py b/tests/io/test_nwb.py index ebc13f62..a3f9e119 100644 --- a/tests/io/test_nwb.py +++ b/tests/io/test_nwb.py @@ -87,6 +87,39 @@ def test_typical_case_append_with_metadata_propagation(nwbfile, slp_typical): assert extracted_dimensions == expected_dimensions +def test_provenance_writing(nwbfile, slp_predictions_with_provenance): + labels = load_slp(slp_predictions_with_provenance) + nwbfile = append_labels_data_to_nwb(labels, nwbfile) + + # Extract processing module + video_index = 0 + video = labels.videos[video_index] + video_path = Path(video.filename) + processing_module_name = f"SLEAP_VIDEO_{video_index:03}_{video_path.stem}" + processing_module = nwbfile.processing[processing_module_name] + + # Test that the provenance information is propagated + for pose_estimation_container in processing_module.data_interfaces.values(): + assert pose_estimation_container.scorer == str(labels.provenance) + + +def test_default_metadata_overwriting(nwbfile, slp_predictions_with_provenance): + labels = load_slp(slp_predictions_with_provenance) + pose_estimation_metadata = {"scorer": "overwritten_value"} + nwbfile = append_labels_data_to_nwb(labels, nwbfile, pose_estimation_metadata) + + # Extract processing module + video_index = 0 + video = labels.videos[video_index] + video_path = Path(video.filename) + processing_module_name = f"SLEAP_VIDEO_{video_index:03}_{video_path.stem}" + processing_module = nwbfile.processing[processing_module_name] + + # Test that the value of scorer was overwritten + for pose_estimation_container in processing_module.data_interfaces.values(): + assert pose_estimation_container.scorer == "overwritten_value" + + def test_complex_case_append(nwbfile, slp_predictions): labels = load_slp(slp_predictions) nwbfile = append_labels_data_to_nwb(labels, nwbfile)