Skip to content

Commit

Permalink
changes made from baskerville to be merged into main
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Dec 7, 2024
1 parent 3f6c10d commit 7df93c3
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 64 deletions.
2 changes: 1 addition & 1 deletion config/RTC_configs/roberta-mt5-trained.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
ocr:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"
model: "microsoft/trocr-small-printed"

translator:
specific_task: "translation_fr_to_en"
Expand Down
3 changes: 1 addition & 2 deletions config/RTC_configs/roberta-mt5-zero-shot.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
ocr:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"
model: "microsoft/trocr-small-printed"

translator:
specific_task: "translation_fr_to_en"
Expand Down
2 changes: 2 additions & 0 deletions config/data_configs/l1_fr_to_en.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ lang_pair:
target: "en"

drop_length: 1000

load_ocr_data: True
8 changes: 6 additions & 2 deletions scripts/single_component_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,19 @@ def main(
# initialise pipeline
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)

if model_key != "ocr":
data_config["load_ocr_data"] = False

data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
test_loader = data_sets["test"]
if model_key == "ocr":
rtc_single_component_pipeline = RecognitionVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
model_pars=pipeline_config
)
elif model_key == "translator":
rtc_single_component_pipeline = TranslationVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
model_pars=pipeline_config
)
elif model_key == "classifier":
rtc_single_component_pipeline = ClassificationVariationalPipeline(
Expand Down
20 changes: 13 additions & 7 deletions src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,17 @@ def extract_articles(

def _make_ocr_data(text: str) -> list[tuple[Image.Image, str]]:
text_split = text.split()
text_split = [text for text in text_split if text not in ("", " ", None)]
text_split = [text for text in text_split if text not in ("", " ")]
generator = GeneratorFromStrings(text_split, count=len(text_split))
return list(generator)


def make_ocr_data(item: LazyRow) -> dict[str, tuple[Image.Image] | tuple[str]]:
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
return {"ocr_images": images, "ocr_targets": targets}
def make_ocr_data(item: LazyRow) -> dict:
try:
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
except ValueError:
return {"ocr_data": {"ocr_images": None, "ocr_targets": None}}
return {"ocr_data": {"ocr_images": images, "ocr_targets": targets}}


class TranslationPreProcesser:
Expand Down Expand Up @@ -229,11 +232,14 @@ def load_multieurlex_for_pipeline(
make_ocr_data,
features=datasets.Features(
{
"ocr_images": datasets.Sequence(datasets.Image(decode=True)),
"ocr_targets": datasets.Sequence(datasets.Value("string")),
"ocr_data": {
"ocr_images": datasets.Sequence(
datasets.Image(decode=True)
),
"ocr_targets": datasets.Sequence(datasets.Value("string")),
},
**feats,
}
),
)

return dataset_dict, meta_data
6 changes: 3 additions & 3 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def get_results(

def recognition_results(
self,
clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]],
var_output: dict[str, dict],
clean_output: dict,
var_output: dict,
**kwargs,
):
# ### RECOGNITION ###
charerror = ocr_error(clean_output)
charerror = ocr_error(clean_output["recognition"])
confidence = var_output["recognition"]["mean_entropy"]
return RecognitionResults(confidence=confidence, accuracy=charerror)

Expand Down
4 changes: 2 additions & 2 deletions src/arc_spice/eval/ocr_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def ocr_error(ocr_output: dict[Any, Any]) -> float:
Returns:
Character error rate across entire output of OCR (float)
"""
preds = [itm["generated_text"].lower() for itm in ocr_output["full_output"]]
targs = [itm["target"].lower() for itm in ocr_output["full_output"]]
preds = [itm["generated_text"].lower() for itm in ocr_output["outputs"]]
targs = [itm["target"].lower() for itm in ocr_output["outputs"]]
cer = CharErrorRate()
return cer(preds, targs).item()
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,23 @@ def __init__(
model_pars: dict[str, dict[str, str]],
n_variational_runs=5,
ocr_batch_size=64,
**kwargs,
):
self.set_device()
super().__init__(
step_name="recognition",
input_key="ocr_data",
forward_function=self.recognise,
confidence_function=self.get_ocr_confidence,
n_variational_runs=n_variational_runs,
)
self.ocr: transformers.Pipeline = pipeline(
model=model_pars["ocr"]["model"],
device=self.device,
pipeline_class=CustomOCRPipeline,
max_new_tokens=20,
batch_size=ocr_batch_size,
**kwargs,
)
self.model = self.ocr.model
super().__init__(
step_name="recognition",
input_key="ocr_data",
forward_function=self.recognise,
confidence_function=self.get_ocr_confidence,
n_variational_runs=n_variational_runs,
**kwargs,
)
self._init_pipeline_map()


Expand All @@ -118,7 +115,6 @@ def __init__(
model_pars: dict[str, dict[str, str]],
n_variational_runs=5,
translation_batch_size=4,
**kwargs,
):
self.set_device()
# need to initialise the NLI models in this case
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def clean_inference(self, x: torch.Tensor) -> dict[str, dict]:

# run the functions
# UNTIL THE OCR DATA IS AVAILABLE
clean_output["recognition"] = self.recognise(x)
clean_output["recognition"] = self.recognise(x["ocr_data"])

clean_output["translation"] = self.translate(
clean_output["recognition"]["outputs"]
Expand Down Expand Up @@ -109,8 +109,8 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
}
# define the input map for brevity in forward pass
input_map = {
"recognition": x,
"translation": clean_output["recognition"]["outputs"],
"recognition": x["ocr_data"],
"translation": clean_output["recognition"]["full_output"],
"classification": clean_output["translation"]["full_output"],
}

Expand Down
67 changes: 34 additions & 33 deletions src/arc_spice/variational_pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,29 +283,27 @@ def recognise(self, inp) -> dict[str, str | list[dict[str, str | torch.Tensor]]]
Returns:
dictionary of outputs:
{
'full_output': [
'outputs': [
{
'generated_text': generated text from ocr model (str),
'target': original target text (str)
}
],
'output': pieced back together string (str)
'full_output': pieced back together string (str)
}
"""
out = self.ocr(inp["ocr_data"]["ocr_images"]) # type: ignore[misc]
text = " ".join([itm[0]["generated_text"] for itm in out])
out = self.ocr(inp["ocr_images"]) # type: ignore[misc]
text = " ".join([itm["generated_text"] for itm in out])
return {
"full_output": [
"outputs": [
{
"target": target,
"generated_text": gen_text["generated_text"],
"entropies": gen_text["entropies"],
"entropies": gen_text["raw_output"]["entropies"],
}
for target, gen_text in zip(
inp["ocr_data"]["ocr_targets"], out, strict=True
)
for target, gen_text in zip(inp["ocr_targets"], out, strict=True)
],
"output": text,
"full_output": text,
}

def translate(self, text: str) -> dict[str, torch.Tensor | str]:
Expand Down Expand Up @@ -429,6 +427,31 @@ def stack_variational_outputs(
# overwrite the existing output dictionary
return new_var_dict

def get_ocr_confidence(self, var_output: dict, **kwargs) -> dict[str, float]:
"""Generate the ocr confidence score.
Args:
var_output: variational run outputs
Returns:
dictionary with metrics
"""
# Adapted for variational methods from: https://arxiv.org/pdf/2412.01221
entropies = []
recognition_batches = var_output["recognition"]["outputs"]
for batch in recognition_batches:
for sequence in batch:
ent = sequence["entropies"]
if ent.dim() == 1:
entropies.append(ent)
else:
entropies.append(ent.squeeze())
all_entropies = torch.cat(entropies)
# mean entropy
mean = torch.mean(all_entropies)
var_output["recognition"].update({"mean_entropy": mean})
return var_output

def sentence_density(
self,
clean_sentence: str,
Expand Down Expand Up @@ -584,28 +607,6 @@ def get_classification_confidence(
)
return var_output

def get_ocr_confidence(self, var_output: dict) -> dict[str, float]:
"""Generate the ocr confidence score.
Args:
var_output: variational run outputs
Returns:
dictionary with metrics
"""
# Adapted for variational methods from: https://arxiv.org/pdf/2412.01221
stacked_entropies = torch.stack(
[
[data["entropies"] for data in output["full_output"]]
for output in var_output["recognition"]
],
dim=1,
)
# mean entropy
mean = torch.mean(stacked_entropies)
var_output["recognition"].update({"mean_entropy": mean})
return var_output


# Translation pipeline with additional functionality to save logits from fwd pass
class CustomTranslationPipeline(TranslationPipeline):
Expand Down Expand Up @@ -699,4 +700,4 @@ def _forward(self, model_inputs, **generate_kwargs):

logits = torch.stack(out.logits, dim=1)
entropy = Categorical(logits=logits).entropy() / np.log(logits[0].size()[1])
return {"model_output": out.sequences, "entropies": entropy}
return {"model_output": out.sequences, "entropies": entropy.squeeze()}

0 comments on commit 7df93c3

Please sign in to comment.