Skip to content

Commit

Permalink
update ch_pp-ocrv3_det (#2173)
Browse files Browse the repository at this point in the history
* update ch_pp-ocrv3_det

* update
  • Loading branch information
jm12138 authored Dec 29, 2022
1 parent 3d13232 commit 4382eee
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 21 deletions.
17 changes: 16 additions & 1 deletion modules/image/text_recognition/ch_pp-ocrv3_det/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,36 @@
import base64
import os
import time
from io import BytesIO

import cv2
import numpy as np
import paddle.inference as paddle_infer
from PIL import Image

from paddlehub.utils.utils import logger
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
from paddlehub.utils.utils import logger


def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
if data is None:
buf = BytesIO()
image_decode = base64.b64decode(b64str.encode('utf8'))
image = BytesIO(image_decode)
im = Image.open(image)
rgb = im.convert('RGB')
rgb.save(buf, 'jpeg')
buf.seek(0)
image_bytes = buf.read()
data_base64 = str(base64.b64encode(image_bytes), encoding="utf-8")
image_decode = base64.b64decode(data_base64)
img_array = np.frombuffer(image_decode, np.uint8)
data = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return data


Expand All @@ -49,6 +63,7 @@ def base64_to_cv2(b64str):
author_email="[email protected]",
type="cv/text_recognition")
class ChPPOCRv3Det:

def __init__(self, enable_mkldnn=False):
"""
initialize with the necessary elements
Expand Down
36 changes: 16 additions & 20 deletions modules/image/text_recognition/ch_pp-ocrv3_det/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

import cv2
import requests
import paddlehub as hub

import paddlehub as hub

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


class TestHubModule(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640'
Expand All @@ -34,49 +35,45 @@ def test_detect_text1(self):
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
self.assertEqual(
results[0]['data'],
[[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])

def test_detect_text2(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False,
)
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
self.assertEqual(
results[0]['data'],
[[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])

def test_detect_text3(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False,
)
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
self.assertEqual(
results[0]['data'],
[[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])

def test_detect_text4(self):
results = self.module.detect_text(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True,
)
self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [
261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])
self.assertEqual(
results[0]['data'],
[[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]])

def test_detect_text5(self):
self.assertRaises(
AttributeError,
self.module.detect_text,
images=['tests/test.jpg']
)
self.assertRaises(AttributeError, self.module.detect_text, images=['tests/test.jpg'])

def test_detect_text6(self):
self.assertRaises(
AssertionError,
self.module.detect_text,
paths=['no.jpg']
)
self.assertRaises(AssertionError, self.module.detect_text, paths=['no.jpg'])

def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
Expand All @@ -87,4 +84,3 @@ def test_save_inference_model(self):

if __name__ == "__main__":
unittest.main()

0 comments on commit 4382eee

Please sign in to comment.