diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/PaddleOCR_VL_SFT.md b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/PaddleOCR_VL_SFT.md new file mode 100644 index 000000000..32ecbed63 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/PaddleOCR_VL_SFT.md @@ -0,0 +1,909 @@ +# Fine-tuning PaddleOCR-VL with New Approaches -- Prompt and Information Extraction + +> AI Studio Project Address: [Fine-tuning PaddleOCR-VL with New Approaches -- Prompt and Information Extraction](https://aistudio.baidu.com/projectdetail/9857242), which can be run directly in AI Studio's A100 environment (V100 environment can only perform model inference, not fine-tuning) + +## Introduction + +When using PaddleOCR-VL, you would use code like the following: + +```python +CHOSEN_TASK = "ocr" # Options: 'ocr' | 'table' | 'chart' | 'formula' +PROMPTS = { + "ocr": "OCR:", + "table": "Table Recognition:", + "formula": "Formula Recognition:", + "chart": "Chart Recognition:", +} +``` + +Therefore, PaddleOCR-VL can recognize text, formulas, tables, and chart elements. + +PaddleOCR-VL, as a Vision-Language Model (VLM) specifically designed for document understanding, accomplishes different tasks through `prompts`. + +Currently, the fine-tuning of PaddleOCR-VL [PaddleOCR-VL-0.9B SFT](https://github.com/PaddlePaddle/ERNIE/blob/release/v1.4/docs/paddleocr_vl_sft_zh.md) also revolves around these four types of tasks. + +This article starts with fine-tuning the `prompts` of PaddleOCR-VL and introduces how to use fine-tuned PaddleOCR-VL for `information extraction`. + +### Fine-tuning Results Comparison + +Here's an example of recognizing and extracting information from an invoice: + +**Before Fine-tuning** + +![raw](images/raw.png) + +
+ +Click to view the raw data + +```json + +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "ocr", + "block_content": "购买方信息 | 名称 | 中青旅联科 | 杭州 | 公关顾问有限公司 | 销售方信息 | 名称 | 杭州万力酒店管理有限公司 | 统一社会信用代码/纳税人识别号 | 纳税人识别号 | 统一社会信用代码/纳税人识别号 | 税额 | 税额/征收率 | 税额/征收率\n**项目名称** | 规格型号 | | | | | | | | | | | | \n**住宿服务** | 住宿费 | | | | | | | | | | | | \n**合计** | | | | | | | | | | | | | \n**价税合计(大写)** | | 壹仟叁佰玖拾柒圆整 | | | | | | | | | | | \n备注 | 销售方地址:浙江省杭州市西湖区转塘街道霞鸣街199号万美商务中心3号楼;电话:0571-85220222;销方开户银行:农行上泗支行;入住人:柳顺;入住日期:9月23日入住-9月26日退房;入住天数:3天;金额:1397元 | | | | | | | | | | | | | \n开票人:祝营营", + "block_bbox": [0, 0, 1260, 838] + }] + } +} + +``` + +
+ +**After Fine-tuning** + +![after](images/sft.png) + +
+ +Click to view the fine-tuning data + +```json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "OCR:{}", + "block_content": "{"发票信息": {"发票名称": "电子发票", "发票号码": "25332000000426443187", "开票日期": "2025年09月26日"}, "销售方信息": {"名称": "杭州万力酒店管理有限公司", "统一社会信用代码": "91330105MA2H2DUJ92", "纳税人识别号": "91330106MA2B1C4UXN"}, "项目名称": "规格型号", "单位": "个", "数量": "3 461.056105610561", "单价": "1383.17", "金额": "税率/征收率", "税额": "13.83"}, "合计": {"金额": "1383.17", "税额": "13.83"}, "价税合计(大写)": "壹仟叁佰玖拾柒圆整", "价税合计(小写)": "1397.00"}, "销售方地址": "浙江省杭州市西湖区转塘街道霞鸣街199号万美商务中心3号楼", "电话": "0571-85220222", "销方开户银行": "农行上泗支行", "入住人": "柳顺", "入住日期": "9月23日 入住-9月26日 退房", "入住天数": "3天", "金额": "1397元"},", + "block_bbox": [0, 0, 1260, 838] + }] + } +} + +``` + +Or specify fields to extract specific information: + +``` json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "OCR:{"发票号码": "", "开票日期": ""}", + "block_content": "{"发票号码": "25332000000426443187", "开票日期": "2025年09月26日"}", + "block_bbox": [0, 0, 1260, 838] + }] + } +} + +``` + +
+ +After fine-tuning, it can output data in `JSON` format and can output corresponding information based on different `prompts` (here `block_label`). + +> Due to the small amount of data used in this fine-tuning, the fine-tuning results are not very good. This is for reference only. + +Regarding the fine-tuning of PaddleOCR-VL, [PaddleOCR-VL-0.9B SFT](https://github.com/PaddlePaddle/ERNIE/blob/release/v1.4/docs/paddleocr_vl_sft_zh.md) already has a very detailed introduction. Since this article's fine-tuning targets `prompts`, the following two parts are slightly different from the original text: + +- Data preparation +- Model inference + +## Data Preparation + +When fine-tuning PaddleOCR-VL with ERNIE, you need to prepare data in `JSON` format along with corresponding image data: + +```json +{ + "image_info": [ + {"matched_text_index": 0, "image_url": "./assets/table_example.jps"}, + ], + "text_info": [ + {"text": "OCR:", "tag": "mask"}, + {"text": "Some text content here", "tag": "no_mask"}, + ] +} +``` + +Where: + +- `image_url` is the image path +- `text_info` with `tag` as `mask` corresponds to the `prompt` part, which is the `TASK` type of PaddleOCR-VL +- `text_info` with `tag` as `no_mask` corresponds to the `completion` part, which is the model's output + +The original model only has these four types of `prompt`: + +```json +{ + "ocr": "OCR:", + "table": "Table Recognition:", + "formula": "Formula Recognition:", + "chart": "Chart Recognition:", +} +``` + +However, we want the model to extract information according to our custom instructions, so we need to define custom `prompt`: + +```json +{ + "image_info": [ + { + "matched_text_index": 0, + "image_url": "/home/aistudio/paddleocr_vl/data/zzsptfp/zzsptfp/b175.jpg" + } + ], + "text_info": [ + { + "text": "OCR:{\"invoice_number\": \"\"}", + "tag": "mask" + }, + { + "text": "{\"invoice_number\": \"25332000000426443187\"}", + "tag": "no_mask" + } + ] +} +``` + +Here, the `text` in `mask` is not just `OCR:` but `OCR:{\"invoice_number\": \"\"}`, meaning we want the model to extract and output only the `invoice_number` field. + +We retain the original `OCR:` part to ensure the model can recognize it, while fine-tuning only the `{\"invoice_number\": \"\"}` part. + +The `text` part in `no_mask` directly outputs data in `JSON` format, corresponding to the `prompt`. + +Finally, we design the `prompt` as follows: + +``` text +# When specific value is a string, e.g., `{"invoice_code":"123456"}` +"OCR:{\"xxx\":\"\"}" + +# When specific value is a dictionary, e.g., `{"buyer":{"name":"Company A"}}` +"OCR:{\"xxx\":{}}" + +# When specific value is a list, e.g., `{"items":[{"name":"Product A"},{"name":"Product B"}]}` +"OCR:{\"xxx\":[]}" +``` + +For details on how to construct the dataset, refer to the appendix section below. + +## Model Fine-tuning + +The fine-tuning process is similar to this. First, install ERNIE: + +```bash +cd paddleocr_vl +git clone https://gitee.com/PaddlePaddle/ERNIE -b release/v1.4 +cd ERNIE +python -m pip install -r requirements/gpu/requirements.txt +python -m pip install -e . +python -m pip install tensorboard +python -m pip install opencv-python-headless +python -m pip install numpy==1.26.4 +``` + +Then, modify the configuration file and copy it to overwrite the original configuration file: + +```bash +cp paddleocr_vl/sft_config/run_ocr_vl_sft_16k.yaml \ + paddleocr_vl/ERNIE/examples/configs/PaddleOCR-VL/sft/run_ocr_vl_sft_16k.yaml +``` + +Download the PaddleOCR-VL model, here using modelscope's SDK: + +```bash +pip install modelscope +``` + +```python +from modelscope import snapshot_download +model_dir = snapshot_download('PaddlePaddle/PaddleOCR-VL', local_dir='paddleocr_vl/paddleocr_vl_model') +``` + +Finally, execute the fine-tuning command. Fine-tuning in AI Studio's A100 environment takes less than 1.5 hours. + +> V100 environment cannot perform fine-tuning but can perform model inference + +```bash +cd paddleocr_vl/ERNIE; CUDA_VISIBLE_DEVICES=0 \ + erniekit train examples/configs/PaddleOCR-VL/sft/run_ocr_vl_sft_16k.yaml +``` + +Here are the training logs: + +![logs](images/logs.png) + +As you can see, `loss` is steadily decreasing, indicating that the fine-tuning should be effective. + +## Model Inference + +After fine-tuning is completed, you can use the fine-tuned model for inference. The model can: + +1. Output complete information in `JSON` format +2. Output corresponding `JSON` format information based on different input fields + +This provides a flexible interface for information extraction tasks. + +Follow [PaddleOCR-VL-0.9B SFT](https://github.com/PaddlePaddle/ERNIE/blob/release/v1.4/docs/paddleocr_vl_sft_zh.md) for inference. First, you need to install the necessary environment + +```bash +python -m pip install -U "paddleocr[doc-parser]" +python -m pip install https://paddle-whl.bj.bcebos.com/nightly/cu126/safetensors/safetensors-0.6.2.dev0-cp38-abi3-linux_x86_64.whl +python -m pip install --force-reinstall opencv-python-headless +python -m pip install numpy==1.26.4 +``` + +At this point, you still cannot directly perform model inference because, in PaddleX, which PaddleOCR depends on, PaddleOCR-VL currently only supports these four types of `prompt_label`: `['ocr', 'formula', 'table', 'chart']`, and our `prompt` obviously cannot pass the code validation: + +Refer to the `paddlex/inference/pipelines/paddleocr_vl/pipeline.py` file + +``` python +assert prompt_label.lower() in [ + "ocr", + "formula", + "table", + "chart", +], f"Layout detection is disabled (use_layout_detection=False). 'prompt_label' must be one of ['ocr', 'formula', 'table', 'chart'], but got '{prompt_label}'." + +``` + +Here is a patch script that can bypass the above restriction: + +```bash +python paddleocr_vl/patch/patch_assert_to_warning.py +``` + +Then, copy the following files to the PaddleOCR-VL-SFT directory, and you can happily perform inference verification. + +```bash +cp paddleocr_vl/paddleocr_vl_model/chat_template.jinja paddleocr_vl/PaddleOCR-VL-SFT +cp paddleocr_vl/paddleocr_vl_model/inference.yml paddleocr_vl/PaddleOCR-VL-SFT +``` + +Here, a new invoice data is used to verify the model. + +```bash +python -m paddleocr doc_parser -i paddleocr_vl/data/test.jpg \ + --vl_rec_model_name "PaddleOCR-VL-0.9B" \ + --vl_rec_model_dir "paddleocr_vl/PaddleOCR-VL-SFT" \ + --save_path="paddleocr_vl/PaddleOCR-VL-SFT_response" \ + --use_layout_detection=False \ + --prompt_label="OCR:{}" +``` + +Output complete information: + +```json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "OCR:{}", + "block_content": "{ + "发票信息": { + "发票名称": "电子发票", + "发票号码": "25332000000426443187", + "开票日期": "2025年09月26日" + }, + "销售方信息": { + "名称": "杭州万力酒店管理有限公司", + "统一社会信用代码": "91330105MA2H2DUJ92", + "纳税人识别号": "91330106MA2B1C4UXN" + }, + "项目名称": "规格型号", + "单位": "个", + "数量": "3 461.056105610561", + "单价": "1383.17", + "金额": "税率/征收率", + "税额": "13.83" + }, "合计": { + "金额": "1383.17", + "税额": "13.83" + }, "价税合计(大写)": "壹仟叁佰玖拾柒圆整", "价税合计(小写)": "1397.00" + }, "销售方地址": "浙江省杭州市西湖区转塘街道霞鸣街199号万美商务中心3号楼", "电话": "0571-85220222", "销方开户银行": "农行上泗支行", "入住人": "柳顺", "入住日期": "9月23日 入住-9月26日 退房", "入住天数": "3天", "金额": "1397元" + }", + "block_bbox": [0, 0, 1260, 838] + }] + } +} +``` + +Note two points: + +- `use_layout_detection=False`, not through the layout model, but directly sending the image to `PaddleOCR-VL-0.9B` +- `prompt_label="OCR:{}"`, here we use our fine-tuned `prompt`, hoping the model outputs complete json format information + +> Note, the data finally output by the model is actually incomplete, for example, missing `购买方` (Buyer) information, which should be caused by the small amount of fine-tuning data. + +Now let's look at the model before fine-tuning, which can only output table-style data: + +```bash +python -m paddleocr doc_parser -i /home/aistudio/paddleocr_vl/data/test.jpg \ + --vl_rec_model_name "PaddleOCR-VL-0.9B" \ + --vl_rec_model_dir "/home/aistudio/paddleocr_vl/paddleocr_vl_model" \ + --save_path="/home/aistudio/paddleocr_vl/paddleocr_vl_model_response" \ + --use_layout_detection=False \ + --prompt_label="ocr" +``` + +Output: + +```json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "ocr", + "block_content": "购买方信息 | 名称 | 中青旅联科 | 杭州 | 公关顾问有限公司 | 销售方信息 | 名称 | 杭州万力酒店管理有限公司 | 统一社会信用代码/纳税人识别号 | 纳税人识别号 | 统一社会信用代码/纳税人识别号 | 税额 | 税额/征收率 | 税额/征收率\n**项目名称** | 规格型号 | | | | | | | | | | | | \n**住宿服务** | 住宿费 | | | | | | | | | | | | \n**合计** | | | | | | | | | | | | | \n**价税合计(大写)** | | 壹仟叁佰玖拾柒圆整 | | | | | | | | | | | \n备注 | 销售方地址:浙江省杭州市西湖区转塘街道霞鸣街199号万美商务中心3号楼;电话:0571-85220222;销方开户银行:农行上泗支行;入住人:柳顺;入住日期:9月23日入住-9月26日退房;入住天数:3天;金额:1397元 | | | | | | | | | | | | | \n开票人:祝营营", + "block_bbox": [0, 0, 1260, 838] + }] + } +} + + +``` + +Then, let's test extracting only partial information: + +```bash +python -m paddleocr doc_parser -i /home/aistudio/paddleocr_vl/data/test.jpg \ + --vl_rec_model_name "PaddleOCR-VL-0.9B" \ + --vl_rec_model_dir "/home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT" \ + --save_path="/home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT_response" \ + --use_layout_detection=False \ + --prompt_label="OCR:{\"购买方名称\": {}, \"销售方名称\": {}}" +``` + +Output: + +```json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "OCR:{"购买方名称": {}, "销售方名称": {}}", + "block_content": "{ + "购买方名称": { + "名称": "中青旅联科(杭州)公关顾问有限公司", + "统一社会信用代码": "91330105MA2H2DUJ92" + }, + "销售方名称": { + "名称": "杭州万力酒店管理有限公司", + "统一社会信用代码": "91330106MA2B1C4UXN" + } + }", + "block_bbox": [0, 0, 1260, 838] + }] + } +} +``` + +As you can see, the model can basically follow our instructions to extract corresponding information. + +## Using transformers Library for Information Extraction + +You can use the transformers library for information extraction, referring to [[Model] Add PaddleOCR-VL Model Support by zhang-prog](https://github.com/huggingface/transformers/pull/42178) + +> Note: Currently, the generated model directory after fine-tuning has not been synchronized for updates. When using the transformers library for information extraction, you need to first download the latest model from [huggingface](https://huggingface.co/PaddlePaddle/PaddleOCR-VL/tree/main), then rename the fine-tuned model file `model-00001-of-00001.safetensors` to `model.safetensors` and place it (overwriting) in the downloaded model directory. + +```python +from transformers import pipeline + +pipe = pipeline( + "image-text-to-text", + model="./PaddleOCR_VL_SFT/PaddleOCR-VL", # downloaded model directory + dtype="bfloat16") +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://ai-studio-static-online.cdn.bcebos.com/dc31c334d4664ca4955aa47d8e202a53a276fd0aab0840b09abe953fe51207d0"}, + {"type": "text", "text": "OCR:{}"}, + ] + } +] +result = pipe(text=messages) +print(result) + +``` + +If GPU memory is insufficient, you can try the following quantization method: + +```python +from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig +import torch + +path = "./PaddleOCR_VL_SFT/PaddleOCR-VL", # downloaded model directory +processor = AutoProcessor.from_pretrained(path, local_files_only=True, use_fast=True) + +# 4-bit quantization configuration to significantly reduce GPU memory usage +quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" +) +model = AutoModelForImageTextToText.from_pretrained( + path, + quantization_config=quantization_config, + # device_map="auto", + local_files_only=True +) +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://ai-studio-static-online.cdn.bcebos.com/dc31c334d4664ca4955aa47d8e202a53a276fd0aab0840b09abe953fe51207d0"}, + {"type": "text", "text": "OCR:{\"Invoice Date\": \"\"}"}, + ] + } +] +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +).to(model.device) + +outputs = model.generate(**inputs, max_new_tokens=100) +result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1]) +print(result) + +``` + +## Using PaddleOCR-VL-REC for Information Extraction + +You can use [PaddleOCR-VL-REC](https://github.com/megemini/PaddleOCR-VL-REC) for information extraction: + +```python +from paddleocr_vl_rec import PaddleOCRVLRec + +# Initialize the recognizer +recognizer = PaddleOCRVLRec( + model_dir="path/to/your/model" +) + +# Use dict as query (will be converted to JSON string) +# Returns JSON format (parse results using json_repair) +result_json = recognizer.predict( + image="/path/to/your/image.jpg", + query={"NAME":"", "ITEMS":[]}, + return_json=True +) +# result_json is a dictionary object +print(type(result_json)) # +print(result_json) + +# Use list as query (will be converted to {"item1":"", "item2":""} format) +result_json = recognizer.predict( + image="/path/to/your/image.jpg", + query=["item1", "item2"], + return_json=True +) +print(result_json) + +recognizer.close() + +``` + +## Summary + +This article introduces how to implement information extraction tasks by fine-tuning the prompts of PaddleOCR-VL. The main methods include: + +1. **Data Preparation**: Using VLM models to generate structured training data, which is more efficient compared to traditional annotation methods. +2. **Prompt Design**: Through carefully designed prompt templates, the model can flexibly output `JSON` format information for different fields. +3. **Model Fine-tuning**: Utilizing PaddleOCR-VL's fine-tuning capability to make it learn to generate corresponding outputs based on different prompts. + +Compared to traditional information extraction methods (such as NER + relation extraction), this method has better integration and flexibility. + +## Appendix + +### 1. Dataset + +There are many application scenarios for information extraction. Here, we use [VAT Ordinary Invoice](https://aistudio.baidu.com/datasetdetail/125158) data as an example. + +> You can refer to the article [Invoice Key Information Extraction Based on VI-LayoutXLM](https://bbs.huaweicloud.com/blogs/383854), which provides a relatively complete explanation of fine-tuning PaddleOCR models for information extraction. + +However, the dataset's annotation for `Relation Extraction` is quite crude. For example: + +![VAT Ordinary Invoice](images/re.jpg) + +Here only `名称` (Name) is annotated, without specifying whether it's `购买方名称` (Buyer Name) or `销售方名称` (Seller Name). + +As mentioned earlier, we can use PaddleOCR-VL as a VLM model. Therefore, we can let a more capable VLM model `teach` PaddleOCR-VL to recognize `购买方名称` (Buyer Name) and `销售方名称` (Seller Name). + +Data can be generated through the `ernie-4.5-turbo-vl-preview` model, referring to the script `paddleocr_vl/tools/extract_ner/extract_ner.py`. + +``` python + +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Multimodal Image Recognition Script +Recognizes image information by calling OpenAI API and returns JSON format data +Supports local images and multimodal large model processing +""" +... + +class MultimodalImageRecognizer: + """Multimodal Image Recognizer""" + ... + + def recognize_image( + self, + image_input: Union[str, bytes], + prompt: str, + system_prompt: str, + max_tokens: int = 2048 + ) -> Dict[str, Any]: + """ + Recognize image information + + Args: + image_input: Image path, URL, or base64 encoding + prompt: User prompt + system_prompt: System prompt + max_tokens: Maximum number of tokens + + Returns: + JSON format data of recognition results + """ + try: + # Create multimodal message + content = self.create_multimodal_message(prompt, image_input) + + # Build message list + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content} + ] + + logger.info(f"Starting API call to recognize image, model: {self.model}") + + # Call API + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=max_tokens, + temperature=0.2 + ) + + ... + + def analyze_image( + self, + image_input: Union[str, bytes], + analysis_type: str = "document" + ) -> Dict[str, Any]: + """ + Analyze image content (simplified version) + + Args: + image_input: Image path, URL, or base64 encoding + analysis_type: Analysis type, fixed as "document" + + Returns: + JSON format data of analysis results + """ + # Use document analysis prompt uniformly + prompt = "请分析这张文档图片中的所有信息,并返回完整的JSON格式数据。如果有的字段没有值,那么保留此字段,值为空。注意:所有的值都以string的形式返回,不要使用数字类型等。" + system_prompt = ''' +你是一个专业的文档分析助手,能够准确分析文档内容并返回结构化的JSON数据。 + +注意:数据的语言与文档的语言保持一致。 +注意:需要保留完整的字段层级关系,不要把所有字段都放到一级字段中。 +注意:JSON数据中不要包含注释,也不需要任何解释或说明。 +注意:对于特殊字符需要进行转义。 + +注意:对于选项字段,只保留所选择的字段值,如果没有选择,则置为空。 +比如,`业务类型` 包括 `账户开户、账户登记` 等选项,文档中`账户登记`是选中状态,则,返回 `{"业务类型":"账户登记"}`,不返回`账户开户`等其他选项。 +再比如,`业务类型` 包括 `账户开户、账户登记` 等选项,文档中没有标记选中的选项,则,返回 `{"业务类型":""}`,也就是说,只保留键,不需要有值。 +... +''' + + return self.recognize_image( + image_input=image_input, + prompt=prompt, + system_prompt=system_prompt + ) + +... +``` + +Use the `paddleocr_vl/tools/extract_ner/batch_extract_ner.py` script to batch generate data. The final generated data is as follows: + +``` json + +{ + "image": "/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp/b0.jpg", + "data": { + "发票名称": "广东增值税专用发票", + "发票编号": "12271524", + "发票代码": "4400154130", + "开票日期": "2016年06月12日", + "购买方": { + "名称": "深圳市购机汇网络有限公司", + "纳税人识别号": "440300083885931", + "地址、电话": "深圳市龙华新区民治街道民治大道展滔科技大厦A12070755-23806606", + "开户行及账号": "中国工商银行股份有限公司深圳园岭支行4000024709200172809" + }, + "密码区": "<<139 -<5//81>84974<00+7>2*0*53-+ +125*++9+-///5-7+/-0>8<9815 5<3/8*+//81/84+>6>4*36>4538", + "货物或应税劳务、服务名称": [ + { + "名称": "小米 红米3 全网通版 时尚金色", + "规格型号": "红米3", + "单位": "个", + "数量": "5", + "单价": "597.43589744", + "金额": "2987.18", + "税率": "17%", + "税额": "507.82" + }, + { + "名称": "移动联通电信4G手机 双卡双待", + "规格型号": "", + "单位": "", + "数量": "", + "单价": "", + "金额": "", + "税率": "", + "税额": "" + } + ], + "合计": { + "金额": "¥2987.18", + "税额": "¥507.82" + }, + "价税合计(大写)": "叁仟肆佰玖拾伍圆整", + "价税合计(小写)": "¥3495.00", + "销售方": { + "名称": "广州晶东贸易有限公司", + "纳税人识别号": "91440101664041243T", + "地址、电话": "广州市黄埔区九龙镇九龙工业园凤凰三横路99号 66215500", + "开户行及账号": "工行北京路支行3602000919200384952" + }, + "备注": "dd42982413947(00001,1952)7996有限", + "收款人": "王梅", + "复核": "张雪", + "开票人": "陈秋燕", + "销售方(章)": "广州晶东贸易有限公司 发票专用章" + } +} + +``` + +The data information generated here is much richer than the original annotation information. Although there are some flaws (for example, `货物或应税劳务、服务名称` should only have one record), it does not hinder the fine-tuning experiment. + +> The processed data has been uploaded to [VAT Ordinary Invoice and JSON Format Information](https://aistudio.baidu.com/dataset/detail/363136/intro). + +### 2. Prompts + +The goal of the `information extraction` task here is: + +- The model can output complete information in `JSON` format +- The model can output corresponding `JSON` format information based on different input fields + +For the above goals, corresponding prompts are designed here: + +**Complete Information** + +``` +"OCR:{}" +``` + +**Specific Information** + +``` +# Specific value is a string, such as `{"发票编码":"123456"}` +"OCR:{\"xxx\":\"\"}" + +# Specific value is a dictionary, such as `{"购买方":{"名称":"A公司"}}` +"OCR:{\"xxx\":{}}" + +# Specific value is a list, such as `{"货物或应税劳务、服务名称":[{"名称":"A产品"},{"名称":"B产品"}]}` +"OCR:{\"xxx\":[]}" +``` + +You can use `paddleocr_vl/tools/process_ner_dataset.py` to generate complete training data, including randomly generated prompts: + +```bash +python paddleocr_vl/tools/process_ner_dataset.py paddleocr_vl/data/zzsptfp \ + -o paddleocr_vl/output.jsonl \ + -n 10 \ + -p /media/shun/bigdata/Dataset/增值税普通发票 \ + -u /home/aistudio/paddleocr_vl/data/zzsptfp +``` + +Then, split the training dataset and validation dataset: + +```bash +python paddleocr_vl/tools/split_jsonl.py paddleocr_vl/output.jsonl \ + paddleocr_vl/output \ + --train_ratio 0.9 \ + --seed 123 +``` + +The final generated data is as follows: + +```json +{ + "image_info": [ + { + "matched_text_index": 0, + "image_url": "/home/aistudio/paddleocr_vl/data/zzsptfp/zzsptfp/b175.jpg" + } + ], + "text_info": [ + { + "text": "OCR:{\"发票名称\": \"\"}", + "tag": "mask" + }, + { + "text": "{\"发票名称\": \"广东增值税专用发票\"}", + "tag": "no_mask" + } + ] +} +``` + +The differences between the generated training data and [PaddleOCR-VL-0.9B SFT](https://github.com/PaddlePaddle/ERNIE/blob/release/v1.4/docs/paddleocr_vl_sft_zh.md) are: + +- The `text` of `mask` is not just `OCR:`, but also includes the field information to be extracted later +- The `text` of `no_mask` is complete `JSON` format information, not a plain text + +### 3. Configuration File Example + +```yaml +### data +train_dataset_type: "erniekit" +eval_dataset_type: "erniekit" +train_dataset_path: "/home/aistudio/paddleocr_vl/output_train.jsonl" +train_dataset_prob: "1.0" +eval_dataset_path: "/home/aistudio/paddleocr_vl/output_val.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 16384 +num_samples_each_epoch: 6000000 +use_pic_id: False +sft_replace_ids: True +sft_image_normalize: True +sft_image_rescale: True +image_dtype: "float32" + +### model +model_name_or_path: "/home/aistudio/paddleocr_vl/paddleocr_vl_model" +fine_tuning: Full +multimodal: True +use_flash_attention: True +use_sparse_flash_attn: True + +### finetuning +# base +stage: OCR-VL-SFT +seed: 23 +do_train: True +# do_eval: True +distributed_dataloader: False +dataloader_num_workers: 8 +prefetch_factor: 10 +batch_size: 1 +packing_size: 8 +packing: True +padding: False +num_train_epochs: 2 +max_steps: 80 +# eval_batch_size: 1 +# eval_iters: 50 +# eval_steps: 100 +# evaluation_strategy: steps +save_steps: 20 +save_total_limit: 5 +save_strategy: steps +logging_steps: 1 +release_grads: True +gradient_accumulation_steps: 8 +logging_dir: /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT/tensorboard_logs/ +output_dir: /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT +disable_tqdm: True + +# train +warmup_steps: 1 +learning_rate: 5.0e-6 +lr_scheduler_type: cosine +min_lr: 5.0e-7 +layerwise_lr_decay_bound: 1.0 +from_scratch: 0 + +# optimizer +weight_decay: 0.1 +adam_epsilon: 1.0e-8 +adam_beta1: 0.9 +adam_beta2: 0.95 + +# performance +tensor_parallel_degree: 1 +pipeline_parallel_degree: 1 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: False +pipeline_parallel_config: enable_delay_scale_loss enable_release_grads disable_partial_send_recv +recompute: True +recompute_granularity: "full" +recompute_use_reentrant: True +compute_type: bf16 +fp16_opt_level: O2 +disable_ckpt_quant: True +# amp_master_grad: True +amp_custom_white_list: + - lookup_table + - lookup_table_v2 + - flash_attn + - matmul + - matmul_v2 + - fused_gemm_epilogue +amp_custom_black_list: + - reduce_sum + - softmax_with_cross_entropy + - c_softmax_with_cross_entropy + - elementwise_div + - sin + - cos +unified_checkpoint: True +# unified_checkpoint_config: async_save +convert_from_hf: True +save_to_hf: True +``` diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/PaddleOCR_VL_SFT_zh.md b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/PaddleOCR_VL_SFT_zh.md new file mode 100644 index 000000000..46054f647 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/PaddleOCR_VL_SFT_zh.md @@ -0,0 +1,911 @@ +# 微调 PaddleOCR-VL 新姿势 -- Prompt 与 信息抽取 + +> AI Studio 项目地址:[微调 PaddleOCR-VL 新姿势 -- Prompt 与 信息抽取](https://aistudio.baidu.com/projectdetail/9857242) ,可在 AI Studio 的 A100 环境中直接运行(V100 环境只能进行模型推理,无法进行微调) + +## 引言 + +当使用 PaddleOCR-VL 时,会使用到如下的代码: + +```python +CHOSEN_TASK = "ocr" # Options: 'ocr' | 'table' | 'chart' | 'formula' +PROMPTS = { + "ocr": "OCR:", + "table": "Table Recognition:", + "formula": "Formula Recognition:", + "chart": "Chart Recognition:", +} +``` + +因此,PaddleOCR-VL 可以识别文本、公式、表格和图表元素。 + +PaddleOCR-VL 作为一款专为文档理解设计的视觉-语言模型(Vision-Language Model, VLM),是通过 `提示词` 完成不同任务的。 + +目前对于 PaddleOCR-VL 的微调 [PaddleOCR-VL-0.9B SFT](https://github.com/PaddlePaddle/ERNIE/blob/release/v1.4/docs/paddleocr_vl_sft_zh.md) 也是围绕这四类任务展开的。 + +本文从微调 PaddleOCR-VL 的 `提示词` 入手,介绍如何通过微调 PaddleOCR-VL 用于 `信息抽取`。 + +### 微调结果对比 + +这里以识别与抽取一张发票内的信息为例: + +**微调之前** + +![raw](images/raw.png) + +
+ + 点击查看原始输出 + + +```json + +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "ocr", + "block_content": "购买方信息 | 名称 | 中青旅联科 | 杭州 | 公关顾问有限公司 | 销售方信息 | 名称 | 杭州万力酒店管理有限公司 | 统一社会信用代码/纳税人识别号 | 纳税人识别号 | 统一社会信用代码/纳税人识别号 | 税额 | 税额/征收率 | 税额/征收率\n**项目名称** | 规格型号 | | | | | | | | | | | | \n**住宿服务** | 住宿费 | | | | | | | | | | | | \n**合计** | | | | | | | | | | | | | \n**价税合计(大写)** | | 壹仟叁佰玖拾柒圆整 | | | | | | | | | | | \n备注 | 销售方地址:浙江省杭州市西湖区转塘街道霞鸣街199号万美商务中心3号楼;电话:0571-85220222;销方开户银行:农行上泗支行;入住人:柳顺;入住日期:9月23日入住-9月26日退房;入住天数:3天;金额:1397元 | | | | | | | | | | | | | \n开票人:祝营营", + "block_bbox": [0, 0, 1260, 838] + }] + } +} + +``` + +
+ +**微调之后** + +![after](images/sft.png) + +
+ + 点击查看微调之后的输出 + +```json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "OCR:{}", + "block_content": "{"发票信息": {"发票名称": "电子发票", "发票号码": "25332000000426443187", "开票日期": "2025年09月26日"}, "销售方信息": {"名称": "杭州万力酒店管理有限公司", "统一社会信用代码": "91330105MA2H2DUJ92", "纳税人识别号": "91330106MA2B1C4UXN"}, "项目名称": "规格型号", "单位": "个", "数量": "3 461.056105610561", "单价": "1383.17", "金额": "税率/征收率", "税额": "13.83"}, "合计": {"金额": "1383.17", "税额": "13.83"}, "价税合计(大写)": "壹仟叁佰玖拾柒圆整", "价税合计(小写)": "1397.00"}, "销售方地址": "浙江省杭州市西湖区转塘街道霞鸣街199号万美商务中心3号楼", "电话": "0571-85220222", "销方开户银行": "农行上泗支行", "入住人": "柳顺", "入住日期": "9月23日 入住-9月26日 退房", "入住天数": "3天", "金额": "1397元"},", + "block_bbox": [0, 0, 1260, 838] + }] + } +} + +``` + +或者指定字段抽取特定信息: + +``` json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "OCR:{"发票号码": "", "开票日期": ""}", + "block_content": "{"发票号码": "25332000000426443187", "开票日期": "2025年09月26日"}", + "block_bbox": [0, 0, 1260, 838] + }] + } +} + +``` + +
+ +微调之后可以输出 `JSON` 格式的数据,并且可以根据不同的 `prompt`(这里的 `block_label`)输出对应的信息。 + +> 由于此次微调的数据量很少,因此微调结果并不好,此处仅做参考。 + +关于 PaddleOCR-VL 的微调,[PaddleOCR-VL-0.9B SFT](https://github.com/PaddlePaddle/ERNIE/blob/release/v1.4/docs/paddleocr_vl_sft_zh.md) 中已经有很详细的介绍,由于本文微调针对的是 `prompt`,因此在: + +- 数据准备 +- 模型推理 + +这两部分与原文略有不同。 + +## 数据准备 + +使用 ERNIE 对 PaddleOCR-VL 进行微调,需要准备 `JSON` 格式的数据与对应的图片数据: + +```json +{ + "image_info": [ + {"matched_text_index": 0, "image_url": "./assets/table_example.jps"}, + ], + "text_info": [ + {"text": "OCR:", "tag": "mask"}, + {"text": "দডর মথ বধ বকসট একনজর দখই চনত পরল তর অনমন\nঠক পনতই লকয রখছ\nর নচ থকই চচয বলল কশর, “এইই; পযছ! পযছ!'\nওপর", "tag": "no_mask"}, + ] +} +``` + +其中, + +- `image_url` 是图片的路径 +- `tag` 是 `mask` 的 `text_info` 对应 `prompt` 部分,也就是 PaddleOCR-VL 的 `TASK` 类型 +- `tag` 是 `no_mask` 的 `text_info` 对应 `completion` 部分,也就是模型的输出 + +原始模型中只有 + +```json +{ + "ocr": "OCR:", + "table": "Table Recognition:", + "formula": "Formula Recognition:", + "chart": "Chart Recognition:", +} +``` + +这四类 `prompt`,而我们希望,模型能够根据我们的指令抽取对应的信息,因此需要自定义 `prompt`: + +```json +{ + "image_info": [ + { + "matched_text_index": 0, + "image_url": "/home/aistudio/paddleocr_vl/data/zzsptfp/zzsptfp/b175.jpg" + } + ], + "text_info": [ + { + "text": "OCR:{\"发票名称\": \"\"}", + "tag": "mask" + }, + { + "text": "{\"发票名称\": \"广东增值税专用发票\"}", + "tag": "no_mask" + } + ] +} +``` + +这里 `tag` 为 `mask` 的 `text` 不是 `OCR:` 而是 `OCR:{\"发票名称\": \"\"}`,也就是说,我们希望模型抽取,且仅输出 `发票名称` 字段。 + +保留原始的 `OCR:` 部分,是为了保证模型能够识别 `OCR:` 部分,而仅对 `{\"发票名称\": \"\"}` 部分进行微调。 + +`tag` 为 `no_mask` 的 `text` 部分直接输出 `JSON` 格式的数据,并且与 `prompt` 对应。 + +最后,我们这里设计 `prompt` 为: + +``` text +# 特定值为字符串,如 `{"发票编码":"123456"}` +"OCR:{\"xxx\":\"\"}" + +# 特定值为字典,如 `{"购买方":{"名称":"A公司"}}` +"OCR:{\"xxx\":{}}" + +# 特定值为列表,如 `{"货物或应税劳务、服务名称":[{"名称":"A产品"},{"名称":"B产品"}]}` +"OCR:{\"xxx\":[]}" +``` + +具体如何构建数据集,可以参考后续的附录部分。 + +## 模型微调 + +微调的过程与 此 类似,首先安装 ERNIE: + +```bash +cd paddleocr_vl +git clone https://gitee.com/PaddlePaddle/ERNIE -b release/v1.4 +cd ERNIE +python -m pip install -r requirements/gpu/requirements.txt +python -m pip install -e . +python -m pip install tensorboard +python -m pip install opencv-python-headless +python -m pip install numpy==1.26.4 +``` + +然后,修改配置文件并复制覆盖原有配置文件: + +```bash +cp paddleocr_vl/sft_config/run_ocr_vl_sft_16k.yaml \ + paddleocr_vl/ERNIE/examples/configs/PaddleOCR-VL/sft/run_ocr_vl_sft_16k.yaml +``` + +下载 PaddleOCR-VL 模型,这里使用 modelscope 的 SDK: + +```bash +pip install modelscope +``` + +```python +from modelscope import snapshot_download +model_dir = snapshot_download('PaddlePaddle/PaddleOCR-VL', local_dir='paddleocr_vl/paddleocr_vl_model') +``` + +最后,就是执行微调命令即可,在 AI Studio 的 A100 环境中进行微调,大约需要不到 1.5 小时。 + +> V100 环境无法执行微调,但是可以进行模型推理 + +```bash +cd paddleocr_vl/ERNIE; CUDA_VISIBLE_DEVICES=0 \ + erniekit train examples/configs/PaddleOCR-VL/sft/run_ocr_vl_sft_16k.yaml +``` + +以下是训练的日志: + +![logs](images/logs.png) + +可以看到,`loss` 在稳定的下降,说明微调应该有效果。 + + +## 模型推理 + +微调完成后,可以使用微调后的模型进行推理。模型可以: + +1. 输出 `JSON` 格式的完整信息 +2. 根据不同的输入字段,输出对应的 `JSON` 格式的信息 + +这为信息抽取任务提供了灵活的接口。 + +按照 [PaddleOCR-VL-0.9B SFT](https://github.com/PaddlePaddle/ERNIE/blob/release/v1.4/docs/paddleocr_vl_sft_zh.md) 进行推理,首先需要安装必要的环境 + +```bash +python -m pip install -U "paddleocr[doc-parser]" +python -m pip install https://paddle-whl.bj.bcebos.com/nightly/cu126/safetensors/safetensors-0.6.2.dev0-cp38-abi3-linux_x86_64.whl +python -m pip install --force-reinstall opencv-python-headless +python -m pip install numpy==1.26.4 +``` + +此时,还不能直接进行模型的推理,因为,PaddleOCR 依赖的 PaddleX 中,目前对于 PaddleOCR-VL 仅支持 `['ocr', 'formula', 'table', 'chart']` 这四类 `prompt_label`,而我们的 `prompt` 显然无法通过代码的验证: + +参考 `paddlex/inference/pipelines/paddleocr_vl/pipeline.py` 文件 + +``` python +assert prompt_label.lower() in [ + "ocr", + "formula", + "table", + "chart", +], f"Layout detection is disabled (use_layout_detection=False). 'prompt_label' must be one of ['ocr', 'formula', 'table', 'chart'], but got '{prompt_label}'." + +``` + +这里写了一个 patch 脚本,可以绕过以上限制: + +```bash +python paddleocr_vl/patch/patch_assert_to_warning.py +``` + +然后,将以下文件拷贝到 PaddleOCR-VL-SFT 目录下,就可以愉快的进行推理验证了。 + +```bash +cp paddleocr_vl/paddleocr_vl_model/chat_template.jinja paddleocr_vl/PaddleOCR-VL-SFT +cp paddleocr_vl/paddleocr_vl_model/inference.yml paddleocr_vl/PaddleOCR-VL-SFT +``` + +这里使用一张新的发票数据来进行模型的验证。 + +```bash +python -m paddleocr doc_parser -i paddleocr_vl/data/test.jpg \ + --vl_rec_model_name "PaddleOCR-VL-0.9B" \ + --vl_rec_model_dir "paddleocr_vl/PaddleOCR-VL-SFT" \ + --save_path="paddleocr_vl/PaddleOCR-VL-SFT_response" \ + --use_layout_detection=False \ + --prompt_label="OCR:{}" +``` + +输出完整的信息: + +```json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "OCR:{}", + "block_content": "{ + "发票信息": { + "发票名称": "电子发票", + "发票号码": "25332000000426443187", + "开票日期": "2025年09月26日" + }, + "销售方信息": { + "名称": "杭州万力酒店管理有限公司", + "统一社会信用代码": "91330105MA2H2DUJ92", + "纳税人识别号": "91330106MA2B1C4UXN" + }, + "项目名称": "规格型号", + "单位": "个", + "数量": "3 461.056105610561", + "单价": "1383.17", + "金额": "税率/征收率", + "税额": "13.83" + }, "合计": { + "金额": "1383.17", + "税额": "13.83" + }, "价税合计(大写)": "壹仟叁佰玖拾柒圆整", "价税合计(小写)": "1397.00" + }, "销售方地址": "浙江省杭州市西湖区转塘街道霞鸣街199号万美商务中心3号楼", "电话": "0571-85220222", "销方开户银行": "农行上泗支行", "入住人": "柳顺", "入住日期": "9月23日 入住-9月26日 退房", "入住天数": "3天", "金额": "1397元" + }", + "block_bbox": [0, 0, 1260, 838] + }] + } +} +``` + +注意两点: + +- `use_layout_detection=False`,不通过 layout 模型,而是直接将图片送入 `PaddleOCR-VL-0.9B` +- `prompt_label="OCR:{}"`,这里使用我们微调的 `prompt` ,希望模型输出完整的 json 格式的信息 + +> 注意,这里模型最终输出的数据实际上不完整,比如,缺少 `购买方` 信息,应该是微调数据较少导致的。 + +再来看看微调之前的模型,只能输出 table 样式的数据: + +```bash +python -m paddleocr doc_parser -i /home/aistudio/paddleocr_vl/data/test.jpg \ + --vl_rec_model_name "PaddleOCR-VL-0.9B" \ + --vl_rec_model_dir "/home/aistudio/paddleocr_vl/paddleocr_vl_model" \ + --save_path="/home/aistudio/paddleocr_vl/paddleocr_vl_model_response" \ + --use_layout_detection=False \ + --prompt_label="ocr" +``` + +输出: + +```json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "ocr", + "block_content": "购买方信息 | 名称 | 中青旅联科 | 杭州 | 公关顾问有限公司 | 销售方信息 | 名称 | 杭州万力酒店管理有限公司 | 统一社会信用代码/纳税人识别号 | 纳税人识别号 | 统一社会信用代码/纳税人识别号 | 税额 | 税额/征收率 | 税额/征收率\n**项目名称** | 规格型号 | | | | | | | | | | | | \n**住宿服务** | 住宿费 | | | | | | | | | | | | \n**合计** | | | | | | | | | | | | | \n**价税合计(大写)** | | 壹仟叁佰玖拾柒圆整 | | | | | | | | | | | \n备注 | 销售方地址:浙江省杭州市西湖区转塘街道霞鸣街199号万美商务中心3号楼;电话:0571-85220222;销方开户银行:农行上泗支行;入住人:柳顺;入住日期:9月23日入住-9月26日退房;入住天数:3天;金额:1397元 | | | | | | | | | | | | | \n开票人:祝营营", + "block_bbox": [0, 0, 1260, 838] + }] + } +} +``` + +然后,测试一下只抽取部分信息: + +```bash +python -m paddleocr doc_parser -i /home/aistudio/paddleocr_vl/data/test.jpg \ + --vl_rec_model_name "PaddleOCR-VL-0.9B" \ + --vl_rec_model_dir "/home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT" \ + --save_path="/home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT_response" \ + --use_layout_detection=False \ + --prompt_label="OCR:{\"购买方名称\": {}, \"销售方名称\": {}}" +``` + +输出: + +```json +{ + "res": { + "input_path": "/home/aistudio/paddleocr_vl/data/test.jpg", + "page_index": None, + "model_settings": { + "use_doc_preprocessor": False, + "use_layout_detection": False, + "use_chart_recognition": False, + "format_block_content": False + }, + "parsing_res_list": [{ + "block_label": "OCR:{"购买方名称": {}, "销售方名称": {}}", + "block_content": "{ + "购买方名称": { + "名称": "中青旅联科(杭州)公关顾问有限公司", + "统一社会信用代码": "91330105MA2H2DUJ92" + }, + "销售方名称": { + "名称": "杭州万力酒店管理有限公司", + "统一社会信用代码": "91330106MA2B1C4UXN" + } + }", + "block_bbox": [0, 0, 1260, 838] + }] + } +} +``` + +可以看到,模型基本上可以跟随我们的指令抽取对应的信息。 + +## 使用 transformers 库进行信息抽取 + +可以使用 transformers 库进行信息抽取,参考 [[Model] Add PaddleOCR-VL Model Support by zhang-prog](https://github.com/huggingface/transformers/pull/42178) + +> 注意,目前微调后生成的模型目录还没有同步更新,在使用 transformers 库进行信息抽取时,需要先下载 [huggingface](https://huggingface.co/PaddlePaddle/PaddleOCR-VL/tree/main) 中最新的模型,然后,将微调后的模型文件 `model-00001-of-00001.safetensors` 重命名为 `model.safetensors`,并放到(并覆盖)下载的模型目录下。 + +```python +from transformers import pipeline + +pipe = pipeline( + "image-text-to-text", + model="./PaddleOCR_VL_SFT/PaddleOCR-VL", # 下载的模型目录 + dtype="bfloat16") +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://ai-studio-static-online.cdn.bcebos.com/dc31c334d4664ca4955aa47d8e202a53a276fd0aab0840b09abe953fe51207d0"}, + {"type": "text", "text": "OCR:{}"}, + ] + } +] +result = pipe(text=messages) +print(result) + +``` + +如果显存不足,可以尝试以下量化方法: + +```python +from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig +import torch + +path = "./PaddleOCR_VL_SFT/PaddleOCR-VL", # 下载的模型目录 +processor = AutoProcessor.from_pretrained(path, local_files_only=True, use_fast=True) + +# 4-bit 量化配置,大幅减少显存占用 +quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" +) +model = AutoModelForImageTextToText.from_pretrained( + path, + quantization_config=quantization_config, + # device_map="auto", + local_files_only=True +) +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://ai-studio-static-online.cdn.bcebos.com/dc31c334d4664ca4955aa47d8e202a53a276fd0aab0840b09abe953fe51207d0"}, + {"type": "text", "text": "OCR:{\"发票日期\": \"\"}"}, + ] + } +] +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +).to(model.device) + +outputs = model.generate(**inputs, max_new_tokens=100) +result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1]) +print(result) + +``` + +## 使用 PaddleOCR-VL-REC 进行信息抽取 + +可以使用 [PaddleOCR-VL-REC](https://github.com/megemini/PaddleOCR-VL-REC) 进行信息抽取: + +```python +from paddleocr_vl_rec import PaddleOCRVLRec + +# 初始化识别器 +recognizer = PaddleOCRVLRec( + model_dir="path/to/your/model" +) + +# 使用 dict 作为 query(会被转化为 JSON 字符串) +# 返回 JSON 格式(使用 json_repair 解析结果) +result_json = recognizer.predict( + image="/path/to/your/image.jpg", + query={"NAME":"", "ITEMS":[]}, + return_json=True +) +# result_json 是一个字典对象 +print(type(result_json)) # +print(result_json) + +# 使用 list 作为 query(会被转化为 {"item1":"", "item2":""} 的形式) +result_json = recognizer.predict( + image="/path/to/your/image.jpg", + query=["item1", "item2"], + return_json=True +) +print(result_json) + +recognizer.close() + +``` + +## 总结 + +本文介绍了如何通过微调 PaddleOCR-VL 的提示词(prompt)来实现信息抽取任务。主要方法包括: + +1. **数据准备**:使用 VLM 模型生成结构化的训练数据,相比于传统标注方式更加高效。 +2. **提示词设计**:通过精心设计的提示词模板,让模型能够灵活地输出不同字段的 `JSON` 格式信息。 +3. **模型微调**:利用 PaddleOCR-VL 的微调能力,使其学会根据不同的提示词生成对应的输出。 + +这种方法相比于传统的信息抽取方法(如 NER + 关系抽取),具有更好的集成度和灵活性。 + +## 附录 + +### 1. 数据集 + +信息抽取的应用场景有很多,这里以 [增值税普通发票](https://aistudio.baidu.com/datasetdetail/125158) 数据为例。 + +> 可以参考 [基于VI-LayoutXLM的发票关键信息抽取](https://bbs.huaweicloud.com/blogs/383854) 这篇文章,对于微调 PaddleOCR 模型进行信息抽取做了比较完整的讲解。 + +但是,数据集对于 `关系抽取(Relation Extraction)` 的标注还是比较简陋的,比如: + +![增值税普通发票](images/re.jpg) + +这里只标注了 `名称`,而没有标注说明是 `购买方名称` 还是 `销售方名称`。 + +前面提到,我们可以把 PaddleOCR-VL 当作 VLM 模型来使用,那么,我们可以让能力更强的 VLM 模型来 `教` PaddleOCR-VL 去识别 `购买方名称` 和 `销售方名称`。 + +数据可以通过 `ernie-4.5-turbo-vl-preview` 模型来生成,参考脚本 `paddleocr_vl/tools/extract_ner/extract_ner.py`。 + +``` python + +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +多模态图像识别脚本 +通过调用OpenAI接口识别图片信息并返回JSON格式数据 +支持本地图片和多模态大模型处理 +""" +... + +class MultimodalImageRecognizer: + """多模态图像识别器""" + ... + + def recognize_image( + self, + image_input: Union[str, bytes], + prompt: str, + system_prompt: str, + max_tokens: int = 2048 + ) -> Dict[str, Any]: + """ + 识别图片信息 + + Args: + image_input: 图片路径、URL或base64编码 + prompt: 用户提示词 + system_prompt: 系统提示词 + max_tokens: 最大令牌数 + + Returns: + 识别结果的JSON格式数据 + """ + try: + # 创建多模态消息 + content = self.create_multimodal_message(prompt, image_input) + + # 构建消息列表 + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content} + ] + + logger.info(f"开始调用API识别图片,模型: {self.model}") + + # 调用API + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=max_tokens, + temperature=0.2 + ) + + ... + + def analyze_image( + self, + image_input: Union[str, bytes], + analysis_type: str = "document" + ) -> Dict[str, Any]: + """ + 分析图片内容(简化版本) + + Args: + image_input: 图片路径、URL或base64编码 + analysis_type: 分析类型,固定为 "document" + + Returns: + 分析结果的JSON格式数据 + """ + # 统一使用文档分析提示词 + prompt = "请分析这张文档图片中的所有信息,并返回完整的JSON格式数据。如果有的字段没有值,那么保留此字段,值为空。注意:所有的值都以string的形式返回,不要使用数字类型等。" + system_prompt = ''' +你是一个专业的文档分析助手,能够准确分析文档内容并返回结构化的JSON数据。 + +注意:数据的语言与文档的语言保持一致。 +注意:需要保留完整的字段层级关系,不要把所有字段都放到一级字段中。 +注意:JSON数据中不要包含注释,也不需要任何解释或说明。 +注意:对于特殊字符需要进行转义。 + +注意:对于选项字段,只保留所选择的字段值,如果没有选择,则置为空。 +比如,`业务类型` 包括 `账户开户、账户登记` 等选项,文档中`账户登记`是选中状态,则,返回 `{"业务类型":"账户登记"}`,不返回`账户开户`等其他选项。 +再比如,`业务类型` 包括 `账户开户、账户登记` 等选项,文档中没有标记选中的选项,则,返回 `{"业务类型":""}`,也就是说,只保留键,不需要有值。 +... +''' + + return self.recognize_image( + image_input=image_input, + prompt=prompt, + system_prompt=system_prompt + ) + +... +``` + +使用 `paddleocr_vl/tools/extract_ner/batch_extract_ner.py` 脚本可以批量生成数据,最终生成的数据参考如下: + +``` json + +{ + "image": "/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp/b0.jpg", + "data": { + "发票名称": "广东增值税专用发票", + "发票编号": "12271524", + "发票代码": "4400154130", + "开票日期": "2016年06月12日", + "购买方": { + "名称": "深圳市购机汇网络有限公司", + "纳税人识别号": "440300083885931", + "地址、电话": "深圳市龙华新区民治街道民治大道展滔科技大厦A12070755-23806606", + "开户行及账号": "中国工商银行股份有限公司深圳园岭支行4000024709200172809" + }, + "密码区": "<<139 -<5//81>84974<00+7>2*0*53-+ +125*++9+-///5-7+/-0>8<9815 5<3/8*+//81/84+>6>4*36>4538", + "货物或应税劳务、服务名称": [ + { + "名称": "小米 红米3 全网通版 时尚金色", + "规格型号": "红米3", + "单位": "个", + "数量": "5", + "单价": "597.43589744", + "金额": "2987.18", + "税率": "17%", + "税额": "507.82" + }, + { + "名称": "移动联通电信4G手机 双卡双待", + "规格型号": "", + "单位": "", + "数量": "", + "单价": "", + "金额": "", + "税率": "", + "税额": "" + } + ], + "合计": { + "金额": "¥2987.18", + "税额": "¥507.82" + }, + "价税合计(大写)": "叁仟肆佰玖拾伍圆整", + "价税合计(小写)": "¥3495.00", + "销售方": { + "名称": "广州晶东贸易有限公司", + "纳税人识别号": "91440101664041243T", + "地址、电话": "广州市黄埔区九龙镇九龙工业园凤凰三横路99号 66215500", + "开户行及账号": "工行北京路支行3602000919200384952" + }, + "备注": "dd42982413947(00001,1952)7996有限", + "收款人": "王梅", + "复核": "张雪", + "开票人": "陈秋燕", + "销售方(章)": "广州晶东贸易有限公司 发票专用章" + } +} + +``` + +这里生成的数据信息比原有的标注信息丰富很多,虽然有一些瑕疵 (比如 `货物或应税劳务、服务名称` 中应该只有一条记录),但是不妨碍进行微调实验的进行。 + +> 处理后的数据已经上传至 [增值税普通发票与JSON格式信息](https://aistudio.baidu.com/dataset/detail/363136/intro)。 + +### 2. 提示词 + +这里的 `信息抽取` 任务,目标是: + +- 模型可以输出 `JSON` 格式的完整信息 +- 模型可以根据不同的输入字段,输出对应的 `JSON` 格式的信息 + +针对以上目标,这里设计了对应的提示词: + +**完整信息** + +``` +"OCR:{}" +``` + +**特定信息** + +``` +# 特定值为字符串,如 `{"发票编码":"123456"}` +"OCR:{\"xxx\":\"\"}" + +# 特定值为字典,如 `{"购买方":{"名称":"A公司"}}` +"OCR:{\"xxx\":{}}" + +# 特定值为列表,如 `{"货物或应税劳务、服务名称":[{"名称":"A产品"},{"名称":"B产品"}]}` +"OCR:{\"xxx\":[]}" +``` + +可以使用 `paddleocr_vl/tools/process_ner_dataset.py` 生成完整的训练数据,包括随机生成的提示词: + +```bash +python paddleocr_vl/tools/process_ner_dataset.py paddleocr_vl/data/zzsptfp \ + -o paddleocr_vl/output.jsonl \ + -n 10 \ + -p /media/shun/bigdata/Dataset/增值税普通发票 \ + -u /home/aistudio/paddleocr_vl/data/zzsptfp +``` + +之后,拆分训练数据集与验证数据集: + +```bash +python paddleocr_vl/tools/split_jsonl.py paddleocr_vl/output.jsonl \ + paddleocr_vl/output \ + --train_ratio 0.9 \ + --seed 123 +``` + +最终生成的数据参考如下: + +```json +{ + "image_info": [ + { + "matched_text_index": 0, + "image_url": "/home/aistudio/paddleocr_vl/data/zzsptfp/zzsptfp/b175.jpg" + } + ], + "text_info": [ + { + "text": "OCR:{\"发票名称\": \"\"}", + "tag": "mask" + }, + { + "text": "{\"发票名称\": \"广东增值税专用发票\"}", + "tag": "no_mask" + } + ] +} +``` + +生成的训练数据与 [PaddleOCR-VL-0.9B SFT](https://github.com/PaddlePaddle/ERNIE/blob/release/v1.4/docs/paddleocr_vl_sft_zh.md) 不同处有: + +- `mask` 的 `text` 不仅仅是 `OCR:` ,还包括之后需要抽取的字段信息 +- `no_mask` 的 `text` 是完整的 `JSON` 格式信息,而不是一段纯文本 + +### 3. 配置文件示例 + +```yaml +### data +train_dataset_type: "erniekit" +eval_dataset_type: "erniekit" +train_dataset_path: "/home/aistudio/paddleocr_vl/output_train.jsonl" +train_dataset_prob: "1.0" +eval_dataset_path: "/home/aistudio/paddleocr_vl/output_val.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 16384 +num_samples_each_epoch: 6000000 +use_pic_id: False +sft_replace_ids: True +sft_image_normalize: True +sft_image_rescale: True +image_dtype: "float32" + +### model +model_name_or_path: "/home/aistudio/paddleocr_vl/paddleocr_vl_model" +fine_tuning: Full +multimodal: True +use_flash_attention: True +use_sparse_flash_attn: True + +### finetuning +# base +stage: OCR-VL-SFT +seed: 23 +do_train: True +# do_eval: True +distributed_dataloader: False +dataloader_num_workers: 8 +prefetch_factor: 10 +batch_size: 1 +packing_size: 8 +packing: True +padding: False +num_train_epochs: 2 +max_steps: 80 +# eval_batch_size: 1 +# eval_iters: 50 +# eval_steps: 100 +# evaluation_strategy: steps +save_steps: 20 +save_total_limit: 5 +save_strategy: steps +logging_steps: 1 +release_grads: True +gradient_accumulation_steps: 8 +logging_dir: /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT/tensorboard_logs/ +output_dir: /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT +disable_tqdm: True + +# train +warmup_steps: 1 +learning_rate: 5.0e-6 +lr_scheduler_type: cosine +min_lr: 5.0e-7 +layerwise_lr_decay_bound: 1.0 +from_scratch: 0 + +# optimizer +weight_decay: 0.1 +adam_epsilon: 1.0e-8 +adam_beta1: 0.9 +adam_beta2: 0.95 + +# performance +tensor_parallel_degree: 1 +pipeline_parallel_degree: 1 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: False +pipeline_parallel_config: enable_delay_scale_loss enable_release_grads disable_partial_send_recv +recompute: True +recompute_granularity: "full" +recompute_use_reentrant: True +compute_type: bf16 +fp16_opt_level: O2 +disable_ckpt_quant: True +# amp_master_grad: True +amp_custom_white_list: + - lookup_table + - lookup_table_v2 + - flash_attn + - matmul + - matmul_v2 + - fused_gemm_epilogue +amp_custom_black_list: + - reduce_sum + - softmax_with_cross_entropy + - c_softmax_with_cross_entropy + - elementwise_div + - sin + - cos +unified_checkpoint: True +# unified_checkpoint_config: async_save +convert_from_hf: True +save_to_hf: True +``` diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/logs.png b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/logs.png new file mode 100644 index 000000000..ce30ff7a7 Binary files /dev/null and b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/logs.png differ diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/raw.png b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/raw.png new file mode 100644 index 000000000..fb31d9d4d Binary files /dev/null and b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/raw.png differ diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/re.jpg b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/re.jpg new file mode 100644 index 000000000..a80ea6c57 Binary files /dev/null and b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/re.jpg differ diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/sft.png b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/sft.png new file mode 100644 index 000000000..996f3ff76 Binary files /dev/null and b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/images/sft.png differ diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/test.jpg b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/test.jpg new file mode 100644 index 000000000..72300f91d Binary files /dev/null and b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/test.jpg differ diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp/b0.jpg b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp/b0.jpg new file mode 100644 index 000000000..733970832 Binary files /dev/null and b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp/b0.jpg differ diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp/b1.jpg b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp/b1.jpg new file mode 100644 index 000000000..d659bfce9 Binary files /dev/null and b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp/b1.jpg differ diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp/b2.jpg b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp/b2.jpg new file mode 100644 index 000000000..fb7295956 Binary files /dev/null and b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp/b2.jpg differ diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp_ner/b0_ner.json b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp_ner/b0_ner.json new file mode 100644 index 000000000..e16953bd8 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp_ner/b0_ner.json @@ -0,0 +1,55 @@ +{ + "image": "/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp/b0.jpg", + "data": { + "发票名称": "广东增值税专用发票", + "发票编号": "12271524", + "发票代码": "4400154130", + "开票日期": "2016年06月12日", + "购买方": { + "名称": "深圳市购机汇网络有限公司", + "纳税人识别号": "440300083885931", + "地址、电话": "深圳市龙华新区民治街道民治大道展滔科技大厦A12070755-23806606", + "开户行及账号": "中国工商银行股份有限公司深圳园岭支行4000024709200172809" + }, + "密码区": "<<139 -<5//81>84974<00+7>2*0*53-+ +125*++9+-///5-7+/-0>8<9815 5<3/8*+//81/84+>6>4*36>4538", + "货物或应税劳务、服务名称": [ + { + "名称": "小米 红米3 全网通版 时尚金色", + "规格型号": "红米3", + "单位": "个", + "数量": "5", + "单价": "597.43589744", + "金额": "2987.18", + "税率": "17%", + "税额": "507.82" + }, + { + "名称": "移动联通电信4G手机 双卡双待", + "规格型号": "", + "单位": "", + "数量": "", + "单价": "", + "金额": "", + "税率": "", + "税额": "" + } + ], + "合计": { + "金额": "¥2987.18", + "税额": "¥507.82" + }, + "价税合计(大写)": "叁仟肆佰玖拾伍圆整", + "价税合计(小写)": "¥3495.00", + "销售方": { + "名称": "广州晶东贸易有限公司", + "纳税人识别号": "91440101664041243T", + "地址、电话": "广州市黄埔区九龙镇九龙工业园凤凰三横路99号 66215500", + "开户行及账号": "工行北京路支行3602000919200384952" + }, + "备注": "dd42982413947(00001,1952)7996有限", + "收款人": "王梅", + "复核": "张雪", + "开票人": "陈秋燕", + "销售方(章)": "广州晶东贸易有限公司 发票专用章" + } +} diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp_ner/b1_ner.json b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp_ner/b1_ner.json new file mode 100644 index 000000000..55844ceeb --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp_ner/b1_ner.json @@ -0,0 +1,41 @@ +{ + "image": "/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp/b1.jpg", + "data": { + "发票名称": "广东增值税专用发票", + "发票编号": "4400154130", + "发票代码": "12270242", + "开票日期": "2016年06月12日", + "购买方": { + "名称": "深圳市购机汇网络有限公司", + "纳税人识别号": "440300083885931", + "地址电话": "深圳市龙华新区民治街道民治大道展滔科技大厦A12070755-23806606", + "开户行及账号": "中国工商银行股份有限公司深圳园岭支行4000024709200172809" + }, + "密码区": "6/**3-02848*6>7137+ -<332/4845/*-2714*895**9768 />*0497-4/<377816+5+761/--5 127<8**32/4+45<4933///8>*48", + "货物或应税劳务服务名称": "小米 红米3 全网通版 时尚金色 移动联通电信4G手机 双卡双待 折扣(59.456%)", + "规格型号": "红米3", + "单位": "个", + "数量": "", + "单价": "597.43569744", + "金额": "2987.18", + "税率": "17%", + "税额": "507.82", + "合计": { + "金额": "¥1211.11", + "税额": "¥205.89", + "价税合计(大写)": "壹仟肆佰壹拾柒圆整", + "价税合计(小写)": "¥1417.00" + }, + "销售方": { + "名称": "广州晶东贸易有限公司", + "纳税人识别号": "91440101664041243T", + "地址电话": "广州市黄埔区九龙镇九龙工业园凤凰三横路99号 56215500", + "开户行及账号": "工行北京路支行3602000919200384952" + }, + "备注": "dd42981320128(00001,1956912801)", + "收款人": "王梅", + "复核": "张雪", + "开票人": "陈秋燕", + "销售方(章)": "广州晶东贸易有限公司 发票专用章" + } +} diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp_ner/b2_ner.json b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp_ner/b2_ner.json new file mode 100644 index 000000000..2a0e8ab8b --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/data/zzsptfp/zzsptfp_ner/b2_ner.json @@ -0,0 +1,46 @@ +{ + "image": "/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp/b2.jpg", + "data": { + "发票信息": { + "发票名称": "广东增值税专用发票", + "发票代码": "4400154130", + "发票号码": "12269558", + "开票日期": "2016年06月12日", + "密码区": "+6><3>439<-311-/*62*0-28+41*12*067>-+*-7-*655<29449<4/5<9**6682-38<73312*047>82<+3757537//", + "机打号码": "", + "机器编号": "" + }, + "购买方信息": { + "名称": "深圳市购机汇网络有限公司", + "纳税人识别号": "440300083885931", + "地址、电话": "深圳市龙华新区民治街道民治大道展滔科技大厦A12070756-23806606", + "开户行及账号": "中国工商银行股份有限公司深圳园岭支行4000024709200172809" + }, + "销售方信息": { + "名称": "广州晶东贸易有限公司", + "纳税人识别号": "91440101664041243T", + "地址、电话": "广州市黄埔区九龙镇九龙工业园凤凰三横路99号 66215500", + "开户行及账号": "工行北京路支行3602000919200384952" + }, + "货物或应税劳务、服务信息": { + "名称": "小米 红米3 全网通版 时尚金色 移动联通电信4G手机 双卡双待", + "规格型号": "红米3", + "单位": "个", + "数量": "5", + "单价": "597.43589744", + "金额": "2987.18", + "税率": "17%", + "税额": "507.82" + }, + "金额合计": { + "金额合计(小写)": "¥3495.00", + "金额合计(大写)": "叁仟肆佰玖拾伍圆整" + }, + "其他信息": { + "收款人": "王梅", + "复核": "张雪", + "开票人": "陈秋燕", + "销售方(章)": "广州晶东贸易有限公司 发票专用章" + } + } +} diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/patch/patch_assert_to_warning.py b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/patch/patch_assert_to_warning.py new file mode 100644 index 000000000..a9d8b9963 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/patch/patch_assert_to_warning.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +""" +找到 paddlex 中的 pipeline.py 文件,将 assert prompt_label 改为 warning 提示。 +""" + +import re +import sys +from pathlib import Path + +try: + import paddlex +except ImportError: + print("错误:未能导入 paddlex,请确保已安装") + sys.exit(1) + + +def main(): + # 根据 paddlex 模块的位置确定目标文件 + paddlex_file = Path(paddlex.__file__) + paddlex_root = paddlex_file.parent + target_file = paddlex_root / "inference/pipelines/paddleocr_vl/pipeline.py" + + # 检查文件是否存在 + if not target_file.exists(): + print(f"错误:文件不存在: {target_file}") + sys.exit(1) + + # 读取文件内容 + with open(target_file, 'r', encoding='utf-8') as f: + content = f.read() + + # 定义要替换的模式 + old_pattern = r''' assert prompt_label\.lower\(\) in \[ + "ocr", + "formula", + "table", + "chart", + \], f"Layout detection is disabled \(use_layout_detection=False\)\. 'prompt_label' must be one of \['ocr', 'formula', 'table', 'chart'\], but got '\{prompt_label\}'\."''' + + # 新的代码(使用 warning 代替 assert) + new_code = ''' if prompt_label.lower() not in [ + "ocr", + "formula", + "table", + "chart", + ]: + import warnings + warnings.warn( + f"Layout detection is disabled (use_layout_detection=False). " + f"'prompt_label' must be one of ['ocr', 'formula', 'table', 'chart'], " + f"but got '{prompt_label}'. Program will continue anyway.", + UserWarning + )''' + + # 执行第一个替换:assert 改为 warning + new_content = re.sub(old_pattern, new_code, content) + + # 检查是否进行了第一个替换 + if new_content == content: + print("警告:未找到匹配的 assert 代码模式,请检查文件内容") + else: + print("成功:已将 assert 改为 warning") + + # 执行第二个替换:text_prompt 改为条件表达式 + old_text_prompt = 'text_prompt = "OCR:"' + new_text_prompt = 'text_prompt = "OCR:" if block_label.lower() == "ocr" else block_label; block_label = block_label.lower()' + + if old_text_prompt in new_content: + new_content = new_content.replace(old_text_prompt, new_text_prompt) + print("成功:已将 text_prompt 改为条件表达式") + else: + print("警告:未找到 text_prompt 的代码") + + # 执行第三个替换:label 移除 .lower() + old_label = '"label": prompt_label.lower(),' + new_label = '"label": prompt_label,' + + if old_label in new_content: + new_content = new_content.replace(old_label, new_label) + print("成功:已将 label 改为 prompt_label") + else: + print("警告:未找到 label 的代码") + + # 写回文件 + try: + with open(target_file, 'w', encoding='utf-8') as f: + f.write(new_content) + print(f"成功:已完成所有 patch,文件位置: {target_file}") + except Exception as e: + print(f"错误:写入文件时出现异常: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/stf_config/run_ocr_vl_sft_16k.yaml b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/stf_config/run_ocr_vl_sft_16k.yaml new file mode 100644 index 000000000..e148a8147 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/stf_config/run_ocr_vl_sft_16k.yaml @@ -0,0 +1,97 @@ +### data +train_dataset_type: "erniekit" +eval_dataset_type: "erniekit" +train_dataset_path: "/home/aistudio/paddleocr_vl/output_train.jsonl" +train_dataset_prob: "1.0" +eval_dataset_path: "/home/aistudio/paddleocr_vl/output_val.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 16384 +num_samples_each_epoch: 6000000 +use_pic_id: False +sft_replace_ids: True +sft_image_normalize: True +sft_image_rescale: True +image_dtype: "float32" + +### model +model_name_or_path: "/home/aistudio/paddleocr_vl/paddleocr_vl_model" +fine_tuning: Full +multimodal: True +use_flash_attention: True +use_sparse_flash_attn: True + +### finetuning +# base +stage: OCR-VL-SFT +seed: 23 +do_train: True +# do_eval: True +distributed_dataloader: False +dataloader_num_workers: 8 +prefetch_factor: 10 +batch_size: 1 +packing_size: 8 +packing: True +padding: False +num_train_epochs: 2 +max_steps: 80 +# eval_batch_size: 1 +# eval_iters: 50 +# eval_steps: 100 +# evaluation_strategy: steps +save_steps: 20 +save_total_limit: 5 +save_strategy: steps +logging_steps: 1 +release_grads: True +gradient_accumulation_steps: 8 +logging_dir: /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT/tensorboard_logs/ +output_dir: /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT +disable_tqdm: True + +# train +warmup_steps: 1 +learning_rate: 5.0e-6 +lr_scheduler_type: cosine +min_lr: 5.0e-7 +layerwise_lr_decay_bound: 1.0 +from_scratch: 0 + +# optimizer +weight_decay: 0.1 +adam_epsilon: 1.0e-8 +adam_beta1: 0.9 +adam_beta2: 0.95 + +# performance +tensor_parallel_degree: 1 +pipeline_parallel_degree: 1 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: False +pipeline_parallel_config: enable_delay_scale_loss enable_release_grads disable_partial_send_recv +recompute: True +recompute_granularity: "full" +recompute_use_reentrant: True +compute_type: bf16 +fp16_opt_level: O2 +disable_ckpt_quant: True +# amp_master_grad: True +amp_custom_white_list: + - lookup_table + - lookup_table_v2 + - flash_attn + - matmul + - matmul_v2 + - fused_gemm_epilogue +amp_custom_black_list: + - reduce_sum + - softmax_with_cross_entropy + - c_softmax_with_cross_entropy + - elementwise_div + - sin + - cos +unified_checkpoint: True +# unified_checkpoint_config: async_save +convert_from_hf: True +save_to_hf: True diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/stf_config/run_ocr_vl_sft_16k.yaml.backup b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/stf_config/run_ocr_vl_sft_16k.yaml.backup new file mode 100644 index 000000000..fa1b51551 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/stf_config/run_ocr_vl_sft_16k.yaml.backup @@ -0,0 +1,97 @@ +### data +train_dataset_type: "erniekit" +eval_dataset_type: "erniekit" +train_dataset_path: "./examples/data/ocr_vl_sft-train_Bengali.jsonl" +train_dataset_prob: "1.0" +eval_dataset_path: "./examples/data/ocr_vl_sft-test_Bengali.jsonl" +eval_dataset_prob: "1.0" +max_seq_len: 16384 +num_samples_each_epoch: 6000000 +use_pic_id: False +sft_replace_ids: True +sft_image_normalize: True +sft_image_rescale: True +image_dtype: "float32" + +### model +model_name_or_path: PaddlePaddle/PaddleOCR-VL +fine_tuning: Full +multimodal: True +use_flash_attention: True +use_sparse_flash_attn: True + +### finetuning +# base +stage: OCR-VL-SFT +seed: 23 +do_train: True +# do_eval: True +distributed_dataloader: False +dataloader_num_workers: 8 +prefetch_factor: 10 +batch_size: 1 +packing_size: 8 +packing: True +padding: False +num_train_epochs: 2 +max_steps: 926 +# eval_batch_size: 1 +# eval_iters: 50 +# eval_steps: 100 +# evaluation_strategy: steps +save_steps: 200 +save_total_limit: 5 +save_strategy: steps +logging_steps: 1 +release_grads: True +gradient_accumulation_steps: 8 +logging_dir: ./PaddleOCR-VL-SFT-Bengali/tensorboard_logs/ +output_dir: ./PaddleOCR-VL-SFT-Bengali +disable_tqdm: True + +# train +warmup_steps: 10 +learning_rate: 5.0e-6 +lr_scheduler_type: cosine +min_lr: 5.0e-7 +layerwise_lr_decay_bound: 1.0 +from_scratch: 0 + +# optimizer +weight_decay: 0.1 +adam_epsilon: 1.0e-8 +adam_beta1: 0.9 +adam_beta2: 0.95 + +# performance +tensor_parallel_degree: 1 +pipeline_parallel_degree: 1 +sharding_parallel_degree: 1 +sharding: stage1 +sequence_parallel: False +pipeline_parallel_config: enable_delay_scale_loss enable_release_grads disable_partial_send_recv +recompute: True +recompute_granularity: "full" +recompute_use_reentrant: True +compute_type: bf16 +fp16_opt_level: O2 +disable_ckpt_quant: True +# amp_master_grad: True +amp_custom_white_list: + - lookup_table + - lookup_table_v2 + - flash_attn + - matmul + - matmul_v2 + - fused_gemm_epilogue +amp_custom_black_list: + - reduce_sum + - softmax_with_cross_entropy + - c_softmax_with_cross_entropy + - elementwise_div + - sin + - cos +unified_checkpoint: True +# unified_checkpoint_config: async_save +convert_from_hf: True +save_to_hf: True diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/extract_ner/batch_extract_ner.py b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/extract_ner/batch_extract_ner.py new file mode 100644 index 000000000..3fd019c31 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/extract_ner/batch_extract_ner.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +批量图片NER提取脚本 +遍历指定目录中的图片文件,使用extract_ner.py提取JSON格式数据并保存 +""" + +import json +import os +import sys +from pathlib import Path +from typing import List, Dict, Any +import argparse + +from extract_ner import MultimodalImageRecognizer + + +def get_image_files(directory: str, extensions: List[str] = None) -> List[str]: + """ + 获取目录中的所有图片文件 + + Args: + directory: 要搜索的目录路径 + extensions: 支持的图片扩展名列表,默认为常见格式 + + Returns: + 图片文件路径列表 + """ + if extensions is None: + extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'] + + image_files = [] + directory_path = Path(directory) + + if not directory_path.exists(): + print(f"错误:目录不存在 - {directory}") + return image_files + + # 递归搜索所有图片文件 + for ext in extensions: + pattern = f"**/*{ext}" + image_files.extend(directory_path.glob(pattern)) + # 同时搜索大写扩展名 + pattern_upper = f"**/*{ext.upper()}" + image_files.extend(directory_path.glob(pattern_upper)) + + # 转换为字符串路径并去重 + image_paths = list(set(str(path) for path in image_files)) + image_paths.sort() + + return image_paths + + +def process_single_image(recognizer: MultimodalImageRecognizer, image_path: str, output_dir: str) -> bool: + """ + 处理单个图片文件 + + Args: + recognizer: 图片识别器实例 + image_path: 图片文件路径 + output_dir: 输出目录 + + Returns: + 处理成功返回True,失败返回False + """ + try: + # 生成输出文件名 + image_name = Path(image_path).stem + output_file = os.path.join(output_dir, f"{image_name}_ner.json") + + # 检查输出文件是否已存在,如果存在则跳过处理 + if os.path.exists(output_file): + print(f"跳过已处理的图片: {image_path} (输出文件已存在: {output_file})") + return True + + print(f"正在处理: {image_path}") + + # 调用extract_ner的功能进行图片识别 + result = recognizer.analyze_image(image_path, "document") + + # 检查提取是否成功 + if not result.get("success", False): + print(f"图片 {image_path} 提取失败,跳过生成文件") + return False + + # 构建输出数据结构,只保留data字段 + output_data = { + "image": image_path, + "data": result.get("data", {}) + } + + # 保存结果 + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(output_data, f, ensure_ascii=False, indent=2) + + print(f"保存结果到: {output_file}") + return True + + except Exception as e: + print(f"处理图片 {image_path} 时发生错误: {e}") + return False + + +def batch_process_images(input_dir: str, output_dir: str = None) -> None: + """ + 批量处理图片 + + Args: + input_dir: 输入目录路径 + output_dir: 输出目录路径,如果为None则在输入目录下创建output子目录 + """ + # 设置输出目录 + if output_dir is None: + output_dir = os.path.join(input_dir, "output") + + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + + # 获取所有图片文件 + print(f"正在扫描目录: {input_dir}") + image_files = get_image_files(input_dir) + + if not image_files: + print("未找到任何图片文件") + return + + print(f"找到 {len(image_files)} 个图片文件") + + # 初始化识别器 + try: + recognizer = MultimodalImageRecognizer() + except Exception as e: + print(f"初始化识别器失败: {e}") + return + + # 处理每个图片文件 + success_count = 0 + total_count = len(image_files) + failed_images = [] # 记录失败的图片路径 + + for image_path in image_files: + if process_single_image(recognizer, image_path, output_dir): + success_count += 1 + else: + failed_images.append(image_path) + + # 输出处理结果统计 + print(f"\n处理完成!") + print(f"总文件数: {total_count}") + print(f"成功处理: {success_count}") + print(f"失败数量: {total_count - success_count}") + print(f"结果保存在: {output_dir}") + + # 打印所有失败的图片路径 + if failed_images: + print(f"\n失败的图片路径列表:") + for failed_path in failed_images: + print(f" - {failed_path}") + else: + print("\n没有失败的图片") + + +def find_deepest_folders(root_dir: str) -> List[str]: + """ + 递归查找最底层的文件夹(不包含任何子文件夹的文件夹) + + Args: + root_dir: 根目录路径 + + Returns: + 最底层文件夹路径列表 + """ + deepest_folders = [] + + for dirpath, dirnames, filenames in os.walk(root_dir): + # 如果当前目录没有子文件夹,则认为是最底层文件夹 + if not dirnames: + # 排除以 _ner 结尾的文件夹 + if not dirpath.endswith('_ner'): + deepest_folders.append(dirpath) + + return deepest_folders + + +def main(): + """主函数""" + # parser = argparse.ArgumentParser(description='批量图片NER提取工具') + # parser.add_argument('input_dir', help='输入图片目录路径') + # parser.add_argument('-o', '--output', help='输出目录路径(可选)') + + # args = parser.parse_args() + + # input_dir = args.input_dir + # output_dir = args.output + # input_dir = '/media/shun/bigdata/Dataset/机动车发票/train' + # output_dir = '/media/shun/bigdata/Dataset/机动车发票/train_ner' + + # input_dir = '/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp' + # output_dir = '/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp_ner' + + # input_dir = '/media/shun/bigdata/Dataset/ocr/基于OCR的表单识别数据集/XFUND_ori/zh.train' + # output_dir = '/media/shun/bigdata/Dataset/ocr/基于OCR的表单识别数据集/XFUND_ori/zh.train_ner' + + # input_dir = '/media/shun/bigdata/Dataset/ocr/基于OCR的表单识别数据集/XFUND_ori/zh.val' + # output_dir = '/media/shun/bigdata/Dataset/ocr/基于OCR的表单识别数据集/XFUND_ori/zh.val_ner' + + # 设置根目录 + root_dir = '/media/shun/bigdata/Dataset/ocr/DkbRrByl/wildreceipt/' + + # 验证根目录 + if not os.path.exists(root_dir): + print(f"错误:根目录不存在 - {root_dir}") + sys.exit(1) + + # 查找所有最底层文件夹 + print(f"正在扫描目录结构,查找最底层文件夹...") + deepest_folders = find_deepest_folders(root_dir) + + if not deepest_folders: + print("未找到任何最底层文件夹") + sys.exit(1) + + print(f"找到 {len(deepest_folders)} 个最底层文件夹") + + # 处理每个最底层文件夹 + for i, input_dir in enumerate(deepest_folders, 1): + print(f"\n[{i}/{len(deepest_folders)}] 处理文件夹: {input_dir}") + + # 生成对应的输出文件夹路径 + output_dir = input_dir + "_ner" + + # 验证输入目录 + if not os.path.exists(input_dir): + print(f"警告:输入目录不存在 - {input_dir},跳过处理") + continue + + # 执行批量处理 + batch_process_images(input_dir, output_dir) + + print(f"\n所有文件夹处理完成!共处理了 {len(deepest_folders)} 个最底层文件夹") + + +if __name__ == "__main__": + main() diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/extract_ner/config.py b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/extract_ner/config.py new file mode 100644 index 000000000..909057c0a --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/extract_ner/config.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +配置文件 +管理API密钥和相关配置 +""" + +import os +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv() + +class Config: + """配置类""" + + # OpenAI API配置 + OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "292a9a1fd77e793cc795e91c02a18dd52b93ab5b") + OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://aistudio.baidu.com/llm/lmapi/v3") + OPENAI_MODEL = os.getenv("OPENAI_MODEL", "ernie-5.0-thinking-preview") + + # 日志配置 + LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") + + @classmethod + def validate_config(cls): + """验证配置""" + # OpenAI配置验证 + if not cls.OPENAI_API_KEY: + print("缺少 OPENAI_API_KEY 配置") + return False + + if not cls.OPENAI_BASE_URL: + print("缺少 OPENAI_BASE_URL 配置") + return False + + if not cls.OPENAI_MODEL: + print("缺少 OPENAI_MODEL 配置") + return False + + return True + + @classmethod + def get_openai_config(cls): + """获取OpenAI配置""" + return { + "api_key": cls.OPENAI_API_KEY, + "base_url": cls.OPENAI_BASE_URL, + "model": cls.OPENAI_MODEL + } + +# 使用示例 +if __name__ == "__main__": + if Config.validate_config(): + print("配置验证通过") + print("OpenAI配置:", Config.get_openai_config()) + else: + print("配置验证失败") diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/extract_ner/extract_ner.py b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/extract_ner/extract_ner.py new file mode 100644 index 000000000..3c59b8294 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/extract_ner/extract_ner.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +多模态图像识别脚本 +通过调用OpenAI接口识别图片信息并返回JSON格式数据 +支持本地图片和多模态大模型处理 +""" + +import base64 +import json +import logging +import os +from pathlib import Path +from typing import Optional, Dict, Any, Union +from openai import OpenAI +from io import BytesIO + +try: + from PIL import Image + HAS_PIL = True +except ImportError: + HAS_PIL = False + +# 导入配置 +from config import Config + +# 配置日志 +logging.basicConfig(level=getattr(logging, Config.LOG_LEVEL)) +logger = logging.getLogger(__name__) + +class MultimodalImageRecognizer: + """多模态图像识别器""" + + def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None): + """ + 初始化多模态图像识别器 + + Args: + api_key: OpenAI API密钥,如果不提供则从配置文件获取 + base_url: API基础URL,如果不提供则从配置文件获取 + model: 模型名称,如果不提供则从配置文件获取 + """ + # 使用配置文件中的默认值 + self.api_key = api_key or Config.OPENAI_API_KEY + self.base_url = base_url or Config.OPENAI_BASE_URL + self.model = model or Config.OPENAI_MODEL + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + logger.info(f"初始化多模态图像识别器,模型: {self.model}") + + def convert_image_to_base64(self, image_path: str) -> Optional[str]: + """ + 将图片文件转换为base64编码,统一转换为JPEG格式 + + Args: + image_path: 图片文件路径 + + Returns: + base64编码字符串,失败时返回None + """ + try: + if not os.path.exists(image_path): + logger.error(f"图片文件不存在: {image_path}") + return None + + # 如果安装了PIL,使用PIL转换图片格式 + if HAS_PIL: + try: + # 打开图片并转换为RGB(处理RGBA或其他格式) + img = Image.open(image_path) + + # 如果图片有RGBA通道,转换为RGB + if img.mode in ('RGBA', 'LA', 'P'): + # 创建白色背景 + rgb_img = Image.new('RGB', img.size, (255, 255, 255)) + # 粘贴原图片到背景上 + if img.mode == 'P': + img = img.convert('RGBA') + rgb_img.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None) + img = rgb_img + elif img.mode != 'RGB': + img = img.convert('RGB') + + # 将图片转换为JPEG字节 + img_buffer = BytesIO() + img.save(img_buffer, format='JPEG', quality=95) + image_data = img_buffer.getvalue() + + logger.info(f"图片已转换为JPEG格式,长度: {len(image_data)}") + except Exception as e: + logger.warning(f"使用PIL转换图片失败: {e},尝试直接读取") + # 如果PIL转换失败,直接读取原始数据 + with open(image_path, 'rb') as image_file: + image_data = image_file.read() + else: + # 如果没有PIL,直接读取原始文件 + logger.warning("未安装PIL库,将直接读取图片文件,建议安装Pillow库以支持格式转换") + with open(image_path, 'rb') as image_file: + image_data = image_file.read() + + # 转换为base64 + base64_encoded = base64.b64encode(image_data).decode('utf-8') + logger.info(f"图片base64转换成功,长度: {len(base64_encoded)}") + return base64_encoded + + except Exception as e: + logger.error(f"转换图片为base64时发生错误: {e}") + return None + + def _get_mime_type(self) -> str: + """ + 获取MIME类型,因为所有图片都转换为JPEG格式 + + Returns: + MIME类型字符串 + """ + return 'image/jpeg' + + def create_multimodal_message(self, text: str, image_input: Union[str, bytes]) -> list: + """ + 创建多模态消息 + + Args: + text: 文本内容 + image_input: 图片路径或base64编码 + + Returns: + 多模态消息列表 + """ + content = [ + {"type": "text", "text": text} + ] + + # 判断输入类型 + if isinstance(image_input, bytes): + # 如果是bytes,转换为base64 + base64_image = base64.b64encode(image_input).decode('utf-8') + image_url = f"data:image/jpeg;base64,{base64_image}" + elif os.path.exists(image_input): + # 如果是文件路径,转换为base64 + base64_image = self.convert_image_to_base64(image_input) + if base64_image: + mime_type = self._get_mime_type() + image_url = f"data:{mime_type};base64,{base64_image}" + else: + logger.error("图片转换失败") + return [{"type": "text", "text": text}] + else: + # 如果是base64或URL,直接使用 + if image_input.startswith('data:'): + image_url = image_input + else: + logger.error("不支持的图片输入格式") + return [{"type": "text", "text": text}] + + content.append({ + "type": "image_url", + "image_url": {"url": image_url} + }) + + return content + + def recognize_image( + self, + image_input: Union[str, bytes], + prompt: str, + system_prompt: str, + max_tokens: int = 2048 + ) -> Dict[str, Any]: + """ + 识别图片信息 + + Args: + image_input: 图片路径、URL或base64编码 + prompt: 用户提示词 + system_prompt: 系统提示词 + max_tokens: 最大令牌数 + + Returns: + 识别结果的JSON格式数据 + """ + try: + # 创建多模态消息 + content = self.create_multimodal_message(prompt, image_input) + + # 构建消息列表 + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content} + ] + + logger.info(f"开始调用API识别图片,模型: {self.model}") + + # 调用API + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=max_tokens, + temperature=0.2 + ) + + # 获取响应内容 + response_content = response.choices[0].message.content + logger.info(f"API响应: {response_content}") + + # 尝试解析JSON响应 + try: + # 预处理响应内容,移除代码块标记 + cleaned_content = self._preprocess_response(response_content) + + # 尝试直接解析JSON + if cleaned_content.strip().startswith('{'): + result = json.loads(cleaned_content) + return { + "success": True, + "data": result, + "raw_response": response_content + } + else: + # 尝试从文本中提取JSON + import re + json_match = re.search(r'\{[\s\S]*\}', cleaned_content) + if json_match: + json_str = json_match.group() + # 尝试修复常见的JSON格式问题 + json_str = self._fix_json_format(json_str) + result = json.loads(json_str) + return { + "success": True, + "data": result, + "raw_response": response_content + } + else: + # 如果无法解析JSON,返回原始文本 + return { + "success": True, + "data": {"content": response_content}, + "raw_response": response_content + } + except json.JSONDecodeError as e: + logger.error(f"JSON解析失败: {e}") + logger.error(f"原始响应内容: {response_content}") + return { + "success": False, + "error": f"JSON解析失败: {e}", + "raw_response": response_content + } + + except Exception as e: + logger.error(f"调用API时发生错误: {e}") + return { + "success": False, + "error": str(e) + } + + def _preprocess_response(self, response_content: str) -> str: + """ + 预处理API响应内容,移除代码块标记等 + + Args: + response_content: 原始响应内容 + + Returns: + 预处理后的内容 + """ + try: + import re + + # 移除代码块标记 + content = re.sub(r'```(?:json)?\s*', '', response_content) + content = re.sub(r'\s*```', '', content) + + # 移除多余的空白行 + content = re.sub(r'\n\s*\n', '\n', content) + + return content.strip() + except Exception as e: + logger.warning(f"响应预处理失败: {e}") + return response_content + + def _fix_json_format(self, json_str: str) -> str: + """ + 尝试修复常见的JSON格式问题 + + Args: + json_str: 可能格式有问题的JSON字符串 + + Returns: + 修复后的JSON字符串 + """ + try: + import re + + # 移除可能的注释 + json_str = re.sub(r'//.*?\n', '', json_str) + json_str = re.sub(r'/\*.*?\*/', '', json_str, flags=re.DOTALL) + + # 修复控制字符问题(在字符串值中)- 改进版本 + def escape_control_chars(match): + # 获取引号内的内容 + content = match.group(1) + # 先将已有的反斜杠进行转义,然后处理其他控制字符 + content = content.replace('\\', '\\\\') # 转义反斜杠 + content = content.replace('"', '\\"') # 转义引号 + content = content.replace('\n', '\\n') # 转义换行符 + content = content.replace('\r', '\\r') # 转义回车符 + content = content.replace('\t', '\\t') # 转义制表符 + return f'"{content}"' + + # 匹配所有包含反斜杠或控制字符的字符串 + json_str = re.sub(r'"([^"\\]*(?:\\.[^"\\]*)*)"', escape_control_chars, json_str) + + # 修复以0开头的数字(在JSON中必须用字符串表示) + json_str = re.sub(r':\s*(0\d+)', r': "\1"', json_str) + + # 尝试修复缺少引号的问题(简单情况) + json_str = re.sub(r'(\w+):', r'"\1":', json_str) + + # 尝试修复单引号问题 + json_str = re.sub(r"'([^']*)'", r'"\1"', json_str) + + # 尝试修复尾部逗号问题 + json_str = re.sub(r',(\s*[}\]])', r'\1', json_str) + + return json_str + except Exception as e: + logger.warning(f"JSON修复失败: {e}") + return json_str + + def analyze_image( + self, + image_input: Union[str, bytes], + analysis_type: str = "document" + ) -> Dict[str, Any]: + """ + 分析图片内容(简化版本) + + Args: + image_input: 图片路径、URL或base64编码 + analysis_type: 分析类型,固定为 "document" + + Returns: + 分析结果的JSON格式数据 + """ + # 统一使用文档分析提示词 + prompt = "请分析这张文档图片中的所有信息,并返回完整的JSON格式数据。如果有的字段没有值,那么保留此字段,值为空。注意:所有的值都以string的形式返回,不要使用数字类型等。" + system_prompt = ''' +你是一个专业的文档分析助手,能够准确分析文档内容并返回结构化的JSON数据。 + +注意:数据的语言与文档的语言保持一致。 +注意:需要保留完整的字段层级关系,不要把所有字段都放到一级字段中。 +注意:JSON数据中不要包含注释,也不需要任何解释或说明。 +注意:对于特殊字符需要进行转义。 + +注意:对于选项字段,只保留所选择的字段值,如果没有选择,则置为空。 +比如,`业务类型` 包括 `账户开户、账户登记` 等选项,文档中`账户登记`是选中状态,则,返回 `{"业务类型":"账户登记"}`,不返回`账户开户`等其他选项。 +再比如,`业务类型` 包括 `账户开户、账户登记` 等选项,文档中没有标记选中的选项,则,返回 `{"业务类型":""}`,也就是说,只保留键,不需要有值。 +选中的样式包括但不限于打勾等情况。 +举例,如果你识别到的是: +``` +{ + "志愿捐献": { + "器官": "☑", # 这里可以是 `√` `X` 等情况 + "眼角膜": "☐", + "其他组织": "☐", + } +} +``` +正确的返回应该是: +``` +{"志愿捐献":"器官"} +``` +而不是 +``` +{ + "志愿捐献": { + "器官": "☑" + } +} +``` +''' + + return self.recognize_image( + image_input=image_input, + prompt=prompt, + system_prompt=system_prompt + ) + + +def main(): + """主函数,演示用法""" + # 验证配置 + if not Config.validate_config(): + print("配置验证失败,请检查配置文件") + return + + # 创建识别器实例 + recognizer = MultimodalImageRecognizer() + + # 示例:识别本地图片 + image_path = "/media/shun/bigdata/Dataset/机动车发票/train/3.jpg" # 替换为实际图片路径 + + if os.path.exists(image_path): + print(f"正在识别图片: {image_path}") + + # 文档分析 + result = recognizer.analyze_image(image_path, "document") + print("文档分析结果:") + print(json.dumps(result, ensure_ascii=False, indent=2)) + + else: + print(f"图片文件不存在: {image_path}") + print("请将图片文件放置在当前目录下,或修改image_path变量") + print("\n配置说明:") + print("1. 复制 .env.example 为 .env") + print("2. 在 .env 文件中设置您的配置") + + +if __name__ == "__main__": + main() diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/process_ner_dataset.py b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/process_ner_dataset.py new file mode 100644 index 000000000..7b9e18135 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/process_ner_dataset.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Script to process NER dataset and convert it to JSONL format. + +This script traverses a parent directory to find subdirectories ending with '_ner', +then processes JSON files ending with '_ner.json' in those subdirectories. +Each input JSON file is converted to a line in the output JSONL file. +""" + +import os +import json +import argparse +import random +from pathlib import Path + + +def _flatten_dict(d, parent_key='', sep='.'): + """ + Flatten a nested dictionary. + + Args: + d (dict): Dictionary to flatten + parent_key (str): Parent key for nested dictionaries + sep (str): Separator for nested keys + + Returns: + dict: Flattened dictionary + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(_flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def _randomly_select_fields(data_field): + """ + Randomly select fields from the data. + + Args: + data_field (dict): Data field to select from + + Returns: + list: List of selected top-level keys + """ + # Get all top-level keys + all_keys = list(data_field.keys()) + + # Randomly select number of fields (at least 1, at most all fields) + num_fields = random.randint(1, len(all_keys)) + + # Randomly select fields + selected_keys = random.sample(all_keys, num_fields) + + return selected_keys + + +def _generate_mask_text(selected_keys, data_field=None): + """ + Generate mask text from selected keys in JSON format. + + Uses different placeholders based on the value type: + - Single values (str, int, float, bool, etc.) → "" + - list → [] + - dict (nested JSON) → {} + + Args: + selected_keys (list): List of selected top-level keys + data_field (dict, optional): Original data field to inspect value types + + Returns: + str: Mask text in JSON format + """ + # Build a dict with selected keys and appropriate placeholders based on value types + mask_dict = {} + + for key in selected_keys: + if data_field and key in data_field: + value = data_field[key] + # Determine placeholder based on value type + if isinstance(value, dict): + mask_dict[key] = {} + elif isinstance(value, list): + mask_dict[key] = [] + else: + # For single values (str, int, float, bool, None, etc.) + mask_dict[key] = "" + else: + # Default to empty string if no data_field provided + mask_dict[key] = "" + + # Convert to JSON string with proper formatting + json_str = json.dumps(mask_dict, ensure_ascii=False) + + return f"OCR:{json_str}" + + +def _generate_no_mask_text(selected_keys, data_field): + """ + Generate no-mask text from selected keys and data in JSON format. + + Args: + selected_keys (list): List of selected top-level keys + data_field (dict): Original data field + + Returns: + str: No-mask text (JSON string with actual values) + """ + # Build a dict with selected keys and their actual values + result_dict = {} + for key in selected_keys: + if key in data_field: + result_dict[key] = data_field[key] + + # Return as formatted JSON string + return json.dumps(result_dict, ensure_ascii=False) + + +def process_ner_dataset(parent_dir, output_file, image_root=None, n_entries=1, common_prefix=None, url_root=None): + """ + Process NER dataset and generate JSONL output. + + Args: + parent_dir (str): Parent directory to traverse + output_file (str): Output JSONL file path + image_root (str, optional): Root directory for image URLs. If provided, + it will be prepended to the relative path. + If not provided, relative path will be used. + n_entries (int, optional): Number of output entries to generate per input file. + Default is 1 (original behavior). + common_prefix (str, optional): Common prefix to strip from image paths. + Default is "" + url_root (str, optional): Root directory to prepend to image_url. + This is applied after processing common_prefix. + """ + # List to store all output data + output_data = [] + + # Set default common_prefix if not provided + if common_prefix is None: + common_prefix = "" + + # Traverse parent directory + parent_path = Path(parent_dir) + + # Find all subdirectories ending with '_ner' recursively + for subdir in parent_path.rglob('*_ner'): + if subdir.is_dir(): + print(f"Processing directory: {subdir.relative_to(parent_path)}") + + # Find JSON files ending with '_ner.json' + for json_file in subdir.glob('*_ner.json'): + print(f"Processing file: {json_file.name}") + + # Read input JSON file + with open(json_file, 'r', encoding='utf-8') as f: + try: + input_data = json.load(f) + + # Extract image URL + image_path = input_data.get('image', '') + # Process image path according to image_root parameter + if image_root: + # If image_root is provided, use it as the root directory + # Remove the common prefix from image_path and prepend image_root + if image_path.startswith(common_prefix): + relative_path = image_path[len(common_prefix):] + image_url = os.path.join(image_root, relative_path).replace(os.sep, '/') + else: + # If the path doesn't match the expected prefix, use it as is + image_url = image_path + else: + # If no image_root is provided, use relative path + if image_path.startswith(common_prefix): + image_url = image_path[len(common_prefix):].lstrip('/') + else: + # Fallback to original relative path processing + if image_path.startswith('/'): + # Find the position of the second '/' to remove the first directory + first_slash = image_path.find('/', 1) + if first_slash != -1: + image_url = image_path[first_slash:].lstrip('/') # Convert to relative path + else: + image_url = image_path # Fallback to original if path format is unexpected + else: + image_url = image_path + + # Prepend url_root or './' based on whether url_root is provided + if url_root: + image_url = os.path.join(url_root, image_url).replace(os.sep, '/') + else: + # If no url_root, prepend './' to make it a relative path + image_url = os.path.join('.', image_url).replace(os.sep, '/') + + # Extract data field + data_field = input_data.get('data', {}) + + # Generate N output entries + for i in range(n_entries): + if i == 0: + # First entry: Keep original behavior (all data fields) + output_item = { + "image_info": [ + { + "matched_text_index": 0, + "image_url": image_url + } + ], + "text_info": [ + { + "text": "OCR:{}", + "tag": "mask" + }, + { + "text": json.dumps(data_field, ensure_ascii=False), + "tag": "no_mask" + } + ] + } + else: + # Subsequent entries: Randomly select fields + selected_fields = _randomly_select_fields(data_field) + mask_text = _generate_mask_text(selected_fields, data_field) + no_mask_text = _generate_no_mask_text(selected_fields, data_field) + + output_item = { + "image_info": [ + { + "matched_text_index": 0, + "image_url": image_url + } + ], + "text_info": [ + { + "text": mask_text, + "tag": "mask" + }, + { + "text": no_mask_text, + "tag": "no_mask" + } + ] + } + + # Add to output data list + output_data.append(output_item) + + except json.JSONDecodeError as e: + print(f"Error decoding JSON in file {json_file}: {e}") + continue + + # Write all output data to JSONL file + with open(output_file, 'w', encoding='utf-8') as f: + for item in output_data: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + print(f"Processed {len(output_data)} files. Output saved to {output_file}") + + +def main(): + parser = argparse.ArgumentParser(description='Process NER dataset and convert to JSONL format') + parser.add_argument('parent_dir', help='Parent directory to traverse') + parser.add_argument('-o', '--output', default='output.jsonl', help='Output JSONL file path') + parser.add_argument('-r', '--image-root', help='Root directory for image URLs') + parser.add_argument('-n', '--n-entries', type=int, default=1, help='Number of output entries to generate per input file') + parser.add_argument('-p', '--prefix', help='Common prefix to strip from image paths') + parser.add_argument('-u', '--url-root', help='Root directory to prepend to image URLs') + + args = parser.parse_args() + + if not os.path.exists(args.parent_dir): + print(f"Error: Parent directory '{args.parent_dir}' does not exist.") + return + + process_ner_dataset(args.parent_dir, args.output, args.image_root, args.n_entries, args.prefix, args.url_root) + + +if __name__ == '__main__': + main() diff --git a/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/split_jsonl.py b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/split_jsonl.py new file mode 100644 index 000000000..a9f77e167 --- /dev/null +++ b/reports/docs/ernie_tutorial/paddleocr_vl_prompt/paddleocr_vl/tools/split_jsonl.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +Split a JSONL file into train and validation sets. + +Usage: + python split_jsonl.py [--train_ratio 0.8] [--seed 42] + +Example: + python split_jsonl.py data.jsonl output/data --train_ratio 0.8 --seed 42 + + This will create: + - output/data_train.jsonl (80% of data) + - output/data_val.jsonl (20% of data) +""" + +import json +import random +import argparse +import sys +from pathlib import Path + + +def split_jsonl(input_file, output_prefix, train_ratio=0.8, seed=42): + """ + Split a JSONL file into train and validation sets. + + Args: + input_file: Path to input JSONL file + output_prefix: Output file prefix (will create {prefix}_train.jsonl and {prefix}_val.jsonl) + train_ratio: Ratio of training data (default: 0.8) + seed: Random seed for reproducibility (default: 42) + """ + + # Set random seed for reproducibility + random.seed(seed) + + # Read all lines from JSONL file + data = [] + try: + with open(input_file, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + data.append(json.loads(line)) + except json.JSONDecodeError as e: + print(f"Warning: Failed to parse JSON at line {line_num}: {e}", file=sys.stderr) + except FileNotFoundError: + print(f"Error: Input file '{input_file}' not found.", file=sys.stderr) + sys.exit(1) + except Exception as e: + print(f"Error reading file: {e}", file=sys.stderr) + sys.exit(1) + + if not data: + print("Error: No valid data found in input file.", file=sys.stderr) + sys.exit(1) + + # Shuffle data + random.shuffle(data) + + # Split data + split_idx = int(len(data) * train_ratio) + train_data = data[:split_idx] + val_data = data[split_idx:] + + # Create output directory if needed + output_path = Path(output_prefix) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write train set + train_file = f"{output_prefix}_train.jsonl" + try: + with open(train_file, 'w', encoding='utf-8') as f: + for item in train_data: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + print(f"Train set: {train_file} ({len(train_data)} samples)") + except Exception as e: + print(f"Error writing train file: {e}", file=sys.stderr) + sys.exit(1) + + # Write validation set + val_file = f"{output_prefix}_val.jsonl" + try: + with open(val_file, 'w', encoding='utf-8') as f: + for item in val_data: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + print(f"Val set: {val_file} ({len(val_data)} samples)") + except Exception as e: + print(f"Error writing val file: {e}", file=sys.stderr) + sys.exit(1) + + print(f"\nTotal samples: {len(data)}") + print(f"Train ratio: {train_ratio:.1%}") + print(f"Val ratio: {1 - train_ratio:.1%}") + + +def main(): + parser = argparse.ArgumentParser( + description="Split a JSONL file into train and validation sets.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python split_jsonl.py data.jsonl output/data + python split_jsonl.py data.jsonl output/data --train_ratio 0.9 + python split_jsonl.py data.jsonl output/data --train_ratio 0.7 --seed 123 + """ + ) + + parser.add_argument('input_file', help='Input JSONL file path') + parser.add_argument('output_prefix', help='Output file prefix (will create {prefix}_train.jsonl and {prefix}_val.jsonl)') + parser.add_argument('--train_ratio', type=float, default=0.8, + help='Ratio of training data (default: 0.8)') + parser.add_argument('--seed', type=int, default=42, + help='Random seed for reproducibility (default: 42)') + + args = parser.parse_args() + + # Validate train_ratio + if not 0 < args.train_ratio < 1: + print("Error: train_ratio must be between 0 and 1 (exclusive).", file=sys.stderr) + sys.exit(1) + + split_jsonl(args.input_file, args.output_prefix, args.train_ratio, args.seed) + + +if __name__ == '__main__': + main()