diff --git a/examples/basic_ranking_order.py b/examples/basic_ranking_order.py index 6134e154..929f489d 100644 --- a/examples/basic_ranking_order.py +++ b/examples/basic_ranking_order.py @@ -12,7 +12,7 @@ "https://assets.rapidata.ai/c13b8feb-fb97-4646-8dfc-97f05d37a637.webp", "https://assets.rapidata.ai/586dc517-c987-4d06-8a6f-553508b86356.webp", "https://assets.rapidata.ai/f4884ecd-cacb-4387-ab18-3b6e7dcdf10c.webp", - "https://assets.rapidata.ai/79076f76-a432-4ef9-9007-6d09a218417a.webp" + "https://assets.rapidata.ai/79076f76-a432-4ef9-9007-6d09a218417a.webp", ] if __name__ == "__main__": @@ -21,9 +21,9 @@ order = rapi.order.create_ranking_order( name="Example Ranking Order", instruction="Which rabbit looks cooler?", - datapoints=DATAPOINTS, - total_comparison_budget=50, #Make 50 comparisons, each comparison containing 2 datapoints - random_comparisons_ratio=0.5 #First half of the comparisons are random, the second half are close matchups + datapoints=[DATAPOINTS], + comparison_budget_per_ranking=50, # Make 50 comparisons, each comparison containing 2 datapoints + random_comparisons_ratio=0.5, # First half of the comparisons are random, the second half are close matchups ).run() order.display_progress_bar() diff --git a/src/rapidata/rapidata_client/datapoints/_datapoint.py b/src/rapidata/rapidata_client/datapoints/_datapoint.py index 183078ae..9d4d0d93 100644 --- a/src/rapidata/rapidata_client/datapoints/_datapoint.py +++ b/src/rapidata/rapidata_client/datapoints/_datapoint.py @@ -19,6 +19,7 @@ class Datapoint(BaseModel): media_context: str | None = None sentence: str | None = None private_note: str | None = None + group: str | None = None @field_validator("context") @classmethod diff --git a/src/rapidata/rapidata_client/datapoints/_datapoint_uploader.py b/src/rapidata/rapidata_client/datapoints/_datapoint_uploader.py index 8728b517..f4810531 100644 --- a/src/rapidata/rapidata_client/datapoints/_datapoint_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_datapoint_uploader.py @@ -45,6 +45,7 @@ def upload_datapoint( asset=uploaded_asset, metadata=metadata, sortIndex=index, + group=datapoint.group, ), ) diff --git a/src/rapidata/rapidata_client/datapoints/_datapoints_validator.py b/src/rapidata/rapidata_client/datapoints/_datapoints_validator.py new file mode 100644 index 00000000..c4e760b0 --- /dev/null +++ b/src/rapidata/rapidata_client/datapoints/_datapoints_validator.py @@ -0,0 +1,70 @@ +from itertools import zip_longest +from typing import Literal, cast, Iterable +from rapidata.rapidata_client.datapoints._datapoint import Datapoint + + +class DatapointsValidator: + @staticmethod + def validate_datapoints( + datapoints: list[str] | list[list[str]], + contexts: list[str] | None = None, + media_contexts: list[str] | None = None, + sentences: list[str] | None = None, + private_notes: list[str] | None = None, + groups: list[str] | None = None, + ) -> None: + if contexts and len(contexts) != len(datapoints): + raise ValueError("Number of contexts must match number of datapoints") + if media_contexts and len(media_contexts) != len(datapoints): + raise ValueError("Number of media contexts must match number of datapoints") + if sentences and len(sentences) != len(datapoints): + raise ValueError("Number of sentences must match number of datapoints") + if private_notes and len(private_notes) != len(datapoints): + raise ValueError("Number of private notes must match number of datapoints") + if groups and ( + len(groups) != len(datapoints) or len(groups) != len(set(groups)) + ): + raise ValueError( + "Number of groups must match number of datapoints and must be unique." + ) + + @staticmethod + def map_datapoints( + datapoints: list[str] | list[list[str]], + contexts: list[str] | None = None, + media_contexts: list[str] | None = None, + sentences: list[str] | None = None, + private_notes: list[str] | None = None, + groups: list[str] | None = None, + data_type: Literal["text", "media"] = "media", + ) -> list[Datapoint]: + DatapointsValidator.validate_datapoints( + datapoints=datapoints, + contexts=contexts, + media_contexts=media_contexts, + sentences=sentences, + private_notes=private_notes, + groups=groups, + ) + return [ + Datapoint( + asset=asset, + data_type=data_type, + context=context, + media_context=media_context, + sentence=sentence, + private_note=private_note, + group=group, + ) + for asset, context, media_context, sentence, private_note, group in cast( + "Iterable[tuple[str | list[str], str | None, str | None, str | None, str | None, str | None]]", # because iterator only supports 5 arguments with specific type casting + zip_longest( + datapoints, + contexts or [], + media_contexts or [], + sentences or [], + private_notes or [], + groups or [], + ), + ) + ] diff --git a/src/rapidata/rapidata_client/order/_rapidata_order_builder.py b/src/rapidata/rapidata_client/order/_rapidata_order_builder.py index e2ea6ef9..89f798df 100644 --- a/src/rapidata/rapidata_client/order/_rapidata_order_builder.py +++ b/src/rapidata/rapidata_client/order/_rapidata_order_builder.py @@ -37,7 +37,11 @@ from rapidata.rapidata_client.referee._naive_referee import NaiveReferee from rapidata.rapidata_client.selection._base_selection import RapidataSelection from rapidata.rapidata_client.settings import RapidataSetting -from rapidata.rapidata_client.workflow import Workflow, FreeTextWorkflow +from rapidata.rapidata_client.workflow import ( + Workflow, + FreeTextWorkflow, + MultiRankingWorkflow, +) from rapidata.service.openapi_service import OpenAPIService from rapidata.rapidata_client.api.rapidata_api_client import ( suppress_rapidata_error_logging, @@ -63,7 +67,7 @@ def __init__( ): self._name = name self.order_id: str | None = None - self.__openapi_service = openapi_service + self._openapi_service = openapi_service self.__dataset: Optional[RapidataDataset] = None self.__workflow: Workflow | None = None self.__referee: Referee | None = None @@ -75,7 +79,7 @@ def __init__( self.__datapoints: list[Datapoint] = [] self.__sticky_state_value: StickyStateLiteral | None = None self.__validation_set_manager: ValidationSetManager = ValidationSetManager( - self.__openapi_service + self._openapi_service ) def _to_model(self) -> CreateOrderModel: @@ -144,7 +148,7 @@ def _set_validation_set_id(self) -> bool: with suppress_rapidata_error_logging(): self.__validation_set_id = ( ( - self.__openapi_service.validation_api.validation_set_recommended_get( + self._openapi_service.validation_api.validation_set_recommended_get( asset_type=[self.__datapoints[0].get_asset_type()], modality=[self.__workflow.modality], instruction=self.__workflow._get_instruction(), @@ -205,7 +209,9 @@ def _create(self) -> RapidataOrder: """ if ( rapidata_config.order.autoValidationSetCreation - and not isinstance(self.__workflow, FreeTextWorkflow) + and not isinstance( + self.__workflow, (FreeTextWorkflow, MultiRankingWorkflow) + ) and not self.__selections ): new_validation_set = self._set_validation_set_id() @@ -215,7 +221,7 @@ def _create(self) -> RapidataOrder: order_model = self._to_model() logger.debug("Creating order with model: %s", order_model) - result = self.__openapi_service.order_api.order_post( + result = self._openapi_service.order_api.order_post( create_order_model=order_model ) @@ -230,7 +236,7 @@ def _create(self) -> RapidataOrder: + f"A new validation set was created. Please annotate {required_amount} datapoint{('s' if required_amount != 1 else '')} so the order runs correctly." + Fore.RESET ) - link = f"https://app.{self.__openapi_service.environment}/validation-set/detail/{self.__validation_set_id}/annotate?orderId={self.order_id}&required={required_amount}" + link = f"https://app.{self._openapi_service.environment}/validation-set/detail/{self.__validation_set_id}/annotate?orderId={self.order_id}&required={required_amount}" could_open_browser = webbrowser.open(link) if not could_open_browser: encoded_url = urllib.parse.quote(link, safe="%/:=&?~#+!$,;'@()*[]") @@ -245,7 +251,7 @@ def _create(self) -> RapidataOrder: managed_print() self.__dataset = ( - RapidataDataset(result.dataset_id, self.__openapi_service) + RapidataDataset(result.dataset_id, self._openapi_service) if result.dataset_id else None ) @@ -256,7 +262,7 @@ def _create(self) -> RapidataOrder: order = RapidataOrder( order_id=self.order_id, - openapi_service=self.__openapi_service, + openapi_service=self._openapi_service, name=self._name, ) @@ -278,7 +284,7 @@ def _create(self) -> RapidataOrder: logger.debug("Datapoints added to the order.") logger.debug("Setting order to preview") try: - self.__openapi_service.order_api.order_order_id_preview_post(self.order_id) + self._openapi_service.order_api.order_order_id_preview_post(self.order_id) except Exception: raise FailedUploadException(self.__dataset, order, failed_uploads) return order diff --git a/src/rapidata/rapidata_client/order/rapidata_order_manager.py b/src/rapidata/rapidata_client/order/rapidata_order_manager.py index c70f19be..f6af3daa 100644 --- a/src/rapidata/rapidata_client/order/rapidata_order_manager.py +++ b/src/rapidata/rapidata_client/order/rapidata_order_manager.py @@ -1,5 +1,4 @@ from typing import Sequence, Optional, Literal, get_args -from itertools import zip_longest from rapidata.rapidata_client.config.tracer import tracer from rapidata.rapidata_client.datapoints.metadata._base_metadata import Metadata @@ -21,17 +20,14 @@ DrawWorkflow, TimestampWorkflow, RankingWorkflow, + MultiRankingWorkflow, ) from rapidata.rapidata_client.datapoints._datapoint import Datapoint -from rapidata.rapidata_client.datapoints.metadata import ( - PromptMetadata, - MediaAssetMetadata, -) from rapidata.rapidata_client.filter import RapidataFilter from rapidata.rapidata_client.filter.rapidata_filters import RapidataFilters from rapidata.rapidata_client.settings import RapidataSettings, RapidataSetting from rapidata.rapidata_client.selection.rapidata_selections import RapidataSelections -from rapidata.rapidata_client.config import logger, rapidata_config +from rapidata.rapidata_client.config import logger from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader from rapidata.api_client.models.query_model import QueryModel @@ -41,9 +37,11 @@ from rapidata.api_client.models.filter_operator import FilterOperator from rapidata.api_client.models.sort_criterion import SortCriterion from rapidata.api_client.models.sort_direction import SortDirection -from rapidata.rapidata_client.order._rapidata_order_builder import StickyStateLiteral +from rapidata.rapidata_client.datapoints._datapoints_validator import ( + DatapointsValidator, +) -from tqdm import tqdm +from rapidata.rapidata_client.order._rapidata_order_builder import StickyStateLiteral class RapidataOrderManager: @@ -69,38 +67,14 @@ def _create_general_order( self, name: str, workflow: Workflow, - assets: list[str] | list[list[str]], - data_type: Literal["media", "text"] = "media", + datapoints: list[Datapoint], responses_per_datapoint: int = 10, - contexts: list[str] | None = None, - media_contexts: list[str] | None = None, validation_set_id: str | None = None, confidence_threshold: float | None = None, filters: Sequence[RapidataFilter] = [], settings: Sequence[RapidataSetting] = [], - sentences: list[str] | None = None, selections: Sequence[RapidataSelection] = [], - private_notes: list[str] | None = None, ) -> RapidataOrder: - - if not assets: - raise ValueError("No datapoints provided") - - if contexts and len(contexts) != len(assets): - raise ValueError("Number of contexts must match number of datapoints") - - if media_contexts and len(media_contexts) != len(assets): - raise ValueError("Number of media contexts must match number of datapoints") - - if sentences and len(sentences) != len(assets): - raise ValueError("Number of sentences must match number of datapoints") - - if private_notes and len(private_notes) != len(assets): - raise ValueError("Number of private notes must match number of datapoints") - - if sentences and contexts: - raise ValueError("You can only use contexts or sentences, not both") - if not confidence_threshold: referee = NaiveReferee(responses=responses_per_datapoint) else: @@ -109,25 +83,17 @@ def _create_general_order( max_vote_count=responses_per_datapoint, ) - if data_type not in ["media", "text"]: - raise ValueError("Data type must be one of 'media' or 'text'") - logger.debug( - "Creating order with parameters: name %s, workflow %s, datapoints %s, data_type %s, responses_per_datapoint %s, contexts %s, media_contexts %s, validation_set_id %s, confidence_threshold %s, filters %s, settings %s, sentences %s, selections %s, private_notes %s", + "Creating order with parameters: name %s, workflow %s, datapoints %s, responses_per_datapoint %s, validation_set_id %s, confidence_threshold %s, filters %s, settings %s, selections %s", name, workflow, - assets, - data_type, + datapoints, responses_per_datapoint, - contexts, - media_contexts, validation_set_id, confidence_threshold, filters, settings, - sentences, selections, - private_notes, ) order_builder = RapidataOrderBuilder( @@ -141,25 +107,7 @@ def _create_general_order( order = ( order_builder._workflow(workflow) - ._datapoints( - datapoints=[ - Datapoint( - asset=asset, - data_type=data_type, - context=context, - media_context=media_context, - sentence=sentence, - private_note=private_note, - ) - for asset, context, media_context, sentence, private_note in zip_longest( - assets, - contexts or [], - media_contexts or [], - sentences or [], - private_notes or [], - ) - ] - ) + ._datapoints(datapoints=datapoints) ._referee(referee) ._filters(filters) ._selections(selections) @@ -245,22 +193,25 @@ def create_classification_order( ): raise ValueError("Datapoints must be a list of strings") + datapoints_instances = DatapointsValidator.map_datapoints( + datapoints=datapoints, + contexts=contexts, + media_contexts=media_contexts, + private_notes=private_notes, + data_type=data_type, + ) return self._create_general_order( name=name, workflow=ClassifyWorkflow( instruction=instruction, answer_options=answer_options ), - assets=datapoints, - data_type=data_type, + datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, - contexts=contexts, - media_contexts=media_contexts, validation_set_id=validation_set_id, confidence_threshold=confidence_threshold, filters=filters, selections=selections, settings=settings, - private_notes=private_notes, ) def create_compare_order( @@ -331,33 +282,36 @@ def create_compare_order( "A_B_naming must be a list of exactly two strings or None" ) + datapoints_instances = DatapointsValidator.map_datapoints( + datapoints=datapoints, + contexts=contexts, + media_contexts=media_contexts, + private_notes=private_notes, + data_type=data_type, + ) return self._create_general_order( name=name, workflow=CompareWorkflow(instruction=instruction, a_b_names=a_b_names), - assets=datapoints, - data_type=data_type, + datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, - contexts=contexts, - media_contexts=media_contexts, validation_set_id=validation_set_id, confidence_threshold=confidence_threshold, filters=filters, selections=selections, settings=settings, - private_notes=private_notes, ) def create_ranking_order( self, name: str, instruction: str, - datapoints: list[str], - total_comparison_budget: int, + datapoints: list[list[str]], + comparison_budget_per_ranking: int, responses_per_comparison: int = 1, data_type: Literal["media", "text"] = "media", random_comparisons_ratio: float = 0.5, - context: Optional[str] = None, - media_context: Optional[str] = None, + contexts: Optional[list[str]] = None, + media_contexts: Optional[list[str]] = None, validation_set_id: Optional[str] = None, filters: Sequence[RapidataFilter] = [], settings: Sequence[RapidataSetting] = [], @@ -366,49 +320,74 @@ def create_ranking_order( """ Create a ranking order. - With this order you can rank a list of datapoints (image, text, video, audio) based on the instruction. - The annotators will be shown two datapoints at a time. The ranking happens in terms of an elo system based on the matchup results. + With this order you can have a multiple lists of datapoints (image, text, video, audio) be ranked based on the instruction. + Each list will be ranked independently, based on comparison matchups. Args: name (str): The name of the order. - instruction (str): The question asked from People when They see two datapoints. - datapoints (list[str]): A list of datapoints that will participate in the ranking. - total_comparison_budget (int): The total number of (pairwise-)comparisons that can be made. - responses_per_comparison (int, optional): The number of responses collected per comparison. Defaults to 1. + instruction (str): The instruction for the ranking. Will be shown with each matchup. + datapoints (list[list[str]]): The outer list is determines the independent rankings, the inner list is the datapoints for each ranking. + comparison_budget_per_ranking (int): The number of comparisons that will be collected per ranking (outer list of datapoints). + responses_per_comparison (int, optional): The number of responses that will be collected per comparison. Defaults to 1. data_type (str, optional): The data type of the datapoints. Defaults to "media" (any form of image, video or audio). \n Other option: "text". - random_comparisons_ratio (float, optional): The fraction of random comparisons in the ranking process. - The rest will focus on pairing similarly ranked datapoints. Defaults to 0.5 and can be left untouched. - context (str, optional): The context for all the comparison. Defaults to None.\n - If provided will be shown in addition to the instruction for all the matchups. - media_context (str, optional): The media context for all the comparison. Defaults to None.\n - If provided will be shown in addition to the instruction for all the matchups. + random_comparisons_ratio (float, optional): The ratio of random comparisons to the total number of comparisons. Defaults to 0.5. + contexts (list[str], optional): The list of contexts for the ranking. Defaults to None.\n + If provided has to be the same length as the outer list of datapoints and will be shown in addition to the instruction. (Therefore will be different for each ranking) + Will be matched up with the datapoints using the list index. + media_contexts (list[str], optional): The list of media contexts for the ranking i.e links to the images / videos. Defaults to None.\n + If provided has to be the same length as the outer list of datapoints and will be shown in addition to the instruction. (Therefore will be different for each ranking) + Will be matched up with the datapoints using the list index. validation_set_id (str, optional): The ID of the validation set. Defaults to None.\n If provided, one validation task will be shown infront of the datapoints that will be labeled. - filters (Sequence[RapidataFilter], optional): The list of filters for the order. Defaults to []. Decides who the tasks should be shown to. - settings (Sequence[RapidataSetting], optional): The list of settings for the order. Defaults to []. Decides how the tasks should be shown. - selections (Sequence[RapidataSelection], optional): The list of selections for the order. Defaults to []. Decides in what order the tasks should be shown. + filters (Sequence[RapidataFilter], optional): The list of filters for the ranking. Defaults to []. Decides who the tasks should be shown to. + settings (Sequence[RapidataSetting], optional): The list of settings for the ranking. Defaults to []. Decides how the tasks should be shown. + selections (Sequence[RapidataSelection], optional): The list of selections for the ranking. Defaults to []. Decides in what order the tasks should be shown. """ - with tracer.start_as_current_span("RapidataOrderManager.create_ranking_order"): - if len(datapoints) < 2: - raise ValueError("At least two datapoints are required") + if contexts and len(contexts) != len(datapoints): + raise ValueError( + "Number of contexts must match the number of sets that will be ranked" + ) + if media_contexts and len(media_contexts) != len(datapoints): + raise ValueError( + "Number of media contexts must match the number of sets that will be ranked" + ) + if not isinstance(datapoints, list) or not all( + isinstance(dp, list) for dp in datapoints + ): + raise ValueError( + "Datapoints must be a list of lists. Outer list is the independent rankings, inner list is the datapoints for each ranking." + ) + if not all(len(set(dp)) == len(dp) for dp in datapoints): + raise ValueError("Each inner list must contain unique datapoints.") + + if not all(len(inner_list) >= 2 for inner_list in datapoints): + raise ValueError( + "Each ranking must contain at least two unique datapoints." + ) - if len(set(datapoints)) != len(datapoints): - raise ValueError("Datapoints must be unique") + datapoints_instances = [] + for i, datapoint in enumerate(datapoints): + for d in datapoint: + datapoints_instances.append( + Datapoint( + asset=d, + data_type=data_type, + context=contexts[i] if contexts else None, + media_context=media_contexts[i] if media_contexts else None, + group=str(i), + ) + ) return self._create_general_order( name=name, - workflow=RankingWorkflow( - criteria=instruction, - total_comparison_budget=total_comparison_budget, + workflow=MultiRankingWorkflow( + instruction=instruction, + comparison_budget_per_ranking=comparison_budget_per_ranking, random_comparisons_ratio=random_comparisons_ratio, - context=context, - media_context=media_context, - file_uploader=self.__asset_uploader, ), - assets=datapoints, - data_type=data_type, + datapoints=datapoints_instances, responses_per_datapoint=responses_per_comparison, validation_set_id=validation_set_id, filters=filters, @@ -458,18 +437,21 @@ def create_free_text_order( with tracer.start_as_current_span( "RapidataOrderManager.create_free_text_order" ): + datapoints_instances = DatapointsValidator.map_datapoints( + datapoints=datapoints, + contexts=contexts, + media_contexts=media_contexts, + private_notes=private_notes, + data_type=data_type, + ) return self._create_general_order( name=name, workflow=FreeTextWorkflow(instruction=instruction), - assets=datapoints, - data_type=data_type, + datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, - contexts=contexts, - media_contexts=media_contexts, filters=filters, selections=selections, settings=settings, - private_notes=private_notes, ) def create_select_words_order( @@ -510,19 +492,22 @@ def create_select_words_order( with tracer.start_as_current_span( "RapidataOrderManager.create_select_words_order" ): + datapoints_instances = DatapointsValidator.map_datapoints( + datapoints=datapoints, + sentences=sentences, + private_notes=private_notes, + ) return self._create_general_order( name=name, workflow=SelectWordsWorkflow( instruction=instruction, ), - assets=datapoints, + datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, validation_set_id=validation_set_id, filters=filters, selections=selections, settings=settings, - sentences=sentences, - private_notes=private_notes, ) def create_locate_order( @@ -564,19 +549,21 @@ def create_locate_order( This will NOT be shown to the labelers but will be included in the result purely for your own reference. """ with tracer.start_as_current_span("RapidataOrderManager.create_locate_order"): - + datapoints_instances = DatapointsValidator.map_datapoints( + datapoints=datapoints, + contexts=contexts, + media_contexts=media_contexts, + private_notes=private_notes, + ) return self._create_general_order( name=name, workflow=LocateWorkflow(target=instruction), - assets=datapoints, + datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, - contexts=contexts, - media_contexts=media_contexts, validation_set_id=validation_set_id, filters=filters, selections=selections, settings=settings, - private_notes=private_notes, ) def create_draw_order( @@ -618,19 +605,21 @@ def create_draw_order( This will NOT be shown to the labelers but will be included in the result purely for your own reference. """ with tracer.start_as_current_span("RapidataOrderManager.create_draw_order"): - + datapoints_instances = DatapointsValidator.map_datapoints( + datapoints=datapoints, + contexts=contexts, + media_contexts=media_contexts, + private_notes=private_notes, + ) return self._create_general_order( name=name, workflow=DrawWorkflow(target=instruction), - assets=datapoints, + datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, - contexts=contexts, - media_contexts=media_contexts, validation_set_id=validation_set_id, filters=filters, selections=selections, settings=settings, - private_notes=private_notes, ) def create_timestamp_order( @@ -678,18 +667,21 @@ def create_timestamp_order( with tracer.start_as_current_span( "RapidataOrderManager.create_timestamp_order" ): + datapoints_instances = DatapointsValidator.map_datapoints( + datapoints=datapoints, + contexts=contexts, + media_contexts=media_contexts, + private_notes=private_notes, + ) return self._create_general_order( name=name, workflow=TimestampWorkflow(instruction=instruction), - assets=datapoints, + datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, - contexts=contexts, - media_contexts=media_contexts, validation_set_id=validation_set_id, filters=filters, selections=selections, settings=settings, - private_notes=private_notes, ) def get_order_by_id(self, order_id: str) -> RapidataOrder: diff --git a/src/rapidata/rapidata_client/workflow/__init__.py b/src/rapidata/rapidata_client/workflow/__init__.py index ebdece25..85cb4f94 100644 --- a/src/rapidata/rapidata_client/workflow/__init__.py +++ b/src/rapidata/rapidata_client/workflow/__init__.py @@ -8,3 +8,4 @@ from ._evaluation_workflow import EvaluationWorkflow from ._timestamp_workflow import TimestampWorkflow from ._ranking_workflow import RankingWorkflow +from ._multi_ranking_workflow import MultiRankingWorkflow diff --git a/src/rapidata/rapidata_client/workflow/_base_workflow.py b/src/rapidata/rapidata_client/workflow/_base_workflow.py index 815b1244..7ccaaf21 100644 --- a/src/rapidata/rapidata_client/workflow/_base_workflow.py +++ b/src/rapidata/rapidata_client/workflow/_base_workflow.py @@ -4,7 +4,9 @@ from rapidata.api_client.models.simple_workflow_model import SimpleWorkflowModel from rapidata.api_client.models.evaluation_workflow_model import EvaluationWorkflowModel from rapidata.api_client.models.compare_workflow_model import CompareWorkflowModel -from rapidata.rapidata_client.referee._base_referee import Referee +from rapidata.api_client.models.grouped_ranking_workflow_model import ( + GroupedRankingWorkflowModel, +) from rapidata.api_client import ( ClassifyPayload, ComparePayload, @@ -51,7 +53,12 @@ def _get_instruction(self) -> str: @abstractmethod def _to_model( self, - ) -> SimpleWorkflowModel | CompareWorkflowModel | EvaluationWorkflowModel: + ) -> ( + SimpleWorkflowModel + | CompareWorkflowModel + | EvaluationWorkflowModel + | GroupedRankingWorkflowModel + ): pass def _format_datapoints(self, datapoints: list[Datapoint]) -> list[Datapoint]: diff --git a/src/rapidata/rapidata_client/workflow/_multi_ranking_workflow.py b/src/rapidata/rapidata_client/workflow/_multi_ranking_workflow.py new file mode 100644 index 00000000..524603f1 --- /dev/null +++ b/src/rapidata/rapidata_client/workflow/_multi_ranking_workflow.py @@ -0,0 +1,83 @@ +from rapidata.api_client import ( + CompareWorkflowModelPairMakerConfig, + OnlinePairMakerConfigModel, + EloConfigModel, +) +from rapidata.api_client.models.grouped_ranking_workflow_model import ( + GroupedRankingWorkflowModel, +) +from rapidata.api_client.models.create_datapoint_from_files_model_metadata_inner import ( + CreateDatapointFromFilesModelMetadataInner, +) +from rapidata.rapidata_client.workflow._base_workflow import Workflow +from rapidata.api_client import ComparePayload +from rapidata.rapidata_client.datapoints._datapoint import Datapoint +from rapidata.api_client.models.rapid_modality import RapidModality +from rapidata.rapidata_client.datapoints.metadata import Metadata + + +class MultiRankingWorkflow(Workflow): + modality = RapidModality.COMPARE + + def __init__( + self, + instruction: str, + comparison_budget_per_ranking: int, + random_comparisons_ratio, + elo_start: int = 1200, + elo_k_factor: int = 40, + elo_scaling_factor: int = 400, + metadatas: list[Metadata] = [], + ): + super().__init__(type="CompareWorkflowConfig") + + self.metadatas = metadatas + + self.instruction = instruction + self.comparison_budget_per_ranking = comparison_budget_per_ranking + self.random_comparisons_ratio = random_comparisons_ratio + self.elo_start = elo_start + self.elo_k_factor = elo_k_factor + self.elo_scaling_factor = elo_scaling_factor + + self.pair_maker_config = CompareWorkflowModelPairMakerConfig( + OnlinePairMakerConfigModel( + _t="OnlinePairMaker", + totalComparisonBudget=comparison_budget_per_ranking, + randomMatchesRatio=random_comparisons_ratio, + ) + ) + + self.elo_config = EloConfigModel( + startingElo=elo_start, + kFactor=elo_k_factor, + scalingFactor=elo_scaling_factor, + ) + + def _to_model(self) -> GroupedRankingWorkflowModel: + + return GroupedRankingWorkflowModel( + _t="GroupedRankingWorkflow", + criteria=self.instruction, + eloConfig=self.elo_config, + pairMakerConfig=self.pair_maker_config, + metadata=[ + CreateDatapointFromFilesModelMetadataInner(metadata.to_model()) + for metadata in self.metadatas + ], + ) + + def _get_instruction(self) -> str: + return self.instruction + + def _to_payload(self, datapoint: Datapoint) -> ComparePayload: + return ComparePayload( + _t="ComparePayload", + criteria=self.instruction, + ) + + def __str__(self) -> str: + return f"MultiRankingWorkflow(instruction='{self.instruction}', metadatas={self.metadatas})" + + def __repr__(self) -> str: + return f"MultiRankingWorkflow(instruction={self.instruction!r}, comparison_budget_per_ranking={self.comparison_budget_per_ranking!r}, random_comparisons_ratio={self.random_comparisons_ratio!r}, elo_start={self.elo_start!r}, elo_k_factor={self.elo_k_factor!r}, elo_scaling_factor={self.elo_scaling_factor!r}, metadatas={self.metadatas!r})" diff --git a/src/rapidata/rapidata_client/workflow/_ranking_workflow.py b/src/rapidata/rapidata_client/workflow/_ranking_workflow.py index f0a7a56c..7704c996 100644 --- a/src/rapidata/rapidata_client/workflow/_ranking_workflow.py +++ b/src/rapidata/rapidata_client/workflow/_ranking_workflow.py @@ -26,7 +26,7 @@ class RankingWorkflow(Workflow): def __init__( self, - criteria: str, + instruction: str, total_comparison_budget: int, random_comparisons_ratio, elo_start: int = 1200, @@ -54,7 +54,7 @@ def __init__( if context: self.metadatas.append(PromptMetadata(prompt=context)) - self.criteria = criteria + self.instruction = instruction self.total_comparison_budget = total_comparison_budget self.random_comparisons_ratio = random_comparisons_ratio self.elo_start = elo_start @@ -76,13 +76,13 @@ def __init__( ) def _get_instruction(self) -> str: - return self.criteria + return self.instruction def _to_model(self) -> CompareWorkflowModel: return CompareWorkflowModel( _t="CompareWorkflow", - criteria=self.criteria, + criteria=self.instruction, eloConfig=self.elo_config, pairMakerConfig=self.pair_maker_config, metadata=[ @@ -94,7 +94,7 @@ def _to_model(self) -> CompareWorkflowModel: def _to_payload(self, datapoint: Datapoint) -> ComparePayload: return ComparePayload( _t="ComparePayload", - criteria=self.criteria, + criteria=self.instruction, ) def _format_datapoints(self, datapoints: list[Datapoint]) -> list[Datapoint]: @@ -116,9 +116,7 @@ def _format_datapoints(self, datapoints: list[Datapoint]) -> list[Datapoint]: return formatted_datapoints def __str__(self) -> str: - return ( - f"RankingWorkflow(criteria='{self.criteria}', metadatas={self.metadatas})" - ) + return f"RankingWorkflow(instruction='{self.instruction}', metadatas={self.metadatas})" def __repr__(self) -> str: - return f"RankingWorkflow(criteria={self.criteria!r}, total_comparison_budget={self.total_comparison_budget!r}, random_comparisons_ratio={self.random_comparisons_ratio!r}, elo_start={self.elo_start!r}, elo_k_factor={self.elo_k_factor!r}, elo_scaling_factor={self.elo_scaling_factor!r}, metadatas={self.metadatas!r})" + return f"RankingWorkflow(instruction={self.instruction!r}, total_comparison_budget={self.total_comparison_budget!r}, random_comparisons_ratio={self.random_comparisons_ratio!r}, elo_start={self.elo_start!r}, elo_k_factor={self.elo_k_factor!r}, elo_scaling_factor={self.elo_scaling_factor!r}, metadatas={self.metadatas!r})"