Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions libs/azure-storage/langchain_azure_storage/document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import azure.core.credentials_async
import azure.identity
import azure.identity.aio
from azure.storage.blob import BlobClient, ContainerClient
from azure.storage.blob import BlobClient, BlobProperties, ContainerClient
from azure.storage.blob.aio import BlobClient as AsyncBlobClient
from azure.storage.blob.aio import ContainerClient as AsyncContainerClient
from langchain_core.document_loaders import BaseLoader
Expand Down Expand Up @@ -256,7 +256,11 @@ def _yield_blob_names(self, container_client: ContainerClient) -> Iterator[str]:
if self._blob_names is not None:
yield from self._blob_names
else:
yield from container_client.list_blob_names(name_starts_with=self._prefix)
for blob in container_client.list_blobs(
name_starts_with=self._prefix, include="metadata"
):
if not self._is_adls_directory(blob):
yield blob.name

async def _ayield_blob_names(
self, async_container_client: AsyncContainerClient
Expand All @@ -265,14 +269,18 @@ async def _ayield_blob_names(
for blob_name in self._blob_names:
yield blob_name
else:
async for blob_name in async_container_client.list_blob_names(
name_starts_with=self._prefix
async for blob in async_container_client.list_blobs(
name_starts_with=self._prefix, include="metadata"
):
yield blob_name
if not self._is_adls_directory(blob):
yield blob.name

def _get_default_document(
self, blob_content: bytes, blob_client: Union[BlobClient, AsyncBlobClient]
) -> Document:
return Document(
blob_content.decode("utf-8"), metadata={"source": blob_client.url}
)

def _is_adls_directory(self, blob: BlobProperties) -> bool:
return blob.size == 0 and blob.metadata.get("hdi_isfolder") == "true"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In between the blob.size and the hdi_isfolder check, let's add a blob.metadata. So:

blob.size == 0 and blob.metadata and blob.metadata.get("hdi_isfolder") == "true"

Mainly blobs could be empty and not be directories and have metadata that equates to None. This additional clause will make sure we short-circuit early to avoid trying to call .get() on a None value and cause the entire listing to error out.

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from langchain_azure_storage.document_loaders import AzureBlobStorageLoader
from tests.utils import (
CustomCSVLoader,
get_datalake_test_blobs,
get_expected_datalake_blobs,
get_expected_documents,
get_first_column_csv_loader,
get_test_blobs,
Expand Down Expand Up @@ -149,3 +151,75 @@ async def test_alazy_load_with_loader_factory_configurations(
assert [
doc async for doc in loader.alazy_load()
] == expected_custom_csv_documents_with_columns


class TestDataLakeDirectoryFiltering:
@pytest.fixture(scope="class")
def datalake_account_url(self) -> str:
datalake_account_url = os.getenv("AZURE_DATALAKE_ACCOUNT_URL")
if datalake_account_url is None:
raise ValueError(
"AZURE_DATALAKE_ACCOUNT_URL environment variable must be set for "
"this test."
)
return datalake_account_url

@pytest.fixture(scope="class")
def datalake_container_name(self) -> str:
return "document-loader-tests"

@pytest.fixture(scope="class")
def datalake_blob_service_client(
self, datalake_account_url: str
) -> BlobServiceClient:
return BlobServiceClient(
account_url=datalake_account_url, credential=DefaultAzureCredential()
)

@pytest.fixture(scope="class")
def datalake_container_setup(
self,
datalake_blob_service_client: BlobServiceClient,
) -> Iterator[None]:
container_client = datalake_blob_service_client.get_container_client(
"document-loader-tests"
)
container_client.create_container()
for blob in get_datalake_test_blobs():
blob_client = container_client.get_blob_client(blob["blob_name"])
blob_client.upload_blob(
blob["blob_content"], metadata=blob["metadata"], overwrite=True
)

yield
container_client.delete_container()

def test_datalake_excludes_directories(
self,
container_name: str,
datalake_account_url: str,
datalake_container_setup: Iterator[None],
) -> None:
loader = AzureBlobStorageLoader(
account_url=datalake_account_url,
container_name=container_name,
)
expected_documents_list = get_expected_documents(
get_expected_datalake_blobs(), datalake_account_url, container_name
)
assert list(loader.lazy_load()) == expected_documents_list

async def test_async_datalake_excludes_directories(
self,
container_name: str,
datalake_account_url: str,
datalake_container_setup: Iterator[None],
) -> None:
loader = AzureBlobStorageLoader(
account_url=datalake_account_url,
container_name=container_name,
)
expected_documents_list = get_expected_documents(
get_expected_datalake_blobs(), datalake_account_url, container_name
)
assert [doc async for doc in loader.alazy_load()] == expected_documents_list
157 changes: 143 additions & 14 deletions libs/azure-storage/tests/unit_tests/test_document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from langchain_azure_storage.document_loaders import AzureBlobStorageLoader
from tests.utils import (
CustomCSVLoader,
get_datalake_test_blobs,
get_expected_datalake_blobs,
get_expected_documents,
get_first_column_csv_loader,
get_test_blobs,
Expand Down Expand Up @@ -62,14 +64,61 @@ def mock_container_client(
"langchain_azure_storage.document_loaders.ContainerClient"
) as mock_container_client_cls:
mock_client = MagicMock(spec=ContainerClient)
mock_client.list_blob_names.return_value = [
blob["blob_name"] for blob in get_test_blobs()
]
mock_blobs = []
for blob in get_test_blobs():
mock_blob = MagicMock()
mock_blob.name = blob["blob_name"]
mock_blobs.append(mock_blob)

mock_client.list_blobs.return_value = mock_blobs
mock_client.get_blob_client.side_effect = get_mock_blob_client
mock_container_client_cls.return_value = mock_client
yield mock_container_client_cls, mock_client


@pytest.fixture
def get_mock_datalake_blob_client(
account_url: str, container_name: str
) -> Callable[[str], MagicMock]:
def _get_blob_client(blob_name: str) -> MagicMock:
mock_blob_client = MagicMock(spec=BlobClient)
mock_blob_client.url = f"{account_url}/{container_name}/{blob_name}"
mock_blob_client.blob_name = blob_name
mock_blob_data = MagicMock(spec=StorageStreamDownloader)
content = next(
blob["blob_content"]
for blob in get_datalake_test_blobs()
if blob["blob_name"] == blob_name
)
mock_blob_data.readall.return_value = content.encode("utf-8")
mock_blob_client.download_blob.return_value = mock_blob_data
return mock_blob_client

return _get_blob_client


@pytest.fixture
def mock_datalake_container_client(
get_mock_datalake_blob_client: Callable[[str], MagicMock],
) -> Iterator[Tuple[MagicMock, MagicMock]]:
with patch(
"langchain_azure_storage.document_loaders.ContainerClient"
) as mock_container_client_cls:
mock_client = MagicMock(spec=ContainerClient)
mock_blobs = []
for blob in get_datalake_test_blobs():
mock_blob = MagicMock()
mock_blob.name = blob["blob_name"]
mock_blob.size = blob["size"]
mock_blob.metadata = blob["metadata"]
mock_blobs.append(mock_blob)

mock_client.list_blobs.return_value = mock_blobs
mock_client.get_blob_client.side_effect = get_mock_datalake_blob_client
mock_container_client_cls.return_value = mock_client
yield mock_container_client_cls, mock_client


@pytest.fixture()
def get_async_mock_blob_client(
account_url: str, container_name: str
Expand Down Expand Up @@ -99,20 +148,67 @@ def async_mock_container_client(
"langchain_azure_storage.document_loaders.AsyncContainerClient"
) as async_mock_container_client_cls:

async def get_async_blob_names(**kwargs: Any) -> AsyncIterator[str]:
async def get_async_blobs(**kwargs: Any) -> AsyncIterator[MagicMock]:
prefix = kwargs.get("name_starts_with")
for blob_name in [
blob["blob_name"] for blob in get_test_blobs(prefix=prefix)
]:
yield blob_name
for blob in get_test_blobs(prefix=prefix):
mock_blob = MagicMock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When mocking the blob properties, we should use be throwing in a BlobProperties spec to give a little more safety on public interfaces.

mock_blob.name = blob["blob_name"]
yield mock_blob

async_mock_client = AsyncMock(spec=AsyncContainerClient)
async_mock_client.list_blob_names.side_effect = get_async_blob_names
async_mock_client.list_blobs.side_effect = get_async_blobs
async_mock_client.get_blob_client.side_effect = get_async_mock_blob_client
async_mock_container_client_cls.return_value = async_mock_client
yield async_mock_container_client_cls, async_mock_client


@pytest.fixture
def get_async_mock_datalake_blob_client(
account_url: str, container_name: str
) -> Callable[[str], AsyncMock]:
def _get_async_blob_client(blob_name: str) -> AsyncMock:
async_mock_blob_client = AsyncMock(spec=AsyncBlobClient)
async_mock_blob_client.url = f"{account_url}/{container_name}/{blob_name}"
async_mock_blob_client.blob_name = blob_name
mock_blob_data = AsyncMock(spec=AsyncStorageStreamDownloader)
content = next(
blob["blob_content"]
for blob in get_datalake_test_blobs()
if blob["blob_name"] == blob_name
)
mock_blob_data.readall.return_value = content.encode("utf-8")
async_mock_blob_client.download_blob.return_value = mock_blob_data
return async_mock_blob_client

return _get_async_blob_client


@pytest.fixture
def async_mock_datalake_container_client(
get_async_mock_datalake_blob_client: Callable[[str], AsyncMock],
) -> Iterator[Tuple[AsyncMock, AsyncMock]]:
with patch(
"langchain_azure_storage.document_loaders.AsyncContainerClient"
) as async_mock_container_client_cls:

async def get_async_blobs(**kwargs: Any) -> AsyncIterator[MagicMock]:
prefix = kwargs.get("name_starts_with")
for blob in get_datalake_test_blobs(prefix=prefix):
mock_blob = MagicMock()
mock_blob.name = blob["blob_name"]
mock_blob.size = int(blob["size"])
mock_blob.metadata = blob["metadata"]
yield mock_blob
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it makes sense if we take the BlobProperty scaffolding logic and roll it up into the get_test_blobs() and get_datalake_test_blobs() utilities. And in order to get mocks back we could either expose an as_mock boolean or expose it as new utilities get_test_mock_blobs(). I don't have a strong preference on either option, but I think that should be able to help us consolidate some of the logic in all of the places we need to build up these mock blobs, which has expanded quite a bit now.


async_mock_client = AsyncMock(spec=AsyncContainerClient)
async_mock_client.list_blobs.side_effect = get_async_blobs
async_mock_client.get_blob_client.side_effect = (
get_async_mock_datalake_blob_client
)
async_mock_container_client_cls.return_value = async_mock_client
yield async_mock_container_client_cls, async_mock_client


def test_lazy_load(
account_url: str,
container_name: str,
Expand Down Expand Up @@ -145,20 +241,25 @@ def test_lazy_load_with_blob_names(
get_test_blobs(blob_names), account_url, container_name
)
assert list(loader.lazy_load()) == expected_documents_list
assert mock_client.list_blob_names.call_count == 0
assert mock_client.list_blobs.call_count == 0


def test_get_blob_client(
create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader],
mock_container_client: Tuple[MagicMock, MagicMock],
) -> None:
_, mock_client = mock_container_client
mock_client.list_blob_names.return_value = ["text_file.txt"]
mock_blob = MagicMock()
mock_blob.name = "text_file.txt"
mock_blob.size = 12
mock_client.list_blobs.return_value = [mock_blob]

loader = create_azure_blob_storage_loader(prefix="text")
list(loader.lazy_load())
mock_client.get_blob_client.assert_called_once_with("text_file.txt")
mock_client.list_blob_names.assert_called_once_with(name_starts_with="text")
mock_client.list_blobs.assert_called_once_with(
name_starts_with="text", include="metadata"
)


def test_default_credential(
Expand Down Expand Up @@ -271,7 +372,7 @@ async def test_alazy_load_with_blob_names(
get_test_blobs(blob_names), account_url, container_name
)
assert [doc async for doc in loader.alazy_load()] == expected_documents_list
assert async_mock_client.list_blob_names.call_count == 0
assert async_mock_client.list_blobs.call_count == 0


async def test_get_async_blob_client(
Expand All @@ -282,7 +383,9 @@ async def test_get_async_blob_client(
loader = create_azure_blob_storage_loader(prefix="text")
[doc async for doc in loader.alazy_load()]
async_mock_client.get_blob_client.assert_called_once_with("text_file.txt")
async_mock_client.list_blob_names.assert_called_once_with(name_starts_with="text")
async_mock_client.list_blobs.assert_called_once_with(
name_starts_with="text", include="metadata"
)


async def test_async_token_credential(
Expand Down Expand Up @@ -367,3 +470,29 @@ async def test_async_user_agent(
[doc async for doc in loader.alazy_load()]
client_kwargs = async_mock_container_client_cls.call_args[1]
assert client_kwargs["user_agent"] == user_agent


def test_datalake_excludes_directories(
account_url: str,
container_name: str,
create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader],
mock_datalake_container_client: Tuple[MagicMock, MagicMock],
) -> None:
loader = create_azure_blob_storage_loader()
expected_documents = get_expected_documents(
get_expected_datalake_blobs(), account_url, container_name
)
assert list(loader.lazy_load()) == expected_documents


async def test_async_datalake_excludes_directories(
account_url: str,
container_name: str,
create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader],
async_mock_datalake_container_client: Tuple[AsyncMock, AsyncMock],
) -> None:
loader = create_azure_blob_storage_loader()
expected_documents = get_expected_documents(
get_expected_datalake_blobs(), account_url, container_name
)
assert [doc async for doc in loader.alazy_load()] == expected_documents
Loading