-
Notifications
You must be signed in to change notification settings - Fork 5
Remove extra validation checks and add some tests for detect and evaluate #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
2bfa525
Remove extra validation checks
alexlyzhov 7574117
Add tests for detect and evaluate
alexlyzhov dacaddc
Update run.sh for running tests
alexlyzhov 1521e02
Add tests for Analyze functions
alexlyzhov 14d5a84
Remove AnalyzeProd, AnalyzeEval
alexlyzhov b5db64f
Bump version
alexlyzhov 734356d
Add low-level API tests
alexlyzhov 90d8a96
Update the version
alexlyzhov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. | ||
|
|
||
| __title__ = "aimon" | ||
| __version__ = "0.9.2" | ||
| __version__ = "0.9.3" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -225,8 +225,6 @@ def evaluate( | |
| # Validata headers to be non-empty and contain atleast the context_docs column | ||
| if not headers: | ||
| raise ValueError("Headers must be a non-empty list") | ||
| if "context_docs" not in headers: | ||
| raise ValueError("Headers must contain the column 'context_docs'") | ||
|
|
||
| # Create application and models | ||
| am_app = client.applications.create( | ||
|
|
@@ -276,287 +274,22 @@ def evaluate( | |
| if ag not in record: | ||
| raise ValueError("Dataset record must contain the column '{}' as specified in the 'headers'" | ||
| " argument in the decorator".format(ag)) | ||
|
|
||
| if "context_docs" not in record: | ||
| raise ValueError("Dataset record must contain the column 'context_docs'") | ||
|
|
||
| _context = record['context_docs'] if isinstance(record['context_docs'], list) else [record['context_docs']] | ||
| # Construct the payload for the analysis | ||
| payload = { | ||
| **record, | ||
| "config": config, | ||
| "application_id": am_app.id, | ||
| "version": am_app.version, | ||
| "context_docs": [d for d in _context], | ||
| "evaluation_id": am_eval.id, | ||
| "evaluation_run_id": eval_run.id, | ||
| } | ||
| if "prompt" in record and record["prompt"]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when async is True, we still need to handle the validation but we can do that in the backend service
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like we have this handled in the backend 👍 |
||
| payload["prompt"] = record["prompt"] | ||
| if "user_query" in record and record["user_query"]: | ||
| payload["user_query"] = record["user_query"] | ||
| if "output" in record and record["output"]: | ||
| payload["output"] = record["output"] | ||
| if "instruction_adherence" in config and "instructions" not in record: | ||
| raise ValueError("When instruction_adherence is specified in the config, " | ||
| "'instructions' must be present in the dataset") | ||
| if "instructions" in record and "instruction_adherence" in config: | ||
| # Only pass instructions if instruction_adherence is specified in the config | ||
| payload["instructions"] = record["instructions"] or "" | ||
|
|
||
| if "retrieval_relevance" in config: | ||
| if "task_definition" in record: | ||
| payload["task_definition"] = record["task_definition"] | ||
| else: | ||
| raise ValueError( "When retrieval_relevance is specified in the config, " | ||
| "'task_definition' must be present in the dataset") | ||
|
|
||
| payload["config"] = config | ||
| if "instructions" in payload and not payload["instructions"]: | ||
| payload["instructions"] = "" | ||
|
|
||
| results.append(EvaluateResponse(record['output'], client.analyze.create(body=[payload]))) | ||
|
|
||
| return results | ||
|
|
||
| class AnalyzeBase: | ||
| DEFAULT_CONFIG = {'hallucination': {'detector_name': 'default'}} | ||
|
|
||
| def __init__(self, application, model, api_key=None, config=None): | ||
| """ | ||
| :param application: An Application object | ||
| :param model: A Model object | ||
| :param api_key: The API key to use for the Aimon client | ||
| """ | ||
| self.client = Client(auth_header="Bearer {}".format(api_key)) | ||
| self.application = application | ||
| self.model = model | ||
| self.config = config if config else self.DEFAULT_CONFIG | ||
| self.initialize() | ||
|
|
||
| def initialize(self): | ||
| # Create or retrieve the model | ||
| self._am_model = self.client.models.create( | ||
| name=self.model.name, | ||
| type=self.model.model_type, | ||
| description="This model is named {} and is of type {}".format(self.model.name, self.model.model_type), | ||
| metadata=self.model.metadata | ||
| ) | ||
|
|
||
| # Create or retrieve the application | ||
| self._am_app = self.client.applications.create( | ||
| name=self.application.name, | ||
| model_name=self._am_model.name, | ||
| stage=self.application.stage, | ||
| type=self.application.type, | ||
| metadata=self.application.metadata | ||
| ) | ||
|
|
||
|
|
||
| class AnalyzeEval(AnalyzeBase): | ||
|
|
||
| def __init__(self, application, model, evaluation_name, dataset_collection_name, headers, | ||
| api_key=None, eval_tags=None, config=None): | ||
| """ | ||
| The wrapped function should have a signature as follows: | ||
| def func(context_docs, user_query, prompt, instructions *args, **kwargs): | ||
| # Your code here | ||
| return output | ||
| [Required] The first argument must be a 'context_docs' which is of type List[str]. | ||
| [Required] The second argument must be a 'user_query' which is of type str. | ||
| [Optional] The third argument must be a 'prompt' which is of type str | ||
| [Optional] If an 'instructions' column is present in the dataset, then the fourth argument | ||
| must be 'instructions' which is of type str | ||
| [Optional] If an 'output' column is present in the dataset, then the fifth argument | ||
| must be 'output' which is of type str | ||
| Return: The function must return an output which is of type str | ||
|
|
||
| :param application: An Application object | ||
| :param model: A Model object | ||
| :param evaluation_name: The name of the evaluation | ||
| :param dataset_collection_name: The name of the dataset collection | ||
| :param headers: A list containing the headers to be used for the evaluation | ||
| :param api_key: The API key to use for the AIMon client | ||
| :param eval_tags: A list of tags to associate with the evaluation | ||
| :param config: A dictionary containing the AIMon configuration for the evaluation | ||
|
|
||
|
|
||
| """ | ||
| super().__init__(application, model, api_key, config) | ||
| warnings.warn( | ||
| f"{self.__class__.__name__} is deprecated and will be removed in a later release. Please use the evaluate method instead.", | ||
| DeprecationWarning, | ||
| stacklevel=2 | ||
| ) | ||
| self.headers = headers | ||
| self.evaluation_name = evaluation_name | ||
| self.dataset_collection_name = dataset_collection_name | ||
| self.eval_tags = eval_tags | ||
| self.eval_initialize() | ||
|
|
||
| def eval_initialize(self): | ||
| if self.dataset_collection_name is None: | ||
| raise ValueError("Dataset collection name must be provided for running an evaluation.") | ||
|
|
||
| # Create or retrieve the dataset collection | ||
| self._am_dataset_collection = self.client.datasets.collection.retrieve(name=self.dataset_collection_name) | ||
|
|
||
| # Create or retrieve the evaluation | ||
| self._eval = self.client.evaluations.create( | ||
| name=self.evaluation_name, | ||
| application_id=self._am_app.id, | ||
| model_id=self._am_model.id, | ||
| dataset_collection_id=self._am_dataset_collection.id | ||
| ) | ||
|
|
||
| def _run_eval(self, func, args, kwargs): | ||
| # Create an evaluation run | ||
| eval_run = self.client.evaluations.run.create( | ||
| evaluation_id=self._eval.id, | ||
| metrics_config=self.config, | ||
| ) | ||
| # Get all records from the datasets | ||
| dataset_collection_records = [] | ||
| for dataset_id in self._am_dataset_collection.dataset_ids: | ||
| dataset_records = self.client.datasets.records.list(sha=dataset_id) | ||
| dataset_collection_records.extend(dataset_records) | ||
| results = [] | ||
| for record in dataset_collection_records: | ||
| # The record must contain the context_docs and user_query fields. | ||
| # The prompt, output and instructions fields are optional. | ||
| # Inspect the record and call the function with the appropriate arguments | ||
| arguments = [] | ||
| for ag in self.headers: | ||
| if ag not in record: | ||
| raise ValueError("Record must contain the column '{}' as specified in the 'headers'" | ||
| " argument in the decorator".format(ag)) | ||
| arguments.append(record[ag]) | ||
| # Inspect the function signature to ensure that it accepts the correct arguments | ||
| sig = inspect.signature(func) | ||
| params = sig.parameters | ||
| if len(params) < len(arguments): | ||
| raise ValueError("Function must accept at least {} arguments".format(len(arguments))) | ||
| # Ensure that the first len(arguments) parameters are named correctly | ||
| param_names = list(params.keys()) | ||
| if param_names[:len(arguments)] != self.headers: | ||
| raise ValueError("Function arguments must be named as specified by the 'headers' argument: {}".format( | ||
| self.headers)) | ||
|
|
||
| result = func(*arguments, *args, **kwargs) | ||
| _context = record['context_docs'] if isinstance(record['context_docs'], list) else [record['context_docs']] | ||
| payload = { | ||
| "application_id": self._am_app.id, | ||
| "version": self._am_app.version, | ||
| "prompt": record['prompt'] or "", | ||
| "user_query": record['user_query'] or "", | ||
| "context_docs": [d for d in _context], | ||
| "output": result, | ||
| "evaluation_id": self._eval.id, | ||
| "evaluation_run_id": eval_run.id, | ||
| } | ||
| if "instruction_adherence" in self.config and "instructions" not in record: | ||
| raise ValueError("When instruction_adherence is specified in the config, " | ||
| "'instructions' must be present in the dataset") | ||
| if "instructions" in record and "instruction_adherence" in self.config: | ||
| # Only pass instructions if instruction_adherence is specified in the config | ||
| payload["instructions"] = record["instructions"] or "" | ||
|
|
||
| if "retrieval_relevance" in self.config: | ||
| if "task_definition" in record: | ||
| payload["task_definition"] = record["task_definition"] | ||
| else: | ||
| raise ValueError( "When retrieval_relevance is specified in the config, " | ||
| "'task_definition' must be present in the dataset") | ||
|
|
||
| payload["config"] = self.config | ||
| results.append((result, self.client.analyze.create(body=[payload]))) | ||
| return results | ||
|
|
||
| def __call__(self, func): | ||
| @wraps(func) | ||
| def wrapper(*args, **kwargs): | ||
| return self._run_eval(func, args, kwargs) | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
||
| class AnalyzeProd(AnalyzeBase): | ||
|
|
||
| def __init__(self, application, model, values_returned, api_key=None, config=None): | ||
| """ | ||
| The wrapped function should return a tuple of values in the order specified by values_returned. In addition, | ||
| the wrapped function should accept a parameter named eval_obj which will be used when using this decorator | ||
| in evaluation mode. | ||
|
|
||
| :param application: An Application object | ||
| :param model: A Model object | ||
| :param values_returned: A list of values in the order returned by the decorated function | ||
| Acceptable values are 'generated_text', 'context', 'user_query', 'instructions' | ||
| """ | ||
| application.stage = "production" | ||
| super().__init__(application, model, api_key, config) | ||
| warnings.warn( | ||
| f"{self.__class__.__name__} is deprecated and will be removed in a later release. Please use Detect with async=True instead.", | ||
| DeprecationWarning, | ||
| stacklevel=2 | ||
| ) | ||
| self.values_returned = values_returned | ||
| if self.values_returned is None or len(self.values_returned) == 0: | ||
| raise ValueError("Values returned by the decorated function must be specified") | ||
| if "generated_text" not in self.values_returned: | ||
| raise ValueError("values_returned must contain 'generated_text'") | ||
| if "context" not in self.values_returned: | ||
| raise ValueError("values_returned must contain 'context'") | ||
| if "instruction_adherence" in self.config and "instructions" not in self.values_returned: | ||
| raise ValueError( | ||
| "When instruction_adherence is specified in the config, 'instructions' must be returned by the decorated function") | ||
|
|
||
| if "retrieval_relevance" in self.config and "task_definition" not in self.values_returned: | ||
| raise ValueError( "When retrieval_relevance is specified in the config, " | ||
| "'task_definition' must be returned by the decorated function") | ||
|
|
||
| if "instructions" in self.values_returned and "instruction_adherence" not in self.config: | ||
| raise ValueError( | ||
| "instruction_adherence must be specified in the config for returning 'instructions' by the decorated function") | ||
| self.config = config if config else self.DEFAULT_CONFIG | ||
|
|
||
| def _run_production_analysis(self, func, args, kwargs): | ||
| result = func(*args, **kwargs) | ||
| if result is None: | ||
| raise ValueError("Result must be returned by the decorated function") | ||
| # Handle the case where the result is a single value | ||
| if not isinstance(result, tuple): | ||
| result = (result,) | ||
|
|
||
| # Create a dictionary mapping output names to results | ||
| result_dict = {name: value for name, value in zip(self.values_returned, result)} | ||
|
|
||
| if "generated_text" not in result_dict: | ||
| raise ValueError("Result of the wrapped function must contain 'generated_text'") | ||
| if "context" not in result_dict: | ||
| raise ValueError("Result of the wrapped function must contain 'context'") | ||
| _context = result_dict['context'] if isinstance(result_dict['context'], list) else [result_dict['context']] | ||
| aimon_payload = { | ||
| "application_id": self._am_app.id, | ||
| "version": self._am_app.version, | ||
| "output": result_dict['generated_text'], | ||
| "context_docs": _context, | ||
| "user_query": result_dict["user_query"] if 'user_query' in result_dict else "No User Query Specified", | ||
| "prompt": result_dict['prompt'] if 'prompt' in result_dict else "No Prompt Specified", | ||
| } | ||
| if 'instructions' in result_dict: | ||
| aimon_payload['instructions'] = result_dict['instructions'] | ||
| if 'actual_request_timestamp' in result_dict: | ||
| aimon_payload["actual_request_timestamp"] = result_dict['actual_request_timestamp'] | ||
| if 'task_definition' in result_dict: | ||
| aimon_payload['task_definition'] = result_dict['task_definition'] | ||
|
|
||
| aimon_payload['config'] = self.config | ||
| aimon_response = self.client.analyze.create(body=[aimon_payload]) | ||
| return result + (aimon_response,) | ||
|
|
||
| def __call__(self, func): | ||
| @wraps(func) | ||
| def wrapper(*args, **kwargs): | ||
| # Production mode, run the provided args through the user function | ||
| return self._run_production_analysis(func, args, kwargs) | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets do 0.10.0 since we are including the new IA detector