Skip to content

Commit

Permalink
Enable configuring batch sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 21, 2024
1 parent c1d8b4d commit a2d5163
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 29 deletions.
30 changes: 20 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,26 @@ tabled DATA_PATH
- `--detect_cell_boxes` by default, tabled will attempt to pull cell information out of the pdf. If you instead want cells to be detected by a detection model, specify this (usually you only need this with pdfs that have bad embedded text).
- `--save_images` specifies that images of detected rows/columns and cells should be saved.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

- `text_lines` - the detected text and bounding boxes for each line
- `text` - the text in the line
- `confidence` - the confidence of the model in the detected text (0-1)
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `languages` - the languages specified for the page
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
After running the script, the output directory will contain folders with the same basenames as the input filenames. Inside those folders will be the markdown files for each table in the source documents. There will also optionally be images of the tables.

There will also be a `results.json` file in the root of the output directory. The file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per table in the document. Each table dictionary contains:

- `cells` - the detected text and bounding boxes for each table cell.
- `bbox` - bbox of the cell within the table bbox
- `text` - the text of the cell
- `row_ids` - ids of rows the cell belongs to
- `col_ids` - ids of columns the cell belongs to
- `order` - order of this cell within its assigned row/column cell. (sort by row, then column, then order)
- `rows` - bboxes of the detected rows
- `bbox` - bbox of the row in (x1, x2, y1, y2) format
- `row_id` - unique id of the row
- `cols` - bboxes of detected columns
- `bbox` - bbox of the column in (x1, x2, y1, y2) format
- `col_id` - unique id of the column
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. The table bbox is relative to this.
- `bbox` - the bounding box of the table within the image bbox.
- `pnum` - page number within the document
- `tnum` - table index on the page

## Interactive App

Expand Down
24 changes: 16 additions & 8 deletions extract.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from collections import defaultdict

import pypdfium2
Expand Down Expand Up @@ -51,18 +52,21 @@ def main(in_path, out_folder, save_json, save_debug_images, skip_detection, dete

formatted_result, ext = formatter(format, page_cells)
base_name = f"page{pnum}_table{i}"
with open(os.path.join(base_path, f"{base_name}.{ext}"), "w") as f:
with open(os.path.join(base_path, f"{base_name}.{ext}"), "w+", encoding="utf-8") as f:
f.write(formatted_result)

img.save(os.path.join(base_path, f"{base_name}.png"))

if save_json:
result = {
"cells": [c.model_dump() for c in page_cells],
"rows": [r.model_dump() for r in page_rc.rows],
"cols": [c.model_dump() for c in page_rc.cols]
}
out_json[name].append(result)
res = {
"cells": [c.model_dump() for c in page_cells],
"rows": [r.model_dump() for r in page_rc.rows],
"cols": [c.model_dump() for c in page_rc.cols],
"bbox": result.bboxes[i].bbox,
"image_bbox": result.image_bboxes[i].bbox,
"pnum": pnum,
"tnum": i
}
out_json[name].append(res)

if save_debug_images:
boxes = [l.bbox for l in page_cells]
Expand All @@ -80,6 +84,10 @@ def main(in_path, out_folder, save_json, save_debug_images, skip_detection, dete
rc_image = draw_bboxes_on_image(cols, rc_image, labels=col_labels, label_font_size=20, color="red")
rc_image.save(os.path.join(base_path, f"{base_name}_rc.png"))

if save_json:
with open(os.path.join(out_folder, "result.json"), "w+", encoding="utf-8") as f:
json.dump(out_json, f, ensure_ascii=False)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tabled-pdf"
version = "0.1.1"
version = "0.1.2"
description = "Detect and recognize tables in PDFs and images."
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
16 changes: 14 additions & 2 deletions tabled/extract.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from typing import List

from surya.schema import Bbox

from tabled.assignment import assign_rows_columns
from tabled.inference.detection import detect_tables
from tabled.inference.recognition import get_cells, recognize_tables
from tabled.schema import ExtractPageResult


def extract_tables(images, highres_images, text_lines, det_models, rec_models, skip_detection=False, detect_boxes=False) -> List[ExtractPageResult]:
def extract_tables(
images,
highres_images,
text_lines,
det_models,
rec_models,
skip_detection=False,
detect_boxes=False
) -> List[ExtractPageResult]:
if not skip_detection:
table_imgs, table_bboxes, table_counts = detect_tables(images, highres_images, det_models)
else:
Expand All @@ -33,7 +43,9 @@ def extract_tables(images, highres_images, text_lines, det_models, rec_models, s
results.append(ExtractPageResult(
table_imgs=table_imgs[page_start:page_end],
cells=cells[page_start:page_end],
rows_cols=table_rec[page_start:page_end]
rows_cols=table_rec[page_start:page_end],
bboxes=[Bbox(bbox=b) for b in table_bboxes[page_start:page_end]],
image_bboxes=[Bbox(bbox=[0, 0, size[0], size[1]]) for size in highres_image_sizes[page_start:page_end]]
))
counter += count

Expand Down
8 changes: 5 additions & 3 deletions tabled/inference/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from surya.postprocessing.util import rescale_bbox
from surya.schema import Bbox

from tabled.settings import settings


def merge_boxes(box1, box2):
return [min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3])]
Expand Down Expand Up @@ -30,10 +32,10 @@ def merge_tables(page_table_boxes):
return [b for i, b in enumerate(page_table_boxes) if i not in ignore_boxes]


def detect_tables(images, highres_images, models):
def detect_tables(images, highres_images, models, detector_batch_size=settings.DETECTOR_BATCH_SIZE, layout_batch_size=settings.LAYOUT_BATCH_SIZE):
det_model, det_processor, layout_model, layout_processor = models
line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
line_predictions = batch_text_detection(images, det_model, det_processor, batch_size=detector_batch_size)
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions, batch_size=layout_batch_size)

table_imgs = []
table_counts = []
Expand Down
12 changes: 7 additions & 5 deletions tabled/inference/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from surya.schema import TableResult
from surya.tables import batch_table_recognition

from tabled.settings import settings

def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_boxes=False):

def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_boxes=False, detector_batch_size=settings.DETECTOR_BATCH_SIZE):
det_model, det_processor = models
table_cells = []
needs_ocr = []
Expand All @@ -27,15 +29,15 @@ def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_

# Inference tables that need it
if len(to_inference_idxs) > 0:
det_results = batch_text_detection([table_imgs[i] for i in to_inference_idxs], det_model, det_processor)
det_results = batch_text_detection([table_imgs[i] for i in to_inference_idxs], det_model, det_processor, batch_size=detector_batch_size)
for idx, det_result in zip(to_inference_idxs, det_results):
cell_bboxes = [{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes]
table_cells[idx] = cell_bboxes

return table_cells, needs_ocr


def recognize_tables(table_imgs, table_cells, needs_ocr: List[bool], models) -> List[TableResult]:
def recognize_tables(table_imgs, table_cells, needs_ocr: List[bool], models, table_rec_batch_size=settings.TABLE_REC_BATCH_SIZE, ocr_batch_size=settings.RECOGNITION_BATCH_SIZE) -> List[TableResult]:
table_rec_model, table_rec_processor, ocr_model, ocr_processor = models

if sum(needs_ocr) > 0:
Expand All @@ -44,14 +46,14 @@ def recognize_tables(table_imgs, table_cells, needs_ocr: List[bool], models) ->
ocr_cells = [[c["bbox"] for c in cells] for cells, needs in zip(table_cells, needs_ocr) if needs]
ocr_langs = [None] * len(ocr_images)

ocr_predictions = run_recognition(ocr_images, ocr_langs, ocr_model, ocr_processor, bboxes=ocr_cells)
ocr_predictions = run_recognition(ocr_images, ocr_langs, ocr_model, ocr_processor, bboxes=ocr_cells, batch_size=ocr_batch_size)

# Assign text to correct spot
for orig_idx, ocr_pred in zip(needs_ocr_idx, ocr_predictions):
for ocr_line, cell in zip(ocr_pred.text_lines, table_cells[orig_idx]):
cell["text"] = ocr_line.text

table_preds = batch_table_recognition(table_imgs, table_cells, table_rec_model, table_rec_processor)
table_preds = batch_table_recognition(table_imgs, table_cells, table_rec_model, table_rec_processor, batch_size=table_rec_batch_size)
return table_preds


2 changes: 2 additions & 0 deletions tabled/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class ExtractPageResult(BaseModel):
cells: List[List[SpanTableCell]]
rows_cols: List[TableResult]
table_imgs: List[Any]
bboxes: List[Bbox] # Bbox of the table
image_bboxes: List[Bbox] # Bbox of the image/page table is inside

@model_validator(mode="after")
def check_cells(self):
Expand Down
11 changes: 11 additions & 0 deletions tabled/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@ class Settings(BaseSettings):
IN_STREAMLIT: bool = False
TORCH_DEVICE: Optional[str] = None

# Batch sizes
# See https://github.com/VikParuchuri/surya for default values
## Table recognition
TABLE_REC_BATCH_SIZE: Optional[int] = None
## OCR
RECOGNITION_BATCH_SIZE: Optional[int] = None
## Text detector
DETECTOR_BATCH_SIZE: Optional[int] = None
## Layout
LAYOUT_BATCH_SIZE: Optional[int] = None

class Config:
env_file = find_dotenv("local.env")
extra = "ignore"
Expand Down

0 comments on commit a2d5163

Please sign in to comment.