Skip to content

Commit

Permalink
Bugfix/remove components crash (and tests for pose.remove_components(…
Browse files Browse the repository at this point in the history
…)) (#150)

* CDL: minor doc typo fix

* Undoing some changes that got mixed in

* Fix remove_components crash #149

* Add test cases for pose.remove_components, update random pose object so header matches body

* Another quick test case
  • Loading branch information
cleong110 authored Feb 13, 2025
1 parent e22d323 commit e07ca68
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 20 deletions.
10 changes: 5 additions & 5 deletions src/python/pose_format/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def remove_components(self, components_to_remove: Union[str, List[str]], points_
for component in self.header.components:
if component.name not in components_to_remove:
components_to_keep.append(component.name)
points_dict[component.name] = []
if points_to_remove is not None:
for point in component.points:
if point not in points_to_remove[component.name]:
points_dict[component.name].append(point)
if points_to_remove:
points_to_remove_list = points_to_remove.get(component.name, [])
points_dict[component.name] = [point for point in component.points if point not in points_to_remove_list]
else:
points_dict[component.name] = component.points[:]

return self.get_components(components_to_keep, points_dict)

Expand Down
36 changes: 29 additions & 7 deletions src/python/pose_format/utils/generic_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from typing import List, get_args
import numpy as np
import pytest
Expand Down Expand Up @@ -154,7 +155,34 @@ def test_correct_wrists(fake_poses: List[Pose]):
assert corrected_pose != pose
assert np.array_equal(corrected_pose.body.data, pose.body.data) is False


@pytest.mark.parametrize("fake_poses", ["holistic"], indirect=["fake_poses"])
def test_remove_one_point_and_one_component(fake_poses: List[Pose]):
component_to_drop = "POSE_WORLD_LANDMARKS"
point_to_drop = "LEFT_KNEE"
for pose in fake_poses:
original_component_names = []
original_points_dict = defaultdict(list)
for component in pose.header.components:
original_component_names.append(component.name)

for point in component.points:
original_points_dict[component.name].append(point)

assert component_to_drop in original_component_names
assert point_to_drop in original_points_dict["POSE_LANDMARKS"]
reduced_pose = pose.remove_components(component_to_drop, {"POSE_LANDMARKS": [point_to_drop]})
new_component_names, new_points_dict = [], defaultdict(list)
new_component_names = []
new_points_dict = defaultdict(list)
for component in reduced_pose.header.components:
new_component_names.append(component.name)

for point in component.points:
new_points_dict[component.name].append(point)


assert component_to_drop not in new_component_names
assert point_to_drop not in new_points_dict["POSE_LANDMARKS"]


@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"])
Expand Down Expand Up @@ -205,9 +233,3 @@ def test_fake_pose(known_pose_format: KnownPoseFormat):
assert pose.header.num_dims() == pose.body.data.shape[-1]

poses = [fake_pose(25) for _ in range(5)]






131 changes: 123 additions & 8 deletions src/python/tests/pose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ def _create_pose_header_component(name: str, num_keypoints: int) -> PoseHeaderCo
return component


def _distribute_points_among_components(component_count: int, total_keypoint_count: int):
if component_count <= 0 or total_keypoint_count < component_count + 1:
raise ValueError("Total keypoints must be at least component count+1 (so that 0 can have two), and component count must be positive")

# Step 1: Initialize with required minimum values
keypoint_counts = [2] + [1] * (component_count - 1) # Ensure first is 2, others at least 1

# Step 2: Distribute remaining points
remaining_points = total_keypoint_count - sum(keypoint_counts)
for _ in range(remaining_points):
keypoint_counts[random.randint(0, component_count - 1)] += 1 # Add randomly

return keypoint_counts

def _create_pose_header(width: int, height: int, depth: int, num_components: int, num_keypoints: int) -> PoseHeader:
"""
Create a PoseHeader with given dimensions and components.
Expand All @@ -79,8 +93,10 @@ def _create_pose_header(width: int, height: int, depth: int, num_components: int
"""
dimensions = PoseHeaderDimensions(width=width, height=height, depth=depth)

keypoints_per_component = _distribute_points_among_components(num_components, num_keypoints)

components = [
_create_pose_header_component(name=str(index), num_keypoints=num_keypoints) for index in range(num_components)
_create_pose_header_component(name=str(index), num_keypoints=keypoints_per_component[index]) for index in range(num_components)
]

header = PoseHeader(version=1.0, dimensions=dimensions, components=components)
Expand Down Expand Up @@ -134,6 +150,8 @@ def _create_random_tensorflow_data(frames_min: Optional[int] = None,
return tensor, mask, confidence




def _create_random_numpy_data(frames_min: Optional[int] = None,
frames_max: Optional[int] = None,
num_frames: Optional[int] = None,
Expand Down Expand Up @@ -286,7 +304,7 @@ def _get_random_pose_object_with_tf_posebody(num_keypoints: int, frames_min: int
return Pose(header=header, body=body)


def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10) -> Pose:
def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10, num_components=3) -> Pose:
"""
Creates a random Pose object with Numpy pose body for testing.
Expand All @@ -313,7 +331,7 @@ def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min:

body = NumPyPoseBody(fps=10, data=masked_array, confidence=confidence)

header = _create_pose_header(width=10, height=7, depth=0, num_components=3, num_keypoints=num_keypoints)
header = _create_pose_header(width=10, height=7, depth=0, num_components=num_components, num_keypoints=num_keypoints)

return Pose(header=header, body=body)

Expand All @@ -329,6 +347,96 @@ def test_pose_object_should_be_callable(self):
"""
assert callable(Pose)

def test_pose_remove_components(self):
pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5)
assert pose.body.data.shape[-2] == 5
assert pose.body.data.shape[-1] == 2 # XY dimensions

self.assertEqual(len(pose.header.components), 3)
self.assertEqual(sum(len(c.points) for c in pose.header.components), 5)
self.assertEqual(pose.header.components[0].name, "0")
self.assertEqual(pose.header.components[1].name, "1")
self.assertEqual(pose.header.components[0].points[0], "0_a")
self.assertIn("1_a", pose.header.components[1].points)
self.assertNotIn("1_f", pose.header.components[1].points)
self.assertNotIn("4", pose.header.components)

# test that we can remove a component
component_to_remove = "0"
pose_copy = pose.copy()
self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components])
pose_copy = pose_copy.remove_components(component_to_remove)
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])


# Remove a point only
point_to_remove = "0_a"
pose_copy = pose.copy()
self.assertIn(point_to_remove, pose_copy.header.components[0].points)
pose_copy = pose_copy.remove_components([], {point_to_remove[0]:[point_to_remove]})
self.assertNotIn(point_to_remove, pose_copy.header.components[0].points)


# Can we remove two things at once
component_to_remove = "1"
point_to_remove = "2_a"
component_to_remove_point_from = "2"

self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components])
self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components])
self.assertIn(point_to_remove, pose_copy.header.components[2].points)
pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove_point_from:[point_to_remove]})
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components]) # this should still be around

# can we remove a component and a point FROM that component without crashing
component_to_remove = "0"
point_to_remove = "0_a"
pose_copy = pose.copy()
self.assertIn(point_to_remove, pose_copy.header.components[0].points)
pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove:[point_to_remove]})
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
self.assertNotIn(point_to_remove, pose_copy.header.components[0].points)


# can we "remove" a component that doesn't exist without crashing
component_to_remove = "NOT EXISTING"
pose_copy = pose.copy()
initial_count = len(pose_copy.header.components)
pose_copy = pose_copy.remove_components([component_to_remove])
self.assertEqual(initial_count, len(pose_copy.header.components))




# can we "remove" a point that doesn't exist from a component that does without crashing
point_to_remove = "2_x"
component_to_remove_point_from = "2"
pose_copy = pose.copy()
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]})
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)


# can we "remove" an empty list of points
component_to_remove_point_from = "2"
pose_copy = pose.copy()
initial_component_count = len(pose_copy.header.components)
initial_point_count = len(pose_copy.header.components[2].points)
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[]})
self.assertEqual(initial_component_count, len(pose_copy.header.components))
self.assertEqual(len(pose_copy.header.components[2].points), initial_point_count)


# can we remove a point from a component that doesn't exist
point_to_remove = "2_x"
component_to_remove_point_from = "NOT EXISTING"
pose_copy = pose.copy()
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]})
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)




class TestPoseTensorflowPoseBody(TestCase):
Expand Down Expand Up @@ -475,7 +583,7 @@ def create_pose_and_frame_dropout_uniform(example: tf.Tensor) -> tf.Tensor:
return example

dataset.map(create_pose_and_frame_dropout_uniform)


def test_pose_tf_posebody_copy_creates_deepcopy(self):
pose = _get_random_pose_object_with_tf_posebody(num_keypoints=5)
Expand All @@ -488,7 +596,7 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self):

# Check that pose and pose_copy are not the same object
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")

# Ensure the data tensors are equal but independent
self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original")

Expand All @@ -515,6 +623,14 @@ class TestPoseNumpyPoseBody(TestCase):
Testcases for Pose objects containing NumPy PoseBody data.
"""

def test_pose_numpy_generated_with_correct_shape(self):
pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5, frames_min=3)

# does the header match the body?
expected_keypoints_count_from_header = sum(len(c.points) for c in pose.header.components)
self.assertEqual(expected_keypoints_count_from_header, pose.body.data.shape[-2])


def test_pose_numpy_posebody_normalize_preserves_shape(self):
"""
Tests if the normalization of Pose object with NumPy PoseBody preserves array shape.
Expand Down Expand Up @@ -593,17 +709,16 @@ def test_pose_torch_posebody_copy_creates_deepcopy(self):
pose = _get_random_pose_object_with_torch_posebody(num_keypoints=5)
self.assertIsInstance(pose.body, TorchPoseBody)
self.assertIsInstance(pose.body.data, TorchMaskedTensor)


pose_copy = pose.copy()
self.assertIsInstance(pose_copy.body, TorchPoseBody)
self.assertIsInstance(pose_copy.body.data, TorchMaskedTensor)

self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
self.assertTrue(pose.body.data.tensor.equal(pose_copy.body.data.tensor), "Copy's data should match original")
self.assertTrue(pose.body.data.mask.equal(pose_copy.body.data.mask), "Copy's mask should match original")

pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor),
pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor),
mask=torch.ones_like(pose.body.data.mask))


Expand Down

0 comments on commit e07ca68

Please sign in to comment.