From 27d7393bf508e75eff882d06492a6cb4982821a1 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Sun, 26 May 2024 22:02:35 -0400 Subject: [PATCH 1/9] add inputs to EvalLoopOutputs --- src/transformers/trainer.py | 4 +++- src/transformers/trainer_utils.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 71c3ee43af2c..86f8397afd6f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3939,7 +3939,9 @@ def evaluation_loop( if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) - return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) + return EvalLoopOutput( + predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples, inputs=all_inputs + ) def _nested_gather(self, tensors, name=None): """ diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 5f6900658840..1eaa79d9609a 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -188,6 +188,7 @@ class EvalLoopOutput(NamedTuple): label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] metrics: Optional[Dict[str, float]] num_samples: Optional[int] + inputs: Union[np.ndarray, Tuple[np.ndarray]] class PredictionOutput(NamedTuple): From 4a68300f5de69836ffa6204e5dd5f4378e3bb3d3 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Sun, 26 May 2024 22:02:58 -0400 Subject: [PATCH 2/9] add a ref to the trainer that callbacks are attached to --- src/transformers/trainer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 86f8397afd6f..ef09a7c7dd38 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -570,6 +570,14 @@ def __init__( ) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + + # Add a reference to the trainer in case callbacks need it + def init_callback(cb): + cb.trainer = self + return cb + + callbacks = [init_callback(cb) for cb in callbacks] + self.callback_handler = CallbackHandler( callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler ) From bf1f801063069396b19cffc8d9a14b7e2ed6b87e Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Sun, 26 May 2024 22:03:25 -0400 Subject: [PATCH 3/9] add basic wandb tables evals logging of inputs, outputs, and expected --- .../integrations/integration_utils.py | 83 ++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 29528feb515c..8eeb5a3ad9db 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -695,7 +695,16 @@ class WandbCallback(TrainerCallback): A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). """ - def __init__(self): + def __init__( + self, + *, + trainer: Optional["Trainer"] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + dataset: Optional["Dataset"] = None, + num_samples: int = 10, + freq: int = 1, + ignore_tokens: Optional[list] = None, + ): has_wandb = is_wandb_available() if not has_wandb: raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.") @@ -704,6 +713,48 @@ def __init__(self): self._wandb = wandb self._initialized = False + + # Setup for evals if user requests it + if os.getenv("WANDB_LOG_EVALS"): + if trainer is not None: + self.trainer = trainer + + if tokenizer is None: + tokenizer = self.trainer.tokenizer + self.tokenizer = tokenizer + + if dataset is None: + dataset = self.trainer.eval_dataset + + try: + sampled_dataset = dataset.select(range(num_samples)) + except IndexError as e: + print(f"WARNING: Could not get those indices: {e=}") + sampled_dataset = dataset + + self.sample_dataset = sampled_dataset + self.freq = freq + + if ignore_tokens is None: + ignore_tokens = [-100] + + padding_token_id = self.tokenizer.pad_token_id + + def replace_ignored_tokens(a): + if isinstance(a, np.ndarray): + mask = np.isin(a, ignore_tokens) + elif isinstance(a, torch.Tensor): + mask = torch.isin(a, torch.tensor(ignore_tokens, dtype=a.dtype)) + else: + raise TypeError(f"Unsupported type replace token type {type(a)}") + + a[mask] = padding_token_id + return a + + self._replace_ignored_tokens_func = replace_ignored_tokens + + self._collected_eval_rows = [] + # log model if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}): DeprecationWarning( @@ -933,6 +984,36 @@ def on_predict(self, args, state, control, metrics, **kwargs): metrics = rewrite_logs(metrics) self._wandb.log(metrics) + def on_evaluate(self, args, state, control, **kwargs): + if os.getenv("WANDB_LOG_EVALS"): + eval_loop_output = self.trainer.eval_loop_output + + inputs = eval_loop_output.inputs + decoded_inputs = self.tokenizer.batch_decode(inputs, skip_special_tokens=True) + + preds = eval_loop_output.predictions + outputs = preds.argmax(axis=-1) + decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + expected = eval_loop_output.label_ids + expected = self._replace_ignored_tokens_func(expected) + decoded_expected = self.tokenizer.batch_decode(expected, skip_special_tokens=True) + + # un-batch and log rows + for dec_inp, dec_out, dec_exp in zip(decoded_inputs, decoded_outputs, decoded_expected): + row = { + "decoded_inputs": dec_inp, + "decoded_outputs": dec_out, + "decoded_expected": dec_exp, + } + self._collected_eval_rows.append(row) + + table = self._wandb.Table(columns=list(row.keys())) + for row in self._collected_eval_rows: + table.add_data(*row.values()) + + self._wandb.log({"evaluation_table": table}) + class CometCallback(TrainerCallback): """ From e4dec9ead5645bc52d11badde00cba319c15c15f Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Sun, 26 May 2024 22:06:42 -0400 Subject: [PATCH 4/9] types --- src/transformers/integrations/integration_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 8eeb5a3ad9db..3fcc5800a92f 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -698,9 +698,9 @@ class WandbCallback(TrainerCallback): def __init__( self, *, - trainer: Optional["Trainer"] = None, - tokenizer: Optional["PreTrainedTokenizerBase"] = None, - dataset: Optional["Dataset"] = None, + trainer = None, + tokenizer = None, + dataset = None, num_samples: int = 10, freq: int = 1, ignore_tokens: Optional[list] = None, From c84d0c41c961b549280639e0ed8fe53dc3c52751 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Mon, 27 May 2024 12:33:53 -0400 Subject: [PATCH 5/9] lint --- src/transformers/integrations/integration_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 3fcc5800a92f..1b36075c550b 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -698,9 +698,9 @@ class WandbCallback(TrainerCallback): def __init__( self, *, - trainer = None, - tokenizer = None, - dataset = None, + trainer=None, + tokenizer=None, + dataset=None, num_samples: int = 10, freq: int = 1, ignore_tokens: Optional[list] = None, From a80db7a914bff5c93b7428316ad9722ae22a6749 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Mon, 27 May 2024 13:00:27 -0400 Subject: [PATCH 6/9] update docstring to mention evals --- src/transformers/integrations/integration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 1b36075c550b..ee89f4f08f0f 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -692,7 +692,7 @@ def print_to_file(s): class WandbCallback(TrainerCallback): """ - A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). + A [`TrainerCallback`] that logs metrics, media, evals, and model checkpoints to [Weight and Biases](https://www.wandb.com/). """ def __init__( From b8d5c6ee12939d80d99f40c022b5c781ea10c4d9 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Mon, 24 Jun 2024 16:24:44 -0400 Subject: [PATCH 7/9] add missing eval_loop_output --- src/transformers/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ef09a7c7dd38..88b2ce32a6c7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3668,6 +3668,7 @@ def evaluate( ) ) + self.eval_loop_output = output self.log(output.metrics) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: From b176f29cdad0f60307a82fe0054f0b2efd909b9a Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Mon, 24 Jun 2024 16:25:31 -0400 Subject: [PATCH 8/9] add guards for empty fields --- .../integrations/integration_utils.py | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index ee89f4f08f0f..72ac650eebdc 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -989,30 +989,40 @@ def on_evaluate(self, args, state, control, **kwargs): eval_loop_output = self.trainer.eval_loop_output inputs = eval_loop_output.inputs - decoded_inputs = self.tokenizer.batch_decode(inputs, skip_special_tokens=True) + decoded_inputs = None + if inputs is not None: + decoded_inputs = self.tokenizer.batch_decode(inputs, skip_special_tokens=True) preds = eval_loop_output.predictions outputs = preds.argmax(axis=-1) - decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + decoded_outputs = None + if outputs is not None: + decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) expected = eval_loop_output.label_ids - expected = self._replace_ignored_tokens_func(expected) - decoded_expected = self.tokenizer.batch_decode(expected, skip_special_tokens=True) - - # un-batch and log rows - for dec_inp, dec_out, dec_exp in zip(decoded_inputs, decoded_outputs, decoded_expected): - row = { - "decoded_inputs": dec_inp, - "decoded_outputs": dec_out, - "decoded_expected": dec_exp, - } - self._collected_eval_rows.append(row) + decoded_expected = None + if expected is not None: + expected = self._replace_ignored_tokens_func(expected) + decoded_expected = self.tokenizer.batch_decode(expected, skip_special_tokens=True) + + # Determine which fields are available + available_fields = [ + ("decoded_inputs", decoded_inputs), + ("decoded_outputs", decoded_outputs), + ("decoded_expected", decoded_expected) + ] + available_fields = [(name, value) for name, value in available_fields if value is not None] - table = self._wandb.Table(columns=list(row.keys())) - for row in self._collected_eval_rows: - table.add_data(*row.values()) + # Create rows using only available fields + for items in zip(*(value for _, value in available_fields)): + row = {name: item for (name, _), item in zip(available_fields, items)} + self._collected_eval_rows.append(row) - self._wandb.log({"evaluation_table": table}) + if self._collected_eval_rows: + table = self._wandb.Table(columns=list(row.keys())) + for row in self._collected_eval_rows: + table.add_data(*row.values()) + self._wandb.log({"evaluation_table": table}) class CometCallback(TrainerCallback): From 2d9b503912cdcb11f6ba9b07895774ad66e1bc42 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Mon, 24 Jun 2024 16:30:22 -0400 Subject: [PATCH 9/9] fmt --- src/transformers/integrations/integration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 72ac650eebdc..72f667385c5d 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -1009,7 +1009,7 @@ def on_evaluate(self, args, state, control, **kwargs): available_fields = [ ("decoded_inputs", decoded_inputs), ("decoded_outputs", decoded_outputs), - ("decoded_expected", decoded_expected) + ("decoded_expected", decoded_expected), ] available_fields = [(name, value) for name, value in available_fields if value is not None]