-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
00d2e72
commit 97b8341
Showing
18 changed files
with
463 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import argparse | ||
import json | ||
import time | ||
|
||
import datasets | ||
from surya.input.pdflines import get_table_blocks | ||
from tabulate import tabulate | ||
from tqdm import tqdm | ||
from scoring import score_table | ||
from tabled.assignment import assign_rows_columns | ||
|
||
from tabled.formats import formatter | ||
from tabled.inference.models import load_recognition_models | ||
from tabled.inference.recognition import recognize_tables | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Benchmark table conversion.") | ||
parser.add_argument("out_file", help="Output filename for results") | ||
parser.add_argument("--dataset", type=str, help="Dataset to use", default="vikp/table_bench2") | ||
args = parser.parse_args() | ||
|
||
ds = datasets.load_dataset(args.dataset, split="train") | ||
|
||
rec_models = load_recognition_models() | ||
|
||
results = [] | ||
table_imgs = [] | ||
table_blocks = [] | ||
for i in range(len(ds)): | ||
row = ds[i] | ||
line_data = json.loads(row["text_lines"]) | ||
table_bbox = row["table_bbox"] | ||
image_size = row["page_size"] | ||
table_img = row["table_image"] | ||
|
||
table_block = get_table_blocks([table_bbox], line_data, image_size)[0] | ||
table_imgs.append(table_img) | ||
table_blocks.append(table_block) | ||
|
||
start = time.time() | ||
table_rec = recognize_tables(table_imgs, table_blocks, [False] * len(table_imgs), rec_models) | ||
total_time = time.time() - start | ||
cells = [assign_rows_columns(tr) for tr in table_rec] | ||
|
||
for i in range(len(ds)): | ||
row = ds[i] | ||
table_cells = cells[i] | ||
table_bbox = row["table_bbox"] | ||
gpt4_table = json.loads(row["gpt_4_table"])["markdown_table"] | ||
|
||
table_markdown, _ = formatter("markdown", table_cells) | ||
|
||
results.append({ | ||
"score": score_table(table_markdown, gpt4_table), | ||
"arxiv_id": row["arxiv_id"], | ||
"page_idx": row["page_idx"], | ||
"marker_table": table_markdown, | ||
"gpt4_table": gpt4_table, | ||
"table_bbox": table_bbox | ||
}) | ||
|
||
avg_score = sum([r["score"] for r in results]) / len(results) | ||
headers = ["Avg score", "Time per table", "Total tables"] | ||
data = [f"{avg_score:.3f}", f"{total_time / len(ds):.3f}", len(ds)] | ||
|
||
table = tabulate([data], headers=headers, tablefmt="github") | ||
print(table) | ||
print("Avg score computed by aligning table cell text with GPT-4 table cell text.") | ||
|
||
with open(args.out_file, "w+") as f: | ||
json.dump(results, f, indent=2) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from rapidfuzz import fuzz | ||
import re | ||
|
||
|
||
def split_to_cells(table): | ||
table = table.strip() | ||
table = re.sub(r" {2,}", "", table) | ||
table_rows = table.split("\n") | ||
table_rows = [t for t in table_rows if t.strip()] | ||
table_cells = [[c.strip() for c in r.split("|")] for r in table_rows] | ||
return table_cells | ||
|
||
|
||
def align_rows(hypothesis, ref_row): | ||
best_alignment = [] | ||
best_alignment_score = 0 | ||
for j in range(0, len(hypothesis)): | ||
hyp_row = hypothesis[j] | ||
alignments = [] | ||
for i in range(len(ref_row)): | ||
if i >= len(hypothesis[j]): | ||
alignments.append(0) | ||
continue | ||
max_cell_align = 0 | ||
for k in range(0, len(hyp_row)): | ||
cell_align = fuzz.ratio(hyp_row[k], ref_row[i], score_cutoff=30) / 100 | ||
if cell_align > max_cell_align: | ||
max_cell_align = cell_align | ||
alignments.append(max_cell_align) | ||
if len(alignments) == 0: | ||
continue | ||
alignment_score = sum(alignments) / len(alignments) | ||
if alignment_score >= best_alignment_score: | ||
best_alignment = alignments | ||
best_alignment_score = alignment_score | ||
return best_alignment | ||
|
||
|
||
def score_table(hypothesis, reference): | ||
hypothesis = split_to_cells(hypothesis) | ||
reference = split_to_cells(reference) | ||
|
||
alignments = [] | ||
for i in range(0, len(reference)): | ||
alignments.extend(align_rows(hypothesis, reference[i])) | ||
return sum(alignments) / max(len(alignments), 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.