Skip to content

Commit

Permalink
Overlap detection postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 22, 2024
1 parent 893f586 commit 4a45145
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.6.6"
version = "0.6.7"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
58 changes: 49 additions & 9 deletions surya/detection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import threading
from queue import Queue
from typing import List, Tuple, Generator

import torch
Expand Down Expand Up @@ -126,18 +128,56 @@ def batch_text_detection(images: List, model, processor, batch_size=None) -> Lis
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)

results = []
result_lock = threading.Lock()
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH

if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for preds, orig_sizes in detection_generator:
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
results.extend(batch_results)
else:
for preds, orig_sizes in detection_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_lines(pred, orig_size))
# Queue to hold batches waiting for postprocessing
batch_queue = Queue(maxsize=10)

def inference_producer():
"""Produces batches from the model and adds them to the queue"""
for batch in detection_generator:
batch_queue.put(batch)
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer():
"""Consumes batches from the queue and processes them"""
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
while True:
batch = batch_queue.get()
if batch is None: # Check for end signal
break

preds, orig_sizes = batch
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))

with result_lock:
results.extend(batch_results)
else:
while True:
batch = batch_queue.get()
if batch is None: # Check for end signal
break

preds, orig_sizes = batch
batch_results = [parallel_get_lines(pred, orig_size)
for pred, orig_size in zip(preds, orig_sizes)]

with result_lock:
results.extend(batch_results)

# Start producer and consumer threads
producer = threading.Thread(target=inference_producer)
consumer = threading.Thread(target=postprocessing_consumer)

producer.start()
consumer.start()

# Wait for both threads to complete
producer.join()
consumer.join()

return results

Expand Down
75 changes: 53 additions & 22 deletions surya/layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import threading
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from queue import Queue
from typing import List, Optional
from PIL import Image
import numpy as np
Expand Down Expand Up @@ -191,39 +193,68 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op
id2label = model.config.id2label

results = []
result_lock = threading.Lock()
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH

if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Queue to hold batches waiting for postprocessing
batch_queue = Queue(maxsize=10)

def inference_producer():
"""Produces batches from the model and adds them to the queue"""
for batch in layout_generator:
batch_queue.put(batch)
batch_queue.put(None) # Signal end of batches

def postprocessing_consumer():
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
img_idx = 0
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
img_idxs = [img_idx + i for i in range(len(preds))]
batch_results = list(executor.map(
parallel_get_regions,
preds,
orig_sizes,
[id2label] * len(preds),
[detection_results[idx] for idx in img_idxs] if detection_results else [None] * len(preds)
))

with result_lock:
results.extend(batch_results)
img_idx += len(preds)
else:
img_idx = 0
for preds, orig_sizes in layout_generator:
futures = []
while True:
batch = batch_queue.get()
if batch is None:
break

preds, orig_sizes = batch
for pred, orig_size in zip(preds, orig_sizes):
future = executor.submit(
parallel_get_regions,
results.append(parallel_get_regions(
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
)
))

futures.append(future)
img_idx += 1

for future in futures:
results.append(future.result())
else:
img_idx = 0
for preds, orig_sizes in layout_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_regions(
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
))

img_idx += 1
# Start producer and consumer threads
producer = threading.Thread(target=inference_producer)
consumer = threading.Thread(target=postprocessing_consumer)

producer.start()
consumer.start()

# Wait for both threads to complete
producer.join()
consumer.join()

return results

0 comments on commit 4a45145

Please sign in to comment.