diff --git a/src/sparseml/transformers/text_classification.py b/src/sparseml/transformers/text_classification.py index 0e001bcfcb9..f21d28553e7 100644 --- a/src/sparseml/transformers/text_classification.py +++ b/src/sparseml/transformers/text_classification.py @@ -666,6 +666,7 @@ def compute_metrics(p: EvalPrediction): if data_args.task_name == "mnli": tasks.append("mnli-mm") eval_datasets.append(raw_datasets["validation_mismatched"]) + combined = {} for eval_dataset, task in zip(eval_datasets, tasks): metrics = trainer.evaluate(eval_dataset=eval_dataset) @@ -677,8 +678,13 @@ def compute_metrics(p: EvalPrediction): ) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + if task == "mnli-mm": + metrics = {k + "_mm": v for k, v in metrics.items()} + if "mnli" in task: + combined.update(metrics) + trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + trainer.save_metrics("eval", combined if "mnli" in task else metrics) if training_args.do_predict: _LOGGER.info("*** Predict ***")