diff --git a/modules/image/text_recognition/ch_pp-ocrv3_det/module.py b/modules/image/text_recognition/ch_pp-ocrv3_det/module.py index 413b44c65..d0b8363c9 100644 --- a/modules/image/text_recognition/ch_pp-ocrv3_det/module.py +++ b/modules/image/text_recognition/ch_pp-ocrv3_det/module.py @@ -21,22 +21,36 @@ import base64 import os import time +from io import BytesIO import cv2 import numpy as np import paddle.inference as paddle_infer from PIL import Image -from paddlehub.utils.utils import logger from paddlehub.module.module import moduleinfo from paddlehub.module.module import runnable from paddlehub.module.module import serving +from paddlehub.utils.utils import logger def base64_to_cv2(b64str): data = base64.b64decode(b64str.encode('utf8')) data = np.fromstring(data, np.uint8) data = cv2.imdecode(data, cv2.IMREAD_COLOR) + if data is None: + buf = BytesIO() + image_decode = base64.b64decode(b64str.encode('utf8')) + image = BytesIO(image_decode) + im = Image.open(image) + rgb = im.convert('RGB') + rgb.save(buf, 'jpeg') + buf.seek(0) + image_bytes = buf.read() + data_base64 = str(base64.b64encode(image_bytes), encoding="utf-8") + image_decode = base64.b64decode(data_base64) + img_array = np.frombuffer(image_decode, np.uint8) + data = cv2.imdecode(img_array, cv2.IMREAD_COLOR) return data @@ -49,6 +63,7 @@ def base64_to_cv2(b64str): author_email="paddle-dev@baidu.com", type="cv/text_recognition") class ChPPOCRv3Det: + def __init__(self, enable_mkldnn=False): """ initialize with the necessary elements diff --git a/modules/image/text_recognition/ch_pp-ocrv3_det/test.py b/modules/image/text_recognition/ch_pp-ocrv3_det/test.py index db09b25bc..43f56d1f1 100644 --- a/modules/image/text_recognition/ch_pp-ocrv3_det/test.py +++ b/modules/image/text_recognition/ch_pp-ocrv3_det/test.py @@ -4,13 +4,14 @@ import cv2 import requests -import paddlehub as hub +import paddlehub as hub os.environ['CUDA_VISIBLE_DEVICES'] = '0' class TestHubModule(unittest.TestCase): + @classmethod def setUpClass(cls) -> None: img_url = 'https://unsplash.com/photos/KTzZVDjUsXw/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MzM3fHx0ZXh0fGVufDB8fHx8MTY2MzUxMTExMQ&force=true&w=640' @@ -34,8 +35,9 @@ def test_detect_text1(self): use_gpu=False, visualization=False, ) - self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [ - 261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) + self.assertEqual( + results[0]['data'], + [[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) def test_detect_text2(self): results = self.module.detect_text( @@ -43,8 +45,9 @@ def test_detect_text2(self): use_gpu=False, visualization=False, ) - self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [ - 261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) + self.assertEqual( + results[0]['data'], + [[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) def test_detect_text3(self): results = self.module.detect_text( @@ -52,8 +55,9 @@ def test_detect_text3(self): use_gpu=True, visualization=False, ) - self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [ - 261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) + self.assertEqual( + results[0]['data'], + [[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) def test_detect_text4(self): results = self.module.detect_text( @@ -61,22 +65,15 @@ def test_detect_text4(self): use_gpu=False, visualization=True, ) - self.assertEqual(results[0]['data'], [[[261, 202], [376, 202], [376, 239], [ - 261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) + self.assertEqual( + results[0]['data'], + [[[261, 202], [376, 202], [376, 239], [261, 239]], [[283, 162], [352, 162], [352, 202], [283, 202]]]) def test_detect_text5(self): - self.assertRaises( - AttributeError, - self.module.detect_text, - images=['tests/test.jpg'] - ) + self.assertRaises(AttributeError, self.module.detect_text, images=['tests/test.jpg']) def test_detect_text6(self): - self.assertRaises( - AssertionError, - self.module.detect_text, - paths=['no.jpg'] - ) + self.assertRaises(AssertionError, self.module.detect_text, paths=['no.jpg']) def test_save_inference_model(self): self.module.save_inference_model('./inference/model') @@ -87,4 +84,3 @@ def test_save_inference_model(self): if __name__ == "__main__": unittest.main() -