Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

33 fix bugs for inference on baskerville #34

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/RTC_configs/roberta-mt5-trained.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
ocr:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"
model: "microsoft/trocr-small-printed"

translator:
specific_task: "translation_fr_to_en"
Expand Down
3 changes: 1 addition & 2 deletions config/RTC_configs/roberta-mt5-zero-shot.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
ocr:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"
model: "microsoft/trocr-small-printed"

translator:
specific_task: "translation_fr_to_en"
Expand Down
2 changes: 2 additions & 0 deletions config/data_configs/l1_fr_to_en.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ lang_pair:
target: "en"

drop_length: 1000

load_ocr_data: True
15 changes: 15 additions & 0 deletions config/experiment/finalised_pipeline_zs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
data_config: l1_fr_to_en

pipeline_config: roberta-mt5-zero-shot

seed:
- 42
- 43
- 44

bask:
jobname: "full_experiment_with_zero_shot"
walltime: '0-24:0:0'
gpu_number: 1
node_number: 1
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"
8 changes: 6 additions & 2 deletions scripts/single_component_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,19 @@ def main(
# initialise pipeline
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)

if model_key != "ocr":
data_config["load_ocr_data"] = False

data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
test_loader = data_sets["test"]
if model_key == "ocr":
rtc_single_component_pipeline = RecognitionVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
model_pars=pipeline_config
)
elif model_key == "translator":
rtc_single_component_pipeline = TranslationVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
model_pars=pipeline_config
)
elif model_key == "classifier":
rtc_single_component_pipeline = ClassificationVariationalPipeline(
Expand Down
20 changes: 13 additions & 7 deletions src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,17 @@ def extract_articles(

def _make_ocr_data(text: str) -> list[tuple[Image.Image, str]]:
text_split = text.split()
text_split = [text for text in text_split if text not in ("", " ", None)]
text_split = [text for text in text_split if text not in ("", " ")]
generator = GeneratorFromStrings(text_split, count=len(text_split))
return list(generator)


def make_ocr_data(item: LazyRow) -> dict[str, tuple[Image.Image] | tuple[str]]:
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
return {"ocr_images": images, "ocr_targets": targets}
def make_ocr_data(item: LazyRow) -> dict:
try:
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
except ValueError:
return {"ocr_data": {"ocr_images": None, "ocr_targets": None}}
return {"ocr_data": {"ocr_images": images, "ocr_targets": targets}}


class TranslationPreProcesser:
Expand Down Expand Up @@ -229,11 +232,14 @@ def load_multieurlex_for_pipeline(
make_ocr_data,
features=datasets.Features(
{
"ocr_images": datasets.Sequence(datasets.Image(decode=True)),
"ocr_targets": datasets.Sequence(datasets.Value("string")),
"ocr_data": {
"ocr_images": datasets.Sequence(
datasets.Image(decode=True)
),
"ocr_targets": datasets.Sequence(datasets.Value("string")),
},
**feats,
}
),
)

return dataset_dict, meta_data
55 changes: 43 additions & 12 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
RTCVariationalPipeline,
)

RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"])
RecognitionResults = namedtuple(
"RecognitionResults", ["mean_entropy", "character_error_rate"]
)

TranslationResults = namedtuple(
"TranslationResults",
Expand Down Expand Up @@ -71,14 +73,16 @@ def get_results(

def recognition_results(
self,
clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]],
var_output: dict[str, dict],
clean_output: dict,
var_output: dict,
**kwargs,
):
# ### RECOGNITION ###
charerror = ocr_error(clean_output)
charerror = ocr_error(clean_output["recognition"])
confidence = var_output["recognition"]["mean_entropy"]
return RecognitionResults(confidence=confidence, accuracy=charerror)
return RecognitionResults(
mean_entropy=confidence, character_error_rate=charerror
)

def translation_results(
self,
Expand Down Expand Up @@ -150,13 +154,40 @@ def run_inference(
pipeline: RTCVariationalPipeline | RTCSingleComponentPipeline,
results_getter: ResultsGetter,
):
type_errors = []
oom_errors = []
results = []
for _, inp in enumerate(tqdm(dataloader)):
clean_out, var_out = pipeline.variational_inference(inp)
row_results_dict = results_getter.get_results(
clean_output=clean_out,
var_output=var_out,
test_row=inp,
)
results.append({inp["celex_id"]: row_results_dict})
# TEMPORARY FIX
try:
clean_out, var_out = pipeline.variational_inference(inp)
row_results_dict = results_getter.get_results(
clean_output=clean_out,
var_output=var_out,
test_row=inp,
)
results.append({inp["celex_id"]: row_results_dict})
# TEMPORARY FIX ->
except TypeError:
type_errors.append(inp["celex_id"])
continue

except torch.cuda.OutOfMemoryError:
oom_errors.append(inp["celex_id"])
continue

except torch.OutOfMemoryError:
oom_errors.append(inp["celex_id"])
continue

print("Skipped following CELEX IDs due to TypeError:")
print(
'"TypeError: Incorrect format used for image. Should be an url linking to'
' an image, a base64 string, a local path, or a PIL image."'
)
print(type_errors)

print("Skipped following CELEX IDs due to torch.cuda.OutOfMemoryError:")
print(oom_errors)

return results
9 changes: 5 additions & 4 deletions src/arc_spice/eval/ocr_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from torchmetrics.text import CharErrorRate

cer = CharErrorRate()


def ocr_error(ocr_output: dict[Any, Any]) -> float:
"""
Expand All @@ -30,7 +32,6 @@ def ocr_error(ocr_output: dict[Any, Any]) -> float:
Returns:
Character error rate across entire output of OCR (float)
"""
preds = [itm["generated_text"].lower() for itm in ocr_output["full_output"]]
targs = [itm["target"].lower() for itm in ocr_output["full_output"]]
cer = CharErrorRate()
return cer(preds, targs).item()
preds = [itm["generated_text"].lower() for itm in ocr_output["outputs"]]
targs = [itm["target"].lower() for itm in ocr_output["outputs"]]
return cer(preds, targs).detach().item()
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,23 @@ def __init__(
model_pars: dict[str, dict[str, str]],
n_variational_runs=5,
ocr_batch_size=64,
**kwargs,
):
self.set_device()
super().__init__(
step_name="recognition",
input_key="ocr_data",
forward_function=self.recognise,
confidence_function=self.get_ocr_confidence,
n_variational_runs=n_variational_runs,
)
self.ocr: transformers.Pipeline = pipeline(
model=model_pars["ocr"]["model"],
device=self.device,
pipeline_class=CustomOCRPipeline,
max_new_tokens=20,
batch_size=ocr_batch_size,
**kwargs,
)
self.model = self.ocr.model
super().__init__(
step_name="recognition",
input_key="ocr_data",
forward_function=self.recognise,
confidence_function=self.get_ocr_confidence,
n_variational_runs=n_variational_runs,
**kwargs,
)
self._init_pipeline_map()


Expand All @@ -118,7 +115,6 @@ def __init__(
model_pars: dict[str, dict[str, str]],
n_variational_runs=5,
translation_batch_size=4,
**kwargs,
):
self.set_device()
# need to initialise the NLI models in this case
Expand Down
13 changes: 7 additions & 6 deletions src/arc_spice/variational_pipelines/RTC_variational_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,19 @@ def clean_inference(self, x: torch.Tensor) -> dict[str, dict]:

# run the functions
# UNTIL THE OCR DATA IS AVAILABLE
clean_output["recognition"] = self.recognise(x)
clean_output["recognition"] = self.recognise(x["ocr_data"])

clean_output["translation"] = self.translate(
clean_output["recognition"]["outputs"]
clean_output["recognition"]["full_output"]
)
# we now need to pass the input correct to the correct forward method
if self.zero_shot:
clean_output["classification"] = self.classify_topic_zero_shot(
clean_output["translation"]["outputs"][0]
clean_output["translation"]["full_output"]
)
else:
clean_output["classification"] = self.classify_topic(
clean_output["translation"]["outputs"][0]
clean_output["translation"]["full_output"]
)
return clean_output

Expand All @@ -109,8 +109,8 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
}
# define the input map for brevity in forward pass
input_map = {
"recognition": x,
"translation": clean_output["recognition"]["outputs"],
"recognition": x["ocr_data"],
"translation": clean_output["recognition"]["full_output"],
"classification": clean_output["translation"]["full_output"],
}

Expand All @@ -130,6 +130,7 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:

# run metric helper functions
var_output = self.stack_variational_outputs(var_output)
var_output = self.get_ocr_confidence(var_output)
var_output = self.translation_semantic_density(
clean_output=clean_output, var_output=var_output
)
Expand Down
67 changes: 34 additions & 33 deletions src/arc_spice/variational_pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,29 +283,27 @@ def recognise(self, inp) -> dict[str, str | list[dict[str, str | torch.Tensor]]]
Returns:
dictionary of outputs:
{
'full_output': [
'outputs': [
{
'generated_text': generated text from ocr model (str),
'target': original target text (str)
}
],
'output': pieced back together string (str)
'full_output': pieced back together string (str)
}
"""
out = self.ocr(inp["ocr_data"]["ocr_images"]) # type: ignore[misc]
text = " ".join([itm[0]["generated_text"] for itm in out])
out = self.ocr(inp["ocr_images"]) # type: ignore[misc]
text = " ".join([itm["generated_text"] for itm in out])
return {
"full_output": [
"outputs": [
{
"target": target,
"generated_text": gen_text["generated_text"],
"entropies": gen_text["entropies"],
"entropies": gen_text["raw_output"]["entropies"],
}
for target, gen_text in zip(
inp["ocr_data"]["ocr_targets"], out, strict=True
)
for target, gen_text in zip(inp["ocr_targets"], out, strict=True)
],
"output": text,
"full_output": text,
}

def translate(self, text: str) -> dict[str, torch.Tensor | str]:
Expand Down Expand Up @@ -429,6 +427,31 @@ def stack_variational_outputs(
# overwrite the existing output dictionary
return new_var_dict

def get_ocr_confidence(self, var_output: dict, **kwargs) -> dict[str, float]:
"""Generate the ocr confidence score.

Args:
var_output: variational run outputs

Returns:
dictionary with metrics
"""
# Adapted for variational methods from: https://arxiv.org/pdf/2412.01221
entropies = []
recognition_batches = var_output["recognition"]["outputs"]
for batch in recognition_batches:
for sequence in batch:
ent = sequence["entropies"]
if ent.dim() == 1:
entropies.append(ent)
else:
entropies.append(ent.squeeze())
all_entropies = torch.cat(entropies)
# mean entropy
mean = torch.mean(all_entropies).item()
var_output["recognition"].update({"mean_entropy": mean})
return var_output

def sentence_density(
self,
clean_sentence: str,
Expand Down Expand Up @@ -584,28 +607,6 @@ def get_classification_confidence(
)
return var_output

def get_ocr_confidence(self, var_output: dict) -> dict[str, float]:
"""Generate the ocr confidence score.

Args:
var_output: variational run outputs

Returns:
dictionary with metrics
"""
# Adapted for variational methods from: https://arxiv.org/pdf/2412.01221
stacked_entropies = torch.stack(
[
[data["entropies"] for data in output["full_output"]]
for output in var_output["recognition"]
],
dim=1,
)
# mean entropy
mean = torch.mean(stacked_entropies)
var_output["recognition"].update({"mean_entropy": mean})
return var_output


# Translation pipeline with additional functionality to save logits from fwd pass
class CustomTranslationPipeline(TranslationPipeline):
Expand Down Expand Up @@ -699,4 +700,4 @@ def _forward(self, model_inputs, **generate_kwargs):

logits = torch.stack(out.logits, dim=1)
entropy = Categorical(logits=logits).entropy() / np.log(logits[0].size()[1])
return {"model_output": out.sequences, "entropies": entropy}
return {"model_output": out.sequences, "entropies": entropy.squeeze()}
Loading
Loading