diff --git a/.gitignore b/.gitignore index 9d6dc80..75a9bb2 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,7 @@ Thumbs.db slurm_scripts/slurm_logs* +slurm_scripts/experiments* # other temp .vscode diff --git a/config/RTC_configs/roberta-mt5-zero-shot.yaml b/config/RTC_configs/roberta-mt5-zero-shot.yaml index 506a6e8..85a2d79 100644 --- a/config/RTC_configs/roberta-mt5-zero-shot.yaml +++ b/config/RTC_configs/roberta-mt5-zero-shot.yaml @@ -1,4 +1,4 @@ -OCR: +ocr: specific_task: "image-to-text" model: "microsoft/trocr-base-handwritten" diff --git a/config/data_configs/l1_fr_to_en.yaml b/config/data_configs/l1_fr_to_en.yaml index 1a3a373..58e12f1 100644 --- a/config/data_configs/l1_fr_to_en.yaml +++ b/config/data_configs/l1_fr_to_en.yaml @@ -5,3 +5,5 @@ level: 1 lang_pair: source: "fr" target: "en" + +drop_length: 1000 diff --git a/config/experiment/baskerville_pipeline_inference_test.yaml b/config/experiment/baskerville_pipeline_inference_test.yaml new file mode 100644 index 0000000..281f1d7 --- /dev/null +++ b/config/experiment/baskerville_pipeline_inference_test.yaml @@ -0,0 +1,13 @@ +data_config: l1_fr_to_en + +pipeline_config: roberta-mt5-zero-shot + +seed: + - 42 + +bask: + jobname: "shortened_input_test" + walltime: '0-12:0:0' + gpu_number: 1 + node_number: 1 + hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache" diff --git a/config/experiment/full_experiment_zero_shot.yaml b/config/experiment/full_experiment_zero_shot.yaml new file mode 100644 index 0000000..db95e53 --- /dev/null +++ b/config/experiment/full_experiment_zero_shot.yaml @@ -0,0 +1,15 @@ +data_config: l1_fr_to_en + +pipeline_config: roberta-mt5-zero-shot + +seed: + - 42 + - 43 + - 44 + +bask: + jobname: "full_experiment_with_zero_shot" + walltime: '0-24:0:0' + gpu_number: 1 + node_number: 1 + hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache" diff --git a/scripts/README.md b/scripts/README.md index 2992405..d5e8e74 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -106,3 +106,30 @@ It's called like so e.g. from project root: ```bash python scripts/pipeline_inference.py [pipeline_config_path] [data_config_path] translator ``` + +## gen_jobscripts.py + +Create jobscript `.sh` files for an experiment, which in this case refers to a `data_config` and `pipeline_config` combo. +It takes a single argument which is `experiment_config_path`. This refers to a file path to a `.yaml` file structured as below: + +### eg. Experiment config: + +```yaml +data_config: l1_fr_to_en + +pipeline_config: roberta-mt5-zero-shot + +seed: + - 42 + - 43 + - 44 + +bask: + jobname: "full_experiment_with_zero_shot" + walltime: '0-24:0:0' + gpu_number: 1 + node_number: 1 + hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache" + + +``` diff --git a/scripts/gen_jobscripts.py b/scripts/gen_jobscripts.py new file mode 100644 index 0000000..7420743 --- /dev/null +++ b/scripts/gen_jobscripts.py @@ -0,0 +1,82 @@ +import os +from pathlib import Path + +from jinja2 import Environment, FileSystemLoader +from jsonargparse import CLI + +from arc_spice.utils import open_yaml_path + +PROJECT_DIR = Path(__file__, "..", "..").resolve() + + +def main(experiment_config_path: str): + """ + _summary_ + + Args: + experiment_config_path: _description_ + """ + experiment_name = experiment_config_path.split("/")[-1].split(".")[0] + experiment_config = open_yaml_path(experiment_config_path) + pipeline_conf_dir = ( + f"{PROJECT_DIR}/config/RTC_configs/{experiment_config['pipeline_config']}.yaml" + ) + data_conf_dir = ( + f"{PROJECT_DIR}/config/data_configs/{experiment_config['data_config']}.yaml" + ) + pipeline_config = open_yaml_path(pipeline_conf_dir) + # Get jinja template + environment = Environment( + loader=FileSystemLoader(PROJECT_DIR / "src" / "arc_spice" / "config") + ) + template = environment.get_template("jobscript_template.sh") + # We don't want to overwrite results + + for index, seed in enumerate(experiment_config["seed"]): + os.makedirs( + f"slurm_scripts/experiments/{experiment_name}/run_{index}", exist_ok=False + ) + for model in pipeline_config: + model_script_dict: dict = experiment_config["bask"] + model_script_dict.update( + { + "script_name": ( + "scripts/single_component_inference.py " + f"{pipeline_conf_dir} {data_conf_dir} {seed}" + f" {experiment_name} {model}" + ), + "job_name": f"{experiment_name}_{model}", + "seed": seed, + } + ) + model_train_script = template.render(model_script_dict) + + with open( + f"slurm_scripts/experiments/{experiment_name}/run_{index}/{model}.sh", + "w", + ) as f: + f.write(model_train_script) + + pipeline_script_dict: dict = experiment_config["bask"] + pipeline_script_dict.update( + { + "script_name": ( + "scripts/pipeline_inference.py " + f"{pipeline_conf_dir} {data_conf_dir} {seed}" + f" {experiment_name}" + ), + "job_name": f"{experiment_name}_full_pipeline", + "seed": seed, + } + ) + pipeline_train_script = template.render(pipeline_script_dict) + + with open( + f"slurm_scripts/experiments/{experiment_name}/run_{index}/full_pipeline.sh", + "w", + ) as f: + f.write(pipeline_train_script) + + +if __name__ == "__main__": + CLI(main) diff --git a/scripts/pipeline_inference.py b/scripts/pipeline_inference.py index 7242704..834cdb6 100644 --- a/scripts/pipeline_inference.py +++ b/scripts/pipeline_inference.py @@ -3,9 +3,9 @@ from jsonargparse import CLI -from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation +from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline from arc_spice.eval.inference_utils import ResultsGetter, run_inference -from arc_spice.utils import open_yaml_path +from arc_spice.utils import open_yaml_path, seed_everything from arc_spice.variational_pipelines.RTC_variational_pipeline import ( RTCVariationalPipeline, ) @@ -13,18 +13,33 @@ OUTPUT_DIR = "outputs" -def main(pipeline_config_pth: str, data_config_pth: str): +def main( + pipeline_config_pth: str, data_config_pth: str, seed: int, experiment_name: str +): """ Run inference on a given pipeline with provided data config Args: pipeline_config_pth: path to pipeline config yaml file data_config_pth: path to data config yaml file + seed: seed for the the inference pass + experiment_name: name of experiment for saving purposes """ + # create save directory -> fail if already exists + data_name = data_config_pth.split("/")[-1].split(".")[0] + pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] + save_loc = ( + f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" + f"{experiment_name}/seed_{seed}/" + ) + # This directory needs to exist for all 4 experiments + os.makedirs(save_loc, exist_ok=True) + # seed experiment + seed_everything(seed=seed) # initialise pipeline data_config = open_yaml_path(data_config_pth) pipeline_config = open_yaml_path(pipeline_config_pth) - data_sets, meta_data = load_multieurlex_for_translation(**data_config) + data_sets, meta_data = load_multieurlex_for_pipeline(**data_config) test_loader = data_sets["test"] rtc_variational_pipeline = RTCVariationalPipeline( model_pars=pipeline_config, data_pars=meta_data @@ -37,11 +52,6 @@ def main(pipeline_config_pth: str, data_config_pth: str): results_getter=results_getter, ) - data_name = data_config_pth.split("/")[-1].split(".")[0] - pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] - save_loc = f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}" - os.makedirs(save_loc, exist_ok=True) - with open(f"{save_loc}/full_pipeline.json", "w") as save_file: json.dump(test_results, save_file) diff --git a/scripts/single_component_inference.py b/scripts/single_component_inference.py index a2c4bfc..fb98747 100644 --- a/scripts/single_component_inference.py +++ b/scripts/single_component_inference.py @@ -16,9 +16,9 @@ from jsonargparse import CLI -from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation +from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline from arc_spice.eval.inference_utils import ResultsGetter, run_inference -from arc_spice.utils import open_yaml_path +from arc_spice.utils import open_yaml_path, seed_everything from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( ClassificationVariationalPipeline, RecognitionVariationalPipeline, @@ -28,19 +28,38 @@ OUTPUT_DIR = "outputs" -def main(pipeline_config_pth: str, data_config_pth: str, model_key: str): +def main( + pipeline_config_pth: str, + data_config_pth: str, + seed: int, + experiment_name: str, + model_key: str, +): """ Run inference on a given pipeline component with provided data config and model key. Args: pipeline_config_pth: path to pipeline config yaml file data_config_pth: path to data config yaml file + seed: seed for the the inference pass + experiment_name: name of experiment for saving purposes model_key: name of model on which to run inference """ + # create save directory -> fail if already exists + data_name = data_config_pth.split("/")[-1].split(".")[0] + pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] + save_loc = ( + f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" + f"{experiment_name}/seed_{seed}/" + ) + # This directory needs to exist for all 4 experiments + os.makedirs(save_loc, exist_ok=True) + # seed experiment + seed_everything(seed=seed) # initialise pipeline data_config = open_yaml_path(data_config_pth) pipeline_config = open_yaml_path(pipeline_config_pth) - data_sets, meta_data = load_multieurlex_for_translation(**data_config) + data_sets, meta_data = load_multieurlex_for_pipeline(**data_config) test_loader = data_sets["test"] if model_key == "ocr": rtc_single_component_pipeline = RecognitionVariationalPipeline( @@ -69,14 +88,6 @@ def main(pipeline_config_pth: str, data_config_pth: str, model_key: str): results_getter=results_getter, ) - data_name = data_config_pth.split("/")[-1].split(".")[0] - pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] - save_loc = ( - f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" - f"single_component" - ) - os.makedirs(save_loc, exist_ok=True) - with open(f"{save_loc}/{model_key}.json", "w") as save_file: json.dump(test_results, save_file) diff --git a/src/arc_spice/config/jobscript_template.sh b/src/arc_spice/config/jobscript_template.sh new file mode 100644 index 0000000..9df076e --- /dev/null +++ b/src/arc_spice/config/jobscript_template.sh @@ -0,0 +1,28 @@ +#!/bin/bash +#SBATCH --account vjgo8416-spice +#SBATCH --qos turing +#SBATCH --job-name {{ job_name }} +#SBATCH --time {{ walltime }} +#SBATCH --nodes {{ node_number }} +#SBATCH --gpus {{ gpu_number }} +#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/{{ job_name }}-%j.out +#SBATCH --cpus-per-gpu 18 + + +# Load required modules here +module purge +module load baskerville +module load bask-apps/live/live +module load Python/3.10.8-GCCcore-12.2.0 + + +# change working directory +cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ + +source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate + +# change huggingface cache to be in project dir rather than user home +export HF_HOME="{{ hf_cache_dir }}" + +# TODO: script uses relative path to project home so must be run from home, fix +python {{ script_name }} diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index d46da5e..871217d 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -133,6 +133,7 @@ def load_multieurlex( level: int, languages: list[str], drop_empty: bool = True, + drop_length: int | None = None, split: str | None = None, ) -> tuple[datasets.DatasetDict, dict[str, Any]]: """ @@ -188,6 +189,11 @@ def load_multieurlex( lambda x: all(x is not None for x in x["text"].values()) ) + if drop_length: + dataset_dict = dataset_dict.filter( + lambda x: len(x["text"][languages[0]]) <= drop_length + ) + # return datasets and metadata return dataset_dict, metadata @@ -197,11 +203,16 @@ def load_multieurlex_for_pipeline( level: int, lang_pair: dict[str, str], drop_empty: bool = True, + drop_length: int | None = None, load_ocr_data: bool = False, ) -> tuple[datasets.DatasetDict, dict[str, Any]]: langs = [lang_pair["source"], lang_pair["target"]] dataset_dict, meta_data = load_multieurlex( - data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty + data_dir=data_dir, + level=level, + languages=langs, + drop_empty=drop_empty, + drop_length=drop_length, ) # instantiate the preprocessor preprocesser = TranslationPreProcesser(lang_pair) diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 0794fa2..bc8c774 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -3,12 +3,12 @@ from typing import Any import torch -from sklearn.metrics import hamming_loss, zero_one_loss +from sklearn.metrics import hamming_loss from torch.utils.data import DataLoader from tqdm import tqdm from arc_spice.data.multieurlex_utils import MultiHot -from arc_spice.eval.translation_error import get_comet_model +from arc_spice.eval.translation_error import conditional_probability, get_comet_model from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( RTCSingleComponentPipeline, ) @@ -17,21 +17,26 @@ ) RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"]) -ClassificationResults = namedtuple( - "ClassificationResults", - [ - "mean_scores", - "hamming_accuracy", - "zero_one_accuracy", - "mean_predicted_entropy", - ], -) + TranslationResults = namedtuple( "TranslationResults", [ "full_output", + "clean_conditional_probability", "comet_score", "weighted_semantic_density", + "mean_entropy", + "sequence_lengths", + ], +) + +ClassificationResults = namedtuple( + "ClassificationResults", + [ + "clean_scores", + "mean_scores", + "hamming_accuracy", + "mean_predicted_entropy", ], ) @@ -78,6 +83,12 @@ def translation_results( source_text = test_row["target_text"] target_text = test_row["target_text"] clean_translation = clean_output["translation"]["full_output"] + clean_entropy: torch.Tensor = clean_output["translation"]["mean_entropy"] + seq_lens: torch.Tensor = var_output["translation"]["sequence_length"] + probs: list[torch.Tensor] = clean_output["translation"]["probs"] + clean_cond_prob = [ + conditional_probability(prob.squeeze()).detach().tolist() for prob in probs + ] # define error model inputs comet_inp = [ @@ -96,6 +107,9 @@ def translation_results( return TranslationResults( comet_score=comet_output["scores"][0], full_output=clean_translation, + clean_conditional_probability=clean_cond_prob, + mean_entropy=clean_entropy, + sequence_lengths=seq_lens, weighted_semantic_density=var_output["translation"][ "weighted_semantic_density" ], @@ -105,19 +119,19 @@ def classification_results( self, test_row: dict[str, Any], var_output: dict[str, dict], - **kwargs, + clean_output: dict[str, dict], ): # ### CLASSIFICATION ### mean_scores: torch.Tensor = var_output["classification"]["mean_scores"] + clean_scores: torch.Tensor = clean_output["classification"]["scores"] preds = torch.round(mean_scores).tolist() labels = self.multihot(test_row["labels"]) hamming_acc = hamming_loss(y_pred=preds, y_true=labels) - zero_one_acc = zero_one_loss(y_pred=preds, y_true=labels) return ClassificationResults( mean_scores=mean_scores.detach().tolist(), + clean_scores=clean_scores, hamming_accuracy=hamming_acc, - zero_one_accuracy=zero_one_acc, mean_predicted_entropy=torch.mean( var_output["classification"]["predicted_entropy"] ).item(), @@ -138,4 +152,5 @@ def run_inference( test_row=inp, ) results.append({inp["celex_id"]: row_results_dict}) + break return results diff --git a/src/arc_spice/eval/translation_error.py b/src/arc_spice/eval/translation_error.py index 510b157..e15a5f4 100644 --- a/src/arc_spice/eval/translation_error.py +++ b/src/arc_spice/eval/translation_error.py @@ -1,3 +1,4 @@ +import torch from comet import download_model, load_from_checkpoint from torcheval.metrics.functional import bleu_score @@ -10,3 +11,7 @@ def get_comet_model(model_path="Unbabel/wmt22-comet-da"): # Load the model checkpoint: comet_model_pth = download_model(model=model_path) return load_from_checkpoint(comet_model_pth) + + +def conditional_probability(prob_scores: torch.Tensor): + return torch.prod(torch.pow(prob_scores, 1 / len(prob_scores)), dim=-1) diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index 9a13646..a0788d1 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -4,10 +4,14 @@ from transformers import pipeline from arc_spice.variational_pipelines.RTC_variational_pipeline import ( - CustomTranslationPipeline, RTCVariationalPipelineBase, ) -from arc_spice.variational_pipelines.utils import dropout_off, dropout_on, set_dropout +from arc_spice.variational_pipelines.utils import ( + CustomTranslationPipeline, + dropout_off, + dropout_on, + set_dropout, +) class RTCSingleComponentPipeline(RTCVariationalPipelineBase): @@ -34,19 +38,6 @@ def __init__( # define objects that are needed and nothing else # naive outputs can remain the same, though only the appropriate outputs will # be outputted - self.naive_outputs = { - "recognition": [ - "outputs", - ], - "translation": [ - "full_output", - "outputs", - "probs", - ], - "classification": [ - "scores", - ], - } self.step_name = step_name self.input_key = input_key self.forward_function = forward_function @@ -98,8 +89,8 @@ def __init__( ): self.set_device() self.ocr = pipeline( - task=model_pars["OCR"]["specific_task"], - model=model_pars["OCR"]["model"], + task=model_pars["ocr"]["specific_task"], + model=model_pars["ocr"]["model"], device=self.device, **kwargs, ) @@ -120,7 +111,7 @@ def __init__( self, model_pars: dict[str, dict[str, str]], n_variational_runs=5, - translation_batch_size=8, + translation_batch_size=4, **kwargs, ): self.set_device() diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index eba0c16..2d9f1ea 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -1,10 +1,10 @@ -import copy from typing import Any import torch -from transformers import TranslationPipeline, pipeline +from transformers import pipeline from arc_spice.variational_pipelines.utils import ( + CustomTranslationPipeline, RTCVariationalPipelineBase, dropout_off, dropout_on, @@ -36,13 +36,13 @@ def __init__( model_pars: dict[str, dict[str, str]], data_pars: dict[str, Any], n_variational_runs=5, - translation_batch_size=8, + translation_batch_size=16, ) -> None: super().__init__(n_variational_runs, translation_batch_size) # defining the pipeline objects self.ocr = pipeline( - task=model_pars["OCR"]["specific_task"], - model=model_pars["OCR"]["model"], + task=model_pars["ocr"]["specific_task"], + model=model_pars["ocr"]["model"], device=self.device, ) self.translator = pipeline( @@ -133,52 +133,3 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]: # on standard call return the clean output def __call__(self, x): return self.clean_inference(x) - - -# Translation pipeline with additional functionality to save logits from fwd pass -class CustomTranslationPipeline(TranslationPipeline): - """ - custom translation pipeline to return the logits with the generated text. Largely - the same as the pytorch version with some additional arguments passed to the - `generate` method. - """ - - def postprocess( - self, - model_outputs: dict, - **postprocess_params, - ): - # model_outputs gets overwritten in the super().postprocess call - # make a copy here so we retain the information we want - raw_out = copy.deepcopy(model_outputs) - processed = super().postprocess(model_outputs, **postprocess_params) - - return { - "translation_text": processed[0]["translation_text"], - "raw_outputs": raw_out, - } - - def _forward(self, model_inputs, **generate_kwargs): - if self.framework == "pt": - in_b, input_length = model_inputs["input_ids"].shape - elif self.framework == "tf": - raise NotImplementedError - - self.check_inputs( - input_length, - generate_kwargs.get("min_length", self.model.config.min_length), - generate_kwargs.get("max_length", self.model.config.max_length), - ) - out = self.model.generate(**model_inputs, **generate_kwargs) - output_ids = out["sequences"] - out_b = output_ids.shape[0] - if self.framework == "pt": - output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) - elif self.framework == "tf": - raise NotImplementedError - - # logits are a tuple of length output_ids[-1]-1 - # each element is a tensor of shape (batch_size, vocab_size) - logits = torch.stack(out["logits"], dim=1) - - return {"output_ids": output_ids, "logits": logits} diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index c9e2692..f37ca30 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -1,11 +1,18 @@ +import copy import logging +import math from abc import ABC, abstractmethod from functools import partial from typing import Any import torch from torch.nn.functional import softmax -from transformers import AutoModelForSequenceClassification, AutoTokenizer, Pipeline +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + Pipeline, + TranslationPipeline, +) logger = logging.Logger("RTC_variational_pipeline") @@ -117,6 +124,7 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8): "full_output", "outputs", "probs", + "mean_entropy", ], "classification": [ "scores", @@ -259,14 +267,14 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]: ] # join these to create the full translation full_translation = ("").join(sentence_translations) - # get softmax of the logits to get token probabilities - softmax_logits = softmax(translator_outputs[0]["raw_outputs"]["logits"], dim=-1) - max_token_scores = torch.max(softmax_logits, dim=-1).values.squeeze(dim=0) # record the output and token probabilities confidence_metrics = [ { "outputs": translator_output["translation_text"], - "probs": max_token_scores, + "probs": translator_output["raw_outputs"]["scores"], + "mean_entropy": torch.mean(translator_output["raw_outputs"]["entropy"]) + .detach() + .tolist(), } for translator_output in translator_outputs ] @@ -373,7 +381,8 @@ def sentence_density( # TODO vectorize # calculate conditional probabilities take power first to avoid NaN - for var_index, var_score in enumerate(var_scores): + for var_index, var_score_out in enumerate(var_scores): + var_score = var_score_out.squeeze() cond_probs[var_index] = torch.prod( torch.pow(var_score, 1 / len(var_score)), dim=-1 ) @@ -381,7 +390,6 @@ def sentence_density( semantic_density = (1 / torch.sum(cond_probs)) * torch.sum( torch.mul(cond_probs, kernel_funcs) ) - return semantic_density.item(), sequence_length def translation_semantic_density( @@ -433,6 +441,7 @@ def translation_semantic_density( { "semantic_densities": densities, "weighted_semantic_density": weighted_average.item(), + "sequence_length": sequence_lengths, } ) @@ -483,3 +492,63 @@ def get_classification_confidence( } ) return var_output + + +# Translation pipeline with additional functionality to save logits from fwd pass +class CustomTranslationPipeline(TranslationPipeline): + """ + custom translation pipeline to return the logits with the generated text. Largely + the same as the pytorch version with some additional arguments passed to the + `generate` method. + """ + + def postprocess( + self, + model_outputs: dict, + **postprocess_params, + ): + # model_outputs gets overwritten in the super().postprocess call + # make a copy here so we retain the information we want + raw_out = copy.deepcopy(model_outputs) + processed = super().postprocess(model_outputs, **postprocess_params) + + return { + "translation_text": processed[0]["translation_text"], + "raw_outputs": raw_out, + } + + def _forward(self, model_inputs, **generate_kwargs): + if self.framework == "pt": + in_b, input_length = model_inputs["input_ids"].shape + elif self.framework == "tf": + raise NotImplementedError + + self.check_inputs( + input_length, + generate_kwargs.get("min_length", self.model.config.min_length), + generate_kwargs.get("max_length", self.model.config.max_length), + ) + out = self.model.generate(**model_inputs, **generate_kwargs) + output_ids = out["sequences"] + out_b = output_ids.shape[0] + if self.framework == "pt": + output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) + elif self.framework == "tf": + raise NotImplementedError + + # logits are a tuple of length output_ids[-1]-1 + # each element is a tensor of shape (batch_size, vocab_size) + logits = torch.stack(out["logits"], dim=1) + # get softmax of the logits to get token probabilities + softmax_logits = softmax(logits, dim=-1) + vocab_size = softmax_logits.shape[-1] + normalised_entropy = torch.distributions.Categorical( + probs=softmax_logits + ).entropy() / math.log(vocab_size) + max_token_scores = torch.max(softmax_logits, dim=-1).values + + return { + "output_ids": output_ids, + "scores": max_token_scores, + "entropy": normalised_entropy, + }