diff --git a/examples/classify_order.py b/examples/classify_order.py index 5a8332bc6..94ccadacc 100644 --- a/examples/classify_order.py +++ b/examples/classify_order.py @@ -3,9 +3,23 @@ from rapidata.rapidata_client.rapidata_client import RapidataClient from rapidata.rapidata_client.workflow import ClassifyWorkflow from rapidata.rapidata_client.referee import NaiveReferee +from rapidata.rapidata_client.metadata.prompt_metadata import PromptMetadata def new_classify_order(rapi: RapidataClient): + # Validation set + validation_set = ( + rapi.new_validation_set("Example Validation Set") + .add_classify_rapid( + media_path="examples/data/wallaby.jpg", + question="What kind of animal is this?", + categories=["Mammal", "Marsupial", "Bird", "Reptile"], + truths=["Marsupial"], + metadata=[PromptMetadata(prompt="Hint: It has a pouch")], + ) + .create() + ) + # Configure order order = ( rapi.new_order( @@ -20,6 +34,7 @@ def new_classify_order(rapi: RapidataClient): .media(["examples/data/wallaby.jpg"]) .referee(NaiveReferee(required_guesses=15)) .feature_flags(FeatureFlags().alert_on_fast_response(3)) + .validation_set_id(validation_set.id) .create() ) diff --git a/examples/compare_order.py b/examples/compare_order.py index 9a6d39993..3c11558d2 100644 --- a/examples/compare_order.py +++ b/examples/compare_order.py @@ -30,12 +30,12 @@ def new_compare_order(rapi: RapidataClient): ) .referee(NaiveReferee(required_guesses=1)) .media( - media_paths=[[ # this is a list of lists of media paths, since each rapid shows two images + media_paths=[[ # this is a list of lists of paths, since each rapid shows two images "examples/data/rapidata_concept_logo.jpg", "examples/data/rapidata_logo.png", ]] ) - .validation_set(validation_set.id) + .validation_set_id(validation_set.id) .feature_flags(FeatureFlags().claire_design().alert_on_fast_response(4)) .create() ) diff --git a/examples/transcription_order.py b/examples/transcription_order.py index ddfef3735..696dca483 100644 --- a/examples/transcription_order.py +++ b/examples/transcription_order.py @@ -38,7 +38,7 @@ def new_transcription_order(rapi: RapidataClient): .referee(NaiveReferee(required_guesses=30)) .feature_flags(FeatureFlags().alert_on_fast_response(4)) .media(media_paths=["examples/data/waiting.mp4"], metadata=[transcription]) - .validation_set(validation_set.id) + .validation_set_id(validation_set.id) .create() ) diff --git a/rapidata/rapidata_client/dataset/rapidata_validation_set.py b/rapidata/rapidata_client/dataset/rapidata_validation_set.py index ed474a7b8..266c22a19 100644 --- a/rapidata/rapidata_client/dataset/rapidata_validation_set.py +++ b/rapidata/rapidata_client/dataset/rapidata_validation_set.py @@ -14,6 +14,7 @@ from rapidata.api_client.models.classify_payload import ClassifyPayload from rapidata.api_client.models.compare_payload import ComparePayload from rapidata.api_client.models.compare_truth import CompareTruth +from rapidata.api_client.models.datapoint_metadata_model_metadata_inner import DatapointMetadataModelMetadataInner from rapidata.api_client.models.empty_validation_truth import EmptyValidationTruth from rapidata.api_client.models.free_text_payload import FreeTextPayload from rapidata.api_client.models.line_payload import LinePayload @@ -27,6 +28,7 @@ from rapidata.api_client.models.transcription_payload import TranscriptionPayload from rapidata.api_client.models.transcription_truth import TranscriptionTruth from rapidata.rapidata_client.dataset.validation_rapid_parts import ValidatioRapidParts +from rapidata.rapidata_client.metadata.base_metadata import Metadata from rapidata.rapidata_client.types import RapidAsset from rapidata.service.openapi_service import OpenAPIService @@ -61,7 +63,7 @@ def add_validation_rapid( | PolygonTruth | TranscriptionTruth ), - metadata: Any, + metadata: list[Metadata], media_paths: RapidAsset, randomCorrectProbability: float, ): @@ -69,7 +71,7 @@ def add_validation_rapid( validationSetId=self.id, payload=AddValidationRapidModelPayload(payload), truth=AddValidationRapidModelTruth(truths), - metadata=metadata or [], + metadata=[DatapointMetadataModelMetadataInner(meta.to_model()) for meta in metadata], randomCorrectProbability=randomCorrectProbability, ) diff --git a/rapidata/rapidata_client/dataset/validation_rapid_parts.py b/rapidata/rapidata_client/dataset/validation_rapid_parts.py index b9a349fa0..3238d1276 100644 --- a/rapidata/rapidata_client/dataset/validation_rapid_parts.py +++ b/rapidata/rapidata_client/dataset/validation_rapid_parts.py @@ -19,6 +19,7 @@ from rapidata.api_client.models.polygon_truth import PolygonTruth from rapidata.api_client.models.transcription_payload import TranscriptionPayload from rapidata.api_client.models.transcription_truth import TranscriptionTruth +from rapidata.rapidata_client.metadata.base_metadata import Metadata from rapidata.rapidata_client.types import RapidAsset @@ -48,5 +49,5 @@ class ValidatioRapidParts: | PolygonTruth | TranscriptionTruth ) - metadata: Any + metadata: list[Metadata] randomCorrectProbability: float diff --git a/rapidata/rapidata_client/dataset/validation_set_builder.py b/rapidata/rapidata_client/dataset/validation_set_builder.py index 1f6ed1ef9..67675ea71 100644 --- a/rapidata/rapidata_client/dataset/validation_set_builder.py +++ b/rapidata/rapidata_client/dataset/validation_set_builder.py @@ -1,13 +1,4 @@ -from dataclasses import dataclass import os -from typing import Any, Union -from rapidata.api_client.models.add_validation_rapid_model import AddValidationRapidModel -from rapidata.api_client.models.add_validation_rapid_model_payload import ( - AddValidationRapidModelPayload, -) -from rapidata.api_client.models.add_validation_rapid_model_truth import ( - AddValidationRapidModelTruth, -) from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth from rapidata.api_client.models.classify_payload import ClassifyPayload from rapidata.api_client.models.compare_payload import ComparePayload @@ -15,14 +6,14 @@ from rapidata.api_client.models.transcription_payload import TranscriptionPayload from rapidata.api_client.models.transcription_truth import TranscriptionTruth from rapidata.api_client.models.transcription_word import TranscriptionWord -from rapidata.rapidata_client.dataset.rapidata_validation_set import RapidataValidationSet +from rapidata.rapidata_client.dataset.rapidata_validation_set import ( + RapidataValidationSet, +) from rapidata.rapidata_client.dataset.validation_rapid_parts import ValidatioRapidParts +from rapidata.rapidata_client.metadata.base_metadata import Metadata from rapidata.service.openapi_service import OpenAPIService - - - class ValidationSetBuilder: def __init__(self, name: str, openapi_service: OpenAPIService): @@ -42,7 +33,10 @@ def create(self): if self.validation_set_id is None: raise ValueError("Failed to create validation set") - validation_set = RapidataValidationSet(validation_set_id=self.validation_set_id, openapi_service=self.openapi_service) + validation_set = RapidataValidationSet( + validation_set_id=self.validation_set_id, + openapi_service=self.openapi_service, + ) for rapid_part in self._rapid_parts: validation_set.add_validation_rapid( @@ -52,14 +46,19 @@ def create(self): media_paths=rapid_part.media_paths, randomCorrectProbability=rapid_part.randomCorrectProbability, ) - + return validation_set def add_classify_rapid( - self, media_path: str, question: str, categories: list[str], truths: list[str] + self, + media_path: str, + question: str, + categories: list[str], + truths: list[str], + metadata: list[Metadata] = [], ): payload = ClassifyPayload( - _t="ClassifyPaylod", possibleCategories=categories, title=question + _t="ClassifyPayload", possibleCategories=categories, title=question ) model_truth = AttachCategoryTruth( correctCategories=truths, _t="AttachCategoryTruth" @@ -71,14 +70,20 @@ def add_classify_rapid( media_paths=media_path, payload=payload, truths=model_truth, - metadata=None, + metadata=metadata, randomCorrectProbability=len(truths) / len(categories), ) ) return self - def add_compare_rapid(self, media_paths: list[str], question: str, truth: str): + def add_compare_rapid( + self, + media_paths: list[str], + question: str, + truth: str, + metadata: list[Metadata] = [], + ): payload = ComparePayload(_t="ComparePayload", criteria=question) # take only last part of truth path truth = os.path.basename(truth) @@ -91,7 +96,6 @@ def add_compare_rapid(self, media_paths: list[str], question: str, truth: str): for media_path in media_paths: if not os.path.exists(media_path): raise FileNotFoundError(f"File not found: {media_path}") - self._rapid_parts.append( ValidatioRapidParts( @@ -99,7 +103,7 @@ def add_compare_rapid(self, media_paths: list[str], question: str, truth: str): media_paths=media_paths, payload=payload, truths=model_truth, - metadata=None, + metadata=metadata, randomCorrectProbability=1 / len(media_paths), ) ) @@ -113,6 +117,7 @@ def add_transcription_rapid( transcription: list[str], correct_words: list[str], strict_grading: bool | None = None, + metadata: list[Metadata] = [], ): transcription_words = [ TranscriptionWord(word=word, wordIndex=i) @@ -143,7 +148,7 @@ def add_transcription_rapid( media_paths=media_path, payload=payload, truths=model_truth, - metadata=None, + metadata=metadata, randomCorrectProbability=1 / len(transcription), ) ) diff --git a/rapidata/rapidata_client/order/rapidata_order_builder.py b/rapidata/rapidata_client/order/rapidata_order_builder.py index 2b5ee37d5..93ee114ea 100644 --- a/rapidata/rapidata_client/order/rapidata_order_builder.py +++ b/rapidata/rapidata_client/order/rapidata_order_builder.py @@ -186,7 +186,7 @@ def aggregator(self, aggregator: AggregatorType): self._aggregator = aggregator return self - def validation_set(self, validation_set_id: str): + def validation_set_id(self, validation_set_id: str): """ Set the validation set for the order.