forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Hackathon 5th No.73] ToT (PaddlePaddle#7660)
* Hackathon TASK73 ToT 1. finish meta/llama2 version * update readme tutorial * modify according to Lint * modify according Link 1. resolve one unused variable * Delete LICENSE * Update LICENSE * black format * isort format * Update search_crosswords-dfs.ipynb * update files formats * Update LICENSE * Update LICENSE * Update LICENSE * Update LICENSE * delete test data * delete some unnecessary files 1. delete some unnecessary files according to comments. * add paddlenlp-llama2 1. add llama2 in paddlenlp * fix one bug * fix outputs bug 1. format data structure * delete meta/llama2 * modify according to comments 1. add acknow into readme 2.change png into url in readme 3. add all the models supported by paddlenlp * change according to comments * Delete .gitignore * Create .gitignore * Move directory * Add tree of thoughts scripts * add first dir * add note * Update README.md add test results of facebook/llama-2-7b-chat and llama-2-13b-chat * Update requirements.txt delete unnecessary packages * Update demo.py add Ernie * Update .gitignore delete pyproject.toml * Update run.py add Ernie * Update __init__.py add Ernie * chat templates * add Ernie * Update llama.py 兼容Ernie * Update bfs.py 兼容Ernie * Update models.py 兼容Ernie * Update run.py * format style * format style * format style * format style * format style * format style * format style * format style * 删掉重复的“测试结果” * 删除Ernie的token,设置环境变量解决 * format style * format style * 删除注释掉的代码 --------- Co-authored-by: root <[email protected]>
- Loading branch information
1 parent
3a42280
commit cdfa861
Showing
20 changed files
with
1,989 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,4 +123,4 @@ FETCH_HEAD | |
|
||
# vscode | ||
.vscode | ||
./ppdiffusers/ppdiffusers/version.py | ||
./ppdiffusers/ppdiffusers/version.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# Tree of Thoughts (ToT) | ||
|
||
 | ||
|
||
论文[Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) 的代码 prompts 和 model outputs 实现。 | ||
|
||
|
||
## Setup | ||
1. 安装 | ||
```bash | ||
git clone [email protected]:PaddlePaddle/PaddleNLP.git | ||
cd pipelines/examples/tree-of-thought/ | ||
pip install -r requirements.txt | ||
``` | ||
|
||
2. 请从 https://github.com/ErnestinaQiu/tree-of-thought-llm/tree/master/src/tot/data 获取测试数据,并放置在 pipelines/examples/tree-of-thought/tree/master/src/tot/data | ||
|
||
## Quick Start | ||
以下是脚本,该脚本尝试使用4 5 6 10解决24点游戏(由于使用llama-7b-chat,可能会稍慢一些) | ||
|
||
|
||
在目录 pipelines/examples/agents/tree-of-thought-llm 下运行 | ||
|
||
``` | ||
python demo.py | ||
``` | ||
|
||
以下是文档的中文翻译: | ||
|
||
```python | ||
import argparse | ||
from tot.methods.bfs import solve | ||
from tot.tasks.game24 import Game24Task | ||
|
||
args = argparse.Namespace(backend='llama-2-7b-chat', temperature=0.6, task='game24', naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5) | ||
|
||
task = Game24Task() | ||
ys, infos = solve(args, task, 900) | ||
print(ys[0]) | ||
``` | ||
|
||
输出结果可能如下(注意它不是确定性的,有时输出可能是错误的): | ||
``` | ||
10 - 4 = 6 (left: 5 6 6) | ||
5 * 6 = 30 (left: 6 30) | ||
30 - 6 = 24 (left: 24) | ||
Answer: (5 * (10 - 4)) - 6 = 24 | ||
``` | ||
|
||
## 论文实验 | ||
|
||
通过 ``sh scripts/{game24, text, crosswords}/{standard_sampling, cot_sampling, bfs}.sh`` 运行实验。 | ||
|
||
非常简单的 ``run.py`` 实现了 ToT + BFS 算法,以及朴素的 IO/CoT 抽样。一些关键参数: | ||
|
||
- ``--naive_run``: 如果为 True,则运行朴素的 IO/CoT 抽样,而不是 ToT + BFS。 | ||
- ``--prompt_sample`` (choices=[``standard``, ``cot``]): 抽样提示 | ||
- ``--method_generate`` (choices=[``sample``, ``propose``]): 思维生成器,是抽样独立思维(用于创意写作)还是提出连续思维(用于24点游戏) | ||
- ``--method_evaluate`` (choices=[``value``, ``vote``]): 状态评估器,是独立使用价值状态(用于24点游戏)还是对状态进行投票(用于创意写作) | ||
- ``--n_generate_sample``: 提示进行思维生成的次数 | ||
- ``--n_evaluate_sample``: 提示进行状态评估的次数 | ||
- ``--n_select_sample``: 每一步保留的状态数量(即论文中的 ``b`` 在 ToT + BFS 算法中) | ||
|
||
## 论文轨迹 | ||
|
||
``logs/`` 包含论文实验的所有轨迹,除了 ``logs/game24/gpt-4_0.7_propose1_value3_greedy5_start900_end1000.json``,该文件是在论文之后重新生成的(因为原始实验是在笔记本中进行的),由于 GPT 解码中的随机性,得分从原来的 74\% 下降到了 69\%。我们希望将来汇总多次运行以考虑抽样随机性,并更新论文,但这不应影响论文的主要结论。 | ||
|
||
## 论文实验的任务脚本 | ||
### crosswords(填字游戏) | ||
``` | ||
python run.py \ | ||
--task crosswords \ # 任务名:填字游戏 | ||
--task_start_index 0 \ # 填字游戏任务数据集中开始的序号 | ||
--task_end_index 20 \ # 填字游戏任务数据集中结束的序号 | ||
--naive_run \ | ||
--prompt_sample cot \ # 抽样提示的方式, cot | ||
--n_generate_sample 10 # 提示进行思维生成的次数, 10次 | ||
``` | ||
|
||
``` | ||
python run.py \ | ||
--task crosswords \ | ||
--task_start_index 0 \ | ||
--task_end_index 20 \ | ||
--naive_run \ # 运行朴素的 IO/CoT 抽样 | ||
--prompt_sample standard \ # 抽样提示的方式, standard | ||
--n_generate_sample 10 | ||
``` | ||
|
||
### game24(24点游戏) | ||
``` | ||
python run.py \ | ||
--task game24 \ # 任务名:24点游戏 | ||
--task_start_index 900 \ # 24点游戏任务数据集中开始的序号 | ||
--task_end_index 1000 \ # 24点游戏任务数据集中结束的序号 | ||
--method_generate propose \ # 思维生成器,是抽样独立思维(用于创意写作)还是提出连续思维(用于24点游戏) | ||
--method_evaluate value \ # 状态评估器,独立使用价值状态(用于24点游戏) | ||
--method_select greedy \ # 策略选择,"greedy"(贪婪) | ||
--n_evaluate_sample 3 \ # 提示进行状态评估的次数 | ||
--n_select_sample 5 \ # 每一步保留的状态数量(即论文中的 ``b`` 在 ToT + BFS 算法中) | ||
``` | ||
|
||
``` | ||
python run.py \ | ||
--task game24 \ | ||
--task_start_index 900 \ | ||
--task_end_index 1000 \ | ||
--naive_run \ # 运行朴素的 IO/CoT 抽样 | ||
--prompt_sample cot \ # 抽样提示的方式, cot | ||
--n_generate_sample 100 \ | ||
``` | ||
|
||
``` | ||
python run.py \ | ||
--task game24 \ | ||
--task_start_index 900 \ | ||
--task_end_index 1000 \ | ||
--naive_run \ | ||
--prompt_sample standard \ | ||
--n_generate_sample 100 \ | ||
``` | ||
|
||
### text(创意写作) | ||
``` | ||
python run.py \ | ||
--task text \ # 任务名:创意写作 | ||
--task_start_index 0 \ # 创意写作任务数据集中开始的序号 | ||
--task_end_index 100 \ # 创意写作任务数据集中结束的序号 | ||
--method_generate sample \ # 思维生成器,是抽样独立思维(用于创意写作)还是提出连续思维(用于24点游戏) | ||
--method_evaluate vote \ # 状态评估器,对状态进行投票(用于创意写作) | ||
--method_select greedy \ # 策略选择,"sample"(举例) | ||
--n_generate_sample 5 \ # 提示进行思维生成的次数 | ||
--n_evaluate_sample 5 \ # 提示进行状态评估的次数 | ||
--n_select_sample 1 \ # 每一步保留的状态数量(即论文中的 ``b`` 在 ToT + BFS 算法中) | ||
--prompt_sample cot \ | ||
--temperature 1.0 \ | ||
``` | ||
|
||
``` | ||
python run.py \ | ||
--task text \ | ||
--task_start_index 0 \ | ||
--task_end_index 100 \ | ||
--naive_run \ # 运行朴素的 IO/CoT 抽样 | ||
--prompt_sample cot \ # 抽样提示的方式, cot | ||
--n_generate_sample 10 \ | ||
--temperature 1.0 \ | ||
``` | ||
|
||
``` | ||
python run.py \ | ||
--task text \ | ||
--task_start_index 0 \ | ||
--task_end_index 100 \ | ||
--naive_run \ # 运行朴素的 IO/CoT 抽样 | ||
--prompt_sample standard \ # 抽样提示的方式, standard | ||
--n_generate_sample 10 \ | ||
--temperature 1.0 \ | ||
``` | ||
|
||
## 测试结果 | ||
本测试采用的是paddlenlp中facebook/llama-2-7b-chat 和 facebook/llama-2-13b-chat.使用的参数为 temperature=0.6, decode_strategy为"greedy_search",max_new_tokens=512,结果如下 | ||
|model|method|acc| | ||
|----|----|----| | ||
|llama-2-7b-chat|cot|0| | ||
|llama-2-7b-chat|standard sampling| 0| | ||
|llama-2-7b-chat|ToT| 3%| | ||
|llama-2-13b-chat|cot|0| | ||
|llama-2-13b-chat|standard sampling|0| | ||
|llama-2-13b-chat|ToT|2%| | ||
|
||
|
||
## 如何添加新任务 | ||
|
||
设置一个新任务很容易,主要包括两个步骤。 | ||
* 在 ``tot/tasks/`` 中设置一个新的任务类和任务文件在 ``tot/data/`` 中。查看 ``tot/tasks/game24.py`` 以获取示例。将任务添加到 ``tot/tasks/__init__.py`` 中。 | ||
* 在 ``tot/prompts/`` 中设置任务特定的提示。查看 ``tot/prompts/game24.py`` 以获取示例。根据任务的性质,选择 ``--method_generate`` (choices=[``sample``, ``propose``]) 和 ``--method_evaluate`` (choices=[``value``, ``vote``]) 及其相应的提示。 | ||
|
||
|
||
## 致谢 | ||
|
||
我们借鉴了Shunyu Yao ect.出色的框架设计,在此对Tree of Thoughts作者及其开源社区表示感谢。 | ||
|
||
We learn form the excellent framework design of Shunyu Yao, and we would like to express our thanks to the authors of Tree of Thoughts and their open source community. | ||
|
||
```bibtex | ||
@misc{yao2023tree, | ||
title={{Tree of Thoughts}: Deliberate Problem Solving with Large Language Models}, | ||
author={Shunyu Yao and Dian Yu and Jeffrey Zhao and Izhak Shafran and Thomas L. Griffiths and Yuan Cao and Karthik Narasimhan}, | ||
year={2023}, | ||
eprint={2305.10601}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CL} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
|
||
from src.llm import Ernie, Ernie_llm_list, llamaChatCompletion, llm_config | ||
from src.tot.methods.bfs import solve | ||
from src.tot.tasks.game24 import Game24Task | ||
|
||
args = argparse.Namespace( | ||
backend="llama-2-7b-chat", | ||
temperature=0.6, | ||
task="game24", | ||
naive_run=False, | ||
prompt_sample=None, | ||
method_generate="propose", | ||
method_evaluate="value", | ||
method_select="greedy", | ||
n_generate_sample=1, | ||
n_evaluate_sample=3, | ||
n_select_sample=5, | ||
log_fp="log.txt", | ||
) | ||
|
||
task = Game24Task() | ||
if args.backend in llm_config.keys(): | ||
chatter = llamaChatCompletion(args.backend) | ||
elif args.backend in Ernie_llm_list: | ||
chatter = Ernie(model=args.backend) | ||
ys, infos = solve(args, task, 900, chatter=chatter) | ||
print(ys[0]) | ||
print(infos) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
aiohttp==3.8.4 | ||
aiosignal==1.3.1 | ||
async-timeout==4.0.2 | ||
attrs==23.1.0 | ||
certifi==2023.5.7 | ||
charset-normalizer==3.1.0 | ||
frozenlist==1.3.3 | ||
idna==3.4 | ||
mpmath==1.3.0 | ||
multidict==6.0.4 | ||
numpy==1.24.3 | ||
requests==2.31.0 | ||
sympy==1.12 | ||
tqdm==4.65.0 | ||
urllib3==2.0.2 | ||
yarl==1.9.2 | ||
pandas==2.0.3 | ||
erniebot==0.5.0 | ||
paddlenlp==2.7.1 | ||
paddlepaddle-gpu==2.6.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# coding=utf8, ErnestinaQiu | ||
|
||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import json | ||
import os | ||
import time | ||
|
||
from src.llm.llama import Ernie, Ernie_llm_list, llamaChatCompletion, llm_config | ||
from src.tot.methods.bfs import naive_solve, solve | ||
from src.tot.models import gpt_usage | ||
from src.tot.tasks import get_task | ||
|
||
|
||
def run(args, chatter): | ||
task = get_task(args.task) | ||
logs, cnt_avg, cnt_any = [], 0, 0 | ||
if args.naive_run: | ||
file = f"./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json" | ||
metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_select}_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt" | ||
else: | ||
file = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json" | ||
metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt" | ||
os.makedirs(os.path.dirname(file), exist_ok=True) | ||
|
||
for i in range(args.task_start_index, args.task_end_index): | ||
args.log_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.log" | ||
args.query_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_query.log" | ||
f = open(args.log_fp, "a", encoding="utf8") | ||
f.write(f"------ index: {i}") | ||
f.close() | ||
|
||
f = open(args.query_fp, "a", encoding="utf8") | ||
f.write(f"------ index: {i}") | ||
f.close() | ||
|
||
chatter.query = [] | ||
chatter.tokenizer.init_chat_template( | ||
os.path.join(os.getcwd(), "pipelines", "examples", "tree-of-thought", "src", "llm", "chat_template.json") | ||
) | ||
|
||
# solve | ||
if args.naive_run: | ||
ys, info = naive_solve(args, task, i, chatter=chatter, args=args) | ||
else: | ||
ys, info = solve(args, task, i, chatter=chatter, args=args) | ||
|
||
# log | ||
infos = [task.test_output(i, y) for y in ys] | ||
info.update({"idx": i, "ys": ys, "infos": infos, "usage_so_far": gpt_usage(args.backend)}) | ||
logs.append(info) | ||
with open(file, "w") as f: | ||
json.dump(logs, f, indent=4) | ||
|
||
# log main metric | ||
accs = [info["r"] for info in infos] | ||
cnt_avg += sum(accs) / len(accs) | ||
cnt_any += any(accs) | ||
mes = f"{i}, 'sum(accs)', {sum(accs)}, 'cnt_avg', {cnt_avg}, 'cnt_any', {cnt_any}, '\n'" | ||
f = open(metric_fp, "a", encoding="utf8") | ||
f.write(mes) | ||
f.close() | ||
|
||
f = open(args.query_fp, "a", encoding="utf8") | ||
f.write(json.dumps(chatter.query)) | ||
f.close() | ||
|
||
n = args.task_end_index - args.task_start_index | ||
mes2 = f"cnt_avg / n: {cnt_avg / n}, cnt_any / n: {cnt_any / n}" | ||
mes3 = f"'usage_so_far', {gpt_usage(args.backend)}" | ||
f = open(metric_fp, "a", encoding="utf8") | ||
f.write(mes2) | ||
f.write(mes3) | ||
f.close() | ||
|
||
|
||
llm_backend_choices = list(llm_config.keys()) | ||
|
||
|
||
def parse_args(): | ||
args = argparse.ArgumentParser() | ||
args.add_argument("--backend", type=str, choices=llm_backend_choices, default="llama-2-7b-chat") | ||
args.add_argument("--temperature", type=float, default=0.6) | ||
|
||
args.add_argument("--task", type=str, required=True, choices=["game24", "text", "crosswords"]) | ||
args.add_argument("--task_start_index", type=int, default=900) | ||
args.add_argument("--task_end_index", type=int, default=1000) | ||
|
||
args.add_argument("--naive_run", action="store_true") | ||
args.add_argument( | ||
"--prompt_sample", type=str, choices=["standard", "cot"] | ||
) # only used when method_generate = sample, or naive_run | ||
|
||
args.add_argument("--method_generate", type=str, choices=["sample", "propose"]) | ||
args.add_argument("--method_evaluate", type=str, choices=["value", "vote"]) | ||
args.add_argument("--method_select", type=str, choices=["sample", "greedy"], default="greedy") | ||
args.add_argument("--n_generate_sample", type=int, default=1) # only thing needed if naive_run | ||
args.add_argument("--n_evaluate_sample", type=int, default=1) | ||
args.add_argument("--n_select_sample", type=int, default=1) | ||
|
||
args.add_argument("--query_fp", type=str, default=f"./logs/default/query_{int(time.time())}.log") | ||
|
||
args = args.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
if args.backend in llm_backend_choices: | ||
chatter = llamaChatCompletion(args.backend) | ||
elif args.backend in Ernie_llm_list: | ||
chatter = Ernie(model=args.backend) | ||
run(args, chatter=chatter) |
Oops, something went wrong.