Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/compare_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def new_compare_order(rapi: RapidataClient):
concept_path = "examples/data/rapidata_concept_logo.jpg"

# configure validation set
validation_set_id = rapi.new_validation_set(
validation_set = rapi.new_validation_set(
name="Example SimpleMatchup Validation Set"
).add_compare_rapid(
media_paths=[logo_path, concept_path],
Expand All @@ -35,7 +35,7 @@ def new_compare_order(rapi: RapidataClient):
"examples/data/rapidata_logo.png",
]]
)
.validation_set(validation_set_id)
.validation_set(validation_set.id)
.feature_flags(FeatureFlags().claire_design().alert_on_fast_response(4))
.create()
)
Expand Down
4 changes: 2 additions & 2 deletions examples/transcription_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def new_transcription_order(rapi: RapidataClient):
validation_set_id = (
validation_set = (
rapi.new_validation_set(
name="Example Transcription Validation Set"
).add_transcription_rapid(
Expand All @@ -38,7 +38,7 @@ def new_transcription_order(rapi: RapidataClient):
.referee(NaiveReferee(required_guesses=30))
.feature_flags(FeatureFlags().alert_on_fast_response(4))
.media(media_paths=["examples/data/waiting.mp4"], metadata=[transcription])
.validation_set(validation_set_id)
.validation_set(validation_set.id)
.create()
)

Expand Down
78 changes: 78 additions & 0 deletions rapidata/rapidata_client/dataset/rapidata_validation_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any
from rapidata.api_client.models.add_validation_rapid_model import (
AddValidationRapidModel,
)
from rapidata.api_client.models.add_validation_rapid_model_payload import (
AddValidationRapidModelPayload,
)
from rapidata.api_client.models.add_validation_rapid_model_truth import (
AddValidationRapidModelTruth,
)
from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth
from rapidata.api_client.models.bounding_box_payload import BoundingBoxPayload
from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth
from rapidata.api_client.models.classify_payload import ClassifyPayload
from rapidata.api_client.models.compare_payload import ComparePayload
from rapidata.api_client.models.compare_truth import CompareTruth
from rapidata.api_client.models.empty_validation_truth import EmptyValidationTruth
from rapidata.api_client.models.free_text_payload import FreeTextPayload
from rapidata.api_client.models.line_payload import LinePayload
from rapidata.api_client.models.line_truth import LineTruth
from rapidata.api_client.models.locate_box_truth import LocateBoxTruth
from rapidata.api_client.models.locate_payload import LocatePayload
from rapidata.api_client.models.named_entity_payload import NamedEntityPayload
from rapidata.api_client.models.named_entity_truth import NamedEntityTruth
from rapidata.api_client.models.polygon_payload import PolygonPayload
from rapidata.api_client.models.polygon_truth import PolygonTruth
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
from rapidata.rapidata_client.dataset.validation_rapid_parts import ValidatioRapidParts
from rapidata.rapidata_client.types import RapidAsset
from rapidata.service.openapi_service import OpenAPIService


class RapidataValidationSet:

def __init__(self, validation_set_id, openapi_service: OpenAPIService):
self.id = validation_set_id
self.openapi_service = openapi_service

def add_validation_rapid(
self,
payload: (
BoundingBoxPayload
| ClassifyPayload
| ComparePayload
| FreeTextPayload
| LinePayload
| LocatePayload
| NamedEntityPayload
| PolygonPayload
| TranscriptionPayload
),
truths: (
AttachCategoryTruth
| BoundingBoxTruth
| CompareTruth
| EmptyValidationTruth
| LineTruth
| LocateBoxTruth
| NamedEntityTruth
| PolygonTruth
| TranscriptionTruth
),
metadata: Any,
media_paths: RapidAsset,
randomCorrectProbability: float,
):
model = AddValidationRapidModel(
validationSetId=self.id,
payload=AddValidationRapidModelPayload(payload),
truth=AddValidationRapidModelTruth(truths),
metadata=metadata or [],
randomCorrectProbability=randomCorrectProbability,
)

self.openapi_service.validation_api.validation_add_validation_rapid_post(
model=model, files=media_paths if isinstance(media_paths, list) else [media_paths] # type: ignore
)
52 changes: 52 additions & 0 deletions rapidata/rapidata_client/dataset/validation_rapid_parts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from dataclasses import dataclass
from typing import Any

from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth
from rapidata.api_client.models.bounding_box_payload import BoundingBoxPayload
from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth
from rapidata.api_client.models.classify_payload import ClassifyPayload
from rapidata.api_client.models.compare_payload import ComparePayload
from rapidata.api_client.models.compare_truth import CompareTruth
from rapidata.api_client.models.empty_validation_truth import EmptyValidationTruth
from rapidata.api_client.models.free_text_payload import FreeTextPayload
from rapidata.api_client.models.line_payload import LinePayload
from rapidata.api_client.models.line_truth import LineTruth
from rapidata.api_client.models.locate_box_truth import LocateBoxTruth
from rapidata.api_client.models.locate_payload import LocatePayload
from rapidata.api_client.models.named_entity_payload import NamedEntityPayload
from rapidata.api_client.models.named_entity_truth import NamedEntityTruth
from rapidata.api_client.models.polygon_payload import PolygonPayload
from rapidata.api_client.models.polygon_truth import PolygonTruth
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
from rapidata.rapidata_client.types import RapidAsset


@dataclass
class ValidatioRapidParts:
question: str
media_paths: RapidAsset
payload: (
BoundingBoxPayload
| ClassifyPayload
| ComparePayload
| FreeTextPayload
| LinePayload
| LocatePayload
| NamedEntityPayload
| PolygonPayload
| TranscriptionPayload
)
truths: (
AttachCategoryTruth
| BoundingBoxTruth
| CompareTruth
| EmptyValidationTruth
| LineTruth
| LocateBoxTruth
| NamedEntityTruth
| PolygonTruth
| TranscriptionTruth
)
metadata: Any
randomCorrectProbability: float
Original file line number Diff line number Diff line change
Expand Up @@ -9,55 +9,18 @@
AddValidationRapidModelTruth,
)
from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth
from rapidata.api_client.models.bounding_box_payload import BoundingBoxPayload
from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth
from rapidata.api_client.models.classify_payload import ClassifyPayload
from rapidata.api_client.models.compare_payload import ComparePayload
from rapidata.api_client.models.compare_truth import CompareTruth
from rapidata.api_client.models.empty_validation_truth import EmptyValidationTruth
from rapidata.api_client.models.free_text_payload import FreeTextPayload
from rapidata.api_client.models.line_payload import LinePayload
from rapidata.api_client.models.line_truth import LineTruth
from rapidata.api_client.models.locate_box_truth import LocateBoxTruth
from rapidata.api_client.models.locate_payload import LocatePayload
from rapidata.api_client.models.named_entity_payload import NamedEntityPayload
from rapidata.api_client.models.named_entity_truth import NamedEntityTruth
from rapidata.api_client.models.polygon_payload import PolygonPayload
from rapidata.api_client.models.polygon_truth import PolygonTruth
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
from rapidata.api_client.models.transcription_word import TranscriptionWord
from rapidata.rapidata_client.dataset.rapidata_validation_set import RapidataValidationSet
from rapidata.rapidata_client.dataset.validation_rapid_parts import ValidatioRapidParts
from rapidata.service.openapi_service import OpenAPIService


@dataclass
class ValidatioRapidParts:
question: str
media_paths: str | list[str]
payload: Union[
BoundingBoxPayload,
ClassifyPayload,
ComparePayload,
FreeTextPayload,
LinePayload,
LocatePayload,
NamedEntityPayload,
PolygonPayload,
TranscriptionPayload,
]
truths: Union[
AttachCategoryTruth,
BoundingBoxTruth,
CompareTruth,
EmptyValidationTruth,
LineTruth,
LocateBoxTruth,
NamedEntityTruth,
PolygonTruth,
TranscriptionTruth,
]
metadata: Any
randomCorrectProbability: float



class ValidationSetBuilder:
Expand All @@ -79,20 +42,18 @@ def create(self):
if self.validation_set_id is None:
raise ValueError("Failed to create validation set")

validation_set = RapidataValidationSet(validation_set_id=self.validation_set_id, openapi_service=self.openapi_service)

for rapid_part in self._rapid_parts:
model = AddValidationRapidModel(
validationSetId=self.validation_set_id,
payload=AddValidationRapidModelPayload(rapid_part.payload),
truth=AddValidationRapidModelTruth(rapid_part.truths),
metadata=rapid_part.metadata or [],
validation_set.add_validation_rapid(
payload=rapid_part.payload,
truths=rapid_part.truths,
metadata=rapid_part.metadata,
media_paths=rapid_part.media_paths,
randomCorrectProbability=rapid_part.randomCorrectProbability,
)

self.openapi_service.validation_api.validation_add_validation_rapid_post(
model=model, files=rapid_part.media_paths if isinstance(rapid_part.media_paths, list) else [rapid_part.media_paths] # type: ignore
)

return str(self.validation_set_id)

return validation_set

def add_classify_rapid(
self, media_path: str, question: str, categories: list[str], truths: list[str]
Expand Down
2 changes: 1 addition & 1 deletion rapidata/rapidata_client/order/rapidata_order.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from rapidata.api_client.models.create_order_model_referee import CreateOrderModelReferee
from rapidata.api_client.models.create_order_model_workflow import CreateOrderModelWorkflow
from rapidata.rapidata_client.order.dataset.rapidata_dataset import RapidataDataset
from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
from rapidata.rapidata_client.workflow import Workflow
from rapidata.api_client.models.create_order_model import CreateOrderModel
from rapidata.rapidata_client.referee import Referee
Expand Down
2 changes: 1 addition & 1 deletion rapidata/rapidata_client/order/rapidata_order_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from rapidata.rapidata_client.feature_flags import FeatureFlags
from rapidata.rapidata_client.metadata.base_metadata import Metadata
from rapidata.rapidata_client.order.dataset.rapidata_dataset import RapidataDataset
from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
from rapidata.rapidata_client.referee.naive_referee import NaiveReferee
from rapidata.rapidata_client.types import RapidAsset
from rapidata.rapidata_client.workflow import Workflow
Expand Down
18 changes: 14 additions & 4 deletions rapidata/rapidata_client/rapidata_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from rapidata.rapidata_client.order.dataset.validation_set_builder import ValidationSetBuilder
from rapidata.rapidata_client.dataset.rapidata_validation_set import (
RapidataValidationSet,
)
from rapidata.rapidata_client.dataset.validation_set_builder import ValidationSetBuilder
from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder
from rapidata.service.openapi_service import OpenAPIService


class RapidataClient:
"""
A client for interacting with the Rapidata API.
Expand All @@ -24,8 +28,6 @@ def __init__(
client_id=client_id, client_secret=client_secret, endpoint=endpoint
)



def new_order(self, name: str) -> RapidataOrderBuilder:
"""
Create a new order using a RapidataOrderBuilder instance.
Expand All @@ -34,7 +36,6 @@ def new_order(self, name: str) -> RapidataOrderBuilder:
:return: A RapidataOrderBuilder instance.
"""
return RapidataOrderBuilder(openapi_service=self.openapi_service, name=name)


def new_validation_set(self, name: str) -> ValidationSetBuilder:
"""
Expand All @@ -44,3 +45,12 @@ def new_validation_set(self, name: str) -> ValidationSetBuilder:
:return: A ValidationDatasetBuilder instance.
"""
return ValidationSetBuilder(name=name, openapi_service=self.openapi_service)

def get_validation_set(self, validation_set_id: str) -> RapidataValidationSet:
"""
Get a validation set by ID.

:param validation_set_id: The ID of the validation set.
:return: The ValidationSet instance.
"""
return RapidataValidationSet(validation_set_id, self.openapi_service)