-
Notifications
You must be signed in to change notification settings - Fork 2
/
bot.py
215 lines (179 loc) · 7.76 KB
/
bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import time
import random
import re
import ssl as ssl_lib
import certifi
import logging
from slackeventsapi import SlackEventAdapter
import slack
from subword_nmt import apply_bpe
from sacremoses import MosesTokenizer, MosesDetokenizer
from utils import load_line_as_data
from joeynmt.helpers import load_config, get_latest_checkpoint, \
load_checkpoint
from joeynmt.vocabulary import build_vocab
from joeynmt.model import build_model
from joeynmt.prediction import validate_on_data
READ_WEBSOCKET_DELAY = 0.5
TOKEN = open("bot.token", "r").readline().strip()
BOT_CHANNEL = open("bot.channel", "r").readline().strip()
BOT_NAME = open("bot.name", "r").readline().strip()
BOT_SIGNIN = open("bot.signin", "r").readline().strip()
def translate(message_text, model, src_vocab, trg_vocab, preprocess, postprocess,
logger, beam_size, beam_alpha, level, lowercase,
max_output_length, use_cuda):
"""
Describes how to translate a text message.
:param message_text: Slack command, could be text.
:param model: The Joey NMT model.
:param src_vocab: Source vocabulary.
:param trg_vocab: Target vocabulary.
:param preprocess: Preprocessing pipeline (a list).
:param postprocess: Postprocessing pipeline (a list).
:param beam_size: Beam size for decoding.
:param beam_alpha: Beam alpha for decoding.
:param level: Segmentation level.
:param lowercase: Lowercasing.
:param max_output_length: Maximum output length.
:param use_cuda: Using CUDA or not.
:return:
"""
sentence = message_text.strip()
# remove emojis
emoji_pattern = re.compile("\:[a-zA-Z]+\:")
sentence = re.sub(emoji_pattern, "", sentence)
sentence = sentence.strip()
if lowercase:
sentence = sentence.lower()
for p in preprocess:
sentence = p(sentence)
# load the data which consists only of this sentence
test_data, src_vocab, trg_vocab = load_line_as_data(lowercase=lowercase,
line=sentence, src_vocab=src_vocab, trg_vocab=trg_vocab, level=level)
# generate outputs
score, loss, ppl, sources, sources_raw, references, hypotheses, \
hypotheses_raw, attention_scores = validate_on_data(
model, data=test_data, batch_size=1, level=level,
max_output_length=max_output_length, eval_metric=None,
use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
beam_alpha=beam_alpha, logger=logger)
# post-process
if level == "char":
response = "".join(hypotheses)
else:
response = " ".join(hypotheses)
for p in postprocess:
response = p(response)
return response
def run_bot(model_dir, bpe_src_code=None, tokenize=None):
"""
Start the bot. This means loading the model according to the config file.
:param model_dir: Model directory of trained Joey NMT model.
:param bpe_src_code: BPE codes for source side processing (optional).
:param tokenize: If True, tokenize inputs with Moses tokenizer.
:return:
"""
cfg_file = model_dir+"/config.yaml"
logger = logging.getLogger(__name__)
# load the Joey configuration
cfg = load_config(cfg_file)
# load the checkpoint
if "load_model" in cfg['training'].keys():
ckpt = cfg['training']["load_model"]
else:
ckpt = get_latest_checkpoint(model_dir)
if ckpt is None:
raise FileNotFoundError("No checkpoint found in directory {}."
.format(model_dir))
# prediction parameters from config
use_cuda = cfg["training"].get("use_cuda", False)
level = cfg["data"]["level"]
max_output_length = cfg["training"].get("max_output_length", None)
lowercase = cfg["data"].get("lowercase", False)
# load the vocabularies
src_vocab_file = cfg["training"]["model_dir"] + "/src_vocab.txt"
trg_vocab_file = cfg["training"]["model_dir"] + "/trg_vocab.txt"
src_vocab = build_vocab(field="src", vocab_file=src_vocab_file,
dataset=None, max_size=-1, min_freq=0)
trg_vocab = build_vocab(field="trg", vocab_file=trg_vocab_file,
dataset=None, max_size=-1, min_freq=0)
# whether to use beam search for decoding, 0: greedy decoding
if "testing" in cfg.keys():
beam_size = cfg["testing"].get("beam_size", 0)
beam_alpha = cfg["testing"].get("alpha", -1)
else:
beam_size = 1
beam_alpha = -1
# pre-processing
if tokenize is not None:
src_tokenizer = MosesTokenizer(lang=cfg["data"]["src"])
trg_tokenizer = MosesDetokenizer(lang=cfg["data"]["trg"])
# tokenize input
tokenizer = lambda x: src_tokenizer.tokenize(x, return_str=True)
detokenizer = lambda x: trg_tokenizer.detokenize(
x.split(), return_str=True)
else:
tokenizer = lambda x: x
detokenizer = lambda x: x
if bpe_src_code is not None and level == "bpe":
# load bpe merge file
merge_file = open(bpe_src_code, "r")
bpe = apply_bpe.BPE(codes=merge_file)
segmenter = lambda x: bpe.process_line(x.strip())
elif level == "char":
# split to chars
segmenter = lambda x: list(x.strip())
else:
segmenter = lambda x: x.strip()
# build model and load parameters into it
model_checkpoint = load_checkpoint(ckpt, use_cuda)
model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
model.load_state_dict(model_checkpoint["model_state"])
if use_cuda:
model.cuda()
print("Joey NMT model loaded successfully.")
web_client = slack.WebClient(TOKEN, timeout=30)
# get bot id
bot_id = (web_client.api_call("auth.test")["user_id"].upper())
# find bot channel id
all_channels = web_client.api_call("conversations.list")["channels"]
for c in all_channels:
if c["name"] == BOT_CHANNEL:
bot_channel_id = c["id"]
slack_events_adapter = SlackEventAdapter(BOT_SIGNIN,
endpoint="/slack/events")
@slack_events_adapter.on("message")
def handle_message(event_data):
message = event_data["event"]
if message.get("subtype") is None:
channel = message["channel"]
user = message["user"]
text = message["text"].strip()
if user != bot_id and message.get("subtype") is None:
# translates all messages in its channel and mentions
if channel == bot_channel_id or bot_id in text:
mention = "<@{}>".format(bot_id)
# TODO remove all possible mentions with regex
if mention in text:
parts = text.split(mention)
text = parts[0].strip()+parts[1].strip()
message = translate(text,
beam_size=beam_size,
beam_alpha=beam_alpha,
level=level,
lowercase=lowercase,
max_output_length=max_output_length,
model=model,
postprocess=[detokenizer],
preprocess=[tokenizer, segmenter],
src_vocab=src_vocab,
trg_vocab=trg_vocab,
use_cuda=use_cuda,
logger=logger)
web_client.chat_postMessage(text=message,
token=TOKEN, channel=channel)
# Error events
@slack_events_adapter.on("error")
def error_handler(err):
print("ERROR: " + str(err))
slack_events_adapter.start(port=3000)