Skip to content

Commit 8ca4413

Browse files
author
yanwii
committed
update
1 parent 1c29713 commit 8ca4413

18 files changed

+588
-2
lines changed

.vscode/settings.json

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"python.linting.pylintEnabled": false
3+
}

LICENSE

100644100755
File mode changed.

README.md

100644100755
+32-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,37 @@
44
于是转投Pytorch, 从此打开了新世界的大门。
55

66
---
7-
需要
8-
**Python3** **Pytorch** **Jieba**
7+
Requirements:
8+
[**Python3**](https://www.python.org/)
9+
[**Pytorch**](https://github.com/pytorch/pytorch)
10+
[**Jieba分词**](https://github.com/fxsjy/jieba)
911

1012
---
13+
14+
### 关于BeamSearch算法
15+
很经典的贪心算法,在很多领域都有应用。
16+
17+
![](./img/beamsearch.png)
18+
19+
20+
在这个引用中 我们引入了惩罚因子
21+
![](./img/beamsearch2.jpeg)
22+
23+
24+
![](./img/1.png)
25+
26+
27+
---
28+
29+
### 用法
30+
31+
# 准备数据
32+
python3 preprocessing.py
33+
# 训练
34+
python3 seq2seq.py train
35+
# 预测
36+
python3 seq2seq.py predict
37+
# 重新训练
38+
python3 seq2seq.py retrain
39+
40+

data/answer.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
我是你

data/dec.segement

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
我 是 你

data/dec.vec

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
3
2+
3
3+
4 5 6

data/dec.vocab

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
__PAD__
2+
__GO__
3+
__EOS__
4+
__UNK__
5+
6+
7+

data/enc.segement

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
你 是 谁

data/enc.vec

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
3 3 3 3
2+
3
3+
4 5 6

data/enc.vocab

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
__PAD__
2+
__GO__
3+
__EOS__
4+
__UNK__
5+
6+
7+

data/question.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
你是谁

data/supplementvocab.txt

Whitespace-only changes.

img/1.png

17.7 KB
Loading

img/beamsearch.png

91.7 KB
Loading

img/beamsearch2.jpeg

74.2 KB
Loading

model/params.pkl

2.68 MB
Binary file not shown.

preprocessing.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import jieba
2+
import re
3+
4+
class preprocessing():
5+
__PAD__ = 0
6+
__GO__ = 1
7+
__EOS__ = 2
8+
__UNK__ = 3
9+
vocab = ['__PAD__', '__GO__', '__EOS__', '__UNK__']
10+
def __init__(self):
11+
#self.encoderFile = "/home/yanwii/Python/NLP/seq2seq/seq2seq_no_buckets/preprocessing/MySeq2seq/Data/alldata_ask.txt"
12+
#self.decoderFile = '/home/yanwii/Python/NLP/seq2seq/seq2seq_no_buckets/preprocessing/MySeq2seq/Data/alldata_answer.txt'
13+
#self.savePath = '/home/yanwii/Python/NLP/seq2seq/seq2seq_pytorch/data/'
14+
self.encoderFile = "./data/question.txt"
15+
self.decoderFile = "./data/answer.txt"
16+
self.savePath = './data/'
17+
18+
jieba.load_userdict("./data/supplementvocab.txt")
19+
20+
def wordToVocabulary(self, originFile, vocabFile, segementFile):
21+
vocabulary = []
22+
sege = open(segementFile, "w")
23+
with open(originFile, 'r') as en:
24+
for sent in en.readlines():
25+
# 去标点
26+
if "enc" in segementFile:
27+
#sentence = re.sub("[\s+\.\!\/_,$%^*(+\"\']+|[+——!,。“”’‘??、~@#¥%……&*()]+", "", sent.strip())
28+
sentence = sent.strip()
29+
words = jieba.lcut(sentence)
30+
print(words)
31+
else:
32+
words = jieba.lcut(sent.strip())
33+
vocabulary.extend(words)
34+
for word in words:
35+
sege.write(word+" ")
36+
sege.write("\n")
37+
sege.close()
38+
39+
# 去重并存入词典
40+
vocab_file = open(vocabFile, "w")
41+
_vocabulary = list(set(vocabulary))
42+
_vocabulary.sort(key=vocabulary.index)
43+
_vocabulary = self.vocab + _vocabulary
44+
for index, word in enumerate(_vocabulary):
45+
vocab_file.write(word+"\n")
46+
vocab_file.close()
47+
48+
def toVec(self, segementFile, vocabFile, doneFile):
49+
word_dicts = {}
50+
vec = []
51+
with open(vocabFile, "r") as dict_f:
52+
for index, word in enumerate(dict_f.readlines()):
53+
word_dicts[word.strip()] = index
54+
55+
f = open(doneFile, "w")
56+
if "enc.vec" in doneFile:
57+
f.write("3 3 3 3\n")
58+
f.write("3\n")
59+
elif "dec.vec" in doneFile:
60+
f.write(str(word_dicts.get("other", 3))+"\n")
61+
f.write(str(word_dicts.get("other", 3))+"\n")
62+
with open(segementFile, "r") as sege_f:
63+
for sent in sege_f.readlines():
64+
sents = [i.strip() for i in sent.split(" ")[:-1]]
65+
vec.extend(sents)
66+
for word in sents:
67+
f.write(str(word_dicts.get(word))+" ")
68+
f.write("\n")
69+
f.close()
70+
71+
72+
def main(self):
73+
# 获得字典
74+
self.wordToVocabulary(self.encoderFile, self.savePath+'enc.vocab', self.savePath+'enc.segement')
75+
self.wordToVocabulary(self.decoderFile, self.savePath+'dec.vocab', self.savePath+'dec.segement')
76+
# 转向量
77+
self.toVec(self.savePath+"enc.segement",
78+
self.savePath+"enc.vocab",
79+
self.savePath+"enc.vec")
80+
self.toVec(self.savePath+"dec.segement",
81+
self.savePath+"dec.vocab",
82+
self.savePath+"dec.vec")
83+
84+
85+
if __name__ == '__main__':
86+
pre = preprocessing()
87+
pre.main()

0 commit comments

Comments
 (0)