Skip to content

Commit

Permalink
updated get_results to return each row and append to a list of result…
Browse files Browse the repository at this point in the history
…s. These are each dictionaries which are identified using the celex_id from the dataset
  • Loading branch information
J-Dymond committed Nov 28, 2024
1 parent d6dbc04 commit c1fba2a
Showing 1 changed file with 54 additions and 58 deletions.
112 changes: 54 additions & 58 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import namedtuple
from collections.abc import Callable
from typing import Any

Expand All @@ -15,10 +16,31 @@
RTCVariationalPipeline,
)

RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"])
ClassificationResults = namedtuple(
"ClassificationResults",
[
"mean_scores",
"hamming_accuracy",
"zero_one_accuracy",
"mean_predicted_entropy",
],
)
TranslationResults = namedtuple(
"TranslationResults",
[
"full_output",
"comet_score",
"weighted_semantic_density",
],
)


class ResultsGetter:
def __init__(self, n_classes: int):
self.results_func_map: dict[str, Callable] = {
self.results_func_map: dict[
str, Callable[..., ClassificationResults | TranslationResults]
] = {
"recognition": self.recognition_results,
"translation": self.translation_results,
"classification": self.classification_results,
Expand All @@ -31,26 +53,26 @@ def get_results(
clean_output: dict[str, dict],
var_output: dict[str, dict],
test_row: dict[str, list[int]],
results_dict: dict[str, dict[str, list]],
):
results_dict = {}
for step_name in clean_output:
results_dict = self.results_func_map[step_name](
test_row, clean_output, var_output, results_dict
)
results_dict["input_data"]["celex_ids"].append(test_row["celex_id"])
results_dict[step_name] = self.results_func_map[step_name](
test_row=test_row,
clean_output=clean_output,
var_output=var_output,
)._asdict()
return results_dict

def recognition_results(self):
def recognition_results(self, *args, **kwargs):
# ### RECOGNITION ###
# TODO: add this into results_getter issue #14
raise NotImplementedError()
# TODO: add this into results_getter : issue #14
return RecognitionResults(confidence=None, accuracy=None)

def translation_results(
self,
test_row: dict[str, Any],
clean_output: dict[str, dict],
var_output: dict[str, dict],
results_dict: dict[str, dict],
):
# ### TRANSLATION ###
source_text = test_row["target_text"]
Expand All @@ -70,77 +92,51 @@ def translation_results(
comet_output = self.comet_model.predict(
comet_inp, batch_size=8, accelerator=comet_device
)
comet_scores = comet_output["scores"]
results_dict["translation"]["full_output"].append(clean_translation)
results_dict["translation"]["comet_score"].append(comet_scores[0])
results_dict["translation"]["weighted_semantic_density"].append(
var_output["translation"]["weighted_semantic_density"]

return TranslationResults(
comet_score=comet_output["scores"][0],
full_output=clean_translation,
weighted_semantic_density=var_output["translation"][
"weighted_semantic_density"
],
)
return results_dict

def classification_results(
self,
test_row: dict[str, Any],
_: dict[str, dict],
var_output: dict[str, dict],
results_dict: dict[str, dict],
**kwargs,
):
# ### CLASSIFICATION ###
mean_scores = var_output["classification"]["mean_scores"]
mean_scores: torch.Tensor = var_output["classification"]["mean_scores"]
preds = torch.round(mean_scores).tolist()
labels = self.multihot(test_row["labels"])
hamming_acc = hamming_loss(y_pred=preds, y_true=labels)
zero_one_acc = zero_one_loss(y_pred=preds, y_true=labels)
results_dict["classification"]["mean_scores"].append(
mean_scores.detach().tolist()
)
results_dict["classification"]["hamming_accuracy"].append(hamming_acc)
results_dict["classification"]["zero_one_accuracy"].append(zero_one_acc)
results_dict["classification"]["mean_predicted_entropy"].append(
torch.mean(var_output["classification"]["predicted_entropy"]).item()
)

return results_dict
return ClassificationResults(
mean_scores=mean_scores.detach().tolist(),
hamming_accuracy=hamming_acc,
zero_one_accuracy=zero_one_acc,
mean_predicted_entropy=torch.mean(
var_output["classification"]["predicted_entropy"]
).item(),
)


def run_inference(
dataloader: DataLoader,
pipeline: RTCVariationalPipeline | RTCSingleComponentPipeline,
results_getter: ResultsGetter,
):
# Get_results updates the results_dict. So it needs to be initialised before being
# run. It is overwritten if a RTCSingleComponentPipeline is used, since some entries
# are not needed.
results_dict = {
"input_data": {"celex_ids": []},
# Placeholder
"ocr": {"outputs": [], "confidence": [], "accuracy": []}, # PLACEHOLDER
"translation": {
"full_output": [],
"weighted_semantic_density": [],
"comet_score": [],
},
"classification": {
"mean_scores": [],
"mean_predicted_entropy": [],
"hamming_accuracy": [],
"zero_one_accuracy": [],
},
}
if isinstance(pipeline, RTCSingleComponentPipeline):
# only need appropriate result dict when evaluating individual component
results_dict = {
"input_data": {"celex_ids": []},
pipeline.step_name: results_dict[pipeline.step_name],
}

results = []
for _, inp in enumerate(tqdm(dataloader)):
clean_out, var_out = pipeline.variational_inference(inp)
results_dict = results_getter.get_results(
row_results_dict = results_getter.get_results(
clean_output=clean_out,
var_output=var_out,
test_row=inp,
results_dict=results_dict,
)

return results_dict
results.append({inp["celex_id"]: row_results_dict})
break
return results

0 comments on commit c1fba2a

Please sign in to comment.