diff --git a/.gitignore b/.gitignore index 82f927558..ef4983ce9 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +.DS_Store + +.env* \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..70af7bbae --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,13 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.unittestEnabled": true, + "python.analysis.enablePytestSupport": false, + "python.testing.pytestEnabled": false, + "python.analysis.typeCheckingMode": "strict" +} \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 000000000..187c46116 --- /dev/null +++ b/environment.yml @@ -0,0 +1,13 @@ +name: rapi-api +channels: + - defaults +dependencies: + - python=3.12 + - ipykernel + - requests + - pandas + - matplotlib + - python-dotenv + - pyjwt + - sphinx +prefix: /Users/mari/miniconda3/envs/rapi-api diff --git a/examples/classify_order.py b/examples/classify_order.py new file mode 100644 index 000000000..ecef6074f --- /dev/null +++ b/examples/classify_order.py @@ -0,0 +1,48 @@ +import dotenv +import os +dotenv.load_dotenv() # type: ignore + +from rapidata.rapidata_client import RapidataClient +from rapidata.rapidata_client.workflow import FeatureFlags +from rapidata.rapidata_client.workflow import ClassifyWorkflow +from rapidata.rapidata_client.workflow.referee import NaiveReferee + +CLIENT_ID = os.getenv("CLIENT_ID") +CLIENT_SECRET = os.getenv("CLIENT_SECRET") +ENDPOINT = os.getenv("ENDPOINT") + +if not CLIENT_ID: + raise Exception("CLIENT_ID not found in environment variables") + +if not CLIENT_SECRET: + raise Exception("CLIENT_SECRET not found in environment variables") + +if not ENDPOINT: + raise Exception("ENDPOINT not found in environment variables") + +rapi = RapidataClient( + client_id=CLIENT_ID, client_secret=CLIENT_SECRET, endpoint=ENDPOINT +) + +# Configure order +order = ( + rapi.new_order( + name="Example Classify Order", + ) + .workflow( + ClassifyWorkflow( + question="Who should be president?", + categories=["Kamala Harris", "Donald Trump"], + ) + .referee(NaiveReferee(required_guesses=15)) + .feature_flags(FeatureFlags().alert_on_fast_response(3)) + ) + .create() +) + +# Add data +order.dataset.add_images_from_paths(["examples/data/kamala_trump.jpg"]) + +# Let's go! +order.submit() +# order.approve() admin only: if it doesn't auto approve and you want to manually approve \ No newline at end of file diff --git a/examples/data/kamala.jpg b/examples/data/kamala.jpg new file mode 100644 index 000000000..f10e364ef Binary files /dev/null and b/examples/data/kamala.jpg differ diff --git a/examples/data/kamala_trump.jpg b/examples/data/kamala_trump.jpg new file mode 100644 index 000000000..a90d7f882 Binary files /dev/null and b/examples/data/kamala_trump.jpg differ diff --git a/examples/data/trump.jpg b/examples/data/trump.jpg new file mode 100644 index 000000000..40cd4b8f3 Binary files /dev/null and b/examples/data/trump.jpg differ diff --git a/examples/free_text_input_order.py b/examples/free_text_input_order.py new file mode 100644 index 000000000..49ce64f66 --- /dev/null +++ b/examples/free_text_input_order.py @@ -0,0 +1,38 @@ +import dotenv +import os +dotenv.load_dotenv() # type: ignore + +from rapidata.rapidata_client import RapidataClient +from rapidata.rapidata_client.workflow import FeatureFlags +from rapidata.rapidata_client.workflow import FreeTextWorkflow +from rapidata.rapidata_client.workflow.country_codes import CountryCodes + +CLIENT_ID = os.getenv("CLIENT_ID") +CLIENT_SECRET = os.getenv("CLIENT_SECRET") +ENDPOINT = os.getenv("ENDPOINT") + +if not CLIENT_ID: + raise Exception("CLIENT_ID not found in environment variables") + +if not CLIENT_SECRET: + raise Exception("CLIENT_SECRET not found in environment variables") + +if not ENDPOINT: + raise Exception("ENDPOINT not found in environment variables") + +rapi = RapidataClient( + client_id=CLIENT_ID, client_secret=CLIENT_SECRET, endpoint=ENDPOINT +) + +order = rapi.new_order( + name="Example Video Free Text Order", + ).workflow( + FreeTextWorkflow( + question="Describe the movement in this video!", + ).feature_flags( + FeatureFlags().free_text_minimum_characters(15).alert_on_fast_response(5) + ).target_country_codes(CountryCodes.ENGLISH_SPEAKING) + ).create() + +order.dataset.add_videos_from_paths([""]) # TODO: insert video path +order.submit() \ No newline at end of file diff --git a/examples/ranking_order.py b/examples/ranking_order.py new file mode 100644 index 000000000..03c67e044 --- /dev/null +++ b/examples/ranking_order.py @@ -0,0 +1,38 @@ +import dotenv +import os +dotenv.load_dotenv() # type: ignore + +from rapidata.rapidata_client import RapidataClient +from rapidata.rapidata_client.workflow import CompareWorkflow + +CLIENT_ID = os.getenv("CLIENT_ID") +CLIENT_SECRET = os.getenv("CLIENT_SECRET") +ENDPOINT = os.getenv("ENDPOINT") + +if not CLIENT_ID: + raise Exception("CLIENT_ID not found in environment variables") + +if not CLIENT_SECRET: + raise Exception("CLIENT_SECRET not found in environment variables") + +if not ENDPOINT: + raise Exception("ENDPOINT not found in environment variables") + +rapi = RapidataClient( + client_id=CLIENT_ID, client_secret=CLIENT_SECRET, endpoint=ENDPOINT +) + +order = rapi.new_order( + name="Example Compare Order", +).workflow( + CompareWorkflow( + criteria="Who should be president?", + ) + .matches_until_completed(5) + .match_size(2) +).create() + +order.dataset.add_images_from_paths(["examples/data/kamala.jpg", "examples/data/trump.jpg"]) + +order.submit() +# order.approve() \ No newline at end of file diff --git a/rapidata/__init__.py b/rapidata/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rapidata/rapidata_client/__init__.py b/rapidata/rapidata_client/__init__.py new file mode 100644 index 000000000..26bf758c3 --- /dev/null +++ b/rapidata/rapidata_client/__init__.py @@ -0,0 +1 @@ +from .rapidata_client import RapidataClient as RapidataClient \ No newline at end of file diff --git a/rapidata/rapidata_client/order/__init__.py b/rapidata/rapidata_client/order/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rapidata/rapidata_client/order/dataset/__init__.py b/rapidata/rapidata_client/order/dataset/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rapidata/rapidata_client/order/dataset/rapidata_dataset.py b/rapidata/rapidata_client/order/dataset/rapidata_dataset.py new file mode 100644 index 000000000..51150eb57 --- /dev/null +++ b/rapidata/rapidata_client/order/dataset/rapidata_dataset.py @@ -0,0 +1,26 @@ +import os +from rapidata.service import LocalFileService +from rapidata.service import RapidataService + + +class RapidataDataset: + + def __init__(self, dataset_id: str, rapidata_service: RapidataService): + self.dataset_id = dataset_id + self.rapidata_service = rapidata_service + self.local_file_service = LocalFileService() + + def add_texts(self, texts: list[str]): + self.rapidata_service.dataset.upload_text_sources(self.dataset_id, texts) + + def add_images_from_paths(self, image_paths: list[str]): + image_names = [os.path.basename(image_path) for image_path in image_paths] + images = self.local_file_service.load_images(image_paths) + + self.rapidata_service.dataset.upload_images(self.dataset_id, images, image_names) + + def add_videos_from_paths(self, video_paths: list[str]): + video_names = [os.path.basename(video_path) for video_path in video_paths] + videos = self.local_file_service.load_videos(video_paths) + + self.rapidata_service.dataset.upload_videos(self.dataset_id, videos, video_names) \ No newline at end of file diff --git a/rapidata/rapidata_client/order/rapidata_order.py b/rapidata/rapidata_client/order/rapidata_order.py new file mode 100644 index 000000000..7d478f9e7 --- /dev/null +++ b/rapidata/rapidata_client/order/rapidata_order.py @@ -0,0 +1,71 @@ +from rapidata.rapidata_client.order.dataset.rapidata_dataset import RapidataDataset +from rapidata.rapidata_client.workflow import Workflow +from rapidata.service import RapidataService + + +class RapidataOrder: + """ + Represents a Rapidata order. + + :param name: The name of the order. + :type name: str + :param workflow: The workflow associated with the order. + :type workflow: Workflow + :param rapidata_service: The Rapidata service used to create and manage the order. + :type rapidata_service: RapidataService + """ + + def __init__( + self, name: str, workflow: Workflow, rapidata_service: RapidataService + ): + self.name = name + self.workflow = workflow + self.rapidata_service = rapidata_service + self.order_id = None + self._dataset = None + + def create(self): + """ + Creates the order using the provided name and workflow. + + :return: The created RapidataOrder instance. + :rtype: RapidataOrder + """ + self.order_id, dataset_id = self.rapidata_service.order.create_order(self.name, self.workflow.to_dict()) + self._dataset = RapidataDataset(dataset_id, self.rapidata_service) + return self + + def submit(self): + """ + Submits the order for processing. + + :raises ValueError: If the order has not been created. + """ + if self.order_id is None: + raise ValueError("You must create the order before submitting it.") + + self.rapidata_service.order.submit(self.order_id) + + def approve(self): + """ + Approves the order for execution. + + :raises ValueError: If the order has not been created. + """ + if self.order_id is None: + raise ValueError("You must create the order before approving it.") + + self.rapidata_service.order.approve(self.order_id) + + @property + def dataset(self): + """ + The dataset associated with the order. + + :raises ValueError: If the order has not been submitted. + :return: The RapidataDataset instance. + :rtype: RapidataDataset + """ + if self._dataset is None: + raise ValueError("You must submit the order before accessing the dataset.") + return self._dataset diff --git a/rapidata/rapidata_client/order/rapidata_order_builder.py b/rapidata/rapidata_client/order/rapidata_order_builder.py new file mode 100644 index 000000000..cad693a51 --- /dev/null +++ b/rapidata/rapidata_client/order/rapidata_order_builder.py @@ -0,0 +1,52 @@ +from rapidata.rapidata_client.workflow import Workflow +from rapidata.rapidata_client.order.rapidata_order import RapidataOrder +from rapidata.service import RapidataService + + +class RapidataOrderBuilder: + """ + Builder object for creating Rapidata orders. + + Use the fluent interface to set the desired configuration. Add a workflow to the order using `.workflow()` and finally call `.create()` to create the order. + + :param rapidata_service: The RapidataService instance. + :type rapidata_service: RapidataService + :param name: The name of the order. + :type name: str + """ + + def __init__( + self, + rapidata_service: RapidataService, + name: str, + ): + self._name = name + self._rapidata_service = rapidata_service + self._workflow: Workflow | None = None + + def create(self) -> RapidataOrder: + """ + Create a RapidataOrder instance based on the configured settings. + + :return: The created RapidataOrder instance. + :rtype: RapidataOrder + :raises ValueError: If no workflow is provided. + """ + if self._workflow is None: + raise ValueError("You must provide a blueprint to create an order.") + + return RapidataOrder( + name=self._name, workflow=self._workflow, rapidata_service=self._rapidata_service + ).create() + + def workflow(self, workflow: Workflow): + """ + Set the workflow for the order. + + :param workflow: The workflow to be set. + :type workflow: Workflow + :return: The updated RapidataOrderBuilder instance. + :rtype: RapidataOrderBuilder + """ + self._workflow = workflow + return self diff --git a/rapidata/rapidata_client/rapidata_client.py b/rapidata/rapidata_client/rapidata_client.py new file mode 100644 index 000000000..85bdbcbf2 --- /dev/null +++ b/rapidata/rapidata_client/rapidata_client.py @@ -0,0 +1,34 @@ +from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder +from rapidata.service import RapidataService + + +class RapidataClient: + """ + A client for interacting with the Rapidata API. + """ + + def __init__( + self, + client_id: str, + client_secret: str, + endpoint: str = "https://api.rapidata.ai", + ): + """ + Initialize the RapidataClient. + + :param client_id: The client ID for authentication. + :param client_secret: The client secret for authentication. + :param endpoint: The API endpoint URL. Defaults to "https://api.rapidata.ai". + """ + self._rapidata_service = RapidataService( + client_id=client_id, client_secret=client_secret, endpoint=endpoint + ) + + def new_order(self, name: str) -> RapidataOrderBuilder: + """ + Create a new order using a RapidataOrderBuilder instance. + + :param name: The name of the order. + :return: A RapidataOrderBuilder instance. + """ + return RapidataOrderBuilder(rapidata_service=self._rapidata_service, name=name) diff --git a/rapidata/rapidata_client/workflow/__init__.py b/rapidata/rapidata_client/workflow/__init__.py new file mode 100644 index 000000000..6a0603aa4 --- /dev/null +++ b/rapidata/rapidata_client/workflow/__init__.py @@ -0,0 +1,6 @@ +from .base_workflow import Workflow as Workflow +from .classify_workflow import ClassifyWorkflow as ClassifyWorkflow +from .compare_workflow import CompareWorkflow as CompareWorkflow +from .free_text_workflow import FreeTextWorkflow as FreeTextWorkflow +from .feature_flags import FeatureFlags as FeatureFlags +from .country_codes import CountryCodes as CountryCodes diff --git a/rapidata/rapidata_client/workflow/base_workflow.py b/rapidata/rapidata_client/workflow/base_workflow.py new file mode 100644 index 000000000..26436198b --- /dev/null +++ b/rapidata/rapidata_client/workflow/base_workflow.py @@ -0,0 +1,35 @@ +from abc import ABC +from typing import Any + +from rapidata.rapidata_client.workflow.feature_flags import FeatureFlags +from rapidata.rapidata_client.workflow.referee.base_referee import Referee +from rapidata.rapidata_client.workflow.referee.naive_referee import NaiveReferee + + +class Workflow(ABC): + + def __init__(self, type: str): + self._type = type + self._referee = NaiveReferee() + self._target_country_codes: list[str] = [] + self._feature_flags: FeatureFlags = FeatureFlags() + + def to_dict(self) -> dict[str, Any]: + return { + "_t": self._type, + "referee": self._referee.to_dict(), + "targetCountryCodes": self._target_country_codes, + "featureFlags": self._feature_flags.to_list(), + } + + def referee(self, referee: Referee): + self._referee = referee + return self + + def target_country_codes(self, target_country_codes: list[str]): + self._target_country_codes = target_country_codes + return self + + def feature_flags(self, feature_flags: FeatureFlags): + self._feature_flags = feature_flags + return self \ No newline at end of file diff --git a/rapidata/rapidata_client/workflow/classify_workflow.py b/rapidata/rapidata_client/workflow/classify_workflow.py new file mode 100644 index 000000000..fffd7e6d7 --- /dev/null +++ b/rapidata/rapidata_client/workflow/classify_workflow.py @@ -0,0 +1,19 @@ +from typing import Any +from rapidata.rapidata_client.workflow import Workflow + + +class ClassifyWorkflow(Workflow): + def __init__(self, question: str, categories: list[str]): + super().__init__(type="SimpleWorkflowConfig") + self._question = question + self._categories = categories + + def to_dict(self) -> dict[str, Any]: + return { + **super().to_dict(), + "blueprint": { + "_t": "ClassifyBlueprint", + "title": self._question, + "possibleCategories": self._categories, + } + } diff --git a/rapidata/rapidata_client/workflow/compare_workflow.py b/rapidata/rapidata_client/workflow/compare_workflow.py new file mode 100644 index 000000000..db4a8b574 --- /dev/null +++ b/rapidata/rapidata_client/workflow/compare_workflow.py @@ -0,0 +1,30 @@ +from typing import Any +from rapidata.rapidata_client.workflow import Workflow + + +class CompareWorkflow(Workflow): + def __init__(self, criteria: str): + super().__init__(type="CompareWorkflowConfig") + self._criteria = criteria + self._k_factor = 40 + self._match_size = 2 + self._matches_until_completed = 10 + + def to_dict(self) -> dict[str, Any]: + return { + **super().to_dict(), + "criteria": self._criteria, + } + + def k_factor(self, k_factor: int): + self._k_factor = k_factor + return self + + def match_size(self, match_size: int): + self._match_size = match_size + return self + + def matches_until_completed(self, matches_until_completed: int): + self._matches_until_completed = matches_until_completed + return self + diff --git a/rapidata/rapidata_client/workflow/country_codes/__init__.py b/rapidata/rapidata_client/workflow/country_codes/__init__.py new file mode 100644 index 000000000..7e22ffb84 --- /dev/null +++ b/rapidata/rapidata_client/workflow/country_codes/__init__.py @@ -0,0 +1 @@ +from .country_codes import CountryCodes as CountryCodes \ No newline at end of file diff --git a/rapidata/rapidata_client/workflow/country_codes/country_codes.py b/rapidata/rapidata_client/workflow/country_codes/country_codes.py new file mode 100644 index 000000000..9c7a2d414 --- /dev/null +++ b/rapidata/rapidata_client/workflow/country_codes/country_codes.py @@ -0,0 +1,19 @@ +class CountryCodes: + + ENGLISH_SPEAKING = [ + "AU", + "BE", + "CA", + "DK", + "FI", + "IE", + "LU", + "NL", + "NZ", + "NO", + "SG", + "SE", + "GB", + "US", + ] + GERMAN_SPEAKING = (["AT", "DE"],) diff --git a/rapidata/rapidata_client/workflow/feature_flags/__init__.py b/rapidata/rapidata_client/workflow/feature_flags/__init__.py new file mode 100644 index 000000000..78f782e04 --- /dev/null +++ b/rapidata/rapidata_client/workflow/feature_flags/__init__.py @@ -0,0 +1 @@ +from .feature_flags import FeatureFlags as FeatureFlags \ No newline at end of file diff --git a/rapidata/rapidata_client/workflow/feature_flags/feature_flags.py b/rapidata/rapidata_client/workflow/feature_flags/feature_flags.py new file mode 100644 index 000000000..ff955a63d --- /dev/null +++ b/rapidata/rapidata_client/workflow/feature_flags/feature_flags.py @@ -0,0 +1,20 @@ +class FeatureFlags: + def __init__(self): + self._flags: dict[str, str] = {} + + def to_list(self) -> list[dict[str, str]]: + # transform dict of flags to list of flags + return [{"key": name, "value": value} for name, value in self._flags.items()] + + def alert_on_fast_response(self, value: int): + self._flags["alertOnFastResponse"] = str(value) + return self + + def disable_translation(self, value: bool): + self._flags["disableTranslation"] = str(value) + return self + + def free_text_minimum_characters(self, value: int): + self._flags["freeTextMinimumCharacters"] = str(value) + return self + \ No newline at end of file diff --git a/rapidata/rapidata_client/workflow/free_text_workflow.py b/rapidata/rapidata_client/workflow/free_text_workflow.py new file mode 100644 index 000000000..d0ae3d64d --- /dev/null +++ b/rapidata/rapidata_client/workflow/free_text_workflow.py @@ -0,0 +1,17 @@ +from typing import Any +from rapidata.rapidata_client.workflow import Workflow + + +class FreeTextWorkflow(Workflow): + def __init__(self, question: str): + super().__init__(type="SimpleWorkflowConfig") + self._question = question + + def to_dict(self) -> dict[str, Any]: + return { + **super().to_dict(), + "blueprint": { + "_t": "FreeTextBlueprint", + "question": self._question, + }, + } diff --git a/rapidata/rapidata_client/workflow/referee/__init__.py b/rapidata/rapidata_client/workflow/referee/__init__.py new file mode 100644 index 000000000..28aa50cb2 --- /dev/null +++ b/rapidata/rapidata_client/workflow/referee/__init__.py @@ -0,0 +1,3 @@ +from .base_referee import Referee as Referee +from .naive_referee import NaiveReferee as NaiveReferee +from .classify_early_stopping_referee import ClassifyEarlyStoppingReferee as ClassifyEarlyStoppingReferee diff --git a/rapidata/rapidata_client/workflow/referee/base_referee.py b/rapidata/rapidata_client/workflow/referee/base_referee.py new file mode 100644 index 000000000..898e1944f --- /dev/null +++ b/rapidata/rapidata_client/workflow/referee/base_referee.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod +from typing import Mapping + + +class Referee(ABC): + """ + The referee defines when a rapid is considered complete. + """ + @abstractmethod + def to_dict(self) -> Mapping[str, str | int | float]: + """ + Convert the referee to a referee configuration dict. + """ + pass \ No newline at end of file diff --git a/rapidata/rapidata_client/workflow/referee/classify_early_stopping_referee.py b/rapidata/rapidata_client/workflow/referee/classify_early_stopping_referee.py new file mode 100644 index 000000000..0e9676b34 --- /dev/null +++ b/rapidata/rapidata_client/workflow/referee/classify_early_stopping_referee.py @@ -0,0 +1,20 @@ +from rapidata.rapidata_client.workflow.referee.base_referee import Referee + + +class ClassifyEarlyStoppingReferee(Referee): + """ + The referee defines when a task is considered complete. + The EarlyStoppingReferee stops the task when the confidence in the winning category is above a threshold. + The threshold behaves logarithmically, i.e. if 0.99 stops too early, try 0.999 or 0.9999. + """ + + def __init__(self, threshold: float = 0.999, max_vote_count: int = 100): + self.threshold = threshold + self.max_vote_count = max_vote_count + + def to_dict(self): + return { + "_t": "ProbabilisticAttachCategoryRefereeConfig", + "threshold": self.threshold, + "maxVotes": self.max_vote_count, + } diff --git a/rapidata/rapidata_client/workflow/referee/naive_referee.py b/rapidata/rapidata_client/workflow/referee/naive_referee.py new file mode 100644 index 000000000..3528aaca7 --- /dev/null +++ b/rapidata/rapidata_client/workflow/referee/naive_referee.py @@ -0,0 +1,18 @@ +from rapidata.rapidata_client.workflow.referee.base_referee import Referee + + +class NaiveReferee(Referee): + """ + The referee defines when a task is considered complete. + The SimpleReferee is the simplest referee, requiring a fixed number of guesses. + """ + + def __init__(self, required_guesses: int = 10): + super().__init__() + self.required_guesses = required_guesses + + def to_dict(self): + return { + "_t": "NaiveRefereeConfig", + "guessesRequired": self.required_guesses, + } diff --git a/rapidata/service/__init__.py b/rapidata/service/__init__.py new file mode 100644 index 000000000..90b5f1955 --- /dev/null +++ b/rapidata/service/__init__.py @@ -0,0 +1,2 @@ +from .rapidata_api_services.rapidata_service import RapidataService as RapidataService +from .local_file_service import LocalFileService as LocalFileService \ No newline at end of file diff --git a/rapidata/service/local_file_service.py b/rapidata/service/local_file_service.py new file mode 100644 index 000000000..8b726c4db --- /dev/null +++ b/rapidata/service/local_file_service.py @@ -0,0 +1,25 @@ +import os +from PIL import Image + +class LocalFileService: + + def load_image(self, image_path: str) -> Image.Image: + self.check_file_exists(image_path) + return Image.open(image_path) + + def load_images(self, image_paths: list[str]) -> list[Image.Image]: + return [self.load_image(image_path) for image_path in image_paths] + + def load_video(self, video_path: str): + self.check_file_exists(video_path) + return open(video_path, 'rb') + + def load_videos(self, video_paths: list[str]): + return [self.load_video(video_path) for video_path in video_paths] + + def _file_exists(self, file_path: str) -> bool: + return os.path.exists(file_path) + + def check_file_exists(self, file_path: str): + if not self._file_exists(file_path): + raise FileNotFoundError(f"File {file_path} not found.") \ No newline at end of file diff --git a/rapidata/service/rapidata_api_services/__init__.py b/rapidata/service/rapidata_api_services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rapidata/service/rapidata_api_services/base_service.py b/rapidata/service/rapidata_api_services/base_service.py new file mode 100644 index 000000000..d91222c16 --- /dev/null +++ b/rapidata/service/rapidata_api_services/base_service.py @@ -0,0 +1,76 @@ +from datetime import datetime, timedelta +from typing import Any +import jwt +import requests + + +class BaseRapidataAPIService: + + def __init__(self, client_id: str, client_secret: str, endpoint: str): + self.client_id = client_id + self.client_secret = client_secret + self.endpoint = endpoint + self.auth_header = None + self.token = self._get_auth_token() + + def _check_response(self, response: requests.Response): + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} - {response.text}") + + def _get_new_auth_token_if_outdated(self): + if not self.token or not self._is_token_valid(): + self._get_auth_token() + + def _is_token_valid(self, expiration_threshold: timedelta = timedelta(minutes=5)): + try: + payload = jwt.decode(self.token, options={"verify_signature": False}) # type: ignore + exp_timestamp = payload.get("exp") + if exp_timestamp: + expiration_time = datetime.fromtimestamp(exp_timestamp) + return datetime.now() + expiration_threshold <= expiration_time + except jwt.DecodeError: + return False + return False + + def _get_auth_token(self): + url = f"{self.endpoint}/Identity/GetClientAuthToken" + params = { + "clientId": self.client_id, + } + headers = {"Authorization": f"Basic {self.client_secret}"} + response = requests.post(url, params=params, headers=headers) + self._check_response(response) + self.token = response.json().get("authToken") + if not self.token: + raise Exception("No token received") + self.auth_header = {"Authorization": f"Bearer {self.token}"} + + def _post( + self, + url: str, + params: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, + json: dict[str, Any] | None = None, + files: Any | None = None, + ): + self._get_new_auth_token_if_outdated() + response = requests.post( + url, + params=params, + data=data, + json=json, + files=files, + headers=self.auth_header, + ) + self._check_response(response) + return response + + def _get( + self, + url: str, + params: dict[str, Any] | None = None, + ): + self._get_new_auth_token_if_outdated() + response = requests.get(url, params=params, headers=self.auth_header) + self._check_response(response) + return response diff --git a/rapidata/service/rapidata_api_services/dataset_service.py b/rapidata/service/rapidata_api_services/dataset_service.py new file mode 100644 index 000000000..5e2de1623 --- /dev/null +++ b/rapidata/service/rapidata_api_services/dataset_service.py @@ -0,0 +1,82 @@ +from io import BufferedReader +from PIL import Image +from rapidata.service.rapidata_api_services.base_service import BaseRapidataAPIService +from rapidata.utils.image_utils import ImageUtils + + +class DatasetService(BaseRapidataAPIService): + def __init__(self, client_id: str, client_secret: str, endpoint: str): + super().__init__( + client_id=client_id, client_secret=client_secret, endpoint=endpoint + ) + + def upload_text_sources(self, dataset_id: str, text_sources: list[str]): + url = f"{self.endpoint}/Dataset/UploadTextSourcesToDataset" + payload = {"datasetId": dataset_id, "textSources": text_sources} + + response = self._post(url, json=payload) + + return response + + def upload_images( + self, dataset_id: str, images: list[Image.Image], image_names: list[str] + ): + url = f"{self.endpoint}/Dataset/UploadImagesToDataset" + + params = {"datasetId": dataset_id} + + images_bytes: list[bytes] = [ + ImageUtils.convert_PIL_image_to_bytes(image) for image in images + ] + + files = [ + ("files", (image_name, image_bytes)) + for image_name, image_bytes in zip(image_names, images_bytes) + ] + + response = self._post(url, params=params, files=files) + + return response + + def upload_videos( + self, dataset_id: str, videos: list[BufferedReader], video_names: list[str] + ): + url = f"{self.endpoint}/Dataset/UploadImagesToDataset" + + params = {"datasetId": dataset_id} + + files = [ + ("files", (video_name, video_bytes)) + for video_name, video_bytes in zip(video_names, videos) + ] + + response = self._post(url, params=params, files=files) + + return response + + def upload_images_from_s3( + self, + dataset_id: str, + bucket_name: str, + region: str, + source_prefix: str, + access_key: str, + secret_key: str, + clear_dataset: bool = True, + ): + url = f"{self.endpoint}/Dataset/UploadFilesFromS3" + + payload = { + "datasetId": dataset_id, + "bucketName": bucket_name, + "region": region, + "sourcePrefix": source_prefix, + "accessKey": access_key, + "secretKey": secret_key, + "useCustomAwsCredentials": True, + "clearDataset": clear_dataset, + } + + response = self._post(url, json=payload) + + return response diff --git a/rapidata/service/rapidata_api_services/order_service.py b/rapidata/service/rapidata_api_services/order_service.py new file mode 100644 index 000000000..0f13c145f --- /dev/null +++ b/rapidata/service/rapidata_api_services/order_service.py @@ -0,0 +1,49 @@ + +from typing import Any +from rapidata.service.rapidata_api_services.base_service import BaseRapidataAPIService + + +class OrderService(BaseRapidataAPIService): + def __init__(self, client_id: str, client_secret: str, endpoint: str): + super().__init__( + client_id=client_id, client_secret=client_secret, endpoint=endpoint + ) + + def create_order(self, name: str, workflow_config: dict[str, Any]) -> tuple[str, str]: + """ + order_name: name of the order that will be displayed in the Rapidata dashboard. + question: The question shown to the labeler in the rapid. + categories: The answer options, between which the labeler can choose. + target_country_codes: A list of two digit target country codes. + disable_translation: Per default, the question and categories get translated with DeepL (or Google Translate, if DeepL doesn't support a language). By setting this to `True`, the translation is disabled. + referee: The referee determines when the task is done. See above for the options. + """ + url = f"{self.endpoint}/Order/CreateDefaultOrder" + + payload = { + "orderName": name, + "datasetName": f"{name} dataset", + "isPublic": False, + "workflowConfig": workflow_config, + "aggregatorType": "Classification", + } + + response = self._post(url, json=payload) + + return response.json()["orderId"], response.json()["datasetId"] + + def submit(self, order_id: str): + url = f"{self.endpoint}/Order/Submit" + params = {"orderId": order_id} + + submit_response = self._post(url, params=params) + + return submit_response + + def approve(self, order_id: str): + url = f"{self.endpoint}/Order/Approve" + params = {"orderId": order_id} + + approve_response = self._post(url, params=params) + + return approve_response diff --git a/rapidata/service/rapidata_api_services/rapidata_service.py b/rapidata/service/rapidata_api_services/rapidata_service.py new file mode 100644 index 000000000..0aea79364 --- /dev/null +++ b/rapidata/service/rapidata_api_services/rapidata_service.py @@ -0,0 +1,18 @@ +from rapidata.service.rapidata_api_services.base_service import BaseRapidataAPIService +from rapidata.service.rapidata_api_services.dataset_service import DatasetService +from rapidata.service.rapidata_api_services.order_service import OrderService + + +class RapidataService(BaseRapidataAPIService): + def __init__(self, client_id: str, client_secret: str, endpoint: str): + super().__init__(client_id, client_secret, endpoint) + self._order_service = OrderService(client_id, client_secret, endpoint) + self._dataset_service = DatasetService(client_id, client_secret, endpoint) + + @property + def order(self): + return self._order_service + + @property + def dataset(self): + return self._dataset_service diff --git a/rapidata/utils/__init__.py b/rapidata/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rapidata/utils/image_utils.py b/rapidata/utils/image_utils.py new file mode 100644 index 000000000..e46ff8364 --- /dev/null +++ b/rapidata/utils/image_utils.py @@ -0,0 +1,13 @@ +from io import BytesIO +import PIL.Image as Image + +class ImageUtils: + + @staticmethod + def convert_PIL_image_to_bytes(image: Image.Image): + """ + Convert a PIL image to bytes with meta data encoded. We can't just use image.tobytes() because this only returns the pixel data. + """ + buffer = BytesIO() + image.save(buffer, image.format) + return buffer.getvalue() \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/order/__init__.py b/tests/order/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/order/test_order.py b/tests/order/test_order.py new file mode 100644 index 000000000..d598f044c --- /dev/null +++ b/tests/order/test_order.py @@ -0,0 +1,23 @@ +import unittest +from unittest.mock import Mock + +from rapidata.rapidata_client.order.rapidata_order import RapidataOrder + +class TestOrder(unittest.TestCase): + + def setUp(self): + self.rapidata_service = Mock() + self.rapidata_service.order = Mock() + self.rapidata_service.order.create_order.return_value = ("order_id", "dataset_id") + self.workflow = Mock() + # create mock to_dict function + self.workflow.to_dict.return_value = {"workflow": "data"} + self.order_name = "test order" + + def test_submit(self): + order = RapidataOrder(self.order_name, self.workflow, self.rapidata_service).create() + order.submit() + + self.rapidata_service.order.create_order.assert_called_with(self.order_name, self.workflow.to_dict()) + self.rapidata_service.order.submit.assert_called_with(order.order_id) + self.assertEqual(order.order_id, "order_id") diff --git a/tests/order/test_order_builder.py b/tests/order/test_order_builder.py new file mode 100644 index 000000000..51c39d238 --- /dev/null +++ b/tests/order/test_order_builder.py @@ -0,0 +1,33 @@ +import unittest +from unittest.mock import Mock + +from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder +from rapidata.rapidata_client.workflow import ClassifyWorkflow + + +class TestOrderBuilder(unittest.TestCase): + def setUp(self): + self.rapidata_service = Mock() + self.rapidata_service.order = Mock() + self.rapidata_service.order.create_order.return_value = ("order_id", "dataset_id") + + def test_raise_error_if_no_workflow(self): + + with self.assertRaises(ValueError): + RapidataOrderBuilder(rapidata_service=self.rapidata_service, name="Test Order").create() + + def test_basic_order_build(self): + order = ( + RapidataOrderBuilder( + rapidata_service=self.rapidata_service, name="Test Order" + ) + .workflow(ClassifyWorkflow(question="Test Question?", categories=["Yes", "No"])) + .create() + ) + + self.assertEqual(order.name, "Test Order") + self.assertIsInstance(order.workflow, ClassifyWorkflow) + + self.assertEqual(order.workflow._question, "Test Question?") # type: ignore + self.assertEqual(order.workflow._categories, ["Yes", "No"]) # type: ignore + \ No newline at end of file