diff --git a/paddle2.0_docs/image_ocr/OCR.ipynb b/paddle2.0_docs/image_ocr/OCR.ipynb
index 370ce582..db68b55e 100644
--- a/paddle2.0_docs/image_ocr/OCR.ipynb
+++ b/paddle2.0_docs/image_ocr/OCR.ipynb
@@ -22,7 +22,9 @@
"**数据展示**\n",
"
\n",
"
\n",
- "
"
+ " \n",
+ "\n",
+ "点此[快速获取本节数据集](https://aistudio.baidu.com/aistudio/datasetdetail/57285),待数据集下载完毕后可使用`!unzip OCR_Dataset.zip -d data/`命令或熟悉的解压软件进行解压,待数据准备工作完成后修改本文“训练准备”中的`DATA_PATH = 解压后数据集路径`。"
],
"cell_type": "markdown",
"metadata": {}
@@ -126,16 +128,13 @@
"\n",
"CTC相关论文:[Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neu](http://people.idsia.ch/~santiago/papers/icml2006.pdf) \n",
"\n",
- "
\n",
+ "
\n",
"
\n",
"\n",
- "网络部分,因本篇采用数据集较为简单且图像尺寸较小并不适合较深层次网络。若在对尺寸较大的图像进行模型构建,可以考虑使用更深层次网络/注意力机制来完成。当然也可以通过目标检测形式先检出文本位置,然后进行OCR部分模型构建。\n",
+ "网络部分,因本篇采用数据集较为简单且图像尺寸较小并不适合较深层次网络。若在对尺寸较大的图像进行模型构建,可以考虑使用更深层次网络/注意力机制来完成。当然也可以通过目标检测形式先检出文本位置,然后进行OCR部分模型构建。(下方样例来源[PaddleOCR](v))\n",
"\n",
"\n",
- "
\n",
- " \n",
- " \n",
- "PaddleOCR效果图\n",
+ "
\n",
"
"
]
},
@@ -164,16 +163,16 @@
" self.is_infer = is_infer\n",
"\n",
" # 定义一层3x3卷积+BatchNorm\n",
- " self.conv1 = paddle.nn.Conv2d(in_channels=IMAGE_SHAPE_C,\n",
+ " self.conv1 = paddle.nn.Conv2D(in_channels=IMAGE_SHAPE_C,\n",
" out_channels=32,\n",
" kernel_size=3)\n",
- " self.bn1 = paddle.nn.BatchNorm2d(32)\n",
+ " self.bn1 = paddle.nn.BatchNorm2D(32)\n",
" # 定义一层步长为2的3x3卷积进行下采样+BatchNorm\n",
- " self.conv2 = paddle.nn.Conv2d(in_channels=32,\n",
+ " self.conv2 = paddle.nn.Conv2D(in_channels=32,\n",
" out_channels=64,\n",
" kernel_size=3,\n",
" stride=2)\n",
- " self.bn2 = paddle.nn.BatchNorm2d(64)\n",
+ " self.bn2 = paddle.nn.BatchNorm2D(64)\n",
" # 定义一层1x1卷积压缩通道数,输出通道数设置为比LABEL_MAX_LEN稍大的定值可获取更优效果,当然也可设置为LABEL_MAX_LEN\n",
" self.conv3 = paddle.nn.Conv2d(in_channels=64,\n",
" out_channels=LABEL_MAX_LEN + 4,\n",
@@ -215,6 +214,8 @@
" if self.is_infer:\n",
" # 输出层 - Shape = (Batch Size, Max label len, Prob) \n",
" x = paddle.nn.functional.softmax(x)\n",
+ " # 转换为标签\n",
+ " x = paddle.tensor.argmax(x, axis=-1)\n",
" return x"
]
},
@@ -286,8 +287,8 @@
" super().__init__()\n",
"\n",
" def forward(self, ipt, label):\n",
- " input_lengths = paddle.tensor.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN + 4)\n",
- " label_lengths = paddle.tensor.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN)\n",
+ " input_lengths = paddle.tensor.creation.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN + 4)\n",
+ " label_lengths = paddle.tensor.creation.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN)\n",
" # 按文档要求进行转换dim顺序\n",
" ipt = paddle.tensor.transpose(ipt, [1, 0, 2])\n",
" # 计算loss\n",
@@ -452,7 +453,7 @@
},
"outputs": [],
"source": [
- "# 待预测目录\n",
+ "# 待预测目录 - 可在测试数据集中挑出\b3张图像放在该目录中进行推理\n",
"INFER_DATA_PATH = \"./sample_img\"\n",
"# 训练后存档点路径 - 10代表使用第10个存档点\n",
"CHECKPOINT_PATH = \"./output/10\"\n",
@@ -505,7 +506,7 @@
{
"source": [
"## 开始预测\n",
- "> 飞桨2.0 CTC Decoder 相关API正在迁移中,本节暂时使用[第三方解码器](https://github.com/awni/speech/blob/072bcf9ff510d814fbfcaad43b2883ecf8f60806/speech/models/ctc_decoder.py)进行解码。"
+ "> 飞桨2.0 CTC Decoder 相关API正在迁移中,本节暂时使用简易版解码器。"
],
"cell_type": "markdown",
"metadata": {
@@ -533,7 +534,22 @@
}
],
"source": [
- "from ctc import decode\n",
+ "# 编写简易版解码器\n",
+ "def ctc_decode(text, blank=10):\n",
+ " \"\"\"\n",
+ " 简易CTC解码器\n",
+ " :param text: 待解码数据\n",
+ " :param blank: 分隔符索引值\n",
+ " :return: 解码后数据\n",
+ " \"\"\"\n",
+ " result = []\n",
+ " cache_idx = -1\n",
+ " for char in text:\n",
+ " if char != blank and char != cache_idx:\n",
+ " result.append(char)\n",
+ " cache_idx = char\n",
+ " return result\n",
+ "\n",
"\n",
"# 实例化预测模型\n",
"model = paddle.Model(Net(is_infer=True), inputs=input_define)\n",
@@ -547,10 +563,10 @@
"img_names = infer_reader.get_names()\n",
"results = model.predict(infer_reader, batch_size=BATCH_SIZE)\n",
"index = 0\n",
- "for result in results[0]:\n",
- " for prob in result:\n",
- " out, _ = decode(prob, blank=10)\n",
- " print(f\"文件名:{img_names[index]},预测结果为:{out}\")\n",
+ "for text_batch in results[0]:\n",
+ " for prob in text_batch:\n",
+ " out = ctc_decode(prob, blank=10)\n",
+ " print(f\"文件名:{img_names[index]},推理结果为:{out}\")\n",
" index += 1"
]
}
diff --git a/paddle2.0_docs/image_ocr/sample_img/9450.jpg b/paddle2.0_docs/image_ocr/sample_img/9450.jpg
new file mode 100644
index 00000000..028273be
Binary files /dev/null and b/paddle2.0_docs/image_ocr/sample_img/9450.jpg differ
diff --git a/paddle2.0_docs/image_ocr/sample_img/9451.jpg b/paddle2.0_docs/image_ocr/sample_img/9451.jpg
new file mode 100644
index 00000000..1fbea8ae
Binary files /dev/null and b/paddle2.0_docs/image_ocr/sample_img/9451.jpg differ
diff --git a/paddle2.0_docs/image_ocr/sample_img/9452.jpg b/paddle2.0_docs/image_ocr/sample_img/9452.jpg
new file mode 100644
index 00000000..ff4fdd93
Binary files /dev/null and b/paddle2.0_docs/image_ocr/sample_img/9452.jpg differ