diff --git a/paddle2.0_docs/image_ocr/OCR.ipynb b/paddle2.0_docs/image_ocr/OCR.ipynb index 370ce582..db68b55e 100644 --- a/paddle2.0_docs/image_ocr/OCR.ipynb +++ b/paddle2.0_docs/image_ocr/OCR.ipynb @@ -22,7 +22,9 @@ "**数据展示**\n", "

\n", "
\n", - "

" + "

\n", + "\n", + "点此[快速获取本节数据集](https://aistudio.baidu.com/aistudio/datasetdetail/57285),待数据集下载完毕后可使用`!unzip OCR_Dataset.zip -d data/`命令或熟悉的解压软件进行解压,待数据准备工作完成后修改本文“训练准备”中的`DATA_PATH = 解压后数据集路径`。" ], "cell_type": "markdown", "metadata": {} @@ -126,16 +128,13 @@ "\n", "CTC相关论文:[Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neu](http://people.idsia.ch/~santiago/papers/icml2006.pdf) \n", "

\n", - "
\n", + "
\n", "

\n", "\n", - "网络部分,因本篇采用数据集较为简单且图像尺寸较小并不适合较深层次网络。若在对尺寸较大的图像进行模型构建,可以考虑使用更深层次网络/注意力机制来完成。当然也可以通过目标检测形式先检出文本位置,然后进行OCR部分模型构建。\n", + "网络部分,因本篇采用数据集较为简单且图像尺寸较小并不适合较深层次网络。若在对尺寸较大的图像进行模型构建,可以考虑使用更深层次网络/注意力机制来完成。当然也可以通过目标检测形式先检出文本位置,然后进行OCR部分模型构建。(下方样例来源[PaddleOCR](v))\n", "\n", "

\n", - "
\n", - " \n", - " \n", - "PaddleOCR效果图\n", + "
\n", "

" ] }, @@ -164,16 +163,16 @@ " self.is_infer = is_infer\n", "\n", " # 定义一层3x3卷积+BatchNorm\n", - " self.conv1 = paddle.nn.Conv2d(in_channels=IMAGE_SHAPE_C,\n", + " self.conv1 = paddle.nn.Conv2D(in_channels=IMAGE_SHAPE_C,\n", " out_channels=32,\n", " kernel_size=3)\n", - " self.bn1 = paddle.nn.BatchNorm2d(32)\n", + " self.bn1 = paddle.nn.BatchNorm2D(32)\n", " # 定义一层步长为2的3x3卷积进行下采样+BatchNorm\n", - " self.conv2 = paddle.nn.Conv2d(in_channels=32,\n", + " self.conv2 = paddle.nn.Conv2D(in_channels=32,\n", " out_channels=64,\n", " kernel_size=3,\n", " stride=2)\n", - " self.bn2 = paddle.nn.BatchNorm2d(64)\n", + " self.bn2 = paddle.nn.BatchNorm2D(64)\n", " # 定义一层1x1卷积压缩通道数,输出通道数设置为比LABEL_MAX_LEN稍大的定值可获取更优效果,当然也可设置为LABEL_MAX_LEN\n", " self.conv3 = paddle.nn.Conv2d(in_channels=64,\n", " out_channels=LABEL_MAX_LEN + 4,\n", @@ -215,6 +214,8 @@ " if self.is_infer:\n", " # 输出层 - Shape = (Batch Size, Max label len, Prob) \n", " x = paddle.nn.functional.softmax(x)\n", + " # 转换为标签\n", + " x = paddle.tensor.argmax(x, axis=-1)\n", " return x" ] }, @@ -286,8 +287,8 @@ " super().__init__()\n", "\n", " def forward(self, ipt, label):\n", - " input_lengths = paddle.tensor.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN + 4)\n", - " label_lengths = paddle.tensor.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN)\n", + " input_lengths = paddle.tensor.creation.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN + 4)\n", + " label_lengths = paddle.tensor.creation.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN)\n", " # 按文档要求进行转换dim顺序\n", " ipt = paddle.tensor.transpose(ipt, [1, 0, 2])\n", " # 计算loss\n", @@ -452,7 +453,7 @@ }, "outputs": [], "source": [ - "# 待预测目录\n", + "# 待预测目录 - 可在测试数据集中挑出\b3张图像放在该目录中进行推理\n", "INFER_DATA_PATH = \"./sample_img\"\n", "# 训练后存档点路径 - 10代表使用第10个存档点\n", "CHECKPOINT_PATH = \"./output/10\"\n", @@ -505,7 +506,7 @@ { "source": [ "## 开始预测\n", - "> 飞桨2.0 CTC Decoder 相关API正在迁移中,本节暂时使用[第三方解码器](https://github.com/awni/speech/blob/072bcf9ff510d814fbfcaad43b2883ecf8f60806/speech/models/ctc_decoder.py)进行解码。" + "> 飞桨2.0 CTC Decoder 相关API正在迁移中,本节暂时使用简易版解码器。" ], "cell_type": "markdown", "metadata": { @@ -533,7 +534,22 @@ } ], "source": [ - "from ctc import decode\n", + "# 编写简易版解码器\n", + "def ctc_decode(text, blank=10):\n", + " \"\"\"\n", + " 简易CTC解码器\n", + " :param text: 待解码数据\n", + " :param blank: 分隔符索引值\n", + " :return: 解码后数据\n", + " \"\"\"\n", + " result = []\n", + " cache_idx = -1\n", + " for char in text:\n", + " if char != blank and char != cache_idx:\n", + " result.append(char)\n", + " cache_idx = char\n", + " return result\n", + "\n", "\n", "# 实例化预测模型\n", "model = paddle.Model(Net(is_infer=True), inputs=input_define)\n", @@ -547,10 +563,10 @@ "img_names = infer_reader.get_names()\n", "results = model.predict(infer_reader, batch_size=BATCH_SIZE)\n", "index = 0\n", - "for result in results[0]:\n", - " for prob in result:\n", - " out, _ = decode(prob, blank=10)\n", - " print(f\"文件名:{img_names[index]},预测结果为:{out}\")\n", + "for text_batch in results[0]:\n", + " for prob in text_batch:\n", + " out = ctc_decode(prob, blank=10)\n", + " print(f\"文件名:{img_names[index]},推理结果为:{out}\")\n", " index += 1" ] } diff --git a/paddle2.0_docs/image_ocr/sample_img/9450.jpg b/paddle2.0_docs/image_ocr/sample_img/9450.jpg new file mode 100644 index 00000000..028273be Binary files /dev/null and b/paddle2.0_docs/image_ocr/sample_img/9450.jpg differ diff --git a/paddle2.0_docs/image_ocr/sample_img/9451.jpg b/paddle2.0_docs/image_ocr/sample_img/9451.jpg new file mode 100644 index 00000000..1fbea8ae Binary files /dev/null and b/paddle2.0_docs/image_ocr/sample_img/9451.jpg differ diff --git a/paddle2.0_docs/image_ocr/sample_img/9452.jpg b/paddle2.0_docs/image_ocr/sample_img/9452.jpg new file mode 100644 index 00000000..ff4fdd93 Binary files /dev/null and b/paddle2.0_docs/image_ocr/sample_img/9452.jpg differ