Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 14, 2023
1 parent 2179e6b commit 91ea1a7
Showing 1 changed file with 4 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,20 @@ def setup_metric(self, data_cfg):

kwargs = {}
if hasattr(data_cfg.metric, 'average'):
if metric_name != 'rouge':
if metric_name != 'rouge':
kwargs['average'] = data_cfg.metric.average

if hasattr(data_cfg.metric, 'num_classes'):
if metric_name != 'rouge':
if metric_name != 'rouge':
kwargs['num_classes'] = data_cfg.metric.num_classes

# GLUE will not have a "src_file_name" attribute and will always have only a single metric.
if hasattr(data_cfg, "src_file_name") or hasattr(data_cfg, "file_names"):
if hasattr(data_cfg, "src_file_name") and isinstance(data_cfg.src_file_name, ListConfig):
# We pass average and num_classes to the metric constructor via kwargs.
metric = [
metric_class(**kwargs)
for _ in range(len(data_cfg.src_file_name))
]
metric = [metric_class(**kwargs) for _ in range(len(data_cfg.src_file_name))]
elif hasattr(data_cfg, "file_names") and isinstance(data_cfg.file_names, ListConfig):
metric = [
metric_class(**kwargs)
for _ in range(len(data_cfg.file_names))
]
metric = [metric_class(**kwargs) for _ in range(len(data_cfg.file_names))]
else:
metric = [metric_class(**kwargs)]
else:
Expand Down

0 comments on commit 91ea1a7

Please sign in to comment.