Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DEVX-594]: Clarifai YOLOx Model usage as Layout Detection Model #2

Open
wants to merge 2 commits into
base: support_clarifai_model
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions unstructured/partition/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io
import os
import re
import tempfile
import warnings
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Optional, cast
Expand Down Expand Up @@ -80,13 +81,14 @@
from unstructured.partition.text import element_from_text
from unstructured.partition.utils.config import env_config
from unstructured.partition.utils.constants import (
LAYOUT_DEFAULT_CLARIFAI_MODEL,
OCR_AGENT_CLARIFAI,
OCR_AGENT_PADDLE,
SORT_MODE_BASIC,
SORT_MODE_DONT,
SORT_MODE_XY_CUT,
OCRMode,
PartitionStrategy,
OCR_AGENT_CLARIFAI,
)
from unstructured.partition.utils.sorting import coord_has_valid_points, sort_page_elements
from unstructured.patches.pdfminer import parse_keyword
Expand All @@ -102,15 +104,12 @@
RE_MULTISPACE_INCLUDING_NEWLINES = re.compile(pattern=r"\s+", flags=re.DOTALL)


@requires_dependencies("unstructured_inference")
def default_hi_res_model() -> str:
# a light config for the hi res model; this is not defined as a constant so that no setting of
# the default hi res model name is done on importing of this submodule; this allows (if user
# prefers) for setting env after importing the sub module and changing the default model name

from unstructured_inference.models.base import DEFAULT_MODEL

return os.environ.get("UNSTRUCTURED_HI_RES_MODEL_NAME", DEFAULT_MODEL)
return os.environ.get("CLARIFAI_HI_RES_MODEL_NAME", LAYOUT_DEFAULT_CLARIFAI_MODEL)


@process_metadata()
Expand Down Expand Up @@ -138,7 +137,6 @@ def partition_pdf(
extract_forms: bool = False,
form_extraction_skip_tables: bool = True,
clarifai_ocr_model: Optional[str] = None,

**kwargs: Any,
) -> list[Element]:
"""Parses a pdf document into a list of interpreted elements.
Expand Down Expand Up @@ -223,7 +221,7 @@ def partition_pdf(
starting_page_number=starting_page_number,
extract_forms=extract_forms,
form_extraction_skip_tables=form_extraction_skip_tables,
clarifai_ocr_model = clarifai_ocr_model,
clarifai_ocr_model=clarifai_ocr_model,
**kwargs,
)

Expand Down Expand Up @@ -567,6 +565,7 @@ def _partition_pdf_or_image_local(
) -> list[Element]:
"""Partition using package installed locally"""
from unstructured_inference.inference.layout import (
DocumentLayout,
process_data_with_model,
process_file_with_model,
)
Expand All @@ -576,6 +575,7 @@ def _partition_pdf_or_image_local(
process_data_with_pdfminer,
process_file_with_pdfminer,
)
from unstructured.partition.utils.clarifai_yolox import ClarifaiYoloXModel

if not is_image:
check_pdf_hi_res_max_pages_exceeded(
Expand All @@ -598,13 +598,25 @@ def _partition_pdf_or_image_local(

skip_analysis_dump = env_config.ANALYSIS_DUMP_OD_SKIP

if hi_res_model_name == LAYOUT_DEFAULT_CLARIFAI_MODEL:
layout_detection_model = ClarifaiYoloXModel()

if file is None:
inferred_document_layout = process_file_with_model(
mogith-pn marked this conversation as resolved.
Show resolved Hide resolved
filename,
is_image=is_image,
model_name=hi_res_model_name,
pdf_image_dpi=pdf_image_dpi,
)
if hi_res_model_name == LAYOUT_DEFAULT_CLARIFAI_MODEL:
inferred_document_layout = (
DocumentLayout.from_image_file(filename, detection_model=layout_detection_model)
if is_image
else DocumentLayout.from_file(
filename, detection_model=layout_detection_model, pdf_image_dpi=pdf_image_dpi
)
)
else:
inferred_document_layout = process_file_with_model(
filename,
is_image=is_image,
model_name=hi_res_model_name,
pdf_image_dpi=pdf_image_dpi,
)

if hi_res_model_name.startswith("chipper"):
# NOTE(alan): We shouldn't do OCR with chipper
Expand Down Expand Up @@ -654,12 +666,28 @@ def _partition_pdf_or_image_local(
ocr_layout_dumper=ocr_layout_dumper,
)
else:
inferred_document_layout = process_data_with_model(
srikanthbachala20 marked this conversation as resolved.
Show resolved Hide resolved
file,
is_image=is_image,
model_name=hi_res_model_name,
pdf_image_dpi=pdf_image_dpi,
)
if hi_res_model_name == LAYOUT_DEFAULT_CLARIFAI_MODEL:
with tempfile.TemporaryDirectory() as tmp_dir_path:
file_path = os.path.join(tmp_dir_path, "document.pdf")
with open(file_path, "wb") as f:
f.write(file.read())
f.flush()
inferred_document_layout = (
DocumentLayout.from_image_file(filename, detection_model=layout_detection_model)
if is_image
else DocumentLayout.from_file(
filename,
detection_model=layout_detection_model,
pdf_image_dpi=pdf_image_dpi,
)
)
else:
inferred_document_layout = process_data_with_model(
file,
is_image=is_image,
model_name=hi_res_model_name,
pdf_image_dpi=pdf_image_dpi,
)

if hi_res_model_name.startswith("chipper"):
# NOTE(alan): We shouldn't do OCR with chipper
Expand Down Expand Up @@ -921,15 +949,17 @@ def _partition_pdf_or_image_with_ocr_from_image(
"""Extract `unstructured` elements from an image using OCR and perform partitioning."""

from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
os.environ['OCR_AGENT'] = OCR_AGENT_CLARIFAI

os.environ["OCR_AGENT"] = OCR_AGENT_CLARIFAI
ocr_agent = OCRAgent.get_agent(language=ocr_languages)

# NOTE(christine): `pytesseract.image_to_string()` returns sorted text
if ocr_agent.is_text_sorted():
sort_mode = SORT_MODE_DONT

ocr_data = ocr_agent.get_layout_elements_from_image(image=image,clarifai_ocr_model=clarifai_ocr_model)
ocr_data = ocr_agent.get_layout_elements_from_image(
image=image, clarifai_ocr_model=clarifai_ocr_model
)

metadata = ElementMetadata(
last_modified=metadata_last_modified,
Expand Down
6 changes: 3 additions & 3 deletions unstructured/partition/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def determine_pdf_or_image_strategy(
logger.warning("Falling back to partitioning with fast.")
return PartitionStrategy.FAST
else:
#clarifai is not installed and the text of the PDF is not extractable
# clarifai is not installed and the text of the PDF is not extractable
raise ImportError(
f"clarifai is not installed. "
f"""Please install using `pip install clarifai`.""")
"clarifai is not installed. " """Please install using `pip install clarifai`."""
)

return strategy

Expand Down
64 changes: 64 additions & 0 deletions unstructured/partition/utils/clarifai_yolox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from clarifai.client.model import Model
from PIL import Image as PILImage
from unstructured_inference.constants import Source
from unstructured_inference.inference.layoutelement import LayoutElement, LayoutElements
from unstructured_inference.models.unstructuredmodel import UnstructuredObjectDetectionModel

from unstructured.partition.utils.constants import LAYOUT_DEFAULT_CLARIFAI_MODEL_URL


class ClarifaiYoloXModel(UnstructuredObjectDetectionModel):
"""Clarifai YoloX model for layout segmentation."""

def __init__(self):
self.model = Model(LAYOUT_DEFAULT_CLARIFAI_MODEL_URL)
self.confidence_threshold = 0.1

def predict(self, x: PILImage.Image) -> LayoutElements:
"""Predict using Clarifai YoloX model."""
image_bytes = self.pil_image_to_bytes(x)
model_prediction = self.model.predict_by_bytes(image_bytes, input_type="image")
mogith-pn marked this conversation as resolved.
Show resolved Hide resolved
return self.parse_data(model_prediction, x)

def initialize(self):
pass

def parse_data(
mogith-pn marked this conversation as resolved.
Show resolved Hide resolved
self,
model_prediction,
image: PILImage.Image,
) -> LayoutElements:
"""Process model prediction output into Unstructured class. Bounding box coordinates
are converted to original image resolution. Layouts are filtered based on confidence
threshold.
"""
regions_data = model_prediction.outputs[0].data.regions
regions = []
input_w, input_h = image.size
for region in regions_data:
bboxes = region.region_info.bounding_box
y1, x1, y2, x2 = bboxes.top_row, bboxes.left_col, bboxes.bottom_row, bboxes.right_col
detected_class = region.data.concepts[0].name
confidence = region.value
if confidence >= self.confidence_threshold:
region = LayoutElement.from_coords(
x1 * input_w,
y1 * input_h,
x2 * input_w,
y2 * input_h,
text=None,
type=detected_class,
prob=confidence,
source=Source.YOLOX,
)
regions.append(region)

regions.sort(key=lambda element: element.bbox.y1)
return LayoutElements.from_list(regions)

def pil_image_to_bytes(self, image: PILImage) -> bytes:
from io import BytesIO

with BytesIO() as output:
image.save(output, format="PNG")
return output.getvalue()
9 changes: 8 additions & 1 deletion unstructured/partition/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class PartitionStrategy:
OCR_AGENT_TESSERACT_OLD = "tesseract"
OCR_AGENT_PADDLE_OLD = "paddle"

OCR_DEFAULT_CLARIFAI_MODEL_URL = 'https://clarifai.com/clarifai/main/models/ocr-scene-english-paddleocr'
OCR_DEFAULT_CLARIFAI_MODEL_URL = (
"https://clarifai.com/clarifai/main/models/ocr-scene-english-paddleocr"
)

OCR_AGENT_TESSERACT = "unstructured.partition.utils.ocr_models.tesseract_ocr.OCRAgentTesseract"
OCR_AGENT_PADDLE = "unstructured.partition.utils.ocr_models.paddle_ocr.OCRAgentPaddle"
Expand All @@ -45,6 +47,11 @@ class PartitionStrategy:
"unstructured.partition.utils.ocr_models.paddle_ocr",
).split(",")

LAYOUT_DEFAULT_CLARIFAI_MODEL_URL = (
"https://clarifai.com/mogith-p-n/YOLOx-model-training/models/model-hps-yolox-detection"
)
LAYOUT_DEFAULT_CLARIFAI_MODEL = "clarifai_yolox"

UNSTRUCTURED_INCLUDE_DEBUG_METADATA = os.getenv("UNSTRUCTURED_INCLUDE_DEBUG_METADATA", False)

# this field is defined by unstructured_pytesseract
Expand Down
39 changes: 22 additions & 17 deletions unstructured/partition/utils/ocr_models/clarifai_ocr.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
from typing import TYPE_CHECKING, List, Optional

import numpy as np
from PIL import Image as PILImage
from clarifai.client.model import Model
from PIL import Image as PILImage

from unstructured.documents.elements import ElementType
from unstructured.logger import logger
from unstructured.partition.utils.constants import (
Source,
OCR_DEFAULT_CLARIFAI_MODEL_URL
)
from unstructured.partition.utils.constants import OCR_DEFAULT_CLARIFAI_MODEL_URL, Source
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies

if TYPE_CHECKING:
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import LayoutElemen

from unstructured_inference.inference.layoutelement import LayoutElement


class OCRAgentClarifai(OCRAgent):
"""OCR service implementation for Clarifai."""

def __init__(self, language: str = "en"):
self.agent = self.load_agent(language)

Expand All @@ -34,24 +31,30 @@ def get_text_from_image(self, image: PILImage, ocr_languages: str = "eng") -> st
return "\n\n".join([r.text for r in ocr_regions])

def get_layout_from_image(
self, image: PILImage, clarifai_model_url: Optional[str] =None, ocr_languages: str = "eng",
) -> List["TextRegion"]:
self,
image: PILImage,
clarifai_model_url: Optional[str] = None,
ocr_languages: str = "eng",
) -> List[TextRegion]:
"""Get the OCR regions from image as a list of text regions with paddle."""
import base64

logger.info("Processing entire page OCR with paddle...")

image_bytes = self.pil_image_to_bytes(image)
ocr_data = Model(clarifai_model_url).predict_by_bytes(image_bytes , input_type="image")
ocr_data = Model(clarifai_model_url).predict_by_bytes(image_bytes, input_type="image")
ocr_regions = self.parse_data(ocr_data)

return ocr_regions

@requires_dependencies("unstructured_inference")
def get_layout_elements_from_image(
self, image: PILImage, ocr_languages: str = "eng",
self,
image: PILImage,
ocr_languages: str = "eng",
clarifai_ocr_model: Optional[str] = None,
) -> List["LayoutElement"]:
) -> List[LayoutElement]:
from unstructured.partition.pdf_image.inference_utils import build_layout_element

if not clarifai_ocr_model:
clarifai_ocr_model = OCR_DEFAULT_CLARIFAI_MODEL_URL
ocr_regions = self.get_layout_from_image(
Expand All @@ -74,7 +77,7 @@ def get_layout_elements_from_image(
]

@requires_dependencies("unstructured_inference")
def parse_data(self, ocr_data, zoom: float = 1) -> List["TextRegion"]:
def parse_data(self, ocr_data, zoom: float = 1) -> List[TextRegion]:
"""
Parse the OCR result data to extract a list of TextRegion objects from
tesseract.
Expand Down Expand Up @@ -105,7 +108,7 @@ def parse_data(self, ocr_data, zoom: float = 1) -> List["TextRegion"]:
from unstructured.partition.pdf_image.inference_utils import build_text_region_from_coords

text_regions = []
#add try catch block
# add try catch block
for data in ocr_data.outputs[0].data.regions:
x1 = data.region_info.bounding_box.top_row
y1 = data.region_info.bounding_box.left_col
Expand All @@ -125,6 +128,7 @@ def parse_data(self, ocr_data, zoom: float = 1) -> List["TextRegion"]:

def image_to_byte_array(self, image: PILImage) -> bytes:
import io

# BytesIO is a file-like buffer stored in memory
imgByteArr = io.BytesIO()
# image.save expects a file-like as a argument
Expand All @@ -135,6 +139,7 @@ def image_to_byte_array(self, image: PILImage) -> bytes:

def pil_image_to_bytes(self, image: PILImage) -> bytes:
from io import BytesIO

with BytesIO() as output:
image.save(output, format="PNG")
return output.getvalue()
return output.getvalue()
1 change: 0 additions & 1 deletion unstructured/partition/utils/ocr_models/ocr_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
OCR_AGENT_PADDLE_OLD,
OCR_AGENT_TESSERACT,
OCR_AGENT_TESSERACT_OLD,
OCR_AGENT_CLARIFAI
)

if TYPE_CHECKING:
Expand Down