diff --git a/docs/advanced_guide/prompt.md b/docs/advanced_guide/prompt.md index b64aed60fe54..20da981169cc 100644 --- a/docs/advanced_guide/prompt.md +++ b/docs/advanced_guide/prompt.md @@ -35,6 +35,7 @@ Prompt API 提供了这类算法实现的基本模块,支持[PET](https://arxi * [实践教程](#实践教程) * [文本分类示例](#文本分类示例) * 其他任务示例(待更新) +* [Reference](#Reference) ## 如何定义模板 @@ -434,39 +435,39 @@ data_args, training_args = parser.parse_args_into_dataclasses() ```python - import paddle - from paddle.metric import Accuracy - - # 损失函数 - criterion = paddle.nn.CrossEntropyLoss() - - # 评估函数 - def compute_metrics(eval_preds): - metric = Accuracy() - correct = metric.compute(paddle.to_tensor(eval_preds.predictions), - paddle.to_tensor(eval_preds.label_ids)) - metric.update(correct) - acc = metric.accumulate() - return {"accuracy": acc} - - # 初始化 - trainer = PromptTrainer(model=prompt_model, - tokenizer=tokenizer, - args=training_args, - criterion=criterion, - train_dataset=data_ds, - eval_dataset=data_ds, - callbacks=None, - compute_metrics=compute_metrics) - - # 训练模型 - if training_args.do_train: - train_result = trainer.train(resume_from_checkpoint=None) - metrics = train_result.metrics - trainer.save_model() - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() +import paddle +from paddle.metric import Accuracy + +# 损失函数 +criterion = paddle.nn.CrossEntropyLoss() + +# 评估函数 +def compute_metrics(eval_preds): + metric = Accuracy() + correct = metric.compute(paddle.to_tensor(eval_preds.predictions), + paddle.to_tensor(eval_preds.label_ids)) + metric.update(correct) + acc = metric.accumulate() + return {"accuracy": acc} + +# 初始化 +trainer = PromptTrainer(model=prompt_model, + tokenizer=tokenizer, + args=training_args, + criterion=criterion, + train_dataset=data_ds, + eval_dataset=data_ds, + callbacks=None, + compute_metrics=compute_metrics) + +# 训练模型 +if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=None) + metrics = train_result.metrics + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() ``` ## 实践教程 @@ -481,6 +482,14 @@ data_args, training_args = parser.parse_args_into_dataclasses() - [多层次文本分类示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/text_classification/hierarchical/few-shot) +## Reference + +- Exploiting Cloze-Questions for Few-Shot Text Classification and Natural Language Inference. [[PDF]](https://arxiv.org/abs/2001.07676) +- GPT Understands, Too. [[PDF]](https://arxiv.org/abs/2103.10385) +- WARP: Word-level Adversarial ReProgramming. [[PDF]](https://aclanthology.org/2021.acl-long.381/) +- RGL: A Simple yet Effective Relation Graph Augmented Prompt-based Tuning Approach for Few-Shot Learning. [[PDF]](https://aclanthology.org/2022.findings-naacl.81/) +- R-Drop: Regularized Dropout for Neural Networks. [[PDF]](https://arxiv.org/abs/2106.14448) + ### 附录