Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add DatasetPublishValidator class #5568

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 2 additions & 19 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
)
from argilla_server.models.suggestions import SuggestionCreateWithRecordId
from argilla_server.search_engine import SearchEngine
from argilla_server.validators.datasets import DatasetCreateValidator, DatasetUpdateValidator
from argilla_server.validators.datasets import DatasetCreateValidator, DatasetPublishValidator, DatasetUpdateValidator
from argilla_server.validators.responses import (
ResponseCreateValidator,
ResponseUpdateValidator,
Expand Down Expand Up @@ -145,16 +145,6 @@ async def create_dataset(db: AsyncSession, dataset_attrs: dict):
return await dataset.save(db)


async def _count_required_fields_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int:
return (await db.execute(select(func.count(Field.id)).filter_by(dataset_id=dataset_id, required=True))).scalar_one()


async def _count_required_questions_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int:
return (
await db.execute(select(func.count(Question.id)).filter_by(dataset_id=dataset_id, required=True))
).scalar_one()


def _allowed_roles_for_metadata_property_create(metadata_property_create: MetadataPropertyCreate) -> List[UserRole]:
if metadata_property_create.visible_for_annotators:
return VISIBLE_FOR_ANNOTATORS_ALLOWED_ROLES
Expand All @@ -163,14 +153,7 @@ def _allowed_roles_for_metadata_property_create(metadata_property_create: Metada


async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset:
if dataset.is_ready:
raise UnprocessableEntityError("Dataset is already published")

if await _count_required_fields_by_dataset_id(db, dataset.id) == 0:
raise UnprocessableEntityError("Dataset cannot be published without required fields")

if await _count_required_questions_by_dataset_id(db, dataset.id) == 0:
raise UnprocessableEntityError("Dataset cannot be published without required questions")
await DatasetPublishValidator.validate(db, dataset)

async with db.begin_nested():
dataset = await dataset.update(db, status=DatasetStatus.ready, autocommit=False)
Expand Down
6 changes: 5 additions & 1 deletion argilla-server/src/argilla_server/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Set, TypeVar, Union
from uuid import UUID

from sqlalchemy import select, sql
from sqlalchemy import select, func, sql
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as postgres_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
Expand Down Expand Up @@ -100,6 +100,10 @@ async def get_by_or_raise(cls, db: AsyncSession, **conditions) -> Self:

raise NotFoundError(f"{cls.__name__} not found filtering by {conditions_str}")

@classmethod
async def count_by(cls, db: AsyncSession, **conditions) -> int:
return (await db.execute(select(func.count(cls.id)).filter_by(**conditions))).scalar_one()

async def update(
self,
db: AsyncSession,
Expand Down
25 changes: 24 additions & 1 deletion argilla-server/src/argilla_server/validators/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.models import Dataset, Field, Question, Workspace
from argilla_server.errors.future import (
NotUniqueError,
UnprocessableEntityError,
UpdateDistributionWithExistingResponsesError,
)
from argilla_server.models import Dataset, Workspace


class DatasetCreateValidator:
Expand All @@ -41,6 +41,29 @@ async def _validate_name_is_not_duplicated(cls, db: AsyncSession, name: str, wor
raise NotUniqueError(f"Dataset with name `{name}` already exists for workspace with id `{workspace_id}`")


class DatasetPublishValidator:
@classmethod
async def validate(cls, db: AsyncSession, dataset: Dataset) -> None:
await cls._validate_has_not_been_published_yet(db, dataset)
await cls._validate_has_at_least_one_required_field(db, dataset)
await cls._validate_has_at_least_one_required_question(db, dataset)

@classmethod
async def _validate_has_not_been_published_yet(cls, db: AsyncSession, dataset: Dataset) -> None:
if dataset.is_ready:
raise UnprocessableEntityError("Dataset has already been published")

@classmethod
async def _validate_has_at_least_one_required_field(cls, db: AsyncSession, dataset: Dataset) -> None:
if await Field.count_by(db, dataset_id=dataset.id, required=True) == 0:
raise UnprocessableEntityError("Dataset cannot be published without required fields")

@classmethod
async def _validate_has_at_least_one_required_question(cls, db: AsyncSession, dataset: Dataset) -> None:
if await Question.count_by(db, dataset_id=dataset.id, required=True) == 0:
raise UnprocessableEntityError("Dataset cannot be published without required questions")


class DatasetUpdateValidator:
@classmethod
async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> None:
Expand Down
6 changes: 3 additions & 3 deletions argilla-server/tests/unit/api/handlers/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4672,10 +4672,10 @@ async def test_publish_dataset_already_published(
response = await async_client.put(f"/api/v1/datasets/{dataset.id}/publish", headers=owner_auth_header)

assert response.status_code == 422
assert response.json() == {"detail": "Dataset is already published"}
assert response.json() == {"detail": "Dataset has already been published"}
assert (await db.execute(select(func.count(Record.id)))).scalar() == 0

async def test_publish_dataset_without_fields(
async def test_publish_dataset_without_required_fields(
self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
):
dataset = await DatasetFactory.create()
Expand All @@ -4688,7 +4688,7 @@ async def test_publish_dataset_without_fields(
assert response.json() == {"detail": "Dataset cannot be published without required fields"}
assert (await db.execute(select(func.count(Record.id)))).scalar() == 0

async def test_publish_dataset_without_questions(
async def test_publish_dataset_without_required_questions(
self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
):
dataset = await DatasetFactory.create()
Expand Down
Loading