From 6544a3cb0f9185ad1f82b1563a515b39588eb871 Mon Sep 17 00:00:00 2001 From: Puneet Saraswat Date: Wed, 20 Sep 2023 10:09:41 -0500 Subject: [PATCH 1/4] add azure --- querent/collectors/azure/__init__.py | 0 querent/collectors/azure/azure_collector.py | 66 +++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 querent/collectors/azure/__init__.py create mode 100644 querent/collectors/azure/azure_collector.py diff --git a/querent/collectors/azure/__init__.py b/querent/collectors/azure/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/querent/collectors/azure/azure_collector.py b/querent/collectors/azure/azure_collector.py new file mode 100644 index 00000000..3c5f210b --- /dev/null +++ b/querent/collectors/azure/azure_collector.py @@ -0,0 +1,66 @@ +import asyncio +from typing import AsyncGenerator +import io + +from azure.storage.blob import BlobServiceClient +from querent.config.collector_config import CollectorBackend, AzureCollectConfig +from querent.collectors.collector_base import Collector +from querent.collectors.collector_factory import CollectorFactory +from querent.collectors.collector_result import CollectorResult +from querent.common.uri import Uri + + +class AzureCollector(Collector): + def __init__(self, config: AzureCollectConfig, container_name: str, prefix: str): + self.account_url = config["account_url"] + self.blob_service_client = BlobServiceClient( + account_url=self.account_url, credential=config["credential"] + ) + self.container_name = container_name + self.chunk_size = 1024 + self.prefix = prefix + + async def connect(self): + pass # No asynchronous connection needed for the Azure Blob Storage client + + async def disconnect(self): + pass # No asynchronous disconnect needed for the Azure Blob Storage client + + async def poll(self) -> AsyncGenerator[CollectorResult, None]: + container_client = self.blob_service_client.get_container_client( + self.container_name + ) + + async for blob in container_client.list_blobs(name_starts_with=self.prefix): + file = self.download_blob_as_byte_stream(container_client, blob.name) + async for chunk in self.read_chunks(file): + yield CollectorResult({"object_key": blob.name, "chunk": chunk}) + + async def read_chunks(self, file): + while True: + chunk = file.read(self.chunk_size) + if not chunk: + break + yield chunk + + def download_blob_as_byte_stream(self, container_client, blob_name): + blob_client = container_client.get_blob_client(blob_name) + blob_properties = blob_client.get_blob_properties() + byte_stream = io.BytesIO() + + if blob_properties["size"] > 0: + stream = blob_client.download_blob() + byte_stream.write(stream.readall()) + byte_stream.seek(0) # Rewind the stream to the beginning + + return byte_stream + + +class AzureCollectorFactory(CollectorFactory): + def backend(self) -> CollectorBackend: + return CollectorBackend.AzureBlobStorage + + def resolve(self, uri: Uri, config: AzureCollectConfig) -> Collector: + container_name = uri.path.strip("/") + prefix = uri.query.get("prefix", "") + return AzureCollector(config, container_name, prefix) From c64701a6f508c770f5f4c7182389ba6960eae304 Mon Sep 17 00:00:00 2001 From: Puneet Saraswat Date: Wed, 20 Sep 2023 10:15:16 -0500 Subject: [PATCH 2/4] add depdency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index df9b1fc9..95763836 100644 --- a/requirements.txt +++ b/requirements.txt @@ -170,3 +170,4 @@ tika openpyxl coverage pytest-cov +azure-storage-blob From 42b19098dabdf6e52bbd77b35ac154e87340088a Mon Sep 17 00:00:00 2001 From: Puneet Saraswat Date: Wed, 20 Sep 2023 10:17:22 -0500 Subject: [PATCH 3/4] add config --- querent/collectors/azure/azure_collector.py | 4 ++-- querent/config/collector_config.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/querent/collectors/azure/azure_collector.py b/querent/collectors/azure/azure_collector.py index 3c5f210b..2335cf69 100644 --- a/querent/collectors/azure/azure_collector.py +++ b/querent/collectors/azure/azure_collector.py @@ -12,9 +12,9 @@ class AzureCollector(Collector): def __init__(self, config: AzureCollectConfig, container_name: str, prefix: str): - self.account_url = config["account_url"] + self.account_url = config.account_url self.blob_service_client = BlobServiceClient( - account_url=self.account_url, credential=config["credential"] + account_url=self.account_url, credential=config.credential ) self.container_name = container_name self.chunk_size = 1024 diff --git a/querent/config/collector_config.py b/querent/config/collector_config.py index d53d6116..d69c6762 100644 --- a/querent/config/collector_config.py +++ b/querent/config/collector_config.py @@ -21,6 +21,10 @@ class FSCollectorConfig(BaseModel): root_path: str chunk_size: int = 1024 +class AzureCollectConfig(BaseModel): + account_url: str + credential: str + chunk_size: int = 1024 class S3CollectConfig(BaseModel): bucket: str From 0d7150ba5a24a2afb1a29fe16975098d1ed7e9bf Mon Sep 17 00:00:00 2001 From: Puneet Saraswat Date: Wed, 20 Sep 2023 13:19:16 -0500 Subject: [PATCH 4/4] add tests and update protocol/factories --- .github/workflows/pytest.yml | 3 +- querent/collectors/azure/azure_collector.py | 55 ++++++++++++------ querent/collectors/collector_resolver.py | 6 +- querent/common/uri.py | 3 + querent/config/collector_config.py | 10 +++- tests/test_azure_collector.py | 62 +++++++++++++++++++++ 6 files changed, 118 insertions(+), 21 deletions(-) create mode 100644 tests/test_azure_collector.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 2dfeaa9b..092d1f28 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -17,7 +17,8 @@ env: GOOGLE_APPLICATION_CREDENTIALS: ${{ secrets.GOOGLE_APPLICATION_CREDENTIALS }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - + AZURE_STORAGE_ACCOUNT_URL: ${{ secrets.AZURE_STORAGE_ACCOUNT_URL }} + AZURE_STORAGE_ACCOUNT_KEY: ${{ secrets.AZURE_STORAGE_ACCOUNT_KEY }} jobs: pytest: runs-on: ubuntu-latest diff --git a/querent/collectors/azure/azure_collector.py b/querent/collectors/azure/azure_collector.py index 2335cf69..31a330ef 100644 --- a/querent/collectors/azure/azure_collector.py +++ b/querent/collectors/azure/azure_collector.py @@ -1,4 +1,3 @@ -import asyncio from typing import AsyncGenerator import io @@ -11,30 +10,52 @@ class AzureCollector(Collector): - def __init__(self, config: AzureCollectConfig, container_name: str, prefix: str): + def __init__(self, config: AzureCollectConfig): + self.connection_string = config.connection_string self.account_url = config.account_url - self.blob_service_client = BlobServiceClient( - account_url=self.account_url, credential=config.credential - ) - self.container_name = container_name + self.credentials = config.credentials + self.container_name = config.container self.chunk_size = 1024 - self.prefix = prefix + self.prefix = config.prefix + self.blob_service_client = None + self.container_client = None async def connect(self): - pass # No asynchronous connection needed for the Azure Blob Storage client + if self.connection_string: + self.blob_service_client = BlobServiceClient.from_connection_string( + conn_str=self.connection_string, + credential=self.credentials, + ) + elif self.account_url: + self.blob_service_client = BlobServiceClient( + account_url=self.account_url, + credential=self.credentials, + ) + self.container_client = self.blob_service_client.get_container_client( + self.container_name + ) async def disconnect(self): pass # No asynchronous disconnect needed for the Azure Blob Storage client async def poll(self) -> AsyncGenerator[CollectorResult, None]: - container_client = self.blob_service_client.get_container_client( - self.container_name - ) + try: + if not self.container_client: + await self.connect() - async for blob in container_client.list_blobs(name_starts_with=self.prefix): - file = self.download_blob_as_byte_stream(container_client, blob.name) - async for chunk in self.read_chunks(file): - yield CollectorResult({"object_key": blob.name, "chunk": chunk}) + blob_list = self.container_client.list_blobs(name_starts_with=self.prefix) + for blob in blob_list: + file = self.download_blob_as_byte_stream( + self.container_client, blob.name + ) + async for chunk in self.read_chunks(file): + yield CollectorResult({"object_key": blob.name, "chunk": chunk}) + except Exception as e: + # Handle exceptions gracefully, e.g., log the error + print(f"An error occurred: {e}") + finally: + # Disconnect the client when done + await self.disconnect() async def read_chunks(self, file): while True: @@ -61,6 +82,4 @@ def backend(self) -> CollectorBackend: return CollectorBackend.AzureBlobStorage def resolve(self, uri: Uri, config: AzureCollectConfig) -> Collector: - container_name = uri.path.strip("/") - prefix = uri.query.get("prefix", "") - return AzureCollector(config, container_name, prefix) + return AzureCollector(config) diff --git a/querent/collectors/collector_resolver.py b/querent/collectors/collector_resolver.py index 98fc28cd..d8e83e9d 100644 --- a/querent/collectors/collector_resolver.py +++ b/querent/collectors/collector_resolver.py @@ -1,4 +1,5 @@ from typing import Optional +from querent.collectors.azure.azure_collector import AzureCollectorFactory from querent.collectors.gcs.gcs_collector import GCSCollectorFactory from querent.collectors.aws.aws_collector import AWSCollectorFactory from querent.collectors.fs.fs_collector import FSCollectorFactory @@ -18,7 +19,8 @@ def __init__(self): CollectorBackend.LocalFile: FSCollectorFactory(), CollectorBackend.S3: AWSCollectorFactory(), CollectorBackend.WebScraper: WebScraperFactory(), - CollectorBackend.Gcs: GCSCollectorFactory() + CollectorBackend.Gcs: GCSCollectorFactory(), + CollectorBackend.AzureBlobStorage: AzureCollectorFactory(), # Add other collector factories as needed } @@ -44,6 +46,8 @@ def _determine_backend(self, protocol: Protocol) -> CollectorBackend: return CollectorBackend.Gcs elif protocol.is_webscraper(): return CollectorBackend.WebScraper + elif protocol.is_azure_blob_storage(): + return CollectorBackend.AzureBlobStorage else: raise CollectorResolverError( CollectorErrorKind.NotSupported, "Unknown backend" diff --git a/querent/common/uri.py b/querent/common/uri.py index 08b39ef7..ffecf8e0 100644 --- a/querent/common/uri.py +++ b/querent/common/uri.py @@ -43,6 +43,9 @@ def is_database(self) -> bool: def is_webscraper(self) -> bool: return self == Protocol.Webscraper + def is_azure_blob_storage(self) -> bool: + return self == Protocol.Azure + class Uri: PROTOCOL_SEPARATOR = "://" diff --git a/querent/config/collector_config.py b/querent/config/collector_config.py index d69c6762..68fb9a4d 100644 --- a/querent/config/collector_config.py +++ b/querent/config/collector_config.py @@ -8,6 +8,7 @@ class CollectorBackend(str, Enum): WebScraper = "webscraper" S3 = "s3" Gcs = "gs" + AzureBlobStorage = "azure" class CollectConfig(BaseModel): @@ -22,8 +23,11 @@ class FSCollectorConfig(BaseModel): chunk_size: int = 1024 class AzureCollectConfig(BaseModel): + connection_string: str account_url: str - credential: str + credentials: str + container: str + prefix: str chunk_size: int = 1024 class S3CollectConfig(BaseModel): @@ -64,6 +68,10 @@ def from_collect_config(cls, collect_config: CollectConfig): return cls( backend=CollectorBackend.WebScraper, config=WebScraperConfig() ) + elif collect_config.backend == CollectorBackend.Azure: + return cls( + backend=CollectorBackend.Azure, config=AzureCollectConfig() + ) else: raise ValueError( f"Unsupported collector backend: {collect_config.backend}") diff --git a/tests/test_azure_collector.py b/tests/test_azure_collector.py new file mode 100644 index 00000000..8d75848c --- /dev/null +++ b/tests/test_azure_collector.py @@ -0,0 +1,62 @@ +import asyncio + +from querent.config.collector_config import AzureCollectConfig +from querent.collectors.collector_resolver import CollectorResolver +from querent.collectors.azure.azure_collector import AzureCollectorFactory +from querent.common.uri import Uri +from querent.config.collector_config import CollectorBackend +import pytest +import os +from dotenv import load_dotenv + +load_dotenv() + +azure_connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") +azure_account_url = os.getenv("AZURE_STORAGE_ACCOUNT_URL") +azure_account_key = os.getenv("AZURE_STORAGE_ACCOUNT_KEY") + + +@pytest.fixture +def azure_config(): + config = AzureCollectConfig( + connection_string="", + account_url=azure_account_url, + credentials=azure_account_key, + chunk=1024, + container="testfiles", + prefix="", + ) + return config + + +def test_azure_collector_factory(): + factory = AzureCollectorFactory() + assert factory.backend() == CollectorBackend.AzureBlobStorage + + +# Modify this function to test the Azure collector + + +@pytest.mark.asyncio +async def test_azure_collector(azure_config): + config = azure_config + uri = Uri("azure://" + config.container) + resolver = CollectorResolver() + collector = resolver.resolve(uri, config) + assert collector is not None + + async def poll_and_print(): + counter = 0 + async for result in collector.poll(): + assert not result.is_error() + chunk = result.unwrap() + assert chunk is not None + if chunk: + counter += 1 + assert counter == 1433 + + await poll_and_print() + + +if __name__ == "__main__": + asyncio.run(test_azure_collector())