Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

OCR案例更新 #927

Merged
merged 2 commits into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions paddle2.0_docs/image_ocr/OCR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
"**数据展示**\n",
"<p align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/865dd55861e24cfaa601d9f87655776c1458d099b487449f946da9b3138fc700\" width=\"400\"><br/>\n",
"</p>"
"</p> \n",
"\n",
"点此[快速获取本节数据集](https://aistudio.baidu.com/aistudio/datasetdetail/57285),待数据集下载完毕后可使用`!unzip OCR_Dataset.zip -d data/`命令或熟悉的解压软件进行解压,待数据准备工作完成后修改本文“训练准备”中的`DATA_PATH = 解压后数据集路径`。"
],
"cell_type": "markdown",
"metadata": {}
Expand Down Expand Up @@ -126,16 +128,13 @@
"\n",
"CTC相关论文:[Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neu](http://people.idsia.ch/~santiago/papers/icml2006.pdf) \n",
"<p align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/50cf1fc38f6b40e596acf71dc43333ff49dcaafb5a9f484b8aeee2db2c08ca67\" width=\"800\"><br/>\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/f9458cedbb4441d682f15fefd3f3cae5e49d499bcf0a4bbdb976dfdff5a2e656\" width=\"800\"><br/>\n",
"</p>\n",
"\n",
"网络部分,因本篇采用数据集较为简单且图像尺寸较小并不适合较深层次网络。若在对尺寸较大的图像进行模型构建,可以考虑使用更深层次网络/注意力机制来完成。当然也可以通过目标检测形式先检出文本位置,然后进行OCR部分模型构建。\n",
"网络部分,因本篇采用数据集较为简单且图像尺寸较小并不适合较深层次网络。若在对尺寸较大的图像进行模型构建,可以考虑使用更深层次网络/注意力机制来完成。当然也可以通过目标检测形式先检出文本位置,然后进行OCR部分模型构建。(下方样例来源[PaddleOCR](v))\n",
"\n",
"<p align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/6e0665ddfe6a46e1b658da870cd4043f1d50e5f4dc2746018c710c58d2e0c18c\" width=\"400\"><br/>\n",
" \n",
" \n",
"<a href=\"https://github.com/PaddlePaddle/PaddleOCR\">PaddleOCR效果图</a>\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/6e0665ddfe6a46e1b658da870cd4043f1d50e5f4dc2746018c710c58d2e0c18c\" width=\"400\"></br>\n",
"</p>"
]
},
Expand Down Expand Up @@ -164,16 +163,16 @@
" self.is_infer = is_infer\n",
"\n",
" # 定义一层3x3卷积+BatchNorm\n",
" self.conv1 = paddle.nn.Conv2d(in_channels=IMAGE_SHAPE_C,\n",
" self.conv1 = paddle.nn.Conv2D(in_channels=IMAGE_SHAPE_C,\n",
" out_channels=32,\n",
" kernel_size=3)\n",
" self.bn1 = paddle.nn.BatchNorm2d(32)\n",
" self.bn1 = paddle.nn.BatchNorm2D(32)\n",
" # 定义一层步长为2的3x3卷积进行下采样+BatchNorm\n",
" self.conv2 = paddle.nn.Conv2d(in_channels=32,\n",
" self.conv2 = paddle.nn.Conv2D(in_channels=32,\n",
" out_channels=64,\n",
" kernel_size=3,\n",
" stride=2)\n",
" self.bn2 = paddle.nn.BatchNorm2d(64)\n",
" self.bn2 = paddle.nn.BatchNorm2D(64)\n",
" # 定义一层1x1卷积压缩通道数,输出通道数设置为比LABEL_MAX_LEN稍大的定值可获取更优效果,当然也可设置为LABEL_MAX_LEN\n",
" self.conv3 = paddle.nn.Conv2d(in_channels=64,\n",
" out_channels=LABEL_MAX_LEN + 4,\n",
Expand Down Expand Up @@ -215,6 +214,8 @@
" if self.is_infer:\n",
" # 输出层 - Shape = (Batch Size, Max label len, Prob) \n",
" x = paddle.nn.functional.softmax(x)\n",
" # 转换为标签\n",
" x = paddle.tensor.argmax(x, axis=-1)\n",
" return x"
]
},
Expand Down Expand Up @@ -286,8 +287,8 @@
" super().__init__()\n",
"\n",
" def forward(self, ipt, label):\n",
" input_lengths = paddle.tensor.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN + 4)\n",
" label_lengths = paddle.tensor.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN)\n",
" input_lengths = paddle.tensor.creation.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN + 4)\n",
" label_lengths = paddle.tensor.creation.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN)\n",
" # 按文档要求进行转换dim顺序\n",
" ipt = paddle.tensor.transpose(ipt, [1, 0, 2])\n",
" # 计算loss\n",
Expand Down Expand Up @@ -452,7 +453,7 @@
},
"outputs": [],
"source": [
"# 待预测目录\n",
"# 待预测目录 - 可在测试数据集中挑出\b3张图像放在该目录中进行推理\n",
"INFER_DATA_PATH = \"./sample_img\"\n",
"# 训练后存档点路径 - 10代表使用第10个存档点\n",
"CHECKPOINT_PATH = \"./output/10\"\n",
Expand Down Expand Up @@ -505,7 +506,7 @@
{
"source": [
"## 开始预测\n",
"> 飞桨2.0 CTC Decoder 相关API正在迁移中,本节暂时使用[第三方解码器](https://github.com/awni/speech/blob/072bcf9ff510d814fbfcaad43b2883ecf8f60806/speech/models/ctc_decoder.py)进行解码。"
"> 飞桨2.0 CTC Decoder 相关API正在迁移中,本节暂时使用简易版解码器。"
],
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -533,7 +534,22 @@
}
],
"source": [
"from ctc import decode\n",
"# 编写简易版解码器\n",
"def ctc_decode(text, blank=10):\n",
" \"\"\"\n",
" 简易CTC解码器\n",
" :param text: 待解码数据\n",
" :param blank: 分隔符索引值\n",
" :return: 解码后数据\n",
" \"\"\"\n",
" result = []\n",
" cache_idx = -1\n",
" for char in text:\n",
" if char != blank and char != cache_idx:\n",
" result.append(char)\n",
" cache_idx = char\n",
" return result\n",
"\n",
"\n",
"# 实例化预测模型\n",
"model = paddle.Model(Net(is_infer=True), inputs=input_define)\n",
Expand All @@ -547,10 +563,10 @@
"img_names = infer_reader.get_names()\n",
"results = model.predict(infer_reader, batch_size=BATCH_SIZE)\n",
"index = 0\n",
"for result in results[0]:\n",
" for prob in result:\n",
" out, _ = decode(prob, blank=10)\n",
" print(f\"文件名:{img_names[index]},预测结果为:{out}\")\n",
"for text_batch in results[0]:\n",
" for prob in text_batch:\n",
" out = ctc_decode(prob, blank=10)\n",
" print(f\"文件名:{img_names[index]},推理结果为:{out}\")\n",
" index += 1"
]
}
Expand Down
Binary file added paddle2.0_docs/image_ocr/sample_img/9450.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added paddle2.0_docs/image_ocr/sample_img/9451.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added paddle2.0_docs/image_ocr/sample_img/9452.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.