-
Notifications
You must be signed in to change notification settings - Fork 41
Document loaders datalake support #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
| from langchain_azure_storage.document_loaders import AzureBlobStorageLoader | ||
| from tests.utils import ( | ||
| CustomCSVLoader, | ||
| get_datalake_test_blobs, | ||
| get_expected_datalake_blobs, | ||
| get_expected_documents, | ||
| get_first_column_csv_loader, | ||
| get_test_blobs, | ||
|
|
@@ -62,14 +64,61 @@ def mock_container_client( | |
| "langchain_azure_storage.document_loaders.ContainerClient" | ||
| ) as mock_container_client_cls: | ||
| mock_client = MagicMock(spec=ContainerClient) | ||
| mock_client.list_blob_names.return_value = [ | ||
| blob["blob_name"] for blob in get_test_blobs() | ||
| ] | ||
| mock_blobs = [] | ||
| for blob in get_test_blobs(): | ||
| mock_blob = MagicMock() | ||
| mock_blob.name = blob["blob_name"] | ||
| mock_blobs.append(mock_blob) | ||
|
|
||
| mock_client.list_blobs.return_value = mock_blobs | ||
kyleknap marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| mock_client.get_blob_client.side_effect = get_mock_blob_client | ||
| mock_container_client_cls.return_value = mock_client | ||
| yield mock_container_client_cls, mock_client | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def get_mock_datalake_blob_client( | ||
| account_url: str, container_name: str | ||
| ) -> Callable[[str], MagicMock]: | ||
| def _get_blob_client(blob_name: str) -> MagicMock: | ||
| mock_blob_client = MagicMock(spec=BlobClient) | ||
| mock_blob_client.url = f"{account_url}/{container_name}/{blob_name}" | ||
| mock_blob_client.blob_name = blob_name | ||
| mock_blob_data = MagicMock(spec=StorageStreamDownloader) | ||
| content = next( | ||
| blob["blob_content"] | ||
| for blob in get_datalake_test_blobs() | ||
| if blob["blob_name"] == blob_name | ||
| ) | ||
| mock_blob_data.readall.return_value = content.encode("utf-8") | ||
| mock_blob_client.download_blob.return_value = mock_blob_data | ||
| return mock_blob_client | ||
|
|
||
| return _get_blob_client | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_datalake_container_client( | ||
| get_mock_datalake_blob_client: Callable[[str], MagicMock], | ||
| ) -> Iterator[Tuple[MagicMock, MagicMock]]: | ||
| with patch( | ||
| "langchain_azure_storage.document_loaders.ContainerClient" | ||
| ) as mock_container_client_cls: | ||
| mock_client = MagicMock(spec=ContainerClient) | ||
| mock_blobs = [] | ||
| for blob in get_datalake_test_blobs(): | ||
| mock_blob = MagicMock() | ||
| mock_blob.name = blob["blob_name"] | ||
| mock_blob.size = blob["size"] | ||
| mock_blob.metadata = blob["metadata"] | ||
| mock_blobs.append(mock_blob) | ||
|
|
||
| mock_client.list_blobs.return_value = mock_blobs | ||
| mock_client.get_blob_client.side_effect = get_mock_datalake_blob_client | ||
| mock_container_client_cls.return_value = mock_client | ||
| yield mock_container_client_cls, mock_client | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def get_async_mock_blob_client( | ||
| account_url: str, container_name: str | ||
|
|
@@ -99,20 +148,67 @@ def async_mock_container_client( | |
| "langchain_azure_storage.document_loaders.AsyncContainerClient" | ||
| ) as async_mock_container_client_cls: | ||
|
|
||
| async def get_async_blob_names(**kwargs: Any) -> AsyncIterator[str]: | ||
| async def get_async_blobs(**kwargs: Any) -> AsyncIterator[MagicMock]: | ||
| prefix = kwargs.get("name_starts_with") | ||
| for blob_name in [ | ||
| blob["blob_name"] for blob in get_test_blobs(prefix=prefix) | ||
| ]: | ||
| yield blob_name | ||
| for blob in get_test_blobs(prefix=prefix): | ||
| mock_blob = MagicMock() | ||
|
||
| mock_blob.name = blob["blob_name"] | ||
| yield mock_blob | ||
|
|
||
| async_mock_client = AsyncMock(spec=AsyncContainerClient) | ||
| async_mock_client.list_blob_names.side_effect = get_async_blob_names | ||
| async_mock_client.list_blobs.side_effect = get_async_blobs | ||
| async_mock_client.get_blob_client.side_effect = get_async_mock_blob_client | ||
| async_mock_container_client_cls.return_value = async_mock_client | ||
| yield async_mock_container_client_cls, async_mock_client | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def get_async_mock_datalake_blob_client( | ||
| account_url: str, container_name: str | ||
| ) -> Callable[[str], AsyncMock]: | ||
| def _get_async_blob_client(blob_name: str) -> AsyncMock: | ||
| async_mock_blob_client = AsyncMock(spec=AsyncBlobClient) | ||
| async_mock_blob_client.url = f"{account_url}/{container_name}/{blob_name}" | ||
| async_mock_blob_client.blob_name = blob_name | ||
| mock_blob_data = AsyncMock(spec=AsyncStorageStreamDownloader) | ||
| content = next( | ||
| blob["blob_content"] | ||
| for blob in get_datalake_test_blobs() | ||
| if blob["blob_name"] == blob_name | ||
| ) | ||
| mock_blob_data.readall.return_value = content.encode("utf-8") | ||
| async_mock_blob_client.download_blob.return_value = mock_blob_data | ||
| return async_mock_blob_client | ||
|
|
||
| return _get_async_blob_client | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def async_mock_datalake_container_client( | ||
| get_async_mock_datalake_blob_client: Callable[[str], AsyncMock], | ||
| ) -> Iterator[Tuple[AsyncMock, AsyncMock]]: | ||
| with patch( | ||
| "langchain_azure_storage.document_loaders.AsyncContainerClient" | ||
| ) as async_mock_container_client_cls: | ||
|
|
||
| async def get_async_blobs(**kwargs: Any) -> AsyncIterator[MagicMock]: | ||
| prefix = kwargs.get("name_starts_with") | ||
| for blob in get_datalake_test_blobs(prefix=prefix): | ||
| mock_blob = MagicMock() | ||
| mock_blob.name = blob["blob_name"] | ||
| mock_blob.size = int(blob["size"]) | ||
| mock_blob.metadata = blob["metadata"] | ||
| yield mock_blob | ||
|
||
|
|
||
| async_mock_client = AsyncMock(spec=AsyncContainerClient) | ||
| async_mock_client.list_blobs.side_effect = get_async_blobs | ||
| async_mock_client.get_blob_client.side_effect = ( | ||
| get_async_mock_datalake_blob_client | ||
| ) | ||
| async_mock_container_client_cls.return_value = async_mock_client | ||
| yield async_mock_container_client_cls, async_mock_client | ||
|
|
||
|
|
||
| def test_lazy_load( | ||
| account_url: str, | ||
| container_name: str, | ||
|
|
@@ -145,20 +241,25 @@ def test_lazy_load_with_blob_names( | |
| get_test_blobs(blob_names), account_url, container_name | ||
| ) | ||
| assert list(loader.lazy_load()) == expected_documents_list | ||
| assert mock_client.list_blob_names.call_count == 0 | ||
| assert mock_client.list_blobs.call_count == 0 | ||
|
|
||
|
|
||
| def test_get_blob_client( | ||
| create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader], | ||
| mock_container_client: Tuple[MagicMock, MagicMock], | ||
| ) -> None: | ||
| _, mock_client = mock_container_client | ||
| mock_client.list_blob_names.return_value = ["text_file.txt"] | ||
| mock_blob = MagicMock() | ||
| mock_blob.name = "text_file.txt" | ||
| mock_blob.size = 12 | ||
| mock_client.list_blobs.return_value = [mock_blob] | ||
|
|
||
| loader = create_azure_blob_storage_loader(prefix="text") | ||
| list(loader.lazy_load()) | ||
| mock_client.get_blob_client.assert_called_once_with("text_file.txt") | ||
| mock_client.list_blob_names.assert_called_once_with(name_starts_with="text") | ||
| mock_client.list_blobs.assert_called_once_with( | ||
| name_starts_with="text", include="metadata" | ||
| ) | ||
|
|
||
|
|
||
| def test_default_credential( | ||
|
|
@@ -271,7 +372,7 @@ async def test_alazy_load_with_blob_names( | |
| get_test_blobs(blob_names), account_url, container_name | ||
| ) | ||
| assert [doc async for doc in loader.alazy_load()] == expected_documents_list | ||
| assert async_mock_client.list_blob_names.call_count == 0 | ||
| assert async_mock_client.list_blobs.call_count == 0 | ||
|
|
||
|
|
||
| async def test_get_async_blob_client( | ||
|
|
@@ -282,7 +383,9 @@ async def test_get_async_blob_client( | |
| loader = create_azure_blob_storage_loader(prefix="text") | ||
| [doc async for doc in loader.alazy_load()] | ||
| async_mock_client.get_blob_client.assert_called_once_with("text_file.txt") | ||
| async_mock_client.list_blob_names.assert_called_once_with(name_starts_with="text") | ||
| async_mock_client.list_blobs.assert_called_once_with( | ||
| name_starts_with="text", include="metadata" | ||
| ) | ||
|
|
||
|
|
||
| async def test_async_token_credential( | ||
|
|
@@ -367,3 +470,29 @@ async def test_async_user_agent( | |
| [doc async for doc in loader.alazy_load()] | ||
| client_kwargs = async_mock_container_client_cls.call_args[1] | ||
| assert client_kwargs["user_agent"] == user_agent | ||
|
|
||
|
|
||
| def test_datalake_excludes_directories( | ||
| account_url: str, | ||
| container_name: str, | ||
| create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader], | ||
| mock_datalake_container_client: Tuple[MagicMock, MagicMock], | ||
| ) -> None: | ||
| loader = create_azure_blob_storage_loader() | ||
| expected_documents = get_expected_documents( | ||
| get_expected_datalake_blobs(), account_url, container_name | ||
| ) | ||
| assert list(loader.lazy_load()) == expected_documents | ||
|
|
||
|
|
||
| async def test_async_datalake_excludes_directories( | ||
| account_url: str, | ||
| container_name: str, | ||
| create_azure_blob_storage_loader: Callable[..., AzureBlobStorageLoader], | ||
| async_mock_datalake_container_client: Tuple[AsyncMock, AsyncMock], | ||
| ) -> None: | ||
| loader = create_azure_blob_storage_loader() | ||
| expected_documents = get_expected_documents( | ||
| get_expected_datalake_blobs(), account_url, container_name | ||
| ) | ||
| assert [doc async for doc in loader.alazy_load()] == expected_documents | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In between the
blob.sizeand thehdi_isfoldercheck, let's add ablob.metadata. So:Mainly blobs could be empty and not be directories and have
metadatathat equates toNone. This additional clause will make sure we short-circuit early to avoid trying to call.get()on aNonevalue and cause the entire listing to error out.