Skip to content

Commit

Permalink
Merge pull request #765 from roboflow/stitch_ocr_detections_workflow_…
Browse files Browse the repository at this point in the history
…block

Stitch ocr detections workflow block
  • Loading branch information
PawelPeczek-Roboflow authored Nov 1, 2024
2 parents 9602ddc + 1e5774e commit 727ebd0
Show file tree
Hide file tree
Showing 13 changed files with 709 additions and 80 deletions.
92 changes: 46 additions & 46 deletions docs/workflows/blocks.md

Large diffs are not rendered by default.

49 changes: 25 additions & 24 deletions docs/workflows/kinds.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,36 +37,37 @@ for the presence of a mask in the input.

## Kinds declared in Roboflow plugins
<!--- AUTOGENERATED_KINDS_LIST -->
* [`image_metadata`](/workflows/kinds/image_metadata): Dictionary with image metadata required by supervision
* [`integer`](/workflows/kinds/integer): Integer value
* [`roboflow_model_id`](/workflows/kinds/roboflow_model_id): Roboflow model id
* [`object_detection_prediction`](/workflows/kinds/object_detection_prediction): Prediction with detected bounding boxes in form of sv.Detections(...) object
* [`video_metadata`](/workflows/kinds/video_metadata): Video image metadata
* [`string`](/workflows/kinds/string): String value
* [`numpy_array`](/workflows/kinds/numpy_array): Numpy array
* [`parent_id`](/workflows/kinds/parent_id): Identifier of parent for step output
* [`qr_code_detection`](/workflows/kinds/qr_code_detection): Prediction with QR code detection
* [`float`](/workflows/kinds/float): Float value
* [`dictionary`](/workflows/kinds/dictionary): Dictionary
* [`roboflow_api_key`](/workflows/kinds/roboflow_api_key): Roboflow API key
* [`detection`](/workflows/kinds/detection): Single element of detections-based prediction (like `object_detection_prediction`)
* [`list_of_values`](/workflows/kinds/list_of_values): List of values of any type
* [`instance_segmentation_prediction`](/workflows/kinds/instance_segmentation_prediction): Prediction with detected bounding boxes and segmentation masks in form of sv.Detections(...) object
* [`float_zero_to_one`](/workflows/kinds/float_zero_to_one): `float` value in range `[0.0, 1.0]`
* [`object_detection_prediction`](/workflows/kinds/object_detection_prediction): Prediction with detected bounding boxes in form of sv.Detections(...) object
* [`*`](/workflows/kinds/*): Equivalent of any element
* [`image`](/workflows/kinds/image): Image in workflows
* [`image_metadata`](/workflows/kinds/image_metadata): Dictionary with image metadata required by supervision
* [`image_keypoints`](/workflows/kinds/image_keypoints): Image keypoints detected by classical Computer Vision method
* [`bar_code_detection`](/workflows/kinds/bar_code_detection): Prediction with barcode detection
* [`roboflow_model_id`](/workflows/kinds/roboflow_model_id): Roboflow model id
* [`bytes`](/workflows/kinds/bytes): This kind represent bytes
* [`roboflow_project`](/workflows/kinds/roboflow_project): Roboflow project name
* [`dictionary`](/workflows/kinds/dictionary): Dictionary
* [`numpy_array`](/workflows/kinds/numpy_array): Numpy array
* [`qr_code_detection`](/workflows/kinds/qr_code_detection): Prediction with QR code detection
* [`classification_prediction`](/workflows/kinds/classification_prediction): Predictions from classifier
* [`contours`](/workflows/kinds/contours): List of numpy arrays where each array represents contour points
* [`serialised_payloads`](/workflows/kinds/serialised_payloads): Serialised element that is usually accepted by sink
* [`video_metadata`](/workflows/kinds/video_metadata): Video image metadata
* [`prediction_type`](/workflows/kinds/prediction_type): String value with type of prediction
* [`zone`](/workflows/kinds/zone): Definition of polygon zone
* [`keypoint_detection_prediction`](/workflows/kinds/keypoint_detection_prediction): Prediction with detected bounding boxes and detected keypoints in form of sv.Detections(...) object
* [`boolean`](/workflows/kinds/boolean): Boolean flag
* [`float`](/workflows/kinds/float): Float value
* [`point`](/workflows/kinds/point): Single point in 2D
* [`top_class`](/workflows/kinds/top_class): String value representing top class predicted by classification model
* [`language_model_output`](/workflows/kinds/language_model_output): LLM / VLM output
* [`image`](/workflows/kinds/image): Image in workflows
* [`roboflow_api_key`](/workflows/kinds/roboflow_api_key): Roboflow API key
* [`parent_id`](/workflows/kinds/parent_id): Identifier of parent for step output
* [`*`](/workflows/kinds/*): Equivalent of any element
* [`rgb_color`](/workflows/kinds/rgb_color): RGB color
* [`boolean`](/workflows/kinds/boolean): Boolean flag
* [`roboflow_project`](/workflows/kinds/roboflow_project): Roboflow project name
* [`image_keypoints`](/workflows/kinds/image_keypoints): Image keypoints detected by classical Computer Vision method
* [`list_of_values`](/workflows/kinds/list_of_values): List of values of any type
* [`zone`](/workflows/kinds/zone): Definition of polygon zone
* [`point`](/workflows/kinds/point): Single point in 2D
* [`prediction_type`](/workflows/kinds/prediction_type): String value with type of prediction
* [`instance_segmentation_prediction`](/workflows/kinds/instance_segmentation_prediction): Prediction with detected bounding boxes and segmentation masks in form of sv.Detections(...) object
* [`integer`](/workflows/kinds/integer): Integer value
* [`keypoint_detection_prediction`](/workflows/kinds/keypoint_detection_prediction): Prediction with detected bounding boxes and detected keypoints in form of sv.Detections(...) object
* [`classification_prediction`](/workflows/kinds/classification_prediction): Predictions from classifier
* [`detection`](/workflows/kinds/detection): Single element of detections-based prediction (like `object_detection_prediction`)
<!--- AUTOGENERATED_KINDS_LIST -->
2 changes: 1 addition & 1 deletion inference/core/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.24.0"
__version__ = "0.25.0"


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions inference/core/workflows/core_steps/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@
from inference.core.workflows.core_steps.transformations.stitch_images.v1 import (
StitchImagesBlockV1,
)
from inference.core.workflows.core_steps.transformations.stitch_ocr_detections.v1 import (
StitchOCRDetectionsBlockV1,
)

# Visualizers
from inference.core.workflows.core_steps.visualizations.background_color.v1 import (
Expand Down Expand Up @@ -425,6 +428,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
StabilityAIInpaintingBlockV1,
StabilizeTrackedDetectionsBlockV1,
StitchImagesBlockV1,
StitchOCRDetectionsBlockV1,
TemplateMatchingBlockV1,
TimeInZoneBlockV1,
TimeInZoneBlockV2,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
from enum import Enum
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

import numpy as np
import supervision as sv
from pydantic import AliasChoices, ConfigDict, Field, field_validator

from inference.core.workflows.execution_engine.entities.base import (
Batch,
OutputDefinition,
)
from inference.core.workflows.execution_engine.entities.types import (
INTEGER_KIND,
OBJECT_DETECTION_PREDICTION_KIND,
STRING_KIND,
StepOutputSelector,
WorkflowParameterSelector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
WorkflowBlock,
WorkflowBlockManifest,
)

LONG_DESCRIPTION = """
Combines OCR detection results into a coherent text string by organizing detections spatially.
This transformation is perfect for turning individual OCR results into structured, readable text!
#### How It Works
This transformation reconstructs the original text from OCR detection results by:
1. 📐 **Grouping** text detections into rows based on their vertical (`y`) positions
2. 📏 **Sorting** detections within each row by horizontal (`x`) position
3. 📜 **Concatenating** the text in reading order (left-to-right, top-to-bottom)
#### Parameters
- **`tolerance`**: Controls how close detections need to be vertically to be considered part of the same line of text.
A higher tolerance will group detections that are further apart vertically.
- **`reading_direction`**: Determines the order in which text is read. Available options:
* **"left_to_right"**: Standard left-to-right reading (e.g., English) ➡️
* **"right_to_left"**: Right-to-left reading (e.g., Arabic) ⬅️
* **"vertical_top_to_bottom"**: Vertical reading from top to bottom ⬇️
* **"vertical_bottom_to_top"**: Vertical reading from bottom to top ⬆️
#### Why Use This Transformation?
This is especially useful for:
- 📖 Converting individual character/word detections into a readable text block
- 📝 Reconstructing multi-line text from OCR results
- 🔀 Maintaining proper reading order for detected text elements
- 🌏 Supporting different writing systems and text orientations
#### Example Usage
Use this transformation after an OCR model that outputs individual words or characters, so you can reconstruct the
original text layout in its intended format.
"""

SHORT_DESCRIPTION = "Combines OCR detection results into a coherent text string by organizing detections spatially."


class ReadingDirection(str, Enum):
LEFT_TO_RIGHT = "left_to_right"
RIGHT_TO_LEFT = "right_to_left"
VERTICAL_TOP_TO_BOTTOM = "vertical_top_to_bottom"
VERTICAL_BOTTOM_TO_TOP = "vertical_bottom_to_top"


class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
json_schema_extra={
"name": "Stitch OCR Detections",
"version": "v1",
"short_description": SHORT_DESCRIPTION,
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "transformation",
"ui_manifest": {
"section": "advanced",
"icon": "fal fa-reel",
"blockPriority": 2,
},
}
)
type: Literal["roboflow_core/stitch_ocr_detections@v1"]
predictions: StepOutputSelector(
kind=[
OBJECT_DETECTION_PREDICTION_KIND,
]
) = Field(
title="OCR Detections",
description="The output of an OCR detection model.",
examples=["$steps.my_ocr_detection_model.predictions"],
)
reading_direction: Literal[
"left_to_right",
"right_to_left",
"vertical_top_to_bottom",
"vertical_bottom_to_top",
] = Field(
title="Reading Direction",
description="The direction of the text in the image.",
examples=["right_to_left"],
json_schema_extra={
"values_metadata": {
"left_to_right": {
"name": "Left To Right",
"description": "Standard left-to-right reading (e.g., English language)",
},
"right_to_left": {
"name": "Right To Left",
"description": "Right-to-left reading (e.g., Arabic)",
},
"vertical_top_to_bottom": {
"name": "Top To Bottom (Vertical)",
"description": "Vertical reading from top to bottom",
},
"vertical_bottom_to_top": {
"name": "Bottom To Top (Vertical)",
"description": "Vertical reading from bottom to top",
},
}
},
)
tolerance: Union[int, WorkflowParameterSelector(kind=[INTEGER_KIND])] = Field(
title="Tolerance",
description="The tolerance for grouping detections into the same line of text.",
default=10,
examples=[10, "$inputs.tolerance"],
)

@field_validator("tolerance")
@classmethod
def ensure_tolerance_greater_than_zero(
cls, value: Union[int, str]
) -> Union[int, str]:
if isinstance(value, int) and value <= 0:
raise ValueError(
"Stitch OCR detections block expects `tollerance` to be greater than zero."
)
return value

@classmethod
def accepts_batch_input(cls) -> bool:
return True

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(name="ocr_text", kind=[STRING_KIND]),
]

@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.0.0,<2.0.0"


class StitchOCRDetectionsBlockV1(WorkflowBlock):
@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
return BlockManifest

def run(
self,
predictions: Batch[sv.Detections],
reading_direction: str,
tolerance: int,
) -> BlockResult:
return [
stitch_ocr_detections(
detections=detections,
reading_direction=reading_direction,
tolerance=tolerance,
)
for detections in predictions
]


def stitch_ocr_detections(
detections: sv.Detections,
reading_direction: str = "left_to_right",
tolerance: int = 10,
) -> Dict[str, str]:
"""
Stitch OCR detections into coherent text based on spatial arrangement.
Args:
detections: Supervision Detections object containing OCR results
reading_direction: Direction to read text ("left_to_right", "right_to_left",
"vertical_top_to_bottom", "vertical_bottom_to_top")
tolerance: Vertical tolerance for grouping text into lines
Returns:
Dict containing stitched OCR text under 'ocr_text' key
"""
if len(detections) == 0:
return {"ocr_text": ""}

xyxy = detections.xyxy.round().astype(dtype=int)
class_names = detections.data["class_name"]

# Prepare coordinates based on reading direction
xyxy = prepare_coordinates(xyxy, reading_direction)

# Group detections into lines
boxes_by_line = group_detections_by_line(xyxy, reading_direction, tolerance)
# Sort lines based on reading direction
lines = sorted(
boxes_by_line.keys(), reverse=reading_direction in ["vertical_bottom_to_top"]
)

# Build final text
ordered_class_names = []
for i, key in enumerate(lines):
line_data = boxes_by_line[key]
line_xyxy = np.array(line_data["xyxy"])
line_idx = np.array(line_data["idx"])

# Sort detections within line
sort_idx = sort_line_detections(line_xyxy, reading_direction)

# Add sorted class names for this line
ordered_class_names.extend(class_names[line_idx[sort_idx]])

# Add line separator if not last line
if i < len(lines) - 1:
ordered_class_names.append(get_line_separator(reading_direction))

return {"ocr_text": "".join(ordered_class_names)}


def prepare_coordinates(
xyxy: np.ndarray,
reading_direction: str,
) -> np.ndarray:
"""Prepare coordinates based on reading direction."""
if reading_direction in ["vertical_top_to_bottom", "vertical_bottom_to_top"]:
# Swap x and y coordinates: [x1,y1,x2,y2] -> [y1,x1,y2,x2]
return xyxy[:, [1, 0, 3, 2]]
return xyxy


def group_detections_by_line(
xyxy: np.ndarray,
reading_direction: str,
tolerance: int,
) -> Dict[float, Dict[str, List]]:
"""Group detections into lines based on primary coordinate."""
# After prepare_coordinates swap, we always group by y ([:, 1])
primary_coord = xyxy[:, 1] # This is y for horizontal, swapped x for vertical

# Round primary coordinate to group into lines
rounded_primary = np.round(primary_coord / tolerance) * tolerance

boxes_by_line = {}
# Group bounding boxes and associated indices by line
for i, (bbox, line_pos) in enumerate(zip(xyxy, rounded_primary)):
if line_pos not in boxes_by_line:
boxes_by_line[line_pos] = {"xyxy": [bbox], "idx": [i]}
else:
boxes_by_line[line_pos]["xyxy"].append(bbox)
boxes_by_line[line_pos]["idx"].append(i)

return boxes_by_line


def sort_line_detections(
line_xyxy: np.ndarray,
reading_direction: str,
) -> np.ndarray:
"""Sort detections within a line based on reading direction."""
# After prepare_coordinates swap, we always sort by x ([:, 0])
if reading_direction in ["left_to_right", "vertical_top_to_bottom"]:
return line_xyxy[:, 0].argsort() # Sort by x1 (original x or swapped y)
else: # right_to_left or vertical_bottom_to_top
return (-line_xyxy[:, 0]).argsort() # Sort by -x1 (original -x or swapped -y)


def get_line_separator(reading_direction: str) -> str:
"""Get the appropriate separator based on reading direction."""
return "\n" if reading_direction in ["left_to_right", "right_to_left"] else " "
Loading

0 comments on commit 727ebd0

Please sign in to comment.