diff --git a/unstructured/partition/pdf.py b/unstructured/partition/pdf.py index 29265c327f..c449464db8 100644 --- a/unstructured/partition/pdf.py +++ b/unstructured/partition/pdf.py @@ -5,6 +5,7 @@ import io import os import re +import tempfile import warnings from pathlib import Path from typing import IO, TYPE_CHECKING, Any, Optional, cast @@ -80,13 +81,14 @@ from unstructured.partition.text import element_from_text from unstructured.partition.utils.config import env_config from unstructured.partition.utils.constants import ( + LAYOUT_DEFAULT_CLARIFAI_MODEL, + OCR_AGENT_CLARIFAI, OCR_AGENT_PADDLE, SORT_MODE_BASIC, SORT_MODE_DONT, SORT_MODE_XY_CUT, OCRMode, PartitionStrategy, - OCR_AGENT_CLARIFAI, ) from unstructured.partition.utils.sorting import coord_has_valid_points, sort_page_elements from unstructured.patches.pdfminer import parse_keyword @@ -102,15 +104,12 @@ RE_MULTISPACE_INCLUDING_NEWLINES = re.compile(pattern=r"\s+", flags=re.DOTALL) -@requires_dependencies("unstructured_inference") def default_hi_res_model() -> str: # a light config for the hi res model; this is not defined as a constant so that no setting of # the default hi res model name is done on importing of this submodule; this allows (if user # prefers) for setting env after importing the sub module and changing the default model name - from unstructured_inference.models.base import DEFAULT_MODEL - - return os.environ.get("UNSTRUCTURED_HI_RES_MODEL_NAME", DEFAULT_MODEL) + return os.environ.get("CLARIFAI_HI_RES_MODEL_NAME", LAYOUT_DEFAULT_CLARIFAI_MODEL) @process_metadata() @@ -138,7 +137,6 @@ def partition_pdf( extract_forms: bool = False, form_extraction_skip_tables: bool = True, clarifai_ocr_model: Optional[str] = None, - **kwargs: Any, ) -> list[Element]: """Parses a pdf document into a list of interpreted elements. @@ -223,7 +221,7 @@ def partition_pdf( starting_page_number=starting_page_number, extract_forms=extract_forms, form_extraction_skip_tables=form_extraction_skip_tables, - clarifai_ocr_model = clarifai_ocr_model, + clarifai_ocr_model=clarifai_ocr_model, **kwargs, ) @@ -567,6 +565,7 @@ def _partition_pdf_or_image_local( ) -> list[Element]: """Partition using package installed locally""" from unstructured_inference.inference.layout import ( + DocumentLayout, process_data_with_model, process_file_with_model, ) @@ -576,6 +575,7 @@ def _partition_pdf_or_image_local( process_data_with_pdfminer, process_file_with_pdfminer, ) + from unstructured.partition.utils.clarifai_yolox import ClarifaiYoloXModel if not is_image: check_pdf_hi_res_max_pages_exceeded( @@ -598,13 +598,25 @@ def _partition_pdf_or_image_local( skip_analysis_dump = env_config.ANALYSIS_DUMP_OD_SKIP + if hi_res_model_name == LAYOUT_DEFAULT_CLARIFAI_MODEL: + layout_detection_model = ClarifaiYoloXModel() + if file is None: - inferred_document_layout = process_file_with_model( - filename, - is_image=is_image, - model_name=hi_res_model_name, - pdf_image_dpi=pdf_image_dpi, - ) + if hi_res_model_name == LAYOUT_DEFAULT_CLARIFAI_MODEL: + inferred_document_layout = ( + DocumentLayout.from_image_file(filename, detection_model=layout_detection_model) + if is_image + else DocumentLayout.from_file( + filename, detection_model=layout_detection_model, pdf_image_dpi=pdf_image_dpi + ) + ) + else: + inferred_document_layout = process_file_with_model( + filename, + is_image=is_image, + model_name=hi_res_model_name, + pdf_image_dpi=pdf_image_dpi, + ) if hi_res_model_name.startswith("chipper"): # NOTE(alan): We shouldn't do OCR with chipper @@ -654,12 +666,28 @@ def _partition_pdf_or_image_local( ocr_layout_dumper=ocr_layout_dumper, ) else: - inferred_document_layout = process_data_with_model( - file, - is_image=is_image, - model_name=hi_res_model_name, - pdf_image_dpi=pdf_image_dpi, - ) + if hi_res_model_name == LAYOUT_DEFAULT_CLARIFAI_MODEL: + with tempfile.TemporaryDirectory() as tmp_dir_path: + file_path = os.path.join(tmp_dir_path, "document.pdf") + with open(file_path, "wb") as f: + f.write(file.read()) + f.flush() + inferred_document_layout = ( + DocumentLayout.from_image_file(filename, detection_model=layout_detection_model) + if is_image + else DocumentLayout.from_file( + filename, + detection_model=layout_detection_model, + pdf_image_dpi=pdf_image_dpi, + ) + ) + else: + inferred_document_layout = process_data_with_model( + file, + is_image=is_image, + model_name=hi_res_model_name, + pdf_image_dpi=pdf_image_dpi, + ) if hi_res_model_name.startswith("chipper"): # NOTE(alan): We shouldn't do OCR with chipper @@ -921,15 +949,17 @@ def _partition_pdf_or_image_with_ocr_from_image( """Extract `unstructured` elements from an image using OCR and perform partitioning.""" from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent - - os.environ['OCR_AGENT'] = OCR_AGENT_CLARIFAI + + os.environ["OCR_AGENT"] = OCR_AGENT_CLARIFAI ocr_agent = OCRAgent.get_agent(language=ocr_languages) # NOTE(christine): `pytesseract.image_to_string()` returns sorted text if ocr_agent.is_text_sorted(): sort_mode = SORT_MODE_DONT - ocr_data = ocr_agent.get_layout_elements_from_image(image=image,clarifai_ocr_model=clarifai_ocr_model) + ocr_data = ocr_agent.get_layout_elements_from_image( + image=image, clarifai_ocr_model=clarifai_ocr_model + ) metadata = ElementMetadata( last_modified=metadata_last_modified, diff --git a/unstructured/partition/strategies.py b/unstructured/partition/strategies.py index 23513b9ccb..22a07f9d5c 100644 --- a/unstructured/partition/strategies.py +++ b/unstructured/partition/strategies.py @@ -79,10 +79,10 @@ def determine_pdf_or_image_strategy( logger.warning("Falling back to partitioning with fast.") return PartitionStrategy.FAST else: - #clarifai is not installed and the text of the PDF is not extractable + # clarifai is not installed and the text of the PDF is not extractable raise ImportError( - f"clarifai is not installed. " - f"""Please install using `pip install clarifai`.""") + "clarifai is not installed. " """Please install using `pip install clarifai`.""" + ) return strategy diff --git a/unstructured/partition/utils/clarifai_yolox.py b/unstructured/partition/utils/clarifai_yolox.py new file mode 100644 index 0000000000..4587057873 --- /dev/null +++ b/unstructured/partition/utils/clarifai_yolox.py @@ -0,0 +1,64 @@ +from clarifai.client.model import Model +from PIL import Image as PILImage +from unstructured_inference.constants import Source +from unstructured_inference.inference.layoutelement import LayoutElement, LayoutElements +from unstructured_inference.models.unstructuredmodel import UnstructuredObjectDetectionModel + +from unstructured.partition.utils.constants import LAYOUT_DEFAULT_CLARIFAI_MODEL_URL + + +class ClarifaiYoloXModel(UnstructuredObjectDetectionModel): + """Clarifai YoloX model for layout segmentation.""" + + def __init__(self): + self.model = Model(LAYOUT_DEFAULT_CLARIFAI_MODEL_URL) + self.confidence_threshold = 0.1 + + def predict(self, x: PILImage.Image) -> LayoutElements: + """Predict using Clarifai YoloX model.""" + image_bytes = self.pil_image_to_bytes(x) + model_prediction = self.model.predict_by_bytes(image_bytes, input_type="image") + return self.parse_data(model_prediction, x) + + def initialize(self): + pass + + def parse_data( + self, + model_prediction, + image: PILImage.Image, + ) -> LayoutElements: + """Process model prediction output into Unstructured class. Bounding box coordinates + are converted to original image resolution. Layouts are filtered based on confidence + threshold. + """ + regions_data = model_prediction.outputs[0].data.regions + regions = [] + input_w, input_h = image.size + for region in regions_data: + bboxes = region.region_info.bounding_box + y1, x1, y2, x2 = bboxes.top_row, bboxes.left_col, bboxes.bottom_row, bboxes.right_col + detected_class = region.data.concepts[0].name + confidence = region.value + if confidence >= self.confidence_threshold: + region = LayoutElement.from_coords( + x1 * input_w, + y1 * input_h, + x2 * input_w, + y2 * input_h, + text=None, + type=detected_class, + prob=confidence, + source=Source.YOLOX, + ) + regions.append(region) + + regions.sort(key=lambda element: element.bbox.y1) + return LayoutElements.from_list(regions) + + def pil_image_to_bytes(self, image: PILImage) -> bytes: + from io import BytesIO + + with BytesIO() as output: + image.save(output, format="PNG") + return output.getvalue() diff --git a/unstructured/partition/utils/constants.py b/unstructured/partition/utils/constants.py index 19ac289815..30207d85f4 100644 --- a/unstructured/partition/utils/constants.py +++ b/unstructured/partition/utils/constants.py @@ -29,7 +29,9 @@ class PartitionStrategy: OCR_AGENT_TESSERACT_OLD = "tesseract" OCR_AGENT_PADDLE_OLD = "paddle" -OCR_DEFAULT_CLARIFAI_MODEL_URL = 'https://clarifai.com/clarifai/main/models/ocr-scene-english-paddleocr' +OCR_DEFAULT_CLARIFAI_MODEL_URL = ( + "https://clarifai.com/clarifai/main/models/ocr-scene-english-paddleocr" +) OCR_AGENT_TESSERACT = "unstructured.partition.utils.ocr_models.tesseract_ocr.OCRAgentTesseract" OCR_AGENT_PADDLE = "unstructured.partition.utils.ocr_models.paddle_ocr.OCRAgentPaddle" @@ -45,6 +47,11 @@ class PartitionStrategy: "unstructured.partition.utils.ocr_models.paddle_ocr", ).split(",") +LAYOUT_DEFAULT_CLARIFAI_MODEL_URL = ( + "https://clarifai.com/mogith-p-n/YOLOx-model-training/models/model-hps-yolox-detection" +) +LAYOUT_DEFAULT_CLARIFAI_MODEL = "clarifai_yolox" + UNSTRUCTURED_INCLUDE_DEBUG_METADATA = os.getenv("UNSTRUCTURED_INCLUDE_DEBUG_METADATA", False) # this field is defined by unstructured_pytesseract diff --git a/unstructured/partition/utils/ocr_models/clarifai_ocr.py b/unstructured/partition/utils/ocr_models/clarifai_ocr.py index 94ea7ac18b..7c9592281f 100644 --- a/unstructured/partition/utils/ocr_models/clarifai_ocr.py +++ b/unstructured/partition/utils/ocr_models/clarifai_ocr.py @@ -1,25 +1,22 @@ from typing import TYPE_CHECKING, List, Optional -import numpy as np -from PIL import Image as PILImage from clarifai.client.model import Model +from PIL import Image as PILImage + from unstructured.documents.elements import ElementType from unstructured.logger import logger -from unstructured.partition.utils.constants import ( - Source, - OCR_DEFAULT_CLARIFAI_MODEL_URL -) +from unstructured.partition.utils.constants import OCR_DEFAULT_CLARIFAI_MODEL_URL, Source from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent from unstructured.utils import requires_dependencies if TYPE_CHECKING: from unstructured_inference.inference.elements import TextRegion - from unstructured_inference.inference.layoutelement import LayoutElemen - + from unstructured_inference.inference.layoutelement import LayoutElement class OCRAgentClarifai(OCRAgent): """OCR service implementation for Clarifai.""" + def __init__(self, language: str = "en"): self.agent = self.load_agent(language) @@ -34,24 +31,30 @@ def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> st return "\n\n".join([r.text for r in ocr_regions]) def get_layout_from_image( - self, image: PILImage, clarifai_model_url: Optional[str] =None, ocr_languages: str = "eng", - ) -> List["TextRegion"]: + self, + image: PILImage, + clarifai_model_url: Optional[str] = None, + ocr_languages: str = "eng", + ) -> List[TextRegion]: """Get the OCR regions from image as a list of text regions with paddle.""" - import base64 + logger.info("Processing entire page OCR with paddle...") image_bytes = self.pil_image_to_bytes(image) - ocr_data = Model(clarifai_model_url).predict_by_bytes(image_bytes , input_type="image") + ocr_data = Model(clarifai_model_url).predict_by_bytes(image_bytes, input_type="image") ocr_regions = self.parse_data(ocr_data) return ocr_regions @requires_dependencies("unstructured_inference") def get_layout_elements_from_image( - self, image: PILImage, ocr_languages: str = "eng", + self, + image: PILImage, + ocr_languages: str = "eng", clarifai_ocr_model: Optional[str] = None, - ) -> List["LayoutElement"]: + ) -> List[LayoutElement]: from unstructured.partition.pdf_image.inference_utils import build_layout_element + if not clarifai_ocr_model: clarifai_ocr_model = OCR_DEFAULT_CLARIFAI_MODEL_URL ocr_regions = self.get_layout_from_image( @@ -74,7 +77,7 @@ def get_layout_elements_from_image( ] @requires_dependencies("unstructured_inference") - def parse_data(self, ocr_data, zoom: float = 1) -> List["TextRegion"]: + def parse_data(self, ocr_data, zoom: float = 1) -> List[TextRegion]: """ Parse the OCR result data to extract a list of TextRegion objects from tesseract. @@ -105,7 +108,7 @@ def parse_data(self, ocr_data, zoom: float = 1) -> List["TextRegion"]: from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords text_regions = [] - #add try catch block + # add try catch block for data in ocr_data.outputs[0].data.regions: x1 = data.region_info.bounding_box.top_row y1 = data.region_info.bounding_box.left_col @@ -125,6 +128,7 @@ def parse_data(self, ocr_data, zoom: float = 1) -> List["TextRegion"]: def image_to_byte_array(self, image: PILImage) -> bytes: import io + # BytesIO is a file-like buffer stored in memory imgByteArr = io.BytesIO() # image.save expects a file-like as a argument @@ -135,6 +139,7 @@ def image_to_byte_array(self, image: PILImage) -> bytes: def pil_image_to_bytes(self, image: PILImage) -> bytes: from io import BytesIO + with BytesIO() as output: image.save(output, format="PNG") - return output.getvalue() \ No newline at end of file + return output.getvalue() diff --git a/unstructured/partition/utils/ocr_models/ocr_interface.py b/unstructured/partition/utils/ocr_models/ocr_interface.py index e62a803f4c..7e59a66312 100644 --- a/unstructured/partition/utils/ocr_models/ocr_interface.py +++ b/unstructured/partition/utils/ocr_models/ocr_interface.py @@ -13,7 +13,6 @@ OCR_AGENT_PADDLE_OLD, OCR_AGENT_TESSERACT, OCR_AGENT_TESSERACT_OLD, - OCR_AGENT_CLARIFAI ) if TYPE_CHECKING: