diff --git a/README.md b/README.md index 8690f83..b87d111 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Tabled is a small library for detecting and extracting tables. It uses [surya]( ## Community -[Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development.` +[Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development. # Hosted API @@ -87,26 +87,45 @@ pip install streamlit tabled_gui ``` +## From python + +```python +from tabled.extract import extract_tables +from tabled.fileinput import load_pdfs_images +from tabled.inference.models import load_detection_models, load_recognition_models + +det_models, rec_models = load_detection_models(), load_recognition_models() +images, highres_images, names, text_lines = load_pdfs_images(IN_PATH) + +page_results = extract_tables(images, highres_images, text_lines, det_models, rec_models) +``` + # Benchmarks -| Avg score | Time per table (s) | Total tables | -|-------------|--------------------|----------------| -| 0.91 | 0.03 | 688 | +| Avg score | Time per table | Total tables | +|-------------|------------------|----------------| +| 0.847 | 0.029 | 688 | ## Quality Getting good ground truth data for tables is hard, since you're either constrained to simple layouts that can be heuristically parsed and rendered, or you need to use LLMs, which make mistakes. I chose to use GPT-4 table predictions as a pseudo-ground-truth. -Tabled gets a `.91` alignment score when compared to GPT-4, which indicates alignment between the text in table rows/cells. Some of the misalignments are due to GPT-4 mistakes, or small inconsistencies in what GPT-4 considered the borders of the table. In general, extraction quality is quite high. +Tabled gets a `.847` alignment score when compared to GPT-4, which indicates alignment between the text in table rows/cells. Some of the misalignments are due to GPT-4 mistakes, or small inconsistencies in what GPT-4 considered the borders of the table. In general, extraction quality is quite high. ## Performance -Running on an A10G with 10GB of VRAM usage and batch size `64`, tabled takes `.03` seconds per table. +Running on an A10G with 10GB of VRAM usage and batch size `64`, tabled takes `.029` seconds per table. -## Running your own +## Running the benchmark Run the benchmark with: ```shell python benchmarks/benchmark.py out.json -``` \ No newline at end of file +``` + +# Acknowledgements + +- Thank you to [Peter Jansen](https://cognitiveai.org/) for the benchmarking dataset, and for discussion about table parsing. +- Huggingface for inference code and model hosting +- PyTorch for training/inference \ No newline at end of file diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index a8e5b1c..3254f27 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -27,6 +27,7 @@ def main(): results = [] table_imgs = [] table_blocks = [] + image_sizes = [] for i in range(len(ds)): row = ds[i] line_data = json.loads(row["text_lines"]) @@ -37,11 +38,12 @@ def main(): table_block = get_table_blocks([table_bbox], line_data, image_size)[0] table_imgs.append(table_img) table_blocks.append(table_block) + image_sizes.append(image_size) start = time.time() table_rec = recognize_tables(table_imgs, table_blocks, [False] * len(table_imgs), rec_models) total_time = time.time() - start - cells = [assign_rows_columns(tr) for tr in table_rec] + cells = [assign_rows_columns(tr, im_size) for tr, im_size in zip(table_rec, image_sizes)] for i in range(len(ds)): row = ds[i] diff --git a/benchmarks/scoring.py b/benchmarks/scoring.py index 2279ad6..d148062 100644 --- a/benchmarks/scoring.py +++ b/benchmarks/scoring.py @@ -15,18 +15,13 @@ def align_rows(hypothesis, ref_row): best_alignment = [] best_alignment_score = 0 for j in range(0, len(hypothesis)): - hyp_row = hypothesis[j] alignments = [] for i in range(len(ref_row)): if i >= len(hypothesis[j]): alignments.append(0) continue - max_cell_align = 0 - for k in range(0, len(hyp_row)): - cell_align = fuzz.ratio(hyp_row[k], ref_row[i], score_cutoff=30) / 100 - if cell_align > max_cell_align: - max_cell_align = cell_align - alignments.append(max_cell_align) + alignment = fuzz.ratio(hypothesis[j][i], ref_row[i], score_cutoff=30) / 100 + alignments.append(alignment) if len(alignments) == 0: continue alignment_score = sum(alignments) / len(alignments) diff --git a/extract.py b/extract.py index 4226dba..f49c97b 100644 --- a/extract.py +++ b/extract.py @@ -9,13 +9,8 @@ from tabled.extract import extract_tables from tabled.formats import formatter -from tabled.formats.markdown import markdown_format -from tabled.inference.detection import detect_tables - -from tabled.assignment import assign_rows_columns from tabled.fileinput import load_pdfs_images from tabled.inference.models import load_detection_models, load_recognition_models -from tabled.inference.recognition import get_cells, recognize_tables @click.command() diff --git a/poetry.lock b/poetry.lock index d5276f5..37117d6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1316,6 +1316,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "json5" version = "0.9.25" @@ -3920,6 +3931,101 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "scikit-learn" +version = "1.5.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + +[[package]] +name = "scipy" +version = "1.14.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.10" +files = [ + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, +] + +[package.dependencies] +numpy = ">=1.23.5,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + [[package]] name = "send2trash" version = "1.8.3" @@ -4146,6 +4252,17 @@ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"] +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "tinycss2" version = "1.3.0" @@ -4914,4 +5031,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e48dcc3c07b66fc1754fbfb1e3eb5fa432c5f64cd54483f01b0f55de54bf591c" +content-hash = "5c2c3b78959ae5bfda93349d9a0f2f659541d6021e268a55103359579806200e" diff --git a/pyproject.toml b/pyproject.toml index ce5e6d1..5ac355b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ pydantic-settings = "^2.5.2" pydantic = "^2.9.2" python-dotenv = "^1.0.1" tabulate = "^0.9.0" +scikit-learn = "^1.5.2" [tool.poetry.group.dev.dependencies] jupyter = "^1.1.1" diff --git a/table_app.py b/table_app.py index 0f62064..01313fe 100644 --- a/table_app.py +++ b/table_app.py @@ -35,7 +35,7 @@ def run_table_rec(image, highres_image, text_line, models, skip_detection=False, cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, models[0][:2], detect_boxes=detect_boxes) table_rec = recognize_tables(table_imgs, cells, needs_ocr, models[1]) - cells = [assign_rows_columns(tr) for tr in table_rec] + cells = [assign_rows_columns(tr, im_size) for tr, im_size in zip(table_rec, highres_image_sizes)] out_data = [] for idx, (cell, pred, table_img) in enumerate(zip(cells, table_rec, table_imgs)): diff --git a/tabled/assignment.py b/tabled/assignment.py index d2cd677..d5c9213 100644 --- a/tabled/assignment.py +++ b/tabled/assignment.py @@ -3,6 +3,7 @@ import numpy as np from surya.schema import TableResult, Bbox +from tabled.heuristics import heuristic_layout from tabled.schema import SpanTableCell @@ -227,11 +228,18 @@ def find_row_gap(r1, r2): detection_result.rows = new_rows -def assign_rows_columns(detection_result: TableResult) -> List[SpanTableCell]: +def assign_rows_columns(detection_result: TableResult, image_size: list, heuristic_thresh=.6) -> List[SpanTableCell]: table_cells = initial_assignment(detection_result) merge_multiline_rows(detection_result, table_cells) table_cells = initial_assignment(detection_result) assign_overlappers(table_cells, detection_result) + total_unassigned = len([tc for tc in table_cells if tc.row_ids[0] is None or tc.col_ids[0] is None]) + unassigned_frac = total_unassigned / max(len(table_cells), 1) + + if unassigned_frac > heuristic_thresh: + table_cells = heuristic_layout(table_cells, image_size) + return table_cells + assign_unassigned(table_cells, detection_result) handle_rowcol_spans(table_cells, detection_result) return table_cells diff --git a/tabled/extract.py b/tabled/extract.py index cde5146..3501e3e 100644 --- a/tabled/extract.py +++ b/tabled/extract.py @@ -23,7 +23,7 @@ def extract_tables(images, highres_images, text_lines, det_models, rec_models, s cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, det_models[:2], detect_boxes=detect_boxes) table_rec = recognize_tables(table_imgs, cells, needs_ocr, rec_models) - cells = [assign_rows_columns(tr) for tr in table_rec] + cells = [assign_rows_columns(tr, im_size) for tr, im_size in zip(table_rec, highres_image_sizes)] results = [] counter = 0 diff --git a/tabled/heuristics/__init__.py b/tabled/heuristics/__init__.py new file mode 100644 index 0000000..8d1a2f3 --- /dev/null +++ b/tabled/heuristics/__init__.py @@ -0,0 +1,34 @@ +from typing import List + +from tabled.heuristics.cells import assign_cells_to_columns +from tabled.schema import SpanTableCell + + +def heuristic_layout(table_cells: List[SpanTableCell], page_size, row_tol=.01) -> List[SpanTableCell]: + table_rows = [] + table_row = [] + y_top = None + y_bottom = None + for cell in table_cells: + normed_y_start = cell.bbox[1] / page_size[1] + normed_y_end = cell.bbox[3] / page_size[1] + + if y_top is None: + y_top = normed_y_start + if y_bottom is None: + y_bottom = normed_y_end + + y_dist = min(abs(normed_y_start - y_bottom), abs(normed_y_end - y_bottom)) + if y_dist < row_tol: + table_row.append(cell) + else: + # New row + if len(table_row) > 0: + table_rows.append(table_row) + table_row = [cell] + y_top = normed_y_start + y_bottom = normed_y_end + if len(table_row) > 0: + table_rows.append(table_row) + + return assign_cells_to_columns(table_rows, page_size) \ No newline at end of file diff --git a/tabled/heuristics/cells.py b/tabled/heuristics/cells.py new file mode 100644 index 0000000..4569c28 --- /dev/null +++ b/tabled/heuristics/cells.py @@ -0,0 +1,83 @@ +import numpy as np +from sklearn.cluster import DBSCAN + + +def cluster_coords(coords, row_count): + if len(coords) == 0: + return [] + coords = np.array(sorted(set(coords))).reshape(-1, 1) + + clustering = DBSCAN(eps=.01, min_samples=max(2, row_count // 4)).fit(coords) + clusters = clustering.labels_ + + separators = [] + for label in set(clusters): + clustered_points = coords[clusters == label] + separators.append(np.mean(clustered_points)) + + separators = sorted(separators) + return separators + + +def find_column_separators(rows, page_size, round_factor=.002, min_count=1): + left_edges = [] + right_edges = [] + centers = [] + + boxes = [c.bbox for r in rows for c in r] + + for cell in boxes: + ncell = [cell[0] / page_size[0], cell[1] / page_size[1], cell[2] / page_size[0], cell[3] / page_size[1]] + left_edges.append(ncell[0] / round_factor * round_factor) + right_edges.append(ncell[2] / round_factor * round_factor) + centers.append((ncell[0] + ncell[2]) / 2 * round_factor / round_factor) + + left_edges = [l for l in left_edges if left_edges.count(l) > min_count] + right_edges = [r for r in right_edges if right_edges.count(r) > min_count] + centers = [c for c in centers if centers.count(c) > min_count] + + sorted_left = cluster_coords(left_edges, len(rows)) + sorted_right = cluster_coords(right_edges, len(rows)) + sorted_center = cluster_coords(centers, len(rows)) + + # Find list with minimum length + separators = max([sorted_left, sorted_right, sorted_center], key=len) + separators.append(1) + separators.insert(0, 0) + return separators + + +def assign_cells_to_columns(rows, page_size, round_factor=.002, tolerance=.01): + separators = find_column_separators(rows, page_size, round_factor=round_factor) + additional_column_index = 0 + row_dicts = [] + + for row in rows: + new_row = {} + last_col_index = -1 + for cell in row: + left_edge = cell.bbox[0] / page_size[0] + column_index = -1 + for i, separator in enumerate(separators): + if left_edge - tolerance < separator and last_col_index < i: + column_index = i + break + if column_index == -1: + column_index = len(separators) + additional_column_index + additional_column_index += 1 + new_row[column_index] = cell + last_col_index = column_index + additional_column_index = 0 + row_dicts.append(new_row) + + cells = [] + for row_idx, row in enumerate(row_dicts): + column = 0 + for col_idx in sorted(row.keys()): + cell = row[col_idx] + cell.row_ids = [row_idx] + cell.col_ids = [column] + cells.append(cell) + column += 1 + + return cells \ No newline at end of file