diff --git a/libs/azure-storage/langchain_azure_storage/document_loaders.py b/libs/azure-storage/langchain_azure_storage/document_loaders.py index 7691cd97..777dd3b9 100644 --- a/libs/azure-storage/langchain_azure_storage/document_loaders.py +++ b/libs/azure-storage/langchain_azure_storage/document_loaders.py @@ -17,7 +17,7 @@ import azure.core.credentials_async import azure.identity import azure.identity.aio -from azure.storage.blob import BlobClient, ContainerClient +from azure.storage.blob import BlobClient, BlobProperties, ContainerClient from azure.storage.blob.aio import BlobClient as AsyncBlobClient from azure.storage.blob.aio import ContainerClient as AsyncContainerClient from langchain_core.document_loaders import BaseLoader @@ -256,7 +256,11 @@ def _yield_blob_names(self, container_client: ContainerClient) -> Iterator[str]: if self._blob_names is not None: yield from self._blob_names else: - yield from container_client.list_blob_names(name_starts_with=self._prefix) + for blob in container_client.list_blobs( + name_starts_with=self._prefix, include="metadata" + ): + if not self._is_adls_directory(blob): + yield blob.name async def _ayield_blob_names( self, async_container_client: AsyncContainerClient @@ -265,10 +269,11 @@ async def _ayield_blob_names( for blob_name in self._blob_names: yield blob_name else: - async for blob_name in async_container_client.list_blob_names( - name_starts_with=self._prefix + async for blob in async_container_client.list_blobs( + name_starts_with=self._prefix, include="metadata" ): - yield blob_name + if not self._is_adls_directory(blob): + yield blob.name def _get_default_document( self, blob_content: bytes, blob_client: Union[BlobClient, AsyncBlobClient] @@ -276,3 +281,10 @@ def _get_default_document( return Document( blob_content.decode("utf-8"), metadata={"source": blob_client.url} ) + + def _is_adls_directory(self, blob: BlobProperties) -> bool: + return ( + blob.size == 0 + and blob.metadata is not None + and blob.metadata.get("hdi_isfolder") == "true" + ) diff --git a/libs/azure-storage/tests/integration_tests/test_document_loaders.py b/libs/azure-storage/tests/integration_tests/test_document_loaders.py index c82098e8..48ed0d2f 100644 --- a/libs/azure-storage/tests/integration_tests/test_document_loaders.py +++ b/libs/azure-storage/tests/integration_tests/test_document_loaders.py @@ -9,6 +9,7 @@ from langchain_azure_storage.document_loaders import AzureBlobStorageLoader from tests.utils import ( CustomCSVLoader, + get_datalake_test_blobs, get_expected_documents, get_first_column_csv_loader, get_test_blobs, @@ -63,8 +64,8 @@ def upload_blobs_to_container( ([], None), (None, None), (None, "text"), - ("text_file.txt", None), - (["text_file.txt", "json_file.json", "csv_file.csv"], None), + ("directory/test_file.txt", None), + (["directory/test_file.txt", "json_file.json", "csv_file.csv"], None), ], ) def test_lazy_load( @@ -109,8 +110,8 @@ def test_lazy_load_with_loader_factory_configurations( ([], None), (None, None), (None, "text"), - ("text_file.txt", None), - (["text_file.txt", "json_file.json", "csv_file.csv"], None), + ("directory/test_file.txt", None), + (["directory/test_file.txt", "json_file.json", "csv_file.csv"], None), ], ) async def test_alazy_load( @@ -149,3 +150,73 @@ async def test_alazy_load_with_loader_factory_configurations( assert [ doc async for doc in loader.alazy_load() ] == expected_custom_csv_documents_with_columns + + +class TestDataLakeDirectoryFiltering: + @pytest.fixture(scope="class") + def datalake_account_url(self) -> str: + datalake_account_url = os.getenv("AZURE_DATALAKE_ACCOUNT_URL") + if datalake_account_url is None: + raise ValueError( + "AZURE_DATALAKE_ACCOUNT_URL environment variable must be set for " + "this test." + ) + return datalake_account_url + + @pytest.fixture(scope="class") + def datalake_container_name(self) -> str: + return "document-loader-tests" + + @pytest.fixture(scope="class") + def datalake_blob_service_client( + self, datalake_account_url: str + ) -> BlobServiceClient: + return BlobServiceClient( + account_url=datalake_account_url, credential=DefaultAzureCredential() + ) + + @pytest.fixture(scope="class") + def datalake_container_setup( + self, + datalake_blob_service_client: BlobServiceClient, + ) -> Iterator[None]: + container_client = datalake_blob_service_client.get_container_client( + "document-loader-tests" + ) + container_client.create_container() + for blob in get_datalake_test_blobs(): + blob_client = container_client.get_blob_client(blob["blob_name"]) + blob_client.upload_blob(blob["blob_content"], overwrite=True) + + yield + container_client.delete_container() + + def test_datalake_excludes_directories( + self, + container_name: str, + datalake_account_url: str, + datalake_container_setup: Iterator[None], + ) -> None: + loader = AzureBlobStorageLoader( + account_url=datalake_account_url, + container_name=container_name, + ) + expected_documents_list = get_expected_documents( + get_datalake_test_blobs(), datalake_account_url, container_name + ) + assert list(loader.lazy_load()) == expected_documents_list + + async def test_async_datalake_excludes_directories( + self, + container_name: str, + datalake_account_url: str, + datalake_container_setup: Iterator[None], + ) -> None: + loader = AzureBlobStorageLoader( + account_url=datalake_account_url, + container_name=container_name, + ) + expected_documents_list = get_expected_documents( + get_datalake_test_blobs(), datalake_account_url, container_name + ) + assert [doc async for doc in loader.alazy_load()] == expected_documents_list diff --git a/libs/azure-storage/tests/unit_tests/test_document_loaders.py b/libs/azure-storage/tests/unit_tests/test_document_loaders.py index b785470d..cbd50a8d 100644 --- a/libs/azure-storage/tests/unit_tests/test_document_loaders.py +++ b/libs/azure-storage/tests/unit_tests/test_document_loaders.py @@ -17,9 +17,11 @@ from langchain_azure_storage.document_loaders import AzureBlobStorageLoader from tests.utils import ( CustomCSVLoader, + get_datalake_test_blobs, get_expected_documents, get_first_column_csv_loader, get_test_blobs, + get_test_mock_blobs, ) @@ -62,14 +64,49 @@ def mock_container_client( "langchain_azure_storage.document_loaders.ContainerClient" ) as mock_container_client_cls: mock_client = MagicMock(spec=ContainerClient) - mock_client.list_blob_names.return_value = [ - blob["blob_name"] for blob in get_test_blobs() - ] + mock_client.list_blobs.return_value = get_test_mock_blobs(get_test_blobs()) mock_client.get_blob_client.side_effect = get_mock_blob_client mock_container_client_cls.return_value = mock_client yield mock_container_client_cls, mock_client +@pytest.fixture +def get_mock_datalake_blob_client( + account_url: str, container_name: str +) -> Callable[[str], MagicMock]: + def _get_blob_client(blob_name: str) -> MagicMock: + mock_blob_client = MagicMock(spec=BlobClient) + mock_blob_client.url = f"{account_url}/{container_name}/{blob_name}" + mock_blob_client.blob_name = blob_name + mock_blob_data = MagicMock(spec=StorageStreamDownloader) + content = next( + blob["blob_content"] + for blob in get_datalake_test_blobs(include_directories=True) + if blob["blob_name"] == blob_name + ) + mock_blob_data.readall.return_value = content.encode("utf-8") + mock_blob_client.download_blob.return_value = mock_blob_data + return mock_blob_client + + return _get_blob_client + + +@pytest.fixture +def mock_datalake_container_client( + get_mock_datalake_blob_client: Callable[[str], MagicMock], +) -> Iterator[Tuple[MagicMock, MagicMock]]: + with patch( + "langchain_azure_storage.document_loaders.ContainerClient" + ) as mock_container_client_cls: + mock_client = MagicMock(spec=ContainerClient) + mock_client.list_blobs.return_value = get_test_mock_blobs( + get_datalake_test_blobs(include_directories=True) + ) + mock_client.get_blob_client.side_effect = get_mock_datalake_blob_client + mock_container_client_cls.return_value = mock_client + yield mock_container_client_cls, mock_client + + @pytest.fixture() def get_async_mock_blob_client( account_url: str, container_name: str @@ -99,20 +136,63 @@ def async_mock_container_client( "langchain_azure_storage.document_loaders.AsyncContainerClient" ) as async_mock_container_client_cls: - async def get_async_blob_names(**kwargs: Any) -> AsyncIterator[str]: + async def get_async_blobs(**kwargs: Any) -> AsyncIterator[MagicMock]: prefix = kwargs.get("name_starts_with") - for blob_name in [ - blob["blob_name"] for blob in get_test_blobs(prefix=prefix) - ]: - yield blob_name + for mock_blob in get_test_mock_blobs(get_test_blobs(prefix=prefix)): + yield mock_blob async_mock_client = AsyncMock(spec=AsyncContainerClient) - async_mock_client.list_blob_names.side_effect = get_async_blob_names + async_mock_client.list_blobs.side_effect = get_async_blobs async_mock_client.get_blob_client.side_effect = get_async_mock_blob_client async_mock_container_client_cls.return_value = async_mock_client yield async_mock_container_client_cls, async_mock_client +@pytest.fixture +def get_async_mock_datalake_blob_client( + account_url: str, container_name: str +) -> Callable[[str], AsyncMock]: + def _get_async_blob_client(blob_name: str) -> AsyncMock: + async_mock_blob_client = AsyncMock(spec=AsyncBlobClient) + async_mock_blob_client.url = f"{account_url}/{container_name}/{blob_name}" + async_mock_blob_client.blob_name = blob_name + mock_blob_data = AsyncMock(spec=AsyncStorageStreamDownloader) + content = next( + blob["blob_content"] + for blob in get_datalake_test_blobs(include_directories=True) + if blob["blob_name"] == blob_name + ) + mock_blob_data.readall.return_value = content.encode("utf-8") + async_mock_blob_client.download_blob.return_value = mock_blob_data + return async_mock_blob_client + + return _get_async_blob_client + + +@pytest.fixture +def async_mock_datalake_container_client( + get_async_mock_datalake_blob_client: Callable[[str], AsyncMock], +) -> Iterator[Tuple[AsyncMock, AsyncMock]]: + with patch( + "langchain_azure_storage.document_loaders.AsyncContainerClient" + ) as async_mock_container_client_cls: + + async def get_async_blobs(**kwargs: Any) -> AsyncIterator[MagicMock]: + prefix = kwargs.get("name_starts_with") + for mock_blob in get_test_mock_blobs( + get_datalake_test_blobs(prefix=prefix, include_directories=True) + ): + yield mock_blob + + async_mock_client = AsyncMock(spec=AsyncContainerClient) + async_mock_client.list_blobs.side_effect = get_async_blobs + async_mock_client.get_blob_client.side_effect = ( + get_async_mock_datalake_blob_client + ) + async_mock_container_client_cls.return_value = async_mock_client + yield async_mock_container_client_cls, async_mock_client + + def test_lazy_load( account_url: str, container_name: str, @@ -128,8 +208,8 @@ def test_lazy_load( @pytest.mark.parametrize( "blob_names", [ - "text_file.txt", - ["text_file.txt", "json_file.json"], + "directory/test_file.txt", + ["directory/test_file.txt", "json_file.json"], ], ) def test_lazy_load_with_blob_names( @@ -145,7 +225,7 @@ def test_lazy_load_with_blob_names( get_test_blobs(blob_names), account_url, container_name ) assert list(loader.lazy_load()) == expected_documents_list - assert mock_client.list_blob_names.call_count == 0 + assert mock_client.list_blobs.call_count == 0 def test_get_blob_client( @@ -153,12 +233,15 @@ def test_get_blob_client( mock_container_client: Tuple[MagicMock, MagicMock], ) -> None: _, mock_client = mock_container_client - mock_client.list_blob_names.return_value = ["text_file.txt"] - - loader = create_azure_blob_storage_loader(prefix="text") + mock_client.list_blobs.return_value = get_test_mock_blobs( + get_test_blobs(blob_names=["json_file.json"]) + ) + loader = create_azure_blob_storage_loader(prefix="json") list(loader.lazy_load()) - mock_client.get_blob_client.assert_called_once_with("text_file.txt") - mock_client.list_blob_names.assert_called_once_with(name_starts_with="text") + mock_client.get_blob_client.assert_called_once_with("json_file.json") + mock_client.list_blobs.assert_called_once_with( + name_starts_with="json", include="metadata" + ) def test_default_credential( @@ -166,7 +249,7 @@ def test_default_credential( create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader], ) -> None: mock_container_client_cls, _ = mock_container_client - loader = create_azure_blob_storage_loader(blob_names="text_file.txt") + loader = create_azure_blob_storage_loader(blob_names="directory/test_file.txt") list(loader.lazy_load()) cred = mock_container_client_cls.call_args[1]["credential"] assert isinstance(cred, azure.identity.DefaultAzureCredential) @@ -181,7 +264,7 @@ def test_override_credential( mock_container_client_cls, _ = mock_container_client mock_credential = AzureSasCredential("test_sas_token") loader = create_azure_blob_storage_loader( - blob_names="text_file.txt", credential=mock_credential + blob_names="directory/test_file.txt", credential=mock_credential ) list(loader.lazy_load()) assert mock_container_client_cls.call_args[1]["credential"] is mock_credential @@ -194,7 +277,7 @@ def test_async_credential_provided_to_sync( mock_credential = DefaultAzureCredential() loader = create_azure_blob_storage_loader( - blob_names="text_file.txt", credential=mock_credential + blob_names="directory/test_file.txt", credential=mock_credential ) with pytest.raises(ValueError, match="Cannot use synchronous load"): list(loader.lazy_load()) @@ -206,7 +289,7 @@ def test_invalid_credential_type( mock_credential = "account-key" with pytest.raises(TypeError, match="Invalid credential type provided."): create_azure_blob_storage_loader( - blob_names="text_file.txt", credential=mock_credential + blob_names="directory/test_file.txt", credential=mock_credential ) @@ -215,7 +298,7 @@ def test_both_blob_names_and_prefix_set( ) -> None: with pytest.raises(ValueError, match="Cannot specify both blob_names and prefix."): create_azure_blob_storage_loader( - blob_names=[blob["blob_name"] for blob in get_test_blobs()], prefix="text" + blob_names=[blob["blob_name"] for blob in get_test_blobs()], prefix="json" ) @@ -254,8 +337,8 @@ async def test_alazy_load( @pytest.mark.parametrize( "blob_names", [ - "text_file.txt", - ["text_file.txt", "json_file.json"], + "directory/test_file.txt", + ["directory/test_file.txt", "json_file.json"], ], ) async def test_alazy_load_with_blob_names( @@ -271,7 +354,7 @@ async def test_alazy_load_with_blob_names( get_test_blobs(blob_names), account_url, container_name ) assert [doc async for doc in loader.alazy_load()] == expected_documents_list - assert async_mock_client.list_blob_names.call_count == 0 + assert async_mock_client.list_blobs.call_count == 0 async def test_get_async_blob_client( @@ -279,10 +362,12 @@ async def test_get_async_blob_client( async_mock_container_client: Tuple[AsyncMock, AsyncMock], ) -> None: _, async_mock_client = async_mock_container_client - loader = create_azure_blob_storage_loader(prefix="text") + loader = create_azure_blob_storage_loader(prefix="json") [doc async for doc in loader.alazy_load()] - async_mock_client.get_blob_client.assert_called_once_with("text_file.txt") - async_mock_client.list_blob_names.assert_called_once_with(name_starts_with="text") + async_mock_client.get_blob_client.assert_called_once_with("json_file.json") + async_mock_client.list_blobs.assert_called_once_with( + name_starts_with="json", include="metadata" + ) async def test_async_token_credential( @@ -294,7 +379,7 @@ async def test_async_token_credential( async_mock_container_client_cls, _ = async_mock_container_client mock_credential = AsyncMock(spec=AsyncTokenCredential) loader = create_azure_blob_storage_loader( - blob_names="text_file.txt", credential=mock_credential + blob_names="json_file.json", credential=mock_credential ) [doc async for doc in loader.alazy_load()] assert async_mock_container_client_cls.call_args[1]["credential"] is mock_credential @@ -305,7 +390,7 @@ async def test_default_async_credential( create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader], ) -> None: async_mock_container_client_cls, _ = async_mock_container_client - loader = create_azure_blob_storage_loader(blob_names="text_file.txt") + loader = create_azure_blob_storage_loader(blob_names="json_file.json") [doc async for doc in loader.alazy_load()] cred = async_mock_container_client_cls.call_args[1]["credential"] assert isinstance(cred, azure.identity.aio.DefaultAzureCredential) @@ -317,7 +402,7 @@ async def test_sync_credential_provided_to_async( from azure.identity import DefaultAzureCredential loader = create_azure_blob_storage_loader( - blob_names="text_file.txt", credential=DefaultAzureCredential() + blob_names="json_file.json", credential=DefaultAzureCredential() ) with pytest.raises(ValueError, match="Cannot use asynchronous load"): [doc async for doc in loader.alazy_load()] @@ -351,7 +436,7 @@ def test_user_agent( ) -> None: mock_container_client_cls, _ = mock_container_client user_agent = f"azpartner-langchain/{__version__}" - loader = create_azure_blob_storage_loader(blob_names="text_file.txt") + loader = create_azure_blob_storage_loader(blob_names="json_file.json") list(loader.lazy_load()) client_kwargs = mock_container_client_cls.call_args[1] assert client_kwargs["user_agent"] == user_agent @@ -363,7 +448,33 @@ async def test_async_user_agent( ) -> None: async_mock_container_client_cls, _ = async_mock_container_client user_agent = f"azpartner-langchain/{__version__}" - loader = create_azure_blob_storage_loader(blob_names="text_file.txt") + loader = create_azure_blob_storage_loader(blob_names="json_file.json") [doc async for doc in loader.alazy_load()] client_kwargs = async_mock_container_client_cls.call_args[1] assert client_kwargs["user_agent"] == user_agent + + +def test_datalake_excludes_directories( + account_url: str, + container_name: str, + create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader], + mock_datalake_container_client: Tuple[MagicMock, MagicMock], +) -> None: + loader = create_azure_blob_storage_loader() + expected_documents = get_expected_documents( + get_datalake_test_blobs(), account_url, container_name + ) + assert list(loader.lazy_load()) == expected_documents + + +async def test_async_datalake_excludes_directories( + account_url: str, + container_name: str, + create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader], + async_mock_datalake_container_client: Tuple[AsyncMock, AsyncMock], +) -> None: + loader = create_azure_blob_storage_loader() + expected_documents = get_expected_documents( + get_datalake_test_blobs(), account_url, container_name + ) + assert [doc async for doc in loader.alazy_load()] == expected_documents diff --git a/libs/azure-storage/tests/utils.py b/libs/azure-storage/tests/utils.py index 440daa25..769872e5 100644 --- a/libs/azure-storage/tests/utils.py +++ b/libs/azure-storage/tests/utils.py @@ -1,16 +1,33 @@ import csv -from typing import Iterable, Iterator, Optional, Union +from typing import Any, Iterable, Iterator, Optional, Union +from unittest.mock import MagicMock +from azure.storage.blob import BlobProperties from langchain_core.document_loaders import BaseLoader from langchain_core.documents.base import Document -_TEST_BLOBS = [ +_TEST_BLOBS: list[dict[str, Any]] = [ { "blob_name": "csv_file.csv", "blob_content": "col1,col2\nval1,val2\nval3,val4", }, - {"blob_name": "json_file.json", "blob_content": "{'test': 'test content'}"}, - {"blob_name": "text_file.txt", "blob_content": "test content"}, + { + "blob_name": "directory/test_file.txt", + "blob_content": "test content", + }, + { + "blob_name": "json_file.json", + "blob_content": "{'test': 'test content'}", + }, +] + +_TEST_DATALAKE_BLOBS: list[dict[str, Any]] = [ + *_TEST_BLOBS, + { + "blob_name": "directory", + "blob_content": "", + "metadata": {"hdi_isfolder": "true"}, + }, ] @@ -51,7 +68,7 @@ def get_first_column_csv_loader(file_path: str) -> CustomCSVLoader: def get_expected_documents( - blobs: list[dict[str, str]], account_url: str, container_name: str + blobs: list[dict[str, Any]], account_url: str, container_name: str ) -> list[Document]: expected_documents_list = [] for blob in blobs: @@ -69,18 +86,63 @@ def get_expected_documents( def get_test_blobs( blob_names: Optional[Union[str, Iterable[str]]] = None, prefix: Optional[str] = None ) -> list[dict[str, str]]: + return _get_test_blobs(_TEST_BLOBS, blob_names, prefix) + + +def get_datalake_test_blobs( + blob_names: Optional[Union[str, Iterable[str]]] = None, + prefix: Optional[str] = None, + include_directories: Optional[bool] = False, +) -> list[dict[str, Any]]: + if not include_directories: + return get_test_blobs(blob_names, prefix) + return _get_test_blobs(_TEST_DATALAKE_BLOBS, blob_names, prefix) + + +def get_test_mock_blobs(blob_list: list[dict[str, Any]]) -> list[MagicMock]: + mock_blobs = [] + for blob in blob_list: + mock_blob = MagicMock(spec=BlobProperties) + mock_blob.name = blob["blob_name"] + mock_blob.size = len(blob["blob_content"]) + mock_blob.metadata = blob.get("metadata", None) + mock_blobs.append(mock_blob) + + return mock_blobs + + +def _get_test_blobs( + blob_list: list[dict[str, Any]], + blob_names: Optional[Union[str, Iterable[str]]] = None, + prefix: Optional[str] = None, +) -> list[dict[str, Any]]: if blob_names is not None: if isinstance(blob_names, str): blob_names = [blob_names] updated_list = [] for name in blob_names: - for blob in _TEST_BLOBS: + for blob in blob_list: if blob["blob_name"] == name: - updated_list.append(blob) + updated_list.append( + { + **blob, + "size": len(blob["blob_content"]), + "metadata": blob.get("metadata", None), + } + ) break return updated_list - if prefix is not None: - return [blob for blob in _TEST_BLOBS if blob["blob_name"].startswith(prefix)] - - return _TEST_BLOBS + prefix_blob_list = ( + [blob for blob in blob_list if blob["blob_name"].startswith(prefix)] + if prefix + else blob_list + ) + return [ + { + **blob, + "size": len(blob["blob_content"]), + "metadata": blob.get("metadata", None), + } + for blob in prefix_blob_list + ]