Skip to content

Commit 060fb54

Browse files
authored
cuDF engine
cuDF engine for gpu acceleration support Signed-off-by: Sasha Meister <[email protected]>
1 parent dcba064 commit 060fb54

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import cudf.pandas
2+
cudf.pandas.install()
3+
import pandas as pd
4+
5+
6+
class cuDF:
7+
def __init__(self):
8+
pass
9+
10+
def load_samples_chunk(self, samples: list[dict]):
11+
chunk = pd.DataFrame(samples)
12+
return chunk
13+
14+
def concat_samples_chunks(self, samples_chunks: list):
15+
samples_datatable = pd.concat(samples_chunks).reset_index(drop=True)
16+
return samples_datatable
17+
18+
def process_vocabulary(self, words_frequencies: dict, hypotheses_metrics: list[object]):
19+
vocabulary_dfs = []
20+
21+
words_frequencies_df = pd.DataFrame(words_frequencies.items(), columns=["Word", "Amount"]).set_index("Word")
22+
vocabulary_dfs.append(words_frequencies_df)
23+
24+
for hypothesis_metrics_obj in hypotheses_metrics:
25+
label = hypothesis_metrics_obj.hypothesis_label
26+
match_words_frequencies = hypothesis_metrics_obj.match_words_frequencies
27+
match_words_frequencies_df = pd.DataFrame(match_words_frequencies.items(), columns=["Word", f"Match_{hypothesis_metrics_obj.hypothesis_label}"]).set_index("Word")
28+
vocabulary_dfs.append(match_words_frequencies_df)
29+
30+
vocabulary_datatable = pd.concat(vocabulary_dfs, axis = 1, join = "outer").reset_index().fillna(0)
31+
32+
for hypothesis_metrics_obj in hypotheses_metrics:
33+
label = hypothesis_metrics_obj.hypothesis_label
34+
postfix = ""
35+
if label != "":
36+
postfix = f"_{label}"
37+
38+
vocabulary_datatable[f"Accuracy{postfix}"] = vocabulary_datatable[f"Match_{label}"] / vocabulary_datatable["Amount"] * 100
39+
vocabulary_datatable[f"Accuracy{postfix}"] = vocabulary_datatable[f"Accuracy{postfix}"].round(2)
40+
vocabulary_datatable = vocabulary_datatable.drop(f"Match_{label}", axis = 1)
41+
hypothesis_metrics_obj.mwa = round(vocabulary_datatable[f"Accuracy{postfix}"].mean(), 2)
42+
43+
return vocabulary_datatable

0 commit comments

Comments
 (0)