Skip to content

Commit

Permalink
added pdf transform and tests (#502)
Browse files Browse the repository at this point in the history
* added pdf transform and tests

* moved pdf transform to schema and dependencies

* removed usused pytest import

* removed the pypdf dependency

* added extract text OCR method

* default non-ocr, moved imports to optional

* moved imports to constructor

* split pdf transform into pdf and pdf_ocr

* added dependencies to all

* both ocr and regular transforms back to one file

* updated typing and some metadata

* added tests and moved import

* can't test pdf ocr without tesseract engine

* removed pdf_ocr from schema
  • Loading branch information
Tyrest authored Aug 7, 2023
1 parent 7c447aa commit 06893f8
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 1 deletion.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ all = [
"google-cloud-aiplatform>=1.25.0",
"cohere>=4.11.2",
"sentence_transformers",
"pdfplumber >= 0.10.2",
"pdf2image >= 1.16.3",
"pytesseract >= 0.3.10",
"bs4",
"httpx",
"fake_useragent"
Expand Down
1 change: 1 addition & 0 deletions src/autolabel/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,4 @@ class TransformType(str, Enum):
"""Enum containing all Transforms supported by autolabel"""

WEBPAGE_TRANSFORM = "webpage_transform"
PDF = "pdf"
2 changes: 2 additions & 0 deletions src/autolabel/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging

from .base import BaseTransform
from .pdf import PDFTransform
from .webpage_transform import WebpageTransform
from typing import Dict, List
from autolabel.schema import TransformType

logger = logging.getLogger(__name__)

TRANSFORM_REGISTRY = {
TransformType.PDF: PDFTransform,
TransformType.WEBPAGE_TRANSFORM: WebpageTransform,
}

Expand Down
3 changes: 2 additions & 1 deletion src/autolabel/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ def __init__(self, output_columns: Dict[str, Any]) -> None:
super().__init__()
self._output_columns = output_columns

@staticmethod
@abstractmethod
def name(self) -> str:
def name() -> str:
pass

@property
Expand Down
98 changes: 98 additions & 0 deletions src/autolabel/transforms/pdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import List, Dict, Any

from autolabel.schema import TransformType
from autolabel.transforms import BaseTransform


class PDFTransform(BaseTransform):
def __init__(
self,
output_columns: Dict[str, Any],
file_path_column: str,
ocr_enabled: bool = False,
page_header: str = "Page {page_num}: {page_content}",
page_sep: str = "\n\n",
) -> None:
"""The output columns for this class should be in the order: [content_column, num_pages_column]"""
super().__init__(output_columns)
self.file_path_column = file_path_column
self.ocr_enabled = ocr_enabled
self.page_format = page_header
self.page_sep = page_sep

if self.ocr_enabled:
try:
from pdf2image import convert_from_path
import pytesseract

self.convert_from_path = convert_from_path
self.pytesseract = pytesseract
self.pytesseract.get_tesseract_version()
except ImportError:
raise ImportError(
"pdf2image and pytesseract are required to use the pdf transform with ocr. Please install pdf2image and pytesseract with the following command: pip install pdf2image pytesseract"
)
except EnvironmentError:
raise EnvironmentError(
"The tesseract engine is required to use the pdf transform with ocr. Please see https://tesseract-ocr.github.io/tessdoc/Installation.html for installation instructions."
)
else:
try:
from langchain.document_loaders import PDFPlumberLoader

self.PDFPlumberLoader = PDFPlumberLoader
except ImportError:
raise ImportError(
"pdfplumber is required to use the pdf transform. Please install pdfplumber with the following command: pip install pdfplumber"
)

@staticmethod
def name() -> str:
return TransformType.PDF

@property
def output_columns(self) -> Dict[str, Any]:
COLUMN_NAMES = [
"content_column",
"metadata_column",
]
return {k: self._output_columns.get(k, k) for k in COLUMN_NAMES}

def get_page_texts(self, row: Dict[str, Any]) -> List[str]:
"""This function gets the text from each page of a PDF file.
If OCR is enabled, it uses the pdf2image library to convert the PDF into images and then uses
pytesseract to convert the images into text. Otherwise, it uses pdfplumber to extract the text.
Args:
row (Dict[str, Any]): The row of data to be transformed.
Returns:
List[str]: A list of strings containing the text from each page of the PDF.
"""
if self.ocr_enabled:
pages = self.convert_from_path(row[self.file_path_column])
return [self.pytesseract.image_to_string(page) for page in pages]
else:
loader = self.PDFPlumberLoader(row[self.file_path_column])
return [page.page_content for page in loader.load()]

async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""This function transforms a PDF file into a string of text.
The text is formatted according to the page_format and
page_sep parameters and returned as a string.
Args:
row (Dict[str, Any]): The row of data to be transformed.
Returns:
Dict[str, Any]: The dict of output columns.
"""
texts = []
for idx, text in enumerate(self.get_page_texts(row)):
texts.append(self.page_format.format(page_num=idx + 1, page_content=text))
output = self.page_sep.join(texts)
transformed_row = {
self.output_columns["content_column"]: output,
self.output_columns["metadata_column"]: {"num_pages": len(texts)},
}
return transformed_row
Binary file added tests/assets/data_loading/Resume.pdf
Binary file not shown.
27 changes: 27 additions & 0 deletions tests/unit/transforms/test_pdf_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from autolabel.transforms.pdf import PDFTransform
import pytest

pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio
async def test_pdf_transform():
# Initialize the PDFTransform class
transform = PDFTransform(
output_columns={
"content_column": "content",
"metadata_column": "metadata",
},
file_path_column="file_path",
)

# Create a mock row
row = {"file_path": "tests/assets/data_loading/Resume.pdf"}
# Transform the row
transformed_row = await transform.apply(row)
# Check the output
assert set(transformed_row.keys()) == set(["content", "metadata"])
assert isinstance(transformed_row["content"], str)
assert isinstance(transformed_row["metadata"], dict)
assert len(transformed_row["content"]) > 0
assert transformed_row["metadata"]["num_pages"] == 1

0 comments on commit 06893f8

Please sign in to comment.