diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index f3ec1d48262d..a74f70c54379 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -5,38 +5,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - id: files - uses: tj-actions/changed-files@v41.0.0 - with: - files: | - cvat-sdk/**/*.py - cvat-cli/**/*.py - tests/python/**/*.py - cvat/apps/quality_control/**/*.py - cvat/apps/analytics_report/**/*.py - dir_names: true - name: Run checks - env: - PR_FILES_AM: ${{ steps.files.outputs.added_modified }} - PR_FILES_RENAMED: ${{ steps.files.outputs.renamed }} run: | - # If different modules use different Black configs, - # we need to run Black for each python component group separately. - # Otherwise, they all will use the same config. + pipx install $(grep "^black" ./cvat-cli/requirements/development.txt) - UPDATED_DIRS="${{steps.files.outputs.all_changed_files}}" + echo "Black version: $(black --version)" - if [[ ! -z $UPDATED_DIRS ]]; then - pipx install $(egrep "black.*" ./cvat-cli/requirements/development.txt) - - echo "Black version: "$(black --version) - echo "The dirs will be checked: $UPDATED_DIRS" - EXIT_CODE=0 - for DIR in $UPDATED_DIRS; do - black --check --diff $DIR || EXIT_CODE=$(($? | $EXIT_CODE)) || true - done - exit $EXIT_CODE - else - echo "No files with the \"py\" extension found" - fi + black --check --diff . diff --git a/.github/workflows/finalize-release.yml b/.github/workflows/finalize-release.yml index 2cb6035769ae..8f19cb1e9e60 100644 --- a/.github/workflows/finalize-release.yml +++ b/.github/workflows/finalize-release.yml @@ -65,7 +65,7 @@ jobs: - name: Bump version run: - ./dev/update_version.py --minor + ./dev/update_version.py --patch - name: Commit post-release changes run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 42d2893080be..dd8854040e02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,55 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 + +## \[2.22.0\] - 2024-11-11 + +### Added + +- Feature to hide a mask during editing () + +- A quality setting to compare point groups without using bbox + () + +- A quality check option to consider empty frames matching + () + +### Changed + +- Reduced memory usage of the utils container + () + +### Removed + +- Removed unused business group + () + +### Fixed + +- Propagation creates copies on non-existing frames in a ground truth job + () + +- Exporting projects with tasks containing honeypots. Honeypots are no longer exported. + () + +- Error after creating GT job on Create job page with frame selection method `random_per_job` + () + +- Fixed issue 'Cannot read properties of undefined (reading 'push')' + () + +- Re-newed import/export request failed immediately if the previous failed + () + +- Fixed automatic zooming in attribute annotation mode for masks + () + +- Export dataset in CVAT format misses frames in tasks with non-default frame step + () + +- Incorrect progress representation on `Requests` page + () + ## \[2.21.3\] - 2024-10-31 diff --git a/cvat-canvas/package.json b/cvat-canvas/package.json index 2b24ff47e347..c89e7506854c 100644 --- a/cvat-canvas/package.json +++ b/cvat-canvas/package.json @@ -1,6 +1,6 @@ { "name": "cvat-canvas", - "version": "2.20.9", + "version": "2.20.10", "type": "module", "description": "Part of Computer Vision Annotation Tool which presents its canvas library", "main": "src/canvas.ts", diff --git a/cvat-canvas/src/typescript/canvasModel.ts b/cvat-canvas/src/typescript/canvasModel.ts index 2c7a1f08d203..0ad62484c14c 100644 --- a/cvat-canvas/src/typescript/canvasModel.ts +++ b/cvat-canvas/src/typescript/canvasModel.ts @@ -96,6 +96,7 @@ export interface Configuration { controlPointsSize?: number; outlinedBorders?: string | false; resetZoom?: boolean; + hideEditedObject?: boolean; } export interface BrushTool { @@ -416,6 +417,7 @@ export class CanvasModelImpl extends MasterImpl implements CanvasModel { textPosition: consts.DEFAULT_SHAPE_TEXT_POSITION, textContent: consts.DEFAULT_SHAPE_TEXT_CONTENT, undefinedAttrValue: consts.DEFAULT_UNDEFINED_ATTR_VALUE, + hideEditedObject: false, }, imageBitmap: false, image: null, @@ -981,6 +983,10 @@ export class CanvasModelImpl extends MasterImpl implements CanvasModel { this.data.configuration.CSSImageFilter = configuration.CSSImageFilter; } + if (typeof configuration.hideEditedObject === 'boolean') { + this.data.configuration.hideEditedObject = configuration.hideEditedObject; + } + this.notify(UpdateReasons.CONFIG_UPDATED); } diff --git a/cvat-canvas/src/typescript/canvasView.ts b/cvat-canvas/src/typescript/canvasView.ts index ab9a96682746..4c346b4d6735 100644 --- a/cvat-canvas/src/typescript/canvasView.ts +++ b/cvat-canvas/src/typescript/canvasView.ts @@ -1918,15 +1918,26 @@ export class CanvasViewImpl implements CanvasView, Listener { this.gridPattern.setAttribute('height', `${size.height}`); } else if (reason === UpdateReasons.SHAPE_FOCUSED) { const { padding, clientID } = this.controller.focusData; + const drawnState = this.drawnStates[clientID]; const object = this.svgShapes[clientID]; - if (object) { - const bbox: SVG.BBox = object.bbox(); - this.onFocusRegion( - bbox.x - padding, - bbox.y - padding, - bbox.width + padding * 2, - bbox.height + padding * 2, - ); + if (drawnState && object) { + const { offset } = this.geometry; + let [x, y, width, height] = [0, 0, 0, 0]; + + if (drawnState.shapeType === 'mask') { + const [xtl, ytl, xbr, ybr] = drawnState.points.slice(-4); + x = xtl + offset; + y = ytl + offset; + width = xbr - xtl + 1; + height = ybr - ytl + 1; + } else { + const bbox: SVG.BBox = object.bbox(); + ({ + x, y, width, height, + } = bbox); + } + + this.onFocusRegion(x - padding, y - padding, width + padding * 2, height + padding * 2); } } else if (reason === UpdateReasons.SHAPE_ACTIVATED) { this.activate(this.controller.activeElement); diff --git a/cvat-canvas/src/typescript/drawHandler.ts b/cvat-canvas/src/typescript/drawHandler.ts index b7e9cbb90130..77b674dec05e 100644 --- a/cvat-canvas/src/typescript/drawHandler.ts +++ b/cvat-canvas/src/typescript/drawHandler.ts @@ -5,7 +5,7 @@ import * as SVG from 'svg.js'; import 'svg.draw.js'; -import './svg.patch'; +import { CIRCLE_STROKE } from './svg.patch'; import { AutoborderHandler } from './autoborderHandler'; import { @@ -104,6 +104,7 @@ export class DrawHandlerImpl implements DrawHandler { private controlPointsSize: number; private selectedShapeOpacity: number; private outlinedBorders: string; + private isHidden: boolean; // we should use any instead of SVG.Shape because svg plugins cannot change declared interface // so, methods like draw() just undefined for SVG.Shape, but nevertheless they exist @@ -1276,6 +1277,7 @@ export class DrawHandlerImpl implements DrawHandler { this.selectedShapeOpacity = configuration.selectedShapeOpacity; this.outlinedBorders = configuration.outlinedBorders || 'black'; this.autobordersEnabled = false; + this.isHidden = false; this.startTimestamp = Date.now(); this.onDrawDoneDefault = onDrawDone; this.canvas = canvas; @@ -1301,10 +1303,28 @@ export class DrawHandlerImpl implements DrawHandler { }); } + private strokePoint(point: SVG.Element): void { + point.attr('stroke', this.isHidden ? 'none' : CIRCLE_STROKE); + point.fill({ opacity: this.isHidden ? 0 : 1 }); + } + + private updateHidden(value: boolean) { + this.isHidden = value; + + if (value) { + this.canvas.attr('pointer-events', 'none'); + } else { + this.canvas.attr('pointer-events', 'all'); + } + } + public configurate(configuration: Configuration): void { this.controlPointsSize = configuration.controlPointsSize; this.selectedShapeOpacity = configuration.selectedShapeOpacity; this.outlinedBorders = configuration.outlinedBorders || 'black'; + if (this.isHidden !== configuration.hideEditedObject) { + this.updateHidden(configuration.hideEditedObject); + } const isFillableRect = this.drawData && this.drawData.shapeType === 'rectangle' && @@ -1315,15 +1335,26 @@ export class DrawHandlerImpl implements DrawHandler { const isFilalblePolygon = this.drawData && this.drawData.shapeType === 'polygon'; if (this.drawInstance && (isFillableRect || isFillableCuboid || isFilalblePolygon)) { - this.drawInstance.fill({ opacity: configuration.selectedShapeOpacity }); + this.drawInstance.fill({ + opacity: configuration.hideEditedObject ? 0 : configuration.selectedShapeOpacity, + }); + } + + if (this.drawInstance && (isFilalblePolygon)) { + const paintHandler = this.drawInstance.remember('_paintHandler'); + if (paintHandler) { + for (const point of (paintHandler as any).set.members) { + this.strokePoint(point); + } + } } if (this.drawInstance && this.drawInstance.attr('stroke')) { - this.drawInstance.attr('stroke', this.outlinedBorders); + this.drawInstance.attr('stroke', configuration.hideEditedObject ? 'none' : this.outlinedBorders); } if (this.pointsGroup && this.pointsGroup.attr('stroke')) { - this.pointsGroup.attr('stroke', this.outlinedBorders); + this.pointsGroup.attr('stroke', configuration.hideEditedObject ? 'none' : this.outlinedBorders); } this.autobordersEnabled = configuration.autoborders; @@ -1369,6 +1400,7 @@ export class DrawHandlerImpl implements DrawHandler { const paintHandler = this.drawInstance.remember('_paintHandler'); for (const point of (paintHandler as any).set.members) { + this.strokePoint(point); point.attr('stroke-width', `${consts.POINTS_STROKE_WIDTH / geometry.scale}`); point.attr('r', `${this.controlPointsSize / geometry.scale}`); } diff --git a/cvat-canvas/src/typescript/editHandler.ts b/cvat-canvas/src/typescript/editHandler.ts index 567eea29c7de..84ecb1684ad4 100644 --- a/cvat-canvas/src/typescript/editHandler.ts +++ b/cvat-canvas/src/typescript/editHandler.ts @@ -472,7 +472,7 @@ export class EditHandlerImpl implements EditHandler { const paintHandler = this.editLine.remember('_paintHandler'); - for (const point of (paintHandler as any).set.members) { + for (const point of paintHandler.set.members) { point.attr('stroke-width', `${consts.POINTS_STROKE_WIDTH / geometry.scale}`); point.attr('r', `${this.controlPointsSize / geometry.scale}`); } diff --git a/cvat-canvas/src/typescript/masksHandler.ts b/cvat-canvas/src/typescript/masksHandler.ts index cdaa4d86d2fa..ca6e5e469a63 100644 --- a/cvat-canvas/src/typescript/masksHandler.ts +++ b/cvat-canvas/src/typescript/masksHandler.ts @@ -6,7 +6,7 @@ import { fabric } from 'fabric'; import debounce from 'lodash/debounce'; import { - DrawData, MasksEditData, Geometry, Configuration, BrushTool, ColorBy, + DrawData, MasksEditData, Geometry, Configuration, BrushTool, ColorBy, Position, } from './canvasModel'; import consts from './consts'; import { DrawHandler } from './drawHandler'; @@ -61,10 +61,11 @@ export class MasksHandlerImpl implements MasksHandler { private editData: MasksEditData | null; private colorBy: ColorBy; - private latestMousePos: { x: number; y: number; }; + private latestMousePos: Position; private startTimestamp: number; private geometry: Geometry; private drawingOpacity: number; + private isHidden: boolean; private keepDrawnPolygon(): void { const canvasWrapper = this.canvas.getElement().parentElement; @@ -217,12 +218,29 @@ export class MasksHandlerImpl implements MasksHandler { private imageDataFromCanvas(wrappingBBox: WrappingBBox): Uint8ClampedArray { const imageData = this.canvas.toCanvasElement() .getContext('2d').getImageData( - wrappingBBox.left, wrappingBBox.top, - wrappingBBox.right - wrappingBBox.left + 1, wrappingBBox.bottom - wrappingBBox.top + 1, + wrappingBBox.left, + wrappingBBox.top, + wrappingBBox.right - wrappingBBox.left + 1, + wrappingBBox.bottom - wrappingBBox.top + 1, ).data; return imageData; } + private updateHidden(value: boolean) { + this.isHidden = value; + + // Need to update style of upper canvas explicitly because update of default cursor is not applied immediately + // https://github.com/fabricjs/fabric.js/issues/1456 + const newOpacity = value ? '0' : ''; + const newCursor = value ? 'inherit' : 'none'; + this.canvas.getElement().parentElement.style.opacity = newOpacity; + const upperCanvas = this.canvas.getElement().parentElement.querySelector('.upper-canvas') as HTMLElement; + if (upperCanvas) { + upperCanvas.style.cursor = newCursor; + } + this.canvas.defaultCursor = newCursor; + } + private updateBrushTools(brushTool?: BrushTool, opts: Partial = {}): void { if (this.isPolygonDrawing) { // tool was switched from polygon to brush for example @@ -350,6 +368,7 @@ export class MasksHandlerImpl implements MasksHandler { this.editData = null; this.drawingOpacity = 0.5; this.brushMarker = null; + this.isHidden = false; this.colorBy = ColorBy.LABEL; this.onDrawDone = onDrawDone; this.onDrawRepeat = onDrawRepeat; @@ -452,7 +471,7 @@ export class MasksHandlerImpl implements MasksHandler { this.canvas.renderAll(); } - if (isMouseDown && !isBrushSizeChanging && ['brush', 'eraser'].includes(tool?.type)) { + if (isMouseDown && !this.isHidden && !isBrushSizeChanging && ['brush', 'eraser'].includes(tool?.type)) { const color = fabric.Color.fromHex(tool.color); color.setAlpha(tool.type === 'eraser' ? 1 : 0.5); @@ -530,6 +549,10 @@ export class MasksHandlerImpl implements MasksHandler { public configurate(configuration: Configuration): void { this.colorBy = configuration.colorBy; + + if (this.isHidden !== configuration.hideEditedObject) { + this.updateHidden(configuration.hideEditedObject); + } } public transform(geometry: Geometry): void { @@ -563,7 +586,10 @@ export class MasksHandlerImpl implements MasksHandler { const color = fabric.Color.fromHex(this.getStateColor(drawData.initialState)).getSource(); const [left, top, right, bottom] = points.slice(-4); const imageBitmap = expandChannels(color[0], color[1], color[2], points); - imageDataToDataURL(imageBitmap, right - left + 1, bottom - top + 1, + imageDataToDataURL( + imageBitmap, + right - left + 1, + bottom - top + 1, (dataURL: string) => new Promise((resolve) => { fabric.Image.fromURL(dataURL, (image: fabric.Image) => { try { @@ -654,7 +680,10 @@ export class MasksHandlerImpl implements MasksHandler { const color = fabric.Color.fromHex(this.getStateColor(editData.state)).getSource(); const [left, top, right, bottom] = points.slice(-4); const imageBitmap = expandChannels(color[0], color[1], color[2], points); - imageDataToDataURL(imageBitmap, right - left + 1, bottom - top + 1, + imageDataToDataURL( + imageBitmap, + right - left + 1, + bottom - top + 1, (dataURL: string) => new Promise((resolve) => { fabric.Image.fromURL(dataURL, (image: fabric.Image) => { try { diff --git a/cvat-canvas/src/typescript/svg.patch.ts b/cvat-canvas/src/typescript/svg.patch.ts index 40af155a956f..7b728b274335 100644 --- a/cvat-canvas/src/typescript/svg.patch.ts +++ b/cvat-canvas/src/typescript/svg.patch.ts @@ -86,6 +86,7 @@ SVG.Element.prototype.draw.extend( }), ); +export const CIRCLE_STROKE = '#000'; // Fix method drawCircles function drawCircles(): void { const array = this.el.array().valueOf(); @@ -109,6 +110,7 @@ function drawCircles(): void { .circle(5) .stroke({ width: 1, + color: CIRCLE_STROKE, }) .fill('#ccc') .center(p.x, p.y), diff --git a/cvat-cli/requirements/base.txt b/cvat-cli/requirements/base.txt index 063da9c864b4..e9be53974d91 100644 --- a/cvat-cli/requirements/base.txt +++ b/cvat-cli/requirements/base.txt @@ -1,3 +1,3 @@ -cvat-sdk~=2.21.3 +cvat-sdk~=2.22.0 Pillow>=10.3.0 setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/cvat-cli/src/cvat_cli/__main__.py b/cvat-cli/src/cvat_cli/__main__.py index b18c8d8bb751..3d89935fc2d0 100755 --- a/cvat-cli/src/cvat_cli/__main__.py +++ b/cvat-cli/src/cvat_cli/__main__.py @@ -1,79 +1,37 @@ # Copyright (C) 2020-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT +import argparse import logging import sys -from http.client import HTTPConnection -from types import SimpleNamespace -from typing import List import urllib3.exceptions from cvat_sdk import exceptions -from cvat_sdk.core.client import Client, Config -from cvat_cli.cli import CLI -from cvat_cli.parser import get_action_args, make_cmdline_parser +from ._internal.commands import COMMANDS +from ._internal.common import build_client, configure_common_arguments, configure_logger +from ._internal.utils import popattr logger = logging.getLogger(__name__) -def configure_logger(level): - formatter = logging.Formatter( - "[%(asctime)s] %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", style="%" - ) - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(level) - if level <= logging.DEBUG: - HTTPConnection.debuglevel = 1 +def main(args: list[str] = None): + parser = argparse.ArgumentParser(description=COMMANDS.description) + configure_common_arguments(parser) + COMMANDS.configure_parser(parser) - -def build_client(parsed_args: SimpleNamespace, logger: logging.Logger) -> Client: - config = Config(verify_ssl=not parsed_args.insecure) - - url = parsed_args.server_host - if parsed_args.server_port: - url += f":{parsed_args.server_port}" - - client = Client( - url=url, - logger=logger, - config=config, - check_server_version=False, # version is checked after auth to support versions < 2.3 - ) - - client.organization_slug = parsed_args.organization - - return client - - -def main(args: List[str] = None): - actions = { - "create": CLI.tasks_create, - "delete": CLI.tasks_delete, - "ls": CLI.tasks_list, - "frames": CLI.tasks_frames, - "dump": CLI.tasks_dump, - "upload": CLI.tasks_upload, - "export": CLI.tasks_export, - "import": CLI.tasks_import, - "auto-annotate": CLI.tasks_auto_annotate, - } - parser = make_cmdline_parser() parsed_args = parser.parse_args(args) - configure_logger(parsed_args.loglevel) - with build_client(parsed_args, logger=logger) as client: - action_args = get_action_args(parser, parsed_args) - try: - cli = CLI(client=client, credentials=parsed_args.auth) - actions[parsed_args.action](cli, **vars(action_args)) - except (exceptions.ApiException, urllib3.exceptions.HTTPError) as e: - logger.critical(e) - return 1 + configure_logger(logger, parsed_args) + + try: + with build_client(parsed_args, logger=logger) as client: + popattr(parsed_args, "_executor")(client, **vars(parsed_args)) + except (exceptions.ApiException, urllib3.exceptions.HTTPError) as e: + logger.critical(e) + return 1 return 0 diff --git a/cvat-cli/src/cvat_cli/_internal/__init__.py b/cvat-cli/src/cvat_cli/_internal/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/cvat-cli/src/cvat_cli/_internal/command_base.py b/cvat-cli/src/cvat_cli/_internal/command_base.py new file mode 100644 index 000000000000..ec6ccbbcd47f --- /dev/null +++ b/cvat-cli/src/cvat_cli/_internal/command_base.py @@ -0,0 +1,53 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import types +from collections.abc import Mapping +from typing import Callable, Protocol + + +class Command(Protocol): + @property + def description(self) -> str: ... + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: ... + + # The exact parameters accepted by `execute` vary between commands, + # so we're forced to declare it like this instead of as a method. + @property + def execute(self) -> Callable[..., None]: ... + + +class CommandGroup: + def __init__(self, *, description: str) -> None: + self._commands: dict[str, Command] = {} + self.description = description + + def command_class(self, name: str): + def decorator(cls: type): + self._commands[name] = cls() + return cls + + return decorator + + def add_command(self, name: str, command: Command) -> None: + self._commands[name] = command + + @property + def commands(self) -> Mapping[str, Command]: + return types.MappingProxyType(self._commands) + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + subparsers = parser.add_subparsers(required=True) + + for name, command in self._commands.items(): + subparser = subparsers.add_parser(name, description=command.description) + subparser.set_defaults(_executor=command.execute) + command.configure_parser(subparser) + + def execute(self) -> None: + # It should be impossible for a command group to be executed, + # because configure_parser requires that a subcommand is specified. + assert False, "unreachable code" diff --git a/cvat-cli/src/cvat_cli/_internal/commands.py b/cvat-cli/src/cvat_cli/_internal/commands.py new file mode 100644 index 000000000000..e86ef3b6350f --- /dev/null +++ b/cvat-cli/src/cvat_cli/_internal/commands.py @@ -0,0 +1,500 @@ +# Copyright (C) 2022-2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import argparse +import importlib +import importlib.util +import json +import textwrap +from collections.abc import Sequence +from pathlib import Path +from typing import Any, Optional + +import cvat_sdk.auto_annotation as cvataa +from attr.converters import to_bool +from cvat_sdk import Client, models +from cvat_sdk.core.helpers import DeferredTqdmProgressReporter +from cvat_sdk.core.proxies.tasks import ResourceType + +from .command_base import CommandGroup +from .parsers import BuildDictAction, parse_function_parameter, parse_label_arg, parse_resource_type + +COMMANDS = CommandGroup(description="Perform common operations related to CVAT tasks.") + + +@COMMANDS.command_class("ls") +class TaskList: + description = "List all CVAT tasks in either basic or JSON format." + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--json", + dest="use_json_output", + default=False, + action="store_true", + help="output JSON data", + ) + + def execute(self, client: Client, *, use_json_output: bool = False): + results = client.tasks.list(return_json=use_json_output) + if use_json_output: + print(json.dumps(json.loads(results), indent=2)) + else: + for r in results: + print(r.id) + + +@COMMANDS.command_class("create") +class TaskCreate: + description = textwrap.dedent( + """\ + Create a new CVAT task. To create a task, you need + to specify labels using the --labels argument or + attach the task to an existing project using the + --project_id argument. + """ + ) + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("name", type=str, help="name of the task") + parser.add_argument( + "resource_type", + default="local", + choices=list(ResourceType), + type=parse_resource_type, + help="type of files specified", + ) + parser.add_argument("resources", type=str, help="list of paths or URLs", nargs="+") + parser.add_argument( + "--annotation_path", default="", type=str, help="path to annotation file" + ) + parser.add_argument( + "--annotation_format", + default="CVAT 1.1", + type=str, + help="format of the annotation file being uploaded, e.g. CVAT 1.1", + ) + parser.add_argument( + "--bug_tracker", "--bug", default=argparse.SUPPRESS, type=str, help="bug tracker URL" + ) + parser.add_argument( + "--chunk_size", + default=argparse.SUPPRESS, + type=int, + help="the number of frames per chunk", + ) + parser.add_argument( + "--completion_verification_period", + dest="status_check_period", + default=2, + type=float, + help=textwrap.dedent( + """\ + number of seconds to wait until checking + if data compression finished (necessary before uploading annotations) + """ + ), + ) + parser.add_argument( + "--copy_data", + default=False, + action="store_true", + help=textwrap.dedent( + """\ + set the option to copy the data, only used when resource type is + share (default: %(default)s) + """ + ), + ) + parser.add_argument( + "--frame_step", + default=argparse.SUPPRESS, + type=int, + help=textwrap.dedent( + """\ + set the frame step option in the advanced configuration + when uploading image series or videos + """ + ), + ) + parser.add_argument( + "--image_quality", + default=70, + type=int, + help=textwrap.dedent( + """\ + set the image quality option in the advanced configuration + when creating tasks.(default: %(default)s) + """ + ), + ) + parser.add_argument( + "--labels", + default="[]", + type=parse_label_arg, + help="string or file containing JSON labels specification", + ) + parser.add_argument( + "--project_id", default=argparse.SUPPRESS, type=int, help="project ID if project exists" + ) + parser.add_argument( + "--overlap", + default=argparse.SUPPRESS, + type=int, + help="the number of intersected frames between different segments", + ) + parser.add_argument( + "--segment_size", + default=argparse.SUPPRESS, + type=int, + help="the number of frames in a segment", + ) + parser.add_argument( + "--sorting-method", + default="lexicographical", + choices=["lexicographical", "natural", "predefined", "random"], + help="""data soring method (default: %(default)s)""", + ) + parser.add_argument( + "--start_frame", + default=argparse.SUPPRESS, + type=int, + help="the start frame of the video", + ) + parser.add_argument( + "--stop_frame", default=argparse.SUPPRESS, type=int, help="the stop frame of the video" + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="""use cache""", # automatically sets default=False + ) + parser.add_argument( + "--use_zip_chunks", + action="store_true", # automatically sets default=False + help="""zip chunks before sending them to the server""", + ) + parser.add_argument( + "--cloud_storage_id", + default=argparse.SUPPRESS, + type=int, + help="cloud storage ID if you would like to use data from cloud storage", + ) + parser.add_argument( + "--filename_pattern", + default=argparse.SUPPRESS, + type=str, + help=textwrap.dedent( + """\ + pattern for filtering data from the manifest file for the upload. + Only shell-style wildcards are supported: + * - matches everything; + ? - matches any single character; + [seq] - matches any character in 'seq'; + [!seq] - matches any character not in seq + """ + ), + ) + + def execute( + self, + client, + *, + name: str, + labels: list[dict[str, str]], + resources: Sequence[str], + resource_type: ResourceType, + annotation_path: str, + annotation_format: str, + status_check_period: int, + **kwargs, + ) -> None: + task_params = {} + data_params = {} + + for k, v in kwargs.items(): + if k in models.DataRequest.attribute_map or k == "frame_step": + data_params[k] = v + else: + task_params[k] = v + + task = client.tasks.create_from_data( + spec=models.TaskWriteRequest(name=name, labels=labels, **task_params), + resource_type=resource_type, + resources=resources, + data_params=data_params, + annotation_path=annotation_path, + annotation_format=annotation_format, + status_check_period=status_check_period, + pbar=DeferredTqdmProgressReporter(), + ) + print("Created task id", task.id) + + +@COMMANDS.command_class("delete") +class TaskDelete: + description = "Delete a list of tasks, ignoring those which don't exist." + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("task_ids", type=int, help="list of task IDs", nargs="+") + + def execute(self, client: Client, *, task_ids: Sequence[int]) -> None: + client.tasks.remove_by_ids(task_ids=task_ids) + + +@COMMANDS.command_class("frames") +class TaskFrames: + description = textwrap.dedent( + """\ + Download the requested frame numbers for a task and save images as + task__frame_.jpg. + """ + ) + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("task_id", type=int, help="task ID") + parser.add_argument("frame_ids", type=int, help="list of frame IDs to download", nargs="+") + parser.add_argument( + "--outdir", type=str, default="", help="directory to save images (default: CWD)" + ) + parser.add_argument( + "--quality", + type=str, + choices=("original", "compressed"), + default="original", + help="choose quality of images (default: %(default)s)", + ) + + def execute( + self, + client: Client, + *, + task_id: int, + frame_ids: Sequence[int], + outdir: str, + quality: str, + ) -> None: + client.tasks.retrieve(obj_id=task_id).download_frames( + frame_ids=frame_ids, + outdir=outdir, + quality=quality, + filename_pattern=f"task_{task_id}" + "_frame_{frame_id:06d}{frame_ext}", + ) + + +@COMMANDS.command_class("dump") +class TaskDump: + description = textwrap.dedent( + """\ + Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). + """ + ) + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("task_id", type=int, help="task ID") + parser.add_argument("filename", type=str, help="output file") + parser.add_argument( + "--format", + dest="fileformat", + type=str, + default="CVAT for images 1.1", + help="annotation format (default: %(default)s)", + ) + parser.add_argument( + "--completion_verification_period", + dest="status_check_period", + default=2, + type=float, + help="number of seconds to wait until checking if dataset building finished", + ) + parser.add_argument( + "--with-images", + type=to_bool, + default=False, + dest="include_images", + help="Whether to include images or not (default: %(default)s)", + ) + + def execute( + self, + client: Client, + *, + task_id: int, + fileformat: str, + filename: str, + status_check_period: int, + include_images: bool, + ) -> None: + client.tasks.retrieve(obj_id=task_id).export_dataset( + format_name=fileformat, + filename=filename, + pbar=DeferredTqdmProgressReporter(), + status_check_period=status_check_period, + include_images=include_images, + ) + + +@COMMANDS.command_class("upload") +class TaskUpload: + description = textwrap.dedent( + """\ + Upload annotations for a task in the specified format + (e.g. 'YOLO ZIP 1.0'). + """ + ) + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("task_id", type=int, help="task ID") + parser.add_argument("filename", type=str, help="upload file") + parser.add_argument( + "--format", + dest="fileformat", + type=str, + default="CVAT 1.1", + help="annotation format (default: %(default)s)", + ) + + def execute( + self, + client: Client, + *, + task_id: int, + fileformat: str, + filename: str, + ) -> None: + client.tasks.retrieve(obj_id=task_id).import_annotations( + format_name=fileformat, + filename=filename, + pbar=DeferredTqdmProgressReporter(), + ) + + +@COMMANDS.command_class("export") +class TaskExport: + description = """Download a task backup.""" + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("task_id", type=int, help="task ID") + parser.add_argument("filename", type=str, help="output file") + parser.add_argument( + "--completion_verification_period", + dest="status_check_period", + default=2, + type=float, + help="time interval between checks if archive building has been finished, in seconds", + ) + + def execute( + self, client: Client, *, task_id: int, filename: str, status_check_period: int + ) -> None: + client.tasks.retrieve(obj_id=task_id).download_backup( + filename=filename, + status_check_period=status_check_period, + pbar=DeferredTqdmProgressReporter(), + ) + + +@COMMANDS.command_class("import") +class TaskImport: + description = """Import a task from a backup file.""" + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("filename", type=str, help="upload file") + parser.add_argument( + "--completion_verification_period", + dest="status_check_period", + default=2, + type=float, + help="time interval between checks if archive processing was finished, in seconds", + ) + + def execute(self, client: Client, *, filename: str, status_check_period: int) -> None: + client.tasks.create_from_backup( + filename=filename, + status_check_period=status_check_period, + pbar=DeferredTqdmProgressReporter(), + ) + + +@COMMANDS.command_class("auto-annotate") +class TaskAutoAnnotate: + description = "Automatically annotate a CVAT task by running a function on the local machine." + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("task_id", type=int, help="task ID") + + function_group = parser.add_mutually_exclusive_group(required=True) + + function_group.add_argument( + "--function-module", + metavar="MODULE", + help="qualified name of a module to use as the function", + ) + + function_group.add_argument( + "--function-file", + metavar="PATH", + type=Path, + help="path to a Python source file to use as the function", + ) + + parser.add_argument( + "--function-parameter", + "-p", + metavar="NAME=TYPE:VALUE", + type=parse_function_parameter, + action=BuildDictAction, + dest="function_parameters", + help="parameter for the function", + ) + + parser.add_argument( + "--clear-existing", + action="store_true", + help="Remove existing annotations from the task", + ) + + parser.add_argument( + "--allow-unmatched-labels", + action="store_true", + help="Allow the function to declare labels not configured in the task", + ) + + def execute( + self, + client: Client, + *, + task_id: int, + function_module: Optional[str] = None, + function_file: Optional[Path] = None, + function_parameters: dict[str, Any], + clear_existing: bool = False, + allow_unmatched_labels: bool = False, + ) -> None: + if function_module is not None: + function = importlib.import_module(function_module) + elif function_file is not None: + module_spec = importlib.util.spec_from_file_location("__cvat_function__", function_file) + function = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(function) + else: + assert False, "function identification arguments missing" + + if hasattr(function, "create"): + # this is actually a function factory + function = function.create(**function_parameters) + else: + if function_parameters: + raise TypeError("function takes no parameters") + + cvataa.annotate_task( + client, + task_id, + function, + pbar=DeferredTqdmProgressReporter(), + clear_existing=clear_existing, + allow_unmatched_labels=allow_unmatched_labels, + ) diff --git a/cvat-cli/src/cvat_cli/_internal/common.py b/cvat-cli/src/cvat_cli/_internal/common.py new file mode 100644 index 000000000000..415a1340958e --- /dev/null +++ b/cvat-cli/src/cvat_cli/_internal/common.py @@ -0,0 +1,104 @@ +# Copyright (C) 2021-2022 Intel Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import getpass +import logging +import os +import sys +from http.client import HTTPConnection + +from cvat_sdk.core.client import Client, Config + +from ..version import VERSION +from .utils import popattr + + +def get_auth(s): + """Parse USER[:PASS] strings and prompt for password if none was + supplied.""" + user, _, password = s.partition(":") + password = password or os.environ.get("PASS") or getpass.getpass() + return user, password + + +def configure_common_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--version", action="version", version=VERSION) + parser.add_argument( + "--insecure", + action="store_true", + help="Allows to disable SSL certificate check", + ) + + parser.add_argument( + "--auth", + type=get_auth, + metavar="USER:[PASS]", + default=getpass.getuser(), + help="""defaults to the current user and supports the PASS + environment variable or password prompt + (default user: %(default)s).""", + ) + parser.add_argument( + "--server-host", type=str, default="localhost", help="host (default: %(default)s)" + ) + parser.add_argument( + "--server-port", + type=int, + default=None, + help="port (default: 80 for http and 443 for https connections)", + ) + parser.add_argument( + "--organization", + "--org", + metavar="SLUG", + help="""short name (slug) of the organization + to use when listing or creating resources; + set to blank string to use the personal workspace + (default: list all accessible objects, create in personal workspace)""", + ) + parser.add_argument( + "--debug", + action="store_const", + dest="loglevel", + const=logging.DEBUG, + default=logging.INFO, + help="show debug output", + ) + + +def configure_logger(logger: logging.Logger, parsed_args: argparse.Namespace) -> None: + level = popattr(parsed_args, "loglevel") + formatter = logging.Formatter( + "[%(asctime)s] %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", style="%" + ) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(level) + if level <= logging.DEBUG: + HTTPConnection.debuglevel = 1 + + +def build_client(parsed_args: argparse.Namespace, logger: logging.Logger) -> Client: + config = Config(verify_ssl=not popattr(parsed_args, "insecure")) + + url = popattr(parsed_args, "server_host") + if server_port := popattr(parsed_args, "server_port"): + url += f":{server_port}" + + client = Client( + url=url, + logger=logger, + config=config, + check_server_version=False, # version is checked after auth to support versions < 2.3 + ) + + client.login(popattr(parsed_args, "auth")) + client.check_server_version(fail_if_unsupported=False) + + client.organization_slug = popattr(parsed_args, "organization") + + return client diff --git a/cvat-cli/src/cvat_cli/_internal/parsers.py b/cvat-cli/src/cvat_cli/_internal/parsers.py new file mode 100644 index 000000000000..a66710a09f47 --- /dev/null +++ b/cvat-cli/src/cvat_cli/_internal/parsers.py @@ -0,0 +1,62 @@ +# Copyright (C) 2021-2022 Intel Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import json +import os.path +from typing import Any + +from attr.converters import to_bool +from cvat_sdk.core.proxies.tasks import ResourceType + + +def parse_resource_type(s: str) -> ResourceType: + try: + return ResourceType[s.upper()] + except KeyError: + return s + + +def parse_label_arg(s): + """If s is a file load it as JSON, otherwise parse s as JSON.""" + if os.path.exists(s): + with open(s, "r") as fp: + return json.load(fp) + else: + return json.loads(s) + + +def parse_function_parameter(s: str) -> tuple[str, Any]: + key, sep, type_and_value = s.partition("=") + + if not sep: + raise argparse.ArgumentTypeError("parameter value not specified") + + type_, sep, value = type_and_value.partition(":") + + if not sep: + raise argparse.ArgumentTypeError("parameter type not specified") + + if type_ == "int": + value = int(value) + elif type_ == "float": + value = float(value) + elif type_ == "str": + pass + elif type_ == "bool": + value = to_bool(value) + else: + raise argparse.ArgumentTypeError(f"unsupported parameter type {type_!r}") + + return (key, value) + + +class BuildDictAction(argparse.Action): + def __init__(self, option_strings, dest, default=None, **kwargs): + super().__init__(option_strings, dest, default=default or {}, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + key, value = values + getattr(namespace, self.dest)[key] = value diff --git a/cvat-cli/src/cvat_cli/_internal/utils.py b/cvat-cli/src/cvat_cli/_internal/utils.py new file mode 100644 index 000000000000..b541534790c4 --- /dev/null +++ b/cvat-cli/src/cvat_cli/_internal/utils.py @@ -0,0 +1,9 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + + +def popattr(obj, name): + value = getattr(obj, name) + delattr(obj, name) + return value diff --git a/cvat-cli/src/cvat_cli/cli.py b/cvat-cli/src/cvat_cli/cli.py deleted file mode 100644 index e7945b18bb2e..000000000000 --- a/cvat-cli/src/cvat_cli/cli.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (C) 2022 CVAT.ai Corporation -# -# SPDX-License-Identifier: MIT - -from __future__ import annotations - -import importlib -import importlib.util -import json -from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import cvat_sdk.auto_annotation as cvataa -from cvat_sdk import Client, models -from cvat_sdk.core.helpers import DeferredTqdmProgressReporter -from cvat_sdk.core.proxies.tasks import ResourceType - - -class CLI: - def __init__(self, client: Client, credentials: Tuple[str, str]): - self.client = client - - self.client.login(credentials) - - self.client.check_server_version(fail_if_unsupported=False) - - def tasks_list(self, *, use_json_output: bool = False, **kwargs): - """List all tasks in either basic or JSON format.""" - results = self.client.tasks.list(return_json=use_json_output, **kwargs) - if use_json_output: - print(json.dumps(json.loads(results), indent=2)) - else: - for r in results: - print(r.id) - - def tasks_create( - self, - name: str, - labels: List[Dict[str, str]], - resources: Sequence[str], - *, - resource_type: ResourceType = ResourceType.LOCAL, - annotation_path: str = "", - annotation_format: str = "CVAT XML 1.1", - status_check_period: int = 2, - **kwargs, - ) -> None: - """ - Create a new task with the given name and labels JSON and add the files to it. - """ - - task_params = {} - data_params = {} - - for k, v in kwargs.items(): - if k in models.DataRequest.attribute_map or k == "frame_step": - data_params[k] = v - else: - task_params[k] = v - - task = self.client.tasks.create_from_data( - spec=models.TaskWriteRequest(name=name, labels=labels, **task_params), - resource_type=resource_type, - resources=resources, - data_params=data_params, - annotation_path=annotation_path, - annotation_format=annotation_format, - status_check_period=status_check_period, - pbar=DeferredTqdmProgressReporter(), - ) - print("Created task id", task.id) - - def tasks_delete(self, task_ids: Sequence[int]) -> None: - """Delete a list of tasks, ignoring those which don't exist.""" - self.client.tasks.remove_by_ids(task_ids=task_ids) - - def tasks_frames( - self, - task_id: int, - frame_ids: Sequence[int], - *, - outdir: str = "", - quality: str = "original", - ) -> None: - """ - Download the requested frame numbers for a task and save images as - task__frame_.jpg. - """ - self.client.tasks.retrieve(obj_id=task_id).download_frames( - frame_ids=frame_ids, - outdir=outdir, - quality=quality, - filename_pattern=f"task_{task_id}" + "_frame_{frame_id:06d}{frame_ext}", - ) - - def tasks_dump( - self, - task_id: int, - fileformat: str, - filename: str, - *, - status_check_period: int = 2, - include_images: bool = False, - ) -> None: - """ - Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). - """ - self.client.tasks.retrieve(obj_id=task_id).export_dataset( - format_name=fileformat, - filename=filename, - pbar=DeferredTqdmProgressReporter(), - status_check_period=status_check_period, - include_images=include_images, - ) - - def tasks_upload( - self, task_id: str, fileformat: str, filename: str, *, status_check_period: int = 2 - ) -> None: - """Upload annotations for a task in the specified format - (e.g. 'YOLO ZIP 1.0').""" - self.client.tasks.retrieve(obj_id=task_id).import_annotations( - format_name=fileformat, - filename=filename, - status_check_period=status_check_period, - pbar=DeferredTqdmProgressReporter(), - ) - - def tasks_export(self, task_id: str, filename: str, *, status_check_period: int = 2) -> None: - """Download a task backup""" - self.client.tasks.retrieve(obj_id=task_id).download_backup( - filename=filename, - status_check_period=status_check_period, - pbar=DeferredTqdmProgressReporter(), - ) - - def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None: - """Import a task from a backup file""" - self.client.tasks.create_from_backup( - filename=filename, - status_check_period=status_check_period, - pbar=DeferredTqdmProgressReporter(), - ) - - def tasks_auto_annotate( - self, - task_id: int, - *, - function_module: Optional[str] = None, - function_file: Optional[Path] = None, - function_parameters: Dict[str, Any], - clear_existing: bool = False, - allow_unmatched_labels: bool = False, - ) -> None: - if function_module is not None: - function = importlib.import_module(function_module) - elif function_file is not None: - module_spec = importlib.util.spec_from_file_location("__cvat_function__", function_file) - function = importlib.util.module_from_spec(module_spec) - module_spec.loader.exec_module(function) - else: - assert False, "function identification arguments missing" - - if hasattr(function, "create"): - # this is actually a function factory - function = function.create(**function_parameters) - else: - if function_parameters: - raise TypeError("function takes no parameters") - - cvataa.annotate_task( - self.client, - task_id, - function, - pbar=DeferredTqdmProgressReporter(), - clear_existing=clear_existing, - allow_unmatched_labels=allow_unmatched_labels, - ) diff --git a/cvat-cli/src/cvat_cli/parser.py b/cvat-cli/src/cvat_cli/parser.py deleted file mode 100644 index d456b087cd65..000000000000 --- a/cvat-cli/src/cvat_cli/parser.py +++ /dev/null @@ -1,452 +0,0 @@ -# Copyright (C) 2021-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation -# -# SPDX-License-Identifier: MIT - -import argparse -import getpass -import json -import logging -import os -import textwrap -from pathlib import Path -from typing import Any, Tuple - -from attr.converters import to_bool -from cvat_sdk.core.proxies.tasks import ResourceType - -from .version import VERSION - - -def get_auth(s): - """Parse USER[:PASS] strings and prompt for password if none was - supplied.""" - user, _, password = s.partition(":") - password = password or os.environ.get("PASS") or getpass.getpass() - return user, password - - -def parse_label_arg(s): - """If s is a file load it as JSON, otherwise parse s as JSON.""" - if os.path.exists(s): - with open(s, "r") as fp: - return json.load(fp) - else: - return json.loads(s) - - -def parse_resource_type(s: str) -> ResourceType: - try: - return ResourceType[s.upper()] - except KeyError: - return s - - -def parse_function_parameter(s: str) -> Tuple[str, Any]: - key, sep, type_and_value = s.partition("=") - - if not sep: - raise argparse.ArgumentTypeError("parameter value not specified") - - type_, sep, value = type_and_value.partition(":") - - if not sep: - raise argparse.ArgumentTypeError("parameter type not specified") - - if type_ == "int": - value = int(value) - elif type_ == "float": - value = float(value) - elif type_ == "str": - pass - elif type_ == "bool": - value = to_bool(value) - else: - raise argparse.ArgumentTypeError(f"unsupported parameter type {type_!r}") - - return (key, value) - - -class BuildDictAction(argparse.Action): - def __init__(self, option_strings, dest, default=None, **kwargs): - super().__init__(option_strings, dest, default=default or {}, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - key, value = values - getattr(namespace, self.dest)[key] = value - - -def make_cmdline_parser() -> argparse.ArgumentParser: - ####################################################################### - # Command line interface definition - ####################################################################### - parser = argparse.ArgumentParser( - description="Perform common operations related to CVAT tasks.\n\n" - ) - parser.add_argument("--version", action="version", version=VERSION) - parser.add_argument( - "--insecure", - action="store_true", - help="Allows to disable SSL certificate check", - ) - - task_subparser = parser.add_subparsers(dest="action") - - ####################################################################### - # Positional arguments - ####################################################################### - parser.add_argument( - "--auth", - type=get_auth, - metavar="USER:[PASS]", - default=getpass.getuser(), - help="""defaults to the current user and supports the PASS - environment variable or password prompt - (default user: %(default)s).""", - ) - parser.add_argument( - "--server-host", type=str, default="localhost", help="host (default: %(default)s)" - ) - parser.add_argument( - "--server-port", - type=int, - default=None, - help="port (default: 80 for http and 443 for https connections)", - ) - parser.add_argument( - "--organization", - "--org", - metavar="SLUG", - help="""short name (slug) of the organization - to use when listing or creating resources; - set to blank string to use the personal workspace - (default: list all accessible objects, create in personal workspace)""", - ) - parser.add_argument( - "--debug", - action="store_const", - dest="loglevel", - const=logging.DEBUG, - default=logging.INFO, - help="show debug output", - ) - - ####################################################################### - # Create - ####################################################################### - task_create_parser = task_subparser.add_parser( - "create", - description=textwrap.dedent( - """\ - Create a new CVAT task. To create a task, you need - to specify labels using the --labels argument or - attach the task to an existing project using the - --project_id argument. - """ - ), - formatter_class=argparse.RawTextHelpFormatter, - ) - task_create_parser.add_argument("name", type=str, help="name of the task") - task_create_parser.add_argument( - "resource_type", - default="local", - choices=list(ResourceType), - type=parse_resource_type, - help="type of files specified", - ) - task_create_parser.add_argument("resources", type=str, help="list of paths or URLs", nargs="+") - task_create_parser.add_argument( - "--annotation_path", default="", type=str, help="path to annotation file" - ) - task_create_parser.add_argument( - "--annotation_format", - default="CVAT 1.1", - type=str, - help="format of the annotation file being uploaded, e.g. CVAT 1.1", - ) - task_create_parser.add_argument( - "--bug_tracker", "--bug", default=None, type=str, help="bug tracker URL" - ) - task_create_parser.add_argument( - "--chunk_size", default=None, type=int, help="the number of frames per chunk" - ) - task_create_parser.add_argument( - "--completion_verification_period", - dest="status_check_period", - default=2, - type=float, - help=textwrap.dedent( - """\ - number of seconds to wait until checking - if data compression finished (necessary before uploading annotations) - """ - ), - ) - task_create_parser.add_argument( - "--copy_data", - default=False, - action="store_true", - help=textwrap.dedent( - """\ - set the option to copy the data, only used when resource type is - share (default: %(default)s) - """ - ), - ) - task_create_parser.add_argument( - "--frame_step", - default=None, - type=int, - help=textwrap.dedent( - """\ - set the frame step option in the advanced configuration - when uploading image series or videos (default: %(default)s) - """ - ), - ) - task_create_parser.add_argument( - "--image_quality", - default=70, - type=int, - help=textwrap.dedent( - """\ - set the image quality option in the advanced configuration - when creating tasks.(default: %(default)s) - """ - ), - ) - task_create_parser.add_argument( - "--labels", - default="[]", - type=parse_label_arg, - help="string or file containing JSON labels specification", - ) - task_create_parser.add_argument( - "--project_id", default=None, type=int, help="project ID if project exists" - ) - task_create_parser.add_argument( - "--overlap", - default=None, - type=int, - help="the number of intersected frames between different segments", - ) - task_create_parser.add_argument( - "--segment_size", default=None, type=int, help="the number of frames in a segment" - ) - task_create_parser.add_argument( - "--sorting-method", - default="lexicographical", - choices=["lexicographical", "natural", "predefined", "random"], - help="""data soring method (default: %(default)s)""", - ) - task_create_parser.add_argument( - "--start_frame", default=None, type=int, help="the start frame of the video" - ) - task_create_parser.add_argument( - "--stop_frame", default=None, type=int, help="the stop frame of the video" - ) - task_create_parser.add_argument( - "--use_cache", action="store_true", help="""use cache""" # automatically sets default=False - ) - task_create_parser.add_argument( - "--use_zip_chunks", - action="store_true", # automatically sets default=False - help="""zip chunks before sending them to the server""", - ) - task_create_parser.add_argument( - "--cloud_storage_id", - default=None, - type=int, - help="cloud storage ID if you would like to use data from cloud storage", - ) - task_create_parser.add_argument( - "--filename_pattern", - type=str, - help=textwrap.dedent( - """\ - pattern for filtering data from the manifest file for the upload. - Only shell-style wildcards are supported: - * - matches everything - ? - matches any single character - [seq] - matches any character in 'seq' - [!seq] - matches any character not in seq - """ - ), - ) - - ####################################################################### - # Delete - ####################################################################### - delete_parser = task_subparser.add_parser("delete", description="Delete a CVAT task.") - delete_parser.add_argument("task_ids", type=int, help="list of task IDs", nargs="+") - - ####################################################################### - # List - ####################################################################### - ls_parser = task_subparser.add_parser( - "ls", description="List all CVAT tasks in simple or JSON format." - ) - ls_parser.add_argument( - "--json", - dest="use_json_output", - default=False, - action="store_true", - help="output JSON data", - ) - - ####################################################################### - # Frames - ####################################################################### - frames_parser = task_subparser.add_parser( - "frames", description="Download all frame images for a CVAT task." - ) - frames_parser.add_argument("task_id", type=int, help="task ID") - frames_parser.add_argument( - "frame_ids", type=int, help="list of frame IDs to download", nargs="+" - ) - frames_parser.add_argument( - "--outdir", type=str, default="", help="directory to save images (default: CWD)" - ) - frames_parser.add_argument( - "--quality", - type=str, - choices=("original", "compressed"), - default="original", - help="choose quality of images (default: %(default)s)", - ) - - ####################################################################### - # Dump - ####################################################################### - dump_parser = task_subparser.add_parser( - "dump", description="Download annotations for a CVAT task." - ) - dump_parser.add_argument("task_id", type=int, help="task ID") - dump_parser.add_argument("filename", type=str, help="output file") - dump_parser.add_argument( - "--format", - dest="fileformat", - type=str, - default="CVAT for images 1.1", - help="annotation format (default: %(default)s)", - ) - dump_parser.add_argument( - "--completion_verification_period", - dest="status_check_period", - default=2, - type=float, - help="number of seconds to wait until checking if dataset building finished", - ) - dump_parser.add_argument( - "--with-images", - type=to_bool, - default=False, - dest="include_images", - help="Whether to include images or not (default: %(default)s)", - ) - - ####################################################################### - # Upload Annotations - ####################################################################### - upload_parser = task_subparser.add_parser( - "upload", description="Upload annotations for a CVAT task." - ) - upload_parser.add_argument("task_id", type=int, help="task ID") - upload_parser.add_argument("filename", type=str, help="upload file") - upload_parser.add_argument( - "--format", - dest="fileformat", - type=str, - default="CVAT 1.1", - help="annotation format (default: %(default)s)", - ) - - ####################################################################### - # Export task - ####################################################################### - export_task_parser = task_subparser.add_parser("export", description="Export a CVAT task.") - export_task_parser.add_argument("task_id", type=int, help="task ID") - export_task_parser.add_argument("filename", type=str, help="output file") - export_task_parser.add_argument( - "--completion_verification_period", - dest="status_check_period", - default=2, - type=float, - help="time interval between checks if archive building has been finished, in seconds", - ) - - ####################################################################### - # Import task - ####################################################################### - import_task_parser = task_subparser.add_parser("import", description="Import a CVAT task.") - import_task_parser.add_argument("filename", type=str, help="upload file") - import_task_parser.add_argument( - "--completion_verification_period", - dest="status_check_period", - default=2, - type=float, - help="time interval between checks if archive processing was finished, in seconds", - ) - - ####################################################################### - # Auto-annotate - ####################################################################### - auto_annotate_task_parser = task_subparser.add_parser( - "auto-annotate", - description="Automatically annotate a CVAT task by running a function on the local machine.", - ) - auto_annotate_task_parser.add_argument("task_id", type=int, help="task ID") - - function_group = auto_annotate_task_parser.add_mutually_exclusive_group(required=True) - - function_group.add_argument( - "--function-module", - metavar="MODULE", - help="qualified name of a module to use as the function", - ) - - function_group.add_argument( - "--function-file", - metavar="PATH", - type=Path, - help="path to a Python source file to use as the function", - ) - - auto_annotate_task_parser.add_argument( - "--function-parameter", - "-p", - metavar="NAME=TYPE:VALUE", - type=parse_function_parameter, - action=BuildDictAction, - dest="function_parameters", - help="parameter for the function", - ) - - auto_annotate_task_parser.add_argument( - "--clear-existing", action="store_true", help="Remove existing annotations from the task" - ) - - auto_annotate_task_parser.add_argument( - "--allow-unmatched-labels", - action="store_true", - help="Allow the function to declare labels not configured in the task", - ) - - return parser - - -def get_action_args( - parser: argparse.ArgumentParser, parsed_args: argparse.Namespace -) -> argparse.Namespace: - # FIXME: a hacky way to remove unnecessary args - action_args = dict(vars(parsed_args)) - - for action in parser._actions: - action_args.pop(action.dest, None) - - # remove default args - for k, v in dict(action_args).items(): - if v is None: - action_args.pop(k, None) - - return argparse.Namespace(**action_args) diff --git a/cvat-cli/src/cvat_cli/version.py b/cvat-cli/src/cvat_cli/version.py index 3899bcb7cd96..b2829a54b105 100644 --- a/cvat-cli/src/cvat_cli/version.py +++ b/cvat-cli/src/cvat_cli/version.py @@ -1 +1 @@ -VERSION = "2.21.3" +VERSION = "2.22.0" diff --git a/cvat-core/src/annotations-actions.ts b/cvat-core/src/annotations-actions.ts index 9e956421ae08..43d3ef29a910 100644 --- a/cvat-core/src/annotations-actions.ts +++ b/cvat-core/src/annotations-actions.ts @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -import { omit, throttle } from 'lodash'; +import { omit, range, throttle } from 'lodash'; import { ArgumentError } from './exceptions'; import { SerializedCollection, SerializedShape } from './server-response-types'; import { Job, Task } from './session'; @@ -107,13 +107,15 @@ class PropagateShapes extends BaseSingleFrameAction { } public async run( - instance, + instance: Job | Task, { collection: { shapes }, frameData: { number } }, ): Promise { if (number === this.#targetFrame) { return { collection: { shapes } }; } - const propagatedShapes = propagateShapes(shapes, number, this.#targetFrame); + + const frameNumbers = instance instanceof Job ? await instance.frames.frameNumbers() : range(0, instance.size); + const propagatedShapes = propagateShapes(shapes, number, this.#targetFrame, frameNumbers); return { collection: { shapes: [...shapes, ...propagatedShapes] } }; } diff --git a/cvat-core/src/frames.ts b/cvat-core/src/frames.ts index b29335865d01..e02b3e640a83 100644 --- a/cvat-core/src/frames.ts +++ b/cvat-core/src/frames.ts @@ -211,6 +211,11 @@ export class FramesMetaData { return Math.floor(this.getFrameIndex(dataFrameNumber) / this.chunkSize); } + getSegmentFrameNumbers(jobStartFrame: number): number[] { + const frames = this.getDataFrameNumbers(); + return frames.map((frame) => this.getJobRelativeFrameNumber(frame) + jobStartFrame); + } + getDataFrameNumbers(): number[] { if (this.includedFrames) { return [...this.includedFrames]; @@ -348,9 +353,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { const requestId = +_.uniqueId(); const requestedDataFrameNumber = meta.getDataFrameNumber(this.number - jobStartFrame); const chunkIndex = meta.getFrameChunkIndex(requestedDataFrameNumber); - const segmentFrameNumbers = meta.getDataFrameNumbers().map((dataFrameNumber: number) => ( - meta.getJobRelativeFrameNumber(dataFrameNumber) + jobStartFrame - )); + const segmentFrameNumbers = meta.getSegmentFrameNumbers(jobStartFrame); const frame = provider.frame(this.number); function findTheNextNotDecodedChunk(currentFrameIndex: number): number | null { @@ -889,9 +892,7 @@ export function getJobFrameNumbers(jobID: number): number[] { } const { meta, jobStartFrame } = frameDataCache[jobID]; - return meta.getDataFrameNumbers().map((dataFrameNumber: number): number => ( - meta.getJobRelativeFrameNumber(dataFrameNumber) + jobStartFrame - )); + return meta.getSegmentFrameNumbers(jobStartFrame); } export function clear(jobID: number): void { diff --git a/cvat-core/src/object-utils.ts b/cvat-core/src/object-utils.ts index 6e7fcbbd8d8c..b3592e5cbe1c 100644 --- a/cvat-core/src/object-utils.ts +++ b/cvat-core/src/object-utils.ts @@ -360,7 +360,7 @@ export function rle2Mask(rle: number[], width: number, height: number): number[] } export function propagateShapes( - shapes: T[], from: number, to: number, + shapes: T[], from: number, to: number, frameNumbers: number[], ): T[] { const getCopy = (shape: T): SerializedShape | SerializedData => { if (shape instanceof ObjectState) { @@ -397,9 +397,18 @@ export function propagateShapes( }; }; + const targetFrameNumbers = frameNumbers.filter( + (frameNumber: number) => frameNumber >= Math.min(from, to) && + frameNumber <= Math.max(from, to) && + frameNumber !== from, + ); + const states: T[] = []; - const sign = Math.sign(to - from); - for (let frame = from + sign; sign > 0 ? frame <= to : frame >= to; frame += sign) { + for (const frame of targetFrameNumbers) { + if (frame === from) { + continue; + } + for (const shape of shapes) { const copy = getCopy(shape); diff --git a/cvat-core/src/quality-settings.ts b/cvat-core/src/quality-settings.ts index c5e3ea6974c2..7c591e371cc4 100644 --- a/cvat-core/src/quality-settings.ts +++ b/cvat-core/src/quality-settings.ts @@ -14,6 +14,11 @@ export enum TargetMetric { RECALL = 'recall', } +export enum PointSizeBase { + IMAGE_SIZE = 'image_size', + GROUP_BBOX_SIZE = 'group_bbox_size', +} + export default class QualitySettings { #id: number; #targetMetric: TargetMetric; @@ -22,6 +27,7 @@ export default class QualitySettings { #task: number; #iouThreshold: number; #oksSigma: number; + #pointSizeBase: PointSizeBase; #lineThickness: number; #lowOverlapThreshold: number; #orientedLines: boolean; @@ -32,6 +38,7 @@ export default class QualitySettings { #objectVisibilityThreshold: number; #panopticComparison: boolean; #compareAttributes: boolean; + #matchEmptyFrames: boolean; #descriptions: Record; constructor(initialData: SerializedQualitySettingsData) { @@ -42,6 +49,7 @@ export default class QualitySettings { this.#maxValidationsPerJob = initialData.max_validations_per_job; this.#iouThreshold = initialData.iou_threshold; this.#oksSigma = initialData.oks_sigma; + this.#pointSizeBase = initialData.point_size_base as PointSizeBase; this.#lineThickness = initialData.line_thickness; this.#lowOverlapThreshold = initialData.low_overlap_threshold; this.#orientedLines = initialData.compare_line_orientation; @@ -52,6 +60,7 @@ export default class QualitySettings { this.#objectVisibilityThreshold = initialData.object_visibility_threshold; this.#panopticComparison = initialData.panoptic_comparison; this.#compareAttributes = initialData.compare_attributes; + this.#matchEmptyFrames = initialData.match_empty_frames; this.#descriptions = initialData.descriptions; } @@ -79,6 +88,14 @@ export default class QualitySettings { this.#oksSigma = newVal; } + get pointSizeBase(): PointSizeBase { + return this.#pointSizeBase; + } + + set pointSizeBase(newVal: PointSizeBase) { + this.#pointSizeBase = newVal; + } + get lineThickness(): number { return this.#lineThickness; } @@ -183,6 +200,14 @@ export default class QualitySettings { this.#maxValidationsPerJob = newVal; } + get matchEmptyFrames(): boolean { + return this.#matchEmptyFrames; + } + + set matchEmptyFrames(newVal: boolean) { + this.#matchEmptyFrames = newVal; + } + get descriptions(): Record { const descriptions: Record = Object.keys(this.#descriptions).reduce((acc, key) => { const camelCaseKey = _.camelCase(key); @@ -197,6 +222,7 @@ export default class QualitySettings { const result: SerializedQualitySettingsData = { iou_threshold: this.#iouThreshold, oks_sigma: this.#oksSigma, + point_size_base: this.#pointSizeBase, line_thickness: this.#lineThickness, low_overlap_threshold: this.#lowOverlapThreshold, compare_line_orientation: this.#orientedLines, @@ -210,6 +236,7 @@ export default class QualitySettings { target_metric: this.#targetMetric, target_metric_threshold: this.#targetMetricThreshold, max_validations_per_job: this.#maxValidationsPerJob, + match_empty_frames: this.#matchEmptyFrames, }; return result; diff --git a/cvat-core/src/request.ts b/cvat-core/src/request.ts index 1935f78b2f0c..ad8aa04d45aa 100644 --- a/cvat-core/src/request.ts +++ b/cvat-core/src/request.ts @@ -55,6 +55,7 @@ export class Request { return this.#status.toLowerCase() as RQStatus; } + // The `progress` represents a value between 0 and 1 get progress(): number | undefined { return this.#progress; } diff --git a/cvat-core/src/requests-manager.ts b/cvat-core/src/requests-manager.ts index 711073988955..800b577242c8 100644 --- a/cvat-core/src/requests-manager.ts +++ b/cvat-core/src/requests-manager.ts @@ -34,7 +34,7 @@ class RequestsManager { requestDelayIdx: number | null, request: Request | null, timeout: number | null; - promise?: Promise; + promise: Promise; }>; private requestStack: number[]; @@ -71,6 +71,7 @@ class RequestsManager { } return this.listening[requestID].promise; } + const promise = new Promise((resolve, reject) => { const timeoutCallback = async (): Promise => { // We make sure that no more than REQUESTS_COUNT requests are sent simultaneously @@ -131,27 +132,29 @@ class RequestsManager { message: `Could not get a status of the request ${requestID}. ${error.toString()}`, }))); } + + delete this.listening[requestID]; reject(error); } } }; - if (initialRequest?.status === RQStatus.FAILED) { - reject(new RequestError(initialRequest?.message)); - } else { - this.listening[requestID] = { - onUpdate: callback ? [callback] : [], - timeout: window.setTimeout(timeoutCallback), - request: initialRequest, - requestDelayIdx: 0, - }; - } + Promise.resolve().then(() => { + // running as microtask to make sure "promise" was initialized + if (initialRequest?.status === RQStatus.FAILED) { + reject(new RequestError(initialRequest?.message)); + } else { + this.listening[requestID] = { + onUpdate: callback ? [callback] : [], + timeout: window.setTimeout(timeoutCallback), + request: initialRequest, + requestDelayIdx: 0, + promise, + }; + } + }); }); - this.listening[requestID] = { - ...this.listening[requestID], - promise, - }; return promise; } diff --git a/cvat-core/src/server-response-types.ts b/cvat-core/src/server-response-types.ts index 4bf7a482bccb..ea97c0730aaa 100644 --- a/cvat-core/src/server-response-types.ts +++ b/cvat-core/src/server-response-types.ts @@ -47,7 +47,7 @@ export interface SerializedUser { first_name: string; last_name: string; email?: string; - groups?: ('user' | 'business' | 'admin')[]; + groups?: ('user' | 'admin')[]; is_staff?: boolean; is_superuser?: boolean; is_active?: boolean; @@ -247,6 +247,7 @@ export interface SerializedQualitySettingsData { max_validations_per_job?: number; iou_threshold?: number; oks_sigma?: number; + point_size_base?: string; line_thickness?: number; low_overlap_threshold?: number; compare_line_orientation?: boolean; @@ -257,6 +258,7 @@ export interface SerializedQualitySettingsData { object_visibility_threshold?: number; panoptic_comparison?: boolean; compare_attributes?: boolean; + match_empty_frames?: boolean; descriptions?: Record; } diff --git a/cvat-core/src/session-implementation.ts b/cvat-core/src/session-implementation.ts index 1c2194250155..904899831abf 100644 --- a/cvat-core/src/session-implementation.ts +++ b/cvat-core/src/session-implementation.ts @@ -879,6 +879,14 @@ export function implementTask(Task: typeof TaskClass): typeof TaskClass { }, }); + Object.defineProperty(Task.prototype.frames.frameNumbers, 'implementation', { + value: function includedFramesImplementation( + this: TaskClass, + ): ReturnType { + throw new Error('Not implemented for Task'); + }, + }); + Object.defineProperty(Task.prototype.frames.preview, 'implementation', { value: function previewImplementation( this: TaskClass, diff --git a/cvat-core/src/session.ts b/cvat-core/src/session.ts index 1164ae0c07de..a2bc2008aef0 100644 --- a/cvat-core/src/session.ts +++ b/cvat-core/src/session.ts @@ -373,8 +373,8 @@ export class Session { }; public actions: { - undo: (count: number) => Promise; - redo: (count: number) => Promise; + undo: (count?: number) => Promise; + redo: (count?: number) => Promise; freeze: (frozen: boolean) => Promise; clear: () => Promise; get: () => Promise<{ undo: [HistoryActions, number][], redo: [HistoryActions, number][] }>; @@ -403,8 +403,8 @@ export class Session { public logger: { log: ( scope: Parameters[0], - payload: Parameters[1], - wait: Parameters[2], + payload?: Parameters[1], + wait?: Parameters[2], ) => ReturnType; }; @@ -463,7 +463,7 @@ export class Session { } } -type InitializerType = Readonly & { labels?: SerializedLabel[] }>; +type InitializerType = Readonly & { labels?: SerializedLabel[] }>>; export class Job extends Session { #data: { diff --git a/cvat-core/src/user.ts b/cvat-core/src/user.ts index 6d7366151fb4..ef28f3633f0e 100644 --- a/cvat-core/src/user.ts +++ b/cvat-core/src/user.ts @@ -11,7 +11,7 @@ export default class User { public readonly email: string; public readonly firstName: string; public readonly lastName: string; - public readonly groups: ('user' | 'business' | 'admin')[]; + public readonly groups: ('user' | 'admin')[]; public readonly lastLogin: string; public readonly dateJoined: string; public readonly isStaff: boolean; diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py index d6294f44f8f6..0f3d82ea32ea 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/driver.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py @@ -3,7 +3,8 @@ # SPDX-License-Identifier: MIT import logging -from typing import List, Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence +from typing import Optional import attrs @@ -119,7 +120,7 @@ def __init__( fun_label, ds_labels_by_name ) - def validate_and_remap(self, shapes: List[models.LabeledShapeRequest], ds_frame: int) -> None: + def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: int) -> None: new_shapes = [] for shape in shapes: diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py index 57457d742256..d257cb7ec889 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MIT from functools import cached_property -from typing import List import PIL.Image import torchvision.models @@ -28,7 +27,7 @@ def spec(self) -> cvataa.DetectionFunctionSpec: ] ) - def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: + def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: results = self._model([self._transforms(image)]) return [ diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py index b4eb47d476d3..c7199b67738b 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MIT from functools import cached_property -from typing import List import PIL.Image import torchvision.models @@ -36,7 +35,7 @@ def spec(self) -> cvataa.DetectionFunctionSpec: ] ) - def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: + def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: results = self._model([self._transforms(image)]) return [ diff --git a/cvat-sdk/cvat_sdk/auto_annotation/interface.py b/cvat-sdk/cvat_sdk/auto_annotation/interface.py index 67313a7da6e5..20a21fe4a5cf 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/interface.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/interface.py @@ -3,7 +3,8 @@ # SPDX-License-Identifier: MIT import abc -from typing import List, Protocol, Sequence +from collections.abc import Sequence +from typing import Protocol import attrs import PIL.Image @@ -79,7 +80,7 @@ def spec(self) -> DetectionFunctionSpec: def detect( self, context: DetectionFunctionContext, image: PIL.Image.Image - ) -> List[models.LabeledShapeRequest]: + ) -> list[models.LabeledShapeRequest]: """ Detects objects on the supplied image and returns the results. diff --git a/cvat-sdk/cvat_sdk/core/client.py b/cvat-sdk/cvat_sdk/core/client.py index 0ae0b88ecad9..168259920c0f 100644 --- a/cvat-sdk/cvat_sdk/core/client.py +++ b/cvat-sdk/cvat_sdk/core/client.py @@ -7,10 +7,11 @@ import logging import urllib.parse +from collections.abc import Generator, Sequence from contextlib import contextmanager, suppress from pathlib import Path from time import sleep -from typing import Any, Dict, Generator, Optional, Sequence, Tuple, TypeVar +from typing import Any, Optional, TypeVar import attrs import packaging.specifiers as specifiers @@ -95,7 +96,7 @@ def __init__( if check_server_version: self.check_server_version() - self._repos: Dict[str, Repo] = {} + self._repos: dict[str, Repo] = {} """A cache for created Repository instances""" _ORG_SLUG_HEADER = "X-Organization" @@ -183,7 +184,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: def close(self) -> None: return self.__exit__(None, None, None) - def login(self, credentials: Tuple[str, str]) -> None: + def login(self, credentials: tuple[str, str]) -> None: (auth, _) = self.api_client.auth_api.create_login( models.LoginSerializerExRequest(username=credentials[0], password=credentials[1]) ) @@ -211,7 +212,7 @@ def wait_for_completion( rq_id: str, *, status_check_period: Optional[int] = None, - ) -> Tuple[models.Request, urllib3.HTTPResponse]: + ) -> tuple[models.Request, urllib3.HTTPResponse]: if status_check_period is None: status_check_period = self.config.status_check_period @@ -319,8 +320,8 @@ def make_endpoint_url( path: str, *, psub: Optional[Sequence[Any]] = None, - kwsub: Optional[Dict[str, Any]] = None, - query_params: Optional[Dict[str, Any]] = None, + kwsub: Optional[dict[str, Any]] = None, + query_params: Optional[dict[str, Any]] = None, ) -> str: url = self.host + path if psub or kwsub: @@ -331,7 +332,7 @@ def make_endpoint_url( def make_client( - host: str, *, port: Optional[int] = None, credentials: Optional[Tuple[str, str]] = None + host: str, *, port: Optional[int] = None, credentials: Optional[tuple[str, str]] = None ) -> Client: url = host.rstrip("/") if port: diff --git a/cvat-sdk/cvat_sdk/core/downloading.py b/cvat-sdk/cvat_sdk/core/downloading.py index 2e8263373350..d44535b2fc82 100644 --- a/cvat-sdk/cvat_sdk/core/downloading.py +++ b/cvat-sdk/cvat_sdk/core/downloading.py @@ -8,7 +8,7 @@ import json from contextlib import closing from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional from cvat_sdk.api_client.api_client import Endpoint from cvat_sdk.core.helpers import expect_status @@ -80,8 +80,8 @@ def prepare_file( self, endpoint: Endpoint, *, - url_params: Optional[Dict[str, Any]] = None, - query_params: Optional[Dict[str, Any]] = None, + url_params: Optional[dict[str, Any]] = None, + query_params: Optional[dict[str, Any]] = None, status_check_period: Optional[int] = None, ): client = self._client @@ -118,8 +118,8 @@ def prepare_and_download_file_from_endpoint( endpoint: Endpoint, filename: Path, *, - url_params: Optional[Dict[str, Any]] = None, - query_params: Optional[Dict[str, Any]] = None, + url_params: Optional[dict[str, Any]] = None, + query_params: Optional[dict[str, Any]] = None, pbar: Optional[ProgressReporter] = None, status_check_period: Optional[int] = None, ): diff --git a/cvat-sdk/cvat_sdk/core/helpers.py b/cvat-sdk/cvat_sdk/core/helpers.py index b04e33e4c687..425fbc78a083 100644 --- a/cvat-sdk/cvat_sdk/core/helpers.py +++ b/cvat-sdk/cvat_sdk/core/helpers.py @@ -7,7 +7,8 @@ import io import json import warnings -from typing import Any, Dict, Iterable, List, Optional, Union +from collections.abc import Iterable +from typing import Any, Optional, Union import tqdm import urllib3 @@ -19,7 +20,7 @@ def get_paginated_collection( endpoint: Endpoint, *, return_json: bool = False, **kwargs -) -> Union[List, List[Dict[str, Any]]]: +) -> Union[list, list[dict[str, Any]]]: """ Accumulates results from all the pages """ diff --git a/cvat-sdk/cvat_sdk/core/progress.py b/cvat-sdk/cvat_sdk/core/progress.py index fd844de722a0..33c7e420714e 100644 --- a/cvat-sdk/cvat_sdk/core/progress.py +++ b/cvat-sdk/cvat_sdk/core/progress.py @@ -6,7 +6,8 @@ from __future__ import annotations import contextlib -from typing import Generator, Iterable, Optional, TypeVar +from collections.abc import Generator, Iterable +from typing import Optional, TypeVar T = TypeVar("T") diff --git a/cvat-sdk/cvat_sdk/core/proxies/annotations.py b/cvat-sdk/cvat_sdk/core/proxies/annotations.py index e9353888119f..53db2af34712 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/annotations.py +++ b/cvat-sdk/cvat_sdk/core/proxies/annotations.py @@ -3,8 +3,9 @@ # SPDX-License-Identifier: MIT from abc import ABC +from collections.abc import Sequence from enum import Enum -from typing import Optional, Sequence +from typing import Optional from cvat_sdk import models from cvat_sdk.core.proxies.model_proxy import _EntityT diff --git a/cvat-sdk/cvat_sdk/core/proxies/issues.py b/cvat-sdk/cvat_sdk/core/proxies/issues.py index 5df1069c1178..8f844d68522a 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/issues.py +++ b/cvat-sdk/cvat_sdk/core/proxies/issues.py @@ -4,8 +4,6 @@ from __future__ import annotations -from typing import List - from cvat_sdk.api_client import apis, models from cvat_sdk.core.helpers import get_paginated_collection from cvat_sdk.core.proxies.model_proxy import ( @@ -53,7 +51,7 @@ class Issue( ): _model_partial_update_arg = "patched_issue_write_request" - def get_comments(self) -> List[Comment]: + def get_comments(self) -> list[Comment]: return [ Comment(self._client, m) for m in get_paginated_collection( diff --git a/cvat-sdk/cvat_sdk/core/proxies/jobs.py b/cvat-sdk/cvat_sdk/core/proxies/jobs.py index ac81380b7566..5eb3e6767477 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/jobs.py +++ b/cvat-sdk/cvat_sdk/core/proxies/jobs.py @@ -6,8 +6,9 @@ import io import mimetypes +from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Sequence +from typing import TYPE_CHECKING, Optional from PIL import Image @@ -93,7 +94,7 @@ def download_frames( outdir: StrPath = ".", quality: str = "original", filename_pattern: str = "frame_{frame_id:06d}{frame_ext}", - ) -> Optional[List[Image.Image]]: + ) -> Optional[list[Image.Image]]: """ Download the requested frame numbers for a job and save images as outdir/filename_pattern """ @@ -125,12 +126,12 @@ def get_meta(self) -> models.IDataMetaRead: (meta, _) = self.api.retrieve_data_meta(self.id) return meta - def get_labels(self) -> List[models.ILabel]: + def get_labels(self) -> list[models.ILabel]: return get_paginated_collection( self._client.api_client.labels_api.list_endpoint, job_id=self.id ) - def get_frames_info(self) -> List[models.IFrameMeta]: + def get_frames_info(self) -> list[models.IFrameMeta]: return self.get_meta().frames def remove_frames_by_ids(self, ids: Sequence[int]) -> None: @@ -141,7 +142,7 @@ def remove_frames_by_ids(self, ids: Sequence[int]) -> None: ), ) - def get_issues(self) -> List[Issue]: + def get_issues(self) -> list[Issue]: return [ Issue(self._client, m) for m in get_paginated_collection( diff --git a/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py b/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py index 40b6ffd27549..1557a61861ae 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py +++ b/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py @@ -12,13 +12,9 @@ TYPE_CHECKING, Any, Callable, - Dict, Generic, - List, Literal, Optional, - Tuple, - Type, TypeVar, Union, overload, @@ -96,15 +92,15 @@ class Repo(ModelProxy[ModelType, ApiType]): Implements group and management operations for entities. """ - _entity_type: Type[Entity[ModelType, ApiType]] + _entity_type: type[Entity[ModelType, ApiType]] ### Utilities def build_model_bases( - mt: Type[ModelType], at: Type[ApiType], *, api_member_name: Optional[str] = None -) -> Tuple[Type[Entity[ModelType, ApiType]], Type[Repo[ModelType, ApiType]]]: + mt: type[ModelType], at: type[ApiType], *, api_member_name: Optional[str] = None +) -> tuple[type[Entity[ModelType, ApiType]], type[Repo[ModelType, ApiType]]]: """ Helps to remove code duplication in declarations of derived classes """ @@ -128,7 +124,7 @@ class _RepoBase(Repo[ModelType, ApiType]): class ModelCreateMixin(Generic[_EntityT, IModel]): - def create(self: Repo, spec: Union[Dict[str, Any], IModel]) -> _EntityT: + def create(self: Repo, spec: Union[dict[str, Any], IModel]) -> _EntityT: """ Creates a new object on the server and returns the corresponding local object """ @@ -149,12 +145,12 @@ def retrieve(self: Repo, obj_id: int) -> _EntityT: class ModelListMixin(Generic[_EntityT]): @overload - def list(self: Repo, *, return_json: Literal[False] = False) -> List[_EntityT]: ... + def list(self: Repo, *, return_json: Literal[False] = False) -> list[_EntityT]: ... @overload - def list(self: Repo, *, return_json: Literal[True] = False) -> List[Any]: ... + def list(self: Repo, *, return_json: Literal[True] = False) -> list[Any]: ... - def list(self: Repo, *, return_json: bool = False) -> List[Union[_EntityT, Any]]: + def list(self: Repo, *, return_json: bool = False) -> list[Union[_EntityT, Any]]: """ Retrieves all objects from the server and returns them in basic or JSON format. """ @@ -174,8 +170,8 @@ class ModelUpdateMixin(ABC, Generic[IModel]): def _model_partial_update_arg(self: Entity) -> str: ... def _export_update_fields( - self: Entity, overrides: Optional[Union[Dict[str, Any], IModel]] = None - ) -> Dict[str, Any]: + self: Entity, overrides: Optional[Union[dict[str, Any], IModel]] = None + ) -> dict[str, Any]: # TODO: support field conversion and assignment updating # fields = to_json(self._model) @@ -194,7 +190,7 @@ def fetch(self: Entity) -> Self: (self._model, _) = self.api.retrieve(id=getattr(self, self._model_id_field)) return self - def update(self: Entity, values: Union[Dict[str, Any], IModel]) -> Self: + def update(self: Entity, values: Union[dict[str, Any], IModel]) -> Self: """ Commits model changes to the server diff --git a/cvat-sdk/cvat_sdk/core/proxies/projects.py b/cvat-sdk/cvat_sdk/core/proxies/projects.py index 70c647bfd033..3e0eddc00e36 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/projects.py +++ b/cvat-sdk/cvat_sdk/core/proxies/projects.py @@ -7,7 +7,7 @@ import io import json from pathlib import Path -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from cvat_sdk.api_client import apis, models from cvat_sdk.core.helpers import get_paginated_collection @@ -72,7 +72,7 @@ def get_annotations(self) -> models.ILabeledData: (annotations, _) = self.api.retrieve_annotations(self.id) return annotations - def get_tasks(self) -> List[Task]: + def get_tasks(self) -> list[Task]: return [ Task(self._client, m) for m in get_paginated_collection( @@ -80,7 +80,7 @@ def get_tasks(self) -> List[Task]: ) ] - def get_labels(self) -> List[models.ILabel]: + def get_labels(self) -> list[models.ILabel]: return get_paginated_collection( self._client.api_client.labels_api.list_endpoint, project_id=self.id ) diff --git a/cvat-sdk/cvat_sdk/core/proxies/tasks.py b/cvat-sdk/cvat_sdk/core/proxies/tasks.py index fe2d80d857b0..e0db111f8511 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/tasks.py +++ b/cvat-sdk/cvat_sdk/core/proxies/tasks.py @@ -8,10 +8,11 @@ import json import mimetypes import shutil +from collections.abc import Sequence from enum import Enum from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional from PIL import Image @@ -72,7 +73,7 @@ def upload_data( *, resource_type: ResourceType = ResourceType.LOCAL, pbar: Optional[ProgressReporter] = None, - params: Optional[Dict[str, Any]] = None, + params: Optional[dict[str, Any]] = None, wait_for_completion: bool = True, status_check_period: Optional[int] = None, ) -> None: @@ -226,7 +227,7 @@ def download_frames( outdir: StrPath = ".", quality: str = "original", filename_pattern: str = "frame_{frame_id:06d}{frame_ext}", - ) -> Optional[List[Image.Image]]: + ) -> Optional[list[Image.Image]]: """ Download the requested frame numbers for a task and save images as outdir/filename_pattern """ @@ -253,7 +254,7 @@ def download_frames( outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext) im.save(outdir / outfile) - def get_jobs(self) -> List[Job]: + def get_jobs(self) -> list[Job]: return [ Job(self._client, model=m) for m in get_paginated_collection( @@ -265,12 +266,12 @@ def get_meta(self) -> models.IDataMetaRead: (meta, _) = self.api.retrieve_data_meta(self.id) return meta - def get_labels(self) -> List[models.ILabel]: + def get_labels(self) -> list[models.ILabel]: return get_paginated_collection( self._client.api_client.labels_api.list_endpoint, task_id=self.id ) - def get_frames_info(self) -> List[models.IFrameMeta]: + def get_frames_info(self) -> list[models.IFrameMeta]: return self.get_meta().frames def remove_frames_by_ids(self, ids: Sequence[int]) -> None: @@ -295,7 +296,7 @@ def create_from_data( resources: Sequence[StrPath], *, resource_type: ResourceType = ResourceType.LOCAL, - data_params: Optional[Dict[str, Any]] = None, + data_params: Optional[dict[str, Any]] = None, annotation_path: str = "", annotation_format: str = "CVAT XML 1.1", status_check_period: int = None, diff --git a/cvat-sdk/cvat_sdk/core/uploading.py b/cvat-sdk/cvat_sdk/core/uploading.py index 0ccfd902da61..068e4d89a0c3 100644 --- a/cvat-sdk/cvat_sdk/core/uploading.py +++ b/cvat-sdk/cvat_sdk/core/uploading.py @@ -6,8 +6,9 @@ import json import os +from contextlib import AbstractContextManager from pathlib import Path -from typing import TYPE_CHECKING, Any, ContextManager, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional import requests import urllib3 @@ -147,9 +148,9 @@ def upload_file( url: str, filename: Path, *, - meta: Dict[str, Any], - query_params: Dict[str, Any] = None, - fields: Optional[Dict[str, Any]] = None, + meta: dict[str, Any], + query_params: dict[str, Any] = None, + fields: Optional[dict[str, Any]] = None, pbar: Optional[ProgressReporter] = None, logger=None, ) -> urllib3.HTTPResponse: @@ -194,7 +195,7 @@ def upload_file( return self._tus_finish_upload(url, query_params=query_params, fields=fields) @staticmethod - def _uploading_task(pbar: ProgressReporter, total_size: int) -> ContextManager[None]: + def _uploading_task(pbar: ProgressReporter, total_size: int) -> AbstractContextManager[None]: return pbar.task( total=total_size, desc="Uploading data", unit_scale=True, unit="B", unit_divisor=1024 ) @@ -256,7 +257,7 @@ def upload_file_and_wait( filename: Path, format_name: str, *, - url_params: Optional[Dict[str, Any]] = None, + url_params: Optional[dict[str, Any]] = None, pbar: Optional[ProgressReporter] = None, status_check_period: Optional[int] = None, ): @@ -279,7 +280,7 @@ def upload_file_and_wait( filename: Path, format_name: str, *, - url_params: Optional[Dict[str, Any]] = None, + url_params: Optional[dict[str, Any]] = None, pbar: Optional[ProgressReporter] = None, status_check_period: Optional[int] = None, ): @@ -302,7 +303,7 @@ def __init__(self, client: Client, *, max_request_size: int = MAX_REQUEST_SIZE): def upload_files( self, url: str, - resources: List[Path], + resources: list[Path], *, pbar: Optional[ProgressReporter] = None, **kwargs, @@ -351,10 +352,10 @@ def upload_files( return self._tus_finish_upload(url, fields=kwargs) def _split_files_by_requests( - self, filenames: List[Path] - ) -> Tuple[List[Tuple[List[Path], int]], List[Path], int]: - bulk_files: Dict[str, int] = {} - separate_files: Dict[str, int] = {} + self, filenames: list[Path] + ) -> tuple[list[tuple[list[Path], int]], list[Path], int]: + bulk_files: dict[str, int] = {} + separate_files: dict[str, int] = {} max_request_size = self.max_request_size # sort by size @@ -369,9 +370,9 @@ def _split_files_by_requests( total_size = sum(bulk_files.values()) + sum(separate_files.values()) # group small files by requests - bulk_file_groups: List[Tuple[List[str], int]] = [] + bulk_file_groups: list[tuple[list[str], int]] = [] current_group_size: int = 0 - current_group: List[str] = [] + current_group: list[str] = [] for filename, file_size in bulk_files.items(): if max_request_size < current_group_size + file_size: bulk_file_groups.append((current_group, current_group_size)) diff --git a/cvat-sdk/cvat_sdk/core/utils.py b/cvat-sdk/cvat_sdk/core/utils.py index 1ef434e3ad5b..efcc787d96de 100644 --- a/cvat-sdk/cvat_sdk/core/utils.py +++ b/cvat-sdk/cvat_sdk/core/utils.py @@ -7,37 +7,26 @@ import contextlib import itertools import os -from typing import ( - IO, - Any, - BinaryIO, - ContextManager, - Dict, - Generator, - Literal, - Sequence, - TextIO, - Union, - overload, -) +from collections.abc import Generator, Sequence +from typing import IO, Any, BinaryIO, Literal, TextIO, Union, overload def filter_dict( - d: Dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None -) -> Dict[str, Any]: + d: dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None +) -> dict[str, Any]: return {k: v for k, v in d.items() if (not keep or k in keep) and (not drop or k not in drop)} @overload def atomic_writer( path: Union[os.PathLike, str], mode: Literal["wb"] -) -> ContextManager[BinaryIO]: ... +) -> contextlib.AbstractContextManager[BinaryIO]: ... @overload def atomic_writer( path: Union[os.PathLike, str], mode: Literal["w"], encoding: str = "UTF-8" -) -> ContextManager[TextIO]: ... +) -> contextlib.AbstractContextManager[TextIO]: ... @contextlib.contextmanager diff --git a/cvat-sdk/cvat_sdk/datasets/caching.py b/cvat-sdk/cvat_sdk/datasets/caching.py index 08e0c123bfe1..f47cdfc3260f 100644 --- a/cvat-sdk/cvat_sdk/datasets/caching.py +++ b/cvat-sdk/cvat_sdk/datasets/caching.py @@ -6,9 +6,10 @@ import json import shutil from abc import ABCMeta, abstractmethod +from collections.abc import Mapping from enum import Enum, auto from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Type, TypeVar, Union, cast +from typing import Any, Callable, TypeVar, Union, cast from attrs import define @@ -39,7 +40,7 @@ class UpdatePolicy(Enum): """ -_CacheObject = Dict[str, Any] +_CacheObject = dict[str, Any] class _CacheObjectModel(metaclass=ABCMeta): @@ -106,7 +107,7 @@ def _serialize_model(self, model: _ModelType) -> _CacheObject: else: raise NotImplementedError("Unexpected model type") - def load_model(self, path: Path, model_type: Type[_ModelType]) -> _ModelType: + def load_model(self, path: Path, model_type: type[_ModelType]) -> _ModelType: return self._deserialize_model(self._load_object(path), model_type) def save_model(self, path: Path, model: _ModelType) -> None: @@ -120,7 +121,7 @@ def ensure_task_model( self, task_id: int, filename: str, - model_type: Type[_ModelType], + model_type: type[_ModelType], downloader: Callable[[], _ModelType], model_description: str, ) -> _ModelType: ... @@ -166,7 +167,7 @@ def ensure_task_model( self, task_id: int, filename: str, - model_type: Type[_ModelType], + model_type: type[_ModelType], downloader: Callable[[], _ModelType], model_description: str, ) -> _ModelType: @@ -225,7 +226,7 @@ def ensure_task_model( self, task_id: int, filename: str, - model_type: Type[_ModelType], + model_type: type[_ModelType], downloader: Callable[[], _ModelType], model_description: str, ) -> _ModelType: @@ -247,7 +248,7 @@ def retrieve_project(self, project_id: int) -> Project: @define class _OfflineTaskModel(_CacheObjectModel): api_model: models.ITaskRead - labels: List[models.ILabel] + labels: list[models.ILabel] def dump(self) -> _CacheObject: return { @@ -278,15 +279,15 @@ def __init__( self._offline_model = cached_model self._cache_manager = cache_manager - def get_labels(self) -> List[models.ILabel]: + def get_labels(self) -> list[models.ILabel]: return self._offline_model.labels @define class _OfflineProjectModel(_CacheObjectModel): api_model: models.IProjectRead - task_ids: List[int] - labels: List[models.ILabel] + task_ids: list[int] + labels: list[models.ILabel] def dump(self) -> _CacheObject: return { @@ -320,14 +321,14 @@ def __init__( self._offline_model = cached_model self._cache_manager = cache_manager - def get_tasks(self) -> List[Task]: + def get_tasks(self) -> list[Task]: return [self._cache_manager.retrieve_task(t) for t in self._offline_model.task_ids] - def get_labels(self) -> List[models.ILabel]: + def get_labels(self) -> list[models.ILabel]: return self._offline_model.labels -_CACHE_MANAGER_CLASSES: Mapping[UpdatePolicy, Type[CacheManager]] = { +_CACHE_MANAGER_CLASSES: Mapping[UpdatePolicy, type[CacheManager]] = { UpdatePolicy.IF_MISSING_OR_STALE: _CacheManagerOnline, UpdatePolicy.NEVER: _CacheManagerOffline, } diff --git a/cvat-sdk/cvat_sdk/datasets/common.py b/cvat-sdk/cvat_sdk/datasets/common.py index b407c490802c..9b816e688bf4 100644 --- a/cvat-sdk/cvat_sdk/datasets/common.py +++ b/cvat-sdk/cvat_sdk/datasets/common.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MIT import abc -from typing import List, Optional +from typing import Optional import attrs import attrs.validators @@ -24,8 +24,8 @@ class FrameAnnotations: Contains annotations that pertain to a single frame. """ - tags: List[models.LabeledImage] = attrs.Factory(list) - shapes: List[models.LabeledShape] = attrs.Factory(list) + tags: list[models.LabeledImage] = attrs.Factory(list) + shapes: list[models.LabeledShape] = attrs.Factory(list) class MediaElement(metaclass=abc.ABCMeta): diff --git a/cvat-sdk/cvat_sdk/datasets/task_dataset.py b/cvat-sdk/cvat_sdk/datasets/task_dataset.py index cf66fa7ab0ea..68424cbb3815 100644 --- a/cvat-sdk/cvat_sdk/datasets/task_dataset.py +++ b/cvat-sdk/cvat_sdk/datasets/task_dataset.py @@ -5,8 +5,8 @@ from __future__ import annotations import zipfile +from collections.abc import Iterable, Sequence from concurrent.futures import ThreadPoolExecutor -from typing import Iterable, Sequence import PIL.Image diff --git a/cvat-sdk/cvat_sdk/pytorch/common.py b/cvat-sdk/cvat_sdk/pytorch/common.py index 97ef38bc33a8..0c208cfc0bd4 100644 --- a/cvat-sdk/cvat_sdk/pytorch/common.py +++ b/cvat-sdk/cvat_sdk/pytorch/common.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: MIT -from typing import Mapping +from collections.abc import Mapping import attrs diff --git a/cvat-sdk/cvat_sdk/pytorch/project_dataset.py b/cvat-sdk/cvat_sdk/pytorch/project_dataset.py index ada554ee1210..7548d9e233a0 100644 --- a/cvat-sdk/cvat_sdk/pytorch/project_dataset.py +++ b/cvat-sdk/cvat_sdk/pytorch/project_dataset.py @@ -3,7 +3,8 @@ # SPDX-License-Identifier: MIT import os -from typing import Callable, Container, Mapping, Optional +from collections.abc import Container, Mapping +from typing import Callable, Optional import torch import torch.utils.data diff --git a/cvat-sdk/cvat_sdk/pytorch/task_dataset.py b/cvat-sdk/cvat_sdk/pytorch/task_dataset.py index 8964d2db47db..8434102d9e63 100644 --- a/cvat-sdk/cvat_sdk/pytorch/task_dataset.py +++ b/cvat-sdk/cvat_sdk/pytorch/task_dataset.py @@ -4,7 +4,8 @@ import os import types -from typing import Callable, Mapping, Optional +from collections.abc import Mapping +from typing import Callable, Optional import torchvision.datasets diff --git a/cvat-sdk/cvat_sdk/pytorch/transforms.py b/cvat-sdk/cvat_sdk/pytorch/transforms.py index 1fb99362defc..5c8a4f7390cb 100644 --- a/cvat-sdk/cvat_sdk/pytorch/transforms.py +++ b/cvat-sdk/cvat_sdk/pytorch/transforms.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: MIT -from typing import FrozenSet, TypedDict +from typing import TypedDict import attrs import attrs.validators @@ -63,7 +63,7 @@ class ExtractBoundingBoxes: * Rotated shapes are not supported. """ - include_shape_types: FrozenSet[str] = attrs.field( + include_shape_types: frozenset[str] = attrs.field( converter=frozenset, validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)), kw_only=True, diff --git a/cvat-sdk/gen/generate.sh b/cvat-sdk/gen/generate.sh index c506564999b6..ca9a08be98fe 100755 --- a/cvat-sdk/gen/generate.sh +++ b/cvat-sdk/gen/generate.sh @@ -8,7 +8,7 @@ set -e GENERATOR_VERSION="v6.0.1" -VERSION="2.21.3" +VERSION="2.22.0" LIB_NAME="cvat_sdk" LAYER1_LIB_NAME="${LIB_NAME}/api_client" DST_DIR="$(cd "$(dirname -- "$0")/.." && pwd)" diff --git a/cvat-sdk/gen/postprocess.py b/cvat-sdk/gen/postprocess.py index 779dc2e10326..d45d9680dc2f 100755 --- a/cvat-sdk/gen/postprocess.py +++ b/cvat-sdk/gen/postprocess.py @@ -30,8 +30,8 @@ def collect_operations(schema): class Replacer: - REPLACEMENT_TOKEN = r"%%%" - ARGS_TOKEN = r"!!!" + REPLACEMENT_TOKEN = r"%%%" # nosec: hardcoded_password_string + ARGS_TOKEN = r"!!!" # nosec: hardcoded_password_string def __init__(self, schema): self._schema = schema @@ -57,9 +57,9 @@ def make_api_name(self, name: str) -> str: return underscore(name) def make_type_annotation(self, type_repr: str) -> str: - type_repr = type_repr.replace("[", "typing.List[") + type_repr = type_repr.replace("[", "list[") type_repr = type_repr.replace("(", "typing.Union[").replace(")", "]") - type_repr = type_repr.replace("{", "typing.Dict[").replace(":", ",").replace("}", "]") + type_repr = type_repr.replace("{", "dict[").replace(":", ",").replace("}", "]") ANY_pattern = "bool, date, datetime, dict, float, int, list, str" type_repr = type_repr.replace(ANY_pattern, "typing.Any") diff --git a/cvat-sdk/gen/templates/openapi-generator/api.mustache b/cvat-sdk/gen/templates/openapi-generator/api.mustache index 160d641bc305..aa7c3ac686a1 100644 --- a/cvat-sdk/gen/templates/openapi-generator/api.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/api.mustache @@ -240,10 +240,10 @@ class {{classname}}(object): _spec_property_naming: bool = False, _content_type: typing.Optional[str] = None, _host_index: typing.Optional[int] = None, - _request_auths: typing.Optional[typing.List] = None, + _request_auths: typing.Optional[list] = None, _async_call: bool = False, **kwargs, - ) -> typing.Tuple[typing.Optional[{{>return_type}}], urllib3.HTTPResponse]: + ) -> tuple[typing.Optional[{{>return_type}}], urllib3.HTTPResponse]: """{{{summary}}}{{^summary}}{{>operation_name}}{{/summary}} # noqa: E501 {{#notes}} diff --git a/cvat-sdk/gen/templates/openapi-generator/api_client.mustache b/cvat-sdk/gen/templates/openapi-generator/api_client.mustache index 436bd26f2d54..d49af604ce94 100644 --- a/cvat-sdk/gen/templates/openapi-generator/api_client.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/api_client.mustache @@ -68,8 +68,8 @@ class ApiClient(object): def __init__(self, configuration: typing.Optional[Configuration] = None, - headers: typing.Optional[typing.Dict[str, str]] = None, - cookies: typing.Optional[typing.Dict[str, str]] = None, + headers: typing.Optional[dict[str, str]] = None, + cookies: typing.Optional[dict[str, str]] = None, pool_threads: int = 1): """ :param configuration: configuration object for this client @@ -85,7 +85,7 @@ class ApiClient(object): self.pool_threads = pool_threads self.rest_client = rest.RESTClientObject(configuration) - self.default_headers: typing.Dict[str, str] = headers or {} + self.default_headers: dict[str, str] = headers or {} self.cookies = SimpleCookie() if cookies: self.cookies.update(cookies) @@ -161,22 +161,22 @@ class ApiClient(object): self, resource_path: str, method: str, - path_params: typing.Optional[typing.Dict[str, typing.Any]] = None, - query_params: typing.Optional[typing.List[typing.Tuple[str, typing.Any]]] = None, - header_params: typing.Optional[typing.Dict[str, typing.Any]] = None, + path_params: typing.Optional[dict[str, typing.Any]] = None, + query_params: typing.Optional[list[tuple[str, typing.Any]]] = None, + header_params: typing.Optional[dict[str, typing.Any]] = None, body: typing.Optional[typing.Any] = None, - post_params: typing.Optional[typing.List[typing.Tuple[str, typing.Any]]] = None, - files: typing.Optional[typing.Dict[str, typing.List[io.IOBase]]] = None, - response_schema: typing.Optional[typing.Tuple[typing.Any]] = None, - auth_settings: typing.Optional[typing.List[str]] = None, - collection_formats: typing.Optional[typing.Dict[str, str]] = None, + post_params: typing.Optional[list[tuple[str, typing.Any]]] = None, + files: typing.Optional[dict[str, list[io.IOBase]]] = None, + response_schema: typing.Optional[tuple[typing.Any]] = None, + auth_settings: typing.Optional[list[str]] = None, + collection_formats: typing.Optional[dict[str, str]] = None, *, _parse_response: bool = True, - _request_timeout: typing.Optional[typing.Union[int, float, typing.Tuple]] = None, + _request_timeout: typing.Optional[typing.Union[int, float, tuple]] = None, _host: typing.Optional[str] = None, _check_type: typing.Optional[bool] = None, _check_status: bool = True, - _request_auths: typing.Optional[typing.List[typing.Dict[str, typing.Any]]] = None + _request_auths: typing.Optional[list[dict[str, typing.Any]]] = None ): config = self.configuration @@ -271,7 +271,7 @@ class ApiClient(object): return (return_data, response) {{/tornado}} - def get_common_headers(self) -> typing.Dict[str, str]: + def get_common_headers(self) -> dict[str, str]: """ Returns a headers dict with all the required headers for requests """ @@ -324,7 +324,7 @@ class ApiClient(object): """ return to_json(obj, read_files=read_files) - def deserialize(self, response: HTTPResponse, response_schema: typing.Tuple, *, _check_type: bool): + def deserialize(self, response: HTTPResponse, response_schema: tuple, *, _check_type: bool): """Deserializes response into an object. :param response (urllib3.HTTPResponse): object to be deserialized. @@ -384,22 +384,22 @@ class ApiClient(object): self, resource_path: str, method: str, - path_params: typing.Optional[typing.Dict[str, typing.Any]] = None, - query_params: typing.Optional[typing.List[typing.Tuple[str, typing.Any]]] = None, - header_params: typing.Optional[typing.Dict[str, typing.Any]] = None, + path_params: typing.Optional[dict[str, typing.Any]] = None, + query_params: typing.Optional[list[tuple[str, typing.Any]]] = None, + header_params: typing.Optional[dict[str, typing.Any]] = None, body: typing.Optional[typing.Any] = None, - post_params: typing.Optional[typing.List[typing.Tuple[str, typing.Any]]] = None, - files: typing.Optional[typing.Dict[str, typing.List[io.IOBase]]] = None, - response_schema: typing.Optional[typing.Tuple[typing.Any]] = None, - auth_settings: typing.Optional[typing.List[str]] = None, - collection_formats: typing.Optional[typing.Dict[str, str]] = None, + post_params: typing.Optional[list[tuple[str, typing.Any]]] = None, + files: typing.Optional[dict[str, list[io.IOBase]]] = None, + response_schema: typing.Optional[tuple[typing.Any]] = None, + auth_settings: typing.Optional[list[str]] = None, + collection_formats: typing.Optional[dict[str, str]] = None, *, _async_call: typing.Optional[bool] = None, _parse_response: bool = True, - _request_timeout: typing.Optional[typing.Union[int, float, typing.Tuple]] = None, + _request_timeout: typing.Optional[typing.Union[int, float, tuple]] = None, _host: typing.Optional[str] = None, _check_type: typing.Optional[bool] = None, - _request_auths: typing.Optional[typing.List[typing.Dict[str, typing.Any]]] = None, + _request_auths: typing.Optional[list[dict[str, typing.Any]]] = None, _check_status: bool = True, ): """Makes the HTTP request (synchronous) and returns deserialized data. @@ -580,7 +580,7 @@ class ApiClient(object): new_params.append((k, v)) return new_params - def _serialize_file(self, file_instance: io.IOBase) -> typing.Tuple[str, typing.Union[str, bytes], str]: + def _serialize_file(self, file_instance: io.IOBase) -> tuple[str, typing.Union[str, bytes], str]: if file_instance.closed is True: raise ApiValueError("Cannot read a closed file.") filename = os.path.basename(file_instance.name) @@ -592,8 +592,7 @@ class ApiClient(object): return filename, filedata, mimetype def files_parameters(self, - files: typing.Optional[typing.Dict[str, - typing.List[io.IOBase]]] = None): + files: typing.Optional[dict[str, list[io.IOBase]]] = None): """Builds form parameters. :param files: None or a dict with key=param_name and @@ -714,7 +713,7 @@ class ApiClient(object): {{#apiInfo}}{{#apis}} {{>api_name}}: '{{classname}}'{{/apis}}{{/apiInfo}} - _apis: typing.Dict[str, object] = { {{#apiInfo}}{{#apis}} + _apis: dict[str, object] = { {{#apiInfo}}{{#apis}} '{{>api_name}}': [None, '{{classname}}'],{{/apis}}{{/apiInfo}} } @@ -739,10 +738,10 @@ class ApiClient(object): class Endpoint(object): def __init__(self, - settings: typing.Optional[typing.Dict[str, typing.Any]] = None, - params_map: typing.Optional[typing.Dict[str, typing.Any]] = None, - root_map: typing.Optional[typing.Dict[str, typing.Any]] = None, - headers_map: typing.Optional[typing.Dict[str, typing.Any]] = None, + settings: typing.Optional[dict[str, typing.Any]] = None, + params_map: typing.Optional[dict[str, typing.Any]] = None, + root_map: typing.Optional[dict[str, typing.Any]] = None, + headers_map: typing.Optional[dict[str, typing.Any]] = None, api_client: typing.Optional[ApiClient] = None ): """Creates an endpoint @@ -897,9 +896,9 @@ class Endpoint(object): _spec_property_naming: bool = False, _content_type: typing.Optional[str] = None, _host_index: typing.Optional[int] = None, - _request_auths: typing.Optional[typing.List] = None, + _request_auths: typing.Optional[list] = None, _async_call: bool = False, - **kwargs) -> typing.Tuple[typing.Optional[typing.Any], HTTPResponse]: + **kwargs) -> tuple[typing.Optional[typing.Any], HTTPResponse]: """ Keyword Args: endpoint args diff --git a/cvat-sdk/gen/templates/openapi-generator/configuration.mustache b/cvat-sdk/gen/templates/openapi-generator/configuration.mustache index cec0c548f1d7..e66aec294afc 100644 --- a/cvat-sdk/gen/templates/openapi-generator/configuration.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/configuration.mustache @@ -169,8 +169,8 @@ class Configuration: def __init__(self, host: typing.Optional[str] = None, - api_key: typing.Optional[typing.Dict[str, str]] = None, - api_key_prefix: typing.Optional[typing.Dict[str, str]] = None, + api_key: typing.Optional[dict[str, str]] = None, + api_key_prefix: typing.Optional[dict[str, str]] = None, username: typing.Optional[str] = None, password: typing.Optional[str]=None, discard_unknown_keys: bool = False, @@ -179,9 +179,9 @@ class Configuration: signing_info=None, {{/hasHttpSignatureMethods}} server_index: typing.Optional[int] = None, - server_variables: typing.Optional[typing.Dict[str, str]] = None, + server_variables: typing.Optional[dict[str, str]] = None, server_operation_index: typing.Optional[int] = None, - server_operation_variables: typing.Optional[typing.Dict[str, str]] = None, + server_operation_variables: typing.Optional[dict[str, str]] = None, ssl_ca_cert: typing.Optional[str] = None, verify_ssl: typing.Optional[bool] = None, ) -> None: diff --git a/cvat-ui/package.json b/cvat-ui/package.json index 2c43904a3fb9..a74485fa107d 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.66.2", + "version": "1.66.4", "description": "CVAT single-page application", "main": "src/index.tsx", "scripts": { diff --git a/cvat-ui/src/actions/annotation-actions.ts b/cvat-ui/src/actions/annotation-actions.ts index 670ace099e5a..0cc8f3052bc0 100644 --- a/cvat-ui/src/actions/annotation-actions.ts +++ b/cvat-ui/src/actions/annotation-actions.ts @@ -126,6 +126,8 @@ export enum AnnotationActionTypes { COLLAPSE_APPEARANCE = 'COLLAPSE_APPEARANCE', COLLAPSE_OBJECT_ITEMS = 'COLLAPSE_OBJECT_ITEMS', ACTIVATE_OBJECT = 'ACTIVATE_OBJECT', + UPDATE_EDITED_STATE = 'UPDATE_EDITED_STATE', + HIDE_ACTIVE_OBJECT = 'HIDE_ACTIVE_OBJECT', REMOVE_OBJECT = 'REMOVE_OBJECT', REMOVE_OBJECT_SUCCESS = 'REMOVE_OBJECT_SUCCESS', REMOVE_OBJECT_FAILED = 'REMOVE_OBJECT_FAILED', @@ -450,6 +452,7 @@ export function propagateObjectAsync(from: number, to: number): ThunkAction { const { job: { instance: sessionInstance, + frameNumbers, }, annotations: { activatedStateID, @@ -463,12 +466,17 @@ export function propagateObjectAsync(from: number, to: number): ThunkAction { throw new Error('There is not an activated object state to be propagated'); } - await sessionInstance.logger.log(EventScope.propagateObject, { count: Math.abs(to - from) }); - const states = cvat.utils.propagateShapes([objectState], from, to); + if (!sessionInstance) { + throw new Error('SessionInstance is not defined, propagation is not possible'); + } - await sessionInstance.annotations.put(states); - const history = await sessionInstance.actions.get(); + const states = cvat.utils.propagateShapes([objectState], from, to, frameNumbers); + if (states.length) { + await sessionInstance.logger.log(EventScope.propagateObject, { count: states.length }); + await sessionInstance.annotations.put(states); + } + const history = await sessionInstance.actions.get(); dispatch({ type: AnnotationActionTypes.PROPAGATE_OBJECT_SUCCESS, payload: { history }, @@ -594,10 +602,10 @@ export function confirmCanvasReadyAsync(): ThunkAction { return async (dispatch: ThunkDispatch, getState: () => CombinedState): Promise => { try { const state: CombinedState = getState(); - const { instance: job } = state.annotation.job; + const job = state.annotation.job.instance as Job; + const includedFrames = state.annotation.job.frameNumbers; const { changeFrameEvent } = state.annotation.player.frame; const chunks = await job.frames.cachedChunks() as number[]; - const includedFrames = await job.frames.frameNumbers() as number[]; const { frameCount, dataChunkSize } = job; const ranges = chunks.map((chunk) => ( @@ -914,7 +922,6 @@ export function getJobAsync({ } } - const jobMeta = await cvat.frames.getMeta('job', job.id); // frame query parameter does not work for GT job const frameNumber = Number.isInteger(initialFrame) && gtJob?.id !== job.id ? initialFrame as number : @@ -923,6 +930,8 @@ export function getJobAsync({ )) || job.startFrame; const frameData = await job.frames.get(frameNumber); + const jobMeta = await cvat.frames.getMeta('job', job.id); + const frameNumbers = await job.frames.frameNumbers(); try { // call first getting of frame data before rendering interface // to load and decode first chunk @@ -960,6 +969,7 @@ export function getJobAsync({ payload: { openTime, job, + frameNumbers, jobMeta, queryParameters, groundTruthInstance: gtJob || null, @@ -1320,7 +1330,7 @@ export function searchAnnotationsAsync( }; } -const ShapeTypeToControl: Record = { +export const ShapeTypeToControl: Record = { [ShapeType.RECTANGLE]: ActiveControl.DRAW_RECTANGLE, [ShapeType.POLYLINE]: ActiveControl.DRAW_POLYLINE, [ShapeType.POLYGON]: ActiveControl.DRAW_POLYGON, @@ -1608,3 +1618,50 @@ export function restoreFrameAsync(frame: number): ThunkAction { } }; } + +export function changeHideActiveObjectAsync(hide: boolean): ThunkAction { + return async (dispatch: ThunkDispatch, getState): Promise => { + const state = getState(); + const { instance: canvas } = state.annotation.canvas; + if (canvas) { + (canvas as Canvas).configure({ + hideEditedObject: hide, + }); + + const { objectState } = state.annotation.editing; + if (objectState) { + objectState.hidden = hide; + await dispatch(updateAnnotationsAsync([objectState])); + } + + dispatch({ + type: AnnotationActionTypes.HIDE_ACTIVE_OBJECT, + payload: { + hide, + }, + }); + } + }; +} + +export function updateEditedStateAsync(objectState: ObjectState | null): ThunkAction { + return async (dispatch: ThunkDispatch, getState): Promise => { + let newActiveObjectHidden = false; + if (objectState) { + newActiveObjectHidden = objectState.hidden; + } + + dispatch({ + type: AnnotationActionTypes.UPDATE_EDITED_STATE, + payload: { + objectState, + }, + }); + + const state = getState(); + const { activeObjectHidden } = state.annotation.canvas; + if (activeObjectHidden !== newActiveObjectHidden) { + dispatch(changeHideActiveObjectAsync(newActiveObjectHidden)); + } + }; +} diff --git a/cvat-ui/src/actions/export-actions.ts b/cvat-ui/src/actions/export-actions.ts index db59a6315c6a..6872e4f853c0 100644 --- a/cvat-ui/src/actions/export-actions.ts +++ b/cvat-ui/src/actions/export-actions.ts @@ -4,11 +4,12 @@ // SPDX-License-Identifier: MIT import { ActionUnion, createAction, ThunkAction } from 'utils/redux'; - -import { Storage, ProjectOrTaskOrJob, Job } from 'cvat-core-wrapper'; import { - getInstanceType, RequestInstanceType, listen, RequestsActions, - shouldListenForProgress, + Storage, ProjectOrTaskOrJob, Job, getCore, StorageLocation, +} from 'cvat-core-wrapper'; +import { + getInstanceType, RequestInstanceType, listen, + RequestsActions, updateRequestProgress, } from './requests-actions'; export enum ExportActionTypes { @@ -24,6 +25,8 @@ export enum ExportActionTypes { EXPORT_BACKUP_FAILED = 'EXPORT_BACKUP_FAILED', } +const core = getCore(); + export const exportActions = { openExportDatasetModal: (instance: ProjectOrTaskOrJob) => ( createAction(ExportActionTypes.OPEN_EXPORT_DATASET_MODAL, { instance }) @@ -36,7 +39,7 @@ export const exportActions = { instanceType: 'project' | 'task' | 'job', format: string, resource: 'dataset' | 'annotations', - target?: 'local' | 'cloudstorage', + target?: StorageLocation, ) => ( createAction(ExportActionTypes.EXPORT_DATASET_SUCCESS, { instance, @@ -67,7 +70,7 @@ export const exportActions = { closeExportBackupModal: (instance: ProjectOrTaskOrJob) => ( createAction(ExportActionTypes.CLOSE_EXPORT_BACKUP_MODAL, { instance }) ), - exportBackupSuccess: (instance: Exclude | RequestInstanceType, instanceType: 'task' | 'project', target?: 'local' | 'cloudstorage') => ( + exportBackupSuccess: (instance: Exclude | RequestInstanceType, instanceType: 'task' | 'project', target?: StorageLocation) => ( createAction(ExportActionTypes.EXPORT_BACKUP_SUCCESS, { instance, instanceType, target }) ), exportBackupFailed: (instance: Exclude | RequestInstanceType, instanceType: 'task' | 'project', error: any) => ( @@ -75,30 +78,9 @@ export const exportActions = { ), }; -export async function listenExportDatasetAsync( - rqID: string, - dispatch: (action: ExportActions | RequestsActions) => void, - params: { - instance: ProjectOrTaskOrJob | RequestInstanceType, - format: string, - saveImages: boolean, - }, -): Promise { - const { instance, format, saveImages } = params; - const resource = saveImages ? 'dataset' : 'annotations'; - - const instanceType = getInstanceType(instance); - try { - const result = await listen(rqID, dispatch); - const target = !result?.url ? 'cloudstorage' : 'local'; - dispatch(exportActions.exportDatasetSuccess( - instance, instanceType, format, resource, target, - )); - } catch (error) { - dispatch(exportActions.exportDatasetFailed(instance, instanceType, format, resource, error)); - } -} - +/** * + * Function is supposed to be used when a new dataset export request initiated by a user +** */ export const exportDatasetAsync = ( instance: ProjectOrTaskOrJob, format: string, @@ -106,21 +88,23 @@ export const exportDatasetAsync = ( useDefaultSettings: boolean, targetStorage: Storage, name?: string, -): ThunkAction => async (dispatch, getState) => { - const state = getState(); - +): ThunkAction => async (dispatch) => { const resource = saveImages ? 'dataset' : 'annotations'; const instanceType = getInstanceType(instance); try { const rqID = await instance.annotations .exportDataset(format, saveImages, useDefaultSettings, targetStorage, name); - if (shouldListenForProgress(rqID, state.requests)) { - await listenExportDatasetAsync(rqID, dispatch, { - instance, format, saveImages, + + if (rqID) { + await core.requests.listen(rqID, { + callback: (updatedRequest) => updateRequestProgress(updatedRequest, dispatch), }); - } - if (!rqID) { + const target = targetStorage.location; + dispatch(exportActions.exportDatasetSuccess( + instance, instanceType, format, resource, target, + )); + } else { dispatch(exportActions.exportDatasetSuccess( instance, instanceType, format, resource, )); @@ -130,47 +114,79 @@ export const exportDatasetAsync = ( } }; -export async function listenExportBackupAsync( +/** * + * Function is supposed to be used when a new backup export request initiated by a user +** */ +export const exportBackupAsync = ( + instance: Exclude, + targetStorage: Storage, + useDefaultSetting: boolean, + fileName: string, +): ThunkAction => async (dispatch) => { + const instanceType = getInstanceType(instance) as 'project' | 'task'; + try { + const rqID = await instance.backup(targetStorage, useDefaultSetting, fileName); + if (rqID) { + await core.requests.listen(rqID, { + callback: (updatedRequest) => updateRequestProgress(updatedRequest, dispatch), + }); + const target = targetStorage.location; + dispatch(exportActions.exportBackupSuccess(instance, instanceType, target)); + } else { + dispatch(exportActions.exportBackupSuccess(instance, instanceType)); + } + } catch (error) { + dispatch(exportActions.exportBackupFailed(instance, instanceType, error as Error)); + } +}; + +/** * + * Function is supposed to be used when application starts listening to existing dataset export request +** */ +export async function listenExportDatasetAsync( rqID: string, dispatch: (action: ExportActions | RequestsActions) => void, params: { - instance: Exclude | RequestInstanceType, + instance: ProjectOrTaskOrJob | RequestInstanceType, + format: string, + saveImages: boolean, }, ): Promise { - const { instance } = params; - const instanceType = getInstanceType(instance) as 'project' | 'task'; + const { instance, format, saveImages } = params; + const resource = saveImages ? 'dataset' : 'annotations'; + const instanceType = getInstanceType(instance); try { const result = await listen(rqID, dispatch); - const target = !result?.url ? 'cloudstorage' : 'local'; - dispatch(exportActions.exportBackupSuccess(instance, instanceType, target)); + const target = !result?.url ? StorageLocation.CLOUD_STORAGE : StorageLocation.LOCAL; + dispatch(exportActions.exportDatasetSuccess( + instance, instanceType, format, resource, target, + )); } catch (error) { - dispatch(exportActions.exportBackupFailed(instance, instanceType, error as Error)); + dispatch(exportActions.exportDatasetFailed(instance, instanceType, format, resource, error)); } } -export const exportBackupAsync = ( - instance: Exclude, - targetStorage: Storage, - useDefaultSetting: boolean, - fileName: string, -): ThunkAction => async (dispatch, getState) => { - const state = getState(); - +/** * + * Function is supposed to be used when application starts listening to existing backup export request +** */ +export async function listenExportBackupAsync( + rqID: string, + dispatch: (action: ExportActions | RequestsActions) => void, + params: { + instance: Exclude | RequestInstanceType, + }, +): Promise { + const { instance } = params; const instanceType = getInstanceType(instance) as 'project' | 'task'; try { - const rqID = await instance - .backup(targetStorage, useDefaultSetting, fileName); - if (shouldListenForProgress(rqID, state.requests)) { - await listenExportBackupAsync(rqID, dispatch, { instance }); - } - if (!rqID) { - dispatch(exportActions.exportBackupSuccess(instance, instanceType)); - } + const result = await listen(rqID, dispatch); + const target = !result?.url ? StorageLocation.CLOUD_STORAGE : StorageLocation.LOCAL; + dispatch(exportActions.exportBackupSuccess(instance, instanceType, target)); } catch (error) { dispatch(exportActions.exportBackupFailed(instance, instanceType, error as Error)); } -}; +} export type ExportActions = ActionUnion; diff --git a/cvat-ui/src/actions/import-actions.ts b/cvat-ui/src/actions/import-actions.ts index e47db0b47818..d7e3a548bb3b 100644 --- a/cvat-ui/src/actions/import-actions.ts +++ b/cvat-ui/src/actions/import-actions.ts @@ -4,15 +4,14 @@ // SPDX-License-Identifier: MIT import { createAction, ActionUnion, ThunkAction } from 'utils/redux'; -import { CombinedState } from 'reducers'; import { getCore, Storage, Job, Task, Project, ProjectOrTaskOrJob, } from 'cvat-core-wrapper'; import { getProjectsAsync } from './projects-actions'; import { AnnotationActionTypes, fetchAnnotationsAsync } from './annotation-actions'; import { - getInstanceType, listen, RequestInstanceType, RequestsActions, - shouldListenForProgress, + getInstanceType, listen, RequestInstanceType, + RequestsActions, updateRequestProgress, } from './requests-actions'; const core = getCore(); @@ -69,25 +68,9 @@ export const importActions = { ), }; -export async function listenImportDatasetAsync( - rqID: string, - dispatch: (action: ImportActions | RequestsActions) => void, - params: { - instance: ProjectOrTaskOrJob | RequestInstanceType, - }, -): Promise { - const { instance } = params; - - const instanceType = getInstanceType(instance); - const resource = instanceType === 'project' ? 'dataset' : 'annotation'; - try { - await listen(rqID, dispatch); - dispatch(importActions.importDatasetSuccess(instance, resource)); - } catch (error) { - dispatch(importActions.importDatasetFailed(instance, resource, error)); - } -} - +/** * + * Function is supposed to be used when a new dataset import request initiated by a user +** */ export const importDatasetAsync = ( instance: ProjectOrTaskOrJob, format: string, @@ -100,55 +83,63 @@ export const importDatasetAsync = ( const instanceType = getInstanceType(instance); const resource = instanceType === 'project' ? 'dataset' : 'annotation'; - try { - const state: CombinedState = getState(); + const listenForImport = (rqID: string) => core.requests.listen(rqID, { + callback: (updatedRequest) => updateRequestProgress(updatedRequest, dispatch), + }); + try { if (instanceType === 'project') { dispatch(importActions.importDataset(instance, format)); - const rqID = await (instance as Project).annotations - .importDataset(format, useDefaultSettings, sourceStorage, file, { + const rqID = await (instance as Project).annotations.importDataset( + format, + useDefaultSettings, + sourceStorage, + file, + { convMaskToPoly, updateStatusCallback: (message: string, progress: number) => ( dispatch(importActions.importDatasetUpdateStatus( instance, Math.floor(progress * 100), message, )) ), - }); - if (shouldListenForProgress(rqID, state.requests)) { - await listen(rqID, dispatch); - } + }, + ); + + await listenForImport(rqID); } else if (instanceType === 'task') { dispatch(importActions.importDataset(instance, format)); - const rqID = await (instance as Task).annotations - .upload(format, useDefaultSettings, sourceStorage, file, { - convMaskToPoly, - }); - if (shouldListenForProgress(rqID, state.requests)) { - await listen(rqID, dispatch); - } + const rqID = await (instance as Task).annotations.upload( + format, + useDefaultSettings, + sourceStorage, + file, + { convMaskToPoly }, + ); + await listenForImport(rqID); } else { // job dispatch(importActions.importDataset(instance, format)); - const rqID = await (instance as Job).annotations - .upload(format, useDefaultSettings, sourceStorage, file, { - convMaskToPoly, + const rqID = await (instance as Job).annotations.upload( + format, + useDefaultSettings, + sourceStorage, + file, + { convMaskToPoly }, + ); + + await listenForImport(rqID); + await (instance as Job).annotations.clear({ reload: true }); + await (instance as Job).actions.clear(); + + // first set empty objects list + // to escape some problems in canvas when shape with the same + // clientID has different type (polygon, rectangle) for example + dispatch({ type: AnnotationActionTypes.UPLOAD_JOB_ANNOTATIONS_SUCCESS }); + + const relevantInstance = getState().annotation.job.instance; + if (relevantInstance && relevantInstance.id === instance.id) { + setTimeout(() => { + dispatch(fetchAnnotationsAsync()); }); - if (shouldListenForProgress(rqID, state.requests)) { - await listen(rqID, dispatch); - - await (instance as Job).annotations.clear({ reload: true }); - await (instance as Job).actions.clear(); - - // first set empty objects list - // to escape some problems in canvas when shape with the same - // clientID has different type (polygon, rectangle) for example - dispatch({ type: AnnotationActionTypes.UPLOAD_JOB_ANNOTATIONS_SUCCESS }); - - const relevantInstance = getState().annotation.job.instance; - if (relevantInstance && relevantInstance.id === instance.id) { - setTimeout(() => { - dispatch(fetchAnnotationsAsync()); - }); - } } } } catch (error) { @@ -163,6 +154,28 @@ export const importDatasetAsync = ( } ); +/** * + * Function is supposed to be used when a new backup import request initiated by a user +** */ +export const importBackupAsync = (instanceType: 'project' | 'task', storage: Storage, file: File | string): ThunkAction => ( + async (dispatch) => { + dispatch(importActions.importBackup()); + try { + const instanceClass = (instanceType === 'task') ? core.classes.Task : core.classes.Project; + const rqID = await instanceClass.restore(storage, file); + const result = await core.requests.listen(rqID, { + callback: (updatedRequest) => updateRequestProgress(updatedRequest, dispatch), + }); + dispatch(importActions.importBackupSuccess(result?.resultID as number, instanceType)); + } catch (error) { + dispatch(importActions.importBackupFailed(instanceType, error)); + } + } +); + +/** * + * Function is supposed to be used when application starts listening to existing backup import request +** */ export async function listenImportBackupAsync( rqID: string, dispatch: (action: ImportActions | RequestsActions) => void, @@ -171,32 +184,34 @@ export async function listenImportBackupAsync( }, ): Promise { const { instanceType } = params; - try { const result = await listen(rqID, dispatch); - - dispatch(importActions.importBackupSuccess(result?.resultID, instanceType)); + dispatch(importActions.importBackupSuccess(result?.resultID as number, instanceType)); } catch (error) { dispatch(importActions.importBackupFailed(instanceType, error)); } } -export const importBackupAsync = (instanceType: 'project' | 'task', storage: Storage, file: File | string): ThunkAction => ( - async (dispatch, getState) => { - const state: CombinedState = getState(); - - dispatch(importActions.importBackup()); +/** * + * Function is supposed to be used when application starts listening to existing dataset import request +** */ +export async function listenImportDatasetAsync( + rqID: string, + dispatch: (action: ImportActions | RequestsActions) => void, + params: { + instance: ProjectOrTaskOrJob | RequestInstanceType, + }, +): Promise { + const { instance } = params; - try { - const instanceClass = (instanceType === 'task') ? core.classes.Task : core.classes.Project; - const rqID = await instanceClass.restore(storage, file); - if (shouldListenForProgress(rqID, state.requests)) { - await listenImportBackupAsync(rqID, dispatch, { instanceType }); - } - } catch (error) { - dispatch(importActions.importBackupFailed(instanceType, error)); - } + const instanceType = getInstanceType(instance); + const resource = instanceType === 'project' ? 'dataset' : 'annotation'; + try { + await listen(rqID, dispatch); + dispatch(importActions.importDatasetSuccess(instance, resource)); + } catch (error) { + dispatch(importActions.importDatasetFailed(instance, resource, error)); } -); +} export type ImportActions = ActionUnion; diff --git a/cvat-ui/src/actions/jobs-actions.ts b/cvat-ui/src/actions/jobs-actions.ts index 7c2df71a9270..e7d13e23b7f1 100644 --- a/cvat-ui/src/actions/jobs-actions.ts +++ b/cvat-ui/src/actions/jobs-actions.ts @@ -1,5 +1,5 @@ // Copyright (C) 2022 Intel Corporation -// Copyright (C) 2023 CVAT.ai Corporation +// Copyright (C) 2023-2024 CVAT.ai Corporation // // SPDX-License-Identifier: MIT @@ -34,7 +34,9 @@ interface JobsList extends Array { } const jobsActions = { - getJobs: (query: Partial) => createAction(JobsActionTypes.GET_JOBS, { query }), + getJobs: (query: Partial, fetchingTimestamp: number) => ( + createAction(JobsActionTypes.GET_JOBS, { query, fetchingTimestamp }) + ), getJobsSuccess: (jobs: JobsList) => ( createAction(JobsActionTypes.GET_JOBS_SUCCESS, { jobs }) ), @@ -73,16 +75,26 @@ const jobsActions = { export type JobsActions = ActionUnion; -export const getJobsAsync = (query: JobsQuery): ThunkAction => async (dispatch) => { +export const getJobsAsync = (query: JobsQuery): ThunkAction => async (dispatch, getState) => { + const requestedOn = Date.now(); + const isRequestRelevant = (): boolean => ( + getState().jobs.fetchingTimestamp === requestedOn + ); + try { // We remove all keys with null values from the query const filteredQuery = filterNull(query); - dispatch(jobsActions.getJobs(filteredQuery as JobsQuery)); + dispatch(jobsActions.getJobs(filteredQuery as JobsQuery, requestedOn)); const jobs = await cvat.jobs.get(filteredQuery); - dispatch(jobsActions.getJobsSuccess(jobs)); + + if (isRequestRelevant()) { + dispatch(jobsActions.getJobsSuccess(jobs)); + } } catch (error) { - dispatch(jobsActions.getJobsFailed(error)); + if (isRequestRelevant()) { + dispatch(jobsActions.getJobsFailed(error)); + } } }; @@ -96,10 +108,20 @@ export const getJobPreviewAsync = (job: Job): ThunkAction => async (dispatch) => } }; -export const createJobAsync = (data: JobData): ThunkAction => async (dispatch) => { - const jobInstance = new cvat.classes.Job(data); +export const createJobAsync = (data: JobData): ThunkAction> => async (dispatch) => { + const initialData = { + type: data.type, + task_id: data.taskID, + }; + const jobInstance = new cvat.classes.Job(initialData); try { - const savedJob = await jobInstance.save(data); + const extras = { + frame_selection_method: data.frameSelectionMethod, + seed: data.seed, + frame_count: data.frameCount, + frames_per_job_count: data.framesPerJobCount, + }; + const savedJob = await jobInstance.save(extras); return savedJob; } catch (error) { dispatch(jobsActions.createJobFailed(error)); diff --git a/cvat-ui/src/actions/projects-actions.ts b/cvat-ui/src/actions/projects-actions.ts index 6ab08543caf4..7dab31145895 100644 --- a/cvat-ui/src/actions/projects-actions.ts +++ b/cvat-ui/src/actions/projects-actions.ts @@ -33,7 +33,7 @@ export enum ProjectsActionTypes { } const projectActions = { - getProjects: () => createAction(ProjectsActionTypes.GET_PROJECTS), + getProjects: (fetchingTimestamp: number) => createAction(ProjectsActionTypes.GET_PROJECTS, { fetchingTimestamp }), getProjectsSuccess: (array: any[], count: number) => ( createAction(ProjectsActionTypes.GET_PROJECTS_SUCCESS, { array, count }) ), @@ -86,8 +86,13 @@ export function getProjectTasksAsync(tasksQuery: Partial = {}): Thun export function getProjectsAsync( query: Partial, tasksQuery: Partial = {}, ): ThunkAction { - return async (dispatch: ThunkDispatch): Promise => { - dispatch(projectActions.getProjects()); + return async (dispatch: ThunkDispatch, getState): Promise => { + const requestedOn = Date.now(); + const isRequestRelevant = (): boolean => ( + getState().projects.fetchingTimestamp === requestedOn + ); + + dispatch(projectActions.getProjects(requestedOn)); dispatch(projectActions.updateProjectsGettingQuery(query, tasksQuery)); // Clear query object from null fields @@ -100,20 +105,22 @@ export function getProjectsAsync( try { result = await cvat.projects.get(filteredQuery); } catch (error) { - dispatch(projectActions.getProjectsFailed(error)); + if (isRequestRelevant()) { + dispatch(projectActions.getProjectsFailed(error)); + } return; } - const array = Array.from(result); - - dispatch(projectActions.getProjectsSuccess(array, result.count)); - - // Appropriate tasks fetching process needs with retrieving only a single project - if (Object.keys(filteredQuery).includes('id') && typeof filteredQuery.id === 'number') { - dispatch(getProjectTasksAsync({ - ...tasksQuery, - projectId: filteredQuery.id, - })); + if (isRequestRelevant()) { + const array = Array.from(result); + dispatch(projectActions.getProjectsSuccess(array, result.count)); + // Appropriate tasks fetching process needs with retrieving only a single project + if (Object.keys(filteredQuery).includes('id') && typeof filteredQuery.id === 'number') { + dispatch(getProjectTasksAsync({ + ...tasksQuery, + projectId: filteredQuery.id, + })); + } } }; } diff --git a/cvat-ui/src/actions/requests-actions.ts b/cvat-ui/src/actions/requests-actions.ts index 1f3972746e7e..f0e2d2f6adc1 100644 --- a/cvat-ui/src/actions/requests-actions.ts +++ b/cvat-ui/src/actions/requests-actions.ts @@ -3,10 +3,8 @@ // SPDX-License-Identifier: MIT import { ActionUnion, createAction } from 'utils/redux'; -import { CombinedState, RequestsQuery, RequestsState } from 'reducers'; -import { - Request, ProjectOrTaskOrJob, getCore, RQStatus, -} from 'cvat-core-wrapper'; +import { CombinedState, RequestsQuery } from 'reducers'; +import { Request, ProjectOrTaskOrJob, getCore } from 'cvat-core-wrapper'; import { Store } from 'redux'; import { getCVATStore } from 'cvat-store'; @@ -88,23 +86,15 @@ export function updateRequestProgress(request: Request, dispatch: (action: Reque ); } -export function shouldListenForProgress(rqID: string | void, state: RequestsState): boolean { - return ( - typeof rqID === 'string' && - (!state.requests[rqID] || [RQStatus.FINISHED, RQStatus.FAILED].includes(state.requests[rqID]?.status)) - ); -} - export function listen( requestID: string, dispatch: (action: RequestsActions) => void, ) : Promise { const { requests } = getStore().getState().requests; - return core.requests - .listen(requestID, { - callback: (updatedRequest) => { - updateRequestProgress(updatedRequest, dispatch); - }, - initialRequest: requests[requestID], - }); + return core.requests.listen(requestID, { + callback: (updatedRequest) => { + updateRequestProgress(updatedRequest, dispatch); + }, + initialRequest: requests[requestID], + }); } diff --git a/cvat-ui/src/actions/requests-async-actions.ts b/cvat-ui/src/actions/requests-async-actions.ts index 06a137eafd28..86151cbd076e 100644 --- a/cvat-ui/src/actions/requests-async-actions.ts +++ b/cvat-ui/src/actions/requests-async-actions.ts @@ -37,12 +37,14 @@ export function getRequestsAsync(query: RequestsQuery): ThunkAction { .forEach((request: Request): void => { const { id: rqID, + status, operation: { type, target, format, taskID, projectID, jobID, }, } = request; - if (state.requests.requests[rqID]) { + const isRequestFinished = [RQStatus.FINISHED, RQStatus.FAILED].includes(status); + if (state.requests.requests[rqID] || isRequestFinished) { return; } diff --git a/cvat-ui/src/actions/tasks-actions.ts b/cvat-ui/src/actions/tasks-actions.ts index d15f033f1e6f..644f5aa7b021 100644 --- a/cvat-ui/src/actions/tasks-actions.ts +++ b/cvat-ui/src/actions/tasks-actions.ts @@ -32,10 +32,11 @@ export enum TasksActionTypes { UPDATE_TASK_IN_STATE = 'UPDATE_TASK_IN_STATE', } -function getTasks(query: Partial, updateQuery: boolean): AnyAction { +function getTasks(query: Partial, updateQuery: boolean, fetchingTimestamp: number): AnyAction { const action = { type: TasksActionTypes.GET_TASKS, payload: { + fetchingTimestamp, updateQuery, query, }, @@ -69,23 +70,30 @@ export function getTasksAsync( query: Partial, updateQuery = true, ): ThunkAction { - return async (dispatch: ThunkDispatch): Promise => { - dispatch(getTasks(query, updateQuery)); + return async (dispatch: ThunkDispatch, getState): Promise => { + const requestedOn = Date.now(); + const isRequestRelevant = (): boolean => ( + getState().tasks.fetchingTimestamp === requestedOn + ); + dispatch(getTasks(query, updateQuery, requestedOn)); const filteredQuery = filterNull(query); let result = null; try { result = await cvat.tasks.get(filteredQuery); } catch (error) { - dispatch(getTasksFailed(error)); + if (isRequestRelevant()) { + dispatch(getTasksFailed(error)); + } return; } - const array = Array.from(result); - - dispatch(getInferenceStatusAsync()); - dispatch(getTasksSuccess(array, result.count)); + if (isRequestRelevant()) { + const array = Array.from(result); + dispatch(getInferenceStatusAsync()); + dispatch(getTasksSuccess(array, result.count)); + } }; } diff --git a/cvat-ui/src/components/annotation-page/annotation-page.tsx b/cvat-ui/src/components/annotation-page/annotation-page.tsx index 37ba42116711..a04734b9c371 100644 --- a/cvat-ui/src/components/annotation-page/annotation-page.tsx +++ b/cvat-ui/src/components/annotation-page/annotation-page.tsx @@ -5,7 +5,6 @@ import React, { useEffect } from 'react'; import Layout from 'antd/lib/layout'; -import Result from 'antd/lib/result'; import Spin from 'antd/lib/spin'; import notification from 'antd/lib/notification'; import Button from 'antd/lib/button'; @@ -19,6 +18,7 @@ import StandardWorkspaceComponent from 'components/annotation-page/standard-work import StandardWorkspace3DComponent from 'components/annotation-page/standard3D-workspace/standard3D-workspace'; import TagAnnotationWorkspace from 'components/annotation-page/tag-annotation-workspace/tag-annotation-workspace'; import FiltersModalComponent from 'components/annotation-page/top-bar/filters-modal'; +import { JobNotFoundComponent } from 'components/common/not-found'; import StatisticsModalComponent from 'components/annotation-page/top-bar/statistics-modal'; import AnnotationTopBarContainer from 'containers/annotation-page/top-bar/top-bar'; import { Workspace } from 'reducers'; @@ -139,14 +139,7 @@ export default function AnnotationPageComponent(props: Props): JSX.Element { } if (typeof job === 'undefined') { - return ( - - ); + return ; } return ( diff --git a/cvat-ui/src/components/annotation-page/canvas/views/canvas2d/brush-tools.tsx b/cvat-ui/src/components/annotation-page/canvas/views/canvas2d/brush-tools.tsx index 6c140438c20e..b6a43ce20cf6 100644 --- a/cvat-ui/src/components/annotation-page/canvas/views/canvas2d/brush-tools.tsx +++ b/cvat-ui/src/components/annotation-page/canvas/views/canvas2d/brush-tools.tsx @@ -6,9 +6,9 @@ import './brush-toolbox-styles.scss'; import React, { useCallback, useEffect, useState } from 'react'; import ReactDOM from 'react-dom'; -import { useDispatch, useSelector } from 'react-redux'; +import { shallowEqual, useDispatch, useSelector } from 'react-redux'; import Button from 'antd/lib/button'; -import Icon, { VerticalAlignBottomOutlined } from '@ant-design/icons'; +import Icon, { EyeInvisibleFilled, EyeOutlined, VerticalAlignBottomOutlined } from '@ant-design/icons'; import InputNumber from 'antd/lib/input-number'; import Select from 'antd/lib/select'; import notification from 'antd/lib/notification'; @@ -23,7 +23,7 @@ import { import CVATTooltip from 'components/common/cvat-tooltip'; import { CombinedState, ObjectType, ShapeType } from 'reducers'; import LabelSelector from 'components/label-selector/label-selector'; -import { rememberObject, updateCanvasBrushTools } from 'actions/annotation-actions'; +import { changeHideActiveObjectAsync, rememberObject, updateCanvasBrushTools } from 'actions/annotation-actions'; import { ShortcutScope } from 'utils/enums'; import GlobalHotKeys from 'utils/mousetrap-react'; import { subKeyMap } from 'utils/component-subkeymap'; @@ -71,12 +71,17 @@ registerComponentShortcuts(componentShortcuts); const MIN_BRUSH_SIZE = 1; function BrushTools(): React.ReactPortal | null { const dispatch = useDispatch(); - const defaultLabelID = useSelector((state: CombinedState) => state.annotation.drawing.activeLabelID); - const config = useSelector((state: CombinedState) => state.annotation.canvas.brushTools); - const canvasInstance = useSelector((state: CombinedState) => state.annotation.canvas.instance); - const labels = useSelector((state: CombinedState) => state.annotation.job.labels); - const { keyMap, normalizedKeyMap } = useSelector((state: CombinedState) => state.shortcuts); - const { visible } = config; + const { + defaultLabelID, visible, canvasInstance, labels, activeObjectHidden, keyMap, normalizedKeyMap, + } = useSelector((state: CombinedState) => ({ + defaultLabelID: state.annotation.drawing.activeLabelID, + visible: state.annotation.canvas.brushTools.visible, + canvasInstance: state.annotation.canvas.instance, + labels: state.annotation.job.labels, + activeObjectHidden: state.annotation.canvas.activeObjectHidden, + keyMap: state.shortcuts.keyMap, + normalizedKeyMap: state.shortcuts.normalizedKeyMap, + }), shallowEqual); const [editableState, setEditableState] = useState(null); const [currentTool, setCurrentTool] = useState<'brush' | 'eraser' | 'polygon-plus' | 'polygon-minus'>('brush'); @@ -103,6 +108,10 @@ function BrushTools(): React.ReactPortal | null { } }, [setCurrentTool, blockedTools['polygon-minus']]); + const hideMask = useCallback((hide: boolean) => { + dispatch(changeHideActiveObjectAsync(hide)); + }, []); + const handlers: Record void> = { ACTIVATE_BRUSH_TOOL_STANDARD_CONTROLS: setBrushTool, ACTIVATE_ERASER_TOOL_STANDARD_CONTROLS: setEraserTool, @@ -365,6 +374,14 @@ function BrushTools(): React.ReactPortal | null { icon={} onClick={() => setRemoveUnderlyingPixels(!removeUnderlyingPixels)} /> + +