Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ERNIE Tiny in model zoo #4011

Merged
merged 16 commits into from
Jan 11, 2023

Conversation

LiuChiachi
Copy link
Contributor

@LiuChiachi LiuChiachi commented Dec 5, 2022

PR types

New features

PR changes

Models & Plans

Description

开源 ERNIE 3.0 Tiny 模型及 端上语义压缩方案

    • model zoo 单测
    • 链接:压缩API文档
    • 移动端
      • 词表量化模型移动端部署验证及性能测试
    • 服务端
      • 不加词表量化模型已经跑通各种后端

附:

案例的背景是意图识别&槽位填充问题
数据集是NLPCC2018车载任务型对话数据集,来源
数据集格式:
image
其中对标签的归纳,左边是意图识别(文本分类),右边一列是槽位填充(序列标注)
image

前处理逻辑input_preprocess输入list of text。输出dict{"query: query_list}
后处理分为intent_cls_postprocessslot_cls_postprocess。分别对应意图识别和槽位填充任务的后处理逻辑
{'intent': 'navigation.cancel_navigation', 'confidence': array([0.12247936], dtype=float32)}
{'value': [['slot': 'singer', 'entity': '周华健', 'pos': [3,4,5]], ['slot': song', 'entity': '花心', 'pos': [7,8]]]}

@paddle-bot
Copy link

paddle-bot bot commented Dec 5, 2022

Thanks for your contribution!

@LiuChiachi LiuChiachi marked this pull request as ready for review January 3, 2023 07:06
@codecov
Copy link

codecov bot commented Jan 3, 2023

Codecov Report

Merging #4011 (ac4b3f0) into develop (5710bbe) will increase coverage by 0.90%.
The diff coverage is 13.23%.

@@             Coverage Diff             @@
##           develop    #4011      +/-   ##
===========================================
+ Coverage    38.71%   39.62%   +0.90%     
===========================================
  Files          421      433      +12     
  Lines        59567    60982    +1415     
===========================================
+ Hits         23064    24165    +1101     
- Misses       36503    36817     +314     
Impacted Files Coverage Δ
paddlenlp/transformers/ofa_utils.py 7.97% <0.00%> (-0.26%) ⬇️
paddlenlp/trainer/trainer.py 59.73% <8.33%> (-0.62%) ⬇️
paddlenlp/trainer/trainer_compress.py 8.94% <10.86%> (+<0.01%) ⬆️
paddlenlp/trainer/compression_args.py 52.72% <100.00%> (+1.78%) ⬆️
paddlenlp/prompt/prompt_utils.py 79.36% <0.00%> (-9.10%) ⬇️
paddlenlp/transformers/ernie/modeling.py 84.90% <0.00%> (-7.41%) ⬇️
paddlenlp/prompt/template.py 76.62% <0.00%> (-3.83%) ⬇️
paddlenlp/transformers/clip/procesing.py 38.77% <0.00%> (-1.97%) ⬇️
paddlenlp/transformers/ernie_vil/procesing.py 38.77% <0.00%> (-1.97%) ⬇️
paddlenlp/prompt/prompt_model.py 71.60% <0.00%> (-1.64%) ⬇️
... and 43 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

* [参考文献](#参考文献)


本项目开源了 **ERNIE 3.0 Tiny v2** 预训练模型及 **端上语义理解压缩方案**。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议不用直接突出Tiny v2的模型,建议是这样的
1)先整体介绍一下ERNIE Tiny模型做的工作,例如主要聚集在模型蒸馏方面,这里可以直接抄ERNIE Tiny 的RP稿
2)再介绍一下ERNIE Tiny v2的升级点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到。现在已经按照如上的建议进行了修改,不知道篇幅是否冗长了些呢


## 预训练模型效果

本项目开源 **ERNIE 3.0 Tiny _Base_ v2** 、**ERNIE 3.0 Tiny _Medium_ v2** 、 **ERNIE 3.0 Tiny _Mini_ v2** 、 **ERNIE 3.0 Tiny _Micro_ v2** 、 **ERNIE 3.0 Tiny _Nano_ v2**、**ERNIE 3.0 Tiny _Pico_ v2** 六个中文模型:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不要直接突出,直接扔v2和v1的模型效果对比就行

LR=5e-5
EPOCHS=30

export finetuned_model=./output/BS${BS}_LR${LR}_${EPOCHS}EPOCHS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到windows用户的问题,这里尽量不要用shell的方式,直接参数固定即可

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢,收到,已经修改,以后也会注意避免这样的写法。

export finetuned_model=./output/BS${BS}_LR${LR}_${EPOCHS}EPOCHS
mkdir $finetuned_model

python train.py \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

脚本的名字改成run_train.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢,已经修改


尽管 ERNIE 3.0 Tiny v2 已提供了效果不错的轻量级模型可以微调后直接使用,但如果有模型部署上线的需求,想要进一步压缩模型体积,降低推理时延,可使用本项目的 **端上语义理解压缩方案** 对上一步微调后的模型进行压缩,为了方便实现,[模型压缩 API](../../../docs/compression.md) 已提供了以下压缩功能。

端上模型压缩流程如下图所示:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是突出一下QAT,同时对打开压缩功能 --do_compress可能要突出一下,因为上下脚本都是run_train.py

- [**ERNIE 3.0-Tiny-_Pico_-v2**](https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_tiny_pico_v2.pdparams) (_4-layer, 312-hidden, 2-heads_)

ERNIE 3.0 Tiny 模型可以用于文本分类、文本推理、实体抽取、问答等各种 NLU 任务中。下表是 ERNIE 3.0 Tiny 模型在 in-domain、out-domain 和 low-resourced 三类数据集上的效果。其中 CLUE 指标可以通过 [PaddleNLP CLUE Benchmark](../../../examples/benchmark/clue) 复现。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有个疑问,咱们这次区分了中文模型和英文模型,这里是不是区分一下

| 原模型 | 82.34 | TBD | TBD | 69.0 |
| 原模型+裁剪(词表+模型宽度) | 82.11 | TBD | TBD | 64.0 |
| 原模型+裁剪(词表+模型宽度)+量化(矩阵乘) | 82.21 | TBD | TBD | 11.0 |
| 原模型+裁剪(词表+模型宽度)+量化(矩阵乘+Embedding) | TBD | TBD | TBD | 5.4 |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 矩阵乘+Embedding 是啥意思了 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里想区分普通量化和词表量化,目前是用针对哪些op 做了量化来表达。QAT是只对矩阵乘op 做了量化,而常说的词表量化是对3个embeddings op量化(不只是word embedding)

@@ -0,0 +1,285 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022->2023

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢, 已修改,这和2023年挂出ERNIE 3.0 Tiny v2的论文也可以对应上

padding_mask |= (input_ids == 2) | (input_ids == 1)
return intent_logits, slot_logits, padding_mask

return intent_logits, slot_logits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥动静态图模型输出不一样了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

动态图训练的时候需要计算metric,需要用到padding_mask这样的信息去评估slot(序列标注)的准确率,而导出的时候不需要padding_mask的计算


import paddle
from model import JointErnie, NLULoss
from utils import compute_metrics, get_label_name, read_example
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

内置模块的import可以放到最后

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我是这么写的,但是又被pre-commit工具修改了



def main():
parser = PdArgumentParser((ModelArguments, DataArguments, CompressionArguments))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的一些config设置有点混乱,需要讨论一下,普通的微调和压缩相关参数混合在一起了

intent_dim=len(intent2id),
slot_dim=len(slot2id),
dropout=model_args.dropout,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种reload我理解是构建两个模型,是不是直接load_dict比较合适

Copy link
Contributor Author

@LiuChiachi LiuChiachi Jan 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个模型继承了ErniePretrainedModel的,可以用from_pretrain来加载参数

@@ -0,0 +1,181 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,2022->2023

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Done

@@ -1257,8 +1257,11 @@ def compute_loss(self, model, inputs, return_outputs=False):
"""
if self.criterion is not None and "labels" in inputs:
labels = inputs.pop("labels")
elif self.criterion is not None and "start_positions" in inputs and "end_positions" in inputs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的改动会对UIE有影响吗?

_dynabert(self, self.model, args.output_dir)

del self.original_model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种model的析构方式看起来不会生效

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

显存不会被释放,但是后面应该可以重利用起这个模型所占的显存而不是直接申请新的。


# from paddlenlp.transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained(output_dir)
# import pdb; pdb.set_trace()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug信息删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢,已经删除

if "vocab_file" in tokenizer.init_config:
tokenizer.init_config.pop("vocab_file")
f = open(os.path.join(output_dir, tokenizer.tokenizer_config_file), "w")
import json
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个直接最前面import就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢,已经修改

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LiuChiachi LiuChiachi merged commit 091888b into PaddlePaddle:develop Jan 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants