-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathexample.py
55 lines (46 loc) · 2.09 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#! -*- coding: utf-8 -*-
# bert4keras加载CDial-GPT
# https://github.com/bojone/CDial-GPT-tf
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import AutoRegressiveDecoder
from bert4keras.snippets import uniout
config_path = '/root/kg/bert/GPT_LCCC-base-tf/gpt_config.json'
checkpoint_path = '/root/kg/bert/GPT_LCCC-base-tf/gpt_model.ckpt'
dict_path = '/root/kg/bert/GPT_LCCC-base-tf/vocab.txt'
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
speakers = [
tokenizer.token_to_id('[speaker1]'),
tokenizer.token_to_id('[speaker2]')
]
model = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
model='gpt'
) # 建立模型,加载权重
class ChatBot(AutoRegressiveDecoder):
"""基于随机采样对话机器人
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids, segment_ids = inputs
curr_segment_ids = np.zeros_like(output_ids) + token_ids[0, -1]
token_ids = np.concatenate([token_ids, output_ids], 1)
segment_ids = np.concatenate([segment_ids, curr_segment_ids], 1)
return model.predict([token_ids, segment_ids])[:, -1]
def response(self, texts, topk=5):
token_ids = [tokenizer._token_start_id, speakers[0]]
segment_ids = [tokenizer._token_start_id, speakers[0]]
for i, text in enumerate(texts):
ids = tokenizer.encode(text)[0][1:-1] + [speakers[(i + 1) % 2]]
token_ids.extend(ids)
segment_ids.extend([speakers[i % 2]] * len(ids))
segment_ids[-1] = speakers[(i + 1) % 2]
results = self.random_sample([token_ids, segment_ids], 1, topk)
return tokenizer.decode(results[0])
chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32)
print(chatbot.response([u'别爱我没结果', u'你这样会失去我的', u'失去了又能怎样']))
"""
回复是随机的,例如:你还有我 | 那就不要爱我 | 你是不是傻 | 等等。
"""