Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
909 changes: 909 additions & 0 deletions reports/docs/ernie_tutorial/paddleocr_vl_prompt/PaddleOCR_VL_SFT.md

Large diffs are not rendered by default.

911 changes: 911 additions & 0 deletions reports/docs/ernie_tutorial/paddleocr_vl_prompt/PaddleOCR_VL_SFT_zh.md

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"image": "/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp/b0.jpg",
"data": {
"发票名称": "广东增值税专用发票",
"发票编号": "12271524",
"发票代码": "4400154130",
"开票日期": "2016年06月12日",
"购买方": {
"名称": "深圳市购机汇网络有限公司",
"纳税人识别号": "440300083885931",
"地址、电话": "深圳市龙华新区民治街道民治大道展滔科技大厦A12070755-23806606",
"开户行及账号": "中国工商银行股份有限公司深圳园岭支行4000024709200172809"
},
"密码区": "<<1<//3*26-++936-9<9*575>39 -<5//81>84974<00+7>2*0*53-+ +125*++9+-///5-7+/-0>8<9815 5<3/8*+//81/84+>6>4*36>4538",
"货物或应税劳务、服务名称": [
{
"名称": "小米 红米3 全网通版 时尚金色",
"规格型号": "红米3",
"单位": "个",
"数量": "5",
"单价": "597.43589744",
"金额": "2987.18",
"税率": "17%",
"税额": "507.82"
},
{
"名称": "移动联通电信4G手机 双卡双待",
"规格型号": "",
"单位": "",
"数量": "",
"单价": "",
"金额": "",
"税率": "",
"税额": ""
}
],
"合计": {
"金额": "¥2987.18",
"税额": "¥507.82"
},
"价税合计(大写)": "叁仟肆佰玖拾伍圆整",
"价税合计(小写)": "¥3495.00",
"销售方": {
"名称": "广州晶东贸易有限公司",
"纳税人识别号": "91440101664041243T",
"地址、电话": "广州市黄埔区九龙镇九龙工业园凤凰三横路99号 66215500",
"开户行及账号": "工行北京路支行3602000919200384952"
},
"备注": "dd42982413947(00001,1952)7996有限",
"收款人": "王梅",
"复核": "张雪",
"开票人": "陈秋燕",
"销售方(章)": "广州晶东贸易有限公司 发票专用章"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"image": "/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp/b1.jpg",
"data": {
"发票名称": "广东增值税专用发票",
"发票编号": "4400154130",
"发票代码": "12270242",
"开票日期": "2016年06月12日",
"购买方": {
"名称": "深圳市购机汇网络有限公司",
"纳税人识别号": "440300083885931",
"地址电话": "深圳市龙华新区民治街道民治大道展滔科技大厦A12070755-23806606",
"开户行及账号": "中国工商银行股份有限公司深圳园岭支行4000024709200172809"
},
"密码区": "6/**3-02848*6</*371/>>7137+ -<332/4845/*-2714*895**9768 />*0497-4/<377816+5+761/--5 127<8**32/4+45<4933///8>*48",
"货物或应税劳务服务名称": "小米 红米3 全网通版 时尚金色 移动联通电信4G手机 双卡双待 折扣(59.456%)",
"规格型号": "红米3",
"单位": "个",
"数量": "",
"单价": "597.43569744",
"金额": "2987.18",
"税率": "17%",
"税额": "507.82",
"合计": {
"金额": "¥1211.11",
"税额": "¥205.89",
"价税合计(大写)": "壹仟肆佰壹拾柒圆整",
"价税合计(小写)": "¥1417.00"
},
"销售方": {
"名称": "广州晶东贸易有限公司",
"纳税人识别号": "91440101664041243T",
"地址电话": "广州市黄埔区九龙镇九龙工业园凤凰三横路99号 56215500",
"开户行及账号": "工行北京路支行3602000919200384952"
},
"备注": "dd42981320128(00001,1956912801)",
"收款人": "王梅",
"复核": "张雪",
"开票人": "陈秋燕",
"销售方(章)": "广州晶东贸易有限公司 发票专用章"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"image": "/media/shun/bigdata/Dataset/增值税普通发票/zzsptfp/b2.jpg",
"data": {
"发票信息": {
"发票名称": "广东增值税专用发票",
"发票代码": "4400154130",
"发票号码": "12269558",
"开票日期": "2016年06月12日",
"密码区": "+6><3>439<-311-/*62*0-28+41*12*067>-+*-7-*655<29449<4/5<9</888/<-093-797>**6682-38<73312*047>82<+3757537//",
"机打号码": "",
"机器编号": ""
},
"购买方信息": {
"名称": "深圳市购机汇网络有限公司",
"纳税人识别号": "440300083885931",
"地址、电话": "深圳市龙华新区民治街道民治大道展滔科技大厦A12070756-23806606",
"开户行及账号": "中国工商银行股份有限公司深圳园岭支行4000024709200172809"
},
"销售方信息": {
"名称": "广州晶东贸易有限公司",
"纳税人识别号": "91440101664041243T",
"地址、电话": "广州市黄埔区九龙镇九龙工业园凤凰三横路99号 66215500",
"开户行及账号": "工行北京路支行3602000919200384952"
},
"货物或应税劳务、服务信息": {
"名称": "小米 红米3 全网通版 时尚金色 移动联通电信4G手机 双卡双待",
"规格型号": "红米3",
"单位": "个",
"数量": "5",
"单价": "597.43589744",
"金额": "2987.18",
"税率": "17%",
"税额": "507.82"
},
"金额合计": {
"金额合计(小写)": "¥3495.00",
"金额合计(大写)": "叁仟肆佰玖拾伍圆整"
},
"其他信息": {
"收款人": "王梅",
"复核": "张雪",
"开票人": "陈秋燕",
"销售方(章)": "广州晶东贸易有限公司 发票专用章"
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/usr/bin/env python3
"""
找到 paddlex 中的 pipeline.py 文件,将 assert prompt_label 改为 warning 提示。
"""

import re
import sys
from pathlib import Path

try:
import paddlex
except ImportError:
print("错误:未能导入 paddlex,请确保已安装")
sys.exit(1)


def main():
# 根据 paddlex 模块的位置确定目标文件
paddlex_file = Path(paddlex.__file__)
paddlex_root = paddlex_file.parent
target_file = paddlex_root / "inference/pipelines/paddleocr_vl/pipeline.py"

# 检查文件是否存在
if not target_file.exists():
print(f"错误:文件不存在: {target_file}")
sys.exit(1)

# 读取文件内容
with open(target_file, 'r', encoding='utf-8') as f:
content = f.read()

# 定义要替换的模式
old_pattern = r''' assert prompt_label\.lower\(\) in \[
"ocr",
"formula",
"table",
"chart",
\], f"Layout detection is disabled \(use_layout_detection=False\)\. 'prompt_label' must be one of \['ocr', 'formula', 'table', 'chart'\], but got '\{prompt_label\}'\."'''

# 新的代码(使用 warning 代替 assert)
new_code = ''' if prompt_label.lower() not in [
"ocr",
"formula",
"table",
"chart",
]:
import warnings
warnings.warn(
f"Layout detection is disabled (use_layout_detection=False). "
f"'prompt_label' must be one of ['ocr', 'formula', 'table', 'chart'], "
f"but got '{prompt_label}'. Program will continue anyway.",
UserWarning
)'''

# 执行第一个替换:assert 改为 warning
new_content = re.sub(old_pattern, new_code, content)

# 检查是否进行了第一个替换
if new_content == content:
print("警告:未找到匹配的 assert 代码模式,请检查文件内容")
else:
print("成功:已将 assert 改为 warning")

# 执行第二个替换:text_prompt 改为条件表达式
old_text_prompt = 'text_prompt = "OCR:"'
new_text_prompt = 'text_prompt = "OCR:" if block_label.lower() == "ocr" else block_label; block_label = block_label.lower()'

if old_text_prompt in new_content:
new_content = new_content.replace(old_text_prompt, new_text_prompt)
print("成功:已将 text_prompt 改为条件表达式")
else:
print("警告:未找到 text_prompt 的代码")

# 执行第三个替换:label 移除 .lower()
old_label = '"label": prompt_label.lower(),'
new_label = '"label": prompt_label,'

if old_label in new_content:
new_content = new_content.replace(old_label, new_label)
print("成功:已将 label 改为 prompt_label")
else:
print("警告:未找到 label 的代码")

# 写回文件
try:
with open(target_file, 'w', encoding='utf-8') as f:
f.write(new_content)
print(f"成功:已完成所有 patch,文件位置: {target_file}")
except Exception as e:
print(f"错误:写入文件时出现异常: {e}")
sys.exit(1)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
### data
train_dataset_type: "erniekit"
eval_dataset_type: "erniekit"
train_dataset_path: "/home/aistudio/paddleocr_vl/output_train.jsonl"
train_dataset_prob: "1.0"
eval_dataset_path: "/home/aistudio/paddleocr_vl/output_val.jsonl"
eval_dataset_prob: "1.0"
max_seq_len: 16384
num_samples_each_epoch: 6000000
use_pic_id: False
sft_replace_ids: True
sft_image_normalize: True
sft_image_rescale: True
image_dtype: "float32"

### model
model_name_or_path: "/home/aistudio/paddleocr_vl/paddleocr_vl_model"
fine_tuning: Full
multimodal: True
use_flash_attention: True
use_sparse_flash_attn: True

### finetuning
# base
stage: OCR-VL-SFT
seed: 23
do_train: True
# do_eval: True
distributed_dataloader: False
dataloader_num_workers: 8
prefetch_factor: 10
batch_size: 1
packing_size: 8
packing: True
padding: False
num_train_epochs: 2
max_steps: 80
# eval_batch_size: 1
# eval_iters: 50
# eval_steps: 100
# evaluation_strategy: steps
save_steps: 20
save_total_limit: 5
save_strategy: steps
logging_steps: 1
release_grads: True
gradient_accumulation_steps: 8
logging_dir: /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT/tensorboard_logs/
output_dir: /home/aistudio/paddleocr_vl/PaddleOCR-VL-SFT
disable_tqdm: True

# train
warmup_steps: 1
learning_rate: 5.0e-6
lr_scheduler_type: cosine
min_lr: 5.0e-7
layerwise_lr_decay_bound: 1.0
from_scratch: 0

# optimizer
weight_decay: 0.1
adam_epsilon: 1.0e-8
adam_beta1: 0.9
adam_beta2: 0.95

# performance
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
sharding_parallel_degree: 1
sharding: stage1
sequence_parallel: False
pipeline_parallel_config: enable_delay_scale_loss enable_release_grads disable_partial_send_recv
recompute: True
recompute_granularity: "full"
recompute_use_reentrant: True
compute_type: bf16
fp16_opt_level: O2
disable_ckpt_quant: True
# amp_master_grad: True
amp_custom_white_list:
- lookup_table
- lookup_table_v2
- flash_attn
- matmul
- matmul_v2
- fused_gemm_epilogue
amp_custom_black_list:
- reduce_sum
- softmax_with_cross_entropy
- c_softmax_with_cross_entropy
- elementwise_div
- sin
- cos
unified_checkpoint: True
# unified_checkpoint_config: async_save
convert_from_hf: True
save_to_hf: True
Loading