Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/classify_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,23 @@
from rapidata.rapidata_client.rapidata_client import RapidataClient
from rapidata.rapidata_client.workflow import ClassifyWorkflow
from rapidata.rapidata_client.referee import NaiveReferee
from rapidata.rapidata_client.metadata.prompt_metadata import PromptMetadata


def new_classify_order(rapi: RapidataClient):
# Validation set
validation_set = (
rapi.new_validation_set("Example Validation Set")
.add_classify_rapid(
media_path="examples/data/wallaby.jpg",
question="What kind of animal is this?",
categories=["Mammal", "Marsupial", "Bird", "Reptile"],
truths=["Marsupial"],
metadata=[PromptMetadata(prompt="Hint: It has a pouch")],
)
.create()
)

# Configure order
order = (
rapi.new_order(
Expand All @@ -20,6 +34,7 @@ def new_classify_order(rapi: RapidataClient):
.media(["examples/data/wallaby.jpg"])
.referee(NaiveReferee(required_guesses=15))
.feature_flags(FeatureFlags().alert_on_fast_response(3))
.validation_set_id(validation_set.id)
.create()
)

Expand Down
4 changes: 2 additions & 2 deletions examples/compare_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def new_compare_order(rapi: RapidataClient):
)
.referee(NaiveReferee(required_guesses=1))
.media(
media_paths=[[ # this is a list of lists of media paths, since each rapid shows two images
media_paths=[[ # this is a list of lists of paths, since each rapid shows two images
"examples/data/rapidata_concept_logo.jpg",
"examples/data/rapidata_logo.png",
]]
)
.validation_set(validation_set.id)
.validation_set_id(validation_set.id)
.feature_flags(FeatureFlags().claire_design().alert_on_fast_response(4))
.create()
)
Expand Down
2 changes: 1 addition & 1 deletion examples/transcription_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def new_transcription_order(rapi: RapidataClient):
.referee(NaiveReferee(required_guesses=30))
.feature_flags(FeatureFlags().alert_on_fast_response(4))
.media(media_paths=["examples/data/waiting.mp4"], metadata=[transcription])
.validation_set(validation_set.id)
.validation_set_id(validation_set.id)
.create()
)

Expand Down
6 changes: 4 additions & 2 deletions rapidata/rapidata_client/dataset/rapidata_validation_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rapidata.api_client.models.classify_payload import ClassifyPayload
from rapidata.api_client.models.compare_payload import ComparePayload
from rapidata.api_client.models.compare_truth import CompareTruth
from rapidata.api_client.models.datapoint_metadata_model_metadata_inner import DatapointMetadataModelMetadataInner
from rapidata.api_client.models.empty_validation_truth import EmptyValidationTruth
from rapidata.api_client.models.free_text_payload import FreeTextPayload
from rapidata.api_client.models.line_payload import LinePayload
Expand All @@ -27,6 +28,7 @@
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
from rapidata.rapidata_client.dataset.validation_rapid_parts import ValidatioRapidParts
from rapidata.rapidata_client.metadata.base_metadata import Metadata
from rapidata.rapidata_client.types import RapidAsset
from rapidata.service.openapi_service import OpenAPIService

Expand Down Expand Up @@ -61,15 +63,15 @@ def add_validation_rapid(
| PolygonTruth
| TranscriptionTruth
),
metadata: Any,
metadata: list[Metadata],
media_paths: RapidAsset,
randomCorrectProbability: float,
):
model = AddValidationRapidModel(
validationSetId=self.id,
payload=AddValidationRapidModelPayload(payload),
truth=AddValidationRapidModelTruth(truths),
metadata=metadata or [],
metadata=[DatapointMetadataModelMetadataInner(meta.to_model()) for meta in metadata],
randomCorrectProbability=randomCorrectProbability,
)

Expand Down
3 changes: 2 additions & 1 deletion rapidata/rapidata_client/dataset/validation_rapid_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from rapidata.api_client.models.polygon_truth import PolygonTruth
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
from rapidata.rapidata_client.metadata.base_metadata import Metadata
from rapidata.rapidata_client.types import RapidAsset


Expand Down Expand Up @@ -48,5 +49,5 @@ class ValidatioRapidParts:
| PolygonTruth
| TranscriptionTruth
)
metadata: Any
metadata: list[Metadata]
randomCorrectProbability: float
49 changes: 27 additions & 22 deletions rapidata/rapidata_client/dataset/validation_set_builder.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
from dataclasses import dataclass
import os
from typing import Any, Union
from rapidata.api_client.models.add_validation_rapid_model import AddValidationRapidModel
from rapidata.api_client.models.add_validation_rapid_model_payload import (
AddValidationRapidModelPayload,
)
from rapidata.api_client.models.add_validation_rapid_model_truth import (
AddValidationRapidModelTruth,
)
from rapidata.api_client.models.attach_category_truth import AttachCategoryTruth
from rapidata.api_client.models.classify_payload import ClassifyPayload
from rapidata.api_client.models.compare_payload import ComparePayload
from rapidata.api_client.models.compare_truth import CompareTruth
from rapidata.api_client.models.transcription_payload import TranscriptionPayload
from rapidata.api_client.models.transcription_truth import TranscriptionTruth
from rapidata.api_client.models.transcription_word import TranscriptionWord
from rapidata.rapidata_client.dataset.rapidata_validation_set import RapidataValidationSet
from rapidata.rapidata_client.dataset.rapidata_validation_set import (
RapidataValidationSet,
)
from rapidata.rapidata_client.dataset.validation_rapid_parts import ValidatioRapidParts
from rapidata.rapidata_client.metadata.base_metadata import Metadata
from rapidata.service.openapi_service import OpenAPIService





class ValidationSetBuilder:

def __init__(self, name: str, openapi_service: OpenAPIService):
Expand All @@ -42,7 +33,10 @@ def create(self):
if self.validation_set_id is None:
raise ValueError("Failed to create validation set")

validation_set = RapidataValidationSet(validation_set_id=self.validation_set_id, openapi_service=self.openapi_service)
validation_set = RapidataValidationSet(
validation_set_id=self.validation_set_id,
openapi_service=self.openapi_service,
)

for rapid_part in self._rapid_parts:
validation_set.add_validation_rapid(
Expand All @@ -52,14 +46,19 @@ def create(self):
media_paths=rapid_part.media_paths,
randomCorrectProbability=rapid_part.randomCorrectProbability,
)

return validation_set

def add_classify_rapid(
self, media_path: str, question: str, categories: list[str], truths: list[str]
self,
media_path: str,
question: str,
categories: list[str],
truths: list[str],
metadata: list[Metadata] = [],
):
payload = ClassifyPayload(
_t="ClassifyPaylod", possibleCategories=categories, title=question
_t="ClassifyPayload", possibleCategories=categories, title=question
)
model_truth = AttachCategoryTruth(
correctCategories=truths, _t="AttachCategoryTruth"
Expand All @@ -71,14 +70,20 @@ def add_classify_rapid(
media_paths=media_path,
payload=payload,
truths=model_truth,
metadata=None,
metadata=metadata,
randomCorrectProbability=len(truths) / len(categories),
)
)

return self

def add_compare_rapid(self, media_paths: list[str], question: str, truth: str):
def add_compare_rapid(
self,
media_paths: list[str],
question: str,
truth: str,
metadata: list[Metadata] = [],
):
payload = ComparePayload(_t="ComparePayload", criteria=question)
# take only last part of truth path
truth = os.path.basename(truth)
Expand All @@ -91,15 +96,14 @@ def add_compare_rapid(self, media_paths: list[str], question: str, truth: str):
for media_path in media_paths:
if not os.path.exists(media_path):
raise FileNotFoundError(f"File not found: {media_path}")


self._rapid_parts.append(
ValidatioRapidParts(
question=question,
media_paths=media_paths,
payload=payload,
truths=model_truth,
metadata=None,
metadata=metadata,
randomCorrectProbability=1 / len(media_paths),
)
)
Expand All @@ -113,6 +117,7 @@ def add_transcription_rapid(
transcription: list[str],
correct_words: list[str],
strict_grading: bool | None = None,
metadata: list[Metadata] = [],
):
transcription_words = [
TranscriptionWord(word=word, wordIndex=i)
Expand Down Expand Up @@ -143,7 +148,7 @@ def add_transcription_rapid(
media_paths=media_path,
payload=payload,
truths=model_truth,
metadata=None,
metadata=metadata,
randomCorrectProbability=1 / len(transcription),
)
)
Expand Down
2 changes: 1 addition & 1 deletion rapidata/rapidata_client/order/rapidata_order_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def aggregator(self, aggregator: AggregatorType):
self._aggregator = aggregator
return self

def validation_set(self, validation_set_id: str):
def validation_set_id(self, validation_set_id: str):
"""
Set the validation set for the order.

Expand Down