Skip to content
This repository was archived by the owner on Jan 24, 2025. It is now read-only.

Commit 6ac7ba5

Browse files
authored
Merge pull request #13 from VikParuchuri/dev
Enable configuring batch sizes
2 parents 4445446 + a2d5163 commit 6ac7ba5

File tree

8 files changed

+76
-29
lines changed

8 files changed

+76
-29
lines changed

README.md

+20-10
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,26 @@ tabled DATA_PATH
6767
- `--detect_cell_boxes` by default, tabled will attempt to pull cell information out of the pdf. If you instead want cells to be detected by a detection model, specify this (usually you only need this with pdfs that have bad embedded text).
6868
- `--save_images` specifies that images of detected rows/columns and cells should be saved.
6969

70-
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
71-
72-
- `text_lines` - the detected text and bounding boxes for each line
73-
- `text` - the text in the line
74-
- `confidence` - the confidence of the model in the detected text (0-1)
75-
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
76-
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
77-
- `languages` - the languages specified for the page
78-
- `page` - the page number in the file
79-
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
70+
After running the script, the output directory will contain folders with the same basenames as the input filenames. Inside those folders will be the markdown files for each table in the source documents. There will also optionally be images of the tables.
71+
72+
There will also be a `results.json` file in the root of the output directory. The file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per table in the document. Each table dictionary contains:
73+
74+
- `cells` - the detected text and bounding boxes for each table cell.
75+
- `bbox` - bbox of the cell within the table bbox
76+
- `text` - the text of the cell
77+
- `row_ids` - ids of rows the cell belongs to
78+
- `col_ids` - ids of columns the cell belongs to
79+
- `order` - order of this cell within its assigned row/column cell. (sort by row, then column, then order)
80+
- `rows` - bboxes of the detected rows
81+
- `bbox` - bbox of the row in (x1, x2, y1, y2) format
82+
- `row_id` - unique id of the row
83+
- `cols` - bboxes of detected columns
84+
- `bbox` - bbox of the column in (x1, x2, y1, y2) format
85+
- `col_id` - unique id of the column
86+
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. The table bbox is relative to this.
87+
- `bbox` - the bounding box of the table within the image bbox.
88+
- `pnum` - page number within the document
89+
- `tnum` - table index on the page
8090

8191
## Interactive App
8292

extract.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from collections import defaultdict
23

34
import pypdfium2
@@ -51,18 +52,21 @@ def main(in_path, out_folder, save_json, save_debug_images, skip_detection, dete
5152

5253
formatted_result, ext = formatter(format, page_cells)
5354
base_name = f"page{pnum}_table{i}"
54-
with open(os.path.join(base_path, f"{base_name}.{ext}"), "w") as f:
55+
with open(os.path.join(base_path, f"{base_name}.{ext}"), "w+", encoding="utf-8") as f:
5556
f.write(formatted_result)
5657

5758
img.save(os.path.join(base_path, f"{base_name}.png"))
5859

59-
if save_json:
60-
result = {
61-
"cells": [c.model_dump() for c in page_cells],
62-
"rows": [r.model_dump() for r in page_rc.rows],
63-
"cols": [c.model_dump() for c in page_rc.cols]
64-
}
65-
out_json[name].append(result)
60+
res = {
61+
"cells": [c.model_dump() for c in page_cells],
62+
"rows": [r.model_dump() for r in page_rc.rows],
63+
"cols": [c.model_dump() for c in page_rc.cols],
64+
"bbox": result.bboxes[i].bbox,
65+
"image_bbox": result.image_bboxes[i].bbox,
66+
"pnum": pnum,
67+
"tnum": i
68+
}
69+
out_json[name].append(res)
6670

6771
if save_debug_images:
6872
boxes = [l.bbox for l in page_cells]
@@ -80,6 +84,10 @@ def main(in_path, out_folder, save_json, save_debug_images, skip_detection, dete
8084
rc_image = draw_bboxes_on_image(cols, rc_image, labels=col_labels, label_font_size=20, color="red")
8185
rc_image.save(os.path.join(base_path, f"{base_name}_rc.png"))
8286

87+
if save_json:
88+
with open(os.path.join(out_folder, "result.json"), "w+", encoding="utf-8") as f:
89+
json.dump(out_json, f, ensure_ascii=False)
90+
8391

8492
if __name__ == "__main__":
8593
main()

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tabled-pdf"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = "Detect and recognize tables in PDFs and images."
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"

tabled/extract.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
from typing import List
22

3+
from surya.schema import Bbox
4+
35
from tabled.assignment import assign_rows_columns
46
from tabled.inference.detection import detect_tables
57
from tabled.inference.recognition import get_cells, recognize_tables
68
from tabled.schema import ExtractPageResult
79

810

9-
def extract_tables(images, highres_images, text_lines, det_models, rec_models, skip_detection=False, detect_boxes=False) -> List[ExtractPageResult]:
11+
def extract_tables(
12+
images,
13+
highres_images,
14+
text_lines,
15+
det_models,
16+
rec_models,
17+
skip_detection=False,
18+
detect_boxes=False
19+
) -> List[ExtractPageResult]:
1020
if not skip_detection:
1121
table_imgs, table_bboxes, table_counts = detect_tables(images, highres_images, det_models)
1222
else:
@@ -33,7 +43,9 @@ def extract_tables(images, highres_images, text_lines, det_models, rec_models, s
3343
results.append(ExtractPageResult(
3444
table_imgs=table_imgs[page_start:page_end],
3545
cells=cells[page_start:page_end],
36-
rows_cols=table_rec[page_start:page_end]
46+
rows_cols=table_rec[page_start:page_end],
47+
bboxes=[Bbox(bbox=b) for b in table_bboxes[page_start:page_end]],
48+
image_bboxes=[Bbox(bbox=[0, 0, size[0], size[1]]) for size in highres_image_sizes[page_start:page_end]]
3749
))
3850
counter += count
3951

tabled/inference/detection.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from surya.postprocessing.util import rescale_bbox
44
from surya.schema import Bbox
55

6+
from tabled.settings import settings
7+
68

79
def merge_boxes(box1, box2):
810
return [min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3])]
@@ -30,10 +32,10 @@ def merge_tables(page_table_boxes):
3032
return [b for i, b in enumerate(page_table_boxes) if i not in ignore_boxes]
3133

3234

33-
def detect_tables(images, highres_images, models):
35+
def detect_tables(images, highres_images, models, detector_batch_size=settings.DETECTOR_BATCH_SIZE, layout_batch_size=settings.LAYOUT_BATCH_SIZE):
3436
det_model, det_processor, layout_model, layout_processor = models
35-
line_predictions = batch_text_detection(images, det_model, det_processor)
36-
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
37+
line_predictions = batch_text_detection(images, det_model, det_processor, batch_size=detector_batch_size)
38+
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions, batch_size=layout_batch_size)
3739

3840
table_imgs = []
3941
table_counts = []

tabled/inference/recognition.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from surya.schema import TableResult
77
from surya.tables import batch_table_recognition
88

9+
from tabled.settings import settings
910

10-
def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_boxes=False):
11+
12+
def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_boxes=False, detector_batch_size=settings.DETECTOR_BATCH_SIZE):
1113
det_model, det_processor = models
1214
table_cells = []
1315
needs_ocr = []
@@ -27,15 +29,15 @@ def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_
2729

2830
# Inference tables that need it
2931
if len(to_inference_idxs) > 0:
30-
det_results = batch_text_detection([table_imgs[i] for i in to_inference_idxs], det_model, det_processor)
32+
det_results = batch_text_detection([table_imgs[i] for i in to_inference_idxs], det_model, det_processor, batch_size=detector_batch_size)
3133
for idx, det_result in zip(to_inference_idxs, det_results):
3234
cell_bboxes = [{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes]
3335
table_cells[idx] = cell_bboxes
3436

3537
return table_cells, needs_ocr
3638

3739

38-
def recognize_tables(table_imgs, table_cells, needs_ocr: List[bool], models) -> List[TableResult]:
40+
def recognize_tables(table_imgs, table_cells, needs_ocr: List[bool], models, table_rec_batch_size=settings.TABLE_REC_BATCH_SIZE, ocr_batch_size=settings.RECOGNITION_BATCH_SIZE) -> List[TableResult]:
3941
table_rec_model, table_rec_processor, ocr_model, ocr_processor = models
4042

4143
if sum(needs_ocr) > 0:
@@ -44,14 +46,14 @@ def recognize_tables(table_imgs, table_cells, needs_ocr: List[bool], models) ->
4446
ocr_cells = [[c["bbox"] for c in cells] for cells, needs in zip(table_cells, needs_ocr) if needs]
4547
ocr_langs = [None] * len(ocr_images)
4648

47-
ocr_predictions = run_recognition(ocr_images, ocr_langs, ocr_model, ocr_processor, bboxes=ocr_cells)
49+
ocr_predictions = run_recognition(ocr_images, ocr_langs, ocr_model, ocr_processor, bboxes=ocr_cells, batch_size=ocr_batch_size)
4850

4951
# Assign text to correct spot
5052
for orig_idx, ocr_pred in zip(needs_ocr_idx, ocr_predictions):
5153
for ocr_line, cell in zip(ocr_pred.text_lines, table_cells[orig_idx]):
5254
cell["text"] = ocr_line.text
5355

54-
table_preds = batch_table_recognition(table_imgs, table_cells, table_rec_model, table_rec_processor)
56+
table_preds = batch_table_recognition(table_imgs, table_cells, table_rec_model, table_rec_processor, batch_size=table_rec_batch_size)
5557
return table_preds
5658

5759

tabled/schema.py

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class ExtractPageResult(BaseModel):
4444
cells: List[List[SpanTableCell]]
4545
rows_cols: List[TableResult]
4646
table_imgs: List[Any]
47+
bboxes: List[Bbox] # Bbox of the table
48+
image_bboxes: List[Bbox] # Bbox of the image/page table is inside
4749

4850
@model_validator(mode="after")
4951
def check_cells(self):

tabled/settings.py

+11
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@ class Settings(BaseSettings):
99
IN_STREAMLIT: bool = False
1010
TORCH_DEVICE: Optional[str] = None
1111

12+
# Batch sizes
13+
# See https://github.com/VikParuchuri/surya for default values
14+
## Table recognition
15+
TABLE_REC_BATCH_SIZE: Optional[int] = None
16+
## OCR
17+
RECOGNITION_BATCH_SIZE: Optional[int] = None
18+
## Text detector
19+
DETECTOR_BATCH_SIZE: Optional[int] = None
20+
## Layout
21+
LAYOUT_BATCH_SIZE: Optional[int] = None
22+
1223
class Config:
1324
env_file = find_dotenv("local.env")
1425
extra = "ignore"

0 commit comments

Comments
 (0)