-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
134 lines (104 loc) · 3.95 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Title: Utils
"""
import bz2
import gzip
from os import path
import tarfile
import io
from itertools import islice, chain
from six import string_types, text_type
import numpy as np
from gensim.models import Word2Vec
def any2utf8(text, errors='strict', encoding='utf8'):
"""Convert a string (unicode or bytestring in `encoding`), to bytestring in utf8."""
if isinstance(text, text_type):
return text.encode('utf8')
# do bytestring -> unicode -> utf8 full circle, to ensure valid utf8
return text_type(text, encoding, errors=errors).encode('utf8')
to_utf8 = any2utf8
# Works just as good with unicode chars
_delchars = [chr(c) for c in range(256)]
_delchars = [x for x in _delchars if not x.isalnum()]
_delchars.remove('\t')
_delchars.remove(' ')
_delchars.remove('-')
_delchars.remove('_') # for instance phrases are joined in word2vec used this char
_delchars = ''.join(_delchars)
_delchars_table = dict((ord(char), None) for char in _delchars)
def standardize_string(s, clean_words=True, lower=True, language="english"):
"""
Ensures common convention across code. Converts to utf-8 and removes non-alphanumeric characters
Parameters
----------
language: only "english" is now supported. If "english" will remove non-alphanumeric characters
lower: if True will lower strńing.
clean_words: if True will remove non alphanumeric characters (for instance '$', '#' or 'ł')
Returns
-------
string: processed string
"""
assert isinstance(s, string_types)
if not isinstance(s, text_type):
s = text_type(s, "utf-8")
if language == "english":
s = (s.lower() if lower else s)
s = (s.translate(_delchars_table) if clean_words else s)
return s
else:
raise NotImplementedError("Not implemented standarization for other languages")
def batched(iterable, size):
sourceiter = iter(iterable)
while True:
batchiter = islice(sourceiter, size)
try:
yield chain([next(batchiter)], batchiter)
except StopIteration:
return
def _open(file_, mode='r'):
"""Open file object given filenames, open files or even archives."""
if isinstance(file_, string_types):
_, ext = path.splitext(file_)
if ext in {'.gz'}:
if mode == "r" or mode == "rb":
# gzip is extremely slow
return io.BufferedReader(gzip.GzipFile(file_, mode=mode))
else:
return gzip.GzipFile(file_, mode=mode)
if ext in {'.bz2'}:
return bz2.BZ2File(file_, mode=mode)
else:
return io.open(file_, mode, **({"encoding": "utf-8"} if "b" not in mode else {}))
return file_
def embedding_info(embeddings: Word2Vec):
print("Vector size:", embeddings.vector_size)
print("Dictionary size", len(embeddings.wv.index_to_key))
print("Window size", embeddings.window)
print("Total training time", embeddings.total_train_time)
def load_word_vectors(file_path):
"""
Load embeddings with the following format:
<word> <vector>
"""
word_vectors = {}
with open(file_path, 'r') as file:
for line in file:
word, *vector = line.split()
word_vectors[word] = np.array(vector, dtype=np.float32)
return word_vectors
def most_similar_words(embeddings, word, k):
"""
Get the top k most similar words to a word
embeddings is of form {word1: vector1, word2: vector2, ...}
"""
word = standardize_string(word)
word_vector = embeddings[word]
similarity = {}
for w, v in embeddings.items():
try:
similarity[w] = np.dot(word_vector, v) / (np.linalg.norm(word_vector) * np.linalg.norm(v))
except Exception as e:
continue
# Return list of [(word, similarity), ...]
return sorted(similarity.items(), key=lambda x: x[1], reverse=True)[:k]
# return sorted(similarity, key=lambda x: similarity[x], reverse=True)[:k]