diff --git a/examples/nlp/language_modeling/conf/megatron_t5_finetune.yaml b/examples/nlp/language_modeling/conf/megatron_t5_finetune.yaml index ab9939af518f..2ba68cbc5979 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_finetune.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_finetune.yaml @@ -87,7 +87,7 @@ model: add_bos_to_input: ${data.train_ds.add_bos_to_input} add_eos_to_input: ${data.train_ds.add_eos_to_input} metric: - name: "exact_string_match" # Name of the evaluation metric to use. + name: "exact_string_match" # Name of the evaluation metric to use. Supported metrics: [`exact_string_match`, `rouge`, `pearson_corr_coef`, `spearman_corr_coef`, `f1`, `accuracy`, `average_precision`] average: micro # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. num_classes: null # Number of classes for the metric. Works only for 'F1', 'accuracy' and 'average_precision' etc. Refer to torchmetrics for metrics where this is supported. class_labels: null # If the targets in your dataset are strings and not integers/float, you need to provide a list of class labels (size = num_classes) so we can convert from strings to integer categories to compute the metric. diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index fb1fe83ee68e..9fce0d52c4a1 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -106,24 +106,36 @@ def setup_metric(self, data_cfg): ) metric_name = data_cfg.metric.name - metric = MetricStringToTorchMetric[metric_name] + metric_class = MetricStringToTorchMetric[metric_name] + # GLUE will not have a "src_file_name" attribute and will always have only a single metric. if hasattr(data_cfg, "src_file_name") or hasattr(data_cfg, "file_names"): - if hasattr(data_cfg, "src_file_name") and isinstance(data_cfg.src_file_name, ListConfig): - # We pass average and num_classes to the metric constructor via kwargs even if they don't exist for each metric. + if ( + hasattr(data_cfg, "src_file_name") + and isinstance(data_cfg.src_file_name, ListConfig) + and metric_name != 'rouge' + ): metric = [ - metric(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes) + metric_class(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes) for _ in range(len(data_cfg.src_file_name)) ] - elif hasattr(data_cfg, "file_names") and isinstance(data_cfg.file_names, ListConfig): + elif ( + hasattr(data_cfg, "file_names") + and isinstance(data_cfg.file_names, ListConfig) + and metric_name != 'rouge' + ): metric = [ - metric(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes) + metric_class(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes) for _ in range(len(data_cfg.file_names)) ] + elif hasattr(data_cfg, "src_file_name") and isinstance(data_cfg.src_file_name, ListConfig): + metric = [metric_class() for _ in range(len(data_cfg.src_file_name))] + elif hasattr(data_cfg, "file_names") and isinstance(data_cfg.file_names, ListConfig): + metric = [metric_class() for _ in range(len(data_cfg.file_names))] else: - metric = [metric(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes)] + metric = [metric_class(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes)] else: - metric = [metric()] # GLUE does need to specify average or num_classes. + metric = [metric_class()] # GLUE does need to specify average or num_classes. return metric, metric_name @@ -221,7 +233,7 @@ def cast_for_metric(self, pred, label, metric_name, class_labels=None, labels_ar else: pred = class_labels.index(pred) if label not in class_labels: - raise ValueError(f"Ground truth labe; {label} is not in the class labels list : {class_labels}") + raise ValueError(f"Ground truth label {label} is not in the class labels list : {class_labels}") label = class_labels.index(label) pred = torch.LongTensor([pred]).to(self.device) label = torch.LongTensor([label]).to(self.device)