Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions haystack/preview/components/file_converters/azure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import List, Union, Dict, Any
from typing import List, Union, Dict, Any, Optional
import os

from haystack.preview.lazy_imports import LazyImport
from haystack.preview import component, Document, default_to_dict
Expand All @@ -22,22 +23,33 @@ class AzureOCRDocumentConverter:
to set up your resource.
"""

def __init__(self, endpoint: str, api_key: str, model_id: str = "prebuilt-read"):
def __init__(self, endpoint: str, api_key: Optional[str] = None, model_id: str = "prebuilt-read"):
"""
Create an AzureOCRDocumentConverter component.

:param endpoint: The endpoint of your Azure resource.
:param api_key: The key of your Azure resource.
:param api_key: The key of your Azure resource. It can be
explicitly provided or automatically read from the
environment variable AZURE_AI_API_KEY (recommended).
:param model_id: The model ID of the model you want to use. Please refer to [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/choose-model-feature)
for a list of available models. Default: `"prebuilt-read"`.
"""
azure_import.check()

if api_key is None:
try:
api_key = os.environ["AZURE_AI_API_KEY"]
except KeyError as e:
raise ValueError(
"AzureOCRDocumentConverter expects an Azure Credential key. "
"Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly."
) from e

self.api_key = api_key
self.document_analysis_client = DocumentAnalysisClient(
endpoint=endpoint, credential=AzureKeyCredential(api_key)
)
self.endpoint = endpoint
self.api_key = api_key
self.model_id = model_id

@component.output_types(documents=List[Document], azure=List[Dict])
Expand Down Expand Up @@ -70,7 +82,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, endpoint=self.endpoint, api_key=self.api_key, model_id=self.model_id)
return default_to_dict(self, endpoint=self.endpoint, model_id=self.model_id)

@staticmethod
def _convert_azure_result_to_document(result: "AnalyzeResult", file_suffix: str) -> Document:
Expand Down
20 changes: 13 additions & 7 deletions haystack/preview/components/websearch/serper_dev.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import logging
from typing import Dict, List, Optional, Any

Expand Down Expand Up @@ -26,20 +27,29 @@ class SerperDevWebSearch:

def __init__(
self,
api_key: str,
api_key: Optional[str] = None,
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
search_params: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for the SerperDev API.
:param api_key: API key for the SerperDev API. It can be
explicitly provided or automatically read from the
environment variable SERPERDEV_API_KEY (recommended).
:param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to.
:param search_params: Additional parameters passed to the SerperDev API.
For example, you can set 'num' to 20 to increase the number of search results.
See the [Serper Dev website](https://serper.dev/) for more details.
"""
if api_key is None:
try:
api_key = os.environ["SERPERDEV_API_KEY"]
except KeyError as e:
raise ValueError(
"SerperDevWebSearch expects an API key. "
"Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly."
) from e
raise ValueError("API key for SerperDev API must be set.")
self.api_key = api_key
self.top_k = top_k
Expand All @@ -51,11 +61,7 @@ def to_dict(self) -> Dict[str, Any]:
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
api_key=self.api_key,
top_k=self.top_k,
allowed_domains=self.allowed_domains,
search_params=self.search_params,
self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params
)

@component.output_types(documents=List[Document], links=List[str])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Remove "api_key" from serialization of AzureOCRDocumentConverter and SerperDevWebSearch.
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@


class TestAzureOCRDocumentConverter:
@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("AZURE_AI_API_KEY", raising=False)
with pytest.raises(ValueError, match="AzureOCRDocumentConverter expects an Azure Credential key"):
AzureOCRDocumentConverter(endpoint="test_endpoint")

@pytest.mark.unit
def test_to_dict(self):
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key")
data = component.to_dict()
assert data == {
"type": "AzureOCRDocumentConverter",
"init_parameters": {
"api_key": "test_credential_key",
"endpoint": "test_endpoint",
"model_id": "prebuilt-read",
},
"init_parameters": {"endpoint": "test_endpoint", "model_id": "prebuilt-read"},
}

@pytest.mark.unit
Expand Down
13 changes: 7 additions & 6 deletions test/preview/components/websearch/test_serperdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def mock_serper_dev_search_result():


class TestSerperDevSearchAPI:
@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("SERPERDEV_API_KEY", raising=False)
with pytest.raises(ValueError, match="SerperDevWebSearch expects an API key"):
SerperDevWebSearch()

@pytest.mark.unit
def test_to_dict(self):
component = SerperDevWebSearch(
Expand All @@ -116,12 +122,7 @@ def test_to_dict(self):
data = component.to_dict()
assert data == {
"type": "SerperDevWebSearch",
"init_parameters": {
"api_key": "test_key",
"top_k": 10,
"allowed_domains": ["test.com"],
"search_params": {"param": "test"},
},
"init_parameters": {"top_k": 10, "allowed_domains": ["test.com"], "search_params": {"param": "test"}},
}

@pytest.mark.unit
Expand Down