Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/basic_ranking_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"https://assets.rapidata.ai/c13b8feb-fb97-4646-8dfc-97f05d37a637.webp",
"https://assets.rapidata.ai/586dc517-c987-4d06-8a6f-553508b86356.webp",
"https://assets.rapidata.ai/f4884ecd-cacb-4387-ab18-3b6e7dcdf10c.webp",
"https://assets.rapidata.ai/79076f76-a432-4ef9-9007-6d09a218417a.webp"
"https://assets.rapidata.ai/79076f76-a432-4ef9-9007-6d09a218417a.webp",
]

if __name__ == "__main__":
Expand All @@ -21,9 +21,9 @@
order = rapi.order.create_ranking_order(
name="Example Ranking Order",
instruction="Which rabbit looks cooler?",
datapoints=DATAPOINTS,
total_comparison_budget=50, #Make 50 comparisons, each comparison containing 2 datapoints
random_comparisons_ratio=0.5 #First half of the comparisons are random, the second half are close matchups
datapoints=[DATAPOINTS],
comparison_budget_per_ranking=50, # Make 50 comparisons, each comparison containing 2 datapoints
random_comparisons_ratio=0.5, # First half of the comparisons are random, the second half are close matchups
).run()

order.display_progress_bar()
Expand Down
1 change: 1 addition & 0 deletions src/rapidata/rapidata_client/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Datapoint(BaseModel):
media_context: str | None = None
sentence: str | None = None
private_note: str | None = None
group: str | None = None

@field_validator("context")
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def upload_datapoint(
asset=uploaded_asset,
metadata=metadata,
sortIndex=index,
group=datapoint.group,
),
)

Expand Down
70 changes: 70 additions & 0 deletions src/rapidata/rapidata_client/datapoints/_datapoints_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from itertools import zip_longest
from typing import Literal, cast, Iterable
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Iterable' is not used.

Suggested change
from typing import Literal, cast, Iterable
from typing import Literal, cast

Copilot uses AI. Check for mistakes.
from rapidata.rapidata_client.datapoints._datapoint import Datapoint


class DatapointsValidator:
@staticmethod
def validate_datapoints(
datapoints: list[str] | list[list[str]],
contexts: list[str] | None = None,
media_contexts: list[str] | None = None,
sentences: list[str] | None = None,
private_notes: list[str] | None = None,
groups: list[str] | None = None,
) -> None:
if contexts and len(contexts) != len(datapoints):
raise ValueError("Number of contexts must match number of datapoints")
if media_contexts and len(media_contexts) != len(datapoints):
raise ValueError("Number of media contexts must match number of datapoints")
if sentences and len(sentences) != len(datapoints):
raise ValueError("Number of sentences must match number of datapoints")
if private_notes and len(private_notes) != len(datapoints):
raise ValueError("Number of private notes must match number of datapoints")
if groups and (
len(groups) != len(datapoints) or len(groups) != len(set(groups))
):
raise ValueError(
"Number of groups must match number of datapoints and must be unique."
)

@staticmethod
def map_datapoints(
datapoints: list[str] | list[list[str]],
contexts: list[str] | None = None,
media_contexts: list[str] | None = None,
sentences: list[str] | None = None,
private_notes: list[str] | None = None,
groups: list[str] | None = None,
data_type: Literal["text", "media"] = "media",
) -> list[Datapoint]:
DatapointsValidator.validate_datapoints(
datapoints=datapoints,
contexts=contexts,
media_contexts=media_contexts,
sentences=sentences,
private_notes=private_notes,
groups=groups,
)
return [
Datapoint(
asset=asset,
data_type=data_type,
context=context,
media_context=media_context,
sentence=sentence,
private_note=private_note,
group=group,
)
for asset, context, media_context, sentence, private_note, group in cast(
"Iterable[tuple[str | list[str], str | None, str | None, str | None, str | None, str | None]]", # because iterator only supports 5 arguments with specific type casting
zip_longest(
datapoints,
contexts or [],
media_contexts or [],
sentences or [],
private_notes or [],
groups or [],
),
)
]
26 changes: 16 additions & 10 deletions src/rapidata/rapidata_client/order/_rapidata_order_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
from rapidata.rapidata_client.referee._naive_referee import NaiveReferee
from rapidata.rapidata_client.selection._base_selection import RapidataSelection
from rapidata.rapidata_client.settings import RapidataSetting
from rapidata.rapidata_client.workflow import Workflow, FreeTextWorkflow
from rapidata.rapidata_client.workflow import (
Workflow,
FreeTextWorkflow,
MultiRankingWorkflow,
)
from rapidata.service.openapi_service import OpenAPIService
from rapidata.rapidata_client.api.rapidata_api_client import (
suppress_rapidata_error_logging,
Expand All @@ -63,7 +67,7 @@ def __init__(
):
self._name = name
self.order_id: str | None = None
self.__openapi_service = openapi_service
self._openapi_service = openapi_service
self.__dataset: Optional[RapidataDataset] = None
self.__workflow: Workflow | None = None
self.__referee: Referee | None = None
Expand All @@ -75,7 +79,7 @@ def __init__(
self.__datapoints: list[Datapoint] = []
self.__sticky_state_value: StickyStateLiteral | None = None
self.__validation_set_manager: ValidationSetManager = ValidationSetManager(
self.__openapi_service
self._openapi_service
)

def _to_model(self) -> CreateOrderModel:
Expand Down Expand Up @@ -144,7 +148,7 @@ def _set_validation_set_id(self) -> bool:
with suppress_rapidata_error_logging():
self.__validation_set_id = (
(
self.__openapi_service.validation_api.validation_set_recommended_get(
self._openapi_service.validation_api.validation_set_recommended_get(
asset_type=[self.__datapoints[0].get_asset_type()],
modality=[self.__workflow.modality],
instruction=self.__workflow._get_instruction(),
Expand Down Expand Up @@ -205,7 +209,9 @@ def _create(self) -> RapidataOrder:
"""
if (
rapidata_config.order.autoValidationSetCreation
and not isinstance(self.__workflow, FreeTextWorkflow)
and not isinstance(
self.__workflow, (FreeTextWorkflow, MultiRankingWorkflow)
)
and not self.__selections
):
new_validation_set = self._set_validation_set_id()
Expand All @@ -215,7 +221,7 @@ def _create(self) -> RapidataOrder:
order_model = self._to_model()
logger.debug("Creating order with model: %s", order_model)

result = self.__openapi_service.order_api.order_post(
result = self._openapi_service.order_api.order_post(
create_order_model=order_model
)

Expand All @@ -230,7 +236,7 @@ def _create(self) -> RapidataOrder:
+ f"A new validation set was created. Please annotate {required_amount} datapoint{('s' if required_amount != 1 else '')} so the order runs correctly."
+ Fore.RESET
)
link = f"https://app.{self.__openapi_service.environment}/validation-set/detail/{self.__validation_set_id}/annotate?orderId={self.order_id}&required={required_amount}"
link = f"https://app.{self._openapi_service.environment}/validation-set/detail/{self.__validation_set_id}/annotate?orderId={self.order_id}&required={required_amount}"
could_open_browser = webbrowser.open(link)
if not could_open_browser:
encoded_url = urllib.parse.quote(link, safe="%/:=&?~#+!$,;'@()*[]")
Expand All @@ -245,7 +251,7 @@ def _create(self) -> RapidataOrder:
managed_print()

self.__dataset = (
RapidataDataset(result.dataset_id, self.__openapi_service)
RapidataDataset(result.dataset_id, self._openapi_service)
if result.dataset_id
else None
)
Expand All @@ -256,7 +262,7 @@ def _create(self) -> RapidataOrder:

order = RapidataOrder(
order_id=self.order_id,
openapi_service=self.__openapi_service,
openapi_service=self._openapi_service,
name=self._name,
)

Expand All @@ -278,7 +284,7 @@ def _create(self) -> RapidataOrder:
logger.debug("Datapoints added to the order.")
logger.debug("Setting order to preview")
try:
self.__openapi_service.order_api.order_order_id_preview_post(self.order_id)
self._openapi_service.order_api.order_order_id_preview_post(self.order_id)
except Exception:
raise FailedUploadException(self.__dataset, order, failed_uploads)
return order
Expand Down
Loading
Loading