Skip to content

Commit

Permalink
Overlap detection postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 22, 2024
1 parent 893f586 commit 46d47ac
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.6.6"
version = "0.6.7"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
56 changes: 46 additions & 10 deletions surya/detection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import threading
from queue import Queue
from typing import List, Tuple, Generator

import torch
Expand Down Expand Up @@ -126,18 +128,52 @@ def batch_text_detection(images: List, model, processor, batch_size=None) -> Lis
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)

results = []
result_lock = threading.Lock()
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH

if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for preds, orig_sizes in detection_generator:
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
results.extend(batch_results)
else:
for preds, orig_sizes in detection_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_lines(pred, orig_size))
batch_queue = Queue(maxsize=4)

def inference_producer():
for batch in detection_generator:
batch_queue.put(batch)
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer():
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))

with result_lock:
results.extend(batch_results)
else:
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
batch_results = [parallel_get_lines(pred, orig_size)
for pred, orig_size in zip(preds, orig_sizes)]

with result_lock:
results.extend(batch_results)

# Start producer and consumer threads
producer = threading.Thread(target=inference_producer)
consumer = threading.Thread(target=postprocessing_consumer)

producer.start()
consumer.start()

# Wait for both threads to complete
producer.join()
consumer.join()

return results

Expand Down
74 changes: 51 additions & 23 deletions surya/layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import threading
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from queue import Queue
from typing import List, Optional
from PIL import Image
import numpy as np
Expand Down Expand Up @@ -191,39 +193,65 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op
id2label = model.config.id2label

results = []
result_lock = threading.Lock()
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH

if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
batch_queue = Queue(maxsize=4)

def inference_producer():
for batch in layout_generator:
batch_queue.put(batch)
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer():
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
img_idx = 0
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
img_idxs = [img_idx + i for i in range(len(preds))]
batch_results = list(executor.map(
parallel_get_regions,
preds,
orig_sizes,
[id2label] * len(preds),
[detection_results[idx] for idx in img_idxs] if detection_results else [None] * len(preds)
))

with result_lock:
results.extend(batch_results)
img_idx += len(preds)
else:
img_idx = 0
for preds, orig_sizes in layout_generator:
futures = []
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
for pred, orig_size in zip(preds, orig_sizes):
future = executor.submit(
parallel_get_regions,
results.append(parallel_get_regions(
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
)
))

futures.append(future)
img_idx += 1

for future in futures:
results.append(future.result())
else:
img_idx = 0
for preds, orig_sizes in layout_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_regions(
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
))

img_idx += 1
# Start producer and consumer threads
producer = threading.Thread(target=inference_producer)
consumer = threading.Thread(target=postprocessing_consumer)

producer.start()
consumer.start()

# Wait for both threads to complete
producer.join()
consumer.join()

return results

0 comments on commit 46d47ac

Please sign in to comment.