Skip to content

Commit

Permalink
fix: warn and fix when topk parameter setting is wrong
Browse files Browse the repository at this point in the history
  • Loading branch information
TingquanGao committed Jun 6, 2022
1 parent 80358ef commit a43f853
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions ppcls/metric/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, topk=(1, 5)):

def reset(self):
self.avg_meters = {
"top{}".format(k): AverageMeter("top{}".format(k))
f"top{k}": AverageMeter(f"top{k}")
for k in self.topk
}

Expand All @@ -55,11 +55,14 @@ def forward(self, x, label):
if output_dims < k:
msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed."
logger.warning(msg)
self.topk.pop(idx)
self.avg_meters.pop(f"top{k}")
continue
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
x, label, k=k)
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])
metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k)
self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"],
x.shape[0])

self.topk = filter(lambda k: k <= output_dims, self.topk)

return metric_dict


Expand Down

0 comments on commit a43f853

Please sign in to comment.