This repository was archived by the owner on Jul 6, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtest.py
executable file
·96 lines (77 loc) · 3.05 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from numpy import argmax
import tensorflow as tf
import CONFIG
from utils import BatchGenerator, load_data, load_dictionary
indexToString = load_dictionary(
os.path.join(os.getcwd(), 'data', 'indexToString.json'))
stringToIndex = load_dictionary(
os.path.join(os.getcwd(), 'data', 'stringToIndex.json'))
model = tf.keras.models.load_model(
os.path.join(os.getcwd(), 'model', 'model.h5'))
def predict_next_word(string, verbose=True, NUMBER_OF_PREDICTIONS=1):
ques_bool = False
idx, ques_bool = string_to_indexes(string.split(), ques_bool)
if len(idx) >= CONFIG.number_of_words:
if verbose:
print('\nindexes of last ', CONFIG.number_of_words, 'words\t:',
idx[-CONFIG.number_of_words:])
prediction = model.predict([[idx[-CONFIG.number_of_words:]]])
best_predictions = []
for _ in range(NUMBER_OF_PREDICTIONS):
argmax_idx = argmax(prediction[:, CONFIG.number_of_words - 1, :])
print(prediction[:, CONFIG.number_of_words - 1, argmax_idx])
best_predictions.append(argmax_idx)
prediction[:, CONFIG.number_of_words - 1, argmax_idx] = 0.0
if verbose:
print('\nprediction indexes\t:', best_predictions)
converted_string = indexes_to_string(best_predictions, ques_bool)
sentences = []
for word in converted_string:
sentences.append(string + ' ' + word)
return sentences
else:
print('\n\nPlease enter atleast', CONFIG.number_of_words, ' words.\n')
def string_to_indexes(array_of_string, ques_bool):
array_of_indexes = []
for word in array_of_string:
if word == '<rare word>':
word = '<unk>'
if word == '.' or word == '?':
word = '<eos>'
if word == 'what' or word == 'why' or word == 'who' or word == 'how' or word == 'whose' or word == 'when' or word == 'which' or word == 'where':
ques_bool = True
try:
array_of_indexes.append(stringToIndex[word])
except:
print("Word ", word,
" does not exist in the vocabulary!\nReplacing it with '<unk>'")
word = '<unk>'
array_of_indexes.append(stringToIndex[word])
pass
return array_of_indexes, ques_bool
def indexes_to_string(array_of_indexes, ques_bool):
array_of_strings = []
for index in array_of_indexes:
word = indexToString[str(index)]
if word == '<eos>':
if ques_bool == True:
word = '?'
else:
word = '.'
if word == 'N': # if word is a number.
# TODO
pass
array_of_strings.append(word)
return array_of_strings
while True:
sentences = predict_next_word(
string=input('\n\nEnter atleast ' + str(CONFIG.number_of_words) +
' words: \n'),
NUMBER_OF_PREDICTIONS=1)
print('\n')
if sentences != None:
count = 0
for sentence in sentences:
count += 1
print(count, '\t-', sentence)