Skip to content

Commit

Permalink
Add benchmark, csv out
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 14, 2024
1 parent 00d2e72 commit 97b8341
Show file tree
Hide file tree
Showing 18 changed files with 463 additions and 53 deletions.
49 changes: 45 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
# Tabled

Tabled is a small library for detecting and extracting tables. It uses [surya](https://www.github.com/VikParuchuri/surya) to first find all the tables in a PDF, then identifies the rows/columns, and turns the cells into either markdown or html.
Tabled is a small library for detecting and extracting tables. It uses [surya](https://www.github.com/VikParuchuri/surya) to find all the tables in a PDF, identifies the rows/columns, and formats cells into markdown, csv, or html.

## Examples
## Example

![Table image 0](static/images/table_example.png)


| Characteristic | | | Population | | | | Change from 2016 to 2060 | |
|--------------------|-------|-------|--------------|-------|-------|-------|------------------------------|---------|
| | 2016 | 2020 | 2030 | 2040 | 2050 | 2060 | Number | Percent |
| Total population | 323.1 | 332.6 | 355.1 | 373.5 | 388.9 | 404.5 | 81.4 | 25.2 |
| Under 18 years | 73.6 | 74.0 | 75.7 | 77.1 | 78.2 | 80.1 | 6.5 | 8.8 |
| 18 to 44 years | 116.0 | 119.2 | 125.0 | 126.4 | 129.6 | 132.7 | 16.7 | 14.4 |
| 45 to 64 years | 84.3 | 83.4 | 81.3 | 89.1 | 95.4 | 97.0 | 12.7 | 15.1 |
| 65 years and over | 49.2 | 56.1 | 73.1 | 80.8 | 85.7 | 94.7 | 45.4 | 92.3 |
| 85 years and over | 6.4 | 6.7 | 9.1 | 14.4 | 18.6 | 19.0 | 12.6 | 198.1 |
| 100 years and over | 0.1 | 0.1 | 0.1 | 0.2 | 0.4 | 0.6 | 0.5 | 618.3 |


## Community
Expand Down Expand Up @@ -46,8 +60,11 @@ tabled DATA_PATH
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--skip_detection` means that the images you pass in are all cropped tables and don't need any detection.
- `--detect_boxes` by default, tabled will attempt to pull cell information out of the pdf. If you instead want cells to be detected by a detection model, specify this (usually you only need this with pdfs that have bad embedded text).
- `--format` specifies output format for each table (`markdown`, `html`, or `csv`)
- `--save_json` saves additional row and column information in a json file
- `--save_debug_images` saves images showing the detected rows and columns
- `--skip_detection` means that the images you pass in are all cropped tables and don't need any table detection.
- `--detect_cell_boxes` by default, tabled will attempt to pull cell information out of the pdf. If you instead want cells to be detected by a detection model, specify this (usually you only need this with pdfs that have bad embedded text).
- `--save_images` specifies that images of detected rows/columns and cells should be saved.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
Expand All @@ -68,4 +85,28 @@ I've included a streamlit app that lets you interactively try tabled on images o
```shell
pip install streamlit
tabled_gui
```

# Benchmarks

| Avg score | Time per table (s) | Total tables |
|-------------|--------------------|----------------|
| 0.91 | 0.03 | 688 |

## Quality

Getting good ground truth data for tables is hard, since you're either constrained to simple layouts that can be heuristically parsed and rendered, or you need to use LLMs, which make mistakes. I chose to use GPT-4 table predictions as a pseudo-ground-truth.

Tabled gets a `.91` alignment score when compared to GPT-4, which indicates alignment between the text in table rows/cells. Some of the misalignments are due to GPT-4 mistakes, or small inconsistencies in what GPT-4 considered the borders of the table. In general, extraction quality is quite high.

## Performance

Running on an A10G with 10GB of VRAM usage and batch size `64`, tabled takes `.03` seconds per table.

## Running your own

Run the benchmark with:

```shell
python benchmarks/benchmark.py out.json
```
76 changes: 76 additions & 0 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import argparse
import json
import time

import datasets
from surya.input.pdflines import get_table_blocks
from tabulate import tabulate
from tqdm import tqdm
from scoring import score_table
from tabled.assignment import assign_rows_columns

from tabled.formats import formatter
from tabled.inference.models import load_recognition_models
from tabled.inference.recognition import recognize_tables


def main():
parser = argparse.ArgumentParser(description="Benchmark table conversion.")
parser.add_argument("out_file", help="Output filename for results")
parser.add_argument("--dataset", type=str, help="Dataset to use", default="vikp/table_bench2")
args = parser.parse_args()

ds = datasets.load_dataset(args.dataset, split="train")

rec_models = load_recognition_models()

results = []
table_imgs = []
table_blocks = []
for i in range(len(ds)):
row = ds[i]
line_data = json.loads(row["text_lines"])
table_bbox = row["table_bbox"]
image_size = row["page_size"]
table_img = row["table_image"]

table_block = get_table_blocks([table_bbox], line_data, image_size)[0]
table_imgs.append(table_img)
table_blocks.append(table_block)

start = time.time()
table_rec = recognize_tables(table_imgs, table_blocks, [False] * len(table_imgs), rec_models)
total_time = time.time() - start
cells = [assign_rows_columns(tr) for tr in table_rec]

for i in range(len(ds)):
row = ds[i]
table_cells = cells[i]
table_bbox = row["table_bbox"]
gpt4_table = json.loads(row["gpt_4_table"])["markdown_table"]

table_markdown, _ = formatter("markdown", table_cells)

results.append({
"score": score_table(table_markdown, gpt4_table),
"arxiv_id": row["arxiv_id"],
"page_idx": row["page_idx"],
"marker_table": table_markdown,
"gpt4_table": gpt4_table,
"table_bbox": table_bbox
})

avg_score = sum([r["score"] for r in results]) / len(results)
headers = ["Avg score", "Time per table", "Total tables"]
data = [f"{avg_score:.3f}", f"{total_time / len(ds):.3f}", len(ds)]

table = tabulate([data], headers=headers, tablefmt="github")
print(table)
print("Avg score computed by aligning table cell text with GPT-4 table cell text.")

with open(args.out_file, "w+") as f:
json.dump(results, f, indent=2)


if __name__ == "__main__":
main()
46 changes: 46 additions & 0 deletions benchmarks/scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from rapidfuzz import fuzz
import re


def split_to_cells(table):
table = table.strip()
table = re.sub(r" {2,}", "", table)
table_rows = table.split("\n")
table_rows = [t for t in table_rows if t.strip()]
table_cells = [[c.strip() for c in r.split("|")] for r in table_rows]
return table_cells


def align_rows(hypothesis, ref_row):
best_alignment = []
best_alignment_score = 0
for j in range(0, len(hypothesis)):
hyp_row = hypothesis[j]
alignments = []
for i in range(len(ref_row)):
if i >= len(hypothesis[j]):
alignments.append(0)
continue
max_cell_align = 0
for k in range(0, len(hyp_row)):
cell_align = fuzz.ratio(hyp_row[k], ref_row[i], score_cutoff=30) / 100
if cell_align > max_cell_align:
max_cell_align = cell_align
alignments.append(max_cell_align)
if len(alignments) == 0:
continue
alignment_score = sum(alignments) / len(alignments)
if alignment_score >= best_alignment_score:
best_alignment = alignments
best_alignment_score = alignment_score
return best_alignment


def score_table(hypothesis, reference):
hypothesis = split_to_cells(hypothesis)
reference = split_to_cells(reference)

alignments = []
for i in range(0, len(reference)):
alignments.extend(align_rows(hypothesis, reference[i]))
return sum(alignments) / max(len(alignments), 1)
15 changes: 8 additions & 7 deletions extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from surya.postprocessing.heatmap import draw_bboxes_on_image

from tabled.extract import extract_tables
from tabled.formats import formatter
from tabled.formats.markdown import markdown_format
from tabled.inference.detection import detect_tables

Expand All @@ -23,8 +24,9 @@
@click.option("--save_json", is_flag=True, help="Save row/column/cell information in json format")
@click.option("--save_debug_images", is_flag=True, help="Save images for debugging")
@click.option("--skip_detection", is_flag=True, help="Skip table detection")
@click.option("--detect_boxes", is_flag=True, help="Detect table cell boxes vs extract from PDF. Will also run OCR.")
def main(in_path, out_folder, save_json, save_debug_images, skip_detection, detect_boxes):
@click.option("--detect_cell_boxes", is_flag=True, help="Detect table cell boxes vs extract from PDF. Will also run OCR.")
@click.option("--format", type=click.Choice(["markdown", "csv", "html"]), default="markdown")
def main(in_path, out_folder, save_json, save_debug_images, skip_detection, detect_cell_boxes, format):
os.makedirs(out_folder, exist_ok=True)
images, highres_images, names, text_lines = load_pdfs_images(in_path)
pnums = []
Expand All @@ -37,26 +39,25 @@ def main(in_path, out_folder, save_json, save_debug_images, skip_detection, dete

prev_name = name


det_models = load_detection_models()
rec_models = load_recognition_models()

page_results = extract_tables(images, highres_images, text_lines, det_models, rec_models, skip_detection=skip_detection, detect_boxes=detect_boxes)
page_results = extract_tables(images, highres_images, text_lines, det_models, rec_models, skip_detection=skip_detection, detect_boxes=detect_cell_boxes)

out_json = defaultdict(list)
for name, pnum, result in zip(names, pnums, page_results):
for i in range(result.total):
page_cells = result.cells[i]
page_rc = result.rows_cols[i]
md = markdown_format(page_cells)
img = result.table_imgs[i]

base_path = os.path.join(out_folder, name)
os.makedirs(base_path, exist_ok=True)

formatted_result, ext = formatter(format, page_cells)
base_name = f"page{pnum}_table{i}"
with open(os.path.join(base_path, f"{base_name}.md"), "w") as f:
f.write(md)
with open(os.path.join(base_path, f"{base_name}.{ext}"), "w") as f:
f.write(formatted_result)

img.save(os.path.join(base_path, f"{base_name}.png"))

Expand Down
Loading

0 comments on commit 97b8341

Please sign in to comment.