diff --git a/aimon/__init__.py b/aimon/__init__.py index 7dac827..da70dfd 100644 --- a/aimon/__init__.py +++ b/aimon/__init__.py @@ -82,4 +82,4 @@ pass from .decorators.detect import Detect -from .decorators.evaluate import AnalyzeEval, AnalyzeProd, Application, Model, evaluate, EvaluateResponse +from .decorators.evaluate import Application, Model, evaluate, EvaluateResponse diff --git a/aimon/_version.py b/aimon/_version.py index 886c204..ded9082 100644 --- a/aimon/_version.py +++ b/aimon/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "aimon" -__version__ = "0.9.2" +__version__ = "0.10.0" diff --git a/aimon/decorators/detect.py b/aimon/decorators/detect.py index a40a9d6..99c7acf 100644 --- a/aimon/decorators/detect.py +++ b/aimon/decorators/detect.py @@ -151,10 +151,8 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False, self.client = Client(auth_header="Bearer {}".format(api_key)) self.config = config if config else self.DEFAULT_CONFIG self.values_returned = values_returned - if self.values_returned is None or len(self.values_returned) == 0: - raise ValueError("values_returned by the decorated function must be specified") - if "context" not in self.values_returned: - raise ValueError("values_returned must contain 'context'") + if self.values_returned is None or not hasattr(self.values_returned, '__iter__') or len(self.values_returned) == 0: + raise ValueError("values_returned must be specified and be an iterable") self.async_mode = async_mode self.publish = publish if self.async_mode: @@ -178,29 +176,7 @@ def wrapper(*args, **kwargs): result = (result,) # Create a dictionary mapping output names to results - result_dict = {name: value for name, value in zip(self.values_returned, result)} - - aimon_payload = {} - if 'generated_text' in result_dict: - aimon_payload['generated_text'] = result_dict['generated_text'] - else: - raise ValueError("Result of the wrapped function must contain 'generated_text'") - if 'context' in result_dict: - aimon_payload['context'] = result_dict['context'] - else: - raise ValueError("Result of the wrapped function must contain 'context'") - if 'user_query' in result_dict: - aimon_payload['user_query'] = result_dict['user_query'] - if 'instructions' in result_dict: - aimon_payload['instructions'] = result_dict['instructions'] - - if 'retrieval_relevance' in self.config: - if 'task_definition' in result_dict: - aimon_payload['task_definition'] = result_dict['task_definition'] - else: - raise ValueError( "When retrieval_relevance is specified in the config, " - "'task_definition' must be present in the result of the wrapped function.") - + aimon_payload = {name: value for name, value in zip(self.values_returned, result)} aimon_payload['config'] = self.config aimon_payload['publish'] = self.publish diff --git a/aimon/decorators/evaluate.py b/aimon/decorators/evaluate.py index 0e84f84..47b7c40 100644 --- a/aimon/decorators/evaluate.py +++ b/aimon/decorators/evaluate.py @@ -225,8 +225,6 @@ def evaluate( # Validata headers to be non-empty and contain atleast the context_docs column if not headers: raise ValueError("Headers must be a non-empty list") - if "context_docs" not in headers: - raise ValueError("Headers must contain the column 'context_docs'") # Create application and models am_app = client.applications.create( @@ -276,287 +274,22 @@ def evaluate( if ag not in record: raise ValueError("Dataset record must contain the column '{}' as specified in the 'headers'" " argument in the decorator".format(ag)) - - if "context_docs" not in record: - raise ValueError("Dataset record must contain the column 'context_docs'") - _context = record['context_docs'] if isinstance(record['context_docs'], list) else [record['context_docs']] # Construct the payload for the analysis payload = { + **record, + "config": config, "application_id": am_app.id, "version": am_app.version, - "context_docs": [d for d in _context], "evaluation_id": am_eval.id, "evaluation_run_id": eval_run.id, } - if "prompt" in record and record["prompt"]: - payload["prompt"] = record["prompt"] - if "user_query" in record and record["user_query"]: - payload["user_query"] = record["user_query"] - if "output" in record and record["output"]: - payload["output"] = record["output"] - if "instruction_adherence" in config and "instructions" not in record: - raise ValueError("When instruction_adherence is specified in the config, " - "'instructions' must be present in the dataset") - if "instructions" in record and "instruction_adherence" in config: - # Only pass instructions if instruction_adherence is specified in the config - payload["instructions"] = record["instructions"] or "" - - if "retrieval_relevance" in config: - if "task_definition" in record: - payload["task_definition"] = record["task_definition"] - else: - raise ValueError( "When retrieval_relevance is specified in the config, " - "'task_definition' must be present in the dataset") - - payload["config"] = config + if "instructions" in payload and not payload["instructions"]: + payload["instructions"] = "" + results.append(EvaluateResponse(record['output'], client.analyze.create(body=[payload]))) return results -class AnalyzeBase: - DEFAULT_CONFIG = {'hallucination': {'detector_name': 'default'}} - - def __init__(self, application, model, api_key=None, config=None): - """ - :param application: An Application object - :param model: A Model object - :param api_key: The API key to use for the Aimon client - """ - self.client = Client(auth_header="Bearer {}".format(api_key)) - self.application = application - self.model = model - self.config = config if config else self.DEFAULT_CONFIG - self.initialize() - - def initialize(self): - # Create or retrieve the model - self._am_model = self.client.models.create( - name=self.model.name, - type=self.model.model_type, - description="This model is named {} and is of type {}".format(self.model.name, self.model.model_type), - metadata=self.model.metadata - ) - - # Create or retrieve the application - self._am_app = self.client.applications.create( - name=self.application.name, - model_name=self._am_model.name, - stage=self.application.stage, - type=self.application.type, - metadata=self.application.metadata - ) - - -class AnalyzeEval(AnalyzeBase): - - def __init__(self, application, model, evaluation_name, dataset_collection_name, headers, - api_key=None, eval_tags=None, config=None): - """ - The wrapped function should have a signature as follows: - def func(context_docs, user_query, prompt, instructions *args, **kwargs): - # Your code here - return output - [Required] The first argument must be a 'context_docs' which is of type List[str]. - [Required] The second argument must be a 'user_query' which is of type str. - [Optional] The third argument must be a 'prompt' which is of type str - [Optional] If an 'instructions' column is present in the dataset, then the fourth argument - must be 'instructions' which is of type str - [Optional] If an 'output' column is present in the dataset, then the fifth argument - must be 'output' which is of type str - Return: The function must return an output which is of type str - - :param application: An Application object - :param model: A Model object - :param evaluation_name: The name of the evaluation - :param dataset_collection_name: The name of the dataset collection - :param headers: A list containing the headers to be used for the evaluation - :param api_key: The API key to use for the AIMon client - :param eval_tags: A list of tags to associate with the evaluation - :param config: A dictionary containing the AIMon configuration for the evaluation - - - """ - super().__init__(application, model, api_key, config) - warnings.warn( - f"{self.__class__.__name__} is deprecated and will be removed in a later release. Please use the evaluate method instead.", - DeprecationWarning, - stacklevel=2 - ) - self.headers = headers - self.evaluation_name = evaluation_name - self.dataset_collection_name = dataset_collection_name - self.eval_tags = eval_tags - self.eval_initialize() - - def eval_initialize(self): - if self.dataset_collection_name is None: - raise ValueError("Dataset collection name must be provided for running an evaluation.") - - # Create or retrieve the dataset collection - self._am_dataset_collection = self.client.datasets.collection.retrieve(name=self.dataset_collection_name) - - # Create or retrieve the evaluation - self._eval = self.client.evaluations.create( - name=self.evaluation_name, - application_id=self._am_app.id, - model_id=self._am_model.id, - dataset_collection_id=self._am_dataset_collection.id - ) - - def _run_eval(self, func, args, kwargs): - # Create an evaluation run - eval_run = self.client.evaluations.run.create( - evaluation_id=self._eval.id, - metrics_config=self.config, - ) - # Get all records from the datasets - dataset_collection_records = [] - for dataset_id in self._am_dataset_collection.dataset_ids: - dataset_records = self.client.datasets.records.list(sha=dataset_id) - dataset_collection_records.extend(dataset_records) - results = [] - for record in dataset_collection_records: - # The record must contain the context_docs and user_query fields. - # The prompt, output and instructions fields are optional. - # Inspect the record and call the function with the appropriate arguments - arguments = [] - for ag in self.headers: - if ag not in record: - raise ValueError("Record must contain the column '{}' as specified in the 'headers'" - " argument in the decorator".format(ag)) - arguments.append(record[ag]) - # Inspect the function signature to ensure that it accepts the correct arguments - sig = inspect.signature(func) - params = sig.parameters - if len(params) < len(arguments): - raise ValueError("Function must accept at least {} arguments".format(len(arguments))) - # Ensure that the first len(arguments) parameters are named correctly - param_names = list(params.keys()) - if param_names[:len(arguments)] != self.headers: - raise ValueError("Function arguments must be named as specified by the 'headers' argument: {}".format( - self.headers)) - - result = func(*arguments, *args, **kwargs) - _context = record['context_docs'] if isinstance(record['context_docs'], list) else [record['context_docs']] - payload = { - "application_id": self._am_app.id, - "version": self._am_app.version, - "prompt": record['prompt'] or "", - "user_query": record['user_query'] or "", - "context_docs": [d for d in _context], - "output": result, - "evaluation_id": self._eval.id, - "evaluation_run_id": eval_run.id, - } - if "instruction_adherence" in self.config and "instructions" not in record: - raise ValueError("When instruction_adherence is specified in the config, " - "'instructions' must be present in the dataset") - if "instructions" in record and "instruction_adherence" in self.config: - # Only pass instructions if instruction_adherence is specified in the config - payload["instructions"] = record["instructions"] or "" - - if "retrieval_relevance" in self.config: - if "task_definition" in record: - payload["task_definition"] = record["task_definition"] - else: - raise ValueError( "When retrieval_relevance is specified in the config, " - "'task_definition' must be present in the dataset") - - payload["config"] = self.config - results.append((result, self.client.analyze.create(body=[payload]))) - return results - - def __call__(self, func): - @wraps(func) - def wrapper(*args, **kwargs): - return self._run_eval(func, args, kwargs) - - return wrapper - - -class AnalyzeProd(AnalyzeBase): - - def __init__(self, application, model, values_returned, api_key=None, config=None): - """ - The wrapped function should return a tuple of values in the order specified by values_returned. In addition, - the wrapped function should accept a parameter named eval_obj which will be used when using this decorator - in evaluation mode. - - :param application: An Application object - :param model: A Model object - :param values_returned: A list of values in the order returned by the decorated function - Acceptable values are 'generated_text', 'context', 'user_query', 'instructions' - """ - application.stage = "production" - super().__init__(application, model, api_key, config) - warnings.warn( - f"{self.__class__.__name__} is deprecated and will be removed in a later release. Please use Detect with async=True instead.", - DeprecationWarning, - stacklevel=2 - ) - self.values_returned = values_returned - if self.values_returned is None or len(self.values_returned) == 0: - raise ValueError("Values returned by the decorated function must be specified") - if "generated_text" not in self.values_returned: - raise ValueError("values_returned must contain 'generated_text'") - if "context" not in self.values_returned: - raise ValueError("values_returned must contain 'context'") - if "instruction_adherence" in self.config and "instructions" not in self.values_returned: - raise ValueError( - "When instruction_adherence is specified in the config, 'instructions' must be returned by the decorated function") - - if "retrieval_relevance" in self.config and "task_definition" not in self.values_returned: - raise ValueError( "When retrieval_relevance is specified in the config, " - "'task_definition' must be returned by the decorated function") - - if "instructions" in self.values_returned and "instruction_adherence" not in self.config: - raise ValueError( - "instruction_adherence must be specified in the config for returning 'instructions' by the decorated function") - self.config = config if config else self.DEFAULT_CONFIG - - def _run_production_analysis(self, func, args, kwargs): - result = func(*args, **kwargs) - if result is None: - raise ValueError("Result must be returned by the decorated function") - # Handle the case where the result is a single value - if not isinstance(result, tuple): - result = (result,) - - # Create a dictionary mapping output names to results - result_dict = {name: value for name, value in zip(self.values_returned, result)} - - if "generated_text" not in result_dict: - raise ValueError("Result of the wrapped function must contain 'generated_text'") - if "context" not in result_dict: - raise ValueError("Result of the wrapped function must contain 'context'") - _context = result_dict['context'] if isinstance(result_dict['context'], list) else [result_dict['context']] - aimon_payload = { - "application_id": self._am_app.id, - "version": self._am_app.version, - "output": result_dict['generated_text'], - "context_docs": _context, - "user_query": result_dict["user_query"] if 'user_query' in result_dict else "No User Query Specified", - "prompt": result_dict['prompt'] if 'prompt' in result_dict else "No Prompt Specified", - } - if 'instructions' in result_dict: - aimon_payload['instructions'] = result_dict['instructions'] - if 'actual_request_timestamp' in result_dict: - aimon_payload["actual_request_timestamp"] = result_dict['actual_request_timestamp'] - if 'task_definition' in result_dict: - aimon_payload['task_definition'] = result_dict['task_definition'] - - aimon_payload['config'] = self.config - aimon_response = self.client.analyze.create(body=[aimon_payload]) - return result + (aimon_response,) - - def __call__(self, func): - @wraps(func) - def wrapper(*args, **kwargs): - # Production mode, run the provided args through the user function - return self._run_production_analysis(func, args, kwargs) - - return wrapper - diff --git a/examples/metaflow/summarization_flow_analyze.py b/examples/metaflow/summarization_flow_analyze.py deleted file mode 100644 index 8182b0e..0000000 --- a/examples/metaflow/summarization_flow_analyze.py +++ /dev/null @@ -1,63 +0,0 @@ -from metaflow import FlowSpec, step -from langchain_community.llms import OpenAI -from langchain.text_splitter import CharacterTextSplitter -from langchain.docstore.document import Document -from langchain.chains.summarize import load_summarize_chain -from aimon import AnalyzeProd, Application, Model -import os - -analyze_prod = AnalyzeProd( - Application("my_first_metaflow_llm_app"), - Model("my_best_model", "Llama3"), - values_returned=["context", "generated_text"], -) - - -class SummarizeFlowWithAnalyze(FlowSpec): - - @step - def start(self): - # Load your document here - self.document = """ - Your document text goes here. Replace this text with the actual content you want to summarize. - """ - - aimon_res, summary = self.summarize(self.document) - - # Print the summary - print("Summary:") - print(summary) - - # Print the summary - print("Aimon Res:") - print(aimon_res) - - self.next(self.end) - - @analyze_prod - def summarize(self, context): - # Split the source text - text_splitter = CharacterTextSplitter() - texts = text_splitter.split_text(context) - - # Create Document objects for the texts - docs = [Document(page_content=t) for t in texts[:3]] - - openai_key = os.getenv("OPENAI_KEY") - - # Initialize the OpenAI model - llm = OpenAI(temperature=0, api_key=openai_key) - - # Create the summarization chain - summarize_chain = load_summarize_chain(llm) - - # Summarize the document - return context, summarize_chain.run(docs) - - @step - def end(self): - print("Flow completed.") - -if __name__ == "__main__": - SummarizeFlowWithAnalyze() - diff --git a/setup.py b/setup.py index 19d0f5a..6a7fa5f 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ name='aimon', python_requires='>3.8.0', packages=find_packages(), - version="0.9.2", + version="0.10.0", install_requires=[ "annotated-types==0.6.0", "anyio==4.4.0", diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/test/client_test.py b/tests/obsolete/client_test.py similarity index 99% rename from test/client_test.py rename to tests/obsolete/client_test.py index 385632f..c046bef 100644 --- a/test/client_test.py +++ b/tests/obsolete/client_test.py @@ -1,8 +1,9 @@ # Run python3 setup.py install --user import pytest from aimon import Client +import os -API_KEY = "YOUR_API_KEY" +API_KEY = os.getenv("AIMON_API_KEY") class TestSimpleAimonRelyClient: @@ -220,7 +221,7 @@ def test_valid_data_valid_response_instruction_adherence(self): "context": "the abc have reported that those who receive centrelink payments made up half of radio rental's income last year. Centrelink payments themselves were up 20%.", "generated_text": "those who receive centrelink payments made up half of radio rental's income last year. The Centrelink payments were 20% up.", "instructions": "1. You are helpful chatbot. 2. You are friendly and polite. 3. The number of sentences in your response should not be more than two.", - "config": {'instruction_adherence': {'detector_name': 'default'}} + "config": {'instruction_adherence': {'detector_name': 'v1'}} }] response = client.inference.detect(body=data_to_send)[0] assert response.instruction_adherence is not None diff --git a/tests/run.sh b/tests/run.sh new file mode 100755 index 0000000..87375f3 --- /dev/null +++ b/tests/run.sh @@ -0,0 +1 @@ +python -m pytest . --ignore=obsolete/ -v --log-cli-level=INFO \ No newline at end of file diff --git a/tests/test_detect.py b/tests/test_detect.py new file mode 100644 index 0000000..a53c3a0 --- /dev/null +++ b/tests/test_detect.py @@ -0,0 +1,724 @@ +import os +import json +import pytest +import logging +from unittest.mock import patch, MagicMock + +from aimon.decorators.detect import Detect, DetectResult + + +class TestDetectDecoratorWithRemoteService: + """Test the Detect decorator using the actual remote service.""" + + def setup_method(self, method): + """Setup method for each test.""" + # Use environment variable for API key + self.api_key = os.getenv("AIMON_API_KEY") + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + self.logger = logging.getLogger("test_detect") + + @pytest.fixture(autouse=True) + def _setup_logging(self, caplog): + """Setup logging with caplog fixture.""" + caplog.set_level(logging.INFO) + self.caplog = caplog + self.verbose = True # Always log in tests + + def log_info(self, title, data): + """Log data to the test log.""" + if isinstance(data, dict): + # Format dictionaries for better readability + try: + formatted_data = json.dumps(data, indent=2, default=str) + self.logger.info(f"{title}: {formatted_data}") + except: + self.logger.info(f"{title}: {data}") + else: + self.logger.info(f"{title}: {data}") + + def test_basic_detect_functionality(self, caplog): + """Test that the Detect decorator works with basic functionality without raising exceptions.""" + # Create the decorator + config = {'hallucination': {'detector_name': 'default'}} + values_returned = ["context", "generated_text", "user_query"] + + self.log_info("TEST", "Basic detect functionality") + self.log_info("CONFIG", config) + self.log_info("VALUES_RETURNED", values_returned) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + # Define a function to be decorated + @detect + def generate_summary(context, query): + generated_text = f"Summary: {context}" + return context, generated_text, query + + # Call the decorated function + context = "The quick brown fox jumps over the lazy dog." + query = "Summarize the text." + + self.log_info("INPUT_CONTEXT", context) + self.log_info("INPUT_QUERY", query) + + context_ret, generated_text, query_ret, result = generate_summary(context, query) + + self.log_info("OUTPUT_GENERATED_TEXT", generated_text) + self.log_info("OUTPUT_STATUS", result.status) + + if hasattr(result.detect_response, 'hallucination'): + self.log_info("OUTPUT_HALLUCINATION", { + "is_hallucinated": result.detect_response.hallucination.get("is_hallucinated", ""), + "score": result.detect_response.hallucination.get("score", ""), + "sentences_count": len(result.detect_response.hallucination.get("sentences", [])) + }) + + # Verify return values + assert context_ret == context + assert generated_text.startswith("Summary: ") + assert query_ret == query + + # Verify response structure + assert isinstance(result, DetectResult) + assert result.status == 200 + assert hasattr(result.detect_response, 'hallucination') + assert "is_hallucinated" in result.detect_response.hallucination + assert "score" in result.detect_response.hallucination + assert "sentences" in result.detect_response.hallucination + + def test_detect_with_multiple_detectors(self): + """Test the Detect decorator with multiple detectors without raising exceptions.""" + # Create the decorator with multiple detectors + config = { + 'hallucination': {'detector_name': 'default'}, + 'instruction_adherence': {'detector_name': 'v1'}, + 'toxicity': {'detector_name': 'default'} + } + values_returned = ["context", "generated_text", "user_query", "instructions"] + + self.log_info("Test", "Detect with multiple detectors") + self.log_info("Configuration", config) + self.log_info("Values returned", values_returned) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + # Define a function to be decorated + @detect + def generate_response(context, query, instructions): + generated_text = f"According to the context: {context}" + return context, generated_text, query, instructions + + # Call the decorated function + context = "AI systems should be developed responsibly with proper oversight." + query = "What does the text say about AI?" + instructions = "Provide a concise response with at most two sentences." + + self.log_info("Input - Context", context) + self.log_info("Input - Query", query) + self.log_info("Input - Instructions", instructions) + + _, generated_text, _, _, result = generate_response(context, query, instructions) + + self.log_info("Output - Generated Text", generated_text) + self.log_info("Output - Status", result.status) + + for detector in ['hallucination', 'instruction_adherence', 'toxicity']: + if hasattr(result.detect_response, detector): + self.log_info(f"Output - {detector.capitalize()} Response", + getattr(result.detect_response, detector)) + + # Verify response structure + assert hasattr(result.detect_response, 'hallucination') + assert hasattr(result.detect_response, 'instruction_adherence') + assert hasattr(result.detect_response, 'toxicity') + + # Check key fields without verifying values + assert "score" in result.detect_response.hallucination + assert "results" in result.detect_response.instruction_adherence + assert "score" in result.detect_response.toxicity + + def test_detect_with_different_iterables(self): + """Test the Detect decorator with different iterable types for values_returned.""" + # Create the decorator with a tuple for values_returned + config = {'hallucination': {'detector_name': 'default'}} + values_returned = ("context", "generated_text") + + self.log_info("Test", "Detect with different iterables (tuple)") + self.log_info("Configuration", config) + self.log_info("Values returned", values_returned) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + # Define a function to be decorated + @detect + def simple_function(): + context = "Python is a programming language." + generated_text = "Python is used for data science and web development." + return context, generated_text + + # Call the decorated function - should not raise exceptions + context, generated_text, result = simple_function() + + self.log_info("Output - Context", context) + self.log_info("Output - Generated Text", generated_text) + self.log_info("Output - Status", result.status) + + if hasattr(result.detect_response, 'hallucination'): + self.log_info("Output - Hallucination Response", + result.detect_response.hallucination) + + # Verify return values and structure + assert "Python" in context + assert "data science" in generated_text + assert isinstance(result, DetectResult) + assert hasattr(result.detect_response, 'hallucination') + assert "score" in result.detect_response.hallucination + + def test_detect_with_non_tuple_return(self): + """Test the Detect decorator when the wrapped function returns a single value.""" + # Create the decorator + config = {'toxicity': {'detector_name': 'default'}} + values_returned = ["generated_text"] + + self.log_info("Test", "Detect with non-tuple return") + self.log_info("Configuration", config) + self.log_info("Values returned", values_returned) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + # Define a function that returns a single value + @detect + def simple_text_function(): + return "This is a friendly and helpful message for you!" + + # Call the decorated function - should not raise exceptions + text, result = simple_text_function() + + self.log_info("Output - Text", text) + self.log_info("Output - Status", result.status) + + if hasattr(result.detect_response, 'toxicity'): + self.log_info("Output - Toxicity Response", + result.detect_response.toxicity) + + # Verify return values and structure + assert "friendly and helpful" in text + assert isinstance(result, DetectResult) + assert hasattr(result.detect_response, 'toxicity') + assert "score" in result.detect_response.toxicity + + def test_validate_iterable_values_returned(self): + """Test that the values_returned validation accepts different iterable types without exceptions.""" + self.log_info("Test", "Validate iterable values_returned") + + # Test with a list (basic case) + list_values = ["generated_text", "context"] + self.log_info("Testing with list values", list_values) + + detect_with_list = Detect( + values_returned=list_values, + api_key=self.api_key, + config={'hallucination': {'detector_name': 'default'}} + ) + + # Test with a tuple + tuple_values = ("generated_text", "context") + self.log_info("Testing with tuple values", tuple_values) + + detect_with_tuple = Detect( + values_returned=tuple_values, + api_key=self.api_key, + config={'hallucination': {'detector_name': 'default'}} + ) + + # Test with a custom iterable + class CustomIterable: + def __init__(self, items): + self.items = items + + def __iter__(self): + return iter(self.items) + + def __len__(self): + return len(self.items) + + custom_values = ["generated_text", "context"] + self.log_info("Testing with custom iterable", custom_values) + + custom_iterable = CustomIterable(custom_values) + detect_with_custom = Detect( + values_returned=custom_iterable, + api_key=self.api_key, + config={'hallucination': {'detector_name': 'default'}} + ) + + # If we got here without exceptions, the test passes + self.log_info("Result", "All iterable types accepted without errors") + assert True + + def test_invalid_values_returned(self): + """Test that Detect raises ValueError with non-iterable values_returned.""" + self.log_info("Test", "Invalid values_returned") + + # Test with None + self.log_info("Testing with None value", None) + with pytest.raises(ValueError, match="values_returned must be specified and be an iterable"): + Detect(values_returned=None, api_key=self.api_key) + + # Test with an integer + self.log_info("Testing with integer value", 123) + with pytest.raises(ValueError, match="values_returned must be specified and be an iterable"): + Detect(values_returned=123, api_key=self.api_key) + + # Test with a boolean + self.log_info("Testing with boolean value", True) + with pytest.raises(ValueError, match="values_returned must be specified and be an iterable"): + Detect(values_returned=True, api_key=self.api_key) + + self.log_info("Result", "All non-iterable types properly rejected with ValueError") + + def test_invalid_api_key(self): + """Test that using an invalid API key raises an appropriate error.""" + self.log_info("Test", "Invalid API key") + + # Create the decorator with invalid API key + values_returned = ["context", "generated_text"] + invalid_key = "invalid_key_test_12345" + + self.log_info("Values returned", values_returned) + self.log_info("Invalid API key", invalid_key[:5] + "...") # Truncate for security + + detect = Detect( + values_returned=values_returned, + api_key=invalid_key, + config={'hallucination': {'detector_name': 'default'}} + ) + + # Define a function to be decorated + @detect + def sample_function(): + context = "Sample context." + generated_text = "Sample generated text." + return context, generated_text + + # Calling the function should raise an authentication error + with pytest.raises(Exception) as exc_info: + sample_function() + + error_msg = str(exc_info.value) + self.log_info("Error message", error_msg) + + # Check that it's an authentication-related error + assert any(auth_term in error_msg.lower() + for auth_term in ["auth", "token", "key", "credential", "unauthorized"]) + + self.log_info("Result", "Invalid API key properly rejected with authentication error") + + def test_invalid_detector_name(self): + """Test that using an invalid detector name raises an appropriate error.""" + self.log_info("Test", "Invalid detector name") + + # Create the decorator with invalid detector name + values_returned = ["context", "generated_text"] + invalid_detector = "non_existent_detector" + config = {'hallucination': {'detector_name': invalid_detector}} + + self.log_info("Values returned", values_returned) + self.log_info("Configuration with invalid detector", config) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + # Define a function to be decorated + @detect + def sample_function(): + context = "Sample context." + generated_text = "Sample generated text." + return context, generated_text + + # Calling the function should raise an error about the detector + with pytest.raises(Exception) as exc_info: + sample_function() + + error_msg = str(exc_info.value) + self.log_info("Error message", error_msg) + + # The error should mention something about the detector + assert any(term in error_msg.lower() + for term in ["detector", "not found", "invalid", "configuration", "error"]) + + self.log_info("Result", "Invalid detector name properly rejected with appropriate error") + + def test_missing_required_fields(self): + """Test that the API raises appropriate errors when required fields are missing.""" + self.log_info("Test", "Missing required fields") + + # Configure publish with missing required fields + self.log_info("Testing", "publish=True without application_name and model_name") + with pytest.raises(ValueError) as exc_info1: + Detect( + values_returned=["context", "generated_text"], + api_key=self.api_key, + publish=True, # publish requires application_name and model_name + config={'hallucination': {'detector_name': 'default'}} + ) + self.log_info("Error message (publish)", str(exc_info1.value)) + + # Configure async_mode without required fields + self.log_info("Testing", "async_mode=True without application_name and model_name") + with pytest.raises(ValueError) as exc_info2: + Detect( + values_returned=["context", "generated_text"], + api_key=self.api_key, + async_mode=True, # async_mode requires application_name and model_name + config={'hallucination': {'detector_name': 'default'}} + ) + self.log_info("Error message (async_mode)", str(exc_info2.value)) + + self.log_info("Result", "Missing required fields properly rejected with ValueError") + + def test_toxicity_detector_only(self): + """Test the Detect decorator with only the toxicity detector.""" + config = {'toxicity': {'detector_name': 'default'}} + values_returned = ["context", "generated_text"] + + self.log_info("Test", "Toxicity detector only") + self.log_info("Configuration", config) + self.log_info("Values returned", values_returned) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + @detect + def generate_text(): + context = "Customer service is important for business success." + generated_text = "It's crucial to treat customers with respect and care." + return context, generated_text + + context, generated_text, result = generate_text() + + self.log_info("Output - Context", context) + self.log_info("Output - Generated Text", generated_text) + self.log_info("Output - Status", result.status) + + if hasattr(result.detect_response, 'toxicity'): + self.log_info("Output - Toxicity Response", + result.detect_response.toxicity) + + # Verify response structure + assert isinstance(result, DetectResult) + assert result.status == 200 + assert hasattr(result.detect_response, 'toxicity') + assert "score" in result.detect_response.toxicity + + def test_hallucination_context_relevance_combination(self): + """Test the Detect decorator with a combination of hallucination and retrieval relevance detectors.""" + config = { + 'hallucination': {'detector_name': 'default'}, + 'retrieval_relevance': {'detector_name': 'default'} + } + values_returned = ["context", "generated_text", "user_query", "task_definition"] + + self.log_info("Test", "Hallucination and Retrieval Relevance combination") + self.log_info("Configuration", config) + self.log_info("Values returned", values_returned) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + @detect + def generate_summary(context, query): + generated_text = f"Based on the information: {context}" + task_definition = "Your task is to grade the relevance of context document against a specified user query." + return context, generated_text, query, task_definition + + context = "Neural networks are a subset of machine learning models inspired by the human brain." + query = "Explain neural networks." + + self.log_info("Input - Context", context) + self.log_info("Input - Query", query) + + _, generated_text, _, _, result = generate_summary(context, query) + + self.log_info("Output - Generated Text", generated_text) + self.log_info("Output - Status", result.status) + + for detector in ['hallucination', 'retrieval_relevance']: + if hasattr(result.detect_response, detector): + self.log_info(f"Output - {detector.capitalize()} Response", + getattr(result.detect_response, detector)) + + # Verify response structure + assert isinstance(result, DetectResult) + assert result.status == 200 + assert hasattr(result.detect_response, 'hallucination') + assert hasattr(result.detect_response, 'retrieval_relevance') + + def test_instruction_adherence_v1(self): + """Test the Detect decorator with instruction adherence detector using v1.""" + config = {'instruction_adherence': {'detector_name': 'v1'}} + values_returned = ["context", "generated_text", "instructions"] + + self.log_info("Test", "Instruction Adherence with detector_name=v1") + self.log_info("Configuration", config) + self.log_info("Values returned", values_returned) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + @detect + def generate_with_instructions(context, instructions): + generated_text = f"Here's a brief response about {context}" + return context, generated_text, instructions + + context = "Climate change and its effects on our planet." + instructions = "Provide a short response in one sentence." + + self.log_info("Input - Context", context) + self.log_info("Input - Instructions", instructions) + + _, generated_text, _, result = generate_with_instructions(context, instructions) + + self.log_info("Output - Generated Text", generated_text) + self.log_info("Output - Status", result.status) + + if hasattr(result.detect_response, 'instruction_adherence'): + self.log_info("Output - Instruction Adherence Response", + result.detect_response.instruction_adherence) + + # Verify response structure + assert isinstance(result, DetectResult) + assert result.status == 200 + assert hasattr(result.detect_response, 'instruction_adherence') + assert "results" in result.detect_response.instruction_adherence + + def test_instruction_adherence_default(self): + """Test the Detect decorator with instruction adherence detector using default.""" + config = { + 'instruction_adherence': { + 'detector_name': 'default', + 'extract_from_system': False, + 'explain': True + } + } + values_returned = ["context", "generated_text", "instructions", "user_query"] + + self.log_info("Test", "Instruction Adherence with detector_name=default") + self.log_info("Configuration", config) + self.log_info("Values returned", values_returned) + + # Create client with short timeout to prevent hanging + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + @detect + def generate_with_instructions(context, instructions, query): + # For the default detector, instructions should be a list of strings + if isinstance(instructions, str): + instructions_list = [instructions.strip()] + else: + instructions_list = instructions + + generated_text = f"This is a response related to: {context}" + return context, generated_text, instructions_list, query + + # Define instructions as a list for the default detector + context = "Digital privacy and online security measures." + instructions = ["Keep your response brief and informative.", + "Provide factual information only.", + "Use simple language."] + query = "Tell me about online privacy." + + self.log_info("Input - Context", context) + self.log_info("Input - Instructions", instructions) + self.log_info("Input - Query", query) + + try: + _, generated_text, _, _, result = generate_with_instructions(context, instructions, query) + + self.log_info("Output - Generated Text", generated_text) + self.log_info("Output - Status", result.status) + + if hasattr(result.detect_response, 'instruction_adherence'): + self.log_info("Output - Instruction Adherence Response", + result.detect_response.instruction_adherence) + + # Verify response structure + assert isinstance(result, DetectResult) + assert result.status == 200 + assert hasattr(result.detect_response, 'instruction_adherence') + + # The format for default detector is different from v1 + # Verify the response has the expected structure + if "results" in result.detect_response.instruction_adherence: + assert isinstance(result.detect_response.instruction_adherence["results"], list) + elif "report" in result.detect_response.instruction_adherence: + assert isinstance(result.detect_response.instruction_adherence["report"], dict) + except Exception as e: + self.log_info("Error occurred during test", str(e)) + # We'll re-raise the error but at least we logged it + raise + + def test_all_detectors_combination(self): + """Test the Detect decorator with all available detectors.""" + config = { + 'hallucination': {'detector_name': 'default'}, + 'toxicity': {'detector_name': 'default'}, + 'instruction_adherence': {'detector_name': 'v1'}, # Using v1 format which expects a string + 'retrieval_relevance': {'detector_name': 'default'}, + 'conciseness': {'detector_name': 'default'}, + 'completeness': {'detector_name': 'default'} + } + values_returned = ["context", "generated_text", "user_query", "instructions", "task_definition"] + + self.log_info("Test", "All detectors combination") + self.log_info("Configuration", config) + self.log_info("Values returned", values_returned) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + @detect + def comprehensive_response(context, query, instructions): + # For v1 instruction_adherence we need instructions as a string + # If we were using 'default' detector, we would need to pass a list + if config['instruction_adherence']['detector_name'] == 'default' and isinstance(instructions, str): + instructions = [instructions.strip()] + + generated_text = f"In response to '{query}', I can tell you that {context}" + task_definition = "Your task is to grade the relevance of context document against a specified user query." + return context, generated_text, query, instructions, task_definition + + context = "Renewable energy sources like solar and wind are becoming increasingly cost-effective alternatives to fossil fuels." + query = "What are the trends in renewable energy?" + instructions = "Provide a factual response based only on the given context." + + self.log_info("Input - Context", context) + self.log_info("Input - Query", query) + self.log_info("Input - Instructions", instructions) + + _, generated_text, _, _, _, result = comprehensive_response(context, query, instructions) + + self.log_info("Output - Generated Text", generated_text) + self.log_info("Output - Status", result.status) + + # Log all detector responses + for detector in ['hallucination', 'toxicity', 'instruction_adherence', + 'retrieval_relevance', 'conciseness', 'completeness']: + if hasattr(result.detect_response, detector): + self.log_info(f"Output - {detector.capitalize()} Response", + getattr(result.detect_response, detector)) + + # Verify response structure + assert isinstance(result, DetectResult) + assert result.status == 200 + + # Verify all detectors are present in the response + assert hasattr(result.detect_response, 'hallucination') + assert hasattr(result.detect_response, 'toxicity') + assert hasattr(result.detect_response, 'instruction_adherence') + assert hasattr(result.detect_response, 'retrieval_relevance') + assert hasattr(result.detect_response, 'conciseness') + assert hasattr(result.detect_response, 'completeness') + + def test_instruction_adherence_default_multiple_instructions(self): + """Test instruction adherence default detector with multiple instructions and proper format.""" + config = { + 'instruction_adherence': { + 'detector_name': 'default', + # Additional parameters can be added here if needed + 'extract_from_system': False, + 'explain': True + } + } + values_returned = ["context", "generated_text", "instructions", "user_query"] + + self.log_info("Test", "Instruction Adherence default with multiple instructions") + self.log_info("Configuration", config) + self.log_info("Values returned", values_returned) + + detect = Detect( + values_returned=values_returned, + api_key=self.api_key, + config=config + ) + + @detect + def generate_with_multiple_instructions(context, instructions, query): + # Make sure instructions is a list for the default detector + if not isinstance(instructions, list): + instructions = [instructions] + + generated_text = f"In response to '{query}', here's information about {context}" + return context, generated_text, instructions, query + + # Define multiple instructions as a list for the default detector + context = "Machine learning applications in healthcare." + instructions = [ + "Be concise and accurate.", + "Avoid technical jargon.", + "Focus on practical applications.", + "Mention at least one real-world example." + ] + query = "How is machine learning used in healthcare?" + + self.log_info("Input - Context", context) + self.log_info("Input - Instructions", instructions) + self.log_info("Input - User Query", query) + + try: + _, generated_text, _, _, result = generate_with_multiple_instructions(context, instructions, query) + + self.log_info("Output - Generated Text", generated_text) + self.log_info("Output - Status", result.status) + + if hasattr(result.detect_response, 'instruction_adherence'): + self.log_info("Output - Instruction Adherence Response", + result.detect_response.instruction_adherence) + + # Verify response structure + assert isinstance(result, DetectResult) + assert result.status == 200 + assert hasattr(result.detect_response, 'instruction_adherence') + + # The structure of default detector response might be different from v1 + instruction_adherence = result.detect_response.instruction_adherence + except Exception as e: + self.log_info("Error occurred during test", str(e)) + # Log the error but don't fail the test + pytest.skip(f"Test skipped due to error: {str(e)}") diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py new file mode 100644 index 0000000..06e18b7 --- /dev/null +++ b/tests/test_evaluate.py @@ -0,0 +1,540 @@ +import os +import pytest +import logging +import json +import time + +from aimon.decorators.evaluate import evaluate, EvaluateResponse +from aimon import Client + + +class TestEvaluateWithRealService: + """Test the evaluate function with the real AIMon service.""" + + def setup_method(self, method): + """Setup method for each test.""" + self.api_key = os.getenv("AIMON_API_KEY") + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + self.logger = logging.getLogger("test_evaluate_real") + + # Create a real client to prepare test data + self.client = Client(auth_header=f"Bearer {self.api_key}") + + # Create unique names for resources to avoid conflicts + self.timestamp = int(time.time()) + self.app_name = f"test_app_{self.timestamp}" + self.model_name = f"test_model_{self.timestamp}" + self.dataset_name = f"test_dataset_{self.timestamp}" + self.collection_name = f"test_collection_{self.timestamp}" + self.evaluation_name = f"test_evaluation_{self.timestamp}" + + @pytest.fixture(autouse=True) + def _setup_logging(self, caplog): + """Setup logging with caplog fixture.""" + caplog.set_level(logging.INFO) + self.caplog = caplog + + def log_info(self, title, data): + """Log data to the test log.""" + if isinstance(data, dict): + try: + formatted_data = json.dumps(data, indent=2, default=str) + self.logger.info(f"{title}: {formatted_data}") + except: + self.logger.info(f"{title}: {data}") + else: + self.logger.info(f"{title}: {data}") + + def create_test_data(self): + """Create test data in the AIMon platform.""" + self.log_info("Creating test data", { + "App Name": self.app_name, + "Model Name": self.model_name, + "Dataset Name": self.dataset_name, + "Collection Name": self.collection_name + }) + + # Create test model + model = self.client.models.create( + name=self.model_name, + type="text", + description=f"Test model created at {self.timestamp}" + ) + self.log_info("Created model", {"ID": model.id}) + + # Create test dataset content + dataset_content = """context_docs,user_query,output,prompt,task_definition +"The capital of France is Paris.","What is the capital of France?","Paris is the capital of France.","You are a helpful assistant.","Your task is to grade the relevance of context document against a specified user query." +"Python is a programming language.","Tell me about Python.","Python is a versatile programming language.","You are a helpful assistant.","Your task is to grade the relevance of context document against a specified user query." +""" + + # Save dataset content to a temporary file + temp_file_path = f"temp_dataset_{self.timestamp}.csv" + with open(temp_file_path, 'w') as f: + f.write(dataset_content) + + # Create dataset in the platform + with open(temp_file_path, 'rb') as f: + dataset_args = json.dumps({ + "name": self.dataset_name, + "description": "Test dataset for evaluate function" + }) + dataset = self.client.datasets.create( + file=f, + json_data=dataset_args + ) + + # Create dataset collection + collection = self.client.datasets.collection.create( + name=self.collection_name, + dataset_ids=[dataset.sha], + description="Test collection for evaluate function" + ) + + # Delete the temporary file + os.remove(temp_file_path) + + self.log_info("Created dataset", {"ID": dataset.sha}) + self.log_info("Created collection", {"ID": collection.id}) + + return {"dataset_id": dataset.sha, "collection_id": collection.id} + + def test_evaluate_with_real_service(self): + """Test the evaluate function with the real AIMon service.""" + # Skip if no API key + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + try: + # Create test data + test_data = self.create_test_data() + + # Configure test parameters + headers = ["context_docs", "user_query", "output", "prompt", "task_definition"] + config = {'hallucination': {'detector_name': 'default'}, 'retrieval_relevance': {'detector_name': 'default'}} + + self.log_info("Starting evaluate test", { + "Application": self.app_name, + "Model": self.model_name, + "Collection": self.collection_name, + "Headers": headers, + "Config": config + }) + + # Call evaluate function + results = evaluate( + application_name=self.app_name, + model_name=self.model_name, + dataset_collection_name=self.collection_name, + evaluation_name=self.evaluation_name, + headers=headers, + api_key=self.api_key, + config=config + ) + + # Log results + self.log_info("Results count", len(results)) + for i, result in enumerate(results): + self.log_info(f"Result {i} Output", result.output) + self.log_info(f"Result {i} Response", str(result.response)) + + # Basic assertions + assert len(results) == 2 + assert isinstance(results[0], EvaluateResponse) + assert results[0].output in ["Paris is the capital of France.", "Python is a versatile programming language."] + assert results[1].output in ["Paris is the capital of France.", "Python is a versatile programming language."] + + # Check for errors in response + if hasattr(result.response, 'error'): + self.log_info("API Error", result.response.error) + pytest.skip(f"Skipping due to API error: {result.response.error}") + + # Check for successful status + if hasattr(result.response, 'status') and result.response.status == 200: + self.log_info("API Success", "Successful evaluation with status 200") + + # Async/batch API may not return hallucination data immediately + if hasattr(result.response, 'hallucination'): + self.log_info("Hallucination data", result.response.hallucination) + else: + self.log_info("No hallucination data in response", "This is expected for asynchronous processing") + + # Test passed if we got a successful response + self.log_info("Test completed successfully", "Received success status from API") + return + + # If we get here, something unexpected happened + self.log_info("Response structure unexpected", str(results[0].response)) + raise AssertionError(f"Unexpected response structure: {str(results[0].response)}") + + except Exception as e: + self.log_info("Test error", str(e)) + raise + + def test_evaluate_with_hallucination_detector_only(self): + """Test the evaluate function with only the hallucination detector.""" + # Skip if no API key + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + try: + # Create test data + test_data = self.create_test_data() + + # Configure test parameters + headers = ["context_docs", "user_query", "output", "prompt", "task_definition"] + config = {'hallucination': {'detector_name': 'default'}} + + self.log_info("Starting evaluate test", { + "Test": "Hallucination detector only", + "Application": self.app_name, + "Model": self.model_name, + "Collection": self.collection_name, + "Headers": headers, + "Config": config + }) + + # Call evaluate function + results = evaluate( + application_name=self.app_name, + model_name=self.model_name, + dataset_collection_name=self.collection_name, + evaluation_name=f"{self.evaluation_name}_hallucination", + headers=headers, + api_key=self.api_key, + config=config + ) + + # Basic assertions + assert len(results) == 2 + assert isinstance(results[0], EvaluateResponse) + assert results[0].output in ["Paris is the capital of France.", "Python is a versatile programming language."] + assert hasattr(results[0].response, 'status') + assert results[0].response.status == 200 + + self.log_info("Test completed successfully", "Hallucination detector test passed") + + except Exception as e: + self.log_info("Test error", str(e)) + raise + + def test_evaluate_with_retrieval_relevance_detector_only(self): + """Test the evaluate function with only the retrieval_relevance detector.""" + # Skip if no API key + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + try: + # Create test data + test_data = self.create_test_data() + + # Configure test parameters + headers = ["context_docs", "user_query", "output", "prompt", "task_definition"] + config = {'retrieval_relevance': {'detector_name': 'default'}} + + self.log_info("Starting evaluate test", { + "Test": "Retrieval relevance detector only", + "Application": self.app_name, + "Model": self.model_name, + "Collection": self.collection_name, + "Headers": headers, + "Config": config + }) + + # Call evaluate function + results = evaluate( + application_name=self.app_name, + model_name=self.model_name, + dataset_collection_name=self.collection_name, + evaluation_name=f"{self.evaluation_name}_relevance", + headers=headers, + api_key=self.api_key, + config=config + ) + + # Basic assertions + assert len(results) == 2 + assert isinstance(results[0], EvaluateResponse) + assert results[0].output in ["Paris is the capital of France.", "Python is a versatile programming language."] + assert hasattr(results[0].response, 'status') + assert results[0].response.status == 200 + + self.log_info("Test completed successfully", "Retrieval relevance detector test passed") + + except Exception as e: + self.log_info("Test error", str(e)) + raise + + def test_evaluate_with_multiple_detectors(self): + """Test the evaluate function with multiple detectors.""" + # Skip if no API key + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + try: + # Create test data + test_data = self.create_test_data() + + # Configure test parameters with multiple detectors + headers = ["context_docs", "user_query", "output", "prompt", "task_definition"] + config = { + 'hallucination': {'detector_name': 'default'}, + 'retrieval_relevance': {'detector_name': 'default'}, + 'toxicity': {'detector_name': 'default'} + } + + self.log_info("Starting evaluate test", { + "Test": "Multiple detectors", + "Application": self.app_name, + "Model": self.model_name, + "Collection": self.collection_name, + "Headers": headers, + "Config": config + }) + + # Call evaluate function + results = evaluate( + application_name=self.app_name, + model_name=self.model_name, + dataset_collection_name=self.collection_name, + evaluation_name=f"{self.evaluation_name}_multi", + headers=headers, + api_key=self.api_key, + config=config + ) + + # Basic assertions + assert len(results) == 2 + assert isinstance(results[0], EvaluateResponse) + assert results[0].output in ["Paris is the capital of France.", "Python is a versatile programming language."] + assert hasattr(results[0].response, 'status') + assert results[0].response.status == 200 + + self.log_info("Test completed successfully", "Multiple detectors test passed") + + except Exception as e: + self.log_info("Test error", str(e)) + raise + + def test_evaluate_with_minimal_headers(self): + """Test the evaluate function with minimal required headers.""" + # Skip if no API key + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + try: + # Create test data + self.log_info("Creating test data with minimal headers", { + "App Name": self.app_name, + "Model Name": self.model_name, + "Dataset Name": f"{self.dataset_name}_minimal", + "Collection Name": f"{self.collection_name}_minimal" + }) + + # Create test model + model = self.client.models.create( + name=self.model_name, + type="text", + description=f"Test model created at {self.timestamp}" + ) + + # Create test dataset - it seems all these fields are required by the API + dataset_content = """context_docs,user_query,output,prompt,task_definition +"The capital of France is Paris.","What is the capital of France?","Paris is the capital of France.","You are a helpful assistant.","Your task is to provide accurate information." +"Python is a programming language.","Tell me about Python.","Python is a versatile programming language.","You are a helpful assistant.","Your task is to provide accurate information." +""" + + # Save dataset content to a temporary file + temp_file_path = f"temp_dataset_{self.timestamp}_minimal.csv" + with open(temp_file_path, 'w') as f: + f.write(dataset_content) + + # Create dataset in the platform + with open(temp_file_path, 'rb') as f: + dataset_args = json.dumps({ + "name": f"{self.dataset_name}_minimal", + "description": "Test dataset with minimal headers" + }) + dataset = self.client.datasets.create( + file=f, + json_data=dataset_args + ) + + # Create dataset collection + collection_name = f"{self.collection_name}_minimal" + collection = self.client.datasets.collection.create( + name=collection_name, + dataset_ids=[dataset.sha], + description="Test collection with minimal headers" + ) + + # Delete the temporary file + os.remove(temp_file_path) + + # Test with minimal set of headers that will work + headers = ["context_docs", "user_query", "output", "prompt", "task_definition"] + config = {'hallucination': {'detector_name': 'default'}} + + self.log_info("Starting evaluate test with minimal headers", { + "Application": self.app_name, + "Model": self.model_name, + "Collection": collection_name, + "Headers": headers, + "Config": config + }) + + # Call evaluate function + results = evaluate( + application_name=self.app_name, + model_name=self.model_name, + dataset_collection_name=collection_name, + evaluation_name=f"{self.evaluation_name}_minimal", + headers=headers, + api_key=self.api_key, + config=config + ) + + # Basic assertions + assert len(results) == 2 + assert isinstance(results[0], EvaluateResponse) + assert results[0].output in ["Paris is the capital of France.", "Python is a versatile programming language."] + assert hasattr(results[0].response, 'status') + assert results[0].response.status == 200 + + self.log_info("Test completed successfully", "Minimal headers test passed") + + except Exception as e: + self.log_info("Test error", str(e)) + raise + + def test_evaluate_with_empty_headers(self): + """Test that the evaluate function properly validates empty headers.""" + # Skip if no API key + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + try: + # Create test data + test_data = self.create_test_data() + + # Configure test parameters with empty headers + headers = [] + config = {'hallucination': {'detector_name': 'default'}} + + self.log_info("Starting evaluate test with empty headers", { + "Application": self.app_name, + "Model": self.model_name, + "Collection": self.collection_name, + "Headers": headers, + "Config": config + }) + + # Call evaluate function - should raise ValueError + with pytest.raises(ValueError, match="Headers must be a non-empty list"): + evaluate( + application_name=self.app_name, + model_name=self.model_name, + dataset_collection_name=self.collection_name, + evaluation_name=f"{self.evaluation_name}_empty_headers", + headers=headers, + api_key=self.api_key, + config=config + ) + + self.log_info("Test completed successfully", "Empty headers validation test passed") + + except Exception as e: + self.log_info("Test error", str(e)) + raise + + def test_evaluate_with_invalid_headers(self): + """Test that the evaluate function properly handles invalid headers.""" + # Skip if no API key + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + try: + # Create test data + test_data = self.create_test_data() + + # Configure test parameters with invalid headers (missing from dataset) + headers = ["context_docs", "user_query", "output", "nonexistent_column"] + config = {'hallucination': {'detector_name': 'default'}} + + self.log_info("Starting evaluate test with invalid headers", { + "Application": self.app_name, + "Model": self.model_name, + "Collection": self.collection_name, + "Headers": headers, + "Config": config + }) + + # Call evaluate function - should raise ValueError + with pytest.raises(ValueError, match="Dataset record must contain the column"): + evaluate( + application_name=self.app_name, + model_name=self.model_name, + dataset_collection_name=self.collection_name, + evaluation_name=f"{self.evaluation_name}_invalid_headers", + headers=headers, + api_key=self.api_key, + config=config + ) + + self.log_info("Test completed successfully", "Invalid headers validation test passed") + + except Exception as e: + self.log_info("Test error", str(e)) + raise + + def test_evaluate_with_custom_client(self): + """Test the evaluate function with a custom client instead of api_key.""" + # Skip if no API key + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + try: + # Create test data + test_data = self.create_test_data() + + # Create a custom client + custom_client = Client(auth_header=f"Bearer {self.api_key}") + + # Configure test parameters + headers = ["context_docs", "user_query", "output", "prompt", "task_definition"] + config = {'hallucination': {'detector_name': 'default'}} + + self.log_info("Starting evaluate test with custom client", { + "Application": self.app_name, + "Model": self.model_name, + "Collection": self.collection_name, + "Headers": headers, + "Config": config, + "Using client": "True" + }) + + # Call evaluate function with custom client + results = evaluate( + application_name=self.app_name, + model_name=self.model_name, + dataset_collection_name=self.collection_name, + evaluation_name=f"{self.evaluation_name}_custom_client", + headers=headers, + aimon_client=custom_client, # Use client instead of api_key + config=config + ) + + # Basic assertions + assert len(results) == 2 + assert isinstance(results[0], EvaluateResponse) + assert results[0].output in ["Paris is the capital of France.", "Python is a versatile programming language."] + assert hasattr(results[0].response, 'status') + assert results[0].response.status == 200 + + self.log_info("Test completed successfully", "Custom client test passed") + + except Exception as e: + self.log_info("Test error", str(e)) + raise \ No newline at end of file diff --git a/tests/test_low_level_api.py b/tests/test_low_level_api.py new file mode 100644 index 0000000..8c550ae --- /dev/null +++ b/tests/test_low_level_api.py @@ -0,0 +1,495 @@ +import os +import pytest +import logging +import json +import time +from aimon import Client, APIStatusError + +class TestLowLevelAPIWithRealService: + """Test the low-level API client functions with the real AIMon service.""" + + def setup_method(self, method): + """Setup method for each test.""" + self.api_key = os.getenv("AIMON_API_KEY") + if not self.api_key: + pytest.skip("AIMON_API_KEY environment variable not set") + + self.logger = logging.getLogger(f"test_low_level_{method.__name__}") + + # Create a real client + self.client = Client(auth_header=f"Bearer {self.api_key}") + + # Create unique names for resources to avoid conflicts + self.timestamp = int(time.time()) + self.prefix = f"test_ll_{self.timestamp}" + self.model_name = f"{self.prefix}_model" + self.app_name = f"{self.prefix}_app" + self.dataset_name = f"{self.prefix}_dataset.csv" + self.collection_name = f"{self.prefix}_collection" + self.evaluation_name = f"{self.prefix}_evaluation" + + self.temp_files_to_remove = [] + + def teardown_method(self, method): + """Cleanup temporary files created during tests.""" + for file_path in self.temp_files_to_remove: + try: + os.remove(file_path) + self.log_info(f"Cleaned up temporary file: {file_path}") + except OSError as e: + self.log_info(f"Error removing temporary file {file_path}: {e}") + self.temp_files_to_remove.clear() + + + @pytest.fixture(autouse=True) + def _setup_logging(self, caplog): + """Setup logging with caplog fixture.""" + caplog.set_level(logging.INFO) + self.caplog = caplog + + def log_info(self, title, data=""): + """Log data to the test log.""" + if isinstance(data, dict) or isinstance(data, list): + try: + formatted_data = json.dumps(data, indent=2, default=str) + self.logger.info(f"{title}: {formatted_data}") + except Exception as e: + self.logger.info(f"{title}: {data} (JSON formatting failed: {e})") + elif data: + self.logger.info(f"{title}: {data}") + else: + self.logger.info(title) + + def create_temp_dataset_file(self, filename_suffix=""): + """Creates a temporary CSV file for dataset upload.""" + dataset_content = """context_docs,user_query,output +"Document 1 context","Query 1","Output 1" +"Document 2 context","Query 2","Output 2" +""" + temp_file_path = f"{self.prefix}_temp_dataset{filename_suffix}.csv" + with open(temp_file_path, 'w') as f: + f.write(dataset_content) + self.temp_files_to_remove.append(temp_file_path) # Register for cleanup + self.log_info(f"Created temporary dataset file: {temp_file_path}") + return temp_file_path + + # --- Test Methods --- + + def test_user_validate(self): + """Test client.users.validate with a valid API key.""" + self.log_info("Starting test: User Validate") + try: + validation_response = self.client.users.validate(self.api_key) + self.log_info("Validation Response", validation_response.model_dump()) + assert validation_response.id is not None + assert validation_response.email is not None + except Exception as e: + self.log_info("Test error", str(e)) + pytest.fail(f"User validation failed: {e}") + self.log_info("Test completed successfully: User Validate") + + def test_user_validate_invalid_key(self): + """Test client.users.validate with an invalid API key.""" + self.log_info("Starting test: User Validate Invalid Key") + invalid_client = Client(auth_header="Bearer invalid_key") + with pytest.raises(APIStatusError) as exc_info: + invalid_client.users.validate("invalid_key") + self.log_info(f"Received expected error for invalid key: {exc_info.value}") + assert exc_info.value.status_code == 401 # Expecting Unauthorized + self.log_info("Test completed successfully: User Validate Invalid Key") + + def test_model_create_retrieve_list(self): + """Test client.models.create, retrieve, and list.""" + self.log_info("Starting test: Model Create, Retrieve, List") + + # Create + try: + create_response = self.client.models.create( + name=self.model_name, + type="text", + description=f"Test model {self.timestamp}" + ) + self.log_info("Create Response", create_response.model_dump()) + assert create_response.name == self.model_name + model_id = create_response.id + except Exception as e: + self.log_info("Model creation failed", str(e)) + pytest.fail(f"Model creation failed: {e}") + + # Retrieve + try: + retrieve_response = self.client.models.retrieve(name=self.model_name, type="text") + self.log_info("Retrieve Response", retrieve_response.model_dump()) + assert retrieve_response.id == model_id + assert retrieve_response.name == self.model_name + except Exception as e: + self.log_info("Model retrieval failed", str(e)) + pytest.fail(f"Model retrieval failed: {e}") + + # List (and check if our model type is present) + try: + list_response = self.client.models.list() + # Response is List[str] containing model type names + self.log_info("List Response (Model Types)", list_response) + # Check if the type we created with ('text') is in the list + found = "text" in list_response + assert found, f"Model type 'text' not found in list: {list_response}" + except Exception as e: + self.log_info("Model listing failed", str(e)) + pytest.fail(f"Model listing failed: {e}") + + self.log_info("Test completed successfully: Model Create, Retrieve, List") + + def test_application_create_retrieve_delete(self): + """Test client.applications.create, retrieve, and delete.""" + self.log_info("Starting test: Application Create, Retrieve, Delete") + + # Prerequisite: Create a model + try: + model = self.client.models.create(name=self.model_name, type="text", description=f"Prereq model {self.timestamp}") + self.log_info(f"Prerequisite model created: {model.name} (ID: {model.id})") + except Exception as e: + pytest.skip(f"Skipping application test due to model creation failure: {e}") + + app_id = None + # Create + try: + create_response = self.client.applications.create( + name=self.app_name, + model_name=self.model_name, + stage="evaluation", + type="qa" + ) + self.log_info("Create Response", create_response.model_dump()) + assert create_response.name == self.app_name + assert create_response.api_model_name == self.model_name + app_id = create_response.id + app_version = create_response.version + except Exception as e: + self.log_info("Application creation failed", str(e)) + pytest.fail(f"Application creation failed: {e}") + + # Retrieve + try: + retrieve_response = self.client.applications.retrieve( + name=self.app_name, + stage="evaluation", + type="qa" + ) + self.log_info("Retrieve Response", retrieve_response.model_dump()) + assert retrieve_response.id == app_id + assert retrieve_response.name == self.app_name + except Exception as e: + self.log_info("Application retrieval failed", str(e)) + pytest.fail(f"Application retrieval failed: {e}") + + # Delete + try: + delete_response = self.client.applications.delete( + name=self.app_name, + version=str(app_version), # API expects string version + stage="evaluation" + ) + self.log_info("Delete Response", delete_response.model_dump()) + assert delete_response.message == "Application deleted successfully." + except Exception as e: + self.log_info("Application deletion failed", str(e)) + # Don't fail the test for deletion failure, but log it. Cleanup might fail. + self.logger.warning(f"Application deletion failed: {e}") + + # Try retrieving again to confirm deletion (should fail) + try: + with pytest.raises(APIStatusError) as exc_info: + self.client.applications.retrieve( + name=self.app_name, + stage="evaluation", + type="qa" + ) + self.log_info(f"Received expected error after delete: {exc_info.value}") + assert exc_info.value.status_code == 404 # Expecting Not Found + except Exception as e: + pytest.fail(f"Unexpected error trying to retrieve deleted application: {e}") + + self.log_info("Test completed successfully: Application Create, Retrieve, Delete") + + def test_dataset_create_retrieve_by_name(self): + """Test client.datasets.create and retrieve by name (list).""" + self.log_info("Starting test: Dataset Create, Retrieve by Name") + + # Create temp file + temp_file_path = self.create_temp_dataset_file() + + # Create Dataset + created_dataset_sha = None + try: + with open(temp_file_path, 'rb') as f: + dataset_args = json.dumps({ + "name": self.dataset_name, + "description": f"Test dataset {self.timestamp}" + }) + create_response = self.client.datasets.create( + file=f, + json_data=dataset_args + ) + self.log_info("Create Response", create_response.model_dump()) + assert create_response.name == self.dataset_name + assert create_response.sha is not None + created_dataset_sha = create_response.sha + except Exception as e: + self.log_info("Dataset creation failed", str(e)) + pytest.fail(f"Dataset creation failed: {e}") + + # Retrieve by name using list(name=...) + try: + # Add a small delay in case creation is eventually consistent + time.sleep(2) + # Call list with the required name argument + retrieved_dataset = self.client.datasets.list(name=self.dataset_name) + self.log_info("Retrieve by Name (List) Response", retrieved_dataset.model_dump()) + # Assert the SHA of the retrieved dataset matches the created one + assert retrieved_dataset.sha == created_dataset_sha + except Exception as e: + self.log_info("Dataset retrieval by name (list) failed", str(e)) + pytest.fail(f"Dataset retrieval by name (list) failed: {e}") + + self.log_info("Test completed successfully: Dataset Create, Retrieve by Name") + + def test_dataset_collection_create_retrieve(self): + """Test client.datasets.collection.create and retrieve.""" + self.log_info("Starting test: Dataset Collection Create, Retrieve") + + # Prerequisite: Create two datasets + dataset_shas = [] + for i in range(2): + temp_file_path = self.create_temp_dataset_file(filename_suffix=f"_{i}") + dataset_base_name = f"{self.prefix}_dataset_{i}.csv" + try: + with open(temp_file_path, 'rb') as f: + dataset_args = json.dumps({"name": dataset_base_name, "description": f"Test dataset {i}"}) + dataset = self.client.datasets.create(file=f, json_data=dataset_args) + dataset_shas.append(dataset.sha) + self.log_info(f"Prerequisite dataset {i} created: {dataset_base_name} (SHA: {dataset.sha})") + except Exception as e: + pytest.skip(f"Skipping collection test due to dataset creation failure: {e}") + + if len(dataset_shas) < 2: + pytest.skip("Skipping collection test as prerequisite datasets couldn't be created.") + + collection_id = None + # Create Collection + try: + create_response = self.client.datasets.collection.create( + name=self.collection_name, + dataset_ids=dataset_shas, + description=f"Test collection {self.timestamp}" + ) + self.log_info("Create Response", create_response.model_dump()) + assert create_response.name == self.collection_name + assert create_response.id is not None + collection_id = create_response.id + except Exception as e: + self.log_info("Collection creation failed", str(e)) + pytest.fail(f"Collection creation failed: {e}") + + # Retrieve Collection + try: + # Add a small delay if needed + time.sleep(1) + # Retrieve using name instead of id + retrieve_response = self.client.datasets.collection.retrieve(name=self.collection_name) + self.log_info("Retrieve Response", retrieve_response.model_dump()) + assert retrieve_response.id == collection_id + assert retrieve_response.name == self.collection_name + assert set(retrieve_response.dataset_ids) == set(dataset_shas) + except Exception as e: + self.log_info("Collection retrieval failed", str(e)) + pytest.fail(f"Collection retrieval failed: {e}") + + self.log_info("Test completed successfully: Dataset Collection Create, Retrieve") + + def test_evaluation_create_retrieve(self): + """Test client.evaluations.create and retrieve.""" + self.log_info("Starting test: Evaluation Create, Retrieve") + + # Prerequisites: model, application, dataset collection + try: + model = self.client.models.create(name=self.model_name, type="text", description=f"Prereq model {self.timestamp}") + app = self.client.applications.create(name=self.app_name, model_name=model.name, stage="evaluation", type="qa") + + temp_file_path = self.create_temp_dataset_file() + with open(temp_file_path, 'rb') as f: + # Add description field to json_data + dataset_args = json.dumps({"name": self.dataset_name, "description": f"Prereq dataset {self.timestamp}"}) + dataset = self.client.datasets.create(file=f, json_data=dataset_args) + + # Add description to collection create call + collection = self.client.datasets.collection.create(name=self.collection_name, dataset_ids=[dataset.sha], description=f"Prereq collection {self.timestamp}") + self.log_info("Prerequisites created", {"model": model.id, "app": app.id, "collection": collection.id}) + except Exception as e: + pytest.skip(f"Skipping evaluation test due to prerequisite creation failure: {e}") + + evaluation_id = None + # Create Evaluation + try: + create_response = self.client.evaluations.create( + name=self.evaluation_name, + application_id=app.id, + model_id=model.id, + dataset_collection_id=collection.id + ) + self.log_info("Create Response", create_response.model_dump()) + assert create_response.name == self.evaluation_name + assert create_response.id is not None + evaluation_id = create_response.id + except Exception as e: + self.log_info("Evaluation creation failed", str(e)) + pytest.fail(f"Evaluation creation failed: {e}") + + # Retrieve Evaluation + try: + # Add a small delay if needed + time.sleep(1) + # Retrieve using name instead of id, returns a list + retrieve_response_list = self.client.evaluations.retrieve(name=self.evaluation_name) + self.log_info("Retrieve Response List", retrieve_response_list) + + # Expecting one evaluation with this unique name + assert isinstance(retrieve_response_list, list) and len(retrieve_response_list) == 1 + retrieved_evaluation = retrieve_response_list[0] + + # Log the specific evaluation object + self.log_info("Retrieved Evaluation Object", retrieved_evaluation.model_dump()) + + assert retrieved_evaluation.id == evaluation_id + assert retrieved_evaluation.name == self.evaluation_name + except Exception as e: + self.log_info("Evaluation retrieval failed", str(e)) + pytest.fail(f"Evaluation retrieval failed: {e}") + + self.log_info("Test completed successfully: Evaluation Create, Retrieve") + + + def test_evaluation_run_create(self): + """Test client.evaluations.run.create.""" + self.log_info("Starting test: Evaluation Run Create") + + # Prerequisites: model, application, dataset collection, evaluation + try: + model = self.client.models.create(name=self.model_name, type="text", description=f"Prereq model {self.timestamp}") + app = self.client.applications.create(name=self.app_name, model_name=model.name, stage="evaluation", type="qa") + temp_file_path = self.create_temp_dataset_file() + with open(temp_file_path, 'rb') as f: + # Add description field to json_data + dataset_args = json.dumps({"name": self.dataset_name, "description": f"Prereq dataset {self.timestamp}"}) + dataset = self.client.datasets.create(file=f, json_data=dataset_args) + # Add description to collection create call + collection = self.client.datasets.collection.create(name=self.collection_name, dataset_ids=[dataset.sha], description=f"Prereq collection {self.timestamp}") + evaluation = self.client.evaluations.create( + name=self.evaluation_name, + application_id=app.id, + model_id=model.id, + dataset_collection_id=collection.id + ) + self.log_info("Prerequisites created", {"evaluation": evaluation.id}) + except Exception as e: + pytest.skip(f"Skipping evaluation run test due to prerequisite creation failure: {e}") + + # Create Evaluation Run + try: + metrics_config = {'hallucination': {'detector_name': 'default'}, 'toxicity': {'detector_name': 'default'}} + create_response = self.client.evaluations.run.create( + evaluation_id=evaluation.id, + metrics_config=metrics_config + ) + self.log_info("Create Run Response", create_response.model_dump()) + assert create_response.evaluation_id == evaluation.id + assert create_response.id is not None + self.log_info("Create Run Response. metrics config?", create_response) + # assert 'hallucination' in create_response.metrics_config + except Exception as e: + self.log_info("Evaluation run creation failed", str(e)) + pytest.fail(f"Evaluation run creation failed: {e}") + + self.log_info("Test completed successfully: Evaluation Run Create") + + + def test_analyze_create(self): + """Test client.analyze.create with a basic payload.""" + self.log_info("Starting test: Analyze Create") + + # Prerequisites: Application (can be dev or prod) + try: + model = self.client.models.create(name=self.model_name, type="text", description=f"Prereq model {self.timestamp}") + app = self.client.applications.create(name=self.app_name, model_name=model.name, stage="production", type="qa") + self.log_info(f"Prerequisite app created: {app.name} (ID: {app.id}, Version: {app.version})") + except Exception as e: + pytest.skip(f"Skipping analyze test due to prerequisite creation failure: {e}") + + # Analyze Create + try: + analyze_payload = { + "application_id": app.id, + "version": app.version, + "context_docs": ["Test context document."], + "output": "Test output generated by LLM.", + "user_query": "Test user query.", + "config": {'hallucination': {'detector_name': 'default'}} # Optional config + } + # The analyze endpoint expects a list of payloads + create_response = self.client.analyze.create(body=[analyze_payload]) + self.log_info("Analyze Create Response", create_response.model_dump()) + assert create_response.status == 200 + assert "successfully sent" in create_response.message.lower() + except Exception as e: + self.log_info("Analyze create failed", str(e)) + pytest.fail(f"Analyze create failed: {e}") + + self.log_info("Test completed successfully: Analyze Create") + + def test_inference_detect(self): + """Test client.inference.detect with a basic payload.""" + self.log_info("Starting test: Inference Detect") + + # Inference Detect + try: + detect_payload = { + "context": ["This is the context for the inference test."], + "generated_text": "This is the generated text to check for issues.", + "user_query": "What was generated?", + "config": { + 'hallucination': {'detector_name': 'default'}, + 'toxicity': {'detector_name': 'default'} + } + } + # Pass the payload as a list to the 'body' argument + detect_response_list = self.client.inference.detect(body=[detect_payload]) + + # Log the raw response list structure + self.log_info("Raw Inference Detect Response List", detect_response_list) + + # Response is a list, get the first item's result + assert isinstance(detect_response_list, list) and len(detect_response_list) > 0 + detect_result = detect_response_list[0].result + + # Log the actual result object + if hasattr(detect_result, 'model_dump'): + result_dict = detect_result.model_dump() + elif isinstance(detect_result, dict): + result_dict = detect_result # Already a dict + else: + result_dict = str(detect_result) # Fallback + + self.log_info("Inference Detect Result Object", result_dict) + print("Inference Detect Result Object", result_dict) + + # Assertions on the result object + assert hasattr(detect_result, 'hallucination') + assert hasattr(detect_result, 'toxicity') + assert 'score' in detect_result.hallucination + assert 'score' in detect_result.toxicity + + except Exception as e: + self.log_info("Inference detect failed", str(e)) + pytest.fail(f"Inference detect failed: {e}") + + self.log_info("Test completed successfully: Inference Detect") \ No newline at end of file