-
Notifications
You must be signed in to change notification settings - Fork 214
/
Copy pathestimator_argilla.py
119 lines (114 loc) · 5.16 KB
/
estimator_argilla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argilla as rg
import time
import pandas as pd
from argilla.client.singleton import active_client
from utils.config import Color
from dataset.base_dataset import DatasetBase
import json
import webbrowser
import base64
class ArgillaEstimator:
"""
The ArgillaEstimator class is responsible to generate the GT for the dataset by using Argilla interface.
In particular using the text classification mode.
"""
def __init__(self, opt):
"""
Initialize a new instance of the ArgillaEstimator class.
"""
try:
self.opt = opt
rg.init(
api_url=opt.api_url,
api_key=opt.api_key,
workspace=opt.workspace
)
self.time_interval = opt.time_interval
except:
raise Exception("Failed to connect to argilla, check connection details")
@staticmethod
def initialize_dataset(dataset_name: str, label_schema: set[str]):
"""
Initialize a new dataset in the Argilla system
:param dataset_name: The name of the dataset
:param label_schema: The list of classes
"""
try:
settings = rg.TextClassificationSettings(label_schema=label_schema)
rg.configure_dataset_settings(name=dataset_name, settings=settings)
except:
raise Exception("Failed to create dataset")
@staticmethod
def upload_missing_records(dataset_name: str, batch_id: int, batch_records: pd.DataFrame):
"""
Update the Argilla dataset by adding missing records from batch_id that appears in batch_records
:param dataset_name: The dataset name
:param batch_id: The batch id
:param batch_records: A dataframe of the batch records
"""
#TODO: sort visualization according to batch_id descending
query = "metadata.batch_id:{}".format(batch_id)
result = rg.load(name=dataset_name, query=query)
df = result.to_pandas()
if len(df) == len(batch_records):
return
if df.empty:
upload_df = batch_records
else:
merged_df = pd.merge(batch_records, df['text'], on='text', how='left', indicator=True)
upload_df = merged_df[merged_df['_merge'] == 'left_only'].drop(columns=['_merge'])
record_list = []
for index, row in upload_df.iterrows():
config = {'text': row['text'], 'metadata': {"batch_id": row['batch_id'], 'id': row['id']}, "id": row['id']}
# if not (row[['prediction']].isnull().any()):
# config['prediction'] = row['prediction'] # TODO: fix it incorrect type!!!
if not(row[['annotation']].isnull().any()): # TODO: fix it incorrect type!!!
config['annotation'] = row['annotation']
record_list.append(rg.TextClassificationRecord(**config))
rg.log(records=record_list, name=dataset_name)
def calc_usage(self):
"""
Dummy function to calculate the usage of the estimator
"""
return 0
def apply(self, dataset: DatasetBase, batch_id: int):
"""
Apply the estimator on the dataset. The function enter to infinite loop until all the records are annotated.
Then it update the dataset with all the annotations
:param dataset: DatasetBase object, contains all the processed records
:param batch_id: The batch id to annotate
"""
current_api = active_client()
try:
rg_dataset = current_api.datasets.find_by_name(dataset.name)
except:
self.initialize_dataset(dataset.name, dataset.label_schema)
rg_dataset = current_api.datasets.find_by_name(dataset.name)
batch_records = dataset[batch_id]
if batch_records.empty:
return []
self.upload_missing_records(dataset.name, batch_id, batch_records)
data = {'metadata': {'batch_id': [str(batch_id)]}}
json_data = json.dumps(data)
encoded_bytes = base64.b64encode(json_data.encode('utf-8'))
encoded_string = str(encoded_bytes, "utf-8")
url_link = self.opt.api_url + '/datasets/' + self.opt.workspace + '/' \
+ dataset.name + '?query=' + encoded_string
print(f"{Color.GREEN}Waiting for annotations from batch {batch_id}:\n{url_link}{Color.END}")
webbrowser.open(url_link)
while True:
query = "(status:Validated OR status:Discarded) AND metadata.batch_id:{}".format(batch_id)
search_results = current_api.search.search_records(
name=dataset.name,
task=rg_dataset.task,
size=0,
query_text=query,
)
if search_results.total == len(batch_records):
result = rg.load(name=dataset.name, query=query)
df = result.to_pandas()[['text', 'annotation', 'metadata', 'status']]
df["annotation"] = df.apply(lambda x: 'Discarded' if x['status']=='Discarded' else x['annotation'], axis=1)
df = df.drop(columns=['status'])
df['id'] = df.apply(lambda x: x['metadata']['id'], axis=1)
return df
time.sleep(self.time_interval)