Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: move suggestion score validations to `SuggestionCreateValidator…
Browse files Browse the repository at this point in the history
…` and add a missing cardinality validation for value and score attributes (#111)

# Description

This PR move all suggestion `score` validations to
`SuggestionCreateValidator` and add a new one checking that `value` and
`score` attributes have the same cardinality.

The additional validation will raise an error when a suggestion have a
single item inside `value` attribute and a list of items inside `score`.

Refs argilla-io/argilla#4638

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [x] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [x] Adding new tests.

**Checklist**

- [ ] I added relevant documentation
- [ ] follows the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
jfcalvo authored Apr 25, 2024
1 parent ea30a60 commit 3337707
Showing 5 changed files with 129 additions and 26 deletions.
14 changes: 1 addition & 13 deletions src/argilla_server/schemas/v1/suggestions.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
from uuid import UUID

from argilla_server.models import SuggestionType
from argilla_server.pydantic_v1 import BaseModel, Field, root_validator
from argilla_server.pydantic_v1 import BaseModel, Field
from argilla_server.schemas.v1.questions import QuestionName
from argilla_server.schemas.v1.responses import (
SPAN_QUESTION_RESPONSE_VALUE_MAX_ITEMS,
@@ -114,15 +114,3 @@ class SuggestionCreate(BaseSuggestion):
description="Agent used to generate the suggestion",
)
score: SuggestionScoreField

@root_validator(skip_on_failure=True)
def check_value_and_score_length(cls, values: dict) -> dict:
value, score = values.get("value"), values.get("score")

if not isinstance(value, list) or not isinstance(score, list):
return values

if len(value) != len(score):
raise ValueError("number of items on value and score attributes doesn't match")

return values
23 changes: 23 additions & 0 deletions src/argilla_server/validators/suggestions.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,29 @@ def __init__(self, suggestion_create: SuggestionCreate):

def validate_for(self, question_settings: QuestionSettings, record: Record) -> None:
self._validate_value(question_settings, record)
self._validate_score()

def _validate_value(self, question_settings: QuestionSettings, record: Record) -> None:
ResponseValueValidator(self._suggestion_create.value).validate_for(question_settings, record)

def _validate_score(self):
self._validate_value_and_score_cardinality()
self._validate_value_and_score_have_same_length()

def _validate_value_and_score_cardinality(self):
if not isinstance(self._suggestion_create.value, list) and isinstance(self._suggestion_create.score, list):
raise ValueError("a list of score values is not allowed for a suggestion with a single value")

if (
isinstance(self._suggestion_create.value, list)
and self._suggestion_create.score is not None
and not isinstance(self._suggestion_create.score, list)
):
raise ValueError("a single score value is not allowed for a suggestion with a multiple items value")

def _validate_value_and_score_have_same_length(self) -> None:
if not isinstance(self._suggestion_create.value, list) or not isinstance(self._suggestion_create.score, list):
return

if len(self._suggestion_create.value) != len(self._suggestion_create.score):
raise ValueError("number of items on value and score attributes doesn't match")
Original file line number Diff line number Diff line change
@@ -172,7 +172,7 @@ async def test_create_dataset_records(
},
{
"type": SuggestionType.model,
"score": 1.0,
"score": [1.0, 0.1],
"value": [
"label-a",
"label-b",
@@ -187,7 +187,7 @@ async def test_create_dataset_records(
},
{
"type": SuggestionType.model,
"score": 0.5,
"score": [0.2, 0.5],
"value": [
{"value": "ranking-a", "rank": 1},
{"value": "ranking-b", "rank": 2},
Original file line number Diff line number Diff line change
@@ -143,7 +143,7 @@ async def test_update_dataset_records(
},
{
"type": SuggestionType.model,
"score": 1.0,
"score": [1.0, 0.1],
"value": [
"label-a",
"label-b",
@@ -158,7 +158,7 @@ async def test_update_dataset_records(
},
{
"type": SuggestionType.model,
"score": 0.5,
"score": [0.2, 0.5],
"value": [
{"value": "ranking-a", "rank": 1},
{"value": "ranking-b", "rank": 2},
110 changes: 101 additions & 9 deletions tests/unit/api/v1/records/test_upsert_suggestion.py
Original file line number Diff line number Diff line change
@@ -17,13 +17,13 @@
from uuid import UUID, uuid4

import pytest
from argilla_server.enums import SuggestionType
from argilla_server.enums import QuestionType, SuggestionType
from argilla_server.models import Suggestion
from httpx import AsyncClient
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from tests.factories import DatasetFactory, RecordFactory, SpanQuestionFactory, TextQuestionFactory
from tests.factories import DatasetFactory, QuestionFactory, RecordFactory, SpanQuestionFactory, TextQuestionFactory


@pytest.mark.asyncio
@@ -54,26 +54,65 @@ async def test_upsert_suggestion_with_valid_agent(

assert response.status_code == 201

@pytest.mark.parametrize("score", [[1.0], [1.0, 0.5], [1.0, 0.5, 0.9, 0.3]])
async def test_upsert_suggestion_with_list_of_scores(
self, async_client: AsyncClient, owner_auth_header: dict, score: List[float]
async def test_upsert_suggestion_with_single_value_score_list(
self, async_client: AsyncClient, owner_auth_header: dict
):
record = await RecordFactory.create()
question = await TextQuestionFactory.create(dataset=record.dataset)
question = await QuestionFactory.create(
dataset=record.dataset,
settings={
"type": QuestionType.multi_label_selection,
"options": [
{"value": "label-1", "text": "Label 1"},
{"value": "label-2", "text": "Label 2"},
],
},
)

response = await async_client.put(
self.url(record.id),
headers=owner_auth_header,
json={
"question_id": str(question.id),
"type": SuggestionType.model,
"value": "value",
"score": score,
"value": ["label-1"],
"score": [1.0],
},
)

assert response.status_code == 201
assert response.json()["score"] == score
assert response.json()["score"] == [1.0]

async def test_upsert_suggestion_with_multiple_values_score_list(
self, async_client: AsyncClient, owner_auth_header: dict
):
record = await RecordFactory.create()
question = await QuestionFactory.create(
dataset=record.dataset,
settings={
"type": QuestionType.multi_label_selection,
"options": [
{"value": "label-1", "text": "Label 1"},
{"value": "label-2", "text": "Label 2"},
{"value": "label-3", "text": "Label 3"},
{"value": "label-4", "text": "Label 3"},
],
},
)

response = await async_client.put(
self.url(record.id),
headers=owner_auth_header,
json={
"question_id": str(question.id),
"type": SuggestionType.model,
"value": ["label-1", "label-2", "label-3", "label-4"],
"score": [1.0, 0.5, 0.9, 0.3],
},
)

assert response.status_code == 201
assert response.json()["score"] == [1.0, 0.5, 0.9, 0.3]

@pytest.mark.parametrize("agent", ["", " ", " ", "-", "_", ":", ".", "/", ","])
async def test_upsert_suggestion_with_invalid_agent(
@@ -225,6 +264,58 @@ async def test_upsert_suggestion_with_empty_list_of_scores(

assert response.status_code == 422

async def test_upsert_suggestion_with_single_value_and_multiple_scores(
self, async_client: AsyncClient, owner_auth_header: dict
):
record = await RecordFactory.create()
question = await TextQuestionFactory.create(dataset=record.dataset)

response = await async_client.put(
self.url(record.id),
headers=owner_auth_header,
json={
"question_id": str(question.id),
"type": SuggestionType.model,
"value": "value",
"agent": "agent",
"score": [1.0],
},
)

assert response.status_code == 422
assert response.json() == {
"detail": "a list of score values is not allowed for a suggestion with a single value"
}

async def test_upsert_suggestion_with_multiple_values_and_single_score(
self, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create()

question = await SpanQuestionFactory.create(name="span-question", dataset=dataset)

record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset)

response = await async_client.put(
self.url(record.id),
headers=owner_auth_header,
json={
"question_id": str(question.id),
"type": SuggestionType.model,
"value": [
{"label": "label-a", "start": 0, "end": 1},
{"label": "label-b", "start": 2, "end": 3},
{"label": "label-c", "start": 4, "end": 5},
],
"score": 0.5,
},
)

assert response.status_code == 422
assert response.json() == {
"detail": "a single score value is not allowed for a suggestion with a multiple items value"
}

@pytest.mark.parametrize("score", [[1.0], [1.0, 0.5], [1.0, 0.5, 0.9, 0.3]])
async def test_upsert_suggestion_with_list_of_scores_not_matching_values_length(
self, async_client: AsyncClient, owner_auth_header: dict, score: List[float]
@@ -251,6 +342,7 @@ async def test_upsert_suggestion_with_list_of_scores_not_matching_values_length(
)

assert response.status_code == 422
assert response.json() == {"detail": "number of items on value and score attributes doesn't match"}

async def test_upsert_suggestion_for_span_question(
self, async_client: AsyncClient, db: AsyncSession, owner_auth_header: dict

0 comments on commit 3337707

Please sign in to comment.