-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathefficiency_test.py
83 lines (65 loc) · 2.38 KB
/
efficiency_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import requests
from REL.training_datasets import TrainingEvaluationDatasets
import numpy as np
np.random.seed(seed=42)
base_url = "C:/Users/mickv/desktop/data_back/"
wiki_version = "wiki_2019"
datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()['aida_testB']
# random_docs = np.random.choice(list(datasets.keys()), 50)
server = False
docs = {}
for i, doc in enumerate(datasets):
sentences = []
for x in datasets[doc]:
if x['sentence'] not in sentences:
sentences.append(x['sentence'])
text = '. '.join([x for x in sentences])
if len(docs) == 50:
print('length docs is 50.')
print('====================')
break
if len(text.split()) > 200:
docs[doc] = [text, []]
# Demo script that can be used to query the API.
if server:
myjson = {
"text": text,
"spans": [
# {"start": 41, "length": 16}
],
}
print('----------------------------')
print(i, 'Input API:')
print(myjson)
print('Output API:')
print(requests.post("http://localhost:5555", json=myjson).json())
print('----------------------------')
# --------------------- Now total --------------------------------
# ------------- RUN SEPARATELY TO BALANCE LOAD--------------------
if not server:
import flair
import torch
from flair.models import SequenceTagger
from REL.mention_detection import MentionDetection
from REL.entity_disambiguation import EntityDisambiguation
from time import time
base_url = "C:/Users/mickv/desktop/data_back/"
flair.device = torch.device('cuda:0')
mention_detection = MentionDetection(base_url, wiki_version)
# Alternatively use Flair NER tagger.
tagger_ner = SequenceTagger.load("ner-fast")
start = time()
mentions_dataset, n_mentions = mention_detection.find_mentions(docs, tagger_ner)
print('MD took: {}'.format(time() - start))
# 3. Load model.
config = {
"mode": "eval",
"model_path": "{}/{}/generated/model".format(
base_url, wiki_version
),
}
model = EntityDisambiguation(base_url, wiki_version, config)
# 4. Entity disambiguation.
start = time()
predictions, timing = model.predict(mentions_dataset)
print('ED took: {}'.format(time() - start))