diff --git a/examples/classify_order.py b/examples/classify_order.py index 10adf4f3e..5a8332bc6 100644 --- a/examples/classify_order.py +++ b/examples/classify_order.py @@ -14,7 +14,7 @@ def new_classify_order(rapi: RapidataClient): .workflow( ClassifyWorkflow( question="What is shown in the image?", - options=["Fish", "Cat", "Wallabe", "Airplane"], + options=["Fish", "Cat", "Wallaby", "Airplane"], ) ) .media(["examples/data/wallaby.jpg"]) diff --git a/rapidata/rapidata_client/feature_flags/feature_flags.py b/rapidata/rapidata_client/feature_flags/feature_flags.py index a9b0b7f87..9d2de560e 100644 --- a/rapidata/rapidata_client/feature_flags/feature_flags.py +++ b/rapidata/rapidata_client/feature_flags/feature_flags.py @@ -7,16 +7,19 @@ def __init__(self): def to_list(self) -> list[FeatureFlagModel]: return [FeatureFlagModel(key=name, value=value) for name, value in self._flags.items()] - + def alert_on_fast_response(self, value: int): - self._flags["alertOnFastResponse"] = str(value) + self._flags["aler_on_fast_response"] = str(value) return self - - def disable_translation(self, value: bool): - self._flags["disableTranslation"] = str(value) + + def disable_translation(self, value: bool = True): + self._flags["disable_translation"] = str(value) return self - + def free_text_minimum_characters(self, value: int): - self._flags["freeTextMinimumCharacters"] = str(value) + self._flags["free_text_minimum_characters"] = str(value) + return self + + def no_shuffle(self, value: bool = True): + self._flags["no_shuffle"] = str(value) return self - \ No newline at end of file diff --git a/rapidata/rapidata_client/order/dataset/validation_set_builder.py b/rapidata/rapidata_client/order/dataset/validation_set_builder.py index 0bd77848b..beb073a24 100644 --- a/rapidata/rapidata_client/order/dataset/validation_set_builder.py +++ b/rapidata/rapidata_client/order/dataset/validation_set_builder.py @@ -128,6 +128,9 @@ def add_compare_rapid(self, media_paths: list[str], question: str, truth: str): for media_path in media_paths: if not os.path.exists(media_path): raise FileNotFoundError(f"File not found: {media_path}") + + # take only last part of truth path + truth = os.path.basename(truth) self._rapid_parts.append( ValidatioRapidParts( diff --git a/rapidata/rapidata_client/referee/classify_early_stopping_referee.py b/rapidata/rapidata_client/referee/classify_early_stopping_referee.py index 293290070..e57078351 100644 --- a/rapidata/rapidata_client/referee/classify_early_stopping_referee.py +++ b/rapidata/rapidata_client/referee/classify_early_stopping_referee.py @@ -1,4 +1,6 @@ +from typing import Any from rapidata.rapidata_client.referee.base_referee import Referee +from rapidata.api_client.models.probabilistic_attach_category_referee_config import ProbabilisticAttachCategoryRefereeConfig class ClassifyEarlyStoppingReferee(Referee): @@ -18,3 +20,10 @@ def to_dict(self): "threshold": self.threshold, "maxVotes": self.max_vote_count, } + + def to_model(self) -> Any: + return ProbabilisticAttachCategoryRefereeConfig( + _t="ProbabilisticAttachCategoryRefereeConfig", + threshold=self.threshold, + maxVotes=self.max_vote_count, + )