-
Notifications
You must be signed in to change notification settings - Fork 32k
Add basic eval table logging for WandbCallback #31050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
27d7393
4a68300
bf1f801
e4dec9e
c84d0c4
a80db7a
b8d5c6e
b176f29
2d9b503
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -692,10 +692,19 @@ def print_to_file(s): | |
|
|
||
| class WandbCallback(TrainerCallback): | ||
| """ | ||
| A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). | ||
| A [`TrainerCallback`] that logs metrics, media, evals, and model checkpoints to [Weight and Biases](https://www.wandb.com/). | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| def __init__( | ||
| self, | ||
| *, | ||
| trainer=None, | ||
| tokenizer=None, | ||
| dataset=None, | ||
| num_samples: int = 10, | ||
| freq: int = 1, | ||
| ignore_tokens: Optional[list] = None, | ||
| ): | ||
| has_wandb = is_wandb_available() | ||
| if not has_wandb: | ||
| raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.") | ||
|
|
@@ -704,6 +713,48 @@ def __init__(self): | |
|
|
||
| self._wandb = wandb | ||
| self._initialized = False | ||
|
|
||
| # Setup for evals if user requests it | ||
| if os.getenv("WANDB_LOG_EVALS"): | ||
| if trainer is not None: | ||
| self.trainer = trainer | ||
|
|
||
| if tokenizer is None: | ||
| tokenizer = self.trainer.tokenizer | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes |
||
| self.tokenizer = tokenizer | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to set even if |
||
|
|
||
| if dataset is None: | ||
| dataset = self.trainer.eval_dataset | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here - assumes self.trainer and self.trainer.dataset is not None |
||
|
|
||
| try: | ||
| sampled_dataset = dataset.select(range(num_samples)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes |
||
| except IndexError as e: | ||
| print(f"WARNING: Could not get those indices: {e=}") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should log rather than print
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we make this a bit clearer - the user never specifies indices. so it's a bit weird to refer to them as "those indices". Maybe something along the lines of "Could not select {num_sample=} rows from the dataset" |
||
| sampled_dataset = dataset | ||
|
|
||
| self.sample_dataset = sampled_dataset | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to store both the full and sampled dataset? |
||
| self.freq = freq | ||
|
|
||
| if ignore_tokens is None: | ||
| ignore_tokens = [-100] | ||
|
|
||
| padding_token_id = self.tokenizer.pad_token_id | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assumes tokenizer is not None. Confusingly, the |
||
|
|
||
| def replace_ignored_tokens(a): | ||
| if isinstance(a, np.ndarray): | ||
| mask = np.isin(a, ignore_tokens) | ||
| elif isinstance(a, torch.Tensor): | ||
| mask = torch.isin(a, torch.tensor(ignore_tokens, dtype=a.dtype)) | ||
| else: | ||
| raise TypeError(f"Unsupported type replace token type {type(a)}") | ||
|
|
||
| a[mask] = padding_token_id | ||
| return a | ||
|
|
||
| self._replace_ignored_tokens_func = replace_ignored_tokens | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this functionality in the callback? |
||
|
|
||
| self._collected_eval_rows = [] | ||
|
|
||
| # log model | ||
| if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}): | ||
| DeprecationWarning( | ||
|
|
@@ -933,6 +984,46 @@ def on_predict(self, args, state, control, metrics, **kwargs): | |
| metrics = rewrite_logs(metrics) | ||
| self._wandb.log(metrics) | ||
|
|
||
| def on_evaluate(self, args, state, control, **kwargs): | ||
| if os.getenv("WANDB_LOG_EVALS"): | ||
| eval_loop_output = self.trainer.eval_loop_output | ||
|
||
|
|
||
| inputs = eval_loop_output.inputs | ||
| decoded_inputs = None | ||
| if inputs is not None: | ||
| decoded_inputs = self.tokenizer.batch_decode(inputs, skip_special_tokens=True) | ||
|
|
||
| preds = eval_loop_output.predictions | ||
| outputs = preds.argmax(axis=-1) | ||
| decoded_outputs = None | ||
| if outputs is not None: | ||
| decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
|
|
||
| expected = eval_loop_output.label_ids | ||
| decoded_expected = None | ||
| if expected is not None: | ||
| expected = self._replace_ignored_tokens_func(expected) | ||
| decoded_expected = self.tokenizer.batch_decode(expected, skip_special_tokens=True) | ||
|
|
||
| # Determine which fields are available | ||
| available_fields = [ | ||
| ("decoded_inputs", decoded_inputs), | ||
| ("decoded_outputs", decoded_outputs), | ||
| ("decoded_expected", decoded_expected), | ||
| ] | ||
| available_fields = [(name, value) for name, value in available_fields if value is not None] | ||
|
|
||
| # Create rows using only available fields | ||
| for items in zip(*(value for _, value in available_fields)): | ||
| row = {name: item for (name, _), item in zip(available_fields, items)} | ||
| self._collected_eval_rows.append(row) | ||
|
|
||
| if self._collected_eval_rows: | ||
| table = self._wandb.Table(columns=list(row.keys())) | ||
| for row in self._collected_eval_rows: | ||
| table.add_data(*row.values()) | ||
| self._wandb.log({"evaluation_table": table}) | ||
|
|
||
|
|
||
| class CometCallback(TrainerCallback): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -570,6 +570,14 @@ def __init__( | |
| ) | ||
| default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) | ||
| callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks | ||
|
|
||
| # Add a reference to the trainer in case callbacks need it | ||
| def init_callback(cb): | ||
| cb.trainer = self | ||
| return cb | ||
|
|
||
| callbacks = [init_callback(cb) for cb in callbacks] | ||
|
||
|
|
||
| self.callback_handler = CallbackHandler( | ||
| callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler | ||
| ) | ||
|
|
@@ -3660,6 +3668,7 @@ def evaluate( | |
| ) | ||
| ) | ||
|
|
||
| self.eval_loop_output = output | ||
| self.log(output.metrics) | ||
|
|
||
| if DebugOption.TPU_METRICS_DEBUG in self.args.debug: | ||
|
|
@@ -3939,7 +3948,9 @@ def evaluation_loop( | |
| if not key.startswith(f"{metric_key_prefix}_"): | ||
| metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) | ||
|
|
||
| return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) | ||
| return EvalLoopOutput( | ||
| predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples, inputs=all_inputs | ||
| ) | ||
|
||
|
|
||
| def _nested_gather(self, tensors, name=None): | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a docstring for all these args. In particular
num_samplesandfreqwhich aren't obvious