|
| 1 | +import cudf.pandas |
| 2 | +cudf.pandas.install() |
| 3 | +import pandas as pd |
| 4 | + |
| 5 | + |
| 6 | +class cuDF: |
| 7 | + def __init__(self): |
| 8 | + pass |
| 9 | + |
| 10 | + def load_samples_chunk(self, samples: list[dict]): |
| 11 | + chunk = pd.DataFrame(samples) |
| 12 | + return chunk |
| 13 | + |
| 14 | + def concat_samples_chunks(self, samples_chunks: list): |
| 15 | + samples_datatable = pd.concat(samples_chunks).reset_index(drop=True) |
| 16 | + return samples_datatable |
| 17 | + |
| 18 | + def process_vocabulary(self, words_frequencies: dict, hypotheses_metrics: list[object]): |
| 19 | + vocabulary_dfs = [] |
| 20 | + |
| 21 | + words_frequencies_df = pd.DataFrame(words_frequencies.items(), columns=["Word", "Amount"]).set_index("Word") |
| 22 | + vocabulary_dfs.append(words_frequencies_df) |
| 23 | + |
| 24 | + for hypothesis_metrics_obj in hypotheses_metrics: |
| 25 | + label = hypothesis_metrics_obj.hypothesis_label |
| 26 | + match_words_frequencies = hypothesis_metrics_obj.match_words_frequencies |
| 27 | + match_words_frequencies_df = pd.DataFrame(match_words_frequencies.items(), columns=["Word", f"Match_{hypothesis_metrics_obj.hypothesis_label}"]).set_index("Word") |
| 28 | + vocabulary_dfs.append(match_words_frequencies_df) |
| 29 | + |
| 30 | + vocabulary_datatable = pd.concat(vocabulary_dfs, axis = 1, join = "outer").reset_index().fillna(0) |
| 31 | + |
| 32 | + for hypothesis_metrics_obj in hypotheses_metrics: |
| 33 | + label = hypothesis_metrics_obj.hypothesis_label |
| 34 | + postfix = "" |
| 35 | + if label != "": |
| 36 | + postfix = f"_{label}" |
| 37 | + |
| 38 | + vocabulary_datatable[f"Accuracy{postfix}"] = vocabulary_datatable[f"Match_{label}"] / vocabulary_datatable["Amount"] * 100 |
| 39 | + vocabulary_datatable[f"Accuracy{postfix}"] = vocabulary_datatable[f"Accuracy{postfix}"].round(2) |
| 40 | + vocabulary_datatable = vocabulary_datatable.drop(f"Match_{label}", axis = 1) |
| 41 | + hypothesis_metrics_obj.mwa = round(vocabulary_datatable[f"Accuracy{postfix}"].mean(), 2) |
| 42 | + |
| 43 | + return vocabulary_datatable |
0 commit comments