-
Notifications
You must be signed in to change notification settings - Fork 822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: per-process ingest connections #1058
Changes from 31 commits
34137ca
ffdc081
97caff9
17420fc
5edd0a0
05e824d
69ad5d3
d74e74f
71e5201
1fbe77c
27e6c00
d57eebc
e767e65
0be0182
48c17dc
058a85d
352acb2
fd2f27b
c4074d3
d11d4ea
626a659
4071a35
7391ea1
cdb50e2
dbebe08
aa7de05
7e6164d
3d7f7e1
b38433a
cf3de88
9b097bb
6513520
1f64a06
3ffeaae
29b97f5
18433e3
cc7149b
171f4a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from dataclasses import dataclass | ||
|
||
import pytest | ||
|
||
from unstructured.ingest.doc_processor.generalized import ( | ||
process_document, | ||
session_handle_var, | ||
) | ||
from unstructured.ingest.interfaces import BaseIngestDoc, IngestDocSessionHandleMixin | ||
|
||
|
||
@dataclass | ||
class IngestDocWithSessionHandle(IngestDocSessionHandleMixin, BaseIngestDoc): | ||
pass | ||
|
||
@pytest.fixture(autouse=True) | ||
def _reset_session_handle(): | ||
session_handle_var.set(None) | ||
|
||
def test_process_document_with_session_handle(mocker): | ||
"""Test that the process_document function calls the doc_processor_fn with the correct | ||
arguments, assigns the session handle, and returns the correct results.""" | ||
mock_session_handle = mocker.MagicMock() | ||
session_handle_var.set(mock_session_handle) | ||
mock_doc = mocker.MagicMock(spec=(IngestDocWithSessionHandle)) | ||
|
||
result = process_document(mock_doc) | ||
|
||
mock_doc.get_file.assert_called_once_with() | ||
mock_doc.write_result.assert_called_with() | ||
mock_doc.cleanup_file.assert_called_once_with() | ||
assert result == mock_doc.process_file.return_value | ||
assert mock_doc.session_handle == mock_session_handle | ||
|
||
|
||
def test_process_document_no_session_handle(mocker): | ||
"""Test that the process_document function calls does not assign session handle the IngestDoc | ||
does not have the session handle mixin.""" | ||
session_handle_var.set(mocker.MagicMock()) | ||
mock_doc = mocker.MagicMock(spec=(BaseIngestDoc)) | ||
|
||
process_document(mock_doc) | ||
|
||
assert not hasattr(mock_doc, "session_handle") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import pytest | ||
|
||
from unstructured.ingest.doc_processor.generalized import session_handle_var | ||
from unstructured.ingest.processor import Processor | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def _reset_session_handle(): | ||
session_handle_var.set(None) | ||
|
||
@pytest.mark.parametrize("test_verbose", [True, False]) | ||
def test_processor_init_with_session_handle(mocker, test_verbose): | ||
"""Test that the init function calls to ingest_log_streaming_init and assigns the session handle | ||
when the a function is passed in.""" | ||
mock_ingest_log_streaming_init = mocker.patch( | ||
"unstructured.ingest.processor.ingest_log_streaming_init", | ||
) | ||
mock_create_session_handle_fn = mocker.MagicMock() | ||
Processor.process_init(test_verbose, mock_create_session_handle_fn) | ||
mock_ingest_log_streaming_init.assert_called_once_with(test_verbose) | ||
mock_create_session_handle_fn.assert_called_once_with() | ||
assert ( | ||
session_handle_var.get() == mock_create_session_handle_fn.return_value | ||
) | ||
|
||
def test_processor_init_no_session_handle(mocker): | ||
"""Test that the init function calls to ingest_log_streaming_init and does not assign the session handle | ||
when the a function is not passed in.""" | ||
mock_ingest_log_streaming_init = mocker.patch( | ||
"unstructured.ingest.processor.ingest_log_streaming_init", | ||
) | ||
Processor.process_init(True) | ||
mock_ingest_log_streaming_init.assert_called_once_with(True) | ||
assert session_handle_var.get() is None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.9.3-dev0" # pragma: no cover | ||
__version__ = "0.9.3-dev1" # pragma: no cover |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,25 +4,36 @@ | |
from dataclasses import dataclass | ||
from mimetypes import guess_extension | ||
from pathlib import Path | ||
from typing import Dict, Optional | ||
from typing import TYPE_CHECKING, Dict, Optional, cast | ||
|
||
from unstructured.file_utils.filetype import EXT_TO_FILETYPE | ||
from unstructured.file_utils.google_filetype import GOOGLE_DRIVE_EXPORT_TYPES | ||
from unstructured.ingest.interfaces import ( | ||
BaseConnector, | ||
BaseConnectorConfig, | ||
BaseIngestDoc, | ||
BaseSessionHandle, | ||
ConnectorCleanupMixin, | ||
ConnectorSessionHandleMixin, | ||
IngestDocCleanupMixin, | ||
IngestDocSessionHandleMixin, | ||
StandardConnectorConfig, | ||
) | ||
from unstructured.ingest.logger import logger | ||
from unstructured.utils import requires_dependencies | ||
|
||
if TYPE_CHECKING: | ||
from googleapiclient.discovery import Resource as GoogleAPIResource | ||
|
||
FILE_FORMAT = "{id}-{name}{ext}" | ||
DIRECTORY_FORMAT = "{id}-{name}" | ||
|
||
|
||
@dataclass | ||
class GoogleDriveSessionHandle(BaseSessionHandle): | ||
service: "GoogleAPIResource" | ||
|
||
|
||
@requires_dependencies(["googleapiclient"], extras="google-drive") | ||
def create_service_account_object(key_path, id=None): | ||
""" | ||
|
@@ -81,13 +92,13 @@ def __post_init__(self): | |
f"Extension not supported. " | ||
f"Value MUST be one of {', '.join([k for k in EXT_TO_FILETYPE if k is not None])}.", | ||
) | ||
self.service = create_service_account_object(self.service_account_key, self.drive_id) | ||
|
||
|
||
@dataclass | ||
class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc): | ||
class GoogleDriveIngestDoc(IngestDocSessionHandleMixin, IngestDocCleanupMixin, BaseIngestDoc): | ||
config: SimpleGoogleDriveConfig | ||
file_meta: Dict | ||
session_handle: Optional[GoogleDriveSessionHandle] = None | ||
|
||
@property | ||
def filename(self): | ||
|
@@ -103,7 +114,9 @@ def get_file(self): | |
from googleapiclient.errors import HttpError | ||
from googleapiclient.http import MediaIoBaseDownload | ||
|
||
self.config.service = create_service_account_object(self.config.service_account_key) | ||
if self.session_handle is None: | ||
raise ValueError("Google Drive session handle was not set.") | ||
self.session_handle = cast(GoogleDriveSessionHandle, self.session_handle) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is already set above: session_handle: Optional[GoogleDriveSessionHandle] = None I don't think the cast is necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. iirc linting gets mad if you don't explicitly cast since it was optional |
||
|
||
if self.file_meta.get("mimeType", "").startswith("application/vnd.google-apps"): | ||
export_mime = GOOGLE_DRIVE_EXPORT_TYPES.get( | ||
|
@@ -117,12 +130,12 @@ def get_file(self): | |
) | ||
return | ||
|
||
request = self.config.service.files().export_media( | ||
request = self.session_handle.service.files().export_media( | ||
fileId=self.file_meta.get("id"), | ||
mimeType=export_mime, | ||
) | ||
else: | ||
request = self.config.service.files().get_media(fileId=self.file_meta.get("id")) | ||
request = self.session_handle.service.files().get_media(fileId=self.file_meta.get("id")) | ||
file = io.BytesIO() | ||
downloader = MediaIoBaseDownload(file, request) | ||
downloaded = False | ||
|
@@ -160,22 +173,32 @@ def write_result(self): | |
logger.info(f"Wrote {self._output_filename}") | ||
|
||
|
||
class GoogleDriveConnector(ConnectorCleanupMixin, BaseConnector): | ||
class GoogleDriveConnector(ConnectorSessionHandleMixin, ConnectorCleanupMixin, BaseConnector): | ||
"""Objects of this class support fetching documents from Google Drive""" | ||
|
||
config: SimpleGoogleDriveConfig | ||
|
||
def __init__(self, standard_config: StandardConnectorConfig, config: SimpleGoogleDriveConfig): | ||
super().__init__(standard_config, config) | ||
|
||
@classmethod | ||
def create_session_handle( | ||
cls, | ||
config: BaseConnectorConfig, | ||
) -> GoogleDriveSessionHandle: | ||
config = cast(SimpleGoogleDriveConfig, config) | ||
service = create_service_account_object(config.service_account_key) | ||
return GoogleDriveSessionHandle(service=service) | ||
|
||
def _list_objects(self, drive_id, recursive=False): | ||
files = [] | ||
service = create_service_account_object(self.config.service_account_key) | ||
|
||
def traverse(drive_id, download_dir, output_dir, recursive=False): | ||
page_token = None | ||
while True: | ||
response = ( | ||
self.config.service.files() | ||
service.files() | ||
.list( | ||
spaces="drive", | ||
fields="nextPageToken, files(id, name, mimeType)", | ||
|
@@ -244,6 +267,4 @@ def initialize(self): | |
|
||
def get_ingest_docs(self): | ||
files = self._list_objects(self.config.drive_id, self.config.recursive) | ||
# Setting to None because service object can't be pickled for multiprocessing. | ||
self.config.service = None | ||
return [GoogleDriveIngestDoc(self.standard_config, self.config, file) for file in files] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,16 @@ | |
import multiprocessing as mp | ||
from contextlib import suppress | ||
from functools import partial | ||
from typing import cast | ||
|
||
from unstructured.ingest.doc_processor.generalized import initialize, process_document | ||
from unstructured.ingest.doc_processor.generalized import ( | ||
initialize, | ||
process_document, | ||
session_handle_var, | ||
) | ||
from unstructured.ingest.interfaces import ( | ||
BaseConnector, | ||
ConnectorSessionHandleMixin, | ||
ProcessorConfigs, | ||
) | ||
from unstructured.ingest.logger import ingest_log_streaming_init, logger | ||
|
@@ -41,6 +47,13 @@ def initialize(self): | |
def cleanup(self): | ||
self.doc_connector.cleanup() | ||
|
||
@classmethod | ||
def process_init(cls, verbose, create_session_handle_fn=None): | ||
ingest_log_streaming_init(verbose) | ||
# set the session handle for the doc processor if the connector supports it | ||
if create_session_handle_fn is not None: | ||
session_handle_var.set(create_session_handle_fn()) | ||
|
||
def _filter_docs_with_outputs(self, docs): | ||
num_docs_all = len(docs) | ||
docs = [doc for doc in docs if not doc.has_output()] | ||
|
@@ -74,15 +87,28 @@ def run(self): | |
if not docs: | ||
return | ||
|
||
# get a create_session_handle function if the connector supports it | ||
create_session_handle_fn = ( | ||
partial( | ||
cast(ConnectorSessionHandleMixin, self.doc_connector).create_session_handle, | ||
cast(BaseConnector, self.doc_connector).config, | ||
) | ||
if isinstance(self.doc_connector, ConnectorSessionHandleMixin) | ||
else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think this handle should actually exist in the parent process. but i think i see the issue that the init process needs connection info, at the very least, following this approach. i'm starting to think the cleanest way to do this is for the subprocess itself to create the SessionHandle lazily for the first IngestDoc it processes, since it will have the connector config at that time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or, it could pass the connector config of the first IngestDoc in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. coming back to this fresh, I think I was over-engineering in case we ever needed to handle subprocess work async, but starting to feel like it's safe to just assume this is always serial within that subprocess. I like the idea of just lazily creating it, will head with that. also, I'm now thinking the config should just own the definition for how to create the session handle. this would also be cleaner if we move that logic inside the subprocess since we don't pass the connector itself through (I don't believe). |
||
) | ||
|
||
# Debugging tip: use the below line and comment out the mp.Pool loop | ||
# block to remain in single process | ||
# self.doc_processor_fn(docs[0]) | ||
logger.info(f"Processing {len(docs)} docs") | ||
try: | ||
with mp.Pool( | ||
processes=self.num_processes, | ||
initializer=ingest_log_streaming_init, | ||
initargs=(logging.DEBUG if self.verbose else logging.INFO,), | ||
initializer=self.process_init, | ||
initargs=( | ||
logging.DEBUG if self.verbose else logging.INFO, | ||
create_session_handle_fn, | ||
), | ||
) as pool: | ||
pool.map(self.doc_processor_fn, docs) | ||
finally: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this. it wasn't critical to the intent of this test, but also needing to freeze time here in combination with other tests touching generalized (and by extension calling
get_model
) was triggering a bizarre failure with importing transformers.models.open_llama.tokenization_open_llama? More info here.