-
Notifications
You must be signed in to change notification settings - Fork 1
/
apply_body_subject_classification.py
127 lines (108 loc) · 4.6 KB
/
apply_body_subject_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Tracer code to figure out how to parse out DocX files."""
import os
from transformers import pipeline
import re
from database_model_definitions import Article, AIResultBody, SEAWEED_LABEL, OTHER_AQUACULTURE_LABEL
from database_model_definitions import RELEVANT_LABEL, IRRELEVANT_LABEL
from database_model_definitions import RELEVANT_TAG, IRRELEVANT_TAG, SEAWEED_TAG, OTHER_AQUACULTURE_TAG
from database import SessionLocal, init_db
RELEVANT_LABEL_TO_SUBJECT = {
f'LABEL_{RELEVANT_LABEL}': RELEVANT_TAG,
f'LABEL_{IRRELEVANT_LABEL}': IRRELEVANT_TAG
}
AQUACULTURE_LABEL_TO_SUBJECT = {
f'LABEL_{SEAWEED_LABEL}': SEAWEED_TAG,
f'LABEL_{OTHER_AQUACULTURE_LABEL}': OTHER_AQUACULTURE_TAG,
}
RELEVANT_SUBJECT_MODEL_PATH = "wwf-seaweed-body-subject-relevant-irrelevant/allenai-longformer-base-4096_19"
AQUACULTURE_SUBJECT_MODEL_PATH = "wwf-seaweed-body-subject-aquaculture-type/allenai-longformer-base-4096_36"
def standardize_text(text):
text = re.sub(r'[!?]', '.', text)
text = re.sub(r'[\'",;:-]', '', text)
text = text.replace('\n', ' ').replace('\r', '')
text = text.lower()
text = re.sub(r'\s+', ' ', text).strip()
return text
def make_pipeline(model_type, model_path):
_pipeline = pipeline(
model_type, model=model_path,
device='cuda', truncation=True)
def _standardize_and_sentiment(raw_text_list):
clean_text = [
standardize_text(raw_text)
for raw_text in raw_text_list]
unique_text = list(set(clean_text))
text_to_index_lookup = {
text_val: index for index, text_val in enumerate(unique_text)
}
unique_results = _pipeline(unique_text)
full_results = [
unique_results[text_to_index_lookup[base_text_val]]
for base_text_val in clean_text]
return full_results
return _standardize_and_sentiment
def main():
not_found = False
for model_path in [RELEVANT_SUBJECT_MODEL_PATH, AQUACULTURE_SUBJECT_MODEL_PATH]:
if not os.path.exists(model_path):
not_found = True
print(f'{model_path} not found, you need to download it from wherever Sam uploaded it to, ask her!')
if not_found:
return
seaweed_re = re.compile('(seaweed|kelp|sea moss)', re.IGNORECASE)
relevant_subject_model = make_pipeline(
'text-classification', RELEVANT_SUBJECT_MODEL_PATH)
aquaculture_subject_model = make_pipeline(
'text-classification', AQUACULTURE_SUBJECT_MODEL_PATH)
print('loaded models...')
init_db()
session = SessionLocal()
articles_without_ai = (
session.query(Article).outerjoin(
AIResultBody, Article.id_key == AIResultBody.article_id)
.filter(
AIResultBody.id_key == None,
Article.body != None,
Article.body != '')
.all())
bodies_without_ai = [article.body for article in articles_without_ai]
print(f'doing sentiment-analysis on {len(bodies_without_ai)} article bodies')
relevant_subject_result_list = [
{
'label': RELEVANT_LABEL_TO_SUBJECT[val['label']],
'score': val['score']
} for val in relevant_subject_model(bodies_without_ai)]
relevant_articles = [
article for article, classification in
zip(articles_without_ai, relevant_subject_result_list)
if classification['label'] == RELEVANT_TAG]
relevant_bodies = [article.body for article in relevant_articles]
aquaculture_type_result_list = [
{
'label': AQUACULTURE_LABEL_TO_SUBJECT[val['label']],
'score': val['score']
} for val in aquaculture_subject_model(relevant_bodies)]
print('updating database')
for article, relevant_subject_result in zip(
articles_without_ai, relevant_subject_result_list):
if relevant_subject_result['label'] == RELEVANT_TAG:
continue
article.body_subject_ai = [
AIResultBody(
value=relevant_subject_result['label'],
score=relevant_subject_result['score'])]
for article, aquaculture_type_result in zip(
relevant_articles, aquaculture_type_result_list):
working_label = aquaculture_type_result['label']
if working_label == OTHER_AQUACULTURE_TAG and re.search(
seaweed_re, article.body):
# override if there's a regular expression match
working_label = SEAWEED_TAG
article.body_subject_ai = [
AIResultBody(
value=working_label,
score=aquaculture_type_result['score'])]
session.commit()
session.close()
if __name__ == '__main__':
main()