Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overlap detection postprocessing #218

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.6.6"
version = "0.6.7"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
56 changes: 46 additions & 10 deletions surya/detection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import threading
from queue import Queue
from typing import List, Tuple, Generator

import torch
Expand Down Expand Up @@ -126,18 +128,52 @@ def batch_text_detection(images: List, model, processor, batch_size=None) -> Lis
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)

results = []
result_lock = threading.Lock()
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH

if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for preds, orig_sizes in detection_generator:
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
results.extend(batch_results)
else:
for preds, orig_sizes in detection_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_lines(pred, orig_size))
batch_queue = Queue(maxsize=4)

def inference_producer():
for batch in detection_generator:
batch_queue.put(batch)
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer():
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))

with result_lock:
results.extend(batch_results)
else:
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
batch_results = [parallel_get_lines(pred, orig_size)
for pred, orig_size in zip(preds, orig_sizes)]

with result_lock:
results.extend(batch_results)

# Start producer and consumer threads
producer = threading.Thread(target=inference_producer)
consumer = threading.Thread(target=postprocessing_consumer)

producer.start()
consumer.start()

# Wait for both threads to complete
producer.join()
consumer.join()

return results

Expand Down
74 changes: 51 additions & 23 deletions surya/layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import threading
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from queue import Queue
from typing import List, Optional
from PIL import Image
import numpy as np
Expand Down Expand Up @@ -191,39 +193,65 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op
id2label = model.config.id2label

results = []
result_lock = threading.Lock()
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH

if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
batch_queue = Queue(maxsize=4)

def inference_producer():
for batch in layout_generator:
batch_queue.put(batch)
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer():
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
img_idx = 0
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
img_idxs = [img_idx + i for i in range(len(preds))]
batch_results = list(executor.map(
parallel_get_regions,
preds,
orig_sizes,
[id2label] * len(preds),
[detection_results[idx] for idx in img_idxs] if detection_results else [None] * len(preds)
))

with result_lock:
results.extend(batch_results)
img_idx += len(preds)
else:
img_idx = 0
for preds, orig_sizes in layout_generator:
futures = []
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
for pred, orig_size in zip(preds, orig_sizes):
future = executor.submit(
parallel_get_regions,
results.append(parallel_get_regions(
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
)
))

futures.append(future)
img_idx += 1

for future in futures:
results.append(future.result())
else:
img_idx = 0
for preds, orig_sizes in layout_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_regions(
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
))

img_idx += 1
# Start producer and consumer threads
producer = threading.Thread(target=inference_producer)
consumer = threading.Thread(target=postprocessing_consumer)

producer.start()
consumer.start()

# Wait for both threads to complete
producer.join()
consumer.join()

return results
Loading