Skip to content

Commit 62ea40a

Browse files
更新不同系统路径适配问题
1 parent a921f8b commit 62ea40a

File tree

5 files changed

+82
-36
lines changed

5 files changed

+82
-36
lines changed

actuator.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import sys
2222
from argparse import ArgumentParser
23-
from dialogue.pytorch.seq2seq.actuator import torch_seq2seq
23+
# from dialogue.pytorch.seq2seq.actuator import torch_seq2seq
2424
from dialogue.tensorflow.seq2seq.actuator import tf_seq2seq
2525
from dialogue.tensorflow.smn.actuator import tf_smn
2626
from dialogue.tensorflow.transformer.actuator import tf_transformer
@@ -37,10 +37,10 @@ def main() -> None:
3737
"seq2seq": lambda: tf_seq2seq(),
3838
"smn": lambda: tf_smn(),
3939
},
40-
"torch": {
41-
"transformer": lambda: None,
42-
"seq2seq": lambda: torch_seq2seq(),
43-
}
40+
# "torch": {
41+
# "transformer": lambda: None,
42+
# "seq2seq": lambda: torch_seq2seq(),
43+
# }
4444
}
4545

4646
options = parser.parse_args(sys.argv[1:5])

dialogue/tensorflow/seq2seq/actuator.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,18 @@ def tf_seq2seq() -> NoReturn:
4848
parser.add_argument("--max_train_data_size", default=0, type=int, required=False, help="用于训练的最大数据大小")
4949
parser.add_argument("--max_valid_data_size", default=0, type=int, required=False, help="用于验证的最大数据大小")
5050
parser.add_argument("--max_sentence", default=40, type=int, required=False, help="单个序列的最大长度")
51-
parser.add_argument("--dict_path", default="data\\preprocess\\seq2seq_dict.json",
51+
parser.add_argument("--dict_path", default="data/preprocess/seq2seq_dict.json",
5252
type=str, required=False, help="字典路径")
53-
parser.add_argument("--checkpoint_dir", default="checkpoints\\tensorflow\\seq2seq",
53+
parser.add_argument("--checkpoint_dir", default="checkpoints/tensorflow/seq2seq",
5454
type=str, required=False, help="检查点路径")
55-
parser.add_argument("--resource_data_path", default="data\\LCCC.json", type=str, required=False, help="原始数据集路径")
56-
parser.add_argument("--tokenized_data_path", default="data\\preprocess\\lccc_tokenized.txt",
55+
parser.add_argument("--resource_data_path", default="data/LCCC.json", type=str, required=False, help="原始数据集路径")
56+
parser.add_argument("--tokenized_data_path", default="data/preprocess/lccc_tokenized.txt",
5757
type=str, required=False, help="处理好的多轮分词数据集路径")
58-
parser.add_argument("--preprocess_data_path", default="data\\preprocess\\single_tokenized.txt",
58+
parser.add_argument("--preprocess_data_path", default="data/preprocess/single_tokenized.txt",
5959
type=str, required=False, help="处理好的单轮分词数据集路径")
60-
parser.add_argument("--valid_data_path", default="data\\preprocess\\single_tokenized.txt", type=str,
60+
parser.add_argument("--valid_data_path", default="data/preprocess/single_tokenized.txt", type=str,
6161
required=False, help="处理好的单轮分词验证评估用数据集路径")
62-
parser.add_argument("--history_image_dir", default="data\\history\\seq2seq\\", type=str, required=False,
62+
parser.add_argument("--history_image_dir", default="data/history/seq2seq/", type=str, required=False,
6363
help="数据指标图表保存路径")
6464
parser.add_argument("--valid_freq", default=5, type=int, required=False, help="验证频率")
6565
parser.add_argument("--checkpoint_save_freq", default=2, type=int, required=False, help="检查点保存频率")
@@ -72,9 +72,9 @@ def tf_seq2seq() -> NoReturn:
7272
parser.add_argument("--start_sign", default="<start>", type=str, required=False, help="序列开始标记")
7373
parser.add_argument("--end_sign", default="<end>", type=str, required=False, help="序列结束标记")
7474
parser.add_argument("--unk_sign", default="<unk>", type=str, required=False, help="未登录词")
75-
parser.add_argument("--encoder_save_path", default="models\\tensorflow\\seq2seq\\encoder", type=str,
75+
parser.add_argument("--encoder_save_path", default="models/tensorflow/seq2seq/encoder", type=str,
7676
required=False, help="Encoder的SaveModel格式保存路径")
77-
parser.add_argument("--decoder_save_path", default="models\\tensorflow\\seq2seq\\decoder", type=str,
77+
parser.add_argument("--decoder_save_path", default="models/tensorflow/seq2seq/decoder", type=str,
7878
required=False, help="Decoder的SaveModel格式保存路径")
7979

8080
options = parser.parse_args().__dict__

dialogue/tensorflow/smn/actuator.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,24 @@ def tf_smn() -> NoReturn:
4848
parser.add_argument("--valid_data_split", default=0.0, type=float, required=False, help="从训练数据集中划分验证数据的比例")
4949
parser.add_argument("--learning_rate", default=0.001, type=float, required=False, help="学习率")
5050
parser.add_argument("--max_database_size", default=0, type=int, required=False, help="最大数据候选数量")
51-
parser.add_argument("--dict_path", default="data\\preprocess\\smn_dict.json", type=str, required=False, help="字典路径")
52-
parser.add_argument("--checkpoint_dir", default="checkpoints\\tensorflow\\smn", type=str, required=False,
51+
parser.add_argument("--dict_path", default="data/preprocess/smn_dict.json", type=str, required=False, help="字典路径")
52+
parser.add_argument("--checkpoint_dir", default="checkpoints/tensorflow/smn", type=str, required=False,
5353
help="检查点路径")
54-
parser.add_argument("--train_data_path", default="data\\ubuntu_train.txt", type=str, required=False,
54+
parser.add_argument("--train_data_path", default="data/ubuntu_train.txt", type=str, required=False,
5555
help="处理好的多轮分词训练数据集路径")
56-
parser.add_argument("--valid_data_path", default="data\\ubuntu_valid.txt", type=str, required=False,
56+
parser.add_argument("--valid_data_path", default="data/ubuntu_valid.txt", type=str, required=False,
5757
help="处理好的多轮分词验证数据集路径")
5858
parser.add_argument("--solr_server", default="http://49.235.33.100:8983/solr/smn/", type=str, required=False,
5959
help="solr服务地址")
60-
parser.add_argument("--candidate_database", default="data\\preprocess\\candidate.json", type=str, required=False,
60+
parser.add_argument("--candidate_database", default="data/preprocess/candidate.json", type=str, required=False,
6161
help="候选回复数据库")
6262
parser.add_argument("--epochs", default=5, type=int, required=False, help="训练步数")
6363
parser.add_argument("--batch_size", default=64, type=int, required=False, help="batch大小")
6464
parser.add_argument("--buffer_size", default=20000, type=int, required=False, help="Dataset加载缓冲大小")
6565
parser.add_argument("--start_sign", default="<start>", type=str, required=False, help="序列开始标记")
6666
parser.add_argument("--end_sign", default="<end>", type=str, required=False, help="序列结束标记")
6767
parser.add_argument("--unk_sign", default="<unk>", type=str, required=False, help="未登录词")
68-
parser.add_argument("--model_save_path", default="models\\tensorflow\\smn", type=str,
68+
parser.add_argument("--model_save_path", default="models/tensorflow/smn", type=str,
6969
required=False, help="SaveModel格式保存路径")
7070

7171
options = parser.parse_args().__dict__

dialogue/tensorflow/transformer/actuator.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,22 @@ def tf_transformer() -> NoReturn:
6363
parser.add_argument("--start_sign", default="<start>", type=str, required=False, help="序列开始标记")
6464
parser.add_argument("--end_sign", default="<end>", type=str, required=False, help="序列结束标记")
6565
parser.add_argument("--unk_sign", default="<unk>", type=str, required=False, help="未登录词")
66-
parser.add_argument("--dict_path", default="data\\preprocess\\transformer_dict.json", type=str, required=False,
66+
parser.add_argument("--dict_path", default="data/preprocess/transformer_dict.json", type=str, required=False,
6767
help="字典路径")
68-
parser.add_argument("--checkpoint_dir", default="checkpoints\\tensorflow\\transformer", type=str, required=False,
68+
parser.add_argument("--checkpoint_dir", default="checkpoints/tensorflow/transformer", type=str, required=False,
6969
help="检查点路径")
70-
parser.add_argument("--raw_data_path", default="data\\LCCC.json", type=str, required=False, help="原始数据集路径")
71-
parser.add_argument("--tokenized_data_path", default="data\\preprocess\\lccc_tokenized.txt", type=str,
70+
parser.add_argument("--raw_data_path", default="data/LCCC.json", type=str, required=False, help="原始数据集路径")
71+
parser.add_argument("--tokenized_data_path", default="data/preprocess/lccc_tokenized.txt", type=str,
7272
required=False, help="处理好的多轮分词数据集路径")
73-
parser.add_argument("--preprocess_data_path", default="data\\preprocess\\single_tokenized.txt", type=str,
73+
parser.add_argument("--preprocess_data_path", default="data/preprocess/single_tokenized.txt", type=str,
7474
required=False, help="处理好的单轮分词训练用数据集路径")
75-
parser.add_argument("--valid_data_path", default="data\\preprocess\\single_tokenized.txt", type=str,
75+
parser.add_argument("--valid_data_path", default="data/preprocess/single_tokenized.txt", type=str,
7676
required=False, help="处理好的单轮分词验证评估用数据集路径")
77-
parser.add_argument("--history_image_dir", default="data\\history\\transformer\\", type=str, required=False,
77+
parser.add_argument("--history_image_dir", default="data/history/transformer/", type=str, required=False,
7878
help="数据指标图表保存路径")
79-
parser.add_argument("--encoder_save_path", default="models\\tensorflow\\transformer\\encoder", type=str,
79+
parser.add_argument("--encoder_save_path", default="models/tensorflow/transformer/encoder", type=str,
8080
required=False, help="Encoder的SaveModel格式保存路径")
81-
parser.add_argument("--decoder_save_path", default="models\\tensorflow\\transformer\\decoder", type=str,
81+
parser.add_argument("--decoder_save_path", default="models/tensorflow/transformer/decoder", type=str,
8282
required=False, help="Decoder的SaveModel格式保存路径")
8383

8484
options = parser.parse_args().__dict__

requirements.txt

+53-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,55 @@
1-
tensorflow==2.3.1
2-
pytorch==1.7.1
1+
absl-py==0.12.0
2+
astunparse==1.6.3
3+
cachetools==4.2.2
4+
certifi==2020.12.5
5+
chardet==4.0.0
6+
click==7.1.2
7+
cycler==0.10.0
8+
Flask==1.1.2
9+
flatbuffers==1.12
10+
gast==0.3.3
11+
google-auth==1.30.0
12+
google-auth-oauthlib==0.4.4
13+
google-pasta==0.2.0
14+
grpcio==1.32.0
15+
h5py==2.10.0
16+
idna==2.10
17+
importlib-metadata==4.0.1
18+
inflect==5.3.0
19+
itsdangerous==1.1.0
320
jieba==0.42.1
4-
inflect==5.0.2
5-
numpy==1.19.0
21+
Jinja2==2.11.3
22+
joblib==1.0.1
23+
Keras-Preprocessing==1.1.2
24+
kiwisolver==1.3.1
25+
Markdown==3.3.4
26+
MarkupSafe==1.1.1
27+
matplotlib==3.4.2
28+
numpy==1.19.5
29+
oauthlib==3.1.0
30+
opt-einsum==3.3.0
31+
Pillow==8.2.0
32+
protobuf==3.16.0
33+
pyasn1==0.4.8
34+
pyasn1-modules==0.2.8
35+
pyparsing==2.4.7
636
pysolr==3.9.0
7-
flask==1.1.2
8-
scikit-learn==0.23.2
9-
flask-cors==3.0.9
37+
python-dateutil==2.8.1
38+
requests==2.25.1
39+
requests-oauthlib==1.3.0
40+
rsa==4.7.2
41+
scikit-learn==0.24.2
42+
scipy==1.6.3
43+
six==1.15.0
44+
tensorboard==2.5.0
45+
tensorboard-data-server==0.6.1
46+
tensorboard-plugin-wit==1.8.0
47+
tensorflow==2.4.1
48+
tensorflow-estimator==2.4.0
49+
termcolor==1.1.0
50+
threadpoolctl==2.1.0
51+
typing-extensions==3.7.4.3
52+
urllib3==1.26.4
53+
Werkzeug==1.0.1
54+
wrapt==1.12.1
55+
zipp==3.4.1

0 commit comments

Comments
 (0)