-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnovel_ngrams_predictions.py
executable file
·84 lines (63 loc) · 2.56 KB
/
novel_ngrams_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Modified version of
https://github.com/Tixierae/OrangeSum/blob/main/compute_overlap.py
# Original copyright is appended below.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
import re
import numpy as np
import matplotlib.pyplot as plt
from nltk.tokenize import sent_tokenize
import nltk
nltk.download('punkt')
import os
import json
import string
from nltk import ngrams
import argparse
def pct_novel_ngrams_in_y(x,y,nmax):
# remove punctuation and lowercase
x = x.translate(str.maketrans('', '', string.punctuation)).lower()
y = y.translate(str.maketrans('', '', string.punctuation)).lower()
percs = dict()
for n in range(1,nmax+1):
ngrams_x = set(ngrams(x.split(),n))
ngrams_y = set(ngrams(y.split(),n))
if len(ngrams_y) == 0:
percs[n] = 'NA'
else:
percs[n] = round(100*len(ngrams_y.difference(ngrams_x))/len(ngrams_y),1)
return percs
# = = = = =
parser = argparse.ArgumentParser()
parser.add_argument('--path_predictions', '-pred', default=None, type=str, help='Path to predictions\' file')
parser.add_argument('--path_article', '-art', default=None, type=str, help='Path to artile')
#parser.add_argument('--dataset_type','-dt', default=None, type=str, help='The type of GreekSum dataset. Accepted values are: Abstract or Title')
nmax = 4 # greatest n-gram order to consider
min_size = 20
args = parser.parse_args()
#path_pred = './title/generated_output.txt'
#path_ref = './title/summarization_data_title/test-article.txt'
path_pred = args.path_predictions
path_art = args.path_article
lens = []
results = dict()
counter = 0
with open(path_pred, 'r') as fr1, open(path_art, 'r') as fr2:
for line1, line2 in zip(fr1, fr2):
article = line2.strip()
head = line1.strip()
lens.append(len(head.split()))
if len(article.split()) > min_size:
to_save = dict()
# whenever the field is too short to have at least one nmax-gram, NA is returned
to_save['pred'] = pct_novel_ngrams_in_y(article,head,nmax)
results[counter]=to_save
counter+=1
print('= = = size (in nb of words) of predictions = = =')
print('min: %s, max: %s, average: %s, median: %s' % (min(lens),max(lens),round(np.mean(lens),2),np.median(lens)))
print('= = = = percentage of novel ngrams in: predictions = = = =')
for n in range(1,nmax + 1):
print('* * * * order:',n,'* * * *')
print(round(np.mean([v['pred'][int(n)] for k,v in results.items() if not v['pred'][int(n)] == 'NA']),1))