diff --git a/README.md b/README.md index b87d111..4eaf4a7 100644 --- a/README.md +++ b/README.md @@ -67,16 +67,26 @@ tabled DATA_PATH - `--detect_cell_boxes` by default, tabled will attempt to pull cell information out of the pdf. If you instead want cells to be detected by a detection model, specify this (usually you only need this with pdfs that have bad embedded text). - `--save_images` specifies that images of detected rows/columns and cells should be saved. -The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - -- `text_lines` - the detected text and bounding boxes for each line - - `text` - the text in the line - - `confidence` - the confidence of the model in the detected text (0-1) - - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. -- `languages` - the languages specified for the page -- `page` - the page number in the file -- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. +After running the script, the output directory will contain folders with the same basenames as the input filenames. Inside those folders will be the markdown files for each table in the source documents. There will also optionally be images of the tables. + +There will also be a `results.json` file in the root of the output directory. The file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per table in the document. Each table dictionary contains: + +- `cells` - the detected text and bounding boxes for each table cell. + - `bbox` - bbox of the cell within the table bbox + - `text` - the text of the cell + - `row_ids` - ids of rows the cell belongs to + - `col_ids` - ids of columns the cell belongs to + - `order` - order of this cell within its assigned row/column cell. (sort by row, then column, then order) +- `rows` - bboxes of the detected rows + - `bbox` - bbox of the row in (x1, x2, y1, y2) format + - `row_id` - unique id of the row +- `cols` - bboxes of detected columns + - `bbox` - bbox of the column in (x1, x2, y1, y2) format + - `col_id` - unique id of the column +- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. The table bbox is relative to this. +- `bbox` - the bounding box of the table within the image bbox. +- `pnum` - page number within the document +- `tnum` - table index on the page ## Interactive App diff --git a/extract.py b/extract.py index 6038aa5..f8693bb 100644 --- a/extract.py +++ b/extract.py @@ -1,3 +1,4 @@ +import json from collections import defaultdict import pypdfium2 @@ -56,13 +57,16 @@ def main(in_path, out_folder, save_json, save_debug_images, skip_detection, dete img.save(os.path.join(base_path, f"{base_name}.png")) - if save_json: - result = { - "cells": [c.model_dump() for c in page_cells], - "rows": [r.model_dump() for r in page_rc.rows], - "cols": [c.model_dump() for c in page_rc.cols] - } - out_json[name].append(result) + result = { + "cells": [c.model_dump() for c in page_cells], + "rows": [r.model_dump() for r in page_rc.rows], + "cols": [c.model_dump() for c in page_rc.cols], + "bbox": result.bboxes[i].bbox, + "image_bbox": result.image_bboxes[i].bbox, + "pnum": pnum, + "tnum": i + } + out_json[name].append(result) if save_debug_images: boxes = [l.bbox for l in page_cells] @@ -80,6 +84,10 @@ def main(in_path, out_folder, save_json, save_debug_images, skip_detection, dete rc_image = draw_bboxes_on_image(cols, rc_image, labels=col_labels, label_font_size=20, color="red") rc_image.save(os.path.join(base_path, f"{base_name}_rc.png")) + if save_json: + with open(os.path.join(out_folder, "result.json"), "w+") as f: + json.dump(out_json, f, ensure_ascii=False) + if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index 9feb0ca..0696b17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tabled-pdf" -version = "0.1.1" +version = "0.1.2" description = "Detect and recognize tables in PDFs and images." authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/tabled/extract.py b/tabled/extract.py index 3501e3e..7284f3b 100644 --- a/tabled/extract.py +++ b/tabled/extract.py @@ -1,12 +1,22 @@ from typing import List +from surya.schema import Bbox + from tabled.assignment import assign_rows_columns from tabled.inference.detection import detect_tables from tabled.inference.recognition import get_cells, recognize_tables from tabled.schema import ExtractPageResult -def extract_tables(images, highres_images, text_lines, det_models, rec_models, skip_detection=False, detect_boxes=False) -> List[ExtractPageResult]: +def extract_tables( + images, + highres_images, + text_lines, + det_models, + rec_models, + skip_detection=False, + detect_boxes=False +) -> List[ExtractPageResult]: if not skip_detection: table_imgs, table_bboxes, table_counts = detect_tables(images, highres_images, det_models) else: @@ -33,7 +43,9 @@ def extract_tables(images, highres_images, text_lines, det_models, rec_models, s results.append(ExtractPageResult( table_imgs=table_imgs[page_start:page_end], cells=cells[page_start:page_end], - rows_cols=table_rec[page_start:page_end] + rows_cols=table_rec[page_start:page_end], + bboxes=[Bbox(bbox=b) for b in table_bboxes[page_start:page_end]], + image_bboxes=[Bbox(bbox=[0, 0, size[0], size[1]]) for size in highres_image_sizes[page_start:page_end]] )) counter += count diff --git a/tabled/inference/detection.py b/tabled/inference/detection.py index 2b0f0db..c1da422 100644 --- a/tabled/inference/detection.py +++ b/tabled/inference/detection.py @@ -3,6 +3,8 @@ from surya.postprocessing.util import rescale_bbox from surya.schema import Bbox +from tabled.settings import settings + def merge_boxes(box1, box2): return [min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3])] @@ -30,10 +32,10 @@ def merge_tables(page_table_boxes): return [b for i, b in enumerate(page_table_boxes) if i not in ignore_boxes] -def detect_tables(images, highres_images, models): +def detect_tables(images, highres_images, models, detector_batch_size=settings.DETECTOR_BATCH_SIZE, layout_batch_size=settings.LAYOUT_BATCH_SIZE): det_model, det_processor, layout_model, layout_processor = models - line_predictions = batch_text_detection(images, det_model, det_processor) - layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions) + line_predictions = batch_text_detection(images, det_model, det_processor, batch_size=detector_batch_size) + layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions, batch_size=layout_batch_size) table_imgs = [] table_counts = [] diff --git a/tabled/inference/recognition.py b/tabled/inference/recognition.py index de32bc7..45859ae 100644 --- a/tabled/inference/recognition.py +++ b/tabled/inference/recognition.py @@ -6,8 +6,10 @@ from surya.schema import TableResult from surya.tables import batch_table_recognition +from tabled.settings import settings -def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_boxes=False): + +def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_boxes=False, detector_batch_size=settings.DETECTOR_BATCH_SIZE): det_model, det_processor = models table_cells = [] needs_ocr = [] @@ -27,7 +29,7 @@ def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_ # Inference tables that need it if len(to_inference_idxs) > 0: - det_results = batch_text_detection([table_imgs[i] for i in to_inference_idxs], det_model, det_processor) + det_results = batch_text_detection([table_imgs[i] for i in to_inference_idxs], det_model, det_processor, batch_size=detector_batch_size) for idx, det_result in zip(to_inference_idxs, det_results): cell_bboxes = [{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] table_cells[idx] = cell_bboxes @@ -35,7 +37,7 @@ def get_cells(table_imgs, table_bboxes, image_sizes, text_lines, models, detect_ return table_cells, needs_ocr -def recognize_tables(table_imgs, table_cells, needs_ocr: List[bool], models) -> List[TableResult]: +def recognize_tables(table_imgs, table_cells, needs_ocr: List[bool], models, table_rec_batch_size=settings.TABLE_REC_BATCH_SIZE, ocr_batch_size=settings.RECOGNITION_BATCH_SIZE) -> List[TableResult]: table_rec_model, table_rec_processor, ocr_model, ocr_processor = models if sum(needs_ocr) > 0: @@ -44,14 +46,14 @@ def recognize_tables(table_imgs, table_cells, needs_ocr: List[bool], models) -> ocr_cells = [[c["bbox"] for c in cells] for cells, needs in zip(table_cells, needs_ocr) if needs] ocr_langs = [None] * len(ocr_images) - ocr_predictions = run_recognition(ocr_images, ocr_langs, ocr_model, ocr_processor, bboxes=ocr_cells) + ocr_predictions = run_recognition(ocr_images, ocr_langs, ocr_model, ocr_processor, bboxes=ocr_cells, batch_size=ocr_batch_size) # Assign text to correct spot for orig_idx, ocr_pred in zip(needs_ocr_idx, ocr_predictions): for ocr_line, cell in zip(ocr_pred.text_lines, table_cells[orig_idx]): cell["text"] = ocr_line.text - table_preds = batch_table_recognition(table_imgs, table_cells, table_rec_model, table_rec_processor) + table_preds = batch_table_recognition(table_imgs, table_cells, table_rec_model, table_rec_processor, batch_size=table_rec_batch_size) return table_preds diff --git a/tabled/schema.py b/tabled/schema.py index 6e6b10a..5f5a349 100644 --- a/tabled/schema.py +++ b/tabled/schema.py @@ -44,6 +44,8 @@ class ExtractPageResult(BaseModel): cells: List[List[SpanTableCell]] rows_cols: List[TableResult] table_imgs: List[Any] + bboxes: List[Bbox] # Bbox of the table + image_bboxes: List[Bbox] # Bbox of the image/page table is inside @model_validator(mode="after") def check_cells(self): diff --git a/tabled/settings.py b/tabled/settings.py index b7d5671..338b450 100644 --- a/tabled/settings.py +++ b/tabled/settings.py @@ -9,6 +9,17 @@ class Settings(BaseSettings): IN_STREAMLIT: bool = False TORCH_DEVICE: Optional[str] = None + # Batch sizes + # See https://github.com/VikParuchuri/surya for default values + ## Table recognition + TABLE_REC_BATCH_SIZE: Optional[int] = None + ## OCR + RECOGNITION_BATCH_SIZE: Optional[int] = None + ## Text detector + DETECTOR_BATCH_SIZE: Optional[int] = None + ## Layout + LAYOUT_BATCH_SIZE: Optional[int] = None + class Config: env_file = find_dotenv("local.env") extra = "ignore"