This repository has been archived by the owner on Jul 26, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquery.py
149 lines (120 loc) · 4.83 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import argparse
import torch
import transformers
from typing import Dict, List, Literal, Tuple, cast
from datasets import load_dataset, DatasetDict
from dotenv import load_dotenv
from src.readers.base_reader import Reader
from src.readers.longformer_reader import LongformerReader
from src.readers.dpr_reader import DprReader
from src.retrievers.base_retriever import Retriever
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import (
FaissRetriever,
FaissRetrieverOptions
)
from src.utils.preprocessing import context_to_reader_input
from src.utils.log import logger
# Setup environment
load_dotenv()
transformers.logging.set_verbosity_error()
def get_retriever(paragraphs: DatasetDict,
r: Literal["es", "faiss"],
lm: Literal["dpr", "longformer"]) -> Retriever:
match (r, lm):
case "es", _:
return ESRetriever(paragraphs)
case "faiss", "dpr":
options = FaissRetrieverOptions.dpr("./src/models/dpr.faiss")
return FaissRetriever(paragraphs, options)
case "faiss", "longformer":
options = FaissRetrieverOptions.longformer(
"./src/models/longformer.faiss")
return FaissRetriever(paragraphs, options)
case _:
raise ValueError("Retriever options not recognized")
def get_reader(lm: Literal["dpr", "longformer"]) -> Reader:
match lm:
case "dpr":
return DprReader()
case "longformer":
return LongformerReader()
case _:
raise ValueError("Language model not recognized")
def print_name(contexts: dict, section: str, id: int):
name = contexts[section][id]
if name != 'nan':
print(f" {section}: {name}")
def get_retrieval_span_scores(answers: List[tuple]):
# calculate answer scores
sm = torch.nn.Softmax(dim=0)
d_scores = sm(torch.Tensor(
[pred.relevance_score for pred in answers]))
s_scores = sm(torch.Tensor(
[pred.span_score for pred in answers]))
return d_scores, s_scores
def print_answers(answers: List[tuple], scores: List[float], contexts: dict):
d_scores, s_scores = get_retrieval_span_scores(answers)
for pos, answer in enumerate(answers):
print(f"{pos + 1:>4}. {answer.text}")
print(f" {'-' * len(answer.text)}")
print_name(contexts, 'chapter', answer.doc_id)
print_name(contexts, 'section', answer.doc_id)
print_name(contexts, 'subsection', answer.doc_id)
print(f" retrieval score: {scores[answer.doc_id]:6.02f}%")
print(f" document score: {d_scores[pos] * 100:6.02f}%")
print(f" span score: {s_scores[pos] * 100:6.02f}%")
print()
def probe(query: str,
retriever: Retriever,
reader: Reader,
num_answers: int = 5) \
-> Tuple[List[tuple], List[float], Dict[str, List[str]]]:
scores, contexts = retriever.retrieve(query)
reader_input = context_to_reader_input(contexts)
answers = reader.read(query, reader_input, num_answers)
return answers, scores, contexts
def default_probe(query: str):
# default probe is a probe that prints 5 answers with faiss
paragraphs = cast(DatasetDict, load_dataset(
"GroNLP/ik-nlp-22_slp", "paragraphs"))
retriever = get_retriever(paragraphs, "faiss", "dpr")
reader = DprReader()
return probe(query, retriever, reader)
def main(args: argparse.Namespace):
# Initialize dataset
paragraphs = cast(DatasetDict, load_dataset(
"GroNLP/ik-nlp-22_slp", "paragraphs"))
# Retrieve
retriever = get_retriever(paragraphs, args.retriever, args.lm)
reader = get_reader(args.lm)
answers, scores, contexts = probe(
args.query, retriever, reader, args.top)
# Print output
print("Question: " + args.query)
print("Answer(s):")
if args.lm == "dpr":
print_answers(answers, scores, contexts)
else:
answers = filter(lambda a: len(a[0].strip()) > 0, answers)
for pos, answer in enumerate(answers, start=1):
print(f" - {answer[0].strip()}")
if __name__ == "__main__":
# Set up CLI arguments
parser = argparse.ArgumentParser(
formatter_class=argparse.MetavarTypeHelpFormatter
)
parser.add_argument(
"query", type=str, help="The question to feed to the QA system")
parser.add_argument(
"--top", "-t", type=int, default=1,
help="The number of answers to retrieve")
parser.add_argument(
"--retriever", "-r", type=str.lower, choices=["faiss", "es"],
default="faiss", help="The retrieval method to use")
parser.add_argument(
"--lm", "-l", type=str.lower,
choices=["dpr", "longformer"], default="dpr",
help="The language model to use for the FAISS retriever")
args = parser.parse_args()
main(args)