diff --git a/examples/compare_order.py b/examples/compare_order.py index c1f747463..9a6d39993 100644 --- a/examples/compare_order.py +++ b/examples/compare_order.py @@ -10,7 +10,7 @@ def new_compare_order(rapi: RapidataClient): concept_path = "examples/data/rapidata_concept_logo.jpg" # configure validation set - validation_set_id = rapi.new_validation_set( + validation_set = rapi.new_validation_set( name="Example SimpleMatchup Validation Set" ).add_compare_rapid( media_paths=[logo_path, concept_path], @@ -35,7 +35,7 @@ def new_compare_order(rapi: RapidataClient): "examples/data/rapidata_logo.png", ]] ) - .validation_set(validation_set_id) + .validation_set(validation_set.id) .feature_flags(FeatureFlags().claire_design().alert_on_fast_response(4)) .create() ) diff --git a/examples/transcription_order.py b/examples/transcription_order.py index 322cfce38..ddfef3735 100644 --- a/examples/transcription_order.py +++ b/examples/transcription_order.py @@ -12,7 +12,7 @@ def new_transcription_order(rapi: RapidataClient): - validation_set_id = ( + validation_set = ( rapi.new_validation_set( name="Example Transcription Validation Set" ).add_transcription_rapid( @@ -38,7 +38,7 @@ def new_transcription_order(rapi: RapidataClient): .referee(NaiveReferee(required_guesses=30)) .feature_flags(FeatureFlags().alert_on_fast_response(4)) .media(media_paths=["examples/data/waiting.mp4"], metadata=[transcription]) - .validation_set(validation_set_id) + .validation_set(validation_set.id) .create() ) diff --git a/rapidata/rapidata_client/order/dataset/__init__.py b/rapidata/rapidata_client/dataset/__init__.py similarity index 100% rename from rapidata/rapidata_client/order/dataset/__init__.py rename to rapidata/rapidata_client/dataset/__init__.py diff --git a/rapidata/rapidata_client/order/dataset/rapidata_dataset.py b/rapidata/rapidata_client/dataset/rapidata_dataset.py similarity index 100% rename from rapidata/rapidata_client/order/dataset/rapidata_dataset.py rename to rapidata/rapidata_client/dataset/rapidata_dataset.py diff --git a/rapidata/rapidata_client/dataset/rapidata_validation_set.py b/rapidata/rapidata_client/dataset/rapidata_validation_set.py new file mode 100644 index 000000000..ed474a7b8 --- /dev/null +++ b/rapidata/rapidata_client/dataset/rapidata_validation_set.py @@ -0,0 +1,78 @@ +from typing import Any +from rapidata.api_client.models.add_validation_rapid_model import ( + AddValidationRapidModel, +) +from rapidata.api_client.models.add_validation_rapid_model_payload import ( + AddValidationRapidModelPayload, +) +from rapidata.api_client.models.add_validation_rapid_model_truth import ( + AddValidationRapidModelTruth, +) +from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth +from rapidata.api_client.models.bounding_box_payload import BoundingBoxPayload +from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth +from rapidata.api_client.models.classify_payload import ClassifyPayload +from rapidata.api_client.models.compare_payload import ComparePayload +from rapidata.api_client.models.compare_truth import CompareTruth +from rapidata.api_client.models.empty_validation_truth import EmptyValidationTruth +from rapidata.api_client.models.free_text_payload import FreeTextPayload +from rapidata.api_client.models.line_payload import LinePayload +from rapidata.api_client.models.line_truth import LineTruth +from rapidata.api_client.models.locate_box_truth import LocateBoxTruth +from rapidata.api_client.models.locate_payload import LocatePayload +from rapidata.api_client.models.named_entity_payload import NamedEntityPayload +from rapidata.api_client.models.named_entity_truth import NamedEntityTruth +from rapidata.api_client.models.polygon_payload import PolygonPayload +from rapidata.api_client.models.polygon_truth import PolygonTruth +from rapidata.api_client.models.transcription_payload import TranscriptionPayload +from rapidata.api_client.models.transcription_truth import TranscriptionTruth +from rapidata.rapidata_client.dataset.validation_rapid_parts import ValidatioRapidParts +from rapidata.rapidata_client.types import RapidAsset +from rapidata.service.openapi_service import OpenAPIService + + +class RapidataValidationSet: + + def __init__(self, validation_set_id, openapi_service: OpenAPIService): + self.id = validation_set_id + self.openapi_service = openapi_service + + def add_validation_rapid( + self, + payload: ( + BoundingBoxPayload + | ClassifyPayload + | ComparePayload + | FreeTextPayload + | LinePayload + | LocatePayload + | NamedEntityPayload + | PolygonPayload + | TranscriptionPayload + ), + truths: ( + AttachCategoryTruth + | BoundingBoxTruth + | CompareTruth + | EmptyValidationTruth + | LineTruth + | LocateBoxTruth + | NamedEntityTruth + | PolygonTruth + | TranscriptionTruth + ), + metadata: Any, + media_paths: RapidAsset, + randomCorrectProbability: float, + ): + model = AddValidationRapidModel( + validationSetId=self.id, + payload=AddValidationRapidModelPayload(payload), + truth=AddValidationRapidModelTruth(truths), + metadata=metadata or [], + randomCorrectProbability=randomCorrectProbability, + ) + + self.openapi_service.validation_api.validation_add_validation_rapid_post( + model=model, files=media_paths if isinstance(media_paths, list) else [media_paths] # type: ignore + ) diff --git a/rapidata/rapidata_client/dataset/validation_rapid_parts.py b/rapidata/rapidata_client/dataset/validation_rapid_parts.py new file mode 100644 index 000000000..b9a349fa0 --- /dev/null +++ b/rapidata/rapidata_client/dataset/validation_rapid_parts.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Any + +from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth +from rapidata.api_client.models.bounding_box_payload import BoundingBoxPayload +from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth +from rapidata.api_client.models.classify_payload import ClassifyPayload +from rapidata.api_client.models.compare_payload import ComparePayload +from rapidata.api_client.models.compare_truth import CompareTruth +from rapidata.api_client.models.empty_validation_truth import EmptyValidationTruth +from rapidata.api_client.models.free_text_payload import FreeTextPayload +from rapidata.api_client.models.line_payload import LinePayload +from rapidata.api_client.models.line_truth import LineTruth +from rapidata.api_client.models.locate_box_truth import LocateBoxTruth +from rapidata.api_client.models.locate_payload import LocatePayload +from rapidata.api_client.models.named_entity_payload import NamedEntityPayload +from rapidata.api_client.models.named_entity_truth import NamedEntityTruth +from rapidata.api_client.models.polygon_payload import PolygonPayload +from rapidata.api_client.models.polygon_truth import PolygonTruth +from rapidata.api_client.models.transcription_payload import TranscriptionPayload +from rapidata.api_client.models.transcription_truth import TranscriptionTruth +from rapidata.rapidata_client.types import RapidAsset + + +@dataclass +class ValidatioRapidParts: + question: str + media_paths: RapidAsset + payload: ( + BoundingBoxPayload + | ClassifyPayload + | ComparePayload + | FreeTextPayload + | LinePayload + | LocatePayload + | NamedEntityPayload + | PolygonPayload + | TranscriptionPayload + ) + truths: ( + AttachCategoryTruth + | BoundingBoxTruth + | CompareTruth + | EmptyValidationTruth + | LineTruth + | LocateBoxTruth + | NamedEntityTruth + | PolygonTruth + | TranscriptionTruth + ) + metadata: Any + randomCorrectProbability: float diff --git a/rapidata/rapidata_client/order/dataset/validation_set_builder.py b/rapidata/rapidata_client/dataset/validation_set_builder.py similarity index 69% rename from rapidata/rapidata_client/order/dataset/validation_set_builder.py rename to rapidata/rapidata_client/dataset/validation_set_builder.py index a886c8382..1f6ed1ef9 100644 --- a/rapidata/rapidata_client/order/dataset/validation_set_builder.py +++ b/rapidata/rapidata_client/dataset/validation_set_builder.py @@ -9,55 +9,18 @@ AddValidationRapidModelTruth, ) from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth -from rapidata.api_client.models.bounding_box_payload import BoundingBoxPayload -from rapidata.api_client.models.bounding_box_truth import BoundingBoxTruth from rapidata.api_client.models.classify_payload import ClassifyPayload from rapidata.api_client.models.compare_payload import ComparePayload from rapidata.api_client.models.compare_truth import CompareTruth -from rapidata.api_client.models.empty_validation_truth import EmptyValidationTruth -from rapidata.api_client.models.free_text_payload import FreeTextPayload -from rapidata.api_client.models.line_payload import LinePayload -from rapidata.api_client.models.line_truth import LineTruth -from rapidata.api_client.models.locate_box_truth import LocateBoxTruth -from rapidata.api_client.models.locate_payload import LocatePayload -from rapidata.api_client.models.named_entity_payload import NamedEntityPayload -from rapidata.api_client.models.named_entity_truth import NamedEntityTruth -from rapidata.api_client.models.polygon_payload import PolygonPayload -from rapidata.api_client.models.polygon_truth import PolygonTruth from rapidata.api_client.models.transcription_payload import TranscriptionPayload from rapidata.api_client.models.transcription_truth import TranscriptionTruth from rapidata.api_client.models.transcription_word import TranscriptionWord +from rapidata.rapidata_client.dataset.rapidata_validation_set import RapidataValidationSet +from rapidata.rapidata_client.dataset.validation_rapid_parts import ValidatioRapidParts from rapidata.service.openapi_service import OpenAPIService -@dataclass -class ValidatioRapidParts: - question: str - media_paths: str | list[str] - payload: Union[ - BoundingBoxPayload, - ClassifyPayload, - ComparePayload, - FreeTextPayload, - LinePayload, - LocatePayload, - NamedEntityPayload, - PolygonPayload, - TranscriptionPayload, - ] - truths: Union[ - AttachCategoryTruth, - BoundingBoxTruth, - CompareTruth, - EmptyValidationTruth, - LineTruth, - LocateBoxTruth, - NamedEntityTruth, - PolygonTruth, - TranscriptionTruth, - ] - metadata: Any - randomCorrectProbability: float + class ValidationSetBuilder: @@ -79,20 +42,18 @@ def create(self): if self.validation_set_id is None: raise ValueError("Failed to create validation set") + validation_set = RapidataValidationSet(validation_set_id=self.validation_set_id, openapi_service=self.openapi_service) + for rapid_part in self._rapid_parts: - model = AddValidationRapidModel( - validationSetId=self.validation_set_id, - payload=AddValidationRapidModelPayload(rapid_part.payload), - truth=AddValidationRapidModelTruth(rapid_part.truths), - metadata=rapid_part.metadata or [], + validation_set.add_validation_rapid( + payload=rapid_part.payload, + truths=rapid_part.truths, + metadata=rapid_part.metadata, + media_paths=rapid_part.media_paths, randomCorrectProbability=rapid_part.randomCorrectProbability, ) - - self.openapi_service.validation_api.validation_add_validation_rapid_post( - model=model, files=rapid_part.media_paths if isinstance(rapid_part.media_paths, list) else [rapid_part.media_paths] # type: ignore - ) - - return str(self.validation_set_id) + + return validation_set def add_classify_rapid( self, media_path: str, question: str, categories: list[str], truths: list[str] diff --git a/rapidata/rapidata_client/order/rapidata_order.py b/rapidata/rapidata_client/order/rapidata_order.py index ead8f0867..09f9757ed 100644 --- a/rapidata/rapidata_client/order/rapidata_order.py +++ b/rapidata/rapidata_client/order/rapidata_order.py @@ -1,7 +1,7 @@ import time from rapidata.api_client.models.create_order_model_referee import CreateOrderModelReferee from rapidata.api_client.models.create_order_model_workflow import CreateOrderModelWorkflow -from rapidata.rapidata_client.order.dataset.rapidata_dataset import RapidataDataset +from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset from rapidata.rapidata_client.workflow import Workflow from rapidata.api_client.models.create_order_model import CreateOrderModel from rapidata.rapidata_client.referee import Referee diff --git a/rapidata/rapidata_client/order/rapidata_order_builder.py b/rapidata/rapidata_client/order/rapidata_order_builder.py index 4ad2f4537..06820ebcc 100644 --- a/rapidata/rapidata_client/order/rapidata_order_builder.py +++ b/rapidata/rapidata_client/order/rapidata_order_builder.py @@ -8,7 +8,7 @@ ) from rapidata.rapidata_client.feature_flags import FeatureFlags from rapidata.rapidata_client.metadata.base_metadata import Metadata -from rapidata.rapidata_client.order.dataset.rapidata_dataset import RapidataDataset +from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset from rapidata.rapidata_client.referee.naive_referee import NaiveReferee from rapidata.rapidata_client.types import RapidAsset from rapidata.rapidata_client.workflow import Workflow diff --git a/rapidata/rapidata_client/rapidata_client.py b/rapidata/rapidata_client/rapidata_client.py index 52a4c89cf..eeccdd553 100644 --- a/rapidata/rapidata_client/rapidata_client.py +++ b/rapidata/rapidata_client/rapidata_client.py @@ -1,7 +1,11 @@ -from rapidata.rapidata_client.order.dataset.validation_set_builder import ValidationSetBuilder +from rapidata.rapidata_client.dataset.rapidata_validation_set import ( + RapidataValidationSet, +) +from rapidata.rapidata_client.dataset.validation_set_builder import ValidationSetBuilder from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder from rapidata.service.openapi_service import OpenAPIService + class RapidataClient: """ A client for interacting with the Rapidata API. @@ -24,8 +28,6 @@ def __init__( client_id=client_id, client_secret=client_secret, endpoint=endpoint ) - - def new_order(self, name: str) -> RapidataOrderBuilder: """ Create a new order using a RapidataOrderBuilder instance. @@ -34,7 +36,6 @@ def new_order(self, name: str) -> RapidataOrderBuilder: :return: A RapidataOrderBuilder instance. """ return RapidataOrderBuilder(openapi_service=self.openapi_service, name=name) - def new_validation_set(self, name: str) -> ValidationSetBuilder: """ @@ -44,3 +45,12 @@ def new_validation_set(self, name: str) -> ValidationSetBuilder: :return: A ValidationDatasetBuilder instance. """ return ValidationSetBuilder(name=name, openapi_service=self.openapi_service) + + def get_validation_set(self, validation_set_id: str) -> RapidataValidationSet: + """ + Get a validation set by ID. + + :param validation_set_id: The ID of the validation set. + :return: The ValidationSet instance. + """ + return RapidataValidationSet(validation_set_id, self.openapi_service)