Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions libs/azure-storage/langchain_azure_storage/document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import azure.core.credentials_async
import azure.identity
import azure.identity.aio
from azure.storage.blob import BlobClient, ContainerClient
from azure.storage.blob import BlobClient, BlobProperties, ContainerClient
from azure.storage.blob.aio import BlobClient as AsyncBlobClient
from azure.storage.blob.aio import ContainerClient as AsyncContainerClient
from langchain_core.document_loaders import BaseLoader
Expand Down Expand Up @@ -256,7 +256,11 @@ def _yield_blob_names(self, container_client: ContainerClient) -> Iterator[str]:
if self._blob_names is not None:
yield from self._blob_names
else:
yield from container_client.list_blob_names(name_starts_with=self._prefix)
for blob in container_client.list_blobs(
name_starts_with=self._prefix, include="metadata"
):
if not self._is_adls_directory(blob):
yield blob.name

async def _ayield_blob_names(
self, async_container_client: AsyncContainerClient
Expand All @@ -265,14 +269,22 @@ async def _ayield_blob_names(
for blob_name in self._blob_names:
yield blob_name
else:
async for blob_name in async_container_client.list_blob_names(
name_starts_with=self._prefix
async for blob in async_container_client.list_blobs(
name_starts_with=self._prefix, include="metadata"
):
yield blob_name
if not self._is_adls_directory(blob):
yield blob.name

def _get_default_document(
self, blob_content: bytes, blob_client: Union[BlobClient, AsyncBlobClient]
) -> Document:
return Document(
blob_content.decode("utf-8"), metadata={"source": blob_client.url}
)

def _is_adls_directory(self, blob: BlobProperties) -> bool:
return (
blob.size == 0
and blob.metadata is not None
and blob.metadata.get("hdi_isfolder") == "true"
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from langchain_azure_storage.document_loaders import AzureBlobStorageLoader
from tests.utils import (
CustomCSVLoader,
get_datalake_test_blobs,
get_expected_documents,
get_first_column_csv_loader,
get_test_blobs,
Expand Down Expand Up @@ -63,8 +64,8 @@ def upload_blobs_to_container(
([], None),
(None, None),
(None, "text"),
("text_file.txt", None),
(["text_file.txt", "json_file.json", "csv_file.csv"], None),
("directory/test_file.txt", None),
(["directory/test_file.txt", "json_file.json", "csv_file.csv"], None),
],
)
def test_lazy_load(
Expand Down Expand Up @@ -109,8 +110,8 @@ def test_lazy_load_with_loader_factory_configurations(
([], None),
(None, None),
(None, "text"),
("text_file.txt", None),
(["text_file.txt", "json_file.json", "csv_file.csv"], None),
("directory/test_file.txt", None),
(["directory/test_file.txt", "json_file.json", "csv_file.csv"], None),
],
)
async def test_alazy_load(
Expand Down Expand Up @@ -149,3 +150,75 @@ async def test_alazy_load_with_loader_factory_configurations(
assert [
doc async for doc in loader.alazy_load()
] == expected_custom_csv_documents_with_columns


class TestDataLakeDirectoryFiltering:
@pytest.fixture(scope="class")
def datalake_account_url(self) -> str:
datalake_account_url = os.getenv("AZURE_DATALAKE_ACCOUNT_URL")
if datalake_account_url is None:
raise ValueError(
"AZURE_DATALAKE_ACCOUNT_URL environment variable must be set for "
"this test."
)
return datalake_account_url

@pytest.fixture(scope="class")
def datalake_container_name(self) -> str:
return "document-loader-tests"

@pytest.fixture(scope="class")
def datalake_blob_service_client(
self, datalake_account_url: str
) -> BlobServiceClient:
return BlobServiceClient(
account_url=datalake_account_url, credential=DefaultAzureCredential()
)

@pytest.fixture(scope="class")
def datalake_container_setup(
self,
datalake_blob_service_client: BlobServiceClient,
) -> Iterator[None]:
container_client = datalake_blob_service_client.get_container_client(
"document-loader-tests"
)
container_client.create_container()
for blob in get_datalake_test_blobs(include_directories=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the datalake tests, should we be including the directories in what we upload? To my understanding, ADLS will automatically create the directory and that is more representative of how customers will have ADLS directories in the first place.

blob_client = container_client.get_blob_client(blob["blob_name"])
blob_client.upload_blob(
blob["blob_content"], metadata=blob["metadata"], overwrite=True
)

yield
container_client.delete_container()

def test_datalake_excludes_directories(
self,
container_name: str,
datalake_account_url: str,
datalake_container_setup: Iterator[None],
) -> None:
loader = AzureBlobStorageLoader(
account_url=datalake_account_url,
container_name=container_name,
)
expected_documents_list = get_expected_documents(
get_datalake_test_blobs(), datalake_account_url, container_name
)
assert list(loader.lazy_load()) == expected_documents_list

async def test_async_datalake_excludes_directories(
self,
container_name: str,
datalake_account_url: str,
datalake_container_setup: Iterator[None],
) -> None:
loader = AzureBlobStorageLoader(
account_url=datalake_account_url,
container_name=container_name,
)
expected_documents_list = get_expected_documents(
get_datalake_test_blobs(), datalake_account_url, container_name
)
assert [doc async for doc in loader.alazy_load()] == expected_documents_list
Loading