diff --git a/haystack/preview/components/file_converters/azure.py b/haystack/preview/components/file_converters/azure.py index be699ea04d..fb1ab76905 100644 --- a/haystack/preview/components/file_converters/azure.py +++ b/haystack/preview/components/file_converters/azure.py @@ -1,5 +1,6 @@ from pathlib import Path -from typing import List, Union, Dict, Any +from typing import List, Union, Dict, Any, Optional +import os from haystack.preview.lazy_imports import LazyImport from haystack.preview import component, Document, default_to_dict @@ -22,22 +23,33 @@ class AzureOCRDocumentConverter: to set up your resource. """ - def __init__(self, endpoint: str, api_key: str, model_id: str = "prebuilt-read"): + def __init__(self, endpoint: str, api_key: Optional[str] = None, model_id: str = "prebuilt-read"): """ Create an AzureOCRDocumentConverter component. :param endpoint: The endpoint of your Azure resource. - :param api_key: The key of your Azure resource. + :param api_key: The key of your Azure resource. It can be + explicitly provided or automatically read from the + environment variable AZURE_AI_API_KEY (recommended). :param model_id: The model ID of the model you want to use. Please refer to [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/choose-model-feature) for a list of available models. Default: `"prebuilt-read"`. """ azure_import.check() + if api_key is None: + try: + api_key = os.environ["AZURE_AI_API_KEY"] + except KeyError as e: + raise ValueError( + "AzureOCRDocumentConverter expects an Azure Credential key. " + "Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly." + ) from e + + self.api_key = api_key self.document_analysis_client = DocumentAnalysisClient( endpoint=endpoint, credential=AzureKeyCredential(api_key) ) self.endpoint = endpoint - self.api_key = api_key self.model_id = model_id @component.output_types(documents=List[Document], azure=List[Dict]) @@ -70,7 +82,7 @@ def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ - return default_to_dict(self, endpoint=self.endpoint, api_key=self.api_key, model_id=self.model_id) + return default_to_dict(self, endpoint=self.endpoint, model_id=self.model_id) @staticmethod def _convert_azure_result_to_document(result: "AnalyzeResult", file_suffix: str) -> Document: diff --git a/haystack/preview/components/websearch/serper_dev.py b/haystack/preview/components/websearch/serper_dev.py index 3023603e26..e034092f98 100644 --- a/haystack/preview/components/websearch/serper_dev.py +++ b/haystack/preview/components/websearch/serper_dev.py @@ -1,4 +1,5 @@ import json +import os import logging from typing import Dict, List, Optional, Any @@ -26,13 +27,15 @@ class SerperDevWebSearch: def __init__( self, - api_key: str, + api_key: Optional[str] = None, top_k: Optional[int] = 10, allowed_domains: Optional[List[str]] = None, search_params: Optional[Dict[str, Any]] = None, ): """ - :param api_key: API key for the SerperDev API. + :param api_key: API key for the SerperDev API. It can be + explicitly provided or automatically read from the + environment variable SERPERDEV_API_KEY (recommended). :param top_k: Number of documents to return. :param allowed_domains: List of domains to limit the search to. :param search_params: Additional parameters passed to the SerperDev API. @@ -40,6 +43,13 @@ def __init__( See the [Serper Dev website](https://serper.dev/) for more details. """ if api_key is None: + try: + api_key = os.environ["SERPERDEV_API_KEY"] + except KeyError as e: + raise ValueError( + "SerperDevWebSearch expects an API key. " + "Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly." + ) from e raise ValueError("API key for SerperDev API must be set.") self.api_key = api_key self.top_k = top_k @@ -51,11 +61,7 @@ def to_dict(self) -> Dict[str, Any]: Serialize this component to a dictionary. """ return default_to_dict( - self, - api_key=self.api_key, - top_k=self.top_k, - allowed_domains=self.allowed_domains, - search_params=self.search_params, + self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params ) @component.output_types(documents=List[Document], links=List[str]) diff --git a/releasenotes/notes/remove-api-key-from-serialization-2474a1539b86e233.yaml b/releasenotes/notes/remove-api-key-from-serialization-2474a1539b86e233.yaml new file mode 100644 index 0000000000..e1a879816e --- /dev/null +++ b/releasenotes/notes/remove-api-key-from-serialization-2474a1539b86e233.yaml @@ -0,0 +1,4 @@ +--- +preview: + - | + Remove "api_key" from serialization of AzureOCRDocumentConverter and SerperDevWebSearch. diff --git a/test/preview/components/file_converters/test_azure_ocr_doc_converter.py b/test/preview/components/file_converters/test_azure_ocr_doc_converter.py index f0707f0912..2369a32f32 100644 --- a/test/preview/components/file_converters/test_azure_ocr_doc_converter.py +++ b/test/preview/components/file_converters/test_azure_ocr_doc_converter.py @@ -7,17 +7,19 @@ class TestAzureOCRDocumentConverter: + @pytest.mark.unit + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("AZURE_AI_API_KEY", raising=False) + with pytest.raises(ValueError, match="AzureOCRDocumentConverter expects an Azure Credential key"): + AzureOCRDocumentConverter(endpoint="test_endpoint") + @pytest.mark.unit def test_to_dict(self): component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key") data = component.to_dict() assert data == { "type": "AzureOCRDocumentConverter", - "init_parameters": { - "api_key": "test_credential_key", - "endpoint": "test_endpoint", - "model_id": "prebuilt-read", - }, + "init_parameters": {"endpoint": "test_endpoint", "model_id": "prebuilt-read"}, } @pytest.mark.unit diff --git a/test/preview/components/websearch/test_serperdev.py b/test/preview/components/websearch/test_serperdev.py index 87a1738731..e94b7fd726 100644 --- a/test/preview/components/websearch/test_serperdev.py +++ b/test/preview/components/websearch/test_serperdev.py @@ -108,6 +108,12 @@ def mock_serper_dev_search_result(): class TestSerperDevSearchAPI: + @pytest.mark.unit + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("SERPERDEV_API_KEY", raising=False) + with pytest.raises(ValueError, match="SerperDevWebSearch expects an API key"): + SerperDevWebSearch() + @pytest.mark.unit def test_to_dict(self): component = SerperDevWebSearch( @@ -116,12 +122,7 @@ def test_to_dict(self): data = component.to_dict() assert data == { "type": "SerperDevWebSearch", - "init_parameters": { - "api_key": "test_key", - "top_k": 10, - "allowed_domains": ["test.com"], - "search_params": {"param": "test"}, - }, + "init_parameters": {"top_k": 10, "allowed_domains": ["test.com"], "search_params": {"param": "test"}}, } @pytest.mark.unit