Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix task creation with gt_pool validation and cloud storage data #8539

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Fixed

- Task creation with cloud storage data and GT_POOL validation mode
(<https://github.com/cvat-ai/cvat/pull/8539>)
28 changes: 3 additions & 25 deletions cvat/apps/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,19 +1157,6 @@ def _update_status(msg: str) -> None:
assert job_file_mapping[-1] == validation_params['frames']
job_file_mapping.pop(-1)

# Update manifest
manifest = ImageManifestManager(db_data.get_manifest_path())
manifest.link(
sources=[extractor.get_path(image.frame) for image in images],
meta={
k: {'related_images': related_images[k] }
for k in related_images
},
data_dir=upload_dir,
DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D),
)
manifest.create()

db_data.update_validation_layout(models.ValidationLayout(
mode=models.ValidationMode.GT_POOL,
frames=list(frame_idx_map.values()),
Expand Down Expand Up @@ -1324,24 +1311,15 @@ def _update_status(msg: str) -> None:
assert image.is_placeholder
image.real_frame = frame_id_map[image.real_frame]

# Update manifest
manifest.reorder([images[frame_idx_map[image.frame]].path for image in new_db_images])

images = new_db_images
db_data.size = len(images)
db_data.start_frame = 0
db_data.stop_frame = 0
db_data.frame_filter = ''

# Update manifest
manifest = ImageManifestManager(db_data.get_manifest_path())
manifest.link(
sources=[extractor.get_path(frame_idx_map[image.frame]) for image in images],
meta={
k: {'related_images': related_images[k] }
for k in related_images
},
data_dir=upload_dir,
DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D),
)
manifest.create()

db_data.update_validation_layout(models.ValidationLayout(
mode=models.ValidationMode.GT_POOL,
Expand Down
95 changes: 89 additions & 6 deletions tests/python/rest_api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ClassVar,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -1529,12 +1530,13 @@ def _create_task_with_cloud_data(
server_files: List[str],
use_cache: bool = True,
sorting_method: str = "lexicographical",
spec: Optional[Dict[str, Any]] = None,
data_type: str = "image",
video_frame_count: int = 10,
server_files_exclude: Optional[List[str]] = None,
org: Optional[str] = None,
org: str = "",
filenames: Optional[List[str]] = None,
task_spec_kwargs: Optional[Dict[str, Any]] = None,
data_spec_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[int, Any]:
s3_client = s3.make_client(bucket=cloud_storage["resource"])
if data_type == "video":
Expand All @@ -1551,7 +1553,9 @@ def _create_task_with_cloud_data(
)
else:
images = generate_image_files(
3, **({"prefixes": ["img_"] * 3} if not filenames else {"filenames": filenames})
3,
sizes=[(100, 50) if i % 2 else (50, 100) for i in range(3)],
**({"prefixes": ["img_"] * 3} if not filenames else {"filenames": filenames}),
)

for image in images:
Expand Down Expand Up @@ -1598,6 +1602,7 @@ def _create_task_with_cloud_data(
"name": "car",
}
],
**(task_spec_kwargs or {}),
}

data_spec = {
Expand All @@ -1608,9 +1613,8 @@ def _create_task_with_cloud_data(
server_files if not use_manifest else server_files + ["test/manifest.jsonl"]
),
"sorting_method": sorting_method,
**(data_spec_kwargs or {}),
}
if spec is not None:
data_spec.update(spec)

if server_files_exclude:
data_spec["server_files_exclude"] = server_files_exclude
Expand Down Expand Up @@ -1984,7 +1988,7 @@ def test_create_task_with_cloud_storage_and_check_retrieve_data_meta(
use_cache=False,
server_files=["test/video/video.avi"],
org=org,
spec=data_spec,
data_spec_kwargs=data_spec,
data_type="video",
)

Expand Down Expand Up @@ -2550,6 +2554,85 @@ def test_can_create_task_with_gt_job_from_video(
else:
assert len(validation_frames) == validation_frames_count

@pytest.mark.with_external_services
@pytest.mark.parametrize("cloud_storage_id", [2])
@pytest.mark.parametrize(
"validation_mode",
[
models.ValidationMode("gt"),
models.ValidationMode("gt_pool"),
],
)
def test_can_create_task_with_validation_and_cloud_data(
self,
cloud_storage_id: int,
validation_mode: models.ValidationMode,
request: pytest.FixtureRequest,
admin_user: str,
cloud_storages: Iterable,
):
cloud_storage = cloud_storages[cloud_storage_id]
server_files = [f"test/sub_0/img_{i}.jpeg" for i in range(3)]
validation_frames = ["test/sub_0/img_1.jpeg"]

(task_id, _) = self._create_task_with_cloud_data(
request,
cloud_storage,
use_manifest=False,
server_files=server_files,
sorting_method=models.SortingMethod(
"random"
), # only random sorting can be used with gt_pool
Marishka17 marked this conversation as resolved.
Show resolved Hide resolved
data_spec_kwargs={
"validation_params": models.DataRequestValidationParams._from_openapi_data(
mode=validation_mode,
frames=validation_frames,
frame_selection_method=models.FrameSelectionMethod("manual"),
frames_per_job_count=1,
)
},
task_spec_kwargs={
# in case of gt_pool: each regular job will contain 1 regular and 1 validation frames,
# (number of validation frames is not included into segment_size)
"segment_size": 1,
},
)

with make_api_client(admin_user) as api_client:
# check that GT job was created
(paginated_jobs, _) = api_client.jobs_api.list(task_id=task_id, type="ground_truth")
assert 1 == len(paginated_jobs["results"])

(paginated_jobs, _) = api_client.jobs_api.list(task_id=task_id, type="annotation")
jobs_count = (
len(server_files) - len(validation_frames)
if validation_mode == models.ValidationMode("gt_pool")
else len(server_files)
)
assert jobs_count == len(paginated_jobs["results"])
# check that the returned meta of images corresponds to the chunk data
# Note: meta is based on the order of images from database
# while chunk with CS data is based on the order of images in a manifest
for job in paginated_jobs["results"]:
(job_meta, _) = api_client.jobs_api.retrieve_data_meta(job["id"])
(_, response) = api_client.jobs_api.retrieve_data(
job["id"], type="chunk", quality="compressed", index=0
)
chunk_file = io.BytesIO(response.data)
assert zipfile.is_zipfile(chunk_file)

with zipfile.ZipFile(chunk_file, "r") as chunk_archive:
chunk_images = {
int(os.path.splitext(name)[0]): np.array(
Image.open(io.BytesIO(chunk_archive.read(name)))
)
for name in chunk_archive.namelist()
}
chunk_images = dict(sorted(chunk_images.items(), key=lambda e: e[0]))

for img, img_meta in zip(chunk_images.values(), job_meta.frames):
assert (img.shape[0], img.shape[1]) == (img_meta.height, img_meta.width)
Marishka17 marked this conversation as resolved.
Show resolved Hide resolved


class _SourceDataType(str, Enum):
images = "images"
Expand Down
12 changes: 9 additions & 3 deletions tests/python/shared/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
from contextlib import closing
from io import BytesIO
from typing import Generator, List, Optional
from typing import Generator, List, Optional, Tuple

import av
import av.video.reformatter
Expand All @@ -25,7 +25,11 @@ def generate_image_file(filename="image.png", size=(100, 50), color=(0, 0, 0)):


def generate_image_files(
count, prefixes=None, *, filenames: Optional[List[str]] = None
count: int,
*,
prefixes: Optional[List[str]] = None,
filenames: Optional[List[str]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
) -> List[BytesIO]:
assert not (prefixes and filenames), "prefixes cannot be used together with filenames"
assert not prefixes or len(prefixes) == count
Expand All @@ -35,7 +39,9 @@ def generate_image_files(
for i in range(count):
prefix = prefixes[i] if prefixes else ""
filename = f"{prefix}{i}.jpeg" if not filenames else filenames[i]
image = generate_image_file(filename, color=(i, i, i))
image = generate_image_file(
filename, color=(i, i, i), **({"size": sizes[i]}) if sizes else {}
Marishka17 marked this conversation as resolved.
Show resolved Hide resolved
)
images.append(image)

return images
Expand Down
15 changes: 15 additions & 0 deletions utils/dataset_manifest/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,21 @@ def emulate_hierarchical_structure(
'next': next_start_index,
}

def reorder(self, reordered_images: List[str]) -> None:
"""
The method takes a list of image names and reorders its content based on this new list.
Due to the implementation of Honeypots, the reordered list of image names may contain duplicates.
"""
unique_images: Dict[str, Any] = {}
for _, image_details in self:
if image_details.full_name not in unique_images:
unique_images[image_details.full_name] = image_details

try:
self.create(content=(unique_images[x] for x in reordered_images))
except KeyError as ex:
raise InvalidManifestError(f"Previous manifest does not contain {ex} image")

class _BaseManifestValidator(ABC):
def __init__(self, full_manifest_path):
self._manifest = _Manifest(full_manifest_path)
Expand Down
Loading