Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added ocr image transform #515

Merged
merged 14 commits into from
Aug 17, 2023
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,7 @@ examples/*/*.csv
*.pkl

# DB Files
*.db
*.db

# Image Files
examples/**/*.png
50 changes: 50 additions & 0 deletions docs/guide/transforms/image_transform.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
The image transform allows users to extract text from image files. Autolabel uses optical character recognition (OCR) to read the images. To use this transform, follow these steps:

## Installation

Use the following command to download all dependencies for the image transform.

```bash
pip install pillow pytesseract
```

The tesseract engine is also required for OCR text extraction. See the [tesseract docs](https://tesseract-ocr.github.io/tessdoc/Installation.html) for installation instructions.

## Parameters for this transform

1. file_path_column: the name of the column containing the file paths of the pdf files to extract text from
2. lang: a string indicating the language of the text in the pdf file. See the [tesseract docs](https://tesseract-ocr.github.io/tessdoc/Data-Files-in-different-versions.html) for a full list of supported languages

## Using the transform

Below is an example of an image transform to extract text from an image file:

```json
{
..., # other config parameters
"transforms": [
..., # other transforms
{
"name": "image",
"params": {
"file_path_column": "file_path",
"lang": "eng"
},
"output_columns": {
"content_column": "content",
"metadata_column": "metadata"
}
}
]
}
```

## Run the transform

```python
from autolabel import LabelingAgent, AutolabelDataset
agent = LabelingAgent(config)
ds = agent.transform(ds)
```

This runs the transformation. We will see the content in the correct column. Access this using `ds.df` in the AutolabelDataset.
9 changes: 6 additions & 3 deletions docs/guide/transforms/pdf_transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ For OCR text extraction, install the <code>pdf2image</code> and <code>pytesserac
pip install pdf2image pytesseract
```

The tesseract engine is also required for OCR text extraction. See the [tesseract docs](https://tesseract-ocr.github.io/tessdoc/Installation.html) for installation instructions.

## Parameters for this transform

<ol>
Expand All @@ -22,9 +24,10 @@ pip install pdf2image pytesseract
<li>page_format: a string containing the format to use for each page of the pdf file. The following fields can be used in the format string:
<ul>
<li>page_num: the page number of the page</li>
<li>page_content: the content of the page</li></li>
</ul>
<li>page_sep: a string containing the separator to use between each page of the pdf file
<li>page_content: the content of the page</li>
</ul></li>
<li>page_sep: a string containing the separator to use between each page of the pdf file</li>
<li>lang: a string indicating the language of the text in the pdf file. See the [tesseract docs](https://tesseract-ocr.github.io/tessdoc/Data-Files-in-different-versions.html) for a full list of supported languages</li>
</ol>

### Output Format
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ nav:
- Introduction: guide/transforms/introduction.md
- Webpage Transform: guide/transforms/webpage_transform.md
- PDF Transform: guide/transforms/pdf_transform.md
- Image Transform: guide/transforms/image_transform.md
- Improving Labeling Accuracy:
- Prompting Better: guide/accuracy/prompting-better.md
- Few-shot Prompting: guide/accuracy/few-shot.md
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ all = [
"pytesseract >= 0.3.10",
"beautifulsoup4 >= 4.12.2",
"httpx",
"fake_useragent"
"fake_useragent",
"pillow >= 9.5.0"
]

[project.urls]
Expand Down
2 changes: 2 additions & 0 deletions src/autolabel/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base import BaseTransform
from .pdf import PDFTransform
from .webpage_transform import WebpageTransform
from .image import ImageTransform
from typing import Dict
from autolabel.transforms.schema import TransformType
from autolabel.cache import BaseCache
Expand All @@ -12,6 +13,7 @@
TRANSFORM_REGISTRY = {
TransformType.PDF: PDFTransform,
TransformType.WEBPAGE_TRANSFORM: WebpageTransform,
TransformType.IMAGE: ImageTransform,
}


Expand Down
89 changes: 89 additions & 0 deletions src/autolabel/transforms/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Dict, Any

from autolabel.transforms.schema import TransformType
from autolabel.transforms import BaseTransform
from autolabel.cache import BaseCache


class ImageTransform(BaseTransform):
"""This class is used to extract text from images using OCR. The output columns dictionary for this class should include the keys 'content_column' and 'metadata_column'

This transform supports the following image formats: PNG, JPEG, TIFF, JPEG 2000, GIF, WebP, BMP, and PNM
"""

COLUMN_NAMES = [
"content_column",
"metadata_column",
]

def __init__(
self,
cache: BaseCache,
output_columns: Dict[str, Any],
file_path_column: str,
lang: str = None,
) -> None:
super().__init__(cache, output_columns)
self.file_path_column = file_path_column
self.lang = lang

try:
from PIL import Image
import pytesseract

self.Image = Image
self.pytesseract = pytesseract
self.pytesseract.get_tesseract_version()
except ImportError:
raise ImportError(
"pillow and pytesseract are required to use the image transform with ocr. Please install pillow and pytesseract with the following command: pip install pillow pytesseract"
)
except EnvironmentError:
raise EnvironmentError(
"The tesseract engine is required to use the image transform with ocr. Please see https://tesseract-ocr.github.io/tessdoc/Installation.html for installation instructions."
)

@staticmethod
def name() -> str:
return TransformType.IMAGE

def get_image_metadata(self, file_path: str):
try:
image = self.Image.open(file_path)
metadata = {
"format": image.format,
"mode": image.mode,
"size": image.size,
"width": image.width,
"height": image.height,
"exif": image._getexif(), # Exif metadata
}
return metadata
except Exception as e:
return {"error": str(e)}

async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""This function transforms an image into text using OCR.

Args:
row (Dict[str, Any]): The row of data to be transformed.

Returns:
Dict[str, Any]: The dict of output columns.
"""
content = self.pytesseract.image_to_string(
row[self.file_path_column], lang=self.lang
)
metadata = self.get_image_metadata(row[self.file_path_column])
transformed_row = {
self.output_columns["content_column"]: content,
self.output_columns["metadata_column"]: metadata,
}
return transformed_row

def params(self) -> Dict[str, Any]:
return {
"output_columns": self.output_columns,
"file_path_column": self.file_path_column,
"lang": self.lang,
}
14 changes: 10 additions & 4 deletions src/autolabel/transforms/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class PDFTransform(BaseTransform):
"""This class is used to extract text from PDFs. The output columns dictionary for this class should include the keys 'content_column' and 'metadata_column'"""

COLUMN_NAMES = [
"content_column",
"metadata_column",
Expand All @@ -17,15 +19,16 @@ def __init__(
output_columns: Dict[str, Any],
file_path_column: str,
ocr_enabled: bool = False,
page_header: str = "Page {page_num}: {page_content}",
page_format: str = "Page {page_num}: {page_content}",
page_sep: str = "\n\n",
lang: str = None,
) -> None:
"""The output columns for this class should be in the order: [content_column, num_pages_column]"""
super().__init__(cache, output_columns)
self.file_path_column = file_path_column
self.ocr_enabled = ocr_enabled
self.page_format = page_header
self.page_format = page_format
self.page_sep = page_sep
self.lang = lang

if self.ocr_enabled:
try:
Expand Down Expand Up @@ -70,7 +73,9 @@ def get_page_texts(self, row: Dict[str, Any]) -> List[str]:
"""
if self.ocr_enabled:
pages = self.convert_from_path(row[self.file_path_column])
return [self.pytesseract.image_to_string(page) for page in pages]
return [
self.pytesseract.image_to_string(page, lang=self.lang) for page in pages
]
else:
loader = self.PDFPlumberLoader(row[self.file_path_column])
return [page.page_content for page in loader.load()]
Expand Down Expand Up @@ -103,4 +108,5 @@ def params(self):
"page_header": self.page_format,
"page_sep": self.page_sep,
"output_columns": self.output_columns,
"lang": self.lang,
}
1 change: 1 addition & 0 deletions src/autolabel/transforms/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class TransformType(str, Enum):

WEBPAGE_TRANSFORM = "webpage_transform"
PDF = "pdf"
IMAGE = "image"


class TransformCacheEntry(BaseModel):
Expand Down
Binary file added tests/assets/transforms/budget.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
67 changes: 67 additions & 0 deletions tests/unit/transforms/test_image_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from autolabel.transforms.image import ImageTransform
import pytest

pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio
async def test_image_transform(mocker):
mocker.patch(
"subprocess.check_output",
return_value="5.3.2".encode("utf-8"),
)
mocker.patch(
"pytesseract.pytesseract.run_and_get_output",
return_value="This is a test",
)

# Initialize the ImageTransform class
transform = ImageTransform(
output_columns={
"content_column": "content",
"metadata_column": "metadata",
},
file_path_column="file_path",
cache=None,
)

# Create a mock row
row = {"file_path": "tests/assets/transforms/budget.png"}
# Transform the row
transformed_row = await transform.apply(row)
# Check the output
assert set(transformed_row.keys()) == set(["content", "metadata"])
assert transformed_row["content"] == "This is a test"
assert isinstance(transformed_row["metadata"], dict)
assert len(transformed_row["content"]) > 0
metadata = transformed_row["metadata"]
assert metadata["format"] == row["file_path"].split(".")[-1].upper()
assert metadata["mode"] == "L"
assert metadata["size"] == (1766, 2257)
assert metadata["width"] == 1766
assert metadata["height"] == 2257
assert metadata["exif"] is None


@pytest.mark.asyncio
async def test_error_handling():
# Initialize the PDFTransform class
transform = ImageTransform(
output_columns={
"content_column": "content",
},
file_path_column="file_path",
cache=None,
)

# Create a mock row
row = {"file_path": "invalid_file.png"}
# Transform the row
transformed_row = await transform.apply(row)
# Check the output
assert set(transformed_row.keys()) == set(["content", "image_error"])
assert transformed_row["content"] == "NO_TRANSFORM"
assert (
transformed_row["image_error"]
== "tesseract is not installed or it's not in your PATH. See README file for more information."
)
34 changes: 33 additions & 1 deletion tests/unit/transforms/test_pdf_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_pdf_transform():
)

# Create a mock row
row = {"file_path": "tests/assets/data_loading/Resume.pdf"}
row = {"file_path": "tests/assets/transforms/Resume.pdf"}
# Transform the row
transformed_row = await transform.apply(row)
# Check the output
Expand All @@ -28,6 +28,38 @@ async def test_pdf_transform():
assert transformed_row["metadata"]["num_pages"] == 1


@pytest.mark.asyncio
async def test_pdf_transform_ocr(mocker):
mocker.patch(
"subprocess.check_output",
return_value="5.3.2".encode("utf-8"),
)
mocker.patch(
"autolabel.transforms.pdf.PDFTransform.get_page_texts",
return_value=["This is a test"],
)
transform = PDFTransform(
output_columns={
"content_column": "content",
"metadata_column": "metadata",
},
file_path_column="file_path",
ocr_enabled=True,
cache=None,
)

# Create a mock row
row = {"file_path": "tests/assets/transforms/Resume.pdf"}
# Transform the row
transformed_row = await transform.apply(row)
# Check the output
assert set(transformed_row.keys()) == set(["content", "metadata"])
assert transformed_row["content"] == "Page 1: This is a test"
assert isinstance(transformed_row["metadata"], dict)
assert len(transformed_row["content"]) > 0
assert transformed_row["metadata"]["num_pages"] == 1


@pytest.mark.asyncio
async def test_error_handling():
# Initialize the PDFTransform class
Expand Down
Loading