diff --git a/pyproject.toml b/pyproject.toml index b8e0e64..ad3810c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.6.6" +version = "0.6.7" description = "OCR, layout, reading order, and table recognition in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/surya/detection.py b/surya/detection.py index a5b4a28..1ed7116 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -1,3 +1,5 @@ +import threading +from queue import Queue from typing import List, Tuple, Generator import torch @@ -126,18 +128,52 @@ def batch_text_detection(images: List, model, processor, batch_size=None) -> Lis detection_generator = batch_detection(images, model, processor, batch_size=batch_size) results = [] + result_lock = threading.Lock() max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH - - if parallelize: - with ProcessPoolExecutor(max_workers=max_workers) as executor: - for preds, orig_sizes in detection_generator: - batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes)) - results.extend(batch_results) - else: - for preds, orig_sizes in detection_generator: - for pred, orig_size in zip(preds, orig_sizes): - results.append(parallel_get_lines(pred, orig_size)) + batch_queue = Queue(maxsize=4) + + def inference_producer(): + for batch in detection_generator: + batch_queue.put(batch) + batch_queue.put(None) # Signal end of batches + + def postprocessing_consumer(): + if parallelize: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + while True: + batch = batch_queue.get() + if batch is None: + break + + preds, orig_sizes = batch + batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes)) + + with result_lock: + results.extend(batch_results) + else: + while True: + batch = batch_queue.get() + if batch is None: + break + + preds, orig_sizes = batch + batch_results = [parallel_get_lines(pred, orig_size) + for pred, orig_size in zip(preds, orig_sizes)] + + with result_lock: + results.extend(batch_results) + + # Start producer and consumer threads + producer = threading.Thread(target=inference_producer) + consumer = threading.Thread(target=postprocessing_consumer) + + producer.start() + consumer.start() + + # Wait for both threads to complete + producer.join() + consumer.join() return results diff --git a/surya/layout.py b/surya/layout.py index 9f7dd4d..1ee4d13 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -1,5 +1,7 @@ +import threading from collections import defaultdict from concurrent.futures import ProcessPoolExecutor +from queue import Queue from typing import List, Optional from PIL import Image import numpy as np @@ -191,39 +193,65 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op id2label = model.config.id2label results = [] + result_lock = threading.Lock() max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH - - if parallelize: - with ProcessPoolExecutor(max_workers=max_workers) as executor: + batch_queue = Queue(maxsize=4) + + def inference_producer(): + for batch in layout_generator: + batch_queue.put(batch) + batch_queue.put(None) # Signal end of batches + + def postprocessing_consumer(): + if parallelize: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + img_idx = 0 + while True: + batch = batch_queue.get() + if batch is None: + break + + preds, orig_sizes = batch + img_idxs = [img_idx + i for i in range(len(preds))] + batch_results = list(executor.map( + parallel_get_regions, + preds, + orig_sizes, + [id2label] * len(preds), + [detection_results[idx] for idx in img_idxs] if detection_results else [None] * len(preds) + )) + + with result_lock: + results.extend(batch_results) + img_idx += len(preds) + else: img_idx = 0 - for preds, orig_sizes in layout_generator: - futures = [] + while True: + batch = batch_queue.get() + if batch is None: + break + + preds, orig_sizes = batch for pred, orig_size in zip(preds, orig_sizes): - future = executor.submit( - parallel_get_regions, + results.append(parallel_get_regions( pred, orig_size, id2label, detection_results[img_idx] if detection_results else None - ) + )) - futures.append(future) img_idx += 1 - for future in futures: - results.append(future.result()) - else: - img_idx = 0 - for preds, orig_sizes in layout_generator: - for pred, orig_size in zip(preds, orig_sizes): - results.append(parallel_get_regions( - pred, - orig_size, - id2label, - detection_results[img_idx] if detection_results else None - )) - - img_idx += 1 + # Start producer and consumer threads + producer = threading.Thread(target=inference_producer) + consumer = threading.Thread(target=postprocessing_consumer) + + producer.start() + consumer.start() + + # Wait for both threads to complete + producer.join() + consumer.join() return results \ No newline at end of file