Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Target metric in quality UI #8347

Merged
merged 5 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
### Added

- New quality settings `Target metric`, `Target metric threshold`, `Max validations per job`
(<https://github.com/cvat-ai/cvat/pull/8347>)

### Changed
- `Mean annotaion quality` card on quality page now displays a value depending on `Target metric` setting
(<https://github.com/cvat-ai/cvat/pull/8347>)
2 changes: 1 addition & 1 deletion cvat-core/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "cvat-core",
"version": "15.1.2",
"version": "15.1.3",
"type": "module",
"description": "Part of Computer Vision Tool which presents an interface for client-side integration",
"main": "src/api.ts",
Expand Down
12 changes: 10 additions & 2 deletions cvat-core/src/api-implementation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import QualitySettings from './quality-settings';
import { FramesMetaData } from './frames';
import AnalyticsReport from './analytics-report';
import { listActions, registerAction, runActions } from './annotations-actions';
import { convertDescriptions, getServerAPISchema } from './server-schema';
import { JobType } from './enums';
import { PaginatedResource } from './core-types';
import CVATCore from '.';
Expand Down Expand Up @@ -142,7 +143,10 @@ export default function implementAPI(cvat: CVATCore): CVATCore {
return result;
});

implementationMixin(cvat.server.apiSchema, serverProxy.server.apiSchema);
implementationMixin(cvat.server.apiSchema, async () => {
const result = await getServerAPISchema();
return result;
});

implementationMixin(cvat.assets.create, async (file: File, guideId: number): Promise<SerializedAsset> => {
if (!(file instanceof File)) {
Expand Down Expand Up @@ -514,7 +518,11 @@ export default function implementAPI(cvat: CVATCore): CVATCore {
const params = fieldsToSnakeCase(filter);

const settings = await serverProxy.analytics.quality.settings.get(params);
return new QualitySettings({ ...settings });
const schema = await getServerAPISchema();
const descriptions = convertDescriptions(schema.components.schemas.QualitySettings.properties);
return new QualitySettings({
...settings, descriptions,
});
});
implementationMixin(cvat.analytics.performance.reports, async (filter: AnalyticsReportFilter) => {
checkFilter(filter, {
Expand Down
65 changes: 61 additions & 4 deletions cvat-core/src/quality-settings.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
// Copyright (C) 2023 CVAT.ai Corporation
// Copyright (C) 2023-2024 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT

import _ from 'lodash';
import { SerializedQualitySettingsData } from './server-response-types';
import PluginRegistry from './plugins';
import serverProxy from './server-proxy';
import { convertDescriptions, getServerAPISchema } from './server-schema';

export enum TargetMetric {
ACCURACY = 'accuracy',
PRECISION = 'precision',
RECALL = 'recall',
}

export default class QualitySettings {
#id: number;
#targetMetric: TargetMetric;
#targetMetricThreshold: number;
#maxValidationsPerJob: number;
#task: number;
#iouThreshold: number;
#oksSigma: number;
Expand All @@ -21,10 +32,14 @@ export default class QualitySettings {
#objectVisibilityThreshold: number;
#panopticComparison: boolean;
#compareAttributes: boolean;
#descriptions: Record<string, string>;

constructor(initialData: SerializedQualitySettingsData) {
this.#id = initialData.id;
this.#task = initialData.task;
this.#targetMetric = initialData.target_metric as TargetMetric;
this.#targetMetricThreshold = initialData.target_metric_threshold;
this.#maxValidationsPerJob = initialData.max_validations_per_job;
this.#iouThreshold = initialData.iou_threshold;
this.#oksSigma = initialData.oks_sigma;
this.#lineThickness = initialData.line_thickness;
Expand All @@ -37,6 +52,7 @@ export default class QualitySettings {
this.#objectVisibilityThreshold = initialData.object_visibility_threshold;
this.#panopticComparison = initialData.panoptic_comparison;
this.#compareAttributes = initialData.compare_attributes;
this.#descriptions = initialData.descriptions;
}

get id(): number {
Expand Down Expand Up @@ -143,6 +159,40 @@ export default class QualitySettings {
this.#compareAttributes = newVal;
}

get targetMetric(): TargetMetric {
return this.#targetMetric;
}

set targetMetric(newVal: TargetMetric) {
this.#targetMetric = newVal;
}

get targetMetricThreshold(): number {
return this.#targetMetricThreshold;
}

set targetMetricThreshold(newVal: number) {
this.#targetMetricThreshold = newVal;
}

get maxValidationsPerJob(): number {
return this.#maxValidationsPerJob;
}

set maxValidationsPerJob(newVal: number) {
this.#maxValidationsPerJob = newVal;
}

get descriptions(): Record<string, string> {
const descriptions: Record<string, string> = Object.keys(this.#descriptions).reduce((acc, key) => {
const camelCaseKey = _.camelCase(key);
acc[camelCaseKey] = this.#descriptions[key];
return acc;
}, {});

return descriptions;
}

public toJSON(): SerializedQualitySettingsData {
const result: SerializedQualitySettingsData = {
iou_threshold: this.#iouThreshold,
Expand All @@ -157,6 +207,9 @@ export default class QualitySettings {
object_visibility_threshold: this.#objectVisibilityThreshold,
panoptic_comparison: this.#panopticComparison,
compare_attributes: this.#compareAttributes,
target_metric: this.#targetMetric,
target_metric_threshold: this.#targetMetricThreshold,
max_validations_per_job: this.#maxValidationsPerJob,
};

return result;
Expand All @@ -172,9 +225,13 @@ Object.defineProperties(QualitySettings.prototype.save, {
implementation: {
writable: false,
enumerable: false,
value: async function implementation() {
const result = await serverProxy.analytics.quality.settings.update(this.id, this.toJSON());
return new QualitySettings(result);
value: async function implementation(): Promise<QualitySettings> {
const result = await serverProxy.analytics.quality.settings.update(
this.id, this.toJSON(),
);
const schema = await getServerAPISchema();
const descriptions = convertDescriptions(schema.components.schemas.QualitySettings.properties);
return new QualitySettings({ ...result, descriptions });
},
},
});
4 changes: 4 additions & 0 deletions cvat-core/src/server-response-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ export type QualitySettingsFilter = Camelized<APIQualitySettingsFilter>;
export interface SerializedQualitySettingsData {
id?: number;
task?: number;
target_metric?: string;
target_metric_threshold?: number;
max_validations_per_job?: number;
iou_threshold?: number;
oks_sigma?: number;
line_thickness?: number;
Expand All @@ -253,6 +256,7 @@ export interface SerializedQualitySettingsData {
object_visibility_threshold?: number;
panoptic_comparison?: boolean;
compare_attributes?: boolean;
descriptions?: Record<string, string>;
}

export interface APIQualityConflictsFilter extends APICommonFilterParams {
Expand Down
24 changes: 24 additions & 0 deletions cvat-core/src/server-schema.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2024 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT

import { SerializedAPISchema } from 'server-response-types';
import serverProxy from './server-proxy';

let schemaCache: SerializedAPISchema | null = null;

export async function getServerAPISchema(): Promise<SerializedAPISchema> {
if (schemaCache) {
return schemaCache;
}

schemaCache = await serverProxy.server.apiSchema();
return schemaCache;
}

export function convertDescriptions(descriptions: Record<string, { description?: string }>): Record<string, string> {
return Object.keys(descriptions).reduce((acc, key) => {
acc[key] = descriptions[key].description ?? '';
return acc;
}, {} as Record<string, string>);
}
2 changes: 1 addition & 1 deletion cvat-ui/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "cvat-ui",
"version": "1.65.0",
"version": "1.65.1",
"description": "CVAT single-page application",
"main": "src/index.tsx",
"scripts": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ export default function QualitySettingsModal(props: Props): JSX.Element | null {
try {
if (settings) {
const values = await form.validateFields();

settings.targetMetric = values.targetMetric;
settings.targetMetricThreshold = values.targetMetricThreshold / 100;

settings.maxValidationsPerJob = values.maxValidationsPerJob;

bsekachev marked this conversation as resolved.
Show resolved Hide resolved
settings.lowOverlapThreshold = values.lowOverlapThreshold / 100;
settings.iouThreshold = values.iouThreshold / 100;
settings.compareAttributes = values.compareAttributes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import Text from 'antd/lib/typography/Text';
import { Col, Row } from 'antd/lib/grid';

import { QualityReport, QualitySummary } from 'cvat-core-wrapper';
import { clampValue, percent } from 'utils/quality';
import AnalyticsCard from '../views/analytics-card';
import { percent, clampValue } from '../utils/text-formatting';

interface Props {
taskReport: QualityReport | null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import Text from 'antd/lib/typography/Text';
import notification from 'antd/lib/notification';
import { Task } from 'cvat-core-wrapper';
import { useIsMounted } from 'utils/hooks';
import { clampValue, percent } from 'utils/quality';
import AnalyticsCard from '../views/analytics-card';
import { percent, clampValue } from '../utils/text-formatting';

interface Props {
task: Task;
Expand Down
62 changes: 13 additions & 49 deletions cvat-ui/src/components/analytics-page/task-quality/job-list.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,35 @@ import React, { useState } from 'react';
import { useHistory } from 'react-router';
import { Row, Col } from 'antd/lib/grid';
import { DownloadOutlined, QuestionCircleOutlined } from '@ant-design/icons';
import { ColumnFilterItem, Key } from 'antd/lib/table/interface';
import { Key } from 'antd/lib/table/interface';
import Table from 'antd/lib/table';
import Button from 'antd/lib/button';
import Text from 'antd/lib/typography/Text';

import {
Task, Job, JobType, QualityReport, getCore,
TargetMetric,
} from 'cvat-core-wrapper';
import CVATTooltip from 'components/common/cvat-tooltip';
import { getQualityColor } from 'utils/quality-color';
import Tag from 'antd/lib/tag';
import { toRepresentation } from '../utils/text-formatting';
import {
collectAssignees, QualityColors, sorter, toRepresentation,
} from 'utils/quality';
import { ConflictsTooltip } from './gt-conflicts';

interface Props {
task: Task;
jobsReports: QualityReport[];
getQualityColor: (value?: number) => QualityColors;
targetMetric: TargetMetric;
}

function JobListComponent(props: Props): JSX.Element {
const {
task: taskInstance,
jobsReports: jobsReportsArray,
getQualityColor,
targetMetric,
} = props;

const jobsReports: Record<number, QualityReport> = jobsReportsArray
Expand All @@ -37,49 +43,6 @@ function JobListComponent(props: Props): JSX.Element {
const { id: taskId, jobs } = taskInstance;
const [renderedJobs] = useState<Job[]>(jobs.filter((job: Job) => job.type === JobType.ANNOTATION));

function sorter(path: string) {
return (obj1: any, obj2: any): number => {
let currentObj1 = obj1;
let currentObj2 = obj2;
let field1: string | number | null = null;
let field2: string | number | null = null;
for (const pathSegment of path.split('.')) {
field1 = currentObj1 && pathSegment in currentObj1 ? currentObj1[pathSegment] : null;
field2 = currentObj2 && pathSegment in currentObj2 ? currentObj2[pathSegment] : null;
currentObj1 = currentObj1 && pathSegment in currentObj1 ? currentObj1[pathSegment] : null;
currentObj2 = currentObj2 && pathSegment in currentObj2 ? currentObj2[pathSegment] : null;
}

if (field1 !== null && field2 !== null) {
if (typeof field1 === 'string' && typeof field2 === 'string') return field1.localeCompare(field2);
if (typeof field1 === 'number' && typeof field2 === 'number' &&
Number.isFinite(field1) && Number.isFinite(field2)) return field1 - field2;
}

if (field1 === null && field2 === null) return 0;

if (field1 === null || (typeof field1 === 'number' && !Number.isFinite(field1))) {
return -1;
}

return 1;
};
}

function collectUsers(path: string): ColumnFilterItem[] {
return Array.from<string | null>(
new Set(
Object.values(jobsReports).map((report: QualityReport) => {
if (report[path] === null) {
return null;
}

return report[path].username;
}),
),
).map((value: string | null) => ({ text: value || 'Is Empty', value: value || false }));
}

const columns = [
{
title: 'Job',
Expand Down Expand Up @@ -133,7 +96,7 @@ function JobListComponent(props: Props): JSX.Element {
<Text>{report?.assignee?.username}</Text>
),
sorter: sorter('assignee.assignee.username'),
filters: collectUsers('assignee'),
filters: collectAssignees(jobsReportsArray),
onFilter: (value: boolean | Key, record: any) => (
record.assignee.assignee?.username || false
) === value,
Expand Down Expand Up @@ -187,10 +150,11 @@ function JobListComponent(props: Props): JSX.Element {
key: 'quality',
align: 'center' as const,
className: 'cvat-job-item-quality',
sorter: sorter('quality.summary.accuracy'),
sorter: sorter(`quality.summary.${targetMetric}`),
render: (report?: QualityReport): JSX.Element => {
const meanAccuracy = report?.summary?.accuracy;
const meanAccuracy = report?.summary?.[targetMetric];
const accuracyRepresentation = toRepresentation(meanAccuracy);

return (
accuracyRepresentation.includes('N/A') ? (
<Text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@ import { Col, Row } from 'antd/lib/grid';
import Text from 'antd/lib/typography/Text';
import Button from 'antd/lib/button';

import { QualityReport, getCore } from 'cvat-core-wrapper';
import { QualityReport, TargetMetric, getCore } from 'cvat-core-wrapper';
import { toRepresentation } from 'utils/quality';
import AnalyticsCard from '../views/analytics-card';
import { toRepresentation } from '../utils/text-formatting';

interface Props {
taskID: number;
taskReport: QualityReport | null;
targetMetric: TargetMetric;
setQualitySettingsVisible: (visible: boolean) => void;
}

function MeanQuality(props: Props): JSX.Element {
const { taskID, taskReport, setQualitySettingsVisible } = props;
const {
taskID, taskReport, targetMetric, setQualitySettingsVisible,
} = props;
const reportSummary = taskReport?.summary;

const tooltip = (
Expand Down Expand Up @@ -87,7 +90,7 @@ function MeanQuality(props: Props): JSX.Element {
<AnalyticsCard
title='Mean annotation quality'
className='cvat-task-mean-annotation-quality'
value={toRepresentation(reportSummary?.accuracy)}
value={toRepresentation(reportSummary?.[targetMetric])}
tooltip={tooltip}
rightElement={downloadReportButton}
/>
Expand Down
Loading
Loading