Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update chinese_text_detection_db_server #2170

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@
```
- 通过命令行方式实现文字识别模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)

- ### 2、代码示例
- ### 2、预测代码示例

- ```python
import paddlehub as hub
import cv2

text_detector = hub.Module(name="chinese_text_detection_db_server"), enable_mkldnn=True) # mkldnn加速仅在CPU下有效
text_detector = hub.Module(name="chinese_text_detection_db_server", enable_mkldnn=True) # mkldnn加速仅在CPU下有效
result = text_detector.detect_text(images=[cv2.imread('/PATH/TO/IMAGE')])

# or
Expand Down Expand Up @@ -175,6 +175,10 @@

移除 fluid api

* 1.1.0

适配 PaddleHub 2.x 版本

- ```shell
$ hub install chinese_text_detection_db_server==1.0.3
$ hub install chinese_text_detection_db_server==1.1.0
```
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand All @@ -9,43 +8,55 @@
import math
import os
import time
from io import BytesIO

import cv2
import numpy as np
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from PIL import Image

import paddlehub as hub
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
from paddlehub.utils.log import logger


def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
if data is None:
buf = BytesIO()
image_decode = base64.b64decode(b64str.encode('utf8'))
image = BytesIO(image_decode)
im = Image.open(image)
rgb = im.convert('RGB')
rgb.save(buf, 'jpeg')
buf.seek(0)
image_bytes = buf.read()
data_base64 = str(base64.b64encode(image_bytes), encoding="utf-8")
image_decode = base64.b64decode(data_base64)
img_array = np.frombuffer(image_decode, np.uint8)
data = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return data


@moduleinfo(
name="chinese_text_detection_db_server",
version="1.0.3",
version="1.1.0",
summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev",
author_email="[email protected]",
type="cv/text_recognition")
class ChineseTextDetectionDBServer(hub.Module):
class ChineseTextDetectionDBServer:

def _initialize(self, enable_mkldnn=False):
def __init__(self, enable_mkldnn=False):
"""
initialize with the necessary elements
"""
self.pretrained_model_path = os.path.join(self.directory, 'inference_model')
self.pretrained_model_path = os.path.join(self.directory, 'inference_model', 'model')
self.enable_mkldnn = enable_mkldnn

self._set_config()
Expand All @@ -62,8 +73,8 @@ def _set_config(self):
"""
predictor config setting
"""
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
model_file_path = self.pretrained_model_path + '.pdmodel'
params_file_path = self.pretrained_model_path + '.pdiparams'

config = Config(model_file_path, params_file_path)
try:
Expand Down Expand Up @@ -211,7 +222,7 @@ def detect_text(self,
data_out = self.output_tensors[0].copy_to_cpu()
dt_boxes_list = postprocessor(data_out, [ratio_list])
boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape)
res['data'] = boxes.astype(np.int).tolist()
res['data'] = boxes.astype(np.int64).tolist()

all_imgs.append(im)
all_ratios.append(ratio_list)
Expand All @@ -230,28 +241,6 @@ def detect_text(self,

return all_results

def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)

model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
program, feeded_var_names, target_vars = paddle.static.load_inference_model(dirname=self.pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)

paddle.static.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)

@serving
def serving_method(self, images, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

from PIL import Image, ImageDraw, ImageFont
from shapely.geometry import Polygon
import cv2
import numpy as np
import pyclipper
from PIL import ImageDraw
from shapely.geometry import Polygon


class DBPreProcess(object):

def __init__(self, max_side_len=960):
self.max_side_len = max_side_len

Expand Down Expand Up @@ -103,7 +103,7 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
contours, _ = outs[0], outs[1]

num_contours = min(len(contours), self.max_candidates)
boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
boxes = np.zeros((num_contours, 4, 2), dtype=np.int64)
scores = np.zeros((num_contours, ), dtype=np.float32)

for index in range(num_contours):
Expand All @@ -127,7 +127,7 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):

box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes[index, :, :] = box.astype(np.int16)
boxes[index, :, :] = box.astype(np.int64)
scores[index] = score
return boxes, scores

Expand Down Expand Up @@ -163,15 +163,15 @@ def get_mini_boxes(self, contour):
def box_score_fast(self, bitmap, _box):
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int64), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int64), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int64), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int64), 0, h - 1)

mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int64), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]

def __call__(self, predictions, ratio_list):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
shapely
pyclipper
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import shutil
import unittest

import cv2
import requests

import paddlehub as hub

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


class TestHubModule(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="chinese_text_detection_db_server")

@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')

def test_detect_text1(self):
results = self.module.detect_text(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False,
)
self.assertEqual(
results[0]['data'],
[[[258, 199], [382, 199], [382, 240], [258, 240]], [[281, 159], [359, 159], [359, 202], [281, 202]]])

def test_detect_text2(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False,
)
self.assertEqual(
results[0]['data'],
[[[258, 199], [382, 199], [382, 240], [258, 240]], [[281, 159], [359, 159], [359, 202], [281, 202]]])

def test_detect_text3(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False,
)
self.assertEqual(
results[0]['data'],
[[[258, 199], [382, 199], [382, 240], [258, 240]], [[281, 159], [359, 159], [359, 202], [281, 202]]])

def test_detect_text4(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True,
)
self.assertEqual(
results[0]['data'],
[[[258, 199], [382, 199], [382, 240], [258, 240]], [[281, 159], [359, 159], [359, 202], [281, 202]]])

def test_detect_text5(self):
self.assertRaises(AttributeError, self.module.detect_text, images=['tests/test.jpg'])

def test_detect_text6(self):
self.assertRaises(AssertionError, self.module.detect_text, paths=['no.jpg'])

def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')

self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))


if __name__ == "__main__":
unittest.main()