Skip to content

Commit

Permalink
Refactor code for the updated features (#39)
Browse files Browse the repository at this point in the history
* Fix invalid polygon comparison in tests

* Improve docstring for the updated crop_covered_segments ratio_tolerance

* Update changelog

* Add changelog entry for the older PR
  • Loading branch information
zhiltsov-max authored Apr 3, 2024
1 parent 3e33d6e commit 82982b1
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
<https://github.com/cvat-ai/datumaro/pull/15>)
- Storing labels with the same name but with a different parent
(<https://github.com/cvat-ai/datumaro/pull/8>)
- Functions to work with plain polygons (COCO-style) - `close_polygon`, `simplify_polygon`
(<https://github.com/cvat-ai/datumaro/pull/39>)

### Changed
- `env.detect_dataset()` now returns a list of detected formats at all recursion levels
Expand Down Expand Up @@ -82,6 +84,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/cvat-ai/datumaro/pull/29>)
- Incorrect writing of `media` field in the Datumaro format, when there are specific media fields
(<https://github.com/cvat-ai/datumaro/pull/34>)
- Added missing `PointCloud` media type in the datumaro module namespace
(<https://github.com/cvat-ai/datumaro/pull/34>)

### Security
- TBD
Expand Down
49 changes: 46 additions & 3 deletions datumaro/util/mask_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: MIT

from functools import partial
from itertools import chain
from itertools import chain, repeat
from typing import List, NamedTuple, NewType, Optional, Sequence, Tuple, TypedDict, Union

import numpy as np
Expand All @@ -22,9 +22,15 @@ class CompressedRle(TypedDict):


Rle = Union[CompressedRle, UncompressedRle]

Polygon = List[int]
"2d polygon with points [x1, y1, x2, y2, ...]"

PolygonGroup = List[Polygon]
"A group of polygons, describing a single object"

BboxCoords = NamedTuple("BboxCoords", [("x", int), ("y", int), ("w", int), ("h", int)])

Segment = Union[PolygonGroup, Rle]

BinaryMask = NewType("BinaryMask", np.ndarray)
Expand Down Expand Up @@ -272,8 +278,11 @@ def crop_covered_segments(
iou_threshold: IoU threshold for objects to be counted as intersected
By default is set to 0 to process any intersected objects
ratio_tolerance: an IoU "handicap" value for a situation
when an object is (almost) fully covered by another one and we
don't want make a "hole" in the background object
when a foreground object is (almost) fully inside of another one,
and we don't want make a "hole" in the background object.
If the foreground object is fully or almost fully (iou - this ratio)
inside the background object, it will be kept.
The default is to keep tiny (0.1% of IoU) foreground objects.
area_threshold: minimal area of included segments
Returns:
Expand Down Expand Up @@ -394,3 +403,37 @@ def merge_masks(
merged_mask = np.where(m, m, merged_mask)

return merged_mask


def close_polygon(p: Polygon) -> Polygon:
"""
Returns the closed version of the polygon (with the same first and last points),
or the polygon itself.
"""
points = np.asarray(p).reshape((-1, 2))

if len(points) > 0 and not np.all(points[-1] == points[0]):
points = np.append(points, points[0])

return points.flatten().tolist()


def simplify_polygon(p: Polygon) -> Polygon:
"Simplifies the polygon by removing repeated points"

points = np.asarray(p).reshape((-1, 2))
updated_points = []

if len(points) > 0:
updated_points.append(points[0])

for point_idx in range(1, len(points)):
prev_point = points[point_idx - 1]
point = points[point_idx]
if not np.all(point == prev_point):
updated_points.append(point)

if len(updated_points) < 3:
updated_points.extend(repeat(updated_points[-1], 3 - len(updated_points)))

return np.asarray(updated_points).flatten().tolist()
105 changes: 102 additions & 3 deletions tests/test_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,49 @@
from .requirements import Requirements, mark_requirement


def _compare_polygons(a, b) -> bool:
return len(a) == len(b) and frozenset(map(frozenset, a)) == frozenset(map(frozenset, b))
def _compare_polygons(a: mask_tools.Polygon, b: mask_tools.Polygon) -> bool:
a = mask_tools.close_polygon(mask_tools.simplify_polygon(a))[:-2]
b = mask_tools.close_polygon(mask_tools.simplify_polygon(b))[:-2]
if len(a) != len(b):
return False

a_points = np.reshape(a, (-1, 2))
b_points = np.reshape(b, (-1, 2))
for b_direction in [1, -1]:
# Polygons can be reversed, need to check both directions
b_ordered = b_points[::b_direction]

for b_pos in range(len(b_ordered)):
b_current = b_ordered
if b_pos > 0:
b_current = np.roll(b_current, b_pos, axis=0)

if np.array_equal(a_points, b_current):
return True

return False


def _compare_polygon_groups(a: mask_tools.PolygonGroup, b: mask_tools.PolygonGroup) -> bool:
def _deduplicate(group: mask_tools.PolygonGroup) -> mask_tools.PolygonGroup:
unique = list()

for polygon in group:
found = False
for existing_polygon in unique:
if _compare_polygons(polygon, existing_polygon):
found = True
break

if not found:
unique.append(polygon)

return unique

a = _deduplicate(a)
b = _deduplicate(b)

return len(a) == len(b) and len(a) == len(_deduplicate(a + b))


class PolygonConversionsTest(TestCase):
Expand All @@ -31,7 +72,7 @@ def test_mask_can_be_converted_to_polygon(self):

computed = mask_tools.mask_to_polygons(mask)

self.assertTrue(_compare_polygons(expected, computed))
self.assertTrue(_compare_polygon_groups(expected, computed))

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_crop_covered_segments(self):
Expand Down Expand Up @@ -196,6 +237,64 @@ def test_mask_to_rle_multi(self):
for case in cases:
self._test_mask_to_rle(case)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_close_open_polygon(self):
source = [1, 1, 2, 3, 4, 5]
expected = [1, 1, 2, 3, 4, 5, 1, 1]

actual = mask_tools.close_polygon(source)

self.assertListEqual(expected, actual)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_close_closed_polygon(self):
source = [1, 1, 2, 3, 4, 5, 1, 1]
expected = [1, 1, 2, 3, 4, 5, 1, 1]

actual = mask_tools.close_polygon(source)

self.assertListEqual(expected, actual)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_close_polygon_with_no_points(self):
source = []
expected = []

actual = mask_tools.close_polygon(source)

self.assertListEqual(expected, actual)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_simplify_polygon(self):
source = [1, 1, 1, 1, 2, 3, 4, 5, 4, 5]
expected = [1, 1, 2, 3, 4, 5]

actual = mask_tools.simplify_polygon(source)

self.assertListEqual(expected, actual)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_simplify_polygon_with_less_3_points(self):
source = [1, 1]
expected = [1, 1, 1, 1, 1, 1]

actual = mask_tools.simplify_polygon(source)

self.assertListEqual(expected, actual)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_compare_polygons(self):
a = [1, 1, 2, 3, 4, 4, 5, 6, 1, 1]
b_variants = [
[2, 3, 4, 4, 5, 6, 1, 1],
[4, 4, 5, 6, 1, 1, 2, 3],
[5, 6, 1, 1, 2, 3, 4, 4],
[1, 1, 2, 3, 4, 4, 5, 6],
]

for b in b_variants:
self.assertTrue(_compare_polygons(a, b), b)


class ColormapOperationsTest(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
Expand Down

0 comments on commit 82982b1

Please sign in to comment.